1//===- InlineHLFIRCopyIn.cpp - Inline hlfir.copy_in ops -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// Transform hlfir.copy_in array operations into loop nests performing element
9// per element assignments. For simplicity, the inlining is done for trivial
10// data types when the copy_in does not require a corresponding copy_out and
11// when the input array is not behind a pointer. This may change in the future.
12//===----------------------------------------------------------------------===//
13
14#include "flang/Optimizer/Builder/FIRBuilder.h"
15#include "flang/Optimizer/Builder/HLFIRTools.h"
16#include "flang/Optimizer/Dialect/FIRType.h"
17#include "flang/Optimizer/HLFIR/HLFIROps.h"
18#include "flang/Optimizer/OpenMP/Passes.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Support/LLVM.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23namespace hlfir {
24#define GEN_PASS_DEF_INLINEHLFIRCOPYIN
25#include "flang/Optimizer/HLFIR/Passes.h.inc"
26} // namespace hlfir
27
28#define DEBUG_TYPE "inline-hlfir-copy-in"
29
30static llvm::cl::opt<bool> noInlineHLFIRCopyIn(
31 "no-inline-hlfir-copy-in",
32 llvm::cl::desc("Do not inline hlfir.copy_in operations"),
33 llvm::cl::init(false));
34
35namespace {
36class InlineCopyInConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> {
37public:
38 using mlir::OpRewritePattern<hlfir::CopyInOp>::OpRewritePattern;
39
40 llvm::LogicalResult
41 matchAndRewrite(hlfir::CopyInOp copyIn,
42 mlir::PatternRewriter &rewriter) const override;
43};
44
45llvm::LogicalResult
46InlineCopyInConversion::matchAndRewrite(hlfir::CopyInOp copyIn,
47 mlir::PatternRewriter &rewriter) const {
48 fir::FirOpBuilder builder(rewriter, copyIn.getOperation());
49 mlir::Location loc = copyIn.getLoc();
50 hlfir::Entity inputVariable{copyIn.getVar()};
51 mlir::Type resultAddrType = copyIn.getCopiedIn().getType();
52 if (!fir::isa_trivial(inputVariable.getFortranElementType()))
53 return rewriter.notifyMatchFailure(copyIn,
54 "CopyInOp's data type is not trivial");
55
56 // There should be exactly one user of WasCopied - the corresponding
57 // CopyOutOp.
58 if (!copyIn.getWasCopied().hasOneUse())
59 return rewriter.notifyMatchFailure(
60 copyIn, "CopyInOp's WasCopied has no single user");
61 // The copy out should always be present, either to actually copy or just
62 // deallocate memory.
63 auto copyOut = mlir::dyn_cast<hlfir::CopyOutOp>(
64 copyIn.getWasCopied().user_begin().getCurrent().getUser());
65
66 if (!copyOut)
67 return rewriter.notifyMatchFailure(copyIn,
68 "CopyInOp has no direct CopyOut");
69
70 if (mlir::cast<fir::BaseBoxType>(resultAddrType).isAssumedRank())
71 return rewriter.notifyMatchFailure(copyIn,
72 "The result array is assumed-rank");
73
74 // Only inline the copy_in when copy_out does not need to be done, i.e. in
75 // case of intent(in).
76 if (copyOut.getVar())
77 return rewriter.notifyMatchFailure(copyIn, "CopyIn needs a copy-out");
78
79 inputVariable =
80 hlfir::derefPointersAndAllocatables(loc, builder, inputVariable);
81 mlir::Type sequenceType =
82 hlfir::getFortranElementOrSequenceType(inputVariable.getType());
83 fir::BoxType resultBoxType = fir::BoxType::get(sequenceType);
84 mlir::Value isContiguous =
85 builder.create<fir::IsContiguousBoxOp>(loc, inputVariable);
86 mlir::Operation::result_range results =
87 builder
88 .genIfOp(loc, {resultBoxType, builder.getI1Type()}, isContiguous,
89 /*withElseRegion=*/true)
90 .genThen([&]() {
91 mlir::Value result = inputVariable;
92 if (fir::isPointerType(inputVariable.getType())) {
93 result = builder.create<fir::ReboxOp>(
94 loc, resultBoxType, inputVariable, mlir::Value{},
95 mlir::Value{});
96 }
97 builder.create<fir::ResultOp>(
98 loc, mlir::ValueRange{result, builder.createBool(loc, false)});
99 })
100 .genElse([&] {
101 mlir::Value shape = hlfir::genShape(loc, builder, inputVariable);
102 llvm::SmallVector<mlir::Value> extents =
103 hlfir::getIndexExtents(loc, builder, shape);
104 llvm::StringRef tmpName{".tmp.copy_in"};
105 llvm::SmallVector<mlir::Value> lenParams;
106 mlir::Value alloc = builder.createHeapTemporary(
107 loc, sequenceType, tmpName, extents, lenParams);
108
109 auto declareOp = builder.create<hlfir::DeclareOp>(
110 loc, alloc, tmpName, shape, lenParams,
111 /*dummy_scope=*/nullptr);
112 hlfir::Entity temp{declareOp.getBase()};
113 hlfir::LoopNest loopNest =
114 hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
115 flangomp::shouldUseWorkshareLowering(copyIn),
116 /*couldVectorize=*/false);
117 builder.setInsertionPointToStart(loopNest.body);
118 hlfir::Entity elem = hlfir::getElementAt(
119 loc, builder, inputVariable, loopNest.oneBasedIndices);
120 elem = hlfir::loadTrivialScalar(loc, builder, elem);
121 hlfir::Entity tempElem = hlfir::getElementAt(
122 loc, builder, temp, loopNest.oneBasedIndices);
123 builder.create<hlfir::AssignOp>(loc, elem, tempElem);
124 builder.setInsertionPointAfter(loopNest.outerOp);
125
126 mlir::Value result;
127 // Make sure the result is always a boxed array by boxing it
128 // ourselves if need be.
129 if (mlir::isa<fir::BaseBoxType>(temp.getType())) {
130 result = temp;
131 } else {
132 fir::ReferenceType refTy =
133 fir::ReferenceType::get(temp.getElementOrSequenceType());
134 mlir::Value refVal = builder.createConvert(loc, refTy, temp);
135 result = builder.create<fir::EmboxOp>(loc, resultBoxType, refVal,
136 shape);
137 }
138
139 builder.create<fir::ResultOp>(
140 loc, mlir::ValueRange{result, builder.createBool(loc, true)});
141 })
142 .getResults();
143
144 mlir::OpResult resultBox = results[0];
145 mlir::OpResult needsCleanup = results[1];
146
147 // Prepare the corresponding copyOut to free the temporary if it is required
148 auto alloca = builder.create<fir::AllocaOp>(loc, resultBox.getType());
149 auto store = builder.create<fir::StoreOp>(loc, resultBox, alloca);
150 rewriter.startOpModification(copyOut);
151 copyOut->setOperand(0, store.getMemref());
152 copyOut->setOperand(1, needsCleanup);
153 rewriter.finalizeOpModification(copyOut);
154
155 rewriter.replaceOp(copyIn, {resultBox, builder.genNot(loc, isContiguous)});
156 return mlir::success();
157}
158
159class InlineHLFIRCopyInPass
160 : public hlfir::impl::InlineHLFIRCopyInBase<InlineHLFIRCopyInPass> {
161public:
162 void runOnOperation() override {
163 mlir::MLIRContext *context = &getContext();
164
165 mlir::GreedyRewriteConfig config;
166 // Prevent the pattern driver from merging blocks.
167 config.setRegionSimplificationLevel(
168 mlir::GreedySimplifyRegionLevel::Disabled);
169
170 mlir::RewritePatternSet patterns(context);
171 if (!noInlineHLFIRCopyIn) {
172 patterns.insert<InlineCopyInConversion>(context);
173 }
174
175 if (mlir::failed(mlir::applyPatternsGreedily(
176 getOperation(), std::move(patterns), config))) {
177 mlir::emitError(getOperation()->getLoc(),
178 "failure in hlfir.copy_in inlining");
179 signalPassFailure();
180 }
181 }
182};
183} // namespace
184

source code of flang/lib/Optimizer/HLFIR/Transforms/InlineHLFIRCopyIn.cpp