1//===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert gpu.launch_func op into a sequence of
10// GPU runtime calls. As most of GPU runtimes does not have a stable published
11// ABI, this pass uses a slim runtime layer that builds on top of the public
12// API from GPU runtime headers.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17
18#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
19#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
21#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
22#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
23#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
24#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
25#include "mlir/Conversion/LLVMCommon/Pattern.h"
26#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
27#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
28#include "mlir/Dialect/Async/IR/Async.h"
29#include "mlir/Dialect/GPU/IR/GPUDialect.h"
30#include "mlir/Dialect/GPU/Transforms/Passes.h"
31#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
32#include "mlir/Dialect/MemRef/IR/MemRef.h"
33#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
34#include "mlir/IR/Attributes.h"
35#include "mlir/IR/Builders.h"
36#include "mlir/IR/BuiltinOps.h"
37#include "mlir/IR/BuiltinTypes.h"
38#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
39
40#include "llvm/ADT/STLExtras.h"
41
42#define DEBUG_TYPE "gpu-to-llvm"
43
44namespace mlir {
45#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
46#include "mlir/Conversion/Passes.h.inc"
47} // namespace mlir
48
49using namespace mlir;
50
51namespace {
52class GpuToLLVMConversionPass
53 : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
54public:
55 using Base::Base;
56 void getDependentDialects(DialectRegistry &registry) const final {
57 Base::getDependentDialects(registry);
58 registerConvertToLLVMDependentDialectLoading(registry);
59 }
60 // Run the dialect converter on the module.
61 void runOnOperation() override;
62};
63
64template <typename OpTy>
65class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
66public:
67 explicit ConvertOpToGpuRuntimeCallPattern(
68 const LLVMTypeConverter &typeConverter)
69 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
70
71protected:
72 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
73 MemRefType type, MemRefDescriptor desc) const {
74 Type indexType = ConvertToLLVMPattern::getIndexType();
75 if (type.hasStaticShape())
76 return ConvertToLLVMPattern::createIndexAttrConstant(
77 builder&: rewriter, loc, resultType: indexType, value: type.getNumElements());
78 // Compute the number of elements by multiplying all the dim sizes.
79 uint64_t rank = type.getRank();
80 Value numElements = desc.size(builder&: rewriter, loc, /*pos=*/0);
81 for (unsigned i = 1; i < rank; i++)
82 numElements = rewriter.create<LLVM::MulOp>(
83 location: loc, args&: numElements, args: desc.size(builder&: rewriter, loc, /*pos=*/i));
84 return numElements;
85 }
86
87 MLIRContext *context = &this->getTypeConverter()->getContext();
88
89 Type llvmVoidType = LLVM::LLVMVoidType::get(ctx: context);
90 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
91 Type llvmInt8Type = IntegerType::get(context, width: 8);
92 Type llvmInt16Type = IntegerType::get(context, width: 16);
93 Type llvmInt32Type = IntegerType::get(context, width: 32);
94 Type llvmInt64Type = IntegerType::get(context, width: 64);
95 Type llvmFloat32Type = Float32Type::get(context);
96 Type llvmIntPtrType = IntegerType::get(
97 context, width: this->getTypeConverter()->getPointerBitwidth(0));
98
99 FunctionCallBuilder streamCreateCallBuilder = {
100 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
101 FunctionCallBuilder streamDestroyCallBuilder = {
102 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
103 FunctionCallBuilder streamSynchronizeCallBuilder = {
104 "mgpuStreamSynchronize",
105 llvmVoidType,
106 {llvmPointerType /* void *stream */}};
107 FunctionCallBuilder streamWaitEventCallBuilder = {
108 "mgpuStreamWaitEvent",
109 llvmVoidType,
110 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
111 FunctionCallBuilder eventCreateCallBuilder = {
112 "mgpuEventCreate", llvmPointerType /* void *event */, {}};
113 FunctionCallBuilder eventDestroyCallBuilder = {
114 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
115 FunctionCallBuilder eventSynchronizeCallBuilder = {
116 "mgpuEventSynchronize",
117 llvmVoidType,
118 {llvmPointerType /* void *event */}};
119 FunctionCallBuilder eventRecordCallBuilder = {
120 "mgpuEventRecord",
121 llvmVoidType,
122 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
123 FunctionCallBuilder hostRegisterCallBuilder = {
124 "mgpuMemHostRegisterMemRef",
125 llvmVoidType,
126 {llvmIntPtrType /* intptr_t rank */,
127 llvmPointerType /* void *memrefDesc */,
128 llvmIntPtrType /* intptr_t elementSizeBytes */}};
129 FunctionCallBuilder hostUnregisterCallBuilder = {
130 "mgpuMemHostUnregisterMemRef",
131 llvmVoidType,
132 {llvmIntPtrType /* intptr_t rank */,
133 llvmPointerType /* void *memrefDesc */,
134 llvmIntPtrType /* intptr_t elementSizeBytes */}};
135 FunctionCallBuilder allocCallBuilder = {
136 "mgpuMemAlloc",
137 llvmPointerType /* void * */,
138 {llvmIntPtrType /* intptr_t sizeBytes */,
139 llvmPointerType /* void *stream */,
140 llvmInt8Type /* bool isHostShared */}};
141 FunctionCallBuilder deallocCallBuilder = {
142 "mgpuMemFree",
143 llvmVoidType,
144 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
145 FunctionCallBuilder memcpyCallBuilder = {
146 "mgpuMemcpy",
147 llvmVoidType,
148 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
149 llvmIntPtrType /* intptr_t sizeBytes */,
150 llvmPointerType /* void *stream */}};
151 FunctionCallBuilder memset16CallBuilder = {
152 "mgpuMemset16",
153 llvmVoidType,
154 {llvmPointerType /* void *dst */,
155 llvmInt16Type /* unsigned short value */,
156 llvmIntPtrType /* intptr_t sizeBytes */,
157 llvmPointerType /* void *stream */}};
158 FunctionCallBuilder memset32CallBuilder = {
159 "mgpuMemset32",
160 llvmVoidType,
161 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
162 llvmIntPtrType /* intptr_t sizeBytes */,
163 llvmPointerType /* void *stream */}};
164 FunctionCallBuilder setDefaultDeviceCallBuilder = {
165 "mgpuSetDefaultDevice",
166 llvmVoidType,
167 {llvmInt32Type /* uint32_t devIndex */}};
168 FunctionCallBuilder createDnVecCallBuilder = {
169 "mgpuCreateDnVec",
170 llvmPointerType,
171 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
172 llvmPointerType /* void *stream */}};
173 FunctionCallBuilder destroyDnVecCallBuilder = {
174 "mgpuDestroyDnVec",
175 llvmVoidType,
176 {llvmPointerType, llvmPointerType /* void *stream */}};
177 FunctionCallBuilder createDnMatCallBuilder = {
178 "mgpuCreateDnMat",
179 llvmPointerType,
180 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
181 llvmPointerType /* void *stream */}};
182 FunctionCallBuilder destroyDnMatCallBuilder = {
183 "mgpuDestroyDnMat",
184 llvmVoidType,
185 {llvmPointerType, llvmPointerType /* void *stream */}};
186 FunctionCallBuilder createCooCallBuilder = {
187 "mgpuCreateCoo",
188 llvmPointerType,
189 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
190 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
191 llvmPointerType /* void *stream */}};
192 FunctionCallBuilder createCooAoSCallBuilder = {
193 "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
194 llvmPointerType,
195 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
196 llvmPointerType, llvmInt32Type, llvmInt32Type,
197 llvmPointerType /* void *stream */}};
198 FunctionCallBuilder createCsrCallBuilder = {
199 "mgpuCreateCsr",
200 llvmPointerType,
201 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
202 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
203 llvmInt32Type, llvmPointerType /* void *stream */}};
204 FunctionCallBuilder createCscCallBuilder = {
205 "mgpuCreateCsc",
206 llvmPointerType,
207 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
208 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
209 llvmInt32Type, llvmPointerType /* void *stream */}};
210 FunctionCallBuilder createBsrCallBuilder = {
211 "mgpuCreateBsr",
212 llvmPointerType,
213 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
214 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
215 llvmInt32Type, llvmInt32Type, llvmInt32Type,
216 llvmPointerType /* void *stream */}};
217 FunctionCallBuilder destroySpMatCallBuilder = {
218 "mgpuDestroySpMat",
219 llvmVoidType,
220 {llvmPointerType, llvmPointerType /* void *stream */}};
221 FunctionCallBuilder spMVBufferSizeCallBuilder = {
222 "mgpuSpMVBufferSize",
223 llvmIntPtrType,
224 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
225 llvmInt32Type, llvmPointerType /* void *stream */}};
226 FunctionCallBuilder spMVCallBuilder = {
227 "mgpuSpMV",
228 llvmVoidType,
229 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
230 llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
231 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
232 "mgpuSpMMBufferSize",
233 llvmIntPtrType,
234 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
235 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
236 FunctionCallBuilder createSpMMCallBuilder = {
237 "mgpuSpMM",
238 llvmVoidType,
239 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
240 llvmPointerType, llvmInt32Type, llvmPointerType,
241 llvmPointerType /* void *stream */}};
242 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
243 "mgpuSDDMMBufferSize",
244 llvmIntPtrType,
245 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
246 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
247 FunctionCallBuilder createSDDMMCallBuilder = {
248 "mgpuSDDMM",
249 llvmVoidType,
250 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
251 llvmPointerType, llvmInt32Type, llvmPointerType,
252 llvmPointerType /* void *stream */}};
253 FunctionCallBuilder createLtDnMatCallBuilder = {
254 "mgpuCreateCuSparseLtDnMat",
255 llvmVoidType,
256 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
257 llvmInt32Type, llvmPointerType /* void *stream */}};
258 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
259 "mgpuDestroyCuSparseLtSpMat",
260 llvmVoidType,
261 {llvmPointerType, llvmPointerType /* void *stream */}};
262 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
263 "mgpuDestroyCuSparseLtDnMat",
264 llvmVoidType,
265 {llvmPointerType, llvmPointerType /* void *stream */}};
266 FunctionCallBuilder create2To4SpMatCallBuilder = {
267 "mgpuCusparseLtCreate2To4SpMat",
268 llvmVoidType,
269 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
270 llvmInt32Type, llvmPointerType /* void *stream */}};
271 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
272 "mgpuCuSparseLtSpMMBufferSize",
273 llvmVoidType,
274 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
275 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
276 llvmPointerType /*void *stream*/}};
277 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
278 "mgpuCuSparseLtSpMM",
279 llvmVoidType,
280 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
281 llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
282 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
283 "mgpuSpGEMMCreateDescr",
284 llvmPointerType,
285 {llvmPointerType /*void *stream*/}};
286 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
287 "mgpuSpGEMMDestroyDescr",
288 llvmVoidType,
289 {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
290 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
291 "mgpuSpGEMMWorkEstimation",
292 llvmIntPtrType,
293 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
294 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
295 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
296 llvmPointerType /*void *stream*/}};
297 FunctionCallBuilder createSpGEMMComputeBuilder = {
298 "mgpuSpGEMMCompute",
299 llvmIntPtrType,
300 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
301 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
302 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
303 llvmPointerType /*void *stream*/}};
304 FunctionCallBuilder createSpGEMMCopyBuilder = {
305 "mgpuSpGEMMCopy",
306 llvmVoidType,
307 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
308 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
309 llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
310 FunctionCallBuilder createSpMatGetSizeBuilder = {
311 "mgpuSpMatGetSize",
312 llvmVoidType,
313 {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
314 llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
315 FunctionCallBuilder createSetCsrPointersBuilder = {
316 "mgpuSetCsrPointers",
317 llvmVoidType,
318 {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
319 llvmPointerType /*crd*/, llvmPointerType /*val*/,
320 llvmPointerType /*void *stream*/}};
321};
322
323/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
324/// call. Currently it supports CUDA and ROCm (HIP).
325class ConvertHostRegisterOpToGpuRuntimeCallPattern
326 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
327public:
328 ConvertHostRegisterOpToGpuRuntimeCallPattern(
329 const LLVMTypeConverter &typeConverter)
330 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
331
332private:
333 LogicalResult
334 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override;
336};
337
338class ConvertHostUnregisterOpToGpuRuntimeCallPattern
339 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
340public:
341 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
342 const LLVMTypeConverter &typeConverter)
343 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
344 }
345
346private:
347 LogicalResult
348 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
349 ConversionPatternRewriter &rewriter) const override;
350};
351
352/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
353/// call. Currently it supports CUDA and ROCm (HIP).
354class ConvertAllocOpToGpuRuntimeCallPattern
355 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
356public:
357 ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
358 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
359
360private:
361 LogicalResult
362 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const override;
364};
365
366/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
367/// call. Currently it supports CUDA and ROCm (HIP).
368class ConvertDeallocOpToGpuRuntimeCallPattern
369 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
370public:
371 ConvertDeallocOpToGpuRuntimeCallPattern(
372 const LLVMTypeConverter &typeConverter)
373 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
374
375private:
376 LogicalResult
377 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
378 ConversionPatternRewriter &rewriter) const override;
379};
380
381class ConvertAsyncYieldToGpuRuntimeCallPattern
382 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
383public:
384 ConvertAsyncYieldToGpuRuntimeCallPattern(
385 const LLVMTypeConverter &typeConverter)
386 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
387
388private:
389 LogicalResult
390 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
391 ConversionPatternRewriter &rewriter) const override;
392};
393
394/// A rewrite pattern to convert gpu.wait operations into a GPU runtime
395/// call. Currently it supports CUDA and ROCm (HIP).
396class ConvertWaitOpToGpuRuntimeCallPattern
397 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
398public:
399 ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
400 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
401
402private:
403 LogicalResult
404 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter) const override;
406};
407
408/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
409/// call. Currently it supports CUDA and ROCm (HIP).
410class ConvertWaitAsyncOpToGpuRuntimeCallPattern
411 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
412public:
413 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
414 const LLVMTypeConverter &typeConverter)
415 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
416
417private:
418 LogicalResult
419 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter) const override;
421};
422
423/// A rewrite patter to legalize gpu.launch_func with LLVM types.
424class LegalizeLaunchFuncOpPattern
425 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
426public:
427 LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
428 bool kernelBarePtrCallConv,
429 bool kernelIntersperseSizeCallConv)
430 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
431 kernelBarePtrCallConv(kernelBarePtrCallConv),
432 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
433
434private:
435 LogicalResult
436 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
437 ConversionPatternRewriter &rewriter) const override;
438
439 bool kernelBarePtrCallConv;
440 bool kernelIntersperseSizeCallConv;
441};
442
443/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
444/// call. Currently it supports CUDA and ROCm (HIP).
445class ConvertMemcpyOpToGpuRuntimeCallPattern
446 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
447public:
448 ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
449 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
450
451private:
452 LogicalResult
453 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
454 ConversionPatternRewriter &rewriter) const override;
455};
456
457/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
458/// call. Currently it supports CUDA and ROCm (HIP).
459class ConvertMemsetOpToGpuRuntimeCallPattern
460 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
461public:
462 ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
463 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
464
465private:
466 LogicalResult
467 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter) const override;
469};
470
471/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
472/// Currently supports CUDA and ROCm (HIP)
473class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
474 : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
475public:
476 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
477 const LLVMTypeConverter &typeConverter)
478 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
479 typeConverter) {}
480
481 LogicalResult
482 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
483 ConversionPatternRewriter &rewriter) const override;
484};
485
486/// Generic rewriting rule for operation on sparse matrices.
487/// Currently supports CUDA (by means of cuSparse and cuSparseLt).
488#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
489 class Convert##op_name##ToGpuRuntimeCallPattern \
490 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
491 public: \
492 Convert##op_name##ToGpuRuntimeCallPattern( \
493 const LLVMTypeConverter &typeConverter) \
494 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
495 \
496 private: \
497 LogicalResult \
498 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
499 ConversionPatternRewriter &rewriter) const override; \
500 };
501
502DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp)
503DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp)
504DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp)
505DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp)
506DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp)
507DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCscOp)
508DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateBsrOp)
509DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp)
510DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp)
511DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp)
512DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp)
513DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp)
514DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp)
515DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp)
516DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp)
517DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp)
518DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp)
519DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
520DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp)
521DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMatGetSizeOp)
522DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
523
524} // namespace
525
526void GpuToLLVMConversionPass::runOnOperation() {
527 MLIRContext *context = &getContext();
528
529 // Perform progressive lowering of vector transfer operations.
530 {
531 RewritePatternSet patterns(&getContext());
532 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
533 vector::populateVectorTransferLoweringPatterns(patterns,
534 /*maxTransferRank=*/1);
535 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns))))
536 return signalPassFailure();
537 }
538
539 LowerToLLVMOptions options(context);
540 options.useBarePtrCallConv = hostBarePtrCallConv;
541 RewritePatternSet patterns(context);
542 ConversionTarget target(*context);
543 target.addLegalDialect<LLVM::LLVMDialect>();
544 LLVMTypeConverter converter(context, options);
545
546 // Populate all patterns from all dialects that implement the
547 // `ConvertToLLVMPatternInterface` interface.
548 for (Dialect *dialect : context->getLoadedDialects()) {
549 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(Val: dialect);
550 if (!iface)
551 continue;
552 iface->populateConvertToLLVMConversionPatterns(target, typeConverter&: converter, patterns);
553 }
554
555 // Preserve GPU modules and binaries. Modules are preserved as they can be
556 // converted later by `gpu-module-to-binary`.
557 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
558 // Accept as legal LaunchFuncOps if the operands have been lowered.
559 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
560 callback: [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
561
562 // These aren't covered by the ConvertToLLVMPatternInterface right now.
563 populateVectorToLLVMConversionPatterns(converter, patterns);
564 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
565 populateAsyncStructuralTypeConversionsAndLegality(typeConverter&: converter, patterns,
566 target);
567 populateGpuToLLVMConversionPatterns(converter, patterns,
568 kernelBarePtrCallConv,
569 kernelIntersperseSizeCallConv);
570
571 if (failed(
572 Result: applyPartialConversion(op: getOperation(), target, patterns: std::move(patterns))))
573 signalPassFailure();
574}
575
576LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
577 ArrayRef<Value> arguments) const {
578 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
579 auto function = [&] {
580 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(name: functionName))
581 return function;
582 return OpBuilder::atBlockEnd(block: module.getBody())
583 .create<LLVM::LLVMFuncOp>(location: loc, args: functionName, args: functionType);
584 }();
585 return builder.create<LLVM::CallOp>(location: loc, args&: function, args&: arguments);
586}
587
588// Corresponding to cusparseIndexType_t defined in cusparse.h.
589static int32_t getCuSparseIndexTypeFrom(Type type) {
590 if (type.isInteger(width: 16))
591 return 1; // CUSPARSE_INDEX_16U
592 if (type.isInteger(width: 32))
593 return 2; // CUSPARSE_INDEX_32I
594 return 3; // CUSPARSE_INDEX_64I
595}
596
597static int32_t getCuSparseLtDataTypeFrom(Type type) {
598 if (type.isF16())
599 return 0; // CUSPARSE_COMPUTE_16F,
600 if (type.isInteger(width: 32))
601 return 1; // CUSPARSE_COMPUTE_32I
602 llvm_unreachable("unsupported type");
603 // TODO: add support to TF32
604}
605
606// Corresponding to cudaDataType_t defined in CUDA library_types.h.
607static int32_t getCuSparseDataTypeFrom(Type type) {
608 if (llvm::isa<ComplexType>(Val: type)) {
609 // get the element type
610 auto elementType = cast<ComplexType>(Val&: type).getElementType();
611 if (elementType.isBF16())
612 return 15; // CUDA_C_16BF
613 if (elementType.isF16())
614 return 6; // CUDA_C_16F
615 if (elementType.isF32())
616 return 4; // CUDA_C_32F
617 if (elementType.isF64())
618 return 5; // CUDA_C_64F
619 if (elementType.isInteger(width: 8))
620 return 7; // CUDA_C_8I
621 if (elementType.isInteger(width: 16))
622 return 21; // CUDA_C_16I
623 if (elementType.isInteger(width: 32))
624 return 11; // CUDA_C_32I
625 }
626 if (type.isBF16())
627 return 14; // CUDA_R_16BF
628 if (type.isF16())
629 return 2; // CUDA_R_16F
630 if (type.isF32())
631 return 0; // CUDA_R_32F
632 if (type.isF64())
633 return 1; // CUDA_R_64F
634 if (type.isInteger(width: 8))
635 return 3; // CUDA_R_8I
636 if (type.isInteger(width: 16))
637 return 20; // CUDA_R_16I
638 if (type.isInteger(width: 32))
639 return 10; // CUDA_R_32I
640
641 llvm_unreachable("unsupported element type");
642}
643
644static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
645 return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
646}
647
648// TODO: We may want a run-time (of the mlir compiler) disablement/warning:
649// cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
650// runtime (of the CUDA program) error , but it might be great if we could at
651// least output a warning when we found the target architecture is <8.0 and the
652// user still wants to use cusparseLt. to make sure when lowering gpu sparse
653// dialect to llvm calls, the cusparselt calls are disabled for cuda
654// architecture <8.0
655static bool is2To4Sparsity(Value spMat) {
656 if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
657 return true;
658 if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
659 return false;
660 if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
661 return false;
662 if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
663 return false;
664 if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
665 return false;
666 if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
667 return false;
668 // Print the spMat defining op
669 spMat.getDefiningOp()->print(os&: llvm::errs());
670 llvm_unreachable("cannot find spmat def");
671}
672
673static bool isSpMMCusparseLtOp(Value op) {
674 for (Operation *user : op.getUsers()) {
675 auto spmmOp = dyn_cast<gpu::SpMMOp>(Val: user);
676 // If the other operator is 50% sparsity then we should use cusparseLt
677 if (!spmmOp)
678 continue;
679 if (is2To4Sparsity(spMat: spmmOp.getSpmatA()))
680 return true;
681 }
682 return false;
683}
684
685// Returns whether all operands are of LLVM type.
686static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
687 ConversionPatternRewriter &rewriter) {
688 if (!llvm::all_of(Range&: operands, P: [](Value value) {
689 return LLVM::isCompatibleType(type: value.getType());
690 }))
691 return rewriter.notifyMatchFailure(
692 arg&: op, msg: "Cannot convert if operands aren't of LLVM type.");
693 return success();
694}
695
696static LogicalResult
697isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
698 gpu::AsyncOpInterface op) {
699 if (op.getAsyncDependencies().size() != 1)
700 return rewriter.notifyMatchFailure(
701 arg&: op, msg: "Can only convert with exactly one async dependency.");
702
703 if (!op.getAsyncToken())
704 return rewriter.notifyMatchFailure(arg&: op, msg: "Can convert only async version.");
705
706 return success();
707}
708
709LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
710 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
711 ConversionPatternRewriter &rewriter) const {
712 auto *op = hostRegisterOp.getOperation();
713 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)))
714 return failure();
715
716 Location loc = op->getLoc();
717
718 auto memRefType = hostRegisterOp.getValue().getType();
719 auto elementType = cast<UnrankedMemRefType>(Val&: memRefType).getElementType();
720 auto elementSize = getSizeInBytes(loc, type: elementType, rewriter);
721
722 auto arguments = getTypeConverter()->promoteOperands(
723 loc, opOperands: op->getOperands(), operands: adaptor.getOperands(), builder&: rewriter);
724 arguments.push_back(Elt: elementSize);
725 hostRegisterCallBuilder.create(loc, builder&: rewriter, arguments);
726
727 rewriter.eraseOp(op);
728 return success();
729}
730
731LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
732 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
733 ConversionPatternRewriter &rewriter) const {
734 Operation *op = hostUnregisterOp.getOperation();
735 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)))
736 return failure();
737
738 Location loc = op->getLoc();
739
740 auto memRefType = hostUnregisterOp.getValue().getType();
741 auto elementType = cast<UnrankedMemRefType>(Val&: memRefType).getElementType();
742 auto elementSize = getSizeInBytes(loc, type: elementType, rewriter);
743
744 auto arguments = getTypeConverter()->promoteOperands(
745 loc, opOperands: op->getOperands(), operands: adaptor.getOperands(), builder&: rewriter);
746 arguments.push_back(Elt: elementSize);
747 hostUnregisterCallBuilder.create(loc, builder&: rewriter, arguments);
748
749 rewriter.eraseOp(op);
750 return success();
751}
752
753LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
754 gpu::AllocOp allocOp, OpAdaptor adaptor,
755 ConversionPatternRewriter &rewriter) const {
756
757 MemRefType memRefType = allocOp.getType();
758
759 if (failed(Result: areAllLLVMTypes(op: allocOp, operands: adaptor.getOperands(), rewriter)) ||
760 !isConvertibleAndHasIdentityMaps(type: memRefType))
761 return failure();
762
763 auto loc = allocOp.getLoc();
764
765 bool isShared = allocOp.getHostShared();
766
767 if (isShared && allocOp.getAsyncToken())
768 return rewriter.notifyMatchFailure(
769 arg&: allocOp, msg: "Host Shared allocation cannot be done async");
770 if (!isShared && failed(Result: isAsyncWithOneDependency(rewriter, op: allocOp)))
771 return failure();
772
773 // Get shape of the memref as values: static sizes are constant
774 // values and dynamic sizes are passed to 'alloc' as operands.
775 SmallVector<Value, 4> shape;
776 SmallVector<Value, 4> strides;
777 Value sizeBytes;
778 getMemRefDescriptorSizes(loc, memRefType, dynamicSizes: adaptor.getDynamicSizes(), rewriter,
779 sizes&: shape, strides, size&: sizeBytes);
780
781 // Allocate the underlying buffer and store a pointer to it in the MemRef
782 // descriptor.
783 auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(location: loc, args: llvmPointerType);
784 Value stream = adaptor.getAsyncDependencies().empty()
785 ? nullPtr
786 : adaptor.getAsyncDependencies().front();
787
788 auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
789 location: loc, args: llvmInt8Type, args: rewriter.getI8IntegerAttr(value: isShared));
790
791 Value allocatedPtr =
792 allocCallBuilder.create(loc, builder&: rewriter, arguments: {sizeBytes, stream, isHostShared})
793 .getResult();
794
795 // No alignment.
796 Value alignedPtr = allocatedPtr;
797
798 // Create the MemRef descriptor.
799 auto memRefDescriptor = this->createMemRefDescriptor(
800 loc, memRefType, allocatedPtr, alignedPtr, sizes: shape, strides, rewriter);
801
802 if (allocOp.getAsyncToken()) {
803 // Async alloc: make dependent ops use the same stream.
804 rewriter.replaceOp(op: allocOp, newValues: {memRefDescriptor, stream});
805 } else {
806 rewriter.replaceOp(op: allocOp, newValues: {memRefDescriptor});
807 }
808
809 return success();
810}
811
812LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
813 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
814 ConversionPatternRewriter &rewriter) const {
815 if (failed(Result: areAllLLVMTypes(op: deallocOp, operands: adaptor.getOperands(), rewriter)) ||
816 failed(Result: isAsyncWithOneDependency(rewriter, op: deallocOp)))
817 return failure();
818
819 Location loc = deallocOp.getLoc();
820
821 Value pointer =
822 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc);
823 Value stream = adaptor.getAsyncDependencies().front();
824 deallocCallBuilder.create(loc, builder&: rewriter, arguments: {pointer, stream});
825
826 rewriter.replaceOp(op: deallocOp, newValues: {stream});
827 return success();
828}
829
830static bool isGpuAsyncTokenType(Value value) {
831 return isa<gpu::AsyncTokenType>(Val: value.getType());
832}
833
834// Converts !gpu.async.token operands of `async.yield` to runtime calls. The
835// !gpu.async.token are lowered to stream within the async.execute region, but
836// are passed as events between them. For each !gpu.async.token operand, we
837// create an event and record it on the stream.
838LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
839 async::YieldOp yieldOp, OpAdaptor adaptor,
840 ConversionPatternRewriter &rewriter) const {
841 if (llvm::none_of(Range: yieldOp.getOperands(), P: isGpuAsyncTokenType))
842 return rewriter.notifyMatchFailure(arg&: yieldOp, msg: "no gpu async token operand");
843
844 Location loc = yieldOp.getLoc();
845 SmallVector<Value, 4> newOperands(adaptor.getOperands());
846 llvm::SmallDenseSet<Value> streams;
847 for (auto &operand : yieldOp->getOpOperands()) {
848 if (!isGpuAsyncTokenType(value: operand.get()))
849 continue;
850 auto idx = operand.getOperandNumber();
851 auto stream = adaptor.getOperands()[idx];
852 auto event = eventCreateCallBuilder.create(loc, builder&: rewriter, arguments: {}).getResult();
853 eventRecordCallBuilder.create(loc, builder&: rewriter, arguments: {event, stream});
854 newOperands[idx] = event;
855 streams.insert(V: stream);
856 }
857 for (auto stream : streams)
858 streamDestroyCallBuilder.create(loc, builder&: rewriter, arguments: {stream});
859
860 rewriter.modifyOpInPlace(root: yieldOp, callable: [&] { yieldOp->setOperands(newOperands); });
861 return success();
862}
863
864// Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
865static bool isDefinedByCallTo(Value value, StringRef functionName) {
866 assert(isa<LLVM::LLVMPointerType>(value.getType()));
867 if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
868 return *defOp.getCallee() == functionName;
869 return false;
870}
871
872// Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
873// with the stream/event operands. The operands are destroyed. That is, it
874// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
875// runtime error. Eventually, we should guarantee this property.
876LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
877 gpu::WaitOp waitOp, OpAdaptor adaptor,
878 ConversionPatternRewriter &rewriter) const {
879 if (waitOp.getAsyncToken())
880 return rewriter.notifyMatchFailure(arg&: waitOp, msg: "Cannot convert async op.");
881
882 Location loc = waitOp.getLoc();
883
884 for (auto operand : adaptor.getOperands()) {
885 if (isDefinedByCallTo(value: operand, functionName: streamCreateCallBuilder.functionName)) {
886 // The converted operand's definition created a stream.
887 streamSynchronizeCallBuilder.create(loc, builder&: rewriter, arguments: {operand});
888 streamDestroyCallBuilder.create(loc, builder&: rewriter, arguments: {operand});
889 } else {
890 // Otherwise the converted operand is an event. This assumes that we use
891 // events in control flow code as well.
892 eventSynchronizeCallBuilder.create(loc, builder&: rewriter, arguments: {operand});
893 eventDestroyCallBuilder.create(loc, builder&: rewriter, arguments: {operand});
894 }
895 }
896
897 rewriter.eraseOp(op: waitOp);
898 return success();
899}
900
901// Converts `gpu.wait async` to runtime calls. The converted op creates a new
902// stream that is synchronized with stream/event operands. The operands are
903// destroyed. That is, it assumes that it is not used afterwards or elsewhere.
904// Otherwise we will get a runtime error. Eventually, we should guarantee this
905// property.
906LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
907 gpu::WaitOp waitOp, OpAdaptor adaptor,
908 ConversionPatternRewriter &rewriter) const {
909 if (!waitOp.getAsyncToken())
910 return rewriter.notifyMatchFailure(arg&: waitOp, msg: "Can only convert async op.");
911
912 Location loc = waitOp.getLoc();
913
914 auto insertionPoint = rewriter.saveInsertionPoint();
915 SmallVector<Value, 1> events;
916 for (auto pair :
917 llvm::zip(t: waitOp.getAsyncDependencies(), u: adaptor.getOperands())) {
918 auto operand = std::get<1>(t&: pair);
919 if (isDefinedByCallTo(value: operand, functionName: streamCreateCallBuilder.functionName)) {
920 // The converted operand's definition created a stream. Insert an event
921 // into the stream just after the last use of the original token operand.
922 auto *defOp = std::get<0>(t&: pair).getDefiningOp();
923 rewriter.setInsertionPointAfter(defOp);
924 auto event = eventCreateCallBuilder.create(loc, builder&: rewriter, arguments: {}).getResult();
925 eventRecordCallBuilder.create(loc, builder&: rewriter, arguments: {event, operand});
926 events.push_back(Elt: event);
927 } else {
928 // Otherwise the converted operand is an event. This assumes that we use
929 // events in control flow code as well.
930 events.push_back(Elt: operand);
931 }
932 }
933 rewriter.restoreInsertionPoint(ip: insertionPoint);
934 auto stream = streamCreateCallBuilder.create(loc, builder&: rewriter, arguments: {}).getResult();
935 for (auto event : events)
936 streamWaitEventCallBuilder.create(loc, builder&: rewriter, arguments: {stream, event});
937 for (auto event : events)
938 eventDestroyCallBuilder.create(loc, builder&: rewriter, arguments: {event});
939 rewriter.replaceOp(op: waitOp, newValues: {stream});
940
941 return success();
942}
943
944// Legalize the op's operands.
945LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
946 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
947 ConversionPatternRewriter &rewriter) const {
948 if (failed(Result: areAllLLVMTypes(op: launchOp, operands: adaptor.getOperands(), rewriter)))
949 return failure();
950
951 if (launchOp.getAsyncDependencies().size() > 1)
952 return rewriter.notifyMatchFailure(
953 arg&: launchOp, msg: "Cannot convert with more than one async dependency.");
954
955 // Fail when the synchronous version of the op has async dependencies. The
956 // lowering destroys the stream, and we do not want to check that there is no
957 // use of the stream after this op.
958 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
959 return rewriter.notifyMatchFailure(
960 arg&: launchOp, msg: "Cannot convert non-async op with async dependencies.");
961
962 Location loc = launchOp.getLoc();
963
964 Value stream = Value();
965 if (!adaptor.getAsyncDependencies().empty())
966 stream = adaptor.getAsyncDependencies().front();
967 // If the async keyword is present and there are no dependencies, then a
968 // stream must be created to pass to subsequent operations.
969 else if (launchOp.getAsyncToken())
970 stream = streamCreateCallBuilder.create(loc, builder&: rewriter, arguments: {}).getResult();
971
972 // Lower the kernel operands to match kernel parameters.
973 // Note: If `useBarePtrCallConv` is set in the type converter's options,
974 // the value of `kernelBarePtrCallConv` will be ignored.
975 OperandRange origArguments = launchOp.getKernelOperands();
976 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
977 loc, opOperands: origArguments, operands: adaptor.getKernelOperands(), builder&: rewriter,
978 /*useBarePtrCallConv=*/kernelBarePtrCallConv);
979 SmallVector<Value, 8> llvmArgumentsWithSizes;
980
981 // Intersperse size information if requested.
982 if (kernelIntersperseSizeCallConv) {
983 if (origArguments.size() != llvmArguments.size()) {
984 // This shouldn't happen if the bare-pointer calling convention is used.
985 return rewriter.notifyMatchFailure(
986 arg&: launchOp,
987 msg: "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
988 }
989
990 llvmArgumentsWithSizes.reserve(N: llvmArguments.size() * 2);
991 for (auto [llvmArg, origArg] : zip_equal(t&: llvmArguments, u&: origArguments)) {
992 auto memrefTy = dyn_cast<MemRefType>(Val: origArg.getType());
993 if (!memrefTy) {
994 return rewriter.notifyMatchFailure(
995 arg&: launchOp, msg: "Operand to launch op is not a memref.");
996 }
997
998 if (!memrefTy.hasStaticShape() ||
999 !memrefTy.getElementType().isIntOrFloat()) {
1000 return rewriter.notifyMatchFailure(
1001 arg&: launchOp, msg: "Operand to launch op is not a memref with a static "
1002 "shape and an integer or float element type.");
1003 }
1004
1005 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1006 if (bitwidth % 8 != 0) {
1007 return rewriter.notifyMatchFailure(
1008 arg&: launchOp, msg: "Operand to launch op is not a memref with a "
1009 "byte-aligned element type.");
1010 }
1011
1012 uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1013 static_cast<uint64_t>(memrefTy.getNumElements());
1014
1015 Value sizeArg = rewriter.create<LLVM::ConstantOp>(
1016 location: loc, args: getIndexType(), args: rewriter.getIndexAttr(value: staticSize));
1017 llvmArgumentsWithSizes.push_back(Elt: llvmArg); // Presumably a bare pointer.
1018 llvmArgumentsWithSizes.push_back(Elt: sizeArg);
1019 }
1020 }
1021
1022 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1023 if (launchOp.hasClusterSize()) {
1024 clusterSize =
1025 gpu::KernelDim3{.x: adaptor.getClusterSizeX(), .y: adaptor.getClusterSizeY(),
1026 .z: adaptor.getClusterSizeZ()};
1027 }
1028 rewriter.create<gpu::LaunchFuncOp>(
1029 location: launchOp.getLoc(), args: launchOp.getKernelAttr(),
1030 args: gpu::KernelDim3{.x: adaptor.getGridSizeX(), .y: adaptor.getGridSizeY(),
1031 .z: adaptor.getGridSizeZ()},
1032 args: gpu::KernelDim3{.x: adaptor.getBlockSizeX(), .y: adaptor.getBlockSizeY(),
1033 .z: adaptor.getBlockSizeZ()},
1034 args: adaptor.getDynamicSharedMemorySize(),
1035 args&: llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1036 args&: stream, args&: clusterSize);
1037 if (launchOp.getAsyncToken())
1038 rewriter.replaceOp(op: launchOp, newValues: {stream});
1039 else
1040 rewriter.eraseOp(op: launchOp);
1041 return success();
1042}
1043
1044static Value bitAndAddrspaceCast(Location loc,
1045 ConversionPatternRewriter &rewriter,
1046 LLVM::LLVMPointerType destinationType,
1047 Value sourcePtr,
1048 const LLVMTypeConverter &typeConverter) {
1049 auto sourceTy = cast<LLVM::LLVMPointerType>(Val: sourcePtr.getType());
1050 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1051 sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1052 location: loc,
1053 args: LLVM::LLVMPointerType::get(context: rewriter.getContext(),
1054 addressSpace: destinationType.getAddressSpace()),
1055 args&: sourcePtr);
1056 return sourcePtr;
1057}
1058
1059LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1060 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1061 ConversionPatternRewriter &rewriter) const {
1062 auto memRefType = cast<MemRefType>(Val: memcpyOp.getSrc().getType());
1063
1064 if (failed(Result: areAllLLVMTypes(op: memcpyOp, operands: adaptor.getOperands(), rewriter)) ||
1065 !isConvertibleAndHasIdentityMaps(type: memRefType) ||
1066 failed(Result: isAsyncWithOneDependency(rewriter, op: memcpyOp)))
1067 return failure();
1068
1069 auto loc = memcpyOp.getLoc();
1070
1071 MemRefDescriptor srcDesc(adaptor.getSrc());
1072 Value numElements = getNumElements(rewriter, loc, type: memRefType, desc: srcDesc);
1073
1074 Type elementPtrType = getElementPtrType(type: memRefType);
1075 Value nullPtr = rewriter.create<LLVM::ZeroOp>(location: loc, args&: elementPtrType);
1076 Value gepPtr = rewriter.create<LLVM::GEPOp>(
1077 location: loc, args&: elementPtrType,
1078 args: typeConverter->convertType(t: memRefType.getElementType()), args&: nullPtr,
1079 args&: numElements);
1080 auto sizeBytes =
1081 rewriter.create<LLVM::PtrToIntOp>(location: loc, args: getIndexType(), args&: gepPtr);
1082
1083 auto src = bitAndAddrspaceCast(loc, rewriter, destinationType: llvmPointerType,
1084 sourcePtr: srcDesc.alignedPtr(builder&: rewriter, loc),
1085 typeConverter: *getTypeConverter());
1086 auto dst = bitAndAddrspaceCast(
1087 loc, rewriter, destinationType: llvmPointerType,
1088 sourcePtr: MemRefDescriptor(adaptor.getDst()).alignedPtr(builder&: rewriter, loc),
1089 typeConverter: *getTypeConverter());
1090
1091 auto stream = adaptor.getAsyncDependencies().front();
1092 memcpyCallBuilder.create(loc, builder&: rewriter, arguments: {dst, src, sizeBytes, stream});
1093
1094 rewriter.replaceOp(op: memcpyOp, newValues: {stream});
1095
1096 return success();
1097}
1098
1099LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1100 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1101 ConversionPatternRewriter &rewriter) const {
1102 auto memRefType = cast<MemRefType>(Val: memsetOp.getDst().getType());
1103
1104 if (failed(Result: areAllLLVMTypes(op: memsetOp, operands: adaptor.getOperands(), rewriter)) ||
1105 !isConvertibleAndHasIdentityMaps(type: memRefType) ||
1106 failed(Result: isAsyncWithOneDependency(rewriter, op: memsetOp)))
1107 return failure();
1108
1109 auto loc = memsetOp.getLoc();
1110
1111 Type valueType = adaptor.getValue().getType();
1112 unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1113 // Ints and floats of 16 or 32 bit width are allowed.
1114 if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1115 return rewriter.notifyMatchFailure(
1116 arg&: memsetOp, msg: "value must be a 16 or 32 bit int or float");
1117 }
1118
1119 unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1120 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1121
1122 MemRefDescriptor dstDesc(adaptor.getDst());
1123 Value numElements = getNumElements(rewriter, loc, type: memRefType, desc: dstDesc);
1124
1125 auto value =
1126 rewriter.create<LLVM::BitcastOp>(location: loc, args&: bitCastType, args: adaptor.getValue());
1127 auto dst = bitAndAddrspaceCast(loc, rewriter, destinationType: llvmPointerType,
1128 sourcePtr: dstDesc.alignedPtr(builder&: rewriter, loc),
1129 typeConverter: *getTypeConverter());
1130
1131 auto stream = adaptor.getAsyncDependencies().front();
1132 FunctionCallBuilder builder =
1133 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1134 builder.create(loc, builder&: rewriter, arguments: {dst, value, numElements, stream});
1135
1136 rewriter.replaceOp(op: memsetOp, newValues: {stream});
1137 return success();
1138}
1139
1140LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1141 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1142 ConversionPatternRewriter &rewriter) const {
1143 Location loc = op.getLoc();
1144 auto call = setDefaultDeviceCallBuilder.create(loc, builder&: rewriter,
1145 arguments: {adaptor.getDevIndex()});
1146 rewriter.replaceOp(op, newOp: call);
1147 return success();
1148}
1149
1150template <typename T>
1151static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1152 Type llvmInt32Type = builder.getIntegerType(width: 32);
1153 return builder.create<LLVM::ConstantOp>(location: loc, args&: llvmInt32Type,
1154 args: static_cast<int32_t>(tValue));
1155}
1156
1157template <typename T>
1158static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1159 Type llvmFloat32Type = builder.getF32Type();
1160 return builder.create<LLVM::ConstantOp>(
1161 location: loc, args&: llvmFloat32Type,
1162 args: builder.getF32FloatAttr(value: static_cast<float>(tValue)));
1163}
1164
1165LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1166 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1167 ConversionPatternRewriter &rewriter) const {
1168 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1169 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1170 return failure();
1171 Location loc = op.getLoc();
1172 auto stream = adaptor.getAsyncDependencies().front();
1173 Value pTensor =
1174 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc);
1175 Type dType = op.getMemref().getType().getElementType();
1176 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1177
1178 SmallVector<Value, 4> dims;
1179 for (Value dim : adaptor.getDims()) {
1180 dims.push_back(Elt: dim);
1181 }
1182
1183 Value handle;
1184 // TODO: For now, we track the use of the handle and lower it to cusparse /
1185 // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1186 // used, we require two separate Creation ops to be the correct logic. In
1187 // future, we may add support to using one handle in sparse tensor / GPU
1188 // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1189 // the dnmat is used with spmat with 2:4 sparsity
1190 if (dims.size() == 2) {
1191 if (isSpMMCusparseLtOp(op: op.getDnTensor())) {
1192 auto handleSz = rewriter.create<LLVM::ConstantOp>(
1193 location: loc, args: getIndexType(), args: rewriter.getIndexAttr(value: 11032));
1194 handle = rewriter.create<LLVM::AllocaOp>(
1195 location: loc, args: llvmPointerType, args: llvmInt8Type, args&: handleSz, /*alignment=*/args: 16);
1196 handle = rewriter.create<LLVM::BitcastOp>(location: loc, args: llvmPointerType, args&: handle);
1197
1198 createLtDnMatCallBuilder
1199 .create(loc, builder&: rewriter,
1200 arguments: {handle, dims[0], dims[1], pTensor, dtp, stream})
1201 .getResult();
1202 } else {
1203 handle =
1204 createDnMatCallBuilder
1205 .create(loc, builder&: rewriter, arguments: {dims[0], dims[1], pTensor, dtp, stream})
1206 .getResult();
1207 }
1208 } else {
1209 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1210 handle = createDnVecCallBuilder
1211 .create(loc, builder&: rewriter, arguments: {dims[0], pTensor, dtp, stream})
1212 .getResult();
1213 }
1214 rewriter.replaceOp(op, newValues: {handle, stream});
1215 return success();
1216}
1217
1218LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1219 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1220 ConversionPatternRewriter &rewriter) const {
1221 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1222 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1223 return failure();
1224 Location loc = op.getLoc();
1225 auto stream = adaptor.getAsyncDependencies().front();
1226 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1227 SmallVector<Value, 4> dims;
1228 for (Value dim : definingOp.getDims()) {
1229 dims.push_back(Elt: dim);
1230 }
1231 if (dims.size() == 2) {
1232 // Use the cusparseLt destroy call if the dnmat is used with spmat with
1233 // 2:4 sparsity
1234 if (isSpMMCusparseLtOp(op: op.getDnTensor())) {
1235 destroyCuSparseLtDnMatBuilder.create(loc, builder&: rewriter,
1236 arguments: {adaptor.getDnTensor(), stream});
1237 } else {
1238 destroyDnMatCallBuilder.create(loc, builder&: rewriter,
1239 arguments: {adaptor.getDnTensor(), stream});
1240 }
1241 } else {
1242 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1243 destroyDnVecCallBuilder.create(loc, builder&: rewriter,
1244 arguments: {adaptor.getDnTensor(), stream});
1245 }
1246 rewriter.replaceOp(op, newValues: {stream});
1247 return success();
1248}
1249
1250LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1251 gpu::CreateCooOp op, OpAdaptor adaptor,
1252 ConversionPatternRewriter &rewriter) const {
1253 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1254 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1255 return failure();
1256 Location loc = op.getLoc();
1257 auto stream = adaptor.getAsyncDependencies().front();
1258 Value pRowIdxs =
1259 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(builder&: rewriter, loc);
1260 Value pColIdxs =
1261 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(builder&: rewriter, loc);
1262 Value pValues =
1263 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1264 Type iType =
1265 llvm::cast<MemRefType>(Val: op.getColIdxs().getType()).getElementType();
1266 Type dType =
1267 llvm::cast<MemRefType>(Val: op.getValues().getType()).getElementType();
1268 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1269 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1270 auto handle =
1271 createCooCallBuilder
1272 .create(loc, builder&: rewriter,
1273 arguments: {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1274 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1275 .getResult();
1276 rewriter.replaceOp(op, newValues: {handle, stream});
1277 return success();
1278}
1279
1280LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1281 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1282 ConversionPatternRewriter &rewriter) const {
1283 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1284 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1285 return failure();
1286 Location loc = op.getLoc();
1287 auto stream = adaptor.getAsyncDependencies().front();
1288 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(builder&: rewriter, loc);
1289 Value pValues =
1290 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1291 Type iType = llvm::cast<MemRefType>(Val: op.getIdxs().getType()).getElementType();
1292 Type dType =
1293 llvm::cast<MemRefType>(Val: op.getValues().getType()).getElementType();
1294 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1295 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1296 auto handle =
1297 createCooAoSCallBuilder
1298 .create(loc, builder&: rewriter,
1299 arguments: {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1300 pIdxs, pValues, itp, dtp, stream})
1301 .getResult();
1302 rewriter.replaceOp(op, newValues: {handle, stream});
1303 return success();
1304}
1305
1306LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1307 gpu::CreateCsrOp op, OpAdaptor adaptor,
1308 ConversionPatternRewriter &rewriter) const {
1309 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1310 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1311 return failure();
1312 Location loc = op.getLoc();
1313 auto stream = adaptor.getAsyncDependencies().front();
1314 Value pRowPos =
1315 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(builder&: rewriter, loc);
1316 Value pColIdxs =
1317 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(builder&: rewriter, loc);
1318 Value pValues =
1319 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1320 Type pType =
1321 llvm::cast<MemRefType>(Val: op.getRowPos().getType()).getElementType();
1322 Type iType =
1323 llvm::cast<MemRefType>(Val: op.getColIdxs().getType()).getElementType();
1324 Type dType =
1325 llvm::cast<MemRefType>(Val: op.getValues().getType()).getElementType();
1326 auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType));
1327 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1328 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1329 auto handle =
1330 createCsrCallBuilder
1331 .create(loc, builder&: rewriter,
1332 arguments: {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1333 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1334 .getResult();
1335 rewriter.replaceOp(op, newValues: {handle, stream});
1336 return success();
1337}
1338
1339LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1340 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1341 ConversionPatternRewriter &rewriter) const {
1342 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1343 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1344 return failure();
1345 Location loc = op.getLoc();
1346 auto stream = adaptor.getAsyncDependencies().front();
1347 Value pMat =
1348 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc);
1349 Type dType =
1350 llvm::cast<MemRefType>(Val: op.getMemref().getType()).getElementType();
1351 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1352
1353 // CUDA runner asserts the size is 44104 bytes.
1354 auto handleSz = rewriter.create<LLVM::ConstantOp>(
1355 location: loc, args: getIndexType(), args: rewriter.getIndexAttr(value: 44104));
1356 Value handle = rewriter.create<LLVM::AllocaOp>(
1357 location: loc, args: llvmPointerType, args: llvmInt8Type, args&: handleSz, /*alignment=*/args: 16);
1358 handle = rewriter.create<LLVM::BitcastOp>(location: loc, args: llvmPointerType, args&: handle);
1359
1360 create2To4SpMatCallBuilder
1361 .create(loc, builder&: rewriter,
1362 arguments: {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1363 .getResult();
1364 rewriter.replaceOp(op, newValues: {handle, stream});
1365 return success();
1366}
1367
1368LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1369 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1370 ConversionPatternRewriter &rewriter) const {
1371 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1372 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1373 return failure();
1374 Location loc = op.getLoc();
1375 auto stream = adaptor.getAsyncDependencies().front();
1376 // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1377 if (is2To4Sparsity(spMat: op.getSpmat())) {
1378 destroyCuSparseLtSpMatBuilder.create(loc, builder&: rewriter,
1379 arguments: {adaptor.getSpmat(), stream});
1380
1381 } else {
1382 destroySpMatCallBuilder.create(loc, builder&: rewriter, arguments: {adaptor.getSpmat(), stream});
1383 }
1384 rewriter.replaceOp(op, newValues: {stream});
1385 return success();
1386}
1387
1388LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1389 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1390 ConversionPatternRewriter &rewriter) const {
1391 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1392 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1393 return failure();
1394 Location loc = op.getLoc();
1395 auto modeA = genConstInt32From(builder&: rewriter, loc, tValue: op.getModeA());
1396 auto computeType = genConstInt32From(
1397 builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: adaptor.getComputeType()));
1398 auto stream = adaptor.getAsyncDependencies().front();
1399 auto bufferSize = spMVBufferSizeCallBuilder
1400 .create(loc, builder&: rewriter,
1401 arguments: {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1402 adaptor.getDnY(), computeType, stream})
1403 .getResult();
1404 rewriter.replaceOp(op, newValues: {bufferSize, stream});
1405 return success();
1406}
1407
1408LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1409 gpu::SpMVOp op, OpAdaptor adaptor,
1410 ConversionPatternRewriter &rewriter) const {
1411 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1412 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1413 return failure();
1414 Location loc = op.getLoc();
1415 auto modeA = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeA());
1416 auto computeType = genConstInt32From(
1417 builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: adaptor.getComputeType()));
1418 auto stream = adaptor.getAsyncDependencies().front();
1419 Value pBuf =
1420 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc);
1421 spMVCallBuilder.create(loc, builder&: rewriter,
1422 arguments: {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1423 adaptor.getDnY(), computeType, pBuf, stream});
1424 rewriter.replaceOp(op, newValues: {stream});
1425 return success();
1426}
1427
1428LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1429 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1430 ConversionPatternRewriter &rewriter) const {
1431 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1432 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1433 return failure();
1434 Location loc = op.getLoc();
1435 auto modeA = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeA());
1436 auto modeB = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeB());
1437 auto stream = adaptor.getAsyncDependencies().front();
1438 Value bufferSize;
1439 if (is2To4Sparsity(spMat: op.getSpmatA())) {
1440 auto pruneFlag =
1441 genConstInt32From(builder&: rewriter, loc, tValue: get2To4PruneFlag(spMat: op.getSpmatA()));
1442 auto computeType = genConstInt32From(
1443 builder&: rewriter, loc, tValue: getCuSparseLtDataTypeFrom(type: adaptor.getComputeType()));
1444 auto three = rewriter.create<LLVM::ConstantOp>(location: loc, args: getIndexType(),
1445 args: rewriter.getIndexAttr(value: 3));
1446 auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1447 location: loc, args: llvmPointerType, args: llvmPointerType, args&: three, /*alignment=*/args: 16);
1448 createCuSparseLtSpMMBufferSizeBuilder
1449 .create(loc, builder&: rewriter,
1450 arguments: {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1451 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1452 pruneFlag, stream})
1453 .getResult();
1454
1455 auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1456 location: loc, args: llvmPointerType, args: llvmPointerType, args&: bufferSize,
1457 args: ValueRange{rewriter.create<LLVM::ConstantOp>(
1458 location: loc, args: getIndexType(), args: rewriter.getIndexAttr(value: 1))});
1459 auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1460 location: loc, args: llvmPointerType, args: llvmPointerType, args&: bufferSize,
1461 args: ValueRange{rewriter.create<LLVM::ConstantOp>(
1462 location: loc, args: getIndexType(), args: rewriter.getIndexAttr(value: 2))});
1463 auto bufferSize0 =
1464 rewriter.create<LLVM::LoadOp>(location: loc, args: llvmInt64Type, args&: bufferSize);
1465 auto bufferSize1 =
1466 rewriter.create<LLVM::LoadOp>(location: loc, args: llvmInt64Type, args&: bufferSizePtr1);
1467 auto bufferSize2 =
1468 rewriter.create<LLVM::LoadOp>(location: loc, args: llvmInt64Type, args&: bufferSizePtr2);
1469
1470 rewriter.replaceOp(op, newValues: {bufferSize0, bufferSize1, bufferSize2, stream});
1471 } else {
1472 auto computeType = genConstInt32From(
1473 builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: adaptor.getComputeType()));
1474 bufferSize =
1475 createSpMMBufferSizeCallBuilder
1476 .create(loc, builder&: rewriter,
1477 arguments: {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1478 adaptor.getDnmatC(), computeType, stream})
1479 .getResult();
1480 rewriter.replaceOp(op, newValues: {bufferSize, stream});
1481 }
1482 return success();
1483}
1484
1485LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1486 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1487 ConversionPatternRewriter &rewriter) const {
1488 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1489 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1490 return failure();
1491 Location loc = op.getLoc();
1492 auto modeA = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeA());
1493 auto modeB = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeB());
1494 auto computeType = genConstInt32From(
1495 builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: adaptor.getComputeType()));
1496 auto stream = adaptor.getAsyncDependencies().front();
1497 auto bufferSize =
1498 createSDDMMBufferSizeCallBuilder
1499 .create(loc, builder&: rewriter,
1500 arguments: {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1501 adaptor.getSpmatC(), computeType, stream})
1502 .getResult();
1503 rewriter.replaceOp(op, newValues: {bufferSize, stream});
1504 return success();
1505}
1506
1507LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1508 gpu::SpMMOp op, OpAdaptor adaptor,
1509 ConversionPatternRewriter &rewriter) const {
1510 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1511 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1512 return failure();
1513 Location loc = op.getLoc();
1514 auto modeA = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeA());
1515 auto modeB = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeB());
1516 auto computeType = genConstInt32From(
1517 builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: adaptor.getComputeType()));
1518
1519 auto stream = adaptor.getAsyncDependencies().front();
1520
1521 // Lower to cusparseLt if applicable
1522 if (is2To4Sparsity(spMat: op.getSpmatA())) {
1523 SmallVector<Value> pBufs;
1524 for (Value buffer : adaptor.getBuffers()) {
1525 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(builder&: rewriter, loc);
1526 pBufs.push_back(Elt: pBuf);
1527 }
1528 createCuSparseLtSpMMBuilder.create(
1529 loc, builder&: rewriter,
1530 arguments: {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1531 pBufs[0], pBufs[1], pBufs[2], stream});
1532 } else {
1533 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1534 .allocatedPtr(builder&: rewriter, loc);
1535 createSpMMCallBuilder.create(loc, builder&: rewriter,
1536 arguments: {modeA, modeB, adaptor.getSpmatA(),
1537 adaptor.getDnmatB(), adaptor.getDnmatC(),
1538 computeType, pBuf, stream});
1539 }
1540 rewriter.replaceOp(op, newValues: {stream});
1541 return success();
1542}
1543
1544template <typename T>
1545static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
1546 converter.addConversion([&converter](T) -> Type {
1547 return LLVM::LLVMPointerType::get(context: &converter.getContext());
1548 });
1549}
1550
1551LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1552 gpu::SDDMMOp op, OpAdaptor adaptor,
1553 ConversionPatternRewriter &rewriter) const {
1554 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1555 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1556 return failure();
1557 Location loc = op.getLoc();
1558 auto computeType = genConstInt32From(
1559 builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: adaptor.getComputeType()));
1560 auto modeA = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeA());
1561 auto modeB = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeB());
1562 auto stream = adaptor.getAsyncDependencies().front();
1563 Value pBuf =
1564 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc);
1565 createSDDMMCallBuilder.create(loc, builder&: rewriter,
1566 arguments: {modeA, modeB, adaptor.getDnmatA(),
1567 adaptor.getDnmatB(), adaptor.getSpmatC(),
1568 computeType, pBuf, stream});
1569 rewriter.replaceOp(op, newValues: {stream});
1570 return success();
1571}
1572
1573LogicalResult
1574ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1575 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1576 ConversionPatternRewriter &rewriter) const {
1577 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1578 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1579 return failure();
1580 Location loc = op.getLoc();
1581 auto stream = adaptor.getAsyncDependencies().front();
1582 Value descr = createSpGEMMCreateDescrBuilder.create(loc, builder&: rewriter, arguments: {stream})
1583 .getResult();
1584 rewriter.replaceOp(op, newValues: {descr, stream});
1585 return success();
1586}
1587
1588LogicalResult
1589ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1590 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1591 ConversionPatternRewriter &rewriter) const {
1592 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1593 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1594 return failure();
1595 Location loc = op.getLoc();
1596 auto stream = adaptor.getAsyncDependencies().front();
1597 createSpGEMMDestroyDescrBuilder.create(loc, builder&: rewriter,
1598 arguments: {adaptor.getDesc(), stream});
1599 rewriter.replaceOp(op, newValues: {stream});
1600 return success();
1601}
1602
1603LogicalResult
1604ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1605 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1606 ConversionPatternRewriter &rewriter) const {
1607 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1608 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1609 return failure();
1610 Location loc = op.getLoc();
1611 auto computeType = genConstInt32From(
1612 builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: adaptor.getComputeType()));
1613 auto modeA = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeA());
1614 auto modeB = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeB());
1615 auto stream = adaptor.getAsyncDependencies().front();
1616
1617 Value pBuf =
1618 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc);
1619 Value bufferSizeNew;
1620
1621 if (adaptor.getKind() ==
1622 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1623 bufferSizeNew =
1624 createSpGEMMWorkEstimationBuilder
1625 .create(loc, builder&: rewriter,
1626 arguments: {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1627 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1628 adaptor.getBufferSz(), pBuf, stream})
1629 .getResult();
1630 } else {
1631 bufferSizeNew =
1632 createSpGEMMComputeBuilder
1633 .create(loc, builder&: rewriter,
1634 arguments: {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1635 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1636 adaptor.getBufferSz(), pBuf, stream})
1637 .getResult();
1638 }
1639 rewriter.replaceOp(op, newValues: {bufferSizeNew, stream});
1640 return success();
1641}
1642
1643LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1644 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1645 ConversionPatternRewriter &rewriter) const {
1646 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1647 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1648 return failure();
1649 Location loc = op.getLoc();
1650 auto computeType = genConstInt32From(
1651 builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: adaptor.getComputeType()));
1652 auto modeA = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeA());
1653 auto modeB = genConstInt32From(builder&: rewriter, loc, tValue: adaptor.getModeB());
1654 auto stream = adaptor.getAsyncDependencies().front();
1655 createSpGEMMCopyBuilder.create(loc, builder&: rewriter,
1656 arguments: {adaptor.getDesc(), modeA, modeB,
1657 adaptor.getSpmatA(), adaptor.getSpmatB(),
1658 adaptor.getSpmatC(), computeType, stream});
1659 rewriter.replaceOp(op, newValues: {stream});
1660 return success();
1661}
1662
1663LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1664 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1665 ConversionPatternRewriter &rewriter) const {
1666 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1667 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1668 return failure();
1669 Location loc = op.getLoc();
1670 auto stream = adaptor.getAsyncDependencies().front();
1671
1672 auto three = rewriter.create<LLVM::ConstantOp>(location: loc, args: getIndexType(),
1673 args: rewriter.getIndexAttr(value: 3));
1674 auto buffer = rewriter.create<LLVM::AllocaOp>(
1675 location: loc, args: llvmPointerType, args: llvmInt64Type, args&: three, /*alignment=*/args: 16);
1676
1677 auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1678 location: loc, args: llvmPointerType, args: llvmPointerType, args&: buffer,
1679 args: ValueRange{rewriter.create<LLVM::ConstantOp>(location: loc, args: getIndexType(),
1680 args: rewriter.getIndexAttr(value: 0))});
1681 auto colsPtr = rewriter.create<LLVM::GEPOp>(
1682 location: loc, args: llvmPointerType, args: llvmPointerType, args&: buffer,
1683 args: ValueRange{rewriter.create<LLVM::ConstantOp>(location: loc, args: getIndexType(),
1684 args: rewriter.getIndexAttr(value: 1))});
1685 auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1686 location: loc, args: llvmPointerType, args: llvmPointerType, args&: buffer,
1687 args: ValueRange{rewriter.create<LLVM::ConstantOp>(location: loc, args: getIndexType(),
1688 args: rewriter.getIndexAttr(value: 2))});
1689 createSpMatGetSizeBuilder.create(
1690 loc, builder&: rewriter, arguments: {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1691 auto rows = rewriter.create<LLVM::LoadOp>(location: loc, args: llvmInt64Type, args&: rowsPtr);
1692 auto cols = rewriter.create<LLVM::LoadOp>(location: loc, args: llvmInt64Type, args&: colsPtr);
1693 auto nnzs = rewriter.create<LLVM::LoadOp>(location: loc, args: llvmInt64Type, args&: nnzsPtr);
1694
1695 rewriter.replaceOp(op, newValues: {rows, cols, nnzs, stream});
1696 return success();
1697}
1698
1699LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1700 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1701 ConversionPatternRewriter &rewriter) const {
1702 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1703 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1704 return failure();
1705 Location loc = op.getLoc();
1706 auto stream = adaptor.getAsyncDependencies().front();
1707 Value pPos =
1708 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(builder&: rewriter, loc);
1709 Value pCrd =
1710 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(builder&: rewriter, loc);
1711 Value pVal =
1712 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1713 createSetCsrPointersBuilder.create(
1714 loc, builder&: rewriter, arguments: {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1715 rewriter.replaceOp(op, newValues: {stream});
1716 return success();
1717}
1718
1719LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1720 gpu::CreateCscOp op, OpAdaptor adaptor,
1721 ConversionPatternRewriter &rewriter) const {
1722 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1723 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1724 return failure();
1725 Location loc = op.getLoc();
1726 auto stream = adaptor.getAsyncDependencies().front();
1727 Value pColPos =
1728 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(builder&: rewriter, loc);
1729 Value pRowIdxs =
1730 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(builder&: rewriter, loc);
1731 Value pValues =
1732 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1733 Type pType =
1734 llvm::cast<MemRefType>(Val: op.getColPos().getType()).getElementType();
1735 Type iType =
1736 llvm::cast<MemRefType>(Val: op.getRowIdxs().getType()).getElementType();
1737 Type dType =
1738 llvm::cast<MemRefType>(Val: op.getValues().getType()).getElementType();
1739 auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType));
1740 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1741 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1742 auto handle =
1743 createCscCallBuilder
1744 .create(loc, builder&: rewriter,
1745 arguments: {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1746 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1747 .getResult();
1748 rewriter.replaceOp(op, newValues: {handle, stream});
1749 return success();
1750}
1751
1752LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1753 gpu::CreateBsrOp op, OpAdaptor adaptor,
1754 ConversionPatternRewriter &rewriter) const {
1755 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)) ||
1756 failed(Result: isAsyncWithOneDependency(rewriter, op)))
1757 return failure();
1758 Location loc = op.getLoc();
1759 auto stream = adaptor.getAsyncDependencies().front();
1760 Value pRowPos =
1761 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(builder&: rewriter, loc);
1762 Value pColIdxs =
1763 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(builder&: rewriter, loc);
1764 Value pValues =
1765 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1766 Type pType =
1767 llvm::cast<MemRefType>(Val: op.getBRowPos().getType()).getElementType();
1768 Type iType =
1769 llvm::cast<MemRefType>(Val: op.getBColIdxs().getType()).getElementType();
1770 Type dType =
1771 llvm::cast<MemRefType>(Val: op.getValues().getType()).getElementType();
1772 auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType));
1773 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1774 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1775 auto handle =
1776 createBsrCallBuilder
1777 .create(loc, builder&: rewriter,
1778 arguments: {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1779 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1780 pColIdxs, pValues, ptp, itp, dtp, stream})
1781 .getResult();
1782 rewriter.replaceOp(op, newValues: {handle, stream});
1783 return success();
1784}
1785
1786void mlir::populateGpuToLLVMConversionPatterns(
1787 LLVMTypeConverter &converter, RewritePatternSet &patterns,
1788 bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
1789 addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1790 addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1791 addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1792 addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1793
1794 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1795 ConvertDeallocOpToGpuRuntimeCallPattern,
1796 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1797 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1798 ConvertMemcpyOpToGpuRuntimeCallPattern,
1799 ConvertMemsetOpToGpuRuntimeCallPattern,
1800 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1801 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1802 ConvertWaitOpToGpuRuntimeCallPattern,
1803 ConvertAsyncYieldToGpuRuntimeCallPattern,
1804 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1805 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1806 ConvertCreateCooOpToGpuRuntimeCallPattern,
1807 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1808 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1809 ConvertCreateCscOpToGpuRuntimeCallPattern,
1810 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1811 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1812 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1813 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1814 ConvertSpMVOpToGpuRuntimeCallPattern,
1815 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1816 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1817 ConvertSpMMOpToGpuRuntimeCallPattern,
1818 ConvertSDDMMOpToGpuRuntimeCallPattern,
1819 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1820 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1821 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1822 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1823 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1824 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(arg&: converter);
1825 patterns.add<LegalizeLaunchFuncOpPattern>(arg&: converter, args&: kernelBarePtrCallConv,
1826 args&: kernelIntersperseSizeCallConv);
1827}
1828
1829//===----------------------------------------------------------------------===//
1830// GPUModuleOp convert to LLVM op interface
1831//===----------------------------------------------------------------------===//
1832
1833namespace {
1834struct GPUModuleOpConvertToLLVMInterface
1835 : public ConvertToLLVMOpInterface::ExternalModel<
1836 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1837 /// Get the conversion patterns from the target attribute.
1838 void getConvertToLLVMConversionAttrs(
1839 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const;
1840};
1841} // namespace
1842
1843void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1844 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
1845 auto module = cast<gpu::GPUModuleOp>(Val: op);
1846 ArrayAttr targetsAttr = module.getTargetsAttr();
1847 // Fail if there are no target attributes or there is more than one target.
1848 if (!targetsAttr || targetsAttr.size() != 1)
1849 return;
1850 if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(Val: targetsAttr[0]))
1851 attrs.push_back(Elt: patternAttr);
1852}
1853
1854void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry &registry) {
1855 registry.addExtension(extensionFn: +[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
1856 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(context&: *ctx);
1857 });
1858}
1859

source code of mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp