1//===- LegalizeDataValues.cpp - -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/OpenACC/OpenACC.h"
13#include "mlir/IR/Dominance.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/RegionUtils.h"
16#include "llvm/Support/ErrorHandling.h"
17
18namespace mlir {
19namespace acc {
20#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
21#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
22} // namespace acc
23} // namespace mlir
24
25using namespace mlir;
26
27namespace {
28
29static bool insideAccComputeRegion(mlir::Operation *op) {
30 mlir::Operation *parent{op->getParentOp()};
31 while (parent) {
32 if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
33 return true;
34 }
35 parent = parent->getParentOp();
36 }
37 return false;
38}
39
40static void collectVars(mlir::ValueRange operands,
41 llvm::SmallVector<std::pair<Value, Value>> &values,
42 bool hostToDevice) {
43 for (auto operand : operands) {
44 Value var = acc::getVar(accDataClauseOp: operand.getDefiningOp());
45 Value accVar = acc::getAccVar(accDataClauseOp: operand.getDefiningOp());
46 if (var && accVar) {
47 if (hostToDevice)
48 values.push_back({var, accVar});
49 else
50 values.push_back({accVar, var});
51 }
52 }
53}
54
55template <typename Op>
56static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
57 Region &outerRegion) {
58 for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
59 if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
60 if constexpr (std::is_same_v<Op, acc::DataOp> ||
61 std::is_same_v<Op, acc::DeclareOp>) {
62 // For data construct regions, only replace uses in contained compute
63 // regions.
64 if (insideAccComputeRegion(use.getOwner())) {
65 use.set(replacement);
66 }
67 } else {
68 use.set(replacement);
69 }
70 }
71 }
72}
73
74template <typename Op>
75static void replaceAllUsesInUnstructuredComputeRegionWith(
76 Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
77 DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
78
79 SmallVector<Operation *> exitOps;
80 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
81 // For declare enter/exit pairs, collect all exit ops
82 for (auto *user : op.getToken().getUsers()) {
83 if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
84 exitOps.push_back(declareExit);
85 }
86 if (exitOps.empty())
87 return;
88 }
89
90 for (auto p : values) {
91 Value hostVal = std::get<0>(p);
92 Value deviceVal = std::get<1>(p);
93 for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
94 Operation *owner = use.getOwner();
95
96 // Check It's the case that the acc entry operation dominates the use.
97 if (!domInfo.dominates(op.getOperation(), owner))
98 continue;
99
100 // Check It's the case that at least one of the acc exit operations
101 // post-dominates the use
102 bool hasPostDominatingExit = false;
103 for (auto *exit : exitOps) {
104 if (postDomInfo.postDominates(exit, owner)) {
105 hasPostDominatingExit = true;
106 break;
107 }
108 }
109
110 if (!hasPostDominatingExit)
111 continue;
112
113 if (insideAccComputeRegion(owner))
114 use.set(deviceVal);
115 }
116 }
117}
118
119template <typename Op>
120static void
121collectAndReplaceInRegion(Op &op, bool hostToDevice,
122 DominanceInfo *domInfo = nullptr,
123 PostDominanceInfo *postDomInfo = nullptr) {
124 llvm::SmallVector<std::pair<Value, Value>> values;
125
126 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
127 collectVars(op.getReductionOperands(), values, hostToDevice);
128 collectVars(op.getPrivateOperands(), values, hostToDevice);
129 } else {
130 collectVars(op.getDataClauseOperands(), values, hostToDevice);
131 if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
132 !std::is_same_v<Op, acc::DataOp> &&
133 !std::is_same_v<Op, acc::DeclareOp> &&
134 !std::is_same_v<Op, acc::HostDataOp> &&
135 !std::is_same_v<Op, acc::DeclareEnterOp>) {
136 collectVars(op.getReductionOperands(), values, hostToDevice);
137 collectVars(op.getPrivateOperands(), values, hostToDevice);
138 collectVars(op.getFirstprivateOperands(), values, hostToDevice);
139 }
140 }
141
142 if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
143 assert(domInfo && postDomInfo &&
144 "Dominance info required for DeclareEnterOp");
145 replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
146 *postDomInfo);
147 } else {
148 for (auto p : values) {
149 replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
150 op.getRegion());
151 }
152 }
153}
154
155class LegalizeDataValuesInRegion
156 : public acc::impl::LegalizeDataValuesInRegionBase<
157 LegalizeDataValuesInRegion> {
158public:
159 using LegalizeDataValuesInRegionBase<
160 LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
161
162 void runOnOperation() override {
163 func::FuncOp funcOp = getOperation();
164 bool replaceHostVsDevice = this->hostToDevice.getValue();
165
166 // Initialize dominance info
167 DominanceInfo domInfo;
168 PostDominanceInfo postDomInfo;
169 bool computedDomInfo = false;
170
171 funcOp.walk([&](Operation *op) {
172 if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
173 !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
174 applyToAccDataConstruct) &&
175 !isa<acc::DeclareEnterOp>(*op))
176 return;
177
178 if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
179 collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
180 } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
181 collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
182 } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
183 collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
184 } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
185 collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
186 } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
187 collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
188 } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
189 collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
190 } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
191 collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
192 } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
193 if (!computedDomInfo) {
194 domInfo = DominanceInfo(funcOp);
195 postDomInfo = PostDominanceInfo(funcOp);
196 computedDomInfo = true;
197 }
198 collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
199 &postDomInfo);
200 } else {
201 llvm_unreachable("unsupported acc region op");
202 }
203 });
204 }
205};
206
207} // end anonymous namespace
208

source code of mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp