1 | //===- IndexingUtils.cpp - Helpers related to index computations ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
10 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
11 | #include "mlir/IR/AffineExpr.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/BuiltinAttributes.h" |
14 | #include "mlir/IR/MLIRContext.h" |
15 | #include "llvm/ADT/STLExtras.h" |
16 | #include <numeric> |
17 | #include <optional> |
18 | |
19 | using namespace mlir; |
20 | |
21 | template <typename ExprType> |
22 | SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes, |
23 | ExprType unit) { |
24 | if (sizes.empty()) |
25 | return {}; |
26 | SmallVector<ExprType> strides(sizes.size(), unit); |
27 | for (int64_t r = strides.size() - 2; r >= 0; --r) |
28 | strides[r] = strides[r + 1] * sizes[r + 1]; |
29 | return strides; |
30 | } |
31 | |
32 | template <typename ExprType> |
33 | SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1, |
34 | ArrayRef<ExprType> v2) { |
35 | // Early exit if both are empty, let zip_equal fail if only 1 is empty. |
36 | if (v1.empty() && v2.empty()) |
37 | return {}; |
38 | SmallVector<ExprType> result; |
39 | for (auto it : llvm::zip_equal(v1, v2)) |
40 | result.push_back(std::get<0>(it) * std::get<1>(it)); |
41 | return result; |
42 | } |
43 | |
44 | template <typename ExprType> |
45 | ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis, |
46 | ExprType zero) { |
47 | assert(offsets.size() == basis.size()); |
48 | ExprType linearIndex = zero; |
49 | for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) |
50 | linearIndex = linearIndex + offsets[idx] * basis[idx]; |
51 | return linearIndex; |
52 | } |
53 | |
54 | template <typename ExprType, typename DivOpTy> |
55 | SmallVector<ExprType> delinearizeImpl(ExprType linearIndex, |
56 | ArrayRef<ExprType> strides, |
57 | DivOpTy divOp) { |
58 | int64_t rank = strides.size(); |
59 | SmallVector<ExprType> offsets(rank); |
60 | for (int64_t r = 0; r < rank; ++r) { |
61 | offsets[r] = divOp(linearIndex, strides[r]); |
62 | linearIndex = linearIndex % strides[r]; |
63 | } |
64 | return offsets; |
65 | } |
66 | |
67 | //===----------------------------------------------------------------------===// |
68 | // Utils that operate on static integer values. |
69 | //===----------------------------------------------------------------------===// |
70 | |
71 | SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) { |
72 | assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) && |
73 | "sizes must be nonnegative" ); |
74 | int64_t unit = 1; |
75 | return ::computeSuffixProductImpl(sizes, unit); |
76 | } |
77 | |
78 | SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1, |
79 | ArrayRef<int64_t> v2) { |
80 | return computeElementwiseMulImpl(v1, v2); |
81 | } |
82 | |
83 | int64_t mlir::computeSum(ArrayRef<int64_t> basis) { |
84 | assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && |
85 | "basis must be nonnegative" ); |
86 | if (basis.empty()) |
87 | return 0; |
88 | return std::accumulate(first: basis.begin(), last: basis.end(), init: 1, binary_op: std::plus<int64_t>()); |
89 | } |
90 | |
91 | int64_t mlir::computeProduct(ArrayRef<int64_t> basis) { |
92 | assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && |
93 | "basis must be nonnegative" ); |
94 | if (basis.empty()) |
95 | return 0; |
96 | return std::accumulate(first: basis.begin(), last: basis.end(), init: 1, |
97 | binary_op: std::multiplies<int64_t>()); |
98 | } |
99 | |
100 | int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) { |
101 | assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && |
102 | "basis must be nonnegative" ); |
103 | int64_t zero = 0; |
104 | return linearizeImpl(offsets, basis, zero); |
105 | } |
106 | |
107 | SmallVector<int64_t> mlir::delinearize(int64_t linearIndex, |
108 | ArrayRef<int64_t> strides) { |
109 | assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) && |
110 | "strides must be nonnegative" ); |
111 | return delinearizeImpl(linearIndex, strides, |
112 | divOp: [](int64_t e1, int64_t e2) { return e1 / e2; }); |
113 | } |
114 | |
115 | std::optional<SmallVector<int64_t>> |
116 | mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) { |
117 | if (shape.size() < subShape.size()) |
118 | return std::nullopt; |
119 | assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) && |
120 | "shape must be nonnegative" ); |
121 | assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) && |
122 | "subShape must be nonnegative" ); |
123 | |
124 | // Starting from the end, compute the integer divisors. |
125 | std::vector<int64_t> result; |
126 | result.reserve(n: shape.size()); |
127 | for (auto [size, subSize] : |
128 | llvm::zip(t: llvm::reverse(C&: shape), u: llvm::reverse(C&: subShape))) { |
129 | // If integral division does not occur, return and let the caller decide. |
130 | if (size % subSize != 0) |
131 | return std::nullopt; |
132 | result.push_back(x: size / subSize); |
133 | } |
134 | // At this point we computed the ratio (in reverse) for the common size. |
135 | // Fill with the remaining entries from the shape (still in reverse). |
136 | int commonSize = subShape.size(); |
137 | std::copy(first: shape.rbegin() + commonSize, last: shape.rend(), |
138 | result: std::back_inserter(x&: result)); |
139 | // Reverse again to get it back in the proper order and return. |
140 | return SmallVector<int64_t>{result.rbegin(), result.rend()}; |
141 | } |
142 | |
143 | //===----------------------------------------------------------------------===// |
144 | // Utils that operate on AffineExpr. |
145 | //===----------------------------------------------------------------------===// |
146 | |
147 | SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) { |
148 | if (sizes.empty()) |
149 | return {}; |
150 | AffineExpr unit = getAffineConstantExpr(constant: 1, context: sizes.front().getContext()); |
151 | return ::computeSuffixProductImpl(sizes, unit); |
152 | } |
153 | |
154 | SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1, |
155 | ArrayRef<AffineExpr> v2) { |
156 | return computeElementwiseMulImpl(v1, v2); |
157 | } |
158 | |
159 | AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { |
160 | if (basis.empty()) |
161 | return getAffineConstantExpr(constant: 0, context: ctx); |
162 | return std::accumulate(first: basis.begin(), last: basis.end(), |
163 | init: getAffineConstantExpr(constant: 0, context: ctx), |
164 | binary_op: std::plus<AffineExpr>()); |
165 | } |
166 | |
167 | AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { |
168 | if (basis.empty()) |
169 | return getAffineConstantExpr(constant: 1, context: ctx); |
170 | return std::accumulate(first: basis.begin(), last: basis.end(), |
171 | init: getAffineConstantExpr(constant: 1, context: ctx), |
172 | binary_op: std::multiplies<AffineExpr>()); |
173 | } |
174 | |
175 | AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets, |
176 | ArrayRef<AffineExpr> basis) { |
177 | AffineExpr zero = getAffineConstantExpr(constant: 0, context: ctx); |
178 | return linearizeImpl(offsets, basis, zero); |
179 | } |
180 | |
181 | AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets, |
182 | ArrayRef<int64_t> basis) { |
183 | |
184 | return linearize(ctx, offsets, basis: getAffineConstantExprs(constants: basis, context: ctx)); |
185 | } |
186 | |
187 | SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex, |
188 | ArrayRef<AffineExpr> strides) { |
189 | return delinearizeImpl( |
190 | linearIndex, strides, |
191 | divOp: [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(other: e2); }); |
192 | } |
193 | |
194 | SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex, |
195 | ArrayRef<int64_t> strides) { |
196 | MLIRContext *ctx = linearIndex.getContext(); |
197 | return delinearize(linearIndex, strides: getAffineConstantExprs(constants: strides, context: ctx)); |
198 | } |
199 | |
200 | //===----------------------------------------------------------------------===// |
201 | // Permutation utils. |
202 | //===----------------------------------------------------------------------===// |
203 | |
204 | SmallVector<int64_t> |
205 | mlir::invertPermutationVector(ArrayRef<int64_t> permutation) { |
206 | assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) && |
207 | "permutation must be non-negative" ); |
208 | SmallVector<int64_t> inversion(permutation.size()); |
209 | for (const auto &pos : llvm::enumerate(First&: permutation)) { |
210 | inversion[pos.value()] = pos.index(); |
211 | } |
212 | return inversion; |
213 | } |
214 | |
215 | bool mlir::isIdentityPermutation(ArrayRef<int64_t> permutation) { |
216 | for (auto i : llvm::seq<int64_t>(Begin: 0, End: permutation.size())) |
217 | if (permutation[i] != i) |
218 | return false; |
219 | return true; |
220 | } |
221 | |
222 | bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) { |
223 | assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) && |
224 | "permutation must be non-negative" ); |
225 | llvm::SmallDenseSet<int64_t, 4> seenVals; |
226 | for (auto val : interchange) { |
227 | if (seenVals.count(V: val)) |
228 | return false; |
229 | seenVals.insert(V: val); |
230 | } |
231 | return seenVals.size() == interchange.size(); |
232 | } |
233 | |
234 | SmallVector<int64_t> |
235 | mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions, |
236 | ArrayRef<int64_t> desiredPositions) { |
237 | SmallVector<int64_t> res(permSize, -1); |
238 | DenseSet<int64_t> seen; |
239 | for (auto [pos, desiredPos] : llvm::zip_equal(t&: positions, u&: desiredPositions)) { |
240 | res[desiredPos] = pos; |
241 | seen.insert(V: pos); |
242 | } |
243 | int64_t nextPos = 0; |
244 | for (int64_t &entry : res) { |
245 | if (entry != -1) |
246 | continue; |
247 | while (seen.contains(V: nextPos)) |
248 | ++nextPos; |
249 | entry = nextPos; |
250 | ++nextPos; |
251 | } |
252 | return res; |
253 | } |
254 | |
255 | SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr, |
256 | unsigned dropFront, |
257 | unsigned dropBack) { |
258 | assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds" ); |
259 | auto range = arrayAttr.getAsRange<IntegerAttr>(); |
260 | SmallVector<int64_t> res; |
261 | res.reserve(N: arrayAttr.size() - dropFront - dropBack); |
262 | for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; |
263 | it != eit; ++it) |
264 | res.push_back(Elt: (*it).getValue().getSExtValue()); |
265 | return res; |
266 | } |
267 | |
268 | // TODO: do we have any common utily for this? |
269 | static MLIRContext *getContext(OpFoldResult val) { |
270 | assert(val && "Invalid value" ); |
271 | if (auto attr = dyn_cast<Attribute>(Val&: val)) { |
272 | return attr.getContext(); |
273 | } |
274 | return cast<Value>(Val&: val).getContext(); |
275 | } |
276 | |
277 | std::pair<AffineExpr, SmallVector<OpFoldResult>> |
278 | mlir::computeLinearIndex(OpFoldResult sourceOffset, |
279 | ArrayRef<OpFoldResult> strides, |
280 | ArrayRef<OpFoldResult> indices) { |
281 | assert(strides.size() == indices.size()); |
282 | auto sourceRank = static_cast<unsigned>(strides.size()); |
283 | |
284 | // Hold the affine symbols and values for the computation of the offset. |
285 | SmallVector<OpFoldResult> values(2 * sourceRank + 1); |
286 | SmallVector<AffineExpr> symbols(2 * sourceRank + 1); |
287 | |
288 | bindSymbolsList(ctx: getContext(val: sourceOffset), exprs: MutableArrayRef{symbols}); |
289 | AffineExpr expr = symbols.front(); |
290 | values[0] = sourceOffset; |
291 | |
292 | for (unsigned i = 0; i < sourceRank; ++i) { |
293 | // Compute the stride. |
294 | OpFoldResult origStride = strides[i]; |
295 | |
296 | // Build up the computation of the offset. |
297 | unsigned baseIdxForDim = 1 + 2 * i; |
298 | unsigned subOffsetForDim = baseIdxForDim; |
299 | unsigned origStrideForDim = baseIdxForDim + 1; |
300 | expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; |
301 | values[subOffsetForDim] = indices[i]; |
302 | values[origStrideForDim] = origStride; |
303 | } |
304 | |
305 | return {expr, values}; |
306 | } |
307 | |
308 | std::pair<AffineExpr, SmallVector<OpFoldResult>> |
309 | mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides, |
310 | ArrayRef<Value> indices) { |
311 | return computeLinearIndex( |
312 | sourceOffset, strides: getAsIndexOpFoldResult(ctx: sourceOffset.getContext(), values: strides), |
313 | indices: getAsOpFoldResult(values: ValueRange(indices))); |
314 | } |
315 | |
316 | //===----------------------------------------------------------------------===// |
317 | // TileOffsetRange |
318 | //===----------------------------------------------------------------------===// |
319 | |
320 | /// Apply left-padding by 1 to the tile shape if required. |
321 | static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape, |
322 | unsigned paddedSize) { |
323 | assert(tileShape.size() <= paddedSize && |
324 | "expected tileShape to <= paddedSize" ); |
325 | if (tileShape.size() == paddedSize) |
326 | return to_vector(Range&: tileShape); |
327 | SmallVector<int64_t> result(paddedSize - tileShape.size(), 1); |
328 | llvm::append_range(C&: result, R&: tileShape); |
329 | return result; |
330 | } |
331 | |
332 | mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl( |
333 | ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape, |
334 | ArrayRef<int64_t> loopOrder) |
335 | : tileShape(padTileShapeToSize(tileShape, paddedSize: shape.size())), |
336 | inverseLoopOrder(invertPermutationVector(permutation: loopOrder)), |
337 | sliceStrides(shape.size()) { |
338 | // Divide the shape by the tile shape. |
339 | std::optional<SmallVector<int64_t>> shapeRatio = |
340 | mlir::computeShapeRatio(shape, subShape: tileShape); |
341 | assert(shapeRatio && shapeRatio->size() == shape.size() && |
342 | "target shape does not evenly divide the original shape" ); |
343 | assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() && |
344 | "expected loop order to be a permutation of rank equal to outer " |
345 | "shape" ); |
346 | |
347 | maxLinearIndex = mlir::computeMaxLinearIndex(basis: *shapeRatio); |
348 | mlir::applyPermutationToVector(inVec&: *shapeRatio, permutation: loopOrder); |
349 | sliceStrides = mlir::computeStrides(sizes: *shapeRatio); |
350 | } |
351 | |
352 | SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets( |
353 | int64_t linearIndex) const { |
354 | SmallVector<int64_t> tileCoords = applyPermutation( |
355 | input: delinearize(linearIndex, strides: sliceStrides), permutation: inverseLoopOrder); |
356 | return computeElementwiseMul(v1: tileCoords, v2: tileShape); |
357 | } |
358 | |
359 | SmallVector<AffineExpr> |
360 | mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets( |
361 | AffineExpr linearIndex) const { |
362 | MLIRContext *ctx = linearIndex.getContext(); |
363 | SmallVector<AffineExpr> tileCoords = applyPermutation( |
364 | input: delinearize(linearIndex, strides: sliceStrides), permutation: inverseLoopOrder); |
365 | return mlir::computeElementwiseMul(v1: tileCoords, |
366 | v2: getAffineConstantExprs(constants: tileShape, context: ctx)); |
367 | } |
368 | |