1//===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert gpu.launch_func op into a sequence of
10// GPU runtime calls. As most of GPU runtimes does not have a stable published
11// ABI, this pass uses a slim runtime layer that builds on top of the public
12// API from GPU runtime headers.
13//
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17
18#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
19#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
21#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
22#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
23#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
24#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
25#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
26#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
27#include "mlir/Conversion/LLVMCommon/Pattern.h"
28#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
29#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
30#include "mlir/Dialect/Async/IR/Async.h"
31#include "mlir/Dialect/GPU/IR/GPUDialect.h"
32#include "mlir/Dialect/GPU/Transforms/Passes.h"
33#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
34#include "mlir/Dialect/MemRef/IR/MemRef.h"
35#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
36#include "mlir/IR/Attributes.h"
37#include "mlir/IR/Builders.h"
38#include "mlir/IR/BuiltinOps.h"
39#include "mlir/IR/BuiltinTypes.h"
40#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
41
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/Support/Error.h"
44#include "llvm/Support/FormatVariadic.h"
45
46#define DEBUG_TYPE "gpu-to-llvm"
47
48namespace mlir {
49#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
50#include "mlir/Conversion/Passes.h.inc"
51} // namespace mlir
52
53using namespace mlir;
54
55namespace {
56class GpuToLLVMConversionPass
57 : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
58public:
59 using Base::Base;
60 void getDependentDialects(DialectRegistry &registry) const final {
61 Base::getDependentDialects(registry);
62 registerConvertToLLVMDependentDialectLoading(registry);
63 }
64 // Run the dialect converter on the module.
65 void runOnOperation() override;
66};
67
68template <typename OpTy>
69class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
70public:
71 explicit ConvertOpToGpuRuntimeCallPattern(
72 const LLVMTypeConverter &typeConverter)
73 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
74
75protected:
76 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
77 MemRefType type, MemRefDescriptor desc) const {
78 Type indexType = ConvertToLLVMPattern::getIndexType();
79 if (type.hasStaticShape())
80 return ConvertToLLVMPattern::createIndexAttrConstant(
81 builder&: rewriter, loc, resultType: indexType, value: type.getNumElements());
82 // Compute the number of elements by multiplying all the dim sizes.
83 uint64_t rank = type.getRank();
84 Value numElements = desc.size(builder&: rewriter, loc, /*pos=*/0);
85 for (unsigned i = 1; i < rank; i++)
86 numElements = rewriter.create<LLVM::MulOp>(
87 loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
88 return numElements;
89 }
90
91 MLIRContext *context = &this->getTypeConverter()->getContext();
92
93 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
94 LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context);
95 Type llvmInt8Type = IntegerType::get(context, 8);
96 Type llvmInt16Type = IntegerType::get(context, 16);
97 Type llvmInt32Type = IntegerType::get(context, 32);
98 Type llvmInt64Type = IntegerType::get(context, 64);
99 Type llvmFloat32Type = Float32Type::get(context);
100 Type llvmIntPtrType = IntegerType::get(
101 context, this->getTypeConverter()->getPointerBitwidth(0));
102
103 FunctionCallBuilder streamCreateCallBuilder = {
104 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
105 FunctionCallBuilder streamDestroyCallBuilder = {
106 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
107 FunctionCallBuilder streamSynchronizeCallBuilder = {
108 "mgpuStreamSynchronize",
109 llvmVoidType,
110 {llvmPointerType /* void *stream */}};
111 FunctionCallBuilder streamWaitEventCallBuilder = {
112 "mgpuStreamWaitEvent",
113 llvmVoidType,
114 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
115 FunctionCallBuilder eventCreateCallBuilder = {
116 "mgpuEventCreate", llvmPointerType /* void *event */, {}};
117 FunctionCallBuilder eventDestroyCallBuilder = {
118 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
119 FunctionCallBuilder eventSynchronizeCallBuilder = {
120 "mgpuEventSynchronize",
121 llvmVoidType,
122 {llvmPointerType /* void *event */}};
123 FunctionCallBuilder eventRecordCallBuilder = {
124 "mgpuEventRecord",
125 llvmVoidType,
126 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
127 FunctionCallBuilder hostRegisterCallBuilder = {
128 "mgpuMemHostRegisterMemRef",
129 llvmVoidType,
130 {llvmIntPtrType /* intptr_t rank */,
131 llvmPointerType /* void *memrefDesc */,
132 llvmIntPtrType /* intptr_t elementSizeBytes */}};
133 FunctionCallBuilder hostUnregisterCallBuilder = {
134 "mgpuMemHostUnregisterMemRef",
135 llvmVoidType,
136 {llvmIntPtrType /* intptr_t rank */,
137 llvmPointerType /* void *memrefDesc */,
138 llvmIntPtrType /* intptr_t elementSizeBytes */}};
139 FunctionCallBuilder allocCallBuilder = {
140 "mgpuMemAlloc",
141 llvmPointerType /* void * */,
142 {llvmIntPtrType /* intptr_t sizeBytes */,
143 llvmPointerType /* void *stream */,
144 llvmInt8Type /* bool isHostShared */}};
145 FunctionCallBuilder deallocCallBuilder = {
146 "mgpuMemFree",
147 llvmVoidType,
148 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
149 FunctionCallBuilder memcpyCallBuilder = {
150 "mgpuMemcpy",
151 llvmVoidType,
152 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
153 llvmIntPtrType /* intptr_t sizeBytes */,
154 llvmPointerType /* void *stream */}};
155 FunctionCallBuilder memset16CallBuilder = {
156 "mgpuMemset16",
157 llvmVoidType,
158 {llvmPointerType /* void *dst */,
159 llvmInt16Type /* unsigned short value */,
160 llvmIntPtrType /* intptr_t sizeBytes */,
161 llvmPointerType /* void *stream */}};
162 FunctionCallBuilder memset32CallBuilder = {
163 "mgpuMemset32",
164 llvmVoidType,
165 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
166 llvmIntPtrType /* intptr_t sizeBytes */,
167 llvmPointerType /* void *stream */}};
168 FunctionCallBuilder setDefaultDeviceCallBuilder = {
169 "mgpuSetDefaultDevice",
170 llvmVoidType,
171 {llvmInt32Type /* uint32_t devIndex */}};
172 FunctionCallBuilder createDnVecCallBuilder = {
173 "mgpuCreateDnVec",
174 llvmPointerType,
175 {llvmIntPtrType, llvmPointerType, llvmInt32Type,
176 llvmPointerType /* void *stream */}};
177 FunctionCallBuilder destroyDnVecCallBuilder = {
178 "mgpuDestroyDnVec",
179 llvmVoidType,
180 {llvmPointerType, llvmPointerType /* void *stream */}};
181 FunctionCallBuilder createDnMatCallBuilder = {
182 "mgpuCreateDnMat",
183 llvmPointerType,
184 {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type,
185 llvmPointerType /* void *stream */}};
186 FunctionCallBuilder destroyDnMatCallBuilder = {
187 "mgpuDestroyDnMat",
188 llvmVoidType,
189 {llvmPointerType, llvmPointerType /* void *stream */}};
190 FunctionCallBuilder createCooCallBuilder = {
191 "mgpuCreateCoo",
192 llvmPointerType,
193 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
194 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
195 llvmPointerType /* void *stream */}};
196 FunctionCallBuilder createCooAoSCallBuilder = {
197 "mgpuCreateCooAoS", // deprecated in cuSPARSE 11.2
198 llvmPointerType,
199 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
200 llvmPointerType, llvmInt32Type, llvmInt32Type,
201 llvmPointerType /* void *stream */}};
202 FunctionCallBuilder createCsrCallBuilder = {
203 "mgpuCreateCsr",
204 llvmPointerType,
205 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
206 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
207 llvmInt32Type, llvmPointerType /* void *stream */}};
208 FunctionCallBuilder createCscCallBuilder = {
209 "mgpuCreateCsc",
210 llvmPointerType,
211 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
212 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
213 llvmInt32Type, llvmPointerType /* void *stream */}};
214 FunctionCallBuilder createBsrCallBuilder = {
215 "mgpuCreateBsr",
216 llvmPointerType,
217 {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
218 llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
219 llvmInt32Type, llvmInt32Type, llvmInt32Type,
220 llvmPointerType /* void *stream */}};
221 FunctionCallBuilder destroySpMatCallBuilder = {
222 "mgpuDestroySpMat",
223 llvmVoidType,
224 {llvmPointerType, llvmPointerType /* void *stream */}};
225 FunctionCallBuilder spMVBufferSizeCallBuilder = {
226 "mgpuSpMVBufferSize",
227 llvmIntPtrType,
228 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
229 llvmInt32Type, llvmPointerType /* void *stream */}};
230 FunctionCallBuilder spMVCallBuilder = {
231 "mgpuSpMV",
232 llvmVoidType,
233 {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
234 llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
235 FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
236 "mgpuSpMMBufferSize",
237 llvmIntPtrType,
238 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
239 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
240 FunctionCallBuilder createSpMMCallBuilder = {
241 "mgpuSpMM",
242 llvmVoidType,
243 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
244 llvmPointerType, llvmInt32Type, llvmPointerType,
245 llvmPointerType /* void *stream */}};
246 FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
247 "mgpuSDDMMBufferSize",
248 llvmIntPtrType,
249 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
250 llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
251 FunctionCallBuilder createSDDMMCallBuilder = {
252 "mgpuSDDMM",
253 llvmVoidType,
254 {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType,
255 llvmPointerType, llvmInt32Type, llvmPointerType,
256 llvmPointerType /* void *stream */}};
257 FunctionCallBuilder createLtDnMatCallBuilder = {
258 "mgpuCreateCuSparseLtDnMat",
259 llvmVoidType,
260 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
261 llvmInt32Type, llvmPointerType /* void *stream */}};
262 FunctionCallBuilder destroyCuSparseLtSpMatBuilder = {
263 "mgpuDestroyCuSparseLtSpMat",
264 llvmVoidType,
265 {llvmPointerType, llvmPointerType /* void *stream */}};
266 FunctionCallBuilder destroyCuSparseLtDnMatBuilder = {
267 "mgpuDestroyCuSparseLtDnMat",
268 llvmVoidType,
269 {llvmPointerType, llvmPointerType /* void *stream */}};
270 FunctionCallBuilder create2To4SpMatCallBuilder = {
271 "mgpuCusparseLtCreate2To4SpMat",
272 llvmVoidType,
273 {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
274 llvmInt32Type, llvmPointerType /* void *stream */}};
275 FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
276 "mgpuCuSparseLtSpMMBufferSize",
277 llvmVoidType,
278 {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
279 llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
280 llvmPointerType /*void *stream*/}};
281 FunctionCallBuilder createCuSparseLtSpMMBuilder = {
282 "mgpuCuSparseLtSpMM",
283 llvmVoidType,
284 {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
285 llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
286 FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
287 "mgpuSpGEMMCreateDescr",
288 llvmPointerType,
289 {llvmPointerType /*void *stream*/}};
290 FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
291 "mgpuSpGEMMDestroyDescr",
292 llvmVoidType,
293 {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
294 FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
295 "mgpuSpGEMMWorkEstimation",
296 llvmIntPtrType,
297 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
298 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
299 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
300 llvmPointerType /*void *stream*/}};
301 FunctionCallBuilder createSpGEMMComputeBuilder = {
302 "mgpuSpGEMMCompute",
303 llvmIntPtrType,
304 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
305 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
306 llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
307 llvmPointerType /*void *stream*/}};
308 FunctionCallBuilder createSpGEMMCopyBuilder = {
309 "mgpuSpGEMMCopy",
310 llvmVoidType,
311 {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
312 llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
313 llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
314 FunctionCallBuilder createSpMatGetSizeBuilder = {
315 "mgpuSpMatGetSize",
316 llvmVoidType,
317 {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
318 llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
319 FunctionCallBuilder createSetCsrPointersBuilder = {
320 "mgpuSetCsrPointers",
321 llvmVoidType,
322 {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
323 llvmPointerType /*crd*/, llvmPointerType /*val*/,
324 llvmPointerType /*void *stream*/}};
325};
326
327/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
328/// call. Currently it supports CUDA and ROCm (HIP).
329class ConvertHostRegisterOpToGpuRuntimeCallPattern
330 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
331public:
332 ConvertHostRegisterOpToGpuRuntimeCallPattern(
333 const LLVMTypeConverter &typeConverter)
334 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
335
336private:
337 LogicalResult
338 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override;
340};
341
342class ConvertHostUnregisterOpToGpuRuntimeCallPattern
343 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> {
344public:
345 ConvertHostUnregisterOpToGpuRuntimeCallPattern(
346 const LLVMTypeConverter &typeConverter)
347 : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) {
348 }
349
350private:
351 LogicalResult
352 matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter) const override;
354};
355
356/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
357/// call. Currently it supports CUDA and ROCm (HIP).
358class ConvertAllocOpToGpuRuntimeCallPattern
359 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
360public:
361 ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
362 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
363
364private:
365 LogicalResult
366 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
367 ConversionPatternRewriter &rewriter) const override;
368};
369
370/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
371/// call. Currently it supports CUDA and ROCm (HIP).
372class ConvertDeallocOpToGpuRuntimeCallPattern
373 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
374public:
375 ConvertDeallocOpToGpuRuntimeCallPattern(
376 const LLVMTypeConverter &typeConverter)
377 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
378
379private:
380 LogicalResult
381 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
382 ConversionPatternRewriter &rewriter) const override;
383};
384
385class ConvertAsyncYieldToGpuRuntimeCallPattern
386 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
387public:
388 ConvertAsyncYieldToGpuRuntimeCallPattern(
389 const LLVMTypeConverter &typeConverter)
390 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
391
392private:
393 LogicalResult
394 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
395 ConversionPatternRewriter &rewriter) const override;
396};
397
398/// A rewrite pattern to convert gpu.wait operations into a GPU runtime
399/// call. Currently it supports CUDA and ROCm (HIP).
400class ConvertWaitOpToGpuRuntimeCallPattern
401 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
402public:
403 ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
404 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
405
406private:
407 LogicalResult
408 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
409 ConversionPatternRewriter &rewriter) const override;
410};
411
412/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
413/// call. Currently it supports CUDA and ROCm (HIP).
414class ConvertWaitAsyncOpToGpuRuntimeCallPattern
415 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
416public:
417 ConvertWaitAsyncOpToGpuRuntimeCallPattern(
418 const LLVMTypeConverter &typeConverter)
419 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
420
421private:
422 LogicalResult
423 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
424 ConversionPatternRewriter &rewriter) const override;
425};
426
427/// A rewrite patter to legalize gpu.launch_func with LLVM types.
428class LegalizeLaunchFuncOpPattern
429 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
430public:
431 LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
432 bool kernelBarePtrCallConv,
433 bool kernelIntersperseSizeCallConv)
434 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
435 kernelBarePtrCallConv(kernelBarePtrCallConv),
436 kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
437
438private:
439 LogicalResult
440 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
441 ConversionPatternRewriter &rewriter) const override;
442
443 bool kernelBarePtrCallConv;
444 bool kernelIntersperseSizeCallConv;
445};
446
447/// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
448/// call. Currently it supports CUDA and ROCm (HIP).
449class ConvertMemcpyOpToGpuRuntimeCallPattern
450 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
451public:
452 ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
453 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
454
455private:
456 LogicalResult
457 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
458 ConversionPatternRewriter &rewriter) const override;
459};
460
461/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
462/// call. Currently it supports CUDA and ROCm (HIP).
463class ConvertMemsetOpToGpuRuntimeCallPattern
464 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
465public:
466 ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter)
467 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
468
469private:
470 LogicalResult
471 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
472 ConversionPatternRewriter &rewriter) const override;
473};
474
475/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
476/// Currently supports CUDA and ROCm (HIP)
477class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
478 : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
479public:
480 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
481 const LLVMTypeConverter &typeConverter)
482 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
483 typeConverter) {}
484
485 LogicalResult
486 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
487 ConversionPatternRewriter &rewriter) const override;
488};
489
490/// Generic rewriting rule for operation on sparse matrices.
491/// Currently supports CUDA (by means of cuSparse and cuSparseLt).
492#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \
493 class Convert##op_name##ToGpuRuntimeCallPattern \
494 : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \
495 public: \
496 Convert##op_name##ToGpuRuntimeCallPattern( \
497 const LLVMTypeConverter &typeConverter) \
498 : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \
499 \
500 private: \
501 LogicalResult \
502 matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \
503 ConversionPatternRewriter &rewriter) const override; \
504 };
505
506DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp)
507DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp)
508DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp)
509DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp)
510DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp)
511DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCscOp)
512DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateBsrOp)
513DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp)
514DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp)
515DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp)
516DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp)
517DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp)
518DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp)
519DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp)
520DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp)
521DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp)
522DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp)
523DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
524DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp)
525DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMatGetSizeOp)
526DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
527
528} // namespace
529
530void GpuToLLVMConversionPass::runOnOperation() {
531 MLIRContext *context = &getContext();
532
533 // Perform progressive lowering of vector transfer operations.
534 {
535 RewritePatternSet patterns(&getContext());
536 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
537 vector::populateVectorTransferLoweringPatterns(patterns,
538 /*maxTransferRank=*/1);
539 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
540 return signalPassFailure();
541 }
542
543 LowerToLLVMOptions options(context);
544 options.useBarePtrCallConv = hostBarePtrCallConv;
545 RewritePatternSet patterns(context);
546 ConversionTarget target(*context);
547 target.addLegalDialect<LLVM::LLVMDialect>();
548 LLVMTypeConverter converter(context, options);
549
550 // Populate all patterns from all dialects that implement the
551 // `ConvertToLLVMPatternInterface` interface.
552 for (Dialect *dialect : context->getLoadedDialects()) {
553 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
554 if (!iface)
555 continue;
556 iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
557 }
558
559 // Preserve GPU modules and binaries. Modules are preserved as they can be
560 // converted later by `gpu-module-to-binary`.
561 target.addLegalOp<gpu::GPUModuleOp, gpu::BinaryOp>();
562 // Accept as legal LaunchFuncOps if the operands have been lowered.
563 target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
564 [&](gpu::LaunchFuncOp op) -> bool { return converter.isLegal(op); });
565
566 // These aren't covered by the ConvertToLLVMPatternInterface right now.
567 populateVectorToLLVMConversionPatterns(converter, patterns);
568 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
569 populateAsyncStructuralTypeConversionsAndLegality(typeConverter&: converter, patterns,
570 target);
571 populateGpuToLLVMConversionPatterns(converter, patterns,
572 kernelBarePtrCallConv,
573 kernelIntersperseSizeCallConv);
574
575 if (failed(
576 applyPartialConversion(getOperation(), target, std::move(patterns))))
577 signalPassFailure();
578}
579
580LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
581 ArrayRef<Value> arguments) const {
582 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
583 auto function = [&] {
584 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
585 return function;
586 return OpBuilder::atBlockEnd(module.getBody())
587 .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
588 }();
589 return builder.create<LLVM::CallOp>(loc, function, arguments);
590}
591
592// Corresponding to cusparseIndexType_t defined in cusparse.h.
593static int32_t getCuSparseIndexTypeFrom(Type type) {
594 if (type.isInteger(width: 16))
595 return 1; // CUSPARSE_INDEX_16U
596 if (type.isInteger(width: 32))
597 return 2; // CUSPARSE_INDEX_32I
598 return 3; // CUSPARSE_INDEX_64I
599}
600
601static int32_t getCuSparseLtDataTypeFrom(Type type) {
602 if (type.isF16())
603 return 0; // CUSPARSE_COMPUTE_16F,
604 if (type.isInteger(width: 32))
605 return 1; // CUSPARSE_COMPUTE_32I
606 llvm_unreachable("unsupported type");
607 // TODO: add support to TF32
608}
609
610// Corresponding to cudaDataType_t defined in CUDA library_types.h.
611static int32_t getCuSparseDataTypeFrom(Type type) {
612 if (llvm::isa<ComplexType>(type)) {
613 // get the element type
614 auto elementType = cast<ComplexType>(type).getElementType();
615 if (elementType.isBF16())
616 return 15; // CUDA_C_16BF
617 if (elementType.isF16())
618 return 6; // CUDA_C_16F
619 if (elementType.isF32())
620 return 4; // CUDA_C_32F
621 if (elementType.isF64())
622 return 5; // CUDA_C_64F
623 if (elementType.isInteger(8))
624 return 7; // CUDA_C_8I
625 if (elementType.isInteger(16))
626 return 21; // CUDA_C_16I
627 if (elementType.isInteger(32))
628 return 11; // CUDA_C_32I
629 }
630 if (type.isBF16())
631 return 14; // CUDA_R_16BF
632 if (type.isF16())
633 return 2; // CUDA_R_16F
634 if (type.isF32())
635 return 0; // CUDA_R_32F
636 if (type.isF64())
637 return 1; // CUDA_R_64F
638 if (type.isInteger(width: 8))
639 return 3; // CUDA_R_8I
640 if (type.isInteger(width: 16))
641 return 20; // CUDA_R_16I
642 if (type.isInteger(width: 32))
643 return 10; // CUDA_R_32I
644
645 llvm_unreachable("unsupported element type");
646}
647
648static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) {
649 return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag();
650}
651
652// TODO: We may want a run-time (of the mlir compiler) disablement/warning:
653// cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
654// runtime (of the CUDA program) error , but it might be great if we could at
655// least output a warning when we found the target architecture is <8.0 and the
656// user still wants to use cusparseLt. to make sure when lowering gpu sparse
657// dialect to llvm calls, the cusparselt calls are disabled for cuda
658// architecture <8.0
659static bool is2To4Sparsity(Value spMat) {
660 if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>())
661 return true;
662 if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
663 return false;
664 if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>())
665 return false;
666 if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
667 return false;
668 if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>())
669 return false;
670 if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>())
671 return false;
672 // Print the spMat defining op
673 spMat.getDefiningOp()->print(os&: llvm::errs());
674 llvm_unreachable("cannot find spmat def");
675}
676
677static bool isSpMMCusparseLtOp(Value op) {
678 for (Operation *user : op.getUsers()) {
679 auto spmmOp = dyn_cast<gpu::SpMMOp>(user);
680 // If the other operator is 50% sparsity then we should use cusparseLt
681 if (!spmmOp)
682 continue;
683 if (is2To4Sparsity(spmmOp.getSpmatA()))
684 return true;
685 }
686 return false;
687}
688
689// Returns whether all operands are of LLVM type.
690static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
691 ConversionPatternRewriter &rewriter) {
692 if (!llvm::all_of(Range&: operands, P: [](Value value) {
693 return LLVM::isCompatibleType(type: value.getType());
694 }))
695 return rewriter.notifyMatchFailure(
696 arg&: op, msg: "Cannot convert if operands aren't of LLVM type.");
697 return success();
698}
699
700static LogicalResult
701isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
702 gpu::AsyncOpInterface op) {
703 if (op.getAsyncDependencies().size() != 1)
704 return rewriter.notifyMatchFailure(
705 op, "Can only convert with exactly one async dependency.");
706
707 if (!op.getAsyncToken())
708 return rewriter.notifyMatchFailure(op, "Can convert only async version.");
709
710 return success();
711}
712
713LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
714 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
715 ConversionPatternRewriter &rewriter) const {
716 auto *op = hostRegisterOp.getOperation();
717 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
718 return failure();
719
720 Location loc = op->getLoc();
721
722 auto memRefType = hostRegisterOp.getValue().getType();
723 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
724 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
725
726 auto arguments = getTypeConverter()->promoteOperands(
727 loc, op->getOperands(), adaptor.getOperands(), rewriter);
728 arguments.push_back(elementSize);
729 hostRegisterCallBuilder.create(loc, rewriter, arguments);
730
731 rewriter.eraseOp(op: op);
732 return success();
733}
734
735LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
736 gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor,
737 ConversionPatternRewriter &rewriter) const {
738 Operation *op = hostUnregisterOp.getOperation();
739 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
740 return failure();
741
742 Location loc = op->getLoc();
743
744 auto memRefType = hostUnregisterOp.getValue().getType();
745 auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
746 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
747
748 auto arguments = getTypeConverter()->promoteOperands(
749 loc, op->getOperands(), adaptor.getOperands(), rewriter);
750 arguments.push_back(elementSize);
751 hostUnregisterCallBuilder.create(loc, rewriter, arguments);
752
753 rewriter.eraseOp(op);
754 return success();
755}
756
757LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
758 gpu::AllocOp allocOp, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter) const {
760
761 MemRefType memRefType = allocOp.getType();
762
763 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
764 !isConvertibleAndHasIdentityMaps(memRefType))
765 return failure();
766
767 auto loc = allocOp.getLoc();
768
769 bool isShared = allocOp.getHostShared();
770
771 if (isShared && allocOp.getAsyncToken())
772 return rewriter.notifyMatchFailure(
773 allocOp, "Host Shared allocation cannot be done async");
774 if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp)))
775 return failure();
776
777 // Get shape of the memref as values: static sizes are constant
778 // values and dynamic sizes are passed to 'alloc' as operands.
779 SmallVector<Value, 4> shape;
780 SmallVector<Value, 4> strides;
781 Value sizeBytes;
782 getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter,
783 shape, strides, sizeBytes);
784
785 // Allocate the underlying buffer and store a pointer to it in the MemRef
786 // descriptor.
787 auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType);
788 Value stream = adaptor.getAsyncDependencies().empty()
789 ? nullPtr
790 : adaptor.getAsyncDependencies().front();
791
792 auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>(
793 loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared));
794
795 Value allocatedPtr =
796 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared})
797 .getResult();
798
799 // No alignment.
800 Value alignedPtr = allocatedPtr;
801
802 // Create the MemRef descriptor.
803 auto memRefDescriptor = this->createMemRefDescriptor(
804 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
805
806 if (allocOp.getAsyncToken()) {
807 // Async alloc: make dependent ops use the same stream.
808 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
809 } else {
810 rewriter.replaceOp(allocOp, {memRefDescriptor});
811 }
812
813 return success();
814}
815
816LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
817 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter) const {
819 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
820 failed(isAsyncWithOneDependency(rewriter, deallocOp)))
821 return failure();
822
823 Location loc = deallocOp.getLoc();
824
825 Value pointer =
826 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc);
827 Value stream = adaptor.getAsyncDependencies().front();
828 deallocCallBuilder.create(loc, rewriter, {pointer, stream});
829
830 rewriter.replaceOp(deallocOp, {stream});
831 return success();
832}
833
834static bool isGpuAsyncTokenType(Value value) {
835 return isa<gpu::AsyncTokenType>(Val: value.getType());
836}
837
838// Converts !gpu.async.token operands of `async.yield` to runtime calls. The
839// !gpu.async.token are lowered to stream within the async.execute region, but
840// are passed as events between them. For each !gpu.async.token operand, we
841// create an event and record it on the stream.
842LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
843 async::YieldOp yieldOp, OpAdaptor adaptor,
844 ConversionPatternRewriter &rewriter) const {
845 if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
846 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
847
848 Location loc = yieldOp.getLoc();
849 SmallVector<Value, 4> newOperands(adaptor.getOperands());
850 llvm::SmallDenseSet<Value> streams;
851 for (auto &operand : yieldOp->getOpOperands()) {
852 if (!isGpuAsyncTokenType(operand.get()))
853 continue;
854 auto idx = operand.getOperandNumber();
855 auto stream = adaptor.getOperands()[idx];
856 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
857 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
858 newOperands[idx] = event;
859 streams.insert(stream);
860 }
861 for (auto stream : streams)
862 streamDestroyCallBuilder.create(loc, rewriter, {stream});
863
864 rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
865 return success();
866}
867
868// Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
869static bool isDefinedByCallTo(Value value, StringRef functionName) {
870 assert(isa<LLVM::LLVMPointerType>(value.getType()));
871 if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
872 return *defOp.getCallee() == functionName;
873 return false;
874}
875
876// Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
877// with the stream/event operands. The operands are destroyed. That is, it
878// assumes that it is not used afterwards or elsewhere. Otherwise we will get a
879// runtime error. Eventually, we should guarantee this property.
880LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
881 gpu::WaitOp waitOp, OpAdaptor adaptor,
882 ConversionPatternRewriter &rewriter) const {
883 if (waitOp.getAsyncToken())
884 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
885
886 Location loc = waitOp.getLoc();
887
888 for (auto operand : adaptor.getOperands()) {
889 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
890 // The converted operand's definition created a stream.
891 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
892 streamDestroyCallBuilder.create(loc, rewriter, {operand});
893 } else {
894 // Otherwise the converted operand is an event. This assumes that we use
895 // events in control flow code as well.
896 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
897 eventDestroyCallBuilder.create(loc, rewriter, {operand});
898 }
899 }
900
901 rewriter.eraseOp(op: waitOp);
902 return success();
903}
904
905// Converts `gpu.wait async` to runtime calls. The converted op creates a new
906// stream that is synchronized with stream/event operands. The operands are
907// destroyed. That is, it assumes that it is not used afterwards or elsewhere.
908// Otherwise we will get a runtime error. Eventually, we should guarantee this
909// property.
910LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
911 gpu::WaitOp waitOp, OpAdaptor adaptor,
912 ConversionPatternRewriter &rewriter) const {
913 if (!waitOp.getAsyncToken())
914 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
915
916 Location loc = waitOp.getLoc();
917
918 auto insertionPoint = rewriter.saveInsertionPoint();
919 SmallVector<Value, 1> events;
920 for (auto pair :
921 llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) {
922 auto operand = std::get<1>(pair);
923 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
924 // The converted operand's definition created a stream. Insert an event
925 // into the stream just after the last use of the original token operand.
926 auto *defOp = std::get<0>(pair).getDefiningOp();
927 rewriter.setInsertionPointAfter(defOp);
928 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult();
929 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
930 events.push_back(event);
931 } else {
932 // Otherwise the converted operand is an event. This assumes that we use
933 // events in control flow code as well.
934 events.push_back(operand);
935 }
936 }
937 rewriter.restoreInsertionPoint(ip: insertionPoint);
938 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
939 for (auto event : events)
940 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
941 for (auto event : events)
942 eventDestroyCallBuilder.create(loc, rewriter, {event});
943 rewriter.replaceOp(waitOp, {stream});
944
945 return success();
946}
947
948// Legalize the op's operands.
949LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
950 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
951 ConversionPatternRewriter &rewriter) const {
952 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
953 return failure();
954
955 if (launchOp.getAsyncDependencies().size() > 1)
956 return rewriter.notifyMatchFailure(
957 launchOp, "Cannot convert with more than one async dependency.");
958
959 // Fail when the synchronous version of the op has async dependencies. The
960 // lowering destroys the stream, and we do not want to check that there is no
961 // use of the stream after this op.
962 if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty())
963 return rewriter.notifyMatchFailure(
964 launchOp, "Cannot convert non-async op with async dependencies.");
965
966 Location loc = launchOp.getLoc();
967
968 Value stream = Value();
969 if (!adaptor.getAsyncDependencies().empty())
970 stream = adaptor.getAsyncDependencies().front();
971 // If the async keyword is present and there are no dependencies, then a
972 // stream must be created to pass to subsequent operations.
973 else if (launchOp.getAsyncToken())
974 stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
975
976 // Lower the kernel operands to match kernel parameters.
977 // Note: If `useBarePtrCallConv` is set in the type converter's options,
978 // the value of `kernelBarePtrCallConv` will be ignored.
979 OperandRange origArguments = launchOp.getKernelOperands();
980 SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
981 loc, origArguments, adaptor.getKernelOperands(), rewriter,
982 /*useBarePtrCallConv=*/kernelBarePtrCallConv);
983 SmallVector<Value, 8> llvmArgumentsWithSizes;
984
985 // Intersperse size information if requested.
986 if (kernelIntersperseSizeCallConv) {
987 if (origArguments.size() != llvmArguments.size()) {
988 // This shouldn't happen if the bare-pointer calling convention is used.
989 return rewriter.notifyMatchFailure(
990 launchOp,
991 "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
992 }
993
994 llvmArgumentsWithSizes.reserve(N: llvmArguments.size() * 2);
995 for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
996 auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
997 if (!memrefTy) {
998 return rewriter.notifyMatchFailure(
999 launchOp, "Operand to launch op is not a memref.");
1000 }
1001
1002 if (!memrefTy.hasStaticShape() ||
1003 !memrefTy.getElementType().isIntOrFloat()) {
1004 return rewriter.notifyMatchFailure(
1005 launchOp, "Operand to launch op is not a memref with a static "
1006 "shape and an integer or float element type.");
1007 }
1008
1009 unsigned bitwidth = memrefTy.getElementTypeBitWidth();
1010 if (bitwidth % 8 != 0) {
1011 return rewriter.notifyMatchFailure(
1012 launchOp, "Operand to launch op is not a memref with a "
1013 "byte-aligned element type.");
1014 }
1015
1016 uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
1017 static_cast<uint64_t>(memrefTy.getNumElements());
1018
1019 Value sizeArg = rewriter.create<LLVM::ConstantOp>(
1020 loc, getIndexType(), rewriter.getIndexAttr(staticSize));
1021 llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
1022 llvmArgumentsWithSizes.push_back(sizeArg);
1023 }
1024 }
1025
1026 std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
1027 if (launchOp.hasClusterSize()) {
1028 clusterSize =
1029 gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
1030 adaptor.getClusterSizeZ()};
1031 }
1032 rewriter.create<gpu::LaunchFuncOp>(
1033 launchOp.getLoc(), launchOp.getKernelAttr(),
1034 gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
1035 adaptor.getGridSizeZ()},
1036 gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
1037 adaptor.getBlockSizeZ()},
1038 adaptor.getDynamicSharedMemorySize(),
1039 llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
1040 stream, clusterSize);
1041 if (launchOp.getAsyncToken())
1042 rewriter.replaceOp(launchOp, {stream});
1043 else
1044 rewriter.eraseOp(op: launchOp);
1045 return success();
1046}
1047
1048static Value bitAndAddrspaceCast(Location loc,
1049 ConversionPatternRewriter &rewriter,
1050 LLVM::LLVMPointerType destinationType,
1051 Value sourcePtr,
1052 const LLVMTypeConverter &typeConverter) {
1053 auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
1054 if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
1055 sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1056 loc,
1057 LLVM::LLVMPointerType::get(rewriter.getContext(),
1058 destinationType.getAddressSpace()),
1059 sourcePtr);
1060 return sourcePtr;
1061}
1062
1063LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
1064 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
1065 ConversionPatternRewriter &rewriter) const {
1066 auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
1067
1068 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
1069 !isConvertibleAndHasIdentityMaps(memRefType) ||
1070 failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
1071 return failure();
1072
1073 auto loc = memcpyOp.getLoc();
1074
1075 MemRefDescriptor srcDesc(adaptor.getSrc());
1076 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
1077
1078 Type elementPtrType = getElementPtrType(memRefType);
1079 Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
1080 Value gepPtr = rewriter.create<LLVM::GEPOp>(
1081 loc, elementPtrType,
1082 typeConverter->convertType(memRefType.getElementType()), nullPtr,
1083 numElements);
1084 auto sizeBytes =
1085 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
1086
1087 auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1088 srcDesc.alignedPtr(rewriter, loc),
1089 *getTypeConverter());
1090 auto dst = bitAndAddrspaceCast(
1091 loc, rewriter, llvmPointerType,
1092 MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
1093 *getTypeConverter());
1094
1095 auto stream = adaptor.getAsyncDependencies().front();
1096 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
1097
1098 rewriter.replaceOp(memcpyOp, {stream});
1099
1100 return success();
1101}
1102
1103LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
1104 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
1105 ConversionPatternRewriter &rewriter) const {
1106 auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
1107
1108 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
1109 !isConvertibleAndHasIdentityMaps(memRefType) ||
1110 failed(isAsyncWithOneDependency(rewriter, memsetOp)))
1111 return failure();
1112
1113 auto loc = memsetOp.getLoc();
1114
1115 Type valueType = adaptor.getValue().getType();
1116 unsigned bitWidth = valueType.getIntOrFloatBitWidth();
1117 // Ints and floats of 16 or 32 bit width are allowed.
1118 if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
1119 return rewriter.notifyMatchFailure(
1120 memsetOp, "value must be a 16 or 32 bit int or float");
1121 }
1122
1123 unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
1124 Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;
1125
1126 MemRefDescriptor dstDesc(adaptor.getDst());
1127 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
1128
1129 auto value =
1130 rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
1131 auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
1132 dstDesc.alignedPtr(rewriter, loc),
1133 *getTypeConverter());
1134
1135 auto stream = adaptor.getAsyncDependencies().front();
1136 FunctionCallBuilder builder =
1137 valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
1138 builder.create(loc, rewriter, {dst, value, numElements, stream});
1139
1140 rewriter.replaceOp(memsetOp, {stream});
1141 return success();
1142}
1143
1144LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
1145 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter) const {
1147 Location loc = op.getLoc();
1148 auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
1149 {adaptor.getDevIndex()});
1150 rewriter.replaceOp(op, call);
1151 return success();
1152}
1153
1154template <typename T>
1155static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) {
1156 Type llvmInt32Type = builder.getIntegerType(32);
1157 return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
1158 static_cast<int32_t>(tValue));
1159}
1160
1161template <typename T>
1162static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) {
1163 Type llvmFloat32Type = builder.getF32Type();
1164 return builder.create<LLVM::ConstantOp>(
1165 loc, llvmFloat32Type,
1166 builder.getF32FloatAttr(static_cast<float>(tValue)));
1167}
1168
1169LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1170 gpu::CreateDnTensorOp op, OpAdaptor adaptor,
1171 ConversionPatternRewriter &rewriter) const {
1172 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1173 failed(isAsyncWithOneDependency(rewriter, op)))
1174 return failure();
1175 Location loc = op.getLoc();
1176 auto stream = adaptor.getAsyncDependencies().front();
1177 Value pTensor =
1178 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc);
1179 Type dType = op.getMemref().getType().getElementType();
1180 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1181
1182 SmallVector<Value, 4> dims;
1183 for (Value dim : adaptor.getDims()) {
1184 dims.push_back(dim);
1185 }
1186
1187 Value handle;
1188 // TODO: For now, we track the use of the handle and lower it to cusparse /
1189 // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are
1190 // used, we require two separate Creation ops to be the correct logic. In
1191 // future, we may add support to using one handle in sparse tensor / GPU
1192 // dialect in both cusparse and cusparseLt. use the cusparseLt create call if
1193 // the dnmat is used with spmat with 2:4 sparsity
1194 if (dims.size() == 2) {
1195 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1196 auto handleSz = rewriter.create<LLVM::ConstantOp>(
1197 loc, getIndexType(), rewriter.getIndexAttr(11032));
1198 handle = rewriter.create<LLVM::AllocaOp>(
1199 loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1200 handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1201
1202 createLtDnMatCallBuilder
1203 .create(loc, rewriter,
1204 {handle, dims[0], dims[1], pTensor, dtp, stream})
1205 .getResult();
1206 } else {
1207 handle =
1208 createDnMatCallBuilder
1209 .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream})
1210 .getResult();
1211 }
1212 } else {
1213 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1214 handle = createDnVecCallBuilder
1215 .create(loc, rewriter, {dims[0], pTensor, dtp, stream})
1216 .getResult();
1217 }
1218 rewriter.replaceOp(op, {handle, stream});
1219 return success();
1220}
1221
1222LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
1223 gpu::DestroyDnTensorOp op, OpAdaptor adaptor,
1224 ConversionPatternRewriter &rewriter) const {
1225 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1226 failed(isAsyncWithOneDependency(rewriter, op)))
1227 return failure();
1228 Location loc = op.getLoc();
1229 auto stream = adaptor.getAsyncDependencies().front();
1230 auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>();
1231 SmallVector<Value, 4> dims;
1232 for (Value dim : definingOp.getDims()) {
1233 dims.push_back(dim);
1234 }
1235 if (dims.size() == 2) {
1236 // Use the cusparseLt destroy call if the dnmat is used with spmat with
1237 // 2:4 sparsity
1238 if (isSpMMCusparseLtOp(op.getDnTensor())) {
1239 destroyCuSparseLtDnMatBuilder.create(loc, rewriter,
1240 {adaptor.getDnTensor(), stream});
1241 } else {
1242 destroyDnMatCallBuilder.create(loc, rewriter,
1243 {adaptor.getDnTensor(), stream});
1244 }
1245 } else {
1246 assert(dims.size() == 1 && "Only 1D and 2D tensors are supported");
1247 destroyDnVecCallBuilder.create(loc, rewriter,
1248 {adaptor.getDnTensor(), stream});
1249 }
1250 rewriter.replaceOp(op, {stream});
1251 return success();
1252}
1253
1254LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
1255 gpu::CreateCooOp op, OpAdaptor adaptor,
1256 ConversionPatternRewriter &rewriter) const {
1257 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1258 failed(isAsyncWithOneDependency(rewriter, op)))
1259 return failure();
1260 Location loc = op.getLoc();
1261 auto stream = adaptor.getAsyncDependencies().front();
1262 Value pRowIdxs =
1263 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(builder&: rewriter, loc);
1264 Value pColIdxs =
1265 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(builder&: rewriter, loc);
1266 Value pValues =
1267 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1268 Type iType =
1269 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1270 Type dType =
1271 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1272 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1273 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1274 auto handle =
1275 createCooCallBuilder
1276 .create(loc, rewriter,
1277 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1278 pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
1279 .getResult();
1280 rewriter.replaceOp(op, {handle, stream});
1281 return success();
1282}
1283
1284LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite(
1285 gpu::CreateCooAoSOp op, OpAdaptor adaptor,
1286 ConversionPatternRewriter &rewriter) const {
1287 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1288 failed(isAsyncWithOneDependency(rewriter, op)))
1289 return failure();
1290 Location loc = op.getLoc();
1291 auto stream = adaptor.getAsyncDependencies().front();
1292 Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(builder&: rewriter, loc);
1293 Value pValues =
1294 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1295 Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType();
1296 Type dType =
1297 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1298 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1299 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1300 auto handle =
1301 createCooAoSCallBuilder
1302 .create(loc, rewriter,
1303 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1304 pIdxs, pValues, itp, dtp, stream})
1305 .getResult();
1306 rewriter.replaceOp(op, {handle, stream});
1307 return success();
1308}
1309
1310LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1311 gpu::CreateCsrOp op, OpAdaptor adaptor,
1312 ConversionPatternRewriter &rewriter) const {
1313 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1314 failed(isAsyncWithOneDependency(rewriter, op)))
1315 return failure();
1316 Location loc = op.getLoc();
1317 auto stream = adaptor.getAsyncDependencies().front();
1318 Value pRowPos =
1319 MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(builder&: rewriter, loc);
1320 Value pColIdxs =
1321 MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(builder&: rewriter, loc);
1322 Value pValues =
1323 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1324 Type pType =
1325 llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
1326 Type iType =
1327 llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
1328 Type dType =
1329 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1330 auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType));
1331 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1332 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1333 auto handle =
1334 createCsrCallBuilder
1335 .create(loc, rewriter,
1336 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1337 pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
1338 .getResult();
1339 rewriter.replaceOp(op, {handle, stream});
1340 return success();
1341}
1342
1343LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1344 gpu::Create2To4SpMatOp op, OpAdaptor adaptor,
1345 ConversionPatternRewriter &rewriter) const {
1346 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1347 failed(isAsyncWithOneDependency(rewriter, op)))
1348 return failure();
1349 Location loc = op.getLoc();
1350 auto stream = adaptor.getAsyncDependencies().front();
1351 Value pMat =
1352 MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc);
1353 Type dType =
1354 llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
1355 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1356
1357 // CUDA runner asserts the size is 44104 bytes.
1358 auto handleSz = rewriter.create<LLVM::ConstantOp>(
1359 loc, getIndexType(), rewriter.getIndexAttr(44104));
1360 Value handle = rewriter.create<LLVM::AllocaOp>(
1361 loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16);
1362 handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle);
1363
1364 create2To4SpMatCallBuilder
1365 .create(loc, rewriter,
1366 {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
1367 .getResult();
1368 rewriter.replaceOp(op, {handle, stream});
1369 return success();
1370}
1371
1372LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
1373 gpu::DestroySpMatOp op, OpAdaptor adaptor,
1374 ConversionPatternRewriter &rewriter) const {
1375 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1376 failed(isAsyncWithOneDependency(rewriter, op)))
1377 return failure();
1378 Location loc = op.getLoc();
1379 auto stream = adaptor.getAsyncDependencies().front();
1380 // Use the cusparseLt destroy call if the spmat is 2:4 sparsity
1381 if (is2To4Sparsity(op.getSpmat())) {
1382 destroyCuSparseLtSpMatBuilder.create(loc, rewriter,
1383 {adaptor.getSpmat(), stream});
1384
1385 } else {
1386 destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream});
1387 }
1388 rewriter.replaceOp(op, {stream});
1389 return success();
1390}
1391
1392LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1393 gpu::SpMVBufferSizeOp op, OpAdaptor adaptor,
1394 ConversionPatternRewriter &rewriter) const {
1395 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1396 failed(isAsyncWithOneDependency(rewriter, op)))
1397 return failure();
1398 Location loc = op.getLoc();
1399 auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
1400 auto computeType = genConstInt32From(
1401 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1402 auto stream = adaptor.getAsyncDependencies().front();
1403 auto bufferSize = spMVBufferSizeCallBuilder
1404 .create(loc, rewriter,
1405 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1406 adaptor.getDnY(), computeType, stream})
1407 .getResult();
1408 rewriter.replaceOp(op, {bufferSize, stream});
1409 return success();
1410}
1411
1412LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
1413 gpu::SpMVOp op, OpAdaptor adaptor,
1414 ConversionPatternRewriter &rewriter) const {
1415 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1416 failed(isAsyncWithOneDependency(rewriter, op)))
1417 return failure();
1418 Location loc = op.getLoc();
1419 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1420 auto computeType = genConstInt32From(
1421 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1422 auto stream = adaptor.getAsyncDependencies().front();
1423 Value pBuf =
1424 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc);
1425 spMVCallBuilder.create(loc, rewriter,
1426 {modeA, adaptor.getSpmatA(), adaptor.getDnX(),
1427 adaptor.getDnY(), computeType, pBuf, stream});
1428 rewriter.replaceOp(op, {stream});
1429 return success();
1430}
1431
1432LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1433 gpu::SpMMBufferSizeOp op, OpAdaptor adaptor,
1434 ConversionPatternRewriter &rewriter) const {
1435 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1436 failed(isAsyncWithOneDependency(rewriter, op)))
1437 return failure();
1438 Location loc = op.getLoc();
1439 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1440 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1441 auto stream = adaptor.getAsyncDependencies().front();
1442 Value bufferSize;
1443 if (is2To4Sparsity(op.getSpmatA())) {
1444 auto pruneFlag =
1445 genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA()));
1446 auto computeType = genConstInt32From(
1447 rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
1448 auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1449 rewriter.getIndexAttr(3));
1450 auto bufferSize = rewriter.create<LLVM::AllocaOp>(
1451 loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16);
1452 createCuSparseLtSpMMBufferSizeBuilder
1453 .create(loc, rewriter,
1454 {bufferSize, modeA, modeB, adaptor.getSpmatA(),
1455 adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
1456 pruneFlag, stream})
1457 .getResult();
1458
1459 auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>(
1460 loc, llvmPointerType, llvmPointerType, bufferSize,
1461 ValueRange{rewriter.create<LLVM::ConstantOp>(
1462 loc, getIndexType(), rewriter.getIndexAttr(1))});
1463 auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
1464 loc, llvmPointerType, llvmPointerType, bufferSize,
1465 ValueRange{rewriter.create<LLVM::ConstantOp>(
1466 loc, getIndexType(), rewriter.getIndexAttr(2))});
1467 auto bufferSize0 =
1468 rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize);
1469 auto bufferSize1 =
1470 rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1);
1471 auto bufferSize2 =
1472 rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
1473
1474 rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream});
1475 } else {
1476 auto computeType = genConstInt32From(
1477 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1478 bufferSize =
1479 createSpMMBufferSizeCallBuilder
1480 .create(loc, rewriter,
1481 {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(),
1482 adaptor.getDnmatC(), computeType, stream})
1483 .getResult();
1484 rewriter.replaceOp(op, {bufferSize, stream});
1485 }
1486 return success();
1487}
1488
1489LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1490 gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
1491 ConversionPatternRewriter &rewriter) const {
1492 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1493 failed(isAsyncWithOneDependency(rewriter, op)))
1494 return failure();
1495 Location loc = op.getLoc();
1496 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1497 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1498 auto computeType = genConstInt32From(
1499 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1500 auto stream = adaptor.getAsyncDependencies().front();
1501 auto bufferSize =
1502 createSDDMMBufferSizeCallBuilder
1503 .create(loc, rewriter,
1504 {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
1505 adaptor.getSpmatC(), computeType, stream})
1506 .getResult();
1507 rewriter.replaceOp(op, {bufferSize, stream});
1508 return success();
1509}
1510
1511LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1512 gpu::SpMMOp op, OpAdaptor adaptor,
1513 ConversionPatternRewriter &rewriter) const {
1514 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1515 failed(isAsyncWithOneDependency(rewriter, op)))
1516 return failure();
1517 Location loc = op.getLoc();
1518 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1519 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1520 auto computeType = genConstInt32From(
1521 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1522
1523 auto stream = adaptor.getAsyncDependencies().front();
1524
1525 // Lower to cusparseLt if applicable
1526 if (is2To4Sparsity(op.getSpmatA())) {
1527 SmallVector<Value> pBufs;
1528 for (Value buffer : adaptor.getBuffers()) {
1529 Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc);
1530 pBufs.push_back(pBuf);
1531 }
1532 createCuSparseLtSpMMBuilder.create(
1533 loc, rewriter,
1534 {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
1535 pBufs[0], pBufs[1], pBufs[2], stream});
1536 } else {
1537 Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
1538 .allocatedPtr(builder&: rewriter, loc);
1539 createSpMMCallBuilder.create(loc, rewriter,
1540 {modeA, modeB, adaptor.getSpmatA(),
1541 adaptor.getDnmatB(), adaptor.getDnmatC(),
1542 computeType, pBuf, stream});
1543 }
1544 rewriter.replaceOp(op, {stream});
1545 return success();
1546}
1547
1548template <typename T>
1549static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
1550 converter.addConversion([&converter](T) -> Type {
1551 return LLVM::LLVMPointerType::get(&converter.getContext());
1552 });
1553}
1554
1555LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
1556 gpu::SDDMMOp op, OpAdaptor adaptor,
1557 ConversionPatternRewriter &rewriter) const {
1558 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1559 failed(isAsyncWithOneDependency(rewriter, op)))
1560 return failure();
1561 Location loc = op.getLoc();
1562 auto computeType = genConstInt32From(
1563 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1564 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1565 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1566 auto stream = adaptor.getAsyncDependencies().front();
1567 Value pBuf =
1568 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc);
1569 createSDDMMCallBuilder.create(loc, rewriter,
1570 {modeA, modeB, adaptor.getDnmatA(),
1571 adaptor.getDnmatB(), adaptor.getSpmatC(),
1572 computeType, pBuf, stream});
1573 rewriter.replaceOp(op, {stream});
1574 return success();
1575}
1576
1577LogicalResult
1578ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1579 gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
1580 ConversionPatternRewriter &rewriter) const {
1581 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1582 failed(isAsyncWithOneDependency(rewriter, op)))
1583 return failure();
1584 Location loc = op.getLoc();
1585 auto stream = adaptor.getAsyncDependencies().front();
1586 Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
1587 .getResult();
1588 rewriter.replaceOp(op, {descr, stream});
1589 return success();
1590}
1591
1592LogicalResult
1593ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
1594 gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
1595 ConversionPatternRewriter &rewriter) const {
1596 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1597 failed(isAsyncWithOneDependency(rewriter, op)))
1598 return failure();
1599 Location loc = op.getLoc();
1600 auto stream = adaptor.getAsyncDependencies().front();
1601 createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
1602 {adaptor.getDesc(), stream});
1603 rewriter.replaceOp(op, {stream});
1604 return success();
1605}
1606
1607LogicalResult
1608ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
1609 gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
1610 ConversionPatternRewriter &rewriter) const {
1611 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1612 failed(isAsyncWithOneDependency(rewriter, op)))
1613 return failure();
1614 Location loc = op.getLoc();
1615 auto computeType = genConstInt32From(
1616 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1617 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1618 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1619 auto stream = adaptor.getAsyncDependencies().front();
1620
1621 Value pBuf =
1622 MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc);
1623 Value bufferSizeNew;
1624
1625 if (adaptor.getKind() ==
1626 gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
1627 bufferSizeNew =
1628 createSpGEMMWorkEstimationBuilder
1629 .create(loc, rewriter,
1630 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1631 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1632 adaptor.getBufferSz(), pBuf, stream})
1633 .getResult();
1634 } else {
1635 bufferSizeNew =
1636 createSpGEMMComputeBuilder
1637 .create(loc, rewriter,
1638 {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
1639 adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
1640 adaptor.getBufferSz(), pBuf, stream})
1641 .getResult();
1642 }
1643 rewriter.replaceOp(op, {bufferSizeNew, stream});
1644 return success();
1645}
1646
1647LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
1648 gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
1649 ConversionPatternRewriter &rewriter) const {
1650 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1651 failed(isAsyncWithOneDependency(rewriter, op)))
1652 return failure();
1653 Location loc = op.getLoc();
1654 auto computeType = genConstInt32From(
1655 rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
1656 auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
1657 auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
1658 auto stream = adaptor.getAsyncDependencies().front();
1659 createSpGEMMCopyBuilder.create(loc, rewriter,
1660 {adaptor.getDesc(), modeA, modeB,
1661 adaptor.getSpmatA(), adaptor.getSpmatB(),
1662 adaptor.getSpmatC(), computeType, stream});
1663 rewriter.replaceOp(op, {stream});
1664 return success();
1665}
1666
1667LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
1668 gpu::SpMatGetSizeOp op, OpAdaptor adaptor,
1669 ConversionPatternRewriter &rewriter) const {
1670 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1671 failed(isAsyncWithOneDependency(rewriter, op)))
1672 return failure();
1673 Location loc = op.getLoc();
1674 auto stream = adaptor.getAsyncDependencies().front();
1675
1676 auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1677 rewriter.getIndexAttr(3));
1678 auto buffer = rewriter.create<LLVM::AllocaOp>(
1679 loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16);
1680
1681 auto rowsPtr = rewriter.create<LLVM::GEPOp>(
1682 loc, llvmPointerType, llvmPointerType, buffer,
1683 ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1684 rewriter.getIndexAttr(0))});
1685 auto colsPtr = rewriter.create<LLVM::GEPOp>(
1686 loc, llvmPointerType, llvmPointerType, buffer,
1687 ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1688 rewriter.getIndexAttr(1))});
1689 auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
1690 loc, llvmPointerType, llvmPointerType, buffer,
1691 ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1692 rewriter.getIndexAttr(2))});
1693 createSpMatGetSizeBuilder.create(
1694 loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
1695 auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
1696 auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
1697 auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
1698
1699 rewriter.replaceOp(op, {rows, cols, nnzs, stream});
1700 return success();
1701}
1702
1703LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1704 gpu::SetCsrPointersOp op, OpAdaptor adaptor,
1705 ConversionPatternRewriter &rewriter) const {
1706 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1707 failed(isAsyncWithOneDependency(rewriter, op)))
1708 return failure();
1709 Location loc = op.getLoc();
1710 auto stream = adaptor.getAsyncDependencies().front();
1711 Value pPos =
1712 MemRefDescriptor(adaptor.getPositions()).allocatedPtr(builder&: rewriter, loc);
1713 Value pCrd =
1714 MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(builder&: rewriter, loc);
1715 Value pVal =
1716 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1717 createSetCsrPointersBuilder.create(
1718 loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
1719 rewriter.replaceOp(op, {stream});
1720 return success();
1721}
1722
1723LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite(
1724 gpu::CreateCscOp op, OpAdaptor adaptor,
1725 ConversionPatternRewriter &rewriter) const {
1726 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1727 failed(isAsyncWithOneDependency(rewriter, op)))
1728 return failure();
1729 Location loc = op.getLoc();
1730 auto stream = adaptor.getAsyncDependencies().front();
1731 Value pColPos =
1732 MemRefDescriptor(adaptor.getColPos()).allocatedPtr(builder&: rewriter, loc);
1733 Value pRowIdxs =
1734 MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(builder&: rewriter, loc);
1735 Value pValues =
1736 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1737 Type pType =
1738 llvm::cast<MemRefType>(op.getColPos().getType()).getElementType();
1739 Type iType =
1740 llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType();
1741 Type dType =
1742 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1743 auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType));
1744 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1745 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1746 auto handle =
1747 createCscCallBuilder
1748 .create(loc, rewriter,
1749 {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
1750 pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1751 .getResult();
1752 rewriter.replaceOp(op, {handle, stream});
1753 return success();
1754}
1755
1756LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
1757 gpu::CreateBsrOp op, OpAdaptor adaptor,
1758 ConversionPatternRewriter &rewriter) const {
1759 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
1760 failed(isAsyncWithOneDependency(rewriter, op)))
1761 return failure();
1762 Location loc = op.getLoc();
1763 auto stream = adaptor.getAsyncDependencies().front();
1764 Value pRowPos =
1765 MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(builder&: rewriter, loc);
1766 Value pColIdxs =
1767 MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(builder&: rewriter, loc);
1768 Value pValues =
1769 MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc);
1770 Type pType =
1771 llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType();
1772 Type iType =
1773 llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType();
1774 Type dType =
1775 llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
1776 auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType));
1777 auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType));
1778 auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType));
1779 auto handle =
1780 createBsrCallBuilder
1781 .create(loc, rewriter,
1782 {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(),
1783 adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos,
1784 pColIdxs, pValues, ptp, itp, dtp, stream})
1785 .getResult();
1786 rewriter.replaceOp(op, {handle, stream});
1787 return success();
1788}
1789
1790void mlir::populateGpuToLLVMConversionPatterns(
1791 LLVMTypeConverter &converter, RewritePatternSet &patterns,
1792 bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
1793 addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
1794 addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
1795 addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
1796 addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
1797
1798 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
1799 ConvertDeallocOpToGpuRuntimeCallPattern,
1800 ConvertHostRegisterOpToGpuRuntimeCallPattern,
1801 ConvertHostUnregisterOpToGpuRuntimeCallPattern,
1802 ConvertMemcpyOpToGpuRuntimeCallPattern,
1803 ConvertMemsetOpToGpuRuntimeCallPattern,
1804 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
1805 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
1806 ConvertWaitOpToGpuRuntimeCallPattern,
1807 ConvertAsyncYieldToGpuRuntimeCallPattern,
1808 ConvertCreateDnTensorOpToGpuRuntimeCallPattern,
1809 ConvertDestroyDnTensorOpToGpuRuntimeCallPattern,
1810 ConvertCreateCooOpToGpuRuntimeCallPattern,
1811 ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1812 ConvertCreateCsrOpToGpuRuntimeCallPattern,
1813 ConvertCreateCscOpToGpuRuntimeCallPattern,
1814 ConvertCreateBsrOpToGpuRuntimeCallPattern,
1815 ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1816 ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1817 ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
1818 ConvertSpMVOpToGpuRuntimeCallPattern,
1819 ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
1820 ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
1821 ConvertSpMMOpToGpuRuntimeCallPattern,
1822 ConvertSDDMMOpToGpuRuntimeCallPattern,
1823 ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
1824 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
1825 ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
1826 ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
1827 ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
1828 ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(arg&: converter);
1829 patterns.add<LegalizeLaunchFuncOpPattern>(arg&: converter, args&: kernelBarePtrCallConv,
1830 args&: kernelIntersperseSizeCallConv);
1831}
1832
1833//===----------------------------------------------------------------------===//
1834// GPUModuleOp convert to LLVM op interface
1835//===----------------------------------------------------------------------===//
1836
1837namespace {
1838struct GPUModuleOpConvertToLLVMInterface
1839 : public ConvertToLLVMOpInterface::ExternalModel<
1840 GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> {
1841 /// Get the conversion patterns from the target attribute.
1842 void getConvertToLLVMConversionAttrs(
1843 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const;
1844};
1845} // namespace
1846
1847void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs(
1848 Operation *op, SmallVectorImpl<ConvertToLLVMAttrInterface> &attrs) const {
1849 auto module = cast<gpu::GPUModuleOp>(op);
1850 ArrayAttr targetsAttr = module.getTargetsAttr();
1851 // Fail if there are no target attributes or there is more than one target.
1852 if (!targetsAttr || targetsAttr.size() != 1)
1853 return;
1854 if (auto patternAttr = dyn_cast<ConvertToLLVMAttrInterface>(targetsAttr[0]))
1855 attrs.push_back(patternAttr);
1856}
1857
1858void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry &registry) {
1859 registry.addExtension(extensionFn: +[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
1860 gpu::GPUModuleOp::attachInterface<GPUModuleOpConvertToLLVMInterface>(*ctx);
1861 });
1862}
1863

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp