| 1 | //===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a pass to expand affine index ops into one or more more |
| 10 | // fundamental operations. |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Affine/LoopUtils.h" |
| 14 | #include "mlir/Dialect/Affine/Passes.h" |
| 15 | |
| 16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 17 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
| 18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 19 | |
| 20 | namespace mlir { |
| 21 | namespace affine { |
| 22 | #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS |
| 23 | #include "mlir/Dialect/Affine/Passes.h.inc" |
| 24 | } // namespace affine |
| 25 | } // namespace mlir |
| 26 | |
| 27 | using namespace mlir; |
| 28 | using namespace mlir::affine; |
| 29 | |
| 30 | /// Given a basis (in static and dynamic components), return the sequence of |
| 31 | /// suffix products of the basis, including the product of the entire basis, |
| 32 | /// which must **not** contain an outer bound. |
| 33 | /// |
| 34 | /// If excess dynamic values are provided, the values at the beginning |
| 35 | /// will be ignored. This allows for dropping the outer bound without |
| 36 | /// needing to manipulate the dynamic value array. `knownPositive` |
| 37 | /// indicases that the values being used to compute the strides are known |
| 38 | /// to be non-negative. |
| 39 | static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter, |
| 40 | ValueRange dynamicBasis, |
| 41 | ArrayRef<int64_t> staticBasis, |
| 42 | bool knownNonNegative) { |
| 43 | if (staticBasis.empty()) |
| 44 | return {}; |
| 45 | |
| 46 | SmallVector<Value> result; |
| 47 | result.reserve(N: staticBasis.size()); |
| 48 | size_t dynamicIndex = dynamicBasis.size(); |
| 49 | Value dynamicPart = nullptr; |
| 50 | int64_t staticPart = 1; |
| 51 | // The products of the strides can't have overflow by definition of |
| 52 | // affine.*_index. |
| 53 | arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw; |
| 54 | if (knownNonNegative) |
| 55 | ovflags = ovflags | arith::IntegerOverflowFlags::nuw; |
| 56 | for (int64_t elem : llvm::reverse(C&: staticBasis)) { |
| 57 | if (ShapedType::isDynamic(dValue: elem)) { |
| 58 | // Note: basis elements and their products are, definitionally, |
| 59 | // non-negative, so `nuw` is justified. |
| 60 | if (dynamicPart) |
| 61 | dynamicPart = rewriter.create<arith::MulIOp>( |
| 62 | location: loc, args&: dynamicPart, args: dynamicBasis[dynamicIndex - 1], args&: ovflags); |
| 63 | else |
| 64 | dynamicPart = dynamicBasis[dynamicIndex - 1]; |
| 65 | --dynamicIndex; |
| 66 | } else { |
| 67 | staticPart *= elem; |
| 68 | } |
| 69 | |
| 70 | if (dynamicPart && staticPart == 1) { |
| 71 | result.push_back(Elt: dynamicPart); |
| 72 | } else { |
| 73 | Value stride = |
| 74 | rewriter.createOrFold<arith::ConstantIndexOp>(location: loc, args&: staticPart); |
| 75 | if (dynamicPart) |
| 76 | stride = |
| 77 | rewriter.create<arith::MulIOp>(location: loc, args&: dynamicPart, args&: stride, args&: ovflags); |
| 78 | result.push_back(Elt: stride); |
| 79 | } |
| 80 | } |
| 81 | std::reverse(first: result.begin(), last: result.end()); |
| 82 | return result; |
| 83 | } |
| 84 | |
| 85 | LogicalResult |
| 86 | affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, |
| 87 | AffineDelinearizeIndexOp op) { |
| 88 | Location loc = op.getLoc(); |
| 89 | Value linearIdx = op.getLinearIndex(); |
| 90 | unsigned numResults = op.getNumResults(); |
| 91 | ArrayRef<int64_t> staticBasis = op.getStaticBasis(); |
| 92 | if (numResults == staticBasis.size()) |
| 93 | staticBasis = staticBasis.drop_front(); |
| 94 | |
| 95 | if (numResults == 1) { |
| 96 | rewriter.replaceOp(op, newValues: linearIdx); |
| 97 | return success(); |
| 98 | } |
| 99 | |
| 100 | SmallVector<Value> results; |
| 101 | results.reserve(N: numResults); |
| 102 | SmallVector<Value> strides = |
| 103 | computeStrides(loc, rewriter, dynamicBasis: op.getDynamicBasis(), staticBasis, |
| 104 | /*knownNonNegative=*/true); |
| 105 | |
| 106 | Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(location: loc, args: 0); |
| 107 | |
| 108 | Value initialPart = |
| 109 | rewriter.create<arith::FloorDivSIOp>(location: loc, args&: linearIdx, args&: strides.front()); |
| 110 | results.push_back(Elt: initialPart); |
| 111 | |
| 112 | auto emitModTerm = [&](Value stride) -> Value { |
| 113 | Value remainder = rewriter.create<arith::RemSIOp>(location: loc, args&: linearIdx, args&: stride); |
| 114 | Value remainderNegative = rewriter.create<arith::CmpIOp>( |
| 115 | location: loc, args: arith::CmpIPredicate::slt, args&: remainder, args&: zero); |
| 116 | // If the correction is relevant, this term is <= stride, which is known |
| 117 | // to be positive in `index`. Otherwise, while 2 * stride might overflow, |
| 118 | // this branch won't be taken, so the risk of `poison` is fine. |
| 119 | Value corrected = rewriter.create<arith::AddIOp>( |
| 120 | location: loc, args&: remainder, args&: stride, args: arith::IntegerOverflowFlags::nsw); |
| 121 | Value mod = rewriter.create<arith::SelectOp>(location: loc, args&: remainderNegative, |
| 122 | args&: corrected, args&: remainder); |
| 123 | return mod; |
| 124 | }; |
| 125 | |
| 126 | // Generate all the intermediate parts |
| 127 | for (size_t i = 0, e = strides.size() - 1; i < e; ++i) { |
| 128 | Value thisStride = strides[i]; |
| 129 | Value nextStride = strides[i + 1]; |
| 130 | Value modulus = emitModTerm(thisStride); |
| 131 | // We know both inputs are positive, so floorDiv == div. |
| 132 | // This could potentially be a divui, but it's not clear if that would |
| 133 | // cause issues. |
| 134 | Value divided = rewriter.create<arith::DivSIOp>(location: loc, args&: modulus, args&: nextStride); |
| 135 | results.push_back(Elt: divided); |
| 136 | } |
| 137 | |
| 138 | results.push_back(Elt: emitModTerm(strides.back())); |
| 139 | |
| 140 | rewriter.replaceOp(op, newValues: results); |
| 141 | return success(); |
| 142 | } |
| 143 | |
| 144 | LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter, |
| 145 | AffineLinearizeIndexOp op) { |
| 146 | // Should be folded away, included here for safety. |
| 147 | if (op.getMultiIndex().empty()) { |
| 148 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, args: 0); |
| 149 | return success(); |
| 150 | } |
| 151 | |
| 152 | Location loc = op.getLoc(); |
| 153 | ValueRange multiIndex = op.getMultiIndex(); |
| 154 | size_t numIndexes = multiIndex.size(); |
| 155 | ArrayRef<int64_t> staticBasis = op.getStaticBasis(); |
| 156 | if (numIndexes == staticBasis.size()) |
| 157 | staticBasis = staticBasis.drop_front(); |
| 158 | |
| 159 | SmallVector<Value> strides = |
| 160 | computeStrides(loc, rewriter, dynamicBasis: op.getDynamicBasis(), staticBasis, |
| 161 | /*knownNonNegative=*/op.getDisjoint()); |
| 162 | SmallVector<std::pair<Value, int64_t>> scaledValues; |
| 163 | scaledValues.reserve(N: numIndexes); |
| 164 | |
| 165 | // Note: strides doesn't contain a value for the final element (stride 1) |
| 166 | // and everything else lines up. We use the "mutable" accessor so we can get |
| 167 | // our hands on an `OpOperand&` for the loop invariant counting function. |
| 168 | for (auto [stride, idxOp] : |
| 169 | llvm::zip_equal(t&: strides, u: llvm::drop_end(RangeOrContainer: op.getMultiIndexMutable()))) { |
| 170 | Value scaledIdx = rewriter.create<arith::MulIOp>( |
| 171 | location: loc, args: idxOp.get(), args&: stride, args: arith::IntegerOverflowFlags::nsw); |
| 172 | int64_t numHoistableLoops = numEnclosingInvariantLoops(operand&: idxOp); |
| 173 | scaledValues.emplace_back(Args&: scaledIdx, Args&: numHoistableLoops); |
| 174 | } |
| 175 | scaledValues.emplace_back( |
| 176 | Args: multiIndex.back(), |
| 177 | Args: numEnclosingInvariantLoops(operand&: op.getMultiIndexMutable()[numIndexes - 1])); |
| 178 | |
| 179 | // Sort by how many enclosing loops there are, ties implicitly broken by |
| 180 | // size of the stride. |
| 181 | llvm::stable_sort(Range&: scaledValues, |
| 182 | C: [&](auto l, auto r) { return l.second > r.second; }); |
| 183 | |
| 184 | Value result = scaledValues.front().first; |
| 185 | for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(RangeOrContainer&: scaledValues)) { |
| 186 | std::ignore = numHoistableLoops; |
| 187 | result = rewriter.create<arith::AddIOp>(location: loc, args&: result, args&: scaledValue, |
| 188 | args: arith::IntegerOverflowFlags::nsw); |
| 189 | } |
| 190 | rewriter.replaceOp(op, newValues: result); |
| 191 | return success(); |
| 192 | } |
| 193 | |
| 194 | namespace { |
| 195 | struct LowerDelinearizeIndexOps |
| 196 | : public OpRewritePattern<AffineDelinearizeIndexOp> { |
| 197 | using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern; |
| 198 | LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, |
| 199 | PatternRewriter &rewriter) const override { |
| 200 | return affine::lowerAffineDelinearizeIndexOp(rewriter, op); |
| 201 | } |
| 202 | }; |
| 203 | |
| 204 | struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> { |
| 205 | using OpRewritePattern::OpRewritePattern; |
| 206 | LogicalResult matchAndRewrite(AffineLinearizeIndexOp op, |
| 207 | PatternRewriter &rewriter) const override { |
| 208 | return affine::lowerAffineLinearizeIndexOp(rewriter, op); |
| 209 | } |
| 210 | }; |
| 211 | |
| 212 | class ExpandAffineIndexOpsPass |
| 213 | : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> { |
| 214 | public: |
| 215 | ExpandAffineIndexOpsPass() = default; |
| 216 | |
| 217 | void runOnOperation() override { |
| 218 | MLIRContext *context = &getContext(); |
| 219 | RewritePatternSet patterns(context); |
| 220 | populateAffineExpandIndexOpsPatterns(patterns); |
| 221 | if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns)))) |
| 222 | return signalPassFailure(); |
| 223 | } |
| 224 | }; |
| 225 | |
| 226 | } // namespace |
| 227 | |
| 228 | void mlir::affine::populateAffineExpandIndexOpsPatterns( |
| 229 | RewritePatternSet &patterns) { |
| 230 | patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>( |
| 231 | arg: patterns.getContext()); |
| 232 | } |
| 233 | |
| 234 | std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() { |
| 235 | return std::make_unique<ExpandAffineIndexOpsPass>(); |
| 236 | } |
| 237 | |