1 | //===-- PreCGRewrite.cpp --------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/CodeGen/CodeGen.h" |
14 | |
15 | #include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done |
16 | #include "flang/Optimizer/Dialect/FIRCG/CGOps.h" |
17 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
18 | #include "flang/Optimizer/Dialect/FIROps.h" |
19 | #include "flang/Optimizer/Dialect/FIRType.h" |
20 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
21 | #include "mlir/IR/Iterators.h" |
22 | #include "mlir/Transforms/DialectConversion.h" |
23 | #include "llvm/ADT/STLExtras.h" |
24 | #include "llvm/Support/Debug.h" |
25 | |
26 | namespace fir { |
27 | #define GEN_PASS_DEF_CODEGENREWRITE |
28 | #include "flang/Optimizer/CodeGen/CGPasses.h.inc" |
29 | } // namespace fir |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // Codegen rewrite: rewriting of subgraphs of ops |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | #define DEBUG_TYPE "flang-codegen-rewrite" |
36 | |
37 | static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec, |
38 | fir::ShapeOp shape) { |
39 | vec.append(shape.getExtents().begin(), shape.getExtents().end()); |
40 | } |
41 | |
42 | // Operands of fir.shape_shift split into two vectors. |
43 | static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec, |
44 | llvm::SmallVectorImpl<mlir::Value> &shiftVec, |
45 | fir::ShapeShiftOp shift) { |
46 | for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end(); |
47 | i != endIter;) { |
48 | shiftVec.push_back(*i++); |
49 | shapeVec.push_back(*i++); |
50 | } |
51 | } |
52 | |
53 | static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec, |
54 | fir::ShiftOp shift) { |
55 | vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); |
56 | } |
57 | |
58 | namespace { |
59 | |
60 | /// Convert fir.embox to the extended form where necessary. |
61 | /// |
62 | /// The embox operation can take arguments that specify multidimensional array |
63 | /// properties at runtime. These properties may be shared between distinct |
64 | /// objects that have the same properties. Before we lower these small DAGs to |
65 | /// LLVM-IR, we gather all the information into a single extended operation. For |
66 | /// example, |
67 | /// ``` |
68 | /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> |
69 | /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> |
70 | /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, |
71 | /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>> |
72 | /// ``` |
73 | /// can be rewritten as |
74 | /// ``` |
75 | /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : |
76 | /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> |
77 | /// !fir.box<!fir.array<?xi32>> |
78 | /// ``` |
79 | class EmboxConversion : public mlir::OpRewritePattern<fir::EmboxOp> { |
80 | public: |
81 | using OpRewritePattern::OpRewritePattern; |
82 | |
83 | llvm::LogicalResult |
84 | matchAndRewrite(fir::EmboxOp embox, |
85 | mlir::PatternRewriter &rewriter) const override { |
86 | // If the embox does not include a shape, then do not convert it |
87 | if (auto shapeVal = embox.getShape()) |
88 | return rewriteDynamicShape(embox, rewriter, shapeVal); |
89 | if (mlir::isa<fir::ClassType>(embox.getType())) |
90 | TODO(embox.getLoc(), "embox conversion for fir.class type" ); |
91 | if (auto boxTy = mlir::dyn_cast<fir::BoxType>(embox.getType())) |
92 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy())) |
93 | if (!seqTy.hasDynamicExtents()) |
94 | return rewriteStaticShape(embox, rewriter, seqTy); |
95 | return mlir::failure(); |
96 | } |
97 | |
98 | llvm::LogicalResult rewriteStaticShape(fir::EmboxOp embox, |
99 | mlir::PatternRewriter &rewriter, |
100 | fir::SequenceType seqTy) const { |
101 | auto loc = embox.getLoc(); |
102 | llvm::SmallVector<mlir::Value> shapeOpers; |
103 | auto idxTy = rewriter.getIndexType(); |
104 | for (auto ext : seqTy.getShape()) { |
105 | auto iAttr = rewriter.getIndexAttr(ext); |
106 | auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr); |
107 | shapeOpers.push_back(extVal); |
108 | } |
109 | auto xbox = rewriter.create<fir::cg::XEmboxOp>( |
110 | loc, embox.getType(), embox.getMemref(), shapeOpers, std::nullopt, |
111 | std::nullopt, std::nullopt, std::nullopt, embox.getTypeparams(), |
112 | embox.getSourceBox(), embox.getAllocatorIdxAttr()); |
113 | LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); |
114 | rewriter.replaceOp(embox, xbox.getOperation()->getResults()); |
115 | return mlir::success(); |
116 | } |
117 | |
118 | llvm::LogicalResult rewriteDynamicShape(fir::EmboxOp embox, |
119 | mlir::PatternRewriter &rewriter, |
120 | mlir::Value shapeVal) const { |
121 | auto loc = embox.getLoc(); |
122 | llvm::SmallVector<mlir::Value> shapeOpers; |
123 | llvm::SmallVector<mlir::Value> shiftOpers; |
124 | if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) { |
125 | populateShape(shapeOpers, shapeOp); |
126 | } else { |
127 | auto shiftOp = |
128 | mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()); |
129 | assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift" ); |
130 | populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); |
131 | } |
132 | llvm::SmallVector<mlir::Value> sliceOpers; |
133 | llvm::SmallVector<mlir::Value> subcompOpers; |
134 | llvm::SmallVector<mlir::Value> substrOpers; |
135 | if (auto s = embox.getSlice()) |
136 | if (auto sliceOp = |
137 | mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { |
138 | sliceOpers.assign(sliceOp.getTriples().begin(), |
139 | sliceOp.getTriples().end()); |
140 | subcompOpers.assign(sliceOp.getFields().begin(), |
141 | sliceOp.getFields().end()); |
142 | substrOpers.assign(sliceOp.getSubstr().begin(), |
143 | sliceOp.getSubstr().end()); |
144 | } |
145 | auto xbox = rewriter.create<fir::cg::XEmboxOp>( |
146 | loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers, |
147 | sliceOpers, subcompOpers, substrOpers, embox.getTypeparams(), |
148 | embox.getSourceBox(), embox.getAllocatorIdxAttr()); |
149 | LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); |
150 | rewriter.replaceOp(embox, xbox.getOperation()->getResults()); |
151 | return mlir::success(); |
152 | } |
153 | }; |
154 | |
155 | /// Convert fir.rebox to the extended form where necessary. |
156 | /// |
157 | /// For example, |
158 | /// ``` |
159 | /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> |
160 | /// !fir.box<!fir.array<?xi32>> |
161 | /// ``` |
162 | /// converted to |
163 | /// ``` |
164 | /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, |
165 | /// index, index) -> !fir.box<!fir.array<?xi32>> |
166 | /// ``` |
167 | class ReboxConversion : public mlir::OpRewritePattern<fir::ReboxOp> { |
168 | public: |
169 | using OpRewritePattern::OpRewritePattern; |
170 | |
171 | llvm::LogicalResult |
172 | matchAndRewrite(fir::ReboxOp rebox, |
173 | mlir::PatternRewriter &rewriter) const override { |
174 | auto loc = rebox.getLoc(); |
175 | llvm::SmallVector<mlir::Value> shapeOpers; |
176 | llvm::SmallVector<mlir::Value> shiftOpers; |
177 | if (auto shapeVal = rebox.getShape()) { |
178 | if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) |
179 | populateShape(shapeOpers, shapeOp); |
180 | else if (auto shiftOp = |
181 | mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp())) |
182 | populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); |
183 | else if (auto shiftOp = |
184 | mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp())) |
185 | populateShift(shiftOpers, shiftOp); |
186 | else |
187 | return mlir::failure(); |
188 | } |
189 | llvm::SmallVector<mlir::Value> sliceOpers; |
190 | llvm::SmallVector<mlir::Value> subcompOpers; |
191 | llvm::SmallVector<mlir::Value> substrOpers; |
192 | if (auto s = rebox.getSlice()) |
193 | if (auto sliceOp = |
194 | mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { |
195 | sliceOpers.append(sliceOp.getTriples().begin(), |
196 | sliceOp.getTriples().end()); |
197 | subcompOpers.append(sliceOp.getFields().begin(), |
198 | sliceOp.getFields().end()); |
199 | substrOpers.append(sliceOp.getSubstr().begin(), |
200 | sliceOp.getSubstr().end()); |
201 | } |
202 | |
203 | auto xRebox = rewriter.create<fir::cg::XReboxOp>( |
204 | loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers, |
205 | sliceOpers, subcompOpers, substrOpers); |
206 | LLVM_DEBUG(llvm::dbgs() |
207 | << "rewriting " << rebox << " to " << xRebox << '\n'); |
208 | rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); |
209 | return mlir::success(); |
210 | } |
211 | }; |
212 | |
213 | /// Convert all fir.array_coor to the extended form. |
214 | /// |
215 | /// For example, |
216 | /// ``` |
217 | /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, |
218 | /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32> |
219 | /// ``` |
220 | /// converted to |
221 | /// ``` |
222 | /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : |
223 | /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> |
224 | /// !fir.ref<i32> |
225 | /// ``` |
226 | class ArrayCoorConversion : public mlir::OpRewritePattern<fir::ArrayCoorOp> { |
227 | public: |
228 | using OpRewritePattern::OpRewritePattern; |
229 | |
230 | llvm::LogicalResult |
231 | matchAndRewrite(fir::ArrayCoorOp arrCoor, |
232 | mlir::PatternRewriter &rewriter) const override { |
233 | auto loc = arrCoor.getLoc(); |
234 | llvm::SmallVector<mlir::Value> shapeOpers; |
235 | llvm::SmallVector<mlir::Value> shiftOpers; |
236 | if (auto shapeVal = arrCoor.getShape()) { |
237 | if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) |
238 | populateShape(shapeOpers, shapeOp); |
239 | else if (auto shiftOp = |
240 | mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp())) |
241 | populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); |
242 | else if (auto shiftOp = |
243 | mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp())) |
244 | populateShift(shiftOpers, shiftOp); |
245 | else |
246 | return mlir::failure(); |
247 | } |
248 | llvm::SmallVector<mlir::Value> sliceOpers; |
249 | llvm::SmallVector<mlir::Value> subcompOpers; |
250 | if (auto s = arrCoor.getSlice()) |
251 | if (auto sliceOp = |
252 | mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) { |
253 | sliceOpers.append(sliceOp.getTriples().begin(), |
254 | sliceOp.getTriples().end()); |
255 | subcompOpers.append(sliceOp.getFields().begin(), |
256 | sliceOp.getFields().end()); |
257 | assert(sliceOp.getSubstr().empty() && |
258 | "Don't allow substring operations on array_coor. This " |
259 | "restriction may be lifted in the future." ); |
260 | } |
261 | auto xArrCoor = rewriter.create<fir::cg::XArrayCoorOp>( |
262 | loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers, |
263 | sliceOpers, subcompOpers, arrCoor.getIndices(), |
264 | arrCoor.getTypeparams()); |
265 | LLVM_DEBUG(llvm::dbgs() |
266 | << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); |
267 | rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); |
268 | return mlir::success(); |
269 | } |
270 | }; |
271 | |
272 | class DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> { |
273 | bool preserveDeclare; |
274 | |
275 | public: |
276 | using OpRewritePattern::OpRewritePattern; |
277 | DeclareOpConversion(mlir::MLIRContext *ctx, bool preserveDecl) |
278 | : OpRewritePattern(ctx), preserveDeclare(preserveDecl) {} |
279 | |
280 | llvm::LogicalResult |
281 | matchAndRewrite(fir::DeclareOp declareOp, |
282 | mlir::PatternRewriter &rewriter) const override { |
283 | if (!preserveDeclare) { |
284 | rewriter.replaceOp(declareOp, declareOp.getMemref()); |
285 | return mlir::success(); |
286 | } |
287 | auto loc = declareOp.getLoc(); |
288 | llvm::SmallVector<mlir::Value> shapeOpers; |
289 | llvm::SmallVector<mlir::Value> shiftOpers; |
290 | if (auto shapeVal = declareOp.getShape()) { |
291 | if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) |
292 | populateShape(shapeOpers, shapeOp); |
293 | else if (auto shiftOp = |
294 | mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp())) |
295 | populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); |
296 | else if (auto shiftOp = |
297 | mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp())) |
298 | populateShift(shiftOpers, shiftOp); |
299 | else |
300 | return mlir::failure(); |
301 | } |
302 | // FIXME: Add FortranAttrs and CudaAttrs |
303 | auto xDeclOp = rewriter.create<fir::cg::XDeclareOp>( |
304 | loc, declareOp.getType(), declareOp.getMemref(), shapeOpers, shiftOpers, |
305 | declareOp.getTypeparams(), declareOp.getDummyScope(), |
306 | declareOp.getUniqName()); |
307 | LLVM_DEBUG(llvm::dbgs() |
308 | << "rewriting " << declareOp << " to " << xDeclOp << '\n'); |
309 | rewriter.replaceOp(declareOp, xDeclOp.getOperation()->getResults()); |
310 | return mlir::success(); |
311 | } |
312 | }; |
313 | |
314 | class DummyScopeOpConversion |
315 | : public mlir::OpRewritePattern<fir::DummyScopeOp> { |
316 | public: |
317 | using OpRewritePattern::OpRewritePattern; |
318 | |
319 | llvm::LogicalResult |
320 | matchAndRewrite(fir::DummyScopeOp dummyScopeOp, |
321 | mlir::PatternRewriter &rewriter) const override { |
322 | rewriter.replaceOpWithNewOp<fir::UndefOp>(dummyScopeOp, |
323 | dummyScopeOp.getType()); |
324 | return mlir::success(); |
325 | } |
326 | }; |
327 | |
328 | /// Simple DCE to erase fir.shape/shift/slice/unused shape operands after this |
329 | /// pass (fir.shape and like have no codegen). |
330 | /// mlir::RegionDCE is expensive and requires running |
331 | /// mlir::eraseUnreachableBlocks. It does things that are not needed here, like |
332 | /// removing unused block arguments. fir.shape/shift/slice cannot be block |
333 | /// arguments. |
334 | /// This helper does a naive backward walk of the IR. It is not even guaranteed |
335 | /// to walk blocks according to backward dominance, but that is good enough for |
336 | /// what is done here, fir.shape/shift/slice have no usages anymore. The |
337 | /// backward walk allows getting rid of most of the unused operands, it is not a |
338 | /// problem to leave some in the weird cases. |
339 | static void simpleDCE(mlir::RewriterBase &rewriter, mlir::Operation *op) { |
340 | op->walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>( |
341 | [&](mlir::Operation *subOp) { |
342 | if (mlir::isOpTriviallyDead(subOp)) |
343 | rewriter.eraseOp(subOp); |
344 | }); |
345 | } |
346 | |
347 | class CodeGenRewrite : public fir::impl::CodeGenRewriteBase<CodeGenRewrite> { |
348 | public: |
349 | using CodeGenRewriteBase<CodeGenRewrite>::CodeGenRewriteBase; |
350 | |
351 | void runOnOperation() override final { |
352 | mlir::ModuleOp mod = getOperation(); |
353 | |
354 | auto &context = getContext(); |
355 | mlir::ConversionTarget target(context); |
356 | target.addLegalDialect<mlir::arith::ArithDialect, fir::FIROpsDialect, |
357 | fir::FIRCodeGenDialect, mlir::func::FuncDialect>(); |
358 | target.addIllegalOp<fir::ArrayCoorOp>(); |
359 | target.addIllegalOp<fir::ReboxOp>(); |
360 | target.addIllegalOp<fir::DeclareOp>(); |
361 | target.addIllegalOp<fir::DummyScopeOp>(); |
362 | target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) { |
363 | return !(embox.getShape() || |
364 | mlir::isa<fir::SequenceType>( |
365 | mlir::cast<fir::BaseBoxType>(embox.getType()).getEleTy())); |
366 | }); |
367 | mlir::RewritePatternSet patterns(&context); |
368 | fir::populatePreCGRewritePatterns(patterns, preserveDeclare); |
369 | if (mlir::failed( |
370 | mlir::applyPartialConversion(mod, target, std::move(patterns)))) { |
371 | mlir::emitError(mlir::UnknownLoc::get(&context), |
372 | "error in running the pre-codegen conversions" ); |
373 | signalPassFailure(); |
374 | return; |
375 | } |
376 | // Erase any residual (fir.shape, fir.slice...). |
377 | mlir::IRRewriter rewriter(&context); |
378 | simpleDCE(rewriter, mod.getOperation()); |
379 | } |
380 | }; |
381 | |
382 | } // namespace |
383 | |
384 | void fir::populatePreCGRewritePatterns(mlir::RewritePatternSet &patterns, |
385 | bool preserveDeclare) { |
386 | patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion, |
387 | DummyScopeOpConversion>(patterns.getContext()); |
388 | patterns.add<DeclareOpConversion>(patterns.getContext(), preserveDeclare); |
389 | } |
390 | |