1//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements passes to convert `gpu.launch_func` op into a sequence
10// of LLVM calls that emulate the host and device sides.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
15
16#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
17#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
18#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
19#include "mlir/Conversion/LLVMCommon/Pattern.h"
20#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
21#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
22#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/GPU/IR/GPUDialect.h"
25#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
27#include "mlir/IR/BuiltinOps.h"
28#include "mlir/IR/SymbolTable.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Transforms/DialectConversion.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/StringExtras.h"
33#include "llvm/Support/FormatVariadic.h"
34
35namespace mlir {
36#define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
37#include "mlir/Conversion/Passes.h.inc"
38} // namespace mlir
39
40using namespace mlir;
41
42static constexpr const char kSPIRVModule[] = "__spv__";
43
44//===----------------------------------------------------------------------===//
45// Utility functions
46//===----------------------------------------------------------------------===//
47
48/// Returns the string name of the `DescriptorSet` decoration.
49static std::string descriptorSetName() {
50 return llvm::convertToSnakeFromCamelCase(
51 stringifyDecoration(spirv::Decoration::DescriptorSet));
52}
53
54/// Returns the string name of the `Binding` decoration.
55static std::string bindingName() {
56 return llvm::convertToSnakeFromCamelCase(
57 stringifyDecoration(spirv::Decoration::Binding));
58}
59
60/// Calculates the index of the kernel's operand that is represented by the
61/// given global variable with the `bind` attribute. We assume that the index of
62/// each kernel's operand is mapped to (descriptorSet, binding) by the map:
63/// i -> (0, i)
64/// which is implemented under `LowerABIAttributesPass`.
65static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
66 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
67 return binding.getInt();
68}
69
70/// Copies the given number of bytes from src to dst pointers.
71static void copy(Location loc, Value dst, Value src, Value size,
72 OpBuilder &builder) {
73 builder.create<LLVM::MemcpyOp>(loc, dst, src, size, /*isVolatile=*/false);
74}
75
76/// Encodes the binding and descriptor set numbers into a new symbolic name.
77/// The name is specified by
78/// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
79/// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
80/// binding numbers.
81static std::string
82createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
83 StringRef kernelModuleName) {
84 IntegerAttr descriptorSet =
85 op->getAttrOfType<IntegerAttr>(descriptorSetName());
86 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
87 return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
88 kernelModuleName.str(), op.getSymName().str(),
89 std::to_string(descriptorSet.getInt()),
90 std::to_string(binding.getInt()));
91}
92
93/// Returns true if the given global variable has both a descriptor set number
94/// and a binding number.
95static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
96 IntegerAttr descriptorSet =
97 op->getAttrOfType<IntegerAttr>(descriptorSetName());
98 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
99 return descriptorSet && binding;
100}
101
102/// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
103/// arguments from the given SPIR-V module. We assume that the module contains a
104/// single entry point function. Hence, all `spirv.GlobalVariable`s with a bind
105/// attribute are kernel arguments.
106static LogicalResult getKernelGlobalVariables(
107 spirv::ModuleOp module,
108 DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
109 auto entryPoints = module.getOps<spirv::EntryPointOp>();
110 if (!llvm::hasSingleElement(entryPoints)) {
111 return module.emitError(
112 "The module must contain exactly one entry point function");
113 }
114 auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
115 for (auto globalOp : globalVariables) {
116 if (hasDescriptorSetAndBinding(globalOp))
117 globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
118 }
119 return success();
120}
121
122/// Encodes the SPIR-V module's symbolic name into the name of the entry point
123/// function.
124static LogicalResult encodeKernelName(spirv::ModuleOp module) {
125 StringRef spvModuleName = module.getSymName().value_or(kSPIRVModule);
126 // We already know that the module contains exactly one entry point function
127 // based on `getKernelGlobalVariables()` call. Update this function's name
128 // to:
129 // {spv_module_name}_{function_name}
130 auto entryPoints = module.getOps<spirv::EntryPointOp>();
131 if (!llvm::hasSingleElement(entryPoints)) {
132 return module.emitError(
133 "The module must contain exactly one entry point function");
134 }
135 spirv::EntryPointOp entryPoint = *entryPoints.begin();
136 StringRef funcName = entryPoint.getFn();
137 auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
138 StringAttr newFuncName =
139 StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
140 if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
141 return failure();
142 SymbolTable::setSymbolName(funcOp, newFuncName);
143 return success();
144}
145
146//===----------------------------------------------------------------------===//
147// Conversion patterns
148//===----------------------------------------------------------------------===//
149
150namespace {
151
152/// Structure to group information about the variables being copied.
153struct CopyInfo {
154 Value dst;
155 Value src;
156 Value size;
157};
158
159/// This pattern emulates a call to the kernel in LLVM dialect. For that, we
160/// copy the data to the global variable (emulating device side), call the
161/// kernel as a normal void LLVM function, and copy the data back (emulating the
162/// host side).
163class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
164 using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
165
166 LogicalResult
167 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
168 ConversionPatternRewriter &rewriter) const override {
169 auto *op = launchOp.getOperation();
170 MLIRContext *context = rewriter.getContext();
171 auto module = launchOp->getParentOfType<ModuleOp>();
172
173 // Get the SPIR-V module that represents the gpu kernel module. The module
174 // is named:
175 // __spv__{kernel_module_name}
176 // based on GPU to SPIR-V conversion.
177 StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
178 std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
179 auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
180 StringAttr::get(context, spvModuleName));
181 if (!spvModule) {
182 return launchOp.emitOpError("SPIR-V kernel module '")
183 << spvModuleName << "' is not found";
184 }
185
186 // Declare kernel function in the main module so that it later can be linked
187 // with its definition from the kernel module. We know that the kernel
188 // function would have no arguments and the data is passed via global
189 // variables. The name of the kernel will be
190 // {spv_module_name}_{kernel_function_name}
191 // to avoid symbolic name conflicts.
192 StringRef kernelFuncName = launchOp.getKernelName().getValue();
193 std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
194 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
195 StringAttr::get(context, newKernelFuncName));
196 if (!kernelFunc) {
197 OpBuilder::InsertionGuard guard(rewriter);
198 rewriter.setInsertionPointToStart(module.getBody());
199 kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
200 rewriter.getUnknownLoc(), newKernelFuncName,
201 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
202 ArrayRef<Type>()));
203 rewriter.setInsertionPoint(launchOp);
204 }
205
206 // Get all global variables associated with the kernel operands.
207 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
208 if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
209 return failure();
210
211 // Traverse kernel operands that were converted to MemRefDescriptors. For
212 // each operand, create a global variable and copy data from operand to it.
213 Location loc = launchOp.getLoc();
214 SmallVector<CopyInfo, 4> copyInfo;
215 auto numKernelOperands = launchOp.getNumKernelOperands();
216 auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
217 for (const auto &operand : llvm::enumerate(kernelOperands)) {
218 // Check if the kernel's operand is a ranked memref.
219 auto memRefType = dyn_cast<MemRefType>(
220 launchOp.getKernelOperand(operand.index()).getType());
221 if (!memRefType)
222 return failure();
223
224 // Calculate the size of the memref and get the pointer to the allocated
225 // buffer.
226 SmallVector<Value, 4> sizes;
227 SmallVector<Value, 4> strides;
228 Value sizeBytes;
229 getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
230 sizeBytes);
231 MemRefDescriptor descriptor(operand.value());
232 Value src = descriptor.allocatedPtr(rewriter, loc);
233
234 // Get the global variable in the SPIR-V module that is associated with
235 // the kernel operand. Construct its new name and create a corresponding
236 // LLVM dialect global variable.
237 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
238 auto pointeeType =
239 cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
240 auto dstGlobalType = typeConverter->convertType(pointeeType);
241 if (!dstGlobalType)
242 return failure();
243 std::string name =
244 createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
245 // Check if this variable has already been created.
246 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
247 if (!dstGlobal) {
248 OpBuilder::InsertionGuard guard(rewriter);
249 rewriter.setInsertionPointToStart(module.getBody());
250 dstGlobal = rewriter.create<LLVM::GlobalOp>(
251 loc, dstGlobalType,
252 /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
253 /*alignment=*/0);
254 rewriter.setInsertionPoint(launchOp);
255 }
256
257 // Copy the data from src operand pointer to dst global variable. Save
258 // src, dst and size so that we can copy data back after emulating the
259 // kernel call.
260 Value dst = rewriter.create<LLVM::AddressOfOp>(
261 loc, typeConverter->convertType(spirvGlobal.getType()),
262 dstGlobal.getSymName());
263 copy(loc, dst, src, sizeBytes, rewriter);
264
265 CopyInfo info;
266 info.dst = dst;
267 info.src = src;
268 info.size = sizeBytes;
269 copyInfo.push_back(info);
270 }
271 // Create a call to the kernel and copy the data back.
272 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
273 ArrayRef<Value>());
274 for (CopyInfo info : copyInfo)
275 copy(loc, dst: info.src, src: info.dst, size: info.size, builder&: rewriter);
276 return success();
277 }
278};
279
280class LowerHostCodeToLLVM
281 : public impl::LowerHostCodeToLLVMPassBase<LowerHostCodeToLLVM> {
282public:
283 using Base::Base;
284
285 void runOnOperation() override {
286 ModuleOp module = getOperation();
287
288 // Erase the GPU module.
289 for (auto gpuModule :
290 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
291 gpuModule.erase();
292
293 // Request C wrapper emission.
294 for (auto func : module.getOps<func::FuncOp>()) {
295 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
296 UnitAttr::get(&getContext()));
297 }
298
299 // Specify options to lower to LLVM and pull in the conversion patterns.
300 LowerToLLVMOptions options(module.getContext());
301
302 auto *context = module.getContext();
303 RewritePatternSet patterns(context);
304 LLVMTypeConverter typeConverter(context, options);
305 mlir::arith::populateArithToLLVMConversionPatterns(converter&: typeConverter, patterns);
306 populateFinalizeMemRefToLLVMConversionPatterns(converter&: typeConverter, patterns);
307 populateFuncToLLVMConversionPatterns(converter&: typeConverter, patterns);
308 patterns.add<GPULaunchLowering>(arg&: typeConverter);
309
310 // Pull in SPIR-V type conversion patterns to convert SPIR-V global
311 // variable's type to LLVM dialect type.
312 populateSPIRVToLLVMTypeConversion(typeConverter);
313
314 ConversionTarget target(*context);
315 target.addLegalDialect<LLVM::LLVMDialect>();
316 if (failed(applyPartialConversion(module, target, std::move(patterns))))
317 signalPassFailure();
318
319 // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
320 // conflicts.
321 for (auto spvModule : module.getOps<spirv::ModuleOp>()) {
322 if (failed(encodeKernelName(spvModule))) {
323 signalPassFailure();
324 return;
325 }
326 }
327 }
328};
329} // namespace
330

source code of mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp