| 1 | //===- SimplifyFIROperations.cpp -- simplify complex FIR operations ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | //===----------------------------------------------------------------------===// |
| 10 | /// \file |
| 11 | /// This pass transforms some FIR operations into their equivalent |
| 12 | /// implementations using other FIR operations. The transformation |
| 13 | /// can legally use SCF dialect and generate Fortran runtime calls. |
| 14 | //===----------------------------------------------------------------------===// |
| 15 | |
| 16 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
| 17 | #include "flang/Optimizer/Builder/Runtime/Inquiry.h" |
| 18 | #include "flang/Optimizer/Builder/Todo.h" |
| 19 | #include "flang/Optimizer/Dialect/FIROps.h" |
| 20 | #include "flang/Optimizer/Transforms/Passes.h" |
| 21 | #include "mlir/IR/IRMapping.h" |
| 22 | #include "mlir/Pass/Pass.h" |
| 23 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 24 | #include <optional> |
| 25 | |
| 26 | namespace fir { |
| 27 | #define GEN_PASS_DEF_SIMPLIFYFIROPERATIONS |
| 28 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
| 29 | } // namespace fir |
| 30 | |
| 31 | #define DEBUG_TYPE "flang-simplify-fir-operations" |
| 32 | |
| 33 | namespace { |
| 34 | /// Pass runner. |
| 35 | class SimplifyFIROperationsPass |
| 36 | : public fir::impl::SimplifyFIROperationsBase<SimplifyFIROperationsPass> { |
| 37 | public: |
| 38 | using fir::impl::SimplifyFIROperationsBase< |
| 39 | SimplifyFIROperationsPass>::SimplifyFIROperationsBase; |
| 40 | |
| 41 | void runOnOperation() override final; |
| 42 | }; |
| 43 | |
| 44 | /// Base class for all conversions holding the pass options. |
| 45 | template <typename Op> |
| 46 | class ConversionBase : public mlir::OpRewritePattern<Op> { |
| 47 | public: |
| 48 | using mlir::OpRewritePattern<Op>::OpRewritePattern; |
| 49 | |
| 50 | template <typename... Args> |
| 51 | ConversionBase(mlir::MLIRContext *context, Args &&...args) |
| 52 | : mlir::OpRewritePattern<Op>(context), |
| 53 | options{std::forward<Args>(args)...} {} |
| 54 | |
| 55 | mlir::LogicalResult matchAndRewrite(Op, |
| 56 | mlir::PatternRewriter &) const override; |
| 57 | |
| 58 | protected: |
| 59 | fir::SimplifyFIROperationsOptions options; |
| 60 | }; |
| 61 | |
| 62 | /// fir::IsContiguousBoxOp converter. |
| 63 | using IsContiguousBoxCoversion = ConversionBase<fir::IsContiguousBoxOp>; |
| 64 | |
| 65 | /// fir::BoxTotalElementsOp converter. |
| 66 | using BoxTotalElementsConversion = ConversionBase<fir::BoxTotalElementsOp>; |
| 67 | } // namespace |
| 68 | |
| 69 | /// Generate a call to IsContiguous/IsContiguousUpTo function or an inline |
| 70 | /// sequence reading extents/strides from the box and checking them. |
| 71 | /// This conversion may produce fir.box_elesize and a loop (for assumed |
| 72 | /// rank). |
| 73 | template <> |
| 74 | mlir::LogicalResult IsContiguousBoxCoversion::matchAndRewrite( |
| 75 | fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const { |
| 76 | mlir::Location loc = op.getLoc(); |
| 77 | fir::FirOpBuilder builder(rewriter, op.getOperation()); |
| 78 | mlir::Value box = op.getBox(); |
| 79 | |
| 80 | if (options.preferInlineImplementation) { |
| 81 | auto boxType = mlir::cast<fir::BaseBoxType>(box.getType()); |
| 82 | unsigned rank = fir::getBoxRank(boxType); |
| 83 | |
| 84 | // If rank is one, or 'innermost' attribute is set and |
| 85 | // it is not a scalar, then generate a simple comparison |
| 86 | // for the leading dimension: (stride == elem_size || extent == 0). |
| 87 | // |
| 88 | // The scalar cases are supposed to be optimized by the canonicalization. |
| 89 | if (rank == 1 || (op.getInnermost() && rank > 0)) { |
| 90 | mlir::Type idxTy = builder.getIndexType(); |
| 91 | auto eleSize = builder.create<fir::BoxEleSizeOp>(loc, idxTy, box); |
| 92 | mlir::Value zero = fir::factory::createZeroValue(builder, loc, idxTy); |
| 93 | auto dimInfo = |
| 94 | builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, zero); |
| 95 | mlir::Value stride = dimInfo.getByteStride(); |
| 96 | mlir::Value pred1 = builder.create<mlir::arith::CmpIOp>( |
| 97 | loc, mlir::arith::CmpIPredicate::eq, eleSize, stride); |
| 98 | mlir::Value extent = dimInfo.getExtent(); |
| 99 | mlir::Value pred2 = builder.create<mlir::arith::CmpIOp>( |
| 100 | loc, mlir::arith::CmpIPredicate::eq, extent, zero); |
| 101 | mlir::Value result = |
| 102 | builder.create<mlir::arith::OrIOp>(loc, pred1, pred2); |
| 103 | result = builder.createConvert(loc, op.getType(), result); |
| 104 | rewriter.replaceOp(op, result); |
| 105 | return mlir::success(); |
| 106 | } |
| 107 | // TODO: support arrays with multiple dimensions. |
| 108 | } |
| 109 | |
| 110 | // Generate Fortran runtime call. |
| 111 | mlir::Value result; |
| 112 | if (op.getInnermost()) { |
| 113 | mlir::Value one = |
| 114 | builder.createIntegerConstant(loc, builder.getI32Type(), 1); |
| 115 | result = fir::runtime::genIsContiguousUpTo(builder, loc, box, one); |
| 116 | } else { |
| 117 | result = fir::runtime::genIsContiguous(builder, loc, box); |
| 118 | } |
| 119 | result = builder.createConvert(loc, op.getType(), result); |
| 120 | rewriter.replaceOp(op, result); |
| 121 | return mlir::success(); |
| 122 | } |
| 123 | |
| 124 | /// Generate a call to Size runtime function or an inline |
| 125 | /// sequence reading extents from the box an multiplying them. |
| 126 | /// This conversion may produce a loop (for assumed rank). |
| 127 | template <> |
| 128 | mlir::LogicalResult BoxTotalElementsConversion::matchAndRewrite( |
| 129 | fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const { |
| 130 | mlir::Location loc = op.getLoc(); |
| 131 | fir::FirOpBuilder builder(rewriter, op.getOperation()); |
| 132 | // TODO: support preferInlineImplementation. |
| 133 | // Reading the extent from the box for 1D arrays probably |
| 134 | // results in less code than the call, so we can always |
| 135 | // inline it. |
| 136 | bool doInline = options.preferInlineImplementation && false; |
| 137 | if (!doInline) { |
| 138 | // Generate Fortran runtime call. |
| 139 | mlir::Value result = fir::runtime::genSize(builder, loc, op.getBox()); |
| 140 | result = builder.createConvert(loc, op.getType(), result); |
| 141 | rewriter.replaceOp(op, result); |
| 142 | return mlir::success(); |
| 143 | } |
| 144 | |
| 145 | // Generate inline implementation. |
| 146 | TODO(loc, "inline BoxTotalElementsOp" ); |
| 147 | return mlir::failure(); |
| 148 | } |
| 149 | |
| 150 | class DoConcurrentConversion |
| 151 | : public mlir::OpRewritePattern<fir::DoConcurrentOp> { |
| 152 | /// Looks up from the operation from and returns the LocalitySpecifierOp with |
| 153 | /// name symbolName |
| 154 | static fir::LocalitySpecifierOp |
| 155 | findLocalizer(mlir::Operation *from, mlir::SymbolRefAttr symbolName) { |
| 156 | fir::LocalitySpecifierOp localizer = |
| 157 | mlir::SymbolTable::lookupNearestSymbolFrom<fir::LocalitySpecifierOp>( |
| 158 | from, symbolName); |
| 159 | assert(localizer && "localizer not found in the symbol table" ); |
| 160 | return localizer; |
| 161 | } |
| 162 | |
| 163 | public: |
| 164 | using mlir::OpRewritePattern<fir::DoConcurrentOp>::OpRewritePattern; |
| 165 | |
| 166 | mlir::LogicalResult |
| 167 | matchAndRewrite(fir::DoConcurrentOp doConcurentOp, |
| 168 | mlir::PatternRewriter &rewriter) const override { |
| 169 | assert(doConcurentOp.getRegion().hasOneBlock()); |
| 170 | mlir::Block &wrapperBlock = doConcurentOp.getRegion().getBlocks().front(); |
| 171 | auto loop = |
| 172 | mlir::cast<fir::DoConcurrentLoopOp>(wrapperBlock.getTerminator()); |
| 173 | assert(loop.getRegion().hasOneBlock()); |
| 174 | mlir::Block &loopBlock = loop.getRegion().getBlocks().front(); |
| 175 | |
| 176 | // Handle localization |
| 177 | if (!loop.getLocalVars().empty()) { |
| 178 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
| 179 | rewriter.setInsertionPointToStart(&loop.getRegion().front()); |
| 180 | |
| 181 | std::optional<mlir::ArrayAttr> localSyms = loop.getLocalSyms(); |
| 182 | |
| 183 | for (auto [localVar, localArg, localizerSym] : llvm::zip_equal( |
| 184 | loop.getLocalVars(), loop.getRegionLocalArgs(), *localSyms)) { |
| 185 | mlir::SymbolRefAttr localizerName = |
| 186 | llvm::cast<mlir::SymbolRefAttr>(localizerSym); |
| 187 | fir::LocalitySpecifierOp localizer = findLocalizer(loop, localizerName); |
| 188 | |
| 189 | if (!localizer.getInitRegion().empty() || |
| 190 | !localizer.getDeallocRegion().empty()) |
| 191 | TODO(localizer.getLoc(), "localizers with `init` and `dealloc` " |
| 192 | "regions are not handled yet." ); |
| 193 | |
| 194 | // TODO Should this be a heap allocation instead? For now, we allocate |
| 195 | // on the stack for each loop iteration. |
| 196 | mlir::Value localAlloc = |
| 197 | rewriter.create<fir::AllocaOp>(loop.getLoc(), localizer.getType()); |
| 198 | |
| 199 | if (localizer.getLocalitySpecifierType() == |
| 200 | fir::LocalitySpecifierType::LocalInit) { |
| 201 | // It is reasonable to make this assumption since, at this stage, |
| 202 | // control-flow ops are not converted yet. Therefore, things like `if` |
| 203 | // conditions will still be represented by their encapsulating `fir` |
| 204 | // dialect ops. |
| 205 | assert(localizer.getCopyRegion().hasOneBlock() && |
| 206 | "Expected localizer to have a single block." ); |
| 207 | mlir::Block *beforeLocalInit = rewriter.getInsertionBlock(); |
| 208 | mlir::Block *afterLocalInit = rewriter.splitBlock( |
| 209 | rewriter.getInsertionBlock(), rewriter.getInsertionPoint()); |
| 210 | rewriter.cloneRegionBefore(localizer.getCopyRegion(), afterLocalInit); |
| 211 | mlir::Block *copyRegionBody = beforeLocalInit->getNextNode(); |
| 212 | |
| 213 | rewriter.eraseOp(copyRegionBody->getTerminator()); |
| 214 | rewriter.mergeBlocks(afterLocalInit, copyRegionBody); |
| 215 | rewriter.mergeBlocks(copyRegionBody, beforeLocalInit, |
| 216 | {localVar, localArg}); |
| 217 | } |
| 218 | |
| 219 | rewriter.replaceAllUsesWith(localArg, localAlloc); |
| 220 | } |
| 221 | |
| 222 | loop.getRegion().front().eraseArguments(loop.getNumInductionVars(), |
| 223 | loop.getNumLocalOperands()); |
| 224 | loop.getLocalVarsMutable().clear(); |
| 225 | loop.setLocalSymsAttr(nullptr); |
| 226 | } |
| 227 | |
| 228 | // Collect iteration variable(s) allocations so that we can move them |
| 229 | // outside the `fir.do_concurrent` wrapper. |
| 230 | llvm::SmallVector<mlir::Operation *> opsToMove; |
| 231 | for (mlir::Operation &op : llvm::drop_end(wrapperBlock)) |
| 232 | opsToMove.push_back(&op); |
| 233 | |
| 234 | fir::FirOpBuilder firBuilder( |
| 235 | rewriter, doConcurentOp->getParentOfType<mlir::ModuleOp>()); |
| 236 | auto *allocIt = firBuilder.getAllocaBlock(); |
| 237 | |
| 238 | for (mlir::Operation *op : llvm::reverse(opsToMove)) |
| 239 | rewriter.moveOpBefore(op, allocIt, allocIt->begin()); |
| 240 | |
| 241 | rewriter.setInsertionPointAfter(doConcurentOp); |
| 242 | fir::DoLoopOp innermostUnorderdLoop; |
| 243 | mlir::SmallVector<mlir::Value> ivArgs; |
| 244 | |
| 245 | for (auto [lb, ub, st, iv] : |
| 246 | llvm::zip_equal(loop.getLowerBound(), loop.getUpperBound(), |
| 247 | loop.getStep(), *loop.getLoopInductionVars())) { |
| 248 | innermostUnorderdLoop = rewriter.create<fir::DoLoopOp>( |
| 249 | doConcurentOp.getLoc(), lb, ub, st, |
| 250 | /*unordred=*/true, /*finalCountValue=*/false, |
| 251 | /*iterArgs=*/std::nullopt, loop.getReduceOperands(), |
| 252 | loop.getReduceAttrsAttr()); |
| 253 | ivArgs.push_back(innermostUnorderdLoop.getInductionVar()); |
| 254 | rewriter.setInsertionPointToStart(innermostUnorderdLoop.getBody()); |
| 255 | } |
| 256 | |
| 257 | rewriter.inlineBlockBefore( |
| 258 | &loopBlock, innermostUnorderdLoop.getBody()->getTerminator(), ivArgs); |
| 259 | rewriter.eraseOp(doConcurentOp); |
| 260 | return mlir::success(); |
| 261 | } |
| 262 | }; |
| 263 | |
| 264 | void SimplifyFIROperationsPass::runOnOperation() { |
| 265 | mlir::ModuleOp module = getOperation(); |
| 266 | mlir::MLIRContext &context = getContext(); |
| 267 | mlir::RewritePatternSet patterns(&context); |
| 268 | fir::populateSimplifyFIROperationsPatterns(patterns, |
| 269 | preferInlineImplementation); |
| 270 | mlir::GreedyRewriteConfig config; |
| 271 | config.setRegionSimplificationLevel( |
| 272 | mlir::GreedySimplifyRegionLevel::Disabled); |
| 273 | |
| 274 | if (mlir::failed( |
| 275 | mlir::applyPatternsGreedily(module, std::move(patterns), config))) { |
| 276 | mlir::emitError(module.getLoc(), DEBUG_TYPE " pass failed" ); |
| 277 | signalPassFailure(); |
| 278 | } |
| 279 | } |
| 280 | |
| 281 | void fir::populateSimplifyFIROperationsPatterns( |
| 282 | mlir::RewritePatternSet &patterns, bool preferInlineImplementation) { |
| 283 | patterns.insert<IsContiguousBoxCoversion, BoxTotalElementsConversion>( |
| 284 | patterns.getContext(), preferInlineImplementation); |
| 285 | patterns.insert<DoConcurrentConversion>(patterns.getContext()); |
| 286 | } |
| 287 | |