1//===- OMPFunctionFiltering.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements transforms to filter out functions intended for the host
10// when compiling for the device and vice versa.
11//
12//===----------------------------------------------------------------------===//
13
14#include "flang/Optimizer/Dialect/FIRDialect.h"
15#include "flang/Optimizer/Dialect/FIROpsSupport.h"
16#include "flang/Optimizer/Transforms/Passes.h"
17
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
20#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
21#include "mlir/IR/BuiltinOps.h"
22#include "llvm/ADT/SmallVector.h"
23
24namespace fir {
25#define GEN_PASS_DEF_OMPFUNCTIONFILTERING
26#include "flang/Optimizer/Transforms/Passes.h.inc"
27} // namespace fir
28
29using namespace mlir;
30
31namespace {
32class OMPFunctionFilteringPass
33 : public fir::impl::OMPFunctionFilteringBase<OMPFunctionFilteringPass> {
34public:
35 OMPFunctionFilteringPass() = default;
36
37 void runOnOperation() override {
38 MLIRContext *context = &getContext();
39 OpBuilder opBuilder(context);
40 auto op = dyn_cast<omp::OffloadModuleInterface>(getOperation());
41 if (!op || !op.getIsTargetDevice())
42 return;
43
44 op->walk<WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
45 // Do not filter functions with target regions inside, because they have
46 // to be available for both host and device so that regular and reverse
47 // offloading can be supported.
48 bool hasTargetRegion =
49 funcOp
50 ->walk<WalkOrder::PreOrder>(
51 [&](omp::TargetOp) { return WalkResult::interrupt(); })
52 .wasInterrupted();
53
54 omp::DeclareTargetDeviceType declareType =
55 omp::DeclareTargetDeviceType::host;
56 auto declareTargetOp =
57 dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation());
58 if (declareTargetOp && declareTargetOp.isDeclareTarget())
59 declareType = declareTargetOp.getDeclareTargetDeviceType();
60
61 // Filtering a function here means deleting it if it doesn't contain a
62 // target region. Else we explicitly set the omp.declare_target
63 // attribute. The second stage of function filtering at the MLIR to LLVM
64 // IR translation level will remove functions that contain the target
65 // region from the generated llvm IR.
66 if (declareType == omp::DeclareTargetDeviceType::host) {
67 SymbolTable::UseRange funcUses = *funcOp.getSymbolUses(op);
68 for (SymbolTable::SymbolUse use : funcUses) {
69 Operation *callOp = use.getUser();
70 if (auto internalFunc = mlir::dyn_cast<func::FuncOp>(callOp)) {
71 // Do not delete internal procedures holding the symbol of their
72 // Fortran host procedure as attribute.
73 internalFunc->removeAttr(fir::getHostSymbolAttrName());
74 // Set public visibility so that the function is not deleted by MLIR
75 // because unused. Changing it is OK here because the function will
76 // be deleted anyway in the second filtering phase.
77 internalFunc.setVisibility(mlir::SymbolTable::Visibility::Public);
78 continue;
79 }
80 // If the callOp has users then replace them with Undef values.
81 if (!callOp->use_empty()) {
82 SmallVector<Value> undefResults;
83 for (Value res : callOp->getResults()) {
84 opBuilder.setInsertionPoint(callOp);
85 undefResults.emplace_back(
86 opBuilder.create<fir::UndefOp>(res.getLoc(), res.getType()));
87 }
88 callOp->replaceAllUsesWith(undefResults);
89 }
90 // Remove the callOp
91 callOp->erase();
92 }
93 if (!hasTargetRegion) {
94 funcOp.erase();
95 return WalkResult::skip();
96 }
97 if (declareTargetOp)
98 declareTargetOp.setDeclareTarget(declareType,
99 omp::DeclareTargetCaptureClause::to);
100 }
101 return WalkResult::advance();
102 });
103 }
104};
105} // namespace
106
107std::unique_ptr<Pass> fir::createOMPFunctionFilteringPass() {
108 return std::make_unique<OMPFunctionFilteringPass>();
109}
110

source code of flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp