1//===- SimplifyFIROperations.cpp -- simplify complex FIR operations ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9//===----------------------------------------------------------------------===//
10/// \file
11/// This pass transforms some FIR operations into their equivalent
12/// implementations using other FIR operations. The transformation
13/// can legally use SCF dialect and generate Fortran runtime calls.
14//===----------------------------------------------------------------------===//
15
16#include "flang/Optimizer/Builder/FIRBuilder.h"
17#include "flang/Optimizer/Builder/Runtime/Inquiry.h"
18#include "flang/Optimizer/Builder/Todo.h"
19#include "flang/Optimizer/Dialect/FIROps.h"
20#include "flang/Optimizer/Transforms/Passes.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24#include <optional>
25
26namespace fir {
27#define GEN_PASS_DEF_SIMPLIFYFIROPERATIONS
28#include "flang/Optimizer/Transforms/Passes.h.inc"
29} // namespace fir
30
31#define DEBUG_TYPE "flang-simplify-fir-operations"
32
33namespace {
34/// Pass runner.
35class SimplifyFIROperationsPass
36 : public fir::impl::SimplifyFIROperationsBase<SimplifyFIROperationsPass> {
37public:
38 using fir::impl::SimplifyFIROperationsBase<
39 SimplifyFIROperationsPass>::SimplifyFIROperationsBase;
40
41 void runOnOperation() override final;
42};
43
44/// Base class for all conversions holding the pass options.
45template <typename Op>
46class ConversionBase : public mlir::OpRewritePattern<Op> {
47public:
48 using mlir::OpRewritePattern<Op>::OpRewritePattern;
49
50 template <typename... Args>
51 ConversionBase(mlir::MLIRContext *context, Args &&...args)
52 : mlir::OpRewritePattern<Op>(context),
53 options{std::forward<Args>(args)...} {}
54
55 mlir::LogicalResult matchAndRewrite(Op,
56 mlir::PatternRewriter &) const override;
57
58protected:
59 fir::SimplifyFIROperationsOptions options;
60};
61
62/// fir::IsContiguousBoxOp converter.
63using IsContiguousBoxCoversion = ConversionBase<fir::IsContiguousBoxOp>;
64
65/// fir::BoxTotalElementsOp converter.
66using BoxTotalElementsConversion = ConversionBase<fir::BoxTotalElementsOp>;
67} // namespace
68
69/// Generate a call to IsContiguous/IsContiguousUpTo function or an inline
70/// sequence reading extents/strides from the box and checking them.
71/// This conversion may produce fir.box_elesize and a loop (for assumed
72/// rank).
73template <>
74mlir::LogicalResult IsContiguousBoxCoversion::matchAndRewrite(
75 fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const {
76 mlir::Location loc = op.getLoc();
77 fir::FirOpBuilder builder(rewriter, op.getOperation());
78 mlir::Value box = op.getBox();
79
80 if (options.preferInlineImplementation) {
81 auto boxType = mlir::cast<fir::BaseBoxType>(box.getType());
82 unsigned rank = fir::getBoxRank(boxType);
83
84 // If rank is one, or 'innermost' attribute is set and
85 // it is not a scalar, then generate a simple comparison
86 // for the leading dimension: (stride == elem_size || extent == 0).
87 //
88 // The scalar cases are supposed to be optimized by the canonicalization.
89 if (rank == 1 || (op.getInnermost() && rank > 0)) {
90 mlir::Type idxTy = builder.getIndexType();
91 auto eleSize = builder.create<fir::BoxEleSizeOp>(loc, idxTy, box);
92 mlir::Value zero = fir::factory::createZeroValue(builder, loc, idxTy);
93 auto dimInfo =
94 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, zero);
95 mlir::Value stride = dimInfo.getByteStride();
96 mlir::Value pred1 = builder.create<mlir::arith::CmpIOp>(
97 loc, mlir::arith::CmpIPredicate::eq, eleSize, stride);
98 mlir::Value extent = dimInfo.getExtent();
99 mlir::Value pred2 = builder.create<mlir::arith::CmpIOp>(
100 loc, mlir::arith::CmpIPredicate::eq, extent, zero);
101 mlir::Value result =
102 builder.create<mlir::arith::OrIOp>(loc, pred1, pred2);
103 result = builder.createConvert(loc, op.getType(), result);
104 rewriter.replaceOp(op, result);
105 return mlir::success();
106 }
107 // TODO: support arrays with multiple dimensions.
108 }
109
110 // Generate Fortran runtime call.
111 mlir::Value result;
112 if (op.getInnermost()) {
113 mlir::Value one =
114 builder.createIntegerConstant(loc, builder.getI32Type(), 1);
115 result = fir::runtime::genIsContiguousUpTo(builder, loc, box, one);
116 } else {
117 result = fir::runtime::genIsContiguous(builder, loc, box);
118 }
119 result = builder.createConvert(loc, op.getType(), result);
120 rewriter.replaceOp(op, result);
121 return mlir::success();
122}
123
124/// Generate a call to Size runtime function or an inline
125/// sequence reading extents from the box an multiplying them.
126/// This conversion may produce a loop (for assumed rank).
127template <>
128mlir::LogicalResult BoxTotalElementsConversion::matchAndRewrite(
129 fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const {
130 mlir::Location loc = op.getLoc();
131 fir::FirOpBuilder builder(rewriter, op.getOperation());
132 // TODO: support preferInlineImplementation.
133 // Reading the extent from the box for 1D arrays probably
134 // results in less code than the call, so we can always
135 // inline it.
136 bool doInline = options.preferInlineImplementation && false;
137 if (!doInline) {
138 // Generate Fortran runtime call.
139 mlir::Value result = fir::runtime::genSize(builder, loc, op.getBox());
140 result = builder.createConvert(loc, op.getType(), result);
141 rewriter.replaceOp(op, result);
142 return mlir::success();
143 }
144
145 // Generate inline implementation.
146 TODO(loc, "inline BoxTotalElementsOp");
147 return mlir::failure();
148}
149
150class DoConcurrentConversion
151 : public mlir::OpRewritePattern<fir::DoConcurrentOp> {
152 /// Looks up from the operation from and returns the LocalitySpecifierOp with
153 /// name symbolName
154 static fir::LocalitySpecifierOp
155 findLocalizer(mlir::Operation *from, mlir::SymbolRefAttr symbolName) {
156 fir::LocalitySpecifierOp localizer =
157 mlir::SymbolTable::lookupNearestSymbolFrom<fir::LocalitySpecifierOp>(
158 from, symbolName);
159 assert(localizer && "localizer not found in the symbol table");
160 return localizer;
161 }
162
163public:
164 using mlir::OpRewritePattern<fir::DoConcurrentOp>::OpRewritePattern;
165
166 mlir::LogicalResult
167 matchAndRewrite(fir::DoConcurrentOp doConcurentOp,
168 mlir::PatternRewriter &rewriter) const override {
169 assert(doConcurentOp.getRegion().hasOneBlock());
170 mlir::Block &wrapperBlock = doConcurentOp.getRegion().getBlocks().front();
171 auto loop =
172 mlir::cast<fir::DoConcurrentLoopOp>(wrapperBlock.getTerminator());
173 assert(loop.getRegion().hasOneBlock());
174 mlir::Block &loopBlock = loop.getRegion().getBlocks().front();
175
176 // Handle localization
177 if (!loop.getLocalVars().empty()) {
178 mlir::OpBuilder::InsertionGuard guard(rewriter);
179 rewriter.setInsertionPointToStart(&loop.getRegion().front());
180
181 std::optional<mlir::ArrayAttr> localSyms = loop.getLocalSyms();
182
183 for (auto [localVar, localArg, localizerSym] : llvm::zip_equal(
184 loop.getLocalVars(), loop.getRegionLocalArgs(), *localSyms)) {
185 mlir::SymbolRefAttr localizerName =
186 llvm::cast<mlir::SymbolRefAttr>(localizerSym);
187 fir::LocalitySpecifierOp localizer = findLocalizer(loop, localizerName);
188
189 if (!localizer.getInitRegion().empty() ||
190 !localizer.getDeallocRegion().empty())
191 TODO(localizer.getLoc(), "localizers with `init` and `dealloc` "
192 "regions are not handled yet.");
193
194 // TODO Should this be a heap allocation instead? For now, we allocate
195 // on the stack for each loop iteration.
196 mlir::Value localAlloc =
197 rewriter.create<fir::AllocaOp>(loop.getLoc(), localizer.getType());
198
199 if (localizer.getLocalitySpecifierType() ==
200 fir::LocalitySpecifierType::LocalInit) {
201 // It is reasonable to make this assumption since, at this stage,
202 // control-flow ops are not converted yet. Therefore, things like `if`
203 // conditions will still be represented by their encapsulating `fir`
204 // dialect ops.
205 assert(localizer.getCopyRegion().hasOneBlock() &&
206 "Expected localizer to have a single block.");
207 mlir::Block *beforeLocalInit = rewriter.getInsertionBlock();
208 mlir::Block *afterLocalInit = rewriter.splitBlock(
209 rewriter.getInsertionBlock(), rewriter.getInsertionPoint());
210 rewriter.cloneRegionBefore(localizer.getCopyRegion(), afterLocalInit);
211 mlir::Block *copyRegionBody = beforeLocalInit->getNextNode();
212
213 rewriter.eraseOp(copyRegionBody->getTerminator());
214 rewriter.mergeBlocks(afterLocalInit, copyRegionBody);
215 rewriter.mergeBlocks(copyRegionBody, beforeLocalInit,
216 {localVar, localArg});
217 }
218
219 rewriter.replaceAllUsesWith(localArg, localAlloc);
220 }
221
222 loop.getRegion().front().eraseArguments(loop.getNumInductionVars(),
223 loop.getNumLocalOperands());
224 loop.getLocalVarsMutable().clear();
225 loop.setLocalSymsAttr(nullptr);
226 }
227
228 // Collect iteration variable(s) allocations so that we can move them
229 // outside the `fir.do_concurrent` wrapper.
230 llvm::SmallVector<mlir::Operation *> opsToMove;
231 for (mlir::Operation &op : llvm::drop_end(wrapperBlock))
232 opsToMove.push_back(&op);
233
234 fir::FirOpBuilder firBuilder(
235 rewriter, doConcurentOp->getParentOfType<mlir::ModuleOp>());
236 auto *allocIt = firBuilder.getAllocaBlock();
237
238 for (mlir::Operation *op : llvm::reverse(opsToMove))
239 rewriter.moveOpBefore(op, allocIt, allocIt->begin());
240
241 rewriter.setInsertionPointAfter(doConcurentOp);
242 fir::DoLoopOp innermostUnorderdLoop;
243 mlir::SmallVector<mlir::Value> ivArgs;
244
245 for (auto [lb, ub, st, iv] :
246 llvm::zip_equal(loop.getLowerBound(), loop.getUpperBound(),
247 loop.getStep(), *loop.getLoopInductionVars())) {
248 innermostUnorderdLoop = rewriter.create<fir::DoLoopOp>(
249 doConcurentOp.getLoc(), lb, ub, st,
250 /*unordred=*/true, /*finalCountValue=*/false,
251 /*iterArgs=*/std::nullopt, loop.getReduceOperands(),
252 loop.getReduceAttrsAttr());
253 ivArgs.push_back(innermostUnorderdLoop.getInductionVar());
254 rewriter.setInsertionPointToStart(innermostUnorderdLoop.getBody());
255 }
256
257 rewriter.inlineBlockBefore(
258 &loopBlock, innermostUnorderdLoop.getBody()->getTerminator(), ivArgs);
259 rewriter.eraseOp(doConcurentOp);
260 return mlir::success();
261 }
262};
263
264void SimplifyFIROperationsPass::runOnOperation() {
265 mlir::ModuleOp module = getOperation();
266 mlir::MLIRContext &context = getContext();
267 mlir::RewritePatternSet patterns(&context);
268 fir::populateSimplifyFIROperationsPatterns(patterns,
269 preferInlineImplementation);
270 mlir::GreedyRewriteConfig config;
271 config.setRegionSimplificationLevel(
272 mlir::GreedySimplifyRegionLevel::Disabled);
273
274 if (mlir::failed(
275 mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
276 mlir::emitError(module.getLoc(), DEBUG_TYPE " pass failed");
277 signalPassFailure();
278 }
279}
280
281void fir::populateSimplifyFIROperationsPatterns(
282 mlir::RewritePatternSet &patterns, bool preferInlineImplementation) {
283 patterns.insert<IsContiguousBoxCoversion, BoxTotalElementsConversion>(
284 patterns.getContext(), preferInlineImplementation);
285 patterns.insert<DoConcurrentConversion>(patterns.getContext());
286}
287

source code of flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp