1 | //===-- AffineDemotion.cpp -----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This transformation is a prototype that demote affine dialects operations |
10 | // after optimizations to FIR loops operations. |
11 | // It is used after the AffinePromotion pass. |
12 | // It is not part of the production pipeline and would need more work in order |
13 | // to be used in production. |
14 | // More information can be found in this presentation: |
15 | // https://slides.com/rajanwalia/deck |
16 | // |
17 | //===----------------------------------------------------------------------===// |
18 | |
19 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
20 | #include "flang/Optimizer/Dialect/FIROps.h" |
21 | #include "flang/Optimizer/Dialect/FIRType.h" |
22 | #include "flang/Optimizer/Transforms/Passes.h" |
23 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
24 | #include "mlir/Dialect/Affine/Utils.h" |
25 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
26 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
27 | #include "mlir/Dialect/SCF/IR/SCF.h" |
28 | #include "mlir/IR/BuiltinAttributes.h" |
29 | #include "mlir/IR/IntegerSet.h" |
30 | #include "mlir/IR/Visitors.h" |
31 | #include "mlir/Pass/Pass.h" |
32 | #include "mlir/Transforms/DialectConversion.h" |
33 | #include "llvm/ADT/DenseMap.h" |
34 | #include "llvm/Support/CommandLine.h" |
35 | #include "llvm/Support/Debug.h" |
36 | |
37 | namespace fir { |
38 | #define GEN_PASS_DEF_AFFINEDIALECTDEMOTION |
39 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
40 | } // namespace fir |
41 | |
42 | #define DEBUG_TYPE "flang-affine-demotion" |
43 | |
44 | using namespace fir; |
45 | using namespace mlir; |
46 | |
47 | namespace { |
48 | |
49 | class AffineLoadConversion |
50 | : public OpConversionPattern<mlir::affine::AffineLoadOp> { |
51 | public: |
52 | using OpConversionPattern<mlir::affine::AffineLoadOp>::OpConversionPattern; |
53 | |
54 | LogicalResult |
55 | matchAndRewrite(mlir::affine::AffineLoadOp op, OpAdaptor adaptor, |
56 | ConversionPatternRewriter &rewriter) const override { |
57 | SmallVector<Value> indices(adaptor.getIndices()); |
58 | auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(), |
59 | op.getAffineMap(), indices); |
60 | if (!maybeExpandedMap) |
61 | return failure(); |
62 | |
63 | auto coorOp = rewriter.create<fir::CoordinateOp>( |
64 | op.getLoc(), fir::ReferenceType::get(op.getResult().getType()), |
65 | adaptor.getMemref(), *maybeExpandedMap); |
66 | |
67 | rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult()); |
68 | return success(); |
69 | } |
70 | }; |
71 | |
72 | class AffineStoreConversion |
73 | : public OpConversionPattern<mlir::affine::AffineStoreOp> { |
74 | public: |
75 | using OpConversionPattern<mlir::affine::AffineStoreOp>::OpConversionPattern; |
76 | |
77 | LogicalResult |
78 | matchAndRewrite(mlir::affine::AffineStoreOp op, OpAdaptor adaptor, |
79 | ConversionPatternRewriter &rewriter) const override { |
80 | SmallVector<Value> indices(op.getIndices()); |
81 | auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(), |
82 | op.getAffineMap(), indices); |
83 | if (!maybeExpandedMap) |
84 | return failure(); |
85 | |
86 | auto coorOp = rewriter.create<fir::CoordinateOp>( |
87 | op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()), |
88 | adaptor.getMemref(), *maybeExpandedMap); |
89 | rewriter.replaceOpWithNewOp<fir::StoreOp>(op, adaptor.getValue(), |
90 | coorOp.getResult()); |
91 | return success(); |
92 | } |
93 | }; |
94 | |
95 | class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> { |
96 | public: |
97 | using OpRewritePattern::OpRewritePattern; |
98 | mlir::LogicalResult |
99 | matchAndRewrite(fir::ConvertOp op, |
100 | mlir::PatternRewriter &rewriter) const override { |
101 | if (op.getRes().getType().isa<mlir::MemRefType>()) { |
102 | // due to index calculation moving to affine maps we still need to |
103 | // add converts for sequence types this has a side effect of losing |
104 | // some information about arrays with known dimensions by creating: |
105 | // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> |
106 | // !fir.ref<!fir.array<?xi32>> |
107 | if (auto refTy = op.getValue().getType().dyn_cast<fir::ReferenceType>()) |
108 | if (auto arrTy = refTy.getEleTy().dyn_cast<fir::SequenceType>()) { |
109 | fir::SequenceType::Shape flatShape = { |
110 | fir::SequenceType::getUnknownExtent()}; |
111 | auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy()); |
112 | auto flatTy = fir::ReferenceType::get(flatArrTy); |
113 | rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy, |
114 | op.getValue()); |
115 | return success(); |
116 | } |
117 | rewriter.startOpModification(op->getParentOp()); |
118 | op.getResult().replaceAllUsesWith(op.getValue()); |
119 | rewriter.finalizeOpModification(op->getParentOp()); |
120 | rewriter.eraseOp(op); |
121 | } |
122 | return success(); |
123 | } |
124 | }; |
125 | |
126 | mlir::Type convertMemRef(mlir::MemRefType type) { |
127 | return fir::SequenceType::get( |
128 | SmallVector<int64_t>(type.getShape().begin(), type.getShape().end()), |
129 | type.getElementType()); |
130 | } |
131 | |
132 | class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> { |
133 | public: |
134 | using OpRewritePattern::OpRewritePattern; |
135 | mlir::LogicalResult |
136 | matchAndRewrite(memref::AllocOp op, |
137 | mlir::PatternRewriter &rewriter) const override { |
138 | rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()), |
139 | op.getMemref()); |
140 | return success(); |
141 | } |
142 | }; |
143 | |
144 | class AffineDialectDemotion |
145 | : public fir::impl::AffineDialectDemotionBase<AffineDialectDemotion> { |
146 | public: |
147 | void runOnOperation() override { |
148 | auto *context = &getContext(); |
149 | auto function = getOperation(); |
150 | LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n" ; |
151 | function.print(llvm::dbgs());); |
152 | |
153 | mlir::RewritePatternSet patterns(context); |
154 | patterns.insert<ConvertConversion>(context); |
155 | patterns.insert<AffineLoadConversion>(context); |
156 | patterns.insert<AffineStoreConversion>(context); |
157 | patterns.insert<StdAllocConversion>(context); |
158 | mlir::ConversionTarget target(*context); |
159 | target.addIllegalOp<memref::AllocOp>(); |
160 | target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) { |
161 | if (op.getRes().getType().isa<mlir::MemRefType>()) |
162 | return false; |
163 | return true; |
164 | }); |
165 | target |
166 | .addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, |
167 | mlir::arith::ArithDialect, mlir::func::FuncDialect>(); |
168 | |
169 | if (mlir::failed(mlir::applyPartialConversion(function, target, |
170 | std::move(patterns)))) { |
171 | mlir::emitError(mlir::UnknownLoc::get(context), |
172 | "error in converting affine dialect\n" ); |
173 | signalPassFailure(); |
174 | } |
175 | } |
176 | }; |
177 | |
178 | } // namespace |
179 | |
180 | std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() { |
181 | return std::make_unique<AffineDialectDemotion>(); |
182 | } |
183 | |