1//===- MemRefUtils.h - Memref helpers to invoke MLIR JIT code ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Utils for MLIR ABI interfacing with frameworks.
10//
11// The templated free functions below make it possible to allocate dense
12// contiguous buffers with shapes that interoperate properly with the MLIR
13// codegen ABI.
14//
15//===----------------------------------------------------------------------===//
16
17#include "mlir/ExecutionEngine/CRunnerUtils.h"
18#include "mlir/Support/LLVM.h"
19#include "llvm/ADT/ArrayRef.h"
20#include "llvm/ADT/STLExtras.h"
21
22#include "llvm/Support/raw_ostream.h"
23
24#include <algorithm>
25#include <array>
26#include <cassert>
27#include <climits>
28#include <functional>
29#include <initializer_list>
30#include <memory>
31#include <optional>
32
33#ifndef MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
34#define MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
35
36namespace mlir {
37using AllocFunType = llvm::function_ref<void *(size_t)>;
38
39namespace detail {
40
41/// Given a shape with sizes greater than 0 along all dimensions, returns the
42/// distance, in number of elements, between a slice in a dimension and the next
43/// slice in the same dimension.
44/// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
45template <size_t N>
46inline std::array<int64_t, N> makeStrides(ArrayRef<int64_t> shape) {
47 assert(shape.size() == N && "expect shape specification to match rank");
48 std::array<int64_t, N> res;
49 int64_t running = 1;
50 for (int64_t idx = N - 1; idx >= 0; --idx) {
51 assert(shape[idx] && "size must be non-negative for all shape dimensions");
52 res[idx] = running;
53 running *= shape[idx];
54 }
55 return res;
56}
57
58/// Build a `StridedMemRefDescriptor<T, N>` that matches the MLIR ABI.
59/// This is an implementation detail that is kept in sync with MLIR codegen
60/// conventions. Additionally takes a `shapeAlloc` array which
61/// is used instead of `shape` to allocate "more aligned" data and compute the
62/// corresponding strides.
63template <int N, typename T>
64typename std::enable_if<(N >= 1), StridedMemRefType<T, N>>::type
65makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape,
66 ArrayRef<int64_t> shapeAlloc) {
67 assert(shape.size() == N);
68 assert(shapeAlloc.size() == N);
69 StridedMemRefType<T, N> descriptor;
70 descriptor.basePtr = static_cast<T *>(ptr);
71 descriptor.data = static_cast<T *>(alignedPtr);
72 descriptor.offset = 0;
73 std::copy(shape.begin(), shape.end(), descriptor.sizes);
74 auto strides = makeStrides<N>(shapeAlloc);
75 std::copy(strides.begin(), strides.end(), descriptor.strides);
76 return descriptor;
77}
78
79/// Build a `StridedMemRefDescriptor<T, 0>` that matches the MLIR ABI.
80/// This is an implementation detail that is kept in sync with MLIR codegen
81/// conventions. Additionally takes a `shapeAlloc` array which
82/// is used instead of `shape` to allocate "more aligned" data and compute the
83/// corresponding strides.
84template <int N, typename T>
85typename std::enable_if<(N == 0), StridedMemRefType<T, 0>>::type
86makeStridedMemRefDescriptor(T *ptr, T *alignedPtr, ArrayRef<int64_t> shape = {},
87 ArrayRef<int64_t> shapeAlloc = {}) {
88 assert(shape.size() == N);
89 assert(shapeAlloc.size() == N);
90 StridedMemRefType<T, 0> descriptor;
91 descriptor.basePtr = static_cast<T *>(ptr);
92 descriptor.data = static_cast<T *>(alignedPtr);
93 descriptor.offset = 0;
94 return descriptor;
95}
96
97/// Align `nElements` of type T with an optional `alignment`.
98/// This replaces a portable `posix_memalign`.
99/// `alignment` must be a power of 2 and greater than the size of T. By default
100/// the alignment is sizeof(T).
101template <typename T>
102std::pair<T *, T *>
103allocAligned(size_t nElements, AllocFunType allocFun = &::malloc,
104 std::optional<uint64_t> alignment = std::optional<uint64_t>()) {
105 assert(sizeof(T) <= UINT_MAX && "Elemental type overflows");
106 auto size = nElements * sizeof(T);
107 auto desiredAlignment = alignment.value_or(u: nextPowerOf2(n: sizeof(T)));
108 assert((desiredAlignment & (desiredAlignment - 1)) == 0);
109 assert(desiredAlignment >= sizeof(T));
110 T *data = reinterpret_cast<T *>(allocFun(size + desiredAlignment));
111 uintptr_t addr = reinterpret_cast<uintptr_t>(data);
112 uintptr_t rem = addr % desiredAlignment;
113 T *alignedData = (rem == 0)
114 ? data
115 : reinterpret_cast<T *>(addr + (desiredAlignment - rem));
116 assert(reinterpret_cast<uintptr_t>(alignedData) % desiredAlignment == 0);
117 return std::make_pair(data, alignedData);
118}
119
120} // namespace detail
121
122//===----------------------------------------------------------------------===//
123// Public API
124//===----------------------------------------------------------------------===//
125
126/// Convenient callback to "visit" a memref element by element.
127/// This takes a reference to an individual element as well as the coordinates.
128/// It can be used in conjuction with a StridedMemrefIterator.
129template <typename T>
130using ElementWiseVisitor = llvm::function_ref<void(T &ptr, ArrayRef<int64_t>)>;
131
132/// Owning MemRef type that abstracts over the runtime type for ranked strided
133/// memref.
134template <typename T, unsigned Rank>
135class OwningMemRef {
136public:
137 using DescriptorType = StridedMemRefType<T, Rank>;
138 using FreeFunType = std::function<void(DescriptorType)>;
139
140 /// Allocate a new dense StridedMemrefRef with a given `shape`. An optional
141 /// `shapeAlloc` array can be supplied to "pad" every dimension individually.
142 /// If an ElementWiseVisitor is provided, it will be used to initialize the
143 /// data, else the memory will be zero-initialized. The alloc and free method
144 /// used to manage the data allocation can be optionally provided, and default
145 /// to malloc/free.
146 OwningMemRef(
147 ArrayRef<int64_t> shape, ArrayRef<int64_t> shapeAlloc = {},
148 ElementWiseVisitor<T> init = {},
149 std::optional<uint64_t> alignment = std::optional<uint64_t>(),
150 AllocFunType allocFun = &::malloc,
151 std::function<void(StridedMemRefType<T, Rank>)> freeFun =
152 [](StridedMemRefType<T, Rank> descriptor) {
153 ::free(ptr: descriptor.data);
154 })
155 : freeFunc(freeFun) {
156 if (shapeAlloc.empty())
157 shapeAlloc = shape;
158 assert(shape.size() == Rank);
159 assert(shapeAlloc.size() == Rank);
160 for (unsigned i = 0; i < Rank; ++i)
161 assert(shape[i] <= shapeAlloc[i] &&
162 "shapeAlloc must be greater than or equal to shape");
163 int64_t nElements = 1;
164 for (int64_t s : shapeAlloc)
165 nElements *= s;
166 auto [data, alignedData] =
167 detail::allocAligned<T>(nElements, allocFun, alignment);
168 descriptor = detail::makeStridedMemRefDescriptor<Rank>(data, alignedData,
169 shape, shapeAlloc);
170 if (init) {
171 for (StridedMemrefIterator<T, Rank> it = descriptor.begin(),
172 end = descriptor.end();
173 it != end; ++it)
174 init(*it, it.getIndices());
175 } else {
176 memset(descriptor.data, 0,
177 nElements * sizeof(T) +
178 alignment.value_or(u: detail::nextPowerOf2(n: sizeof(T))));
179 }
180 }
181 /// Take ownership of an existing descriptor with a custom deleter.
182 OwningMemRef(DescriptorType descriptor, FreeFunType freeFunc)
183 : freeFunc(freeFunc), descriptor(descriptor) {}
184 ~OwningMemRef() {
185 if (freeFunc)
186 freeFunc(descriptor);
187 }
188 OwningMemRef(const OwningMemRef &) = delete;
189 OwningMemRef &operator=(const OwningMemRef &) = delete;
190 OwningMemRef &operator=(const OwningMemRef &&other) {
191 freeFunc = other.freeFunc;
192 descriptor = other.descriptor;
193 other.freeFunc = nullptr;
194 memset(&other.descriptor, 0, sizeof(other.descriptor));
195 }
196 OwningMemRef(OwningMemRef &&other) { *this = std::move(other); }
197
198 DescriptorType &operator*() { return descriptor; }
199 DescriptorType *operator->() { return &descriptor; }
200 T &operator[](std::initializer_list<int64_t> indices) {
201 return descriptor[indices];
202 }
203
204private:
205 /// Custom deleter used to release the data buffer manager with the descriptor
206 /// below.
207 FreeFunType freeFunc;
208 /// The descriptor is an instance of StridedMemRefType<T, rank>.
209 DescriptorType descriptor;
210};
211
212} // namespace mlir
213
214#endif // MLIR_EXECUTIONENGINE_MEMREFUTILS_H_
215

source code of mlir/include/mlir/ExecutionEngine/MemRefUtils.h