1 | //===-- PreCGRewrite.cpp --------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/CodeGen/CodeGen.h" |
14 | |
15 | #include "CGOps.h" |
16 | #include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done |
17 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
18 | #include "flang/Optimizer/Dialect/FIROps.h" |
19 | #include "flang/Optimizer/Dialect/FIRType.h" |
20 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
21 | #include "mlir/Transforms/DialectConversion.h" |
22 | #include "mlir/Transforms/RegionUtils.h" |
23 | #include "llvm/ADT/STLExtras.h" |
24 | #include "llvm/Support/Debug.h" |
25 | |
26 | namespace fir { |
27 | #define GEN_PASS_DEF_CODEGENREWRITE |
28 | #include "flang/Optimizer/CodeGen/CGPasses.h.inc" |
29 | } // namespace fir |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // Codegen rewrite: rewriting of subgraphs of ops |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | #define DEBUG_TYPE "flang-codegen-rewrite" |
36 | |
37 | static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec, |
38 | fir::ShapeOp shape) { |
39 | vec.append(shape.getExtents().begin(), shape.getExtents().end()); |
40 | } |
41 | |
42 | // Operands of fir.shape_shift split into two vectors. |
43 | static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec, |
44 | llvm::SmallVectorImpl<mlir::Value> &shiftVec, |
45 | fir::ShapeShiftOp shift) { |
46 | for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end(); |
47 | i != endIter;) { |
48 | shiftVec.push_back(*i++); |
49 | shapeVec.push_back(*i++); |
50 | } |
51 | } |
52 | |
53 | static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec, |
54 | fir::ShiftOp shift) { |
55 | vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); |
56 | } |
57 | |
58 | namespace { |
59 | |
60 | /// Convert fir.embox to the extended form where necessary. |
61 | /// |
62 | /// The embox operation can take arguments that specify multidimensional array |
63 | /// properties at runtime. These properties may be shared between distinct |
64 | /// objects that have the same properties. Before we lower these small DAGs to |
65 | /// LLVM-IR, we gather all the information into a single extended operation. For |
66 | /// example, |
67 | /// ``` |
68 | /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> |
69 | /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> |
70 | /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, |
71 | /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>> |
72 | /// ``` |
73 | /// can be rewritten as |
74 | /// ``` |
75 | /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : |
76 | /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> |
77 | /// !fir.box<!fir.array<?xi32>> |
78 | /// ``` |
79 | class EmboxConversion : public mlir::OpRewritePattern<fir::EmboxOp> { |
80 | public: |
81 | using OpRewritePattern::OpRewritePattern; |
82 | |
83 | mlir::LogicalResult |
84 | matchAndRewrite(fir::EmboxOp embox, |
85 | mlir::PatternRewriter &rewriter) const override { |
86 | // If the embox does not include a shape, then do not convert it |
87 | if (auto shapeVal = embox.getShape()) |
88 | return rewriteDynamicShape(embox, rewriter, shapeVal); |
89 | if (embox.getType().isa<fir::ClassType>()) |
90 | TODO(embox.getLoc(), "embox conversion for fir.class type" ); |
91 | if (auto boxTy = embox.getType().dyn_cast<fir::BoxType>()) |
92 | if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>()) |
93 | if (!seqTy.hasDynamicExtents()) |
94 | return rewriteStaticShape(embox, rewriter, seqTy); |
95 | return mlir::failure(); |
96 | } |
97 | |
98 | mlir::LogicalResult rewriteStaticShape(fir::EmboxOp embox, |
99 | mlir::PatternRewriter &rewriter, |
100 | fir::SequenceType seqTy) const { |
101 | auto loc = embox.getLoc(); |
102 | llvm::SmallVector<mlir::Value> shapeOpers; |
103 | auto idxTy = rewriter.getIndexType(); |
104 | for (auto ext : seqTy.getShape()) { |
105 | auto iAttr = rewriter.getIndexAttr(ext); |
106 | auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr); |
107 | shapeOpers.push_back(extVal); |
108 | } |
109 | auto xbox = rewriter.create<fir::cg::XEmboxOp>( |
110 | loc, embox.getType(), embox.getMemref(), shapeOpers, std::nullopt, |
111 | std::nullopt, std::nullopt, std::nullopt, embox.getTypeparams(), |
112 | embox.getSourceBox()); |
113 | LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); |
114 | rewriter.replaceOp(embox, xbox.getOperation()->getResults()); |
115 | return mlir::success(); |
116 | } |
117 | |
118 | mlir::LogicalResult rewriteDynamicShape(fir::EmboxOp embox, |
119 | mlir::PatternRewriter &rewriter, |
120 | mlir::Value shapeVal) const { |
121 | auto loc = embox.getLoc(); |
122 | llvm::SmallVector<mlir::Value> shapeOpers; |
123 | llvm::SmallVector<mlir::Value> shiftOpers; |
124 | if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) { |
125 | populateShape(shapeOpers, shapeOp); |
126 | } else { |
127 | auto shiftOp = |
128 | mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()); |
129 | assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift" ); |
130 | populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); |
131 | } |
132 | llvm::SmallVector<mlir::Value> sliceOpers; |
133 | llvm::SmallVector<mlir::Value> subcompOpers; |
134 | llvm::SmallVector<mlir::Value> substrOpers; |
135 | if (auto s = embox.getSlice()) |
136 | if (auto sliceOp = |
137 | mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { |
138 | sliceOpers.assign(sliceOp.getTriples().begin(), |
139 | sliceOp.getTriples().end()); |
140 | subcompOpers.assign(sliceOp.getFields().begin(), |
141 | sliceOp.getFields().end()); |
142 | substrOpers.assign(sliceOp.getSubstr().begin(), |
143 | sliceOp.getSubstr().end()); |
144 | } |
145 | auto xbox = rewriter.create<fir::cg::XEmboxOp>( |
146 | loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers, |
147 | sliceOpers, subcompOpers, substrOpers, embox.getTypeparams(), |
148 | embox.getSourceBox()); |
149 | LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); |
150 | rewriter.replaceOp(embox, xbox.getOperation()->getResults()); |
151 | return mlir::success(); |
152 | } |
153 | }; |
154 | |
155 | /// Convert fir.rebox to the extended form where necessary. |
156 | /// |
157 | /// For example, |
158 | /// ``` |
159 | /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> |
160 | /// !fir.box<!fir.array<?xi32>> |
161 | /// ``` |
162 | /// converted to |
163 | /// ``` |
164 | /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, |
165 | /// index, index) -> !fir.box<!fir.array<?xi32>> |
166 | /// ``` |
167 | class ReboxConversion : public mlir::OpRewritePattern<fir::ReboxOp> { |
168 | public: |
169 | using OpRewritePattern::OpRewritePattern; |
170 | |
171 | mlir::LogicalResult |
172 | matchAndRewrite(fir::ReboxOp rebox, |
173 | mlir::PatternRewriter &rewriter) const override { |
174 | auto loc = rebox.getLoc(); |
175 | llvm::SmallVector<mlir::Value> shapeOpers; |
176 | llvm::SmallVector<mlir::Value> shiftOpers; |
177 | if (auto shapeVal = rebox.getShape()) { |
178 | if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) |
179 | populateShape(shapeOpers, shapeOp); |
180 | else if (auto shiftOp = |
181 | mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp())) |
182 | populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); |
183 | else if (auto shiftOp = |
184 | mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp())) |
185 | populateShift(shiftOpers, shiftOp); |
186 | else |
187 | return mlir::failure(); |
188 | } |
189 | llvm::SmallVector<mlir::Value> sliceOpers; |
190 | llvm::SmallVector<mlir::Value> subcompOpers; |
191 | llvm::SmallVector<mlir::Value> substrOpers; |
192 | if (auto s = rebox.getSlice()) |
193 | if (auto sliceOp = |
194 | mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { |
195 | sliceOpers.append(sliceOp.getTriples().begin(), |
196 | sliceOp.getTriples().end()); |
197 | subcompOpers.append(sliceOp.getFields().begin(), |
198 | sliceOp.getFields().end()); |
199 | substrOpers.append(sliceOp.getSubstr().begin(), |
200 | sliceOp.getSubstr().end()); |
201 | } |
202 | |
203 | auto xRebox = rewriter.create<fir::cg::XReboxOp>( |
204 | loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers, |
205 | sliceOpers, subcompOpers, substrOpers); |
206 | LLVM_DEBUG(llvm::dbgs() |
207 | << "rewriting " << rebox << " to " << xRebox << '\n'); |
208 | rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); |
209 | return mlir::success(); |
210 | } |
211 | }; |
212 | |
213 | /// Convert all fir.array_coor to the extended form. |
214 | /// |
215 | /// For example, |
216 | /// ``` |
217 | /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, |
218 | /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32> |
219 | /// ``` |
220 | /// converted to |
221 | /// ``` |
222 | /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : |
223 | /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> |
224 | /// !fir.ref<i32> |
225 | /// ``` |
226 | class ArrayCoorConversion : public mlir::OpRewritePattern<fir::ArrayCoorOp> { |
227 | public: |
228 | using OpRewritePattern::OpRewritePattern; |
229 | |
230 | mlir::LogicalResult |
231 | matchAndRewrite(fir::ArrayCoorOp arrCoor, |
232 | mlir::PatternRewriter &rewriter) const override { |
233 | auto loc = arrCoor.getLoc(); |
234 | llvm::SmallVector<mlir::Value> shapeOpers; |
235 | llvm::SmallVector<mlir::Value> shiftOpers; |
236 | if (auto shapeVal = arrCoor.getShape()) { |
237 | if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) |
238 | populateShape(shapeOpers, shapeOp); |
239 | else if (auto shiftOp = |
240 | mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp())) |
241 | populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); |
242 | else if (auto shiftOp = |
243 | mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp())) |
244 | populateShift(shiftOpers, shiftOp); |
245 | else |
246 | return mlir::failure(); |
247 | } |
248 | llvm::SmallVector<mlir::Value> sliceOpers; |
249 | llvm::SmallVector<mlir::Value> subcompOpers; |
250 | if (auto s = arrCoor.getSlice()) |
251 | if (auto sliceOp = |
252 | mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { |
253 | sliceOpers.append(sliceOp.getTriples().begin(), |
254 | sliceOp.getTriples().end()); |
255 | subcompOpers.append(sliceOp.getFields().begin(), |
256 | sliceOp.getFields().end()); |
257 | assert(sliceOp.getSubstr().empty() && |
258 | "Don't allow substring operations on array_coor. This " |
259 | "restriction may be lifted in the future." ); |
260 | } |
261 | auto xArrCoor = rewriter.create<fir::cg::XArrayCoorOp>( |
262 | loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers, |
263 | sliceOpers, subcompOpers, arrCoor.getIndices(), |
264 | arrCoor.getTypeparams()); |
265 | LLVM_DEBUG(llvm::dbgs() |
266 | << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); |
267 | rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); |
268 | return mlir::success(); |
269 | } |
270 | }; |
271 | |
272 | class DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> { |
273 | public: |
274 | using OpRewritePattern::OpRewritePattern; |
275 | |
276 | mlir::LogicalResult |
277 | matchAndRewrite(fir::DeclareOp declareOp, |
278 | mlir::PatternRewriter &rewriter) const override { |
279 | rewriter.replaceOp(declareOp, declareOp.getMemref()); |
280 | return mlir::success(); |
281 | } |
282 | }; |
283 | |
284 | class CodeGenRewrite : public fir::impl::CodeGenRewriteBase<CodeGenRewrite> { |
285 | public: |
286 | void runOnOperation() override final { |
287 | mlir::ModuleOp mod = getOperation(); |
288 | |
289 | auto &context = getContext(); |
290 | mlir::ConversionTarget target(context); |
291 | target.addLegalDialect<mlir::arith::ArithDialect, fir::FIROpsDialect, |
292 | fir::FIRCodeGenDialect, mlir::func::FuncDialect>(); |
293 | target.addIllegalOp<fir::ArrayCoorOp>(); |
294 | target.addIllegalOp<fir::ReboxOp>(); |
295 | target.addIllegalOp<fir::DeclareOp>(); |
296 | target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) { |
297 | return !(embox.getShape() || embox.getType() |
298 | .cast<fir::BaseBoxType>() |
299 | .getEleTy() |
300 | .isa<fir::SequenceType>()); |
301 | }); |
302 | mlir::RewritePatternSet patterns(&context); |
303 | fir::populatePreCGRewritePatterns(patterns); |
304 | if (mlir::failed( |
305 | mlir::applyPartialConversion(mod, target, std::move(patterns)))) { |
306 | mlir::emitError(mlir::UnknownLoc::get(&context), |
307 | "error in running the pre-codegen conversions" ); |
308 | signalPassFailure(); |
309 | return; |
310 | } |
311 | // Erase any residual (fir.shape, fir.slice...). |
312 | mlir::IRRewriter rewriter(&context); |
313 | (void)mlir::runRegionDCE(rewriter, mod->getRegions()); |
314 | } |
315 | }; |
316 | |
317 | } // namespace |
318 | |
319 | std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() { |
320 | return std::make_unique<CodeGenRewrite>(); |
321 | } |
322 | |
323 | void fir::populatePreCGRewritePatterns(mlir::RewritePatternSet &patterns) { |
324 | patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion, |
325 | DeclareOpConversion>(patterns.getContext()); |
326 | } |
327 | |