1 | //===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to expand affine index ops into one or more more |
10 | // fundamental operations. |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/Passes.h" |
14 | |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
17 | #include "mlir/Dialect/Affine/Utils.h" |
18 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
19 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
20 | |
21 | namespace mlir { |
22 | namespace affine { |
23 | #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE |
24 | #include "mlir/Dialect/Affine/Passes.h.inc" |
25 | } // namespace affine |
26 | } // namespace mlir |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::affine; |
30 | |
31 | namespace { |
32 | /// Lowers `affine.delinearize_index` into a sequence of division and remainder |
33 | /// operations. |
34 | struct LowerDelinearizeIndexOps |
35 | : public OpRewritePattern<AffineDelinearizeIndexOp> { |
36 | using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern; |
37 | LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, |
38 | PatternRewriter &rewriter) const override { |
39 | FailureOr<SmallVector<Value>> multiIndex = |
40 | delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(), |
41 | op.getEffectiveBasis(), /*hasOuterBound=*/false); |
42 | if (failed(multiIndex)) |
43 | return failure(); |
44 | rewriter.replaceOp(op, *multiIndex); |
45 | return success(); |
46 | } |
47 | }; |
48 | |
49 | /// Lowers `affine.linearize_index` into a sequence of multiplications and |
50 | /// additions. |
51 | struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> { |
52 | using OpRewritePattern::OpRewritePattern; |
53 | LogicalResult matchAndRewrite(AffineLinearizeIndexOp op, |
54 | PatternRewriter &rewriter) const override { |
55 | // Should be folded away, included here for safety. |
56 | if (op.getMultiIndex().empty()) { |
57 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0); |
58 | return success(); |
59 | } |
60 | |
61 | SmallVector<OpFoldResult> multiIndex = |
62 | getAsOpFoldResult(op.getMultiIndex()); |
63 | OpFoldResult linearIndex = |
64 | linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis()); |
65 | Value linearIndexValue = |
66 | getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex); |
67 | rewriter.replaceOp(op, linearIndexValue); |
68 | return success(); |
69 | } |
70 | }; |
71 | |
72 | class ExpandAffineIndexOpsAsAffinePass |
73 | : public affine::impl::AffineExpandIndexOpsAsAffineBase< |
74 | ExpandAffineIndexOpsAsAffinePass> { |
75 | public: |
76 | ExpandAffineIndexOpsAsAffinePass() = default; |
77 | |
78 | void runOnOperation() override { |
79 | MLIRContext *context = &getContext(); |
80 | RewritePatternSet patterns(context); |
81 | populateAffineExpandIndexOpsAsAffinePatterns(patterns); |
82 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
83 | return signalPassFailure(); |
84 | } |
85 | }; |
86 | |
87 | } // namespace |
88 | |
89 | void mlir::affine::populateAffineExpandIndexOpsAsAffinePatterns( |
90 | RewritePatternSet &patterns) { |
91 | patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>( |
92 | arg: patterns.getContext()); |
93 | } |
94 | |
95 | std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsAsAffinePass() { |
96 | return std::make_unique<ExpandAffineIndexOpsAsAffinePass>(); |
97 | } |
98 | |