1//===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the `OffloadingLLVMTranslationAttrInterface` for the
10// `SelectObject` attribute.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
15#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16
17#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
18#include "mlir/Target/LLVMIR/Export.h"
19#include "mlir/Target/LLVMIR/ModuleTranslation.h"
20
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/IRBuilder.h"
24#include "llvm/IR/LLVMContext.h"
25#include "llvm/IR/Module.h"
26#include "llvm/Support/FormatVariadic.h"
27#include "llvm/Transforms/Utils/ModuleUtils.h"
28
29using namespace mlir;
30
31namespace {
32// Implementation of the `OffloadingLLVMTranslationAttrInterface` model.
33class SelectObjectAttrImpl
34 : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
35 SelectObjectAttrImpl> {
36 // Returns the selected object for embedding.
37 gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
38
39public:
40 // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
41 // global binary string which gets loaded/unloaded into a global module
42 // object through a global ctor/dtor.
43 LogicalResult embedBinary(Attribute attribute, Operation *operation,
44 llvm::IRBuilderBase &builder,
45 LLVM::ModuleTranslation &moduleTranslation) const;
46
47 // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting
48 // in a kernel launch call.
49 LogicalResult launchKernel(Attribute attribute,
50 Operation *launchFuncOperation,
51 Operation *binaryOperation,
52 llvm::IRBuilderBase &builder,
53 LLVM::ModuleTranslation &moduleTranslation) const;
54};
55} // namespace
56
57gpu::ObjectAttr
58SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
59 ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
60
61 // Obtain the index of the object to select.
62 int64_t index = -1;
63 if (Attribute target =
64 cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
65 .getTarget()) {
66 // If the target attribute is a number it is the index. Otherwise compare
67 // the attribute to every target inside the object array to find the index.
68 if (auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
69 index = indexAttr.getInt();
70 } else {
71 for (auto [i, attr] : llvm::enumerate(objects)) {
72 auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
73 if (obj.getTarget() == target) {
74 index = i;
75 }
76 }
77 }
78 } else {
79 // If the target attribute is null then it's selecting the first object in
80 // the object array.
81 index = 0;
82 }
83
84 if (index < 0 || index >= static_cast<int64_t>(objects.size())) {
85 op->emitError("the requested target object couldn't be found");
86 return nullptr;
87 }
88 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
89}
90
91static Twine getModuleIdentifier(StringRef moduleName) {
92 return moduleName + "_module";
93}
94
95namespace llvm {
96static LogicalResult embedBinaryImpl(StringRef moduleName,
97 gpu::ObjectAttr object, Module &module) {
98
99 // Embed the object as a global string.
100 // Add null for assembly output for JIT paths that expect null-terminated
101 // strings.
102 bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly);
103 StringRef serializedStr = object.getObject().getValue();
104 Constant *serializedCst =
105 ConstantDataArray::getString(Context&: module.getContext(), Initializer: serializedStr, AddNull: addNull);
106 GlobalVariable *serializedObj =
107 new GlobalVariable(module, serializedCst->getType(), true,
108 GlobalValue::LinkageTypes::InternalLinkage,
109 serializedCst, moduleName + "_binary");
110 serializedObj->setAlignment(MaybeAlign(8));
111 serializedObj->setUnnamedAddr(GlobalValue::UnnamedAddr::None);
112
113 // Default JIT optimization level.
114 auto optLevel = APInt::getZero(numBits: 32);
115
116 if (DictionaryAttr objectProps = object.getProperties()) {
117 if (auto section = dyn_cast_or_null<StringAttr>(
118 objectProps.get(gpu::elfSectionName))) {
119 serializedObj->setSection(section.getValue());
120 }
121 // Check if there's an optimization level embedded in the object.
122 if (auto optAttr = dyn_cast_or_null<IntegerAttr>(objectProps.get("O")))
123 optLevel = optAttr.getValue();
124 }
125
126 IRBuilder<> builder(module.getContext());
127 auto i32Ty = builder.getInt32Ty();
128 auto i64Ty = builder.getInt64Ty();
129 auto ptrTy = builder.getPtrTy(AddrSpace: 0);
130 auto voidTy = builder.getVoidTy();
131
132 // Embed the module as a global object.
133 auto *modulePtr = new GlobalVariable(
134 module, ptrTy, /*isConstant=*/false, GlobalValue::InternalLinkage,
135 /*Initializer=*/ConstantPointerNull::get(T: ptrTy),
136 getModuleIdentifier(moduleName));
137
138 auto *loadFn = Function::Create(Ty: FunctionType::get(Result: voidTy, /*IsVarArg=*/isVarArg: false),
139 Linkage: GlobalValue::InternalLinkage,
140 N: moduleName + "_load", M&: module);
141 loadFn->setSection(".text.startup");
142 auto *loadBlock = BasicBlock::Create(Context&: module.getContext(), Name: "entry", Parent: loadFn);
143 builder.SetInsertPoint(loadBlock);
144 Value *moduleObj = [&] {
145 if (object.getFormat() == gpu::CompilationTarget::Assembly) {
146 FunctionCallee moduleLoadFn = module.getOrInsertFunction(
147 Name: "mgpuModuleLoadJIT", T: FunctionType::get(Result: ptrTy, Params: {ptrTy, i32Ty}, isVarArg: false));
148 Constant *optValue = ConstantInt::get(Ty: i32Ty, V: optLevel);
149 return builder.CreateCall(moduleLoadFn, {serializedObj, optValue});
150 } else {
151 FunctionCallee moduleLoadFn = module.getOrInsertFunction(
152 Name: "mgpuModuleLoad", T: FunctionType::get(Result: ptrTy, Params: {ptrTy, i64Ty}, isVarArg: false));
153 Constant *binarySize =
154 ConstantInt::get(Ty: i64Ty, V: serializedStr.size() + (addNull ? 1 : 0));
155 return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize});
156 }
157 }();
158 builder.CreateStore(Val: moduleObj, Ptr: modulePtr);
159 builder.CreateRetVoid();
160 appendToGlobalCtors(M&: module, F: loadFn, /*Priority=*/123);
161
162 auto *unloadFn = Function::Create(
163 Ty: FunctionType::get(Result: voidTy, /*IsVarArg=*/isVarArg: false),
164 Linkage: GlobalValue::InternalLinkage, N: moduleName + "_unload", M&: module);
165 unloadFn->setSection(".text.startup");
166 auto *unloadBlock =
167 BasicBlock::Create(Context&: module.getContext(), Name: "entry", Parent: unloadFn);
168 builder.SetInsertPoint(unloadBlock);
169 FunctionCallee moduleUnloadFn = module.getOrInsertFunction(
170 Name: "mgpuModuleUnload", T: FunctionType::get(Result: voidTy, Params: ptrTy, isVarArg: false));
171 builder.CreateCall(Callee: moduleUnloadFn, Args: builder.CreateLoad(Ty: ptrTy, Ptr: modulePtr));
172 builder.CreateRetVoid();
173 appendToGlobalDtors(M&: module, F: unloadFn, /*Priority=*/123);
174
175 return success();
176}
177} // namespace llvm
178
179LogicalResult SelectObjectAttrImpl::embedBinary(
180 Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
181 LLVM::ModuleTranslation &moduleTranslation) const {
182 assert(operation && "The binary operation must be non null.");
183 if (!operation)
184 return failure();
185
186 auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
187 if (!op) {
188 operation->emitError(message: "operation must be a GPU binary");
189 return failure();
190 }
191
192 gpu::ObjectAttr object = getSelectedObject(op);
193 if (!object)
194 return failure();
195
196 return embedBinaryImpl(op.getName(), object,
197 *moduleTranslation.getLLVMModule());
198}
199
200namespace llvm {
201namespace {
202class LaunchKernel {
203public:
204 LaunchKernel(Module &module, IRBuilderBase &builder,
205 mlir::LLVM::ModuleTranslation &moduleTranslation);
206 // Get the kernel launch callee.
207 FunctionCallee getKernelLaunchFn();
208
209 // Get the kernel launch callee.
210 FunctionCallee getClusterKernelLaunchFn();
211
212 // Get the module function callee.
213 FunctionCallee getModuleFunctionFn();
214
215 // Get the stream create callee.
216 FunctionCallee getStreamCreateFn();
217
218 // Get the stream destroy callee.
219 FunctionCallee getStreamDestroyFn();
220
221 // Get the stream sync callee.
222 FunctionCallee getStreamSyncFn();
223
224 // Ger or create the function name global string.
225 Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
226
227 // Create the void* kernel array for passing the arguments.
228 Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
229
230 // Create the full kernel launch.
231 llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
232 mlir::gpu::ObjectAttr object);
233
234private:
235 Module &module;
236 IRBuilderBase &builder;
237 mlir::LLVM::ModuleTranslation &moduleTranslation;
238 Type *i32Ty{};
239 Type *i64Ty{};
240 Type *voidTy{};
241 Type *intPtrTy{};
242 PointerType *ptrTy{};
243};
244} // namespace
245} // namespace llvm
246
247LogicalResult SelectObjectAttrImpl::launchKernel(
248 Attribute attribute, Operation *launchFuncOperation,
249 Operation *binaryOperation, llvm::IRBuilderBase &builder,
250 LLVM::ModuleTranslation &moduleTranslation) const {
251
252 assert(launchFuncOperation && "The launch func operation must be non null.");
253 if (!launchFuncOperation)
254 return failure();
255
256 auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
257 if (!launchFuncOp) {
258 launchFuncOperation->emitError(message: "operation must be a GPU launch func Op.");
259 return failure();
260 }
261
262 auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
263 if (!binOp) {
264 binaryOperation->emitError(message: "operation must be a GPU binary.");
265 return failure();
266 }
267 gpu::ObjectAttr object = getSelectedObject(binOp);
268 if (!object)
269 return failure();
270
271 return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder,
272 moduleTranslation)
273 .createKernelLaunch(launchFuncOp, object);
274}
275
276llvm::LaunchKernel::LaunchKernel(
277 Module &module, IRBuilderBase &builder,
278 mlir::LLVM::ModuleTranslation &moduleTranslation)
279 : module(module), builder(builder), moduleTranslation(moduleTranslation) {
280 i32Ty = builder.getInt32Ty();
281 i64Ty = builder.getInt64Ty();
282 ptrTy = builder.getPtrTy(AddrSpace: 0);
283 voidTy = builder.getVoidTy();
284 intPtrTy = builder.getIntPtrTy(DL: module.getDataLayout());
285}
286
287llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
288 return module.getOrInsertFunction(
289 Name: "mgpuLaunchKernel",
290 T: FunctionType::get(Result: voidTy,
291 Params: ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy,
292 intPtrTy, intPtrTy, intPtrTy, i32Ty,
293 ptrTy, ptrTy, ptrTy, i64Ty}),
294 isVarArg: false));
295}
296
297llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
298 return module.getOrInsertFunction(
299 Name: "mgpuLaunchClusterKernel",
300 T: FunctionType::get(
301 Result: voidTy,
302 Params: ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
303 intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
304 i32Ty, ptrTy, ptrTy, ptrTy}),
305 isVarArg: false));
306}
307
308llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
309 return module.getOrInsertFunction(
310 Name: "mgpuModuleGetFunction",
311 T: FunctionType::get(Result: ptrTy, Params: ArrayRef<Type *>({ptrTy, ptrTy}), isVarArg: false));
312}
313
314llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
315 return module.getOrInsertFunction(Name: "mgpuStreamCreate",
316 T: FunctionType::get(Result: ptrTy, isVarArg: false));
317}
318
319llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
320 return module.getOrInsertFunction(
321 Name: "mgpuStreamDestroy",
322 T: FunctionType::get(Result: voidTy, Params: ArrayRef<Type *>({ptrTy}), isVarArg: false));
323}
324
325llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
326 return module.getOrInsertFunction(
327 Name: "mgpuStreamSynchronize",
328 T: FunctionType::get(Result: voidTy, Params: ArrayRef<Type *>({ptrTy}), isVarArg: false));
329}
330
331// Generates an LLVM IR dialect global that contains the name of the given
332// kernel function as a C string, and returns a pointer to its beginning.
333llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
334 StringRef kernelName) {
335 std::string globalName =
336 std::string(formatv(Fmt: "{0}_{1}_name", Vals&: moduleName, Vals&: kernelName));
337
338 if (GlobalVariable *gv = module.getGlobalVariable(Name: globalName, AllowInternal: true))
339 return gv;
340
341 return builder.CreateGlobalString(Str: kernelName, Name: globalName);
342}
343
344// Creates a struct containing all kernel parameters on the stack and returns
345// an array of type-erased pointers to the fields of the struct. The array can
346// then be passed to the CUDA / ROCm (HIP) kernel launch calls.
347// The generated code is essentially as follows:
348//
349// %struct = alloca(sizeof(struct { Parameters... }))
350// %array = alloca(NumParameters * sizeof(void *))
351// for (i : [0, NumParameters))
352// %fieldPtr = llvm.getelementptr %struct[0, i]
353// llvm.store parameters[i], %fieldPtr
354// %elementPtr = llvm.getelementptr %array[i]
355// llvm.store %fieldPtr, %elementPtr
356// return %array
357llvm::Value *
358llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
359 SmallVector<Value *> args =
360 moduleTranslation.lookupValues(values: op.getKernelOperands());
361 SmallVector<Type *> structTypes(args.size(), nullptr);
362
363 for (auto [i, arg] : llvm::enumerate(args))
364 structTypes[i] = arg->getType();
365
366 Type *structTy = StructType::create(Context&: module.getContext(), Elements: structTypes);
367 Value *argStruct = builder.CreateAlloca(Ty: structTy, AddrSpace: 0u);
368 Value *argArray = builder.CreateAlloca(
369 Ty: ptrTy, ArraySize: ConstantInt::get(Ty: intPtrTy, V: structTypes.size()));
370
371 for (auto [i, arg] : enumerate(args)) {
372 Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
373 builder.CreateStore(arg, structMember);
374 Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
375 builder.CreateStore(structMember, arrayMember);
376 }
377 return argArray;
378}
379
380// Emits LLVM IR to launch a kernel function:
381// %1 = load %global_module_object
382// %2 = call @mgpuModuleGetFunction(%1, %global_kernel_name)
383// %3 = call @mgpuStreamCreate()
384// %4 = <see createKernelArgArray()>
385// call @mgpuLaunchKernel(%2, ..., %3, %4, ...)
386// call @mgpuStreamSynchronize(%3)
387// call @mgpuStreamDestroy(%3)
388llvm::LogicalResult
389llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
390 mlir::gpu::ObjectAttr object) {
391 auto llvmValue = [&](mlir::Value value) -> Value * {
392 Value *v = moduleTranslation.lookupValue(value);
393 assert(v && "Value has not been translated.");
394 return v;
395 };
396
397 // Get grid dimensions.
398 mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
399 Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y),
400 *gz = llvmValue(grid.z);
401
402 // Get block dimensions.
403 mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
404 Value *bx = llvmValue(block.x), *by = llvmValue(block.y),
405 *bz = llvmValue(block.z);
406
407 // Get dynamic shared memory size.
408 Value *dynamicMemorySize = nullptr;
409 if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
410 dynamicMemorySize = llvmValue(dynSz);
411 else
412 dynamicMemorySize = ConstantInt::get(Ty: i32Ty, V: 0);
413
414 // Create the argument array.
415 Value *argArray = createKernelArgArray(op);
416
417 // Load the kernel function.
418 StringRef moduleName = op.getKernelModuleName().getValue();
419 Twine moduleIdentifier = getModuleIdentifier(moduleName);
420 Value *modulePtr = module.getGlobalVariable(Name: moduleIdentifier.str(), AllowInternal: true);
421 if (!modulePtr)
422 return op.emitError() << "Couldn't find the binary: " << moduleIdentifier;
423 Value *moduleObj = builder.CreateLoad(Ty: ptrTy, Ptr: modulePtr);
424 Value *functionName = getOrCreateFunctionName(moduleName, kernelName: op.getKernelName());
425 Value *moduleFunction =
426 builder.CreateCall(Callee: getModuleFunctionFn(), Args: {moduleObj, functionName});
427
428 // Get the stream to use for execution. If there's no async object then create
429 // a stream to make a synchronous kernel launch.
430 Value *stream = nullptr;
431 // Sync & destroy the stream, for synchronous launches.
432 auto destroyStream = make_scope_exit(F: [&]() {
433 builder.CreateCall(Callee: getStreamSyncFn(), Args: {stream});
434 builder.CreateCall(Callee: getStreamDestroyFn(), Args: {stream});
435 });
436 if (mlir::Value asyncObject = op.getAsyncObject()) {
437 stream = llvmValue(asyncObject);
438 destroyStream.release();
439 } else {
440 stream = builder.CreateCall(Callee: getStreamCreateFn(), Args: {});
441 }
442
443 llvm::Constant *paramsCount =
444 llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands());
445
446 // Create the launch call.
447 Value *nullPtr = ConstantPointerNull::get(T: ptrTy);
448
449 // Launch kernel with clusters if cluster size is specified.
450 if (op.hasClusterSize()) {
451 mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
452 Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
453 *cz = llvmValue(cluster.z);
454 builder.CreateCall(
455 Callee: getClusterKernelLaunchFn(),
456 Args: ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
457 dynamicMemorySize, stream, argArray, nullPtr}));
458 } else {
459 builder.CreateCall(Callee: getKernelLaunchFn(),
460 Args: ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
461 bz, dynamicMemorySize, stream,
462 argArray, nullPtr, paramsCount}));
463 }
464
465 return success();
466}
467
468void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
469 DialectRegistry &registry) {
470 registry.addExtension(extensionFn: +[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
471 SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
472 });
473}
474

source code of mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp