1//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
10
11#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13#include "mlir/Conversion/LLVMCommon/Pattern.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
21#include "mlir/Dialect/SCF/Transforms/Patterns.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/ImplicitLocOpBuilder.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/IR/TypeUtilities.h"
26#include "mlir/IR/Value.h"
27#include "mlir/Pass/Pass.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/ErrorHandling.h"
30#include "llvm/Support/raw_ostream.h"
31#include <optional>
32
33#define DEBUG_TYPE "nvgpu-to-nvvm"
34#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
35#define DBGSE() (llvm::dbgs())
36
37namespace mlir {
38#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
39#include "mlir/Conversion/Passes.h.inc"
40} // namespace mlir
41
42using namespace mlir;
43
44/// Number of bits that needs to be excluded when building matrix descriptor for
45/// wgmma operations.
46constexpr int exclude4LSB = 4;
47
48/// GPU has 32 bit registers, this function truncates values when larger width
49/// is not needed.
50static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
51 Type type = value.getType();
52 assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
53 if (type.getIntOrFloatBitWidth() <= 32)
54 return value;
55 return b.create<LLVM::TruncOp>(b.getI32Type(), value);
56}
57
58/// Returns the type for the intrinsic given the vectorResultType of the
59/// `gpu.mma.sync` operation.
60static Type inferIntrinsicResultType(Type vectorResultType) {
61 MLIRContext *ctx = vectorResultType.getContext();
62 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
63 auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
64 auto i32Ty = IntegerType::get(ctx, 32);
65 auto i32x2Ty = LLVM::getFixedVectorType(elementType: i32Ty, numElements: 2);
66 Type f64Ty = Float64Type::get(ctx);
67 Type f64x2Ty = LLVM::getFixedVectorType(elementType: f64Ty, numElements: 2);
68 Type f32Ty = Float32Type::get(ctx);
69 Type f32x2Ty = LLVM::getFixedVectorType(elementType: f32Ty, numElements: 2);
70 if (a.getElementType() == f16x2Ty) {
71 return LLVM::LLVMStructType::getLiteral(
72 context: ctx, types: SmallVector<Type>(a.getNumElements(), f16x2Ty));
73 }
74 if (a.getElementType() == i32x2Ty) {
75 return LLVM::LLVMStructType::getLiteral(
76 context: ctx,
77 types: SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
78 }
79 if (a.getElementType() == f64x2Ty) {
80 return LLVM::LLVMStructType::getLiteral(context: ctx, types: {f64Ty, f64Ty});
81 }
82 if (a.getElementType() == f32x2Ty) {
83 return LLVM::LLVMStructType::getLiteral(
84 context: ctx,
85 types: SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
86 }
87 if (a.getElementType() == LLVM::getFixedVectorType(elementType: f32Ty, numElements: 1)) {
88 return LLVM::LLVMStructType::getLiteral(
89 context: ctx, types: SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
90 }
91 return vectorResultType;
92}
93
94/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
95/// always an LLVM struct) into a fragment that is compatible with the vector
96/// type of this operation. This involves extracting elements from the struct
97/// and inserting them into an LLVM array. These extra data-movement
98/// operations should be canonicalized away by the LLVM backend.
99static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
100 Type resultType, Value intrinsicResult,
101 RewriterBase &rewriter) {
102 MLIRContext *ctx = rewriter.getContext();
103 auto structType = dyn_cast<LLVM::LLVMStructType>(Val&: intrinsicResultType);
104 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
105 Type i32Ty = rewriter.getI32Type();
106 Type f32Ty = rewriter.getF32Type();
107 Type f64Ty = rewriter.getF64Type();
108 Type f16x2Ty = LLVM::getFixedVectorType(elementType: rewriter.getF16Type(), numElements: 2);
109 Type i32x2Ty = LLVM::getFixedVectorType(elementType: i32Ty, numElements: 2);
110 Type f64x2Ty = LLVM::getFixedVectorType(elementType: f64Ty, numElements: 2);
111 Type f32x2Ty = LLVM::getFixedVectorType(elementType: f32Ty, numElements: 2);
112 Type f32x1Ty = LLVM::getFixedVectorType(elementType: f32Ty, numElements: 1);
113
114 auto makeConst = [&](int32_t index) -> Value {
115 return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
116 rewriter.getI32IntegerAttr(index));
117 };
118
119 if (arrayType) {
120 SmallVector<Value, 4> elements;
121
122 // The intrinsic returns 32-bit wide elements in a form which can be
123 // directly bitcasted and inserted into the result vector.
124 if (arrayType.getElementType() == f16x2Ty ||
125 arrayType.getElementType() == f32x1Ty) {
126 for (unsigned i = 0; i < structType.getBody().size(); i++) {
127 Value el =
128 rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
129 el = rewriter.createOrFold<LLVM::BitcastOp>(
130 loc, arrayType.getElementType(), el);
131 elements.push_back(Elt: el);
132 }
133 }
134
135 // The intrinsic returns i32, f64, and f32 values as individual scalars,
136 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
137 // need to extract them from the struct and pack them into the 64-bit wide
138 // rows of the vector result.
139 if (arrayType.getElementType() == i32x2Ty ||
140 arrayType.getElementType() == f64x2Ty ||
141 arrayType.getElementType() == f32x2Ty) {
142
143 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
144 Value vec =
145 rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
146 Value x1 =
147 rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
148 Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
149 i * 2 + 1);
150 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
151 x1, makeConst(0));
152 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
153 x2, makeConst(1));
154 elements.push_back(Elt: vec);
155 }
156 }
157
158 // Create the final vectorized result.
159 Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
160 for (const auto &el : llvm::enumerate(First&: elements)) {
161 result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
162 el.index());
163 }
164 return result;
165 }
166
167 return intrinsicResult;
168}
169
170/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
171/// given as 2D `vectors` where the rows are 32b or 64b wide. The
172/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
173/// scalars of certain types. This function helps unpack the `vector` arguments
174/// and cast them to the types expected by `nvvm.mma.sync`.
175static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
176 Value operand,
177 NVVM::MMATypes operandPtxType) {
178 SmallVector<Value> result;
179 Type i32Ty = b.getI32Type();
180 Type f64Ty = b.getF64Type();
181 Type f32Ty = b.getF32Type();
182 Type i64Ty = b.getI64Type();
183 Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
184 Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
185 Type f32x1Ty = LLVM::getFixedVectorType(elementType: f32Ty, numElements: 1);
186 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
187
188 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
189 Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
190
191 // For 4xi8 vectors, the intrinsic expects these to be provided as i32
192 // scalar types.
193 if (arrayTy.getElementType() == i8x4Ty ||
194 arrayTy.getElementType() == i4x8Ty ||
195 (arrayTy.getElementType() == f32x1Ty &&
196 operandPtxType == NVVM::MMATypes::tf32)) {
197 result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
198 continue;
199 }
200
201 // For some element types (i32, f32, f64), we need to unpack the inner
202 // vector/array type as well because the intrinsic expects individual
203 // scalars to be provided.
204 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
205 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
206 innerArrayTy.getElementType() == f64Ty ||
207 innerArrayTy.getElementType() == f32Ty)) {
208 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
209 idx < innerSize; idx++) {
210 result.push_back(b.create<LLVM::ExtractElementOp>(
211 toUse,
212 b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
213 }
214 continue;
215 }
216 result.push_back(Elt: toUse);
217 }
218 return result;
219}
220
221/// Returns whether mbarrier object has shared memory address space.
222static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
223 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
224 barrierType.getMemorySpace()));
225}
226
227/// Returns the memory space attribute of the mbarrier object.
228Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context,
229 nvgpu::MBarrierGroupType barrierType) {
230 Attribute memorySpace = {};
231 if (isMbarrierShared(barrierType)) {
232 memorySpace =
233 IntegerAttr::get(IntegerType::get(context, 64),
234 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
235 }
236 return memorySpace;
237}
238
239/// Returns memref type of the mbarrier object. The type is defined in the
240/// MBarrierGroupType.
241MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
242 nvgpu::MBarrierGroupType barrierType) {
243 Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType: barrierType);
244 MemRefLayoutAttrInterface layout;
245 return MemRefType::get({barrierType.getNumBarriers()},
246 IntegerType::get(context, 64), layout, memorySpace);
247}
248
249namespace {
250
251struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
252 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
253
254 LogicalResult
255 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter) const override {
257 MLIRContext *ctx = getContext();
258 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
259
260 // The result type of ldmatrix will always be a struct of 32bit integer
261 // registers if more than one 32bit value is returned. Otherwise, the result
262 // is a single i32. The result type of the GPU operation is always a vector
263 // of shape (NumRegisters, VectorRegister) where VectorRegister is the
264 // vector type of the result and always 32 bits long. We bitcast the result
265 // of the NVVM::LdMatrix to this vector type.
266 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
267 if (!vectorResultType) {
268 return failure();
269 }
270 Type innerVectorType = LLVM::getFixedVectorType(
271 elementType: vectorResultType.getElementType(), numElements: vectorResultType.getDimSize(1));
272
273 int64_t num32BitRegs = vectorResultType.getDimSize(0);
274
275 Type ldMatrixResultType;
276 if (num32BitRegs > 1) {
277 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
278 context: ctx, types: SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
279 } else {
280 ldMatrixResultType = rewriter.getI32Type();
281 }
282
283 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
284 Value srcPtr =
285 getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
286 adaptor.getIndices(), rewriter);
287 Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
288 ldMatrixResultType, srcPtr,
289 /*num=*/op.getNumTiles(),
290 /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
291 : NVVM::MMALayout::row);
292
293 // The ldmatrix operation returns either a single i32 value or a struct of
294 // i32 values. Here we unpack those values and cast them back to their
295 // actual vector type (still of width 32b) and repack them into a result
296 // struct.
297 Type finalResultType = typeConverter->convertType(vectorResultType);
298 Value result = b.create<LLVM::UndefOp>(finalResultType);
299 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
300 Value i32Register =
301 num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
302 : ldMatrixResult;
303 Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
304 result = b.create<LLVM::InsertValueOp>(result, casted, i);
305 }
306
307 rewriter.replaceOp(op, result);
308 return success();
309 }
310};
311
312/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
313/// enum).
314static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
315 Type elType = getElementTypeOrSelf(type: t);
316 if (elType.isInteger(8))
317 return NVVM::MMATypes::s8;
318 if (elType.isInteger(4))
319 return NVVM::MMATypes::s4;
320 if (elType.isF16())
321 return NVVM::MMATypes::f16;
322 if (elType.isF64())
323 return NVVM::MMATypes::f64;
324 if (elType.isF32())
325 return NVVM::MMATypes::tf32;
326 return failure();
327}
328
329struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
330 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
331
332 LogicalResult
333 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
334 ConversionPatternRewriter &rewriter) const override {
335 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
336 // Get the shapes of the MMAMatrix type being used. The shapes will
337 // choose which intrinsic this op will be lowered to.
338 VectorType aType = op.getMatrixA().getType();
339 VectorType bType = op.getMatrixA().getType();
340 VectorType cType = op.getMatrixC().getType();
341
342 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
343
344 // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
345 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
346 if (aType.getElementType().isF32() && !tf32Enabled)
347 return failure();
348
349 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
350 if (failed(ptxTypeA))
351 return op->emitOpError("failed to deduce operand PTX types");
352 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
353 if (failed(ptxTypeB))
354 return op->emitOpError("failed to deduce operand PTX types");
355 std::optional<NVVM::MMATypes> ptxTypeC =
356 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
357 /*isAccumulator=*/true);
358 if (!ptxTypeC)
359 return op->emitError(
360 "could not infer the PTX type for the accumulator/result");
361
362 // TODO: add an attribute to the op to customize this behavior.
363 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
364 if (isa<IntegerType>(aType.getElementType()))
365 overflow = NVVM::MMAIntOverflow::satfinite;
366
367 SmallVector<Value> matA =
368 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
369 SmallVector<Value> matB =
370 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
371 SmallVector<Value> matC =
372 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
373
374 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
375 Type intrinsicResTy = inferIntrinsicResultType(
376 typeConverter->convertType(op->getResultTypes()[0]));
377 Value intrinsicResult = b.create<NVVM::MmaOp>(
378 intrinsicResTy, matA, matB, matC,
379 /*shape=*/gemmShape,
380 /*b1Op=*/std::nullopt,
381 /*intOverflow=*/overflow,
382 /*multiplicandPtxTypes=*/
383 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
384 /*multiplicandLayouts=*/
385 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
386 NVVM::MMALayout::col});
387 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
388 desiredRetTy, intrinsicResult,
389 rewriter));
390 return success();
391 }
392};
393
394struct ConvertNVGPUToNVVMPass
395 : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
396 using Base::Base;
397
398 void getDependentDialects(DialectRegistry &registry) const override {
399 registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
400 arith::ArithDialect>();
401 }
402
403 void runOnOperation() override {
404 LowerToLLVMOptions options(&getContext());
405 RewritePatternSet patterns(&getContext());
406 LLVMTypeConverter converter(&getContext(), options);
407 IRRewriter rewriter(&getContext());
408 populateGpuMemorySpaceAttributeConversions(
409 typeConverter&: converter, mapping: [](gpu::AddressSpace space) -> unsigned {
410 switch (space) {
411 case gpu::AddressSpace::Global:
412 return static_cast<unsigned>(
413 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
414 case gpu::AddressSpace::Workgroup:
415 return static_cast<unsigned>(
416 NVVM::NVVMMemorySpace::kSharedMemorySpace);
417 case gpu::AddressSpace::Private:
418 return 0;
419 }
420 llvm_unreachable("unknown address space enum value");
421 return 0;
422 });
423 /// device-side async tokens cannot be materialized in nvvm. We just
424 /// convert them to a dummy i32 type in order to easily drop them during
425 /// conversion.
426 converter.addConversion(callback: [&](nvgpu::DeviceAsyncTokenType type) -> Type {
427 return converter.convertType(IntegerType::get(type.getContext(), 32));
428 });
429 converter.addConversion(callback: [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
430 Type elemType = type.getFragmented().getElementType();
431 int64_t sizeM = type.getFragmented().getDimSize(0);
432 int64_t sizeN = type.getFragmented().getDimSize(1);
433
434 unsigned numMembers;
435 if (elemType.isF32() || elemType.isInteger(width: 32))
436 numMembers = sizeN / 2;
437 else if (elemType.isF16())
438 numMembers = sizeN / 4;
439 else
440 llvm_unreachable("unsupported type for warpgroup accumulator");
441
442 SmallVector<Type> innerStructBody;
443 for (unsigned i = 0; i < numMembers; i++)
444 innerStructBody.push_back(Elt: elemType);
445 auto innerStructType =
446 LLVM::LLVMStructType::getLiteral(context: type.getContext(), types: innerStructBody);
447
448 SmallVector<Type> structBody;
449 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
450 structBody.push_back(Elt: innerStructType);
451
452 auto convertedType =
453 LLVM::LLVMStructType::getLiteral(context: type.getContext(), types: structBody);
454 return converter.convertType(convertedType);
455 });
456 converter.addConversion(callback: [&](nvgpu::MBarrierTokenType type) -> Type {
457 return converter.convertType(IntegerType::get(type.getContext(), 64));
458 });
459 converter.addConversion(
460 callback: [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
461 return converter.convertType(IntegerType::get(type.getContext(), 64));
462 });
463 converter.addConversion(callback: [&](nvgpu::MBarrierGroupType type) -> Type {
464 return converter.convertType(
465 nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
466 });
467 converter.addConversion(callback: [&](nvgpu::TensorMapDescriptorType type) -> Type {
468 return LLVM::LLVMPointerType::get(type.getContext());
469 });
470 populateNVGPUToNVVMConversionPatterns(converter, patterns);
471 LLVMConversionTarget target(getContext());
472 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
473 target.addLegalDialect<::mlir::arith::ArithDialect>();
474 target.addLegalDialect<::mlir::memref::MemRefDialect>();
475 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
476 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
477 typeConverter&: converter, patterns, target);
478 if (failed(applyPartialConversion(getOperation(), target,
479 std::move(patterns))))
480 signalPassFailure();
481 }
482};
483
484/// Returns the constraints for the sparse MMA inline assembly instruction.
485static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
486 unsigned matBSize,
487 unsigned matCSize) {
488 std::string str;
489 llvm::raw_string_ostream ss(str);
490 for (unsigned i = 0; i < matCSize; i++)
491 ss << "=r,";
492 for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
493 ss << "r,";
494 // The final operand is for the sparsity metadata.
495 // The sparsity selector appears as direct literal.
496 ss << "r";
497 ss.flush();
498 return str;
499}
500
501/// Returns the string for the `mma.sp.sync` instruction that corresponds to
502/// the given parameters. Note that this function doesn't do any validation,
503/// it's expected that the provided parameters correspond to a valid
504/// instruction.
505static std::string buildMmaSparseAsmString(
506 const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
507 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
509 std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
510 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511 return NVVM::stringifyMMATypes(ptxType);
512 };
513
514 std::string asmStr;
515 llvm::raw_string_ostream ss(asmStr);
516 ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
517 << shape[2] << ".row.col.";
518
519 if (overflow)
520 ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
521
522 ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
523 << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
524 unsigned asmArgIdx = 0;
525
526 // The operand string is structured into sections `{matC elements...},
527 // {matA elements...}, {matB elements...}, {matC elements}`.
528 for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
529 ss << "{";
530 for (unsigned i = 0; i < arrSize; i++)
531 ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
532 ss << "},";
533 }
534 ss << "$" << asmArgIdx++ << ",";
535 assert(metaDataSelector <= 1);
536 ss << "0x" << metaDataSelector << ";";
537 ss.flush();
538 return asmStr;
539}
540
541/// Builds an inline assembly operation corresponding to the specified MMA
542/// sparse sync operation.
543static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
544 ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
545 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
546 std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
547 ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
548 int64_t metadataSelector, const std::array<int64_t, 3> &shape,
549 Type intrinsicResultType) {
550 auto asmDialectAttr =
551 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
552
553 const unsigned matASize = unpackedAData.size();
554 const unsigned matBSize = unpackedB.size();
555 const unsigned matCSize = unpackedC.size();
556
557 std::string asmStr = buildMmaSparseAsmString(
558 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
559 ptxTypeD, overflow, metadataSelector);
560 std::string constraintStr =
561 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
562
563 SmallVector<Value> asmVals;
564 asmVals.reserve(N: matASize + matBSize + matCSize + 1);
565 for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
566 llvm::append_range(C&: asmVals, R&: args);
567 asmVals.push_back(Elt: indexData);
568
569 return b.create<LLVM::InlineAsmOp>(
570 /*resultTypes=*/intrinsicResultType,
571 /*operands=*/asmVals,
572 /*asm_string=*/asmStr,
573 /*constraints=*/constraintStr,
574 /*has_side_effects=*/true,
575 /*is_align_stack=*/false,
576 /*asm_dialect=*/asmDialectAttr,
577 /*operand_attrs=*/ArrayAttr());
578}
579
580/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
581struct NVGPUMmaSparseSyncLowering
582 : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
583 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
584
585 LogicalResult
586 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
587 ConversionPatternRewriter &rewriter) const override {
588 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
589 // Get the shapes of the MMAMatrix type being used. The shapes will
590 // choose which intrinsic this op will be lowered to.
591 VectorType aType = op.getMatrixA().getType();
592 VectorType bType = op.getMatrixB().getType();
593 VectorType cType = op.getMatrixC().getType();
594
595 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
596 if (failed(ptxTypeA))
597 return op->emitOpError("failed to deduce operand PTX types");
598 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
599 if (failed(ptxTypeB))
600 return op->emitOpError("failed to deduce operand PTX types");
601 std::optional<NVVM::MMATypes> ptxTypeC =
602 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
603 /*isAccumulator=*/true);
604 if (!ptxTypeC)
605 return op->emitError(
606 "could not infer the PTX type for the accumulator/result");
607
608 // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
609 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
610 if (aType.getElementType().isF32() && !tf32Enabled)
611 return failure();
612
613 // TODO: add an attribute to the op to customize this behavior.
614 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
615 if (isa<IntegerType>(aType.getElementType()))
616 overflow = NVVM::MMAIntOverflow::satfinite;
617
618 SmallVector<Value> matA =
619 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
620 SmallVector<Value> matB =
621 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
622 SmallVector<Value> matC =
623 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
624
625 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
626 Type intrinsicResTy = inferIntrinsicResultType(
627 typeConverter->convertType(op->getResultTypes()[0]));
628
629 // Bitcast the sparse metadata from vector<2xf16> to an i32.
630 Value sparseMetadata = adaptor.getSparseMetadata();
631 if (sparseMetadata.getType() !=
632 LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
633 return op->emitOpError() << "Expected metadata type to be LLVM "
634 "VectorType of 2 i16 elements";
635 sparseMetadata =
636 b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
637
638 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
639 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
640 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
641 intrinsicResTy);
642 if (failed(intrinsicResult))
643 return failure();
644
645 assert((*intrinsicResult).getNumResults() == 1 &&
646 "expected inline asm op returns a single LLVM struct type");
647 rewriter.replaceOp(
648 op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
649 (*intrinsicResult)->getResult(0), rewriter));
650 return success();
651 }
652};
653
654struct NVGPUAsyncCopyLowering
655 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
656 using ConvertOpToLLVMPattern<
657 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
658
659 LogicalResult
660 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
661 ConversionPatternRewriter &rewriter) const override {
662 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
663 Location loc = op.getLoc();
664 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
665 Value dstPtr =
666 getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
667 adaptor.getDstIndices(), rewriter);
668 FailureOr<unsigned> dstAddressSpace =
669 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
670 if (failed(result: dstAddressSpace))
671 return rewriter.notifyMatchFailure(
672 arg&: loc, msg: "destination memref address space not convertible to integer");
673
674 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
675 FailureOr<unsigned> srcAddressSpace =
676 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
677 if (failed(result: srcAddressSpace))
678 return rewriter.notifyMatchFailure(
679 arg&: loc, msg: "source memref address space not convertible to integer");
680
681 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
682 adaptor.getSrcIndices(), rewriter);
683 // Intrinsics takes a global pointer so we need an address space cast.
684 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
685 op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
686 scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
687 int64_t dstElements = adaptor.getDstElements().getZExtValue();
688 int64_t sizeInBytes =
689 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
690 // When the optional SrcElements argument is *not* present, the regular
691 // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
692 // memory) to fill DstElements number of elements in the destination
693 // (shared memory).
694 Value srcBytes = adaptor.getSrcElements();
695 if (srcBytes) {
696 // When the optional SrcElements argument is present, the source (global
697 // memory) of CpAsyncOp is read only for SrcElements number of elements.
698 // The rest of the DstElements in the destination (shared memory) are
699 // filled with zeros.
700 Value c3I32 =
701 b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
702 Value bitwidth = b.create<LLVM::ConstantOp>(
703 b.getI32Type(),
704 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
705 Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
706 srcBytes = b.create<LLVM::LShrOp>(
707 b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
708 }
709 // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
710 // 16 dst bytes.
711 NVVM::LoadCacheModifierKind cacheModifier =
712 (op.getBypassL1().value_or(false) && sizeInBytes == 16)
713 ? NVVM::LoadCacheModifierKind::CG
714 : NVVM::LoadCacheModifierKind::CA;
715
716 b.create<NVVM::CpAsyncOp>(
717 dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
718 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
719 srcBytes);
720
721 // Drop the result token.
722 Value zero = b.create<LLVM::ConstantOp>(
723 IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
724 rewriter.replaceOp(op, zero);
725 return success();
726 }
727};
728
729struct NVGPUAsyncCreateGroupLowering
730 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
731 using ConvertOpToLLVMPattern<
732 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
733
734 LogicalResult
735 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
736 ConversionPatternRewriter &rewriter) const override {
737 rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
738 // Drop the result token.
739 Value zero = rewriter.create<LLVM::ConstantOp>(
740 op->getLoc(), IntegerType::get(op.getContext(), 32),
741 rewriter.getI32IntegerAttr(0));
742 rewriter.replaceOp(op, zero);
743 return success();
744 }
745};
746
747struct NVGPUAsyncWaitLowering
748 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
749 using ConvertOpToLLVMPattern<
750 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
751
752 LogicalResult
753 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
754 ConversionPatternRewriter &rewriter) const override {
755 // If numGroup is not present pick 0 as a conservative correct value.
756 int32_t numGroups = adaptor.getNumGroups().value_or(0);
757 rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
758 rewriter.eraseOp(op: op);
759 return success();
760 }
761};
762
763/// Creates mbarrier object in shared memory
764struct NVGPUMBarrierCreateLowering
765 : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
766 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
767
768 template <typename moduleT>
769 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
770 Operation *funcOp, moduleT moduleOp,
771 MemRefType barrierType) const {
772 SymbolTable symbolTable(moduleOp);
773 OpBuilder::InsertionGuard guard(rewriter);
774 rewriter.setInsertionPoint(&moduleOp.front());
775 auto global = rewriter.create<memref::GlobalOp>(
776 funcOp->getLoc(), "__mbarrier",
777 /*sym_visibility=*/rewriter.getStringAttr("private"),
778 /*type=*/barrierType,
779 /*initial_value=*/ElementsAttr(),
780 /*constant=*/false,
781 /*alignment=*/rewriter.getI64IntegerAttr(8));
782 symbolTable.insert(symbol: global);
783 return global;
784 }
785
786 LogicalResult
787 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
788 ConversionPatternRewriter &rewriter) const override {
789 Operation *funcOp = op->getParentOp();
790 MemRefType barrierType = nvgpu::getMBarrierMemrefType(
791 rewriter.getContext(), op.getBarriers().getType());
792
793 memref::GlobalOp global;
794 if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
795 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
796 else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
797 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
798
799 rewriter.setInsertionPoint(op);
800 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
801 global.getName());
802 return success();
803 }
804};
805
806/// Base class for lowering mbarrier operations to nvvm intrinsics.
807template <typename SourceOp>
808struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
809public:
810 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
811 /// Returns the base pointer of the mbarrier object.
812 Value getMbarrierPtr(ImplicitLocOpBuilder &b,
813 nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
814 Value mbarId,
815 ConversionPatternRewriter &rewriter) const {
816 MemRefType mbarrierMemrefType =
817 nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
818 return ConvertToLLVMPattern::getStridedElementPtr(
819 loc: b.getLoc(), type: mbarrierMemrefType, memRefDesc: memrefDesc, indices: {mbarId}, rewriter);
820 }
821};
822
823/// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
824struct NVGPUMBarrierInitLowering
825 : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
826 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
827
828 LogicalResult
829 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
830 ConversionPatternRewriter &rewriter) const override {
831 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
832 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
833 rewriter.setInsertionPoint(op);
834 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
835 adaptor.getMbarId(), rewriter);
836 Value count = truncToI32(b, adaptor.getCount());
837 if (isMbarrierShared(mbarrierType)) {
838 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
839 op, barrier, count, adaptor.getPredicate());
840 } else {
841 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
842 adaptor.getPredicate());
843 }
844 return success();
845 }
846};
847
848/// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
849struct NVGPUMBarrierArriveLowering
850 : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
851 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
852 LogicalResult
853 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
854 ConversionPatternRewriter &rewriter) const override {
855 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
856 Value barrier =
857 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
858 adaptor.getMbarId(), rewriter);
859 Type tokenType = getTypeConverter()->convertType(
860 nvgpu::MBarrierTokenType::get(op->getContext()));
861 if (isMbarrierShared(op.getBarriers().getType())) {
862 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
863 barrier);
864 } else {
865 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
866 barrier);
867 }
868 return success();
869 }
870};
871
872/// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
873/// `nvvm.mbarrier.arrive.nocomplete`
874struct NVGPUMBarrierArriveNoCompleteLowering
875 : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
876 using MBarrierBasePattern<
877 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
878 LogicalResult
879 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
880 ConversionPatternRewriter &rewriter) const override {
881 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
882 Value barrier =
883 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
884 adaptor.getMbarId(), rewriter);
885 Type tokenType = getTypeConverter()->convertType(
886 nvgpu::MBarrierTokenType::get(op->getContext()));
887 Value count = truncToI32(b, adaptor.getCount());
888 if (isMbarrierShared(op.getBarriers().getType())) {
889 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
890 op, tokenType, barrier, count);
891 } else {
892 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
893 op, tokenType, barrier, count);
894 }
895 return success();
896 }
897};
898
899/// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
900struct NVGPUMBarrierTestWaitLowering
901 : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
902 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
903 LogicalResult
904 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
905 ConversionPatternRewriter &rewriter) const override {
906 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
907 Value barrier =
908 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
909 adaptor.getMbarId(), rewriter);
910 Type retType = rewriter.getI1Type();
911 if (isMbarrierShared(op.getBarriers().getType())) {
912 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
913 op, retType, barrier, adaptor.getToken());
914 } else {
915 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
916 op, retType, barrier, adaptor.getToken());
917 }
918 return success();
919 }
920};
921
922struct NVGPUMBarrierArriveExpectTxLowering
923 : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
924 using MBarrierBasePattern<
925 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
926 LogicalResult
927 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
928 ConversionPatternRewriter &rewriter) const override {
929 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
930 Value barrier =
931 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
932 adaptor.getMbarId(), rewriter);
933 Value txcount = truncToI32(b, adaptor.getTxcount());
934
935 if (isMbarrierShared(op.getBarriers().getType())) {
936 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
937 op, barrier, txcount, adaptor.getPredicate());
938 return success();
939 }
940
941 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
942 op, barrier, txcount, adaptor.getPredicate());
943 return success();
944 }
945};
946
947struct NVGPUMBarrierTryWaitParityLowering
948 : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
949 using MBarrierBasePattern<
950 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
951 LogicalResult
952 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
953 ConversionPatternRewriter &rewriter) const override {
954 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
955 Value barrier =
956 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
957 adaptor.getMbarId(), rewriter);
958 Value ticks = truncToI32(b, adaptor.getTicks());
959 Value phase =
960 b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
961
962 if (isMbarrierShared(op.getBarriers().getType())) {
963 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
964 op, barrier, phase, ticks);
965 return success();
966 }
967
968 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
969 phase, ticks);
970 return success();
971 }
972};
973
974struct NVGPUTmaAsyncLoadOpLowering
975 : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
976 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
977 LogicalResult
978 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
979 ConversionPatternRewriter &rewriter) const override {
980 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
981 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
982 Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
983 adaptor.getDst(), {}, rewriter);
984 Value barrier =
985 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
986 adaptor.getMbarId(), rewriter);
987
988 SmallVector<Value> coords = adaptor.getCoordinates();
989 for (auto [index, value] : llvm::enumerate(coords)) {
990 coords[index] = truncToI32(b, value);
991 }
992 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
993 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
994 ValueRange{}, adaptor.getMulticastMask(), Value{},
995 adaptor.getPredicate());
996 return success();
997 }
998};
999
1000struct NVGPUTmaAsyncStoreOpLowering
1001 : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1002 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1003 LogicalResult
1004 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1005 ConversionPatternRewriter &rewriter) const override {
1006 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1007 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1008 Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
1009 adaptor.getSrc(), {}, rewriter);
1010 SmallVector<Value> coords = adaptor.getCoordinates();
1011 for (auto [index, value] : llvm::enumerate(coords)) {
1012 coords[index] = truncToI32(b, value);
1013 }
1014
1015 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1016 op, adaptor.getTensorMapDescriptor(), dest, coords,
1017 adaptor.getPredicate());
1018 return success();
1019 }
1020};
1021
1022struct NVGPUGenerateWarpgroupDescriptorLowering
1023 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1024 using ConvertOpToLLVMPattern<
1025 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1026
1027 LogicalResult
1028 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1029 ConversionPatternRewriter &rewriter) const override {
1030
1031 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1032
1033 nvgpu::TensorMapSwizzleKind swizzleKind =
1034 op.getTensorMap().getType().getSwizzle();
1035
1036 unsigned layout =
1037 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1038 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1039 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1040 : 1;
1041 unsigned swizzle =
1042 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1043 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1044 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1045 : 0;
1046
1047 auto ti64 = b.getIntegerType(64);
1048 auto makeConst = [&](uint64_t index) -> Value {
1049 return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
1050 };
1051 auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1052 return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1053 };
1054 auto shiftRight = [&](Value value, unsigned shift) -> Value {
1055 return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1056 };
1057 auto insertBit = [&](Value desc, Value val, int startBit) {
1058 return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1059 };
1060
1061 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1062 uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1063 uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1064 uint64_t offsetVal = 0;
1065
1066 Value strideDim = makeConst(strideDimVal);
1067 Value leadDim = makeConst(leadDimVal);
1068
1069 Value baseAddr = getStridedElementPtr(
1070 op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1071 adaptor.getTensor(), {}, rewriter);
1072 Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
1073 // Just use 14 bits for base address
1074 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1075
1076 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1077 startLeadBit = 16, startBaseAddrBit = 0;
1078 Value dsc = makeConst(0);
1079 // // [62,64) swizzle type
1080 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1081 // // [49,52) base_offset
1082 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1083 // // [32,46) stride
1084 dsc = insertBit(dsc, strideDim, startStrideBit);
1085 // // [16,30) leading dimension
1086 dsc = insertBit(dsc, leadDim, startLeadBit);
1087 // // [0,14) start_address
1088 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1089
1090 LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1091 << "leading_off:" << leadDimVal << "\t"
1092 << "stride_off :" << strideDimVal << "\t"
1093 << "base_offset:" << offsetVal << "\t"
1094 << "layout_type:" << swizzle << " ("
1095 << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1096 << ")\n start_addr : " << baseAddr << "\n");
1097
1098 rewriter.replaceOp(op, dsc);
1099 return success();
1100 }
1101};
1102
1103static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1104 return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
1105 b.getI32IntegerAttr(index));
1106}
1107
1108/// Returns a Value that holds data type enum that is expected by CUDA driver.
1109static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1110 // Enum is from CUDA driver API
1111 // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1112 enum CUtensorMapDataTypeEnum {
1113 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1114 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1115 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1116 CU_TENSOR_MAP_DATA_TYPE_INT32,
1117 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1118 CU_TENSOR_MAP_DATA_TYPE_INT64,
1119 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1120 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1121 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1122 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1123 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1124 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1125 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1126 };
1127
1128 if (type.isUnsignedInteger(width: 8))
1129 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT8);
1130 if (type.isUnsignedInteger(width: 16))
1131 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT16);
1132 if (type.isUnsignedInteger(width: 32))
1133 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT32);
1134 if (type.isUnsignedInteger(width: 64))
1135 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT64);
1136 if (type.isSignlessInteger(width: 32))
1137 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_INT32);
1138 if (type.isSignlessInteger(width: 64))
1139 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_INT64);
1140 if (type.isF16())
1141 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1142 if (type.isF32())
1143 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1144 if (type.isF64())
1145 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1146 if (type.isBF16())
1147 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1148
1149 llvm_unreachable("Not supported data type");
1150}
1151
1152struct NVGPUTmaCreateDescriptorOpLowering
1153 : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1154 using ConvertOpToLLVMPattern<
1155 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1156 LogicalResult
1157 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1158 ConversionPatternRewriter &rewriter) const override {
1159 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1160 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1161 Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1162
1163 Value tensorElementType =
1164 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1165 auto promotedOperands = getTypeConverter()->promoteOperands(
1166 b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1167
1168 Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1169 makeI64Const(b, 5));
1170 for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1171 Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1172 boxArrayPtr, makeI64Const(b, index));
1173 b.create<LLVM::StoreOp>(value, gep);
1174 }
1175
1176 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1177 // Set Arguments for the function call
1178 SmallVector<Value> arguments;
1179 arguments.push_back(Elt: promotedOperands[0]); // rank
1180 arguments.push_back(Elt: promotedOperands[1]); // descriptor
1181 arguments.push_back(Elt: tensorElementType); // data type
1182 arguments.push_back(
1183 Elt: makeI64Const(b, index: (int)desc.getInterleave())); // interleave
1184 arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getSwizzle())); // swizzle
1185 arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getL2promo())); // l2promo
1186 arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getOob())); // oob
1187 arguments.push_back(Elt: boxArrayPtr); // box dimensions
1188
1189 // Set data types of the arguments
1190 SmallVector<Type> argTypes = {
1191 llvmInt64Type, /* int64_t tensorRank */
1192 llvmPointerType, /* ptr */
1193 llvmInt64Type, /* int64_t */
1194 llvmInt64Type, /* int64_t */
1195 llvmInt64Type, /* int64_t */
1196 llvmInt64Type, /* int64_t */
1197 llvmInt64Type, /* int64_t */
1198 llvmPointerType /* ptr */
1199 };
1200 FunctionCallBuilder hostRegisterCallBuilder = {
1201 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1202 Value tensorMap =
1203 hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1204
1205 rewriter.replaceOp(op, tensorMap);
1206 return success();
1207 }
1208};
1209
1210struct NVGPUWarpgroupMmaOpLowering
1211 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1212 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1213
1214 /// This is a helper class to generate required NVVM Ops for warp-group level
1215 /// matrix multiplication.
1216 /// When the given GEMM shape is larger than the shape of
1217 /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1218 /// Op(s), group and execute them asynchronously. The class also handles
1219 /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1220 /// create descriptors for each instruction.
1221 ///
1222 /// For example this is the case when the shape of GEMM is 128x128x128
1223 ///
1224 /// nvvm.wgmma.fence.aligned
1225 ///
1226 /// nvvm.wgmma.mma.async descA, descB
1227 /// iterate(descA, descB)
1228 /// nvvm.wgmma.mma.async descA, descB
1229 /// [6x times more]
1230 ///
1231 /// nvvm.wgmma.group.sync.aligned
1232 /// nvvm.wgmma.wait.group.sync [groupId]
1233 ///
1234 class WarpgroupGemm {
1235 nvgpu::WarpgroupMmaOp op;
1236 ImplicitLocOpBuilder b;
1237 OpAdaptor adaptor;
1238
1239 // Entire shape of the given Op
1240 int64_t totalM, totalN, totalK;
1241
1242 // Shape of one wgmma instruction
1243 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1244
1245 // Iteration counts for GEMM
1246 int iterationM = 0, iterationN = 0, iterationK = 0;
1247
1248 /// The function returns the shape of wgmma instruction that is defined in
1249 /// PTX programming guide.
1250 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1251 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1252 wgmmaM = 64;
1253 wgmmaN = sizeN;
1254 if (inputElemType.isTF32()) {
1255 wgmmaK = 8;
1256 } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1257 wgmmaK = 16;
1258 } else if (inputElemType.isFloat8E4M3FN() ||
1259 inputElemType.isFloat8E5M2() || inputElemType.isInteger(width: 16)) {
1260 wgmmaK = 32;
1261 } else if (inputElemType.isInteger(width: 1)) {
1262 wgmmaK = 256;
1263 } else {
1264 llvm_unreachable("msg: not supported K shape");
1265 }
1266 LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1267 << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1268 }
1269
1270 /// Generates WGMMATypesAttr from MLIR Type
1271 NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1272 bool useF32 = false) const {
1273 auto getWgmmaType = [=](Type elemType) {
1274 if (elemType.isF32() || elemType.isTF32())
1275 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1276 if (elemType.isF16())
1277 return NVVM::WGMMATypes::f16;
1278 if (elemType.isBF16())
1279 return NVVM::WGMMATypes::bf16;
1280 if (elemType.isFloat8E4M3FN())
1281 return NVVM::WGMMATypes::e4m3;
1282 if (elemType.isFloat8E5M2())
1283 return NVVM::WGMMATypes::e5m2;
1284 if (elemType.isInteger(1))
1285 return NVVM::WGMMATypes::b1;
1286 if (elemType.isInteger(8))
1287 return NVVM::WGMMATypes::s8;
1288 if (elemType.isUnsignedInteger(8))
1289 return NVVM::WGMMATypes::u8;
1290 if (elemType.isInteger(32))
1291 return NVVM::WGMMATypes::s32;
1292 llvm_unreachable("unsupported type");
1293 };
1294 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1295 }
1296
1297 /// Generates layout attribute for the input matrix for wgmma instruction
1298 NVVM::MMALayoutAttr
1299 generateWgmmaLayout(std::optional<bool> transpose) const {
1300 if (transpose.value_or(false))
1301 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1302 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1303 }
1304
1305 /// Generates shape attribute for wgmma instruction
1306 NVVM::MMAShapeAttr generateWgmmaShape() const {
1307 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1308 }
1309
1310 /// Generates scale attributes of output matrix for wgmma instruction
1311 NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1312 return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1313 NVVM::WGMMAScaleOut::one);
1314 }
1315 /// Generates scale attributes of input matrix for wgmma instruction
1316 NVVM::WGMMAScaleInAttr generateScaleIn() const {
1317 return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1318 NVVM::WGMMAScaleIn::one);
1319 }
1320
1321 /// Basic function to generate Add
1322 Value makeAdd(Value lhs, Value rhs) {
1323 return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1324 };
1325
1326 /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1327 /// Currently, it only handles row-major.
1328 ///
1329 /// It moves the pointer like below for [128][64] size:
1330 /// +2 +4 +6
1331 /// ↓ ↓ ↓
1332 /// descA ---> +--+--+--+--+
1333 /// |->|->|->|->|
1334 /// | | | | |
1335 /// | | | | |
1336 /// | | | | |
1337 /// descA+512---> +-----------+
1338 /// | | | | |
1339 /// | | | | |
1340 /// | | | | |
1341 /// | | | | |
1342 /// +-----------+
1343 ///
1344 Value iterateDescriptorA(Value desc, int i, int j, int k) {
1345 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1346 Type elemA = matrixTypeA.getElementType();
1347 int byte = elemA.getIntOrFloatBitWidth() / 8;
1348 int tileShapeA = matrixTypeA.getDimSize(1);
1349 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1350 incrementVal = incrementVal >> exclude4LSB;
1351 LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1352 << "] [wgmma descriptors] Descriptor A + "
1353 << incrementVal << " | \t ");
1354 if (!incrementVal)
1355 return desc;
1356 return makeAdd(lhs: desc, rhs: makeI64Const(b, index: incrementVal));
1357 }
1358
1359 /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1360 /// Currently, it only handles column-major.
1361 ///
1362 /// It moves the pointer like below for [128][64] size:
1363 /// descB ---> +--+--+--+--+--+--+--+--+
1364 /// |↓ | | | | | | | |
1365 /// |↓ | | | | | | | |
1366 /// |↓ | | | | | | | |
1367 /// |↓ | | | | | | | |
1368 /// +--+--+--+--+--+--+--+--+
1369 ///
1370 Value iterateDescriptorB(Value desc, int i, int j, int k) {
1371 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1372 Type elemB = matrixTypeB.getElementType();
1373 int byte = elemB.getIntOrFloatBitWidth() / 8;
1374 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1375 incrementVal = incrementVal >> exclude4LSB;
1376 LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1377 if (!incrementVal)
1378 return desc;
1379 return makeAdd(lhs: desc, rhs: makeI64Const(b, index: incrementVal));
1380 }
1381
1382 /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1383 /// descriptors and arranges them based on induction variables: i, j, and k.
1384 Value generateWgmma(int i, int j, int k, Value matrixC) {
1385 LLVM_DEBUG(DBGS() << "\t wgmma."
1386 << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1387 << "(A[" << (iterationM * wgmmaM) << ":"
1388 << (iterationM * wgmmaM) + wgmmaM << "]["
1389 << (iterationK * wgmmaK) << ":"
1390 << (iterationK * wgmmaK + wgmmaK) << "] * "
1391 << " B[" << (iterationK * wgmmaK) << ":"
1392 << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1393 << wgmmaN << "])\n");
1394
1395 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1396 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1397
1398 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1399 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1400
1401 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1402 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1403
1404 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1405 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1406
1407 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1408 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1409 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1410 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1411 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1412
1413 auto overflow = NVVM::MMAIntOverflowAttr::get(
1414 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1415
1416 return b.create<NVVM::WgmmaMmaAsyncOp>(
1417 matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1418 itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1419 overflow);
1420 }
1421
1422 /// Generates multiple wgmma instructions to complete the given GEMM shape
1423 Value generateWgmmaGroup() {
1424 Value wgmmaResult =
1425 b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
1426
1427 // Perform GEMM
1428 SmallVector<Value> wgmmaResults;
1429 for (int i = 0; i < iterationM; ++i) {
1430 Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1431 for (int j = 0; j < iterationN; ++j)
1432 for (int k = 0; k < iterationK; ++k)
1433 matrixC = generateWgmma(i, j, k, matrixC);
1434 wgmmaResults.push_back(Elt: matrixC);
1435 }
1436 for (auto [idx, matrix] : llvm::enumerate(First&: wgmmaResults)) {
1437 wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
1438 wgmmaResult, matrix, idx);
1439 }
1440 return wgmmaResult;
1441 }
1442
1443 public:
1444 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1445 OpAdaptor adaptor)
1446 : op(op), b(b), adaptor(adaptor) {
1447 // Find the entire GEMM Shape
1448 totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1449 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1450 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1451 LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1452 << "] += A[" << totalM << "][" << totalK << "] * B["
1453 << totalK << "][" << totalN << "] ---===\n");
1454
1455 // Find the shape for one wgmma instruction
1456 findWgmmaShape(
1457 sizeM: totalM, sizeN: totalN,
1458 inputElemType: op.getDescriptorA().getType().getTensor().getElementType());
1459
1460 // Iterations counts to complete the given shape with wgmma shape
1461 iterationM = totalM / wgmmaM;
1462 iterationN = totalN / wgmmaN;
1463 iterationK = totalK / wgmmaK;
1464 }
1465
1466 /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1467 /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1468 /// instructions and group synchronization, as well as waiting
1469 /// (WgmmaGroupSyncAlignedOp) for group synchronization
1470 /// (WgmmaWaitGroupSyncOp) after the instructions.
1471 Value generateWarpgroupMma() {
1472 b.create<NVVM::WgmmaFenceAlignedOp>();
1473 Value wgmmaResult = generateWgmmaGroup();
1474 b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1475 b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1476 return wgmmaResult;
1477 }
1478 };
1479 LogicalResult
1480 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1481 ConversionPatternRewriter &rewriter) const override {
1482 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1483
1484 // Step 1. Build a helper class
1485 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1486
1487 // Step 2. Get the entire GEMM Shape
1488 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1489
1490 // Step 3. Replace fragmented result struct with the op results
1491 rewriter.replaceOp(op, wgmmaResult);
1492 return success();
1493 }
1494};
1495
1496struct NVGPUWarpgroupMmaStoreOpLowering
1497 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1498 using ConvertOpToLLVMPattern<
1499 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1500
1501 /// This function stores a fragmented register matrix owned by a warp group
1502 /// (128 threads) into a memref. Each thread has 64 registers, each the size
1503 /// of a struct.
1504 /// Here is what each threads (T) holds, each `d` is struct value with a
1505 /// number.
1506 ///
1507 /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1508 /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1509 /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1510 /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1511 /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1512 ///
1513 /// Matrix-D:
1514 /// +______________________________________________________________________+
1515 /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1516 /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1517 /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1518 /// ..| .........|.........|.........|.........|........|...........|........|
1519 /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1520 /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1521 /// ..| .........|.........|.........|.........|........|...........|........|
1522 /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1523 /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1524 /// ..| .........|.........|.........|.........|........|...........|........|
1525 /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1526 /// ..| .........|.........|.........|.........|........|...........|........|
1527 /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1528 /// ..| .........|.........|.........|.........|........|...........|........|
1529 /// +______________________________________________________________________+
1530 ///
1531 /// \param rewriter: The pattern rewriter.
1532 /// \param matrixD: Result of the warp-group MMA operation (fragmented
1533 /// matrix). It is holded by a thread and a struct with 64 elements.
1534 /// \param dstMemref: The memref where the registers will be stored.
1535 /// \param offset: the offset within the memref where the registers will be
1536 /// stored.
1537 void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1538 TypedValue<MemRefType> dstMemref,
1539 int offset) const {
1540 Type i32 = b.getI32Type();
1541
1542 auto makeConst = [&](int32_t index) -> Value {
1543 return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1544 };
1545 Value c1 = makeConst(1);
1546 Value c2 = makeConst(2);
1547 Value c4 = makeConst(4);
1548 Value c8 = makeConst(8);
1549 Value c16 = makeConst(16);
1550 Value warpSize = makeConst(kWarpSize);
1551
1552 auto makeMul = [&](Value lhs, Value rhs) -> Value {
1553 return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1554 };
1555 auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1556 return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1557 };
1558
1559 auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1560 TypedValue<::mlir::MemRefType> memref) {
1561 Type it = b.getIndexType();
1562 Value idx = b.create<arith::IndexCastOp>(it, x);
1563 Value idy0 = b.create<arith::IndexCastOp>(it, y);
1564 Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1565 Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1566 Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1567 b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1568 b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1569 };
1570
1571 Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1572 Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1573 Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1574 Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1575 Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1576
1577 Value tj = makeMul(lane4modId, c2);
1578 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1579 if (offset)
1580 ti = makeAdd(ti, makeConst(offset));
1581
1582 auto structType = cast<LLVM::LLVMStructType>(Val: matrixD.getType());
1583
1584 // Number of 32-bit registers owns per thread
1585 constexpr unsigned numAdjacentRegisters = 2;
1586 // Number of 8x8 matrices one below another per warp
1587 constexpr unsigned numStackedMatrices = 2;
1588
1589 size_t storeCount = (structType.getBody().size() /
1590 (numStackedMatrices * numAdjacentRegisters));
1591
1592 for (size_t i = 0; i < numStackedMatrices; ++i) {
1593 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1594 for (size_t j = 0; j < storeCount; ++j) {
1595 Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1596 size_t structIndex = (i * numAdjacentRegisters) +
1597 (j * (numStackedMatrices * numAdjacentRegisters));
1598 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1599 }
1600 }
1601 }
1602
1603 LogicalResult
1604 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1605 ConversionPatternRewriter &rewriter) const override {
1606 int offset = 0;
1607 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1608 Value matriDValue = adaptor.getMatrixD();
1609 auto stype = cast<LLVM::LLVMStructType>(Val: matriDValue.getType());
1610 for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1611 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1612 Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
1613 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1614 offset += structType.getBody().size();
1615 }
1616 rewriter.eraseOp(op: op);
1617 return success();
1618 }
1619};
1620
1621struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1622 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1623 using ConvertOpToLLVMPattern<
1624 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1625 LogicalResult
1626 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1627 ConversionPatternRewriter &rewriter) const override {
1628 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1629 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1630 getTypeConverter()->convertType(op.getMatrixC().getType()));
1631 Type elemType = cast<LLVM::LLVMStructType>(Val: packStructType.getBody().front())
1632 .getBody()
1633 .front();
1634 Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1635 Value packStruct = b.create<LLVM::UndefOp>(packStructType);
1636 SmallVector<Value> innerStructs;
1637 // Unpack the structs and set all values to zero
1638 for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1639 auto structType = cast<LLVM::LLVMStructType>(s);
1640 Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1641 for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1642 structValue = b.create<LLVM::InsertValueOp>(
1643 structType, structValue, zero, ArrayRef<int64_t>({i}));
1644 }
1645 innerStructs.push_back(structValue);
1646 }
1647 // Pack the inner structs into a single struct
1648 for (auto [idx, matrix] : llvm::enumerate(First&: innerStructs)) {
1649 packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1650 packStruct, matrix, idx);
1651 }
1652 rewriter.replaceOp(op, packStruct);
1653 return success();
1654 }
1655};
1656
1657struct NVGPUTmaPrefetchOpLowering
1658 : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1659 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1660 LogicalResult
1661 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1662 ConversionPatternRewriter &rewriter) const override {
1663 rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1664 op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1665 return success();
1666 }
1667};
1668
1669} // namespace
1670
1671void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
1672 RewritePatternSet &patterns) {
1673 patterns.add<
1674 NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1675 NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1676 NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1677 NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1678 NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1679 NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1680 NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1681 NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1682 NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1683 NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1684 NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1685 NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1686 NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1687 NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1688 NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1689 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1690 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1691 NVGPUMmaSparseSyncLowering>(arg&: converter);
1692}
1693

source code of mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp