1 | //===- InlineHLFIRCopyIn.cpp - Inline hlfir.copy_in ops -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // Transform hlfir.copy_in array operations into loop nests performing element |
9 | // per element assignments. For simplicity, the inlining is done for trivial |
10 | // data types when the copy_in does not require a corresponding copy_out and |
11 | // when the input array is not behind a pointer. This may change in the future. |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
15 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
16 | #include "flang/Optimizer/Dialect/FIRType.h" |
17 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
18 | #include "flang/Optimizer/OpenMP/Passes.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Support/LLVM.h" |
21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
22 | |
23 | namespace hlfir { |
24 | #define GEN_PASS_DEF_INLINEHLFIRCOPYIN |
25 | #include "flang/Optimizer/HLFIR/Passes.h.inc" |
26 | } // namespace hlfir |
27 | |
28 | #define DEBUG_TYPE "inline-hlfir-copy-in" |
29 | |
30 | static llvm::cl::opt<bool> noInlineHLFIRCopyIn( |
31 | "no-inline-hlfir-copy-in" , |
32 | llvm::cl::desc("Do not inline hlfir.copy_in operations" ), |
33 | llvm::cl::init(false)); |
34 | |
35 | namespace { |
36 | class InlineCopyInConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> { |
37 | public: |
38 | using mlir::OpRewritePattern<hlfir::CopyInOp>::OpRewritePattern; |
39 | |
40 | llvm::LogicalResult |
41 | matchAndRewrite(hlfir::CopyInOp copyIn, |
42 | mlir::PatternRewriter &rewriter) const override; |
43 | }; |
44 | |
45 | llvm::LogicalResult |
46 | InlineCopyInConversion::matchAndRewrite(hlfir::CopyInOp copyIn, |
47 | mlir::PatternRewriter &rewriter) const { |
48 | fir::FirOpBuilder builder(rewriter, copyIn.getOperation()); |
49 | mlir::Location loc = copyIn.getLoc(); |
50 | hlfir::Entity inputVariable{copyIn.getVar()}; |
51 | mlir::Type resultAddrType = copyIn.getCopiedIn().getType(); |
52 | if (!fir::isa_trivial(inputVariable.getFortranElementType())) |
53 | return rewriter.notifyMatchFailure(copyIn, |
54 | "CopyInOp's data type is not trivial" ); |
55 | |
56 | // There should be exactly one user of WasCopied - the corresponding |
57 | // CopyOutOp. |
58 | if (!copyIn.getWasCopied().hasOneUse()) |
59 | return rewriter.notifyMatchFailure( |
60 | copyIn, "CopyInOp's WasCopied has no single user" ); |
61 | // The copy out should always be present, either to actually copy or just |
62 | // deallocate memory. |
63 | auto copyOut = mlir::dyn_cast<hlfir::CopyOutOp>( |
64 | copyIn.getWasCopied().user_begin().getCurrent().getUser()); |
65 | |
66 | if (!copyOut) |
67 | return rewriter.notifyMatchFailure(copyIn, |
68 | "CopyInOp has no direct CopyOut" ); |
69 | |
70 | if (mlir::cast<fir::BaseBoxType>(resultAddrType).isAssumedRank()) |
71 | return rewriter.notifyMatchFailure(copyIn, |
72 | "The result array is assumed-rank" ); |
73 | |
74 | // Only inline the copy_in when copy_out does not need to be done, i.e. in |
75 | // case of intent(in). |
76 | if (copyOut.getVar()) |
77 | return rewriter.notifyMatchFailure(copyIn, "CopyIn needs a copy-out" ); |
78 | |
79 | inputVariable = |
80 | hlfir::derefPointersAndAllocatables(loc, builder, inputVariable); |
81 | mlir::Type sequenceType = |
82 | hlfir::getFortranElementOrSequenceType(inputVariable.getType()); |
83 | fir::BoxType resultBoxType = fir::BoxType::get(sequenceType); |
84 | mlir::Value isContiguous = |
85 | builder.create<fir::IsContiguousBoxOp>(loc, inputVariable); |
86 | mlir::Operation::result_range results = |
87 | builder |
88 | .genIfOp(loc, {resultBoxType, builder.getI1Type()}, isContiguous, |
89 | /*withElseRegion=*/true) |
90 | .genThen([&]() { |
91 | mlir::Value result = inputVariable; |
92 | if (fir::isPointerType(inputVariable.getType())) { |
93 | result = builder.create<fir::ReboxOp>( |
94 | loc, resultBoxType, inputVariable, mlir::Value{}, |
95 | mlir::Value{}); |
96 | } |
97 | builder.create<fir::ResultOp>( |
98 | loc, mlir::ValueRange{result, builder.createBool(loc, false)}); |
99 | }) |
100 | .genElse([&] { |
101 | mlir::Value shape = hlfir::genShape(loc, builder, inputVariable); |
102 | llvm::SmallVector<mlir::Value> extents = |
103 | hlfir::getIndexExtents(loc, builder, shape); |
104 | llvm::StringRef tmpName{".tmp.copy_in" }; |
105 | llvm::SmallVector<mlir::Value> lenParams; |
106 | mlir::Value alloc = builder.createHeapTemporary( |
107 | loc, sequenceType, tmpName, extents, lenParams); |
108 | |
109 | auto declareOp = builder.create<hlfir::DeclareOp>( |
110 | loc, alloc, tmpName, shape, lenParams, |
111 | /*dummy_scope=*/nullptr); |
112 | hlfir::Entity temp{declareOp.getBase()}; |
113 | hlfir::LoopNest loopNest = |
114 | hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, |
115 | flangomp::shouldUseWorkshareLowering(copyIn), |
116 | /*couldVectorize=*/false); |
117 | builder.setInsertionPointToStart(loopNest.body); |
118 | hlfir::Entity elem = hlfir::getElementAt( |
119 | loc, builder, inputVariable, loopNest.oneBasedIndices); |
120 | elem = hlfir::loadTrivialScalar(loc, builder, elem); |
121 | hlfir::Entity tempElem = hlfir::getElementAt( |
122 | loc, builder, temp, loopNest.oneBasedIndices); |
123 | builder.create<hlfir::AssignOp>(loc, elem, tempElem); |
124 | builder.setInsertionPointAfter(loopNest.outerOp); |
125 | |
126 | mlir::Value result; |
127 | // Make sure the result is always a boxed array by boxing it |
128 | // ourselves if need be. |
129 | if (mlir::isa<fir::BaseBoxType>(temp.getType())) { |
130 | result = temp; |
131 | } else { |
132 | fir::ReferenceType refTy = |
133 | fir::ReferenceType::get(temp.getElementOrSequenceType()); |
134 | mlir::Value refVal = builder.createConvert(loc, refTy, temp); |
135 | result = builder.create<fir::EmboxOp>(loc, resultBoxType, refVal, |
136 | shape); |
137 | } |
138 | |
139 | builder.create<fir::ResultOp>( |
140 | loc, mlir::ValueRange{result, builder.createBool(loc, true)}); |
141 | }) |
142 | .getResults(); |
143 | |
144 | mlir::OpResult resultBox = results[0]; |
145 | mlir::OpResult needsCleanup = results[1]; |
146 | |
147 | // Prepare the corresponding copyOut to free the temporary if it is required |
148 | auto alloca = builder.create<fir::AllocaOp>(loc, resultBox.getType()); |
149 | auto store = builder.create<fir::StoreOp>(loc, resultBox, alloca); |
150 | rewriter.startOpModification(copyOut); |
151 | copyOut->setOperand(0, store.getMemref()); |
152 | copyOut->setOperand(1, needsCleanup); |
153 | rewriter.finalizeOpModification(copyOut); |
154 | |
155 | rewriter.replaceOp(copyIn, {resultBox, builder.genNot(loc, isContiguous)}); |
156 | return mlir::success(); |
157 | } |
158 | |
159 | class InlineHLFIRCopyInPass |
160 | : public hlfir::impl::InlineHLFIRCopyInBase<InlineHLFIRCopyInPass> { |
161 | public: |
162 | void runOnOperation() override { |
163 | mlir::MLIRContext *context = &getContext(); |
164 | |
165 | mlir::GreedyRewriteConfig config; |
166 | // Prevent the pattern driver from merging blocks. |
167 | config.setRegionSimplificationLevel( |
168 | mlir::GreedySimplifyRegionLevel::Disabled); |
169 | |
170 | mlir::RewritePatternSet patterns(context); |
171 | if (!noInlineHLFIRCopyIn) { |
172 | patterns.insert<InlineCopyInConversion>(context); |
173 | } |
174 | |
175 | if (mlir::failed(mlir::applyPatternsGreedily( |
176 | getOperation(), std::move(patterns), config))) { |
177 | mlir::emitError(getOperation()->getLoc(), |
178 | "failure in hlfir.copy_in inlining" ); |
179 | signalPassFailure(); |
180 | } |
181 | } |
182 | }; |
183 | } // namespace |
184 | |