1//===- LegalizeData.cpp - -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/OpenACC/OpenACC.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/RegionUtils.h"
15
16namespace mlir {
17namespace acc {
18#define GEN_PASS_DEF_LEGALIZEDATAINREGION
19#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
20} // namespace acc
21} // namespace mlir
22
23using namespace mlir;
24
25namespace {
26
27static void collectPtrs(mlir::ValueRange operands,
28 llvm::SmallVector<std::pair<Value, Value>> &values,
29 bool hostToDevice) {
30 for (auto operand : operands) {
31 Value varPtr = acc::getVarPtr(accDataClauseOp: operand.getDefiningOp());
32 Value accPtr = acc::getAccPtr(accDataClauseOp: operand.getDefiningOp());
33 if (varPtr && accPtr) {
34 if (hostToDevice)
35 values.push_back({varPtr, accPtr});
36 else
37 values.push_back({accPtr, varPtr});
38 }
39 }
40}
41
42template <typename Op>
43static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
44 llvm::SmallVector<std::pair<Value, Value>> values;
45
46 if constexpr (std::is_same_v<Op, acc::LoopOp>) {
47 collectPtrs(op.getReductionOperands(), values, hostToDevice);
48 collectPtrs(op.getPrivateOperands(), values, hostToDevice);
49 } else {
50 collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
51 if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
52 collectPtrs(op.getReductionOperands(), values, hostToDevice);
53 collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
54 collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
55 }
56 }
57
58 for (auto p : values)
59 replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
60}
61
62struct LegalizeDataInRegion
63 : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
64
65 void runOnOperation() override {
66 func::FuncOp funcOp = getOperation();
67 bool replaceHostVsDevice = this->hostToDevice.getValue();
68
69 funcOp.walk([&](Operation *op) {
70 if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
71 return;
72
73 if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
74 collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
75 } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
76 collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
77 } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
78 collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
79 } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
80 collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
81 }
82 });
83 }
84};
85
86} // end anonymous namespace
87
88std::unique_ptr<OperationPass<func::FuncOp>>
89mlir::acc::createLegalizeDataInRegion() {
90 return std::make_unique<LegalizeDataInRegion>();
91}
92

source code of mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp