1 | //===- MemRefUtils.h - Memref helpers to invoke MLIR JIT code ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Utils for MLIR ABI interfacing with frameworks. |
10 | // |
11 | // The templated free functions below make it possible to allocate dense |
12 | // contiguous buffers with shapes that interoperate properly with the MLIR |
13 | // codegen ABI. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "mlir/ExecutionEngine/CRunnerUtils.h" |
18 | #include "mlir/Support/LLVM.h" |
19 | #include "llvm/ADT/ArrayRef.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | |
22 | #include "llvm/Support/raw_ostream.h" |
23 | |
24 | #include <algorithm> |
25 | #include <array> |
26 | #include <cassert> |
27 | #include <climits> |
28 | #include <functional> |
29 | #include <initializer_list> |
30 | #include <memory> |
31 | #include <optional> |
32 | |
33 | #ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ |
34 | #define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ |
35 | |
36 | namespace mlir { |
37 | using AllocFunType = llvm::function_ref<void *(size_t)>; |
38 | |
39 | namespace detail { |
40 | |
41 | /// Given a shape with sizes greater than 0 along all dimensions, returns the |
42 | /// distance, in number of elements, between a slice in a dimension and the next |
43 | /// slice in the same dimension. |
44 | /// e.g. shape[3, 4, 5] -> strides[20, 5, 1] |
45 | template <size_t N> |
46 | inline std::array<int64_t, N> makeStrides(ArrayRef<int64_t> shape) { |
47 | assert(shape.size() == N && "expect shape specification to match rank" ); |
48 | std::array<int64_t, N> res; |
49 | int64_t running = 1; |
50 | for (int64_t idx = N - 1; idx >= 0; --idx) { |
51 | assert(shape[idx] && "size must be non-negative for all shape dimensions" ); |
52 | res[idx] = running; |
53 | running *= shape[idx]; |
54 | } |
55 | return res; |
56 | } |
57 | |
58 | /// Build a `StridedMemRefDescriptor<T, N>` that matches the MLIR ABI. |
59 | /// This is an implementation detail that is kept in sync with MLIR codegen |
60 | /// conventions. Additionally takes a `shapeAlloc` array which |
61 | /// is used instead of `shape` to allocate "more aligned" data and compute the |
62 | /// corresponding strides. |
63 | template <int N, typename T> |
64 | typename std::enable_if<(N >= 1), StridedMemRefType<T, N>>::type |
65 | makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape, |
66 | ArrayRef<int64_t> shapeAlloc) { |
67 | assert(shape.size() == N); |
68 | assert(shapeAlloc.size() == N); |
69 | StridedMemRefType<T, N> descriptor; |
70 | descriptor.basePtr = static_cast<T *>(ptr); |
71 | descriptor.data = static_cast<T *>(alignedPtr); |
72 | descriptor.offset = 0; |
73 | std::copy(shape.begin(), shape.end(), descriptor.sizes); |
74 | auto strides = makeStrides<N>(shapeAlloc); |
75 | std::copy(strides.begin(), strides.end(), descriptor.strides); |
76 | return descriptor; |
77 | } |
78 | |
79 | /// Build a `StridedMemRefDescriptor<T, 0>` that matches the MLIR ABI. |
80 | /// This is an implementation detail that is kept in sync with MLIR codegen |
81 | /// conventions. Additionally takes a `shapeAlloc` array which |
82 | /// is used instead of `shape` to allocate "more aligned" data and compute the |
83 | /// corresponding strides. |
84 | template <int N, typename T> |
85 | typename std::enable_if<(N == 0), StridedMemRefType<T, 0>>::type |
86 | makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape = {}, |
87 | ArrayRef<int64_t> shapeAlloc = {}) { |
88 | assert(shape.size() == N); |
89 | assert(shapeAlloc.size() == N); |
90 | StridedMemRefType<T, 0> descriptor; |
91 | descriptor.basePtr = static_cast<T *>(ptr); |
92 | descriptor.data = static_cast<T *>(alignedPtr); |
93 | descriptor.offset = 0; |
94 | return descriptor; |
95 | } |
96 | |
97 | /// Align `nElements` of type T with an optional `alignment`. |
98 | /// This replaces a portable `posix_memalign`. |
99 | /// `alignment` must be a power of 2 and greater than the size of T. By default |
100 | /// the alignment is sizeof(T). |
101 | template <typename T> |
102 | std::pair<T *, T *> |
103 | allocAligned(size_t nElements, AllocFunType allocFun = &::malloc, |
104 | std::optional<uint64_t> alignment = std::optional<uint64_t>()) { |
105 | assert(sizeof(T) <= UINT_MAX && "Elemental type overflows" ); |
106 | auto size = nElements * sizeof(T); |
107 | auto desiredAlignment = alignment.value_or(u: nextPowerOf2(n: sizeof(T))); |
108 | assert((desiredAlignment & (desiredAlignment - 1)) == 0); |
109 | assert(desiredAlignment >= sizeof(T)); |
110 | T *data = reinterpret_cast<T *>(allocFun(size + desiredAlignment)); |
111 | uintptr_t addr = reinterpret_cast<uintptr_t>(data); |
112 | uintptr_t rem = addr % desiredAlignment; |
113 | T *alignedData = (rem == 0) |
114 | ? data |
115 | : reinterpret_cast<T *>(addr + (desiredAlignment - rem)); |
116 | assert(reinterpret_cast<uintptr_t>(alignedData) % desiredAlignment == 0); |
117 | return std::make_pair(data, alignedData); |
118 | } |
119 | |
120 | } // namespace detail |
121 | |
122 | //===----------------------------------------------------------------------===// |
123 | // Public API |
124 | //===----------------------------------------------------------------------===// |
125 | |
126 | /// Convenient callback to "visit" a memref element by element. |
127 | /// This takes a reference to an individual element as well as the coordinates. |
128 | /// It can be used in conjuction with a StridedMemrefIterator. |
129 | template <typename T> |
130 | using ElementWiseVisitor = llvm::function_ref<void(T &ptr, ArrayRef<int64_t>)>; |
131 | |
132 | /// Owning MemRef type that abstracts over the runtime type for ranked strided |
133 | /// memref. |
134 | template <typename T, unsigned Rank> |
135 | class OwningMemRef { |
136 | public: |
137 | using DescriptorType = StridedMemRefType<T, Rank>; |
138 | using FreeFunType = std::function<void(DescriptorType)>; |
139 | |
140 | /// Allocate a new dense StridedMemrefRef with a given `shape`. An optional |
141 | /// `shapeAlloc` array can be supplied to "pad" every dimension individually. |
142 | /// If an ElementWiseVisitor is provided, it will be used to initialize the |
143 | /// data, else the memory will be zero-initialized. The alloc and free method |
144 | /// used to manage the data allocation can be optionally provided, and default |
145 | /// to malloc/free. |
146 | OwningMemRef( |
147 | ArrayRef<int64_t> shape, ArrayRef<int64_t> shapeAlloc = {}, |
148 | ElementWiseVisitor<T> init = {}, |
149 | std::optional<uint64_t> alignment = std::optional<uint64_t>(), |
150 | AllocFunType allocFun = &::malloc, |
151 | std::function<void(StridedMemRefType<T, Rank>)> freeFun = |
152 | [](StridedMemRefType<T, Rank> descriptor) { |
153 | ::free(ptr: descriptor.data); |
154 | }) |
155 | : freeFunc(freeFun) { |
156 | if (shapeAlloc.empty()) |
157 | shapeAlloc = shape; |
158 | assert(shape.size() == Rank); |
159 | assert(shapeAlloc.size() == Rank); |
160 | for (unsigned i = 0; i < Rank; ++i) |
161 | assert(shape[i] <= shapeAlloc[i] && |
162 | "shapeAlloc must be greater than or equal to shape" ); |
163 | int64_t nElements = 1; |
164 | for (int64_t s : shapeAlloc) |
165 | nElements *= s; |
166 | auto [data, alignedData] = |
167 | detail::allocAligned<T>(nElements, allocFun, alignment); |
168 | descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData, |
169 | shape, shapeAlloc); |
170 | if (init) { |
171 | for (StridedMemrefIterator<T, Rank> it = descriptor.begin(), |
172 | end = descriptor.end(); |
173 | it != end; ++it) |
174 | init(*it, it.getIndices()); |
175 | } else { |
176 | memset(descriptor.data, 0, |
177 | nElements * sizeof(T) + |
178 | alignment.value_or(u: detail::nextPowerOf2(n: sizeof(T)))); |
179 | } |
180 | } |
181 | /// Take ownership of an existing descriptor with a custom deleter. |
182 | OwningMemRef(DescriptorType descriptor, FreeFunType freeFunc) |
183 | : freeFunc(freeFunc), descriptor(descriptor) {} |
184 | ~OwningMemRef() { |
185 | if (freeFunc) |
186 | freeFunc(descriptor); |
187 | } |
188 | OwningMemRef(const OwningMemRef &) = delete; |
189 | OwningMemRef &operator=(const OwningMemRef &) = delete; |
190 | OwningMemRef &operator=(const OwningMemRef &&other) { |
191 | freeFunc = other.freeFunc; |
192 | descriptor = other.descriptor; |
193 | other.freeFunc = nullptr; |
194 | memset(&other.descriptor, 0, sizeof(other.descriptor)); |
195 | } |
196 | OwningMemRef(OwningMemRef &&other) { *this = std::move(other); } |
197 | |
198 | DescriptorType &operator*() { return descriptor; } |
199 | DescriptorType *operator->() { return &descriptor; } |
200 | T &operator[](std::initializer_list<int64_t> indices) { |
201 | return descriptor[indices]; |
202 | } |
203 | |
204 | private: |
205 | /// Custom deleter used to release the data buffer manager with the descriptor |
206 | /// below. |
207 | FreeFunType freeFunc; |
208 | /// The descriptor is an instance of StridedMemRefType<T, rank>. |
209 | DescriptorType descriptor; |
210 | }; |
211 | |
212 | } // namespace mlir |
213 | |
214 | #endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_ |
215 | |