1 | //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to convert gpu.launch_func op into a sequence of |
10 | // GPU runtime calls. As most of GPU runtimes does not have a stable published |
11 | // ABI, this pass uses a slim runtime layer that builds on top of the public |
12 | // API from GPU runtime headers. |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
17 | |
18 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
19 | #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" |
20 | #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
21 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
22 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" |
23 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
24 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
25 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
26 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
27 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
28 | #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
29 | #include "mlir/Dialect/Async/IR/Async.h" |
30 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
31 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
32 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
33 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
34 | #include "mlir/IR/Attributes.h" |
35 | #include "mlir/IR/Builders.h" |
36 | #include "mlir/IR/BuiltinOps.h" |
37 | #include "mlir/IR/BuiltinTypes.h" |
38 | |
39 | #include "llvm/ADT/STLExtras.h" |
40 | #include "llvm/Support/Error.h" |
41 | #include "llvm/Support/FormatVariadic.h" |
42 | |
43 | #define DEBUG_TYPE "gpu-to-llvm" |
44 | |
45 | namespace mlir { |
46 | #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS |
47 | #include "mlir/Conversion/Passes.h.inc" |
48 | } // namespace mlir |
49 | |
50 | using namespace mlir; |
51 | |
52 | static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst" ; |
53 | |
54 | namespace { |
55 | class GpuToLLVMConversionPass |
56 | : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> { |
57 | public: |
58 | using Base::Base; |
59 | void getDependentDialects(DialectRegistry ®istry) const final { |
60 | Base::getDependentDialects(registry); |
61 | registerConvertToLLVMDependentDialectLoading(registry); |
62 | } |
63 | // Run the dialect converter on the module. |
64 | void runOnOperation() override; |
65 | }; |
66 | |
67 | template <typename OpTy> |
68 | class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> { |
69 | public: |
70 | explicit ConvertOpToGpuRuntimeCallPattern( |
71 | const LLVMTypeConverter &typeConverter) |
72 | : ConvertOpToLLVMPattern<OpTy>(typeConverter) {} |
73 | |
74 | protected: |
75 | Value getNumElements(ConversionPatternRewriter &rewriter, Location loc, |
76 | MemRefType type, MemRefDescriptor desc) const { |
77 | Type indexType = ConvertToLLVMPattern::getIndexType(); |
78 | return type.hasStaticShape() |
79 | ? ConvertToLLVMPattern::createIndexAttrConstant( |
80 | rewriter, loc, indexType, type.getNumElements()) |
81 | // For identity maps (verified by caller), the number of |
82 | // elements is stride[0] * size[0]. |
83 | : rewriter.create<LLVM::MulOp>(loc, |
84 | desc.stride(rewriter, loc, 0), |
85 | desc.size(rewriter, loc, 0)); |
86 | } |
87 | |
88 | MLIRContext *context = &this->getTypeConverter()->getContext(); |
89 | |
90 | Type llvmVoidType = LLVM::LLVMVoidType::get(ctx: context); |
91 | LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(context); |
92 | Type llvmInt8Type = IntegerType::get(context, 8); |
93 | Type llvmInt16Type = IntegerType::get(context, 16); |
94 | Type llvmInt32Type = IntegerType::get(context, 32); |
95 | Type llvmInt64Type = IntegerType::get(context, 64); |
96 | Type llvmFloat32Type = Float32Type::get(context); |
97 | Type llvmIntPtrType = IntegerType::get( |
98 | context, this->getTypeConverter()->getPointerBitwidth(0)); |
99 | |
100 | FunctionCallBuilder moduleLoadCallBuilder = { |
101 | "mgpuModuleLoad" , |
102 | llvmPointerType /* void *module */, |
103 | {llvmPointerType /* void *cubin */, llvmInt64Type /* size_t size */}}; |
104 | FunctionCallBuilder moduleUnloadCallBuilder = { |
105 | "mgpuModuleUnload" , llvmVoidType, {llvmPointerType /* void *module */}}; |
106 | FunctionCallBuilder moduleGetFunctionCallBuilder = { |
107 | "mgpuModuleGetFunction" , |
108 | llvmPointerType /* void *function */, |
109 | { |
110 | llvmPointerType, /* void *module */ |
111 | llvmPointerType /* char *name */ |
112 | }}; |
113 | FunctionCallBuilder launchKernelCallBuilder = { |
114 | "mgpuLaunchKernel" , |
115 | llvmVoidType, |
116 | { |
117 | llvmPointerType, /* void* f */ |
118 | llvmIntPtrType, /* intptr_t gridXDim */ |
119 | llvmIntPtrType, /* intptr_t gridyDim */ |
120 | llvmIntPtrType, /* intptr_t gridZDim */ |
121 | llvmIntPtrType, /* intptr_t blockXDim */ |
122 | llvmIntPtrType, /* intptr_t blockYDim */ |
123 | llvmIntPtrType, /* intptr_t blockZDim */ |
124 | llvmInt32Type, /* unsigned int sharedMemBytes */ |
125 | llvmPointerType, /* void *hstream */ |
126 | llvmPointerType, /* void **kernelParams */ |
127 | llvmPointerType, /* void **extra */ |
128 | llvmInt64Type /* size_t paramsCount */ |
129 | }}; |
130 | FunctionCallBuilder streamCreateCallBuilder = { |
131 | "mgpuStreamCreate" , llvmPointerType /* void *stream */, {}}; |
132 | FunctionCallBuilder streamDestroyCallBuilder = { |
133 | "mgpuStreamDestroy" , llvmVoidType, {llvmPointerType /* void *stream */}}; |
134 | FunctionCallBuilder streamSynchronizeCallBuilder = { |
135 | "mgpuStreamSynchronize" , |
136 | llvmVoidType, |
137 | {llvmPointerType /* void *stream */}}; |
138 | FunctionCallBuilder streamWaitEventCallBuilder = { |
139 | "mgpuStreamWaitEvent" , |
140 | llvmVoidType, |
141 | {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}}; |
142 | FunctionCallBuilder eventCreateCallBuilder = { |
143 | "mgpuEventCreate" , llvmPointerType /* void *event */, {}}; |
144 | FunctionCallBuilder eventDestroyCallBuilder = { |
145 | "mgpuEventDestroy" , llvmVoidType, {llvmPointerType /* void *event */}}; |
146 | FunctionCallBuilder eventSynchronizeCallBuilder = { |
147 | "mgpuEventSynchronize" , |
148 | llvmVoidType, |
149 | {llvmPointerType /* void *event */}}; |
150 | FunctionCallBuilder eventRecordCallBuilder = { |
151 | "mgpuEventRecord" , |
152 | llvmVoidType, |
153 | {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}}; |
154 | FunctionCallBuilder hostRegisterCallBuilder = { |
155 | "mgpuMemHostRegisterMemRef" , |
156 | llvmVoidType, |
157 | {llvmIntPtrType /* intptr_t rank */, |
158 | llvmPointerType /* void *memrefDesc */, |
159 | llvmIntPtrType /* intptr_t elementSizeBytes */}}; |
160 | FunctionCallBuilder hostUnregisterCallBuilder = { |
161 | "mgpuMemHostUnregisterMemRef" , |
162 | llvmVoidType, |
163 | {llvmIntPtrType /* intptr_t rank */, |
164 | llvmPointerType /* void *memrefDesc */, |
165 | llvmIntPtrType /* intptr_t elementSizeBytes */}}; |
166 | FunctionCallBuilder allocCallBuilder = { |
167 | "mgpuMemAlloc" , |
168 | llvmPointerType /* void * */, |
169 | {llvmIntPtrType /* intptr_t sizeBytes */, |
170 | llvmPointerType /* void *stream */, |
171 | llvmInt8Type /* bool isHostShared */}}; |
172 | FunctionCallBuilder deallocCallBuilder = { |
173 | "mgpuMemFree" , |
174 | llvmVoidType, |
175 | {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}}; |
176 | FunctionCallBuilder memcpyCallBuilder = { |
177 | "mgpuMemcpy" , |
178 | llvmVoidType, |
179 | {llvmPointerType /* void *dst */, llvmPointerType /* void *src */, |
180 | llvmIntPtrType /* intptr_t sizeBytes */, |
181 | llvmPointerType /* void *stream */}}; |
182 | FunctionCallBuilder memset16CallBuilder = { |
183 | "mgpuMemset16" , |
184 | llvmVoidType, |
185 | {llvmPointerType /* void *dst */, |
186 | llvmInt16Type /* unsigned short value */, |
187 | llvmIntPtrType /* intptr_t sizeBytes */, |
188 | llvmPointerType /* void *stream */}}; |
189 | FunctionCallBuilder memset32CallBuilder = { |
190 | "mgpuMemset32" , |
191 | llvmVoidType, |
192 | {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */, |
193 | llvmIntPtrType /* intptr_t sizeBytes */, |
194 | llvmPointerType /* void *stream */}}; |
195 | FunctionCallBuilder setDefaultDeviceCallBuilder = { |
196 | "mgpuSetDefaultDevice" , |
197 | llvmVoidType, |
198 | {llvmInt32Type /* uint32_t devIndex */}}; |
199 | FunctionCallBuilder createDnVecCallBuilder = { |
200 | "mgpuCreateDnVec" , |
201 | llvmPointerType, |
202 | {llvmIntPtrType, llvmPointerType, llvmInt32Type, |
203 | llvmPointerType /* void *stream */}}; |
204 | FunctionCallBuilder destroyDnVecCallBuilder = { |
205 | "mgpuDestroyDnVec" , |
206 | llvmVoidType, |
207 | {llvmPointerType, llvmPointerType /* void *stream */}}; |
208 | FunctionCallBuilder createDnMatCallBuilder = { |
209 | "mgpuCreateDnMat" , |
210 | llvmPointerType, |
211 | {llvmIntPtrType, llvmIntPtrType, llvmPointerType, llvmInt32Type, |
212 | llvmPointerType /* void *stream */}}; |
213 | FunctionCallBuilder destroyDnMatCallBuilder = { |
214 | "mgpuDestroyDnMat" , |
215 | llvmVoidType, |
216 | {llvmPointerType, llvmPointerType /* void *stream */}}; |
217 | FunctionCallBuilder createCooCallBuilder = { |
218 | "mgpuCreateCoo" , |
219 | llvmPointerType, |
220 | {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, |
221 | llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, |
222 | llvmPointerType /* void *stream */}}; |
223 | FunctionCallBuilder createCooAoSCallBuilder = { |
224 | "mgpuCreateCooAoS" , // deprecated in cuSPARSE 11.2 |
225 | llvmPointerType, |
226 | {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, |
227 | llvmPointerType, llvmInt32Type, llvmInt32Type, |
228 | llvmPointerType /* void *stream */}}; |
229 | FunctionCallBuilder createCsrCallBuilder = { |
230 | "mgpuCreateCsr" , |
231 | llvmPointerType, |
232 | {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, |
233 | llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, |
234 | llvmInt32Type, llvmPointerType /* void *stream */}}; |
235 | FunctionCallBuilder createCscCallBuilder = { |
236 | "mgpuCreateCsc" , |
237 | llvmPointerType, |
238 | {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, |
239 | llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, |
240 | llvmInt32Type, llvmPointerType /* void *stream */}}; |
241 | FunctionCallBuilder createBsrCallBuilder = { |
242 | "mgpuCreateBsr" , |
243 | llvmPointerType, |
244 | {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, |
245 | llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType, |
246 | llvmInt32Type, llvmInt32Type, llvmInt32Type, |
247 | llvmPointerType /* void *stream */}}; |
248 | FunctionCallBuilder destroySpMatCallBuilder = { |
249 | "mgpuDestroySpMat" , |
250 | llvmVoidType, |
251 | {llvmPointerType, llvmPointerType /* void *stream */}}; |
252 | FunctionCallBuilder spMVBufferSizeCallBuilder = { |
253 | "mgpuSpMVBufferSize" , |
254 | llvmIntPtrType, |
255 | {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, |
256 | llvmInt32Type, llvmPointerType /* void *stream */}}; |
257 | FunctionCallBuilder spMVCallBuilder = { |
258 | "mgpuSpMV" , |
259 | llvmVoidType, |
260 | {llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType, |
261 | llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}}; |
262 | FunctionCallBuilder createSpMMBufferSizeCallBuilder = { |
263 | "mgpuSpMMBufferSize" , |
264 | llvmIntPtrType, |
265 | {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, |
266 | llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}}; |
267 | FunctionCallBuilder createSpMMCallBuilder = { |
268 | "mgpuSpMM" , |
269 | llvmVoidType, |
270 | {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, |
271 | llvmPointerType, llvmInt32Type, llvmPointerType, |
272 | llvmPointerType /* void *stream */}}; |
273 | FunctionCallBuilder createSDDMMBufferSizeCallBuilder = { |
274 | "mgpuSDDMMBufferSize" , |
275 | llvmIntPtrType, |
276 | {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, |
277 | llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}}; |
278 | FunctionCallBuilder createSDDMMCallBuilder = { |
279 | "mgpuSDDMM" , |
280 | llvmVoidType, |
281 | {llvmInt32Type, llvmInt32Type, llvmPointerType, llvmPointerType, |
282 | llvmPointerType, llvmInt32Type, llvmPointerType, |
283 | llvmPointerType /* void *stream */}}; |
284 | FunctionCallBuilder createLtDnMatCallBuilder = { |
285 | "mgpuCreateCuSparseLtDnMat" , |
286 | llvmVoidType, |
287 | {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, |
288 | llvmInt32Type, llvmPointerType /* void *stream */}}; |
289 | FunctionCallBuilder destroyCuSparseLtSpMatBuilder = { |
290 | "mgpuDestroyCuSparseLtSpMat" , |
291 | llvmVoidType, |
292 | {llvmPointerType, llvmPointerType /* void *stream */}}; |
293 | FunctionCallBuilder destroyCuSparseLtDnMatBuilder = { |
294 | "mgpuDestroyCuSparseLtDnMat" , |
295 | llvmVoidType, |
296 | {llvmPointerType, llvmPointerType /* void *stream */}}; |
297 | FunctionCallBuilder create2To4SpMatCallBuilder = { |
298 | "mgpuCusparseLtCreate2To4SpMat" , |
299 | llvmVoidType, |
300 | {llvmPointerType, llvmIntPtrType, llvmIntPtrType, llvmPointerType, |
301 | llvmInt32Type, llvmPointerType /* void *stream */}}; |
302 | FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = { |
303 | "mgpuCuSparseLtSpMMBufferSize" , |
304 | llvmVoidType, |
305 | {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType, |
306 | llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type, |
307 | llvmPointerType /*void *stream*/}}; |
308 | FunctionCallBuilder createCuSparseLtSpMMBuilder = { |
309 | "mgpuCuSparseLtSpMM" , |
310 | llvmVoidType, |
311 | {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType, |
312 | llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}}; |
313 | FunctionCallBuilder createSpGEMMCreateDescrBuilder = { |
314 | "mgpuSpGEMMCreateDescr" , |
315 | llvmPointerType, |
316 | {llvmPointerType /*void *stream*/}}; |
317 | FunctionCallBuilder createSpGEMMDestroyDescrBuilder = { |
318 | "mgpuSpGEMMDestroyDescr" , |
319 | llvmVoidType, |
320 | {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}}; |
321 | FunctionCallBuilder createSpGEMMWorkEstimationBuilder = { |
322 | "mgpuSpGEMMWorkEstimation" , |
323 | llvmIntPtrType, |
324 | {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/, |
325 | llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/, |
326 | llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/, |
327 | llvmPointerType /*void *stream*/}}; |
328 | FunctionCallBuilder createSpGEMMComputeBuilder = { |
329 | "mgpuSpGEMMCompute" , |
330 | llvmIntPtrType, |
331 | {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/, |
332 | llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/, |
333 | llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/, |
334 | llvmPointerType /*void *stream*/}}; |
335 | FunctionCallBuilder createSpGEMMCopyBuilder = { |
336 | "mgpuSpGEMMCopy" , |
337 | llvmVoidType, |
338 | {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/, |
339 | llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/, |
340 | llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}}; |
341 | FunctionCallBuilder createSpMatGetSizeBuilder = { |
342 | "mgpuSpMatGetSize" , |
343 | llvmVoidType, |
344 | {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/, |
345 | llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}}; |
346 | FunctionCallBuilder createSetCsrPointersBuilder = { |
347 | "mgpuSetCsrPointers" , |
348 | llvmVoidType, |
349 | {llvmPointerType /*spmat*/, llvmPointerType /*pos*/, |
350 | llvmPointerType /*crd*/, llvmPointerType /*val*/, |
351 | llvmPointerType /*void *stream*/}}; |
352 | }; |
353 | |
354 | /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime |
355 | /// call. Currently it supports CUDA and ROCm (HIP). |
356 | class ConvertHostRegisterOpToGpuRuntimeCallPattern |
357 | : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> { |
358 | public: |
359 | ConvertHostRegisterOpToGpuRuntimeCallPattern( |
360 | const LLVMTypeConverter &typeConverter) |
361 | : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {} |
362 | |
363 | private: |
364 | LogicalResult |
365 | matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, |
366 | ConversionPatternRewriter &rewriter) const override; |
367 | }; |
368 | |
369 | class ConvertHostUnregisterOpToGpuRuntimeCallPattern |
370 | : public ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp> { |
371 | public: |
372 | ConvertHostUnregisterOpToGpuRuntimeCallPattern( |
373 | const LLVMTypeConverter &typeConverter) |
374 | : ConvertOpToGpuRuntimeCallPattern<gpu::HostUnregisterOp>(typeConverter) { |
375 | } |
376 | |
377 | private: |
378 | LogicalResult |
379 | matchAndRewrite(gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor, |
380 | ConversionPatternRewriter &rewriter) const override; |
381 | }; |
382 | |
383 | /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime |
384 | /// call. Currently it supports CUDA and ROCm (HIP). |
385 | class ConvertAllocOpToGpuRuntimeCallPattern |
386 | : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> { |
387 | public: |
388 | ConvertAllocOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) |
389 | : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {} |
390 | |
391 | private: |
392 | LogicalResult |
393 | matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor, |
394 | ConversionPatternRewriter &rewriter) const override; |
395 | }; |
396 | |
397 | /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime |
398 | /// call. Currently it supports CUDA and ROCm (HIP). |
399 | class ConvertDeallocOpToGpuRuntimeCallPattern |
400 | : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> { |
401 | public: |
402 | ConvertDeallocOpToGpuRuntimeCallPattern( |
403 | const LLVMTypeConverter &typeConverter) |
404 | : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {} |
405 | |
406 | private: |
407 | LogicalResult |
408 | matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor, |
409 | ConversionPatternRewriter &rewriter) const override; |
410 | }; |
411 | |
412 | class ConvertAsyncYieldToGpuRuntimeCallPattern |
413 | : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> { |
414 | public: |
415 | ConvertAsyncYieldToGpuRuntimeCallPattern( |
416 | const LLVMTypeConverter &typeConverter) |
417 | : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {} |
418 | |
419 | private: |
420 | LogicalResult |
421 | matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor, |
422 | ConversionPatternRewriter &rewriter) const override; |
423 | }; |
424 | |
425 | /// A rewrite pattern to convert gpu.wait operations into a GPU runtime |
426 | /// call. Currently it supports CUDA and ROCm (HIP). |
427 | class ConvertWaitOpToGpuRuntimeCallPattern |
428 | : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { |
429 | public: |
430 | ConvertWaitOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) |
431 | : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} |
432 | |
433 | private: |
434 | LogicalResult |
435 | matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, |
436 | ConversionPatternRewriter &rewriter) const override; |
437 | }; |
438 | |
439 | /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime |
440 | /// call. Currently it supports CUDA and ROCm (HIP). |
441 | class ConvertWaitAsyncOpToGpuRuntimeCallPattern |
442 | : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { |
443 | public: |
444 | ConvertWaitAsyncOpToGpuRuntimeCallPattern( |
445 | const LLVMTypeConverter &typeConverter) |
446 | : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} |
447 | |
448 | private: |
449 | LogicalResult |
450 | matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, |
451 | ConversionPatternRewriter &rewriter) const override; |
452 | }; |
453 | |
454 | /// A rewrite patter to convert gpu.launch_func operations into a sequence of |
455 | /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP). |
456 | /// |
457 | /// In essence, a gpu.launch_func operations gets compiled into the following |
458 | /// sequence of runtime calls: |
459 | /// |
460 | /// * moduleLoad -- loads the module given the cubin / hsaco data |
461 | /// * moduleGetFunction -- gets a handle to the actual kernel function |
462 | /// * getStreamHelper -- initializes a new compute stream on GPU |
463 | /// * launchKernel -- launches the kernel on a stream |
464 | /// * streamSynchronize -- waits for operations on the stream to finish |
465 | /// |
466 | /// Intermediate data structures are allocated on the stack. |
467 | class ConvertLaunchFuncOpToGpuRuntimeCallPattern |
468 | : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> { |
469 | public: |
470 | ConvertLaunchFuncOpToGpuRuntimeCallPattern( |
471 | const LLVMTypeConverter &typeConverter, StringRef gpuBinaryAnnotation, |
472 | bool kernelBarePtrCallConv, SymbolTable *cachedModuleTable) |
473 | : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter), |
474 | gpuBinaryAnnotation(gpuBinaryAnnotation), |
475 | kernelBarePtrCallConv(kernelBarePtrCallConv), |
476 | cachedModuleTable(cachedModuleTable) {} |
477 | |
478 | private: |
479 | Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, |
480 | OpBuilder &builder) const; |
481 | Value generateKernelNameConstant(StringRef moduleName, StringRef name, |
482 | Location loc, OpBuilder &builder) const; |
483 | |
484 | LogicalResult |
485 | matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, |
486 | ConversionPatternRewriter &rewriter) const override; |
487 | |
488 | llvm::SmallString<32> gpuBinaryAnnotation; |
489 | bool kernelBarePtrCallConv; |
490 | SymbolTable *cachedModuleTable; |
491 | }; |
492 | |
493 | class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> { |
494 | using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern; |
495 | |
496 | LogicalResult matchAndRewrite(gpu::GPUModuleOp op, |
497 | PatternRewriter &rewriter) const override { |
498 | // GPU kernel modules are no longer necessary since we have a global |
499 | // constant with the CUBIN, or HSACO data. |
500 | rewriter.eraseOp(op: op); |
501 | return success(); |
502 | } |
503 | }; |
504 | |
505 | /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime |
506 | /// call. Currently it supports CUDA and ROCm (HIP). |
507 | class ConvertMemcpyOpToGpuRuntimeCallPattern |
508 | : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> { |
509 | public: |
510 | ConvertMemcpyOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) |
511 | : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {} |
512 | |
513 | private: |
514 | LogicalResult |
515 | matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, |
516 | ConversionPatternRewriter &rewriter) const override; |
517 | }; |
518 | |
519 | /// A rewrite pattern to convert gpu.memset operations into a GPU runtime |
520 | /// call. Currently it supports CUDA and ROCm (HIP). |
521 | class ConvertMemsetOpToGpuRuntimeCallPattern |
522 | : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> { |
523 | public: |
524 | ConvertMemsetOpToGpuRuntimeCallPattern(const LLVMTypeConverter &typeConverter) |
525 | : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {} |
526 | |
527 | private: |
528 | LogicalResult |
529 | matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor, |
530 | ConversionPatternRewriter &rewriter) const override; |
531 | }; |
532 | |
533 | /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call. |
534 | /// Currently supports CUDA and ROCm (HIP) |
535 | class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern |
536 | : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> { |
537 | public: |
538 | ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern( |
539 | const LLVMTypeConverter &typeConverter) |
540 | : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>( |
541 | typeConverter) {} |
542 | |
543 | LogicalResult |
544 | matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, |
545 | ConversionPatternRewriter &rewriter) const override; |
546 | }; |
547 | |
548 | /// Generic rewriting rule for operation on sparse matrices. |
549 | /// Currently supports CUDA (by means of cuSparse and cuSparseLt). |
550 | #define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name) \ |
551 | class Convert##op_name##ToGpuRuntimeCallPattern \ |
552 | : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> { \ |
553 | public: \ |
554 | Convert##op_name##ToGpuRuntimeCallPattern( \ |
555 | const LLVMTypeConverter &typeConverter) \ |
556 | : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {} \ |
557 | \ |
558 | private: \ |
559 | LogicalResult \ |
560 | matchAndRewrite(gpu::op_name op, OpAdaptor adaptor, \ |
561 | ConversionPatternRewriter &rewriter) const override; \ |
562 | }; |
563 | |
564 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateDnTensorOp) |
565 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp) |
566 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooOp) |
567 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCooAoSOp) |
568 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCsrOp) |
569 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateCscOp) |
570 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(CreateBsrOp) |
571 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(Create2To4SpMatOp) |
572 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroySpMatOp) |
573 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVBufferSizeOp) |
574 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMVOp) |
575 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMBufferSizeOp) |
576 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMBufferSizeOp) |
577 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMMOp) |
578 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp) |
579 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp) |
580 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp) |
581 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp) |
582 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp) |
583 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpMatGetSizeOp) |
584 | DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp) |
585 | |
586 | } // namespace |
587 | |
588 | void GpuToLLVMConversionPass::runOnOperation() { |
589 | MLIRContext *context = &getContext(); |
590 | SymbolTable symbolTable = SymbolTable(getOperation()); |
591 | LowerToLLVMOptions options(context); |
592 | options.useBarePtrCallConv = hostBarePtrCallConv; |
593 | RewritePatternSet patterns(context); |
594 | ConversionTarget target(*context); |
595 | target.addLegalDialect<LLVM::LLVMDialect>(); |
596 | LLVMTypeConverter converter(context, options); |
597 | |
598 | // Populate all patterns from all dialects that implement the |
599 | // `ConvertToLLVMPatternInterface` interface. |
600 | for (Dialect *dialect : context->getLoadedDialects()) { |
601 | auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); |
602 | if (!iface) |
603 | continue; |
604 | iface->populateConvertToLLVMConversionPatterns(target, converter, patterns); |
605 | } |
606 | |
607 | // Preserve GPU modules if they have target attributes. |
608 | target.addDynamicallyLegalOp<gpu::GPUModuleOp>( |
609 | callback: [](gpu::GPUModuleOp module) -> bool { |
610 | return module.getTargetsAttr() != nullptr; |
611 | }); |
612 | // Accept as legal LaunchFuncOps if they refer to GPU Modules with targets and |
613 | // the operands have been lowered. |
614 | target.addDynamicallyLegalOp<gpu::LaunchFuncOp>( |
615 | [&](gpu::LaunchFuncOp op) -> bool { |
616 | auto module = |
617 | symbolTable.lookup<gpu::GPUModuleOp>(op.getKernelModuleName()); |
618 | return converter.isLegal(op->getOperandTypes()) && |
619 | converter.isLegal(op->getResultTypes()) && |
620 | (module && module.getTargetsAttr() && |
621 | !module.getTargetsAttr().empty()); |
622 | }); |
623 | |
624 | // These aren't covered by the ConvertToLLVMPatternInterface right now. |
625 | populateVectorToLLVMConversionPatterns(converter, patterns); |
626 | populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); |
627 | populateAsyncStructuralTypeConversionsAndLegality(typeConverter&: converter, patterns, |
628 | target); |
629 | populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation, |
630 | kernelBarePtrCallConv, &symbolTable); |
631 | |
632 | if (failed( |
633 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
634 | signalPassFailure(); |
635 | } |
636 | |
637 | LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, |
638 | ArrayRef<Value> arguments) const { |
639 | auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>(); |
640 | auto function = [&] { |
641 | if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) |
642 | return function; |
643 | return OpBuilder::atBlockEnd(module.getBody()) |
644 | .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); |
645 | }(); |
646 | return builder.create<LLVM::CallOp>(loc, function, arguments); |
647 | } |
648 | |
649 | // Corresponding to cusparseIndexType_t defined in cusparse.h. |
650 | static int32_t getCuSparseIndexTypeFrom(Type type) { |
651 | if (type.isInteger(width: 16)) |
652 | return 1; // CUSPARSE_INDEX_16U |
653 | if (type.isInteger(width: 32)) |
654 | return 2; // CUSPARSE_INDEX_32I |
655 | return 3; // CUSPARSE_INDEX_64I |
656 | } |
657 | |
658 | static int32_t getCuSparseLtDataTypeFrom(Type type) { |
659 | if (type.isF16()) |
660 | return 0; // CUSPARSE_COMPUTE_16F, |
661 | if (type.isInteger(width: 32)) |
662 | return 1; // CUSPARSE_COMPUTE_32I |
663 | llvm_unreachable("unsupported type" ); |
664 | // TODO: add support to TF32 |
665 | } |
666 | |
667 | // Corresponding to cudaDataType_t defined in CUDA library_types.h. |
668 | static int32_t getCuSparseDataTypeFrom(Type type) { |
669 | if (llvm::isa<ComplexType>(type)) { |
670 | // get the element type |
671 | auto elementType = cast<ComplexType>(type).getElementType(); |
672 | if (elementType.isBF16()) |
673 | return 15; // CUDA_C_16BF |
674 | if (elementType.isF16()) |
675 | return 6; // CUDA_C_16F |
676 | if (elementType.isF32()) |
677 | return 4; // CUDA_C_32F |
678 | if (elementType.isF64()) |
679 | return 5; // CUDA_C_64F |
680 | if (elementType.isInteger(8)) |
681 | return 7; // CUDA_C_8I |
682 | if (elementType.isInteger(16)) |
683 | return 21; // CUDA_C_16I |
684 | if (elementType.isInteger(32)) |
685 | return 11; // CUDA_C_32I |
686 | } |
687 | if (type.isBF16()) |
688 | return 14; // CUDA_R_16BF |
689 | if (type.isF16()) |
690 | return 2; // CUDA_R_16F |
691 | if (type.isF32()) |
692 | return 0; // CUDA_R_32F |
693 | if (type.isF64()) |
694 | return 1; // CUDA_R_64F |
695 | if (type.isInteger(width: 8)) |
696 | return 3; // CUDA_R_8I |
697 | if (type.isInteger(width: 16)) |
698 | return 20; // CUDA_R_16I |
699 | if (type.isInteger(width: 32)) |
700 | return 10; // CUDA_R_32I |
701 | |
702 | llvm_unreachable("unsupported element type" ); |
703 | } |
704 | |
705 | static gpu::Prune2To4SpMatFlag get2To4PruneFlag(Value spMat) { |
706 | return spMat.getDefiningOp<gpu::Create2To4SpMatOp>().getPruneFlag(); |
707 | } |
708 | |
709 | // TODO: We may want a run-time (of the mlir compiler) disablement/warning: |
710 | // cusparseLt currently won't work for cuda architecture <8.0 and will trigger a |
711 | // runtime (of the CUDA program) error , but it might be great if we could at |
712 | // least output a warning when we found the target architecture is <8.0 and the |
713 | // user still wants to use cusparseLt. to make sure when lowering gpu sparse |
714 | // dialect to llvm calls, the cusparselt calls are disabled for cuda |
715 | // architecture <8.0 |
716 | static bool is2To4Sparsity(Value spMat) { |
717 | if (auto op = spMat.getDefiningOp<gpu::Create2To4SpMatOp>()) |
718 | return true; |
719 | if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>()) |
720 | return false; |
721 | if (auto op = spMat.getDefiningOp<gpu::CreateCooAoSOp>()) |
722 | return false; |
723 | if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>()) |
724 | return false; |
725 | if (auto op = spMat.getDefiningOp<gpu::CreateCscOp>()) |
726 | return false; |
727 | if (auto op = spMat.getDefiningOp<gpu::CreateBsrOp>()) |
728 | return false; |
729 | // Print the spMat defining op |
730 | spMat.getDefiningOp()->print(os&: llvm::errs()); |
731 | llvm_unreachable("cannot find spmat def" ); |
732 | } |
733 | |
734 | static bool isSpMMCusparseLtOp(Value op) { |
735 | for (Operation *user : op.getUsers()) { |
736 | auto spmmOp = dyn_cast<gpu::SpMMOp>(user); |
737 | // If the other operator is 50% sparsity then we should use cusparseLt |
738 | if (!spmmOp) |
739 | continue; |
740 | if (is2To4Sparsity(spmmOp.getSpmatA())) |
741 | return true; |
742 | } |
743 | return false; |
744 | } |
745 | |
746 | // Returns whether all operands are of LLVM type. |
747 | static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, |
748 | ConversionPatternRewriter &rewriter) { |
749 | if (!llvm::all_of(Range&: operands, P: [](Value value) { |
750 | return LLVM::isCompatibleType(type: value.getType()); |
751 | })) |
752 | return rewriter.notifyMatchFailure( |
753 | arg&: op, msg: "Cannot convert if operands aren't of LLVM type." ); |
754 | return success(); |
755 | } |
756 | |
757 | static LogicalResult |
758 | isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, |
759 | gpu::AsyncOpInterface op) { |
760 | if (op.getAsyncDependencies().size() != 1) |
761 | return rewriter.notifyMatchFailure( |
762 | op, "Can only convert with exactly one async dependency." ); |
763 | |
764 | if (!op.getAsyncToken()) |
765 | return rewriter.notifyMatchFailure(op, "Can convert only async version." ); |
766 | |
767 | return success(); |
768 | } |
769 | |
770 | LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( |
771 | gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, |
772 | ConversionPatternRewriter &rewriter) const { |
773 | auto *op = hostRegisterOp.getOperation(); |
774 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) |
775 | return failure(); |
776 | |
777 | Location loc = op->getLoc(); |
778 | |
779 | auto memRefType = hostRegisterOp.getValue().getType(); |
780 | auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType(); |
781 | auto elementSize = getSizeInBytes(loc, elementType, rewriter); |
782 | |
783 | auto arguments = getTypeConverter()->promoteOperands( |
784 | loc, op->getOperands(), adaptor.getOperands(), rewriter); |
785 | arguments.push_back(elementSize); |
786 | hostRegisterCallBuilder.create(loc, rewriter, arguments); |
787 | |
788 | rewriter.eraseOp(op: op); |
789 | return success(); |
790 | } |
791 | |
792 | LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite( |
793 | gpu::HostUnregisterOp hostUnregisterOp, OpAdaptor adaptor, |
794 | ConversionPatternRewriter &rewriter) const { |
795 | Operation *op = hostUnregisterOp.getOperation(); |
796 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) |
797 | return failure(); |
798 | |
799 | Location loc = op->getLoc(); |
800 | |
801 | auto memRefType = hostUnregisterOp.getValue().getType(); |
802 | auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType(); |
803 | auto elementSize = getSizeInBytes(loc, elementType, rewriter); |
804 | |
805 | auto arguments = getTypeConverter()->promoteOperands( |
806 | loc, op->getOperands(), adaptor.getOperands(), rewriter); |
807 | arguments.push_back(elementSize); |
808 | hostUnregisterCallBuilder.create(loc, rewriter, arguments); |
809 | |
810 | rewriter.eraseOp(op); |
811 | return success(); |
812 | } |
813 | |
814 | LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( |
815 | gpu::AllocOp allocOp, OpAdaptor adaptor, |
816 | ConversionPatternRewriter &rewriter) const { |
817 | |
818 | MemRefType memRefType = allocOp.getType(); |
819 | |
820 | if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) || |
821 | !isConvertibleAndHasIdentityMaps(memRefType)) |
822 | return failure(); |
823 | |
824 | auto loc = allocOp.getLoc(); |
825 | |
826 | bool isShared = allocOp.getHostShared(); |
827 | |
828 | if (isShared && allocOp.getAsyncToken()) |
829 | return rewriter.notifyMatchFailure( |
830 | allocOp, "Host Shared allocation cannot be done async" ); |
831 | if (!isShared && failed(isAsyncWithOneDependency(rewriter, allocOp))) |
832 | return failure(); |
833 | |
834 | // Get shape of the memref as values: static sizes are constant |
835 | // values and dynamic sizes are passed to 'alloc' as operands. |
836 | SmallVector<Value, 4> shape; |
837 | SmallVector<Value, 4> strides; |
838 | Value sizeBytes; |
839 | getMemRefDescriptorSizes(loc, memRefType, adaptor.getDynamicSizes(), rewriter, |
840 | shape, strides, sizeBytes); |
841 | |
842 | // Allocate the underlying buffer and store a pointer to it in the MemRef |
843 | // descriptor. |
844 | auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType); |
845 | Value stream = adaptor.getAsyncDependencies().empty() |
846 | ? nullPtr |
847 | : adaptor.getAsyncDependencies().front(); |
848 | |
849 | auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>( |
850 | loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); |
851 | |
852 | Value allocatedPtr = |
853 | allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared}) |
854 | .getResult(); |
855 | |
856 | // No alignment. |
857 | Value alignedPtr = allocatedPtr; |
858 | |
859 | // Create the MemRef descriptor. |
860 | auto memRefDescriptor = this->createMemRefDescriptor( |
861 | loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); |
862 | |
863 | if (allocOp.getAsyncToken()) { |
864 | // Async alloc: make dependent ops use the same stream. |
865 | rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); |
866 | } else { |
867 | rewriter.replaceOp(allocOp, {memRefDescriptor}); |
868 | } |
869 | |
870 | return success(); |
871 | } |
872 | |
873 | LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( |
874 | gpu::DeallocOp deallocOp, OpAdaptor adaptor, |
875 | ConversionPatternRewriter &rewriter) const { |
876 | if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) || |
877 | failed(isAsyncWithOneDependency(rewriter, deallocOp))) |
878 | return failure(); |
879 | |
880 | Location loc = deallocOp.getLoc(); |
881 | |
882 | Value pointer = |
883 | MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc); |
884 | Value stream = adaptor.getAsyncDependencies().front(); |
885 | deallocCallBuilder.create(loc, rewriter, {pointer, stream}); |
886 | |
887 | rewriter.replaceOp(deallocOp, {stream}); |
888 | return success(); |
889 | } |
890 | |
891 | static bool isGpuAsyncTokenType(Value value) { |
892 | return isa<gpu::AsyncTokenType>(Val: value.getType()); |
893 | } |
894 | |
895 | // Converts !gpu.async.token operands of `async.yield` to runtime calls. The |
896 | // !gpu.async.token are lowered to stream within the async.execute region, but |
897 | // are passed as events between them. For each !gpu.async.token operand, we |
898 | // create an event and record it on the stream. |
899 | LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( |
900 | async::YieldOp yieldOp, OpAdaptor adaptor, |
901 | ConversionPatternRewriter &rewriter) const { |
902 | if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType)) |
903 | return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand" ); |
904 | |
905 | Location loc = yieldOp.getLoc(); |
906 | SmallVector<Value, 4> newOperands(adaptor.getOperands()); |
907 | llvm::SmallDenseSet<Value> streams; |
908 | for (auto &operand : yieldOp->getOpOperands()) { |
909 | if (!isGpuAsyncTokenType(operand.get())) |
910 | continue; |
911 | auto idx = operand.getOperandNumber(); |
912 | auto stream = adaptor.getOperands()[idx]; |
913 | auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(); |
914 | eventRecordCallBuilder.create(loc, rewriter, {event, stream}); |
915 | newOperands[idx] = event; |
916 | streams.insert(stream); |
917 | } |
918 | for (auto stream : streams) |
919 | streamDestroyCallBuilder.create(loc, rewriter, {stream}); |
920 | |
921 | rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); }); |
922 | return success(); |
923 | } |
924 | |
925 | // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. |
926 | static bool isDefinedByCallTo(Value value, StringRef functionName) { |
927 | assert(isa<LLVM::LLVMPointerType>(value.getType())); |
928 | if (auto defOp = value.getDefiningOp<LLVM::CallOp>()) |
929 | return defOp.getCallee()->equals(functionName); |
930 | return false; |
931 | } |
932 | |
933 | // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host |
934 | // with the stream/event operands. The operands are destroyed. That is, it |
935 | // assumes that it is not used afterwards or elsewhere. Otherwise we will get a |
936 | // runtime error. Eventually, we should guarantee this property. |
937 | LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( |
938 | gpu::WaitOp waitOp, OpAdaptor adaptor, |
939 | ConversionPatternRewriter &rewriter) const { |
940 | if (waitOp.getAsyncToken()) |
941 | return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op." ); |
942 | |
943 | Location loc = waitOp.getLoc(); |
944 | |
945 | for (auto operand : adaptor.getOperands()) { |
946 | if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { |
947 | // The converted operand's definition created a stream. |
948 | streamSynchronizeCallBuilder.create(loc, rewriter, {operand}); |
949 | streamDestroyCallBuilder.create(loc, rewriter, {operand}); |
950 | } else { |
951 | // Otherwise the converted operand is an event. This assumes that we use |
952 | // events in control flow code as well. |
953 | eventSynchronizeCallBuilder.create(loc, rewriter, {operand}); |
954 | eventDestroyCallBuilder.create(loc, rewriter, {operand}); |
955 | } |
956 | } |
957 | |
958 | rewriter.eraseOp(op: waitOp); |
959 | return success(); |
960 | } |
961 | |
962 | // Converts `gpu.wait async` to runtime calls. The converted op creates a new |
963 | // stream that is synchronized with stream/event operands. The operands are |
964 | // destroyed. That is, it assumes that it is not used afterwards or elsewhere. |
965 | // Otherwise we will get a runtime error. Eventually, we should guarantee this |
966 | // property. |
967 | LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( |
968 | gpu::WaitOp waitOp, OpAdaptor adaptor, |
969 | ConversionPatternRewriter &rewriter) const { |
970 | if (!waitOp.getAsyncToken()) |
971 | return rewriter.notifyMatchFailure(waitOp, "Can only convert async op." ); |
972 | |
973 | Location loc = waitOp.getLoc(); |
974 | |
975 | auto insertionPoint = rewriter.saveInsertionPoint(); |
976 | SmallVector<Value, 1> events; |
977 | for (auto pair : |
978 | llvm::zip(waitOp.getAsyncDependencies(), adaptor.getOperands())) { |
979 | auto operand = std::get<1>(pair); |
980 | if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { |
981 | // The converted operand's definition created a stream. Insert an event |
982 | // into the stream just after the last use of the original token operand. |
983 | auto *defOp = std::get<0>(pair).getDefiningOp(); |
984 | rewriter.setInsertionPointAfter(defOp); |
985 | auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(); |
986 | eventRecordCallBuilder.create(loc, rewriter, {event, operand}); |
987 | events.push_back(event); |
988 | } else { |
989 | // Otherwise the converted operand is an event. This assumes that we use |
990 | // events in control flow code as well. |
991 | events.push_back(operand); |
992 | } |
993 | } |
994 | rewriter.restoreInsertionPoint(ip: insertionPoint); |
995 | auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); |
996 | for (auto event : events) |
997 | streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); |
998 | for (auto event : events) |
999 | eventDestroyCallBuilder.create(loc, rewriter, {event}); |
1000 | rewriter.replaceOp(waitOp, {stream}); |
1001 | |
1002 | return success(); |
1003 | } |
1004 | |
1005 | // Creates a struct containing all kernel parameters on the stack and returns |
1006 | // an array of type-erased pointers to the fields of the struct. The array can |
1007 | // then be passed to the CUDA / ROCm (HIP) kernel launch calls. |
1008 | // The generated code is essentially as follows: |
1009 | // |
1010 | // %struct = alloca(sizeof(struct { Parameters... })) |
1011 | // %array = alloca(NumParameters * sizeof(void *)) |
1012 | // for (i : [0, NumParameters)) |
1013 | // %fieldPtr = llvm.getelementptr %struct[0, i] |
1014 | // llvm.store parameters[i], %fieldPtr |
1015 | // %elementPtr = llvm.getelementptr %array[i] |
1016 | // llvm.store %fieldPtr, %elementPtr |
1017 | // return %array |
1018 | Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( |
1019 | gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const { |
1020 | auto loc = launchOp.getLoc(); |
1021 | auto numKernelOperands = launchOp.getNumKernelOperands(); |
1022 | // Note: If `useBarePtrCallConv` is set in the type converter's options, |
1023 | // the value of `kernelBarePtrCallConv` will be ignored. |
1024 | SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands( |
1025 | loc, launchOp.getOperands().take_back(numKernelOperands), |
1026 | adaptor.getOperands().take_back(numKernelOperands), builder, |
1027 | /*useBarePtrCallConv=*/kernelBarePtrCallConv); |
1028 | auto numArguments = arguments.size(); |
1029 | SmallVector<Type, 4> argumentTypes; |
1030 | argumentTypes.reserve(N: numArguments); |
1031 | for (auto argument : arguments) |
1032 | argumentTypes.push_back(argument.getType()); |
1033 | auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(), |
1034 | argumentTypes); |
1035 | auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 1); |
1036 | auto structPtr = |
1037 | builder.create<LLVM::AllocaOp>(loc, llvmPointerType, structType, one, |
1038 | /*alignment=*/0); |
1039 | auto arraySize = |
1040 | builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, numArguments); |
1041 | auto arrayPtr = builder.create<LLVM::AllocaOp>( |
1042 | loc, llvmPointerType, llvmPointerType, arraySize, /*alignment=*/0); |
1043 | for (const auto &en : llvm::enumerate(arguments)) { |
1044 | const auto index = static_cast<int32_t>(en.index()); |
1045 | Value fieldPtr = |
1046 | builder.create<LLVM::GEPOp>(loc, llvmPointerType, structType, structPtr, |
1047 | ArrayRef<LLVM::GEPArg>{0, index}); |
1048 | builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr); |
1049 | auto elementPtr = |
1050 | builder.create<LLVM::GEPOp>(loc, llvmPointerType, llvmPointerType, |
1051 | arrayPtr, ArrayRef<LLVM::GEPArg>{index}); |
1052 | builder.create<LLVM::StoreOp>(loc, fieldPtr, elementPtr); |
1053 | } |
1054 | return arrayPtr; |
1055 | } |
1056 | |
1057 | // Generates an LLVM IR dialect global that contains the name of the given |
1058 | // kernel function as a C string, and returns a pointer to its beginning. |
1059 | // The code is essentially: |
1060 | // |
1061 | // llvm.global constant @kernel_name("function_name\00") |
1062 | // func(...) { |
1063 | // %0 = llvm.addressof @kernel_name |
1064 | // %1 = llvm.constant (0 : index) |
1065 | // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> |
1066 | // } |
1067 | Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant( |
1068 | StringRef moduleName, StringRef name, Location loc, |
1069 | OpBuilder &builder) const { |
1070 | // Make sure the trailing zero is included in the constant. |
1071 | std::vector<char> kernelName(name.begin(), name.end()); |
1072 | kernelName.push_back(x: '\0'); |
1073 | |
1074 | std::string globalName = |
1075 | std::string(llvm::formatv(Fmt: "{0}_{1}_kernel_name" , Vals&: moduleName, Vals&: name)); |
1076 | return LLVM::createGlobalString( |
1077 | loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()), |
1078 | LLVM::Linkage::Internal); |
1079 | } |
1080 | |
1081 | // Emits LLVM IR to launch a kernel function. Expects the module that contains |
1082 | // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a |
1083 | // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. |
1084 | // |
1085 | // %0 = call %binarygetter |
1086 | // %1 = call %moduleLoad(%0) |
1087 | // %2 = <see generateKernelNameConstant> |
1088 | // %3 = call %moduleGetFunction(%1, %2) |
1089 | // %4 = call %streamCreate() |
1090 | // %5 = <see generateParamsArray> |
1091 | // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr) |
1092 | // call %streamSynchronize(%4) |
1093 | // call %streamDestroy(%4) |
1094 | // call %moduleUnload(%1) |
1095 | // |
1096 | // If the op is async, the stream corresponds to the (single) async dependency |
1097 | // as well as the async token the op produces. |
1098 | LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( |
1099 | gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, |
1100 | ConversionPatternRewriter &rewriter) const { |
1101 | if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter))) |
1102 | return failure(); |
1103 | |
1104 | if (launchOp.getAsyncDependencies().size() > 1) |
1105 | return rewriter.notifyMatchFailure( |
1106 | launchOp, "Cannot convert with more than one async dependency." ); |
1107 | |
1108 | // Fail when the synchronous version of the op has async dependencies. The |
1109 | // lowering destroys the stream, and we do not want to check that there is no |
1110 | // use of the stream after this op. |
1111 | if (!launchOp.getAsyncToken() && !launchOp.getAsyncDependencies().empty()) |
1112 | return rewriter.notifyMatchFailure( |
1113 | launchOp, "Cannot convert non-async op with async dependencies." ); |
1114 | |
1115 | Location loc = launchOp.getLoc(); |
1116 | |
1117 | // Create an LLVM global with CUBIN extracted from the kernel annotation and |
1118 | // obtain a pointer to the first byte in it. |
1119 | gpu::GPUModuleOp kernelModule; |
1120 | if (cachedModuleTable) |
1121 | kernelModule = cachedModuleTable->lookup<gpu::GPUModuleOp>( |
1122 | launchOp.getKernelModuleName()); |
1123 | else |
1124 | kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>( |
1125 | launchOp, launchOp.getKernelModuleName()); |
1126 | assert(kernelModule && "expected a kernel module" ); |
1127 | |
1128 | // If the module has Targets then just update the op operands. |
1129 | if (ArrayAttr targets = kernelModule.getTargetsAttr()) { |
1130 | Value stream = Value(); |
1131 | if (!adaptor.getAsyncDependencies().empty()) |
1132 | stream = adaptor.getAsyncDependencies().front(); |
1133 | // If the async keyword is present and there are no dependencies, then a |
1134 | // stream must be created to pass to subsequent operations. |
1135 | else if (launchOp.getAsyncToken()) |
1136 | stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); |
1137 | |
1138 | // Lower the kernel operands to match kernel parameters. |
1139 | // Note: If `useBarePtrCallConv` is set in the type converter's options, |
1140 | // the value of `kernelBarePtrCallConv` will be ignored. |
1141 | SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands( |
1142 | loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), |
1143 | rewriter, /*useBarePtrCallConv=*/kernelBarePtrCallConv); |
1144 | |
1145 | std::optional<gpu::KernelDim3> clusterSize = std::nullopt; |
1146 | if (launchOp.hasClusterSize()) { |
1147 | clusterSize = |
1148 | gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), |
1149 | adaptor.getClusterSizeZ()}; |
1150 | } |
1151 | rewriter.create<gpu::LaunchFuncOp>( |
1152 | launchOp.getLoc(), launchOp.getKernelAttr(), |
1153 | gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), |
1154 | adaptor.getGridSizeZ()}, |
1155 | gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), |
1156 | adaptor.getBlockSizeZ()}, |
1157 | adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize); |
1158 | if (launchOp.getAsyncToken()) |
1159 | rewriter.replaceOp(launchOp, {stream}); |
1160 | else |
1161 | rewriter.eraseOp(op: launchOp); |
1162 | return success(); |
1163 | } |
1164 | |
1165 | auto binaryAttr = |
1166 | kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation); |
1167 | if (!binaryAttr) { |
1168 | kernelModule.emitOpError() |
1169 | << "missing " << gpuBinaryAnnotation << " attribute" ; |
1170 | return failure(); |
1171 | } |
1172 | |
1173 | SmallString<128> nameBuffer(kernelModule.getName()); |
1174 | nameBuffer.append(RHS: kGpuBinaryStorageSuffix); |
1175 | Value data = |
1176 | LLVM::createGlobalString(loc, rewriter, nameBuffer.str(), |
1177 | binaryAttr.getValue(), LLVM::Linkage::Internal); |
1178 | |
1179 | // Pass the binary size. SPIRV requires binary size. |
1180 | auto gpuBlob = binaryAttr.getValue(); |
1181 | auto gpuBlobSize = rewriter.create<mlir::LLVM::ConstantOp>( |
1182 | loc, llvmInt64Type, |
1183 | mlir::IntegerAttr::get(llvmInt64Type, |
1184 | static_cast<int64_t>(gpuBlob.size()))); |
1185 | |
1186 | auto module = |
1187 | moduleLoadCallBuilder.create(loc, rewriter, {data, gpuBlobSize}); |
1188 | |
1189 | // Pass the count of the parameters to runtime wrappers |
1190 | auto paramsCount = rewriter.create<mlir::LLVM::ConstantOp>( |
1191 | loc, llvmInt64Type, |
1192 | mlir::IntegerAttr::get( |
1193 | llvmInt64Type, |
1194 | static_cast<int64_t>(launchOp.getNumKernelOperands()))); |
1195 | |
1196 | // Get the function from the module. The name corresponds to the name of |
1197 | // the kernel function. |
1198 | auto kernelName = generateKernelNameConstant( |
1199 | moduleName: launchOp.getKernelModuleName().getValue(), |
1200 | name: launchOp.getKernelName().getValue(), loc, builder&: rewriter); |
1201 | auto function = moduleGetFunctionCallBuilder.create( |
1202 | loc, rewriter, {module.getResult(), kernelName}); |
1203 | Value zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0); |
1204 | Value stream = |
1205 | adaptor.getAsyncDependencies().empty() |
1206 | ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult() |
1207 | : adaptor.getAsyncDependencies().front(); |
1208 | // Create array of pointers to kernel arguments. |
1209 | auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter); |
1210 | auto nullpointer = rewriter.create<LLVM::ZeroOp>(loc, llvmPointerType); |
1211 | Value dynamicSharedMemorySize = launchOp.getDynamicSharedMemorySize() |
1212 | ? launchOp.getDynamicSharedMemorySize() |
1213 | : zero; |
1214 | launchKernelCallBuilder.create( |
1215 | loc, rewriter, |
1216 | {function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(), |
1217 | adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), |
1218 | adaptor.getBlockSizeZ(), dynamicSharedMemorySize, stream, kernelParams, |
1219 | /*extra=*/nullpointer, paramsCount}); |
1220 | |
1221 | if (launchOp.getAsyncToken()) { |
1222 | // Async launch: make dependent ops use the same stream. |
1223 | rewriter.replaceOp(launchOp, {stream}); |
1224 | } else { |
1225 | // Synchronize with host and destroy stream. This must be the stream created |
1226 | // above (with no other uses) because we check that the synchronous version |
1227 | // does not have any async dependencies. |
1228 | streamSynchronizeCallBuilder.create(loc, rewriter, stream); |
1229 | streamDestroyCallBuilder.create(loc, rewriter, stream); |
1230 | rewriter.eraseOp(op: launchOp); |
1231 | } |
1232 | moduleUnloadCallBuilder.create(loc, rewriter, module.getResult()); |
1233 | |
1234 | return success(); |
1235 | } |
1236 | |
1237 | static Value bitAndAddrspaceCast(Location loc, |
1238 | ConversionPatternRewriter &rewriter, |
1239 | LLVM::LLVMPointerType destinationType, |
1240 | Value sourcePtr, |
1241 | const LLVMTypeConverter &typeConverter) { |
1242 | auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType()); |
1243 | if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) |
1244 | sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>( |
1245 | loc, |
1246 | LLVM::LLVMPointerType::get(rewriter.getContext(), |
1247 | destinationType.getAddressSpace()), |
1248 | sourcePtr); |
1249 | return sourcePtr; |
1250 | } |
1251 | |
1252 | LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( |
1253 | gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, |
1254 | ConversionPatternRewriter &rewriter) const { |
1255 | auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType()); |
1256 | |
1257 | if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || |
1258 | !isConvertibleAndHasIdentityMaps(memRefType) || |
1259 | failed(isAsyncWithOneDependency(rewriter, memcpyOp))) |
1260 | return failure(); |
1261 | |
1262 | auto loc = memcpyOp.getLoc(); |
1263 | |
1264 | MemRefDescriptor srcDesc(adaptor.getSrc()); |
1265 | Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); |
1266 | |
1267 | Type elementPtrType = getElementPtrType(memRefType); |
1268 | Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); |
1269 | Value gepPtr = rewriter.create<LLVM::GEPOp>( |
1270 | loc, elementPtrType, |
1271 | typeConverter->convertType(memRefType.getElementType()), nullPtr, |
1272 | numElements); |
1273 | auto sizeBytes = |
1274 | rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); |
1275 | |
1276 | auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, |
1277 | srcDesc.alignedPtr(rewriter, loc), |
1278 | *getTypeConverter()); |
1279 | auto dst = bitAndAddrspaceCast( |
1280 | loc, rewriter, llvmPointerType, |
1281 | MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc), |
1282 | *getTypeConverter()); |
1283 | |
1284 | auto stream = adaptor.getAsyncDependencies().front(); |
1285 | memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); |
1286 | |
1287 | rewriter.replaceOp(memcpyOp, {stream}); |
1288 | |
1289 | return success(); |
1290 | } |
1291 | |
1292 | LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( |
1293 | gpu::MemsetOp memsetOp, OpAdaptor adaptor, |
1294 | ConversionPatternRewriter &rewriter) const { |
1295 | auto memRefType = cast<MemRefType>(memsetOp.getDst().getType()); |
1296 | |
1297 | if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || |
1298 | !isConvertibleAndHasIdentityMaps(memRefType) || |
1299 | failed(isAsyncWithOneDependency(rewriter, memsetOp))) |
1300 | return failure(); |
1301 | |
1302 | auto loc = memsetOp.getLoc(); |
1303 | |
1304 | Type valueType = adaptor.getValue().getType(); |
1305 | unsigned bitWidth = valueType.getIntOrFloatBitWidth(); |
1306 | // Ints and floats of 16 or 32 bit width are allowed. |
1307 | if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) { |
1308 | return rewriter.notifyMatchFailure( |
1309 | memsetOp, "value must be a 16 or 32 bit int or float" ); |
1310 | } |
1311 | |
1312 | unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth(); |
1313 | Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type; |
1314 | |
1315 | MemRefDescriptor dstDesc(adaptor.getDst()); |
1316 | Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); |
1317 | |
1318 | auto value = |
1319 | rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue()); |
1320 | auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, |
1321 | dstDesc.alignedPtr(rewriter, loc), |
1322 | *getTypeConverter()); |
1323 | |
1324 | auto stream = adaptor.getAsyncDependencies().front(); |
1325 | FunctionCallBuilder builder = |
1326 | valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder; |
1327 | builder.create(loc, rewriter, {dst, value, numElements, stream}); |
1328 | |
1329 | rewriter.replaceOp(memsetOp, {stream}); |
1330 | return success(); |
1331 | } |
1332 | |
1333 | LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( |
1334 | gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, |
1335 | ConversionPatternRewriter &rewriter) const { |
1336 | Location loc = op.getLoc(); |
1337 | auto call = setDefaultDeviceCallBuilder.create(loc, rewriter, |
1338 | {adaptor.getDevIndex()}); |
1339 | rewriter.replaceOp(op, call); |
1340 | return success(); |
1341 | } |
1342 | |
1343 | template <typename T> |
1344 | static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) { |
1345 | Type llvmInt32Type = builder.getIntegerType(32); |
1346 | return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, |
1347 | static_cast<int32_t>(tValue)); |
1348 | } |
1349 | |
1350 | template <typename T> |
1351 | static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) { |
1352 | Type llvmFloat32Type = builder.getF32Type(); |
1353 | return builder.create<LLVM::ConstantOp>( |
1354 | loc, llvmFloat32Type, |
1355 | builder.getF32FloatAttr(static_cast<float>(tValue))); |
1356 | } |
1357 | |
1358 | LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( |
1359 | gpu::CreateDnTensorOp op, OpAdaptor adaptor, |
1360 | ConversionPatternRewriter &rewriter) const { |
1361 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1362 | failed(isAsyncWithOneDependency(rewriter, op))) |
1363 | return failure(); |
1364 | Location loc = op.getLoc(); |
1365 | auto stream = adaptor.getAsyncDependencies().front(); |
1366 | Value pTensor = |
1367 | MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc); |
1368 | Type dType = op.getMemref().getType().getElementType(); |
1369 | auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType)); |
1370 | |
1371 | SmallVector<Value, 4> dims; |
1372 | for (Value dim : adaptor.getDims()) { |
1373 | dims.push_back(dim); |
1374 | } |
1375 | |
1376 | Value handle; |
1377 | // TODO: For now, we track the use of the handle and lower it to cusparse / |
1378 | // cusparseLt accordingly. If in a block, both cusparse and cusparseLt are |
1379 | // used, we require two separate Creation ops to be the correct logic. In |
1380 | // future, we may add support to using one handle in sparse tensor / GPU |
1381 | // dialect in both cusparse and cusparseLt. use the cusparseLt create call if |
1382 | // the dnmat is used with spmat with 2:4 sparsity |
1383 | if (dims.size() == 2) { |
1384 | if (isSpMMCusparseLtOp(op.getDnTensor())) { |
1385 | auto handleSz = rewriter.create<LLVM::ConstantOp>( |
1386 | loc, getIndexType(), rewriter.getIndexAttr(11032)); |
1387 | handle = rewriter.create<LLVM::AllocaOp>( |
1388 | loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); |
1389 | handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); |
1390 | |
1391 | createLtDnMatCallBuilder |
1392 | .create(loc, rewriter, |
1393 | {handle, dims[0], dims[1], pTensor, dtp, stream}) |
1394 | .getResult(); |
1395 | } else { |
1396 | handle = |
1397 | createDnMatCallBuilder |
1398 | .create(loc, rewriter, {dims[0], dims[1], pTensor, dtp, stream}) |
1399 | .getResult(); |
1400 | } |
1401 | } else { |
1402 | assert(dims.size() == 1 && "Only 1D and 2D tensors are supported" ); |
1403 | handle = createDnVecCallBuilder |
1404 | .create(loc, rewriter, {dims[0], pTensor, dtp, stream}) |
1405 | .getResult(); |
1406 | } |
1407 | rewriter.replaceOp(op, {handle, stream}); |
1408 | return success(); |
1409 | } |
1410 | |
1411 | LogicalResult ConvertDestroyDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( |
1412 | gpu::DestroyDnTensorOp op, OpAdaptor adaptor, |
1413 | ConversionPatternRewriter &rewriter) const { |
1414 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1415 | failed(isAsyncWithOneDependency(rewriter, op))) |
1416 | return failure(); |
1417 | Location loc = op.getLoc(); |
1418 | auto stream = adaptor.getAsyncDependencies().front(); |
1419 | auto definingOp = op.getDnTensor().getDefiningOp<gpu::CreateDnTensorOp>(); |
1420 | SmallVector<Value, 4> dims; |
1421 | for (Value dim : definingOp.getDims()) { |
1422 | dims.push_back(dim); |
1423 | } |
1424 | if (dims.size() == 2) { |
1425 | // Use the cusparseLt destroy call if the dnmat is used with spmat with |
1426 | // 2:4 sparsity |
1427 | if (isSpMMCusparseLtOp(op.getDnTensor())) { |
1428 | destroyCuSparseLtDnMatBuilder.create(loc, rewriter, |
1429 | {adaptor.getDnTensor(), stream}); |
1430 | } else { |
1431 | destroyDnMatCallBuilder.create(loc, rewriter, |
1432 | {adaptor.getDnTensor(), stream}); |
1433 | } |
1434 | } else { |
1435 | assert(dims.size() == 1 && "Only 1D and 2D tensors are supported" ); |
1436 | destroyDnVecCallBuilder.create(loc, rewriter, |
1437 | {adaptor.getDnTensor(), stream}); |
1438 | } |
1439 | rewriter.replaceOp(op, {stream}); |
1440 | return success(); |
1441 | } |
1442 | |
1443 | LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite( |
1444 | gpu::CreateCooOp op, OpAdaptor adaptor, |
1445 | ConversionPatternRewriter &rewriter) const { |
1446 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1447 | failed(isAsyncWithOneDependency(rewriter, op))) |
1448 | return failure(); |
1449 | Location loc = op.getLoc(); |
1450 | auto stream = adaptor.getAsyncDependencies().front(); |
1451 | Value pRowIdxs = |
1452 | MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(builder&: rewriter, loc); |
1453 | Value pColIdxs = |
1454 | MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(builder&: rewriter, loc); |
1455 | Value pValues = |
1456 | MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc); |
1457 | Type iType = |
1458 | llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType(); |
1459 | Type dType = |
1460 | llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); |
1461 | auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType)); |
1462 | auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType)); |
1463 | auto handle = |
1464 | createCooCallBuilder |
1465 | .create(loc, rewriter, |
1466 | {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), |
1467 | pRowIdxs, pColIdxs, pValues, itp, dtp, stream}) |
1468 | .getResult(); |
1469 | rewriter.replaceOp(op, {handle, stream}); |
1470 | return success(); |
1471 | } |
1472 | |
1473 | LogicalResult ConvertCreateCooAoSOpToGpuRuntimeCallPattern::matchAndRewrite( |
1474 | gpu::CreateCooAoSOp op, OpAdaptor adaptor, |
1475 | ConversionPatternRewriter &rewriter) const { |
1476 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1477 | failed(isAsyncWithOneDependency(rewriter, op))) |
1478 | return failure(); |
1479 | Location loc = op.getLoc(); |
1480 | auto stream = adaptor.getAsyncDependencies().front(); |
1481 | Value pIdxs = MemRefDescriptor(adaptor.getIdxs()).allocatedPtr(builder&: rewriter, loc); |
1482 | Value pValues = |
1483 | MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc); |
1484 | Type iType = llvm::cast<MemRefType>(op.getIdxs().getType()).getElementType(); |
1485 | Type dType = |
1486 | llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); |
1487 | auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType)); |
1488 | auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType)); |
1489 | auto handle = |
1490 | createCooAoSCallBuilder |
1491 | .create(loc, rewriter, |
1492 | {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), |
1493 | pIdxs, pValues, itp, dtp, stream}) |
1494 | .getResult(); |
1495 | rewriter.replaceOp(op, {handle, stream}); |
1496 | return success(); |
1497 | } |
1498 | |
1499 | LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite( |
1500 | gpu::CreateCsrOp op, OpAdaptor adaptor, |
1501 | ConversionPatternRewriter &rewriter) const { |
1502 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1503 | failed(isAsyncWithOneDependency(rewriter, op))) |
1504 | return failure(); |
1505 | Location loc = op.getLoc(); |
1506 | auto stream = adaptor.getAsyncDependencies().front(); |
1507 | Value pRowPos = |
1508 | MemRefDescriptor(adaptor.getRowPos()).allocatedPtr(builder&: rewriter, loc); |
1509 | Value pColIdxs = |
1510 | MemRefDescriptor(adaptor.getColIdxs()).allocatedPtr(builder&: rewriter, loc); |
1511 | Value pValues = |
1512 | MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc); |
1513 | Type pType = |
1514 | llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType(); |
1515 | Type iType = |
1516 | llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType(); |
1517 | Type dType = |
1518 | llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); |
1519 | auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType)); |
1520 | auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType)); |
1521 | auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType)); |
1522 | auto handle = |
1523 | createCsrCallBuilder |
1524 | .create(loc, rewriter, |
1525 | {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), |
1526 | pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream}) |
1527 | .getResult(); |
1528 | rewriter.replaceOp(op, {handle, stream}); |
1529 | return success(); |
1530 | } |
1531 | |
1532 | LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite( |
1533 | gpu::Create2To4SpMatOp op, OpAdaptor adaptor, |
1534 | ConversionPatternRewriter &rewriter) const { |
1535 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1536 | failed(isAsyncWithOneDependency(rewriter, op))) |
1537 | return failure(); |
1538 | Location loc = op.getLoc(); |
1539 | auto stream = adaptor.getAsyncDependencies().front(); |
1540 | Value pMat = |
1541 | MemRefDescriptor(adaptor.getMemref()).allocatedPtr(builder&: rewriter, loc); |
1542 | Type dType = |
1543 | llvm::cast<MemRefType>(op.getMemref().getType()).getElementType(); |
1544 | auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType)); |
1545 | |
1546 | // CUDA runner asserts the size is 44104 bytes. |
1547 | auto handleSz = rewriter.create<LLVM::ConstantOp>( |
1548 | loc, getIndexType(), rewriter.getIndexAttr(44104)); |
1549 | Value handle = rewriter.create<LLVM::AllocaOp>( |
1550 | loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); |
1551 | handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); |
1552 | |
1553 | create2To4SpMatCallBuilder |
1554 | .create(loc, rewriter, |
1555 | {handle, adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream}) |
1556 | .getResult(); |
1557 | rewriter.replaceOp(op, {handle, stream}); |
1558 | return success(); |
1559 | } |
1560 | |
1561 | LogicalResult ConvertDestroySpMatOpToGpuRuntimeCallPattern::matchAndRewrite( |
1562 | gpu::DestroySpMatOp op, OpAdaptor adaptor, |
1563 | ConversionPatternRewriter &rewriter) const { |
1564 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1565 | failed(isAsyncWithOneDependency(rewriter, op))) |
1566 | return failure(); |
1567 | Location loc = op.getLoc(); |
1568 | auto stream = adaptor.getAsyncDependencies().front(); |
1569 | // Use the cusparseLt destroy call if the spmat is 2:4 sparsity |
1570 | if (is2To4Sparsity(op.getSpmat())) { |
1571 | destroyCuSparseLtSpMatBuilder.create(loc, rewriter, |
1572 | {adaptor.getSpmat(), stream}); |
1573 | |
1574 | } else { |
1575 | destroySpMatCallBuilder.create(loc, rewriter, {adaptor.getSpmat(), stream}); |
1576 | } |
1577 | rewriter.replaceOp(op, {stream}); |
1578 | return success(); |
1579 | } |
1580 | |
1581 | LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( |
1582 | gpu::SpMVBufferSizeOp op, OpAdaptor adaptor, |
1583 | ConversionPatternRewriter &rewriter) const { |
1584 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1585 | failed(isAsyncWithOneDependency(rewriter, op))) |
1586 | return failure(); |
1587 | Location loc = op.getLoc(); |
1588 | auto modeA = genConstInt32From(rewriter, loc, op.getModeA()); |
1589 | auto computeType = genConstInt32From( |
1590 | rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); |
1591 | auto stream = adaptor.getAsyncDependencies().front(); |
1592 | auto bufferSize = spMVBufferSizeCallBuilder |
1593 | .create(loc, rewriter, |
1594 | {modeA, adaptor.getSpmatA(), adaptor.getDnX(), |
1595 | adaptor.getDnY(), computeType, stream}) |
1596 | .getResult(); |
1597 | rewriter.replaceOp(op, {bufferSize, stream}); |
1598 | return success(); |
1599 | } |
1600 | |
1601 | LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite( |
1602 | gpu::SpMVOp op, OpAdaptor adaptor, |
1603 | ConversionPatternRewriter &rewriter) const { |
1604 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1605 | failed(isAsyncWithOneDependency(rewriter, op))) |
1606 | return failure(); |
1607 | Location loc = op.getLoc(); |
1608 | auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); |
1609 | auto computeType = genConstInt32From( |
1610 | rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); |
1611 | auto stream = adaptor.getAsyncDependencies().front(); |
1612 | Value pBuf = |
1613 | MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc); |
1614 | spMVCallBuilder.create(loc, rewriter, |
1615 | {modeA, adaptor.getSpmatA(), adaptor.getDnX(), |
1616 | adaptor.getDnY(), computeType, pBuf, stream}); |
1617 | rewriter.replaceOp(op, {stream}); |
1618 | return success(); |
1619 | } |
1620 | |
1621 | LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( |
1622 | gpu::SpMMBufferSizeOp op, OpAdaptor adaptor, |
1623 | ConversionPatternRewriter &rewriter) const { |
1624 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1625 | failed(isAsyncWithOneDependency(rewriter, op))) |
1626 | return failure(); |
1627 | Location loc = op.getLoc(); |
1628 | auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); |
1629 | auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); |
1630 | auto stream = adaptor.getAsyncDependencies().front(); |
1631 | Value bufferSize; |
1632 | if (is2To4Sparsity(op.getSpmatA())) { |
1633 | auto pruneFlag = |
1634 | genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA())); |
1635 | auto computeType = genConstInt32From( |
1636 | rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType())); |
1637 | auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), |
1638 | rewriter.getIndexAttr(3)); |
1639 | auto bufferSize = rewriter.create<LLVM::AllocaOp>( |
1640 | loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16); |
1641 | createCuSparseLtSpMMBufferSizeBuilder |
1642 | .create(loc, rewriter, |
1643 | {bufferSize, modeA, modeB, adaptor.getSpmatA(), |
1644 | adaptor.getDnmatB(), adaptor.getDnmatC(), computeType, |
1645 | pruneFlag, stream}) |
1646 | .getResult(); |
1647 | |
1648 | auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>( |
1649 | loc, llvmPointerType, llvmPointerType, bufferSize, |
1650 | ValueRange{rewriter.create<LLVM::ConstantOp>( |
1651 | loc, getIndexType(), rewriter.getIndexAttr(1))}); |
1652 | auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>( |
1653 | loc, llvmPointerType, llvmPointerType, bufferSize, |
1654 | ValueRange{rewriter.create<LLVM::ConstantOp>( |
1655 | loc, getIndexType(), rewriter.getIndexAttr(2))}); |
1656 | auto bufferSize0 = |
1657 | rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize); |
1658 | auto bufferSize1 = |
1659 | rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1); |
1660 | auto bufferSize2 = |
1661 | rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2); |
1662 | |
1663 | rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream}); |
1664 | } else { |
1665 | auto computeType = genConstInt32From( |
1666 | rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); |
1667 | bufferSize = |
1668 | createSpMMBufferSizeCallBuilder |
1669 | .create(loc, rewriter, |
1670 | {modeA, modeB, adaptor.getSpmatA(), adaptor.getDnmatB(), |
1671 | adaptor.getDnmatC(), computeType, stream}) |
1672 | .getResult(); |
1673 | rewriter.replaceOp(op, {bufferSize, stream}); |
1674 | } |
1675 | return success(); |
1676 | } |
1677 | |
1678 | LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( |
1679 | gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor, |
1680 | ConversionPatternRewriter &rewriter) const { |
1681 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1682 | failed(isAsyncWithOneDependency(rewriter, op))) |
1683 | return failure(); |
1684 | Location loc = op.getLoc(); |
1685 | auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); |
1686 | auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); |
1687 | auto computeType = genConstInt32From( |
1688 | rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); |
1689 | auto stream = adaptor.getAsyncDependencies().front(); |
1690 | auto bufferSize = |
1691 | createSDDMMBufferSizeCallBuilder |
1692 | .create(loc, rewriter, |
1693 | {modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(), |
1694 | adaptor.getSpmatC(), computeType, stream}) |
1695 | .getResult(); |
1696 | rewriter.replaceOp(op, {bufferSize, stream}); |
1697 | return success(); |
1698 | } |
1699 | |
1700 | LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite( |
1701 | gpu::SpMMOp op, OpAdaptor adaptor, |
1702 | ConversionPatternRewriter &rewriter) const { |
1703 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1704 | failed(isAsyncWithOneDependency(rewriter, op))) |
1705 | return failure(); |
1706 | Location loc = op.getLoc(); |
1707 | auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); |
1708 | auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); |
1709 | auto computeType = genConstInt32From( |
1710 | rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); |
1711 | |
1712 | auto stream = adaptor.getAsyncDependencies().front(); |
1713 | |
1714 | // Lower to cusparseLt if applicable |
1715 | if (is2To4Sparsity(op.getSpmatA())) { |
1716 | SmallVector<Value> pBufs; |
1717 | for (Value buffer : adaptor.getBuffers()) { |
1718 | Value pBuf = MemRefDescriptor(buffer).allocatedPtr(rewriter, loc); |
1719 | pBufs.push_back(pBuf); |
1720 | } |
1721 | createCuSparseLtSpMMBuilder.create( |
1722 | loc, rewriter, |
1723 | {adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(), |
1724 | pBufs[0], pBufs[1], pBufs[2], stream}); |
1725 | } else { |
1726 | Value pBuf = MemRefDescriptor(adaptor.getBuffers().front()) |
1727 | .allocatedPtr(builder&: rewriter, loc); |
1728 | createSpMMCallBuilder.create(loc, rewriter, |
1729 | {modeA, modeB, adaptor.getSpmatA(), |
1730 | adaptor.getDnmatB(), adaptor.getDnmatC(), |
1731 | computeType, pBuf, stream}); |
1732 | } |
1733 | rewriter.replaceOp(op, {stream}); |
1734 | return success(); |
1735 | } |
1736 | |
1737 | template <typename T> |
1738 | static void addOpaquePointerConversion(LLVMTypeConverter &converter) { |
1739 | converter.addConversion([&converter](T) -> Type { |
1740 | return LLVM::LLVMPointerType::get(&converter.getContext()); |
1741 | }); |
1742 | } |
1743 | |
1744 | LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite( |
1745 | gpu::SDDMMOp op, OpAdaptor adaptor, |
1746 | ConversionPatternRewriter &rewriter) const { |
1747 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1748 | failed(isAsyncWithOneDependency(rewriter, op))) |
1749 | return failure(); |
1750 | Location loc = op.getLoc(); |
1751 | auto computeType = genConstInt32From( |
1752 | rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); |
1753 | auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); |
1754 | auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); |
1755 | auto stream = adaptor.getAsyncDependencies().front(); |
1756 | Value pBuf = |
1757 | MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc); |
1758 | createSDDMMCallBuilder.create(loc, rewriter, |
1759 | {modeA, modeB, adaptor.getDnmatA(), |
1760 | adaptor.getDnmatB(), adaptor.getSpmatC(), |
1761 | computeType, pBuf, stream}); |
1762 | rewriter.replaceOp(op, {stream}); |
1763 | return success(); |
1764 | } |
1765 | |
1766 | LogicalResult |
1767 | ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite( |
1768 | gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor, |
1769 | ConversionPatternRewriter &rewriter) const { |
1770 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1771 | failed(isAsyncWithOneDependency(rewriter, op))) |
1772 | return failure(); |
1773 | Location loc = op.getLoc(); |
1774 | auto stream = adaptor.getAsyncDependencies().front(); |
1775 | Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream}) |
1776 | .getResult(); |
1777 | rewriter.replaceOp(op, {descr, stream}); |
1778 | return success(); |
1779 | } |
1780 | |
1781 | LogicalResult |
1782 | ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite( |
1783 | gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor, |
1784 | ConversionPatternRewriter &rewriter) const { |
1785 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1786 | failed(isAsyncWithOneDependency(rewriter, op))) |
1787 | return failure(); |
1788 | Location loc = op.getLoc(); |
1789 | auto stream = adaptor.getAsyncDependencies().front(); |
1790 | createSpGEMMDestroyDescrBuilder.create(loc, rewriter, |
1791 | {adaptor.getDesc(), stream}); |
1792 | rewriter.replaceOp(op, {stream}); |
1793 | return success(); |
1794 | } |
1795 | |
1796 | LogicalResult |
1797 | ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite( |
1798 | gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor, |
1799 | ConversionPatternRewriter &rewriter) const { |
1800 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1801 | failed(isAsyncWithOneDependency(rewriter, op))) |
1802 | return failure(); |
1803 | Location loc = op.getLoc(); |
1804 | auto computeType = genConstInt32From( |
1805 | rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); |
1806 | auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); |
1807 | auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); |
1808 | auto stream = adaptor.getAsyncDependencies().front(); |
1809 | |
1810 | Value pBuf = |
1811 | MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(builder&: rewriter, loc); |
1812 | Value bufferSizeNew; |
1813 | |
1814 | if (adaptor.getKind() == |
1815 | gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) { |
1816 | bufferSizeNew = |
1817 | createSpGEMMWorkEstimationBuilder |
1818 | .create(loc, rewriter, |
1819 | {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(), |
1820 | adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, |
1821 | adaptor.getBufferSz(), pBuf, stream}) |
1822 | .getResult(); |
1823 | } else { |
1824 | bufferSizeNew = |
1825 | createSpGEMMComputeBuilder |
1826 | .create(loc, rewriter, |
1827 | {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(), |
1828 | adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, |
1829 | adaptor.getBufferSz(), pBuf, stream}) |
1830 | .getResult(); |
1831 | } |
1832 | rewriter.replaceOp(op, {bufferSizeNew, stream}); |
1833 | return success(); |
1834 | } |
1835 | |
1836 | LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite( |
1837 | gpu::SpGEMMCopyOp op, OpAdaptor adaptor, |
1838 | ConversionPatternRewriter &rewriter) const { |
1839 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1840 | failed(isAsyncWithOneDependency(rewriter, op))) |
1841 | return failure(); |
1842 | Location loc = op.getLoc(); |
1843 | auto computeType = genConstInt32From( |
1844 | rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType())); |
1845 | auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA()); |
1846 | auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB()); |
1847 | auto stream = adaptor.getAsyncDependencies().front(); |
1848 | createSpGEMMCopyBuilder.create(loc, rewriter, |
1849 | {adaptor.getDesc(), modeA, modeB, |
1850 | adaptor.getSpmatA(), adaptor.getSpmatB(), |
1851 | adaptor.getSpmatC(), computeType, stream}); |
1852 | rewriter.replaceOp(op, {stream}); |
1853 | return success(); |
1854 | } |
1855 | |
1856 | LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite( |
1857 | gpu::SpMatGetSizeOp op, OpAdaptor adaptor, |
1858 | ConversionPatternRewriter &rewriter) const { |
1859 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1860 | failed(isAsyncWithOneDependency(rewriter, op))) |
1861 | return failure(); |
1862 | Location loc = op.getLoc(); |
1863 | auto stream = adaptor.getAsyncDependencies().front(); |
1864 | |
1865 | auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), |
1866 | rewriter.getIndexAttr(3)); |
1867 | auto buffer = rewriter.create<LLVM::AllocaOp>( |
1868 | loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16); |
1869 | |
1870 | auto rowsPtr = rewriter.create<LLVM::GEPOp>( |
1871 | loc, llvmPointerType, llvmPointerType, buffer, |
1872 | ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), |
1873 | rewriter.getIndexAttr(0))}); |
1874 | auto colsPtr = rewriter.create<LLVM::GEPOp>( |
1875 | loc, llvmPointerType, llvmPointerType, buffer, |
1876 | ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), |
1877 | rewriter.getIndexAttr(1))}); |
1878 | auto nnzsPtr = rewriter.create<LLVM::GEPOp>( |
1879 | loc, llvmPointerType, llvmPointerType, buffer, |
1880 | ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), |
1881 | rewriter.getIndexAttr(2))}); |
1882 | createSpMatGetSizeBuilder.create( |
1883 | loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream}); |
1884 | auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr); |
1885 | auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr); |
1886 | auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr); |
1887 | |
1888 | rewriter.replaceOp(op, {rows, cols, nnzs, stream}); |
1889 | return success(); |
1890 | } |
1891 | |
1892 | LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite( |
1893 | gpu::SetCsrPointersOp op, OpAdaptor adaptor, |
1894 | ConversionPatternRewriter &rewriter) const { |
1895 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1896 | failed(isAsyncWithOneDependency(rewriter, op))) |
1897 | return failure(); |
1898 | Location loc = op.getLoc(); |
1899 | auto stream = adaptor.getAsyncDependencies().front(); |
1900 | Value pPos = |
1901 | MemRefDescriptor(adaptor.getPositions()).allocatedPtr(builder&: rewriter, loc); |
1902 | Value pCrd = |
1903 | MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(builder&: rewriter, loc); |
1904 | Value pVal = |
1905 | MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc); |
1906 | createSetCsrPointersBuilder.create( |
1907 | loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream}); |
1908 | rewriter.replaceOp(op, {stream}); |
1909 | return success(); |
1910 | } |
1911 | |
1912 | LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite( |
1913 | gpu::CreateCscOp op, OpAdaptor adaptor, |
1914 | ConversionPatternRewriter &rewriter) const { |
1915 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1916 | failed(isAsyncWithOneDependency(rewriter, op))) |
1917 | return failure(); |
1918 | Location loc = op.getLoc(); |
1919 | auto stream = adaptor.getAsyncDependencies().front(); |
1920 | Value pColPos = |
1921 | MemRefDescriptor(adaptor.getColPos()).allocatedPtr(builder&: rewriter, loc); |
1922 | Value pRowIdxs = |
1923 | MemRefDescriptor(adaptor.getRowIdxs()).allocatedPtr(builder&: rewriter, loc); |
1924 | Value pValues = |
1925 | MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc); |
1926 | Type pType = |
1927 | llvm::cast<MemRefType>(op.getColPos().getType()).getElementType(); |
1928 | Type iType = |
1929 | llvm::cast<MemRefType>(op.getRowIdxs().getType()).getElementType(); |
1930 | Type dType = |
1931 | llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); |
1932 | auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType)); |
1933 | auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType)); |
1934 | auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType)); |
1935 | auto handle = |
1936 | createCscCallBuilder |
1937 | .create(loc, rewriter, |
1938 | {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(), |
1939 | pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream}) |
1940 | .getResult(); |
1941 | rewriter.replaceOp(op, {handle, stream}); |
1942 | return success(); |
1943 | } |
1944 | |
1945 | LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite( |
1946 | gpu::CreateBsrOp op, OpAdaptor adaptor, |
1947 | ConversionPatternRewriter &rewriter) const { |
1948 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) || |
1949 | failed(isAsyncWithOneDependency(rewriter, op))) |
1950 | return failure(); |
1951 | Location loc = op.getLoc(); |
1952 | auto stream = adaptor.getAsyncDependencies().front(); |
1953 | Value pRowPos = |
1954 | MemRefDescriptor(adaptor.getBRowPos()).allocatedPtr(builder&: rewriter, loc); |
1955 | Value pColIdxs = |
1956 | MemRefDescriptor(adaptor.getBColIdxs()).allocatedPtr(builder&: rewriter, loc); |
1957 | Value pValues = |
1958 | MemRefDescriptor(adaptor.getValues()).allocatedPtr(builder&: rewriter, loc); |
1959 | Type pType = |
1960 | llvm::cast<MemRefType>(op.getBRowPos().getType()).getElementType(); |
1961 | Type iType = |
1962 | llvm::cast<MemRefType>(op.getBColIdxs().getType()).getElementType(); |
1963 | Type dType = |
1964 | llvm::cast<MemRefType>(op.getValues().getType()).getElementType(); |
1965 | auto ptp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: pType)); |
1966 | auto itp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseIndexTypeFrom(type: iType)); |
1967 | auto dtp = genConstInt32From(builder&: rewriter, loc, tValue: getCuSparseDataTypeFrom(type: dType)); |
1968 | auto handle = |
1969 | createBsrCallBuilder |
1970 | .create(loc, rewriter, |
1971 | {adaptor.getBrows(), adaptor.getBcols(), adaptor.getBnnz(), |
1972 | adaptor.getRBlockSize(), adaptor.getCBlockSize(), pRowPos, |
1973 | pColIdxs, pValues, ptp, itp, dtp, stream}) |
1974 | .getResult(); |
1975 | rewriter.replaceOp(op, {handle, stream}); |
1976 | return success(); |
1977 | } |
1978 | |
1979 | void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, |
1980 | RewritePatternSet &patterns, |
1981 | StringRef gpuBinaryAnnotation, |
1982 | bool kernelBarePtrCallConv, |
1983 | SymbolTable *cachedModuleTable) { |
1984 | addOpaquePointerConversion<gpu::AsyncTokenType>(converter); |
1985 | addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter); |
1986 | addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter); |
1987 | addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter); |
1988 | |
1989 | patterns.add<ConvertAllocOpToGpuRuntimeCallPattern, |
1990 | ConvertDeallocOpToGpuRuntimeCallPattern, |
1991 | ConvertHostRegisterOpToGpuRuntimeCallPattern, |
1992 | ConvertHostUnregisterOpToGpuRuntimeCallPattern, |
1993 | ConvertMemcpyOpToGpuRuntimeCallPattern, |
1994 | ConvertMemsetOpToGpuRuntimeCallPattern, |
1995 | ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern, |
1996 | ConvertWaitAsyncOpToGpuRuntimeCallPattern, |
1997 | ConvertWaitOpToGpuRuntimeCallPattern, |
1998 | ConvertAsyncYieldToGpuRuntimeCallPattern, |
1999 | ConvertCreateDnTensorOpToGpuRuntimeCallPattern, |
2000 | ConvertDestroyDnTensorOpToGpuRuntimeCallPattern, |
2001 | ConvertCreateCooOpToGpuRuntimeCallPattern, |
2002 | ConvertCreateCooAoSOpToGpuRuntimeCallPattern, |
2003 | ConvertCreateCsrOpToGpuRuntimeCallPattern, |
2004 | ConvertCreateCscOpToGpuRuntimeCallPattern, |
2005 | ConvertCreateBsrOpToGpuRuntimeCallPattern, |
2006 | ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern, |
2007 | ConvertDestroySpMatOpToGpuRuntimeCallPattern, |
2008 | ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern, |
2009 | ConvertSpMVOpToGpuRuntimeCallPattern, |
2010 | ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern, |
2011 | ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern, |
2012 | ConvertSpMMOpToGpuRuntimeCallPattern, |
2013 | ConvertSDDMMOpToGpuRuntimeCallPattern, |
2014 | ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern, |
2015 | ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern, |
2016 | ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern, |
2017 | ConvertSpGEMMCopyOpToGpuRuntimeCallPattern, |
2018 | ConvertSpMatGetSizeOpToGpuRuntimeCallPattern, |
2019 | ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(arg&: converter); |
2020 | patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>( |
2021 | arg&: converter, args&: gpuBinaryAnnotation, args&: kernelBarePtrCallConv, args&: cachedModuleTable); |
2022 | patterns.add<EraseGpuModuleOpPattern>(arg: &converter.getContext()); |
2023 | } |
2024 | |