1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
17#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18
19#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
21#include "mlir/Dialect/GPU/IR/GPUDialect.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinAttributes.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/IR/Diagnostics.h"
26#include "mlir/IR/DialectImplementation.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
29#include "mlir/IR/OperationSupport.h"
30#include "mlir/IR/Types.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/AsmParser/Parser.h"
34#include "llvm/IR/Attributes.h"
35#include "llvm/IR/Function.h"
36#include "llvm/IR/IRBuilder.h"
37#include "llvm/IR/IntrinsicsNVPTX.h"
38#include "llvm/IR/Type.h"
39#include "llvm/Support/Casting.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/SourceMgr.h"
42#include "llvm/Support/raw_ostream.h"
43#include <cassert>
44#include <optional>
45#include <string>
46
47using namespace mlir;
48using namespace NVVM;
49
50#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
51#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
52
53//===----------------------------------------------------------------------===//
54// Verifier methods
55//===----------------------------------------------------------------------===//
56
57// This verifier is shared among the following Ops:
58// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
59// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
60// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
61static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
62 bool isIm2Col,
63 size_t numIm2ColOffsets,
64 Location loc) {
65 if (tensorDims < 1 || tensorDims > 5)
66 return emitError(loc, message: "expects coordinates between 1 to 5 dimension");
67
68 // For Im2Col mode, there are two constraints:
69 if (isIm2Col) {
70 // 1. Tensor must always be at least 3-d.
71 if (tensorDims < 3)
72 return emitError(
73 loc,
74 message: "to use im2col mode, the tensor has to be at least 3-dimensional");
75 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
76 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
77 return emitError(
78 loc, message: "im2col offsets must be 2 less than number of coordinates");
79 }
80 return success();
81}
82
83LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
84 size_t numIm2ColOffsets = getIm2colOffsets().size();
85 bool isIm2Col = numIm2ColOffsets > 0;
86 return cpAsyncBulkTensorCommonVerifier(tensorDims: getCoordinates().size(), isIm2Col,
87 numIm2ColOffsets, loc: getLoc());
88}
89
90LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
91 if (getCoordinates().size() > 5)
92 return emitError(message: "Maximum 5 coordinates and dimension is supported.");
93 return success();
94}
95
96LogicalResult CpAsyncOp::verify() {
97 if (getModifier() != LoadCacheModifierKind::CG &&
98 getModifier() != LoadCacheModifierKind::CA)
99 return emitError(message: "Only CG and CA cache modifiers are supported.");
100 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
101 return emitError(message: "expected byte size to be either 4, 8 or 16.");
102 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
103 return emitError(message: "CG cache modifier is only support for 16 bytes copy.");
104 return success();
105}
106
107LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
108 size_t numIm2ColOffsets = getIm2colOffsets().size();
109 bool isIm2Col = numIm2ColOffsets > 0;
110 return cpAsyncBulkTensorCommonVerifier(tensorDims: getCoordinates().size(), isIm2Col,
111 numIm2ColOffsets, loc: getLoc());
112}
113
114LogicalResult CpAsyncBulkTensorReduceOp::verify() {
115 bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
116 return cpAsyncBulkTensorCommonVerifier(tensorDims: getCoordinates().size(), isIm2Col, numIm2ColOffsets: 0,
117 loc: getLoc());
118}
119
120LogicalResult ConvertFloatToTF32Op::verify() {
121 using RndMode = NVVM::FPRoundingMode;
122 switch (getRnd()) {
123 case RndMode::RNA:
124 if (getRelu())
125 return emitError(message: "Relu not supported with rna rounding mode.");
126 break;
127 case RndMode::RN:
128 case RndMode::RZ:
129 break;
130 default:
131 return emitError(
132 message: "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
133 }
134 return success();
135}
136
137LogicalResult ConvertF32x2ToF8x2Op::verify() {
138 using RndMode = NVVM::FPRoundingMode;
139 using SatMode = NVVM::SaturationMode;
140
141 bool isRoundingModeRN = getRnd() == RndMode::RN;
142 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
143 bool isRoundingModeRP = getRnd() == RndMode::RP;
144 bool isSatFinite = getSat() == SatMode::SATFINITE;
145
146 bool hasRelu = getRelu();
147
148 switch (getType()) {
149 case ConvertFP8Type::E4M3:
150 case ConvertFP8Type::E5M2:
151 if (!isRoundingModeRN)
152 return emitOpError(message: "Only RN rounding mode is supported for conversions "
153 "from f32x2 to .e4m3x2 or .e5m2x2 types");
154 if (!isSatFinite)
155 return emitOpError(message: "Only SATFINITE saturation mode is supported for "
156 "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
157 break;
158 case ConvertFP8Type::UE8M0:
159 if (!(isRoundingModeRZ || isRoundingModeRP))
160 return emitOpError(message: "Only RZ or RP rounding modes are supported for "
161 "conversions from f32x2 to .ue8m0x2 type");
162 if (hasRelu)
163 return emitOpError(message: "relu not supported for conversions to .ue8m0x2 type");
164 break;
165 }
166 return success();
167}
168
169LogicalResult ConvertF16x2ToF8x2Op::verify() {
170 if (getType() == ConvertFP8Type::UE8M0)
171 return emitOpError(message: "Only .e4m3 or .e5m2 types are supported for "
172 "conversions from f16x2 to f8x2.");
173
174 return success();
175}
176
177LogicalResult ConvertBF16x2ToF8x2Op::verify() {
178 using RndMode = NVVM::FPRoundingMode;
179
180 if (getType() != ConvertFP8Type::UE8M0)
181 return emitOpError(
182 message: "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
183
184 auto rnd = getRnd();
185 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
186 return emitOpError(message: "Only RZ and RP rounding modes are supported for "
187 "conversions from bf16x2 to f8x2.");
188
189 return success();
190}
191
192LogicalResult BulkStoreOp::verify() {
193 if (getInitVal() != 0)
194 return emitOpError(message: "only 0 is supported for initVal, got ") << getInitVal();
195 return success();
196}
197
198// Given the element type of an operand and whether or not it is an accumulator,
199// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
200// operand's element type.
201std::optional<mlir::NVVM::MMATypes>
202MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
203 auto half2Type =
204 VectorType::get(shape: 2, elementType: Float16Type::get(context: operandElType.getContext()));
205 if (operandElType.isF64())
206 return NVVM::MMATypes::f64;
207 if (operandElType.isF16() || operandElType == half2Type)
208 return NVVM::MMATypes::f16;
209 if (operandElType.isF32() && isAccumulator)
210 return NVVM::MMATypes::f32;
211 if (operandElType.isF32() && !isAccumulator)
212 return NVVM::MMATypes::tf32;
213 if (llvm::isa<IntegerType>(Val: operandElType)) {
214 if (isAccumulator)
215 return NVVM::MMATypes::s32;
216 return std::nullopt;
217 }
218
219 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(Val&: operandElType)) {
220 if (structType.getBody().empty())
221 return std::nullopt;
222 return inferOperandMMAType(operandElType: structType.getBody()[0], isAccumulator);
223 }
224
225 return std::nullopt;
226}
227
228static bool isInt4PtxType(MMATypes type) {
229 return (type == MMATypes::u4 || type == MMATypes::s4);
230}
231
232static bool isInt8PtxType(MMATypes type) {
233 return (type == MMATypes::u8 || type == MMATypes::s8);
234}
235
236static bool isIntegerPtxType(MMATypes type) {
237 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
238 type == MMATypes::s32;
239}
240
241MMATypes MmaOp::accumPtxType() {
242 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
243 operandElType: getODSOperands(index: 2).getTypes().front(), /*isAccumulator=*/true);
244 assert(val.has_value() && "accumulator PTX type should always be inferrable");
245 return val.value();
246}
247
248MMATypes MmaOp::resultPtxType() {
249 std::optional<mlir::NVVM::MMATypes> val =
250 inferOperandMMAType(operandElType: getResult().getType(), /*isAccumulator=*/true);
251 assert(val.has_value() && "result PTX type should always be inferrable");
252 return val.value();
253}
254
255void MmaOp::print(OpAsmPrinter &p) {
256 SmallVector<Type, 4> regTypes;
257 struct OperandFragment {
258 StringRef operandName;
259 StringRef ptxTypeAttr;
260 SmallVector<Value, 4> regs;
261 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
262 : operandName(name), ptxTypeAttr(ptxTypeName) {}
263 };
264
265 std::array<OperandFragment, 3> frags{
266 OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
267 OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
268 OperandFragment("C", "")};
269 SmallVector<StringRef, 4> ignoreAttrNames{
270 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
271
272 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
273 auto &frag = frags[fragIdx];
274 auto varOperandSpec = getODSOperandIndexAndLength(index: fragIdx);
275 for (auto operandIdx = varOperandSpec.first;
276 operandIdx < varOperandSpec.first + varOperandSpec.second;
277 operandIdx++) {
278 frag.regs.push_back(Elt: this->getOperand(i: operandIdx));
279 if (operandIdx == 0) {
280 regTypes.push_back(Elt: this->getOperand(i: operandIdx).getType());
281 }
282 }
283 std::optional<MMATypes> inferredType =
284 inferOperandMMAType(operandElType: regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
285 if (inferredType)
286 ignoreAttrNames.push_back(Elt: frag.ptxTypeAttr);
287 }
288
289 auto printMmaOperand = [&](const OperandFragment &frag) -> void {
290 p << " " << frag.operandName;
291 p << "[";
292 p.printOperands(container: frag.regs);
293 p << "] ";
294 };
295
296 for (const auto &frag : frags) {
297 printMmaOperand(frag);
298 }
299
300 p.printOptionalAttrDict(attrs: this->getOperation()->getAttrs(), elidedAttrs: ignoreAttrNames);
301
302 // Print the types of the operands and result.
303 p << " : "
304 << "(";
305 llvm::interleaveComma(c: SmallVector<Type, 3>{frags[0].regs[0].getType(),
306 frags[1].regs[0].getType(),
307 frags[2].regs[0].getType()},
308 os&: p);
309 p << ")";
310 p.printArrowTypeList(types: TypeRange{this->getRes().getType()});
311}
312
313void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
314 ValueRange operandA, ValueRange operandB, ValueRange operandC,
315 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
316 std::optional<MMAIntOverflow> intOverflow,
317 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
318 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
319
320 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
321 MLIRContext *ctx = builder.getContext();
322 result.addAttribute(
323 name: "shape", attr: builder.getAttr<MMAShapeAttr>(args: shape[0], args: shape[1], args: shape[2]));
324
325 result.addOperands(newOperands: operandA);
326 result.addOperands(newOperands: operandB);
327 result.addOperands(newOperands: operandC);
328
329 if (multiplicandPtxTypes) {
330 result.addAttribute(name: "multiplicandAPtxType",
331 attr: MMATypesAttr::get(context: ctx, value: (*multiplicandPtxTypes)[0]));
332 result.addAttribute(name: "multiplicandBPtxType",
333 attr: MMATypesAttr::get(context: ctx, value: (*multiplicandPtxTypes)[1]));
334 } else {
335 if (auto res = inferOperandMMAType(operandElType: operandA[0].getType(), isAccumulator: false))
336 result.addAttribute(name: "multiplicandAPtxType", attr: MMATypesAttr::get(context: ctx, value: *res));
337 if (auto res = inferOperandMMAType(operandElType: operandB[0].getType(), isAccumulator: false))
338 result.addAttribute(name: "multiplicandBPtxType", attr: MMATypesAttr::get(context: ctx, value: *res));
339 }
340
341 if (multiplicandLayouts) {
342 result.addAttribute(name: "layoutA",
343 attr: MMALayoutAttr::get(context: ctx, value: (*multiplicandLayouts)[0]));
344 result.addAttribute(name: "layoutB",
345 attr: MMALayoutAttr::get(context: ctx, value: (*multiplicandLayouts)[1]));
346 } else {
347 result.addAttribute(name: "layoutA", attr: MMALayoutAttr::get(context: ctx, value: MMALayout::row));
348 result.addAttribute(name: "layoutB", attr: MMALayoutAttr::get(context: ctx, value: MMALayout::col));
349 }
350
351 if (intOverflow.has_value())
352 result.addAttribute(name: "intOverflowBehavior",
353 attr: MMAIntOverflowAttr::get(context: ctx, value: *intOverflow));
354 if (b1Op.has_value())
355 result.addAttribute(name: "b1Op", attr: MMAB1OpAttr::get(context: ctx, value: *b1Op));
356
357 result.addTypes(newTypes: resultType);
358 result.addAttribute(
359 name: MmaOp::getOperandSegmentSizeAttr(),
360 attr: builder.getDenseI32ArrayAttr(values: {static_cast<int32_t>(operandA.size()),
361 static_cast<int32_t>(operandB.size()),
362 static_cast<int32_t>(operandC.size())}));
363}
364
365// <operation> :=
366// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
367// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
368// `->` type($res)
369ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
370 struct OperandFragment {
371 std::optional<MMATypes> elemtype;
372 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
373 SmallVector<Type> regTypes;
374 };
375
376 Builder &builder = parser.getBuilder();
377 std::array<OperandFragment, 4> frags;
378
379 NamedAttrList namedAttributes;
380
381 // A helper to parse the operand segments.
382 auto parseMmaOperand = [&](StringRef operandName,
383 OperandFragment &frag) -> LogicalResult {
384 if (parser.parseKeyword(keyword: operandName).failed())
385 return failure();
386 if (parser
387 .parseOperandList(result&: frag.regs, delimiter: OpAsmParser::Delimiter::OptionalSquare)
388 .failed())
389 return failure();
390 return success();
391 };
392
393 // Parse the operand segments.
394 if (parseMmaOperand("A", frags[0]).failed())
395 return failure();
396 if (parseMmaOperand("B", frags[1]).failed())
397 return failure();
398 if (parseMmaOperand("C", frags[2]).failed())
399 return failure();
400
401 if (parser.parseOptionalAttrDict(result&: namedAttributes).failed())
402 return failure();
403
404 // Parse the type specification and resolve operands.
405 SmallVector<Type, 3> operandTypes;
406 if (failed(Result: parser.parseColon()))
407 return failure();
408 if (failed(Result: parser.parseLParen()))
409 return failure();
410 if (failed(Result: parser.parseTypeList(result&: operandTypes)))
411 return failure();
412 if (failed(Result: parser.parseRParen()))
413 if (operandTypes.size() != 3)
414 return parser.emitError(
415 loc: parser.getNameLoc(),
416 message: "expected one type for each operand segment but got " +
417 Twine(operandTypes.size()) + " types");
418 for (const auto &iter : llvm::enumerate(First&: operandTypes)) {
419 auto &frag = frags[iter.index()];
420 frag.regTypes.resize(N: frag.regs.size(), NV: iter.value());
421 if (failed(Result: parser.resolveOperands(operands&: frag.regs, types&: frag.regTypes,
422 loc: parser.getNameLoc(), result&: result.operands)))
423 return failure();
424 frag.elemtype = inferOperandMMAType(operandElType: frag.regTypes[0],
425 /*isAccumulator*/ iter.index() < 2);
426 }
427
428 Type resultType;
429 if (parser.parseArrow() || parser.parseType(result&: resultType))
430 return failure();
431 frags[3].elemtype = inferOperandMMAType(operandElType: resultType, /*isAccumulator*/ true);
432
433 std::array<StringRef, 2> names{"multiplicandAPtxType",
434 "multiplicandBPtxType"};
435 for (unsigned idx = 0; idx < names.size(); idx++) {
436 const auto &frag = frags[idx];
437 std::optional<NamedAttribute> attr = namedAttributes.getNamed(name: names[idx]);
438 if (!frag.elemtype.has_value() && !attr.has_value()) {
439 return parser.emitError(
440 loc: parser.getNameLoc(),
441 message: "attribute " + names[idx] +
442 " is not provided explicitly and cannot be inferred");
443 }
444 if (!attr.has_value())
445 result.addAttribute(
446 name: names[idx], attr: MMATypesAttr::get(context: parser.getContext(), value: *frag.elemtype));
447 }
448
449 result.addTypes(newTypes: resultType);
450 if (!namedAttributes.empty())
451 result.addAttributes(newAttributes: namedAttributes);
452 result.addAttribute(name: MmaOp::getOperandSegmentSizeAttr(),
453 attr: builder.getDenseI32ArrayAttr(values: {
454 static_cast<int32_t>(frags[0].regs.size()),
455 static_cast<int32_t>(frags[1].regs.size()),
456 static_cast<int32_t>(frags[2].regs.size()),
457 }));
458 return success();
459}
460
461LogicalResult MmaOp::verify() {
462 MLIRContext *context = getContext();
463 auto f16Ty = Float16Type::get(context);
464 auto i32Ty = IntegerType::get(context, width: 32);
465 auto f16x2Ty = VectorType::get(shape: 2, elementType: f16Ty);
466 auto f32Ty = Float32Type::get(context);
467 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
468 context, types: {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
469
470 auto s32x4StructTy =
471 LLVM::LLVMStructType::getLiteral(context, types: {i32Ty, i32Ty, i32Ty, i32Ty});
472 auto f32x8StructTy =
473 LLVM::LLVMStructType::getLiteral(context, types: SmallVector<Type>(8, f32Ty));
474 auto f16x2x2StructTy =
475 LLVM::LLVMStructType::getLiteral(context, types: {f16x2Ty, f16x2Ty});
476 auto f32x4StructTy =
477 LLVM::LLVMStructType::getLiteral(context, types: {f32Ty, f32Ty, f32Ty, f32Ty});
478 auto s32x2StructTy =
479 LLVM::LLVMStructType::getLiteral(context, types: {i32Ty, i32Ty});
480
481 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
482 getShapeAttr().getK()};
483
484 // These variables define the set of allowed data types for matrices A, B, C,
485 // and result.
486 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
487 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
488 AllowedShapes allowedShapes;
489 AllowedTypes expectedA;
490 AllowedTypes expectedB;
491 AllowedTypes expectedC;
492 SmallVector<Type> expectedResult;
493
494 // When M = 16, we just need to calculate the number of 8xk tiles, where
495 // k is a factor that depends on the data type.
496 if (mmaShape[0] == 16) {
497 int64_t kFactor;
498 Type multiplicandFragType;
499 switch (*getMultiplicandAPtxType()) {
500 case MMATypes::tf32:
501 kFactor = 4;
502 multiplicandFragType = i32Ty;
503 expectedResult.push_back(Elt: LLVM::LLVMStructType::getLiteral(
504 context, types: {f32Ty, f32Ty, f32Ty, f32Ty}));
505 break;
506 case MMATypes::bf16:
507 kFactor = 8;
508 multiplicandFragType = i32Ty;
509 expectedResult.push_back(Elt: LLVM::LLVMStructType::getLiteral(
510 context, types: {f32Ty, f32Ty, f32Ty, f32Ty}));
511 break;
512 case MMATypes::f16:
513 kFactor = 8;
514 multiplicandFragType = f16x2Ty;
515 expectedResult.push_back(Elt: f16x2x2StructTy);
516 expectedResult.push_back(Elt: f32x4StructTy);
517 break;
518 case MMATypes::s4:
519 case MMATypes::u4:
520 kFactor = 32;
521 break;
522 case MMATypes::b1:
523 kFactor = 128;
524 break;
525 case MMATypes::s8:
526 case MMATypes::u8:
527 kFactor = 16;
528 break;
529 default:
530 return emitError(message: "invalid shape or multiplicand type: " +
531 stringifyEnum(enumValue: getMultiplicandAPtxType().value()));
532 }
533
534 if (isIntegerPtxType(type: getMultiplicandAPtxType().value())) {
535 expectedResult.push_back(Elt: s32x4StructTy);
536 expectedC.emplace_back(Args: 4, Args&: i32Ty);
537 multiplicandFragType = i32Ty;
538 } else {
539 expectedC.emplace_back(Args: 2, Args&: f16x2Ty);
540 expectedC.emplace_back(Args: 4, Args&: f32Ty);
541 }
542
543 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
544 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
545 expectedA.emplace_back(Args&: unitA, Args&: multiplicandFragType);
546 expectedB.emplace_back(Args&: unitB, Args&: multiplicandFragType);
547 allowedShapes.push_back(Elt: {16, 8, kFactor});
548 allowedShapes.push_back(Elt: {16, 8, kFactor * 2});
549 }
550
551 // In the M=8 case, there is only 1 possible case per data type.
552 if (mmaShape[0] == 8) {
553 if (*getMultiplicandAPtxType() == MMATypes::f16) {
554 expectedA.emplace_back(Args: 2, Args&: f16x2Ty);
555 expectedB.emplace_back(Args: 2, Args&: f16x2Ty);
556 expectedResult.push_back(Elt: f16x2x4StructTy);
557 expectedResult.push_back(Elt: f32x8StructTy);
558 expectedC.emplace_back(Args: 4, Args&: f16x2Ty);
559 expectedC.emplace_back(Args: 8, Args&: f32Ty);
560 allowedShapes.push_back(Elt: {8, 8, 4});
561 }
562 if (*getMultiplicandAPtxType() == MMATypes::f64) {
563 Type f64Ty = Float64Type::get(context);
564 expectedA.emplace_back(Args: 1, Args&: f64Ty);
565 expectedB.emplace_back(Args: 1, Args&: f64Ty);
566 expectedC.emplace_back(Args: 2, Args&: f64Ty);
567 expectedResult.emplace_back(Args: LLVM::LLVMStructType::getLiteral(
568 context, types: SmallVector<Type>(2, f64Ty)));
569 allowedShapes.push_back(Elt: {8, 8, 4});
570 }
571 if (isIntegerPtxType(type: getMultiplicandAPtxType().value())) {
572 expectedA.push_back(Elt: {i32Ty});
573 expectedB.push_back(Elt: {i32Ty});
574 expectedC.push_back(Elt: {i32Ty, i32Ty});
575 expectedResult.push_back(Elt: s32x2StructTy);
576 if (isInt4PtxType(type: getMultiplicandAPtxType().value()))
577 allowedShapes.push_back(Elt: {8, 8, 32});
578 if (isInt8PtxType(type: getMultiplicandAPtxType().value()))
579 allowedShapes.push_back(Elt: {8, 8, 16});
580 if (getMultiplicandAPtxType().value() == MMATypes::b1)
581 allowedShapes.push_back(Elt: {8, 8, 128});
582 }
583 }
584
585 std::string errorMessage;
586 llvm::raw_string_ostream errorStream(errorMessage);
587
588 // Check that we matched an existing shape/dtype combination.
589 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
590 !llvm::is_contained(Range&: allowedShapes, Element: mmaShape)) {
591 errorStream << "unimplemented variant for MMA shape <";
592 llvm::interleaveComma(c: mmaShape, os&: errorStream);
593 errorStream << ">";
594 return emitOpError(message: errorMessage);
595 }
596
597 // Verify the operand types for segments of A, B, and C operands.
598 std::array<StringRef, 3> operandNames{"A", "B", "C"};
599 for (const auto &iter : llvm::enumerate(
600 First: SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
601 auto spec = this->getODSOperandIndexAndLength(index: iter.index());
602 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
603 operand_type_begin() + spec.first +
604 spec.second);
605 bool match = llvm::is_contained(Range&: iter.value(), Element: operandTySeg);
606
607 if (!match) {
608 errorStream << "Could not match types for the "
609 << operandNames[iter.index()]
610 << " operands; expected one of ";
611 for (const auto &x : iter.value()) {
612 errorStream << x.size() << "x" << x[0] << " ";
613 }
614 errorStream << "but got ";
615 llvm::interleaveComma(c: operandTySeg, os&: errorStream);
616 return emitOpError(message: errorMessage);
617 }
618 }
619
620 // Check the result type
621 if (!llvm::any_of(Range&: expectedResult, P: [&](Type expectedResultType) {
622 return expectedResultType == getResult().getType();
623 })) {
624 errorStream
625 << "Could not match allowed types for the result; expected one of ";
626 llvm::interleaveComma(c: expectedResult, os&: errorStream);
627 errorStream << " but got " << getResult().getType();
628 return emitOpError(message: errorMessage);
629 }
630
631 // Ensure that binary MMA variants have a b1 MMA operation defined.
632 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
633 return emitOpError(message: "op requires " + getB1OpAttrName().strref() +
634 " attribute");
635 }
636
637 // Ensure int4/int8 MMA variants specify the accum overflow behavior
638 // attribute.
639 if (isInt4PtxType(type: *getMultiplicandAPtxType()) ||
640 isInt8PtxType(type: *getMultiplicandAPtxType())) {
641 if (!getIntOverflowBehavior())
642 return emitOpError(message: "op requires " +
643 getIntOverflowBehaviorAttrName().strref() +
644 " attribute");
645 }
646
647 return success();
648}
649
650LogicalResult ShflOp::verify() {
651 if (!(*this)->getAttrOfType<UnitAttr>(name: "return_value_and_is_valid"))
652 return success();
653 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(Val: getType());
654 auto elementType = (type && type.getBody().size() == 2)
655 ? llvm::dyn_cast<IntegerType>(Val: type.getBody()[1])
656 : nullptr;
657 if (!elementType || elementType.getWidth() != 1)
658 return emitError(message: "expected return type to be a two-element struct with "
659 "i1 as the second element");
660 return success();
661}
662
663std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
664 NVVM::MMAFrag frag, int nRow,
665 int nCol,
666 MLIRContext *context) {
667 unsigned numberElements = 0;
668 Type elementType;
669 OpBuilder builder(context);
670 Type f16x2 = VectorType::get(shape: 2, elementType: builder.getF16Type());
671 if (type == NVVM::MMATypes::f16) {
672 elementType = f16x2;
673 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
674 numberElements = 8;
675 else
676 numberElements = 4;
677 } else if (type == NVVM::MMATypes::f32) {
678 elementType = builder.getF32Type();
679 numberElements = 8;
680 } else if (type == NVVM::MMATypes::tf32) {
681 elementType = builder.getI32Type();
682 numberElements = 4;
683 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
684 elementType = builder.getI32Type();
685 int parallelSize = 0;
686 if (frag == NVVM::MMAFrag::a)
687 parallelSize = nRow;
688 if (frag == NVVM::MMAFrag::b)
689 parallelSize = nCol;
690
691 // m == 16 && n == 16 && k == 16
692 if (parallelSize == 16)
693 numberElements = 2;
694 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
695 else if (parallelSize == 8)
696 numberElements = 1;
697 else if (parallelSize == 32)
698 numberElements = 4;
699 } else if (type == NVVM::MMATypes::s32) {
700 elementType = builder.getI32Type();
701 numberElements = 8;
702 }
703 assert(numberElements != 0 && elementType != nullptr);
704 return std::make_pair(x&: elementType, y&: numberElements);
705}
706
707static std::pair<mlir::Type, unsigned>
708inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
709 int k, MLIRContext *context) {
710 int nRow, nCol;
711 if (frag == NVVM::MMAFrag::a) {
712 nRow = m;
713 nCol = k;
714 } else if (frag == NVVM::MMAFrag::b) {
715 nRow = k;
716 nCol = n;
717 } else {
718 nRow = m;
719 nCol = n;
720 }
721 assert(nRow && nCol);
722 return inferMMAType(type, frag, nRow, nCol, context);
723}
724
725LogicalResult NVVM::WMMALoadOp::verify() {
726 unsigned addressSpace =
727 llvm::cast<LLVM::LLVMPointerType>(Val: getPtr().getType()).getAddressSpace();
728 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
729 addressSpace != NVVM::kSharedMemorySpace)
730 return emitOpError(message: "expected source pointer in memory "
731 "space 0, 1, 3");
732
733 if (NVVM::WMMALoadOp::getIntrinsicID(m: getM(), n: getN(), k: getK(), layoutEnum: getLayout(),
734 eltypeEnum: getEltype(), fragEnum: getFrag()) == 0)
735 return emitOpError() << "invalid attribute combination";
736 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
737 type: getEltype(), frag: getFrag(), m: getM(), n: getN(), k: getK(), context: getContext());
738 Type dstType = LLVM::LLVMStructType::getLiteral(
739 context: getContext(), types: SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
740 if (getType() != dstType)
741 return emitOpError(message: "expected destination type is a structure of ")
742 << typeInfo.second << " elements of type " << typeInfo.first;
743 return success();
744}
745
746LogicalResult NVVM::WMMAStoreOp::verify() {
747 unsigned addressSpace =
748 llvm::cast<LLVM::LLVMPointerType>(Val: getPtr().getType()).getAddressSpace();
749 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
750 addressSpace != NVVM::kSharedMemorySpace)
751 return emitOpError(message: "expected operands to be a source pointer in memory "
752 "space 0, 1, 3");
753
754 if (NVVM::WMMAStoreOp::getIntrinsicID(m: getM(), n: getN(), k: getK(), layoutEnum: getLayout(),
755 eltypeEnum: getEltype()) == 0)
756 return emitOpError() << "invalid attribute combination";
757 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
758 type: getEltype(), frag: NVVM::MMAFrag::c, m: getM(), n: getN(), k: getK(), context: getContext());
759 if (getArgs().size() != typeInfo.second)
760 return emitOpError() << "expected " << typeInfo.second << " data operands";
761 if (llvm::any_of(Range: getArgs(), P: [&typeInfo](Value operands) {
762 return operands.getType() != typeInfo.first;
763 }))
764 return emitOpError() << "expected data operands of type " << typeInfo.first;
765 return success();
766}
767
768LogicalResult NVVM::WMMAMmaOp::verify() {
769 if (NVVM::WMMAMmaOp::getIntrinsicID(m: getM(), n: getN(), k: getK(), layoutAEnum: getLayoutA(),
770 layoutBEnum: getLayoutB(), eltypeAEnum: getEltypeA(),
771 eltypeBEnum: getEltypeB()) == 0)
772 return emitOpError() << "invalid attribute combination";
773 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
774 type: getEltypeA(), frag: NVVM::MMAFrag::a, m: getM(), n: getN(), k: getK(), context: getContext());
775 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
776 type: getEltypeA(), frag: NVVM::MMAFrag::b, m: getM(), n: getN(), k: getK(), context: getContext());
777 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
778 type: getEltypeB(), frag: NVVM::MMAFrag::c, m: getM(), n: getN(), k: getK(), context: getContext());
779 SmallVector<Type, 32> arguments;
780 arguments.append(NumInputs: typeInfoA.second, Elt: typeInfoA.first);
781 arguments.append(NumInputs: typeInfoB.second, Elt: typeInfoB.first);
782 arguments.append(NumInputs: typeInfoC.second, Elt: typeInfoC.first);
783 unsigned numArgs = arguments.size();
784 if (getArgs().size() != numArgs)
785 return emitOpError() << "expected " << numArgs << " arguments";
786 for (unsigned i = 0; i < numArgs; i++) {
787 if (getArgs()[i].getType() != arguments[i])
788 return emitOpError() << "expected argument " << i << " to be of type "
789 << arguments[i];
790 }
791 Type dstType = LLVM::LLVMStructType::getLiteral(
792 context: getContext(), types: SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
793 if (getType() != dstType)
794 return emitOpError(message: "expected destination type is a structure of ")
795 << typeInfoC.second << " elements of type " << typeInfoC.first;
796 return success();
797}
798
799LogicalResult NVVM::LdMatrixOp::verify() {
800 unsigned addressSpace =
801 llvm::cast<LLVM::LLVMPointerType>(Val: getPtr().getType()).getAddressSpace();
802 if (addressSpace != NVVM::kSharedMemorySpace)
803 return emitOpError(message: "expected source pointer in memory space 3");
804
805 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
806 return emitOpError(message: "expected num attribute to be 1, 2 or 4");
807
808 Type i32 = IntegerType::get(context: getContext(), width: 32);
809 if (getNum() == 1 && getType() != i32)
810 return emitOpError(message: "expected destination type is i32");
811 if (getNum() == 2 || getNum() == 4) {
812 Type dstType = LLVM::LLVMStructType::getLiteral(
813 context: getContext(), types: SmallVector<Type>(getNum(), i32));
814 if (getType() != dstType)
815 return emitOpError(message: "expected destination type is a structure of ")
816 << getNum() << " elements of type i32";
817 }
818 return success();
819}
820
821LogicalResult NVVM::StMatrixOp::verify() {
822 unsigned addressSpace =
823 llvm::cast<LLVM::LLVMPointerType>(Val: getPtr().getType()).getAddressSpace();
824 if (addressSpace != NVVM::kSharedMemorySpace)
825 return emitOpError(message: "expected source pointer in memory space 3");
826
827 int numMatrix = getSources().size();
828 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
829 return emitOpError(message: "expected num attribute to be 1, 2 or 4");
830
831 return success();
832}
833
834FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
835 if (typeA == NVVM::WGMMATypes::tf32)
836 return 8;
837 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
838 return 16;
839 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
840 return 32;
841 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
842 return 32;
843 if (typeA == NVVM::WGMMATypes::b1)
844 return 256;
845 return failure();
846}
847
848LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
849 NVVM::WGMMATypes typeA,
850 NVVM::WGMMATypes typeB) {
851 switch (typeA) {
852 case NVVM::WGMMATypes::f16:
853 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
854 typeB == NVVM::WGMMATypes::f16)
855 return success();
856 break;
857 case NVVM::WGMMATypes::tf32:
858 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
859 return success();
860 break;
861 case NVVM::WGMMATypes::u8:
862 case NVVM::WGMMATypes::s8:
863 if (typeD == NVVM::WGMMATypes::s32 &&
864 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
865 return success();
866 break;
867 case NVVM::WGMMATypes::b1:
868 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
869 return success();
870 break;
871 case NVVM::WGMMATypes::bf16:
872 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
873 typeB == NVVM::WGMMATypes::bf16)
874 return success();
875 break;
876 case NVVM::WGMMATypes::e4m3:
877 case NVVM::WGMMATypes::e5m2:
878 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
879 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
880 return success();
881 break;
882 case WGMMATypes::f32:
883 case WGMMATypes::s32:
884 llvm_unreachable("unsupported input types");
885 break;
886 }
887 return failure();
888}
889
890LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
891 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
892 72, 80, 88, 96, 104, 112, 120, 128,
893 136, 144, 152, 160, 168, 176, 184, 192,
894 200, 208, 216, 224, 232, 240, 248, 256};
895 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
896 80, 96, 112, 128, 144, 160,
897 176, 192, 208, 224, 240, 256};
898 switch (typeA) {
899 case WGMMATypes::f16:
900 case WGMMATypes::tf32:
901 case WGMMATypes::bf16:
902 case WGMMATypes::e4m3:
903 case WGMMATypes::e5m2:
904 if (llvm::is_contained(Range&: allowedN, Element: sizeN))
905 return success();
906 break;
907 case WGMMATypes::u8:
908 case WGMMATypes::s8:
909 case WGMMATypes::b1:
910 if (llvm::is_contained(Range&: allowedNshort, Element: sizeN))
911 return success();
912 break;
913 case WGMMATypes::f32:
914 case WGMMATypes::s32:
915 llvm_unreachable("unsupported input types");
916 break;
917 }
918 return failure();
919}
920
921LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
922 Value outValue = getResults();
923 auto stype = dyn_cast<LLVM::LLVMStructType>(Val: outValue.getType());
924 if (!stype)
925 return emitOpError() << "expected results to be struct";
926 int outputSize = stype.getBody().size();
927 WGMMATypes typeD = getTypeD();
928 WGMMATypes typeA = getTypeA();
929 WGMMATypes typeB = getTypeB();
930
931 for (Type t : stype.getBody()) {
932 if (t != stype.getBody().front())
933 return emitOpError()
934 << "all elements in struct must be same type but there is " << t;
935 }
936
937 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
938 typeD != WGMMATypes::s32) {
939 return emitOpError() << "does not support the given output type "
940 << NVVM::stringifyWGMMATypes(val: typeD);
941 }
942 if (typeD == WGMMATypes::s32 &&
943 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
944 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
945 }
946
947 if (failed(Result: isAllowedWGMMADataType(typeD, typeA, typeB))) {
948 return emitOpError() << NVVM::stringifyWGMMATypes(val: typeD)
949 << " += " << NVVM::stringifyWGMMATypes(val: typeA) << " * "
950 << NVVM::stringifyWGMMATypes(val: typeB)
951 << ", it is not supported.";
952 }
953
954 // Check M
955 if (getShape().getM() != 64)
956 return emitOpError() << "shape 'm' must be 64";
957
958 // Check K
959 FailureOr<int> allowedK = getAllowedSizeK(typeA);
960 if (failed(Result: allowedK) || allowedK.value() != getShape().getK())
961 return emitOpError() << "shape 'k' must be " << allowedK.value()
962 << " for input type "
963 << NVVM::stringifyWGMMATypes(val: typeA);
964
965 // Check N
966 if (failed(Result: isAllowedSizeN(sizeN: getShape().getN(), typeA))) {
967 return emitOpError() << "has input type "
968 << NVVM::stringifyWGMMATypes(val: typeA) << " n is set to "
969 << getShape().getN() << ", it is not supported.";
970 }
971
972 // Check transpose (only available for f16/bf16)
973 // Matrices A should be stored in row-major and B in column-major.
974 // Only f16/bf16 matrices can be stored in either column-major or row-major
975 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
976 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
977 (getLayoutA() == mlir::NVVM::MMALayout::col ||
978 getLayoutB() == mlir::NVVM::MMALayout::row)) {
979 return emitOpError()
980 << "given layouts layout_a = " << stringifyMMALayout(val: getLayoutA())
981 << " and layout_b = " << stringifyMMALayout(val: getLayoutB())
982 << " for input types " << stringifyWGMMATypes(val: typeA) << " and "
983 << stringifyWGMMATypes(val: typeB)
984 << " requires transpose. However, this is only supported for: "
985 << stringifyMMATypes(val: MMATypes::f16) << " and "
986 << stringifyMMATypes(val: MMATypes::bf16);
987 }
988
989 // Check result registers
990 int expectedOutput = 0;
991 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
992 expectedOutput = getShape().getN() / 2;
993 if (typeD == WGMMATypes::f16)
994 expectedOutput = getShape().getN() / 4;
995 if (outputSize != expectedOutput) {
996 return emitOpError() << "results " << expectedOutput
997 << ", however output struct has " << outputSize
998 << " elements";
999 }
1000 // Check satfinite (only available for s32 accumulator)
1001 if (typeD != WGMMATypes::s32 &&
1002 getSatfinite().value_or(u: NVVM::MMAIntOverflow::wrapped) ==
1003 NVVM::MMAIntOverflow::satfinite) {
1004 return emitOpError()
1005 << " `satfinite` can be only used with s32 accumulator, however "
1006 "the current accumulator is "
1007 << NVVM::stringifyWGMMATypes(val: typeD);
1008 }
1009
1010 return success();
1011}
1012
1013std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1014
1015 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
1016 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1017
1018 StringRef outputTypeName = stringifyWGMMATypes(val: getTypeD());
1019
1020 int expectedOutputRegisters = 0;
1021 if (getTypeD() == WGMMATypes::f16)
1022 expectedOutputRegisters = getShape().getN() / 4;
1023 else
1024 expectedOutputRegisters = getShape().getN() / 2;
1025
1026 std::string ptx;
1027 llvm::raw_string_ostream ss(ptx);
1028
1029 ss << "{\n"
1030 ".reg .pred p;\n"
1031 "setp.ne.b32 p, $"
1032 << ((expectedOutputRegisters * 2) + 2)
1033 << ", 0;\n"
1034 "wgmma.mma_async.sync.aligned.m"
1035 << m << "n" << n << "k" << k << "." << outputTypeName << "."
1036 << stringifyWGMMATypes(val: getTypeA()) << "."
1037 << stringifyWGMMATypes(val: getTypeB());
1038 if (getSatfinite().value_or(u: NVVM::MMAIntOverflow::wrapped) ==
1039 NVVM::MMAIntOverflow::satfinite)
1040 ss << ".satfinite";
1041 ss << " {";
1042 int regCnt = 0;
1043 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1044 ss << "$" << regCnt;
1045 if (regCnt != expectedOutputRegisters - 1)
1046 ss << ", ";
1047 }
1048
1049 ss << "},";
1050 // Need to map read/write registers correctly.
1051 regCnt = (regCnt * 2);
1052 ss << " $" << (regCnt) << ","
1053 << " $" << (regCnt + 1) << ","
1054 << " p";
1055 if (getTypeD() != WGMMATypes::s32) {
1056 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1057 }
1058 // Don't add transpose parameters unless needed.
1059 if (isF16) {
1060 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1061 }
1062 ss << ";\n"
1063 << "}\n";
1064 return ptx;
1065}
1066
1067void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1068 RewriterBase &rewriter,
1069 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1070 &asmValues) {
1071 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1072 if (getResults())
1073 asmValues.push_back(Elt: {getResults(), mlir::NVVM::PTXRegisterMod::Write});
1074 if (getInouts())
1075 asmValues.push_back(Elt: {getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1076 asmValues.push_back(Elt: {getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1077 asmValues.push_back(Elt: {getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1078 asmValues.push_back(Elt: {makeConstantI32(rewriter, val: static_cast<int>(getScaleD())),
1079 mlir::NVVM::PTXRegisterMod::Read});
1080 if (getTypeD() != WGMMATypes::s32) {
1081 asmValues.push_back(
1082 Elt: {makeConstantI32(rewriter,
1083 val: getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1084 mlir::NVVM::PTXRegisterMod::Read});
1085 asmValues.push_back(
1086 Elt: {makeConstantI32(rewriter,
1087 val: getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1088 mlir::NVVM::PTXRegisterMod::Read});
1089 }
1090 if (isF16) {
1091 asmValues.push_back(
1092 Elt: {makeConstantI32(rewriter, val: static_cast<int>(getLayoutA())),
1093 mlir::NVVM::PTXRegisterMod::Read});
1094 asmValues.push_back(
1095 Elt: {makeConstantI32(rewriter, val: 1 - static_cast<int>(getLayoutB())),
1096 mlir::NVVM::PTXRegisterMod::Read});
1097 }
1098}
1099LogicalResult NVVM::FenceProxyOp::verify() {
1100 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1101 return emitOpError() << "tensormap proxy is not a supported proxy kind";
1102 if (getKind() == NVVM::ProxyKind::GENERIC)
1103 return emitOpError() << "generic proxy not a supported proxy kind";
1104 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1105 return emitOpError() << "async_shared fence requires space attribute";
1106 }
1107 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1108 return emitOpError() << "only async_shared fence can have space attribute";
1109 }
1110 return success();
1111}
1112
1113LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1114 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1115 return emitOpError(message: "uni-directional proxies only support generic for "
1116 "from_proxy attribute");
1117
1118 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1119 return emitOpError(message: "uni-directional proxies only support tensormap "
1120 "for to_proxy attribute");
1121
1122 return success();
1123}
1124
1125LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1126 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1127 return emitOpError(message: "uni-directional proxies only support generic for "
1128 "from_proxy attribute");
1129
1130 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1131 return emitOpError(message: "uni-directional proxies only support tensormap "
1132 "for to_proxy attribute");
1133
1134 return success();
1135}
1136
1137LogicalResult NVVM::SetMaxRegisterOp::verify() {
1138 if (getRegCount() % 8)
1139 return emitOpError(message: "new register size must be multiple of 8");
1140 if (getRegCount() < 24 || getRegCount() > 256)
1141 return emitOpError(message: "new register size must be in between 24 to 256");
1142 return success();
1143}
1144
1145LogicalResult NVVM::BarrierOp::verify() {
1146 if (getNumberOfThreads() && !getBarrierId())
1147 return emitOpError(
1148 message: "barrier id is missing, it should be set between 0 to 15");
1149 return success();
1150}
1151
1152LogicalResult NVVM::Tcgen05CpOp::verify() {
1153 auto mc = getMulticast();
1154
1155 using SH = Tcgen05CpShape;
1156 using MC = Tcgen05CpMulticast;
1157 switch (getShape()) {
1158 case SH::SHAPE_128x256b:
1159 case SH::SHAPE_128x128b:
1160 case SH::SHAPE_4x256b:
1161 if (mc != MC::NONE)
1162 return emitError(message: "Invalid multicast type for tcgen05.cp Op");
1163 break;
1164 case SH::SHAPE_64x128b:
1165 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1166 return emitError(message: "Shape 64x128b requires multicast warpx2_01_23 or "
1167 "warpx2_02_13 for tcgen05.cp Op");
1168 break;
1169 case SH::SHAPE_32x128b:
1170 if (mc != MC::WARPX4)
1171 return emitError(
1172 message: "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1173 break;
1174 }
1175 return success();
1176}
1177
1178LogicalResult NVVM::MatchSyncOp::verify() {
1179 if (getKind() == NVVM::MatchSyncKind::all) {
1180 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(Val: getType());
1181 if (!type || type.getBody().size() != 2 ||
1182 !type.getBody()[0].isInteger(width: 32) || !type.getBody()[1].isInteger(width: 1)) {
1183 return emitOpError(message: "match.sync 'all' returns a two element struct with "
1184 "first element as i32 and second element as i1");
1185 }
1186 } else {
1187 if (!getType().isInteger(width: 32)) {
1188 return emitOpError(message: "match.sync 'any' returns an i32");
1189 }
1190 }
1191 return success();
1192}
1193
1194LogicalResult NVVM::VoteSyncOp::verify() {
1195 if (getKind() == NVVM::VoteSyncKind::ballot) {
1196 if (!getType().isInteger(width: 32)) {
1197 return emitOpError(message: "vote.sync 'ballot' returns an i32");
1198 }
1199 } else {
1200 if (!getType().isInteger(width: 1)) {
1201 return emitOpError(message: "vote.sync 'any', 'all' and 'uni' returns an i1");
1202 }
1203 }
1204 return success();
1205}
1206
1207LogicalResult NVVM::PrefetchOp::verify() {
1208 using MemSpace = NVVM::NVVMMemorySpace;
1209 using CacheLevel = NVVM::PrefetchCacheLevel;
1210
1211 unsigned addressSpace =
1212 llvm::cast<LLVM::LLVMPointerType>(Val: getAddr().getType()).getAddressSpace();
1213 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1214
1215 if (getUniform()) {
1216 if (getCacheLevel() != CacheLevel::L1)
1217 return emitOpError(message: "unsupported cache level, the only supported uniform "
1218 "cache level is L1");
1219
1220 if (addressSpace != MemSpace::kGenericMemorySpace)
1221 return emitOpError(
1222 message: "prefetch to uniform cache requires a generic pointer");
1223 }
1224
1225 if (evictPriority) {
1226 if (getCacheLevel() != CacheLevel::L2)
1227 return emitOpError(
1228 message: "cache eviction priority supported only for cache level L2");
1229
1230 if (addressSpace != MemSpace::kGlobalMemorySpace)
1231 return emitOpError(message: "cache eviction priority requires a global pointer");
1232
1233 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1234 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1235 return emitOpError(
1236 message: "unsupported cache eviction priority, only evict_last and "
1237 "evict_normal are supported");
1238 }
1239
1240 return success();
1241}
1242
1243/// Packs the given `field` into the `result`.
1244/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
1245static llvm::Value *
1246packValInto64Bits(llvm::IRBuilderBase &builder,
1247 llvm::Value *result, // the `result` (unset bits are zero)
1248 llvm::Value *field, // `field` to pack into `result`
1249 unsigned sizeInBits, // Size of `field` in bits
1250 unsigned start) { // Starting bit within `result`
1251 field = builder.CreateZExtOrBitCast(V: field, DestTy: builder.getInt32Ty());
1252
1253 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1254 if (mask != 0xffffffffu)
1255 field = builder.CreateAnd(LHS: field, RHS: builder.getInt32(C: mask));
1256
1257 field = builder.CreateZExtOrBitCast(V: field, DestTy: builder.getInt64Ty());
1258 field = builder.CreateShl(LHS: field, RHS: start);
1259
1260 return builder.CreateOr(LHS: result, RHS: field);
1261}
1262
1263void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
1264 LLVM::ModuleTranslation &mt,
1265 llvm::IRBuilderBase &builder) {
1266 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(Val&: op);
1267 llvm::Value *smemDesc = builder.getInt64(C: 0);
1268
1269 smemDesc = packValInto64Bits(builder, result: smemDesc,
1270 field: mt.lookupValue(value: thisOp.getStartAddr()), sizeInBits: 14, start: 0);
1271 smemDesc = packValInto64Bits(
1272 builder, result: smemDesc, field: mt.lookupValue(value: thisOp.getLeadingDimOffset()), sizeInBits: 14, start: 16);
1273 smemDesc = packValInto64Bits(
1274 builder, result: smemDesc, field: mt.lookupValue(value: thisOp.getStrideDimOffset()), sizeInBits: 14, start: 32);
1275
1276 smemDesc = packValInto64Bits(builder, result: smemDesc, field: builder.getInt32(C: 1), sizeInBits: 3, start: 46);
1277 smemDesc = packValInto64Bits(builder, result: smemDesc,
1278 field: mt.lookupValue(value: thisOp.getBaseOffset()), sizeInBits: 3, start: 49);
1279 smemDesc = packValInto64Bits(
1280 builder, result: smemDesc, field: mt.lookupValue(value: thisOp.getLeadingDimMode()), sizeInBits: 1, start: 52);
1281 smemDesc = packValInto64Bits(builder, result: smemDesc,
1282 field: mt.lookupValue(value: thisOp.getSwizzleMode()), sizeInBits: 3, start: 61);
1283
1284 mt.mapValue(value: thisOp.getRes()) = smemDesc;
1285}
1286
1287//===----------------------------------------------------------------------===//
1288// getIntrinsicID/getIntrinsicIDAndArgs methods
1289//===----------------------------------------------------------------------===//
1290
1291#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1292 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1293
1294#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1295 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1296
1297llvm::Intrinsic::ID
1298CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1299 llvm::SmallVector<llvm::Value *> &args) {
1300 llvm::Intrinsic::ID id;
1301
1302 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(Val&: op);
1303 bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
1304 switch (cpAsyncOp.getSize()) {
1305 case 4:
1306 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1307 break;
1308 case 8:
1309 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1310 break;
1311 case 16:
1312 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1313 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1314 : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1315 break;
1316 default:
1317 llvm_unreachable("Invalid copy size in CpAsyncOp.");
1318 }
1319
1320 // Fill the Intrinsic Args
1321 args.push_back(Elt: mt.lookupValue(value: cpAsyncOp.getDst()));
1322 args.push_back(Elt: mt.lookupValue(value: cpAsyncOp.getSrc()));
1323 if (hasCpSize)
1324 args.push_back(Elt: mt.lookupValue(value: cpAsyncOp.getCpSize()));
1325
1326 return id;
1327}
1328
1329mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
1330 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1331 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(Val&: op);
1332 llvm::SmallVector<llvm::Value *> args;
1333 llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1334
1335 // Fill the Intrinsic Args
1336 args.push_back(Elt: mt.lookupValue(value: thisOp.getSrcMem()));
1337 args.push_back(Elt: mt.lookupValue(value: thisOp.getSize()));
1338
1339 mlir::Value cacheHint = thisOp.getL2CacheHint();
1340 const bool hasCacheHint = static_cast<bool>(cacheHint);
1341 llvm::Value *i64Unused =
1342 llvm::ConstantInt::get(Ty: llvm::Type::getInt64Ty(C&: mt.getLLVMContext()), V: 0);
1343 args.push_back(Elt: hasCacheHint ? mt.lookupValue(value: cacheHint) : i64Unused);
1344 args.push_back(Elt: builder.getInt1(V: hasCacheHint));
1345
1346 return {id, std::move(args)};
1347}
1348
1349mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1350 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1351 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(Val&: op);
1352 llvm::SmallVector<llvm::Value *> args;
1353 llvm::Intrinsic::ID id =
1354 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1355
1356 // Fill the Intrinsic Args
1357 args.push_back(Elt: mt.lookupValue(value: thisOp.getDstMem()));
1358 args.push_back(Elt: mt.lookupValue(value: thisOp.getSrcMem()));
1359 args.push_back(Elt: mt.lookupValue(value: thisOp.getSize()));
1360
1361 mlir::Value cacheHint = thisOp.getL2CacheHint();
1362 const bool hasCacheHint = static_cast<bool>(cacheHint);
1363 llvm::Value *i64Unused =
1364 llvm::ConstantInt::get(Ty: llvm::Type::getInt64Ty(C&: mt.getLLVMContext()), V: 0);
1365 args.push_back(Elt: hasCacheHint ? mt.lookupValue(value: cacheHint) : i64Unused);
1366 args.push_back(Elt: builder.getInt1(V: hasCacheHint));
1367
1368 // Choose the bytemask variant
1369 if (mlir::Value byteMask = thisOp.getByteMask()) {
1370 args.push_back(Elt: mt.lookupValue(value: byteMask));
1371 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1372 }
1373
1374 return {id, std::move(args)};
1375}
1376
1377llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1378 bool isIm2Col) {
1379 switch (tensorDims) {
1380 case 1:
1381 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1382 case 2:
1383 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1384 case 3:
1385 return isIm2Col
1386 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1387 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1388 case 4:
1389 return isIm2Col
1390 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1391 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1392 case 5:
1393 return isIm2Col
1394 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1395 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1396 default:
1397 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1398 }
1399}
1400
1401#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1402 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1403
1404#define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1405 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1406 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1407
1408#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1409 [&]() -> auto { \
1410 switch (dims) { \
1411 case 1: \
1412 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1413 case 2: \
1414 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1415 case 3: \
1416 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1417 case 4: \
1418 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1419 case 5: \
1420 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1421 default: \
1422 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1423 } \
1424 }()
1425
1426llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1427 int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1428 using RedTy = NVVM::TMAReduxKind;
1429 switch (kind) {
1430 case RedTy::ADD:
1431 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1432 case RedTy::MIN:
1433 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1434 case RedTy::MAX:
1435 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1436 case RedTy::INC:
1437 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1438 case RedTy::DEC:
1439 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1440 case RedTy::AND:
1441 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1442 case RedTy::OR:
1443 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1444 case RedTy::XOR:
1445 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1446 }
1447 llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1448}
1449
1450#define _none
1451
1452#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1453 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1454 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1455
1456#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1457 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1458 : CVT_F2TF32_ID_IMPL(rnd, relu, )
1459
1460llvm::Intrinsic::ID
1461ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1462 NVVM::SaturationMode sat, bool hasRelu) {
1463 using RndMode = NVVM::FPRoundingMode;
1464 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1465 switch (rnd) {
1466 case RndMode::RN:
1467 return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
1468 case RndMode::RZ:
1469 return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
1470 case RndMode::RNA:
1471 return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
1472 default:
1473 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
1474 }
1475}
1476
1477#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
1478 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1479 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1480
1481llvm::Intrinsic::ID
1482ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
1483 switch (type) {
1484 case NVVM::ConvertFP6Type::E2M3:
1485 return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
1486 case NVVM::ConvertFP6Type::E3M2:
1487 return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
1488 }
1489 llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
1490}
1491
1492#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1493 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1494 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1495
1496#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1497 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1498 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
1499
1500llvm::Intrinsic::ID
1501ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
1502 NVVM::FPRoundingMode rnd,
1503 NVVM::SaturationMode sat, bool hasRelu) {
1504 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1505 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1506 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1507
1508 switch (type) {
1509 case NVVM::ConvertFP8Type::E4M3:
1510 return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
1511 case NVVM::ConvertFP8Type::E5M2:
1512 return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
1513 case NVVM::ConvertFP8Type::UE8M0:
1514 if (hasRoundingModeRZ)
1515 return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
1516 else if (hasRoundingModeRP)
1517 return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
1518 }
1519 llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
1520}
1521
1522#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1523 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1524 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1525
1526llvm::Intrinsic::ID
1527ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
1528 switch (type) {
1529 case NVVM::ConvertFP8Type::E4M3:
1530 return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1531 case NVVM::ConvertFP8Type::E5M2:
1532 return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1533 default:
1534 llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
1535 }
1536}
1537
1538#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1539 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1540 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1541
1542llvm::Intrinsic::ID
1543ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1544 NVVM::SaturationMode sat) {
1545 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1546 switch (rnd) {
1547 case NVVM::FPRoundingMode::RZ:
1548 return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
1549 case NVVM::FPRoundingMode::RP:
1550 return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
1551 default:
1552 llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
1553 }
1554}
1555
1556llvm::Intrinsic::ID
1557Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
1558 LLVM::ModuleTranslation &mt,
1559 llvm::SmallVector<llvm::Value *> &args) {
1560 auto curOp = cast<NVVM::Tcgen05AllocOp>(Val&: op);
1561 unsigned as = llvm::cast<LLVM::LLVMPointerType>(Val: curOp.getAddr().getType())
1562 .getAddressSpace();
1563 bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1564 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1565
1566 llvm::Intrinsic::ID id;
1567 if (isShared) {
1568 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1569 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1570 } else {
1571 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1572 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1573 }
1574
1575 // Fill the Intrinsic Args
1576 args.push_back(Elt: mt.lookupValue(value: curOp.getAddr()));
1577 args.push_back(Elt: mt.lookupValue(value: curOp.getNCols()));
1578
1579 return id;
1580}
1581
1582llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
1583 Operation &op, LLVM::ModuleTranslation &mt,
1584 llvm::SmallVector<llvm::Value *> &args) {
1585 auto curOp = cast<NVVM::Tcgen05DeallocOp>(Val&: op);
1586 auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1587 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1588 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1589
1590 // Fill the Intrinsic Args
1591 args.push_back(Elt: mt.lookupValue(value: curOp.getTaddr()));
1592 args.push_back(Elt: mt.lookupValue(value: curOp.getNCols()));
1593
1594 return id;
1595}
1596
1597#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1598 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1599 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1600
1601#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1602 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1603 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1604
1605llvm::Intrinsic::ID
1606Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
1607 LLVM::ModuleTranslation &mt,
1608 llvm::SmallVector<llvm::Value *> &args) {
1609 auto curOp = cast<NVVM::Tcgen05CommitOp>(Val&: op);
1610 unsigned as = llvm::cast<LLVM::LLVMPointerType>(Val: curOp.getAddr().getType())
1611 .getAddressSpace();
1612 bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1613 bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
1614 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1615
1616 llvm::Intrinsic::ID id =
1617 is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
1618 : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
1619
1620 // Fill the Intrinsic Args
1621 args.push_back(Elt: mt.lookupValue(value: curOp.getAddr()));
1622 if (hasMulticast)
1623 args.push_back(Elt: mt.lookupValue(value: curOp.getMulticastMask()));
1624
1625 return id;
1626}
1627
1628#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1629 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1630
1631#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1632 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1633 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1634
1635#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1636 [&]() -> auto { \
1637 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1638 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1639 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1640 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1641 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1642 }()
1643
1644llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
1645 auto curOp = cast<NVVM::Tcgen05CpOp>(Val&: op);
1646 bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1647 auto srcFmt = curOp.getSrcFormat();
1648 auto mc = curOp.getMulticast();
1649
1650 switch (curOp.getShape()) {
1651 case Tcgen05CpShape::SHAPE_128x256b:
1652 return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
1653 case Tcgen05CpShape::SHAPE_128x128b:
1654 return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
1655 case Tcgen05CpShape::SHAPE_4x256b:
1656 return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
1657 case Tcgen05CpShape::SHAPE_32x128b:
1658 return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
1659 case Tcgen05CpShape::SHAPE_64x128b:
1660 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1661 ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
1662 : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
1663 }
1664 llvm_unreachable("Invalid shape in tcgen05 cp Op");
1665}
1666
1667// Returns the valid vector length for a given shape and vector length, the
1668// function models the table mentioned in the tcgen05.{ld, st} Op description
1669static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
1670 unsigned vecLen) {
1671 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1672 return vecLen >= 2;
1673 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1674 return vecLen >= 4;
1675 return true;
1676}
1677
1678LogicalResult Tcgen05LdOp::verify() {
1679 LogicalResult result = success();
1680 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1681 result = emitError(message: "shape 16x32bx2 requires offset argument");
1682
1683 auto resTy = getRes().getType();
1684 unsigned resLen = isa<VectorType>(Val: resTy)
1685 ? llvm::cast<VectorType>(Val&: resTy).getNumElements()
1686 : 1;
1687 if (!isValidVectorLength(shape: getShape(), vecLen: resLen))
1688 result = emitError(message: llvm::formatv(Fmt: "invalid result type length {0} for shape "
1689 "{1} in tcgen05.ld Op",
1690 Vals&: resLen, Vals: stringifyEnum(enumValue: getShape())));
1691
1692 return result;
1693}
1694
1695LogicalResult Tcgen05StOp::verify() {
1696 LogicalResult result = success();
1697 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1698 result = emitError(message: "shape 16x32bx2 requires offset argument");
1699
1700 auto valTy = getVal().getType();
1701 unsigned valLen = isa<VectorType>(Val: valTy)
1702 ? llvm::cast<VectorType>(Val&: valTy).getNumElements()
1703 : 1;
1704 if (!isValidVectorLength(shape: getShape(), vecLen: valLen))
1705 result = emitError(message: llvm::formatv(Fmt: "invalid input length {0} for shape "
1706 "{1} in tcgen05.st Op",
1707 Vals&: valLen, Vals: stringifyEnum(enumValue: getShape())));
1708
1709 return result;
1710}
1711
1712/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
1713/// have ConstantRangeAttr.
1714static void nvvmInferResultRanges(Operation *op, Value result,
1715 ArrayRef<::mlir::ConstantIntRanges> argRanges,
1716 SetIntRangeFn setResultRanges) {
1717 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>(name: "range")) {
1718 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1719 rangeAttr.getLower(), rangeAttr.getUpper()});
1720 }
1721}
1722
1723static llvm::Value *getAsPackedI32(llvm::Value *arg,
1724 llvm::IRBuilderBase &builder) {
1725 return builder.CreateBitCast(V: arg,
1726 DestTy: llvm::Type::getInt32Ty(C&: builder.getContext()));
1727}
1728
1729NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
1730 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1731 auto curOp = cast<NVVM::DotAccumulate4WayOp>(Val&: op);
1732
1733 llvm::SmallVector<llvm::Value *> args;
1734 args.push_back(Elt: getAsPackedI32(arg: mt.lookupValue(value: curOp.getA()), builder));
1735 args.push_back(Elt: getAsPackedI32(arg: mt.lookupValue(value: curOp.getB()), builder));
1736 args.push_back(Elt: mt.lookupValue(value: curOp.getC()));
1737
1738 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1739 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1740 unsigned type = (isASigned << 1) | isBSigned;
1741 const llvm::Intrinsic::ID ids[] = {
1742 llvm::Intrinsic::nvvm_idp4a_u_u,
1743 llvm::Intrinsic::nvvm_idp4a_u_s,
1744 llvm::Intrinsic::nvvm_idp4a_s_u,
1745 llvm::Intrinsic::nvvm_idp4a_s_s,
1746 };
1747 return {ids[type], args};
1748}
1749
1750NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
1751 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1752 auto curOp = cast<NVVM::DotAccumulate2WayOp>(Val&: op);
1753
1754 llvm::SmallVector<llvm::Value *> args;
1755 args.push_back(Elt: getAsPackedI32(arg: mt.lookupValue(value: curOp.getA()), builder));
1756 args.push_back(Elt: getAsPackedI32(arg: mt.lookupValue(value: curOp.getB()), builder));
1757 args.push_back(Elt: builder.getInt1(V: curOp.getBHi()));
1758 args.push_back(Elt: mt.lookupValue(value: curOp.getC()));
1759
1760 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1761 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1762 unsigned type = (isASigned << 1) | isBSigned;
1763 const llvm::Intrinsic::ID ids[] = {
1764 llvm::Intrinsic::nvvm_idp2a_u_u,
1765 llvm::Intrinsic::nvvm_idp2a_u_s,
1766 llvm::Intrinsic::nvvm_idp2a_s_u,
1767 llvm::Intrinsic::nvvm_idp2a_s_s,
1768 };
1769 return {ids[type], args};
1770}
1771
1772llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
1773 using MemSpace = NVVM::NVVMMemorySpace;
1774 using CacheLevel = NVVM::PrefetchCacheLevel;
1775
1776 NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
1777 std::optional<NVVM::CacheEvictionPriority> evictPriority =
1778 op.getEvictPriority();
1779 unsigned addressSpace =
1780 llvm::cast<LLVM::LLVMPointerType>(Val: op.getAddr().getType())
1781 .getAddressSpace();
1782
1783 if (op.getUniform() && cacheLevel == CacheLevel::L1)
1784 return llvm::Intrinsic::nvvm_prefetchu_L1;
1785
1786 if (evictPriority && cacheLevel == CacheLevel::L2) {
1787 switch (*evictPriority) {
1788 case NVVM::CacheEvictionPriority::EvictLast:
1789 return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1790 case NVVM::CacheEvictionPriority::EvictNormal:
1791 return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
1792 default:
1793 llvm_unreachable("Invalid cache eviction priority");
1794 }
1795 }
1796
1797 switch (addressSpace) {
1798 case MemSpace::kGenericMemorySpace:
1799 return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
1800 : llvm::Intrinsic::nvvm_prefetch_L2;
1801 case MemSpace::kGlobalMemorySpace:
1802 return cacheLevel == CacheLevel::L1
1803 ? llvm::Intrinsic::nvvm_prefetch_global_L1
1804 : llvm::Intrinsic::nvvm_prefetch_global_L2;
1805 case MemSpace::kLocalMemorySpace:
1806 return cacheLevel == CacheLevel::L1
1807 ? llvm::Intrinsic::nvvm_prefetch_local_L1
1808 : llvm::Intrinsic::nvvm_prefetch_local_L2;
1809 default:
1810 llvm_unreachable("Invalid pointer address space");
1811 }
1812}
1813
1814//===----------------------------------------------------------------------===//
1815// NVVMDialect initialization, type parsing, and registration.
1816//===----------------------------------------------------------------------===//
1817
1818// TODO: This should be the llvm.nvvm dialect once this is supported.
1819void NVVMDialect::initialize() {
1820 addOperations<
1821#define GET_OP_LIST
1822#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1823 >();
1824 addAttributes<
1825#define GET_ATTRDEF_LIST
1826#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1827 >();
1828
1829 // Support unknown operations because not all NVVM operations are
1830 // registered.
1831 allowUnknownOperations();
1832 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1833 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1834}
1835
1836LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1837 NamedAttribute attr) {
1838 StringAttr attrName = attr.getName();
1839 // Kernel function attribute should be attached to functions.
1840 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1841 if (!isa<LLVM::LLVMFuncOp>(Val: op)) {
1842 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1843 << "' attribute attached to unexpected op";
1844 }
1845 }
1846 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
1847 // dim
1848 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1849 attrName == NVVMDialect::getReqntidAttrName() ||
1850 attrName == NVVMDialect::getClusterDimAttrName()) {
1851 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(Val: attr.getValue());
1852 if (!values || values.empty() || values.size() > 3)
1853 return op->emitError()
1854 << "'" << attrName
1855 << "' attribute must be integer array with maximum 3 index";
1856 }
1857 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
1858 // attribute
1859 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1860 attrName == NVVMDialect::getMaxnregAttrName() ||
1861 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1862 if (!llvm::dyn_cast<IntegerAttr>(Val: attr.getValue()))
1863 return op->emitError()
1864 << "'" << attrName << "' attribute must be integer constant";
1865 }
1866
1867 return success();
1868}
1869
1870LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1871 unsigned regionIndex,
1872 unsigned argIndex,
1873 NamedAttribute argAttr) {
1874 auto funcOp = dyn_cast<FunctionOpInterface>(Val: op);
1875 if (!funcOp)
1876 return success();
1877
1878 bool isKernel = op->hasAttr(name: NVVMDialect::getKernelFuncAttrName());
1879 StringAttr attrName = argAttr.getName();
1880 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1881 if (!isKernel) {
1882 return op->emitError()
1883 << "'" << attrName
1884 << "' attribute must be present only on kernel arguments";
1885 }
1886 if (!isa<UnitAttr>(Val: argAttr.getValue()))
1887 return op->emitError() << "'" << attrName << "' must be a unit attribute";
1888 if (!funcOp.getArgAttr(index: argIndex, name: LLVM::LLVMDialect::getByValAttrName())) {
1889 return op->emitError()
1890 << "'" << attrName
1891 << "' attribute requires the argument to also have attribute '"
1892 << LLVM::LLVMDialect::getByValAttrName() << "'";
1893 }
1894 }
1895
1896 return success();
1897}
1898
1899//===----------------------------------------------------------------------===//
1900// NVVM target attribute.
1901//===----------------------------------------------------------------------===//
1902LogicalResult
1903NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1904 int optLevel, StringRef triple, StringRef chip,
1905 StringRef features, DictionaryAttr flags,
1906 ArrayAttr files, bool verifyTarget) {
1907 if (optLevel < 0 || optLevel > 3) {
1908 emitError() << "The optimization level must be a number between 0 and 3.";
1909 return failure();
1910 }
1911 if (triple.empty()) {
1912 emitError() << "The target triple cannot be empty.";
1913 return failure();
1914 }
1915 if (chip.empty()) {
1916 emitError() << "The target chip cannot be empty.";
1917 return failure();
1918 }
1919 if (files && !llvm::all_of(Range&: files, P: [](::mlir::Attribute attr) {
1920 return mlir::isa_and_nonnull<StringAttr>(Val: attr);
1921 })) {
1922 emitError() << "All the elements in the `link` array must be strings.";
1923 return failure();
1924 }
1925 return success();
1926}
1927
1928LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
1929 if (!getVerifyTarget())
1930 return success();
1931
1932 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(Val: gpuModule);
1933 if (!gpuModuleOp) {
1934 return emitError(loc: gpuModule->getLoc(),
1935 message: "NVVM target attribute must be attached to a GPU module");
1936 }
1937
1938 const NVVMCheckSMVersion targetSMVersion =
1939 NVVMCheckSMVersion::getTargetSMVersionFromStr(smVersionString: getChip());
1940 if (!targetSMVersion.isMinimumSMVersion()) {
1941 return emitError(loc: gpuModule->getLoc(),
1942 message: "Minimum NVVM target SM version is sm_20");
1943 }
1944
1945 gpuModuleOp->walk(callback: [&](Operation *op) {
1946 if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(Val: op)) {
1947 const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
1948 if (!requirement.isCompatibleWith(targetSM: targetSMVersion)) {
1949 op->emitOpError() << "is not supported on " << getChip();
1950 return WalkResult::interrupt();
1951 }
1952 }
1953 return WalkResult::advance();
1954 });
1955
1956 return success();
1957}
1958
1959#define GET_OP_CLASSES
1960#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1961
1962#define GET_ATTRDEF_CLASSES
1963#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1964

source code of mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp