1 | //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" |
10 | |
11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
12 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
13 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
14 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
18 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
19 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
20 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
21 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
22 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
23 | #include "mlir/IR/BuiltinTypes.h" |
24 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | #include "mlir/IR/TypeUtilities.h" |
27 | #include "mlir/IR/Value.h" |
28 | #include "mlir/Pass/Pass.h" |
29 | #include "llvm/Support/Debug.h" |
30 | #include "llvm/Support/ErrorHandling.h" |
31 | #include "llvm/Support/raw_ostream.h" |
32 | #include <optional> |
33 | |
34 | #define DEBUG_TYPE "nvgpu-to-nvvm" |
35 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
36 | #define DBGSE() (llvm::dbgs()) |
37 | |
38 | namespace mlir { |
39 | #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS |
40 | #include "mlir/Conversion/Passes.h.inc" |
41 | } // namespace mlir |
42 | |
43 | using namespace mlir; |
44 | |
45 | /// Number of bits that needs to be excluded when building matrix descriptor for |
46 | /// wgmma operations. |
47 | constexpr int exclude4LSB = 4; |
48 | |
49 | /// GPU has 32 bit registers, this function truncates values when larger width |
50 | /// is not needed. |
51 | static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { |
52 | Type type = value.getType(); |
53 | assert(llvm::isa<IntegerType>(type) && "expected an integer Value"); |
54 | if (type.getIntOrFloatBitWidth() <= 32) |
55 | return value; |
56 | return b.create<LLVM::TruncOp>(b.getI32Type(), value); |
57 | } |
58 | |
59 | /// Returns the type for the intrinsic given the vectorResultType of the |
60 | /// `gpu.mma.sync` operation. |
61 | static Type inferIntrinsicResultType(Type vectorResultType) { |
62 | MLIRContext *ctx = vectorResultType.getContext(); |
63 | auto a = cast<LLVM::LLVMArrayType>(vectorResultType); |
64 | auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx)); |
65 | auto i32Ty = IntegerType::get(ctx, 32); |
66 | auto i32x2Ty = VectorType::get(2, i32Ty); |
67 | Type f64Ty = Float64Type::get(ctx); |
68 | Type f64x2Ty = VectorType::get(2, f64Ty); |
69 | Type f32Ty = Float32Type::get(ctx); |
70 | Type f32x2Ty = VectorType::get(2, f32Ty); |
71 | if (a.getElementType() == f16x2Ty) { |
72 | return LLVM::LLVMStructType::getLiteral( |
73 | ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty)); |
74 | } |
75 | if (a.getElementType() == i32x2Ty) { |
76 | return LLVM::LLVMStructType::getLiteral( |
77 | ctx, |
78 | SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty)); |
79 | } |
80 | if (a.getElementType() == f64x2Ty) { |
81 | return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); |
82 | } |
83 | if (a.getElementType() == f32x2Ty) { |
84 | return LLVM::LLVMStructType::getLiteral( |
85 | ctx, |
86 | SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty)); |
87 | } |
88 | if (a.getElementType() == VectorType::get(1, f32Ty)) { |
89 | return LLVM::LLVMStructType::getLiteral( |
90 | ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty)); |
91 | } |
92 | return vectorResultType; |
93 | } |
94 | |
95 | /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is |
96 | /// always an LLVM struct) into a fragment that is compatible with the vector |
97 | /// type of this operation. This involves extracting elements from the struct |
98 | /// and inserting them into an LLVM array. These extra data-movement |
99 | /// operations should be canonicalized away by the LLVM backend. |
100 | static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, |
101 | Type resultType, Value intrinsicResult, |
102 | RewriterBase &rewriter) { |
103 | MLIRContext *ctx = rewriter.getContext(); |
104 | auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType); |
105 | auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType); |
106 | Type i32Ty = rewriter.getI32Type(); |
107 | Type f32Ty = rewriter.getF32Type(); |
108 | Type f64Ty = rewriter.getF64Type(); |
109 | Type f16x2Ty = VectorType::get(2, rewriter.getF16Type()); |
110 | Type i32x2Ty = VectorType::get(2, i32Ty); |
111 | Type f64x2Ty = VectorType::get(2, f64Ty); |
112 | Type f32x2Ty = VectorType::get(2, f32Ty); |
113 | Type f32x1Ty = VectorType::get(1, f32Ty); |
114 | |
115 | auto makeConst = [&](int32_t index) -> Value { |
116 | return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), |
117 | rewriter.getI32IntegerAttr(index)); |
118 | }; |
119 | |
120 | if (arrayType) { |
121 | SmallVector<Value, 4> elements; |
122 | |
123 | // The intrinsic returns 32-bit wide elements in a form which can be |
124 | // directly bitcasted and inserted into the result vector. |
125 | if (arrayType.getElementType() == f16x2Ty || |
126 | arrayType.getElementType() == f32x1Ty) { |
127 | for (unsigned i = 0; i < structType.getBody().size(); i++) { |
128 | Value el = |
129 | rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i); |
130 | el = rewriter.createOrFold<LLVM::BitcastOp>( |
131 | loc, arrayType.getElementType(), el); |
132 | elements.push_back(Elt: el); |
133 | } |
134 | } |
135 | |
136 | // The intrinsic returns i32, f64, and f32 values as individual scalars, |
137 | // even when the result is notionally a 64-bit wide element (e.g. f32x2). We |
138 | // need to extract them from the struct and pack them into the 64-bit wide |
139 | // rows of the vector result. |
140 | if (arrayType.getElementType() == i32x2Ty || |
141 | arrayType.getElementType() == f64x2Ty || |
142 | arrayType.getElementType() == f32x2Ty) { |
143 | |
144 | for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { |
145 | Value vec = |
146 | rewriter.create<LLVM::PoisonOp>(loc, arrayType.getElementType()); |
147 | Value x1 = |
148 | rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2); |
149 | Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, |
150 | i * 2 + 1); |
151 | vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, |
152 | x1, makeConst(0)); |
153 | vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, |
154 | x2, makeConst(1)); |
155 | elements.push_back(Elt: vec); |
156 | } |
157 | } |
158 | |
159 | // Create the final vectorized result. |
160 | Value result = rewriter.create<LLVM::PoisonOp>(loc, arrayType); |
161 | for (const auto &el : llvm::enumerate(First&: elements)) { |
162 | result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(), |
163 | el.index()); |
164 | } |
165 | return result; |
166 | } |
167 | |
168 | return intrinsicResult; |
169 | } |
170 | |
171 | /// The `gpu.mma.sync` converter below expects matrix fragment operands to be |
172 | /// given as 2D `vectors` where the rows are 32b or 64b wide. The |
173 | /// `nvvm.mma.sync` op expects these argments to be a given in a long list of |
174 | /// scalars of certain types. This function helps unpack the `vector` arguments |
175 | /// and cast them to the types expected by `nvvm.mma.sync`. |
176 | static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, |
177 | Value operand, |
178 | NVVM::MMATypes operandPtxType) { |
179 | SmallVector<Value> result; |
180 | Type i32Ty = b.getI32Type(); |
181 | Type f64Ty = b.getF64Type(); |
182 | Type f32Ty = b.getF32Type(); |
183 | Type i64Ty = b.getI64Type(); |
184 | Type i8x4Ty = VectorType::get(4, b.getI8Type()); |
185 | Type i4x8Ty = VectorType::get(8, b.getIntegerType(4)); |
186 | Type f32x1Ty = VectorType::get(1, f32Ty); |
187 | auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType()); |
188 | |
189 | for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { |
190 | Value toUse = b.create<LLVM::ExtractValueOp>(operand, i); |
191 | |
192 | // For 4xi8 vectors, the intrinsic expects these to be provided as i32 |
193 | // scalar types. |
194 | if (arrayTy.getElementType() == i8x4Ty || |
195 | arrayTy.getElementType() == i4x8Ty || |
196 | (arrayTy.getElementType() == f32x1Ty && |
197 | operandPtxType == NVVM::MMATypes::tf32)) { |
198 | result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse)); |
199 | continue; |
200 | } |
201 | |
202 | // For some element types (i32, f32, f64), we need to unpack the inner |
203 | // vector/array type as well because the intrinsic expects individual |
204 | // scalars to be provided. |
205 | VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType()); |
206 | if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || |
207 | innerArrayTy.getElementType() == f64Ty || |
208 | innerArrayTy.getElementType() == f32Ty)) { |
209 | for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); |
210 | idx < innerSize; idx++) { |
211 | result.push_back(b.create<LLVM::ExtractElementOp>( |
212 | toUse, |
213 | b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx)))); |
214 | } |
215 | continue; |
216 | } |
217 | result.push_back(Elt: toUse); |
218 | } |
219 | return result; |
220 | } |
221 | |
222 | /// Returns whether mbarrier object has shared memory address space. |
223 | static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) { |
224 | return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace( |
225 | barrierType.getMemorySpace())); |
226 | } |
227 | |
228 | /// Returns the memory space attribute of the mbarrier object. |
229 | Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context, |
230 | nvgpu::MBarrierGroupType barrierType) { |
231 | Attribute memorySpace = {}; |
232 | if (isMbarrierShared(barrierType)) { |
233 | memorySpace = |
234 | IntegerAttr::get(IntegerType::get(context, 64), |
235 | nvgpu::NVGPUDialect::kSharedMemoryAddressSpace); |
236 | } |
237 | return memorySpace; |
238 | } |
239 | |
240 | /// Returns memref type of the mbarrier object. The type is defined in the |
241 | /// MBarrierGroupType. |
242 | MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context, |
243 | nvgpu::MBarrierGroupType barrierType) { |
244 | Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType: barrierType); |
245 | MemRefLayoutAttrInterface layout; |
246 | return MemRefType::get({barrierType.getNumBarriers()}, |
247 | IntegerType::get(context, 64), layout, memorySpace); |
248 | } |
249 | |
250 | namespace { |
251 | |
252 | struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { |
253 | using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern; |
254 | |
255 | LogicalResult |
256 | matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, |
257 | ConversionPatternRewriter &rewriter) const override { |
258 | MLIRContext *ctx = getContext(); |
259 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
260 | |
261 | // The result type of ldmatrix will always be a struct of 32bit integer |
262 | // registers if more than one 32bit value is returned. Otherwise, the result |
263 | // is a single i32. The result type of the GPU operation is always a vector |
264 | // of shape (NumRegisters, VectorRegister) where VectorRegister is the |
265 | // vector type of the result and always 32 bits long. We bitcast the result |
266 | // of the NVVM::LdMatrix to this vector type. |
267 | auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]); |
268 | if (!vectorResultType) { |
269 | return failure(); |
270 | } |
271 | Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1), |
272 | vectorResultType.getElementType()); |
273 | |
274 | int64_t num32BitRegs = vectorResultType.getDimSize(0); |
275 | |
276 | Type ldMatrixResultType; |
277 | if (num32BitRegs > 1) { |
278 | ldMatrixResultType = LLVM::LLVMStructType::getLiteral( |
279 | ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type())); |
280 | } else { |
281 | ldMatrixResultType = rewriter.getI32Type(); |
282 | } |
283 | |
284 | auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType()); |
285 | Value srcPtr = |
286 | getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, |
287 | adaptor.getSrcMemref(), adaptor.getIndices()); |
288 | Value ldMatrixResult = b.create<NVVM::LdMatrixOp>( |
289 | ldMatrixResultType, srcPtr, |
290 | /*num=*/op.getNumTiles(), |
291 | /*layout=*/op.getTranspose() ? NVVM::MMALayout::col |
292 | : NVVM::MMALayout::row); |
293 | |
294 | // The ldmatrix operation returns either a single i32 value or a struct of |
295 | // i32 values. Here we unpack those values and cast them back to their |
296 | // actual vector type (still of width 32b) and repack them into a result |
297 | // struct. |
298 | Type finalResultType = typeConverter->convertType(vectorResultType); |
299 | Value result = b.create<LLVM::PoisonOp>(finalResultType); |
300 | for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { |
301 | Value i32Register = |
302 | num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i) |
303 | : ldMatrixResult; |
304 | Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register); |
305 | result = b.create<LLVM::InsertValueOp>(result, casted, i); |
306 | } |
307 | |
308 | rewriter.replaceOp(op, result); |
309 | return success(); |
310 | } |
311 | }; |
312 | |
313 | /// Convert the given type into the corresponding PTX type (NVVM::MMATypes |
314 | /// enum). |
315 | static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) { |
316 | Type elType = getElementTypeOrSelf(type: t); |
317 | if (elType.isInteger(8)) |
318 | return NVVM::MMATypes::s8; |
319 | if (elType.isInteger(4)) |
320 | return NVVM::MMATypes::s4; |
321 | if (elType.isF16()) |
322 | return NVVM::MMATypes::f16; |
323 | if (elType.isF64()) |
324 | return NVVM::MMATypes::f64; |
325 | if (elType.isF32()) |
326 | return NVVM::MMATypes::tf32; |
327 | return failure(); |
328 | } |
329 | |
330 | struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> { |
331 | using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern; |
332 | |
333 | LogicalResult |
334 | matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, |
335 | ConversionPatternRewriter &rewriter) const override { |
336 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
337 | // Get the shapes of the MMAMatrix type being used. The shapes will |
338 | // choose which intrinsic this op will be lowered to. |
339 | VectorType aType = op.getMatrixA().getType(); |
340 | VectorType bType = op.getMatrixA().getType(); |
341 | VectorType cType = op.getMatrixC().getType(); |
342 | |
343 | std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray(); |
344 | |
345 | // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32). |
346 | bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); |
347 | if (aType.getElementType().isF32() && !tf32Enabled) |
348 | return failure(); |
349 | |
350 | FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType); |
351 | if (failed(ptxTypeA)) |
352 | return op->emitOpError("failed to deduce operand PTX types"); |
353 | FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType); |
354 | if (failed(ptxTypeB)) |
355 | return op->emitOpError("failed to deduce operand PTX types"); |
356 | std::optional<NVVM::MMATypes> ptxTypeC = |
357 | NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), |
358 | /*isAccumulator=*/true); |
359 | if (!ptxTypeC) |
360 | return op->emitError( |
361 | "could not infer the PTX type for the accumulator/result"); |
362 | |
363 | // TODO: add an attribute to the op to customize this behavior. |
364 | std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt); |
365 | if (isa<IntegerType>(aType.getElementType())) |
366 | overflow = NVVM::MMAIntOverflow::satfinite; |
367 | |
368 | SmallVector<Value> matA = |
369 | unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA); |
370 | SmallVector<Value> matB = |
371 | unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB); |
372 | SmallVector<Value> matC = |
373 | unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC); |
374 | |
375 | Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); |
376 | Type intrinsicResTy = inferIntrinsicResultType( |
377 | typeConverter->convertType(op->getResultTypes()[0])); |
378 | Value intrinsicResult = b.create<NVVM::MmaOp>( |
379 | intrinsicResTy, matA, matB, matC, |
380 | /*shape=*/gemmShape, |
381 | /*b1Op=*/std::nullopt, |
382 | /*intOverflow=*/overflow, |
383 | /*multiplicandPtxTypes=*/ |
384 | std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB}, |
385 | /*multiplicandLayouts=*/ |
386 | std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row, |
387 | NVVM::MMALayout::col}); |
388 | rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, |
389 | desiredRetTy, intrinsicResult, |
390 | rewriter)); |
391 | return success(); |
392 | } |
393 | }; |
394 | |
395 | struct ConvertNVGPUToNVVMPass |
396 | : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> { |
397 | using Base::Base; |
398 | |
399 | void getDependentDialects(DialectRegistry ®istry) const override { |
400 | registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect, |
401 | arith::ArithDialect>(); |
402 | } |
403 | |
404 | void runOnOperation() override { |
405 | LowerToLLVMOptions options(&getContext()); |
406 | RewritePatternSet patterns(&getContext()); |
407 | LLVMTypeConverter converter(&getContext(), options); |
408 | IRRewriter rewriter(&getContext()); |
409 | populateGpuMemorySpaceAttributeConversions( |
410 | typeConverter&: converter, mapping: [](gpu::AddressSpace space) -> unsigned { |
411 | switch (space) { |
412 | case gpu::AddressSpace::Global: |
413 | return static_cast<unsigned>( |
414 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
415 | case gpu::AddressSpace::Workgroup: |
416 | return static_cast<unsigned>( |
417 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
418 | case gpu::AddressSpace::Private: |
419 | return 0; |
420 | } |
421 | llvm_unreachable("unknown address space enum value"); |
422 | return 0; |
423 | }); |
424 | /// device-side async tokens cannot be materialized in nvvm. We just |
425 | /// convert them to a dummy i32 type in order to easily drop them during |
426 | /// conversion. |
427 | converter.addConversion(callback: [&](nvgpu::DeviceAsyncTokenType type) -> Type { |
428 | return converter.convertType(IntegerType::get(type.getContext(), 32)); |
429 | }); |
430 | converter.addConversion(callback: [&](nvgpu::WarpgroupAccumulatorType type) -> Type { |
431 | Type elemType = type.getFragmented().getElementType(); |
432 | int64_t sizeM = type.getFragmented().getDimSize(0); |
433 | int64_t sizeN = type.getFragmented().getDimSize(1); |
434 | |
435 | unsigned numMembers; |
436 | if (elemType.isF32() || elemType.isInteger(width: 32)) |
437 | numMembers = sizeN / 2; |
438 | else if (elemType.isF16()) |
439 | numMembers = sizeN / 4; |
440 | else |
441 | llvm_unreachable("unsupported type for warpgroup accumulator"); |
442 | |
443 | SmallVector<Type> innerStructBody; |
444 | for (unsigned i = 0; i < numMembers; i++) |
445 | innerStructBody.push_back(Elt: elemType); |
446 | auto innerStructType = |
447 | LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody); |
448 | |
449 | SmallVector<Type> structBody; |
450 | for (int i = 0; i < sizeM; i += kWgmmaSizeM) |
451 | structBody.push_back(Elt: innerStructType); |
452 | |
453 | auto convertedType = |
454 | LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); |
455 | return converter.convertType(convertedType); |
456 | }); |
457 | converter.addConversion(callback: [&](nvgpu::MBarrierTokenType type) -> Type { |
458 | return converter.convertType(IntegerType::get(type.getContext(), 64)); |
459 | }); |
460 | converter.addConversion( |
461 | callback: [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { |
462 | return converter.convertType(IntegerType::get(type.getContext(), 64)); |
463 | }); |
464 | converter.addConversion(callback: [&](nvgpu::MBarrierGroupType type) -> Type { |
465 | return converter.convertType( |
466 | nvgpu::getMBarrierMemrefType(rewriter.getContext(), type)); |
467 | }); |
468 | converter.addConversion(callback: [&](nvgpu::TensorMapDescriptorType type) -> Type { |
469 | return LLVM::LLVMPointerType::get(type.getContext()); |
470 | }); |
471 | populateNVGPUToNVVMConversionPatterns(converter, patterns); |
472 | LLVMConversionTarget target(getContext()); |
473 | target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); |
474 | target.addLegalDialect<::mlir::arith::ArithDialect>(); |
475 | target.addLegalDialect<::mlir::memref::MemRefDialect>(); |
476 | target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); |
477 | mlir::scf::populateSCFStructuralTypeConversionsAndLegality( |
478 | typeConverter: converter, patterns, target); |
479 | if (failed(applyPartialConversion(getOperation(), target, |
480 | std::move(patterns)))) |
481 | signalPassFailure(); |
482 | } |
483 | }; |
484 | |
485 | /// Returns the constraints for the sparse MMA inline assembly instruction. |
486 | static std::string buildMmaSparseAsmConstraintString(unsigned matASize, |
487 | unsigned matBSize, |
488 | unsigned matCSize) { |
489 | std::string str; |
490 | llvm::raw_string_ostream ss(str); |
491 | for (unsigned i = 0; i < matCSize; i++) |
492 | ss << "=r,"; |
493 | for (unsigned i = 0; i < matASize + matBSize + matCSize; i++) |
494 | ss << "r,"; |
495 | // The final operand is for the sparsity metadata. |
496 | // The sparsity selector appears as direct literal. |
497 | ss << "r"; |
498 | return str; |
499 | } |
500 | |
501 | /// Returns the string for the `mma.sp.sync` instruction that corresponds to |
502 | /// the given parameters. Note that this function doesn't do any validation, |
503 | /// it's expected that the provided parameters correspond to a valid |
504 | /// instruction. |
505 | static std::string buildMmaSparseAsmString( |
506 | const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize, |
507 | unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, |
508 | NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, |
509 | std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) { |
510 | auto ptxTypeStr = [](NVVM::MMATypes ptxType) { |
511 | return NVVM::stringifyMMATypes(ptxType); |
512 | }; |
513 | |
514 | std::string asmStr; |
515 | llvm::raw_string_ostream ss(asmStr); |
516 | ss << "mma.sp.sync.aligned.m"<< shape[0] << "n"<< shape[1] << "k" |
517 | << shape[2] << ".row.col."; |
518 | |
519 | if (overflow) |
520 | ss << NVVM::stringifyMMAIntOverflow(*overflow) << "."; |
521 | |
522 | ss << ptxTypeStr(ptxTypeD) << "."<< ptxTypeStr(ptxTypeA) << "." |
523 | << ptxTypeStr(ptxTypeB) << "."<< ptxTypeStr(ptxTypeC) << " "; |
524 | unsigned asmArgIdx = 0; |
525 | |
526 | // The operand string is structured into sections `{matC elements...}, |
527 | // {matA elements...}, {matB elements...}, {matC elements}`. |
528 | for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) { |
529 | ss << "{"; |
530 | for (unsigned i = 0; i < arrSize; i++) |
531 | ss << "$"<< asmArgIdx++ << (i < arrSize - 1 ? ",": ""); |
532 | ss << "},"; |
533 | } |
534 | ss << "$"<< asmArgIdx++ << ","; |
535 | assert(metaDataSelector <= 1); |
536 | ss << "0x"<< metaDataSelector << ";"; |
537 | return asmStr; |
538 | } |
539 | |
540 | /// Builds an inline assembly operation corresponding to the specified MMA |
541 | /// sparse sync operation. |
542 | static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm( |
543 | ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, |
544 | NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, |
545 | std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData, |
546 | ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData, |
547 | int64_t metadataSelector, const std::array<int64_t, 3> &shape, |
548 | Type intrinsicResultType) { |
549 | auto asmDialectAttr = |
550 | LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT); |
551 | |
552 | const unsigned matASize = unpackedAData.size(); |
553 | const unsigned matBSize = unpackedB.size(); |
554 | const unsigned matCSize = unpackedC.size(); |
555 | |
556 | std::string asmStr = buildMmaSparseAsmString( |
557 | shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC, |
558 | ptxTypeD, overflow, metadataSelector); |
559 | std::string constraintStr = |
560 | buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize); |
561 | |
562 | SmallVector<Value> asmVals; |
563 | asmVals.reserve(N: matASize + matBSize + matCSize + 1); |
564 | for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC}) |
565 | llvm::append_range(C&: asmVals, R&: args); |
566 | asmVals.push_back(Elt: indexData); |
567 | |
568 | return b.create<LLVM::InlineAsmOp>( |
569 | /*resultTypes=*/intrinsicResultType, |
570 | /*operands=*/asmVals, |
571 | /*asm_string=*/asmStr, |
572 | /*constraints=*/constraintStr, |
573 | /*has_side_effects=*/true, |
574 | /*is_align_stack=*/false, LLVM::TailCallKind::None, |
575 | /*asm_dialect=*/asmDialectAttr, |
576 | /*operand_attrs=*/ArrayAttr()); |
577 | } |
578 | |
579 | /// Lowers `nvgpu.mma.sp.sync` to inline assembly. |
580 | struct NVGPUMmaSparseSyncLowering |
581 | : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> { |
582 | using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern; |
583 | |
584 | LogicalResult |
585 | matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor, |
586 | ConversionPatternRewriter &rewriter) const override { |
587 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
588 | // Get the shapes of the MMAMatrix type being used. The shapes will |
589 | // choose which intrinsic this op will be lowered to. |
590 | VectorType aType = op.getMatrixA().getType(); |
591 | VectorType bType = op.getMatrixB().getType(); |
592 | VectorType cType = op.getMatrixC().getType(); |
593 | |
594 | FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType); |
595 | if (failed(ptxTypeA)) |
596 | return op->emitOpError("failed to deduce operand PTX types"); |
597 | FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType); |
598 | if (failed(ptxTypeB)) |
599 | return op->emitOpError("failed to deduce operand PTX types"); |
600 | std::optional<NVVM::MMATypes> ptxTypeC = |
601 | NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), |
602 | /*isAccumulator=*/true); |
603 | if (!ptxTypeC) |
604 | return op->emitError( |
605 | "could not infer the PTX type for the accumulator/result"); |
606 | |
607 | // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32). |
608 | bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); |
609 | if (aType.getElementType().isF32() && !tf32Enabled) |
610 | return failure(); |
611 | |
612 | // TODO: add an attribute to the op to customize this behavior. |
613 | std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt); |
614 | if (isa<IntegerType>(aType.getElementType())) |
615 | overflow = NVVM::MMAIntOverflow::satfinite; |
616 | |
617 | SmallVector<Value> matA = |
618 | unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA); |
619 | SmallVector<Value> matB = |
620 | unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB); |
621 | SmallVector<Value> matC = |
622 | unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC); |
623 | |
624 | Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); |
625 | Type intrinsicResTy = inferIntrinsicResultType( |
626 | typeConverter->convertType(op->getResultTypes()[0])); |
627 | |
628 | // Bitcast the sparse metadata from vector<2xf16> to an i32. |
629 | Value sparseMetadata = adaptor.getSparseMetadata(); |
630 | if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type())) |
631 | return op->emitOpError() << "Expected metadata type to be LLVM " |
632 | "VectorType of 2 i16 elements"; |
633 | sparseMetadata = |
634 | b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata); |
635 | |
636 | FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm( |
637 | b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, |
638 | matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(), |
639 | intrinsicResTy); |
640 | if (failed(intrinsicResult)) |
641 | return failure(); |
642 | |
643 | assert((*intrinsicResult).getNumResults() == 1 && |
644 | "expected inline asm op returns a single LLVM struct type"); |
645 | rewriter.replaceOp( |
646 | op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, |
647 | (*intrinsicResult)->getResult(0), rewriter)); |
648 | return success(); |
649 | } |
650 | }; |
651 | |
652 | struct NVGPUAsyncCopyLowering |
653 | : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> { |
654 | using ConvertOpToLLVMPattern< |
655 | nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern; |
656 | |
657 | LogicalResult |
658 | matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, |
659 | ConversionPatternRewriter &rewriter) const override { |
660 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
661 | Location loc = op.getLoc(); |
662 | auto dstMemrefType = cast<MemRefType>(op.getDst().getType()); |
663 | Value dstPtr = |
664 | getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType, |
665 | adaptor.getDst(), adaptor.getDstIndices()); |
666 | FailureOr<unsigned> dstAddressSpace = |
667 | getTypeConverter()->getMemRefAddressSpace(dstMemrefType); |
668 | if (failed(Result: dstAddressSpace)) |
669 | return rewriter.notifyMatchFailure( |
670 | arg&: loc, msg: "destination memref address space not convertible to integer"); |
671 | |
672 | auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); |
673 | FailureOr<unsigned> srcAddressSpace = |
674 | getTypeConverter()->getMemRefAddressSpace(srcMemrefType); |
675 | if (failed(Result: srcAddressSpace)) |
676 | return rewriter.notifyMatchFailure( |
677 | arg&: loc, msg: "source memref address space not convertible to integer"); |
678 | |
679 | Value scrPtr = |
680 | getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(), |
681 | adaptor.getSrcIndices()); |
682 | // Intrinsics takes a global pointer so we need an address space cast. |
683 | auto srcPointerGlobalType = LLVM::LLVMPointerType::get( |
684 | op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
685 | scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr); |
686 | int64_t dstElements = adaptor.getDstElements().getZExtValue(); |
687 | int64_t sizeInBytes = |
688 | (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; |
689 | // When the optional SrcElements argument is *not* present, the regular |
690 | // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global |
691 | // memory) to fill DstElements number of elements in the destination |
692 | // (shared memory). |
693 | Value srcBytes = adaptor.getSrcElements(); |
694 | if (srcBytes) { |
695 | // When the optional SrcElements argument is present, the source (global |
696 | // memory) of CpAsyncOp is read only for SrcElements number of elements. |
697 | // The rest of the DstElements in the destination (shared memory) are |
698 | // filled with zeros. |
699 | Value c3I32 = |
700 | b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3)); |
701 | Value bitwidth = b.create<LLVM::ConstantOp>( |
702 | b.getI32Type(), |
703 | b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); |
704 | Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes); |
705 | srcBytes = b.create<LLVM::LShrOp>( |
706 | b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32); |
707 | } |
708 | // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than |
709 | // 16 dst bytes. |
710 | NVVM::LoadCacheModifierKind cacheModifier = |
711 | (op.getBypassL1().value_or(false) && sizeInBytes == 16) |
712 | ? NVVM::LoadCacheModifierKind::CG |
713 | : NVVM::LoadCacheModifierKind::CA; |
714 | |
715 | b.create<NVVM::CpAsyncOp>( |
716 | dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), |
717 | NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier), |
718 | srcBytes); |
719 | |
720 | // Drop the result token. |
721 | Value zero = b.create<LLVM::ConstantOp>( |
722 | IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); |
723 | rewriter.replaceOp(op, zero); |
724 | return success(); |
725 | } |
726 | }; |
727 | |
728 | struct NVGPUAsyncCreateGroupLowering |
729 | : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> { |
730 | using ConvertOpToLLVMPattern< |
731 | nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern; |
732 | |
733 | LogicalResult |
734 | matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, |
735 | ConversionPatternRewriter &rewriter) const override { |
736 | rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc()); |
737 | // Drop the result token. |
738 | Value zero = rewriter.create<LLVM::ConstantOp>( |
739 | op->getLoc(), IntegerType::get(op.getContext(), 32), |
740 | rewriter.getI32IntegerAttr(0)); |
741 | rewriter.replaceOp(op, zero); |
742 | return success(); |
743 | } |
744 | }; |
745 | |
746 | struct NVGPUAsyncWaitLowering |
747 | : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> { |
748 | using ConvertOpToLLVMPattern< |
749 | nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern; |
750 | |
751 | LogicalResult |
752 | matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor, |
753 | ConversionPatternRewriter &rewriter) const override { |
754 | // If numGroup is not present pick 0 as a conservative correct value. |
755 | int32_t numGroups = adaptor.getNumGroups().value_or(0); |
756 | rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups); |
757 | rewriter.eraseOp(op: op); |
758 | return success(); |
759 | } |
760 | }; |
761 | |
762 | /// Creates mbarrier object in shared memory |
763 | struct NVGPUMBarrierCreateLowering |
764 | : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> { |
765 | using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern; |
766 | |
767 | template <typename moduleT> |
768 | memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter, |
769 | Operation *funcOp, moduleT moduleOp, |
770 | MemRefType barrierType) const { |
771 | SymbolTable symbolTable(moduleOp); |
772 | OpBuilder::InsertionGuard guard(rewriter); |
773 | rewriter.setInsertionPoint(&moduleOp.front()); |
774 | auto global = rewriter.create<memref::GlobalOp>( |
775 | funcOp->getLoc(), "__mbarrier", |
776 | /*sym_visibility=*/rewriter.getStringAttr("private"), |
777 | /*type=*/barrierType, |
778 | /*initial_value=*/ElementsAttr(), |
779 | /*constant=*/false, |
780 | /*alignment=*/rewriter.getI64IntegerAttr(8)); |
781 | symbolTable.insert(symbol: global); |
782 | return global; |
783 | } |
784 | |
785 | LogicalResult |
786 | matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor, |
787 | ConversionPatternRewriter &rewriter) const override { |
788 | Operation *funcOp = op->getParentOp(); |
789 | MemRefType barrierType = nvgpu::getMBarrierMemrefType( |
790 | rewriter.getContext(), op.getBarriers().getType()); |
791 | |
792 | memref::GlobalOp global; |
793 | if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>()) |
794 | global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); |
795 | else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>()) |
796 | global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); |
797 | |
798 | rewriter.setInsertionPoint(op); |
799 | rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType, |
800 | global.getName()); |
801 | return success(); |
802 | } |
803 | }; |
804 | |
805 | /// Base class for lowering mbarrier operations to nvvm intrinsics. |
806 | template <typename SourceOp> |
807 | struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> { |
808 | public: |
809 | using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; |
810 | /// Returns the base pointer of the mbarrier object. |
811 | Value getMbarrierPtr(ImplicitLocOpBuilder &b, |
812 | nvgpu::MBarrierGroupType mbarType, Value memrefDesc, |
813 | Value mbarId, |
814 | ConversionPatternRewriter &rewriter) const { |
815 | MemRefType mbarrierMemrefType = |
816 | nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType); |
817 | return ConvertToLLVMPattern::getStridedElementPtr( |
818 | rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}); |
819 | } |
820 | }; |
821 | |
822 | struct NVGPUMBarrierGetLowering |
823 | : public MBarrierBasePattern<nvgpu::MBarrierGetOp> { |
824 | using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern; |
825 | |
826 | LogicalResult |
827 | matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor, |
828 | ConversionPatternRewriter &rewriter) const override { |
829 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
830 | nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType(); |
831 | rewriter.setInsertionPoint(op); |
832 | Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(), |
833 | adaptor.getMbarId(), rewriter); |
834 | Type resType = op.getMbarrierPointer().getType(); |
835 | rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier); |
836 | return success(); |
837 | } |
838 | }; |
839 | |
840 | /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init` |
841 | struct NVGPUMBarrierInitLowering |
842 | : public MBarrierBasePattern<nvgpu::MBarrierInitOp> { |
843 | using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern; |
844 | |
845 | LogicalResult |
846 | matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor, |
847 | ConversionPatternRewriter &rewriter) const override { |
848 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
849 | nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType(); |
850 | rewriter.setInsertionPoint(op); |
851 | Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(), |
852 | adaptor.getMbarId(), rewriter); |
853 | Value count = truncToI32(b, adaptor.getCount()); |
854 | if (isMbarrierShared(mbarrierType)) { |
855 | rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>( |
856 | op, barrier, count, adaptor.getPredicate()); |
857 | } else { |
858 | rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, |
859 | adaptor.getPredicate()); |
860 | } |
861 | return success(); |
862 | } |
863 | }; |
864 | |
865 | /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive` |
866 | struct NVGPUMBarrierArriveLowering |
867 | : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> { |
868 | using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern; |
869 | LogicalResult |
870 | matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor, |
871 | ConversionPatternRewriter &rewriter) const override { |
872 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
873 | Value barrier = |
874 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
875 | adaptor.getMbarId(), rewriter); |
876 | Type tokenType = getTypeConverter()->convertType( |
877 | nvgpu::MBarrierTokenType::get(op->getContext())); |
878 | if (isMbarrierShared(op.getBarriers().getType())) { |
879 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType, |
880 | barrier); |
881 | } else { |
882 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, |
883 | barrier); |
884 | } |
885 | return success(); |
886 | } |
887 | }; |
888 | |
889 | /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to |
890 | /// `nvvm.mbarrier.arrive.nocomplete` |
891 | struct NVGPUMBarrierArriveNoCompleteLowering |
892 | : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> { |
893 | using MBarrierBasePattern< |
894 | nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern; |
895 | LogicalResult |
896 | matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor, |
897 | ConversionPatternRewriter &rewriter) const override { |
898 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
899 | Value barrier = |
900 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
901 | adaptor.getMbarId(), rewriter); |
902 | Type tokenType = getTypeConverter()->convertType( |
903 | nvgpu::MBarrierTokenType::get(op->getContext())); |
904 | Value count = truncToI32(b, adaptor.getCount()); |
905 | if (isMbarrierShared(op.getBarriers().getType())) { |
906 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>( |
907 | op, tokenType, barrier, count); |
908 | } else { |
909 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( |
910 | op, tokenType, barrier, count); |
911 | } |
912 | return success(); |
913 | } |
914 | }; |
915 | |
916 | /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait` |
917 | struct NVGPUMBarrierTestWaitLowering |
918 | : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> { |
919 | using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern; |
920 | LogicalResult |
921 | matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor, |
922 | ConversionPatternRewriter &rewriter) const override { |
923 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
924 | Value barrier = |
925 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
926 | adaptor.getMbarId(), rewriter); |
927 | Type retType = rewriter.getI1Type(); |
928 | if (isMbarrierShared(op.getBarriers().getType())) { |
929 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>( |
930 | op, retType, barrier, adaptor.getToken()); |
931 | } else { |
932 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>( |
933 | op, retType, barrier, adaptor.getToken()); |
934 | } |
935 | return success(); |
936 | } |
937 | }; |
938 | |
939 | struct NVGPUMBarrierArriveExpectTxLowering |
940 | : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> { |
941 | using MBarrierBasePattern< |
942 | nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern; |
943 | LogicalResult |
944 | matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor, |
945 | ConversionPatternRewriter &rewriter) const override { |
946 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
947 | Value barrier = |
948 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
949 | adaptor.getMbarId(), rewriter); |
950 | Value txcount = truncToI32(b, adaptor.getTxcount()); |
951 | |
952 | if (isMbarrierShared(op.getBarriers().getType())) { |
953 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>( |
954 | op, barrier, txcount, adaptor.getPredicate()); |
955 | return success(); |
956 | } |
957 | |
958 | rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>( |
959 | op, barrier, txcount, adaptor.getPredicate()); |
960 | return success(); |
961 | } |
962 | }; |
963 | |
964 | struct NVGPUMBarrierTryWaitParityLowering |
965 | : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> { |
966 | using MBarrierBasePattern< |
967 | nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern; |
968 | LogicalResult |
969 | matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor, |
970 | ConversionPatternRewriter &rewriter) const override { |
971 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
972 | Value barrier = |
973 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
974 | adaptor.getMbarId(), rewriter); |
975 | Value ticks = truncToI32(b, adaptor.getTicks()); |
976 | Value phase = |
977 | b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity()); |
978 | |
979 | if (isMbarrierShared(op.getBarriers().getType())) { |
980 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( |
981 | op, barrier, phase, ticks); |
982 | return success(); |
983 | } |
984 | |
985 | rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier, |
986 | phase, ticks); |
987 | return success(); |
988 | } |
989 | }; |
990 | |
991 | struct NVGPUTmaAsyncLoadOpLowering |
992 | : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> { |
993 | using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern; |
994 | LogicalResult |
995 | matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor, |
996 | ConversionPatternRewriter &rewriter) const override { |
997 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
998 | auto srcMemrefType = cast<MemRefType>(op.getDst().getType()); |
999 | Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType, |
1000 | adaptor.getDst(), {}); |
1001 | Value barrier = |
1002 | getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), |
1003 | adaptor.getMbarId(), rewriter); |
1004 | |
1005 | SmallVector<Value> coords = adaptor.getCoordinates(); |
1006 | for (auto [index, value] : llvm::enumerate(coords)) { |
1007 | coords[index] = truncToI32(b, value); |
1008 | } |
1009 | rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>( |
1010 | op, dest, adaptor.getTensorMapDescriptor(), coords, barrier, |
1011 | ValueRange{}, adaptor.getMulticastMask(), Value{}, |
1012 | adaptor.getPredicate()); |
1013 | return success(); |
1014 | } |
1015 | }; |
1016 | |
1017 | struct NVGPUTmaAsyncStoreOpLowering |
1018 | : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> { |
1019 | using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern; |
1020 | LogicalResult |
1021 | matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor, |
1022 | ConversionPatternRewriter &rewriter) const override { |
1023 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1024 | auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); |
1025 | Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType, |
1026 | adaptor.getSrc(), {}); |
1027 | SmallVector<Value> coords = adaptor.getCoordinates(); |
1028 | for (auto [index, value] : llvm::enumerate(coords)) { |
1029 | coords[index] = truncToI32(b, value); |
1030 | } |
1031 | |
1032 | rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>( |
1033 | op, adaptor.getTensorMapDescriptor(), dest, coords, |
1034 | adaptor.getPredicate()); |
1035 | return success(); |
1036 | } |
1037 | }; |
1038 | |
1039 | struct NVGPUGenerateWarpgroupDescriptorLowering |
1040 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> { |
1041 | using ConvertOpToLLVMPattern< |
1042 | nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern; |
1043 | |
1044 | LogicalResult |
1045 | matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor, |
1046 | ConversionPatternRewriter &rewriter) const override { |
1047 | |
1048 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1049 | |
1050 | nvgpu::TensorMapSwizzleKind swizzleKind = |
1051 | op.getTensorMap().getType().getSwizzle(); |
1052 | |
1053 | unsigned layout = |
1054 | (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128 |
1055 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64 |
1056 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32 |
1057 | : 1; |
1058 | unsigned swizzle = |
1059 | (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1 |
1060 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2 |
1061 | : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3 |
1062 | : 0; |
1063 | |
1064 | auto ti64 = b.getIntegerType(64); |
1065 | auto makeConst = [&](uint64_t index) -> Value { |
1066 | return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index)); |
1067 | }; |
1068 | auto shiftLeft = [&](Value value, unsigned shift) -> Value { |
1069 | return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift)); |
1070 | }; |
1071 | auto shiftRight = [&](Value value, unsigned shift) -> Value { |
1072 | return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift)); |
1073 | }; |
1074 | auto insertBit = [&](Value desc, Value val, int startBit) { |
1075 | return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit)); |
1076 | }; |
1077 | |
1078 | int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); |
1079 | uint64_t strideDimVal = (layout << 3) >> exclude4LSB; |
1080 | uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB; |
1081 | uint64_t offsetVal = 0; |
1082 | |
1083 | Value strideDim = makeConst(strideDimVal); |
1084 | Value leadDim = makeConst(leadDimVal); |
1085 | |
1086 | Value baseAddr = getStridedElementPtr( |
1087 | rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()), |
1088 | adaptor.getTensor(), {}); |
1089 | Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr); |
1090 | // Just use 14 bits for base address |
1091 | Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); |
1092 | |
1093 | int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32, |
1094 | startLeadBit = 16, startBaseAddrBit = 0; |
1095 | Value dsc = makeConst(0); |
1096 | // // [62,64) swizzle type |
1097 | dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit); |
1098 | // // [49,52) base_offset |
1099 | dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit); |
1100 | // // [32,46) stride |
1101 | dsc = insertBit(dsc, strideDim, startStrideBit); |
1102 | // // [16,30) leading dimension |
1103 | dsc = insertBit(dsc, leadDim, startLeadBit); |
1104 | // // [0,14) start_address |
1105 | dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); |
1106 | |
1107 | LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " |
1108 | << "leading_off:"<< leadDimVal << "\t" |
1109 | << "stride_off :"<< strideDimVal << "\t" |
1110 | << "base_offset:"<< offsetVal << "\t" |
1111 | << "layout_type:"<< swizzle << " (" |
1112 | << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) |
1113 | << ")\n start_addr : "<< baseAddr << "\n"); |
1114 | |
1115 | rewriter.replaceOp(op, dsc); |
1116 | return success(); |
1117 | } |
1118 | }; |
1119 | |
1120 | static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) { |
1121 | return b.create<LLVM::ConstantOp>(b.getIntegerType(64), |
1122 | b.getI32IntegerAttr(index)); |
1123 | } |
1124 | |
1125 | /// Returns a Value that holds data type enum that is expected by CUDA driver. |
1126 | static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) { |
1127 | // Enum is from CUDA driver API |
1128 | // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html |
1129 | enum CUtensorMapDataTypeEnum { |
1130 | CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, |
1131 | CU_TENSOR_MAP_DATA_TYPE_UINT16, |
1132 | CU_TENSOR_MAP_DATA_TYPE_UINT32, |
1133 | CU_TENSOR_MAP_DATA_TYPE_INT32, |
1134 | CU_TENSOR_MAP_DATA_TYPE_UINT64, |
1135 | CU_TENSOR_MAP_DATA_TYPE_INT64, |
1136 | CU_TENSOR_MAP_DATA_TYPE_FLOAT16, |
1137 | CU_TENSOR_MAP_DATA_TYPE_FLOAT32, |
1138 | CU_TENSOR_MAP_DATA_TYPE_FLOAT64, |
1139 | CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, |
1140 | CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, |
1141 | CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, |
1142 | CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ |
1143 | }; |
1144 | |
1145 | if (type.isUnsignedInteger(width: 8)) |
1146 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT8); |
1147 | if (type.isUnsignedInteger(width: 16)) |
1148 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT16); |
1149 | if (type.isUnsignedInteger(width: 32)) |
1150 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT32); |
1151 | if (type.isUnsignedInteger(width: 64)) |
1152 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT64); |
1153 | if (type.isSignlessInteger(width: 32)) |
1154 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_INT32); |
1155 | if (type.isSignlessInteger(width: 64)) |
1156 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_INT64); |
1157 | if (type.isF16()) |
1158 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT16); |
1159 | if (type.isF32()) |
1160 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT32); |
1161 | if (type.isF64()) |
1162 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT64); |
1163 | if (type.isBF16()) |
1164 | return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_BFLOAT16); |
1165 | |
1166 | llvm_unreachable("Not supported data type"); |
1167 | } |
1168 | |
1169 | struct NVGPUTmaCreateDescriptorOpLowering |
1170 | : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> { |
1171 | using ConvertOpToLLVMPattern< |
1172 | nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern; |
1173 | LogicalResult |
1174 | matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor, |
1175 | ConversionPatternRewriter &rewriter) const override { |
1176 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1177 | auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext()); |
1178 | Type llvmInt64Type = IntegerType::get(op->getContext(), 64); |
1179 | |
1180 | Value tensorElementType = |
1181 | elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType()); |
1182 | auto promotedOperands = getTypeConverter()->promoteOperands( |
1183 | b.getLoc(), op->getOperands(), adaptor.getOperands(), b); |
1184 | |
1185 | Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type, |
1186 | makeI64Const(b, 5)); |
1187 | for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) { |
1188 | Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType, |
1189 | boxArrayPtr, makeI64Const(b, index)); |
1190 | b.create<LLVM::StoreOp>(value, gep); |
1191 | } |
1192 | |
1193 | nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType(); |
1194 | // Set Arguments for the function call |
1195 | SmallVector<Value> arguments; |
1196 | arguments.push_back(Elt: promotedOperands[0]); // rank |
1197 | arguments.push_back(Elt: promotedOperands[1]); // descriptor |
1198 | arguments.push_back(Elt: tensorElementType); // data type |
1199 | arguments.push_back( |
1200 | Elt: makeI64Const(b, index: (int)desc.getInterleave())); // interleave |
1201 | arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getSwizzle())); // swizzle |
1202 | arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getL2promo())); // l2promo |
1203 | arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getOob())); // oob |
1204 | arguments.push_back(Elt: boxArrayPtr); // box dimensions |
1205 | |
1206 | // Set data types of the arguments |
1207 | SmallVector<Type> argTypes = { |
1208 | llvmInt64Type, /* int64_t tensorRank */ |
1209 | llvmPointerType, /* ptr */ |
1210 | llvmInt64Type, /* int64_t */ |
1211 | llvmInt64Type, /* int64_t */ |
1212 | llvmInt64Type, /* int64_t */ |
1213 | llvmInt64Type, /* int64_t */ |
1214 | llvmInt64Type, /* int64_t */ |
1215 | llvmPointerType /* ptr */ |
1216 | }; |
1217 | FunctionCallBuilder hostRegisterCallBuilder = { |
1218 | "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes}; |
1219 | Value tensorMap = |
1220 | hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult(); |
1221 | |
1222 | rewriter.replaceOp(op, tensorMap); |
1223 | return success(); |
1224 | } |
1225 | }; |
1226 | |
1227 | struct NVGPUWarpgroupMmaOpLowering |
1228 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> { |
1229 | using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern; |
1230 | |
1231 | /// This is a helper class to generate required NVVM Ops for warp-group level |
1232 | /// matrix multiplication. |
1233 | /// When the given GEMM shape is larger than the shape of |
1234 | /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp |
1235 | /// Op(s), group and execute them asynchronously. The class also handles |
1236 | /// waiting for completion and iterates through WarpgroupMatrixDescriptor to |
1237 | /// create descriptors for each instruction. |
1238 | /// |
1239 | /// For example this is the case when the shape of GEMM is 128x128x128 |
1240 | /// |
1241 | /// nvvm.wgmma.fence.aligned |
1242 | /// |
1243 | /// nvvm.wgmma.mma.async descA, descB |
1244 | /// iterate(descA, descB) |
1245 | /// nvvm.wgmma.mma.async descA, descB |
1246 | /// [6x times more] |
1247 | /// |
1248 | /// nvvm.wgmma.group.sync.aligned |
1249 | /// nvvm.wgmma.wait.group.sync [groupId] |
1250 | /// |
1251 | class WarpgroupGemm { |
1252 | nvgpu::WarpgroupMmaOp op; |
1253 | ImplicitLocOpBuilder b; |
1254 | OpAdaptor adaptor; |
1255 | |
1256 | // Entire shape of the given Op |
1257 | int64_t totalM, totalN, totalK; |
1258 | |
1259 | // Shape of one wgmma instruction |
1260 | int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0; |
1261 | |
1262 | // Iteration counts for GEMM |
1263 | int iterationM = 0, iterationN = 0, iterationK = 0; |
1264 | |
1265 | /// The function returns the shape of wgmma instruction that is defined in |
1266 | /// PTX programming guide. |
1267 | /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape |
1268 | void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) { |
1269 | wgmmaM = 64; |
1270 | wgmmaN = sizeN; |
1271 | if (inputElemType.isTF32()) { |
1272 | wgmmaK = 8; |
1273 | } else if (inputElemType.isF16() || inputElemType.isBF16()) { |
1274 | wgmmaK = 16; |
1275 | } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) || |
1276 | inputElemType.isInteger(16)) { |
1277 | wgmmaK = 32; |
1278 | } else if (inputElemType.isInteger(width: 1)) { |
1279 | wgmmaK = 256; |
1280 | } else { |
1281 | llvm_unreachable("msg: not supported K shape"); |
1282 | } |
1283 | LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = "<< wgmmaM |
1284 | << ", n = "<< wgmmaN << ", k = "<< wgmmaK << "]\n"); |
1285 | } |
1286 | |
1287 | /// Generates WGMMATypesAttr from MLIR Type |
1288 | NVVM::WGMMATypesAttr generateWgmmaType(Type type, |
1289 | bool useF32 = false) const { |
1290 | auto getWgmmaType = [=](Type elemType) { |
1291 | if (elemType.isF32() || elemType.isTF32()) |
1292 | return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32; |
1293 | if (elemType.isF16()) |
1294 | return NVVM::WGMMATypes::f16; |
1295 | if (elemType.isBF16()) |
1296 | return NVVM::WGMMATypes::bf16; |
1297 | if (isa<Float8E4M3FNType>(elemType)) |
1298 | return NVVM::WGMMATypes::e4m3; |
1299 | if (isa<Float8E5M2Type>(elemType)) |
1300 | return NVVM::WGMMATypes::e5m2; |
1301 | if (elemType.isInteger(1)) |
1302 | return NVVM::WGMMATypes::b1; |
1303 | if (elemType.isInteger(8)) |
1304 | return NVVM::WGMMATypes::s8; |
1305 | if (elemType.isUnsignedInteger(8)) |
1306 | return NVVM::WGMMATypes::u8; |
1307 | if (elemType.isInteger(32)) |
1308 | return NVVM::WGMMATypes::s32; |
1309 | llvm_unreachable("unsupported type"); |
1310 | }; |
1311 | return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type)); |
1312 | } |
1313 | |
1314 | /// Generates layout attribute for the input matrix for wgmma instruction |
1315 | NVVM::MMALayoutAttr |
1316 | generateWgmmaLayout(std::optional<bool> transpose) const { |
1317 | if (transpose.value_or(false)) |
1318 | return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col); |
1319 | return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row); |
1320 | } |
1321 | |
1322 | /// Generates shape attribute for wgmma instruction |
1323 | NVVM::MMAShapeAttr generateWgmmaShape() const { |
1324 | return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK); |
1325 | } |
1326 | |
1327 | /// Generates scale attributes of output matrix for wgmma instruction |
1328 | NVVM::WGMMAScaleOutAttr generateScaleOut() const { |
1329 | return NVVM::WGMMAScaleOutAttr::get(op->getContext(), |
1330 | NVVM::WGMMAScaleOut::one); |
1331 | } |
1332 | /// Generates scale attributes of input matrix for wgmma instruction |
1333 | NVVM::WGMMAScaleInAttr generateScaleIn() const { |
1334 | return NVVM::WGMMAScaleInAttr::get(op->getContext(), |
1335 | NVVM::WGMMAScaleIn::one); |
1336 | } |
1337 | |
1338 | /// Basic function to generate Add |
1339 | Value makeAdd(Value lhs, Value rhs) { |
1340 | return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); |
1341 | }; |
1342 | |
1343 | /// Moves the descriptor pointer of matrix-A for the next wgmma instruction. |
1344 | /// Currently, it only handles row-major. |
1345 | /// |
1346 | /// It moves the pointer like below for [128][64] size: |
1347 | /// +2 +4 +6 |
1348 | /// ↓ ↓ ↓ |
1349 | /// descA ---> +--+--+--+--+ |
1350 | /// |->|->|->|->| |
1351 | /// | | | | | |
1352 | /// | | | | | |
1353 | /// | | | | | |
1354 | /// descA+512---> +-----------+ |
1355 | /// | | | | | |
1356 | /// | | | | | |
1357 | /// | | | | | |
1358 | /// | | | | | |
1359 | /// +-----------+ |
1360 | /// |
1361 | Value iterateDescriptorA(Value desc, int i, int j, int k) { |
1362 | MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor(); |
1363 | Type elemA = matrixTypeA.getElementType(); |
1364 | int byte = elemA.getIntOrFloatBitWidth() / 8; |
1365 | int tileShapeA = matrixTypeA.getDimSize(1); |
1366 | int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; |
1367 | incrementVal = incrementVal >> exclude4LSB; |
1368 | LLVM_DEBUG(DBGS() << "\t\t[m: "<< i << " n: "<< j << " k: "<< k |
1369 | << "] [wgmma descriptors] Descriptor A + " |
1370 | << incrementVal << " | \t "); |
1371 | if (!incrementVal) |
1372 | return desc; |
1373 | return makeAdd(lhs: desc, rhs: makeI64Const(b, index: incrementVal)); |
1374 | } |
1375 | |
1376 | /// Moves the descriptor pointer of matrix-B for the next wgmma instruction. |
1377 | /// Currently, it only handles column-major. |
1378 | /// |
1379 | /// It moves the pointer like below for [128][64] size: |
1380 | /// descB ---> +--+--+--+--+--+--+--+--+ |
1381 | /// |↓ | | | | | | | | |
1382 | /// |↓ | | | | | | | | |
1383 | /// |↓ | | | | | | | | |
1384 | /// |↓ | | | | | | | | |
1385 | /// +--+--+--+--+--+--+--+--+ |
1386 | /// |
1387 | Value iterateDescriptorB(Value desc, int i, int j, int k) { |
1388 | MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor(); |
1389 | Type elemB = matrixTypeB.getElementType(); |
1390 | int byte = elemB.getIntOrFloatBitWidth() / 8; |
1391 | int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; |
1392 | incrementVal = incrementVal >> exclude4LSB; |
1393 | LLVM_DEBUG(DBGSE() << "Descriptor B + "<< incrementVal << "\n"); |
1394 | if (!incrementVal) |
1395 | return desc; |
1396 | return makeAdd(lhs: desc, rhs: makeI64Const(b, index: incrementVal)); |
1397 | } |
1398 | |
1399 | /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix |
1400 | /// descriptors and arranges them based on induction variables: i, j, and k. |
1401 | Value generateWgmma(int i, int j, int k, Value matrixC) { |
1402 | LLVM_DEBUG(DBGS() << "\t wgmma." |
1403 | << "m"<< wgmmaM << "n"<< wgmmaN << "k"<< wgmmaK |
1404 | << "(A["<< (iterationM * wgmmaM) << ":" |
1405 | << (iterationM * wgmmaM) + wgmmaM << "][" |
1406 | << (iterationK * wgmmaK) << ":" |
1407 | << (iterationK * wgmmaK + wgmmaK) << "] * " |
1408 | << " B["<< (iterationK * wgmmaK) << ":" |
1409 | << (iterationK * wgmmaK + wgmmaK) << "]["<< 0 << ":" |
1410 | << wgmmaN << "])\n"); |
1411 | |
1412 | Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); |
1413 | Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); |
1414 | |
1415 | Type elemA = op.getDescriptorA().getType().getTensor().getElementType(); |
1416 | NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA); |
1417 | |
1418 | Type elemB = op.getDescriptorB().getType().getTensor().getElementType(); |
1419 | NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB); |
1420 | |
1421 | Type elemD = op.getMatrixC().getType().getFragmented().getElementType(); |
1422 | NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true); |
1423 | |
1424 | NVVM::MMAShapeAttr shape = generateWgmmaShape(); |
1425 | NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut(); |
1426 | NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn(); |
1427 | NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA()); |
1428 | NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB()); |
1429 | |
1430 | auto overflow = NVVM::MMAIntOverflowAttr::get( |
1431 | op->getContext(), NVVM::MMAIntOverflow::wrapped); |
1432 | |
1433 | return b.create<NVVM::WgmmaMmaAsyncOp>( |
1434 | matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA, |
1435 | itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, |
1436 | overflow); |
1437 | } |
1438 | |
1439 | /// Generates multiple wgmma instructions to complete the given GEMM shape |
1440 | Value generateWgmmaGroup() { |
1441 | Value wgmmaResult = |
1442 | b.create<LLVM::PoisonOp>(adaptor.getMatrixC().getType()); |
1443 | |
1444 | // Perform GEMM |
1445 | SmallVector<Value> wgmmaResults; |
1446 | for (int i = 0; i < iterationM; ++i) { |
1447 | Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i); |
1448 | for (int j = 0; j < iterationN; ++j) |
1449 | for (int k = 0; k < iterationK; ++k) |
1450 | matrixC = generateWgmma(i, j, k, matrixC); |
1451 | wgmmaResults.push_back(Elt: matrixC); |
1452 | } |
1453 | for (auto [idx, matrix] : llvm::enumerate(First&: wgmmaResults)) { |
1454 | wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(), |
1455 | wgmmaResult, matrix, idx); |
1456 | } |
1457 | return wgmmaResult; |
1458 | } |
1459 | |
1460 | public: |
1461 | WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b, |
1462 | OpAdaptor adaptor) |
1463 | : op(op), b(b), adaptor(adaptor) { |
1464 | // Find the entire GEMM Shape |
1465 | totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); |
1466 | totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); |
1467 | totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); |
1468 | LLVM_DEBUG(DBGS() << "===--- GEMM D["<< totalM << "]["<< totalN |
1469 | << "] += A["<< totalM << "]["<< totalK << "] * B[" |
1470 | << totalK << "]["<< totalN << "] ---===\n"); |
1471 | |
1472 | // Find the shape for one wgmma instruction |
1473 | findWgmmaShape( |
1474 | sizeM: totalM, sizeN: totalN, |
1475 | inputElemType: op.getDescriptorA().getType().getTensor().getElementType()); |
1476 | |
1477 | // Iterations counts to complete the given shape with wgmma shape |
1478 | iterationM = totalM / wgmmaM; |
1479 | iterationN = totalN / wgmmaN; |
1480 | iterationK = totalK / wgmmaK; |
1481 | } |
1482 | |
1483 | /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It |
1484 | /// includes generating a fence Op (WgmmaFenceAlignedOp) before the |
1485 | /// instructions and group synchronization, as well as waiting |
1486 | /// (WgmmaGroupSyncAlignedOp) for group synchronization |
1487 | /// (WgmmaWaitGroupSyncOp) after the instructions. |
1488 | Value generateWarpgroupMma() { |
1489 | b.create<NVVM::WgmmaFenceAlignedOp>(); |
1490 | Value wgmmaResult = generateWgmmaGroup(); |
1491 | b.create<NVVM::WgmmaGroupSyncAlignedOp>(); |
1492 | b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup()); |
1493 | return wgmmaResult; |
1494 | } |
1495 | }; |
1496 | LogicalResult |
1497 | matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, |
1498 | ConversionPatternRewriter &rewriter) const override { |
1499 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1500 | |
1501 | // Step 1. Build a helper class |
1502 | WarpgroupGemm warpgroupGemm(op, b, adaptor); |
1503 | |
1504 | // Step 2. Get the entire GEMM Shape |
1505 | Value wgmmaResult = warpgroupGemm.generateWarpgroupMma(); |
1506 | |
1507 | // Step 3. Replace fragmented result struct with the op results |
1508 | rewriter.replaceOp(op, wgmmaResult); |
1509 | return success(); |
1510 | } |
1511 | }; |
1512 | |
1513 | struct NVGPUWarpgroupMmaStoreOpLowering |
1514 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> { |
1515 | using ConvertOpToLLVMPattern< |
1516 | nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern; |
1517 | |
1518 | /// This function stores a fragmented register matrix owned by a warp group |
1519 | /// (128 threads) into a memref. Each thread has 64 registers, each the size |
1520 | /// of a struct. |
1521 | /// Here is what each threads (T) holds, each `d` is struct value with a |
1522 | /// number. |
1523 | /// |
1524 | /// Threads in warp-group (128 threads) and what they owns in the matrixD: |
1525 | /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N] |
1526 | /// 32-63 Warp-1 -> MatrixD[16:31][0:N] |
1527 | /// 64-95 Warp-2 -> MatrixD[32:47][0:N] |
1528 | /// 96-127 Warp-3 -> MatrixD[48:64][0:N] |
1529 | /// |
1530 | /// Matrix-D: |
1531 | /// +______________________________________________________________________+ |
1532 | /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 | |
1533 | /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY| |
1534 | /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY| |
1535 | /// ..| .........|.........|.........|.........|........|...........|........| |
1536 | /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW| |
1537 | /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW| |
1538 | /// ..| .........|.........|.........|.........|........|...........|........| |
1539 | /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........| |
1540 | /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........| |
1541 | /// ..| .........|.........|.........|.........|........|...........|........| |
1542 | /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........| |
1543 | /// ..| .........|.........|.........|.........|........|...........|........| |
1544 | /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........| |
1545 | /// ..| .........|.........|.........|.........|........|...........|........| |
1546 | /// +______________________________________________________________________+ |
1547 | /// |
1548 | /// \param rewriter: The pattern rewriter. |
1549 | /// \param matrixD: Result of the warp-group MMA operation (fragmented |
1550 | /// matrix). It is holded by a thread and a struct with 64 elements. |
1551 | /// \param dstMemref: The memref where the registers will be stored. |
1552 | /// \param offset: the offset within the memref where the registers will be |
1553 | /// stored. |
1554 | void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD, |
1555 | TypedValue<MemRefType> dstMemref, |
1556 | int offset) const { |
1557 | Type i32 = b.getI32Type(); |
1558 | |
1559 | auto makeConst = [&](int32_t index) -> Value { |
1560 | return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index)); |
1561 | }; |
1562 | Value c1 = makeConst(1); |
1563 | Value c2 = makeConst(2); |
1564 | Value c4 = makeConst(4); |
1565 | Value c8 = makeConst(8); |
1566 | Value c16 = makeConst(16); |
1567 | Value warpSize = makeConst(kWarpSize); |
1568 | |
1569 | auto makeMul = [&](Value lhs, Value rhs) -> Value { |
1570 | return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs); |
1571 | }; |
1572 | auto makeAdd = [&](Value lhs, Value rhs) -> Value { |
1573 | return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); |
1574 | }; |
1575 | |
1576 | auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, |
1577 | TypedValue<::mlir::MemRefType> memref) { |
1578 | Type it = b.getIndexType(); |
1579 | Value idx = b.create<arith::IndexCastOp>(it, x); |
1580 | Value idy0 = b.create<arith::IndexCastOp>(it, y); |
1581 | Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1)); |
1582 | Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i); |
1583 | Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1); |
1584 | b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0}); |
1585 | b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1}); |
1586 | }; |
1587 | |
1588 | Value tidx = b.create<NVVM::ThreadIdXOp>(i32); |
1589 | Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize); |
1590 | Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize); |
1591 | Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4); |
1592 | Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4); |
1593 | |
1594 | Value tj = makeMul(lane4modId, c2); |
1595 | Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); |
1596 | if (offset) |
1597 | ti = makeAdd(ti, makeConst(offset)); |
1598 | |
1599 | auto structType = cast<LLVM::LLVMStructType>(matrixD.getType()); |
1600 | |
1601 | // Number of 32-bit registers owns per thread |
1602 | constexpr unsigned numAdjacentRegisters = 2; |
1603 | // Number of 8x8 matrices one below another per warp |
1604 | constexpr unsigned numStackedMatrices = 2; |
1605 | |
1606 | size_t storeCount = (structType.getBody().size() / |
1607 | (numStackedMatrices * numAdjacentRegisters)); |
1608 | |
1609 | for (size_t i = 0; i < numStackedMatrices; ++i) { |
1610 | Value idx = makeAdd(ti, makeMul(makeConst(i), c8)); |
1611 | for (size_t j = 0; j < storeCount; ++j) { |
1612 | Value idy = makeAdd(tj, makeMul(makeConst(j), c8)); |
1613 | size_t structIndex = (i * numAdjacentRegisters) + |
1614 | (j * (numStackedMatrices * numAdjacentRegisters)); |
1615 | makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref); |
1616 | } |
1617 | } |
1618 | } |
1619 | |
1620 | LogicalResult |
1621 | matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor, |
1622 | ConversionPatternRewriter &rewriter) const override { |
1623 | int offset = 0; |
1624 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1625 | Value matriDValue = adaptor.getMatrixD(); |
1626 | auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType()); |
1627 | for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { |
1628 | auto structType = cast<LLVM::LLVMStructType>(matrixD); |
1629 | Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx); |
1630 | storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); |
1631 | offset += structType.getBody().size(); |
1632 | } |
1633 | rewriter.eraseOp(op: op); |
1634 | return success(); |
1635 | } |
1636 | }; |
1637 | |
1638 | struct NVGPUWarpgroupMmaInitAccumulatorOpLowering |
1639 | : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> { |
1640 | using ConvertOpToLLVMPattern< |
1641 | nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern; |
1642 | LogicalResult |
1643 | matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor, |
1644 | ConversionPatternRewriter &rewriter) const override { |
1645 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1646 | LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>( |
1647 | getTypeConverter()->convertType(op.getMatrixC().getType())); |
1648 | Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front()) |
1649 | .getBody() |
1650 | .front(); |
1651 | Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType)); |
1652 | Value packStruct = b.create<LLVM::PoisonOp>(packStructType); |
1653 | SmallVector<Value> innerStructs; |
1654 | // Unpack the structs and set all values to zero |
1655 | for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { |
1656 | auto structType = cast<LLVM::LLVMStructType>(s); |
1657 | Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx); |
1658 | for (unsigned i = 0; i < structType.getBody().size(); ++i) { |
1659 | structValue = b.create<LLVM::InsertValueOp>( |
1660 | structType, structValue, zero, ArrayRef<int64_t>({i})); |
1661 | } |
1662 | innerStructs.push_back(structValue); |
1663 | } |
1664 | // Pack the inner structs into a single struct |
1665 | for (auto [idx, matrix] : llvm::enumerate(First&: innerStructs)) { |
1666 | packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(), |
1667 | packStruct, matrix, idx); |
1668 | } |
1669 | rewriter.replaceOp(op, packStruct); |
1670 | return success(); |
1671 | } |
1672 | }; |
1673 | |
1674 | struct NVGPUTmaFenceOpLowering |
1675 | : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> { |
1676 | using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern; |
1677 | LogicalResult |
1678 | matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor, |
1679 | ConversionPatternRewriter &rewriter) const override { |
1680 | MLIRContext *ctx = op.getContext(); |
1681 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1682 | auto i32Ty = b.getI32Type(); |
1683 | Value tensormapSize = |
1684 | b.create<LLVM::ConstantOp>(i32Ty, rewriter.getI32IntegerAttr(128)); |
1685 | |
1686 | auto memscope = |
1687 | NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS); |
1688 | |
1689 | rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>( |
1690 | op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize); |
1691 | |
1692 | return success(); |
1693 | } |
1694 | }; |
1695 | |
1696 | struct NVGPUTmaPrefetchOpLowering |
1697 | : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> { |
1698 | using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern; |
1699 | LogicalResult |
1700 | matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor, |
1701 | ConversionPatternRewriter &rewriter) const override { |
1702 | rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>( |
1703 | op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate()); |
1704 | return success(); |
1705 | } |
1706 | }; |
1707 | |
1708 | struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> { |
1709 | using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern; |
1710 | LogicalResult |
1711 | matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor, |
1712 | ConversionPatternRewriter &rewriter) const override { |
1713 | ImplicitLocOpBuilder b(op->getLoc(), rewriter); |
1714 | auto i64Ty = b.getI64Type(); |
1715 | auto f32Ty = b.getF32Type(); |
1716 | VectorType inTy = op.getIn().getType(); |
1717 | // apply rcp.approx.ftz.f on each element in vector. |
1718 | auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) { |
1719 | Value ret1DVec = b.create<LLVM::PoisonOp>(llvm1DVectorTy); |
1720 | int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements(); |
1721 | for (int i = 0; i < numElems; i++) { |
1722 | Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i)); |
1723 | Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx); |
1724 | Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem); |
1725 | ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx); |
1726 | } |
1727 | return ret1DVec; |
1728 | }; |
1729 | if (inTy.getRank() == 1) { |
1730 | rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn())); |
1731 | return success(); |
1732 | } |
1733 | return LLVM::detail::handleMultidimensionalVectors( |
1734 | op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *(this->getTypeConverter()), |
1735 | createOperand: [&](Type llvm1DVectorTy, ValueRange operands) -> Value { |
1736 | OpAdaptor adaptor(operands); |
1737 | return convert1DVec(llvm1DVectorTy, adaptor.getIn()); |
1738 | }, |
1739 | rewriter); |
1740 | } |
1741 | }; |
1742 | } // namespace |
1743 | |
1744 | void mlir::populateNVGPUToNVVMConversionPatterns( |
1745 | const LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
1746 | patterns.add< |
1747 | NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create |
1748 | NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init |
1749 | NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get |
1750 | NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive |
1751 | NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete |
1752 | NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity |
1753 | NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity |
1754 | NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load |
1755 | NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store |
1756 | NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor |
1757 | NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor |
1758 | NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor |
1759 | NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx |
1760 | NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor |
1761 | NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma |
1762 | NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store |
1763 | NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator |
1764 | MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, |
1765 | NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, |
1766 | NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(arg: converter); |
1767 | } |
1768 |
Definitions
- exclude4LSB
- truncToI32
- inferIntrinsicResultType
- convertIntrinsicResult
- unpackOperandVector
- isMbarrierShared
- getMbarrierMemorySpace
- getMBarrierMemrefType
- MmaLdMatrixOpToNVVM
- matchAndRewrite
- getNvvmMmaType
- MmaSyncOptoNVVM
- matchAndRewrite
- ConvertNVGPUToNVVMPass
- getDependentDialects
- runOnOperation
- buildMmaSparseAsmConstraintString
- buildMmaSparseAsmString
- emitMmaSparseSyncOpAsm
- NVGPUMmaSparseSyncLowering
- matchAndRewrite
- NVGPUAsyncCopyLowering
- matchAndRewrite
- NVGPUAsyncCreateGroupLowering
- matchAndRewrite
- NVGPUAsyncWaitLowering
- matchAndRewrite
- NVGPUMBarrierCreateLowering
- generateGlobalBarrier
- matchAndRewrite
- MBarrierBasePattern
- getMbarrierPtr
- NVGPUMBarrierGetLowering
- matchAndRewrite
- NVGPUMBarrierInitLowering
- matchAndRewrite
- NVGPUMBarrierArriveLowering
- matchAndRewrite
- NVGPUMBarrierArriveNoCompleteLowering
- matchAndRewrite
- NVGPUMBarrierTestWaitLowering
- matchAndRewrite
- NVGPUMBarrierArriveExpectTxLowering
- matchAndRewrite
- NVGPUMBarrierTryWaitParityLowering
- matchAndRewrite
- NVGPUTmaAsyncLoadOpLowering
- matchAndRewrite
- NVGPUTmaAsyncStoreOpLowering
- matchAndRewrite
- NVGPUGenerateWarpgroupDescriptorLowering
- matchAndRewrite
- makeI64Const
- elementTypeAsLLVMConstant
- NVGPUTmaCreateDescriptorOpLowering
- matchAndRewrite
- NVGPUWarpgroupMmaOpLowering
- WarpgroupGemm
- findWgmmaShape
- generateWgmmaType
- generateWgmmaLayout
- generateWgmmaShape
- generateScaleOut
- generateScaleIn
- makeAdd
- iterateDescriptorA
- iterateDescriptorB
- generateWgmma
- generateWgmmaGroup
- WarpgroupGemm
- generateWarpgroupMma
- matchAndRewrite
- NVGPUWarpgroupMmaStoreOpLowering
- storeFragmentedMatrix
- matchAndRewrite
- NVGPUWarpgroupMmaInitAccumulatorOpLowering
- matchAndRewrite
- NVGPUTmaFenceOpLowering
- matchAndRewrite
- NVGPUTmaPrefetchOpLowering
- matchAndRewrite
- NVGPURcpOpLowering
- matchAndRewrite
Improve your Profiling and Debugging skills
Find out more