1#include "flang/Optimizer/Transforms/Passes.h"
2#include "mlir/Dialect/Func/IR/FuncOps.h"
3#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
4#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
5#include "mlir/IR/BuiltinDialect.h"
6#include "mlir/IR/BuiltinOps.h"
7#include "mlir/IR/Operation.h"
8#include "mlir/IR/SymbolTable.h"
9#include "mlir/Pass/Pass.h"
10#include "mlir/Support/LLVM.h"
11#include "llvm/ADT/SmallPtrSet.h"
12
13namespace fir {
14#define GEN_PASS_DEF_OMPMARKDECLARETARGETPASS
15#include "flang/Optimizer/Transforms/Passes.h.inc"
16} // namespace fir
17
18namespace {
19class OMPMarkDeclareTargetPass
20 : public fir::impl::OMPMarkDeclareTargetPassBase<OMPMarkDeclareTargetPass> {
21
22 void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
23 mlir::omp::DeclareTargetCaptureClause parentCapClause,
24 mlir::Operation *currOp,
25 llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
26 if (visited.contains(currOp))
27 return;
28 visited.insert(currOp);
29
30 currOp->walk([&, this](mlir::Operation *op) {
31 if (auto callOp = llvm::dyn_cast<mlir::CallOpInterface>(op)) {
32 if (auto symRef = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
33 callOp.getCallableForCallee())) {
34 if (auto currFOp =
35 getOperation().lookupSymbol<mlir::func::FuncOp>(symRef)) {
36 auto current = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
37 currFOp.getOperation());
38
39 if (current.isDeclareTarget()) {
40 auto currentDt = current.getDeclareTargetDeviceType();
41
42 // Found the same function twice, with different device_types,
43 // mark as Any as it belongs to both
44 if (currentDt != parentDevTy &&
45 currentDt != mlir::omp::DeclareTargetDeviceType::any) {
46 current.setDeclareTarget(
47 mlir::omp::DeclareTargetDeviceType::any,
48 current.getDeclareTargetCaptureClause());
49 }
50 } else {
51 current.setDeclareTarget(parentDevTy, parentCapClause);
52 }
53
54 markNestedFuncs(parentDevTy, parentCapClause, currFOp, visited);
55 }
56 }
57 }
58 });
59 }
60
61 // This pass executes on mlir::ModuleOp's marking functions contained within
62 // as implicitly declare target if they are called from within an explicitly
63 // marked declare target function or a target region (TargetOp)
64 void runOnOperation() override {
65 for (auto functionOp : getOperation().getOps<mlir::func::FuncOp>()) {
66 auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
67 functionOp.getOperation());
68 if (declareTargetOp.isDeclareTarget()) {
69 llvm::SmallPtrSet<mlir::Operation *, 16> visited;
70 markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(),
71 declareTargetOp.getDeclareTargetCaptureClause(),
72 functionOp, visited);
73 }
74 }
75
76 // TODO: Extend to work with reverse-offloading, this shouldn't
77 // require too much effort, just need to check the device clause
78 // when it's lowering has been implemented and change the
79 // DeclareTargetDeviceType argument from nohost to host depending on
80 // the contents of the device clause
81 getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
82 llvm::SmallPtrSet<mlir::Operation *, 16> visited;
83 markNestedFuncs(mlir::omp::DeclareTargetDeviceType::nohost,
84 mlir::omp::DeclareTargetCaptureClause::to, tarOp,
85 visited);
86 });
87 }
88};
89
90} // namespace
91
92namespace fir {
93std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
94createOMPMarkDeclareTargetPass() {
95 return std::make_unique<OMPMarkDeclareTargetPass>();
96}
97} // namespace fir
98

source code of flang/lib/Optimizer/Transforms/OMPMarkDeclareTarget.cpp