1//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements passes to convert `gpu.launch_func` op into a sequence
10// of LLVM calls that emulate the host and device sides.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
15#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
16#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
17#include "mlir/Conversion/LLVMCommon/Pattern.h"
18#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
19#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
20#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h"
22#include "mlir/Dialect/GPU/IR/GPUDialect.h"
23#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/SymbolTable.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/StringExtras.h"
31#include "llvm/Support/FormatVariadic.h"
32
33namespace mlir {
34#define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
35#include "mlir/Conversion/Passes.h.inc"
36} // namespace mlir
37
38using namespace mlir;
39
40static constexpr const char kSPIRVModule[] = "__spv__";
41
42//===----------------------------------------------------------------------===//
43// Utility functions
44//===----------------------------------------------------------------------===//
45
46/// Returns the string name of the `DescriptorSet` decoration.
47static std::string descriptorSetName() {
48 return llvm::convertToSnakeFromCamelCase(
49 input: stringifyDecoration(spirv::Decoration::DescriptorSet));
50}
51
52/// Returns the string name of the `Binding` decoration.
53static std::string bindingName() {
54 return llvm::convertToSnakeFromCamelCase(
55 input: stringifyDecoration(spirv::Decoration::Binding));
56}
57
58/// Calculates the index of the kernel's operand that is represented by the
59/// given global variable with the `bind` attribute. We assume that the index of
60/// each kernel's operand is mapped to (descriptorSet, binding) by the map:
61/// i -> (0, i)
62/// which is implemented under `LowerABIAttributesPass`.
63static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
64 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(name: bindingName());
65 return binding.getInt();
66}
67
68/// Copies the given number of bytes from src to dst pointers.
69static void copy(Location loc, Value dst, Value src, Value size,
70 OpBuilder &builder) {
71 builder.create<LLVM::MemcpyOp>(location: loc, args&: dst, args&: src, args&: size, /*isVolatile=*/args: false);
72}
73
74/// Encodes the binding and descriptor set numbers into a new symbolic name.
75/// The name is specified by
76/// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
77/// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
78/// binding numbers.
79static std::string
80createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
81 StringRef kernelModuleName) {
82 IntegerAttr descriptorSet =
83 op->getAttrOfType<IntegerAttr>(name: descriptorSetName());
84 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(name: bindingName());
85 return llvm::formatv(Fmt: "{0}_{1}_descriptor_set{2}_binding{3}",
86 Vals: kernelModuleName.str(), Vals: op.getSymName().str(),
87 Vals: std::to_string(val: descriptorSet.getInt()),
88 Vals: std::to_string(val: binding.getInt()));
89}
90
91/// Returns true if the given global variable has both a descriptor set number
92/// and a binding number.
93static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
94 IntegerAttr descriptorSet =
95 op->getAttrOfType<IntegerAttr>(name: descriptorSetName());
96 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(name: bindingName());
97 return descriptorSet && binding;
98}
99
100/// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
101/// arguments from the given SPIR-V module. We assume that the module contains a
102/// single entry point function. Hence, all `spirv.GlobalVariable`s with a bind
103/// attribute are kernel arguments.
104static LogicalResult getKernelGlobalVariables(
105 spirv::ModuleOp module,
106 DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
107 auto entryPoints = module.getOps<spirv::EntryPointOp>();
108 if (!llvm::hasSingleElement(C&: entryPoints)) {
109 return module.emitError(
110 message: "The module must contain exactly one entry point function");
111 }
112 auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
113 for (auto globalOp : globalVariables) {
114 if (hasDescriptorSetAndBinding(op: globalOp))
115 globalVariableMap[calculateGlobalIndex(op: globalOp)] = globalOp;
116 }
117 return success();
118}
119
120/// Encodes the SPIR-V module's symbolic name into the name of the entry point
121/// function.
122static LogicalResult encodeKernelName(spirv::ModuleOp module) {
123 StringRef spvModuleName = module.getSymName().value_or(u: kSPIRVModule);
124 // We already know that the module contains exactly one entry point function
125 // based on `getKernelGlobalVariables()` call. Update this function's name
126 // to:
127 // {spv_module_name}_{function_name}
128 auto entryPoints = module.getOps<spirv::EntryPointOp>();
129 if (!llvm::hasSingleElement(C&: entryPoints)) {
130 return module.emitError(
131 message: "The module must contain exactly one entry point function");
132 }
133 spirv::EntryPointOp entryPoint = *entryPoints.begin();
134 StringRef funcName = entryPoint.getFn();
135 auto funcOp = module.lookupSymbol<spirv::FuncOp>(symbol: entryPoint.getFnAttr());
136 StringAttr newFuncName =
137 StringAttr::get(context: module->getContext(), bytes: spvModuleName + "_" + funcName);
138 if (failed(Result: SymbolTable::replaceAllSymbolUses(oldSymbol: funcOp, newSymbolName: newFuncName, from: module)))
139 return failure();
140 SymbolTable::setSymbolName(symbol: funcOp, name: newFuncName);
141 return success();
142}
143
144//===----------------------------------------------------------------------===//
145// Conversion patterns
146//===----------------------------------------------------------------------===//
147
148namespace {
149
150/// Structure to group information about the variables being copied.
151struct CopyInfo {
152 Value dst;
153 Value src;
154 Value size;
155};
156
157/// This pattern emulates a call to the kernel in LLVM dialect. For that, we
158/// copy the data to the global variable (emulating device side), call the
159/// kernel as a normal void LLVM function, and copy the data back (emulating the
160/// host side).
161class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
162 using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
163
164 LogicalResult
165 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
166 ConversionPatternRewriter &rewriter) const override {
167 auto *op = launchOp.getOperation();
168 MLIRContext *context = rewriter.getContext();
169 auto module = launchOp->getParentOfType<ModuleOp>();
170
171 // Get the SPIR-V module that represents the gpu kernel module. The module
172 // is named:
173 // __spv__{kernel_module_name}
174 // based on GPU to SPIR-V conversion.
175 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
176 std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
177 auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
178 name: StringAttr::get(context, bytes: spvModuleName));
179 if (!spvModule) {
180 return launchOp.emitOpError(message: "SPIR-V kernel module '")
181 << spvModuleName << "' is not found";
182 }
183
184 // Declare kernel function in the main module so that it later can be linked
185 // with its definition from the kernel module. We know that the kernel
186 // function would have no arguments and the data is passed via global
187 // variables. The name of the kernel will be
188 // {spv_module_name}_{kernel_function_name}
189 // to avoid symbolic name conflicts.
190 StringRef kernelFuncName = launchOp.getKernelName().getValue();
191 std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
192 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
193 name: StringAttr::get(context, bytes: newKernelFuncName));
194 if (!kernelFunc) {
195 OpBuilder::InsertionGuard guard(rewriter);
196 rewriter.setInsertionPointToStart(module.getBody());
197 kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
198 location: rewriter.getUnknownLoc(), args&: newKernelFuncName,
199 args: LLVM::LLVMFunctionType::get(result: LLVM::LLVMVoidType::get(ctx: context),
200 arguments: ArrayRef<Type>()));
201 rewriter.setInsertionPoint(launchOp);
202 }
203
204 // Get all global variables associated with the kernel operands.
205 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
206 if (failed(Result: getKernelGlobalVariables(module: spvModule, globalVariableMap)))
207 return failure();
208
209 // Traverse kernel operands that were converted to MemRefDescriptors. For
210 // each operand, create a global variable and copy data from operand to it.
211 Location loc = launchOp.getLoc();
212 SmallVector<CopyInfo, 4> copyInfo;
213 auto numKernelOperands = launchOp.getNumKernelOperands();
214 auto kernelOperands = adaptor.getOperands().take_back(n: numKernelOperands);
215 for (const auto &operand : llvm::enumerate(First&: kernelOperands)) {
216 // Check if the kernel's operand is a ranked memref.
217 auto memRefType = dyn_cast<MemRefType>(
218 Val: launchOp.getKernelOperand(i: operand.index()).getType());
219 if (!memRefType)
220 return failure();
221
222 // Calculate the size of the memref and get the pointer to the allocated
223 // buffer.
224 SmallVector<Value, 4> sizes;
225 SmallVector<Value, 4> strides;
226 Value sizeBytes;
227 getMemRefDescriptorSizes(loc, memRefType, dynamicSizes: {}, rewriter, sizes, strides,
228 size&: sizeBytes);
229 MemRefDescriptor descriptor(operand.value());
230 Value src = descriptor.allocatedPtr(builder&: rewriter, loc);
231
232 // Get the global variable in the SPIR-V module that is associated with
233 // the kernel operand. Construct its new name and create a corresponding
234 // LLVM dialect global variable.
235 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
236 auto pointeeType =
237 cast<spirv::PointerType>(Val: spirvGlobal.getType()).getPointeeType();
238 auto dstGlobalType = typeConverter->convertType(t: pointeeType);
239 if (!dstGlobalType)
240 return failure();
241 std::string name =
242 createGlobalVariableWithBindName(op: spirvGlobal, kernelModuleName: spvModuleName);
243 // Check if this variable has already been created.
244 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
245 if (!dstGlobal) {
246 OpBuilder::InsertionGuard guard(rewriter);
247 rewriter.setInsertionPointToStart(module.getBody());
248 dstGlobal = rewriter.create<LLVM::GlobalOp>(
249 location: loc, args&: dstGlobalType,
250 /*isConstant=*/args: false, args: LLVM::Linkage::Linkonce, args&: name, args: Attribute(),
251 /*alignment=*/args: 0);
252 rewriter.setInsertionPoint(launchOp);
253 }
254
255 // Copy the data from src operand pointer to dst global variable. Save
256 // src, dst and size so that we can copy data back after emulating the
257 // kernel call.
258 Value dst = rewriter.create<LLVM::AddressOfOp>(
259 location: loc, args: typeConverter->convertType(t: spirvGlobal.getType()),
260 args: dstGlobal.getSymName());
261 copy(loc, dst, src, size: sizeBytes, builder&: rewriter);
262
263 CopyInfo info;
264 info.dst = dst;
265 info.src = src;
266 info.size = sizeBytes;
267 copyInfo.push_back(Elt: info);
268 }
269 // Create a call to the kernel and copy the data back.
270 Operation *callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
271 op, args&: kernelFunc, args: ArrayRef<Value>());
272 rewriter.setInsertionPointAfter(callOp);
273 for (CopyInfo info : copyInfo)
274 copy(loc, dst: info.src, src: info.dst, size: info.size, builder&: rewriter);
275 return success();
276 }
277};
278
279class LowerHostCodeToLLVM
280 : public impl::LowerHostCodeToLLVMPassBase<LowerHostCodeToLLVM> {
281public:
282 using Base::Base;
283
284 void runOnOperation() override {
285 ModuleOp module = getOperation();
286
287 // Erase the GPU module.
288 for (auto gpuModule :
289 llvm::make_early_inc_range(Range: module.getOps<gpu::GPUModuleOp>()))
290 gpuModule.erase();
291
292 // Request C wrapper emission.
293 for (auto func : module.getOps<func::FuncOp>()) {
294 func->setAttr(name: LLVM::LLVMDialect::getEmitCWrapperAttrName(),
295 value: UnitAttr::get(context: &getContext()));
296 }
297
298 // Specify options to lower to LLVM and pull in the conversion patterns.
299 LowerToLLVMOptions options(module.getContext());
300
301 auto *context = module.getContext();
302 RewritePatternSet patterns(context);
303 LLVMTypeConverter typeConverter(context, options);
304 mlir::arith::populateArithToLLVMConversionPatterns(converter: typeConverter, patterns);
305 populateFinalizeMemRefToLLVMConversionPatterns(converter: typeConverter, patterns);
306 populateFuncToLLVMConversionPatterns(converter: typeConverter, patterns);
307 patterns.add<GPULaunchLowering>(arg&: typeConverter);
308
309 // Pull in SPIR-V type conversion patterns to convert SPIR-V global
310 // variable's type to LLVM dialect type.
311 populateSPIRVToLLVMTypeConversion(typeConverter);
312
313 ConversionTarget target(*context);
314 target.addLegalDialect<LLVM::LLVMDialect>();
315 if (failed(Result: applyPartialConversion(op: module, target, patterns: std::move(patterns))))
316 signalPassFailure();
317
318 // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
319 // conflicts.
320 for (auto spvModule : module.getOps<spirv::ModuleOp>()) {
321 if (failed(Result: encodeKernelName(module: spvModule))) {
322 signalPassFailure();
323 return;
324 }
325 }
326 }
327};
328} // namespace
329

source code of mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp