1//===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert gpu.launch_func op into a sequence of
10// GPU runtime calls. As most of GPU runtimes does not have a stable published
11// ABI, this pass uses a slim runtime layer that builds on top of the public
12// API from GPU runtime headers.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17
18#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
19#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
21#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
22#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
23#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
24#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
25#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
26#include "mlir/Conversion/LLVMCommon/Pattern.h"
27#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
28#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
29#include "mlir/Dialect/Async/IR/Async.h"
30#include "mlir/Dialect/GPU/IR/GPUDialect.h"
31#include "mlir/Dialect/GPU/Transforms/Passes.h"
32#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
33#include "mlir/Dialect/MemRef/IR/MemRef.h"
34#include "mlir/IR/Attributes.h"
35#include "mlir/IR/Builders.h"
36#include "mlir/IR/BuiltinOps.h"
37#include "mlir/IR/BuiltinTypes.h"
38
39#include "llvm/ADT/STLExtras.h"
40#include "llvm/Support/Error.h"
41#include "llvm/Support/FormatVariadic.h"
42
43#define DEBUG_TYPE "gpu-to-llvm"
44
45namespace mlir {
46#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
47#include "mlir/Conversion/Passes.h.inc"
48} // namespace mlir
49
50using namespace mlir;
51
52static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
53
54namespace {
55class GpuToLLVMConversionPass
56 : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
57public:
58 using Base::Base;
59 void getDependentDialects(DialectRegistry &registry) const final {
60 Base::getDependentDialects(registry);
61 registerConvertToLLVMDependentDialectLoading(registry);
62 }
63 // Run the dialect converter on the module.
64 void runOnOperation() override;
65};
66
67template <typename OpTy>
68class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
69public:
70 explicit ConvertOpToGpuRuntimeCallPattern(
71 const LLVMTypeConverter &typeConverter)
72 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
73
74protected:
75 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
76 MemRefType type, MemRefDescriptor desc) const {
77 Type indexType = ConvertToLLVMPattern::getIndexType();
78 return type.hasStaticShape()
79 ? ConvertToLLVMPattern::createIndexAttrConstant(
80 rewriter, loc, indexType, type.getNumElements())
81 // For identity maps (verified by caller), the number of
82 // elements is stride[0] * size[0].
83 : rewriter.create<LLVM::MulOp>(loc,
84 desc.stride(rewriter, loc, 0),
85 desc.size(rewriter, loc, 0));
86 }
87
88 MLIRContext *context = &this->getTypeConverter()->getContext();
89
90 Type llvmVoidType = LLVM::LLVMVoidType::get(ctx: context);
91 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
92 Type llvmInt8Type = IntegerType::get(context, 8);
93 Type llvmInt16Type = IntegerType::get(context, 16);
94 Type llvmInt32Type = IntegerType::get(context, 32);
95 Type llvmInt64Type = IntegerType::get(context, 64);
96 Type llvmFloat32Type = Float32Type::get(context);
97 Type llvmIntPtrType = IntegerType::get(
98 context, this->getTypeConverter()->getPointerBitwidth(0));
99
100 FunctionCallBuilder moduleLoadCallBuilder = {
101 "mgpuModuleLoad",
102 llvmPointerType /* void *module */,
103 {llvmPointerType /* void *cubin */, llvmInt64Type /* size_t size */}};
104 FunctionCallBuilder moduleUnloadCallBuilder = {
105 "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
106 FunctionCallBuilder moduleGetFunctionCallBuilder = {
107 "mgpuModuleGetFunction",
108 llvmPointerType /* void *function */,
109 {
110 llvmPointerType, /* void *module */
111 llvmPointerType /* char *name */
112 }};
113 FunctionCallBuilder launchKernelCallBuilder = {
114 "mgpuLaunchKernel",
115 llvmVoidType,
116 {
117 llvmPointerType, /* void* f */
118 llvmIntPtrType, /* intptr_t gridXDim */
119 llvmIntPtrType, /* intptr_t gridyDim */
120 llvmIntPtrType, /* intptr_t gridZDim */
121 llvmIntPtrType, /* intptr_t blockXDim */
122 llvmIntPtrType, /* intptr_t blockYDim */
123 llvmIntPtrType, /* intptr_t blockZDim */
124 llvmInt32Type, /* unsigned int sharedMemBytes */
125 llvmPointerType, /* void *hstream */
126 llvmPointerType, /* void **kernelParams */
127 llvmPointerType, /* void **extra */
128 llvmInt64Type /* size_t paramsCount */
129 }};
130 FunctionCallBuilder streamCreateCallBuilder = {
131 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
132 FunctionCallBuilder streamDestroyCallBuilder = {
133 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
134 FunctionCallBuilder streamSynchronizeCallBuilder = {
135 "mgpuStreamSynchronize",
136 llvmVoidType,
137 {llvmPointerType /* void *stream */}};
138 FunctionCallBuilder streamWaitEventCallBuilder = {
139 "mgpuStreamWaitEvent",
140 llvmVoidType,
141 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
142 FunctionCallBuilder eventCreateCallBuilder = {
143 "mgpuEventCreate", llvmPointerType /* void *event */, {}};
144 FunctionCallBuilder eventDestroyCallBuilder = {
145 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
146 FunctionCallBuilder eventSynchronizeCallBuilder = {
147 "mgpuEventSynchronize",
148 llvmVoidType,
149 {llvmPointerType /* void *event */}};
150 FunctionCallBuilder eventRecordCallBuilder = {
151 "mgpuEventRecord",
152 llvmVoidType,
153 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
154 FunctionCallBuilder hostRegisterCallBuilder = {
155 "mgpuMemHostRegisterMemRef",
156 llvmVoidType,
157 {llvmIntPtrType /* intptr_t rank */,
158 llvmPointerType /* void *memrefDesc */,
159 llvmIntPtrType /* intptr_t elementSizeBytes */}};
160 FunctionCallBuilder hostUnregisterCallBuilder = {
161 "mgpuMemHostUnregisterMemRef",
162 llvmVoidType,
163 {llvmIntPtrType /* intptr_t rank */,
164 llvmPointerType /* void *memrefDesc */,
165 llvmIntPtrType /* intptr_t elementSizeBytes */}};
166 FunctionCallBuilder allocCallBuilder = {
167 "mgpuMemAlloc",
168 llvmPointerType /* void * */,
169 {llvmIntPtrType /* intptr_t sizeBytes */,
170 llvmPointerType /* void *stream */,
171 llvmInt8Type /* bool isHostShared */}};
172 FunctionCallBuilder deallocCallBuilder = {
173 "mgpuMemFree",
174 llvmVoidType,
175 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
176 FunctionCallBuilder memcpyCallBuilder = {
177 "mgpuMemcpy",
178 llvmVoidType,
179 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
180 llvmIntPtrType /* intptr_t sizeBytes */,
181 llvmPointerType /* void *stream */}};
182 FunctionCallBuilder memset16CallBuilder = {
183 "mgpuMemset16",
184 llvmVoidType,
185 {llvmPointerType /* void *dst */,
186 llvmInt16Type /* unsigned short value */,
187 llvmIntPtrType /* intptr_t sizeBytes */,
188 llvmPointerType /* void *stream */}};
189 FunctionCallBuilder memset32CallBuilder = {
190 "mgpuMemset32",
191 llvmVoidType,
192 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
193 llvmIntPtrType /* intptr_t sizeBytes */,
194 llvmPointerType /* void *stream */}};
195 FunctionCallBuilder setDefaultDeviceCallBuilder = {
196 "mgpuSetDefaultDevice",
197 llvmVoidType,
198 {llvmInt32Type /* uint32_t devIndex */}};
199 FunctionCallBuilder createDnVecCallBuilder = {
200 "mgpuCreateDnVec",
201 llvmPointerType,
202 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
203 llvmPointerType /* void *stream */}};
204 FunctionCallBuilder destroyDnVecCallBuilder = {
205 "mgpuDestroyDnVec",
206 llvmVoidType,
207 {llvmPointerType, llvmPointerType /* void *stream */}};
208 FunctionCallBuilder createDnMatCallBuilder = {
209 "mgpuCreateDnMat",
210 llvmPointerType,
211 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
212 llvmPointerType /* void *stream */}};
213 FunctionCallBuilder destroyDnMatCallBuilder = {
214 "mgpuDestroyDnMat",
215 llvmVoidType,
216 {llvmPointerType, llvmPointerType /* void *stream */}};
217 FunctionCallBuilder createCooCallBuilder = {
218 "mgpuCreateCoo",
219 llvmPointerType,
220 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
221 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
222 llvmPointerType /* void *stream */}};
223 FunctionCallBuilder createCooAoSCallBuilder = {
224 "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
225 llvmPointerType,
226 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
227 llvmPointerType, llvmInt32Type, llvmInt32Type,
228 llvmPointerType /* void *stream */}};
229 FunctionCallBuilder createCsrCallBuilder = {
230 "mgpuCreateCsr",
231 llvmPointerType,
232 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
233 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
234 llvmInt32Type, llvmPointerType /* void *stream */}};
235 FunctionCallBuilder createCscCallBuilder = {
236 "mgpuCreateCsc",
237 llvmPointerType,
238 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
239 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
240 llvmInt32Type, llvmPointerType /* void *stream */}};
241 FunctionCallBuilder createBsrCallBuilder = {
242 "mgpuCreateBsr",
243 llvmPointerType,
244 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
245 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
246 llvmInt32Type, llvmInt32Type, llvmInt32Type,
247 llvmPointerType /* void *stream */}};
248 FunctionCallBuilder destroySpMatCallBuilder = {
249 "mgpuDestroySpMat",
250 llvmVoidType,
251 {llvmPointerType, llvmPointerType /* void *stream */}};
252 FunctionCallBuilder spMVBufferSizeCallBuilder = {
253 "mgpuSpMVBufferSize",
254 llvmIntPtrType,
255 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
256 llvmInt32Type, llvmPointerType /* void *stream */}};
257 FunctionCallBuilder spMVCallBuilder = {
258 "mgpuSpMV",
259 llvmVoidType,
260 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
261 llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
262 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
263 "mgpuSpMMBufferSize",
264 llvmIntPtrType,
265 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
266 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
267 FunctionCallBuilder createSpMMCallBuilder = {
268 "mgpuSpMM",
269 llvmVoidType,
270 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
271 llvmPointerType, llvmInt32Type, llvmPointerType,
272 llvmPointerType /* void *stream */}};
273 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
274 "mgpuSDDMMBufferSize",
275 llvmIntPtrType,
276 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
277 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
278 FunctionCallBuilder createSDDMMCallBuilder = {
279 "mgpuSDDMM",
280 llvmVoidType,
281 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
282 llvmPointerType, llvmInt32Type, llvmPointerType,
283 llvmPointerType /* void *stream */}};
284 FunctionCallBuilder createLtDnMatCallBuilder = {
285 "mgpuCreateCuSparseLtDnMat",
286 llvmVoidType,
287 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
288 llvmInt32Type, llvmPointerType /* void *stream */}};
289 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
290 "mgpuDestroyCuSparseLtSpMat",
291 llvmVoidType,
292 {llvmPointerType, llvmPointerType /* void *stream */}};
293 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
294 "mgpuDestroyCuSparseLtDnMat",
295 llvmVoidType,
296 {llvmPointerType, llvmPointerType /* void *stream */}};
297 FunctionCallBuilder create2To4SpMatCallBuilder = {
298 "mgpuCusparseLtCreate2To4SpMat",
299 llvmVoidType,
300 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
301 llvmInt32Type, llvmPointerType /* void *stream */}};
302 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
303 "mgpuCuSparseLtSpMMBufferSize",
304 llvmVoidType,
305 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
306 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
307 llvmPointerType /*void *stream*/}};
308 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
309 "mgpuCuSparseLtSpMM",
310 llvmVoidType,
311 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
312 llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
313 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
314 "mgpuSpGEMMCreateDescr",
315 llvmPointerType,
316 {llvmPointerType /*void *stream*/}};
317 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
318 "mgpuSpGEMMDestroyDescr",
319 llvmVoidType,
320 {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
321 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
322 "mgpuSpGEMMWorkEstimation",
323 llvmIntPtrType,
324 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
325 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
326 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
327 llvmPointerType /*void *stream*/}};
328 FunctionCallBuilder createSpGEMMComputeBuilder = {
329 "mgpuSpGEMMCompute",
330 llvmIntPtrType,
331 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
332 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
333 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
334 llvmPointerType /*void *stream*/}};
335 FunctionCallBuilder createSpGEMMCopyBuilder = {
336 "mgpuSpGEMMCopy",
337 llvmVoidType,
338 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
339 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
340 llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
341 FunctionCallBuilder createSpMatGetSizeBuilder = {
342 "mgpuSpMatGetSize",
343 llvmVoidType,
344 {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
345 llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
346 FunctionCallBuilder createSetCsrPointersBuilder = {
347 "mgpuSetCsrPointers",
348 llvmVoidType,
349 {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
350 llvmPointerType /*crd*/, llvmPointerType /*val*/,
351 llvmPointerType /*void *stream*/}};
352};
353
354/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
355/// call. Currently it supports CUDA and ROCm (HIP).
356class ConvertHostRegisterOpToGpuRuntimeCallPattern
357 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
358public:
359 ConvertHostRegisterOpToGpuRuntimeCallPattern(
360 const LLVMTypeConverter &typeConverter)
361 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
362
363private:
364 LogicalResult
365 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
366 ConversionPatternRewriter &rewriter) const override;
367};
368
369class ConvertHostUnregisterOpToGpuRuntimeCallPattern
370 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
371public:
372 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
373 const LLVMTypeConverter &typeConverter)
374 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
375 }
376
377private:
378 LogicalResult
379 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
380 ConversionPatternRewriter &rewriter) const override;
381};
382
383/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
384/// call. Currently it supports CUDA and ROCm (HIP).
385class ConvertAllocOpToGpuRuntimeCallPattern
386 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
387public:
388 ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
389 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
390
391private:
392 LogicalResult
393 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
394 ConversionPatternRewriter &rewriter) const override;
395};
396
397/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
398/// call. Currently it supports CUDA and ROCm (HIP).
399class ConvertDeallocOpToGpuRuntimeCallPattern
400 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
401public:
402 ConvertDeallocOpToGpuRuntimeCallPattern(
403 const LLVMTypeConverter &typeConverter)
404 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
405
406private:
407 LogicalResult
408 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
409 ConversionPatternRewriter &rewriter) const override;
410};
411
412class ConvertAsyncYieldToGpuRuntimeCallPattern
413 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
414public:
415 ConvertAsyncYieldToGpuRuntimeCallPattern(
416 const LLVMTypeConverter &typeConverter)
417 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
418
419private:
420 LogicalResult
421 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
422 ConversionPatternRewriter &rewriter) const override;
423};
424
425/// A rewrite pattern to convert gpu.wait operations into a GPU runtime
426/// call. Currently it supports CUDA and ROCm (HIP).
427class ConvertWaitOpToGpuRuntimeCallPattern
428 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
429public:
430 ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
431 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
432
433private:
434 LogicalResult
435 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
436 ConversionPatternRewriter &rewriter) const override;
437};
438
439/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
440/// call. Currently it supports CUDA and ROCm (HIP).
441class ConvertWaitAsyncOpToGpuRuntimeCallPattern
442 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
443public:
444 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
445 const LLVMTypeConverter &typeConverter)
446 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
447
448private:
449 LogicalResult
450 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter) const override;
452};
453
454/// A rewrite patter to convert gpu.launch_func operations into a sequence of
455/// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
456///
457/// In essence, a gpu.launch_func operations gets compiled into the following
458/// sequence of runtime calls:
459///
460/// * moduleLoad -- loads the module given the cubin / hsaco data
461/// * moduleGetFunction -- gets a handle to the actual kernel function
462/// * getStreamHelper -- initializes a new compute stream on GPU
463/// * launchKernel -- launches the kernel on a stream
464/// * streamSynchronize -- waits for operations on the stream to finish
465///
466/// Intermediate data structures are allocated on the stack.
467class ConvertLaunchFuncOpToGpuRuntimeCallPattern
468 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
469public:
470 ConvertLaunchFuncOpToGpuRuntimeCallPattern(
471 const LLVMTypeConverter &typeConverter, StringRef gpuBinaryAnnotation,
472 bool kernelBarePtrCallConv, SymbolTable *cachedModuleTable)
473 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
474 gpuBinaryAnnotation(gpuBinaryAnnotation),
475 kernelBarePtrCallConv(kernelBarePtrCallConv),
476 cachedModuleTable(cachedModuleTable) {}
477
478private:
479 Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
480 OpBuilder &builder) const;
481 Value generateKernelNameConstant(StringRef moduleName, StringRef name,
482 Location loc, OpBuilder &builder) const;
483
484 LogicalResult
485 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
486 ConversionPatternRewriter &rewriter) const override;
487
488 llvm::SmallString<32> gpuBinaryAnnotation;
489 bool kernelBarePtrCallConv;
490 SymbolTable *cachedModuleTable;
491};
492
493class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
494 using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
495
496 LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
497 PatternRewriter &rewriter) const override {
498 // GPU kernel modules are no longer necessary since we have a global
499 // constant with the CUBIN, or HSACO data.
500 rewriter.eraseOp(op: op);
501 return success();
502 }
503};
504
505/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
506/// call. Currently it supports CUDA and ROCm (HIP).
507class ConvertMemcpyOpToGpuRuntimeCallPattern
508 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
509public:
510 ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
511 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
512
513private:
514 LogicalResult
515 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
516 ConversionPatternRewriter &rewriter) const override;
517};
518
519/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
520/// call. Currently it supports CUDA and ROCm (HIP).
521class ConvertMemsetOpToGpuRuntimeCallPattern
522 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
523public:
524 ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
525 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
526
527private:
528 LogicalResult
529 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
530 ConversionPatternRewriter &rewriter) const override;
531};
532
533/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
534/// Currently supports CUDA and ROCm (HIP)
535class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
536 : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
537public:
538 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
539 const LLVMTypeConverter &typeConverter)
540 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
541 typeConverter) {}
542
543 LogicalResult
544 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
545 ConversionPatternRewriter &rewriter) const override;
546};
547
548/// Generic rewriting rule for operation on sparse matrices.
549/// Currently supports CUDA (by means of cuSparse and cuSparseLt).
550#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
551 class Convert##op_name##ToGpuRuntimeCallPattern \
552 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
553 public: \
554 Convert##op_name##ToGpuRuntimeCallPattern( \
555 const LLVMTypeConverter &typeConverter) \
556 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
557 \
558 private: \
559 LogicalResult \
560 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
561 ConversionPatternRewriter &rewriter) const override; \
562 };
563
564DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp)
565DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp)
566DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp)
567DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp)
568DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp)
569DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCscOp)
570DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateBsrOp)
571DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp)
572DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp)
573DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp)
574DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp)
575DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp)
576DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp)
577DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp)
578DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp)
579DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp)
580DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp)
581DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
582DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp)
583DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMatGetSizeOp)
584DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
585
586} // namespace
587
588void GpuToLLVMConversionPass::runOnOperation() {
589 MLIRContext *context = &getContext();
590 SymbolTable symbolTable = SymbolTable(getOperation());
591 LowerToLLVMOptions options(context);
592 options.useBarePtrCallConv = hostBarePtrCallConv;
593 RewritePatternSet patterns(context);
594 ConversionTarget target(*context);
595 target.addLegalDialect<LLVM::LLVMDialect>();
596 LLVMTypeConverter converter(context, options);
597
598 // Populate all patterns from all dialects that implement the
599 // `ConvertToLLVMPatternInterface` interface.
600 for (Dialect *dialect : context->getLoadedDialects()) {
601 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
602 if (!iface)
603 continue;
604 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
605 }
606
607 // Preserve GPU modules if they have target attributes.
608 target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
609 callback: [](gpu::GPUModuleOp module) -> bool {
610 return module.getTargetsAttr() != nullptr;
611 });
612 // Accept as legal LaunchFuncOps if they refer to GPU Modules with targets and
613 // the operands have been lowered.
614 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
615 [&](gpu::LaunchFuncOp op) -> bool {
616 auto module =
617 symbolTable.lookup<gpu::GPUModuleOp>(op.getKernelModuleName());
618 return converter.isLegal(op->getOperandTypes()) &&
619 converter.isLegal(op->getResultTypes()) &&
620 (module && module.getTargetsAttr() &&
621 !module.getTargetsAttr().empty());
622 });
623
624 // These aren't covered by the ConvertToLLVMPatternInterface right now.
625 populateVectorToLLVMConversionPatterns(converter, patterns);
626 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
627 populateAsyncStructuralTypeConversionsAndLegality(typeConverter&: converter, patterns,
628 target);
629 populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
630 kernelBarePtrCallConv, &symbolTable);
631
632 if (failed(
633 applyPartialConversion(getOperation(), target, std::move(patterns))))
634 signalPassFailure();
635}
636
637LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
638 ArrayRef<Value> arguments) const {
639 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
640 auto function = [&] {
641 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
642 return function;
643 return OpBuilder::atBlockEnd(module.getBody())
644 .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
645 }();
646 return builder.create<LLVM::CallOp>(loc, function, arguments);
647}
648
649// Corresponding to cusparseIndexType_t defined in cusparse.h.
650static int32_t getCuSparseIndexTypeFrom(Type type) {
651 if (type.isInteger(width: 16))
652 return 1; // CUSPARSE_INDEX_16U
653 if (type.isInteger(width: 32))
654 return 2; // CUSPARSE_INDEX_32I
655 return 3; // CUSPARSE_INDEX_64I
656}
657
658static int32_t getCuSparseLtDataTypeFrom(Type type) {
659 if (type.isF16())
660 return 0; // CUSPARSE_COMPUTE_16F,
661 if (type.isInteger(width: 32))
662 return 1; // CUSPARSE_COMPUTE_32I
663 llvm_unreachable("unsupported type");
664 // TODO: add support to TF32
665}
666
667// Corresponding to cudaDataType_t defined in CUDA library_types.h.
668static int32_t getCuSparseDataTypeFrom(Type type) {
669 if (llvm::isa<ComplexType>(type)) {
670 // get the element type
671 auto elementType = cast<ComplexType>(type).getElementType();
672 if (elementType.isBF16())
673 return 15; // CUDA_C_16BF
674 if (elementType.isF16())
675 return 6; // CUDA_C_16F
676 if (elementType.isF32())
677 return 4; // CUDA_C_32F
678 if (elementType.isF64())
679 return 5; // CUDA_C_64F
680 if (elementType.isInteger(8))
681 return 7; // CUDA_C_8I
682 if (elementType.isInteger(16))
683 return 21; // CUDA_C_16I
684 if (elementType.isInteger(32))
685 return 11; // CUDA_C_32I
686 }
687 if (type.isBF16())
688 return 14; // CUDA_R_16BF
689 if (type.isF16())
690 return 2; // CUDA_R_16F
691 if (type.isF32())
692 return 0; // CUDA_R_32F
693 if (type.isF64())
694 return 1; // CUDA_R_64F
695 if (type.isInteger(width: 8))
696 return 3; // CUDA_R_8I
697 if (type.isInteger(width: 16))
698 return 20; // CUDA_R_16I
699 if (type.isInteger(width: 32))
700 return 10; // CUDA_R_32I
701
702 llvm_unreachable("unsupported element type");
703}
704
705static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
706 return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
707}
708
709// TODO: We may want a run-time (of the mlir compiler) disablement/warning:
710// cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
711// runtime (of the CUDA program) error , but it might be great if we could at
712// least output a warning when we found the target architecture is <8.0 and the
713// user still wants to use cusparseLt. to make sure when lowering gpu sparse
714// dialect to llvm calls, the cusparselt calls are disabled for cuda
715// architecture <8.0
716static bool is2To4Sparsity(Value spMat) {
717 if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
718 return true;
719 if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
720 return false;
721 if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
722 return false;
723 if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
724 return false;
725 if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
726 return false;
727 if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
728 return false;
729 // Print the spMat defining op
730 spMat.getDefiningOp()->print(os&: llvm::errs());
731 llvm_unreachable("cannot find spmat def");
732}
733
734static bool isSpMMCusparseLtOp(Value op) {
735 for (Operation *user : op.getUsers()) {
736 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
737 // If the other operator is 50% sparsity then we should use cusparseLt
738 if (!spmmOp)
739 continue;
740 if (is2To4Sparsity(spmmOp.getSpmatA()))
741 return true;
742 }
743 return false;
744}
745
746// Returns whether all operands are of LLVM type.
747static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
748 ConversionPatternRewriter &rewriter) {
749 if (!llvm::all_of(Range&: operands, P: [](Value value) {
750 return LLVM::isCompatibleType(type: value.getType());
751 }))
752 return rewriter.notifyMatchFailure(
753 arg&: op, msg: "Cannot convert if operands aren't of LLVM type.");
754 return success();
755}
756
757static LogicalResult
758isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
759 gpu::AsyncOpInterface op) {
760 if (op.getAsyncDependencies().size() != 1)
761 return rewriter.notifyMatchFailure(
762 op, "Can only convert with exactly one async dependency.");
763
764 if (!op.getAsyncToken())
765 return rewriter.notifyMatchFailure(op, "Can convert only async version.");
766
767 return success();
768}
769
770LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
771 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
772 ConversionPatternRewriter &rewriter) const {
773 auto *op = hostRegisterOp.getOperation();
774 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
775 return failure();
776
777 Location loc = op->getLoc();
778
779 auto memRefType = hostRegisterOp.getValue().getType();
780 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
781 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
782
783 auto arguments = getTypeConverter()->promoteOperands(
784 loc, op->getOperands(), adaptor.getOperands(), rewriter);
785 arguments.push_back(elementSize);
786 hostRegisterCallBuilder.create(loc, rewriter, arguments);
787
788 rewriter.eraseOp(op: op);
789 return success();
790}
791
792LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
793 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
794 ConversionPatternRewriter &rewriter) const {
795 Operation *op = hostUnregisterOp.getOperation();
796 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
797 return failure();
798
799 Location loc = op->getLoc();
800
801 auto memRefType = hostUnregisterOp.getValue().getType();
802 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
803 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
804
805 auto arguments = getTypeConverter()->promoteOperands(
806 loc, op->getOperands(), adaptor.getOperands(), rewriter);
807 arguments.push_back(elementSize);
808 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
809
810 rewriter.eraseOp(op);
811 return success();
812}
813
814LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
815 gpu::AllocOp allocOp, OpAdaptor adaptor,
816 ConversionPatternRewriter &rewriter) const {
817
818 MemRefType memRefType = allocOp.getType();
819
820 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
821 !isConvertibleAndHasIdentityMaps(memRefType))
822 return failure();
823
824 auto loc = allocOp.getLoc();
825
826 bool isShared = allocOp.getHostShared();
827
828 if (isShared && allocOp.getAsyncToken())
829 return rewriter.notifyMatchFailure(
830 allocOp, "Host Shared allocation cannot be done async");
831 if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
832 return failure();
833
834 // Get shape of the memref as values: static sizes are constant
835 // values and dynamic sizes are passed to 'alloc' as operands.
836 SmallVector<Value, 4> shape;
837 SmallVector<Value, 4> strides;
838 Value sizeBytes;
839 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
840 shape, strides, sizeBytes);
841
842 // Allocate the underlying buffer and store a pointer to it in the MemRef
843 // descriptor.
844 auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
845 Value stream = adaptor.getAsyncDependencies().empty()
846 ? nullPtr
847 : adaptor.getAsyncDependencies().front();
848
849 auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
850 loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
851
852 Value allocatedPtr =
853 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
854 .getResult();
855
856 // No alignment.
857 Value alignedPtr = allocatedPtr;
858
859 // Create the MemRef descriptor.
860 auto memRefDescriptor = this->createMemRefDescriptor(
861 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
862
863 if (allocOp.getAsyncToken()) {
864 // Async alloc: make dependent ops use the same stream.
865 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
866 } else {
867 rewriter.replaceOp(allocOp, {memRefDescriptor});
868 }
869
870 return success();
871}
872
873LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
874 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
875 ConversionPatternRewriter &rewriter) const {
876 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
877 failed(isAsyncWithOneDependency(rewriter, deallocOp)))
878 return failure();
879
880 Location loc = deallocOp.getLoc();
881
882 Value pointer =
883 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc);
884 Value stream = adaptor.getAsyncDependencies().front();
885 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
886
887 rewriter.replaceOp(deallocOp, {stream});
888 return success();
889}
890
891static bool isGpuAsyncTokenType(Value value) {
892 return isa<gpu::AsyncTokenType>(Val: value.getType());
893}
894
895// Converts !gpu.async.token operands of `async.yield` to runtime calls. The
896// !gpu.async.token are lowered to stream within the async.execute region, but
897// are passed as events between them. For each !gpu.async.token operand, we
898// create an event and record it on the stream.
899LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
900 async::YieldOp yieldOp, OpAdaptor adaptor,
901 ConversionPatternRewriter &rewriter) const {
902 if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
903 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
904
905 Location loc = yieldOp.getLoc();
906 SmallVector<Value, 4> newOperands(adaptor.getOperands());
907 llvm::SmallDenseSet<Value> streams;
908 for (auto &operand : yieldOp->getOpOperands()) {
909 if (!isGpuAsyncTokenType(operand.get()))
910 continue;
911 auto idx = operand.getOperandNumber();
912 auto stream = adaptor.getOperands()[idx];
913 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
914 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
915 newOperands[idx] = event;
916 streams.insert(stream);
917 }
918 for (auto stream : streams)
919 streamDestroyCallBuilder.create(loc, rewriter, {stream});
920
921 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
922 return success();
923}
924
925// Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
926static bool isDefinedByCallTo(Value value, StringRef functionName) {
927 assert(isa<LLVM::LLVMPointerType>(value.getType()));
928 if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
929 return defOp.getCallee()->equals(functionName);
930 return false;
931}
932
933// Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
934// with the stream/event operands. The operands are destroyed. That is, it
935// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
936// runtime error. Eventually, we should guarantee this property.
937LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
938 gpu::WaitOp waitOp, OpAdaptor adaptor,
939 ConversionPatternRewriter &rewriter) const {
940 if (waitOp.getAsyncToken())
941 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
942
943 Location loc = waitOp.getLoc();
944
945 for (auto operand : adaptor.getOperands()) {
946 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
947 // The converted operand's definition created a stream.
948 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
949 streamDestroyCallBuilder.create(loc, rewriter, {operand});
950 } else {
951 // Otherwise the converted operand is an event. This assumes that we use
952 // events in control flow code as well.
953 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
954 eventDestroyCallBuilder.create(loc, rewriter, {operand});
955 }
956 }
957
958 rewriter.eraseOp(op: waitOp);
959 return success();
960}
961
962// Converts `gpu.wait async` to runtime calls. The converted op creates a new
963// stream that is synchronized with stream/event operands. The operands are
964// destroyed. That is, it assumes that it is not used afterwards or elsewhere.
965// Otherwise we will get a runtime error. Eventually, we should guarantee this
966// property.
967LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
968 gpu::WaitOp waitOp, OpAdaptor adaptor,
969 ConversionPatternRewriter &rewriter) const {
970 if (!waitOp.getAsyncToken())
971 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
972
973 Location loc = waitOp.getLoc();
974
975 auto insertionPoint = rewriter.saveInsertionPoint();
976 SmallVector<Value, 1> events;
977 for (auto pair :
978 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
979 auto operand = std::get<1>(pair);
980 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
981 // The converted operand's definition created a stream. Insert an event
982 // into the stream just after the last use of the original token operand.
983 auto *defOp = std::get<0>(pair).getDefiningOp();
984 rewriter.setInsertionPointAfter(defOp);
985 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
986 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
987 events.push_back(event);
988 } else {
989 // Otherwise the converted operand is an event. This assumes that we use
990 // events in control flow code as well.
991 events.push_back(operand);
992 }
993 }
994 rewriter.restoreInsertionPoint(ip: insertionPoint);
995 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
996 for (auto event : events)
997 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
998 for (auto event : events)
999 eventDestroyCallBuilder.create(loc, rewriter, {event});
1000 rewriter.replaceOp(waitOp, {stream});
1001
1002 return success();
1003}
1004
1005// Creates a struct containing all kernel parameters on the stack and returns
1006// an array of type-erased pointers to the fields of the struct. The array can
1007// then be passed to the CUDA / ROCm (HIP) kernel launch calls.
1008// The generated code is essentially as follows:
1009//
1010// %struct = alloca(sizeof(struct { Parameters... }))
1011// %array = alloca(NumParameters * sizeof(void *))
1012// for (i : [0, NumParameters))
1013// %fieldPtr = llvm.getelementptr %struct[0, i]
1014// llvm.store parameters[i], %fieldPtr
1015// %elementPtr = llvm.getelementptr %array[i]
1016// llvm.store %fieldPtr, %elementPtr
1017// return %array
1018Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
1019 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
1020 auto loc = launchOp.getLoc();
1021 auto numKernelOperands = launchOp.getNumKernelOperands();
1022 // Note: If `useBarePtrCallConv` is set in the type converter's options,
1023 // the value of `kernelBarePtrCallConv` will be ignored.
1024 SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
1025 loc, launchOp.getOperands().take_back(numKernelOperands),
1026 adaptor.getOperands().take_back(numKernelOperands), builder,
1027 /*useBarePtrCallConv=*/kernelBarePtrCallConv);
1028 auto numArguments = arguments.size();
1029 SmallVector<Type, 4> argumentTypes;
1030 argumentTypes.reserve(N: numArguments);
1031 for (auto argument : arguments)
1032 argumentTypes.push_back(argument.getType());
1033 auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
1034 argumentTypes);
1035 auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 1);
1036 auto structPtr =
1037 builder.create<LLVM::AllocaOp>(loc, llvmPointerType, structType, one,
1038 /*alignment=*/0);
1039 auto arraySize =
1040 builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, numArguments);
1041 auto arrayPtr = builder.create<LLVM::AllocaOp>(
1042 loc, llvmPointerType, llvmPointerType, arraySize, /*alignment=*/0);
1043 for (const auto &en : llvm::enumerate(arguments)) {
1044 const auto index = static_cast<int32_t>(en.index());
1045 Value fieldPtr =
1046 builder.create<LLVM::GEPOp>(loc, llvmPointerType, structType, structPtr,
1047 ArrayRef<LLVM::GEPArg>{0, index});
1048 builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
1049 auto elementPtr =
1050 builder.create<LLVM::GEPOp>(loc, llvmPointerType, llvmPointerType,
1051 arrayPtr, ArrayRef<LLVM::GEPArg>{index});
1052 builder.create<LLVM::StoreOp>(loc, fieldPtr, elementPtr);
1053 }
1054 return arrayPtr;
1055}
1056
1057// Generates an LLVM IR dialect global that contains the name of the given
1058// kernel function as a C string, and returns a pointer to its beginning.
1059// The code is essentially:
1060//
1061// llvm.global constant @kernel_name("function_name\00")
1062// func(...) {
1063// %0 = llvm.addressof @kernel_name
1064// %1 = llvm.constant (0 : index)
1065// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
1066// }
1067Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
1068 StringRef moduleName, StringRef name, Location loc,
1069 OpBuilder &builder) const {
1070 // Make sure the trailing zero is included in the constant.
1071 std::vector<char> kernelName(name.begin(), name.end());
1072 kernelName.push_back(x: '\0');
1073
1074 std::string globalName =
1075 std::string(llvm::formatv(Fmt: "{0}_{1}_kernel_name", Vals&: moduleName, Vals&: name));
1076 return LLVM::createGlobalString(
1077 loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
1078 LLVM::Linkage::Internal);
1079}
1080
1081// Emits LLVM IR to launch a kernel function. Expects the module that contains
1082// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
1083// hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
1084//
1085// %0 = call %binarygetter
1086// %1 = call %moduleLoad(%0)
1087// %2 = <see generateKernelNameConstant>
1088// %3 = call %moduleGetFunction(%1, %2)
1089// %4 = call %streamCreate()
1090// %5 = <see generateParamsArray>
1091// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
1092// call %streamSynchronize(%4)
1093// call %streamDestroy(%4)
1094// call %moduleUnload(%1)
1095//
1096// If the op is async, the stream corresponds to the (single) async dependency
1097// as well as the async token the op produces.
1098LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
1099 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
1100 ConversionPatternRewriter &rewriter) const {
1101 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
1102 return failure();
1103
1104 if (launchOp.getAsyncDependencies().size() > 1)
1105 return rewriter.notifyMatchFailure(
1106 launchOp, "Cannot convert with more than one async dependency.");
1107
1108 // Fail when the synchronous version of the op has async dependencies. The
1109 // lowering destroys the stream, and we do not want to check that there is no
1110 // use of the stream after this op.
1111 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
1112 return rewriter.notifyMatchFailure(
1113 launchOp, "Cannot convert non-async op with async dependencies.");
1114
1115 Location loc = launchOp.getLoc();
1116
1117 // Create an LLVM global with CUBIN extracted from the kernel annotation and
1118 // obtain a pointer to the first byte in it.
1119 gpu::GPUModuleOp kernelModule;
1120 if (cachedModuleTable)
1121 kernelModule = cachedModuleTable->lookup<gpu::GPUModuleOp>(
1122 launchOp.getKernelModuleName());
1123 else
1124 kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
1125 launchOp, launchOp.getKernelModuleName());
1126 assert(kernelModule && "expected a kernel module");
1127
1128 // If the module has Targets then just update the op operands.
1129 if (ArrayAttr targets = kernelModule.getTargetsAttr()) {
1130 Value stream = Value();
1131 if (!adaptor.getAsyncDependencies().empty())
1132 stream = adaptor.getAsyncDependencies().front();
1133 // If the async keyword is present and there are no dependencies, then a
1134 // stream must be created to pass to subsequent operations.
1135 else if (launchOp.getAsyncToken())
1136 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
1137
1138 // Lower the kernel operands to match kernel parameters.
1139 // Note: If `useBarePtrCallConv` is set in the type converter's options,
1140 // the value of `kernelBarePtrCallConv` will be ignored.
1141 SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
1142 loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
1143 rewriter, /*useBarePtrCallConv=*/kernelBarePtrCallConv);
1144
1145 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1146 if (launchOp.hasClusterSize()) {
1147 clusterSize =
1148 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1149 adaptor.getClusterSizeZ()};
1150 }
1151 rewriter.create<gpu::LaunchFuncOp>(
1152 launchOp.getLoc(), launchOp.getKernelAttr(),
1153 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1154 adaptor.getGridSizeZ()},
1155 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1156 adaptor.getBlockSizeZ()},
1157 adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
1158 if (launchOp.getAsyncToken())
1159 rewriter.replaceOp(launchOp, {stream});
1160 else
1161 rewriter.eraseOp(op: launchOp);
1162 return success();
1163 }
1164
1165 auto binaryAttr =
1166 kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
1167 if (!binaryAttr) {
1168 kernelModule.emitOpError()
1169 << "missing " << gpuBinaryAnnotation << " attribute";
1170 return failure();
1171 }
1172
1173 SmallString<128> nameBuffer(kernelModule.getName());
1174 nameBuffer.append(RHS: kGpuBinaryStorageSuffix);
1175 Value data =
1176 LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
1177 binaryAttr.getValue(), LLVM::Linkage::Internal);
1178
1179 // Pass the binary size. SPIRV requires binary size.
1180 auto gpuBlob = binaryAttr.getValue();
1181 auto gpuBlobSize = rewriter.create<mlir::LLVM::ConstantOp>(
1182 loc, llvmInt64Type,
1183 mlir::IntegerAttr::get(llvmInt64Type,
1184 static_cast<int64_t>(gpuBlob.size())));
1185
1186 auto module =
1187 moduleLoadCallBuilder.create(loc, rewriter, {data, gpuBlobSize});
1188
1189 // Pass the count of the parameters to runtime wrappers
1190 auto paramsCount = rewriter.create<mlir::LLVM::ConstantOp>(
1191 loc, llvmInt64Type,
1192 mlir::IntegerAttr::get(
1193 llvmInt64Type,
1194 static_cast<int64_t>(launchOp.getNumKernelOperands())));
1195
1196 // Get the function from the module. The name corresponds to the name of
1197 // the kernel function.
1198 auto kernelName = generateKernelNameConstant(
1199 moduleName: launchOp.getKernelModuleName().getValue(),
1200 name: launchOp.getKernelName().getValue(), loc, builder&: rewriter);
1201 auto function = moduleGetFunctionCallBuilder.create(
1202 loc, rewriter, {module.getResult(), kernelName});
1203 Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
1204 Value stream =
1205 adaptor.getAsyncDependencies().empty()
1206 ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult()
1207 : adaptor.getAsyncDependencies().front();
1208 // Create array of pointers to kernel arguments.
1209 auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
1210 auto nullpointer = rewriter.create<LLVM::ZeroOp>(loc, llvmPointerType);
1211 Value dynamicSharedMemorySize = launchOp.getDynamicSharedMemorySize()
1212 ? launchOp.getDynamicSharedMemorySize()
1213 : zero;
1214 launchKernelCallBuilder.create(
1215 loc, rewriter,
1216 {function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1217 adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1218 adaptor.getBlockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
1219 /*extra=*/nullpointer, paramsCount});
1220
1221 if (launchOp.getAsyncToken()) {
1222 // Async launch: make dependent ops use the same stream.
1223 rewriter.replaceOp(launchOp, {stream});
1224 } else {
1225 // Synchronize with host and destroy stream. This must be the stream created
1226 // above (with no other uses) because we check that the synchronous version
1227 // does not have any async dependencies.
1228 streamSynchronizeCallBuilder.create(loc, rewriter, stream);
1229 streamDestroyCallBuilder.create(loc, rewriter, stream);
1230 rewriter.eraseOp(op: launchOp);
1231 }
1232 moduleUnloadCallBuilder.create(loc, rewriter, module.getResult());
1233
1234 return success();
1235}
1236
1237static Value bitAndAddrspaceCast(Location loc,
1238 ConversionPatternRewriter &rewriter,
1239 LLVM::LLVMPointerType destinationType,
1240 Value sourcePtr,
1241 const LLVMTypeConverter &typeConverter) {
1242 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1243 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1244 sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1245 loc,
1246 LLVM::LLVMPointerType::get(rewriter.getContext(),
1247 destinationType.getAddressSpace()),
1248 sourcePtr);
1249 return sourcePtr;
1250}
1251
1252LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1253 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1254 ConversionPatternRewriter &rewriter) const {
1255 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1256
1257 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1258 !isConvertibleAndHasIdentityMaps(memRefType) ||
1259 failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1260 return failure();
1261
1262 auto loc = memcpyOp.getLoc();
1263
1264 MemRefDescriptor srcDesc(adaptor.getSrc());
1265 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1266
1267 Type elementPtrType = getElementPtrType(memRefType);
1268 Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
1269 Value gepPtr = rewriter.create<LLVM::GEPOp>(
1270 loc, elementPtrType,
1271 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1272 numElements);
1273 auto sizeBytes =
1274 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1275
1276 auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1277 srcDesc.alignedPtr(rewriter, loc),
1278 *getTypeConverter());
1279 auto dst = bitAndAddrspaceCast(
1280 loc, rewriter, llvmPointerType,
1281 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1282 *getTypeConverter());
1283
1284 auto stream = adaptor.getAsyncDependencies().front();
1285 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1286
1287 rewriter.replaceOp(memcpyOp, {stream});
1288
1289 return success();
1290}
1291
1292LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1293 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1294 ConversionPatternRewriter &rewriter) const {
1295 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1296
1297 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1298 !isConvertibleAndHasIdentityMaps(memRefType) ||
1299 failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1300 return failure();
1301
1302 auto loc = memsetOp.getLoc();
1303
1304 Type valueType = adaptor.getValue().getType();
1305 unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1306 // Ints and floats of 16 or 32 bit width are allowed.
1307 if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1308 return rewriter.notifyMatchFailure(
1309 memsetOp, "value must be a 16 or 32 bit int or float");
1310 }
1311
1312 unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1313 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1314
1315 MemRefDescriptor dstDesc(adaptor.getDst());
1316 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1317
1318 auto value =
1319 rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1320 auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1321 dstDesc.alignedPtr(rewriter, loc),
1322 *getTypeConverter());
1323
1324 auto stream = adaptor.getAsyncDependencies().front();
1325 FunctionCallBuilder builder =
1326 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1327 builder.create(loc, rewriter, {dst, value, numElements, stream});
1328
1329 rewriter.replaceOp(memsetOp, {stream});
1330 return success();
1331}
1332
1333LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1334 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1335 ConversionPatternRewriter &rewriter) const {
1336 Location loc = op.getLoc();
1337 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1338 {adaptor.getDevIndex()});
1339 rewriter.replaceOp(op, call);
1340 return success();
1341}
1342
1343template <typename T>
1344static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1345 Type llvmInt32Type = builder.getIntegerType(32);
1346 return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
1347 static_cast<int32_t>(tValue));
1348}
1349
1350template <typename T>
1351static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1352 Type llvmFloat32Type = builder.getF32Type();
1353 return builder.create<LLVM::ConstantOp>(
1354 loc, llvmFloat32Type,
1355 builder.getF32FloatAttr(static_cast<float>(tValue)));
1356}
1357
1358LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1359 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1360 ConversionPatternRewriter &rewriter) const {
1361 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1362 failed(isAsyncWithOneDependency(rewriter, op)))
1363 return failure();
1364 Location loc = op.getLoc();
1365 auto stream = adaptor.getAsyncDependencies().front();
1366 Value pTensor =
1367 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc);
1368 Type dType = op.getMemref().getType().getElementType();
1369 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1370
1371 SmallVector<Value, 4> dims;
1372 for (Value dim : adaptor.getDims()) {
1373 dims.push_back(dim);
1374 }
1375
1376 Value handle;
1377 // TODO: For now, we track the use of the handle and lower it to cusparse /
1378 // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1379 // used, we require two separate Creation ops to be the correct logic. In
1380 // future, we may add support to using one handle in sparse tensor / GPU
1381 // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1382 // the dnmat is used with spmat with 2:4 sparsity
1383 if (dims.size() == 2) {
1384 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1385 auto handleSz = rewriter.create<LLVM::ConstantOp>(
1386 loc, getIndexType(), rewriter.getIndexAttr(11032));
1387 handle = rewriter.create<LLVM::AllocaOp>(
1388 loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1389 handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1390
1391 createLtDnMatCallBuilder
1392 .create(loc, rewriter,
1393 {handle, dims[0], dims[1], pTensor, dtp, stream})
1394 .getResult();
1395 } else {
1396 handle =
1397 createDnMatCallBuilder
1398 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1399 .getResult();
1400 }
1401 } else {
1402 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1403 handle = createDnVecCallBuilder
1404 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1405 .getResult();
1406 }
1407 rewriter.replaceOp(op, {handle, stream});
1408 return success();
1409}
1410
1411LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1412 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1413 ConversionPatternRewriter &rewriter) const {
1414 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1415 failed(isAsyncWithOneDependency(rewriter, op)))
1416 return failure();
1417 Location loc = op.getLoc();
1418 auto stream = adaptor.getAsyncDependencies().front();
1419 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1420 SmallVector<Value, 4> dims;
1421 for (Value dim : definingOp.getDims()) {
1422 dims.push_back(dim);
1423 }
1424 if (dims.size() == 2) {
1425 // Use the cusparseLt destroy call if the dnmat is used with spmat with
1426 // 2:4 sparsity
1427 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1428 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1429 {adaptor.getDnTensor(), stream});
1430 } else {
1431 destroyDnMatCallBuilder.create(loc, rewriter,
1432 {adaptor.getDnTensor(), stream});
1433 }
1434 } else {
1435 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1436 destroyDnVecCallBuilder.create(loc, rewriter,
1437 {adaptor.getDnTensor(), stream});
1438 }
1439 rewriter.replaceOp(op, {stream});
1440 return success();
1441}
1442
1443LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1444 gpu::CreateCooOp op, OpAdaptor adaptor,
1445 ConversionPatternRewriter &rewriter) const {
1446 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1447 failed(isAsyncWithOneDependency(rewriter, op)))
1448 return failure();
1449 Location loc = op.getLoc();
1450 auto stream = adaptor.getAsyncDependencies().front();
1451 Value pRowIdxs =
1452 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(builder&: rewriter, loc);
1453 Value pColIdxs =
1454 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(builder&: rewriter, loc);
1455 Value pValues =
1456 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1457 Type iType =
1458 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1459 Type dType =
1460 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1461 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1462 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1463 auto handle =
1464 createCooCallBuilder
1465 .create(loc, rewriter,
1466 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1467 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1468 .getResult();
1469 rewriter.replaceOp(op, {handle, stream});
1470 return success();
1471}
1472
1473LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1474 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1475 ConversionPatternRewriter &rewriter) const {
1476 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1477 failed(isAsyncWithOneDependency(rewriter, op)))
1478 return failure();
1479 Location loc = op.getLoc();
1480 auto stream = adaptor.getAsyncDependencies().front();
1481 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(builder&: rewriter, loc);
1482 Value pValues =
1483 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1484 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1485 Type dType =
1486 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1487 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1488 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1489 auto handle =
1490 createCooAoSCallBuilder
1491 .create(loc, rewriter,
1492 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1493 pIdxs, pValues, itp, dtp, stream})
1494 .getResult();
1495 rewriter.replaceOp(op, {handle, stream});
1496 return success();
1497}
1498
1499LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1500 gpu::CreateCsrOp op, OpAdaptor adaptor,
1501 ConversionPatternRewriter &rewriter) const {
1502 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1503 failed(isAsyncWithOneDependency(rewriter, op)))
1504 return failure();
1505 Location loc = op.getLoc();
1506 auto stream = adaptor.getAsyncDependencies().front();
1507 Value pRowPos =
1508 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(builder&: rewriter, loc);
1509 Value pColIdxs =
1510 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(builder&: rewriter, loc);
1511 Value pValues =
1512 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1513 Type pType =
1514 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1515 Type iType =
1516 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1517 Type dType =
1518 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1519 auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType));
1520 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1521 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1522 auto handle =
1523 createCsrCallBuilder
1524 .create(loc, rewriter,
1525 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1526 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1527 .getResult();
1528 rewriter.replaceOp(op, {handle, stream});
1529 return success();
1530}
1531
1532LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1533 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1534 ConversionPatternRewriter &rewriter) const {
1535 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1536 failed(isAsyncWithOneDependency(rewriter, op)))
1537 return failure();
1538 Location loc = op.getLoc();
1539 auto stream = adaptor.getAsyncDependencies().front();
1540 Value pMat =
1541 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc);
1542 Type dType =
1543 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1544 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1545
1546 // CUDA runner asserts the size is 44104 bytes.
1547 auto handleSz = rewriter.create<LLVM::ConstantOp>(
1548 loc, getIndexType(), rewriter.getIndexAttr(44104));
1549 Value handle = rewriter.create<LLVM::AllocaOp>(
1550 loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1551 handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1552
1553 create2To4SpMatCallBuilder
1554 .create(loc, rewriter,
1555 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1556 .getResult();
1557 rewriter.replaceOp(op, {handle, stream});
1558 return success();
1559}
1560
1561LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1562 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1563 ConversionPatternRewriter &rewriter) const {
1564 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1565 failed(isAsyncWithOneDependency(rewriter, op)))
1566 return failure();
1567 Location loc = op.getLoc();
1568 auto stream = adaptor.getAsyncDependencies().front();
1569 // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1570 if (is2To4Sparsity(op.getSpmat())) {
1571 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1572 {adaptor.getSpmat(), stream});
1573
1574 } else {
1575 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1576 }
1577 rewriter.replaceOp(op, {stream});
1578 return success();
1579}
1580
1581LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1582 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1583 ConversionPatternRewriter &rewriter) const {
1584 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1585 failed(isAsyncWithOneDependency(rewriter, op)))
1586 return failure();
1587 Location loc = op.getLoc();
1588 auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1589 auto computeType = genConstInt32From(
1590 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1591 auto stream = adaptor.getAsyncDependencies().front();
1592 auto bufferSize = spMVBufferSizeCallBuilder
1593 .create(loc, rewriter,
1594 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1595 adaptor.getDnY(), computeType, stream})
1596 .getResult();
1597 rewriter.replaceOp(op, {bufferSize, stream});
1598 return success();
1599}
1600
1601LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1602 gpu::SpMVOp op, OpAdaptor adaptor,
1603 ConversionPatternRewriter &rewriter) const {
1604 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1605 failed(isAsyncWithOneDependency(rewriter, op)))
1606 return failure();
1607 Location loc = op.getLoc();
1608 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1609 auto computeType = genConstInt32From(
1610 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1611 auto stream = adaptor.getAsyncDependencies().front();
1612 Value pBuf =
1613 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc);
1614 spMVCallBuilder.create(loc, rewriter,
1615 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1616 adaptor.getDnY(), computeType, pBuf, stream});
1617 rewriter.replaceOp(op, {stream});
1618 return success();
1619}
1620
1621LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1622 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1623 ConversionPatternRewriter &rewriter) const {
1624 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1625 failed(isAsyncWithOneDependency(rewriter, op)))
1626 return failure();
1627 Location loc = op.getLoc();
1628 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1629 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1630 auto stream = adaptor.getAsyncDependencies().front();
1631 Value bufferSize;
1632 if (is2To4Sparsity(op.getSpmatA())) {
1633 auto pruneFlag =
1634 genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1635 auto computeType = genConstInt32From(
1636 rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1637 auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1638 rewriter.getIndexAttr(3));
1639 auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1640 loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
1641 createCuSparseLtSpMMBufferSizeBuilder
1642 .create(loc, rewriter,
1643 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1644 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1645 pruneFlag, stream})
1646 .getResult();
1647
1648 auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1649 loc, llvmPointerType, llvmPointerType, bufferSize,
1650 ValueRange{rewriter.create<LLVM::ConstantOp>(
1651 loc, getIndexType(), rewriter.getIndexAttr(1))});
1652 auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1653 loc, llvmPointerType, llvmPointerType, bufferSize,
1654 ValueRange{rewriter.create<LLVM::ConstantOp>(
1655 loc, getIndexType(), rewriter.getIndexAttr(2))});
1656 auto bufferSize0 =
1657 rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1658 auto bufferSize1 =
1659 rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1660 auto bufferSize2 =
1661 rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1662
1663 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1664 } else {
1665 auto computeType = genConstInt32From(
1666 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1667 bufferSize =
1668 createSpMMBufferSizeCallBuilder
1669 .create(loc, rewriter,
1670 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1671 adaptor.getDnmatC(), computeType, stream})
1672 .getResult();
1673 rewriter.replaceOp(op, {bufferSize, stream});
1674 }
1675 return success();
1676}
1677
1678LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1679 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1680 ConversionPatternRewriter &rewriter) const {
1681 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1682 failed(isAsyncWithOneDependency(rewriter, op)))
1683 return failure();
1684 Location loc = op.getLoc();
1685 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1686 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1687 auto computeType = genConstInt32From(
1688 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1689 auto stream = adaptor.getAsyncDependencies().front();
1690 auto bufferSize =
1691 createSDDMMBufferSizeCallBuilder
1692 .create(loc, rewriter,
1693 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1694 adaptor.getSpmatC(), computeType, stream})
1695 .getResult();
1696 rewriter.replaceOp(op, {bufferSize, stream});
1697 return success();
1698}
1699
1700LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1701 gpu::SpMMOp op, OpAdaptor adaptor,
1702 ConversionPatternRewriter &rewriter) const {
1703 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1704 failed(isAsyncWithOneDependency(rewriter, op)))
1705 return failure();
1706 Location loc = op.getLoc();
1707 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1708 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1709 auto computeType = genConstInt32From(
1710 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1711
1712 auto stream = adaptor.getAsyncDependencies().front();
1713
1714 // Lower to cusparseLt if applicable
1715 if (is2To4Sparsity(op.getSpmatA())) {
1716 SmallVector<Value> pBufs;
1717 for (Value buffer : adaptor.getBuffers()) {
1718 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1719 pBufs.push_back(pBuf);
1720 }
1721 createCuSparseLtSpMMBuilder.create(
1722 loc, rewriter,
1723 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1724 pBufs[0], pBufs[1], pBufs[2], stream});
1725 } else {
1726 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1727 .allocatedPtr(builder&: rewriter, loc);
1728 createSpMMCallBuilder.create(loc, rewriter,
1729 {modeA, modeB, adaptor.getSpmatA(),
1730 adaptor.getDnmatB(), adaptor.getDnmatC(),
1731 computeType, pBuf, stream});
1732 }
1733 rewriter.replaceOp(op, {stream});
1734 return success();
1735}
1736
1737template <typename T>
1738static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
1739 converter.addConversion([&converter](T) -> Type {
1740 return LLVM::LLVMPointerType::get(&converter.getContext());
1741 });
1742}
1743
1744LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1745 gpu::SDDMMOp op, OpAdaptor adaptor,
1746 ConversionPatternRewriter &rewriter) const {
1747 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1748 failed(isAsyncWithOneDependency(rewriter, op)))
1749 return failure();
1750 Location loc = op.getLoc();
1751 auto computeType = genConstInt32From(
1752 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1753 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1754 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1755 auto stream = adaptor.getAsyncDependencies().front();
1756 Value pBuf =
1757 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc);
1758 createSDDMMCallBuilder.create(loc, rewriter,
1759 {modeA, modeB, adaptor.getDnmatA(),
1760 adaptor.getDnmatB(), adaptor.getSpmatC(),
1761 computeType, pBuf, stream});
1762 rewriter.replaceOp(op, {stream});
1763 return success();
1764}
1765
1766LogicalResult
1767ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1768 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1769 ConversionPatternRewriter &rewriter) const {
1770 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1771 failed(isAsyncWithOneDependency(rewriter, op)))
1772 return failure();
1773 Location loc = op.getLoc();
1774 auto stream = adaptor.getAsyncDependencies().front();
1775 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1776 .getResult();
1777 rewriter.replaceOp(op, {descr, stream});
1778 return success();
1779}
1780
1781LogicalResult
1782ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1783 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1784 ConversionPatternRewriter &rewriter) const {
1785 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1786 failed(isAsyncWithOneDependency(rewriter, op)))
1787 return failure();
1788 Location loc = op.getLoc();
1789 auto stream = adaptor.getAsyncDependencies().front();
1790 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1791 {adaptor.getDesc(), stream});
1792 rewriter.replaceOp(op, {stream});
1793 return success();
1794}
1795
1796LogicalResult
1797ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1798 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1799 ConversionPatternRewriter &rewriter) const {
1800 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1801 failed(isAsyncWithOneDependency(rewriter, op)))
1802 return failure();
1803 Location loc = op.getLoc();
1804 auto computeType = genConstInt32From(
1805 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1806 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1807 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1808 auto stream = adaptor.getAsyncDependencies().front();
1809
1810 Value pBuf =
1811 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc);
1812 Value bufferSizeNew;
1813
1814 if (adaptor.getKind() ==
1815 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1816 bufferSizeNew =
1817 createSpGEMMWorkEstimationBuilder
1818 .create(loc, rewriter,
1819 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1820 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1821 adaptor.getBufferSz(), pBuf, stream})
1822 .getResult();
1823 } else {
1824 bufferSizeNew =
1825 createSpGEMMComputeBuilder
1826 .create(loc, rewriter,
1827 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1828 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1829 adaptor.getBufferSz(), pBuf, stream})
1830 .getResult();
1831 }
1832 rewriter.replaceOp(op, {bufferSizeNew, stream});
1833 return success();
1834}
1835
1836LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1837 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1838 ConversionPatternRewriter &rewriter) const {
1839 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1840 failed(isAsyncWithOneDependency(rewriter, op)))
1841 return failure();
1842 Location loc = op.getLoc();
1843 auto computeType = genConstInt32From(
1844 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1845 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1846 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1847 auto stream = adaptor.getAsyncDependencies().front();
1848 createSpGEMMCopyBuilder.create(loc, rewriter,
1849 {adaptor.getDesc(), modeA, modeB,
1850 adaptor.getSpmatA(), adaptor.getSpmatB(),
1851 adaptor.getSpmatC(), computeType, stream});
1852 rewriter.replaceOp(op, {stream});
1853 return success();
1854}
1855
1856LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1857 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1858 ConversionPatternRewriter &rewriter) const {
1859 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1860 failed(isAsyncWithOneDependency(rewriter, op)))
1861 return failure();
1862 Location loc = op.getLoc();
1863 auto stream = adaptor.getAsyncDependencies().front();
1864
1865 auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1866 rewriter.getIndexAttr(3));
1867 auto buffer = rewriter.create<LLVM::AllocaOp>(
1868 loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
1869
1870 auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1871 loc, llvmPointerType, llvmPointerType, buffer,
1872 ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1873 rewriter.getIndexAttr(0))});
1874 auto colsPtr = rewriter.create<LLVM::GEPOp>(
1875 loc, llvmPointerType, llvmPointerType, buffer,
1876 ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1877 rewriter.getIndexAttr(1))});
1878 auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1879 loc, llvmPointerType, llvmPointerType, buffer,
1880 ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1881 rewriter.getIndexAttr(2))});
1882 createSpMatGetSizeBuilder.create(
1883 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1884 auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1885 auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1886 auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1887
1888 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1889 return success();
1890}
1891
1892LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1893 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1894 ConversionPatternRewriter &rewriter) const {
1895 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1896 failed(isAsyncWithOneDependency(rewriter, op)))
1897 return failure();
1898 Location loc = op.getLoc();
1899 auto stream = adaptor.getAsyncDependencies().front();
1900 Value pPos =
1901 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(builder&: rewriter, loc);
1902 Value pCrd =
1903 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(builder&: rewriter, loc);
1904 Value pVal =
1905 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1906 createSetCsrPointersBuilder.create(
1907 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1908 rewriter.replaceOp(op, {stream});
1909 return success();
1910}
1911
1912LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1913 gpu::CreateCscOp op, OpAdaptor adaptor,
1914 ConversionPatternRewriter &rewriter) const {
1915 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1916 failed(isAsyncWithOneDependency(rewriter, op)))
1917 return failure();
1918 Location loc = op.getLoc();
1919 auto stream = adaptor.getAsyncDependencies().front();
1920 Value pColPos =
1921 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(builder&: rewriter, loc);
1922 Value pRowIdxs =
1923 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(builder&: rewriter, loc);
1924 Value pValues =
1925 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1926 Type pType =
1927 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1928 Type iType =
1929 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1930 Type dType =
1931 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1932 auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType));
1933 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1934 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1935 auto handle =
1936 createCscCallBuilder
1937 .create(loc, rewriter,
1938 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1939 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1940 .getResult();
1941 rewriter.replaceOp(op, {handle, stream});
1942 return success();
1943}
1944
1945LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1946 gpu::CreateBsrOp op, OpAdaptor adaptor,
1947 ConversionPatternRewriter &rewriter) const {
1948 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1949 failed(isAsyncWithOneDependency(rewriter, op)))
1950 return failure();
1951 Location loc = op.getLoc();
1952 auto stream = adaptor.getAsyncDependencies().front();
1953 Value pRowPos =
1954 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(builder&: rewriter, loc);
1955 Value pColIdxs =
1956 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(builder&: rewriter, loc);
1957 Value pValues =
1958 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1959 Type pType =
1960 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1961 Type iType =
1962 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1963 Type dType =
1964 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1965 auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType));
1966 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1967 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1968 auto handle =
1969 createBsrCallBuilder
1970 .create(loc, rewriter,
1971 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1972 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1973 pColIdxs, pValues, ptp, itp, dtp, stream})
1974 .getResult();
1975 rewriter.replaceOp(op, {handle, stream});
1976 return success();
1977}
1978
1979void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
1980 RewritePatternSet &patterns,
1981 StringRef gpuBinaryAnnotation,
1982 bool kernelBarePtrCallConv,
1983 SymbolTable *cachedModuleTable) {
1984 addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1985 addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1986 addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1987 addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1988
1989 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1990 ConvertDeallocOpToGpuRuntimeCallPattern,
1991 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1992 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1993 ConvertMemcpyOpToGpuRuntimeCallPattern,
1994 ConvertMemsetOpToGpuRuntimeCallPattern,
1995 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1996 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1997 ConvertWaitOpToGpuRuntimeCallPattern,
1998 ConvertAsyncYieldToGpuRuntimeCallPattern,
1999 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
2000 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
2001 ConvertCreateCooOpToGpuRuntimeCallPattern,
2002 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
2003 ConvertCreateCsrOpToGpuRuntimeCallPattern,
2004 ConvertCreateCscOpToGpuRuntimeCallPattern,
2005 ConvertCreateBsrOpToGpuRuntimeCallPattern,
2006 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
2007 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
2008 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
2009 ConvertSpMVOpToGpuRuntimeCallPattern,
2010 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
2011 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
2012 ConvertSpMMOpToGpuRuntimeCallPattern,
2013 ConvertSDDMMOpToGpuRuntimeCallPattern,
2014 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
2015 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
2016 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
2017 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
2018 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
2019 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(arg&: converter);
2020 patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
2021 arg&: converter, args&: gpuBinaryAnnotation, args&: kernelBarePtrCallConv, args&: cachedModuleTable);
2022 patterns.add<EraseGpuModuleOpPattern>(arg: &converter.getContext());
2023}
2024

source code of mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp