1 | //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the types and operation details for the NVVM IR dialect in |
10 | // MLIR, and the LLVM IR dialect. It also registers the dialect. |
11 | // |
12 | // The NVVM dialect only contains GPU specific additions on top of the general |
13 | // LLVM dialect. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
18 | |
19 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
20 | #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" |
21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
22 | #include "mlir/IR/Builders.h" |
23 | #include "mlir/IR/BuiltinAttributes.h" |
24 | #include "mlir/IR/BuiltinTypes.h" |
25 | #include "mlir/IR/Diagnostics.h" |
26 | #include "mlir/IR/DialectImplementation.h" |
27 | #include "mlir/IR/MLIRContext.h" |
28 | #include "mlir/IR/Operation.h" |
29 | #include "mlir/IR/OperationSupport.h" |
30 | #include "mlir/IR/Types.h" |
31 | #include "mlir/Support/LogicalResult.h" |
32 | #include "llvm/ADT/STLExtras.h" |
33 | #include "llvm/ADT/TypeSwitch.h" |
34 | #include "llvm/AsmParser/Parser.h" |
35 | #include "llvm/IR/Attributes.h" |
36 | #include "llvm/IR/Function.h" |
37 | #include "llvm/IR/Type.h" |
38 | #include "llvm/Support/Casting.h" |
39 | #include "llvm/Support/SourceMgr.h" |
40 | #include "llvm/Support/raw_ostream.h" |
41 | #include <cassert> |
42 | #include <optional> |
43 | #include <string> |
44 | |
45 | using namespace mlir; |
46 | using namespace NVVM; |
47 | |
48 | #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" |
49 | #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" |
50 | |
51 | //===----------------------------------------------------------------------===// |
52 | // Printing/parsing for NVVM ops |
53 | //===----------------------------------------------------------------------===// |
54 | |
55 | static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { |
56 | p << " " << op->getOperands(); |
57 | if (op->getNumResults() > 0) |
58 | p << " : " << op->getResultTypes(); |
59 | } |
60 | |
61 | // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type |
62 | ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) { |
63 | MLIRContext *context = parser.getContext(); |
64 | auto int32Ty = IntegerType::get(context, 32); |
65 | auto int1Ty = IntegerType::get(context, 1); |
66 | |
67 | SmallVector<OpAsmParser::UnresolvedOperand, 8> ops; |
68 | Type type; |
69 | return failure(parser.parseOperandList(ops) || |
70 | parser.parseOptionalAttrDict(result.attributes) || |
71 | parser.parseColonType(type) || |
72 | parser.addTypeToList(type, result.types) || |
73 | parser.resolveOperands(ops, {int32Ty, int1Ty}, |
74 | parser.getNameLoc(), result.operands)); |
75 | } |
76 | |
77 | void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } |
78 | |
79 | LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { |
80 | if (getCoordinates().empty() || getCoordinates().size() > 5) |
81 | return emitError("expects coordinates between 1 to 5 dimension" ); |
82 | |
83 | // Check for im2col mode |
84 | if (!getIm2colOffsets().empty()) { |
85 | if (getCoordinates().size() < 3) |
86 | return emitError( |
87 | "to use im2col mode, the tensor has to be at least 3-dimensional" ); |
88 | if (getCoordinates().size() != (getIm2colOffsets().size() + 2)) |
89 | return emitError( |
90 | "im2col offsets must be 2 less than number of coordinates" ); |
91 | } |
92 | return success(); |
93 | } |
94 | |
95 | LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() { |
96 | if (getCoordinates().size() > 5) |
97 | return emitError("Maximum 5 coordinates and dimension is supported." ); |
98 | return success(); |
99 | } |
100 | |
101 | LogicalResult CpAsyncOp::verify() { |
102 | if (getModifier() != LoadCacheModifierKind::CG && |
103 | getModifier() != LoadCacheModifierKind::CA) |
104 | return emitError("Only CG and CA cache modifiers are supported." ); |
105 | if (getSize() != 4 && getSize() != 8 && getSize() != 16) |
106 | return emitError("expected byte size to be either 4, 8 or 16." ); |
107 | if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16) |
108 | return emitError("CG cache modifier is only support for 16 bytes copy." ); |
109 | return success(); |
110 | } |
111 | |
112 | // Given the element type of an operand and whether or not it is an accumulator, |
113 | // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the |
114 | // operand's element type. |
115 | std::optional<mlir::NVVM::MMATypes> |
116 | MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) { |
117 | auto half2Type = |
118 | LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2); |
119 | if (operandElType.isF64()) |
120 | return NVVM::MMATypes::f64; |
121 | if (operandElType.isF16() || operandElType == half2Type) |
122 | return NVVM::MMATypes::f16; |
123 | if (operandElType.isF32() && isAccumulator) |
124 | return NVVM::MMATypes::f32; |
125 | if (operandElType.isF32() && !isAccumulator) |
126 | return NVVM::MMATypes::tf32; |
127 | if (llvm::isa<IntegerType>(operandElType)) { |
128 | if (isAccumulator) |
129 | return NVVM::MMATypes::s32; |
130 | return std::nullopt; |
131 | } |
132 | |
133 | if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) { |
134 | if (structType.getBody().empty()) |
135 | return std::nullopt; |
136 | return inferOperandMMAType(structType.getBody()[0], isAccumulator); |
137 | } |
138 | |
139 | return std::nullopt; |
140 | } |
141 | |
142 | static bool isInt4PtxType(MMATypes type) { |
143 | return (type == MMATypes::u4 || type == MMATypes::s4); |
144 | } |
145 | |
146 | static bool isInt8PtxType(MMATypes type) { |
147 | return (type == MMATypes::u8 || type == MMATypes::s8); |
148 | } |
149 | |
150 | static bool isIntegerPtxType(MMATypes type) { |
151 | return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 || |
152 | type == MMATypes::s32; |
153 | } |
154 | |
155 | MMATypes MmaOp::accumPtxType() { |
156 | std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType( |
157 | getODSOperands(2).getTypes().front(), /*isAccum=*/true); |
158 | assert(val.has_value() && "accumulator PTX type should always be inferrable" ); |
159 | return val.value(); |
160 | } |
161 | |
162 | MMATypes MmaOp::resultPtxType() { |
163 | std::optional<mlir::NVVM::MMATypes> val = |
164 | inferOperandMMAType(getResult().getType(), /*isAccum=*/true); |
165 | assert(val.has_value() && "result PTX type should always be inferrable" ); |
166 | return val.value(); |
167 | } |
168 | |
169 | void MmaOp::print(OpAsmPrinter &p) { |
170 | SmallVector<Type, 4> regTypes; |
171 | struct OperandFragment { |
172 | StringRef operandName; |
173 | StringRef ptxTypeAttr; |
174 | SmallVector<Value, 4> regs; |
175 | explicit OperandFragment(StringRef name, StringRef ptxTypeName) |
176 | : operandName(name), ptxTypeAttr(ptxTypeName) {} |
177 | }; |
178 | |
179 | std::array<OperandFragment, 3> frags{ |
180 | OperandFragment("A" , getMultiplicandAPtxTypeAttrName()), |
181 | OperandFragment("B" , getMultiplicandBPtxTypeAttrName()), |
182 | OperandFragment("C" , "" )}; |
183 | SmallVector<StringRef, 4> ignoreAttrNames{ |
184 | mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()}; |
185 | |
186 | for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { |
187 | auto &frag = frags[fragIdx]; |
188 | auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); |
189 | for (auto operandIdx = varOperandSpec.first; |
190 | operandIdx < varOperandSpec.first + varOperandSpec.second; |
191 | operandIdx++) { |
192 | frag.regs.push_back(this->getOperand(operandIdx)); |
193 | if (operandIdx == 0) { |
194 | regTypes.push_back(this->getOperand(operandIdx).getType()); |
195 | } |
196 | } |
197 | std::optional<MMATypes> inferredType = |
198 | inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2); |
199 | if (inferredType) |
200 | ignoreAttrNames.push_back(frag.ptxTypeAttr); |
201 | } |
202 | |
203 | auto printMmaOperand = [&](const OperandFragment &frag) -> void { |
204 | p << " " << frag.operandName; |
205 | p << "[" ; |
206 | p.printOperands(frag.regs); |
207 | p << "] " ; |
208 | }; |
209 | |
210 | for (const auto &frag : frags) { |
211 | printMmaOperand(frag); |
212 | } |
213 | |
214 | p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); |
215 | |
216 | // Print the types of the operands and result. |
217 | p << " : " << "(" ; |
218 | llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(), |
219 | frags[1].regs[0].getType(), |
220 | frags[2].regs[0].getType()}, |
221 | p); |
222 | p << ")" ; |
223 | p.printArrowTypeList(TypeRange{this->getRes().getType()}); |
224 | } |
225 | |
226 | void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType, |
227 | ValueRange operandA, ValueRange operandB, ValueRange operandC, |
228 | ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op, |
229 | std::optional<MMAIntOverflow> intOverflow, |
230 | std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes, |
231 | std::optional<std::array<MMALayout, 2>> multiplicandLayouts) { |
232 | |
233 | assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)" ); |
234 | MLIRContext *ctx = builder.getContext(); |
235 | result.addAttribute( |
236 | "shape" , builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2])); |
237 | |
238 | result.addOperands(operandA); |
239 | result.addOperands(operandB); |
240 | result.addOperands(operandC); |
241 | |
242 | if (multiplicandPtxTypes) { |
243 | result.addAttribute("multiplicandAPtxType" , |
244 | MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); |
245 | result.addAttribute("multiplicandBPtxType" , |
246 | MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); |
247 | } else { |
248 | if (auto res = inferOperandMMAType(operandA[0].getType(), false)) |
249 | result.addAttribute("multiplicandAPtxType" , MMATypesAttr::get(ctx, *res)); |
250 | if (auto res = inferOperandMMAType(operandB[0].getType(), false)) |
251 | result.addAttribute("multiplicandBPtxType" , MMATypesAttr::get(ctx, *res)); |
252 | } |
253 | |
254 | if (multiplicandLayouts) { |
255 | result.addAttribute("layoutA" , |
256 | MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0])); |
257 | result.addAttribute("layoutB" , |
258 | MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1])); |
259 | } else { |
260 | result.addAttribute("layoutA" , MMALayoutAttr::get(ctx, MMALayout::row)); |
261 | result.addAttribute("layoutB" , MMALayoutAttr::get(ctx, MMALayout::col)); |
262 | } |
263 | |
264 | if (intOverflow.has_value()) |
265 | result.addAttribute("intOverflowBehavior" , |
266 | MMAIntOverflowAttr::get(ctx, *intOverflow)); |
267 | if (b1Op.has_value()) |
268 | result.addAttribute("b1Op" , MMAB1OpAttr::get(ctx, *b1Op)); |
269 | |
270 | result.addTypes(resultType); |
271 | result.addAttribute( |
272 | MmaOp::getOperandSegmentSizeAttr(), |
273 | builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()), |
274 | static_cast<int32_t>(operandB.size()), |
275 | static_cast<int32_t>(operandC.size())})); |
276 | } |
277 | |
278 | // <operation> := |
279 | // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]` |
280 | // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0])) |
281 | // `->` type($res) |
282 | ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) { |
283 | struct OperandFragment { |
284 | std::optional<MMATypes> elemtype; |
285 | SmallVector<OpAsmParser::UnresolvedOperand, 4> regs; |
286 | SmallVector<Type> regTypes; |
287 | }; |
288 | |
289 | Builder &builder = parser.getBuilder(); |
290 | std::array<OperandFragment, 4> frags; |
291 | |
292 | NamedAttrList namedAttributes; |
293 | |
294 | // A helper to parse the operand segments. |
295 | auto parseMmaOperand = [&](StringRef operandName, |
296 | OperandFragment &frag) -> LogicalResult { |
297 | if (parser.parseKeyword(operandName).failed()) |
298 | return failure(); |
299 | if (parser |
300 | .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) |
301 | .failed()) |
302 | return failure(); |
303 | return success(); |
304 | }; |
305 | |
306 | // Parse the operand segments. |
307 | if (parseMmaOperand("A" , frags[0]).failed()) |
308 | return failure(); |
309 | if (parseMmaOperand("B" , frags[1]).failed()) |
310 | return failure(); |
311 | if (parseMmaOperand("C" , frags[2]).failed()) |
312 | return failure(); |
313 | |
314 | if (parser.parseOptionalAttrDict(namedAttributes).failed()) |
315 | return failure(); |
316 | |
317 | // Parse the type specification and resolve operands. |
318 | SmallVector<Type, 3> operandTypes; |
319 | if (failed(parser.parseColon())) |
320 | return failure(); |
321 | if (failed(parser.parseLParen())) |
322 | return failure(); |
323 | if (failed(parser.parseTypeList(operandTypes))) |
324 | return failure(); |
325 | if (failed(parser.parseRParen())) |
326 | if (operandTypes.size() != 3) |
327 | return parser.emitError( |
328 | parser.getNameLoc(), |
329 | "expected one type for each operand segment but got " + |
330 | Twine(operandTypes.size()) + " types" ); |
331 | for (const auto &iter : llvm::enumerate(operandTypes)) { |
332 | auto &frag = frags[iter.index()]; |
333 | frag.regTypes.resize(frag.regs.size(), iter.value()); |
334 | if (failed(parser.resolveOperands(frag.regs, frag.regTypes, |
335 | parser.getNameLoc(), result.operands))) |
336 | return failure(); |
337 | frag.elemtype = |
338 | inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2); |
339 | } |
340 | |
341 | Type resultType; |
342 | if (parser.parseArrow() || parser.parseType(resultType)) |
343 | return failure(); |
344 | frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true); |
345 | |
346 | std::array<StringRef, 2> names{"multiplicandAPtxType" , |
347 | "multiplicandBPtxType" }; |
348 | for (unsigned idx = 0; idx < names.size(); idx++) { |
349 | const auto &frag = frags[idx]; |
350 | std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]); |
351 | if (!frag.elemtype.has_value() && !attr.has_value()) { |
352 | return parser.emitError( |
353 | parser.getNameLoc(), |
354 | "attribute " + names[idx] + |
355 | " is not provided explicitly and cannot be inferred" ); |
356 | } |
357 | if (!attr.has_value()) |
358 | result.addAttribute( |
359 | names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); |
360 | } |
361 | |
362 | result.addTypes(resultType); |
363 | if (!namedAttributes.empty()) |
364 | result.addAttributes(namedAttributes); |
365 | result.addAttribute(MmaOp::getOperandSegmentSizeAttr(), |
366 | builder.getDenseI32ArrayAttr({ |
367 | static_cast<int32_t>(frags[0].regs.size()), |
368 | static_cast<int32_t>(frags[1].regs.size()), |
369 | static_cast<int32_t>(frags[2].regs.size()), |
370 | })); |
371 | return success(); |
372 | } |
373 | |
374 | LogicalResult MmaOp::verify() { |
375 | MLIRContext *context = getContext(); |
376 | auto f16Ty = Float16Type::get(context); |
377 | auto i32Ty = IntegerType::get(context, 32); |
378 | auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); |
379 | auto f32Ty = Float32Type::get(context); |
380 | auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( |
381 | context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); |
382 | |
383 | auto s32x4StructTy = |
384 | LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); |
385 | auto f32x8StructTy = |
386 | LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty)); |
387 | auto f16x2x2StructTy = |
388 | LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); |
389 | auto f32x4StructTy = |
390 | LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); |
391 | auto s32x2StructTy = |
392 | LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); |
393 | |
394 | std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), |
395 | getShapeAttr().getK()}; |
396 | |
397 | // These variables define the set of allowed data types for matrices A, B, C, |
398 | // and result. |
399 | using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>; |
400 | using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>; |
401 | AllowedShapes allowedShapes; |
402 | AllowedTypes expectedA; |
403 | AllowedTypes expectedB; |
404 | AllowedTypes expectedC; |
405 | SmallVector<Type> expectedResult; |
406 | |
407 | // When M = 16, we just need to calculate the number of 8xk tiles, where |
408 | // k is a factor that depends on the data type. |
409 | if (mmaShape[0] == 16) { |
410 | int64_t kFactor; |
411 | Type multiplicandFragType; |
412 | switch (*getMultiplicandAPtxType()) { |
413 | case MMATypes::tf32: |
414 | kFactor = 4; |
415 | multiplicandFragType = i32Ty; |
416 | expectedResult.push_back(LLVM::LLVMStructType::getLiteral( |
417 | context, {f32Ty, f32Ty, f32Ty, f32Ty})); |
418 | break; |
419 | case MMATypes::f16: |
420 | case MMATypes::bf16: |
421 | kFactor = 8; |
422 | multiplicandFragType = f16x2Ty; |
423 | expectedResult.push_back(f16x2x2StructTy); |
424 | expectedResult.push_back(f32x4StructTy); |
425 | break; |
426 | case MMATypes::s4: |
427 | case MMATypes::u4: |
428 | kFactor = 32; |
429 | break; |
430 | case MMATypes::b1: |
431 | kFactor = 128; |
432 | break; |
433 | case MMATypes::s8: |
434 | case MMATypes::u8: |
435 | kFactor = 16; |
436 | break; |
437 | default: |
438 | return emitError("invalid shape or multiplicand type: " + |
439 | stringifyEnum(getMultiplicandAPtxType().value())); |
440 | } |
441 | |
442 | if (isIntegerPtxType(getMultiplicandAPtxType().value())) { |
443 | expectedResult.push_back(s32x4StructTy); |
444 | expectedC.emplace_back(4, i32Ty); |
445 | multiplicandFragType = i32Ty; |
446 | } else { |
447 | expectedC.emplace_back(2, f16x2Ty); |
448 | expectedC.emplace_back(4, f32Ty); |
449 | } |
450 | |
451 | int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor); |
452 | int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); |
453 | expectedA.emplace_back(unitA, multiplicandFragType); |
454 | expectedB.emplace_back(unitB, multiplicandFragType); |
455 | allowedShapes.push_back({16, 8, kFactor}); |
456 | allowedShapes.push_back({16, 8, kFactor * 2}); |
457 | } |
458 | |
459 | // In the M=8 case, there is only 1 possible case per data type. |
460 | if (mmaShape[0] == 8) { |
461 | if (*getMultiplicandAPtxType() == MMATypes::f16) { |
462 | expectedA.emplace_back(2, f16x2Ty); |
463 | expectedB.emplace_back(2, f16x2Ty); |
464 | expectedResult.push_back(f16x2x4StructTy); |
465 | expectedResult.push_back(f32x8StructTy); |
466 | expectedC.emplace_back(4, f16x2Ty); |
467 | expectedC.emplace_back(8, f32Ty); |
468 | allowedShapes.push_back({8, 8, 4}); |
469 | } |
470 | if (*getMultiplicandAPtxType() == MMATypes::f64) { |
471 | Type f64Ty = Float64Type::get(context); |
472 | expectedA.emplace_back(1, f64Ty); |
473 | expectedB.emplace_back(1, f64Ty); |
474 | expectedC.emplace_back(2, f64Ty); |
475 | // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2)); |
476 | expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( |
477 | context, SmallVector<Type>(2, f64Ty))); |
478 | allowedShapes.push_back({8, 8, 4}); |
479 | } |
480 | if (isIntegerPtxType(getMultiplicandAPtxType().value())) { |
481 | expectedA.push_back({i32Ty}); |
482 | expectedB.push_back({i32Ty}); |
483 | expectedC.push_back({i32Ty, i32Ty}); |
484 | expectedResult.push_back(s32x2StructTy); |
485 | if (isInt4PtxType(getMultiplicandAPtxType().value())) |
486 | allowedShapes.push_back({8, 8, 32}); |
487 | if (isInt8PtxType(getMultiplicandAPtxType().value())) |
488 | allowedShapes.push_back({8, 8, 16}); |
489 | if (getMultiplicandAPtxType().value() == MMATypes::b1) |
490 | allowedShapes.push_back({8, 8, 128}); |
491 | } |
492 | } |
493 | |
494 | std::string errorMessage; |
495 | llvm::raw_string_ostream errorStream(errorMessage); |
496 | |
497 | // Check that we matched an existing shape/dtype combination. |
498 | if (expectedA.empty() || expectedB.empty() || expectedC.empty() || |
499 | !llvm::is_contained(allowedShapes, mmaShape)) { |
500 | errorStream << "unimplemented variant for MMA shape <" ; |
501 | llvm::interleaveComma(mmaShape, errorStream); |
502 | errorStream << ">" ; |
503 | return emitOpError(errorMessage); |
504 | } |
505 | |
506 | // Verify the operand types for segments of A, B, and C operands. |
507 | std::array<StringRef, 3> operandNames{"A" , "B" , "C" }; |
508 | for (const auto &iter : llvm::enumerate( |
509 | SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) { |
510 | auto spec = this->getODSOperandIndexAndLength(iter.index()); |
511 | SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first, |
512 | operand_type_begin() + spec.first + |
513 | spec.second); |
514 | bool match = llvm::is_contained(iter.value(), operandTySeg); |
515 | |
516 | if (!match) { |
517 | errorStream << "Could not match types for the " |
518 | << operandNames[iter.index()] |
519 | << " operands; expected one of " ; |
520 | for (const auto &x : iter.value()) { |
521 | errorStream << x.size() << "x" << x[0] << " " ; |
522 | } |
523 | errorStream << "but got " ; |
524 | llvm::interleaveComma(operandTySeg, errorStream); |
525 | return emitOpError(errorStream.str()); |
526 | } |
527 | } |
528 | |
529 | // Check the result type |
530 | if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { |
531 | return expectedResultType == getResult().getType(); |
532 | })) { |
533 | errorStream |
534 | << "Could not match allowed types for the result; expected one of " ; |
535 | llvm::interleaveComma(expectedResult, errorStream); |
536 | errorStream << " but got " << getResult().getType(); |
537 | return emitOpError(errorStream.str()); |
538 | } |
539 | |
540 | // Ensure that binary MMA variants have a b1 MMA operation defined. |
541 | if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) { |
542 | return emitOpError("op requires " + getB1OpAttrName().strref() + |
543 | " attribute" ); |
544 | } |
545 | |
546 | // Ensure int4/int8 MMA variants specify the accum overflow behavior |
547 | // attribute. |
548 | if (isInt4PtxType(*getMultiplicandAPtxType()) || |
549 | isInt8PtxType(*getMultiplicandAPtxType())) { |
550 | if (!getIntOverflowBehavior()) |
551 | return emitOpError("op requires " + |
552 | getIntOverflowBehaviorAttrName().strref() + |
553 | " attribute" ); |
554 | } |
555 | |
556 | return success(); |
557 | } |
558 | |
559 | LogicalResult ShflOp::verify() { |
560 | if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid" )) |
561 | return success(); |
562 | auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); |
563 | auto elementType = (type && type.getBody().size() == 2) |
564 | ? llvm::dyn_cast<IntegerType>(type.getBody()[1]) |
565 | : nullptr; |
566 | if (!elementType || elementType.getWidth() != 1) |
567 | return emitError("expected return type to be a two-element struct with " |
568 | "i1 as the second element" ); |
569 | return success(); |
570 | } |
571 | |
572 | std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type, |
573 | NVVM::MMAFrag frag, int nRow, |
574 | int nCol, |
575 | MLIRContext *context) { |
576 | unsigned numberElements = 0; |
577 | Type elementType; |
578 | OpBuilder builder(context); |
579 | Type f16x2 = VectorType::get(2, builder.getF16Type()); |
580 | if (type == NVVM::MMATypes::f16) { |
581 | elementType = f16x2; |
582 | if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) |
583 | numberElements = 8; |
584 | else |
585 | numberElements = 4; |
586 | } else if (type == NVVM::MMATypes::f32) { |
587 | elementType = builder.getF32Type(); |
588 | numberElements = 8; |
589 | } else if (type == NVVM::MMATypes::tf32) { |
590 | elementType = builder.getI32Type(); |
591 | numberElements = 4; |
592 | } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) { |
593 | elementType = builder.getI32Type(); |
594 | int parallelSize = 0; |
595 | if (frag == NVVM::MMAFrag::a) |
596 | parallelSize = nRow; |
597 | if (frag == NVVM::MMAFrag::b) |
598 | parallelSize = nCol; |
599 | |
600 | // m == 16 && n == 16 && k == 16 |
601 | if (parallelSize == 16) |
602 | numberElements = 2; |
603 | // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16 |
604 | else if (parallelSize == 8) |
605 | numberElements = 1; |
606 | else if (parallelSize == 32) |
607 | numberElements = 4; |
608 | } else if (type == NVVM::MMATypes::s32) { |
609 | elementType = builder.getI32Type(); |
610 | numberElements = 8; |
611 | } |
612 | assert(numberElements != 0 && elementType != nullptr); |
613 | return std::make_pair(x&: elementType, y&: numberElements); |
614 | } |
615 | |
616 | static std::pair<mlir::Type, unsigned> |
617 | inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, |
618 | int k, MLIRContext *context) { |
619 | int nRow, nCol; |
620 | if (frag == NVVM::MMAFrag::a) { |
621 | nRow = m; |
622 | nCol = k; |
623 | } else if (frag == NVVM::MMAFrag::b) { |
624 | nRow = k; |
625 | nCol = n; |
626 | } else { |
627 | nRow = m; |
628 | nCol = n; |
629 | } |
630 | assert(nRow && nCol); |
631 | return inferMMAType(type, frag, nRow, nCol, context); |
632 | } |
633 | |
634 | LogicalResult NVVM::WMMALoadOp::verify() { |
635 | unsigned addressSpace = |
636 | llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); |
637 | if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && |
638 | addressSpace != NVVM::kSharedMemorySpace) |
639 | return emitOpError("expected source pointer in memory " |
640 | "space 0, 1, 3" ); |
641 | |
642 | if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), |
643 | getEltype(), getFrag()) == 0) |
644 | return emitOpError() << "invalid attribute combination" ; |
645 | std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( |
646 | getEltype(), getFrag(), getM(), getN(), getK(), getContext()); |
647 | Type dstType = LLVM::LLVMStructType::getLiteral( |
648 | getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); |
649 | if (getType() != dstType) |
650 | return emitOpError("expected destination type is a structure of " ) |
651 | << typeInfo.second << " elements of type " << typeInfo.first; |
652 | return success(); |
653 | } |
654 | |
655 | LogicalResult NVVM::WMMAStoreOp::verify() { |
656 | unsigned addressSpace = |
657 | llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); |
658 | if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && |
659 | addressSpace != NVVM::kSharedMemorySpace) |
660 | return emitOpError("expected operands to be a source pointer in memory " |
661 | "space 0, 1, 3" ); |
662 | |
663 | if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), |
664 | getEltype()) == 0) |
665 | return emitOpError() << "invalid attribute combination" ; |
666 | std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( |
667 | getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); |
668 | if (getArgs().size() != typeInfo.second) |
669 | return emitOpError() << "expected " << typeInfo.second << " data operands" ; |
670 | if (llvm::any_of(getArgs(), [&typeInfo](Value operands) { |
671 | return operands.getType() != typeInfo.first; |
672 | })) |
673 | return emitOpError() << "expected data operands of type " << typeInfo.first; |
674 | return success(); |
675 | } |
676 | |
677 | LogicalResult NVVM::WMMAMmaOp::verify() { |
678 | if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(), |
679 | getLayoutB(), getEltypeA(), |
680 | getEltypeB()) == 0) |
681 | return emitOpError() << "invalid attribute combination" ; |
682 | std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK( |
683 | getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext()); |
684 | std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK( |
685 | getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext()); |
686 | std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK( |
687 | getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); |
688 | SmallVector<Type, 32> arguments; |
689 | arguments.append(typeInfoA.second, typeInfoA.first); |
690 | arguments.append(typeInfoB.second, typeInfoB.first); |
691 | arguments.append(typeInfoC.second, typeInfoC.first); |
692 | unsigned numArgs = arguments.size(); |
693 | if (getArgs().size() != numArgs) |
694 | return emitOpError() << "expected " << numArgs << " arguments" ; |
695 | for (unsigned i = 0; i < numArgs; i++) { |
696 | if (getArgs()[i].getType() != arguments[i]) |
697 | return emitOpError() << "expected argument " << i << " to be of type " |
698 | << arguments[i]; |
699 | } |
700 | Type dstType = LLVM::LLVMStructType::getLiteral( |
701 | getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first)); |
702 | if (getType() != dstType) |
703 | return emitOpError("expected destination type is a structure of " ) |
704 | << typeInfoC.second << " elements of type " << typeInfoC.first; |
705 | return success(); |
706 | } |
707 | |
708 | LogicalResult NVVM::LdMatrixOp::verify() { |
709 | unsigned addressSpace = |
710 | llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); |
711 | if (addressSpace != NVVM::kSharedMemorySpace) |
712 | return emitOpError("expected source pointer in memory space 3" ); |
713 | |
714 | if (getNum() != 1 && getNum() != 2 && getNum() != 4) |
715 | return emitOpError("expected num attribute to be 1, 2 or 4" ); |
716 | |
717 | Type i32 = IntegerType::get(getContext(), 32); |
718 | if (getNum() == 1 && getType() != i32) |
719 | return emitOpError("expected destination type is i32" ); |
720 | if (getNum() == 2 || getNum() == 4) { |
721 | Type dstType = LLVM::LLVMStructType::getLiteral( |
722 | getContext(), SmallVector<Type>(getNum(), i32)); |
723 | if (getType() != dstType) |
724 | return emitOpError("expected destination type is a structure of " ) |
725 | << getNum() << " elements of type i32" ; |
726 | } |
727 | return success(); |
728 | } |
729 | |
730 | LogicalResult NVVM::StMatrixOp::verify() { |
731 | unsigned addressSpace = |
732 | llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); |
733 | if (addressSpace != NVVM::kSharedMemorySpace) |
734 | return emitOpError("expected source pointer in memory space 3" ); |
735 | |
736 | int numMatrix = getSources().size(); |
737 | if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) |
738 | return emitOpError("expected num attribute to be 1, 2 or 4" ); |
739 | |
740 | return success(); |
741 | } |
742 | |
743 | FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) { |
744 | if (typeA == NVVM::WGMMATypes::tf32) |
745 | return 8; |
746 | if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16) |
747 | return 16; |
748 | if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8) |
749 | return 32; |
750 | if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2) |
751 | return 32; |
752 | if (typeA == NVVM::WGMMATypes::b1) |
753 | return 256; |
754 | return failure(); |
755 | } |
756 | |
757 | LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, |
758 | NVVM::WGMMATypes typeA, |
759 | NVVM::WGMMATypes typeB) { |
760 | switch (typeA) { |
761 | case NVVM::WGMMATypes::f16: |
762 | if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && |
763 | typeB == NVVM::WGMMATypes::f16) |
764 | return success(); |
765 | break; |
766 | case NVVM::WGMMATypes::tf32: |
767 | if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32) |
768 | return success(); |
769 | break; |
770 | case NVVM::WGMMATypes::u8: |
771 | case NVVM::WGMMATypes::s8: |
772 | if (typeD == NVVM::WGMMATypes::s32 && |
773 | (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8)) |
774 | return success(); |
775 | break; |
776 | case NVVM::WGMMATypes::b1: |
777 | if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1) |
778 | return success(); |
779 | break; |
780 | case NVVM::WGMMATypes::bf16: |
781 | if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && |
782 | typeB == NVVM::WGMMATypes::bf16) |
783 | return success(); |
784 | break; |
785 | case NVVM::WGMMATypes::e4m3: |
786 | case NVVM::WGMMATypes::e5m2: |
787 | if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && |
788 | (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3)) |
789 | return success(); |
790 | break; |
791 | case WGMMATypes::f32: |
792 | case WGMMATypes::s32: |
793 | llvm_unreachable("unsupported input types" ); |
794 | break; |
795 | } |
796 | return failure(); |
797 | } |
798 | |
799 | LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) { |
800 | SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64, |
801 | 72, 80, 88, 96, 104, 112, 120, 128, |
802 | 136, 144, 152, 160, 168, 176, 184, 192, |
803 | 200, 208, 216, 224, 232, 240, 248, 256}; |
804 | SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64, |
805 | 80, 96, 112, 128, 144, 160, |
806 | 176, 192, 208, 224, 240, 256}; |
807 | switch (typeA) { |
808 | case WGMMATypes::f16: |
809 | case WGMMATypes::tf32: |
810 | case WGMMATypes::bf16: |
811 | case WGMMATypes::e4m3: |
812 | case WGMMATypes::e5m2: |
813 | if (llvm::is_contained(Range&: allowedN, Element: sizeN)) |
814 | return success(); |
815 | break; |
816 | case WGMMATypes::u8: |
817 | case WGMMATypes::s8: |
818 | case WGMMATypes::b1: |
819 | if (llvm::is_contained(Range&: allowedNshort, Element: sizeN)) |
820 | return success(); |
821 | break; |
822 | case WGMMATypes::f32: |
823 | case WGMMATypes::s32: |
824 | llvm_unreachable("unsupported input types" ); |
825 | break; |
826 | } |
827 | return failure(); |
828 | } |
829 | |
830 | LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { |
831 | Value outValue = getResults(); |
832 | auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType()); |
833 | if (!stype) |
834 | return emitOpError() << "expected results to be struct" ; |
835 | int outputSize = stype.getBody().size(); |
836 | WGMMATypes typeD = getTypeD(); |
837 | WGMMATypes typeA = getTypeA(); |
838 | WGMMATypes typeB = getTypeB(); |
839 | |
840 | for (Type t : stype.getBody()) { |
841 | if (t != stype.getBody().front()) |
842 | return emitOpError() |
843 | << "all elements in struct must be same type but there is " << t; |
844 | } |
845 | |
846 | if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 && |
847 | typeD != WGMMATypes::s32) { |
848 | return emitOpError() << "does not support the given output type " |
849 | << NVVM::stringifyWGMMATypes(typeD); |
850 | } |
851 | if (typeD == WGMMATypes::s32 && |
852 | (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) { |
853 | return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg" ; |
854 | } |
855 | |
856 | if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) { |
857 | return emitOpError() << NVVM::stringifyWGMMATypes(typeD) |
858 | << " += " << NVVM::stringifyWGMMATypes(typeA) << " * " |
859 | << NVVM::stringifyWGMMATypes(typeB) |
860 | << ", it is not supported." ; |
861 | } |
862 | |
863 | // Check M |
864 | if (getShape().getM() != 64) |
865 | return emitOpError() << "shape 'm' must be 64" ; |
866 | |
867 | // Check K |
868 | FailureOr<int> allowedK = getAllowedSizeK(typeA); |
869 | if (failed(allowedK) || allowedK.value() != getShape().getK()) |
870 | return emitOpError() << "shape 'k' must be " << allowedK.value() |
871 | << " for input type " |
872 | << NVVM::stringifyWGMMATypes(typeA); |
873 | |
874 | // Check N |
875 | if (failed(isAllowedSizeN(getShape().getN(), typeA))) { |
876 | return emitOpError() << "has input type " |
877 | << NVVM::stringifyWGMMATypes(typeA) << " n is set to " |
878 | << getShape().getN() << ", it is not supported." ; |
879 | } |
880 | |
881 | // Check transpose (only available for f16/bf16) |
882 | if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) && |
883 | (getLayoutA() == mlir::NVVM::MMALayout::col || |
884 | getLayoutB() == mlir::NVVM::MMALayout::col)) { |
885 | return emitOpError() |
886 | << "given layouts layout_a = " << stringifyMMALayout(getLayoutA()) |
887 | << " and layout_b = " << stringifyMMALayout(getLayoutB()) |
888 | << " for input types " << stringifyWGMMATypes(typeA) << " and " |
889 | << stringifyWGMMATypes(typeB) |
890 | << " requires transpose. However, this is only supported for: " |
891 | << stringifyMMATypes(MMATypes::f16) << " and " |
892 | << stringifyMMATypes(MMATypes::bf16); |
893 | } |
894 | |
895 | // Check result registers |
896 | int expectedOutput = 0; |
897 | if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32) |
898 | expectedOutput = getShape().getN() / 2; |
899 | if (typeD == WGMMATypes::f16) |
900 | expectedOutput = getShape().getN() / 4; |
901 | if (outputSize != expectedOutput) { |
902 | return emitOpError() << "results " << expectedOutput |
903 | << ", however output struct has " << outputSize |
904 | << " elements" ; |
905 | } |
906 | // Check satfinite (only available for s32 accumulator) |
907 | if (typeD != WGMMATypes::s32 && |
908 | getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == |
909 | NVVM::MMAIntOverflow::satfinite) { |
910 | return emitOpError() |
911 | << " `satfinite` can be only used with s32 accumulator, however " |
912 | "the current accumulator is " |
913 | << NVVM::stringifyWGMMATypes(typeD); |
914 | } |
915 | |
916 | return success(); |
917 | } |
918 | |
919 | std::string NVVM::WgmmaMmaAsyncOp::getPtx() { |
920 | |
921 | int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); |
922 | bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; |
923 | |
924 | StringRef outputTypeName = stringifyWGMMATypes(getTypeD()); |
925 | |
926 | int expectedOutputRegisters = 0; |
927 | if (getTypeD() == WGMMATypes::f16) |
928 | expectedOutputRegisters = getShape().getN() / 4; |
929 | else |
930 | expectedOutputRegisters = getShape().getN() / 2; |
931 | |
932 | std::string ptx; |
933 | llvm::raw_string_ostream ss(ptx); |
934 | |
935 | ss << "{\n" |
936 | ".reg .pred p;\n" |
937 | "setp.ne.b32 p, $" |
938 | << ((expectedOutputRegisters * 2) + 2) |
939 | << ", 0;\n" |
940 | "wgmma.mma_async.sync.aligned.m" |
941 | << m << "n" << n << "k" << k << "." << outputTypeName << "." |
942 | << stringifyWGMMATypes(getTypeA()) << "." |
943 | << stringifyWGMMATypes(getTypeB()); |
944 | if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == |
945 | NVVM::MMAIntOverflow::satfinite) |
946 | ss << ".satfinite" ; |
947 | ss << " {" ; |
948 | int regCnt = 0; |
949 | for (; regCnt < expectedOutputRegisters; ++regCnt) { |
950 | ss << "$" << regCnt; |
951 | if (regCnt != expectedOutputRegisters - 1) |
952 | ss << ", " ; |
953 | } |
954 | |
955 | ss << "}," ; |
956 | // Need to map read/write registers correctly. |
957 | regCnt = (regCnt * 2); |
958 | ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p" ; |
959 | if (getTypeD() != WGMMATypes::s32) { |
960 | ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); |
961 | } |
962 | // Don't add transpose parameters unless needed. |
963 | if (isF16) { |
964 | ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6); |
965 | } |
966 | ss << ";\n" |
967 | << "}\n" ; |
968 | ss.flush(); |
969 | return ptx; |
970 | } |
971 | |
972 | void NVVM::WgmmaMmaAsyncOp::getAsmValues( |
973 | RewriterBase &rewriter, |
974 | llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> |
975 | &asmValues) { |
976 | bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; |
977 | if (getResults()) |
978 | asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); |
979 | if (getInouts()) |
980 | asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite}); |
981 | asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read}); |
982 | asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read}); |
983 | asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())), |
984 | mlir::NVVM::PTXRegisterMod::Read}); |
985 | if (getTypeD() != WGMMATypes::s32) { |
986 | asmValues.push_back( |
987 | {makeConstantI32(rewriter, |
988 | getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1), |
989 | mlir::NVVM::PTXRegisterMod::Read}); |
990 | asmValues.push_back( |
991 | {makeConstantI32(rewriter, |
992 | getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1), |
993 | mlir::NVVM::PTXRegisterMod::Read}); |
994 | } |
995 | if (isF16) { |
996 | asmValues.push_back( |
997 | {makeConstantI32(rewriter, static_cast<int>(getLayoutA())), |
998 | mlir::NVVM::PTXRegisterMod::Read}); |
999 | asmValues.push_back( |
1000 | {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())), |
1001 | mlir::NVVM::PTXRegisterMod::Read}); |
1002 | } |
1003 | } |
1004 | LogicalResult NVVM::FenceProxyOp::verify() { |
1005 | if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) { |
1006 | return emitOpError() << "async_shared fence requires space attribute" ; |
1007 | } |
1008 | if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) { |
1009 | return emitOpError() << "only async_shared fence can have space attribute" ; |
1010 | } |
1011 | return success(); |
1012 | } |
1013 | |
1014 | LogicalResult NVVM::SetMaxRegisterOp::verify() { |
1015 | if (getRegCount() % 8) |
1016 | return emitOpError("new register size must be multiple of 8" ); |
1017 | if (getRegCount() < 24 || getRegCount() > 256) |
1018 | return emitOpError("new register size must be in between 24 to 256" ); |
1019 | return success(); |
1020 | } |
1021 | |
1022 | LogicalResult NVVM::BarrierOp::verify() { |
1023 | if (getNumberOfThreads() && !getBarrierId()) |
1024 | return emitOpError( |
1025 | "barrier id is missing, it should be set between 0 to 15" ); |
1026 | return success(); |
1027 | } |
1028 | |
1029 | //===----------------------------------------------------------------------===// |
1030 | // NVVMDialect initialization, type parsing, and registration. |
1031 | //===----------------------------------------------------------------------===// |
1032 | |
1033 | // TODO: This should be the llvm.nvvm dialect once this is supported. |
1034 | void NVVMDialect::initialize() { |
1035 | addOperations< |
1036 | #define GET_OP_LIST |
1037 | #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" |
1038 | >(); |
1039 | addAttributes< |
1040 | #define GET_ATTRDEF_LIST |
1041 | #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" |
1042 | >(); |
1043 | |
1044 | // Support unknown operations because not all NVVM operations are |
1045 | // registered. |
1046 | allowUnknownOperations(); |
1047 | declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>(); |
1048 | declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>(); |
1049 | } |
1050 | |
1051 | LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, |
1052 | NamedAttribute attr) { |
1053 | StringAttr attrName = attr.getName(); |
1054 | // Kernel function attribute should be attached to functions. |
1055 | if (attrName == NVVMDialect::getKernelFuncAttrName()) { |
1056 | if (!isa<LLVM::LLVMFuncOp>(op)) { |
1057 | return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() |
1058 | << "' attribute attached to unexpected op" ; |
1059 | } |
1060 | } |
1061 | // If maxntid and reqntid exist, it must be an array with max 3 dim |
1062 | if (attrName == NVVMDialect::getMaxntidAttrName() || |
1063 | attrName == NVVMDialect::getReqntidAttrName()) { |
1064 | auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue()); |
1065 | if (!values || values.empty() || values.size() > 3) |
1066 | return op->emitError() |
1067 | << "'" << attrName |
1068 | << "' attribute must be integer array with maximum 3 index" ; |
1069 | } |
1070 | // If minctasm and maxnreg exist, it must be an integer attribute |
1071 | if (attrName == NVVMDialect::getMinctasmAttrName() || |
1072 | attrName == NVVMDialect::getMaxnregAttrName()) { |
1073 | if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) |
1074 | return op->emitError() |
1075 | << "'" << attrName << "' attribute must be integer constant" ; |
1076 | } |
1077 | |
1078 | return success(); |
1079 | } |
1080 | |
1081 | LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op, |
1082 | unsigned regionIndex, |
1083 | unsigned argIndex, |
1084 | NamedAttribute argAttr) { |
1085 | auto funcOp = dyn_cast<FunctionOpInterface>(op); |
1086 | if (!funcOp) |
1087 | return success(); |
1088 | |
1089 | bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName()); |
1090 | StringAttr attrName = argAttr.getName(); |
1091 | if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) { |
1092 | if (!isKernel) { |
1093 | return op->emitError() |
1094 | << "'" << attrName |
1095 | << "' attribute must be present only on kernel arguments" ; |
1096 | } |
1097 | if (!isa<UnitAttr>(argAttr.getValue())) |
1098 | return op->emitError() << "'" << attrName << "' must be a unit attribute" ; |
1099 | if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) { |
1100 | return op->emitError() |
1101 | << "'" << attrName |
1102 | << "' attribute requires the argument to also have attribute '" |
1103 | << LLVM::LLVMDialect::getByValAttrName() << "'" ; |
1104 | } |
1105 | } |
1106 | |
1107 | return success(); |
1108 | } |
1109 | |
1110 | //===----------------------------------------------------------------------===// |
1111 | // NVVM target attribute. |
1112 | //===----------------------------------------------------------------------===// |
1113 | LogicalResult |
1114 | NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
1115 | int optLevel, StringRef triple, StringRef chip, |
1116 | StringRef features, DictionaryAttr flags, |
1117 | ArrayAttr files) { |
1118 | if (optLevel < 0 || optLevel > 3) { |
1119 | emitError() << "The optimization level must be a number between 0 and 3." ; |
1120 | return failure(); |
1121 | } |
1122 | if (triple.empty()) { |
1123 | emitError() << "The target triple cannot be empty." ; |
1124 | return failure(); |
1125 | } |
1126 | if (chip.empty()) { |
1127 | emitError() << "The target chip cannot be empty." ; |
1128 | return failure(); |
1129 | } |
1130 | if (files && !llvm::all_of(files, [](::mlir::Attribute attr) { |
1131 | return attr && mlir::isa<StringAttr>(attr); |
1132 | })) { |
1133 | emitError() << "All the elements in the `link` array must be strings." ; |
1134 | return failure(); |
1135 | } |
1136 | return success(); |
1137 | } |
1138 | |
1139 | #define GET_OP_CLASSES |
1140 | #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" |
1141 | |
1142 | #define GET_ATTRDEF_CLASSES |
1143 | #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" |
1144 | |