| 1 | //===- SimplifyFIROperations.cpp -- simplify complex FIR operations ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | //===----------------------------------------------------------------------===// |
| 10 | /// \file |
| 11 | /// This pass transforms some FIR operations into their equivalent |
| 12 | /// implementations using other FIR operations. The transformation |
| 13 | /// can legally use SCF dialect and generate Fortran runtime calls. |
| 14 | //===----------------------------------------------------------------------===// |
| 15 | |
| 16 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
| 17 | #include "flang/Optimizer/Builder/Runtime/Inquiry.h" |
| 18 | #include "flang/Optimizer/Builder/Todo.h" |
| 19 | #include "flang/Optimizer/Dialect/FIROps.h" |
| 20 | #include "flang/Optimizer/Transforms/Passes.h" |
| 21 | #include "mlir/IR/IRMapping.h" |
| 22 | #include "mlir/Pass/Pass.h" |
| 23 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 24 | #include <optional> |
| 25 | |
| 26 | namespace fir { |
| 27 | #define GEN_PASS_DEF_SIMPLIFYFIROPERATIONS |
| 28 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
| 29 | } // namespace fir |
| 30 | |
| 31 | #define DEBUG_TYPE "flang-simplify-fir-operations" |
| 32 | |
| 33 | namespace { |
| 34 | /// Pass runner. |
| 35 | class SimplifyFIROperationsPass |
| 36 | : public fir::impl::SimplifyFIROperationsBase<SimplifyFIROperationsPass> { |
| 37 | public: |
| 38 | using fir::impl::SimplifyFIROperationsBase< |
| 39 | SimplifyFIROperationsPass>::SimplifyFIROperationsBase; |
| 40 | |
| 41 | void runOnOperation() override final; |
| 42 | }; |
| 43 | |
| 44 | /// Base class for all conversions holding the pass options. |
| 45 | template <typename Op> |
| 46 | class ConversionBase : public mlir::OpRewritePattern<Op> { |
| 47 | public: |
| 48 | using mlir::OpRewritePattern<Op>::OpRewritePattern; |
| 49 | |
| 50 | template <typename... Args> |
| 51 | ConversionBase(mlir::MLIRContext *context, Args &&...args) |
| 52 | : mlir::OpRewritePattern<Op>(context), |
| 53 | options{std::forward<Args>(args)...} {} |
| 54 | |
| 55 | mlir::LogicalResult matchAndRewrite(Op, |
| 56 | mlir::PatternRewriter &) const override; |
| 57 | |
| 58 | protected: |
| 59 | fir::SimplifyFIROperationsOptions options; |
| 60 | }; |
| 61 | |
| 62 | /// fir::IsContiguousBoxOp converter. |
| 63 | using IsContiguousBoxCoversion = ConversionBase<fir::IsContiguousBoxOp>; |
| 64 | |
| 65 | /// fir::BoxTotalElementsOp converter. |
| 66 | using BoxTotalElementsConversion = ConversionBase<fir::BoxTotalElementsOp>; |
| 67 | } // namespace |
| 68 | |
| 69 | /// Generate a call to IsContiguous/IsContiguousUpTo function or an inline |
| 70 | /// sequence reading extents/strides from the box and checking them. |
| 71 | /// This conversion may produce fir.box_elesize and a loop (for assumed |
| 72 | /// rank). |
| 73 | template <> |
| 74 | mlir::LogicalResult IsContiguousBoxCoversion::matchAndRewrite( |
| 75 | fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const { |
| 76 | mlir::Location loc = op.getLoc(); |
| 77 | fir::FirOpBuilder builder(rewriter, op.getOperation()); |
| 78 | mlir::Value box = op.getBox(); |
| 79 | |
| 80 | if (options.preferInlineImplementation) { |
| 81 | auto boxType = mlir::cast<fir::BaseBoxType>(box.getType()); |
| 82 | unsigned rank = fir::getBoxRank(boxType); |
| 83 | |
| 84 | // If rank is one, or 'innermost' attribute is set and |
| 85 | // it is not a scalar, then generate a simple comparison |
| 86 | // for the leading dimension: (stride == elem_size || extent == 0). |
| 87 | // |
| 88 | // The scalar cases are supposed to be optimized by the canonicalization. |
| 89 | if (rank == 1 || (op.getInnermost() && rank > 0)) { |
| 90 | mlir::Type idxTy = builder.getIndexType(); |
| 91 | auto eleSize = builder.create<fir::BoxEleSizeOp>(loc, idxTy, box); |
| 92 | mlir::Value zero = fir::factory::createZeroValue(builder, loc, idxTy); |
| 93 | auto dimInfo = |
| 94 | builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, zero); |
| 95 | mlir::Value stride = dimInfo.getByteStride(); |
| 96 | mlir::Value pred1 = builder.create<mlir::arith::CmpIOp>( |
| 97 | loc, mlir::arith::CmpIPredicate::eq, eleSize, stride); |
| 98 | mlir::Value extent = dimInfo.getExtent(); |
| 99 | mlir::Value pred2 = builder.create<mlir::arith::CmpIOp>( |
| 100 | loc, mlir::arith::CmpIPredicate::eq, extent, zero); |
| 101 | mlir::Value result = |
| 102 | builder.create<mlir::arith::OrIOp>(loc, pred1, pred2); |
| 103 | result = builder.createConvert(loc, op.getType(), result); |
| 104 | rewriter.replaceOp(op, result); |
| 105 | return mlir::success(); |
| 106 | } |
| 107 | // TODO: support arrays with multiple dimensions. |
| 108 | } |
| 109 | |
| 110 | // Generate Fortran runtime call. |
| 111 | mlir::Value result; |
| 112 | if (op.getInnermost()) { |
| 113 | mlir::Value one = |
| 114 | builder.createIntegerConstant(loc, builder.getI32Type(), 1); |
| 115 | result = fir::runtime::genIsContiguousUpTo(builder, loc, box, one); |
| 116 | } else { |
| 117 | result = fir::runtime::genIsContiguous(builder, loc, box); |
| 118 | } |
| 119 | result = builder.createConvert(loc, op.getType(), result); |
| 120 | rewriter.replaceOp(op, result); |
| 121 | return mlir::success(); |
| 122 | } |
| 123 | |
| 124 | /// Generate a call to Size runtime function or an inline |
| 125 | /// sequence reading extents from the box an multiplying them. |
| 126 | /// This conversion may produce a loop (for assumed rank). |
| 127 | template <> |
| 128 | mlir::LogicalResult BoxTotalElementsConversion::matchAndRewrite( |
| 129 | fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const { |
| 130 | mlir::Location loc = op.getLoc(); |
| 131 | fir::FirOpBuilder builder(rewriter, op.getOperation()); |
| 132 | // TODO: support preferInlineImplementation. |
| 133 | // Reading the extent from the box for 1D arrays probably |
| 134 | // results in less code than the call, so we can always |
| 135 | // inline it. |
| 136 | bool doInline = options.preferInlineImplementation && false; |
| 137 | if (!doInline) { |
| 138 | // Generate Fortran runtime call. |
| 139 | mlir::Value result = fir::runtime::genSize(builder, loc, op.getBox()); |
| 140 | result = builder.createConvert(loc, op.getType(), result); |
| 141 | rewriter.replaceOp(op, result); |
| 142 | return mlir::success(); |
| 143 | } |
| 144 | |
| 145 | // Generate inline implementation. |
| 146 | TODO(loc, "inline BoxTotalElementsOp" ); |
| 147 | return mlir::failure(); |
| 148 | } |
| 149 | |
| 150 | class DoConcurrentConversion |
| 151 | : public mlir::OpRewritePattern<fir::DoConcurrentOp> { |
| 152 | /// Looks up from the operation from and returns the LocalitySpecifierOp with |
| 153 | /// name symbolName |
| 154 | static fir::LocalitySpecifierOp |
| 155 | findLocalizer(mlir::Operation *from, mlir::SymbolRefAttr symbolName) { |
| 156 | fir::LocalitySpecifierOp localizer = |
| 157 | mlir::SymbolTable::lookupNearestSymbolFrom<fir::LocalitySpecifierOp>( |
| 158 | from, symbolName); |
| 159 | assert(localizer && "localizer not found in the symbol table" ); |
| 160 | return localizer; |
| 161 | } |
| 162 | |
| 163 | public: |
| 164 | using mlir::OpRewritePattern<fir::DoConcurrentOp>::OpRewritePattern; |
| 165 | |
| 166 | mlir::LogicalResult |
| 167 | matchAndRewrite(fir::DoConcurrentOp doConcurentOp, |
| 168 | mlir::PatternRewriter &rewriter) const override { |
| 169 | assert(doConcurentOp.getRegion().hasOneBlock()); |
| 170 | mlir::Block &wrapperBlock = doConcurentOp.getRegion().getBlocks().front(); |
| 171 | auto loop = |
| 172 | mlir::cast<fir::DoConcurrentLoopOp>(wrapperBlock.getTerminator()); |
| 173 | assert(loop.getRegion().hasOneBlock()); |
| 174 | mlir::Block &loopBlock = loop.getRegion().getBlocks().front(); |
| 175 | |
| 176 | // Handle localization |
| 177 | if (!loop.getLocalVars().empty()) { |
| 178 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
| 179 | rewriter.setInsertionPointToStart(&loop.getRegion().front()); |
| 180 | |
| 181 | std::optional<mlir::ArrayAttr> localSyms = loop.getLocalSyms(); |
| 182 | |
| 183 | for (auto localInfo : llvm::zip_equal( |
| 184 | loop.getLocalVars(), loop.getRegionLocalArgs(), *localSyms)) { |
| 185 | mlir::Value localVar = std::get<0>(localInfo); |
| 186 | mlir::BlockArgument localArg = std::get<1>(localInfo); |
| 187 | mlir::Attribute localizerSym = std::get<2>(localInfo); |
| 188 | mlir::SymbolRefAttr localizerName = |
| 189 | llvm::cast<mlir::SymbolRefAttr>(localizerSym); |
| 190 | fir::LocalitySpecifierOp localizer = findLocalizer(loop, localizerName); |
| 191 | |
| 192 | // TODO Should this be a heap allocation instead? For now, we allocate |
| 193 | // on the stack for each loop iteration. |
| 194 | mlir::Value localAlloc = |
| 195 | rewriter.create<fir::AllocaOp>(loop.getLoc(), localizer.getType()); |
| 196 | |
| 197 | auto cloneLocalizerRegion = [&](mlir::Region ®ion, |
| 198 | mlir::ValueRange regionArgs, |
| 199 | mlir::Block::iterator insertionPoint) { |
| 200 | // It is reasonable to make this assumption since, at this stage, |
| 201 | // control-flow ops are not converted yet. Therefore, things like `if` |
| 202 | // conditions will still be represented by their encapsulating `fir` |
| 203 | // dialect ops. |
| 204 | assert(region.hasOneBlock() && |
| 205 | "Expected localizer region to have a single block." ); |
| 206 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
| 207 | rewriter.setInsertionPoint(rewriter.getInsertionBlock(), |
| 208 | insertionPoint); |
| 209 | mlir::IRMapping mapper; |
| 210 | mapper.map(region.getArguments(), regionArgs); |
| 211 | for (mlir::Operation &op : region.front().without_terminator()) |
| 212 | (void)rewriter.clone(op, mapper); |
| 213 | }; |
| 214 | |
| 215 | if (!localizer.getInitRegion().empty()) |
| 216 | cloneLocalizerRegion(localizer.getInitRegion(), {localVar, localArg}, |
| 217 | rewriter.getInsertionPoint()); |
| 218 | |
| 219 | if (localizer.getLocalitySpecifierType() == |
| 220 | fir::LocalitySpecifierType::LocalInit) |
| 221 | cloneLocalizerRegion(localizer.getCopyRegion(), {localVar, localArg}, |
| 222 | rewriter.getInsertionPoint()); |
| 223 | |
| 224 | if (!localizer.getDeallocRegion().empty()) |
| 225 | cloneLocalizerRegion(localizer.getDeallocRegion(), {localArg}, |
| 226 | rewriter.getInsertionBlock()->end()); |
| 227 | |
| 228 | rewriter.replaceAllUsesWith(localArg, localAlloc); |
| 229 | } |
| 230 | |
| 231 | loop.getRegion().front().eraseArguments(loop.getNumInductionVars(), |
| 232 | loop.getNumLocalOperands()); |
| 233 | loop.getLocalVarsMutable().clear(); |
| 234 | loop.setLocalSymsAttr(nullptr); |
| 235 | } |
| 236 | |
| 237 | for (auto [reduceVar, reduceArg] : |
| 238 | llvm::zip_equal(loop.getReduceVars(), loop.getRegionReduceArgs())) |
| 239 | rewriter.replaceAllUsesWith(reduceArg, reduceVar); |
| 240 | |
| 241 | // Collect iteration variable(s) allocations so that we can move them |
| 242 | // outside the `fir.do_concurrent` wrapper. |
| 243 | llvm::SmallVector<mlir::Operation *> opsToMove; |
| 244 | for (mlir::Operation &op : llvm::drop_end(wrapperBlock)) |
| 245 | opsToMove.push_back(&op); |
| 246 | |
| 247 | fir::FirOpBuilder firBuilder( |
| 248 | rewriter, doConcurentOp->getParentOfType<mlir::ModuleOp>()); |
| 249 | auto *allocIt = firBuilder.getAllocaBlock(); |
| 250 | |
| 251 | for (mlir::Operation *op : llvm::reverse(opsToMove)) |
| 252 | rewriter.moveOpBefore(op, allocIt, allocIt->begin()); |
| 253 | |
| 254 | rewriter.setInsertionPointAfter(doConcurentOp); |
| 255 | fir::DoLoopOp innermostUnorderdLoop; |
| 256 | mlir::SmallVector<mlir::Value> ivArgs; |
| 257 | |
| 258 | for (auto [lb, ub, st, iv] : |
| 259 | llvm::zip_equal(loop.getLowerBound(), loop.getUpperBound(), |
| 260 | loop.getStep(), *loop.getLoopInductionVars())) { |
| 261 | innermostUnorderdLoop = rewriter.create<fir::DoLoopOp>( |
| 262 | doConcurentOp.getLoc(), lb, ub, st, |
| 263 | /*unordred=*/true, /*finalCountValue=*/false, |
| 264 | /*iterArgs=*/std::nullopt, loop.getReduceVars(), |
| 265 | loop.getReduceAttrsAttr()); |
| 266 | ivArgs.push_back(innermostUnorderdLoop.getInductionVar()); |
| 267 | rewriter.setInsertionPointToStart(innermostUnorderdLoop.getBody()); |
| 268 | } |
| 269 | |
| 270 | loop.getRegion().front().eraseArguments(loop.getNumInductionVars() + |
| 271 | loop.getNumLocalOperands(), |
| 272 | loop.getNumReduceOperands()); |
| 273 | |
| 274 | rewriter.inlineBlockBefore( |
| 275 | &loopBlock, innermostUnorderdLoop.getBody()->getTerminator(), ivArgs); |
| 276 | rewriter.eraseOp(doConcurentOp); |
| 277 | return mlir::success(); |
| 278 | } |
| 279 | }; |
| 280 | |
| 281 | void SimplifyFIROperationsPass::runOnOperation() { |
| 282 | mlir::ModuleOp module = getOperation(); |
| 283 | mlir::MLIRContext &context = getContext(); |
| 284 | mlir::RewritePatternSet patterns(&context); |
| 285 | fir::populateSimplifyFIROperationsPatterns(patterns, |
| 286 | preferInlineImplementation); |
| 287 | mlir::GreedyRewriteConfig config; |
| 288 | config.setRegionSimplificationLevel( |
| 289 | mlir::GreedySimplifyRegionLevel::Disabled); |
| 290 | |
| 291 | if (mlir::failed( |
| 292 | mlir::applyPatternsGreedily(module, std::move(patterns), config))) { |
| 293 | mlir::emitError(module.getLoc(), DEBUG_TYPE " pass failed" ); |
| 294 | signalPassFailure(); |
| 295 | } |
| 296 | } |
| 297 | |
| 298 | void fir::populateSimplifyFIROperationsPatterns( |
| 299 | mlir::RewritePatternSet &patterns, bool preferInlineImplementation) { |
| 300 | patterns.insert<IsContiguousBoxCoversion, BoxTotalElementsConversion>( |
| 301 | patterns.getContext(), preferInlineImplementation); |
| 302 | patterns.insert<DoConcurrentConversion>(patterns.getContext()); |
| 303 | } |
| 304 | |