1//===-- PreCGRewrite.cpp --------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/CodeGen/CodeGen.h"
14
15#include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done
16#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
17#include "flang/Optimizer/Dialect/FIRDialect.h"
18#include "flang/Optimizer/Dialect/FIROps.h"
19#include "flang/Optimizer/Dialect/FIRType.h"
20#include "flang/Optimizer/Dialect/Support/FIRContext.h"
21#include "mlir/IR/Iterators.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25
26namespace fir {
27#define GEN_PASS_DEF_CODEGENREWRITE
28#include "flang/Optimizer/CodeGen/CGPasses.h.inc"
29} // namespace fir
30
31//===----------------------------------------------------------------------===//
32// Codegen rewrite: rewriting of subgraphs of ops
33//===----------------------------------------------------------------------===//
34
35#define DEBUG_TYPE "flang-codegen-rewrite"
36
37static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
38 fir::ShapeOp shape) {
39 vec.append(shape.getExtents().begin(), shape.getExtents().end());
40}
41
42// Operands of fir.shape_shift split into two vectors.
43static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
44 llvm::SmallVectorImpl<mlir::Value> &shiftVec,
45 fir::ShapeShiftOp shift) {
46 for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end();
47 i != endIter;) {
48 shiftVec.push_back(*i++);
49 shapeVec.push_back(*i++);
50 }
51}
52
53static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
54 fir::ShiftOp shift) {
55 vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
56}
57
58namespace {
59
60/// Convert fir.embox to the extended form where necessary.
61///
62/// The embox operation can take arguments that specify multidimensional array
63/// properties at runtime. These properties may be shared between distinct
64/// objects that have the same properties. Before we lower these small DAGs to
65/// LLVM-IR, we gather all the information into a single extended operation. For
66/// example,
67/// ```
68/// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
69/// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
70/// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>,
71/// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
72/// ```
73/// can be rewritten as
74/// ```
75/// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] :
76/// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
77/// !fir.box<!fir.array<?xi32>>
78/// ```
79class EmboxConversion : public mlir::OpRewritePattern<fir::EmboxOp> {
80public:
81 using OpRewritePattern::OpRewritePattern;
82
83 llvm::LogicalResult
84 matchAndRewrite(fir::EmboxOp embox,
85 mlir::PatternRewriter &rewriter) const override {
86 // If the embox does not include a shape, then do not convert it
87 if (auto shapeVal = embox.getShape())
88 return rewriteDynamicShape(embox, rewriter, shapeVal);
89 if (mlir::isa<fir::ClassType>(embox.getType()))
90 TODO(embox.getLoc(), "embox conversion for fir.class type");
91 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(embox.getType()))
92 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
93 if (!seqTy.hasDynamicExtents())
94 return rewriteStaticShape(embox, rewriter, seqTy);
95 return mlir::failure();
96 }
97
98 llvm::LogicalResult rewriteStaticShape(fir::EmboxOp embox,
99 mlir::PatternRewriter &rewriter,
100 fir::SequenceType seqTy) const {
101 auto loc = embox.getLoc();
102 llvm::SmallVector<mlir::Value> shapeOpers;
103 auto idxTy = rewriter.getIndexType();
104 for (auto ext : seqTy.getShape()) {
105 auto iAttr = rewriter.getIndexAttr(ext);
106 auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr);
107 shapeOpers.push_back(extVal);
108 }
109 auto xbox = rewriter.create<fir::cg::XEmboxOp>(
110 loc, embox.getType(), embox.getMemref(), shapeOpers, mlir::ValueRange{},
111 mlir::ValueRange{}, mlir::ValueRange{}, mlir::ValueRange{},
112 embox.getTypeparams(), embox.getSourceBox(),
113 embox.getAllocatorIdxAttr());
114 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
115 rewriter.replaceOp(embox, xbox.getOperation()->getResults());
116 return mlir::success();
117 }
118
119 llvm::LogicalResult rewriteDynamicShape(fir::EmboxOp embox,
120 mlir::PatternRewriter &rewriter,
121 mlir::Value shapeVal) const {
122 auto loc = embox.getLoc();
123 llvm::SmallVector<mlir::Value> shapeOpers;
124 llvm::SmallVector<mlir::Value> shiftOpers;
125 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) {
126 populateShape(shapeOpers, shapeOp);
127 } else {
128 auto shiftOp =
129 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp());
130 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
131 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
132 }
133 llvm::SmallVector<mlir::Value> sliceOpers;
134 llvm::SmallVector<mlir::Value> subcompOpers;
135 llvm::SmallVector<mlir::Value> substrOpers;
136 if (auto s = embox.getSlice())
137 if (auto sliceOp =
138 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
139 sliceOpers.assign(sliceOp.getTriples().begin(),
140 sliceOp.getTriples().end());
141 subcompOpers.assign(sliceOp.getFields().begin(),
142 sliceOp.getFields().end());
143 substrOpers.assign(sliceOp.getSubstr().begin(),
144 sliceOp.getSubstr().end());
145 }
146 auto xbox = rewriter.create<fir::cg::XEmboxOp>(
147 loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers,
148 sliceOpers, subcompOpers, substrOpers, embox.getTypeparams(),
149 embox.getSourceBox(), embox.getAllocatorIdxAttr());
150 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
151 rewriter.replaceOp(embox, xbox.getOperation()->getResults());
152 return mlir::success();
153 }
154};
155
156/// Convert fir.rebox to the extended form where necessary.
157///
158/// For example,
159/// ```
160/// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) ->
161/// !fir.box<!fir.array<?xi32>>
162/// ```
163/// converted to
164/// ```
165/// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
166/// index, index) -> !fir.box<!fir.array<?xi32>>
167/// ```
168class ReboxConversion : public mlir::OpRewritePattern<fir::ReboxOp> {
169public:
170 using OpRewritePattern::OpRewritePattern;
171
172 llvm::LogicalResult
173 matchAndRewrite(fir::ReboxOp rebox,
174 mlir::PatternRewriter &rewriter) const override {
175 auto loc = rebox.getLoc();
176 llvm::SmallVector<mlir::Value> shapeOpers;
177 llvm::SmallVector<mlir::Value> shiftOpers;
178 if (auto shapeVal = rebox.getShape()) {
179 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
180 populateShape(shapeOpers, shapeOp);
181 else if (auto shiftOp =
182 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
183 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
184 else if (auto shiftOp =
185 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
186 populateShift(shiftOpers, shiftOp);
187 else
188 return mlir::failure();
189 }
190 llvm::SmallVector<mlir::Value> sliceOpers;
191 llvm::SmallVector<mlir::Value> subcompOpers;
192 llvm::SmallVector<mlir::Value> substrOpers;
193 if (auto s = rebox.getSlice())
194 if (auto sliceOp =
195 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
196 sliceOpers.append(sliceOp.getTriples().begin(),
197 sliceOp.getTriples().end());
198 subcompOpers.append(sliceOp.getFields().begin(),
199 sliceOp.getFields().end());
200 substrOpers.append(sliceOp.getSubstr().begin(),
201 sliceOp.getSubstr().end());
202 }
203
204 auto xRebox = rewriter.create<fir::cg::XReboxOp>(
205 loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers,
206 sliceOpers, subcompOpers, substrOpers);
207 LLVM_DEBUG(llvm::dbgs()
208 << "rewriting " << rebox << " to " << xRebox << '\n');
209 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
210 return mlir::success();
211 }
212};
213
214/// Convert all fir.array_coor to the extended form.
215///
216/// For example,
217/// ```
218/// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>,
219/// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
220/// ```
221/// converted to
222/// ```
223/// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> :
224/// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
225/// !fir.ref<i32>
226/// ```
227class ArrayCoorConversion : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
228public:
229 using OpRewritePattern::OpRewritePattern;
230
231 llvm::LogicalResult
232 matchAndRewrite(fir::ArrayCoorOp arrCoor,
233 mlir::PatternRewriter &rewriter) const override {
234 auto loc = arrCoor.getLoc();
235 llvm::SmallVector<mlir::Value> shapeOpers;
236 llvm::SmallVector<mlir::Value> shiftOpers;
237 if (auto shapeVal = arrCoor.getShape()) {
238 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
239 populateShape(shapeOpers, shapeOp);
240 else if (auto shiftOp =
241 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
242 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
243 else if (auto shiftOp =
244 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
245 populateShift(shiftOpers, shiftOp);
246 else
247 return mlir::failure();
248 }
249 llvm::SmallVector<mlir::Value> sliceOpers;
250 llvm::SmallVector<mlir::Value> subcompOpers;
251 if (auto s = arrCoor.getSlice())
252 if (auto sliceOp =
253 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
254 sliceOpers.append(sliceOp.getTriples().begin(),
255 sliceOp.getTriples().end());
256 subcompOpers.append(sliceOp.getFields().begin(),
257 sliceOp.getFields().end());
258 assert(sliceOp.getSubstr().empty() &&
259 "Don't allow substring operations on array_coor. This "
260 "restriction may be lifted in the future.");
261 }
262 auto xArrCoor = rewriter.create<fir::cg::XArrayCoorOp>(
263 loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers,
264 sliceOpers, subcompOpers, arrCoor.getIndices(),
265 arrCoor.getTypeparams());
266 LLVM_DEBUG(llvm::dbgs()
267 << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
268 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
269 return mlir::success();
270 }
271};
272
273class DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
274 bool preserveDeclare;
275
276public:
277 using OpRewritePattern::OpRewritePattern;
278 DeclareOpConversion(mlir::MLIRContext *ctx, bool preserveDecl)
279 : OpRewritePattern(ctx), preserveDeclare(preserveDecl) {}
280
281 llvm::LogicalResult
282 matchAndRewrite(fir::DeclareOp declareOp,
283 mlir::PatternRewriter &rewriter) const override {
284 if (!preserveDeclare) {
285 rewriter.replaceOp(declareOp, declareOp.getMemref());
286 return mlir::success();
287 }
288 auto loc = declareOp.getLoc();
289 llvm::SmallVector<mlir::Value> shapeOpers;
290 llvm::SmallVector<mlir::Value> shiftOpers;
291 if (auto shapeVal = declareOp.getShape()) {
292 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
293 populateShape(shapeOpers, shapeOp);
294 else if (auto shiftOp =
295 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
296 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
297 else if (auto shiftOp =
298 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
299 populateShift(shiftOpers, shiftOp);
300 else
301 return mlir::failure();
302 }
303 // FIXME: Add FortranAttrs and CudaAttrs
304 auto xDeclOp = rewriter.create<fir::cg::XDeclareOp>(
305 loc, declareOp.getType(), declareOp.getMemref(), shapeOpers, shiftOpers,
306 declareOp.getTypeparams(), declareOp.getDummyScope(),
307 declareOp.getUniqName());
308 LLVM_DEBUG(llvm::dbgs()
309 << "rewriting " << declareOp << " to " << xDeclOp << '\n');
310 rewriter.replaceOp(declareOp, xDeclOp.getOperation()->getResults());
311 return mlir::success();
312 }
313};
314
315class DummyScopeOpConversion
316 : public mlir::OpRewritePattern<fir::DummyScopeOp> {
317public:
318 using OpRewritePattern::OpRewritePattern;
319
320 llvm::LogicalResult
321 matchAndRewrite(fir::DummyScopeOp dummyScopeOp,
322 mlir::PatternRewriter &rewriter) const override {
323 rewriter.replaceOpWithNewOp<fir::UndefOp>(dummyScopeOp,
324 dummyScopeOp.getType());
325 return mlir::success();
326 }
327};
328
329/// Simple DCE to erase fir.shape/shift/slice/unused shape operands after this
330/// pass (fir.shape and like have no codegen).
331/// mlir::RegionDCE is expensive and requires running
332/// mlir::eraseUnreachableBlocks. It does things that are not needed here, like
333/// removing unused block arguments. fir.shape/shift/slice cannot be block
334/// arguments.
335/// This helper does a naive backward walk of the IR. It is not even guaranteed
336/// to walk blocks according to backward dominance, but that is good enough for
337/// what is done here, fir.shape/shift/slice have no usages anymore. The
338/// backward walk allows getting rid of most of the unused operands, it is not a
339/// problem to leave some in the weird cases.
340static void simpleDCE(mlir::RewriterBase &rewriter, mlir::Operation *op) {
341 op->walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
342 [&](mlir::Operation *subOp) {
343 if (mlir::isOpTriviallyDead(subOp))
344 rewriter.eraseOp(subOp);
345 });
346}
347
348class CodeGenRewrite : public fir::impl::CodeGenRewriteBase<CodeGenRewrite> {
349public:
350 using CodeGenRewriteBase<CodeGenRewrite>::CodeGenRewriteBase;
351
352 void runOnOperation() override final {
353 mlir::ModuleOp mod = getOperation();
354
355 auto &context = getContext();
356 mlir::ConversionTarget target(context);
357 target.addLegalDialect<mlir::arith::ArithDialect, fir::FIROpsDialect,
358 fir::FIRCodeGenDialect, mlir::func::FuncDialect>();
359 target.addIllegalOp<fir::ArrayCoorOp>();
360 target.addIllegalOp<fir::ReboxOp>();
361 target.addIllegalOp<fir::DeclareOp>();
362 target.addIllegalOp<fir::DummyScopeOp>();
363 target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) {
364 return !(embox.getShape() ||
365 mlir::isa<fir::SequenceType>(
366 mlir::cast<fir::BaseBoxType>(embox.getType()).getEleTy()));
367 });
368 mlir::RewritePatternSet patterns(&context);
369 fir::populatePreCGRewritePatterns(patterns, preserveDeclare);
370 if (mlir::failed(
371 mlir::applyPartialConversion(mod, target, std::move(patterns)))) {
372 mlir::emitError(mlir::UnknownLoc::get(&context),
373 "error in running the pre-codegen conversions");
374 signalPassFailure();
375 return;
376 }
377 // Erase any residual (fir.shape, fir.slice...).
378 mlir::IRRewriter rewriter(&context);
379 simpleDCE(rewriter, mod.getOperation());
380 }
381};
382
383} // namespace
384
385void fir::populatePreCGRewritePatterns(mlir::RewritePatternSet &patterns,
386 bool preserveDeclare) {
387 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion,
388 DummyScopeOpConversion>(patterns.getContext());
389 patterns.add<DeclareOpConversion>(patterns.getContext(), preserveDeclare);
390}
391

source code of flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp