| 1 | //===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a pass to expand affine index ops into one or more more |
| 10 | // fundamental operations. |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Affine/LoopUtils.h" |
| 14 | #include "mlir/Dialect/Affine/Passes.h" |
| 15 | |
| 16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 17 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
| 18 | #include "mlir/Dialect/Affine/Utils.h" |
| 19 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 21 | |
| 22 | namespace mlir { |
| 23 | namespace affine { |
| 24 | #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS |
| 25 | #include "mlir/Dialect/Affine/Passes.h.inc" |
| 26 | } // namespace affine |
| 27 | } // namespace mlir |
| 28 | |
| 29 | using namespace mlir; |
| 30 | using namespace mlir::affine; |
| 31 | |
| 32 | /// Given a basis (in static and dynamic components), return the sequence of |
| 33 | /// suffix products of the basis, including the product of the entire basis, |
| 34 | /// which must **not** contain an outer bound. |
| 35 | /// |
| 36 | /// If excess dynamic values are provided, the values at the beginning |
| 37 | /// will be ignored. This allows for dropping the outer bound without |
| 38 | /// needing to manipulate the dynamic value array. `knownPositive` |
| 39 | /// indicases that the values being used to compute the strides are known |
| 40 | /// to be non-negative. |
| 41 | static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter, |
| 42 | ValueRange dynamicBasis, |
| 43 | ArrayRef<int64_t> staticBasis, |
| 44 | bool knownNonNegative) { |
| 45 | if (staticBasis.empty()) |
| 46 | return {}; |
| 47 | |
| 48 | SmallVector<Value> result; |
| 49 | result.reserve(staticBasis.size()); |
| 50 | size_t dynamicIndex = dynamicBasis.size(); |
| 51 | Value dynamicPart = nullptr; |
| 52 | int64_t staticPart = 1; |
| 53 | // The products of the strides can't have overflow by definition of |
| 54 | // affine.*_index. |
| 55 | arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw; |
| 56 | if (knownNonNegative) |
| 57 | ovflags = ovflags | arith::IntegerOverflowFlags::nuw; |
| 58 | for (int64_t elem : llvm::reverse(staticBasis)) { |
| 59 | if (ShapedType::isDynamic(elem)) { |
| 60 | // Note: basis elements and their products are, definitionally, |
| 61 | // non-negative, so `nuw` is justified. |
| 62 | if (dynamicPart) |
| 63 | dynamicPart = rewriter.create<arith::MulIOp>( |
| 64 | loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags); |
| 65 | else |
| 66 | dynamicPart = dynamicBasis[dynamicIndex - 1]; |
| 67 | --dynamicIndex; |
| 68 | } else { |
| 69 | staticPart *= elem; |
| 70 | } |
| 71 | |
| 72 | if (dynamicPart && staticPart == 1) { |
| 73 | result.push_back(dynamicPart); |
| 74 | } else { |
| 75 | Value stride = |
| 76 | rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart); |
| 77 | if (dynamicPart) |
| 78 | stride = |
| 79 | rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags); |
| 80 | result.push_back(stride); |
| 81 | } |
| 82 | } |
| 83 | std::reverse(result.begin(), result.end()); |
| 84 | return result; |
| 85 | } |
| 86 | |
| 87 | namespace { |
| 88 | /// Lowers `affine.delinearize_index` into a sequence of division and remainder |
| 89 | /// operations. |
| 90 | struct LowerDelinearizeIndexOps |
| 91 | : public OpRewritePattern<AffineDelinearizeIndexOp> { |
| 92 | using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern; |
| 93 | LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, |
| 94 | PatternRewriter &rewriter) const override { |
| 95 | Location loc = op.getLoc(); |
| 96 | Value linearIdx = op.getLinearIndex(); |
| 97 | unsigned numResults = op.getNumResults(); |
| 98 | ArrayRef<int64_t> staticBasis = op.getStaticBasis(); |
| 99 | if (numResults == staticBasis.size()) |
| 100 | staticBasis = staticBasis.drop_front(); |
| 101 | |
| 102 | if (numResults == 1) { |
| 103 | rewriter.replaceOp(op, linearIdx); |
| 104 | return success(); |
| 105 | } |
| 106 | |
| 107 | SmallVector<Value> results; |
| 108 | results.reserve(numResults); |
| 109 | SmallVector<Value> strides = |
| 110 | computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis, |
| 111 | /*knownNonNegative=*/true); |
| 112 | |
| 113 | Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); |
| 114 | |
| 115 | Value initialPart = |
| 116 | rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front()); |
| 117 | results.push_back(initialPart); |
| 118 | |
| 119 | auto emitModTerm = [&](Value stride) -> Value { |
| 120 | Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride); |
| 121 | Value remainderNegative = rewriter.create<arith::CmpIOp>( |
| 122 | loc, arith::CmpIPredicate::slt, remainder, zero); |
| 123 | // If the correction is relevant, this term is <= stride, which is known |
| 124 | // to be positive in `index`. Otherwise, while 2 * stride might overflow, |
| 125 | // this branch won't be taken, so the risk of `poison` is fine. |
| 126 | Value corrected = rewriter.create<arith::AddIOp>( |
| 127 | loc, remainder, stride, arith::IntegerOverflowFlags::nsw); |
| 128 | Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative, |
| 129 | corrected, remainder); |
| 130 | return mod; |
| 131 | }; |
| 132 | |
| 133 | // Generate all the intermediate parts |
| 134 | for (size_t i = 0, e = strides.size() - 1; i < e; ++i) { |
| 135 | Value thisStride = strides[i]; |
| 136 | Value nextStride = strides[i + 1]; |
| 137 | Value modulus = emitModTerm(thisStride); |
| 138 | // We know both inputs are positive, so floorDiv == div. |
| 139 | // This could potentially be a divui, but it's not clear if that would |
| 140 | // cause issues. |
| 141 | Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride); |
| 142 | results.push_back(divided); |
| 143 | } |
| 144 | |
| 145 | results.push_back(emitModTerm(strides.back())); |
| 146 | |
| 147 | rewriter.replaceOp(op, results); |
| 148 | return success(); |
| 149 | } |
| 150 | }; |
| 151 | |
| 152 | /// Lowers `affine.linearize_index` into a sequence of multiplications and |
| 153 | /// additions. Make a best effort to sort the input indices so that |
| 154 | /// the most loop-invariant terms are at the left of the additions |
| 155 | /// to enable loop-invariant code motion. |
| 156 | struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> { |
| 157 | using OpRewritePattern::OpRewritePattern; |
| 158 | LogicalResult matchAndRewrite(AffineLinearizeIndexOp op, |
| 159 | PatternRewriter &rewriter) const override { |
| 160 | // Should be folded away, included here for safety. |
| 161 | if (op.getMultiIndex().empty()) { |
| 162 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0); |
| 163 | return success(); |
| 164 | } |
| 165 | |
| 166 | Location loc = op.getLoc(); |
| 167 | ValueRange multiIndex = op.getMultiIndex(); |
| 168 | size_t numIndexes = multiIndex.size(); |
| 169 | ArrayRef<int64_t> staticBasis = op.getStaticBasis(); |
| 170 | if (numIndexes == staticBasis.size()) |
| 171 | staticBasis = staticBasis.drop_front(); |
| 172 | |
| 173 | SmallVector<Value> strides = |
| 174 | computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis, |
| 175 | /*knownNonNegative=*/op.getDisjoint()); |
| 176 | SmallVector<std::pair<Value, int64_t>> scaledValues; |
| 177 | scaledValues.reserve(N: numIndexes); |
| 178 | |
| 179 | // Note: strides doesn't contain a value for the final element (stride 1) |
| 180 | // and everything else lines up. We use the "mutable" accessor so we can get |
| 181 | // our hands on an `OpOperand&` for the loop invariant counting function. |
| 182 | for (auto [stride, idxOp] : |
| 183 | llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) { |
| 184 | Value scaledIdx = rewriter.create<arith::MulIOp>( |
| 185 | loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw); |
| 186 | int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp); |
| 187 | scaledValues.emplace_back(scaledIdx, numHoistableLoops); |
| 188 | } |
| 189 | scaledValues.emplace_back( |
| 190 | multiIndex.back(), |
| 191 | numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1])); |
| 192 | |
| 193 | // Sort by how many enclosing loops there are, ties implicitly broken by |
| 194 | // size of the stride. |
| 195 | llvm::stable_sort(Range&: scaledValues, |
| 196 | C: [&](auto l, auto r) { return l.second > r.second; }); |
| 197 | |
| 198 | Value result = scaledValues.front().first; |
| 199 | for (auto [scaledValue, numHoistableLoops] : |
| 200 | llvm::drop_begin(RangeOrContainer&: scaledValues)) { |
| 201 | std::ignore = numHoistableLoops; |
| 202 | result = rewriter.create<arith::AddIOp>(loc, result, scaledValue, |
| 203 | arith::IntegerOverflowFlags::nsw); |
| 204 | } |
| 205 | rewriter.replaceOp(op, result); |
| 206 | return success(); |
| 207 | } |
| 208 | }; |
| 209 | |
| 210 | class ExpandAffineIndexOpsPass |
| 211 | : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> { |
| 212 | public: |
| 213 | ExpandAffineIndexOpsPass() = default; |
| 214 | |
| 215 | void runOnOperation() override { |
| 216 | MLIRContext *context = &getContext(); |
| 217 | RewritePatternSet patterns(context); |
| 218 | populateAffineExpandIndexOpsPatterns(patterns); |
| 219 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
| 220 | return signalPassFailure(); |
| 221 | } |
| 222 | }; |
| 223 | |
| 224 | } // namespace |
| 225 | |
| 226 | void mlir::affine::populateAffineExpandIndexOpsPatterns( |
| 227 | RewritePatternSet &patterns) { |
| 228 | patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>( |
| 229 | arg: patterns.getContext()); |
| 230 | } |
| 231 | |
| 232 | std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() { |
| 233 | return std::make_unique<ExpandAffineIndexOpsPass>(); |
| 234 | } |
| 235 | |