1 | //===- LegalizeDataValues.cpp - -------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/OpenACC/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
13 | #include "mlir/IR/Dominance.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Transforms/RegionUtils.h" |
16 | #include "llvm/Support/ErrorHandling.h" |
17 | |
18 | namespace mlir { |
19 | namespace acc { |
20 | #define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION |
21 | #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" |
22 | } // namespace acc |
23 | } // namespace mlir |
24 | |
25 | using namespace mlir; |
26 | |
27 | namespace { |
28 | |
29 | static bool insideAccComputeRegion(mlir::Operation *op) { |
30 | mlir::Operation *parent{op->getParentOp()}; |
31 | while (parent) { |
32 | if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) { |
33 | return true; |
34 | } |
35 | parent = parent->getParentOp(); |
36 | } |
37 | return false; |
38 | } |
39 | |
40 | static void collectVars(mlir::ValueRange operands, |
41 | llvm::SmallVector<std::pair<Value, Value>> &values, |
42 | bool hostToDevice) { |
43 | for (auto operand : operands) { |
44 | Value var = acc::getVar(accDataClauseOp: operand.getDefiningOp()); |
45 | Value accVar = acc::getAccVar(accDataClauseOp: operand.getDefiningOp()); |
46 | if (var && accVar) { |
47 | if (hostToDevice) |
48 | values.push_back({var, accVar}); |
49 | else |
50 | values.push_back({accVar, var}); |
51 | } |
52 | } |
53 | } |
54 | |
55 | template <typename Op> |
56 | static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement, |
57 | Region &outerRegion) { |
58 | for (auto &use : llvm::make_early_inc_range(orig.getUses())) { |
59 | if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) { |
60 | if constexpr (std::is_same_v<Op, acc::DataOp> || |
61 | std::is_same_v<Op, acc::DeclareOp>) { |
62 | // For data construct regions, only replace uses in contained compute |
63 | // regions. |
64 | if (insideAccComputeRegion(use.getOwner())) { |
65 | use.set(replacement); |
66 | } |
67 | } else { |
68 | use.set(replacement); |
69 | } |
70 | } |
71 | } |
72 | } |
73 | |
74 | template <typename Op> |
75 | static void replaceAllUsesInUnstructuredComputeRegionWith( |
76 | Op &op, llvm::SmallVector<std::pair<Value, Value>> &values, |
77 | DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) { |
78 | |
79 | SmallVector<Operation *> exitOps; |
80 | if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) { |
81 | // For declare enter/exit pairs, collect all exit ops |
82 | for (auto *user : op.getToken().getUsers()) { |
83 | if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user)) |
84 | exitOps.push_back(declareExit); |
85 | } |
86 | if (exitOps.empty()) |
87 | return; |
88 | } |
89 | |
90 | for (auto p : values) { |
91 | Value hostVal = std::get<0>(p); |
92 | Value deviceVal = std::get<1>(p); |
93 | for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) { |
94 | Operation *owner = use.getOwner(); |
95 | |
96 | // Check It's the case that the acc entry operation dominates the use. |
97 | if (!domInfo.dominates(op.getOperation(), owner)) |
98 | continue; |
99 | |
100 | // Check It's the case that at least one of the acc exit operations |
101 | // post-dominates the use |
102 | bool hasPostDominatingExit = false; |
103 | for (auto *exit : exitOps) { |
104 | if (postDomInfo.postDominates(exit, owner)) { |
105 | hasPostDominatingExit = true; |
106 | break; |
107 | } |
108 | } |
109 | |
110 | if (!hasPostDominatingExit) |
111 | continue; |
112 | |
113 | if (insideAccComputeRegion(owner)) |
114 | use.set(deviceVal); |
115 | } |
116 | } |
117 | } |
118 | |
119 | template <typename Op> |
120 | static void |
121 | collectAndReplaceInRegion(Op &op, bool hostToDevice, |
122 | DominanceInfo *domInfo = nullptr, |
123 | PostDominanceInfo *postDomInfo = nullptr) { |
124 | llvm::SmallVector<std::pair<Value, Value>> values; |
125 | |
126 | if constexpr (std::is_same_v<Op, acc::LoopOp>) { |
127 | collectVars(op.getReductionOperands(), values, hostToDevice); |
128 | collectVars(op.getPrivateOperands(), values, hostToDevice); |
129 | } else { |
130 | collectVars(op.getDataClauseOperands(), values, hostToDevice); |
131 | if constexpr (!std::is_same_v<Op, acc::KernelsOp> && |
132 | !std::is_same_v<Op, acc::DataOp> && |
133 | !std::is_same_v<Op, acc::DeclareOp> && |
134 | !std::is_same_v<Op, acc::HostDataOp> && |
135 | !std::is_same_v<Op, acc::DeclareEnterOp>) { |
136 | collectVars(op.getReductionOperands(), values, hostToDevice); |
137 | collectVars(op.getPrivateOperands(), values, hostToDevice); |
138 | collectVars(op.getFirstprivateOperands(), values, hostToDevice); |
139 | } |
140 | } |
141 | |
142 | if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) { |
143 | assert(domInfo && postDomInfo && |
144 | "Dominance info required for DeclareEnterOp" ); |
145 | replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo, |
146 | *postDomInfo); |
147 | } else { |
148 | for (auto p : values) { |
149 | replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p), |
150 | op.getRegion()); |
151 | } |
152 | } |
153 | } |
154 | |
155 | class LegalizeDataValuesInRegion |
156 | : public acc::impl::LegalizeDataValuesInRegionBase< |
157 | LegalizeDataValuesInRegion> { |
158 | public: |
159 | using LegalizeDataValuesInRegionBase< |
160 | LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase; |
161 | |
162 | void runOnOperation() override { |
163 | func::FuncOp funcOp = getOperation(); |
164 | bool replaceHostVsDevice = this->hostToDevice.getValue(); |
165 | |
166 | // Initialize dominance info |
167 | DominanceInfo domInfo; |
168 | PostDominanceInfo postDomInfo; |
169 | bool computedDomInfo = false; |
170 | |
171 | funcOp.walk([&](Operation *op) { |
172 | if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) && |
173 | !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) && |
174 | applyToAccDataConstruct) && |
175 | !isa<acc::DeclareEnterOp>(*op)) |
176 | return; |
177 | |
178 | if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) { |
179 | collectAndReplaceInRegion(parallelOp, replaceHostVsDevice); |
180 | } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) { |
181 | collectAndReplaceInRegion(serialOp, replaceHostVsDevice); |
182 | } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) { |
183 | collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice); |
184 | } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) { |
185 | collectAndReplaceInRegion(loopOp, replaceHostVsDevice); |
186 | } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) { |
187 | collectAndReplaceInRegion(dataOp, replaceHostVsDevice); |
188 | } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) { |
189 | collectAndReplaceInRegion(declareOp, replaceHostVsDevice); |
190 | } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) { |
191 | collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice); |
192 | } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) { |
193 | if (!computedDomInfo) { |
194 | domInfo = DominanceInfo(funcOp); |
195 | postDomInfo = PostDominanceInfo(funcOp); |
196 | computedDomInfo = true; |
197 | } |
198 | collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo, |
199 | &postDomInfo); |
200 | } else { |
201 | llvm_unreachable("unsupported acc region op" ); |
202 | } |
203 | }); |
204 | } |
205 | }; |
206 | |
207 | } // end anonymous namespace |
208 | |