1//===-- PreCGRewrite.cpp --------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/CodeGen/CodeGen.h"
14
15#include "CGOps.h"
16#include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done
17#include "flang/Optimizer/Dialect/FIRDialect.h"
18#include "flang/Optimizer/Dialect/FIROps.h"
19#include "flang/Optimizer/Dialect/FIRType.h"
20#include "flang/Optimizer/Dialect/Support/FIRContext.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "mlir/Transforms/RegionUtils.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25
26namespace fir {
27#define GEN_PASS_DEF_CODEGENREWRITE
28#include "flang/Optimizer/CodeGen/CGPasses.h.inc"
29} // namespace fir
30
31//===----------------------------------------------------------------------===//
32// Codegen rewrite: rewriting of subgraphs of ops
33//===----------------------------------------------------------------------===//
34
35#define DEBUG_TYPE "flang-codegen-rewrite"
36
37static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
38 fir::ShapeOp shape) {
39 vec.append(shape.getExtents().begin(), shape.getExtents().end());
40}
41
42// Operands of fir.shape_shift split into two vectors.
43static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
44 llvm::SmallVectorImpl<mlir::Value> &shiftVec,
45 fir::ShapeShiftOp shift) {
46 for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end();
47 i != endIter;) {
48 shiftVec.push_back(*i++);
49 shapeVec.push_back(*i++);
50 }
51}
52
53static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
54 fir::ShiftOp shift) {
55 vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
56}
57
58namespace {
59
60/// Convert fir.embox to the extended form where necessary.
61///
62/// The embox operation can take arguments that specify multidimensional array
63/// properties at runtime. These properties may be shared between distinct
64/// objects that have the same properties. Before we lower these small DAGs to
65/// LLVM-IR, we gather all the information into a single extended operation. For
66/// example,
67/// ```
68/// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
69/// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
70/// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>,
71/// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
72/// ```
73/// can be rewritten as
74/// ```
75/// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] :
76/// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
77/// !fir.box<!fir.array<?xi32>>
78/// ```
79class EmboxConversion : public mlir::OpRewritePattern<fir::EmboxOp> {
80public:
81 using OpRewritePattern::OpRewritePattern;
82
83 mlir::LogicalResult
84 matchAndRewrite(fir::EmboxOp embox,
85 mlir::PatternRewriter &rewriter) const override {
86 // If the embox does not include a shape, then do not convert it
87 if (auto shapeVal = embox.getShape())
88 return rewriteDynamicShape(embox, rewriter, shapeVal);
89 if (embox.getType().isa<fir::ClassType>())
90 TODO(embox.getLoc(), "embox conversion for fir.class type");
91 if (auto boxTy = embox.getType().dyn_cast<fir::BoxType>())
92 if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
93 if (!seqTy.hasDynamicExtents())
94 return rewriteStaticShape(embox, rewriter, seqTy);
95 return mlir::failure();
96 }
97
98 mlir::LogicalResult rewriteStaticShape(fir::EmboxOp embox,
99 mlir::PatternRewriter &rewriter,
100 fir::SequenceType seqTy) const {
101 auto loc = embox.getLoc();
102 llvm::SmallVector<mlir::Value> shapeOpers;
103 auto idxTy = rewriter.getIndexType();
104 for (auto ext : seqTy.getShape()) {
105 auto iAttr = rewriter.getIndexAttr(ext);
106 auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr);
107 shapeOpers.push_back(extVal);
108 }
109 auto xbox = rewriter.create<fir::cg::XEmboxOp>(
110 loc, embox.getType(), embox.getMemref(), shapeOpers, std::nullopt,
111 std::nullopt, std::nullopt, std::nullopt, embox.getTypeparams(),
112 embox.getSourceBox());
113 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
114 rewriter.replaceOp(embox, xbox.getOperation()->getResults());
115 return mlir::success();
116 }
117
118 mlir::LogicalResult rewriteDynamicShape(fir::EmboxOp embox,
119 mlir::PatternRewriter &rewriter,
120 mlir::Value shapeVal) const {
121 auto loc = embox.getLoc();
122 llvm::SmallVector<mlir::Value> shapeOpers;
123 llvm::SmallVector<mlir::Value> shiftOpers;
124 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) {
125 populateShape(shapeOpers, shapeOp);
126 } else {
127 auto shiftOp =
128 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp());
129 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
130 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
131 }
132 llvm::SmallVector<mlir::Value> sliceOpers;
133 llvm::SmallVector<mlir::Value> subcompOpers;
134 llvm::SmallVector<mlir::Value> substrOpers;
135 if (auto s = embox.getSlice())
136 if (auto sliceOp =
137 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
138 sliceOpers.assign(sliceOp.getTriples().begin(),
139 sliceOp.getTriples().end());
140 subcompOpers.assign(sliceOp.getFields().begin(),
141 sliceOp.getFields().end());
142 substrOpers.assign(sliceOp.getSubstr().begin(),
143 sliceOp.getSubstr().end());
144 }
145 auto xbox = rewriter.create<fir::cg::XEmboxOp>(
146 loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers,
147 sliceOpers, subcompOpers, substrOpers, embox.getTypeparams(),
148 embox.getSourceBox());
149 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
150 rewriter.replaceOp(embox, xbox.getOperation()->getResults());
151 return mlir::success();
152 }
153};
154
155/// Convert fir.rebox to the extended form where necessary.
156///
157/// For example,
158/// ```
159/// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) ->
160/// !fir.box<!fir.array<?xi32>>
161/// ```
162/// converted to
163/// ```
164/// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
165/// index, index) -> !fir.box<!fir.array<?xi32>>
166/// ```
167class ReboxConversion : public mlir::OpRewritePattern<fir::ReboxOp> {
168public:
169 using OpRewritePattern::OpRewritePattern;
170
171 mlir::LogicalResult
172 matchAndRewrite(fir::ReboxOp rebox,
173 mlir::PatternRewriter &rewriter) const override {
174 auto loc = rebox.getLoc();
175 llvm::SmallVector<mlir::Value> shapeOpers;
176 llvm::SmallVector<mlir::Value> shiftOpers;
177 if (auto shapeVal = rebox.getShape()) {
178 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
179 populateShape(shapeOpers, shapeOp);
180 else if (auto shiftOp =
181 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
182 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
183 else if (auto shiftOp =
184 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
185 populateShift(shiftOpers, shiftOp);
186 else
187 return mlir::failure();
188 }
189 llvm::SmallVector<mlir::Value> sliceOpers;
190 llvm::SmallVector<mlir::Value> subcompOpers;
191 llvm::SmallVector<mlir::Value> substrOpers;
192 if (auto s = rebox.getSlice())
193 if (auto sliceOp =
194 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
195 sliceOpers.append(sliceOp.getTriples().begin(),
196 sliceOp.getTriples().end());
197 subcompOpers.append(sliceOp.getFields().begin(),
198 sliceOp.getFields().end());
199 substrOpers.append(sliceOp.getSubstr().begin(),
200 sliceOp.getSubstr().end());
201 }
202
203 auto xRebox = rewriter.create<fir::cg::XReboxOp>(
204 loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers,
205 sliceOpers, subcompOpers, substrOpers);
206 LLVM_DEBUG(llvm::dbgs()
207 << "rewriting " << rebox << " to " << xRebox << '\n');
208 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
209 return mlir::success();
210 }
211};
212
213/// Convert all fir.array_coor to the extended form.
214///
215/// For example,
216/// ```
217/// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>,
218/// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
219/// ```
220/// converted to
221/// ```
222/// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> :
223/// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
224/// !fir.ref<i32>
225/// ```
226class ArrayCoorConversion : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
227public:
228 using OpRewritePattern::OpRewritePattern;
229
230 mlir::LogicalResult
231 matchAndRewrite(fir::ArrayCoorOp arrCoor,
232 mlir::PatternRewriter &rewriter) const override {
233 auto loc = arrCoor.getLoc();
234 llvm::SmallVector<mlir::Value> shapeOpers;
235 llvm::SmallVector<mlir::Value> shiftOpers;
236 if (auto shapeVal = arrCoor.getShape()) {
237 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
238 populateShape(shapeOpers, shapeOp);
239 else if (auto shiftOp =
240 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
241 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
242 else if (auto shiftOp =
243 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
244 populateShift(shiftOpers, shiftOp);
245 else
246 return mlir::failure();
247 }
248 llvm::SmallVector<mlir::Value> sliceOpers;
249 llvm::SmallVector<mlir::Value> subcompOpers;
250 if (auto s = arrCoor.getSlice())
251 if (auto sliceOp =
252 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
253 sliceOpers.append(sliceOp.getTriples().begin(),
254 sliceOp.getTriples().end());
255 subcompOpers.append(sliceOp.getFields().begin(),
256 sliceOp.getFields().end());
257 assert(sliceOp.getSubstr().empty() &&
258 "Don't allow substring operations on array_coor. This "
259 "restriction may be lifted in the future.");
260 }
261 auto xArrCoor = rewriter.create<fir::cg::XArrayCoorOp>(
262 loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers,
263 sliceOpers, subcompOpers, arrCoor.getIndices(),
264 arrCoor.getTypeparams());
265 LLVM_DEBUG(llvm::dbgs()
266 << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
267 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
268 return mlir::success();
269 }
270};
271
272class DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
273public:
274 using OpRewritePattern::OpRewritePattern;
275
276 mlir::LogicalResult
277 matchAndRewrite(fir::DeclareOp declareOp,
278 mlir::PatternRewriter &rewriter) const override {
279 rewriter.replaceOp(declareOp, declareOp.getMemref());
280 return mlir::success();
281 }
282};
283
284class CodeGenRewrite : public fir::impl::CodeGenRewriteBase<CodeGenRewrite> {
285public:
286 void runOnOperation() override final {
287 mlir::ModuleOp mod = getOperation();
288
289 auto &context = getContext();
290 mlir::ConversionTarget target(context);
291 target.addLegalDialect<mlir::arith::ArithDialect, fir::FIROpsDialect,
292 fir::FIRCodeGenDialect, mlir::func::FuncDialect>();
293 target.addIllegalOp<fir::ArrayCoorOp>();
294 target.addIllegalOp<fir::ReboxOp>();
295 target.addIllegalOp<fir::DeclareOp>();
296 target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) {
297 return !(embox.getShape() || embox.getType()
298 .cast<fir::BaseBoxType>()
299 .getEleTy()
300 .isa<fir::SequenceType>());
301 });
302 mlir::RewritePatternSet patterns(&context);
303 fir::populatePreCGRewritePatterns(patterns);
304 if (mlir::failed(
305 mlir::applyPartialConversion(mod, target, std::move(patterns)))) {
306 mlir::emitError(mlir::UnknownLoc::get(&context),
307 "error in running the pre-codegen conversions");
308 signalPassFailure();
309 return;
310 }
311 // Erase any residual (fir.shape, fir.slice...).
312 mlir::IRRewriter rewriter(&context);
313 (void)mlir::runRegionDCE(rewriter, mod->getRegions());
314 }
315};
316
317} // namespace
318
319std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() {
320 return std::make_unique<CodeGenRewrite>();
321}
322
323void fir::populatePreCGRewritePatterns(mlir::RewritePatternSet &patterns) {
324 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion,
325 DeclareOpConversion>(patterns.getContext());
326}
327

source code of flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp