1 | #include "flang/Optimizer/Transforms/Passes.h" |
2 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
3 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
4 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
5 | #include "mlir/IR/BuiltinDialect.h" |
6 | #include "mlir/IR/BuiltinOps.h" |
7 | #include "mlir/IR/Operation.h" |
8 | #include "mlir/IR/SymbolTable.h" |
9 | #include "mlir/Pass/Pass.h" |
10 | #include "mlir/Support/LLVM.h" |
11 | #include "llvm/ADT/SmallPtrSet.h" |
12 | |
13 | namespace fir { |
14 | #define GEN_PASS_DEF_OMPMARKDECLARETARGETPASS |
15 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
16 | } // namespace fir |
17 | |
18 | namespace { |
19 | class OMPMarkDeclareTargetPass |
20 | : public fir::impl::OMPMarkDeclareTargetPassBase<OMPMarkDeclareTargetPass> { |
21 | |
22 | void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy, |
23 | mlir::omp::DeclareTargetCaptureClause parentCapClause, |
24 | mlir::Operation *currOp, |
25 | llvm::SmallPtrSet<mlir::Operation *, 16> visited) { |
26 | if (visited.contains(currOp)) |
27 | return; |
28 | visited.insert(currOp); |
29 | |
30 | currOp->walk([&, this](mlir::Operation *op) { |
31 | if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) { |
32 | if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>( |
33 | callOp.getCallableForCallee())) { |
34 | if (auto currFOp = |
35 | getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) { |
36 | auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( |
37 | currFOp.getOperation()); |
38 | |
39 | if (current.isDeclareTarget()) { |
40 | auto currentDt = current.getDeclareTargetDeviceType(); |
41 | |
42 | // Found the same function twice, with different device_types, |
43 | // mark as Any as it belongs to both |
44 | if (currentDt != parentDevTy && |
45 | currentDt != mlir::omp::DeclareTargetDeviceType::any) { |
46 | current.setDeclareTarget( |
47 | mlir::omp::DeclareTargetDeviceType::any, |
48 | current.getDeclareTargetCaptureClause()); |
49 | } |
50 | } else { |
51 | current.setDeclareTarget(parentDevTy, parentCapClause); |
52 | } |
53 | |
54 | markNestedFuncs(parentDevTy, parentCapClause, currFOp, visited); |
55 | } |
56 | } |
57 | } |
58 | }); |
59 | } |
60 | |
61 | // This pass executes on mlir::ModuleOp's marking functions contained within |
62 | // as implicitly declare target if they are called from within an explicitly |
63 | // marked declare target function or a target region (TargetOp) |
64 | void runOnOperation() override { |
65 | for (auto functionOp : getOperation().getOps<mlir::func::FuncOp>()) { |
66 | auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( |
67 | functionOp.getOperation()); |
68 | if (declareTargetOp.isDeclareTarget()) { |
69 | llvm::SmallPtrSet<mlir::Operation *, 16> visited; |
70 | markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(), |
71 | declareTargetOp.getDeclareTargetCaptureClause(), |
72 | functionOp, visited); |
73 | } |
74 | } |
75 | |
76 | // TODO: Extend to work with reverse-offloading, this shouldn't |
77 | // require too much effort, just need to check the device clause |
78 | // when it's lowering has been implemented and change the |
79 | // DeclareTargetDeviceType argument from nohost to host depending on |
80 | // the contents of the device clause |
81 | getOperation()->walk([&](mlir::omp::TargetOp tarOp) { |
82 | llvm::SmallPtrSet<mlir::Operation *, 16> visited; |
83 | markNestedFuncs(mlir::omp::DeclareTargetDeviceType::nohost, |
84 | mlir::omp::DeclareTargetCaptureClause::to, tarOp, |
85 | visited); |
86 | }); |
87 | } |
88 | }; |
89 | |
90 | } // namespace |
91 | |
92 | namespace fir { |
93 | std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
94 | createOMPMarkDeclareTargetPass() { |
95 | return std::make_unique<OMPMarkDeclareTargetPass>(); |
96 | } |
97 | } // namespace fir |
98 | |