1//===-- CUFDeviceGlobal.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Builder/CUFCommon.h"
10#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
11#include "flang/Optimizer/Dialect/FIRDialect.h"
12#include "flang/Optimizer/Dialect/FIROps.h"
13#include "flang/Optimizer/HLFIR/HLFIROps.h"
14#include "flang/Optimizer/Support/InternalNames.h"
15#include "flang/Runtime/CUDA/common.h"
16#include "flang/Runtime/allocatable.h"
17#include "flang/Support/Fortran.h"
18#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19#include "mlir/IR/SymbolTable.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/ADT/DenseSet.h"
23
24namespace fir {
25#define GEN_PASS_DEF_CUFDEVICEGLOBAL
26#include "flang/Optimizer/Transforms/Passes.h.inc"
27} // namespace fir
28
29namespace {
30
31static void processAddrOfOp(fir::AddrOfOp addrOfOp,
32 mlir::SymbolTable &symbolTable,
33 llvm::DenseSet<fir::GlobalOp> &candidates,
34 bool recurseInGlobal) {
35
36 // Check if there is a real use of the global.
37 if (addrOfOp.getOperation()->hasOneUse()) {
38 mlir::OpOperand &addrUse = *addrOfOp.getOperation()->getUses().begin();
39 if (mlir::isa<fir::DeclareOp>(addrUse.getOwner()) &&
40 addrUse.getOwner()->use_empty())
41 return;
42 }
43
44 if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
45 addrOfOp.getSymbol().getRootReference().getValue())) {
46 // TO DO: limit candidates to non-scalars. Scalars appear to have been
47 // folded in already.
48 if (recurseInGlobal)
49 globalOp.walk([&](fir::AddrOfOp op) {
50 processAddrOfOp(op, symbolTable, candidates, recurseInGlobal);
51 });
52 candidates.insert(globalOp);
53 }
54}
55
56static void processEmboxOp(fir::EmboxOp emboxOp, mlir::SymbolTable &symbolTable,
57 llvm::DenseSet<fir::GlobalOp> &candidates) {
58 if (auto recTy = mlir::dyn_cast<fir::RecordType>(
59 fir::unwrapRefType(emboxOp.getMemref().getType()))) {
60 if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
61 fir::NameUniquer::getTypeDescriptorName(recTy.getName()))) {
62 if (!candidates.contains(globalOp)) {
63 globalOp.walk([&](fir::AddrOfOp op) {
64 processAddrOfOp(op, symbolTable, candidates,
65 /*recurseInGlobal=*/true);
66 });
67 candidates.insert(globalOp);
68 }
69 }
70 }
71}
72
73static void
74prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
75 mlir::SymbolTable &symbolTable,
76 llvm::DenseSet<fir::GlobalOp> &candidates) {
77 auto cudaProcAttr{
78 funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
79 if (cudaProcAttr && cudaProcAttr.getValue() != cuf::ProcAttribute::Host) {
80 funcOp.walk([&](fir::AddrOfOp op) {
81 processAddrOfOp(op, symbolTable, candidates, /*recurseInGlobal=*/false);
82 });
83 funcOp.walk(
84 [&](fir::EmboxOp op) { processEmboxOp(op, symbolTable, candidates); });
85 }
86}
87
88class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
89public:
90 void runOnOperation() override {
91 mlir::Operation *op = getOperation();
92 mlir::ModuleOp mod = mlir::dyn_cast<mlir::ModuleOp>(op);
93 if (!mod)
94 return signalPassFailure();
95
96 llvm::DenseSet<fir::GlobalOp> candidates;
97 mlir::SymbolTable symTable(mod);
98 mod.walk([&](mlir::func::FuncOp funcOp) {
99 prepareImplicitDeviceGlobals(funcOp, symTable, candidates);
100 return mlir::WalkResult::advance();
101 });
102 mod.walk([&](cuf::KernelOp kernelOp) {
103 kernelOp.walk([&](fir::AddrOfOp addrOfOp) {
104 processAddrOfOp(addrOfOp, symTable, candidates,
105 /*recurseInGlobal=*/false);
106 });
107 });
108
109 // Copying the device global variable into the gpu module
110 mlir::SymbolTable parentSymTable(mod);
111 auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable);
112 if (!gpuMod)
113 return signalPassFailure();
114 mlir::SymbolTable gpuSymTable(gpuMod);
115 for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
116 if (cuf::isRegisteredDeviceGlobal(globalOp))
117 candidates.insert(globalOp);
118 }
119 for (auto globalOp : candidates) {
120 auto globalName{globalOp.getSymbol().getValue()};
121 if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
122 break;
123 }
124 gpuSymTable.insert(globalOp->clone());
125 }
126 }
127};
128} // namespace
129

source code of flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp