| 1 | //===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the `OffloadingLLVMTranslationAttrInterface` for the |
| 10 | // `SelectObject` attribute. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" |
| 15 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 16 | |
| 17 | #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" |
| 18 | #include "mlir/Target/LLVMIR/Export.h" |
| 19 | #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
| 20 | |
| 21 | #include "llvm/ADT/ScopeExit.h" |
| 22 | #include "llvm/IR/Constants.h" |
| 23 | #include "llvm/IR/IRBuilder.h" |
| 24 | #include "llvm/IR/LLVMContext.h" |
| 25 | #include "llvm/IR/Module.h" |
| 26 | #include "llvm/Support/FormatVariadic.h" |
| 27 | #include "llvm/Transforms/Utils/ModuleUtils.h" |
| 28 | |
| 29 | using namespace mlir; |
| 30 | |
| 31 | namespace { |
| 32 | // Implementation of the `OffloadingLLVMTranslationAttrInterface` model. |
| 33 | class SelectObjectAttrImpl |
| 34 | : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel< |
| 35 | SelectObjectAttrImpl> { |
| 36 | // Returns the selected object for embedding. |
| 37 | gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const; |
| 38 | |
| 39 | public: |
| 40 | // Translates a `gpu.binary`, embedding the binary into a host LLVM module as |
| 41 | // global binary string which gets loaded/unloaded into a global module |
| 42 | // object through a global ctor/dtor. |
| 43 | LogicalResult embedBinary(Attribute attribute, Operation *operation, |
| 44 | llvm::IRBuilderBase &builder, |
| 45 | LLVM::ModuleTranslation &moduleTranslation) const; |
| 46 | |
| 47 | // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting |
| 48 | // in a kernel launch call. |
| 49 | LogicalResult launchKernel(Attribute attribute, |
| 50 | Operation *launchFuncOperation, |
| 51 | Operation *binaryOperation, |
| 52 | llvm::IRBuilderBase &builder, |
| 53 | LLVM::ModuleTranslation &moduleTranslation) const; |
| 54 | }; |
| 55 | } // namespace |
| 56 | |
| 57 | gpu::ObjectAttr |
| 58 | SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const { |
| 59 | ArrayRef<Attribute> objects = op.getObjectsAttr().getValue(); |
| 60 | |
| 61 | // Obtain the index of the object to select. |
| 62 | int64_t index = -1; |
| 63 | if (Attribute target = |
| 64 | cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr()) |
| 65 | .getTarget()) { |
| 66 | // If the target attribute is a number it is the index. Otherwise compare |
| 67 | // the attribute to every target inside the object array to find the index. |
| 68 | if (auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) { |
| 69 | index = indexAttr.getInt(); |
| 70 | } else { |
| 71 | for (auto [i, attr] : llvm::enumerate(objects)) { |
| 72 | auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr); |
| 73 | if (obj.getTarget() == target) { |
| 74 | index = i; |
| 75 | } |
| 76 | } |
| 77 | } |
| 78 | } else { |
| 79 | // If the target attribute is null then it's selecting the first object in |
| 80 | // the object array. |
| 81 | index = 0; |
| 82 | } |
| 83 | |
| 84 | if (index < 0 || index >= static_cast<int64_t>(objects.size())) { |
| 85 | op->emitError("the requested target object couldn't be found" ); |
| 86 | return nullptr; |
| 87 | } |
| 88 | return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]); |
| 89 | } |
| 90 | |
| 91 | static Twine getModuleIdentifier(StringRef moduleName) { |
| 92 | return moduleName + "_module" ; |
| 93 | } |
| 94 | |
| 95 | namespace llvm { |
| 96 | static LogicalResult embedBinaryImpl(StringRef moduleName, |
| 97 | gpu::ObjectAttr object, Module &module) { |
| 98 | |
| 99 | // Embed the object as a global string. |
| 100 | // Add null for assembly output for JIT paths that expect null-terminated |
| 101 | // strings. |
| 102 | bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly); |
| 103 | StringRef serializedStr = object.getObject().getValue(); |
| 104 | Constant *serializedCst = |
| 105 | ConstantDataArray::getString(Context&: module.getContext(), Initializer: serializedStr, AddNull: addNull); |
| 106 | GlobalVariable *serializedObj = |
| 107 | new GlobalVariable(module, serializedCst->getType(), true, |
| 108 | GlobalValue::LinkageTypes::InternalLinkage, |
| 109 | serializedCst, moduleName + "_binary" ); |
| 110 | serializedObj->setAlignment(MaybeAlign(8)); |
| 111 | serializedObj->setUnnamedAddr(GlobalValue::UnnamedAddr::None); |
| 112 | |
| 113 | // Default JIT optimization level. |
| 114 | auto optLevel = APInt::getZero(numBits: 32); |
| 115 | |
| 116 | if (DictionaryAttr objectProps = object.getProperties()) { |
| 117 | if (auto section = dyn_cast_or_null<StringAttr>( |
| 118 | objectProps.get(gpu::elfSectionName))) { |
| 119 | serializedObj->setSection(section.getValue()); |
| 120 | } |
| 121 | // Check if there's an optimization level embedded in the object. |
| 122 | if (auto optAttr = dyn_cast_or_null<IntegerAttr>(objectProps.get("O" ))) |
| 123 | optLevel = optAttr.getValue(); |
| 124 | } |
| 125 | |
| 126 | IRBuilder<> builder(module.getContext()); |
| 127 | auto i32Ty = builder.getInt32Ty(); |
| 128 | auto i64Ty = builder.getInt64Ty(); |
| 129 | auto ptrTy = builder.getPtrTy(AddrSpace: 0); |
| 130 | auto voidTy = builder.getVoidTy(); |
| 131 | |
| 132 | // Embed the module as a global object. |
| 133 | auto *modulePtr = new GlobalVariable( |
| 134 | module, ptrTy, /*isConstant=*/false, GlobalValue::InternalLinkage, |
| 135 | /*Initializer=*/ConstantPointerNull::get(T: ptrTy), |
| 136 | getModuleIdentifier(moduleName)); |
| 137 | |
| 138 | auto *loadFn = Function::Create(Ty: FunctionType::get(Result: voidTy, /*IsVarArg=*/isVarArg: false), |
| 139 | Linkage: GlobalValue::InternalLinkage, |
| 140 | N: moduleName + "_load" , M&: module); |
| 141 | loadFn->setSection(".text.startup" ); |
| 142 | auto *loadBlock = BasicBlock::Create(Context&: module.getContext(), Name: "entry" , Parent: loadFn); |
| 143 | builder.SetInsertPoint(loadBlock); |
| 144 | Value *moduleObj = [&] { |
| 145 | if (object.getFormat() == gpu::CompilationTarget::Assembly) { |
| 146 | FunctionCallee moduleLoadFn = module.getOrInsertFunction( |
| 147 | Name: "mgpuModuleLoadJIT" , T: FunctionType::get(Result: ptrTy, Params: {ptrTy, i32Ty}, isVarArg: false)); |
| 148 | Constant *optValue = ConstantInt::get(Ty: i32Ty, V: optLevel); |
| 149 | return builder.CreateCall(moduleLoadFn, {serializedObj, optValue}); |
| 150 | } else { |
| 151 | FunctionCallee moduleLoadFn = module.getOrInsertFunction( |
| 152 | Name: "mgpuModuleLoad" , T: FunctionType::get(Result: ptrTy, Params: {ptrTy, i64Ty}, isVarArg: false)); |
| 153 | Constant *binarySize = |
| 154 | ConstantInt::get(Ty: i64Ty, V: serializedStr.size() + (addNull ? 1 : 0)); |
| 155 | return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize}); |
| 156 | } |
| 157 | }(); |
| 158 | builder.CreateStore(Val: moduleObj, Ptr: modulePtr); |
| 159 | builder.CreateRetVoid(); |
| 160 | appendToGlobalCtors(M&: module, F: loadFn, /*Priority=*/123); |
| 161 | |
| 162 | auto *unloadFn = Function::Create( |
| 163 | Ty: FunctionType::get(Result: voidTy, /*IsVarArg=*/isVarArg: false), |
| 164 | Linkage: GlobalValue::InternalLinkage, N: moduleName + "_unload" , M&: module); |
| 165 | unloadFn->setSection(".text.startup" ); |
| 166 | auto *unloadBlock = |
| 167 | BasicBlock::Create(Context&: module.getContext(), Name: "entry" , Parent: unloadFn); |
| 168 | builder.SetInsertPoint(unloadBlock); |
| 169 | FunctionCallee moduleUnloadFn = module.getOrInsertFunction( |
| 170 | Name: "mgpuModuleUnload" , T: FunctionType::get(Result: voidTy, Params: ptrTy, isVarArg: false)); |
| 171 | builder.CreateCall(Callee: moduleUnloadFn, Args: builder.CreateLoad(Ty: ptrTy, Ptr: modulePtr)); |
| 172 | builder.CreateRetVoid(); |
| 173 | appendToGlobalDtors(M&: module, F: unloadFn, /*Priority=*/123); |
| 174 | |
| 175 | return success(); |
| 176 | } |
| 177 | } // namespace llvm |
| 178 | |
| 179 | LogicalResult SelectObjectAttrImpl::embedBinary( |
| 180 | Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder, |
| 181 | LLVM::ModuleTranslation &moduleTranslation) const { |
| 182 | assert(operation && "The binary operation must be non null." ); |
| 183 | if (!operation) |
| 184 | return failure(); |
| 185 | |
| 186 | auto op = mlir::dyn_cast<gpu::BinaryOp>(operation); |
| 187 | if (!op) { |
| 188 | operation->emitError(message: "operation must be a GPU binary" ); |
| 189 | return failure(); |
| 190 | } |
| 191 | |
| 192 | gpu::ObjectAttr object = getSelectedObject(op); |
| 193 | if (!object) |
| 194 | return failure(); |
| 195 | |
| 196 | return embedBinaryImpl(op.getName(), object, |
| 197 | *moduleTranslation.getLLVMModule()); |
| 198 | } |
| 199 | |
| 200 | namespace llvm { |
| 201 | namespace { |
| 202 | class LaunchKernel { |
| 203 | public: |
| 204 | LaunchKernel(Module &module, IRBuilderBase &builder, |
| 205 | mlir::LLVM::ModuleTranslation &moduleTranslation); |
| 206 | // Get the kernel launch callee. |
| 207 | FunctionCallee getKernelLaunchFn(); |
| 208 | |
| 209 | // Get the kernel launch callee. |
| 210 | FunctionCallee getClusterKernelLaunchFn(); |
| 211 | |
| 212 | // Get the module function callee. |
| 213 | FunctionCallee getModuleFunctionFn(); |
| 214 | |
| 215 | // Get the stream create callee. |
| 216 | FunctionCallee getStreamCreateFn(); |
| 217 | |
| 218 | // Get the stream destroy callee. |
| 219 | FunctionCallee getStreamDestroyFn(); |
| 220 | |
| 221 | // Get the stream sync callee. |
| 222 | FunctionCallee getStreamSyncFn(); |
| 223 | |
| 224 | // Ger or create the function name global string. |
| 225 | Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName); |
| 226 | |
| 227 | // Create the void* kernel array for passing the arguments. |
| 228 | Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op); |
| 229 | |
| 230 | // Create the full kernel launch. |
| 231 | llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op, |
| 232 | mlir::gpu::ObjectAttr object); |
| 233 | |
| 234 | private: |
| 235 | Module &module; |
| 236 | IRBuilderBase &builder; |
| 237 | mlir::LLVM::ModuleTranslation &moduleTranslation; |
| 238 | Type *i32Ty{}; |
| 239 | Type *i64Ty{}; |
| 240 | Type *voidTy{}; |
| 241 | Type *intPtrTy{}; |
| 242 | PointerType *ptrTy{}; |
| 243 | }; |
| 244 | } // namespace |
| 245 | } // namespace llvm |
| 246 | |
| 247 | LogicalResult SelectObjectAttrImpl::launchKernel( |
| 248 | Attribute attribute, Operation *launchFuncOperation, |
| 249 | Operation *binaryOperation, llvm::IRBuilderBase &builder, |
| 250 | LLVM::ModuleTranslation &moduleTranslation) const { |
| 251 | |
| 252 | assert(launchFuncOperation && "The launch func operation must be non null." ); |
| 253 | if (!launchFuncOperation) |
| 254 | return failure(); |
| 255 | |
| 256 | auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation); |
| 257 | if (!launchFuncOp) { |
| 258 | launchFuncOperation->emitError(message: "operation must be a GPU launch func Op." ); |
| 259 | return failure(); |
| 260 | } |
| 261 | |
| 262 | auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation); |
| 263 | if (!binOp) { |
| 264 | binaryOperation->emitError(message: "operation must be a GPU binary." ); |
| 265 | return failure(); |
| 266 | } |
| 267 | gpu::ObjectAttr object = getSelectedObject(binOp); |
| 268 | if (!object) |
| 269 | return failure(); |
| 270 | |
| 271 | return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder, |
| 272 | moduleTranslation) |
| 273 | .createKernelLaunch(launchFuncOp, object); |
| 274 | } |
| 275 | |
| 276 | llvm::LaunchKernel::LaunchKernel( |
| 277 | Module &module, IRBuilderBase &builder, |
| 278 | mlir::LLVM::ModuleTranslation &moduleTranslation) |
| 279 | : module(module), builder(builder), moduleTranslation(moduleTranslation) { |
| 280 | i32Ty = builder.getInt32Ty(); |
| 281 | i64Ty = builder.getInt64Ty(); |
| 282 | ptrTy = builder.getPtrTy(AddrSpace: 0); |
| 283 | voidTy = builder.getVoidTy(); |
| 284 | intPtrTy = builder.getIntPtrTy(DL: module.getDataLayout()); |
| 285 | } |
| 286 | |
| 287 | llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() { |
| 288 | return module.getOrInsertFunction( |
| 289 | Name: "mgpuLaunchKernel" , |
| 290 | T: FunctionType::get(Result: voidTy, |
| 291 | Params: ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, |
| 292 | intPtrTy, intPtrTy, intPtrTy, i32Ty, |
| 293 | ptrTy, ptrTy, ptrTy, i64Ty}), |
| 294 | isVarArg: false)); |
| 295 | } |
| 296 | |
| 297 | llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() { |
| 298 | return module.getOrInsertFunction( |
| 299 | Name: "mgpuLaunchClusterKernel" , |
| 300 | T: FunctionType::get( |
| 301 | Result: voidTy, |
| 302 | Params: ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy, |
| 303 | intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy, |
| 304 | i32Ty, ptrTy, ptrTy, ptrTy}), |
| 305 | isVarArg: false)); |
| 306 | } |
| 307 | |
| 308 | llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() { |
| 309 | return module.getOrInsertFunction( |
| 310 | Name: "mgpuModuleGetFunction" , |
| 311 | T: FunctionType::get(Result: ptrTy, Params: ArrayRef<Type *>({ptrTy, ptrTy}), isVarArg: false)); |
| 312 | } |
| 313 | |
| 314 | llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() { |
| 315 | return module.getOrInsertFunction(Name: "mgpuStreamCreate" , |
| 316 | T: FunctionType::get(Result: ptrTy, isVarArg: false)); |
| 317 | } |
| 318 | |
| 319 | llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() { |
| 320 | return module.getOrInsertFunction( |
| 321 | Name: "mgpuStreamDestroy" , |
| 322 | T: FunctionType::get(Result: voidTy, Params: ArrayRef<Type *>({ptrTy}), isVarArg: false)); |
| 323 | } |
| 324 | |
| 325 | llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() { |
| 326 | return module.getOrInsertFunction( |
| 327 | Name: "mgpuStreamSynchronize" , |
| 328 | T: FunctionType::get(Result: voidTy, Params: ArrayRef<Type *>({ptrTy}), isVarArg: false)); |
| 329 | } |
| 330 | |
| 331 | // Generates an LLVM IR dialect global that contains the name of the given |
| 332 | // kernel function as a C string, and returns a pointer to its beginning. |
| 333 | llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName, |
| 334 | StringRef kernelName) { |
| 335 | std::string globalName = |
| 336 | std::string(formatv(Fmt: "{0}_{1}_name" , Vals&: moduleName, Vals&: kernelName)); |
| 337 | |
| 338 | if (GlobalVariable *gv = module.getGlobalVariable(Name: globalName, AllowInternal: true)) |
| 339 | return gv; |
| 340 | |
| 341 | return builder.CreateGlobalString(Str: kernelName, Name: globalName); |
| 342 | } |
| 343 | |
| 344 | // Creates a struct containing all kernel parameters on the stack and returns |
| 345 | // an array of type-erased pointers to the fields of the struct. The array can |
| 346 | // then be passed to the CUDA / ROCm (HIP) kernel launch calls. |
| 347 | // The generated code is essentially as follows: |
| 348 | // |
| 349 | // %struct = alloca(sizeof(struct { Parameters... })) |
| 350 | // %array = alloca(NumParameters * sizeof(void *)) |
| 351 | // for (i : [0, NumParameters)) |
| 352 | // %fieldPtr = llvm.getelementptr %struct[0, i] |
| 353 | // llvm.store parameters[i], %fieldPtr |
| 354 | // %elementPtr = llvm.getelementptr %array[i] |
| 355 | // llvm.store %fieldPtr, %elementPtr |
| 356 | // return %array |
| 357 | llvm::Value * |
| 358 | llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) { |
| 359 | SmallVector<Value *> args = |
| 360 | moduleTranslation.lookupValues(values: op.getKernelOperands()); |
| 361 | SmallVector<Type *> structTypes(args.size(), nullptr); |
| 362 | |
| 363 | for (auto [i, arg] : llvm::enumerate(args)) |
| 364 | structTypes[i] = arg->getType(); |
| 365 | |
| 366 | Type *structTy = StructType::create(Context&: module.getContext(), Elements: structTypes); |
| 367 | Value *argStruct = builder.CreateAlloca(Ty: structTy, AddrSpace: 0u); |
| 368 | Value *argArray = builder.CreateAlloca( |
| 369 | Ty: ptrTy, ArraySize: ConstantInt::get(Ty: intPtrTy, V: structTypes.size())); |
| 370 | |
| 371 | for (auto [i, arg] : enumerate(args)) { |
| 372 | Value *structMember = builder.CreateStructGEP(structTy, argStruct, i); |
| 373 | builder.CreateStore(arg, structMember); |
| 374 | Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i); |
| 375 | builder.CreateStore(structMember, arrayMember); |
| 376 | } |
| 377 | return argArray; |
| 378 | } |
| 379 | |
| 380 | // Emits LLVM IR to launch a kernel function: |
| 381 | // %1 = load %global_module_object |
| 382 | // %2 = call @mgpuModuleGetFunction(%1, %global_kernel_name) |
| 383 | // %3 = call @mgpuStreamCreate() |
| 384 | // %4 = <see createKernelArgArray()> |
| 385 | // call @mgpuLaunchKernel(%2, ..., %3, %4, ...) |
| 386 | // call @mgpuStreamSynchronize(%3) |
| 387 | // call @mgpuStreamDestroy(%3) |
| 388 | llvm::LogicalResult |
| 389 | llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op, |
| 390 | mlir::gpu::ObjectAttr object) { |
| 391 | auto llvmValue = [&](mlir::Value value) -> Value * { |
| 392 | Value *v = moduleTranslation.lookupValue(value); |
| 393 | assert(v && "Value has not been translated." ); |
| 394 | return v; |
| 395 | }; |
| 396 | |
| 397 | // Get grid dimensions. |
| 398 | mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues(); |
| 399 | Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y), |
| 400 | *gz = llvmValue(grid.z); |
| 401 | |
| 402 | // Get block dimensions. |
| 403 | mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues(); |
| 404 | Value *bx = llvmValue(block.x), *by = llvmValue(block.y), |
| 405 | *bz = llvmValue(block.z); |
| 406 | |
| 407 | // Get dynamic shared memory size. |
| 408 | Value *dynamicMemorySize = nullptr; |
| 409 | if (mlir::Value dynSz = op.getDynamicSharedMemorySize()) |
| 410 | dynamicMemorySize = llvmValue(dynSz); |
| 411 | else |
| 412 | dynamicMemorySize = ConstantInt::get(Ty: i32Ty, V: 0); |
| 413 | |
| 414 | // Create the argument array. |
| 415 | Value *argArray = createKernelArgArray(op); |
| 416 | |
| 417 | // Load the kernel function. |
| 418 | StringRef moduleName = op.getKernelModuleName().getValue(); |
| 419 | Twine moduleIdentifier = getModuleIdentifier(moduleName); |
| 420 | Value *modulePtr = module.getGlobalVariable(Name: moduleIdentifier.str(), AllowInternal: true); |
| 421 | if (!modulePtr) |
| 422 | return op.emitError() << "Couldn't find the binary: " << moduleIdentifier; |
| 423 | Value *moduleObj = builder.CreateLoad(Ty: ptrTy, Ptr: modulePtr); |
| 424 | Value *functionName = getOrCreateFunctionName(moduleName, kernelName: op.getKernelName()); |
| 425 | Value *moduleFunction = |
| 426 | builder.CreateCall(Callee: getModuleFunctionFn(), Args: {moduleObj, functionName}); |
| 427 | |
| 428 | // Get the stream to use for execution. If there's no async object then create |
| 429 | // a stream to make a synchronous kernel launch. |
| 430 | Value *stream = nullptr; |
| 431 | // Sync & destroy the stream, for synchronous launches. |
| 432 | auto destroyStream = make_scope_exit(F: [&]() { |
| 433 | builder.CreateCall(Callee: getStreamSyncFn(), Args: {stream}); |
| 434 | builder.CreateCall(Callee: getStreamDestroyFn(), Args: {stream}); |
| 435 | }); |
| 436 | if (mlir::Value asyncObject = op.getAsyncObject()) { |
| 437 | stream = llvmValue(asyncObject); |
| 438 | destroyStream.release(); |
| 439 | } else { |
| 440 | stream = builder.CreateCall(Callee: getStreamCreateFn(), Args: {}); |
| 441 | } |
| 442 | |
| 443 | llvm::Constant *paramsCount = |
| 444 | llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands()); |
| 445 | |
| 446 | // Create the launch call. |
| 447 | Value *nullPtr = ConstantPointerNull::get(T: ptrTy); |
| 448 | |
| 449 | // Launch kernel with clusters if cluster size is specified. |
| 450 | if (op.hasClusterSize()) { |
| 451 | mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues(); |
| 452 | Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y), |
| 453 | *cz = llvmValue(cluster.z); |
| 454 | builder.CreateCall( |
| 455 | Callee: getClusterKernelLaunchFn(), |
| 456 | Args: ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz, |
| 457 | dynamicMemorySize, stream, argArray, nullPtr})); |
| 458 | } else { |
| 459 | builder.CreateCall(Callee: getKernelLaunchFn(), |
| 460 | Args: ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, |
| 461 | bz, dynamicMemorySize, stream, |
| 462 | argArray, nullPtr, paramsCount})); |
| 463 | } |
| 464 | |
| 465 | return success(); |
| 466 | } |
| 467 | |
| 468 | void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels( |
| 469 | DialectRegistry ®istry) { |
| 470 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, gpu::GPUDialect *dialect) { |
| 471 | SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx); |
| 472 | }); |
| 473 | } |
| 474 | |