1//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
10
11#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13#include "mlir/Conversion/LLVMCommon/Pattern.h"
14#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
19#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
22#include "mlir/Dialect/SCF/Transforms/Patterns.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/ImplicitLocOpBuilder.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "mlir/IR/Value.h"
28#include "mlir/Pass/Pass.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
32#include <optional>
33
34#define DEBUG_TYPE "nvgpu-to-nvvm"
35#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36#define DBGSE() (llvm::dbgs())
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44
45/// Number of bits that needs to be excluded when building matrix descriptor for
46/// wgmma operations.
47constexpr int exclude4LSB = 4;
48
49/// GPU has 32 bit registers, this function truncates values when larger width
50/// is not needed.
51static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
52 Type type = value.getType();
53 assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
54 if (type.getIntOrFloatBitWidth() <= 32)
55 return value;
56 return b.create<LLVM::TruncOp>(b.getI32Type(), value);
57}
58
59/// Returns the type for the intrinsic given the vectorResultType of the
60/// `gpu.mma.sync` operation.
61static Type inferIntrinsicResultType(Type vectorResultType) {
62 MLIRContext *ctx = vectorResultType.getContext();
63 auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
64 auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
65 auto i32Ty = IntegerType::get(ctx, 32);
66 auto i32x2Ty = VectorType::get(2, i32Ty);
67 Type f64Ty = Float64Type::get(ctx);
68 Type f64x2Ty = VectorType::get(2, f64Ty);
69 Type f32Ty = Float32Type::get(ctx);
70 Type f32x2Ty = VectorType::get(2, f32Ty);
71 if (a.getElementType() == f16x2Ty) {
72 return LLVM::LLVMStructType::getLiteral(
73 ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
74 }
75 if (a.getElementType() == i32x2Ty) {
76 return LLVM::LLVMStructType::getLiteral(
77 ctx,
78 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
79 }
80 if (a.getElementType() == f64x2Ty) {
81 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
82 }
83 if (a.getElementType() == f32x2Ty) {
84 return LLVM::LLVMStructType::getLiteral(
85 ctx,
86 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
87 }
88 if (a.getElementType() == VectorType::get(1, f32Ty)) {
89 return LLVM::LLVMStructType::getLiteral(
90 ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
91 }
92 return vectorResultType;
93}
94
95/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
96/// always an LLVM struct) into a fragment that is compatible with the vector
97/// type of this operation. This involves extracting elements from the struct
98/// and inserting them into an LLVM array. These extra data-movement
99/// operations should be canonicalized away by the LLVM backend.
100static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
101 Type resultType, Value intrinsicResult,
102 RewriterBase &rewriter) {
103 MLIRContext *ctx = rewriter.getContext();
104 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
105 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
106 Type i32Ty = rewriter.getI32Type();
107 Type f32Ty = rewriter.getF32Type();
108 Type f64Ty = rewriter.getF64Type();
109 Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
110 Type i32x2Ty = VectorType::get(2, i32Ty);
111 Type f64x2Ty = VectorType::get(2, f64Ty);
112 Type f32x2Ty = VectorType::get(2, f32Ty);
113 Type f32x1Ty = VectorType::get(1, f32Ty);
114
115 auto makeConst = [&](int32_t index) -> Value {
116 return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
117 rewriter.getI32IntegerAttr(index));
118 };
119
120 if (arrayType) {
121 SmallVector<Value, 4> elements;
122
123 // The intrinsic returns 32-bit wide elements in a form which can be
124 // directly bitcasted and inserted into the result vector.
125 if (arrayType.getElementType() == f16x2Ty ||
126 arrayType.getElementType() == f32x1Ty) {
127 for (unsigned i = 0; i < structType.getBody().size(); i++) {
128 Value el =
129 rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i);
130 el = rewriter.createOrFold<LLVM::BitcastOp>(
131 loc, arrayType.getElementType(), el);
132 elements.push_back(Elt: el);
133 }
134 }
135
136 // The intrinsic returns i32, f64, and f32 values as individual scalars,
137 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
138 // need to extract them from the struct and pack them into the 64-bit wide
139 // rows of the vector result.
140 if (arrayType.getElementType() == i32x2Ty ||
141 arrayType.getElementType() == f64x2Ty ||
142 arrayType.getElementType() == f32x2Ty) {
143
144 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
145 Value vec =
146 rewriter.create<LLVM::PoisonOp>(loc, arrayType.getElementType());
147 Value x1 =
148 rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2);
149 Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult,
150 i * 2 + 1);
151 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
152 x1, makeConst(0));
153 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
154 x2, makeConst(1));
155 elements.push_back(Elt: vec);
156 }
157 }
158
159 // Create the final vectorized result.
160 Value result = rewriter.create<LLVM::PoisonOp>(loc, arrayType);
161 for (const auto &el : llvm::enumerate(First&: elements)) {
162 result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(),
163 el.index());
164 }
165 return result;
166 }
167
168 return intrinsicResult;
169}
170
171/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
172/// given as 2D `vectors` where the rows are 32b or 64b wide. The
173/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
174/// scalars of certain types. This function helps unpack the `vector` arguments
175/// and cast them to the types expected by `nvvm.mma.sync`.
176static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
177 Value operand,
178 NVVM::MMATypes operandPtxType) {
179 SmallVector<Value> result;
180 Type i32Ty = b.getI32Type();
181 Type f64Ty = b.getF64Type();
182 Type f32Ty = b.getF32Type();
183 Type i64Ty = b.getI64Type();
184 Type i8x4Ty = VectorType::get(4, b.getI8Type());
185 Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
186 Type f32x1Ty = VectorType::get(1, f32Ty);
187 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
188
189 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
190 Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
191
192 // For 4xi8 vectors, the intrinsic expects these to be provided as i32
193 // scalar types.
194 if (arrayTy.getElementType() == i8x4Ty ||
195 arrayTy.getElementType() == i4x8Ty ||
196 (arrayTy.getElementType() == f32x1Ty &&
197 operandPtxType == NVVM::MMATypes::tf32)) {
198 result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
199 continue;
200 }
201
202 // For some element types (i32, f32, f64), we need to unpack the inner
203 // vector/array type as well because the intrinsic expects individual
204 // scalars to be provided.
205 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
206 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
207 innerArrayTy.getElementType() == f64Ty ||
208 innerArrayTy.getElementType() == f32Ty)) {
209 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
210 idx < innerSize; idx++) {
211 result.push_back(b.create<LLVM::ExtractElementOp>(
212 toUse,
213 b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
214 }
215 continue;
216 }
217 result.push_back(Elt: toUse);
218 }
219 return result;
220}
221
222/// Returns whether mbarrier object has shared memory address space.
223static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) {
224 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
225 barrierType.getMemorySpace()));
226}
227
228/// Returns the memory space attribute of the mbarrier object.
229Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context,
230 nvgpu::MBarrierGroupType barrierType) {
231 Attribute memorySpace = {};
232 if (isMbarrierShared(barrierType)) {
233 memorySpace =
234 IntegerAttr::get(IntegerType::get(context, 64),
235 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace);
236 }
237 return memorySpace;
238}
239
240/// Returns memref type of the mbarrier object. The type is defined in the
241/// MBarrierGroupType.
242MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context,
243 nvgpu::MBarrierGroupType barrierType) {
244 Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType: barrierType);
245 MemRefLayoutAttrInterface layout;
246 return MemRefType::get({barrierType.getNumBarriers()},
247 IntegerType::get(context, 64), layout, memorySpace);
248}
249
250namespace {
251
252struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
253 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
254
255 LogicalResult
256 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
257 ConversionPatternRewriter &rewriter) const override {
258 MLIRContext *ctx = getContext();
259 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
260
261 // The result type of ldmatrix will always be a struct of 32bit integer
262 // registers if more than one 32bit value is returned. Otherwise, the result
263 // is a single i32. The result type of the GPU operation is always a vector
264 // of shape (NumRegisters, VectorRegister) where VectorRegister is the
265 // vector type of the result and always 32 bits long. We bitcast the result
266 // of the NVVM::LdMatrix to this vector type.
267 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
268 if (!vectorResultType) {
269 return failure();
270 }
271 Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
272 vectorResultType.getElementType());
273
274 int64_t num32BitRegs = vectorResultType.getDimSize(0);
275
276 Type ldMatrixResultType;
277 if (num32BitRegs > 1) {
278 ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
279 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
280 } else {
281 ldMatrixResultType = rewriter.getI32Type();
282 }
283
284 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
285 Value srcPtr =
286 getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
287 adaptor.getSrcMemref(), adaptor.getIndices());
288 Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
289 ldMatrixResultType, srcPtr,
290 /*num=*/op.getNumTiles(),
291 /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
292 : NVVM::MMALayout::row);
293
294 // The ldmatrix operation returns either a single i32 value or a struct of
295 // i32 values. Here we unpack those values and cast them back to their
296 // actual vector type (still of width 32b) and repack them into a result
297 // struct.
298 Type finalResultType = typeConverter->convertType(vectorResultType);
299 Value result = b.create<LLVM::PoisonOp>(finalResultType);
300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
301 Value i32Register =
302 num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
303 : ldMatrixResult;
304 Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
305 result = b.create<LLVM::InsertValueOp>(result, casted, i);
306 }
307
308 rewriter.replaceOp(op, result);
309 return success();
310 }
311};
312
313/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
314/// enum).
315static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
316 Type elType = getElementTypeOrSelf(type: t);
317 if (elType.isInteger(8))
318 return NVVM::MMATypes::s8;
319 if (elType.isInteger(4))
320 return NVVM::MMATypes::s4;
321 if (elType.isF16())
322 return NVVM::MMATypes::f16;
323 if (elType.isF64())
324 return NVVM::MMATypes::f64;
325 if (elType.isF32())
326 return NVVM::MMATypes::tf32;
327 return failure();
328}
329
330struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
331 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
332
333 LogicalResult
334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override {
336 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
337 // Get the shapes of the MMAMatrix type being used. The shapes will
338 // choose which intrinsic this op will be lowered to.
339 VectorType aType = op.getMatrixA().getType();
340 VectorType bType = op.getMatrixA().getType();
341 VectorType cType = op.getMatrixC().getType();
342
343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
344
345 // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
347 if (aType.getElementType().isF32() && !tf32Enabled)
348 return failure();
349
350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
351 if (failed(ptxTypeA))
352 return op->emitOpError("failed to deduce operand PTX types");
353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
354 if (failed(ptxTypeB))
355 return op->emitOpError("failed to deduce operand PTX types");
356 std::optional<NVVM::MMATypes> ptxTypeC =
357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
358 /*isAccumulator=*/true);
359 if (!ptxTypeC)
360 return op->emitError(
361 "could not infer the PTX type for the accumulator/result");
362
363 // TODO: add an attribute to the op to customize this behavior.
364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
365 if (isa<IntegerType>(aType.getElementType()))
366 overflow = NVVM::MMAIntOverflow::satfinite;
367
368 SmallVector<Value> matA =
369 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
370 SmallVector<Value> matB =
371 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
372 SmallVector<Value> matC =
373 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
374
375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
376 Type intrinsicResTy = inferIntrinsicResultType(
377 typeConverter->convertType(op->getResultTypes()[0]));
378 Value intrinsicResult = b.create<NVVM::MmaOp>(
379 intrinsicResTy, matA, matB, matC,
380 /*shape=*/gemmShape,
381 /*b1Op=*/std::nullopt,
382 /*intOverflow=*/overflow,
383 /*multiplicandPtxTypes=*/
384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
385 /*multiplicandLayouts=*/
386 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
387 NVVM::MMALayout::col});
388 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
389 desiredRetTy, intrinsicResult,
390 rewriter));
391 return success();
392 }
393};
394
395struct ConvertNVGPUToNVVMPass
396 : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> {
397 using Base::Base;
398
399 void getDependentDialects(DialectRegistry &registry) const override {
400 registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
401 arith::ArithDialect>();
402 }
403
404 void runOnOperation() override {
405 LowerToLLVMOptions options(&getContext());
406 RewritePatternSet patterns(&getContext());
407 LLVMTypeConverter converter(&getContext(), options);
408 IRRewriter rewriter(&getContext());
409 populateGpuMemorySpaceAttributeConversions(
410 typeConverter&: converter, mapping: [](gpu::AddressSpace space) -> unsigned {
411 switch (space) {
412 case gpu::AddressSpace::Global:
413 return static_cast<unsigned>(
414 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
415 case gpu::AddressSpace::Workgroup:
416 return static_cast<unsigned>(
417 NVVM::NVVMMemorySpace::kSharedMemorySpace);
418 case gpu::AddressSpace::Private:
419 return 0;
420 }
421 llvm_unreachable("unknown address space enum value");
422 return 0;
423 });
424 /// device-side async tokens cannot be materialized in nvvm. We just
425 /// convert them to a dummy i32 type in order to easily drop them during
426 /// conversion.
427 converter.addConversion(callback: [&](nvgpu::DeviceAsyncTokenType type) -> Type {
428 return converter.convertType(IntegerType::get(type.getContext(), 32));
429 });
430 converter.addConversion(callback: [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
431 Type elemType = type.getFragmented().getElementType();
432 int64_t sizeM = type.getFragmented().getDimSize(0);
433 int64_t sizeN = type.getFragmented().getDimSize(1);
434
435 unsigned numMembers;
436 if (elemType.isF32() || elemType.isInteger(width: 32))
437 numMembers = sizeN / 2;
438 else if (elemType.isF16())
439 numMembers = sizeN / 4;
440 else
441 llvm_unreachable("unsupported type for warpgroup accumulator");
442
443 SmallVector<Type> innerStructBody;
444 for (unsigned i = 0; i < numMembers; i++)
445 innerStructBody.push_back(Elt: elemType);
446 auto innerStructType =
447 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
448
449 SmallVector<Type> structBody;
450 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
451 structBody.push_back(Elt: innerStructType);
452
453 auto convertedType =
454 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
455 return converter.convertType(convertedType);
456 });
457 converter.addConversion(callback: [&](nvgpu::MBarrierTokenType type) -> Type {
458 return converter.convertType(IntegerType::get(type.getContext(), 64));
459 });
460 converter.addConversion(
461 callback: [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
462 return converter.convertType(IntegerType::get(type.getContext(), 64));
463 });
464 converter.addConversion(callback: [&](nvgpu::MBarrierGroupType type) -> Type {
465 return converter.convertType(
466 nvgpu::getMBarrierMemrefType(rewriter.getContext(), type));
467 });
468 converter.addConversion(callback: [&](nvgpu::TensorMapDescriptorType type) -> Type {
469 return LLVM::LLVMPointerType::get(type.getContext());
470 });
471 populateNVGPUToNVVMConversionPatterns(converter, patterns);
472 LLVMConversionTarget target(getContext());
473 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
474 target.addLegalDialect<::mlir::arith::ArithDialect>();
475 target.addLegalDialect<::mlir::memref::MemRefDialect>();
476 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
477 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
478 typeConverter: converter, patterns, target);
479 if (failed(applyPartialConversion(getOperation(), target,
480 std::move(patterns))))
481 signalPassFailure();
482 }
483};
484
485/// Returns the constraints for the sparse MMA inline assembly instruction.
486static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
487 unsigned matBSize,
488 unsigned matCSize) {
489 std::string str;
490 llvm::raw_string_ostream ss(str);
491 for (unsigned i = 0; i < matCSize; i++)
492 ss << "=r,";
493 for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
494 ss << "r,";
495 // The final operand is for the sparsity metadata.
496 // The sparsity selector appears as direct literal.
497 ss << "r";
498 return str;
499}
500
501/// Returns the string for the `mma.sp.sync` instruction that corresponds to
502/// the given parameters. Note that this function doesn't do any validation,
503/// it's expected that the provided parameters correspond to a valid
504/// instruction.
505static std::string buildMmaSparseAsmString(
506 const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
507 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
508 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
509 std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
510 auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
511 return NVVM::stringifyMMATypes(ptxType);
512 };
513
514 std::string asmStr;
515 llvm::raw_string_ostream ss(asmStr);
516 ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
517 << shape[2] << ".row.col.";
518
519 if (overflow)
520 ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
521
522 ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
523 << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
524 unsigned asmArgIdx = 0;
525
526 // The operand string is structured into sections `{matC elements...},
527 // {matA elements...}, {matB elements...}, {matC elements}`.
528 for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
529 ss << "{";
530 for (unsigned i = 0; i < arrSize; i++)
531 ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
532 ss << "},";
533 }
534 ss << "$" << asmArgIdx++ << ",";
535 assert(metaDataSelector <= 1);
536 ss << "0x" << metaDataSelector << ";";
537 return asmStr;
538}
539
540/// Builds an inline assembly operation corresponding to the specified MMA
541/// sparse sync operation.
542static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
543 ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
544 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
545 std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
546 ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
547 int64_t metadataSelector, const std::array<int64_t, 3> &shape,
548 Type intrinsicResultType) {
549 auto asmDialectAttr =
550 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
551
552 const unsigned matASize = unpackedAData.size();
553 const unsigned matBSize = unpackedB.size();
554 const unsigned matCSize = unpackedC.size();
555
556 std::string asmStr = buildMmaSparseAsmString(
557 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
558 ptxTypeD, overflow, metadataSelector);
559 std::string constraintStr =
560 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
561
562 SmallVector<Value> asmVals;
563 asmVals.reserve(N: matASize + matBSize + matCSize + 1);
564 for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
565 llvm::append_range(C&: asmVals, R&: args);
566 asmVals.push_back(Elt: indexData);
567
568 return b.create<LLVM::InlineAsmOp>(
569 /*resultTypes=*/intrinsicResultType,
570 /*operands=*/asmVals,
571 /*asm_string=*/asmStr,
572 /*constraints=*/constraintStr,
573 /*has_side_effects=*/true,
574 /*is_align_stack=*/false, LLVM::TailCallKind::None,
575 /*asm_dialect=*/asmDialectAttr,
576 /*operand_attrs=*/ArrayAttr());
577}
578
579/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
580struct NVGPUMmaSparseSyncLowering
581 : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
582 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
583
584 LogicalResult
585 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter) const override {
587 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
588 // Get the shapes of the MMAMatrix type being used. The shapes will
589 // choose which intrinsic this op will be lowered to.
590 VectorType aType = op.getMatrixA().getType();
591 VectorType bType = op.getMatrixB().getType();
592 VectorType cType = op.getMatrixC().getType();
593
594 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
595 if (failed(ptxTypeA))
596 return op->emitOpError("failed to deduce operand PTX types");
597 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
598 if (failed(ptxTypeB))
599 return op->emitOpError("failed to deduce operand PTX types");
600 std::optional<NVVM::MMATypes> ptxTypeC =
601 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(),
602 /*isAccumulator=*/true);
603 if (!ptxTypeC)
604 return op->emitError(
605 "could not infer the PTX type for the accumulator/result");
606
607 // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
608 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
609 if (aType.getElementType().isF32() && !tf32Enabled)
610 return failure();
611
612 // TODO: add an attribute to the op to customize this behavior.
613 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
614 if (isa<IntegerType>(aType.getElementType()))
615 overflow = NVVM::MMAIntOverflow::satfinite;
616
617 SmallVector<Value> matA =
618 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
619 SmallVector<Value> matB =
620 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
621 SmallVector<Value> matC =
622 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
623
624 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
625 Type intrinsicResTy = inferIntrinsicResultType(
626 typeConverter->convertType(op->getResultTypes()[0]));
627
628 // Bitcast the sparse metadata from vector<2xf16> to an i32.
629 Value sparseMetadata = adaptor.getSparseMetadata();
630 if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
631 return op->emitOpError() << "Expected metadata type to be LLVM "
632 "VectorType of 2 i16 elements";
633 sparseMetadata =
634 b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
635
636 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
637 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
638 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
639 intrinsicResTy);
640 if (failed(intrinsicResult))
641 return failure();
642
643 assert((*intrinsicResult).getNumResults() == 1 &&
644 "expected inline asm op returns a single LLVM struct type");
645 rewriter.replaceOp(
646 op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
647 (*intrinsicResult)->getResult(0), rewriter));
648 return success();
649 }
650};
651
652struct NVGPUAsyncCopyLowering
653 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
654 using ConvertOpToLLVMPattern<
655 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern;
656
657 LogicalResult
658 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
659 ConversionPatternRewriter &rewriter) const override {
660 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
661 Location loc = op.getLoc();
662 auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
663 Value dstPtr =
664 getStridedElementPtr(rewriter, b.getLoc(), dstMemrefType,
665 adaptor.getDst(), adaptor.getDstIndices());
666 FailureOr<unsigned> dstAddressSpace =
667 getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
668 if (failed(Result: dstAddressSpace))
669 return rewriter.notifyMatchFailure(
670 arg&: loc, msg: "destination memref address space not convertible to integer");
671
672 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
673 FailureOr<unsigned> srcAddressSpace =
674 getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
675 if (failed(Result: srcAddressSpace))
676 return rewriter.notifyMatchFailure(
677 arg&: loc, msg: "source memref address space not convertible to integer");
678
679 Value scrPtr =
680 getStridedElementPtr(rewriter, loc, srcMemrefType, adaptor.getSrc(),
681 adaptor.getSrcIndices());
682 // Intrinsics takes a global pointer so we need an address space cast.
683 auto srcPointerGlobalType = LLVM::LLVMPointerType::get(
684 op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace);
685 scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
686 int64_t dstElements = adaptor.getDstElements().getZExtValue();
687 int64_t sizeInBytes =
688 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
689 // When the optional SrcElements argument is *not* present, the regular
690 // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
691 // memory) to fill DstElements number of elements in the destination
692 // (shared memory).
693 Value srcBytes = adaptor.getSrcElements();
694 if (srcBytes) {
695 // When the optional SrcElements argument is present, the source (global
696 // memory) of CpAsyncOp is read only for SrcElements number of elements.
697 // The rest of the DstElements in the destination (shared memory) are
698 // filled with zeros.
699 Value c3I32 =
700 b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
701 Value bitwidth = b.create<LLVM::ConstantOp>(
702 b.getI32Type(),
703 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
704 Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
705 srcBytes = b.create<LLVM::LShrOp>(
706 b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
707 }
708 // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
709 // 16 dst bytes.
710 NVVM::LoadCacheModifierKind cacheModifier =
711 (op.getBypassL1().value_or(false) && sizeInBytes == 16)
712 ? NVVM::LoadCacheModifierKind::CG
713 : NVVM::LoadCacheModifierKind::CA;
714
715 b.create<NVVM::CpAsyncOp>(
716 dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
717 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
718 srcBytes);
719
720 // Drop the result token.
721 Value zero = b.create<LLVM::ConstantOp>(
722 IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
723 rewriter.replaceOp(op, zero);
724 return success();
725 }
726};
727
728struct NVGPUAsyncCreateGroupLowering
729 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> {
730 using ConvertOpToLLVMPattern<
731 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern;
732
733 LogicalResult
734 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
735 ConversionPatternRewriter &rewriter) const override {
736 rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
737 // Drop the result token.
738 Value zero = rewriter.create<LLVM::ConstantOp>(
739 op->getLoc(), IntegerType::get(op.getContext(), 32),
740 rewriter.getI32IntegerAttr(0));
741 rewriter.replaceOp(op, zero);
742 return success();
743 }
744};
745
746struct NVGPUAsyncWaitLowering
747 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> {
748 using ConvertOpToLLVMPattern<
749 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern;
750
751 LogicalResult
752 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
753 ConversionPatternRewriter &rewriter) const override {
754 // If numGroup is not present pick 0 as a conservative correct value.
755 int32_t numGroups = adaptor.getNumGroups().value_or(0);
756 rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
757 rewriter.eraseOp(op: op);
758 return success();
759 }
760};
761
762/// Creates mbarrier object in shared memory
763struct NVGPUMBarrierCreateLowering
764 : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> {
765 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern;
766
767 template <typename moduleT>
768 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter,
769 Operation *funcOp, moduleT moduleOp,
770 MemRefType barrierType) const {
771 SymbolTable symbolTable(moduleOp);
772 OpBuilder::InsertionGuard guard(rewriter);
773 rewriter.setInsertionPoint(&moduleOp.front());
774 auto global = rewriter.create<memref::GlobalOp>(
775 funcOp->getLoc(), "__mbarrier",
776 /*sym_visibility=*/rewriter.getStringAttr("private"),
777 /*type=*/barrierType,
778 /*initial_value=*/ElementsAttr(),
779 /*constant=*/false,
780 /*alignment=*/rewriter.getI64IntegerAttr(8));
781 symbolTable.insert(symbol: global);
782 return global;
783 }
784
785 LogicalResult
786 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor,
787 ConversionPatternRewriter &rewriter) const override {
788 Operation *funcOp = op->getParentOp();
789 MemRefType barrierType = nvgpu::getMBarrierMemrefType(
790 rewriter.getContext(), op.getBarriers().getType());
791
792 memref::GlobalOp global;
793 if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>())
794 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
795 else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>())
796 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType);
797
798 rewriter.setInsertionPoint(op);
799 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType,
800 global.getName());
801 return success();
802 }
803};
804
805/// Base class for lowering mbarrier operations to nvvm intrinsics.
806template <typename SourceOp>
807struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
808public:
809 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
810 /// Returns the base pointer of the mbarrier object.
811 Value getMbarrierPtr(ImplicitLocOpBuilder &b,
812 nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
813 Value mbarId,
814 ConversionPatternRewriter &rewriter) const {
815 MemRefType mbarrierMemrefType =
816 nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
817 return ConvertToLLVMPattern::getStridedElementPtr(
818 rewriter, b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId});
819 }
820};
821
822struct NVGPUMBarrierGetLowering
823 : public MBarrierBasePattern<nvgpu::MBarrierGetOp> {
824 using MBarrierBasePattern<nvgpu::MBarrierGetOp>::MBarrierBasePattern;
825
826 LogicalResult
827 matchAndRewrite(nvgpu::MBarrierGetOp op, OpAdaptor adaptor,
828 ConversionPatternRewriter &rewriter) const override {
829 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
830 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
831 rewriter.setInsertionPoint(op);
832 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
833 adaptor.getMbarId(), rewriter);
834 Type resType = op.getMbarrierPointer().getType();
835 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, resType, barrier);
836 return success();
837 }
838};
839
840/// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
841struct NVGPUMBarrierInitLowering
842 : public MBarrierBasePattern<nvgpu::MBarrierInitOp> {
843 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern;
844
845 LogicalResult
846 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
847 ConversionPatternRewriter &rewriter) const override {
848 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
849 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
850 rewriter.setInsertionPoint(op);
851 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
852 adaptor.getMbarId(), rewriter);
853 Value count = truncToI32(b, adaptor.getCount());
854 if (isMbarrierShared(mbarrierType)) {
855 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
856 op, barrier, count, adaptor.getPredicate());
857 } else {
858 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
859 adaptor.getPredicate());
860 }
861 return success();
862 }
863};
864
865/// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
866struct NVGPUMBarrierArriveLowering
867 : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> {
868 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern;
869 LogicalResult
870 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
871 ConversionPatternRewriter &rewriter) const override {
872 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
873 Value barrier =
874 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
875 adaptor.getMbarId(), rewriter);
876 Type tokenType = getTypeConverter()->convertType(
877 nvgpu::MBarrierTokenType::get(op->getContext()));
878 if (isMbarrierShared(op.getBarriers().getType())) {
879 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType,
880 barrier);
881 } else {
882 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType,
883 barrier);
884 }
885 return success();
886 }
887};
888
889/// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
890/// `nvvm.mbarrier.arrive.nocomplete`
891struct NVGPUMBarrierArriveNoCompleteLowering
892 : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> {
893 using MBarrierBasePattern<
894 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern;
895 LogicalResult
896 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
897 ConversionPatternRewriter &rewriter) const override {
898 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
899 Value barrier =
900 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
901 adaptor.getMbarId(), rewriter);
902 Type tokenType = getTypeConverter()->convertType(
903 nvgpu::MBarrierTokenType::get(op->getContext()));
904 Value count = truncToI32(b, adaptor.getCount());
905 if (isMbarrierShared(op.getBarriers().getType())) {
906 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
907 op, tokenType, barrier, count);
908 } else {
909 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>(
910 op, tokenType, barrier, count);
911 }
912 return success();
913 }
914};
915
916/// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
917struct NVGPUMBarrierTestWaitLowering
918 : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> {
919 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern;
920 LogicalResult
921 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
922 ConversionPatternRewriter &rewriter) const override {
923 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
924 Value barrier =
925 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
926 adaptor.getMbarId(), rewriter);
927 Type retType = rewriter.getI1Type();
928 if (isMbarrierShared(op.getBarriers().getType())) {
929 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>(
930 op, retType, barrier, adaptor.getToken());
931 } else {
932 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>(
933 op, retType, barrier, adaptor.getToken());
934 }
935 return success();
936 }
937};
938
939struct NVGPUMBarrierArriveExpectTxLowering
940 : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> {
941 using MBarrierBasePattern<
942 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern;
943 LogicalResult
944 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
945 ConversionPatternRewriter &rewriter) const override {
946 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
947 Value barrier =
948 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
949 adaptor.getMbarId(), rewriter);
950 Value txcount = truncToI32(b, adaptor.getTxcount());
951
952 if (isMbarrierShared(op.getBarriers().getType())) {
953 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
954 op, barrier, txcount, adaptor.getPredicate());
955 return success();
956 }
957
958 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
959 op, barrier, txcount, adaptor.getPredicate());
960 return success();
961 }
962};
963
964struct NVGPUMBarrierTryWaitParityLowering
965 : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> {
966 using MBarrierBasePattern<
967 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern;
968 LogicalResult
969 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
970 ConversionPatternRewriter &rewriter) const override {
971 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
972 Value barrier =
973 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
974 adaptor.getMbarId(), rewriter);
975 Value ticks = truncToI32(b, adaptor.getTicks());
976 Value phase =
977 b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
978
979 if (isMbarrierShared(op.getBarriers().getType())) {
980 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
981 op, barrier, phase, ticks);
982 return success();
983 }
984
985 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
986 phase, ticks);
987 return success();
988 }
989};
990
991struct NVGPUTmaAsyncLoadOpLowering
992 : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> {
993 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern;
994 LogicalResult
995 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
996 ConversionPatternRewriter &rewriter) const override {
997 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
998 auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
999 Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1000 adaptor.getDst(), {});
1001 Value barrier =
1002 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
1003 adaptor.getMbarId(), rewriter);
1004
1005 SmallVector<Value> coords = adaptor.getCoordinates();
1006 for (auto [index, value] : llvm::enumerate(coords)) {
1007 coords[index] = truncToI32(b, value);
1008 }
1009 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
1010 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
1011 ValueRange{}, adaptor.getMulticastMask(), Value{},
1012 adaptor.getPredicate());
1013 return success();
1014 }
1015};
1016
1017struct NVGPUTmaAsyncStoreOpLowering
1018 : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
1019 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
1020 LogicalResult
1021 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
1022 ConversionPatternRewriter &rewriter) const override {
1023 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1024 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
1025 Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
1026 adaptor.getSrc(), {});
1027 SmallVector<Value> coords = adaptor.getCoordinates();
1028 for (auto [index, value] : llvm::enumerate(coords)) {
1029 coords[index] = truncToI32(b, value);
1030 }
1031
1032 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
1033 op, adaptor.getTensorMapDescriptor(), dest, coords,
1034 adaptor.getPredicate());
1035 return success();
1036 }
1037};
1038
1039struct NVGPUGenerateWarpgroupDescriptorLowering
1040 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
1041 using ConvertOpToLLVMPattern<
1042 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern;
1043
1044 LogicalResult
1045 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor,
1046 ConversionPatternRewriter &rewriter) const override {
1047
1048 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1049
1050 nvgpu::TensorMapSwizzleKind swizzleKind =
1051 op.getTensorMap().getType().getSwizzle();
1052
1053 unsigned layout =
1054 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
1055 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
1056 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
1057 : 1;
1058 unsigned swizzle =
1059 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
1060 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
1061 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
1062 : 0;
1063
1064 auto ti64 = b.getIntegerType(64);
1065 auto makeConst = [&](uint64_t index) -> Value {
1066 return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
1067 };
1068 auto shiftLeft = [&](Value value, unsigned shift) -> Value {
1069 return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
1070 };
1071 auto shiftRight = [&](Value value, unsigned shift) -> Value {
1072 return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
1073 };
1074 auto insertBit = [&](Value desc, Value val, int startBit) {
1075 return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
1076 };
1077
1078 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
1079 uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
1080 uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
1081 uint64_t offsetVal = 0;
1082
1083 Value strideDim = makeConst(strideDimVal);
1084 Value leadDim = makeConst(leadDimVal);
1085
1086 Value baseAddr = getStridedElementPtr(
1087 rewriter, op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
1088 adaptor.getTensor(), {});
1089 Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
1090 // Just use 14 bits for base address
1091 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
1092
1093 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
1094 startLeadBit = 16, startBaseAddrBit = 0;
1095 Value dsc = makeConst(0);
1096 // // [62,64) swizzle type
1097 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
1098 // // [49,52) base_offset
1099 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit);
1100 // // [32,46) stride
1101 dsc = insertBit(dsc, strideDim, startStrideBit);
1102 // // [16,30) leading dimension
1103 dsc = insertBit(dsc, leadDim, startLeadBit);
1104 // // [0,14) start_address
1105 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
1106
1107 LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1108 << "leading_off:" << leadDimVal << "\t"
1109 << "stride_off :" << strideDimVal << "\t"
1110 << "base_offset:" << offsetVal << "\t"
1111 << "layout_type:" << swizzle << " ("
1112 << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
1113 << ")\n start_addr : " << baseAddr << "\n");
1114
1115 rewriter.replaceOp(op, dsc);
1116 return success();
1117 }
1118};
1119
1120static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
1121 return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
1122 b.getI32IntegerAttr(index));
1123}
1124
1125/// Returns a Value that holds data type enum that is expected by CUDA driver.
1126static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
1127 // Enum is from CUDA driver API
1128 // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1129 enum CUtensorMapDataTypeEnum {
1130 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
1131 CU_TENSOR_MAP_DATA_TYPE_UINT16,
1132 CU_TENSOR_MAP_DATA_TYPE_UINT32,
1133 CU_TENSOR_MAP_DATA_TYPE_INT32,
1134 CU_TENSOR_MAP_DATA_TYPE_UINT64,
1135 CU_TENSOR_MAP_DATA_TYPE_INT64,
1136 CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
1137 CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
1138 CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
1139 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
1140 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
1141 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
1142 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1143 };
1144
1145 if (type.isUnsignedInteger(width: 8))
1146 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT8);
1147 if (type.isUnsignedInteger(width: 16))
1148 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT16);
1149 if (type.isUnsignedInteger(width: 32))
1150 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT32);
1151 if (type.isUnsignedInteger(width: 64))
1152 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_UINT64);
1153 if (type.isSignlessInteger(width: 32))
1154 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_INT32);
1155 if (type.isSignlessInteger(width: 64))
1156 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_INT64);
1157 if (type.isF16())
1158 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
1159 if (type.isF32())
1160 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
1161 if (type.isF64())
1162 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
1163 if (type.isBF16())
1164 return makeI64Const(b, index: CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
1165
1166 llvm_unreachable("Not supported data type");
1167}
1168
1169struct NVGPUTmaCreateDescriptorOpLowering
1170 : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> {
1171 using ConvertOpToLLVMPattern<
1172 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern;
1173 LogicalResult
1174 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
1175 ConversionPatternRewriter &rewriter) const override {
1176 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1177 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext());
1178 Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
1179
1180 Value tensorElementType =
1181 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
1182 auto promotedOperands = getTypeConverter()->promoteOperands(
1183 b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
1184
1185 Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
1186 makeI64Const(b, 5));
1187 for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
1188 Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
1189 boxArrayPtr, makeI64Const(b, index));
1190 b.create<LLVM::StoreOp>(value, gep);
1191 }
1192
1193 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
1194 // Set Arguments for the function call
1195 SmallVector<Value> arguments;
1196 arguments.push_back(Elt: promotedOperands[0]); // rank
1197 arguments.push_back(Elt: promotedOperands[1]); // descriptor
1198 arguments.push_back(Elt: tensorElementType); // data type
1199 arguments.push_back(
1200 Elt: makeI64Const(b, index: (int)desc.getInterleave())); // interleave
1201 arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getSwizzle())); // swizzle
1202 arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getL2promo())); // l2promo
1203 arguments.push_back(Elt: makeI64Const(b, index: (int)desc.getOob())); // oob
1204 arguments.push_back(Elt: boxArrayPtr); // box dimensions
1205
1206 // Set data types of the arguments
1207 SmallVector<Type> argTypes = {
1208 llvmInt64Type, /* int64_t tensorRank */
1209 llvmPointerType, /* ptr */
1210 llvmInt64Type, /* int64_t */
1211 llvmInt64Type, /* int64_t */
1212 llvmInt64Type, /* int64_t */
1213 llvmInt64Type, /* int64_t */
1214 llvmInt64Type, /* int64_t */
1215 llvmPointerType /* ptr */
1216 };
1217 FunctionCallBuilder hostRegisterCallBuilder = {
1218 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
1219 Value tensorMap =
1220 hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
1221
1222 rewriter.replaceOp(op, tensorMap);
1223 return success();
1224 }
1225};
1226
1227struct NVGPUWarpgroupMmaOpLowering
1228 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
1229 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
1230
1231 /// This is a helper class to generate required NVVM Ops for warp-group level
1232 /// matrix multiplication.
1233 /// When the given GEMM shape is larger than the shape of
1234 /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1235 /// Op(s), group and execute them asynchronously. The class also handles
1236 /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1237 /// create descriptors for each instruction.
1238 ///
1239 /// For example this is the case when the shape of GEMM is 128x128x128
1240 ///
1241 /// nvvm.wgmma.fence.aligned
1242 ///
1243 /// nvvm.wgmma.mma.async descA, descB
1244 /// iterate(descA, descB)
1245 /// nvvm.wgmma.mma.async descA, descB
1246 /// [6x times more]
1247 ///
1248 /// nvvm.wgmma.group.sync.aligned
1249 /// nvvm.wgmma.wait.group.sync [groupId]
1250 ///
1251 class WarpgroupGemm {
1252 nvgpu::WarpgroupMmaOp op;
1253 ImplicitLocOpBuilder b;
1254 OpAdaptor adaptor;
1255
1256 // Entire shape of the given Op
1257 int64_t totalM, totalN, totalK;
1258
1259 // Shape of one wgmma instruction
1260 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
1261
1262 // Iteration counts for GEMM
1263 int iterationM = 0, iterationN = 0, iterationK = 0;
1264
1265 /// The function returns the shape of wgmma instruction that is defined in
1266 /// PTX programming guide.
1267 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1268 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
1269 wgmmaM = 64;
1270 wgmmaN = sizeN;
1271 if (inputElemType.isTF32()) {
1272 wgmmaK = 8;
1273 } else if (inputElemType.isF16() || inputElemType.isBF16()) {
1274 wgmmaK = 16;
1275 } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1276 inputElemType.isInteger(16)) {
1277 wgmmaK = 32;
1278 } else if (inputElemType.isInteger(width: 1)) {
1279 wgmmaK = 256;
1280 } else {
1281 llvm_unreachable("msg: not supported K shape");
1282 }
1283 LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1284 << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
1285 }
1286
1287 /// Generates WGMMATypesAttr from MLIR Type
1288 NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1289 bool useF32 = false) const {
1290 auto getWgmmaType = [=](Type elemType) {
1291 if (elemType.isF32() || elemType.isTF32())
1292 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
1293 if (elemType.isF16())
1294 return NVVM::WGMMATypes::f16;
1295 if (elemType.isBF16())
1296 return NVVM::WGMMATypes::bf16;
1297 if (isa<Float8E4M3FNType>(elemType))
1298 return NVVM::WGMMATypes::e4m3;
1299 if (isa<Float8E5M2Type>(elemType))
1300 return NVVM::WGMMATypes::e5m2;
1301 if (elemType.isInteger(1))
1302 return NVVM::WGMMATypes::b1;
1303 if (elemType.isInteger(8))
1304 return NVVM::WGMMATypes::s8;
1305 if (elemType.isUnsignedInteger(8))
1306 return NVVM::WGMMATypes::u8;
1307 if (elemType.isInteger(32))
1308 return NVVM::WGMMATypes::s32;
1309 llvm_unreachable("unsupported type");
1310 };
1311 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
1312 }
1313
1314 /// Generates layout attribute for the input matrix for wgmma instruction
1315 NVVM::MMALayoutAttr
1316 generateWgmmaLayout(std::optional<bool> transpose) const {
1317 if (transpose.value_or(false))
1318 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
1319 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
1320 }
1321
1322 /// Generates shape attribute for wgmma instruction
1323 NVVM::MMAShapeAttr generateWgmmaShape() const {
1324 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
1325 }
1326
1327 /// Generates scale attributes of output matrix for wgmma instruction
1328 NVVM::WGMMAScaleOutAttr generateScaleOut() const {
1329 return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
1330 NVVM::WGMMAScaleOut::one);
1331 }
1332 /// Generates scale attributes of input matrix for wgmma instruction
1333 NVVM::WGMMAScaleInAttr generateScaleIn() const {
1334 return NVVM::WGMMAScaleInAttr::get(op->getContext(),
1335 NVVM::WGMMAScaleIn::one);
1336 }
1337
1338 /// Basic function to generate Add
1339 Value makeAdd(Value lhs, Value rhs) {
1340 return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1341 };
1342
1343 /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1344 /// Currently, it only handles row-major.
1345 ///
1346 /// It moves the pointer like below for [128][64] size:
1347 /// +2 +4 +6
1348 /// ↓ ↓ ↓
1349 /// descA ---> +--+--+--+--+
1350 /// |->|->|->|->|
1351 /// | | | | |
1352 /// | | | | |
1353 /// | | | | |
1354 /// descA+512---> +-----------+
1355 /// | | | | |
1356 /// | | | | |
1357 /// | | | | |
1358 /// | | | | |
1359 /// +-----------+
1360 ///
1361 Value iterateDescriptorA(Value desc, int i, int j, int k) {
1362 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
1363 Type elemA = matrixTypeA.getElementType();
1364 int byte = elemA.getIntOrFloatBitWidth() / 8;
1365 int tileShapeA = matrixTypeA.getDimSize(1);
1366 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
1367 incrementVal = incrementVal >> exclude4LSB;
1368 LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
1369 << "] [wgmma descriptors] Descriptor A + "
1370 << incrementVal << " | \t ");
1371 if (!incrementVal)
1372 return desc;
1373 return makeAdd(lhs: desc, rhs: makeI64Const(b, index: incrementVal));
1374 }
1375
1376 /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1377 /// Currently, it only handles column-major.
1378 ///
1379 /// It moves the pointer like below for [128][64] size:
1380 /// descB ---> +--+--+--+--+--+--+--+--+
1381 /// |↓ | | | | | | | |
1382 /// |↓ | | | | | | | |
1383 /// |↓ | | | | | | | |
1384 /// |↓ | | | | | | | |
1385 /// +--+--+--+--+--+--+--+--+
1386 ///
1387 Value iterateDescriptorB(Value desc, int i, int j, int k) {
1388 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
1389 Type elemB = matrixTypeB.getElementType();
1390 int byte = elemB.getIntOrFloatBitWidth() / 8;
1391 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
1392 incrementVal = incrementVal >> exclude4LSB;
1393 LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
1394 if (!incrementVal)
1395 return desc;
1396 return makeAdd(lhs: desc, rhs: makeI64Const(b, index: incrementVal));
1397 }
1398
1399 /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1400 /// descriptors and arranges them based on induction variables: i, j, and k.
1401 Value generateWgmma(int i, int j, int k, Value matrixC) {
1402 LLVM_DEBUG(DBGS() << "\t wgmma."
1403 << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
1404 << "(A[" << (iterationM * wgmmaM) << ":"
1405 << (iterationM * wgmmaM) + wgmmaM << "]["
1406 << (iterationK * wgmmaK) << ":"
1407 << (iterationK * wgmmaK + wgmmaK) << "] * "
1408 << " B[" << (iterationK * wgmmaK) << ":"
1409 << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
1410 << wgmmaN << "])\n");
1411
1412 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
1413 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
1414
1415 Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
1416 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
1417
1418 Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
1419 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
1420
1421 Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1422 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1423
1424 NVVM::MMAShapeAttr shape = generateWgmmaShape();
1425 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
1426 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
1427 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
1428 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB());
1429
1430 auto overflow = NVVM::MMAIntOverflowAttr::get(
1431 op->getContext(), NVVM::MMAIntOverflow::wrapped);
1432
1433 return b.create<NVVM::WgmmaMmaAsyncOp>(
1434 matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1435 itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1436 overflow);
1437 }
1438
1439 /// Generates multiple wgmma instructions to complete the given GEMM shape
1440 Value generateWgmmaGroup() {
1441 Value wgmmaResult =
1442 b.create<LLVM::PoisonOp>(adaptor.getMatrixC().getType());
1443
1444 // Perform GEMM
1445 SmallVector<Value> wgmmaResults;
1446 for (int i = 0; i < iterationM; ++i) {
1447 Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
1448 for (int j = 0; j < iterationN; ++j)
1449 for (int k = 0; k < iterationK; ++k)
1450 matrixC = generateWgmma(i, j, k, matrixC);
1451 wgmmaResults.push_back(Elt: matrixC);
1452 }
1453 for (auto [idx, matrix] : llvm::enumerate(First&: wgmmaResults)) {
1454 wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
1455 wgmmaResult, matrix, idx);
1456 }
1457 return wgmmaResult;
1458 }
1459
1460 public:
1461 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
1462 OpAdaptor adaptor)
1463 : op(op), b(b), adaptor(adaptor) {
1464 // Find the entire GEMM Shape
1465 totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
1466 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
1467 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
1468 LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
1469 << "] += A[" << totalM << "][" << totalK << "] * B["
1470 << totalK << "][" << totalN << "] ---===\n");
1471
1472 // Find the shape for one wgmma instruction
1473 findWgmmaShape(
1474 sizeM: totalM, sizeN: totalN,
1475 inputElemType: op.getDescriptorA().getType().getTensor().getElementType());
1476
1477 // Iterations counts to complete the given shape with wgmma shape
1478 iterationM = totalM / wgmmaM;
1479 iterationN = totalN / wgmmaN;
1480 iterationK = totalK / wgmmaK;
1481 }
1482
1483 /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1484 /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1485 /// instructions and group synchronization, as well as waiting
1486 /// (WgmmaGroupSyncAlignedOp) for group synchronization
1487 /// (WgmmaWaitGroupSyncOp) after the instructions.
1488 Value generateWarpgroupMma() {
1489 b.create<NVVM::WgmmaFenceAlignedOp>();
1490 Value wgmmaResult = generateWgmmaGroup();
1491 b.create<NVVM::WgmmaGroupSyncAlignedOp>();
1492 b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
1493 return wgmmaResult;
1494 }
1495 };
1496 LogicalResult
1497 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
1498 ConversionPatternRewriter &rewriter) const override {
1499 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1500
1501 // Step 1. Build a helper class
1502 WarpgroupGemm warpgroupGemm(op, b, adaptor);
1503
1504 // Step 2. Get the entire GEMM Shape
1505 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
1506
1507 // Step 3. Replace fragmented result struct with the op results
1508 rewriter.replaceOp(op, wgmmaResult);
1509 return success();
1510 }
1511};
1512
1513struct NVGPUWarpgroupMmaStoreOpLowering
1514 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
1515 using ConvertOpToLLVMPattern<
1516 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1517
1518 /// This function stores a fragmented register matrix owned by a warp group
1519 /// (128 threads) into a memref. Each thread has 64 registers, each the size
1520 /// of a struct.
1521 /// Here is what each threads (T) holds, each `d` is struct value with a
1522 /// number.
1523 ///
1524 /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1525 /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1526 /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1527 /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1528 /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1529 ///
1530 /// Matrix-D:
1531 /// +______________________________________________________________________+
1532 /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1533 /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1534 /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1535 /// ..| .........|.........|.........|.........|........|...........|........|
1536 /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1537 /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1538 /// ..| .........|.........|.........|.........|........|...........|........|
1539 /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1540 /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1541 /// ..| .........|.........|.........|.........|........|...........|........|
1542 /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1543 /// ..| .........|.........|.........|.........|........|...........|........|
1544 /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1545 /// ..| .........|.........|.........|.........|........|...........|........|
1546 /// +______________________________________________________________________+
1547 ///
1548 /// \param rewriter: The pattern rewriter.
1549 /// \param matrixD: Result of the warp-group MMA operation (fragmented
1550 /// matrix). It is holded by a thread and a struct with 64 elements.
1551 /// \param dstMemref: The memref where the registers will be stored.
1552 /// \param offset: the offset within the memref where the registers will be
1553 /// stored.
1554 void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
1555 TypedValue<MemRefType> dstMemref,
1556 int offset) const {
1557 Type i32 = b.getI32Type();
1558
1559 auto makeConst = [&](int32_t index) -> Value {
1560 return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index));
1561 };
1562 Value c1 = makeConst(1);
1563 Value c2 = makeConst(2);
1564 Value c4 = makeConst(4);
1565 Value c8 = makeConst(8);
1566 Value c16 = makeConst(16);
1567 Value warpSize = makeConst(kWarpSize);
1568
1569 auto makeMul = [&](Value lhs, Value rhs) -> Value {
1570 return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs);
1571 };
1572 auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1573 return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
1574 };
1575
1576 auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1577 TypedValue<::mlir::MemRefType> memref) {
1578 Type it = b.getIndexType();
1579 Value idx = b.create<arith::IndexCastOp>(it, x);
1580 Value idy0 = b.create<arith::IndexCastOp>(it, y);
1581 Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1));
1582 Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i);
1583 Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1);
1584 b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1585 b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1586 };
1587
1588 Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
1589 Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
1590 Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
1591 Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
1592 Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
1593
1594 Value tj = makeMul(lane4modId, c2);
1595 Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
1596 if (offset)
1597 ti = makeAdd(ti, makeConst(offset));
1598
1599 auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
1600
1601 // Number of 32-bit registers owns per thread
1602 constexpr unsigned numAdjacentRegisters = 2;
1603 // Number of 8x8 matrices one below another per warp
1604 constexpr unsigned numStackedMatrices = 2;
1605
1606 size_t storeCount = (structType.getBody().size() /
1607 (numStackedMatrices * numAdjacentRegisters));
1608
1609 for (size_t i = 0; i < numStackedMatrices; ++i) {
1610 Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
1611 for (size_t j = 0; j < storeCount; ++j) {
1612 Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
1613 size_t structIndex = (i * numAdjacentRegisters) +
1614 (j * (numStackedMatrices * numAdjacentRegisters));
1615 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
1616 }
1617 }
1618 }
1619
1620 LogicalResult
1621 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1622 ConversionPatternRewriter &rewriter) const override {
1623 int offset = 0;
1624 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1625 Value matriDValue = adaptor.getMatrixD();
1626 auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
1627 for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
1628 auto structType = cast<LLVM::LLVMStructType>(matrixD);
1629 Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
1630 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
1631 offset += structType.getBody().size();
1632 }
1633 rewriter.eraseOp(op: op);
1634 return success();
1635 }
1636};
1637
1638struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1639 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
1640 using ConvertOpToLLVMPattern<
1641 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
1642 LogicalResult
1643 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
1644 ConversionPatternRewriter &rewriter) const override {
1645 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1646 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
1647 getTypeConverter()->convertType(op.getMatrixC().getType()));
1648 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
1649 .getBody()
1650 .front();
1651 Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
1652 Value packStruct = b.create<LLVM::PoisonOp>(packStructType);
1653 SmallVector<Value> innerStructs;
1654 // Unpack the structs and set all values to zero
1655 for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
1656 auto structType = cast<LLVM::LLVMStructType>(s);
1657 Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
1658 for (unsigned i = 0; i < structType.getBody().size(); ++i) {
1659 structValue = b.create<LLVM::InsertValueOp>(
1660 structType, structValue, zero, ArrayRef<int64_t>({i}));
1661 }
1662 innerStructs.push_back(structValue);
1663 }
1664 // Pack the inner structs into a single struct
1665 for (auto [idx, matrix] : llvm::enumerate(First&: innerStructs)) {
1666 packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(),
1667 packStruct, matrix, idx);
1668 }
1669 rewriter.replaceOp(op, packStruct);
1670 return success();
1671 }
1672};
1673
1674struct NVGPUTmaFenceOpLowering
1675 : public ConvertOpToLLVMPattern<nvgpu::TmaFenceOp> {
1676 using ConvertOpToLLVMPattern<nvgpu::TmaFenceOp>::ConvertOpToLLVMPattern;
1677 LogicalResult
1678 matchAndRewrite(nvgpu::TmaFenceOp op, OpAdaptor adaptor,
1679 ConversionPatternRewriter &rewriter) const override {
1680 MLIRContext *ctx = op.getContext();
1681 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1682 auto i32Ty = b.getI32Type();
1683 Value tensormapSize =
1684 b.create<LLVM::ConstantOp>(i32Ty, rewriter.getI32IntegerAttr(128));
1685
1686 auto memscope =
1687 NVVM::MemScopeKindAttr::get(ctx, ::mlir::NVVM::MemScopeKind::SYS);
1688
1689 rewriter.replaceOpWithNewOp<NVVM::FenceProxyAcquireOp>(
1690 op, memscope, adaptor.getTensorMapDescriptor(), tensormapSize);
1691
1692 return success();
1693 }
1694};
1695
1696struct NVGPUTmaPrefetchOpLowering
1697 : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> {
1698 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern;
1699 LogicalResult
1700 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
1701 ConversionPatternRewriter &rewriter) const override {
1702 rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1703 op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1704 return success();
1705 }
1706};
1707
1708struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
1709 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
1710 LogicalResult
1711 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
1712 ConversionPatternRewriter &rewriter) const override {
1713 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1714 auto i64Ty = b.getI64Type();
1715 auto f32Ty = b.getF32Type();
1716 VectorType inTy = op.getIn().getType();
1717 // apply rcp.approx.ftz.f on each element in vector.
1718 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
1719 Value ret1DVec = b.create<LLVM::PoisonOp>(llvm1DVectorTy);
1720 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
1721 for (int i = 0; i < numElems; i++) {
1722 Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
1723 Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
1724 Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
1725 ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
1726 }
1727 return ret1DVec;
1728 };
1729 if (inTy.getRank() == 1) {
1730 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
1731 return success();
1732 }
1733 return LLVM::detail::handleMultidimensionalVectors(
1734 op: op.getOperation(), operands: adaptor.getOperands(), typeConverter: *(this->getTypeConverter()),
1735 createOperand: [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
1736 OpAdaptor adaptor(operands);
1737 return convert1DVec(llvm1DVectorTy, adaptor.getIn());
1738 },
1739 rewriter);
1740 }
1741};
1742} // namespace
1743
1744void mlir::populateNVGPUToNVVMConversionPatterns(
1745 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1746 patterns.add<
1747 NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
1748 NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
1749 NVGPUMBarrierGetLowering, // nvgpu.mbarrier.get
1750 NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
1751 NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete
1752 NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
1753 NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
1754 NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
1755 NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
1756 NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
1757 NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
1758 NVGPUTmaFenceOpLowering, // nvgpu.tma.fence.descriptor
1759 NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
1760 NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
1761 NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
1762 NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
1763 NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
1764 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
1765 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
1766 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(arg: converter);
1767}
1768

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp