1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
17#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18
19#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
21#include "mlir/Dialect/GPU/IR/GPUDialect.h"
22#include "mlir/Dialect/Utils/StaticValueUtils.h"
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/BuiltinAttributes.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/Diagnostics.h"
27#include "mlir/IR/DialectImplementation.h"
28#include "mlir/IR/MLIRContext.h"
29#include "mlir/IR/Operation.h"
30#include "mlir/IR/OperationSupport.h"
31#include "mlir/IR/Types.h"
32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/AsmParser/Parser.h"
35#include "llvm/IR/Attributes.h"
36#include "llvm/IR/Function.h"
37#include "llvm/IR/IRBuilder.h"
38#include "llvm/IR/IntrinsicsNVPTX.h"
39#include "llvm/IR/Type.h"
40#include "llvm/Support/Casting.h"
41#include "llvm/Support/FormatVariadic.h"
42#include "llvm/Support/SourceMgr.h"
43#include "llvm/Support/raw_ostream.h"
44#include <cassert>
45#include <optional>
46#include <string>
47
48using namespace mlir;
49using namespace NVVM;
50
51#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
52#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
53
54//===----------------------------------------------------------------------===//
55// Verifier methods
56//===----------------------------------------------------------------------===//
57
58// This verifier is shared among the following Ops:
59// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
60// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
61// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
62static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
63 bool isIm2Col,
64 size_t numIm2ColOffsets,
65 Location loc) {
66 if (tensorDims < 1 || tensorDims > 5)
67 return emitError(loc, message: "expects coordinates between 1 to 5 dimension");
68
69 // For Im2Col mode, there are two constraints:
70 if (isIm2Col) {
71 // 1. Tensor must always be at least 3-d.
72 if (tensorDims < 3)
73 return emitError(
74 loc,
75 message: "to use im2col mode, the tensor has to be at least 3-dimensional");
76 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
77 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
78 return emitError(
79 loc, message: "im2col offsets must be 2 less than number of coordinates");
80 }
81 return success();
82}
83
84LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
85 size_t numIm2ColOffsets = getIm2colOffsets().size();
86 bool isIm2Col = numIm2ColOffsets > 0;
87 return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
88 numIm2ColOffsets, getLoc());
89}
90
91LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
92 if (getCoordinates().size() > 5)
93 return emitError("Maximum 5 coordinates and dimension is supported.");
94 return success();
95}
96
97LogicalResult CpAsyncOp::verify() {
98 if (getModifier() != LoadCacheModifierKind::CG &&
99 getModifier() != LoadCacheModifierKind::CA)
100 return emitError("Only CG and CA cache modifiers are supported.");
101 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
102 return emitError("expected byte size to be either 4, 8 or 16.");
103 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
104 return emitError("CG cache modifier is only support for 16 bytes copy.");
105 return success();
106}
107
108LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
109 size_t numIm2ColOffsets = getIm2colOffsets().size();
110 bool isIm2Col = numIm2ColOffsets > 0;
111 return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
112 numIm2ColOffsets, getLoc());
113}
114
115LogicalResult CpAsyncBulkTensorReduceOp::verify() {
116 bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
117 return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
118 getLoc());
119}
120
121LogicalResult ConvertFloatToTF32Op::verify() {
122 using RndMode = NVVM::FPRoundingMode;
123 switch (getRnd()) {
124 case RndMode::RNA:
125 if (getRelu())
126 return emitError("Relu not supported with rna rounding mode.");
127 break;
128 case RndMode::RN:
129 case RndMode::RZ:
130 break;
131 default:
132 return emitError(
133 "Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.");
134 }
135 return success();
136}
137
138LogicalResult ConvertF32x2ToF8x2Op::verify() {
139 using RndMode = NVVM::FPRoundingMode;
140 using SatMode = NVVM::SaturationMode;
141
142 bool isRoundingModeRN = getRnd() == RndMode::RN;
143 bool isRoundingModeRZ = getRnd() == RndMode::RZ;
144 bool isRoundingModeRP = getRnd() == RndMode::RP;
145 bool isSatFinite = getSat() == SatMode::SATFINITE;
146
147 bool hasRelu = getRelu();
148
149 switch (getType()) {
150 case ConvertFP8Type::E4M3:
151 case ConvertFP8Type::E5M2:
152 if (!isRoundingModeRN)
153 return emitOpError("Only RN rounding mode is supported for conversions "
154 "from f32x2 to .e4m3x2 or .e5m2x2 types");
155 if (!isSatFinite)
156 return emitOpError("Only SATFINITE saturation mode is supported for "
157 "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
158 break;
159 case ConvertFP8Type::UE8M0:
160 if (!(isRoundingModeRZ || isRoundingModeRP))
161 return emitOpError("Only RZ or RP rounding modes are supported for "
162 "conversions from f32x2 to .ue8m0x2 type");
163 if (hasRelu)
164 return emitOpError("relu not supported for conversions to .ue8m0x2 type");
165 break;
166 }
167 return success();
168}
169
170LogicalResult ConvertF16x2ToF8x2Op::verify() {
171 if (getType() == ConvertFP8Type::UE8M0)
172 return emitOpError("Only .e4m3 or .e5m2 types are supported for "
173 "conversions from f16x2 to f8x2.");
174
175 return success();
176}
177
178LogicalResult ConvertBF16x2ToF8x2Op::verify() {
179 using RndMode = NVVM::FPRoundingMode;
180
181 if (getType() != ConvertFP8Type::UE8M0)
182 return emitOpError(
183 "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
184
185 auto rnd = getRnd();
186 if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
187 return emitOpError("Only RZ and RP rounding modes are supported for "
188 "conversions from bf16x2 to f8x2.");
189
190 return success();
191}
192
193LogicalResult BulkStoreOp::verify() {
194 if (getInitVal() != 0)
195 return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
196 return success();
197}
198
199// Given the element type of an operand and whether or not it is an accumulator,
200// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
201// operand's element type.
202std::optional<mlir::NVVM::MMATypes>
203MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
204 auto half2Type =
205 VectorType::get(2, Float16Type::get(operandElType.getContext()));
206 if (operandElType.isF64())
207 return NVVM::MMATypes::f64;
208 if (operandElType.isF16() || operandElType == half2Type)
209 return NVVM::MMATypes::f16;
210 if (operandElType.isF32() && isAccumulator)
211 return NVVM::MMATypes::f32;
212 if (operandElType.isF32() && !isAccumulator)
213 return NVVM::MMATypes::tf32;
214 if (llvm::isa<IntegerType>(operandElType)) {
215 if (isAccumulator)
216 return NVVM::MMATypes::s32;
217 return std::nullopt;
218 }
219
220 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
221 if (structType.getBody().empty())
222 return std::nullopt;
223 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
224 }
225
226 return std::nullopt;
227}
228
229static bool isInt4PtxType(MMATypes type) {
230 return (type == MMATypes::u4 || type == MMATypes::s4);
231}
232
233static bool isInt8PtxType(MMATypes type) {
234 return (type == MMATypes::u8 || type == MMATypes::s8);
235}
236
237static bool isIntegerPtxType(MMATypes type) {
238 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
239 type == MMATypes::s32;
240}
241
242MMATypes MmaOp::accumPtxType() {
243 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
244 getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
245 assert(val.has_value() && "accumulator PTX type should always be inferrable");
246 return val.value();
247}
248
249MMATypes MmaOp::resultPtxType() {
250 std::optional<mlir::NVVM::MMATypes> val =
251 inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
252 assert(val.has_value() && "result PTX type should always be inferrable");
253 return val.value();
254}
255
256void MmaOp::print(OpAsmPrinter &p) {
257 SmallVector<Type, 4> regTypes;
258 struct OperandFragment {
259 StringRef operandName;
260 StringRef ptxTypeAttr;
261 SmallVector<Value, 4> regs;
262 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
263 : operandName(name), ptxTypeAttr(ptxTypeName) {}
264 };
265
266 std::array<OperandFragment, 3> frags{
267 OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
268 OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
269 OperandFragment("C", "")};
270 SmallVector<StringRef, 4> ignoreAttrNames{
271 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
272
273 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
274 auto &frag = frags[fragIdx];
275 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
276 for (auto operandIdx = varOperandSpec.first;
277 operandIdx < varOperandSpec.first + varOperandSpec.second;
278 operandIdx++) {
279 frag.regs.push_back(this->getOperand(operandIdx));
280 if (operandIdx == 0) {
281 regTypes.push_back(this->getOperand(operandIdx).getType());
282 }
283 }
284 std::optional<MMATypes> inferredType =
285 inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
286 if (inferredType)
287 ignoreAttrNames.push_back(frag.ptxTypeAttr);
288 }
289
290 auto printMmaOperand = [&](const OperandFragment &frag) -> void {
291 p << " " << frag.operandName;
292 p << "[";
293 p.printOperands(frag.regs);
294 p << "] ";
295 };
296
297 for (const auto &frag : frags) {
298 printMmaOperand(frag);
299 }
300
301 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
302
303 // Print the types of the operands and result.
304 p << " : "
305 << "(";
306 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
307 frags[1].regs[0].getType(),
308 frags[2].regs[0].getType()},
309 p);
310 p << ")";
311 p.printArrowTypeList(TypeRange{this->getRes().getType()});
312}
313
314void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
315 ValueRange operandA, ValueRange operandB, ValueRange operandC,
316 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
317 std::optional<MMAIntOverflow> intOverflow,
318 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
319 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
320
321 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
322 MLIRContext *ctx = builder.getContext();
323 result.addAttribute(
324 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
325
326 result.addOperands(operandA);
327 result.addOperands(operandB);
328 result.addOperands(operandC);
329
330 if (multiplicandPtxTypes) {
331 result.addAttribute("multiplicandAPtxType",
332 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
333 result.addAttribute("multiplicandBPtxType",
334 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
335 } else {
336 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
337 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
338 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
339 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
340 }
341
342 if (multiplicandLayouts) {
343 result.addAttribute("layoutA",
344 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
345 result.addAttribute("layoutB",
346 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
347 } else {
348 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
349 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
350 }
351
352 if (intOverflow.has_value())
353 result.addAttribute("intOverflowBehavior",
354 MMAIntOverflowAttr::get(ctx, *intOverflow));
355 if (b1Op.has_value())
356 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
357
358 result.addTypes(resultType);
359 result.addAttribute(
360 MmaOp::getOperandSegmentSizeAttr(),
361 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
362 static_cast<int32_t>(operandB.size()),
363 static_cast<int32_t>(operandC.size())}));
364}
365
366// <operation> :=
367// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
368// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
369// `->` type($res)
370ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
371 struct OperandFragment {
372 std::optional<MMATypes> elemtype;
373 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
374 SmallVector<Type> regTypes;
375 };
376
377 Builder &builder = parser.getBuilder();
378 std::array<OperandFragment, 4> frags;
379
380 NamedAttrList namedAttributes;
381
382 // A helper to parse the operand segments.
383 auto parseMmaOperand = [&](StringRef operandName,
384 OperandFragment &frag) -> LogicalResult {
385 if (parser.parseKeyword(operandName).failed())
386 return failure();
387 if (parser
388 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
389 .failed())
390 return failure();
391 return success();
392 };
393
394 // Parse the operand segments.
395 if (parseMmaOperand("A", frags[0]).failed())
396 return failure();
397 if (parseMmaOperand("B", frags[1]).failed())
398 return failure();
399 if (parseMmaOperand("C", frags[2]).failed())
400 return failure();
401
402 if (parser.parseOptionalAttrDict(namedAttributes).failed())
403 return failure();
404
405 // Parse the type specification and resolve operands.
406 SmallVector<Type, 3> operandTypes;
407 if (failed(parser.parseColon()))
408 return failure();
409 if (failed(parser.parseLParen()))
410 return failure();
411 if (failed(parser.parseTypeList(operandTypes)))
412 return failure();
413 if (failed(parser.parseRParen()))
414 if (operandTypes.size() != 3)
415 return parser.emitError(
416 parser.getNameLoc(),
417 "expected one type for each operand segment but got " +
418 Twine(operandTypes.size()) + " types");
419 for (const auto &iter : llvm::enumerate(operandTypes)) {
420 auto &frag = frags[iter.index()];
421 frag.regTypes.resize(frag.regs.size(), iter.value());
422 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
423 parser.getNameLoc(), result.operands)))
424 return failure();
425 frag.elemtype = inferOperandMMAType(frag.regTypes[0],
426 /*isAccumulator*/ iter.index() < 2);
427 }
428
429 Type resultType;
430 if (parser.parseArrow() || parser.parseType(resultType))
431 return failure();
432 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccumulator*/ true);
433
434 std::array<StringRef, 2> names{"multiplicandAPtxType",
435 "multiplicandBPtxType"};
436 for (unsigned idx = 0; idx < names.size(); idx++) {
437 const auto &frag = frags[idx];
438 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
439 if (!frag.elemtype.has_value() && !attr.has_value()) {
440 return parser.emitError(
441 parser.getNameLoc(),
442 "attribute " + names[idx] +
443 " is not provided explicitly and cannot be inferred");
444 }
445 if (!attr.has_value())
446 result.addAttribute(
447 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
448 }
449
450 result.addTypes(resultType);
451 if (!namedAttributes.empty())
452 result.addAttributes(namedAttributes);
453 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
454 builder.getDenseI32ArrayAttr({
455 static_cast<int32_t>(frags[0].regs.size()),
456 static_cast<int32_t>(frags[1].regs.size()),
457 static_cast<int32_t>(frags[2].regs.size()),
458 }));
459 return success();
460}
461
462LogicalResult MmaOp::verify() {
463 MLIRContext *context = getContext();
464 auto f16Ty = Float16Type::get(context);
465 auto i32Ty = IntegerType::get(context, 32);
466 auto f16x2Ty = VectorType::get(2, f16Ty);
467 auto f32Ty = Float32Type::get(context);
468 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
469 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
470
471 auto s32x4StructTy =
472 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
473 auto f32x8StructTy =
474 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
475 auto f16x2x2StructTy =
476 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
477 auto f32x4StructTy =
478 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
479 auto s32x2StructTy =
480 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
481
482 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
483 getShapeAttr().getK()};
484
485 // These variables define the set of allowed data types for matrices A, B, C,
486 // and result.
487 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
488 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
489 AllowedShapes allowedShapes;
490 AllowedTypes expectedA;
491 AllowedTypes expectedB;
492 AllowedTypes expectedC;
493 SmallVector<Type> expectedResult;
494
495 // When M = 16, we just need to calculate the number of 8xk tiles, where
496 // k is a factor that depends on the data type.
497 if (mmaShape[0] == 16) {
498 int64_t kFactor;
499 Type multiplicandFragType;
500 switch (*getMultiplicandAPtxType()) {
501 case MMATypes::tf32:
502 kFactor = 4;
503 multiplicandFragType = i32Ty;
504 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
505 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
506 break;
507 case MMATypes::bf16:
508 kFactor = 8;
509 multiplicandFragType = i32Ty;
510 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
511 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
512 break;
513 case MMATypes::f16:
514 kFactor = 8;
515 multiplicandFragType = f16x2Ty;
516 expectedResult.push_back(f16x2x2StructTy);
517 expectedResult.push_back(f32x4StructTy);
518 break;
519 case MMATypes::s4:
520 case MMATypes::u4:
521 kFactor = 32;
522 break;
523 case MMATypes::b1:
524 kFactor = 128;
525 break;
526 case MMATypes::s8:
527 case MMATypes::u8:
528 kFactor = 16;
529 break;
530 default:
531 return emitError("invalid shape or multiplicand type: " +
532 stringifyEnum(getMultiplicandAPtxType().value()));
533 }
534
535 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
536 expectedResult.push_back(s32x4StructTy);
537 expectedC.emplace_back(4, i32Ty);
538 multiplicandFragType = i32Ty;
539 } else {
540 expectedC.emplace_back(2, f16x2Ty);
541 expectedC.emplace_back(4, f32Ty);
542 }
543
544 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
545 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
546 expectedA.emplace_back(unitA, multiplicandFragType);
547 expectedB.emplace_back(unitB, multiplicandFragType);
548 allowedShapes.push_back({16, 8, kFactor});
549 allowedShapes.push_back({16, 8, kFactor * 2});
550 }
551
552 // In the M=8 case, there is only 1 possible case per data type.
553 if (mmaShape[0] == 8) {
554 if (*getMultiplicandAPtxType() == MMATypes::f16) {
555 expectedA.emplace_back(2, f16x2Ty);
556 expectedB.emplace_back(2, f16x2Ty);
557 expectedResult.push_back(f16x2x4StructTy);
558 expectedResult.push_back(f32x8StructTy);
559 expectedC.emplace_back(4, f16x2Ty);
560 expectedC.emplace_back(8, f32Ty);
561 allowedShapes.push_back({8, 8, 4});
562 }
563 if (*getMultiplicandAPtxType() == MMATypes::f64) {
564 Type f64Ty = Float64Type::get(context);
565 expectedA.emplace_back(1, f64Ty);
566 expectedB.emplace_back(1, f64Ty);
567 expectedC.emplace_back(2, f64Ty);
568 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
569 context, SmallVector<Type>(2, f64Ty)));
570 allowedShapes.push_back({8, 8, 4});
571 }
572 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
573 expectedA.push_back({i32Ty});
574 expectedB.push_back({i32Ty});
575 expectedC.push_back({i32Ty, i32Ty});
576 expectedResult.push_back(s32x2StructTy);
577 if (isInt4PtxType(getMultiplicandAPtxType().value()))
578 allowedShapes.push_back({8, 8, 32});
579 if (isInt8PtxType(getMultiplicandAPtxType().value()))
580 allowedShapes.push_back({8, 8, 16});
581 if (getMultiplicandAPtxType().value() == MMATypes::b1)
582 allowedShapes.push_back({8, 8, 128});
583 }
584 }
585
586 std::string errorMessage;
587 llvm::raw_string_ostream errorStream(errorMessage);
588
589 // Check that we matched an existing shape/dtype combination.
590 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
591 !llvm::is_contained(allowedShapes, mmaShape)) {
592 errorStream << "unimplemented variant for MMA shape <";
593 llvm::interleaveComma(mmaShape, errorStream);
594 errorStream << ">";
595 return emitOpError(errorMessage);
596 }
597
598 // Verify the operand types for segments of A, B, and C operands.
599 std::array<StringRef, 3> operandNames{"A", "B", "C"};
600 for (const auto &iter : llvm::enumerate(
601 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
602 auto spec = this->getODSOperandIndexAndLength(iter.index());
603 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
604 operand_type_begin() + spec.first +
605 spec.second);
606 bool match = llvm::is_contained(iter.value(), operandTySeg);
607
608 if (!match) {
609 errorStream << "Could not match types for the "
610 << operandNames[iter.index()]
611 << " operands; expected one of ";
612 for (const auto &x : iter.value()) {
613 errorStream << x.size() << "x" << x[0] << " ";
614 }
615 errorStream << "but got ";
616 llvm::interleaveComma(operandTySeg, errorStream);
617 return emitOpError(errorMessage);
618 }
619 }
620
621 // Check the result type
622 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
623 return expectedResultType == getResult().getType();
624 })) {
625 errorStream
626 << "Could not match allowed types for the result; expected one of ";
627 llvm::interleaveComma(expectedResult, errorStream);
628 errorStream << " but got " << getResult().getType();
629 return emitOpError(errorMessage);
630 }
631
632 // Ensure that binary MMA variants have a b1 MMA operation defined.
633 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
634 return emitOpError("op requires " + getB1OpAttrName().strref() +
635 " attribute");
636 }
637
638 // Ensure int4/int8 MMA variants specify the accum overflow behavior
639 // attribute.
640 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
641 isInt8PtxType(*getMultiplicandAPtxType())) {
642 if (!getIntOverflowBehavior())
643 return emitOpError("op requires " +
644 getIntOverflowBehaviorAttrName().strref() +
645 " attribute");
646 }
647
648 return success();
649}
650
651LogicalResult ShflOp::verify() {
652 if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
653 return success();
654 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
655 auto elementType = (type && type.getBody().size() == 2)
656 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
657 : nullptr;
658 if (!elementType || elementType.getWidth() != 1)
659 return emitError("expected return type to be a two-element struct with "
660 "i1 as the second element");
661 return success();
662}
663
664std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
665 NVVM::MMAFrag frag, int nRow,
666 int nCol,
667 MLIRContext *context) {
668 unsigned numberElements = 0;
669 Type elementType;
670 OpBuilder builder(context);
671 Type f16x2 = VectorType::get(2, builder.getF16Type());
672 if (type == NVVM::MMATypes::f16) {
673 elementType = f16x2;
674 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
675 numberElements = 8;
676 else
677 numberElements = 4;
678 } else if (type == NVVM::MMATypes::f32) {
679 elementType = builder.getF32Type();
680 numberElements = 8;
681 } else if (type == NVVM::MMATypes::tf32) {
682 elementType = builder.getI32Type();
683 numberElements = 4;
684 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
685 elementType = builder.getI32Type();
686 int parallelSize = 0;
687 if (frag == NVVM::MMAFrag::a)
688 parallelSize = nRow;
689 if (frag == NVVM::MMAFrag::b)
690 parallelSize = nCol;
691
692 // m == 16 && n == 16 && k == 16
693 if (parallelSize == 16)
694 numberElements = 2;
695 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
696 else if (parallelSize == 8)
697 numberElements = 1;
698 else if (parallelSize == 32)
699 numberElements = 4;
700 } else if (type == NVVM::MMATypes::s32) {
701 elementType = builder.getI32Type();
702 numberElements = 8;
703 }
704 assert(numberElements != 0 && elementType != nullptr);
705 return std::make_pair(x&: elementType, y&: numberElements);
706}
707
708static std::pair<mlir::Type, unsigned>
709inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
710 int k, MLIRContext *context) {
711 int nRow, nCol;
712 if (frag == NVVM::MMAFrag::a) {
713 nRow = m;
714 nCol = k;
715 } else if (frag == NVVM::MMAFrag::b) {
716 nRow = k;
717 nCol = n;
718 } else {
719 nRow = m;
720 nCol = n;
721 }
722 assert(nRow && nCol);
723 return inferMMAType(type, frag, nRow, nCol, context);
724}
725
726LogicalResult NVVM::WMMALoadOp::verify() {
727 unsigned addressSpace =
728 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
729 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
730 addressSpace != NVVM::kSharedMemorySpace)
731 return emitOpError("expected source pointer in memory "
732 "space 0, 1, 3");
733
734 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
735 getEltype(), getFrag()) == 0)
736 return emitOpError() << "invalid attribute combination";
737 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
738 getEltype(), getFrag(), getM(), getN(), getK(), getContext());
739 Type dstType = LLVM::LLVMStructType::getLiteral(
740 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
741 if (getType() != dstType)
742 return emitOpError("expected destination type is a structure of ")
743 << typeInfo.second << " elements of type " << typeInfo.first;
744 return success();
745}
746
747LogicalResult NVVM::WMMAStoreOp::verify() {
748 unsigned addressSpace =
749 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
750 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
751 addressSpace != NVVM::kSharedMemorySpace)
752 return emitOpError("expected operands to be a source pointer in memory "
753 "space 0, 1, 3");
754
755 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
756 getEltype()) == 0)
757 return emitOpError() << "invalid attribute combination";
758 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
759 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
760 if (getArgs().size() != typeInfo.second)
761 return emitOpError() << "expected " << typeInfo.second << " data operands";
762 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
763 return operands.getType() != typeInfo.first;
764 }))
765 return emitOpError() << "expected data operands of type " << typeInfo.first;
766 return success();
767}
768
769LogicalResult NVVM::WMMAMmaOp::verify() {
770 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
771 getLayoutB(), getEltypeA(),
772 getEltypeB()) == 0)
773 return emitOpError() << "invalid attribute combination";
774 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
775 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
776 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
777 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
778 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
779 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
780 SmallVector<Type, 32> arguments;
781 arguments.append(typeInfoA.second, typeInfoA.first);
782 arguments.append(typeInfoB.second, typeInfoB.first);
783 arguments.append(typeInfoC.second, typeInfoC.first);
784 unsigned numArgs = arguments.size();
785 if (getArgs().size() != numArgs)
786 return emitOpError() << "expected " << numArgs << " arguments";
787 for (unsigned i = 0; i < numArgs; i++) {
788 if (getArgs()[i].getType() != arguments[i])
789 return emitOpError() << "expected argument " << i << " to be of type "
790 << arguments[i];
791 }
792 Type dstType = LLVM::LLVMStructType::getLiteral(
793 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
794 if (getType() != dstType)
795 return emitOpError("expected destination type is a structure of ")
796 << typeInfoC.second << " elements of type " << typeInfoC.first;
797 return success();
798}
799
800LogicalResult NVVM::LdMatrixOp::verify() {
801 unsigned addressSpace =
802 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
803 if (addressSpace != NVVM::kSharedMemorySpace)
804 return emitOpError("expected source pointer in memory space 3");
805
806 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
807 return emitOpError("expected num attribute to be 1, 2 or 4");
808
809 Type i32 = IntegerType::get(getContext(), 32);
810 if (getNum() == 1 && getType() != i32)
811 return emitOpError("expected destination type is i32");
812 if (getNum() == 2 || getNum() == 4) {
813 Type dstType = LLVM::LLVMStructType::getLiteral(
814 getContext(), SmallVector<Type>(getNum(), i32));
815 if (getType() != dstType)
816 return emitOpError("expected destination type is a structure of ")
817 << getNum() << " elements of type i32";
818 }
819 return success();
820}
821
822LogicalResult NVVM::StMatrixOp::verify() {
823 unsigned addressSpace =
824 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
825 if (addressSpace != NVVM::kSharedMemorySpace)
826 return emitOpError("expected source pointer in memory space 3");
827
828 int numMatrix = getSources().size();
829 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
830 return emitOpError("expected num attribute to be 1, 2 or 4");
831
832 return success();
833}
834
835FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
836 if (typeA == NVVM::WGMMATypes::tf32)
837 return 8;
838 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
839 return 16;
840 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
841 return 32;
842 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
843 return 32;
844 if (typeA == NVVM::WGMMATypes::b1)
845 return 256;
846 return failure();
847}
848
849LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
850 NVVM::WGMMATypes typeA,
851 NVVM::WGMMATypes typeB) {
852 switch (typeA) {
853 case NVVM::WGMMATypes::f16:
854 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
855 typeB == NVVM::WGMMATypes::f16)
856 return success();
857 break;
858 case NVVM::WGMMATypes::tf32:
859 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
860 return success();
861 break;
862 case NVVM::WGMMATypes::u8:
863 case NVVM::WGMMATypes::s8:
864 if (typeD == NVVM::WGMMATypes::s32 &&
865 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
866 return success();
867 break;
868 case NVVM::WGMMATypes::b1:
869 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
870 return success();
871 break;
872 case NVVM::WGMMATypes::bf16:
873 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
874 typeB == NVVM::WGMMATypes::bf16)
875 return success();
876 break;
877 case NVVM::WGMMATypes::e4m3:
878 case NVVM::WGMMATypes::e5m2:
879 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
880 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
881 return success();
882 break;
883 case WGMMATypes::f32:
884 case WGMMATypes::s32:
885 llvm_unreachable("unsupported input types");
886 break;
887 }
888 return failure();
889}
890
891LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
892 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
893 72, 80, 88, 96, 104, 112, 120, 128,
894 136, 144, 152, 160, 168, 176, 184, 192,
895 200, 208, 216, 224, 232, 240, 248, 256};
896 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
897 80, 96, 112, 128, 144, 160,
898 176, 192, 208, 224, 240, 256};
899 switch (typeA) {
900 case WGMMATypes::f16:
901 case WGMMATypes::tf32:
902 case WGMMATypes::bf16:
903 case WGMMATypes::e4m3:
904 case WGMMATypes::e5m2:
905 if (llvm::is_contained(Range&: allowedN, Element: sizeN))
906 return success();
907 break;
908 case WGMMATypes::u8:
909 case WGMMATypes::s8:
910 case WGMMATypes::b1:
911 if (llvm::is_contained(Range&: allowedNshort, Element: sizeN))
912 return success();
913 break;
914 case WGMMATypes::f32:
915 case WGMMATypes::s32:
916 llvm_unreachable("unsupported input types");
917 break;
918 }
919 return failure();
920}
921
922LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
923 Value outValue = getResults();
924 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
925 if (!stype)
926 return emitOpError() << "expected results to be struct";
927 int outputSize = stype.getBody().size();
928 WGMMATypes typeD = getTypeD();
929 WGMMATypes typeA = getTypeA();
930 WGMMATypes typeB = getTypeB();
931
932 for (Type t : stype.getBody()) {
933 if (t != stype.getBody().front())
934 return emitOpError()
935 << "all elements in struct must be same type but there is " << t;
936 }
937
938 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
939 typeD != WGMMATypes::s32) {
940 return emitOpError() << "does not support the given output type "
941 << NVVM::stringifyWGMMATypes(typeD);
942 }
943 if (typeD == WGMMATypes::s32 &&
944 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
945 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
946 }
947
948 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
949 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
950 << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
951 << NVVM::stringifyWGMMATypes(typeB)
952 << ", it is not supported.";
953 }
954
955 // Check M
956 if (getShape().getM() != 64)
957 return emitOpError() << "shape 'm' must be 64";
958
959 // Check K
960 FailureOr<int> allowedK = getAllowedSizeK(typeA);
961 if (failed(allowedK) || allowedK.value() != getShape().getK())
962 return emitOpError() << "shape 'k' must be " << allowedK.value()
963 << " for input type "
964 << NVVM::stringifyWGMMATypes(typeA);
965
966 // Check N
967 if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
968 return emitOpError() << "has input type "
969 << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
970 << getShape().getN() << ", it is not supported.";
971 }
972
973 // Check transpose (only available for f16/bf16)
974 // Matrices A should be stored in row-major and B in column-major.
975 // Only f16/bf16 matrices can be stored in either column-major or row-major
976 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
977 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
978 (getLayoutA() == mlir::NVVM::MMALayout::col ||
979 getLayoutB() == mlir::NVVM::MMALayout::row)) {
980 return emitOpError()
981 << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
982 << " and layout_b = " << stringifyMMALayout(getLayoutB())
983 << " for input types " << stringifyWGMMATypes(typeA) << " and "
984 << stringifyWGMMATypes(typeB)
985 << " requires transpose. However, this is only supported for: "
986 << stringifyMMATypes(MMATypes::f16) << " and "
987 << stringifyMMATypes(MMATypes::bf16);
988 }
989
990 // Check result registers
991 int expectedOutput = 0;
992 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
993 expectedOutput = getShape().getN() / 2;
994 if (typeD == WGMMATypes::f16)
995 expectedOutput = getShape().getN() / 4;
996 if (outputSize != expectedOutput) {
997 return emitOpError() << "results " << expectedOutput
998 << ", however output struct has " << outputSize
999 << " elements";
1000 }
1001 // Check satfinite (only available for s32 accumulator)
1002 if (typeD != WGMMATypes::s32 &&
1003 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1004 NVVM::MMAIntOverflow::satfinite) {
1005 return emitOpError()
1006 << " `satfinite` can be only used with s32 accumulator, however "
1007 "the current accumulator is "
1008 << NVVM::stringifyWGMMATypes(typeD);
1009 }
1010
1011 return success();
1012}
1013
1014std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
1015
1016 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
1017 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1018
1019 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
1020
1021 int expectedOutputRegisters = 0;
1022 if (getTypeD() == WGMMATypes::f16)
1023 expectedOutputRegisters = getShape().getN() / 4;
1024 else
1025 expectedOutputRegisters = getShape().getN() / 2;
1026
1027 std::string ptx;
1028 llvm::raw_string_ostream ss(ptx);
1029
1030 ss << "{\n"
1031 ".reg .pred p;\n"
1032 "setp.ne.b32 p, $"
1033 << ((expectedOutputRegisters * 2) + 2)
1034 << ", 0;\n"
1035 "wgmma.mma_async.sync.aligned.m"
1036 << m << "n" << n << "k" << k << "." << outputTypeName << "."
1037 << stringifyWGMMATypes(getTypeA()) << "."
1038 << stringifyWGMMATypes(getTypeB());
1039 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1040 NVVM::MMAIntOverflow::satfinite)
1041 ss << ".satfinite";
1042 ss << " {";
1043 int regCnt = 0;
1044 for (; regCnt < expectedOutputRegisters; ++regCnt) {
1045 ss << "$" << regCnt;
1046 if (regCnt != expectedOutputRegisters - 1)
1047 ss << ", ";
1048 }
1049
1050 ss << "},";
1051 // Need to map read/write registers correctly.
1052 regCnt = (regCnt * 2);
1053 ss << " $" << (regCnt) << ","
1054 << " $" << (regCnt + 1) << ","
1055 << " p";
1056 if (getTypeD() != WGMMATypes::s32) {
1057 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
1058 }
1059 // Don't add transpose parameters unless needed.
1060 if (isF16) {
1061 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
1062 }
1063 ss << ";\n"
1064 << "}\n";
1065 return ptx;
1066}
1067
1068void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1069 RewriterBase &rewriter,
1070 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1071 &asmValues) {
1072 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1073 if (getResults())
1074 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1075 if (getInouts())
1076 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1077 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1078 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1079 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1080 mlir::NVVM::PTXRegisterMod::Read});
1081 if (getTypeD() != WGMMATypes::s32) {
1082 asmValues.push_back(
1083 {makeConstantI32(rewriter,
1084 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1085 mlir::NVVM::PTXRegisterMod::Read});
1086 asmValues.push_back(
1087 {makeConstantI32(rewriter,
1088 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1089 mlir::NVVM::PTXRegisterMod::Read});
1090 }
1091 if (isF16) {
1092 asmValues.push_back(
1093 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1094 mlir::NVVM::PTXRegisterMod::Read});
1095 asmValues.push_back(
1096 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1097 mlir::NVVM::PTXRegisterMod::Read});
1098 }
1099}
1100LogicalResult NVVM::FenceProxyOp::verify() {
1101 if (getKind() == NVVM::ProxyKind::TENSORMAP)
1102 return emitOpError() << "tensormap proxy is not a supported proxy kind";
1103 if (getKind() == NVVM::ProxyKind::GENERIC)
1104 return emitOpError() << "generic proxy not a supported proxy kind";
1105 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1106 return emitOpError() << "async_shared fence requires space attribute";
1107 }
1108 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1109 return emitOpError() << "only async_shared fence can have space attribute";
1110 }
1111 return success();
1112}
1113
1114LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1115 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1116 return emitOpError("uni-directional proxies only support generic for "
1117 "from_proxy attribute");
1118
1119 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1120 return emitOpError("uni-directional proxies only support tensormap "
1121 "for to_proxy attribute");
1122
1123 return success();
1124}
1125
1126LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1127 if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1128 return emitOpError("uni-directional proxies only support generic for "
1129 "from_proxy attribute");
1130
1131 if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1132 return emitOpError("uni-directional proxies only support tensormap "
1133 "for to_proxy attribute");
1134
1135 return success();
1136}
1137
1138LogicalResult NVVM::SetMaxRegisterOp::verify() {
1139 if (getRegCount() % 8)
1140 return emitOpError("new register size must be multiple of 8");
1141 if (getRegCount() < 24 || getRegCount() > 256)
1142 return emitOpError("new register size must be in between 24 to 256");
1143 return success();
1144}
1145
1146LogicalResult NVVM::BarrierOp::verify() {
1147 if (getNumberOfThreads() && !getBarrierId())
1148 return emitOpError(
1149 "barrier id is missing, it should be set between 0 to 15");
1150 return success();
1151}
1152
1153LogicalResult NVVM::Tcgen05CpOp::verify() {
1154 auto mc = getMulticast();
1155
1156 using SH = Tcgen05CpShape;
1157 using MC = Tcgen05CpMulticast;
1158 switch (getShape()) {
1159 case SH::SHAPE_128x256b:
1160 case SH::SHAPE_128x128b:
1161 case SH::SHAPE_4x256b:
1162 if (mc != MC::NONE)
1163 return emitError("Invalid multicast type for tcgen05.cp Op");
1164 break;
1165 case SH::SHAPE_64x128b:
1166 if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1167 return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1168 "warpx2_02_13 for tcgen05.cp Op");
1169 break;
1170 case SH::SHAPE_32x128b:
1171 if (mc != MC::WARPX4)
1172 return emitError(
1173 "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1174 break;
1175 }
1176 return success();
1177}
1178
1179LogicalResult NVVM::MatchSyncOp::verify() {
1180 if (getKind() == NVVM::MatchSyncKind::all) {
1181 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1182 if (!type || type.getBody().size() != 2 ||
1183 !type.getBody()[0].isInteger(32) || !type.getBody()[1].isInteger(1)) {
1184 return emitOpError("match.sync 'all' returns a two element struct with "
1185 "first element as i32 and second element as i1");
1186 }
1187 } else {
1188 if (!getType().isInteger(32)) {
1189 return emitOpError("match.sync 'any' returns an i32");
1190 }
1191 }
1192 return success();
1193}
1194
1195LogicalResult NVVM::VoteSyncOp::verify() {
1196 if (getKind() == NVVM::VoteSyncKind::ballot) {
1197 if (!getType().isInteger(32)) {
1198 return emitOpError("vote.sync 'ballot' returns an i32");
1199 }
1200 } else {
1201 if (!getType().isInteger(1)) {
1202 return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1");
1203 }
1204 }
1205 return success();
1206}
1207
1208LogicalResult NVVM::PrefetchOp::verify() {
1209 using MemSpace = NVVM::NVVMMemorySpace;
1210 using CacheLevel = NVVM::PrefetchCacheLevel;
1211
1212 unsigned addressSpace =
1213 llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
1214 std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1215
1216 if (getUniform()) {
1217 if (getCacheLevel() != CacheLevel::L1)
1218 return emitOpError("unsupported cache level, the only supported uniform "
1219 "cache level is L1");
1220
1221 if (addressSpace != MemSpace::kGenericMemorySpace)
1222 return emitOpError(
1223 "prefetch to uniform cache requires a generic pointer");
1224 }
1225
1226 if (evictPriority) {
1227 if (getCacheLevel() != CacheLevel::L2)
1228 return emitOpError(
1229 "cache eviction priority supported only for cache level L2");
1230
1231 if (addressSpace != MemSpace::kGlobalMemorySpace)
1232 return emitOpError("cache eviction priority requires a global pointer");
1233
1234 if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1235 *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1236 return emitOpError(
1237 "unsupported cache eviction priority, only evict_last and "
1238 "evict_normal are supported");
1239 }
1240
1241 return success();
1242}
1243
1244/// Packs the given `field` into the `result`.
1245/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
1246static llvm::Value *
1247packValInto64Bits(llvm::IRBuilderBase &builder,
1248 llvm::Value *result, // the `result` (unset bits are zero)
1249 llvm::Value *field, // `field` to pack into `result`
1250 unsigned sizeInBits, // Size of `field` in bits
1251 unsigned start) { // Starting bit within `result`
1252 field = builder.CreateZExtOrBitCast(V: field, DestTy: builder.getInt32Ty());
1253
1254 unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1255 if (mask != 0xffffffffu)
1256 field = builder.CreateAnd(LHS: field, RHS: builder.getInt32(C: mask));
1257
1258 field = builder.CreateZExtOrBitCast(V: field, DestTy: builder.getInt64Ty());
1259 field = builder.CreateShl(LHS: field, RHS: start);
1260
1261 return builder.CreateOr(LHS: result, RHS: field);
1262}
1263
1264void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
1265 LLVM::ModuleTranslation &mt,
1266 llvm::IRBuilderBase &builder) {
1267 auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1268 llvm::Value *smemDesc = builder.getInt64(0);
1269
1270 smemDesc = packValInto64Bits(builder, smemDesc,
1271 mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1272 smemDesc = packValInto64Bits(
1273 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1274 smemDesc = packValInto64Bits(
1275 builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1276
1277 smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
1278 smemDesc = packValInto64Bits(builder, smemDesc,
1279 mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1280 smemDesc = packValInto64Bits(
1281 builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1282 smemDesc = packValInto64Bits(builder, smemDesc,
1283 mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1284
1285 mt.mapValue(thisOp.getRes()) = smemDesc;
1286}
1287
1288//===----------------------------------------------------------------------===//
1289// getIntrinsicID/getIntrinsicIDAndArgs methods
1290//===----------------------------------------------------------------------===//
1291
1292#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
1293 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1294
1295#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \
1296 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1297
1298llvm::Intrinsic::ID
1299CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1300 llvm::SmallVector<llvm::Value *> &args) {
1301 llvm::Intrinsic::ID id;
1302
1303 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1304 bool hasCpSize = static_cast<bool>(cpAsyncOp.getCpSize());
1305 switch (cpAsyncOp.getSize()) {
1306 case 4:
1307 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1308 break;
1309 case 8:
1310 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1311 break;
1312 case 16:
1313 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1314 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1315 : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1316 break;
1317 default:
1318 llvm_unreachable("Invalid copy size in CpAsyncOp.");
1319 }
1320
1321 // Fill the Intrinsic Args
1322 args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1323 args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1324 if (hasCpSize)
1325 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1326
1327 return id;
1328}
1329
1330mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
1331 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1332 auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
1333 llvm::SmallVector<llvm::Value *> args;
1334 llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;
1335
1336 // Fill the Intrinsic Args
1337 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1338 args.push_back(mt.lookupValue(thisOp.getSize()));
1339
1340 mlir::Value cacheHint = thisOp.getL2CacheHint();
1341 const bool hasCacheHint = static_cast<bool>(cacheHint);
1342 llvm::Value *i64Unused =
1343 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1344 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1345 args.push_back(builder.getInt1(hasCacheHint));
1346
1347 return {id, std::move(args)};
1348}
1349
1350mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1351 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1352 auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1353 llvm::SmallVector<llvm::Value *> args;
1354 llvm::Intrinsic::ID id =
1355 llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1356
1357 // Fill the Intrinsic Args
1358 args.push_back(mt.lookupValue(thisOp.getDstMem()));
1359 args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1360 args.push_back(mt.lookupValue(thisOp.getSize()));
1361
1362 mlir::Value cacheHint = thisOp.getL2CacheHint();
1363 const bool hasCacheHint = static_cast<bool>(cacheHint);
1364 llvm::Value *i64Unused =
1365 llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1366 args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1367 args.push_back(builder.getInt1(hasCacheHint));
1368
1369 // Choose the bytemask variant
1370 if (mlir::Value byteMask = thisOp.getByteMask()) {
1371 args.push_back(mt.lookupValue(byteMask));
1372 id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1373 }
1374
1375 return {id, std::move(args)};
1376}
1377
1378llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1379 bool isIm2Col) {
1380 switch (tensorDims) {
1381 case 1:
1382 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1383 case 2:
1384 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1385 case 3:
1386 return isIm2Col
1387 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1388 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1389 case 4:
1390 return isIm2Col
1391 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1392 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1393 case 5:
1394 return isIm2Col
1395 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1396 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1397 default:
1398 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1399 }
1400}
1401
1402#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1403 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1404
1405#define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1406 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1407 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1408
1409#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1410 [&]() -> auto { \
1411 switch (dims) { \
1412 case 1: \
1413 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1414 case 2: \
1415 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1416 case 3: \
1417 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1418 case 4: \
1419 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1420 case 5: \
1421 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1422 default: \
1423 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1424 } \
1425 }()
1426
1427llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1428 int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1429 using RedTy = NVVM::TMAReduxKind;
1430 switch (kind) {
1431 case RedTy::ADD:
1432 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1433 case RedTy::MIN:
1434 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1435 case RedTy::MAX:
1436 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1437 case RedTy::INC:
1438 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1439 case RedTy::DEC:
1440 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1441 case RedTy::AND:
1442 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1443 case RedTy::OR:
1444 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1445 case RedTy::XOR:
1446 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1447 }
1448 llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1449}
1450
1451#define _none
1452
1453#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1454 hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
1455 : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
1456
1457#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
1458 hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
1459 : CVT_F2TF32_ID_IMPL(rnd, relu, )
1460
1461llvm::Intrinsic::ID
1462ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1463 NVVM::SaturationMode sat, bool hasRelu) {
1464 using RndMode = NVVM::FPRoundingMode;
1465 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1466 switch (rnd) {
1467 case RndMode::RN:
1468 return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
1469 case RndMode::RZ:
1470 return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
1471 case RndMode::RNA:
1472 return GET_CVT_F2TF32_ID(rna, _none, _satfinite);
1473 default:
1474 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
1475 }
1476}
1477
1478#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
1479 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
1480 : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1481
1482llvm::Intrinsic::ID
1483ConvertF32x2ToF6x2Op::getIntrinsicID(NVVM::ConvertFP6Type type, bool hasRelu) {
1484 switch (type) {
1485 case NVVM::ConvertFP6Type::E2M3:
1486 return GET_F32x2_TO_F6x2_ID(e2m3x2, hasRelu);
1487 case NVVM::ConvertFP6Type::E3M2:
1488 return GET_F32x2_TO_F6x2_ID(e3m2x2, hasRelu);
1489 }
1490 llvm_unreachable("Invalid conversion in ConvertF32x2ToF6x2Op");
1491}
1492
1493#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
1494 has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1495 : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1496
1497#define GET_F32x2_TO_F8X2_S_ID(type, has_relu) \
1498 has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1499 : llvm::Intrinsic::nvvm_ff_to_##type##_rn
1500
1501llvm::Intrinsic::ID
1502ConvertF32x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type,
1503 NVVM::FPRoundingMode rnd,
1504 NVVM::SaturationMode sat, bool hasRelu) {
1505 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1506 bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1507 bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1508
1509 switch (type) {
1510 case NVVM::ConvertFP8Type::E4M3:
1511 return GET_F32x2_TO_F8X2_S_ID(e4m3x2, hasRelu);
1512 case NVVM::ConvertFP8Type::E5M2:
1513 return GET_F32x2_TO_F8X2_S_ID(e5m2x2, hasRelu);
1514 case NVVM::ConvertFP8Type::UE8M0:
1515 if (hasRoundingModeRZ)
1516 return GET_F32x2_TO_F8X2_US_ID(rz, hasSatFinite);
1517 else if (hasRoundingModeRP)
1518 return GET_F32x2_TO_F8X2_US_ID(rp, hasSatFinite);
1519 }
1520 llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
1521}
1522
1523#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1524 has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1525 : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1526
1527llvm::Intrinsic::ID
1528ConvertF16x2ToF8x2Op::getIntrinsicID(NVVM::ConvertFP8Type type, bool hasRelu) {
1529 switch (type) {
1530 case NVVM::ConvertFP8Type::E4M3:
1531 return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1532 case NVVM::ConvertFP8Type::E5M2:
1533 return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1534 default:
1535 llvm_unreachable("Invalid ConvertFP8Type for CvtF16x2ToF8x2Op");
1536 }
1537}
1538
1539#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1540 has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1541 : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1542
1543llvm::Intrinsic::ID
1544ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1545 NVVM::SaturationMode sat) {
1546 bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1547 switch (rnd) {
1548 case NVVM::FPRoundingMode::RZ:
1549 return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
1550 case NVVM::FPRoundingMode::RP:
1551 return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
1552 default:
1553 llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
1554 }
1555}
1556
1557llvm::Intrinsic::ID
1558Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
1559 LLVM::ModuleTranslation &mt,
1560 llvm::SmallVector<llvm::Value *> &args) {
1561 auto curOp = cast<NVVM::Tcgen05AllocOp>(op);
1562 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1563 .getAddressSpace();
1564 bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1565 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1566
1567 llvm::Intrinsic::ID id;
1568 if (isShared) {
1569 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg2
1570 : llvm::Intrinsic::nvvm_tcgen05_alloc_shared_cg1;
1571 } else {
1572 id = is2CTAMode ? llvm::Intrinsic::nvvm_tcgen05_alloc_cg2
1573 : llvm::Intrinsic::nvvm_tcgen05_alloc_cg1;
1574 }
1575
1576 // Fill the Intrinsic Args
1577 args.push_back(mt.lookupValue(curOp.getAddr()));
1578 args.push_back(mt.lookupValue(curOp.getNCols()));
1579
1580 return id;
1581}
1582
1583llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs(
1584 Operation &op, LLVM::ModuleTranslation &mt,
1585 llvm::SmallVector<llvm::Value *> &args) {
1586 auto curOp = cast<NVVM::Tcgen05DeallocOp>(op);
1587 auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1)
1588 ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1
1589 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2;
1590
1591 // Fill the Intrinsic Args
1592 args.push_back(mt.lookupValue(curOp.getTaddr()));
1593 args.push_back(mt.lookupValue(curOp.getNCols()));
1594
1595 return id;
1596}
1597
1598#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \
1599 is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
1600 : llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg
1601
1602#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc) \
1603 has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
1604 : TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
1605
1606llvm::Intrinsic::ID
1607Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
1608 LLVM::ModuleTranslation &mt,
1609 llvm::SmallVector<llvm::Value *> &args) {
1610 auto curOp = cast<NVVM::Tcgen05CommitOp>(op);
1611 unsigned as = llvm::cast<LLVM::LLVMPointerType>(curOp.getAddr().getType())
1612 .getAddressSpace();
1613 bool isShared = as == NVVMMemorySpace::kSharedMemorySpace;
1614 bool hasMulticast = static_cast<bool>(curOp.getMulticastMask());
1615 bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1616
1617 llvm::Intrinsic::ID id =
1618 is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast)
1619 : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast);
1620
1621 // Fill the Intrinsic Args
1622 args.push_back(mt.lookupValue(curOp.getAddr()));
1623 if (hasMulticast)
1624 args.push_back(mt.lookupValue(curOp.getMulticastMask()));
1625
1626 return id;
1627}
1628
1629#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1630 llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1631
1632#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1633 is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1634 : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1635
1636#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1637 [&]() -> auto { \
1638 if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
1639 return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1640 if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
1641 return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1642 return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1643 }()
1644
1645llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
1646 auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1647 bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1648 auto srcFmt = curOp.getSrcFormat();
1649 auto mc = curOp.getMulticast();
1650
1651 switch (curOp.getShape()) {
1652 case Tcgen05CpShape::SHAPE_128x256b:
1653 return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
1654 case Tcgen05CpShape::SHAPE_128x128b:
1655 return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
1656 case Tcgen05CpShape::SHAPE_4x256b:
1657 return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
1658 case Tcgen05CpShape::SHAPE_32x128b:
1659 return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
1660 case Tcgen05CpShape::SHAPE_64x128b:
1661 return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1662 ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
1663 : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
1664 }
1665 llvm_unreachable("Invalid shape in tcgen05 cp Op");
1666}
1667
1668// Returns the valid vector length for a given shape and vector length, the
1669// function models the table mentioned in the tcgen05.{ld, st} Op description
1670static unsigned isValidVectorLength(NVVM::Tcgen05LdStShape shape,
1671 unsigned vecLen) {
1672 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X128B)
1673 return vecLen >= 2;
1674 if (shape == NVVM::Tcgen05LdStShape::SHAPE_16X256B)
1675 return vecLen >= 4;
1676 return true;
1677}
1678
1679LogicalResult Tcgen05LdOp::verify() {
1680 LogicalResult result = success();
1681 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1682 result = emitError("shape 16x32bx2 requires offset argument");
1683
1684 auto resTy = getRes().getType();
1685 unsigned resLen = isa<VectorType>(resTy)
1686 ? llvm::cast<VectorType>(resTy).getNumElements()
1687 : 1;
1688 if (!isValidVectorLength(getShape(), resLen))
1689 result = emitError(llvm::formatv("invalid result type length {0} for shape "
1690 "{1} in tcgen05.ld Op",
1691 resLen, stringifyEnum(getShape())));
1692
1693 return result;
1694}
1695
1696LogicalResult Tcgen05StOp::verify() {
1697 LogicalResult result = success();
1698 if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
1699 result = emitError("shape 16x32bx2 requires offset argument");
1700
1701 auto valTy = getVal().getType();
1702 unsigned valLen = isa<VectorType>(valTy)
1703 ? llvm::cast<VectorType>(valTy).getNumElements()
1704 : 1;
1705 if (!isValidVectorLength(getShape(), valLen))
1706 result = emitError(llvm::formatv("invalid input length {0} for shape "
1707 "{1} in tcgen05.st Op",
1708 valLen, stringifyEnum(getShape())));
1709
1710 return result;
1711}
1712
1713/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
1714/// have ConstantRangeAttr.
1715static void nvvmInferResultRanges(Operation *op, Value result,
1716 ArrayRef<::mlir::ConstantIntRanges> argRanges,
1717 SetIntRangeFn setResultRanges) {
1718 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
1719 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1720 rangeAttr.getLower(), rangeAttr.getUpper()});
1721 }
1722}
1723
1724static llvm::Value *getAsPackedI32(llvm::Value *arg,
1725 llvm::IRBuilderBase &builder) {
1726 return builder.CreateBitCast(V: arg,
1727 DestTy: llvm::Type::getInt32Ty(C&: builder.getContext()));
1728}
1729
1730NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
1731 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1732 auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
1733
1734 llvm::SmallVector<llvm::Value *> args;
1735 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
1736 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
1737 args.push_back(mt.lookupValue(curOp.getC()));
1738
1739 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1740 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1741 unsigned type = (isASigned << 1) | isBSigned;
1742 const llvm::Intrinsic::ID ids[] = {
1743 llvm::Intrinsic::nvvm_idp4a_u_u,
1744 llvm::Intrinsic::nvvm_idp4a_u_s,
1745 llvm::Intrinsic::nvvm_idp4a_s_u,
1746 llvm::Intrinsic::nvvm_idp4a_s_s,
1747 };
1748 return {ids[type], args};
1749}
1750
1751NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
1752 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1753 auto curOp = cast<NVVM::DotAccumulate2WayOp>(op);
1754
1755 llvm::SmallVector<llvm::Value *> args;
1756 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
1757 args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
1758 args.push_back(builder.getInt1(curOp.getBHi()));
1759 args.push_back(mt.lookupValue(curOp.getC()));
1760
1761 bool isASigned = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
1762 bool isBSigned = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
1763 unsigned type = (isASigned << 1) | isBSigned;
1764 const llvm::Intrinsic::ID ids[] = {
1765 llvm::Intrinsic::nvvm_idp2a_u_u,
1766 llvm::Intrinsic::nvvm_idp2a_u_s,
1767 llvm::Intrinsic::nvvm_idp2a_s_u,
1768 llvm::Intrinsic::nvvm_idp2a_s_s,
1769 };
1770 return {ids[type], args};
1771}
1772
1773llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
1774 using MemSpace = NVVM::NVVMMemorySpace;
1775 using CacheLevel = NVVM::PrefetchCacheLevel;
1776
1777 NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
1778 std::optional<NVVM::CacheEvictionPriority> evictPriority =
1779 op.getEvictPriority();
1780 unsigned addressSpace =
1781 llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
1782 .getAddressSpace();
1783
1784 if (op.getUniform() && cacheLevel == CacheLevel::L1)
1785 return llvm::Intrinsic::nvvm_prefetchu_L1;
1786
1787 if (evictPriority && cacheLevel == CacheLevel::L2) {
1788 switch (*evictPriority) {
1789 case NVVM::CacheEvictionPriority::EvictLast:
1790 return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1791 case NVVM::CacheEvictionPriority::EvictNormal:
1792 return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
1793 default:
1794 llvm_unreachable("Invalid cache eviction priority");
1795 }
1796 }
1797
1798 switch (addressSpace) {
1799 case MemSpace::kGenericMemorySpace:
1800 return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
1801 : llvm::Intrinsic::nvvm_prefetch_L2;
1802 case MemSpace::kGlobalMemorySpace:
1803 return cacheLevel == CacheLevel::L1
1804 ? llvm::Intrinsic::nvvm_prefetch_global_L1
1805 : llvm::Intrinsic::nvvm_prefetch_global_L2;
1806 case MemSpace::kLocalMemorySpace:
1807 return cacheLevel == CacheLevel::L1
1808 ? llvm::Intrinsic::nvvm_prefetch_local_L1
1809 : llvm::Intrinsic::nvvm_prefetch_local_L2;
1810 default:
1811 llvm_unreachable("Invalid pointer address space");
1812 }
1813}
1814
1815//===----------------------------------------------------------------------===//
1816// NVVMDialect initialization, type parsing, and registration.
1817//===----------------------------------------------------------------------===//
1818
1819// TODO: This should be the llvm.nvvm dialect once this is supported.
1820void NVVMDialect::initialize() {
1821 addOperations<
1822#define GET_OP_LIST
1823#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1824 >();
1825 addAttributes<
1826#define GET_ATTRDEF_LIST
1827#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1828 >();
1829
1830 // Support unknown operations because not all NVVM operations are
1831 // registered.
1832 allowUnknownOperations();
1833 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1834 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1835}
1836
1837LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1838 NamedAttribute attr) {
1839 StringAttr attrName = attr.getName();
1840 // Kernel function attribute should be attached to functions.
1841 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1842 if (!isa<LLVM::LLVMFuncOp>(op)) {
1843 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1844 << "' attribute attached to unexpected op";
1845 }
1846 }
1847 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
1848 // dim
1849 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1850 attrName == NVVMDialect::getReqntidAttrName() ||
1851 attrName == NVVMDialect::getClusterDimAttrName()) {
1852 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1853 if (!values || values.empty() || values.size() > 3)
1854 return op->emitError()
1855 << "'" << attrName
1856 << "' attribute must be integer array with maximum 3 index";
1857 }
1858 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
1859 // attribute
1860 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1861 attrName == NVVMDialect::getMaxnregAttrName() ||
1862 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1863 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1864 return op->emitError()
1865 << "'" << attrName << "' attribute must be integer constant";
1866 }
1867
1868 return success();
1869}
1870
1871LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1872 unsigned regionIndex,
1873 unsigned argIndex,
1874 NamedAttribute argAttr) {
1875 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1876 if (!funcOp)
1877 return success();
1878
1879 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1880 StringAttr attrName = argAttr.getName();
1881 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1882 if (!isKernel) {
1883 return op->emitError()
1884 << "'" << attrName
1885 << "' attribute must be present only on kernel arguments";
1886 }
1887 if (!isa<UnitAttr>(argAttr.getValue()))
1888 return op->emitError() << "'" << attrName << "' must be a unit attribute";
1889 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1890 return op->emitError()
1891 << "'" << attrName
1892 << "' attribute requires the argument to also have attribute '"
1893 << LLVM::LLVMDialect::getByValAttrName() << "'";
1894 }
1895 }
1896
1897 return success();
1898}
1899
1900//===----------------------------------------------------------------------===//
1901// NVVM target attribute.
1902//===----------------------------------------------------------------------===//
1903LogicalResult
1904NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1905 int optLevel, StringRef triple, StringRef chip,
1906 StringRef features, DictionaryAttr flags,
1907 ArrayAttr files, bool verifyTarget) {
1908 if (optLevel < 0 || optLevel > 3) {
1909 emitError() << "The optimization level must be a number between 0 and 3.";
1910 return failure();
1911 }
1912 if (triple.empty()) {
1913 emitError() << "The target triple cannot be empty.";
1914 return failure();
1915 }
1916 if (chip.empty()) {
1917 emitError() << "The target chip cannot be empty.";
1918 return failure();
1919 }
1920 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1921 return mlir::isa_and_nonnull<StringAttr>(attr);
1922 })) {
1923 emitError() << "All the elements in the `link` array must be strings.";
1924 return failure();
1925 }
1926 return success();
1927}
1928
1929LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
1930 if (!getVerifyTarget())
1931 return success();
1932
1933 auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
1934 if (!gpuModuleOp) {
1935 return emitError(gpuModule->getLoc(),
1936 "NVVM target attribute must be attached to a GPU module");
1937 }
1938
1939 const NVVMCheckSMVersion targetSMVersion =
1940 NVVMCheckSMVersion::getTargetSMVersionFromStr(getChip());
1941 if (!targetSMVersion.isMinimumSMVersion()) {
1942 return emitError(gpuModule->getLoc(),
1943 "Minimum NVVM target SM version is sm_20");
1944 }
1945
1946 gpuModuleOp->walk([&](Operation *op) {
1947 if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
1948 const NVVMCheckSMVersion requirement = reqOp.getRequiredMinSMVersion();
1949 if (!requirement.isCompatibleWith(targetSMVersion)) {
1950 op->emitOpError() << "is not supported on " << getChip();
1951 return WalkResult::interrupt();
1952 }
1953 }
1954 return WalkResult::advance();
1955 });
1956
1957 return success();
1958}
1959
1960#define GET_OP_CLASSES
1961#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1962
1963#define GET_ATTRDEF_CLASSES
1964#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1965

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp