1//===-- CUFAddConstructor.cpp ---------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Builder/BoxValue.h"
10#include "flang/Optimizer/Builder/CUFCommon.h"
11#include "flang/Optimizer/Builder/FIRBuilder.h"
12#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
13#include "flang/Optimizer/Builder/Todo.h"
14#include "flang/Optimizer/CodeGen/Target.h"
15#include "flang/Optimizer/CodeGen/TypeConverter.h"
16#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
17#include "flang/Optimizer/Dialect/FIRAttr.h"
18#include "flang/Optimizer/Dialect/FIRDialect.h"
19#include "flang/Optimizer/Dialect/FIROps.h"
20#include "flang/Optimizer/Dialect/FIROpsSupport.h"
21#include "flang/Optimizer/Dialect/FIRType.h"
22#include "flang/Optimizer/Support/DataLayout.h"
23#include "flang/Runtime/CUDA/registration.h"
24#include "flang/Runtime/entry-names.h"
25#include "mlir/Dialect/DLTI/DLTI.h"
26#include "mlir/Dialect/GPU/IR/GPUDialect.h"
27#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
28#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
29#include "mlir/IR/Value.h"
30#include "mlir/Pass/Pass.h"
31#include "llvm/ADT/SmallVector.h"
32
33namespace fir {
34#define GEN_PASS_DEF_CUFADDCONSTRUCTOR
35#include "flang/Optimizer/Transforms/Passes.h.inc"
36} // namespace fir
37
38using namespace Fortran::runtime::cuda;
39
40namespace {
41
42static constexpr llvm::StringRef cudaFortranCtorName{
43 "__cudaFortranConstructor"};
44
45struct CUFAddConstructor
46 : public fir::impl::CUFAddConstructorBase<CUFAddConstructor> {
47
48 void runOnOperation() override {
49 mlir::ModuleOp mod = getOperation();
50 mlir::SymbolTable symTab(mod);
51 mlir::OpBuilder opBuilder{mod.getBodyRegion()};
52 fir::FirOpBuilder builder(opBuilder, mod);
53 fir::KindMapping kindMap{fir::getKindMapping(mod)};
54 builder.setInsertionPointToEnd(mod.getBody());
55 mlir::Location loc = mod.getLoc();
56 auto *ctx = mod.getContext();
57 auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx);
58 auto idxTy = builder.getIndexType();
59 auto funcTy =
60 mlir::LLVM::LLVMFunctionType::get(voidTy, {}, /*isVarArg=*/false);
61 std::optional<mlir::DataLayout> dl =
62 fir::support::getOrSetMLIRDataLayout(mod, /*allowDefaultLayout=*/false);
63 if (!dl) {
64 mlir::emitError(mod.getLoc(),
65 "data layout attribute is required to perform " +
66 getName() + "pass");
67 }
68
69 // Symbol reference to CUFRegisterAllocator.
70 builder.setInsertionPointToEnd(mod.getBody());
71 auto registerFuncOp = builder.create<mlir::LLVM::LLVMFuncOp>(
72 loc, RTNAME_STRING(CUFRegisterAllocator), funcTy);
73 registerFuncOp.setVisibility(mlir::SymbolTable::Visibility::Private);
74 auto cufRegisterAllocatorRef = mlir::SymbolRefAttr::get(
75 mod.getContext(), RTNAME_STRING(CUFRegisterAllocator));
76 builder.setInsertionPointToEnd(mod.getBody());
77
78 // Create the constructor function that call CUFRegisterAllocator.
79 auto func = builder.create<mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName,
80 funcTy);
81 func.setLinkage(mlir::LLVM::Linkage::Internal);
82 builder.setInsertionPointToStart(func.addEntryBlock(builder));
83 builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
84
85 auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
86 if (gpuMod) {
87 auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
88 auto registeredMod = builder.create<cuf::RegisterModuleOp>(
89 loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
90
91 fir::LLVMTypeConverter typeConverter(mod, /*applyTBAA=*/false,
92 /*forceUnifiedTBAATree=*/false, *dl);
93 // Register kernels
94 for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
95 if (func.isKernel()) {
96 auto kernelName = mlir::SymbolRefAttr::get(
97 builder.getStringAttr(cudaDeviceModuleName),
98 {mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
99 builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
100 }
101 }
102
103 // Register variables
104 for (fir::GlobalOp globalOp : mod.getOps<fir::GlobalOp>()) {
105 auto attr = globalOp.getDataAttrAttr();
106 if (!attr)
107 continue;
108
109 if (attr.getValue() == cuf::DataAttribute::Managed &&
110 !mlir::isa<fir::BaseBoxType>(globalOp.getType()))
111 TODO(loc, "registration of non-allocatable managed variables");
112
113 mlir::func::FuncOp func;
114 switch (attr.getValue()) {
115 case cuf::DataAttribute::Device:
116 case cuf::DataAttribute::Constant:
117 case cuf::DataAttribute::Managed: {
118 func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>(
119 loc, builder);
120 auto fTy = func.getFunctionType();
121
122 // Global variable name
123 std::string gblNameStr = globalOp.getSymbol().getValue().str();
124 gblNameStr += '\0';
125 mlir::Value gblName = fir::getBase(
126 fir::factory::createStringLiteral(builder, loc, gblNameStr));
127
128 // Global variable size
129 std::optional<uint64_t> size;
130 if (auto boxTy =
131 mlir::dyn_cast<fir::BaseBoxType>(globalOp.getType())) {
132 mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
133 size = dl->getTypeSizeInBits(structTy) / 8;
134 }
135 if (!size) {
136 size = fir::getTypeSizeAndAlignmentOrCrash(loc, globalOp.getType(),
137 *dl, kindMap)
138 .first;
139 }
140 auto sizeVal = builder.createIntegerConstant(loc, idxTy, *size);
141
142 // Global variable address
143 mlir::Value addr = builder.create<fir::AddrOfOp>(
144 loc, globalOp.resultType(), globalOp.getSymbol());
145
146 llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
147 builder, loc, fTy, registeredMod, addr, gblName, sizeVal)};
148 builder.create<fir::CallOp>(loc, func, args);
149 } break;
150 default:
151 break;
152 }
153 }
154 }
155 builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
156
157 // Create the llvm.global_ctor with the function.
158 // TODO: We might want to have a utility that retrieve it if already
159 // created and adds new functions.
160 builder.setInsertionPointToEnd(mod.getBody());
161 llvm::SmallVector<mlir::Attribute> funcs;
162 funcs.push_back(
163 mlir::FlatSymbolRefAttr::get(mod.getContext(), func.getSymName()));
164 llvm::SmallVector<int> priorities;
165 llvm::SmallVector<mlir::Attribute> data;
166 priorities.push_back(0);
167 data.push_back(mlir::LLVM::ZeroAttr::get(mod.getContext()));
168 builder.create<mlir::LLVM::GlobalCtorsOp>(
169 mod.getLoc(), builder.getArrayAttr(funcs),
170 builder.getI32ArrayAttr(priorities), builder.getArrayAttr(data));
171 }
172};
173
174} // end anonymous namespace
175

source code of flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp