| 1 | //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements passes to convert `gpu.launch_func` op into a sequence |
| 10 | // of LLVM calls that emulate the host and device sides. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" |
| 15 | |
| 16 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
| 17 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
| 18 | #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
| 19 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 20 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| 21 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
| 22 | #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" |
| 23 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 24 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 25 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 26 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 27 | #include "mlir/IR/BuiltinOps.h" |
| 28 | #include "mlir/IR/SymbolTable.h" |
| 29 | #include "mlir/Pass/Pass.h" |
| 30 | #include "mlir/Transforms/DialectConversion.h" |
| 31 | #include "llvm/ADT/DenseMap.h" |
| 32 | #include "llvm/ADT/StringExtras.h" |
| 33 | #include "llvm/Support/FormatVariadic.h" |
| 34 | |
| 35 | namespace mlir { |
| 36 | #define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS |
| 37 | #include "mlir/Conversion/Passes.h.inc" |
| 38 | } // namespace mlir |
| 39 | |
| 40 | using namespace mlir; |
| 41 | |
| 42 | static constexpr const char kSPIRVModule[] = "__spv__" ; |
| 43 | |
| 44 | //===----------------------------------------------------------------------===// |
| 45 | // Utility functions |
| 46 | //===----------------------------------------------------------------------===// |
| 47 | |
| 48 | /// Returns the string name of the `DescriptorSet` decoration. |
| 49 | static std::string descriptorSetName() { |
| 50 | return llvm::convertToSnakeFromCamelCase( |
| 51 | stringifyDecoration(spirv::Decoration::DescriptorSet)); |
| 52 | } |
| 53 | |
| 54 | /// Returns the string name of the `Binding` decoration. |
| 55 | static std::string bindingName() { |
| 56 | return llvm::convertToSnakeFromCamelCase( |
| 57 | stringifyDecoration(spirv::Decoration::Binding)); |
| 58 | } |
| 59 | |
| 60 | /// Calculates the index of the kernel's operand that is represented by the |
| 61 | /// given global variable with the `bind` attribute. We assume that the index of |
| 62 | /// each kernel's operand is mapped to (descriptorSet, binding) by the map: |
| 63 | /// i -> (0, i) |
| 64 | /// which is implemented under `LowerABIAttributesPass`. |
| 65 | static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) { |
| 66 | IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); |
| 67 | return binding.getInt(); |
| 68 | } |
| 69 | |
| 70 | /// Copies the given number of bytes from src to dst pointers. |
| 71 | static void copy(Location loc, Value dst, Value src, Value size, |
| 72 | OpBuilder &builder) { |
| 73 | builder.create<LLVM::MemcpyOp>(loc, dst, src, size, /*isVolatile=*/false); |
| 74 | } |
| 75 | |
| 76 | /// Encodes the binding and descriptor set numbers into a new symbolic name. |
| 77 | /// The name is specified by |
| 78 | /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b} |
| 79 | /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and |
| 80 | /// binding numbers. |
| 81 | static std::string |
| 82 | createGlobalVariableWithBindName(spirv::GlobalVariableOp op, |
| 83 | StringRef kernelModuleName) { |
| 84 | IntegerAttr descriptorSet = |
| 85 | op->getAttrOfType<IntegerAttr>(descriptorSetName()); |
| 86 | IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); |
| 87 | return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}" , |
| 88 | kernelModuleName.str(), op.getSymName().str(), |
| 89 | std::to_string(descriptorSet.getInt()), |
| 90 | std::to_string(binding.getInt())); |
| 91 | } |
| 92 | |
| 93 | /// Returns true if the given global variable has both a descriptor set number |
| 94 | /// and a binding number. |
| 95 | static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) { |
| 96 | IntegerAttr descriptorSet = |
| 97 | op->getAttrOfType<IntegerAttr>(descriptorSetName()); |
| 98 | IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); |
| 99 | return descriptorSet && binding; |
| 100 | } |
| 101 | |
| 102 | /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel |
| 103 | /// arguments from the given SPIR-V module. We assume that the module contains a |
| 104 | /// single entry point function. Hence, all `spirv.GlobalVariable`s with a bind |
| 105 | /// attribute are kernel arguments. |
| 106 | static LogicalResult getKernelGlobalVariables( |
| 107 | spirv::ModuleOp module, |
| 108 | DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) { |
| 109 | auto entryPoints = module.getOps<spirv::EntryPointOp>(); |
| 110 | if (!llvm::hasSingleElement(entryPoints)) { |
| 111 | return module.emitError( |
| 112 | "The module must contain exactly one entry point function" ); |
| 113 | } |
| 114 | auto globalVariables = module.getOps<spirv::GlobalVariableOp>(); |
| 115 | for (auto globalOp : globalVariables) { |
| 116 | if (hasDescriptorSetAndBinding(globalOp)) |
| 117 | globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp; |
| 118 | } |
| 119 | return success(); |
| 120 | } |
| 121 | |
| 122 | /// Encodes the SPIR-V module's symbolic name into the name of the entry point |
| 123 | /// function. |
| 124 | static LogicalResult encodeKernelName(spirv::ModuleOp module) { |
| 125 | StringRef spvModuleName = module.getSymName().value_or(kSPIRVModule); |
| 126 | // We already know that the module contains exactly one entry point function |
| 127 | // based on `getKernelGlobalVariables()` call. Update this function's name |
| 128 | // to: |
| 129 | // {spv_module_name}_{function_name} |
| 130 | auto entryPoints = module.getOps<spirv::EntryPointOp>(); |
| 131 | if (!llvm::hasSingleElement(entryPoints)) { |
| 132 | return module.emitError( |
| 133 | "The module must contain exactly one entry point function" ); |
| 134 | } |
| 135 | spirv::EntryPointOp entryPoint = *entryPoints.begin(); |
| 136 | StringRef funcName = entryPoint.getFn(); |
| 137 | auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr()); |
| 138 | StringAttr newFuncName = |
| 139 | StringAttr::get(module->getContext(), spvModuleName + "_" + funcName); |
| 140 | if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module))) |
| 141 | return failure(); |
| 142 | SymbolTable::setSymbolName(funcOp, newFuncName); |
| 143 | return success(); |
| 144 | } |
| 145 | |
| 146 | //===----------------------------------------------------------------------===// |
| 147 | // Conversion patterns |
| 148 | //===----------------------------------------------------------------------===// |
| 149 | |
| 150 | namespace { |
| 151 | |
| 152 | /// Structure to group information about the variables being copied. |
| 153 | struct CopyInfo { |
| 154 | Value dst; |
| 155 | Value src; |
| 156 | Value size; |
| 157 | }; |
| 158 | |
| 159 | /// This pattern emulates a call to the kernel in LLVM dialect. For that, we |
| 160 | /// copy the data to the global variable (emulating device side), call the |
| 161 | /// kernel as a normal void LLVM function, and copy the data back (emulating the |
| 162 | /// host side). |
| 163 | class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { |
| 164 | using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern; |
| 165 | |
| 166 | LogicalResult |
| 167 | matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, |
| 168 | ConversionPatternRewriter &rewriter) const override { |
| 169 | auto *op = launchOp.getOperation(); |
| 170 | MLIRContext *context = rewriter.getContext(); |
| 171 | auto module = launchOp->getParentOfType<ModuleOp>(); |
| 172 | |
| 173 | // Get the SPIR-V module that represents the gpu kernel module. The module |
| 174 | // is named: |
| 175 | // __spv__{kernel_module_name} |
| 176 | // based on GPU to SPIR-V conversion. |
| 177 | StringRef kernelModuleName = launchOp.getKernelModuleName().getValue(); |
| 178 | std::string spvModuleName = kSPIRVModule + kernelModuleName.str(); |
| 179 | auto spvModule = module.lookupSymbol<spirv::ModuleOp>( |
| 180 | StringAttr::get(context, spvModuleName)); |
| 181 | if (!spvModule) { |
| 182 | return launchOp.emitOpError("SPIR-V kernel module '" ) |
| 183 | << spvModuleName << "' is not found" ; |
| 184 | } |
| 185 | |
| 186 | // Declare kernel function in the main module so that it later can be linked |
| 187 | // with its definition from the kernel module. We know that the kernel |
| 188 | // function would have no arguments and the data is passed via global |
| 189 | // variables. The name of the kernel will be |
| 190 | // {spv_module_name}_{kernel_function_name} |
| 191 | // to avoid symbolic name conflicts. |
| 192 | StringRef kernelFuncName = launchOp.getKernelName().getValue(); |
| 193 | std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str(); |
| 194 | auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>( |
| 195 | StringAttr::get(context, newKernelFuncName)); |
| 196 | if (!kernelFunc) { |
| 197 | OpBuilder::InsertionGuard guard(rewriter); |
| 198 | rewriter.setInsertionPointToStart(module.getBody()); |
| 199 | kernelFunc = rewriter.create<LLVM::LLVMFuncOp>( |
| 200 | rewriter.getUnknownLoc(), newKernelFuncName, |
| 201 | LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), |
| 202 | ArrayRef<Type>())); |
| 203 | rewriter.setInsertionPoint(launchOp); |
| 204 | } |
| 205 | |
| 206 | // Get all global variables associated with the kernel operands. |
| 207 | DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap; |
| 208 | if (failed(getKernelGlobalVariables(spvModule, globalVariableMap))) |
| 209 | return failure(); |
| 210 | |
| 211 | // Traverse kernel operands that were converted to MemRefDescriptors. For |
| 212 | // each operand, create a global variable and copy data from operand to it. |
| 213 | Location loc = launchOp.getLoc(); |
| 214 | SmallVector<CopyInfo, 4> copyInfo; |
| 215 | auto numKernelOperands = launchOp.getNumKernelOperands(); |
| 216 | auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands); |
| 217 | for (const auto &operand : llvm::enumerate(kernelOperands)) { |
| 218 | // Check if the kernel's operand is a ranked memref. |
| 219 | auto memRefType = dyn_cast<MemRefType>( |
| 220 | launchOp.getKernelOperand(operand.index()).getType()); |
| 221 | if (!memRefType) |
| 222 | return failure(); |
| 223 | |
| 224 | // Calculate the size of the memref and get the pointer to the allocated |
| 225 | // buffer. |
| 226 | SmallVector<Value, 4> sizes; |
| 227 | SmallVector<Value, 4> strides; |
| 228 | Value sizeBytes; |
| 229 | getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides, |
| 230 | sizeBytes); |
| 231 | MemRefDescriptor descriptor(operand.value()); |
| 232 | Value src = descriptor.allocatedPtr(rewriter, loc); |
| 233 | |
| 234 | // Get the global variable in the SPIR-V module that is associated with |
| 235 | // the kernel operand. Construct its new name and create a corresponding |
| 236 | // LLVM dialect global variable. |
| 237 | spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; |
| 238 | auto pointeeType = |
| 239 | cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType(); |
| 240 | auto dstGlobalType = typeConverter->convertType(pointeeType); |
| 241 | if (!dstGlobalType) |
| 242 | return failure(); |
| 243 | std::string name = |
| 244 | createGlobalVariableWithBindName(spirvGlobal, spvModuleName); |
| 245 | // Check if this variable has already been created. |
| 246 | auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name); |
| 247 | if (!dstGlobal) { |
| 248 | OpBuilder::InsertionGuard guard(rewriter); |
| 249 | rewriter.setInsertionPointToStart(module.getBody()); |
| 250 | dstGlobal = rewriter.create<LLVM::GlobalOp>( |
| 251 | loc, dstGlobalType, |
| 252 | /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(), |
| 253 | /*alignment=*/0); |
| 254 | rewriter.setInsertionPoint(launchOp); |
| 255 | } |
| 256 | |
| 257 | // Copy the data from src operand pointer to dst global variable. Save |
| 258 | // src, dst and size so that we can copy data back after emulating the |
| 259 | // kernel call. |
| 260 | Value dst = rewriter.create<LLVM::AddressOfOp>( |
| 261 | loc, typeConverter->convertType(spirvGlobal.getType()), |
| 262 | dstGlobal.getSymName()); |
| 263 | copy(loc, dst, src, sizeBytes, rewriter); |
| 264 | |
| 265 | CopyInfo info; |
| 266 | info.dst = dst; |
| 267 | info.src = src; |
| 268 | info.size = sizeBytes; |
| 269 | copyInfo.push_back(info); |
| 270 | } |
| 271 | // Create a call to the kernel and copy the data back. |
| 272 | rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc, |
| 273 | ArrayRef<Value>()); |
| 274 | for (CopyInfo info : copyInfo) |
| 275 | copy(loc, dst: info.src, src: info.dst, size: info.size, builder&: rewriter); |
| 276 | return success(); |
| 277 | } |
| 278 | }; |
| 279 | |
| 280 | class LowerHostCodeToLLVM |
| 281 | : public impl::LowerHostCodeToLLVMPassBase<LowerHostCodeToLLVM> { |
| 282 | public: |
| 283 | using Base::Base; |
| 284 | |
| 285 | void runOnOperation() override { |
| 286 | ModuleOp module = getOperation(); |
| 287 | |
| 288 | // Erase the GPU module. |
| 289 | for (auto gpuModule : |
| 290 | llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>())) |
| 291 | gpuModule.erase(); |
| 292 | |
| 293 | // Request C wrapper emission. |
| 294 | for (auto func : module.getOps<func::FuncOp>()) { |
| 295 | func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
| 296 | UnitAttr::get(&getContext())); |
| 297 | } |
| 298 | |
| 299 | // Specify options to lower to LLVM and pull in the conversion patterns. |
| 300 | LowerToLLVMOptions options(module.getContext()); |
| 301 | |
| 302 | auto *context = module.getContext(); |
| 303 | RewritePatternSet patterns(context); |
| 304 | LLVMTypeConverter typeConverter(context, options); |
| 305 | mlir::arith::populateArithToLLVMConversionPatterns(converter: typeConverter, patterns); |
| 306 | populateFinalizeMemRefToLLVMConversionPatterns(converter: typeConverter, patterns); |
| 307 | populateFuncToLLVMConversionPatterns(converter: typeConverter, patterns); |
| 308 | patterns.add<GPULaunchLowering>(arg&: typeConverter); |
| 309 | |
| 310 | // Pull in SPIR-V type conversion patterns to convert SPIR-V global |
| 311 | // variable's type to LLVM dialect type. |
| 312 | populateSPIRVToLLVMTypeConversion(typeConverter); |
| 313 | |
| 314 | ConversionTarget target(*context); |
| 315 | target.addLegalDialect<LLVM::LLVMDialect>(); |
| 316 | if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
| 317 | signalPassFailure(); |
| 318 | |
| 319 | // Finally, modify the kernel function in SPIR-V modules to avoid symbolic |
| 320 | // conflicts. |
| 321 | for (auto spvModule : module.getOps<spirv::ModuleOp>()) { |
| 322 | if (failed(encodeKernelName(spvModule))) { |
| 323 | signalPassFailure(); |
| 324 | return; |
| 325 | } |
| 326 | } |
| 327 | } |
| 328 | }; |
| 329 | } // namespace |
| 330 | |