1 | //===- MarkDeclareTarget.cpp -------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Mark functions called from explicit target code as implicitly declare target. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/OpenMP/Passes.h" |
14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
16 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
17 | #include "mlir/IR/BuiltinDialect.h" |
18 | #include "mlir/IR/BuiltinOps.h" |
19 | #include "mlir/IR/Operation.h" |
20 | #include "mlir/IR/SymbolTable.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | #include "mlir/Support/LLVM.h" |
23 | #include "llvm/ADT/SmallPtrSet.h" |
24 | |
25 | namespace flangomp { |
26 | #define GEN_PASS_DEF_MARKDECLARETARGETPASS |
27 | #include "flang/Optimizer/OpenMP/Passes.h.inc" |
28 | } // namespace flangomp |
29 | |
30 | namespace { |
31 | class MarkDeclareTargetPass |
32 | : public flangomp::impl::MarkDeclareTargetPassBase<MarkDeclareTargetPass> { |
33 | |
34 | void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy, |
35 | mlir::omp::DeclareTargetCaptureClause parentCapClause, |
36 | mlir::Operation *currOp, |
37 | llvm::SmallPtrSet<mlir::Operation *, 16> visited) { |
38 | if (visited.contains(currOp)) |
39 | return; |
40 | visited.insert(currOp); |
41 | |
42 | currOp->walk([&, this](mlir::Operation *op) { |
43 | if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) { |
44 | if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>( |
45 | callOp.getCallableForCallee())) { |
46 | if (auto currFOp = |
47 | getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) { |
48 | auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( |
49 | currFOp.getOperation()); |
50 | |
51 | if (current.isDeclareTarget()) { |
52 | auto currentDt = current.getDeclareTargetDeviceType(); |
53 | |
54 | // Found the same function twice, with different device_types, |
55 | // mark as Any as it belongs to both |
56 | if (currentDt != parentDevTy && |
57 | currentDt != mlir::omp::DeclareTargetDeviceType::any) { |
58 | current.setDeclareTarget( |
59 | mlir::omp::DeclareTargetDeviceType::any, |
60 | current.getDeclareTargetCaptureClause()); |
61 | } |
62 | } else { |
63 | current.setDeclareTarget(parentDevTy, parentCapClause); |
64 | } |
65 | |
66 | markNestedFuncs(parentDevTy, parentCapClause, currFOp, visited); |
67 | } |
68 | } |
69 | } |
70 | }); |
71 | } |
72 | |
73 | // This pass executes on mlir::ModuleOp's marking functions contained within |
74 | // as implicitly declare target if they are called from within an explicitly |
75 | // marked declare target function or a target region (TargetOp) |
76 | void runOnOperation() override { |
77 | for (auto functionOp : getOperation().getOps<mlir::func::FuncOp>()) { |
78 | auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( |
79 | functionOp.getOperation()); |
80 | if (declareTargetOp.isDeclareTarget()) { |
81 | llvm::SmallPtrSet<mlir::Operation *, 16> visited; |
82 | markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(), |
83 | declareTargetOp.getDeclareTargetCaptureClause(), |
84 | functionOp, visited); |
85 | } |
86 | } |
87 | |
88 | // TODO: Extend to work with reverse-offloading, this shouldn't |
89 | // require too much effort, just need to check the device clause |
90 | // when it's lowering has been implemented and change the |
91 | // DeclareTargetDeviceType argument from nohost to host depending on |
92 | // the contents of the device clause |
93 | getOperation()->walk([&](mlir::omp::TargetOp tarOp) { |
94 | llvm::SmallPtrSet<mlir::Operation *, 16> visited; |
95 | markNestedFuncs(mlir::omp::DeclareTargetDeviceType::nohost, |
96 | mlir::omp::DeclareTargetCaptureClause::to, tarOp, |
97 | visited); |
98 | }); |
99 | } |
100 | }; |
101 | |
102 | } // namespace |
103 | |