1 | //===- SimplifyFIROperations.cpp -- simplify complex FIR operations ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | //===----------------------------------------------------------------------===// |
10 | /// \file |
11 | /// This pass transforms some FIR operations into their equivalent |
12 | /// implementations using other FIR operations. The transformation |
13 | /// can legally use SCF dialect and generate Fortran runtime calls. |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
17 | #include "flang/Optimizer/Builder/Runtime/Inquiry.h" |
18 | #include "flang/Optimizer/Builder/Todo.h" |
19 | #include "flang/Optimizer/Dialect/FIROps.h" |
20 | #include "flang/Optimizer/Transforms/Passes.h" |
21 | #include "mlir/IR/IRMapping.h" |
22 | #include "mlir/Pass/Pass.h" |
23 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
24 | #include <optional> |
25 | |
26 | namespace fir { |
27 | #define GEN_PASS_DEF_SIMPLIFYFIROPERATIONS |
28 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
29 | } // namespace fir |
30 | |
31 | #define DEBUG_TYPE "flang-simplify-fir-operations" |
32 | |
33 | namespace { |
34 | /// Pass runner. |
35 | class SimplifyFIROperationsPass |
36 | : public fir::impl::SimplifyFIROperationsBase<SimplifyFIROperationsPass> { |
37 | public: |
38 | using fir::impl::SimplifyFIROperationsBase< |
39 | SimplifyFIROperationsPass>::SimplifyFIROperationsBase; |
40 | |
41 | void runOnOperation() override final; |
42 | }; |
43 | |
44 | /// Base class for all conversions holding the pass options. |
45 | template <typename Op> |
46 | class ConversionBase : public mlir::OpRewritePattern<Op> { |
47 | public: |
48 | using mlir::OpRewritePattern<Op>::OpRewritePattern; |
49 | |
50 | template <typename... Args> |
51 | ConversionBase(mlir::MLIRContext *context, Args &&...args) |
52 | : mlir::OpRewritePattern<Op>(context), |
53 | options{std::forward<Args>(args)...} {} |
54 | |
55 | mlir::LogicalResult matchAndRewrite(Op, |
56 | mlir::PatternRewriter &) const override; |
57 | |
58 | protected: |
59 | fir::SimplifyFIROperationsOptions options; |
60 | }; |
61 | |
62 | /// fir::IsContiguousBoxOp converter. |
63 | using IsContiguousBoxCoversion = ConversionBase<fir::IsContiguousBoxOp>; |
64 | |
65 | /// fir::BoxTotalElementsOp converter. |
66 | using BoxTotalElementsConversion = ConversionBase<fir::BoxTotalElementsOp>; |
67 | } // namespace |
68 | |
69 | /// Generate a call to IsContiguous/IsContiguousUpTo function or an inline |
70 | /// sequence reading extents/strides from the box and checking them. |
71 | /// This conversion may produce fir.box_elesize and a loop (for assumed |
72 | /// rank). |
73 | template <> |
74 | mlir::LogicalResult IsContiguousBoxCoversion::matchAndRewrite( |
75 | fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const { |
76 | mlir::Location loc = op.getLoc(); |
77 | fir::FirOpBuilder builder(rewriter, op.getOperation()); |
78 | mlir::Value box = op.getBox(); |
79 | |
80 | if (options.preferInlineImplementation) { |
81 | auto boxType = mlir::cast<fir::BaseBoxType>(box.getType()); |
82 | unsigned rank = fir::getBoxRank(boxType); |
83 | |
84 | // If rank is one, or 'innermost' attribute is set and |
85 | // it is not a scalar, then generate a simple comparison |
86 | // for the leading dimension: (stride == elem_size || extent == 0). |
87 | // |
88 | // The scalar cases are supposed to be optimized by the canonicalization. |
89 | if (rank == 1 || (op.getInnermost() && rank > 0)) { |
90 | mlir::Type idxTy = builder.getIndexType(); |
91 | auto eleSize = builder.create<fir::BoxEleSizeOp>(loc, idxTy, box); |
92 | mlir::Value zero = fir::factory::createZeroValue(builder, loc, idxTy); |
93 | auto dimInfo = |
94 | builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, zero); |
95 | mlir::Value stride = dimInfo.getByteStride(); |
96 | mlir::Value pred1 = builder.create<mlir::arith::CmpIOp>( |
97 | loc, mlir::arith::CmpIPredicate::eq, eleSize, stride); |
98 | mlir::Value extent = dimInfo.getExtent(); |
99 | mlir::Value pred2 = builder.create<mlir::arith::CmpIOp>( |
100 | loc, mlir::arith::CmpIPredicate::eq, extent, zero); |
101 | mlir::Value result = |
102 | builder.create<mlir::arith::OrIOp>(loc, pred1, pred2); |
103 | result = builder.createConvert(loc, op.getType(), result); |
104 | rewriter.replaceOp(op, result); |
105 | return mlir::success(); |
106 | } |
107 | // TODO: support arrays with multiple dimensions. |
108 | } |
109 | |
110 | // Generate Fortran runtime call. |
111 | mlir::Value result; |
112 | if (op.getInnermost()) { |
113 | mlir::Value one = |
114 | builder.createIntegerConstant(loc, builder.getI32Type(), 1); |
115 | result = fir::runtime::genIsContiguousUpTo(builder, loc, box, one); |
116 | } else { |
117 | result = fir::runtime::genIsContiguous(builder, loc, box); |
118 | } |
119 | result = builder.createConvert(loc, op.getType(), result); |
120 | rewriter.replaceOp(op, result); |
121 | return mlir::success(); |
122 | } |
123 | |
124 | /// Generate a call to Size runtime function or an inline |
125 | /// sequence reading extents from the box an multiplying them. |
126 | /// This conversion may produce a loop (for assumed rank). |
127 | template <> |
128 | mlir::LogicalResult BoxTotalElementsConversion::matchAndRewrite( |
129 | fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const { |
130 | mlir::Location loc = op.getLoc(); |
131 | fir::FirOpBuilder builder(rewriter, op.getOperation()); |
132 | // TODO: support preferInlineImplementation. |
133 | // Reading the extent from the box for 1D arrays probably |
134 | // results in less code than the call, so we can always |
135 | // inline it. |
136 | bool doInline = options.preferInlineImplementation && false; |
137 | if (!doInline) { |
138 | // Generate Fortran runtime call. |
139 | mlir::Value result = fir::runtime::genSize(builder, loc, op.getBox()); |
140 | result = builder.createConvert(loc, op.getType(), result); |
141 | rewriter.replaceOp(op, result); |
142 | return mlir::success(); |
143 | } |
144 | |
145 | // Generate inline implementation. |
146 | TODO(loc, "inline BoxTotalElementsOp" ); |
147 | return mlir::failure(); |
148 | } |
149 | |
150 | class DoConcurrentConversion |
151 | : public mlir::OpRewritePattern<fir::DoConcurrentOp> { |
152 | /// Looks up from the operation from and returns the LocalitySpecifierOp with |
153 | /// name symbolName |
154 | static fir::LocalitySpecifierOp |
155 | findLocalizer(mlir::Operation *from, mlir::SymbolRefAttr symbolName) { |
156 | fir::LocalitySpecifierOp localizer = |
157 | mlir::SymbolTable::lookupNearestSymbolFrom<fir::LocalitySpecifierOp>( |
158 | from, symbolName); |
159 | assert(localizer && "localizer not found in the symbol table" ); |
160 | return localizer; |
161 | } |
162 | |
163 | public: |
164 | using mlir::OpRewritePattern<fir::DoConcurrentOp>::OpRewritePattern; |
165 | |
166 | mlir::LogicalResult |
167 | matchAndRewrite(fir::DoConcurrentOp doConcurentOp, |
168 | mlir::PatternRewriter &rewriter) const override { |
169 | assert(doConcurentOp.getRegion().hasOneBlock()); |
170 | mlir::Block &wrapperBlock = doConcurentOp.getRegion().getBlocks().front(); |
171 | auto loop = |
172 | mlir::cast<fir::DoConcurrentLoopOp>(wrapperBlock.getTerminator()); |
173 | assert(loop.getRegion().hasOneBlock()); |
174 | mlir::Block &loopBlock = loop.getRegion().getBlocks().front(); |
175 | |
176 | // Handle localization |
177 | if (!loop.getLocalVars().empty()) { |
178 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
179 | rewriter.setInsertionPointToStart(&loop.getRegion().front()); |
180 | |
181 | std::optional<mlir::ArrayAttr> localSyms = loop.getLocalSyms(); |
182 | |
183 | for (auto [localVar, localArg, localizerSym] : llvm::zip_equal( |
184 | loop.getLocalVars(), loop.getRegionLocalArgs(), *localSyms)) { |
185 | mlir::SymbolRefAttr localizerName = |
186 | llvm::cast<mlir::SymbolRefAttr>(localizerSym); |
187 | fir::LocalitySpecifierOp localizer = findLocalizer(loop, localizerName); |
188 | |
189 | if (!localizer.getInitRegion().empty() || |
190 | !localizer.getDeallocRegion().empty()) |
191 | TODO(localizer.getLoc(), "localizers with `init` and `dealloc` " |
192 | "regions are not handled yet." ); |
193 | |
194 | // TODO Should this be a heap allocation instead? For now, we allocate |
195 | // on the stack for each loop iteration. |
196 | mlir::Value localAlloc = |
197 | rewriter.create<fir::AllocaOp>(loop.getLoc(), localizer.getType()); |
198 | |
199 | if (localizer.getLocalitySpecifierType() == |
200 | fir::LocalitySpecifierType::LocalInit) { |
201 | // It is reasonable to make this assumption since, at this stage, |
202 | // control-flow ops are not converted yet. Therefore, things like `if` |
203 | // conditions will still be represented by their encapsulating `fir` |
204 | // dialect ops. |
205 | assert(localizer.getCopyRegion().hasOneBlock() && |
206 | "Expected localizer to have a single block." ); |
207 | mlir::Block *beforeLocalInit = rewriter.getInsertionBlock(); |
208 | mlir::Block *afterLocalInit = rewriter.splitBlock( |
209 | rewriter.getInsertionBlock(), rewriter.getInsertionPoint()); |
210 | rewriter.cloneRegionBefore(localizer.getCopyRegion(), afterLocalInit); |
211 | mlir::Block *copyRegionBody = beforeLocalInit->getNextNode(); |
212 | |
213 | rewriter.eraseOp(copyRegionBody->getTerminator()); |
214 | rewriter.mergeBlocks(afterLocalInit, copyRegionBody); |
215 | rewriter.mergeBlocks(copyRegionBody, beforeLocalInit, |
216 | {localVar, localArg}); |
217 | } |
218 | |
219 | rewriter.replaceAllUsesWith(localArg, localAlloc); |
220 | } |
221 | |
222 | loop.getRegion().front().eraseArguments(loop.getNumInductionVars(), |
223 | loop.getNumLocalOperands()); |
224 | loop.getLocalVarsMutable().clear(); |
225 | loop.setLocalSymsAttr(nullptr); |
226 | } |
227 | |
228 | // Collect iteration variable(s) allocations so that we can move them |
229 | // outside the `fir.do_concurrent` wrapper. |
230 | llvm::SmallVector<mlir::Operation *> opsToMove; |
231 | for (mlir::Operation &op : llvm::drop_end(wrapperBlock)) |
232 | opsToMove.push_back(&op); |
233 | |
234 | fir::FirOpBuilder firBuilder( |
235 | rewriter, doConcurentOp->getParentOfType<mlir::ModuleOp>()); |
236 | auto *allocIt = firBuilder.getAllocaBlock(); |
237 | |
238 | for (mlir::Operation *op : llvm::reverse(opsToMove)) |
239 | rewriter.moveOpBefore(op, allocIt, allocIt->begin()); |
240 | |
241 | rewriter.setInsertionPointAfter(doConcurentOp); |
242 | fir::DoLoopOp innermostUnorderdLoop; |
243 | mlir::SmallVector<mlir::Value> ivArgs; |
244 | |
245 | for (auto [lb, ub, st, iv] : |
246 | llvm::zip_equal(loop.getLowerBound(), loop.getUpperBound(), |
247 | loop.getStep(), *loop.getLoopInductionVars())) { |
248 | innermostUnorderdLoop = rewriter.create<fir::DoLoopOp>( |
249 | doConcurentOp.getLoc(), lb, ub, st, |
250 | /*unordred=*/true, /*finalCountValue=*/false, |
251 | /*iterArgs=*/std::nullopt, loop.getReduceOperands(), |
252 | loop.getReduceAttrsAttr()); |
253 | ivArgs.push_back(innermostUnorderdLoop.getInductionVar()); |
254 | rewriter.setInsertionPointToStart(innermostUnorderdLoop.getBody()); |
255 | } |
256 | |
257 | rewriter.inlineBlockBefore( |
258 | &loopBlock, innermostUnorderdLoop.getBody()->getTerminator(), ivArgs); |
259 | rewriter.eraseOp(doConcurentOp); |
260 | return mlir::success(); |
261 | } |
262 | }; |
263 | |
264 | void SimplifyFIROperationsPass::runOnOperation() { |
265 | mlir::ModuleOp module = getOperation(); |
266 | mlir::MLIRContext &context = getContext(); |
267 | mlir::RewritePatternSet patterns(&context); |
268 | fir::populateSimplifyFIROperationsPatterns(patterns, |
269 | preferInlineImplementation); |
270 | mlir::GreedyRewriteConfig config; |
271 | config.setRegionSimplificationLevel( |
272 | mlir::GreedySimplifyRegionLevel::Disabled); |
273 | |
274 | if (mlir::failed( |
275 | mlir::applyPatternsGreedily(module, std::move(patterns), config))) { |
276 | mlir::emitError(module.getLoc(), DEBUG_TYPE " pass failed" ); |
277 | signalPassFailure(); |
278 | } |
279 | } |
280 | |
281 | void fir::populateSimplifyFIROperationsPatterns( |
282 | mlir::RewritePatternSet &patterns, bool preferInlineImplementation) { |
283 | patterns.insert<IsContiguousBoxCoversion, BoxTotalElementsConversion>( |
284 | patterns.getContext(), preferInlineImplementation); |
285 | patterns.insert<DoConcurrentConversion>(patterns.getContext()); |
286 | } |
287 | |