1//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the NVVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12// The NVVM dialect only contains GPU specific additions on top of the general
13// LLVM dialect.
14//
15//===----------------------------------------------------------------------===//
16
17#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18
19#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinAttributes.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/IR/Diagnostics.h"
26#include "mlir/IR/DialectImplementation.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/Operation.h"
29#include "mlir/IR/OperationSupport.h"
30#include "mlir/IR/Types.h"
31#include "mlir/Support/LogicalResult.h"
32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/AsmParser/Parser.h"
35#include "llvm/IR/Attributes.h"
36#include "llvm/IR/Function.h"
37#include "llvm/IR/Type.h"
38#include "llvm/Support/Casting.h"
39#include "llvm/Support/SourceMgr.h"
40#include "llvm/Support/raw_ostream.h"
41#include <cassert>
42#include <optional>
43#include <string>
44
45using namespace mlir;
46using namespace NVVM;
47
48#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
49#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
50
51//===----------------------------------------------------------------------===//
52// Printing/parsing for NVVM ops
53//===----------------------------------------------------------------------===//
54
55static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
56 p << " " << op->getOperands();
57 if (op->getNumResults() > 0)
58 p << " : " << op->getResultTypes();
59}
60
61// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
62ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
63 MLIRContext *context = parser.getContext();
64 auto int32Ty = IntegerType::get(context, 32);
65 auto int1Ty = IntegerType::get(context, 1);
66
67 SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
68 Type type;
69 return failure(parser.parseOperandList(ops) ||
70 parser.parseOptionalAttrDict(result.attributes) ||
71 parser.parseColonType(type) ||
72 parser.addTypeToList(type, result.types) ||
73 parser.resolveOperands(ops, {int32Ty, int1Ty},
74 parser.getNameLoc(), result.operands));
75}
76
77void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
78
79LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
80 if (getCoordinates().empty() || getCoordinates().size() > 5)
81 return emitError("expects coordinates between 1 to 5 dimension");
82
83 // Check for im2col mode
84 if (!getIm2colOffsets().empty()) {
85 if (getCoordinates().size() < 3)
86 return emitError(
87 "to use im2col mode, the tensor has to be at least 3-dimensional");
88 if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
89 return emitError(
90 "im2col offsets must be 2 less than number of coordinates");
91 }
92 return success();
93}
94
95LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
96 if (getCoordinates().size() > 5)
97 return emitError("Maximum 5 coordinates and dimension is supported.");
98 return success();
99}
100
101LogicalResult CpAsyncOp::verify() {
102 if (getModifier() != LoadCacheModifierKind::CG &&
103 getModifier() != LoadCacheModifierKind::CA)
104 return emitError("Only CG and CA cache modifiers are supported.");
105 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
106 return emitError("expected byte size to be either 4, 8 or 16.");
107 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
108 return emitError("CG cache modifier is only support for 16 bytes copy.");
109 return success();
110}
111
112// Given the element type of an operand and whether or not it is an accumulator,
113// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
114// operand's element type.
115std::optional<mlir::NVVM::MMATypes>
116MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
117 auto half2Type =
118 LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
119 if (operandElType.isF64())
120 return NVVM::MMATypes::f64;
121 if (operandElType.isF16() || operandElType == half2Type)
122 return NVVM::MMATypes::f16;
123 if (operandElType.isF32() && isAccumulator)
124 return NVVM::MMATypes::f32;
125 if (operandElType.isF32() && !isAccumulator)
126 return NVVM::MMATypes::tf32;
127 if (llvm::isa<IntegerType>(operandElType)) {
128 if (isAccumulator)
129 return NVVM::MMATypes::s32;
130 return std::nullopt;
131 }
132
133 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
134 if (structType.getBody().empty())
135 return std::nullopt;
136 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
137 }
138
139 return std::nullopt;
140}
141
142static bool isInt4PtxType(MMATypes type) {
143 return (type == MMATypes::u4 || type == MMATypes::s4);
144}
145
146static bool isInt8PtxType(MMATypes type) {
147 return (type == MMATypes::u8 || type == MMATypes::s8);
148}
149
150static bool isIntegerPtxType(MMATypes type) {
151 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
152 type == MMATypes::s32;
153}
154
155MMATypes MmaOp::accumPtxType() {
156 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
157 getODSOperands(2).getTypes().front(), /*isAccum=*/true);
158 assert(val.has_value() && "accumulator PTX type should always be inferrable");
159 return val.value();
160}
161
162MMATypes MmaOp::resultPtxType() {
163 std::optional<mlir::NVVM::MMATypes> val =
164 inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
165 assert(val.has_value() && "result PTX type should always be inferrable");
166 return val.value();
167}
168
169void MmaOp::print(OpAsmPrinter &p) {
170 SmallVector<Type, 4> regTypes;
171 struct OperandFragment {
172 StringRef operandName;
173 StringRef ptxTypeAttr;
174 SmallVector<Value, 4> regs;
175 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
176 : operandName(name), ptxTypeAttr(ptxTypeName) {}
177 };
178
179 std::array<OperandFragment, 3> frags{
180 OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
181 OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
182 OperandFragment("C", "")};
183 SmallVector<StringRef, 4> ignoreAttrNames{
184 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
185
186 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
187 auto &frag = frags[fragIdx];
188 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
189 for (auto operandIdx = varOperandSpec.first;
190 operandIdx < varOperandSpec.first + varOperandSpec.second;
191 operandIdx++) {
192 frag.regs.push_back(this->getOperand(operandIdx));
193 if (operandIdx == 0) {
194 regTypes.push_back(this->getOperand(operandIdx).getType());
195 }
196 }
197 std::optional<MMATypes> inferredType =
198 inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
199 if (inferredType)
200 ignoreAttrNames.push_back(frag.ptxTypeAttr);
201 }
202
203 auto printMmaOperand = [&](const OperandFragment &frag) -> void {
204 p << " " << frag.operandName;
205 p << "[";
206 p.printOperands(frag.regs);
207 p << "] ";
208 };
209
210 for (const auto &frag : frags) {
211 printMmaOperand(frag);
212 }
213
214 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
215
216 // Print the types of the operands and result.
217 p << " : " << "(";
218 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
219 frags[1].regs[0].getType(),
220 frags[2].regs[0].getType()},
221 p);
222 p << ")";
223 p.printArrowTypeList(TypeRange{this->getRes().getType()});
224}
225
226void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
227 ValueRange operandA, ValueRange operandB, ValueRange operandC,
228 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
229 std::optional<MMAIntOverflow> intOverflow,
230 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
231 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
232
233 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
234 MLIRContext *ctx = builder.getContext();
235 result.addAttribute(
236 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
237
238 result.addOperands(operandA);
239 result.addOperands(operandB);
240 result.addOperands(operandC);
241
242 if (multiplicandPtxTypes) {
243 result.addAttribute("multiplicandAPtxType",
244 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
245 result.addAttribute("multiplicandBPtxType",
246 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
247 } else {
248 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
249 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
250 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
251 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
252 }
253
254 if (multiplicandLayouts) {
255 result.addAttribute("layoutA",
256 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
257 result.addAttribute("layoutB",
258 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
259 } else {
260 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
261 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
262 }
263
264 if (intOverflow.has_value())
265 result.addAttribute("intOverflowBehavior",
266 MMAIntOverflowAttr::get(ctx, *intOverflow));
267 if (b1Op.has_value())
268 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
269
270 result.addTypes(resultType);
271 result.addAttribute(
272 MmaOp::getOperandSegmentSizeAttr(),
273 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
274 static_cast<int32_t>(operandB.size()),
275 static_cast<int32_t>(operandC.size())}));
276}
277
278// <operation> :=
279// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
280// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
281// `->` type($res)
282ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
283 struct OperandFragment {
284 std::optional<MMATypes> elemtype;
285 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
286 SmallVector<Type> regTypes;
287 };
288
289 Builder &builder = parser.getBuilder();
290 std::array<OperandFragment, 4> frags;
291
292 NamedAttrList namedAttributes;
293
294 // A helper to parse the operand segments.
295 auto parseMmaOperand = [&](StringRef operandName,
296 OperandFragment &frag) -> LogicalResult {
297 if (parser.parseKeyword(operandName).failed())
298 return failure();
299 if (parser
300 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
301 .failed())
302 return failure();
303 return success();
304 };
305
306 // Parse the operand segments.
307 if (parseMmaOperand("A", frags[0]).failed())
308 return failure();
309 if (parseMmaOperand("B", frags[1]).failed())
310 return failure();
311 if (parseMmaOperand("C", frags[2]).failed())
312 return failure();
313
314 if (parser.parseOptionalAttrDict(namedAttributes).failed())
315 return failure();
316
317 // Parse the type specification and resolve operands.
318 SmallVector<Type, 3> operandTypes;
319 if (failed(parser.parseColon()))
320 return failure();
321 if (failed(parser.parseLParen()))
322 return failure();
323 if (failed(parser.parseTypeList(operandTypes)))
324 return failure();
325 if (failed(parser.parseRParen()))
326 if (operandTypes.size() != 3)
327 return parser.emitError(
328 parser.getNameLoc(),
329 "expected one type for each operand segment but got " +
330 Twine(operandTypes.size()) + " types");
331 for (const auto &iter : llvm::enumerate(operandTypes)) {
332 auto &frag = frags[iter.index()];
333 frag.regTypes.resize(frag.regs.size(), iter.value());
334 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
335 parser.getNameLoc(), result.operands)))
336 return failure();
337 frag.elemtype =
338 inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
339 }
340
341 Type resultType;
342 if (parser.parseArrow() || parser.parseType(resultType))
343 return failure();
344 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
345
346 std::array<StringRef, 2> names{"multiplicandAPtxType",
347 "multiplicandBPtxType"};
348 for (unsigned idx = 0; idx < names.size(); idx++) {
349 const auto &frag = frags[idx];
350 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
351 if (!frag.elemtype.has_value() && !attr.has_value()) {
352 return parser.emitError(
353 parser.getNameLoc(),
354 "attribute " + names[idx] +
355 " is not provided explicitly and cannot be inferred");
356 }
357 if (!attr.has_value())
358 result.addAttribute(
359 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
360 }
361
362 result.addTypes(resultType);
363 if (!namedAttributes.empty())
364 result.addAttributes(namedAttributes);
365 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
366 builder.getDenseI32ArrayAttr({
367 static_cast<int32_t>(frags[0].regs.size()),
368 static_cast<int32_t>(frags[1].regs.size()),
369 static_cast<int32_t>(frags[2].regs.size()),
370 }));
371 return success();
372}
373
374LogicalResult MmaOp::verify() {
375 MLIRContext *context = getContext();
376 auto f16Ty = Float16Type::get(context);
377 auto i32Ty = IntegerType::get(context, 32);
378 auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
379 auto f32Ty = Float32Type::get(context);
380 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
381 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
382
383 auto s32x4StructTy =
384 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
385 auto f32x8StructTy =
386 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
387 auto f16x2x2StructTy =
388 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
389 auto f32x4StructTy =
390 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
391 auto s32x2StructTy =
392 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
393
394 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
395 getShapeAttr().getK()};
396
397 // These variables define the set of allowed data types for matrices A, B, C,
398 // and result.
399 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
400 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
401 AllowedShapes allowedShapes;
402 AllowedTypes expectedA;
403 AllowedTypes expectedB;
404 AllowedTypes expectedC;
405 SmallVector<Type> expectedResult;
406
407 // When M = 16, we just need to calculate the number of 8xk tiles, where
408 // k is a factor that depends on the data type.
409 if (mmaShape[0] == 16) {
410 int64_t kFactor;
411 Type multiplicandFragType;
412 switch (*getMultiplicandAPtxType()) {
413 case MMATypes::tf32:
414 kFactor = 4;
415 multiplicandFragType = i32Ty;
416 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
417 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
418 break;
419 case MMATypes::f16:
420 case MMATypes::bf16:
421 kFactor = 8;
422 multiplicandFragType = f16x2Ty;
423 expectedResult.push_back(f16x2x2StructTy);
424 expectedResult.push_back(f32x4StructTy);
425 break;
426 case MMATypes::s4:
427 case MMATypes::u4:
428 kFactor = 32;
429 break;
430 case MMATypes::b1:
431 kFactor = 128;
432 break;
433 case MMATypes::s8:
434 case MMATypes::u8:
435 kFactor = 16;
436 break;
437 default:
438 return emitError("invalid shape or multiplicand type: " +
439 stringifyEnum(getMultiplicandAPtxType().value()));
440 }
441
442 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
443 expectedResult.push_back(s32x4StructTy);
444 expectedC.emplace_back(4, i32Ty);
445 multiplicandFragType = i32Ty;
446 } else {
447 expectedC.emplace_back(2, f16x2Ty);
448 expectedC.emplace_back(4, f32Ty);
449 }
450
451 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
452 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
453 expectedA.emplace_back(unitA, multiplicandFragType);
454 expectedB.emplace_back(unitB, multiplicandFragType);
455 allowedShapes.push_back({16, 8, kFactor});
456 allowedShapes.push_back({16, 8, kFactor * 2});
457 }
458
459 // In the M=8 case, there is only 1 possible case per data type.
460 if (mmaShape[0] == 8) {
461 if (*getMultiplicandAPtxType() == MMATypes::f16) {
462 expectedA.emplace_back(2, f16x2Ty);
463 expectedB.emplace_back(2, f16x2Ty);
464 expectedResult.push_back(f16x2x4StructTy);
465 expectedResult.push_back(f32x8StructTy);
466 expectedC.emplace_back(4, f16x2Ty);
467 expectedC.emplace_back(8, f32Ty);
468 allowedShapes.push_back({8, 8, 4});
469 }
470 if (*getMultiplicandAPtxType() == MMATypes::f64) {
471 Type f64Ty = Float64Type::get(context);
472 expectedA.emplace_back(1, f64Ty);
473 expectedB.emplace_back(1, f64Ty);
474 expectedC.emplace_back(2, f64Ty);
475 // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
476 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
477 context, SmallVector<Type>(2, f64Ty)));
478 allowedShapes.push_back({8, 8, 4});
479 }
480 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
481 expectedA.push_back({i32Ty});
482 expectedB.push_back({i32Ty});
483 expectedC.push_back({i32Ty, i32Ty});
484 expectedResult.push_back(s32x2StructTy);
485 if (isInt4PtxType(getMultiplicandAPtxType().value()))
486 allowedShapes.push_back({8, 8, 32});
487 if (isInt8PtxType(getMultiplicandAPtxType().value()))
488 allowedShapes.push_back({8, 8, 16});
489 if (getMultiplicandAPtxType().value() == MMATypes::b1)
490 allowedShapes.push_back({8, 8, 128});
491 }
492 }
493
494 std::string errorMessage;
495 llvm::raw_string_ostream errorStream(errorMessage);
496
497 // Check that we matched an existing shape/dtype combination.
498 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
499 !llvm::is_contained(allowedShapes, mmaShape)) {
500 errorStream << "unimplemented variant for MMA shape <";
501 llvm::interleaveComma(mmaShape, errorStream);
502 errorStream << ">";
503 return emitOpError(errorMessage);
504 }
505
506 // Verify the operand types for segments of A, B, and C operands.
507 std::array<StringRef, 3> operandNames{"A", "B", "C"};
508 for (const auto &iter : llvm::enumerate(
509 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
510 auto spec = this->getODSOperandIndexAndLength(iter.index());
511 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
512 operand_type_begin() + spec.first +
513 spec.second);
514 bool match = llvm::is_contained(iter.value(), operandTySeg);
515
516 if (!match) {
517 errorStream << "Could not match types for the "
518 << operandNames[iter.index()]
519 << " operands; expected one of ";
520 for (const auto &x : iter.value()) {
521 errorStream << x.size() << "x" << x[0] << " ";
522 }
523 errorStream << "but got ";
524 llvm::interleaveComma(operandTySeg, errorStream);
525 return emitOpError(errorStream.str());
526 }
527 }
528
529 // Check the result type
530 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
531 return expectedResultType == getResult().getType();
532 })) {
533 errorStream
534 << "Could not match allowed types for the result; expected one of ";
535 llvm::interleaveComma(expectedResult, errorStream);
536 errorStream << " but got " << getResult().getType();
537 return emitOpError(errorStream.str());
538 }
539
540 // Ensure that binary MMA variants have a b1 MMA operation defined.
541 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
542 return emitOpError("op requires " + getB1OpAttrName().strref() +
543 " attribute");
544 }
545
546 // Ensure int4/int8 MMA variants specify the accum overflow behavior
547 // attribute.
548 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
549 isInt8PtxType(*getMultiplicandAPtxType())) {
550 if (!getIntOverflowBehavior())
551 return emitOpError("op requires " +
552 getIntOverflowBehaviorAttrName().strref() +
553 " attribute");
554 }
555
556 return success();
557}
558
559LogicalResult ShflOp::verify() {
560 if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
561 return success();
562 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
563 auto elementType = (type && type.getBody().size() == 2)
564 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
565 : nullptr;
566 if (!elementType || elementType.getWidth() != 1)
567 return emitError("expected return type to be a two-element struct with "
568 "i1 as the second element");
569 return success();
570}
571
572std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
573 NVVM::MMAFrag frag, int nRow,
574 int nCol,
575 MLIRContext *context) {
576 unsigned numberElements = 0;
577 Type elementType;
578 OpBuilder builder(context);
579 Type f16x2 = VectorType::get(2, builder.getF16Type());
580 if (type == NVVM::MMATypes::f16) {
581 elementType = f16x2;
582 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
583 numberElements = 8;
584 else
585 numberElements = 4;
586 } else if (type == NVVM::MMATypes::f32) {
587 elementType = builder.getF32Type();
588 numberElements = 8;
589 } else if (type == NVVM::MMATypes::tf32) {
590 elementType = builder.getI32Type();
591 numberElements = 4;
592 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
593 elementType = builder.getI32Type();
594 int parallelSize = 0;
595 if (frag == NVVM::MMAFrag::a)
596 parallelSize = nRow;
597 if (frag == NVVM::MMAFrag::b)
598 parallelSize = nCol;
599
600 // m == 16 && n == 16 && k == 16
601 if (parallelSize == 16)
602 numberElements = 2;
603 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
604 else if (parallelSize == 8)
605 numberElements = 1;
606 else if (parallelSize == 32)
607 numberElements = 4;
608 } else if (type == NVVM::MMATypes::s32) {
609 elementType = builder.getI32Type();
610 numberElements = 8;
611 }
612 assert(numberElements != 0 && elementType != nullptr);
613 return std::make_pair(x&: elementType, y&: numberElements);
614}
615
616static std::pair<mlir::Type, unsigned>
617inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
618 int k, MLIRContext *context) {
619 int nRow, nCol;
620 if (frag == NVVM::MMAFrag::a) {
621 nRow = m;
622 nCol = k;
623 } else if (frag == NVVM::MMAFrag::b) {
624 nRow = k;
625 nCol = n;
626 } else {
627 nRow = m;
628 nCol = n;
629 }
630 assert(nRow && nCol);
631 return inferMMAType(type, frag, nRow, nCol, context);
632}
633
634LogicalResult NVVM::WMMALoadOp::verify() {
635 unsigned addressSpace =
636 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
637 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
638 addressSpace != NVVM::kSharedMemorySpace)
639 return emitOpError("expected source pointer in memory "
640 "space 0, 1, 3");
641
642 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
643 getEltype(), getFrag()) == 0)
644 return emitOpError() << "invalid attribute combination";
645 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
646 getEltype(), getFrag(), getM(), getN(), getK(), getContext());
647 Type dstType = LLVM::LLVMStructType::getLiteral(
648 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
649 if (getType() != dstType)
650 return emitOpError("expected destination type is a structure of ")
651 << typeInfo.second << " elements of type " << typeInfo.first;
652 return success();
653}
654
655LogicalResult NVVM::WMMAStoreOp::verify() {
656 unsigned addressSpace =
657 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
658 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
659 addressSpace != NVVM::kSharedMemorySpace)
660 return emitOpError("expected operands to be a source pointer in memory "
661 "space 0, 1, 3");
662
663 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
664 getEltype()) == 0)
665 return emitOpError() << "invalid attribute combination";
666 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
667 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
668 if (getArgs().size() != typeInfo.second)
669 return emitOpError() << "expected " << typeInfo.second << " data operands";
670 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
671 return operands.getType() != typeInfo.first;
672 }))
673 return emitOpError() << "expected data operands of type " << typeInfo.first;
674 return success();
675}
676
677LogicalResult NVVM::WMMAMmaOp::verify() {
678 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
679 getLayoutB(), getEltypeA(),
680 getEltypeB()) == 0)
681 return emitOpError() << "invalid attribute combination";
682 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
683 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
684 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
685 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
686 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
687 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
688 SmallVector<Type, 32> arguments;
689 arguments.append(typeInfoA.second, typeInfoA.first);
690 arguments.append(typeInfoB.second, typeInfoB.first);
691 arguments.append(typeInfoC.second, typeInfoC.first);
692 unsigned numArgs = arguments.size();
693 if (getArgs().size() != numArgs)
694 return emitOpError() << "expected " << numArgs << " arguments";
695 for (unsigned i = 0; i < numArgs; i++) {
696 if (getArgs()[i].getType() != arguments[i])
697 return emitOpError() << "expected argument " << i << " to be of type "
698 << arguments[i];
699 }
700 Type dstType = LLVM::LLVMStructType::getLiteral(
701 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
702 if (getType() != dstType)
703 return emitOpError("expected destination type is a structure of ")
704 << typeInfoC.second << " elements of type " << typeInfoC.first;
705 return success();
706}
707
708LogicalResult NVVM::LdMatrixOp::verify() {
709 unsigned addressSpace =
710 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
711 if (addressSpace != NVVM::kSharedMemorySpace)
712 return emitOpError("expected source pointer in memory space 3");
713
714 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
715 return emitOpError("expected num attribute to be 1, 2 or 4");
716
717 Type i32 = IntegerType::get(getContext(), 32);
718 if (getNum() == 1 && getType() != i32)
719 return emitOpError("expected destination type is i32");
720 if (getNum() == 2 || getNum() == 4) {
721 Type dstType = LLVM::LLVMStructType::getLiteral(
722 getContext(), SmallVector<Type>(getNum(), i32));
723 if (getType() != dstType)
724 return emitOpError("expected destination type is a structure of ")
725 << getNum() << " elements of type i32";
726 }
727 return success();
728}
729
730LogicalResult NVVM::StMatrixOp::verify() {
731 unsigned addressSpace =
732 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
733 if (addressSpace != NVVM::kSharedMemorySpace)
734 return emitOpError("expected source pointer in memory space 3");
735
736 int numMatrix = getSources().size();
737 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
738 return emitOpError("expected num attribute to be 1, 2 or 4");
739
740 return success();
741}
742
743FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
744 if (typeA == NVVM::WGMMATypes::tf32)
745 return 8;
746 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
747 return 16;
748 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
749 return 32;
750 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
751 return 32;
752 if (typeA == NVVM::WGMMATypes::b1)
753 return 256;
754 return failure();
755}
756
757LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
758 NVVM::WGMMATypes typeA,
759 NVVM::WGMMATypes typeB) {
760 switch (typeA) {
761 case NVVM::WGMMATypes::f16:
762 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
763 typeB == NVVM::WGMMATypes::f16)
764 return success();
765 break;
766 case NVVM::WGMMATypes::tf32:
767 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
768 return success();
769 break;
770 case NVVM::WGMMATypes::u8:
771 case NVVM::WGMMATypes::s8:
772 if (typeD == NVVM::WGMMATypes::s32 &&
773 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
774 return success();
775 break;
776 case NVVM::WGMMATypes::b1:
777 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
778 return success();
779 break;
780 case NVVM::WGMMATypes::bf16:
781 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
782 typeB == NVVM::WGMMATypes::bf16)
783 return success();
784 break;
785 case NVVM::WGMMATypes::e4m3:
786 case NVVM::WGMMATypes::e5m2:
787 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
788 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
789 return success();
790 break;
791 case WGMMATypes::f32:
792 case WGMMATypes::s32:
793 llvm_unreachable("unsupported input types");
794 break;
795 }
796 return failure();
797}
798
799LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
800 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
801 72, 80, 88, 96, 104, 112, 120, 128,
802 136, 144, 152, 160, 168, 176, 184, 192,
803 200, 208, 216, 224, 232, 240, 248, 256};
804 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
805 80, 96, 112, 128, 144, 160,
806 176, 192, 208, 224, 240, 256};
807 switch (typeA) {
808 case WGMMATypes::f16:
809 case WGMMATypes::tf32:
810 case WGMMATypes::bf16:
811 case WGMMATypes::e4m3:
812 case WGMMATypes::e5m2:
813 if (llvm::is_contained(Range&: allowedN, Element: sizeN))
814 return success();
815 break;
816 case WGMMATypes::u8:
817 case WGMMATypes::s8:
818 case WGMMATypes::b1:
819 if (llvm::is_contained(Range&: allowedNshort, Element: sizeN))
820 return success();
821 break;
822 case WGMMATypes::f32:
823 case WGMMATypes::s32:
824 llvm_unreachable("unsupported input types");
825 break;
826 }
827 return failure();
828}
829
830LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
831 Value outValue = getResults();
832 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
833 if (!stype)
834 return emitOpError() << "expected results to be struct";
835 int outputSize = stype.getBody().size();
836 WGMMATypes typeD = getTypeD();
837 WGMMATypes typeA = getTypeA();
838 WGMMATypes typeB = getTypeB();
839
840 for (Type t : stype.getBody()) {
841 if (t != stype.getBody().front())
842 return emitOpError()
843 << "all elements in struct must be same type but there is " << t;
844 }
845
846 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
847 typeD != WGMMATypes::s32) {
848 return emitOpError() << "does not support the given output type "
849 << NVVM::stringifyWGMMATypes(typeD);
850 }
851 if (typeD == WGMMATypes::s32 &&
852 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
853 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
854 }
855
856 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
857 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
858 << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
859 << NVVM::stringifyWGMMATypes(typeB)
860 << ", it is not supported.";
861 }
862
863 // Check M
864 if (getShape().getM() != 64)
865 return emitOpError() << "shape 'm' must be 64";
866
867 // Check K
868 FailureOr<int> allowedK = getAllowedSizeK(typeA);
869 if (failed(allowedK) || allowedK.value() != getShape().getK())
870 return emitOpError() << "shape 'k' must be " << allowedK.value()
871 << " for input type "
872 << NVVM::stringifyWGMMATypes(typeA);
873
874 // Check N
875 if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
876 return emitOpError() << "has input type "
877 << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
878 << getShape().getN() << ", it is not supported.";
879 }
880
881 // Check transpose (only available for f16/bf16)
882 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
883 (getLayoutA() == mlir::NVVM::MMALayout::col ||
884 getLayoutB() == mlir::NVVM::MMALayout::col)) {
885 return emitOpError()
886 << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
887 << " and layout_b = " << stringifyMMALayout(getLayoutB())
888 << " for input types " << stringifyWGMMATypes(typeA) << " and "
889 << stringifyWGMMATypes(typeB)
890 << " requires transpose. However, this is only supported for: "
891 << stringifyMMATypes(MMATypes::f16) << " and "
892 << stringifyMMATypes(MMATypes::bf16);
893 }
894
895 // Check result registers
896 int expectedOutput = 0;
897 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
898 expectedOutput = getShape().getN() / 2;
899 if (typeD == WGMMATypes::f16)
900 expectedOutput = getShape().getN() / 4;
901 if (outputSize != expectedOutput) {
902 return emitOpError() << "results " << expectedOutput
903 << ", however output struct has " << outputSize
904 << " elements";
905 }
906 // Check satfinite (only available for s32 accumulator)
907 if (typeD != WGMMATypes::s32 &&
908 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
909 NVVM::MMAIntOverflow::satfinite) {
910 return emitOpError()
911 << " `satfinite` can be only used with s32 accumulator, however "
912 "the current accumulator is "
913 << NVVM::stringifyWGMMATypes(typeD);
914 }
915
916 return success();
917}
918
919std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
920
921 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
922 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
923
924 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
925
926 int expectedOutputRegisters = 0;
927 if (getTypeD() == WGMMATypes::f16)
928 expectedOutputRegisters = getShape().getN() / 4;
929 else
930 expectedOutputRegisters = getShape().getN() / 2;
931
932 std::string ptx;
933 llvm::raw_string_ostream ss(ptx);
934
935 ss << "{\n"
936 ".reg .pred p;\n"
937 "setp.ne.b32 p, $"
938 << ((expectedOutputRegisters * 2) + 2)
939 << ", 0;\n"
940 "wgmma.mma_async.sync.aligned.m"
941 << m << "n" << n << "k" << k << "." << outputTypeName << "."
942 << stringifyWGMMATypes(getTypeA()) << "."
943 << stringifyWGMMATypes(getTypeB());
944 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
945 NVVM::MMAIntOverflow::satfinite)
946 ss << ".satfinite";
947 ss << " {";
948 int regCnt = 0;
949 for (; regCnt < expectedOutputRegisters; ++regCnt) {
950 ss << "$" << regCnt;
951 if (regCnt != expectedOutputRegisters - 1)
952 ss << ", ";
953 }
954
955 ss << "},";
956 // Need to map read/write registers correctly.
957 regCnt = (regCnt * 2);
958 ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
959 if (getTypeD() != WGMMATypes::s32) {
960 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
961 }
962 // Don't add transpose parameters unless needed.
963 if (isF16) {
964 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
965 }
966 ss << ";\n"
967 << "}\n";
968 ss.flush();
969 return ptx;
970}
971
972void NVVM::WgmmaMmaAsyncOp::getAsmValues(
973 RewriterBase &rewriter,
974 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
975 &asmValues) {
976 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
977 if (getResults())
978 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
979 if (getInouts())
980 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
981 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
982 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
983 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
984 mlir::NVVM::PTXRegisterMod::Read});
985 if (getTypeD() != WGMMATypes::s32) {
986 asmValues.push_back(
987 {makeConstantI32(rewriter,
988 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
989 mlir::NVVM::PTXRegisterMod::Read});
990 asmValues.push_back(
991 {makeConstantI32(rewriter,
992 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
993 mlir::NVVM::PTXRegisterMod::Read});
994 }
995 if (isF16) {
996 asmValues.push_back(
997 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
998 mlir::NVVM::PTXRegisterMod::Read});
999 asmValues.push_back(
1000 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1001 mlir::NVVM::PTXRegisterMod::Read});
1002 }
1003}
1004LogicalResult NVVM::FenceProxyOp::verify() {
1005 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1006 return emitOpError() << "async_shared fence requires space attribute";
1007 }
1008 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1009 return emitOpError() << "only async_shared fence can have space attribute";
1010 }
1011 return success();
1012}
1013
1014LogicalResult NVVM::SetMaxRegisterOp::verify() {
1015 if (getRegCount() % 8)
1016 return emitOpError("new register size must be multiple of 8");
1017 if (getRegCount() < 24 || getRegCount() > 256)
1018 return emitOpError("new register size must be in between 24 to 256");
1019 return success();
1020}
1021
1022LogicalResult NVVM::BarrierOp::verify() {
1023 if (getNumberOfThreads() && !getBarrierId())
1024 return emitOpError(
1025 "barrier id is missing, it should be set between 0 to 15");
1026 return success();
1027}
1028
1029//===----------------------------------------------------------------------===//
1030// NVVMDialect initialization, type parsing, and registration.
1031//===----------------------------------------------------------------------===//
1032
1033// TODO: This should be the llvm.nvvm dialect once this is supported.
1034void NVVMDialect::initialize() {
1035 addOperations<
1036#define GET_OP_LIST
1037#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1038 >();
1039 addAttributes<
1040#define GET_ATTRDEF_LIST
1041#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1042 >();
1043
1044 // Support unknown operations because not all NVVM operations are
1045 // registered.
1046 allowUnknownOperations();
1047 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1048 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1049}
1050
1051LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1052 NamedAttribute attr) {
1053 StringAttr attrName = attr.getName();
1054 // Kernel function attribute should be attached to functions.
1055 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1056 if (!isa<LLVM::LLVMFuncOp>(op)) {
1057 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1058 << "' attribute attached to unexpected op";
1059 }
1060 }
1061 // If maxntid and reqntid exist, it must be an array with max 3 dim
1062 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1063 attrName == NVVMDialect::getReqntidAttrName()) {
1064 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1065 if (!values || values.empty() || values.size() > 3)
1066 return op->emitError()
1067 << "'" << attrName
1068 << "' attribute must be integer array with maximum 3 index";
1069 }
1070 // If minctasm and maxnreg exist, it must be an integer attribute
1071 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1072 attrName == NVVMDialect::getMaxnregAttrName()) {
1073 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1074 return op->emitError()
1075 << "'" << attrName << "' attribute must be integer constant";
1076 }
1077
1078 return success();
1079}
1080
1081LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1082 unsigned regionIndex,
1083 unsigned argIndex,
1084 NamedAttribute argAttr) {
1085 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1086 if (!funcOp)
1087 return success();
1088
1089 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1090 StringAttr attrName = argAttr.getName();
1091 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1092 if (!isKernel) {
1093 return op->emitError()
1094 << "'" << attrName
1095 << "' attribute must be present only on kernel arguments";
1096 }
1097 if (!isa<UnitAttr>(argAttr.getValue()))
1098 return op->emitError() << "'" << attrName << "' must be a unit attribute";
1099 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1100 return op->emitError()
1101 << "'" << attrName
1102 << "' attribute requires the argument to also have attribute '"
1103 << LLVM::LLVMDialect::getByValAttrName() << "'";
1104 }
1105 }
1106
1107 return success();
1108}
1109
1110//===----------------------------------------------------------------------===//
1111// NVVM target attribute.
1112//===----------------------------------------------------------------------===//
1113LogicalResult
1114NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1115 int optLevel, StringRef triple, StringRef chip,
1116 StringRef features, DictionaryAttr flags,
1117 ArrayAttr files) {
1118 if (optLevel < 0 || optLevel > 3) {
1119 emitError() << "The optimization level must be a number between 0 and 3.";
1120 return failure();
1121 }
1122 if (triple.empty()) {
1123 emitError() << "The target triple cannot be empty.";
1124 return failure();
1125 }
1126 if (chip.empty()) {
1127 emitError() << "The target chip cannot be empty.";
1128 return failure();
1129 }
1130 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1131 return attr && mlir::isa<StringAttr>(attr);
1132 })) {
1133 emitError() << "All the elements in the `link` array must be strings.";
1134 return failure();
1135 }
1136 return success();
1137}
1138
1139#define GET_OP_CLASSES
1140#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1141
1142#define GET_ATTRDEF_CLASSES
1143#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1144

source code of mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp