1 | //===-- CUFAddConstructor.cpp ---------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Builder/BoxValue.h" |
10 | #include "flang/Optimizer/Builder/CUFCommon.h" |
11 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
12 | #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" |
13 | #include "flang/Optimizer/Builder/Todo.h" |
14 | #include "flang/Optimizer/CodeGen/Target.h" |
15 | #include "flang/Optimizer/CodeGen/TypeConverter.h" |
16 | #include "flang/Optimizer/Dialect/CUF/CUFOps.h" |
17 | #include "flang/Optimizer/Dialect/FIRAttr.h" |
18 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
19 | #include "flang/Optimizer/Dialect/FIROps.h" |
20 | #include "flang/Optimizer/Dialect/FIROpsSupport.h" |
21 | #include "flang/Optimizer/Dialect/FIRType.h" |
22 | #include "flang/Optimizer/Support/DataLayout.h" |
23 | #include "flang/Runtime/CUDA/registration.h" |
24 | #include "flang/Runtime/entry-names.h" |
25 | #include "mlir/Dialect/DLTI/DLTI.h" |
26 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
27 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
28 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
29 | #include "mlir/IR/Value.h" |
30 | #include "mlir/Pass/Pass.h" |
31 | #include "llvm/ADT/SmallVector.h" |
32 | |
33 | namespace fir { |
34 | #define GEN_PASS_DEF_CUFADDCONSTRUCTOR |
35 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
36 | } // namespace fir |
37 | |
38 | using namespace Fortran::runtime::cuda; |
39 | |
40 | namespace { |
41 | |
42 | static constexpr llvm::StringRef cudaFortranCtorName{ |
43 | "__cudaFortranConstructor" }; |
44 | |
45 | struct CUFAddConstructor |
46 | : public fir::impl::CUFAddConstructorBase<CUFAddConstructor> { |
47 | |
48 | void runOnOperation() override { |
49 | mlir::ModuleOp mod = getOperation(); |
50 | mlir::SymbolTable symTab(mod); |
51 | mlir::OpBuilder opBuilder{mod.getBodyRegion()}; |
52 | fir::FirOpBuilder builder(opBuilder, mod); |
53 | fir::KindMapping kindMap{fir::getKindMapping(mod)}; |
54 | builder.setInsertionPointToEnd(mod.getBody()); |
55 | mlir::Location loc = mod.getLoc(); |
56 | auto *ctx = mod.getContext(); |
57 | auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx); |
58 | auto idxTy = builder.getIndexType(); |
59 | auto funcTy = |
60 | mlir::LLVM::LLVMFunctionType::get(voidTy, {}, /*isVarArg=*/false); |
61 | std::optional<mlir::DataLayout> dl = |
62 | fir::support::getOrSetMLIRDataLayout(mod, /*allowDefaultLayout=*/false); |
63 | if (!dl) { |
64 | mlir::emitError(mod.getLoc(), |
65 | "data layout attribute is required to perform " + |
66 | getName() + "pass" ); |
67 | } |
68 | |
69 | // Symbol reference to CUFRegisterAllocator. |
70 | builder.setInsertionPointToEnd(mod.getBody()); |
71 | auto registerFuncOp = builder.create<mlir::LLVM::LLVMFuncOp>( |
72 | loc, RTNAME_STRING(CUFRegisterAllocator), funcTy); |
73 | registerFuncOp.setVisibility(mlir::SymbolTable::Visibility::Private); |
74 | auto cufRegisterAllocatorRef = mlir::SymbolRefAttr::get( |
75 | mod.getContext(), RTNAME_STRING(CUFRegisterAllocator)); |
76 | builder.setInsertionPointToEnd(mod.getBody()); |
77 | |
78 | // Create the constructor function that call CUFRegisterAllocator. |
79 | auto func = builder.create<mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName, |
80 | funcTy); |
81 | func.setLinkage(mlir::LLVM::Linkage::Internal); |
82 | builder.setInsertionPointToStart(func.addEntryBlock(builder)); |
83 | builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef); |
84 | |
85 | auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName); |
86 | if (gpuMod) { |
87 | auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx); |
88 | auto registeredMod = builder.create<cuf::RegisterModuleOp>( |
89 | loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName())); |
90 | |
91 | fir::LLVMTypeConverter typeConverter(mod, /*applyTBAA=*/false, |
92 | /*forceUnifiedTBAATree=*/false, *dl); |
93 | // Register kernels |
94 | for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) { |
95 | if (func.isKernel()) { |
96 | auto kernelName = mlir::SymbolRefAttr::get( |
97 | builder.getStringAttr(cudaDeviceModuleName), |
98 | {mlir::SymbolRefAttr::get(builder.getContext(), func.getName())}); |
99 | builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod); |
100 | } |
101 | } |
102 | |
103 | // Register variables |
104 | for (fir::GlobalOp globalOp : mod.getOps<fir::GlobalOp>()) { |
105 | auto attr = globalOp.getDataAttrAttr(); |
106 | if (!attr) |
107 | continue; |
108 | |
109 | if (attr.getValue() == cuf::DataAttribute::Managed && |
110 | !mlir::isa<fir::BaseBoxType>(globalOp.getType())) |
111 | TODO(loc, "registration of non-allocatable managed variables" ); |
112 | |
113 | mlir::func::FuncOp func; |
114 | switch (attr.getValue()) { |
115 | case cuf::DataAttribute::Device: |
116 | case cuf::DataAttribute::Constant: |
117 | case cuf::DataAttribute::Managed: { |
118 | func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>( |
119 | loc, builder); |
120 | auto fTy = func.getFunctionType(); |
121 | |
122 | // Global variable name |
123 | std::string gblNameStr = globalOp.getSymbol().getValue().str(); |
124 | gblNameStr += '\0'; |
125 | mlir::Value gblName = fir::getBase( |
126 | fir::factory::createStringLiteral(builder, loc, gblNameStr)); |
127 | |
128 | // Global variable size |
129 | std::optional<uint64_t> size; |
130 | if (auto boxTy = |
131 | mlir::dyn_cast<fir::BaseBoxType>(globalOp.getType())) { |
132 | mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy); |
133 | size = dl->getTypeSizeInBits(structTy) / 8; |
134 | } |
135 | if (!size) { |
136 | size = fir::getTypeSizeAndAlignmentOrCrash(loc, globalOp.getType(), |
137 | *dl, kindMap) |
138 | .first; |
139 | } |
140 | auto sizeVal = builder.createIntegerConstant(loc, idxTy, *size); |
141 | |
142 | // Global variable address |
143 | mlir::Value addr = builder.create<fir::AddrOfOp>( |
144 | loc, globalOp.resultType(), globalOp.getSymbol()); |
145 | |
146 | llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( |
147 | builder, loc, fTy, registeredMod, addr, gblName, sizeVal)}; |
148 | builder.create<fir::CallOp>(loc, func, args); |
149 | } break; |
150 | default: |
151 | break; |
152 | } |
153 | } |
154 | } |
155 | builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{}); |
156 | |
157 | // Create the llvm.global_ctor with the function. |
158 | // TODO: We might want to have a utility that retrieve it if already |
159 | // created and adds new functions. |
160 | builder.setInsertionPointToEnd(mod.getBody()); |
161 | llvm::SmallVector<mlir::Attribute> funcs; |
162 | funcs.push_back( |
163 | mlir::FlatSymbolRefAttr::get(mod.getContext(), func.getSymName())); |
164 | llvm::SmallVector<int> priorities; |
165 | llvm::SmallVector<mlir::Attribute> data; |
166 | priorities.push_back(0); |
167 | data.push_back(mlir::LLVM::ZeroAttr::get(mod.getContext())); |
168 | builder.create<mlir::LLVM::GlobalCtorsOp>( |
169 | mod.getLoc(), builder.getArrayAttr(funcs), |
170 | builder.getI32ArrayAttr(priorities), builder.getArrayAttr(data)); |
171 | } |
172 | }; |
173 | |
174 | } // end anonymous namespace |
175 | |