1 | //===- LegalizeData.cpp - -------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/OpenACC/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
12 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
13 | #include "mlir/Pass/Pass.h" |
14 | #include "mlir/Transforms/RegionUtils.h" |
15 | |
16 | namespace mlir { |
17 | namespace acc { |
18 | #define GEN_PASS_DEF_LEGALIZEDATAINREGION |
19 | #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" |
20 | } // namespace acc |
21 | } // namespace mlir |
22 | |
23 | using namespace mlir; |
24 | |
25 | namespace { |
26 | |
27 | static void collectPtrs(mlir::ValueRange operands, |
28 | llvm::SmallVector<std::pair<Value, Value>> &values, |
29 | bool hostToDevice) { |
30 | for (auto operand : operands) { |
31 | Value varPtr = acc::getVarPtr(accDataClauseOp: operand.getDefiningOp()); |
32 | Value accPtr = acc::getAccPtr(accDataClauseOp: operand.getDefiningOp()); |
33 | if (varPtr && accPtr) { |
34 | if (hostToDevice) |
35 | values.push_back({varPtr, accPtr}); |
36 | else |
37 | values.push_back({accPtr, varPtr}); |
38 | } |
39 | } |
40 | } |
41 | |
42 | template <typename Op> |
43 | static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { |
44 | llvm::SmallVector<std::pair<Value, Value>> values; |
45 | |
46 | if constexpr (std::is_same_v<Op, acc::LoopOp>) { |
47 | collectPtrs(op.getReductionOperands(), values, hostToDevice); |
48 | collectPtrs(op.getPrivateOperands(), values, hostToDevice); |
49 | } else { |
50 | collectPtrs(op.getDataClauseOperands(), values, hostToDevice); |
51 | if constexpr (!std::is_same_v<Op, acc::KernelsOp>) { |
52 | collectPtrs(op.getReductionOperands(), values, hostToDevice); |
53 | collectPtrs(op.getGangPrivateOperands(), values, hostToDevice); |
54 | collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice); |
55 | } |
56 | } |
57 | |
58 | for (auto p : values) |
59 | replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion()); |
60 | } |
61 | |
62 | struct LegalizeDataInRegion |
63 | : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> { |
64 | |
65 | void runOnOperation() override { |
66 | func::FuncOp funcOp = getOperation(); |
67 | bool replaceHostVsDevice = this->hostToDevice.getValue(); |
68 | |
69 | funcOp.walk([&](Operation *op) { |
70 | if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op)) |
71 | return; |
72 | |
73 | if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) { |
74 | collectAndReplaceInRegion(parallelOp, replaceHostVsDevice); |
75 | } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) { |
76 | collectAndReplaceInRegion(serialOp, replaceHostVsDevice); |
77 | } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) { |
78 | collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice); |
79 | } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) { |
80 | collectAndReplaceInRegion(loopOp, replaceHostVsDevice); |
81 | } |
82 | }); |
83 | } |
84 | }; |
85 | |
86 | } // end anonymous namespace |
87 | |
88 | std::unique_ptr<OperationPass<func::FuncOp>> |
89 | mlir::acc::createLegalizeDataInRegion() { |
90 | return std::make_unique<LegalizeDataInRegion>(); |
91 | } |
92 | |