| 1 | //===- InlineHLFIRCopyIn.cpp - Inline hlfir.copy_in ops -------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // Transform hlfir.copy_in array operations into loop nests performing element |
| 9 | // per element assignments. For simplicity, the inlining is done for trivial |
| 10 | // data types when the copy_in does not require a corresponding copy_out and |
| 11 | // when the input array is not behind a pointer. This may change in the future. |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
| 15 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
| 16 | #include "flang/Optimizer/Dialect/FIRType.h" |
| 17 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
| 18 | #include "flang/Optimizer/OpenMP/Passes.h" |
| 19 | #include "mlir/IR/PatternMatch.h" |
| 20 | #include "mlir/Support/LLVM.h" |
| 21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 22 | |
| 23 | namespace hlfir { |
| 24 | #define GEN_PASS_DEF_INLINEHLFIRCOPYIN |
| 25 | #include "flang/Optimizer/HLFIR/Passes.h.inc" |
| 26 | } // namespace hlfir |
| 27 | |
| 28 | #define DEBUG_TYPE "inline-hlfir-copy-in" |
| 29 | |
| 30 | static llvm::cl::opt<bool> noInlineHLFIRCopyIn( |
| 31 | "no-inline-hlfir-copy-in" , |
| 32 | llvm::cl::desc("Do not inline hlfir.copy_in operations" ), |
| 33 | llvm::cl::init(false)); |
| 34 | |
| 35 | namespace { |
| 36 | class InlineCopyInConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> { |
| 37 | public: |
| 38 | using mlir::OpRewritePattern<hlfir::CopyInOp>::OpRewritePattern; |
| 39 | |
| 40 | llvm::LogicalResult |
| 41 | matchAndRewrite(hlfir::CopyInOp copyIn, |
| 42 | mlir::PatternRewriter &rewriter) const override; |
| 43 | }; |
| 44 | |
| 45 | llvm::LogicalResult |
| 46 | InlineCopyInConversion::matchAndRewrite(hlfir::CopyInOp copyIn, |
| 47 | mlir::PatternRewriter &rewriter) const { |
| 48 | fir::FirOpBuilder builder(rewriter, copyIn.getOperation()); |
| 49 | mlir::Location loc = copyIn.getLoc(); |
| 50 | hlfir::Entity inputVariable{copyIn.getVar()}; |
| 51 | mlir::Type resultAddrType = copyIn.getCopiedIn().getType(); |
| 52 | if (!fir::isa_trivial(inputVariable.getFortranElementType())) |
| 53 | return rewriter.notifyMatchFailure(copyIn, |
| 54 | "CopyInOp's data type is not trivial" ); |
| 55 | |
| 56 | // There should be exactly one user of WasCopied - the corresponding |
| 57 | // CopyOutOp. |
| 58 | if (!copyIn.getWasCopied().hasOneUse()) |
| 59 | return rewriter.notifyMatchFailure( |
| 60 | copyIn, "CopyInOp's WasCopied has no single user" ); |
| 61 | // The copy out should always be present, either to actually copy or just |
| 62 | // deallocate memory. |
| 63 | auto copyOut = mlir::dyn_cast<hlfir::CopyOutOp>( |
| 64 | copyIn.getWasCopied().user_begin().getCurrent().getUser()); |
| 65 | |
| 66 | if (!copyOut) |
| 67 | return rewriter.notifyMatchFailure(copyIn, |
| 68 | "CopyInOp has no direct CopyOut" ); |
| 69 | |
| 70 | if (mlir::cast<fir::BaseBoxType>(resultAddrType).isAssumedRank()) |
| 71 | return rewriter.notifyMatchFailure(copyIn, |
| 72 | "The result array is assumed-rank" ); |
| 73 | |
| 74 | // Only inline the copy_in when copy_out does not need to be done, i.e. in |
| 75 | // case of intent(in). |
| 76 | if (copyOut.getVar()) |
| 77 | return rewriter.notifyMatchFailure(copyIn, "CopyIn needs a copy-out" ); |
| 78 | |
| 79 | inputVariable = |
| 80 | hlfir::derefPointersAndAllocatables(loc, builder, inputVariable); |
| 81 | mlir::Type sequenceType = |
| 82 | hlfir::getFortranElementOrSequenceType(inputVariable.getType()); |
| 83 | fir::BoxType resultBoxType = fir::BoxType::get(sequenceType); |
| 84 | mlir::Value isContiguous = |
| 85 | builder.create<fir::IsContiguousBoxOp>(loc, inputVariable); |
| 86 | mlir::Operation::result_range results = |
| 87 | builder |
| 88 | .genIfOp(loc, {resultBoxType, builder.getI1Type()}, isContiguous, |
| 89 | /*withElseRegion=*/true) |
| 90 | .genThen([&]() { |
| 91 | mlir::Value result = inputVariable; |
| 92 | if (fir::isPointerType(inputVariable.getType())) { |
| 93 | result = builder.create<fir::ReboxOp>( |
| 94 | loc, resultBoxType, inputVariable, mlir::Value{}, |
| 95 | mlir::Value{}); |
| 96 | } |
| 97 | builder.create<fir::ResultOp>( |
| 98 | loc, mlir::ValueRange{result, builder.createBool(loc, false)}); |
| 99 | }) |
| 100 | .genElse([&] { |
| 101 | mlir::Value shape = hlfir::genShape(loc, builder, inputVariable); |
| 102 | llvm::SmallVector<mlir::Value> extents = |
| 103 | hlfir::getIndexExtents(loc, builder, shape); |
| 104 | llvm::StringRef tmpName{".tmp.copy_in" }; |
| 105 | llvm::SmallVector<mlir::Value> lenParams; |
| 106 | mlir::Value alloc = builder.createHeapTemporary( |
| 107 | loc, sequenceType, tmpName, extents, lenParams); |
| 108 | |
| 109 | auto declareOp = builder.create<hlfir::DeclareOp>( |
| 110 | loc, alloc, tmpName, shape, lenParams, |
| 111 | /*dummy_scope=*/nullptr); |
| 112 | hlfir::Entity temp{declareOp.getBase()}; |
| 113 | hlfir::LoopNest loopNest = |
| 114 | hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, |
| 115 | flangomp::shouldUseWorkshareLowering(copyIn), |
| 116 | /*couldVectorize=*/false); |
| 117 | builder.setInsertionPointToStart(loopNest.body); |
| 118 | hlfir::Entity elem = hlfir::getElementAt( |
| 119 | loc, builder, inputVariable, loopNest.oneBasedIndices); |
| 120 | elem = hlfir::loadTrivialScalar(loc, builder, elem); |
| 121 | hlfir::Entity tempElem = hlfir::getElementAt( |
| 122 | loc, builder, temp, loopNest.oneBasedIndices); |
| 123 | builder.create<hlfir::AssignOp>(loc, elem, tempElem); |
| 124 | builder.setInsertionPointAfter(loopNest.outerOp); |
| 125 | |
| 126 | mlir::Value result; |
| 127 | // Make sure the result is always a boxed array by boxing it |
| 128 | // ourselves if need be. |
| 129 | if (mlir::isa<fir::BaseBoxType>(temp.getType())) { |
| 130 | result = temp; |
| 131 | } else { |
| 132 | fir::ReferenceType refTy = |
| 133 | fir::ReferenceType::get(temp.getElementOrSequenceType()); |
| 134 | mlir::Value refVal = builder.createConvert(loc, refTy, temp); |
| 135 | result = builder.create<fir::EmboxOp>(loc, resultBoxType, refVal, |
| 136 | shape); |
| 137 | } |
| 138 | |
| 139 | builder.create<fir::ResultOp>( |
| 140 | loc, mlir::ValueRange{result, builder.createBool(loc, true)}); |
| 141 | }) |
| 142 | .getResults(); |
| 143 | |
| 144 | mlir::OpResult resultBox = results[0]; |
| 145 | mlir::OpResult needsCleanup = results[1]; |
| 146 | |
| 147 | // Prepare the corresponding copyOut to free the temporary if it is required |
| 148 | auto alloca = builder.create<fir::AllocaOp>(loc, resultBox.getType()); |
| 149 | auto store = builder.create<fir::StoreOp>(loc, resultBox, alloca); |
| 150 | rewriter.startOpModification(copyOut); |
| 151 | copyOut->setOperand(0, store.getMemref()); |
| 152 | copyOut->setOperand(1, needsCleanup); |
| 153 | rewriter.finalizeOpModification(copyOut); |
| 154 | |
| 155 | rewriter.replaceOp(copyIn, {resultBox, builder.genNot(loc, isContiguous)}); |
| 156 | return mlir::success(); |
| 157 | } |
| 158 | |
| 159 | class InlineHLFIRCopyInPass |
| 160 | : public hlfir::impl::InlineHLFIRCopyInBase<InlineHLFIRCopyInPass> { |
| 161 | public: |
| 162 | void runOnOperation() override { |
| 163 | mlir::MLIRContext *context = &getContext(); |
| 164 | |
| 165 | mlir::GreedyRewriteConfig config; |
| 166 | // Prevent the pattern driver from merging blocks. |
| 167 | config.setRegionSimplificationLevel( |
| 168 | mlir::GreedySimplifyRegionLevel::Disabled); |
| 169 | |
| 170 | mlir::RewritePatternSet patterns(context); |
| 171 | if (!noInlineHLFIRCopyIn) { |
| 172 | patterns.insert<InlineCopyInConversion>(context); |
| 173 | } |
| 174 | |
| 175 | if (mlir::failed(mlir::applyPatternsGreedily( |
| 176 | getOperation(), std::move(patterns), config))) { |
| 177 | mlir::emitError(getOperation()->getLoc(), |
| 178 | "failure in hlfir.copy_in inlining" ); |
| 179 | signalPassFailure(); |
| 180 | } |
| 181 | } |
| 182 | }; |
| 183 | } // namespace |
| 184 | |