| 1 | //===- IndexingUtils.cpp - Helpers related to index computations ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 10 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 11 | #include "mlir/IR/AffineExpr.h" |
| 12 | #include "mlir/IR/Builders.h" |
| 13 | #include "mlir/IR/BuiltinAttributes.h" |
| 14 | #include "mlir/IR/MLIRContext.h" |
| 15 | #include "llvm/ADT/STLExtras.h" |
| 16 | #include <numeric> |
| 17 | #include <optional> |
| 18 | |
| 19 | using namespace mlir; |
| 20 | |
| 21 | template <typename ExprType> |
| 22 | SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes, |
| 23 | ExprType unit) { |
| 24 | if (sizes.empty()) |
| 25 | return {}; |
| 26 | SmallVector<ExprType> strides(sizes.size(), unit); |
| 27 | for (int64_t r = static_cast<int64_t>(strides.size()) - 2; r >= 0; --r) |
| 28 | strides[r] = strides[r + 1] * sizes[r + 1]; |
| 29 | return strides; |
| 30 | } |
| 31 | |
| 32 | template <typename ExprType> |
| 33 | SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1, |
| 34 | ArrayRef<ExprType> v2) { |
| 35 | // Early exit if both are empty, let zip_equal fail if only 1 is empty. |
| 36 | if (v1.empty() && v2.empty()) |
| 37 | return {}; |
| 38 | SmallVector<ExprType> result; |
| 39 | for (auto it : llvm::zip_equal(v1, v2)) |
| 40 | result.push_back(std::get<0>(it) * std::get<1>(it)); |
| 41 | return result; |
| 42 | } |
| 43 | |
| 44 | template <typename ExprType> |
| 45 | ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis, |
| 46 | ExprType zero) { |
| 47 | assert(offsets.size() == basis.size()); |
| 48 | ExprType linearIndex = zero; |
| 49 | for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) |
| 50 | linearIndex = linearIndex + offsets[idx] * basis[idx]; |
| 51 | return linearIndex; |
| 52 | } |
| 53 | |
| 54 | template <typename ExprType, typename DivOpTy> |
| 55 | SmallVector<ExprType> delinearizeImpl(ExprType linearIndex, |
| 56 | ArrayRef<ExprType> strides, |
| 57 | DivOpTy divOp) { |
| 58 | int64_t rank = strides.size(); |
| 59 | SmallVector<ExprType> offsets(rank); |
| 60 | for (int64_t r = 0; r < rank; ++r) { |
| 61 | offsets[r] = divOp(linearIndex, strides[r]); |
| 62 | linearIndex = linearIndex % strides[r]; |
| 63 | } |
| 64 | return offsets; |
| 65 | } |
| 66 | |
| 67 | //===----------------------------------------------------------------------===// |
| 68 | // Utils that operate on static integer values. |
| 69 | //===----------------------------------------------------------------------===// |
| 70 | |
| 71 | SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) { |
| 72 | assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) && |
| 73 | "sizes must be nonnegative" ); |
| 74 | int64_t unit = 1; |
| 75 | return ::computeSuffixProductImpl(sizes, unit); |
| 76 | } |
| 77 | |
| 78 | SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1, |
| 79 | ArrayRef<int64_t> v2) { |
| 80 | return computeElementwiseMulImpl(v1, v2); |
| 81 | } |
| 82 | |
| 83 | int64_t mlir::computeSum(ArrayRef<int64_t> basis) { |
| 84 | assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && |
| 85 | "basis must be nonnegative" ); |
| 86 | if (basis.empty()) |
| 87 | return 0; |
| 88 | return std::accumulate(first: basis.begin(), last: basis.end(), init: 1, binary_op: std::plus<int64_t>()); |
| 89 | } |
| 90 | |
| 91 | int64_t mlir::computeProduct(ArrayRef<int64_t> basis) { |
| 92 | assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && |
| 93 | "basis must be nonnegative" ); |
| 94 | if (basis.empty()) |
| 95 | return 1; |
| 96 | return std::accumulate(first: basis.begin(), last: basis.end(), init: 1, |
| 97 | binary_op: std::multiplies<int64_t>()); |
| 98 | } |
| 99 | |
| 100 | int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) { |
| 101 | assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && |
| 102 | "basis must be nonnegative" ); |
| 103 | int64_t zero = 0; |
| 104 | return linearizeImpl(offsets, basis, zero); |
| 105 | } |
| 106 | |
| 107 | SmallVector<int64_t> mlir::delinearize(int64_t linearIndex, |
| 108 | ArrayRef<int64_t> strides) { |
| 109 | assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) && |
| 110 | "strides must be nonnegative" ); |
| 111 | return delinearizeImpl(linearIndex, strides, |
| 112 | divOp: [](int64_t e1, int64_t e2) { return e1 / e2; }); |
| 113 | } |
| 114 | |
| 115 | std::optional<SmallVector<int64_t>> |
| 116 | mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) { |
| 117 | if (shape.size() < subShape.size()) |
| 118 | return std::nullopt; |
| 119 | assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) && |
| 120 | "shape must be nonnegative" ); |
| 121 | assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) && |
| 122 | "subShape must be nonnegative" ); |
| 123 | |
| 124 | // Starting from the end, compute the integer divisors. |
| 125 | std::vector<int64_t> result; |
| 126 | result.reserve(n: shape.size()); |
| 127 | for (auto [size, subSize] : |
| 128 | llvm::zip(t: llvm::reverse(C&: shape), u: llvm::reverse(C&: subShape))) { |
| 129 | // If integral division does not occur, return and let the caller decide. |
| 130 | if (size % subSize != 0) |
| 131 | return std::nullopt; |
| 132 | result.push_back(x: size / subSize); |
| 133 | } |
| 134 | // At this point we computed the ratio (in reverse) for the common size. |
| 135 | // Fill with the remaining entries from the shape (still in reverse). |
| 136 | int commonSize = subShape.size(); |
| 137 | std::copy(first: shape.rbegin() + commonSize, last: shape.rend(), |
| 138 | result: std::back_inserter(x&: result)); |
| 139 | // Reverse again to get it back in the proper order and return. |
| 140 | return SmallVector<int64_t>{result.rbegin(), result.rend()}; |
| 141 | } |
| 142 | |
| 143 | //===----------------------------------------------------------------------===// |
| 144 | // Utils that operate on AffineExpr. |
| 145 | //===----------------------------------------------------------------------===// |
| 146 | |
| 147 | SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) { |
| 148 | if (sizes.empty()) |
| 149 | return {}; |
| 150 | AffineExpr unit = getAffineConstantExpr(constant: 1, context: sizes.front().getContext()); |
| 151 | return ::computeSuffixProductImpl(sizes, unit); |
| 152 | } |
| 153 | |
| 154 | SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1, |
| 155 | ArrayRef<AffineExpr> v2) { |
| 156 | return computeElementwiseMulImpl(v1, v2); |
| 157 | } |
| 158 | |
| 159 | AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { |
| 160 | if (basis.empty()) |
| 161 | return getAffineConstantExpr(constant: 0, context: ctx); |
| 162 | return std::accumulate(first: basis.begin(), last: basis.end(), |
| 163 | init: getAffineConstantExpr(constant: 0, context: ctx), |
| 164 | binary_op: std::plus<AffineExpr>()); |
| 165 | } |
| 166 | |
| 167 | AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) { |
| 168 | if (basis.empty()) |
| 169 | return getAffineConstantExpr(constant: 1, context: ctx); |
| 170 | return std::accumulate(first: basis.begin(), last: basis.end(), |
| 171 | init: getAffineConstantExpr(constant: 1, context: ctx), |
| 172 | binary_op: std::multiplies<AffineExpr>()); |
| 173 | } |
| 174 | |
| 175 | AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets, |
| 176 | ArrayRef<AffineExpr> basis) { |
| 177 | AffineExpr zero = getAffineConstantExpr(constant: 0, context: ctx); |
| 178 | return linearizeImpl(offsets, basis, zero); |
| 179 | } |
| 180 | |
| 181 | AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets, |
| 182 | ArrayRef<int64_t> basis) { |
| 183 | |
| 184 | return linearize(ctx, offsets, basis: getAffineConstantExprs(constants: basis, context: ctx)); |
| 185 | } |
| 186 | |
| 187 | SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex, |
| 188 | ArrayRef<AffineExpr> strides) { |
| 189 | return delinearizeImpl( |
| 190 | linearIndex, strides, |
| 191 | divOp: [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(other: e2); }); |
| 192 | } |
| 193 | |
| 194 | SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex, |
| 195 | ArrayRef<int64_t> strides) { |
| 196 | MLIRContext *ctx = linearIndex.getContext(); |
| 197 | return delinearize(linearIndex, strides: getAffineConstantExprs(constants: strides, context: ctx)); |
| 198 | } |
| 199 | |
| 200 | //===----------------------------------------------------------------------===// |
| 201 | // Permutation utils. |
| 202 | //===----------------------------------------------------------------------===// |
| 203 | |
| 204 | SmallVector<int64_t> |
| 205 | mlir::invertPermutationVector(ArrayRef<int64_t> permutation) { |
| 206 | assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) && |
| 207 | "permutation must be non-negative" ); |
| 208 | SmallVector<int64_t> inversion(permutation.size()); |
| 209 | for (const auto &pos : llvm::enumerate(First&: permutation)) { |
| 210 | inversion[pos.value()] = pos.index(); |
| 211 | } |
| 212 | return inversion; |
| 213 | } |
| 214 | |
| 215 | bool mlir::isIdentityPermutation(ArrayRef<int64_t> permutation) { |
| 216 | for (auto i : llvm::seq<int64_t>(Begin: 0, End: permutation.size())) |
| 217 | if (permutation[i] != i) |
| 218 | return false; |
| 219 | return true; |
| 220 | } |
| 221 | |
| 222 | bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) { |
| 223 | llvm::SmallDenseSet<int64_t, 4> seenVals; |
| 224 | for (auto val : interchange) { |
| 225 | if (val < 0 || static_cast<uint64_t>(val) >= interchange.size()) |
| 226 | return false; |
| 227 | if (seenVals.count(V: val)) |
| 228 | return false; |
| 229 | seenVals.insert(V: val); |
| 230 | } |
| 231 | return seenVals.size() == interchange.size(); |
| 232 | } |
| 233 | |
| 234 | SmallVector<int64_t> |
| 235 | mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions, |
| 236 | ArrayRef<int64_t> desiredPositions) { |
| 237 | SmallVector<int64_t> res(permSize, -1); |
| 238 | DenseSet<int64_t> seen; |
| 239 | for (auto [pos, desiredPos] : llvm::zip_equal(t&: positions, u&: desiredPositions)) { |
| 240 | res[desiredPos] = pos; |
| 241 | seen.insert(V: pos); |
| 242 | } |
| 243 | int64_t nextPos = 0; |
| 244 | for (int64_t &entry : res) { |
| 245 | if (entry != -1) |
| 246 | continue; |
| 247 | while (seen.contains(V: nextPos)) |
| 248 | ++nextPos; |
| 249 | entry = nextPos; |
| 250 | ++nextPos; |
| 251 | } |
| 252 | return res; |
| 253 | } |
| 254 | |
| 255 | SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm, |
| 256 | ArrayRef<int64_t> dropPositions) { |
| 257 | assert(inputPerm.size() >= dropPositions.size() && |
| 258 | "expect inputPerm size large than position to drop" ); |
| 259 | SmallVector<int64_t> res; |
| 260 | unsigned permSize = inputPerm.size(); |
| 261 | for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) { |
| 262 | int64_t targetIndex = inputPerm[inputIndex]; |
| 263 | bool shouldDrop = false; |
| 264 | unsigned dropSize = dropPositions.size(); |
| 265 | for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) { |
| 266 | if (dropPositions[dropIndex] == inputPerm[inputIndex]) { |
| 267 | shouldDrop = true; |
| 268 | break; |
| 269 | } |
| 270 | if (dropPositions[dropIndex] < inputPerm[inputIndex]) { |
| 271 | targetIndex--; |
| 272 | } |
| 273 | } |
| 274 | if (!shouldDrop) { |
| 275 | res.push_back(Elt: targetIndex); |
| 276 | } |
| 277 | } |
| 278 | return res; |
| 279 | } |
| 280 | |
| 281 | SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr, |
| 282 | unsigned dropFront, |
| 283 | unsigned dropBack) { |
| 284 | assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds" ); |
| 285 | auto range = arrayAttr.getAsRange<IntegerAttr>(); |
| 286 | SmallVector<int64_t> res; |
| 287 | res.reserve(N: arrayAttr.size() - dropFront - dropBack); |
| 288 | for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; |
| 289 | it != eit; ++it) |
| 290 | res.push_back(Elt: (*it).getValue().getSExtValue()); |
| 291 | return res; |
| 292 | } |
| 293 | |
| 294 | // TODO: do we have any common utily for this? |
| 295 | static MLIRContext *getContext(OpFoldResult val) { |
| 296 | assert(val && "Invalid value" ); |
| 297 | if (auto attr = dyn_cast<Attribute>(Val&: val)) { |
| 298 | return attr.getContext(); |
| 299 | } |
| 300 | return cast<Value>(Val&: val).getContext(); |
| 301 | } |
| 302 | |
| 303 | std::pair<AffineExpr, SmallVector<OpFoldResult>> |
| 304 | mlir::computeLinearIndex(OpFoldResult sourceOffset, |
| 305 | ArrayRef<OpFoldResult> strides, |
| 306 | ArrayRef<OpFoldResult> indices) { |
| 307 | assert(strides.size() == indices.size()); |
| 308 | auto sourceRank = static_cast<unsigned>(strides.size()); |
| 309 | |
| 310 | // Hold the affine symbols and values for the computation of the offset. |
| 311 | SmallVector<OpFoldResult> values(2 * sourceRank + 1); |
| 312 | SmallVector<AffineExpr> symbols(2 * sourceRank + 1); |
| 313 | |
| 314 | bindSymbolsList(ctx: getContext(val: sourceOffset), exprs: MutableArrayRef{symbols}); |
| 315 | AffineExpr expr = symbols.front(); |
| 316 | values[0] = sourceOffset; |
| 317 | |
| 318 | for (unsigned i = 0; i < sourceRank; ++i) { |
| 319 | // Compute the stride. |
| 320 | OpFoldResult origStride = strides[i]; |
| 321 | |
| 322 | // Build up the computation of the offset. |
| 323 | unsigned baseIdxForDim = 1 + 2 * i; |
| 324 | unsigned subOffsetForDim = baseIdxForDim; |
| 325 | unsigned origStrideForDim = baseIdxForDim + 1; |
| 326 | expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; |
| 327 | values[subOffsetForDim] = indices[i]; |
| 328 | values[origStrideForDim] = origStride; |
| 329 | } |
| 330 | |
| 331 | return {expr, values}; |
| 332 | } |
| 333 | |
| 334 | std::pair<AffineExpr, SmallVector<OpFoldResult>> |
| 335 | mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides, |
| 336 | ArrayRef<Value> indices) { |
| 337 | return computeLinearIndex( |
| 338 | sourceOffset, strides: getAsIndexOpFoldResult(ctx: sourceOffset.getContext(), values: strides), |
| 339 | indices: getAsOpFoldResult(values: ValueRange(indices))); |
| 340 | } |
| 341 | |
| 342 | //===----------------------------------------------------------------------===// |
| 343 | // TileOffsetRange |
| 344 | //===----------------------------------------------------------------------===// |
| 345 | |
| 346 | /// Apply left-padding by 1 to the tile shape if required. |
| 347 | static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape, |
| 348 | unsigned paddedSize) { |
| 349 | assert(tileShape.size() <= paddedSize && |
| 350 | "expected tileShape to <= paddedSize" ); |
| 351 | if (tileShape.size() == paddedSize) |
| 352 | return to_vector(Range&: tileShape); |
| 353 | SmallVector<int64_t> result(paddedSize - tileShape.size(), 1); |
| 354 | llvm::append_range(C&: result, R&: tileShape); |
| 355 | return result; |
| 356 | } |
| 357 | |
| 358 | mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl( |
| 359 | ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape, |
| 360 | ArrayRef<int64_t> loopOrder) |
| 361 | : tileShape(padTileShapeToSize(tileShape, paddedSize: shape.size())), |
| 362 | inverseLoopOrder(invertPermutationVector(permutation: loopOrder)), |
| 363 | sliceStrides(shape.size()) { |
| 364 | // Divide the shape by the tile shape. |
| 365 | std::optional<SmallVector<int64_t>> shapeRatio = |
| 366 | mlir::computeShapeRatio(shape, subShape: tileShape); |
| 367 | assert(shapeRatio && shapeRatio->size() == shape.size() && |
| 368 | "target shape does not evenly divide the original shape" ); |
| 369 | assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() && |
| 370 | "expected loop order to be a permutation of rank equal to outer " |
| 371 | "shape" ); |
| 372 | |
| 373 | maxLinearIndex = mlir::computeMaxLinearIndex(basis: *shapeRatio); |
| 374 | mlir::applyPermutationToVector(inVec&: *shapeRatio, permutation: loopOrder); |
| 375 | sliceStrides = mlir::computeStrides(sizes: *shapeRatio); |
| 376 | } |
| 377 | |
| 378 | SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets( |
| 379 | int64_t linearIndex) const { |
| 380 | SmallVector<int64_t> tileCoords = applyPermutation( |
| 381 | input: delinearize(linearIndex, strides: sliceStrides), permutation: inverseLoopOrder); |
| 382 | return computeElementwiseMul(v1: tileCoords, v2: tileShape); |
| 383 | } |
| 384 | |
| 385 | SmallVector<AffineExpr> |
| 386 | mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets( |
| 387 | AffineExpr linearIndex) const { |
| 388 | MLIRContext *ctx = linearIndex.getContext(); |
| 389 | SmallVector<AffineExpr> tileCoords = applyPermutation( |
| 390 | input: delinearize(linearIndex, strides: sliceStrides), permutation: inverseLoopOrder); |
| 391 | return mlir::computeElementwiseMul(v1: tileCoords, |
| 392 | v2: getAffineConstantExprs(constants: tileShape, context: ctx)); |
| 393 | } |
| 394 | |