1 | //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" |
10 | |
11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
12 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
13 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
18 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
19 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
20 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
21 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
24 | #include "mlir/IR/PatternMatch.h" |
25 | #include "mlir/IR/TypeUtilities.h" |
26 | #include "mlir/IR/Value.h" |
27 | #include "mlir/Pass/Pass.h" |
28 | #include "llvm/Support/Debug.h" |
29 | #include "llvm/Support/ErrorHandling.h" |
30 | #include "llvm/Support/raw_ostream.h" |
31 | #include <optional> |
32 | |
33 | #define DEBUG_TYPE "nvgpu-to-nvvm" |
34 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
35 | #define DBGSE() (llvm::dbgs()) |
36 | |
37 | namespace mlir { |
38 | #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS |
39 | #include "mlir/Conversion/Passes.h.inc" |
40 | } // namespace mlir |
41 | |
42 | using namespace mlir; |
43 | |
44 | /// Number of bits that needs to be excluded when building matrix descriptor for |
45 | /// wgmma operations. |
46 | constexpr int exclude4LSB = 4; |
47 | |
48 | /// GPU has 32 bit registers, this function truncates values when larger width |
49 | /// is not needed. |
50 | static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { |
51 | Type type = value.getType(); |
52 | assert(llvm::isa<IntegerType>(type) && "expected an integer Value" ); |
53 | if (type.getIntOrFloatBitWidth() <= 32) |
54 | return value; |
55 | return b.create<LLVM::TruncOp>(b.getI32Type(), value); |
56 | } |
57 | |
58 | /// Returns the type for the intrinsic given the vectorResultType of the |
59 | /// `gpu.mma.sync` operation. |
60 | static Type inferIntrinsicResultType(Type vectorResultType) { |
61 | MLIRContext *ctx = vectorResultType.getContext(); |
62 | auto a = cast<LLVM::LLVMArrayType>(vectorResultType); |
63 | auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); |
64 | auto i32Ty = IntegerType::get(ctx, 32); |
65 | auto i32x2Ty = LLVM::getFixedVectorType(elementType: i32Ty, numElements: 2); |
66 | Type f64Ty = Float64Type::get(ctx); |
67 | Type f64x2Ty = LLVM::getFixedVectorType(elementType: f64Ty, numElements: 2); |
68 | Type f32Ty = Float32Type::get(ctx); |
69 | Type f32x2Ty = LLVM::getFixedVectorType(elementType: f32Ty, numElements: 2); |
70 | if (a.getElementType() == f16x2Ty) { |
71 | return LLVM::LLVMStructType::getLiteral( |
72 | context: ctx, types: SmallVector<Type>(a.getNumElements(), f16x2Ty)); |
73 | } |
74 | if (a.getElementType() == i32x2Ty) { |
75 | return LLVM::LLVMStructType::getLiteral( |
76 | context: ctx, |
77 | types: SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty)); |
78 | } |
79 | if (a.getElementType() == f64x2Ty) { |
80 | return LLVM::LLVMStructType::getLiteral(context: ctx, types: {f64Ty, f64Ty}); |
81 | } |
82 | if (a.getElementType() == f32x2Ty) { |
83 | return LLVM::LLVMStructType::getLiteral( |
84 | context: ctx, |
85 | types: SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty)); |
86 | } |
87 | if (a.getElementType() == LLVM::getFixedVectorType(elementType: f32Ty, numElements: 1)) { |
88 | return LLVM::LLVMStructType::getLiteral( |
89 | context: ctx, types: SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty)); |
90 | } |
91 | return vectorResultType; |
92 | } |
93 | |
94 | /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is |
95 | /// always an LLVM struct) into a fragment that is compatible with the vector |
96 | /// type of this operation. This involves extracting elements from the struct |
97 | /// and inserting them into an LLVM array. These extra data-movement |
98 | /// operations should be canonicalized away by the LLVM backend. |
99 | static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, |
100 | Type resultType, Value intrinsicResult, |
101 | RewriterBase &rewriter) { |
102 | MLIRContext *ctx = rewriter.getContext(); |
103 | auto structType = dyn_cast<LLVM::LLVMStructType>(Val&: intrinsicResultType); |
104 | auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType); |
105 | Type i32Ty = rewriter.getI32Type(); |
106 | Type f32Ty = rewriter.getF32Type(); |
107 | Type f64Ty = rewriter.getF64Type(); |
108 | Type f16x2Ty = LLVM::getFixedVectorType(elementType: rewriter.getF16Type(), numElements: 2); |
109 | Type i32x2Ty = LLVM::getFixedVectorType(elementType: i32Ty, numElements: 2); |
110 | Type f64x2Ty = LLVM::getFixedVectorType(elementType: f64Ty, numElements: 2); |
111 | Type f32x2Ty = LLVM::getFixedVectorType(elementType: f32Ty, numElements: 2); |
112 | Type f32x1Ty = LLVM::getFixedVectorType(elementType: f32Ty, numElements: 1); |
113 | |
114 | auto makeConst = [&](int32_t index) -> Value { |
115 | return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), |
116 | rewriter.getI32IntegerAttr(index)); |
117 | }; |
118 | |
119 | if (arrayType) { |
120 | SmallVector<Value, 4> elements; |
121 | |
122 | // The intrinsic returns 32-bit wide elements in a form which can be |
123 | // directly bitcasted and inserted into the result vector. |
124 | if (arrayType.getElementType() == f16x2Ty || |
125 | arrayType.getElementType() == f32x1Ty) { |
126 | for (unsigned i = 0; i < structType.getBody().size(); i++) { |
127 | Value el = |
128 | rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i); |
129 | el = rewriter.createOrFold<LLVM::BitcastOp>( |
130 | loc, arrayType.getElementType(), el); |
131 | elements.push_back(Elt: el); |
132 | } |
133 | } |
134 | |
135 | // The intrinsic returns i32, f64, and f32 values as individual scalars, |
136 | // even when the result is notionally a 64-bit wide element (e.g. f32x2). We |
137 | // need to extract them from the struct and pack them into the 64-bit wide |
138 | // rows of the vector result. |
139 | if (arrayType.getElementType() == i32x2Ty || |
140 | arrayType.getElementType() == f64x2Ty || |
141 | arrayType.getElementType() == f32x2Ty) { |
142 | |
143 | for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { |
144 | Value vec = |
145 | rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType()); |
146 | Value x1 = |
147 | rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2); |
148 | Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, |
149 | i * 2 + 1); |
150 | vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, |
151 | x1, makeConst(0)); |
152 | vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, |
153 | x2, makeConst(1)); |
154 | elements.push_back(Elt: vec); |
155 | } |
156 | } |
157 | |
158 | // Create the final vectorized result. |
159 | Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType); |
160 | for (const auto &el : llvm::enumerate(First&: elements)) { |
161 | result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(), |
162 | el.index()); |
163 | } |
164 | return result; |
165 | } |
166 | |
167 | return intrinsicResult; |
168 | } |
169 | |
170 | /// The `gpu.mma.sync` converter below expects matrix fragment operands to be |
171 | /// given as 2D `vectors` where the rows are 32b or 64b wide. The |
172 | /// `nvvm.mma.sync` op expects these argments to be a given in a long list of |
173 | /// scalars of certain types. This function helps unpack the `vector` arguments |
174 | /// and cast them to the types expected by `nvvm.mma.sync`. |
175 | static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, |
176 | Value operand, |
177 | NVVM::MMATypes operandPtxType) { |
178 | SmallVector<Value> result; |
179 | Type i32Ty = b.getI32Type(); |
180 | Type f64Ty = b.getF64Type(); |
181 | Type f32Ty = b.getF32Type(); |
182 | Type i64Ty = b.getI64Type(); |
183 | Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4); |
184 | Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8); |
185 | Type f32x1Ty = LLVM::getFixedVectorType(elementType: f32Ty, numElements: 1); |
186 | auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType()); |
187 | |
188 | for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { |
189 | Value toUse = b.create<LLVM::ExtractValueOp>(operand, i); |
190 | |
191 | // For 4xi8 vectors, the intrinsic expects these to be provided as i32 |
192 | // scalar types. |
193 | if (arrayTy.getElementType() == i8x4Ty || |
194 | arrayTy.getElementType() == i4x8Ty || |
195 | (arrayTy.getElementType() == f32x1Ty && |
196 | operandPtxType == NVVM::MMATypes::tf32)) { |
197 | result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse)); |
198 | continue; |
199 | } |
200 | |
201 | // For some element types (i32, f32, f64), we need to unpack the inner |
202 | // vector/array type as well because the intrinsic expects individual |
203 | // scalars to be provided. |
204 | VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType()); |
205 | if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || |
206 | innerArrayTy.getElementType() == f64Ty || |
207 | innerArrayTy.getElementType() == f32Ty)) { |
208 | for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); |
209 | idx < innerSize; idx++) { |
210 | result.push_back(b.create<LLVM::ExtractElementOp>( |
211 | toUse, |
212 | b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx)))); |
213 | } |
214 | continue; |
215 | } |
216 | result.push_back(Elt: toUse); |
217 | } |
218 | return result; |
219 | } |
220 | |
221 | /// Returns whether mbarrier object has shared memory address space. |
222 | static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) { |
223 | return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace( |
224 | barrierType.getMemorySpace())); |
225 | } |
226 | |
227 | /// Returns the memory space attribute of the mbarrier object. |
228 | Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context, |
229 | nvgpu::MBarrierGroupType barrierType) { |
230 | Attribute memorySpace = {}; |
231 | if (isMbarrierShared(barrierType)) { |
232 | memorySpace = |
233 | IntegerAttr::get(IntegerType::get(context, 64), |
234 | nvgpu::NVGPUDialect::kSharedMemoryAddressSpace); |
235 | } |
236 | return memorySpace; |
237 | } |
238 | |
239 | /// Returns memref type of the mbarrier object. The type is defined in the |
240 | /// MBarrierGroupType. |
241 | MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context, |
242 | nvgpu::MBarrierGroupType barrierType) { |
243 | Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType: barrierType); |
244 | MemRefLayoutAttrInterface layout; |
245 | return MemRefType::get({barrierType.getNumBarriers()}, |
246 | IntegerType::get(context, 64), layout, memorySpace); |
247 | } |
248 | |
249 | namespace { |
250 | |
251 | struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { |
252 | using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern; |
253 | |
254 | LogicalResult |
255 | matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, |
256 | ConversionPatternRewriter &rewriter) const override { |
257 | MLIRContext *ctx = getContext(); |
258 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
259 | |
260 | // The result type of ldmatrix will always be a struct of 32bit integer |
261 | // registers if more than one 32bit value is returned. Otherwise, the result |
262 | // is a single i32. The result type of the GPU operation is always a vector |
263 | // of shape (NumRegisters, VectorRegister) where VectorRegister is the |
264 | // vector type of the result and always 32 bits long. We bitcast the result |
265 | // of the NVVM::LdMatrix to this vector type. |
266 | auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]); |
267 | if (!vectorResultType) { |
268 | return failure(); |
269 | } |
270 | Type innerVectorType = LLVM::getFixedVectorType( |
271 | elementType: vectorResultType.getElementType(), numElements: vectorResultType.getDimSize(1)); |
272 | |
273 | int64_t num32BitRegs = vectorResultType.getDimSize(0); |
274 | |
275 | Type ldMatrixResultType; |
276 | if (num32BitRegs > 1) { |
277 | ldMatrixResultType = LLVM::LLVMStructType::getLiteral( |
278 | context: ctx, types: SmallVector<Type>(num32BitRegs, rewriter.getI32Type())); |
279 | } else { |
280 | ldMatrixResultType = rewriter.getI32Type(); |
281 | } |
282 | |
283 | auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType()); |
284 | Value srcPtr = |
285 | getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), |
286 | adaptor.getIndices(), rewriter); |
287 | Value ldMatrixResult = b.create<NVVM::LdMatrixOp>( |
288 | ldMatrixResultType, srcPtr, |
289 | /*num=*/op.getNumTiles(), |
290 | /*layout=*/op.getTranspose() ? NVVM::MMALayout::col |
291 | : NVVM::MMALayout::row); |
292 | |
293 | // The ldmatrix operation returns either a single i32 value or a struct of |
294 | // i32 values. Here we unpack those values and cast them back to their |
295 | // actual vector type (still of width 32b) and repack them into a result |
296 | // struct. |
297 | Type finalResultType = typeConverter->convertType(vectorResultType); |
298 | Value result = b.create<LLVM::UndefOp>(finalResultType); |
299 | for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { |
300 | Value i32Register = |
301 | num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i) |
302 | : ldMatrixResult; |
303 | Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register); |
304 | result = b.create<LLVM::InsertValueOp>(result, casted, i); |
305 | } |
306 | |
307 | rewriter.replaceOp(op, result); |
308 | return success(); |
309 | } |
310 | }; |
311 | |
312 | /// Convert the given type into the corresponding PTX type (NVVM::MMATypes |
313 | /// enum). |
314 | static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) { |
315 | Type elType = getElementTypeOrSelf(type: t); |
316 | if (elType.isInteger(8)) |
317 | return NVVM::MMATypes::s8; |
318 | if (elType.isInteger(4)) |
319 | return NVVM::MMATypes::s4; |
320 | if (elType.isF16()) |
321 | return NVVM::MMATypes::f16; |
322 | if (elType.isF64()) |
323 | return NVVM::MMATypes::f64; |
324 | if (elType.isF32()) |
325 | return NVVM::MMATypes::tf32; |
326 | return failure(); |
327 | } |
328 | |
329 | struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> { |
330 | using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern; |
331 | |
332 | LogicalResult |
333 | matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, |
334 | ConversionPatternRewriter &rewriter) const override { |
335 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
336 | // Get the shapes of the MMAMatrix type being used. The shapes will |
337 | // choose which intrinsic this op will be lowered to. |
338 | VectorType aType = op.getMatrixA().getType(); |
339 | VectorType bType = op.getMatrixA().getType(); |
340 | VectorType cType = op.getMatrixC().getType(); |
341 | |
342 | std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray(); |
343 | |
344 | // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32). |
345 | bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); |
346 | if (aType.getElementType().isF32() && !tf32Enabled) |
347 | return failure(); |
348 | |
349 | FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType); |
350 | if (failed(ptxTypeA)) |
351 | return op->emitOpError("failed to deduce operand PTX types" ); |
352 | FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType); |
353 | if (failed(ptxTypeB)) |
354 | return op->emitOpError("failed to deduce operand PTX types" ); |
355 | std::optional<NVVM::MMATypes> ptxTypeC = |
356 | NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), |
357 | /*isAccumulator=*/true); |
358 | if (!ptxTypeC) |
359 | return op->emitError( |
360 | "could not infer the PTX type for the accumulator/result" ); |
361 | |
362 | // TODO: add an attribute to the op to customize this behavior. |
363 | std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt); |
364 | if (isa<IntegerType>(aType.getElementType())) |
365 | overflow = NVVM::MMAIntOverflow::satfinite; |
366 | |
367 | SmallVector<Value> matA = |
368 | unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA); |
369 | SmallVector<Value> matB = |
370 | unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB); |
371 | SmallVector<Value> matC = |
372 | unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC); |
373 | |
374 | Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); |
375 | Type intrinsicResTy = inferIntrinsicResultType( |
376 | typeConverter->convertType(op->getResultTypes()[0])); |
377 | Value intrinsicResult = b.create<NVVM::MmaOp>( |
378 | intrinsicResTy, matA, matB, matC, |
379 | /*shape=*/gemmShape, |
380 | /*b1Op=*/std::nullopt, |
381 | /*intOverflow=*/overflow, |
382 | /*multiplicandPtxTypes=*/ |
383 | std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB}, |
384 | /*multiplicandLayouts=*/ |
385 | std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row, |
386 | NVVM::MMALayout::col}); |
387 | rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, |
388 | desiredRetTy, intrinsicResult, |
389 | rewriter)); |
390 | return success(); |
391 | } |
392 | }; |
393 | |
394 | struct ConvertNVGPUToNVVMPass |
395 | : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> { |
396 | using Base::Base; |
397 | |
398 | void getDependentDialects(DialectRegistry ®istry) const override { |
399 | registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect, |
400 | arith::ArithDialect>(); |
401 | } |
402 | |
403 | void runOnOperation() override { |
404 | LowerToLLVMOptions options(&getContext()); |
405 | RewritePatternSet patterns(&getContext()); |
406 | LLVMTypeConverter converter(&getContext(), options); |
407 | IRRewriter rewriter(&getContext()); |
408 | populateGpuMemorySpaceAttributeConversions( |
409 | typeConverter&: converter, mapping: [](gpu::AddressSpace space) -> unsigned { |
410 | switch (space) { |
411 | case gpu::AddressSpace::Global: |
412 | return static_cast<unsigned>( |
413 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
414 | case gpu::AddressSpace::Workgroup: |
415 | return static_cast<unsigned>( |
416 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
417 | case gpu::AddressSpace::Private: |
418 | return 0; |
419 | } |
420 | llvm_unreachable("unknown address space enum value" ); |
421 | return 0; |
422 | }); |
423 | /// device-side async tokens cannot be materialized in nvvm. We just |
424 | /// convert them to a dummy i32 type in order to easily drop them during |
425 | /// conversion. |
426 | converter.addConversion(callback: [&](nvgpu::DeviceAsyncTokenType type) -> Type { |
427 | return converter.convertType(IntegerType::get(type.getContext(), 32)); |
428 | }); |
429 | converter.addConversion(callback: [&](nvgpu::WarpgroupAccumulatorType type) -> Type { |
430 | Type elemType = type.getFragmented().getElementType(); |
431 | int64_t sizeM = type.getFragmented().getDimSize(0); |
432 | int64_t sizeN = type.getFragmented().getDimSize(1); |
433 | |
434 | unsigned numMembers; |
435 | if (elemType.isF32() || elemType.isInteger(width: 32)) |
436 | numMembers = sizeN / 2; |
437 | else if (elemType.isF16()) |
438 | numMembers = sizeN / 4; |
439 | else |
440 | llvm_unreachable("unsupported type for warpgroup accumulator" ); |
441 | |
442 | SmallVector<Type> innerStructBody; |
443 | for (unsigned i = 0; i < numMembers; i++) |
444 | innerStructBody.push_back(Elt: elemType); |
445 | auto innerStructType = |
446 | LLVM::LLVMStructType::getLiteral(context: type.getContext(), types: innerStructBody); |
447 | |
448 | SmallVector<Type> structBody; |
449 | for (int i = 0; i < sizeM; i += kWgmmaSizeM) |
450 | structBody.push_back(Elt: innerStructType); |
451 | |
452 | auto convertedType = |
453 | LLVM::LLVMStructType::getLiteral(context: type.getContext(), types: structBody); |
454 | return converter.convertType(convertedType); |
455 | }); |
456 | converter.addConversion(callback: [&](nvgpu::MBarrierTokenType type) -> Type { |
457 | return converter.convertType(IntegerType::get(type.getContext(), 64)); |
458 | }); |
459 | converter.addConversion( |
460 | callback: [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { |
461 | return converter.convertType(IntegerType::get(type.getContext(), 64)); |
462 | }); |
463 | converter.addConversion(callback: [&](nvgpu::MBarrierGroupType type) -> Type { |
464 | return converter.convertType( |
465 | nvgpu::getMBarrierMemrefType(rewriter.getContext(), type)); |
466 | }); |
467 | converter.addConversion(callback: [&](nvgpu::TensorMapDescriptorType type) -> Type { |
468 | return LLVM::LLVMPointerType::get(type.getContext()); |
469 | }); |
470 | populateNVGPUToNVVMConversionPatterns(converter, patterns); |
471 | LLVMConversionTarget target(getContext()); |
472 | target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); |
473 | target.addLegalDialect<::mlir::arith::ArithDialect>(); |
474 | target.addLegalDialect<::mlir::memref::MemRefDialect>(); |
475 | target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); |
476 | mlir::scf::populateSCFStructuralTypeConversionsAndLegality( |
477 | typeConverter&: converter, patterns, target); |
478 | if (failed(applyPartialConversion(getOperation(), target, |
479 | std::move(patterns)))) |
480 | signalPassFailure(); |
481 | } |
482 | }; |
483 | |
484 | /// Returns the constraints for the sparse MMA inline assembly instruction. |
485 | static std::string buildMmaSparseAsmConstraintString(unsigned matASize, |
486 | unsigned matBSize, |
487 | unsigned matCSize) { |
488 | std::string str; |
489 | llvm::raw_string_ostream ss(str); |
490 | for (unsigned i = 0; i < matCSize; i++) |
491 | ss << "=r," ; |
492 | for (unsigned i = 0; i < matASize + matBSize + matCSize; i++) |
493 | ss << "r," ; |
494 | // The final operand is for the sparsity metadata. |
495 | // The sparsity selector appears as direct literal. |
496 | ss << "r" ; |
497 | ss.flush(); |
498 | return str; |
499 | } |
500 | |
501 | /// Returns the string for the `mma.sp.sync` instruction that corresponds to |
502 | /// the given parameters. Note that this function doesn't do any validation, |
503 | /// it's expected that the provided parameters correspond to a valid |
504 | /// instruction. |
505 | static std::string buildMmaSparseAsmString( |
506 | const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize, |
507 | unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, |
508 | NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, |
509 | std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) { |
510 | auto ptxTypeStr = [](NVVM::MMATypes ptxType) { |
511 | return NVVM::stringifyMMATypes(ptxType); |
512 | }; |
513 | |
514 | std::string asmStr; |
515 | llvm::raw_string_ostream ss(asmStr); |
516 | ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k" |
517 | << shape[2] << ".row.col." ; |
518 | |
519 | if (overflow) |
520 | ss << NVVM::stringifyMMAIntOverflow(*overflow) << "." ; |
521 | |
522 | ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "." |
523 | << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " " ; |
524 | unsigned asmArgIdx = 0; |
525 | |
526 | // The operand string is structured into sections `{matC elements...}, |
527 | // {matA elements...}, {matB elements...}, {matC elements}`. |
528 | for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) { |
529 | ss << "{" ; |
530 | for (unsigned i = 0; i < arrSize; i++) |
531 | ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "" ); |
532 | ss << "}," ; |
533 | } |
534 | ss << "$" << asmArgIdx++ << "," ; |
535 | assert(metaDataSelector <= 1); |
536 | ss << "0x" << metaDataSelector << ";" ; |
537 | ss.flush(); |
538 | return asmStr; |
539 | } |
540 | |
541 | /// Builds an inline assembly operation corresponding to the specified MMA |
542 | /// sparse sync operation. |
543 | static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm( |
544 | ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, |
545 | NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, |
546 | std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData, |
547 | ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData, |
548 | int64_t metadataSelector, const std::array<int64_t, 3> &shape, |
549 | Type intrinsicResultType) { |
550 | auto asmDialectAttr = |
551 | LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT); |
552 | |
553 | const unsigned matASize = unpackedAData.size(); |
554 | const unsigned matBSize = unpackedB.size(); |
555 | const unsigned matCSize = unpackedC.size(); |
556 | |
557 | std::string asmStr = buildMmaSparseAsmString( |
558 | shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC, |
559 | ptxTypeD, overflow, metadataSelector); |
560 | std::string constraintStr = |
561 | buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize); |
562 | |
563 | SmallVector<Value> asmVals; |
564 | asmVals.reserve(N: matASize + matBSize + matCSize + 1); |
565 | for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC}) |
566 | llvm::append_range(C&: asmVals, R&: args); |
567 | asmVals.push_back(Elt: indexData); |
568 | |
569 | return b.create<LLVM::InlineAsmOp>( |
570 | /*resultTypes=*/intrinsicResultType, |
571 | /*operands=*/asmVals, |
572 | /*asm_string=*/asmStr, |
573 | /*constraints=*/constraintStr, |
574 | /*has_side_effects=*/true, |
575 | /*is_align_stack=*/false, |
576 | /*asm_dialect=*/asmDialectAttr, |
577 | /*operand_attrs=*/ArrayAttr()); |
578 | } |
579 | |
580 | /// Lowers `nvgpu.mma.sp.sync` to inline assembly. |
581 | struct NVGPUMmaSparseSyncLowering |
582 | : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> { |
583 | using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern; |
584 | |
585 | LogicalResult |
586 | matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor, |
587 | ConversionPatternRewriter &rewriter) const override { |
588 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
589 | // Get the shapes of the MMAMatrix type being used. The shapes will |
590 | // choose which intrinsic this op will be lowered to. |
591 | VectorType aType = op.getMatrixA().getType(); |
592 | VectorType bType = op.getMatrixB().getType(); |
593 | VectorType cType = op.getMatrixC().getType(); |
594 | |
595 | FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType); |
596 | if (failed(ptxTypeA)) |
597 | return op->emitOpError("failed to deduce operand PTX types" ); |
598 | FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType); |
599 | if (failed(ptxTypeB)) |
600 | return op->emitOpError("failed to deduce operand PTX types" ); |
601 | std::optional<NVVM::MMATypes> ptxTypeC = |
602 | NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), |
603 | /*isAccumulator=*/true); |
604 | if (!ptxTypeC) |
605 | return op->emitError( |
606 | "could not infer the PTX type for the accumulator/result" ); |
607 | |
608 | // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32). |
609 | bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); |
610 | if (aType.getElementType().isF32() && !tf32Enabled) |
611 | return failure(); |
612 | |
613 | // TODO: add an attribute to the op to customize this behavior. |
614 | std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt); |
615 | if (isa<IntegerType>(aType.getElementType())) |
616 | overflow = NVVM::MMAIntOverflow::satfinite; |
617 | |
618 | SmallVector<Value> matA = |
619 | unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA); |
620 | SmallVector<Value> matB = |
621 | unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB); |
622 | SmallVector<Value> matC = |
623 | unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC); |
624 | |
625 | Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); |
626 | Type intrinsicResTy = inferIntrinsicResultType( |
627 | typeConverter->convertType(op->getResultTypes()[0])); |
628 | |
629 | // Bitcast the sparse metadata from vector<2xf16> to an i32. |
630 | Value sparseMetadata = adaptor.getSparseMetadata(); |
631 | if (sparseMetadata.getType() != |
632 | LLVM::getFixedVectorType(rewriter.getI16Type(), 2)) |
633 | return op->emitOpError() << "Expected metadata type to be LLVM " |
634 | "VectorType of 2 i16 elements" ; |
635 | sparseMetadata = |
636 | b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata); |
637 | |
638 | FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm( |
639 | b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, |
640 | matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(), |
641 | intrinsicResTy); |
642 | if (failed(intrinsicResult)) |
643 | return failure(); |
644 | |
645 | assert((*intrinsicResult).getNumResults() == 1 && |
646 | "expected inline asm op returns a single LLVM struct type" ); |
647 | rewriter.replaceOp( |
648 | op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, |
649 | (*intrinsicResult)->getResult(0), rewriter)); |
650 | return success(); |
651 | } |
652 | }; |
653 | |
654 | struct NVGPUAsyncCopyLowering |
655 | : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> { |
656 | using ConvertOpToLLVMPattern< |
657 | nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern; |
658 | |
659 | LogicalResult |
660 | matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, |
661 | ConversionPatternRewriter &rewriter) const override { |
662 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
663 | Location loc = op.getLoc(); |
664 | auto dstMemrefType = cast<MemRefType>(op.getDst().getType()); |
665 | Value dstPtr = |
666 | getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(), |
667 | adaptor.getDstIndices(), rewriter); |
668 | FailureOr<unsigned> dstAddressSpace = |
669 | getTypeConverter()->getMemRefAddressSpace(dstMemrefType); |
670 | if (failed(result: dstAddressSpace)) |
671 | return rewriter.notifyMatchFailure( |
672 | arg&: loc, msg: "destination memref address space not convertible to integer" ); |
673 | |
674 | auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); |
675 | FailureOr<unsigned> srcAddressSpace = |
676 | getTypeConverter()->getMemRefAddressSpace(srcMemrefType); |
677 | if (failed(result: srcAddressSpace)) |
678 | return rewriter.notifyMatchFailure( |
679 | arg&: loc, msg: "source memref address space not convertible to integer" ); |
680 | |
681 | Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(), |
682 | adaptor.getSrcIndices(), rewriter); |
683 | // Intrinsics takes a global pointer so we need an address space cast. |
684 | auto srcPointerGlobalType = LLVM::LLVMPointerType::get( |
685 | op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
686 | scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr); |
687 | int64_t dstElements = adaptor.getDstElements().getZExtValue(); |
688 | int64_t sizeInBytes = |
689 | (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; |
690 | // When the optional SrcElements argument is *not* present, the regular |
691 | // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global |
692 | // memory) to fill DstElements number of elements in the destination |
693 | // (shared memory). |
694 | Value srcBytes = adaptor.getSrcElements(); |
695 | if (srcBytes) { |
696 | // When the optional SrcElements argument is present, the source (global |
697 | // memory) of CpAsyncOp is read only for SrcElements number of elements. |
698 | // The rest of the DstElements in the destination (shared memory) are |
699 | // filled with zeros. |
700 | Value c3I32 = |
701 | b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3)); |
702 | Value bitwidth = b.create<LLVM::ConstantOp>( |
703 | b.getI32Type(), |
704 | b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); |
705 | Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes); |
706 | srcBytes = b.create<LLVM::LShrOp>( |
707 | b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32); |
708 | } |
709 | // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than |
710 | // 16 dst bytes. |
711 | NVVM::LoadCacheModifierKind cacheModifier = |
712 | (op.getBypassL1().value_or(false) && sizeInBytes == 16) |
713 | ? NVVM::LoadCacheModifierKind::CG |
714 | : NVVM::LoadCacheModifierKind::CA; |
715 | |
716 | b.create<NVVM::CpAsyncOp>( |
717 | dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), |
718 | NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier), |
719 | srcBytes); |
720 | |
721 | // Drop the result token. |
722 | Value zero = b.create<LLVM::ConstantOp>( |
723 | IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); |
724 | rewriter.replaceOp(op, zero); |
725 | return success(); |
726 | } |
727 | }; |
728 | |
729 | struct NVGPUAsyncCreateGroupLowering |
730 | : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> { |
731 | using ConvertOpToLLVMPattern< |
732 | nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern; |
733 | |
734 | LogicalResult |
735 | matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, |
736 | ConversionPatternRewriter &rewriter) const override { |
737 | rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc()); |
738 | // Drop the result token. |
739 | Value zero = rewriter.create<LLVM::ConstantOp>( |
740 | op->getLoc(), IntegerType::get(op.getContext(), 32), |
741 | rewriter.getI32IntegerAttr(0)); |
742 | rewriter.replaceOp(op, zero); |
743 | return success(); |
744 | } |
745 | }; |
746 | |
747 | struct NVGPUAsyncWaitLowering |
748 | : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> { |
749 | using ConvertOpToLLVMPattern< |
750 | nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern; |
751 | |
752 | LogicalResult |
753 | matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor, |
754 | ConversionPatternRewriter &rewriter) const override { |
755 | // If numGroup is not present pick 0 as a conservative correct value. |
756 | int32_t numGroups = adaptor.getNumGroups().value_or(0); |
757 | rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups); |
758 | rewriter.eraseOp(op: op); |
759 | return success(); |
760 | } |
761 | }; |
762 | |
763 | /// Creates mbarrier object in shared memory |
764 | struct NVGPUMBarrierCreateLowering |
765 | : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> { |
766 | using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern; |
767 | |
768 | template <typename moduleT> |
769 | memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter, |
770 | Operation *funcOp, moduleT moduleOp, |
771 | MemRefType barrierType) const { |
772 | SymbolTable symbolTable(moduleOp); |
773 | OpBuilder::InsertionGuard guard(rewriter); |
774 | rewriter.setInsertionPoint(&moduleOp.front()); |
775 | auto global = rewriter.create<memref::GlobalOp>( |
776 | funcOp->getLoc(), "__mbarrier" , |
777 | /*sym_visibility=*/rewriter.getStringAttr("private" ), |
778 | /*type=*/barrierType, |
779 | /*initial_value=*/ElementsAttr(), |
780 | /*constant=*/false, |
781 | /*alignment=*/rewriter.getI64IntegerAttr(8)); |
782 | symbolTable.insert(symbol: global); |
783 | return global; |
784 | } |
785 | |
786 | LogicalResult |
787 | matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor, |
788 | ConversionPatternRewriter &rewriter) const override { |
789 | Operation *funcOp = op->getParentOp(); |
790 | MemRefType barrierType = nvgpu::getMBarrierMemrefType( |
791 | rewriter.getContext(), op.getBarriers().getType()); |
792 | |
793 | memref::GlobalOp global; |
794 | if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>()) |
795 | global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); |
796 | else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>()) |
797 | global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); |
798 | |
799 | rewriter.setInsertionPoint(op); |
800 | rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType, |
801 | global.getName()); |
802 | return success(); |
803 | } |
804 | }; |
805 | |
806 | /// Base class for lowering mbarrier operations to nvvm intrinsics. |
807 | template <typename SourceOp> |
808 | struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> { |
809 | public: |
810 | using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; |
811 | /// Returns the base pointer of the mbarrier object. |
812 | Value getMbarrierPtr(ImplicitLocOpBuilder &b, |
813 | nvgpu::MBarrierGroupType mbarType, Value memrefDesc, |
814 | Value mbarId, |
815 | ConversionPatternRewriter &rewriter) const { |
816 | MemRefType mbarrierMemrefType = |
817 | nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType); |
818 | return ConvertToLLVMPattern::getStridedElementPtr( |
819 | loc: b.getLoc(), type: mbarrierMemrefType, memRefDesc: memrefDesc, indices: {mbarId}, rewriter); |
820 | } |
821 | }; |
822 | |
823 | /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init` |
824 | struct NVGPUMBarrierInitLowering |
825 | : public MBarrierBasePattern<nvgpu::MBarrierInitOp> { |
826 | using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern; |
827 | |
828 | LogicalResult |
829 | matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor, |
830 | ConversionPatternRewriter &rewriter) const override { |
831 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
832 | nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType(); |
833 | rewriter.setInsertionPoint(op); |
834 | Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(), |
835 | adaptor.getMbarId(), rewriter); |
836 | Value count = truncToI32(b, adaptor.getCount()); |
837 | if (isMbarrierShared(mbarrierType)) { |
838 | rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>( |
839 | op, barrier, count, adaptor.getPredicate()); |
840 | } else { |
841 | rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, |
842 | adaptor.getPredicate()); |
843 | } |
844 | return success(); |
845 | } |
846 | }; |
847 | |
848 | /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive` |
849 | struct NVGPUMBarrierArriveLowering |
850 | : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> { |
851 | using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern; |
852 | LogicalResult |
853 | matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor, |
854 | ConversionPatternRewriter &rewriter) const override { |
855 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
856 | Value barrier = |
857 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
858 | adaptor.getMbarId(), rewriter); |
859 | Type tokenType = getTypeConverter()->convertType( |
860 | nvgpu::MBarrierTokenType::get(op->getContext())); |
861 | if (isMbarrierShared(op.getBarriers().getType())) { |
862 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType, |
863 | barrier); |
864 | } else { |
865 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, |
866 | barrier); |
867 | } |
868 | return success(); |
869 | } |
870 | }; |
871 | |
872 | /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to |
873 | /// `nvvm.mbarrier.arrive.nocomplete` |
874 | struct NVGPUMBarrierArriveNoCompleteLowering |
875 | : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> { |
876 | using MBarrierBasePattern< |
877 | nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern; |
878 | LogicalResult |
879 | matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor, |
880 | ConversionPatternRewriter &rewriter) const override { |
881 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
882 | Value barrier = |
883 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
884 | adaptor.getMbarId(), rewriter); |
885 | Type tokenType = getTypeConverter()->convertType( |
886 | nvgpu::MBarrierTokenType::get(op->getContext())); |
887 | Value count = truncToI32(b, adaptor.getCount()); |
888 | if (isMbarrierShared(op.getBarriers().getType())) { |
889 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>( |
890 | op, tokenType, barrier, count); |
891 | } else { |
892 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( |
893 | op, tokenType, barrier, count); |
894 | } |
895 | return success(); |
896 | } |
897 | }; |
898 | |
899 | /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait` |
900 | struct NVGPUMBarrierTestWaitLowering |
901 | : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> { |
902 | using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern; |
903 | LogicalResult |
904 | matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor, |
905 | ConversionPatternRewriter &rewriter) const override { |
906 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
907 | Value barrier = |
908 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
909 | adaptor.getMbarId(), rewriter); |
910 | Type retType = rewriter.getI1Type(); |
911 | if (isMbarrierShared(op.getBarriers().getType())) { |
912 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>( |
913 | op, retType, barrier, adaptor.getToken()); |
914 | } else { |
915 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>( |
916 | op, retType, barrier, adaptor.getToken()); |
917 | } |
918 | return success(); |
919 | } |
920 | }; |
921 | |
922 | struct NVGPUMBarrierArriveExpectTxLowering |
923 | : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> { |
924 | using MBarrierBasePattern< |
925 | nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern; |
926 | LogicalResult |
927 | matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor, |
928 | ConversionPatternRewriter &rewriter) const override { |
929 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
930 | Value barrier = |
931 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
932 | adaptor.getMbarId(), rewriter); |
933 | Value txcount = truncToI32(b, adaptor.getTxcount()); |
934 | |
935 | if (isMbarrierShared(op.getBarriers().getType())) { |
936 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>( |
937 | op, barrier, txcount, adaptor.getPredicate()); |
938 | return success(); |
939 | } |
940 | |
941 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>( |
942 | op, barrier, txcount, adaptor.getPredicate()); |
943 | return success(); |
944 | } |
945 | }; |
946 | |
947 | struct NVGPUMBarrierTryWaitParityLowering |
948 | : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> { |
949 | using MBarrierBasePattern< |
950 | nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern; |
951 | LogicalResult |
952 | matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor, |
953 | ConversionPatternRewriter &rewriter) const override { |
954 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
955 | Value barrier = |
956 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
957 | adaptor.getMbarId(), rewriter); |
958 | Value ticks = truncToI32(b, adaptor.getTicks()); |
959 | Value phase = |
960 | b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity()); |
961 | |
962 | if (isMbarrierShared(op.getBarriers().getType())) { |
963 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( |
964 | op, barrier, phase, ticks); |
965 | return success(); |
966 | } |
967 | |
968 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier, |
969 | phase, ticks); |
970 | return success(); |
971 | } |
972 | }; |
973 | |
974 | struct NVGPUTmaAsyncLoadOpLowering |
975 | : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> { |
976 | using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern; |
977 | LogicalResult |
978 | matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor, |
979 | ConversionPatternRewriter &rewriter) const override { |
980 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
981 | auto srcMemrefType = cast<MemRefType>(op.getDst().getType()); |
982 | Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, |
983 | adaptor.getDst(), {}, rewriter); |
984 | Value barrier = |
985 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
986 | adaptor.getMbarId(), rewriter); |
987 | |
988 | SmallVector<Value> coords = adaptor.getCoordinates(); |
989 | for (auto [index, value] : llvm::enumerate(coords)) { |
990 | coords[index] = truncToI32(b, value); |
991 | } |
992 | rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>( |
993 | op, dest, adaptor.getTensorMapDescriptor(), coords, barrier, |
994 | ValueRange{}, adaptor.getMulticastMask(), Value{}, |
995 | adaptor.getPredicate()); |
996 | return success(); |
997 | } |
998 | }; |
999 | |
1000 | struct NVGPUTmaAsyncStoreOpLowering |
1001 | : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> { |
1002 | using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern; |
1003 | LogicalResult |
1004 | matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor, |
1005 | ConversionPatternRewriter &rewriter) const override { |
1006 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1007 | auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); |
1008 | Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, |
1009 | adaptor.getSrc(), {}, rewriter); |
1010 | SmallVector<Value> coords = adaptor.getCoordinates(); |
1011 | for (auto [index, value] : llvm::enumerate(coords)) { |
1012 | coords[index] = truncToI32(b, value); |
1013 | } |
1014 | |
1015 | rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>( |
1016 | op, adaptor.getTensorMapDescriptor(), dest, coords, |
1017 | adaptor.getPredicate()); |
1018 | return success(); |
1019 | } |
1020 | }; |
1021 | |
1022 | struct NVGPUGenerateWarpgroupDescriptorLowering |
1023 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> { |
1024 | using ConvertOpToLLVMPattern< |
1025 | nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern; |
1026 | |
1027 | LogicalResult |
1028 | matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor, |
1029 | ConversionPatternRewriter &rewriter) const override { |
1030 | |
1031 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1032 | |
1033 | nvgpu::TensorMapSwizzleKind swizzleKind = |
1034 | op.getTensorMap().getType().getSwizzle(); |
1035 | |
1036 | unsigned layout = |
1037 | (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128 |
1038 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64 |
1039 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32 |
1040 | : 1; |
1041 | unsigned swizzle = |
1042 | (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1 |
1043 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2 |
1044 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3 |
1045 | : 0; |
1046 | |
1047 | auto ti64 = b.getIntegerType(64); |
1048 | auto makeConst = [&](uint64_t index) -> Value { |
1049 | return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index)); |
1050 | }; |
1051 | auto shiftLeft = [&](Value value, unsigned shift) -> Value { |
1052 | return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift)); |
1053 | }; |
1054 | auto shiftRight = [&](Value value, unsigned shift) -> Value { |
1055 | return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift)); |
1056 | }; |
1057 | auto insertBit = [&](Value desc, Value val, int startBit) { |
1058 | return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit)); |
1059 | }; |
1060 | |
1061 | int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); |
1062 | uint64_t strideDimVal = (layout << 3) >> exclude4LSB; |
1063 | uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB; |
1064 | uint64_t offsetVal = 0; |
1065 | |
1066 | Value strideDim = makeConst(strideDimVal); |
1067 | Value leadDim = makeConst(leadDimVal); |
1068 | |
1069 | Value baseAddr = getStridedElementPtr( |
1070 | op->getLoc(), cast<MemRefType>(op.getTensor().getType()), |
1071 | adaptor.getTensor(), {}, rewriter); |
1072 | Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr); |
1073 | // Just use 14 bits for base address |
1074 | Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); |
1075 | |
1076 | int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32, |
1077 | startLeadBit = 16, startBaseAddrBit = 0; |
1078 | Value dsc = makeConst(0); |
1079 | // // [62,64) swizzle type |
1080 | dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit); |
1081 | // // [49,52) base_offset |
1082 | dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit); |
1083 | // // [32,46) stride |
1084 | dsc = insertBit(dsc, strideDim, startStrideBit); |
1085 | // // [16,30) leading dimension |
1086 | dsc = insertBit(dsc, leadDim, startLeadBit); |
1087 | // // [0,14) start_address |
1088 | dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); |
1089 | |
1090 | LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " |
1091 | << "leading_off:" << leadDimVal << "\t" |
1092 | << "stride_off :" << strideDimVal << "\t" |
1093 | << "base_offset:" << offsetVal << "\t" |
1094 | << "layout_type:" << swizzle << " (" |
1095 | << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) |
1096 | << ")\n start_addr : " << baseAddr << "\n" ); |
1097 | |
1098 | rewriter.replaceOp(op, dsc); |
1099 | return success(); |
1100 | } |
1101 | }; |
1102 | |
1103 | static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) { |
1104 | return b.create<LLVM::ConstantOp>(b.getIntegerType(64), |
1105 | b.getI32IntegerAttr(index)); |
1106 | } |
1107 | |
1108 | /// Returns a Value that holds data type enum that is expected by CUDA driver. |
1109 | static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) { |
1110 | // Enum is from CUDA driver API |
1111 | // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html |
1112 | enum CUtensorMapDataTypeEnum { |
1113 | CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, |
1114 | CU_TENSOR_MAP_DATA_TYPE_UINT16, |
1115 | CU_TENSOR_MAP_DATA_TYPE_UINT32, |
1116 | CU_TENSOR_MAP_DATA_TYPE_INT32, |
1117 | CU_TENSOR_MAP_DATA_TYPE_UINT64, |
1118 | CU_TENSOR_MAP_DATA_TYPE_INT64, |
1119 | CU_TENSOR_MAP_DATA_TYPE_FLOAT16, |
1120 | CU_TENSOR_MAP_DATA_TYPE_FLOAT32, |
1121 | CU_TENSOR_MAP_DATA_TYPE_FLOAT64, |
1122 | CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, |
1123 | CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, |
1124 | CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, |
1125 | CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ |
1126 | }; |
1127 | |
1128 | if (type.isUnsignedInteger(width: 8)) |
1129 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT8); |
1130 | if (type.isUnsignedInteger(width: 16)) |
1131 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT16); |
1132 | if (type.isUnsignedInteger(width: 32)) |
1133 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT32); |
1134 | if (type.isUnsignedInteger(width: 64)) |
1135 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT64); |
1136 | if (type.isSignlessInteger(width: 32)) |
1137 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_INT32); |
1138 | if (type.isSignlessInteger(width: 64)) |
1139 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_INT64); |
1140 | if (type.isF16()) |
1141 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT16); |
1142 | if (type.isF32()) |
1143 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT32); |
1144 | if (type.isF64()) |
1145 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT64); |
1146 | if (type.isBF16()) |
1147 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_BFLOAT16); |
1148 | |
1149 | llvm_unreachable("Not supported data type" ); |
1150 | } |
1151 | |
1152 | struct NVGPUTmaCreateDescriptorOpLowering |
1153 | : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> { |
1154 | using ConvertOpToLLVMPattern< |
1155 | nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern; |
1156 | LogicalResult |
1157 | matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor, |
1158 | ConversionPatternRewriter &rewriter) const override { |
1159 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1160 | auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext()); |
1161 | Type llvmInt64Type = IntegerType::get(op->getContext(), 64); |
1162 | |
1163 | Value tensorElementType = |
1164 | elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType()); |
1165 | auto promotedOperands = getTypeConverter()->promoteOperands( |
1166 | b.getLoc(), op->getOperands(), adaptor.getOperands(), b); |
1167 | |
1168 | Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type, |
1169 | makeI64Const(b, 5)); |
1170 | for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) { |
1171 | Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType, |
1172 | boxArrayPtr, makeI64Const(b, index)); |
1173 | b.create<LLVM::StoreOp>(value, gep); |
1174 | } |
1175 | |
1176 | nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType(); |
1177 | // Set Arguments for the function call |
1178 | SmallVector<Value> arguments; |
1179 | arguments.push_back(Elt: promotedOperands[0]); // rank |
1180 | arguments.push_back(Elt: promotedOperands[1]); // descriptor |
1181 | arguments.push_back(Elt: tensorElementType); // data type |
1182 | arguments.push_back( |
1183 | Elt: makeI64Const(b, index: (int)desc.getInterleave())); // interleave |
1184 | arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getSwizzle())); // swizzle |
1185 | arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getL2promo())); // l2promo |
1186 | arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getOob())); // oob |
1187 | arguments.push_back(Elt: boxArrayPtr); // box dimensions |
1188 | |
1189 | // Set data types of the arguments |
1190 | SmallVector<Type> argTypes = { |
1191 | llvmInt64Type, /* int64_t tensorRank */ |
1192 | llvmPointerType, /* ptr */ |
1193 | llvmInt64Type, /* int64_t */ |
1194 | llvmInt64Type, /* int64_t */ |
1195 | llvmInt64Type, /* int64_t */ |
1196 | llvmInt64Type, /* int64_t */ |
1197 | llvmInt64Type, /* int64_t */ |
1198 | llvmPointerType /* ptr */ |
1199 | }; |
1200 | FunctionCallBuilder hostRegisterCallBuilder = { |
1201 | "mgpuTensorMapEncodeTiledMemref" , llvmPointerType, argTypes}; |
1202 | Value tensorMap = |
1203 | hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult(); |
1204 | |
1205 | rewriter.replaceOp(op, tensorMap); |
1206 | return success(); |
1207 | } |
1208 | }; |
1209 | |
1210 | struct NVGPUWarpgroupMmaOpLowering |
1211 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> { |
1212 | using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern; |
1213 | |
1214 | /// This is a helper class to generate required NVVM Ops for warp-group level |
1215 | /// matrix multiplication. |
1216 | /// When the given GEMM shape is larger than the shape of |
1217 | /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp |
1218 | /// Op(s), group and execute them asynchronously. The class also handles |
1219 | /// waiting for completion and iterates through WarpgroupMatrixDescriptor to |
1220 | /// create descriptors for each instruction. |
1221 | /// |
1222 | /// For example this is the case when the shape of GEMM is 128x128x128 |
1223 | /// |
1224 | /// nvvm.wgmma.fence.aligned |
1225 | /// |
1226 | /// nvvm.wgmma.mma.async descA, descB |
1227 | /// iterate(descA, descB) |
1228 | /// nvvm.wgmma.mma.async descA, descB |
1229 | /// [6x times more] |
1230 | /// |
1231 | /// nvvm.wgmma.group.sync.aligned |
1232 | /// nvvm.wgmma.wait.group.sync [groupId] |
1233 | /// |
1234 | class WarpgroupGemm { |
1235 | nvgpu::WarpgroupMmaOp op; |
1236 | ImplicitLocOpBuilder b; |
1237 | OpAdaptor adaptor; |
1238 | |
1239 | // Entire shape of the given Op |
1240 | int64_t totalM, totalN, totalK; |
1241 | |
1242 | // Shape of one wgmma instruction |
1243 | int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0; |
1244 | |
1245 | // Iteration counts for GEMM |
1246 | int iterationM = 0, iterationN = 0, iterationK = 0; |
1247 | |
1248 | /// The function returns the shape of wgmma instruction that is defined in |
1249 | /// PTX programming guide. |
1250 | /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape |
1251 | void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) { |
1252 | wgmmaM = 64; |
1253 | wgmmaN = sizeN; |
1254 | if (inputElemType.isTF32()) { |
1255 | wgmmaK = 8; |
1256 | } else if (inputElemType.isF16() || inputElemType.isBF16()) { |
1257 | wgmmaK = 16; |
1258 | } else if (inputElemType.isFloat8E4M3FN() || |
1259 | inputElemType.isFloat8E5M2() || inputElemType.isInteger(width: 16)) { |
1260 | wgmmaK = 32; |
1261 | } else if (inputElemType.isInteger(width: 1)) { |
1262 | wgmmaK = 256; |
1263 | } else { |
1264 | llvm_unreachable("msg: not supported K shape" ); |
1265 | } |
1266 | LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM |
1267 | << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n" ); |
1268 | } |
1269 | |
1270 | /// Generates WGMMATypesAttr from MLIR Type |
1271 | NVVM::WGMMATypesAttr generateWgmmaType(Type type, |
1272 | bool useF32 = false) const { |
1273 | auto getWgmmaType = [=](Type elemType) { |
1274 | if (elemType.isF32() || elemType.isTF32()) |
1275 | return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32; |
1276 | if (elemType.isF16()) |
1277 | return NVVM::WGMMATypes::f16; |
1278 | if (elemType.isBF16()) |
1279 | return NVVM::WGMMATypes::bf16; |
1280 | if (elemType.isFloat8E4M3FN()) |
1281 | return NVVM::WGMMATypes::e4m3; |
1282 | if (elemType.isFloat8E5M2()) |
1283 | return NVVM::WGMMATypes::e5m2; |
1284 | if (elemType.isInteger(1)) |
1285 | return NVVM::WGMMATypes::b1; |
1286 | if (elemType.isInteger(8)) |
1287 | return NVVM::WGMMATypes::s8; |
1288 | if (elemType.isUnsignedInteger(8)) |
1289 | return NVVM::WGMMATypes::u8; |
1290 | if (elemType.isInteger(32)) |
1291 | return NVVM::WGMMATypes::s32; |
1292 | llvm_unreachable("unsupported type" ); |
1293 | }; |
1294 | return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type)); |
1295 | } |
1296 | |
1297 | /// Generates layout attribute for the input matrix for wgmma instruction |
1298 | NVVM::MMALayoutAttr |
1299 | generateWgmmaLayout(std::optional<bool> transpose) const { |
1300 | if (transpose.value_or(false)) |
1301 | return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col); |
1302 | return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row); |
1303 | } |
1304 | |
1305 | /// Generates shape attribute for wgmma instruction |
1306 | NVVM::MMAShapeAttr generateWgmmaShape() const { |
1307 | return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK); |
1308 | } |
1309 | |
1310 | /// Generates scale attributes of output matrix for wgmma instruction |
1311 | NVVM::WGMMAScaleOutAttr generateScaleOut() const { |
1312 | return NVVM::WGMMAScaleOutAttr::get(op->getContext(), |
1313 | NVVM::WGMMAScaleOut::one); |
1314 | } |
1315 | /// Generates scale attributes of input matrix for wgmma instruction |
1316 | NVVM::WGMMAScaleInAttr generateScaleIn() const { |
1317 | return NVVM::WGMMAScaleInAttr::get(op->getContext(), |
1318 | NVVM::WGMMAScaleIn::one); |
1319 | } |
1320 | |
1321 | /// Basic function to generate Add |
1322 | Value makeAdd(Value lhs, Value rhs) { |
1323 | return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); |
1324 | }; |
1325 | |
1326 | /// Moves the descriptor pointer of matrix-A for the next wgmma instruction. |
1327 | /// Currently, it only handles row-major. |
1328 | /// |
1329 | /// It moves the pointer like below for [128][64] size: |
1330 | /// +2 +4 +6 |
1331 | /// ↓ ↓ ↓ |
1332 | /// descA ---> +--+--+--+--+ |
1333 | /// |->|->|->|->| |
1334 | /// | | | | | |
1335 | /// | | | | | |
1336 | /// | | | | | |
1337 | /// descA+512---> +-----------+ |
1338 | /// | | | | | |
1339 | /// | | | | | |
1340 | /// | | | | | |
1341 | /// | | | | | |
1342 | /// +-----------+ |
1343 | /// |
1344 | Value iterateDescriptorA(Value desc, int i, int j, int k) { |
1345 | MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor(); |
1346 | Type elemA = matrixTypeA.getElementType(); |
1347 | int byte = elemA.getIntOrFloatBitWidth() / 8; |
1348 | int tileShapeA = matrixTypeA.getDimSize(1); |
1349 | int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; |
1350 | incrementVal = incrementVal >> exclude4LSB; |
1351 | LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k |
1352 | << "] [wgmma descriptors] Descriptor A + " |
1353 | << incrementVal << " | \t " ); |
1354 | if (!incrementVal) |
1355 | return desc; |
1356 | return makeAdd(lhs: desc, rhs: makeI64Const(b, index: incrementVal)); |
1357 | } |
1358 | |
1359 | /// Moves the descriptor pointer of matrix-B for the next wgmma instruction. |
1360 | /// Currently, it only handles column-major. |
1361 | /// |
1362 | /// It moves the pointer like below for [128][64] size: |
1363 | /// descB ---> +--+--+--+--+--+--+--+--+ |
1364 | /// |↓ | | | | | | | | |
1365 | /// |↓ | | | | | | | | |
1366 | /// |↓ | | | | | | | | |
1367 | /// |↓ | | | | | | | | |
1368 | /// +--+--+--+--+--+--+--+--+ |
1369 | /// |
1370 | Value iterateDescriptorB(Value desc, int i, int j, int k) { |
1371 | MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor(); |
1372 | Type elemB = matrixTypeB.getElementType(); |
1373 | int byte = elemB.getIntOrFloatBitWidth() / 8; |
1374 | int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; |
1375 | incrementVal = incrementVal >> exclude4LSB; |
1376 | LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n" ); |
1377 | if (!incrementVal) |
1378 | return desc; |
1379 | return makeAdd(lhs: desc, rhs: makeI64Const(b, index: incrementVal)); |
1380 | } |
1381 | |
1382 | /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix |
1383 | /// descriptors and arranges them based on induction variables: i, j, and k. |
1384 | Value generateWgmma(int i, int j, int k, Value matrixC) { |
1385 | LLVM_DEBUG(DBGS() << "\t wgmma." |
1386 | << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK |
1387 | << "(A[" << (iterationM * wgmmaM) << ":" |
1388 | << (iterationM * wgmmaM) + wgmmaM << "][" |
1389 | << (iterationK * wgmmaK) << ":" |
1390 | << (iterationK * wgmmaK + wgmmaK) << "] * " |
1391 | << " B[" << (iterationK * wgmmaK) << ":" |
1392 | << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" |
1393 | << wgmmaN << "])\n" ); |
1394 | |
1395 | Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); |
1396 | Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); |
1397 | |
1398 | Type elemA = op.getDescriptorA().getType().getTensor().getElementType(); |
1399 | NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA); |
1400 | |
1401 | Type elemB = op.getDescriptorB().getType().getTensor().getElementType(); |
1402 | NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB); |
1403 | |
1404 | Type elemD = op.getMatrixC().getType().getFragmented().getElementType(); |
1405 | NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true); |
1406 | |
1407 | NVVM::MMAShapeAttr shape = generateWgmmaShape(); |
1408 | NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut(); |
1409 | NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn(); |
1410 | NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA()); |
1411 | NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB()); |
1412 | |
1413 | auto overflow = NVVM::MMAIntOverflowAttr::get( |
1414 | op->getContext(), NVVM::MMAIntOverflow::wrapped); |
1415 | |
1416 | return b.create<NVVM::WgmmaMmaAsyncOp>( |
1417 | matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA, |
1418 | itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, |
1419 | overflow); |
1420 | } |
1421 | |
1422 | /// Generates multiple wgmma instructions to complete the given GEMM shape |
1423 | Value generateWgmmaGroup() { |
1424 | Value wgmmaResult = |
1425 | b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType()); |
1426 | |
1427 | // Perform GEMM |
1428 | SmallVector<Value> wgmmaResults; |
1429 | for (int i = 0; i < iterationM; ++i) { |
1430 | Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i); |
1431 | for (int j = 0; j < iterationN; ++j) |
1432 | for (int k = 0; k < iterationK; ++k) |
1433 | matrixC = generateWgmma(i, j, k, matrixC); |
1434 | wgmmaResults.push_back(Elt: matrixC); |
1435 | } |
1436 | for (auto [idx, matrix] : llvm::enumerate(First&: wgmmaResults)) { |
1437 | wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(), |
1438 | wgmmaResult, matrix, idx); |
1439 | } |
1440 | return wgmmaResult; |
1441 | } |
1442 | |
1443 | public: |
1444 | WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b, |
1445 | OpAdaptor adaptor) |
1446 | : op(op), b(b), adaptor(adaptor) { |
1447 | // Find the entire GEMM Shape |
1448 | totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); |
1449 | totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); |
1450 | totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); |
1451 | LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN |
1452 | << "] += A[" << totalM << "][" << totalK << "] * B[" |
1453 | << totalK << "][" << totalN << "] ---===\n" ); |
1454 | |
1455 | // Find the shape for one wgmma instruction |
1456 | findWgmmaShape( |
1457 | sizeM: totalM, sizeN: totalN, |
1458 | inputElemType: op.getDescriptorA().getType().getTensor().getElementType()); |
1459 | |
1460 | // Iterations counts to complete the given shape with wgmma shape |
1461 | iterationM = totalM / wgmmaM; |
1462 | iterationN = totalN / wgmmaN; |
1463 | iterationK = totalK / wgmmaK; |
1464 | } |
1465 | |
1466 | /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It |
1467 | /// includes generating a fence Op (WgmmaFenceAlignedOp) before the |
1468 | /// instructions and group synchronization, as well as waiting |
1469 | /// (WgmmaGroupSyncAlignedOp) for group synchronization |
1470 | /// (WgmmaWaitGroupSyncOp) after the instructions. |
1471 | Value generateWarpgroupMma() { |
1472 | b.create<NVVM::WgmmaFenceAlignedOp>(); |
1473 | Value wgmmaResult = generateWgmmaGroup(); |
1474 | b.create<NVVM::WgmmaGroupSyncAlignedOp>(); |
1475 | b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup()); |
1476 | return wgmmaResult; |
1477 | } |
1478 | }; |
1479 | LogicalResult |
1480 | matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, |
1481 | ConversionPatternRewriter &rewriter) const override { |
1482 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1483 | |
1484 | // Step 1. Build a helper class |
1485 | WarpgroupGemm warpgroupGemm(op, b, adaptor); |
1486 | |
1487 | // Step 2. Get the entire GEMM Shape |
1488 | Value wgmmaResult = warpgroupGemm.generateWarpgroupMma(); |
1489 | |
1490 | // Step 3. Replace fragmented result struct with the op results |
1491 | rewriter.replaceOp(op, wgmmaResult); |
1492 | return success(); |
1493 | } |
1494 | }; |
1495 | |
1496 | struct NVGPUWarpgroupMmaStoreOpLowering |
1497 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> { |
1498 | using ConvertOpToLLVMPattern< |
1499 | nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern; |
1500 | |
1501 | /// This function stores a fragmented register matrix owned by a warp group |
1502 | /// (128 threads) into a memref. Each thread has 64 registers, each the size |
1503 | /// of a struct. |
1504 | /// Here is what each threads (T) holds, each `d` is struct value with a |
1505 | /// number. |
1506 | /// |
1507 | /// Threads in warp-group (128 threads) and what they owns in the matrixD: |
1508 | /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N] |
1509 | /// 32-63 Warp-1 -> MatrixD[16:31][0:N] |
1510 | /// 64-95 Warp-2 -> MatrixD[32:47][0:N] |
1511 | /// 96-127 Warp-3 -> MatrixD[48:64][0:N] |
1512 | /// |
1513 | /// Matrix-D: |
1514 | /// +______________________________________________________________________+ |
1515 | /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 | |
1516 | /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY| |
1517 | /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY| |
1518 | /// ..| .........|.........|.........|.........|........|...........|........| |
1519 | /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW| |
1520 | /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW| |
1521 | /// ..| .........|.........|.........|.........|........|...........|........| |
1522 | /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........| |
1523 | /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........| |
1524 | /// ..| .........|.........|.........|.........|........|...........|........| |
1525 | /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........| |
1526 | /// ..| .........|.........|.........|.........|........|...........|........| |
1527 | /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........| |
1528 | /// ..| .........|.........|.........|.........|........|...........|........| |
1529 | /// +______________________________________________________________________+ |
1530 | /// |
1531 | /// \param rewriter: The pattern rewriter. |
1532 | /// \param matrixD: Result of the warp-group MMA operation (fragmented |
1533 | /// matrix). It is holded by a thread and a struct with 64 elements. |
1534 | /// \param dstMemref: The memref where the registers will be stored. |
1535 | /// \param offset: the offset within the memref where the registers will be |
1536 | /// stored. |
1537 | void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD, |
1538 | TypedValue<MemRefType> dstMemref, |
1539 | int offset) const { |
1540 | Type i32 = b.getI32Type(); |
1541 | |
1542 | auto makeConst = [&](int32_t index) -> Value { |
1543 | return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index)); |
1544 | }; |
1545 | Value c1 = makeConst(1); |
1546 | Value c2 = makeConst(2); |
1547 | Value c4 = makeConst(4); |
1548 | Value c8 = makeConst(8); |
1549 | Value c16 = makeConst(16); |
1550 | Value warpSize = makeConst(kWarpSize); |
1551 | |
1552 | auto makeMul = [&](Value lhs, Value rhs) -> Value { |
1553 | return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs); |
1554 | }; |
1555 | auto makeAdd = [&](Value lhs, Value rhs) -> Value { |
1556 | return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); |
1557 | }; |
1558 | |
1559 | auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, |
1560 | TypedValue<::mlir::MemRefType> memref) { |
1561 | Type it = b.getIndexType(); |
1562 | Value idx = b.create<arith::IndexCastOp>(it, x); |
1563 | Value idy0 = b.create<arith::IndexCastOp>(it, y); |
1564 | Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1)); |
1565 | Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i); |
1566 | Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1); |
1567 | b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0}); |
1568 | b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1}); |
1569 | }; |
1570 | |
1571 | Value tidx = b.create<NVVM::ThreadIdXOp>(i32); |
1572 | Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize); |
1573 | Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize); |
1574 | Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4); |
1575 | Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4); |
1576 | |
1577 | Value tj = makeMul(lane4modId, c2); |
1578 | Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); |
1579 | if (offset) |
1580 | ti = makeAdd(ti, makeConst(offset)); |
1581 | |
1582 | auto structType = cast<LLVM::LLVMStructType>(Val: matrixD.getType()); |
1583 | |
1584 | // Number of 32-bit registers owns per thread |
1585 | constexpr unsigned numAdjacentRegisters = 2; |
1586 | // Number of 8x8 matrices one below another per warp |
1587 | constexpr unsigned numStackedMatrices = 2; |
1588 | |
1589 | size_t storeCount = (structType.getBody().size() / |
1590 | (numStackedMatrices * numAdjacentRegisters)); |
1591 | |
1592 | for (size_t i = 0; i < numStackedMatrices; ++i) { |
1593 | Value idx = makeAdd(ti, makeMul(makeConst(i), c8)); |
1594 | for (size_t j = 0; j < storeCount; ++j) { |
1595 | Value idy = makeAdd(tj, makeMul(makeConst(j), c8)); |
1596 | size_t structIndex = (i * numAdjacentRegisters) + |
1597 | (j * (numStackedMatrices * numAdjacentRegisters)); |
1598 | makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref); |
1599 | } |
1600 | } |
1601 | } |
1602 | |
1603 | LogicalResult |
1604 | matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor, |
1605 | ConversionPatternRewriter &rewriter) const override { |
1606 | int offset = 0; |
1607 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1608 | Value matriDValue = adaptor.getMatrixD(); |
1609 | auto stype = cast<LLVM::LLVMStructType>(Val: matriDValue.getType()); |
1610 | for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { |
1611 | auto structType = cast<LLVM::LLVMStructType>(matrixD); |
1612 | Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx); |
1613 | storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); |
1614 | offset += structType.getBody().size(); |
1615 | } |
1616 | rewriter.eraseOp(op: op); |
1617 | return success(); |
1618 | } |
1619 | }; |
1620 | |
1621 | struct NVGPUWarpgroupMmaInitAccumulatorOpLowering |
1622 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> { |
1623 | using ConvertOpToLLVMPattern< |
1624 | nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern; |
1625 | LogicalResult |
1626 | matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor, |
1627 | ConversionPatternRewriter &rewriter) const override { |
1628 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1629 | LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>( |
1630 | getTypeConverter()->convertType(op.getMatrixC().getType())); |
1631 | Type elemType = cast<LLVM::LLVMStructType>(Val: packStructType.getBody().front()) |
1632 | .getBody() |
1633 | .front(); |
1634 | Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType)); |
1635 | Value packStruct = b.create<LLVM::UndefOp>(packStructType); |
1636 | SmallVector<Value> innerStructs; |
1637 | // Unpack the structs and set all values to zero |
1638 | for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { |
1639 | auto structType = cast<LLVM::LLVMStructType>(s); |
1640 | Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx); |
1641 | for (unsigned i = 0; i < structType.getBody().size(); ++i) { |
1642 | structValue = b.create<LLVM::InsertValueOp>( |
1643 | structType, structValue, zero, ArrayRef<int64_t>({i})); |
1644 | } |
1645 | innerStructs.push_back(structValue); |
1646 | } |
1647 | // Pack the inner structs into a single struct |
1648 | for (auto [idx, matrix] : llvm::enumerate(First&: innerStructs)) { |
1649 | packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(), |
1650 | packStruct, matrix, idx); |
1651 | } |
1652 | rewriter.replaceOp(op, packStruct); |
1653 | return success(); |
1654 | } |
1655 | }; |
1656 | |
1657 | struct NVGPUTmaPrefetchOpLowering |
1658 | : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> { |
1659 | using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern; |
1660 | LogicalResult |
1661 | matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor, |
1662 | ConversionPatternRewriter &rewriter) const override { |
1663 | rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>( |
1664 | op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate()); |
1665 | return success(); |
1666 | } |
1667 | }; |
1668 | |
1669 | } // namespace |
1670 | |
1671 | void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, |
1672 | RewritePatternSet &patterns) { |
1673 | patterns.add< |
1674 | NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create |
1675 | NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init |
1676 | NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive |
1677 | NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete |
1678 | NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity |
1679 | NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity |
1680 | NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load |
1681 | NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store |
1682 | NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor |
1683 | NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor |
1684 | NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx |
1685 | NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor |
1686 | NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma |
1687 | NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store |
1688 | NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator |
1689 | MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, |
1690 | NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, |
1691 | NVGPUMmaSparseSyncLowering>(arg&: converter); |
1692 | } |
1693 | |