1//===- IndexingUtils.cpp - Helpers related to index computations ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/IndexingUtils.h"
10#include "mlir/Dialect/Utils/StaticValueUtils.h"
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinAttributes.h"
14#include "mlir/IR/MLIRContext.h"
15#include "llvm/ADT/STLExtras.h"
16#include <numeric>
17#include <optional>
18
19using namespace mlir;
20
21template <typename ExprType>
22SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
23 ExprType unit) {
24 if (sizes.empty())
25 return {};
26 SmallVector<ExprType> strides(sizes.size(), unit);
27 for (int64_t r = static_cast<int64_t>(strides.size()) - 2; r >= 0; --r)
28 strides[r] = strides[r + 1] * sizes[r + 1];
29 return strides;
30}
31
32template <typename ExprType>
33SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
34 ArrayRef<ExprType> v2) {
35 // Early exit if both are empty, let zip_equal fail if only 1 is empty.
36 if (v1.empty() && v2.empty())
37 return {};
38 SmallVector<ExprType> result;
39 for (auto it : llvm::zip_equal(v1, v2))
40 result.push_back(std::get<0>(it) * std::get<1>(it));
41 return result;
42}
43
44template <typename ExprType>
45ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis,
46 ExprType zero) {
47 assert(offsets.size() == basis.size());
48 ExprType linearIndex = zero;
49 for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
50 linearIndex = linearIndex + offsets[idx] * basis[idx];
51 return linearIndex;
52}
53
54template <typename ExprType, typename DivOpTy>
55SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
56 ArrayRef<ExprType> strides,
57 DivOpTy divOp) {
58 int64_t rank = strides.size();
59 SmallVector<ExprType> offsets(rank);
60 for (int64_t r = 0; r < rank; ++r) {
61 offsets[r] = divOp(linearIndex, strides[r]);
62 linearIndex = linearIndex % strides[r];
63 }
64 return offsets;
65}
66
67//===----------------------------------------------------------------------===//
68// Utils that operate on static integer values.
69//===----------------------------------------------------------------------===//
70
71SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
72 assert((sizes.empty() ||
73 llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
74 "sizes must be nonnegative");
75 int64_t unit = 1;
76 return ::computeSuffixProductImpl(sizes, unit);
77}
78
79SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
80 ArrayRef<int64_t> v2) {
81 return computeElementwiseMulImpl(v1, v2);
82}
83
84int64_t mlir::computeSum(ArrayRef<int64_t> basis) {
85 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
86 "basis must be nonnegative");
87 if (basis.empty())
88 return 0;
89 return std::accumulate(first: basis.begin(), last: basis.end(), init: 1, binary_op: std::plus<int64_t>());
90}
91
92int64_t mlir::computeProduct(ArrayRef<int64_t> basis) {
93 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
94 "basis must be nonnegative");
95 if (basis.empty())
96 return 1;
97 return std::accumulate(first: basis.begin(), last: basis.end(), init: 1,
98 binary_op: std::multiplies<int64_t>());
99}
100
101int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
102 assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
103 "basis must be nonnegative");
104 int64_t zero = 0;
105 return linearizeImpl(offsets, basis, zero);
106}
107
108SmallVector<int64_t> mlir::delinearize(int64_t linearIndex,
109 ArrayRef<int64_t> strides) {
110 assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) &&
111 "strides must be nonnegative");
112 return delinearizeImpl(linearIndex, strides,
113 divOp: [](int64_t e1, int64_t e2) { return e1 / e2; });
114}
115
116std::optional<SmallVector<int64_t>>
117mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
118 if (shape.size() < subShape.size())
119 return std::nullopt;
120 assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) &&
121 "shape must be nonnegative");
122 assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) &&
123 "subShape must be nonnegative");
124
125 // Starting from the end, compute the integer divisors.
126 std::vector<int64_t> result;
127 result.reserve(n: shape.size());
128 for (auto [size, subSize] :
129 llvm::zip(t: llvm::reverse(C&: shape), u: llvm::reverse(C&: subShape))) {
130 // If integral division does not occur, return and let the caller decide.
131 if (size % subSize != 0)
132 return std::nullopt;
133 result.push_back(x: size / subSize);
134 }
135 // At this point we computed the ratio (in reverse) for the common size.
136 // Fill with the remaining entries from the shape (still in reverse).
137 int commonSize = subShape.size();
138 std::copy(first: shape.rbegin() + commonSize, last: shape.rend(),
139 result: std::back_inserter(x&: result));
140 // Reverse again to get it back in the proper order and return.
141 return SmallVector<int64_t>{result.rbegin(), result.rend()};
142}
143
144//===----------------------------------------------------------------------===//
145// Utils that operate on AffineExpr.
146//===----------------------------------------------------------------------===//
147
148SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) {
149 if (sizes.empty())
150 return {};
151 AffineExpr unit = getAffineConstantExpr(constant: 1, context: sizes.front().getContext());
152 return ::computeSuffixProductImpl(sizes, unit);
153}
154
155SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1,
156 ArrayRef<AffineExpr> v2) {
157 return computeElementwiseMulImpl(v1, v2);
158}
159
160AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
161 if (basis.empty())
162 return getAffineConstantExpr(constant: 0, context: ctx);
163 return std::accumulate(first: basis.begin(), last: basis.end(),
164 init: getAffineConstantExpr(constant: 0, context: ctx),
165 binary_op: std::plus<AffineExpr>());
166}
167
168AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
169 if (basis.empty())
170 return getAffineConstantExpr(constant: 1, context: ctx);
171 return std::accumulate(first: basis.begin(), last: basis.end(),
172 init: getAffineConstantExpr(constant: 1, context: ctx),
173 binary_op: std::multiplies<AffineExpr>());
174}
175
176AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
177 ArrayRef<AffineExpr> basis) {
178 AffineExpr zero = getAffineConstantExpr(constant: 0, context: ctx);
179 return linearizeImpl(offsets, basis, zero);
180}
181
182AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
183 ArrayRef<int64_t> basis) {
184
185 return linearize(ctx, offsets, basis: getAffineConstantExprs(constants: basis, context: ctx));
186}
187
188SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
189 ArrayRef<AffineExpr> strides) {
190 return delinearizeImpl(
191 linearIndex, strides,
192 divOp: [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(other: e2); });
193}
194
195SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
196 ArrayRef<int64_t> strides) {
197 MLIRContext *ctx = linearIndex.getContext();
198 return delinearize(linearIndex, strides: getAffineConstantExprs(constants: strides, context: ctx));
199}
200
201//===----------------------------------------------------------------------===//
202// Permutation utils.
203//===----------------------------------------------------------------------===//
204
205SmallVector<int64_t>
206mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
207 assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) &&
208 "permutation must be non-negative");
209 SmallVector<int64_t> inversion(permutation.size());
210 for (const auto &pos : llvm::enumerate(First&: permutation)) {
211 inversion[pos.value()] = pos.index();
212 }
213 return inversion;
214}
215
216bool mlir::isIdentityPermutation(ArrayRef<int64_t> permutation) {
217 for (auto i : llvm::seq<int64_t>(Begin: 0, End: permutation.size()))
218 if (permutation[i] != i)
219 return false;
220 return true;
221}
222
223bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
224 llvm::SmallDenseSet<int64_t, 4> seenVals;
225 for (auto val : interchange) {
226 if (val < 0 || static_cast<uint64_t>(val) >= interchange.size())
227 return false;
228 if (seenVals.count(V: val))
229 return false;
230 seenVals.insert(V: val);
231 }
232 return seenVals.size() == interchange.size();
233}
234
235SmallVector<int64_t>
236mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
237 ArrayRef<int64_t> desiredPositions) {
238 SmallVector<int64_t> res(permSize, -1);
239 DenseSet<int64_t> seen;
240 for (auto [pos, desiredPos] : llvm::zip_equal(t&: positions, u&: desiredPositions)) {
241 res[desiredPos] = pos;
242 seen.insert(V: pos);
243 }
244 int64_t nextPos = 0;
245 for (int64_t &entry : res) {
246 if (entry != -1)
247 continue;
248 while (seen.contains(V: nextPos))
249 ++nextPos;
250 entry = nextPos;
251 ++nextPos;
252 }
253 return res;
254}
255
256SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm,
257 ArrayRef<int64_t> dropPositions) {
258 assert(inputPerm.size() >= dropPositions.size() &&
259 "expect inputPerm size large than position to drop");
260 SmallVector<int64_t> res;
261 unsigned permSize = inputPerm.size();
262 for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) {
263 int64_t targetIndex = inputPerm[inputIndex];
264 bool shouldDrop = false;
265 unsigned dropSize = dropPositions.size();
266 for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) {
267 if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
268 shouldDrop = true;
269 break;
270 }
271 if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
272 targetIndex--;
273 }
274 }
275 if (!shouldDrop) {
276 res.push_back(Elt: targetIndex);
277 }
278 }
279 return res;
280}
281
282SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
283 unsigned dropFront,
284 unsigned dropBack) {
285 assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
286 auto range = arrayAttr.getAsRange<IntegerAttr>();
287 SmallVector<int64_t> res;
288 res.reserve(N: arrayAttr.size() - dropFront - dropBack);
289 for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
290 it != eit; ++it)
291 res.push_back(Elt: (*it).getValue().getSExtValue());
292 return res;
293}
294
295// TODO: do we have any common utily for this?
296static MLIRContext *getContext(OpFoldResult val) {
297 assert(val && "Invalid value");
298 if (auto attr = dyn_cast<Attribute>(Val&: val)) {
299 return attr.getContext();
300 }
301 return cast<Value>(Val&: val).getContext();
302}
303
304std::pair<AffineExpr, SmallVector<OpFoldResult>>
305mlir::computeLinearIndex(OpFoldResult sourceOffset,
306 ArrayRef<OpFoldResult> strides,
307 ArrayRef<OpFoldResult> indices) {
308 assert(strides.size() == indices.size());
309 auto sourceRank = static_cast<unsigned>(strides.size());
310
311 // Hold the affine symbols and values for the computation of the offset.
312 SmallVector<OpFoldResult> values(2 * sourceRank + 1);
313 SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
314
315 bindSymbolsList(ctx: getContext(val: sourceOffset), exprs: MutableArrayRef{symbols});
316 AffineExpr expr = symbols.front();
317 values[0] = sourceOffset;
318
319 for (unsigned i = 0; i < sourceRank; ++i) {
320 // Compute the stride.
321 OpFoldResult origStride = strides[i];
322
323 // Build up the computation of the offset.
324 unsigned baseIdxForDim = 1 + 2 * i;
325 unsigned subOffsetForDim = baseIdxForDim;
326 unsigned origStrideForDim = baseIdxForDim + 1;
327 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
328 values[subOffsetForDim] = indices[i];
329 values[origStrideForDim] = origStride;
330 }
331
332 return {expr, values};
333}
334
335std::pair<AffineExpr, SmallVector<OpFoldResult>>
336mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
337 ArrayRef<Value> indices) {
338 return computeLinearIndex(
339 sourceOffset, strides: getAsIndexOpFoldResult(ctx: sourceOffset.getContext(), values: strides),
340 indices: getAsOpFoldResult(values: ValueRange(indices)));
341}
342
343//===----------------------------------------------------------------------===//
344// TileOffsetRange
345//===----------------------------------------------------------------------===//
346
347/// Apply left-padding by 1 to the tile shape if required.
348static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape,
349 unsigned paddedSize) {
350 assert(tileShape.size() <= paddedSize &&
351 "expected tileShape to <= paddedSize");
352 if (tileShape.size() == paddedSize)
353 return to_vector(Range&: tileShape);
354 SmallVector<int64_t> result(paddedSize - tileShape.size(), 1);
355 llvm::append_range(C&: result, R&: tileShape);
356 return result;
357}
358
359mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl(
360 ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
361 ArrayRef<int64_t> loopOrder)
362 : tileShape(padTileShapeToSize(tileShape, paddedSize: shape.size())),
363 inverseLoopOrder(invertPermutationVector(permutation: loopOrder)),
364 sliceStrides(shape.size()) {
365 // Divide the shape by the tile shape.
366 std::optional<SmallVector<int64_t>> shapeRatio =
367 mlir::computeShapeRatio(shape, subShape: tileShape);
368 assert(shapeRatio && shapeRatio->size() == shape.size() &&
369 "target shape does not evenly divide the original shape");
370 assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() &&
371 "expected loop order to be a permutation of rank equal to outer "
372 "shape");
373
374 maxLinearIndex = mlir::computeMaxLinearIndex(basis: *shapeRatio);
375 mlir::applyPermutationToVector(inVec&: *shapeRatio, permutation: loopOrder);
376 sliceStrides = mlir::computeStrides(sizes: *shapeRatio);
377}
378
379SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets(
380 int64_t linearIndex) const {
381 SmallVector<int64_t> tileCoords = applyPermutation(
382 input: delinearize(linearIndex, strides: sliceStrides), permutation: inverseLoopOrder);
383 return computeElementwiseMul(v1: tileCoords, v2: tileShape);
384}
385
386SmallVector<AffineExpr>
387mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets(
388 AffineExpr linearIndex) const {
389 MLIRContext *ctx = linearIndex.getContext();
390 SmallVector<AffineExpr> tileCoords = applyPermutation(
391 input: delinearize(linearIndex, strides: sliceStrides), permutation: inverseLoopOrder);
392 return mlir::computeElementwiseMul(v1: tileCoords,
393 v2: getAffineConstantExprs(constants: tileShape, context: ctx));
394}
395

source code of mlir/lib/Dialect/Utils/IndexingUtils.cpp