1//===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the NVGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/DialectImplementation.h"
20#include "mlir/IR/TypeUtilities.h"
21#include "mlir/IR/Verifier.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
24
25using namespace mlir;
26using namespace mlir::nvgpu;
27
28#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
29
30void nvgpu::NVGPUDialect::initialize() {
31 addTypes<
32#define GET_TYPEDEF_LIST
33#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
34 >();
35 addAttributes<
36#define GET_ATTRDEF_LIST
37#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
38 >();
39 addOperations<
40#define GET_OP_LIST
41#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
42 >();
43}
44
45bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
46 if (!memorySpace)
47 return false;
48 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(Val&: memorySpace))
49 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
50 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(Val&: memorySpace))
51 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
52 return false;
53}
54
55bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
56 Attribute memorySpace = type.getMemorySpace();
57 return isSharedMemoryAddressSpace(memorySpace);
58}
59
60//===----------------------------------------------------------------------===//
61// NVGPU_DeviceAsyncCopyOp
62//===----------------------------------------------------------------------===//
63
64LogicalResult DeviceAsyncCopyOp::verify() {
65 auto srcMemref = llvm::cast<MemRefType>(Val: getSrc().getType());
66 auto dstMemref = llvm::cast<MemRefType>(Val: getDst().getType());
67
68 if (!srcMemref.isLastDimUnitStride())
69 return emitError(message: "source memref most minor dim must have unit stride");
70 if (!dstMemref.isLastDimUnitStride())
71 return emitError(message: "destination memref most minor dim must have unit stride");
72 if (!NVGPUDialect::hasSharedMemoryAddressSpace(type: dstMemref))
73 return emitError()
74 << "destination memref must have a memory space attribute of "
75 "IntegerAttr("
76 << NVGPUDialect::kSharedMemoryAddressSpace
77 << ") or gpu::AddressSpaceAttr(Workgroup)";
78 if (dstMemref.getElementType() != srcMemref.getElementType())
79 return emitError(message: "source and destination must have the same element type");
80 if (size_t(srcMemref.getRank()) != getSrcIndices().size())
81 return emitOpError() << "expected " << srcMemref.getRank()
82 << " source indices, got " << getSrcIndices().size();
83 if (size_t(dstMemref.getRank()) != getDstIndices().size())
84 return emitOpError() << "expected " << dstMemref.getRank()
85 << " destination indices, got "
86 << getDstIndices().size();
87 int64_t dstElements = getDstElements().getZExtValue();
88 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
89 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
90 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
91 InFlightDiagnostic diag = emitError();
92 diag << "Requested copy elements is " << dstElements << " with width "
93 << dstMemref.getElementTypeBitWidth()
94 << ". But copy elements could be one of ";
95 if ((32 / dstWidth) > 0)
96 diag << (32 / dstWidth) << ", ";
97 if ((64 / dstWidth) > 0)
98 diag << (64 / dstWidth) << ", ";
99 if ((128 / dstWidth) > 0)
100 diag << (128 / dstWidth) << ".";
101 return diag;
102 }
103 if (getBypassL1().has_value()) {
104 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
105 if (getBypassL1().value() && sizeInBytes != 16) {
106 return emitOpError() << "bypassL1 does not satify alignment for "
107 << dstMemref << " with destination element "
108 << dstElements
109 << ". Unset bypassL1, or set "
110 "destination element to "
111 << req;
112 }
113 }
114 return success();
115}
116
117//===----------------------------------------------------------------------===//
118// NVGPU_MmaSyncOp
119//===----------------------------------------------------------------------===//
120void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
121 ::mlir::OperationState &odsState, Value matrixA,
122 Value matrixB, Value matrixC, ArrayAttr mmaShape) {
123 build(odsBuilder, odsState, res: matrixC.getType(), matrixA, matrixB, matrixC,
124 mmaShape, tf32Enabled: UnitAttr());
125}
126
127void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
128 ::mlir::OperationState &odsState, Value matrixA,
129 Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
130 bool tf32Enabled) {
131 build(odsBuilder, odsState, res: matrixC.getType(), matrixA, matrixB, matrixC,
132 mmaShape: odsBuilder.getI64ArrayAttr(values: mmaShape),
133 tf32Enabled: tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
134}
135
136/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
137static LogicalResult verifyMmaSyncOp(Operation *op,
138 TypedValue<VectorType> matrixA,
139 TypedValue<VectorType> matrixB,
140 TypedValue<VectorType> matrixC,
141 const std::array<int64_t, 3> &mmaShape,
142 bool tf32Enabled, bool sparse = false) {
143
144 // The verification for mma.sync covering various shapes and data types is
145 // based on the fundamental tensor core shape.
146
147 // "Fundamental" tensor core shapes:
148 // - For F32 (TF32), F16, S8, and S4 data
149 // types the fundamental tensor core operation is of shape 8-by-8-by-128b.
150 // - F64 is an exception and is of shape 8-by-8-by-256b.
151 int64_t shapeM = 8;
152 int64_t shapeN = 8;
153 int64_t shapeK; // set based on data type (128b for all data types except F64)
154
155 // Number of elements A, B, and C per thread per fundamental tensor core tile
156 int64_t numElementA; // set based on data type (32b except F64)
157 int64_t numElementB; // set based on data type (32b except F64)
158 int64_t numElementC{2}; // two accumulator elements per fundamental tile
159
160 // nvgpu.mma.sync vector operands (per thread)
161 auto aVector = matrixA.getType();
162 auto bVector = matrixB.getType();
163 auto cVector = matrixC.getType();
164
165 // vector shapes
166 ArrayRef<int64_t> aShape = aVector.getShape();
167 ArrayRef<int64_t> bShape = bVector.getShape();
168 ArrayRef<int64_t> cShape = cVector.getShape();
169
170 // vector element type
171 Type aType = aVector.getElementType();
172
173 // Certain data types are not allowed in sparse mode.
174 if (sparse && aType.isF64())
175 return op->emitError() << "f64 is not supported for sparse mode";
176
177 if (aType.isF64()) {
178 // exception to 8-by-8-128b fundamental tensor core tile size
179 shapeK = 4;
180 numElementA = 1;
181 numElementB = 1;
182 } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
183 aType.isInteger(width: 8) || aType.isInteger(width: 4)) {
184 // 8-by-8-128b fundamental tensor core tile size
185 int operandBitwidth = aType.getIntOrFloatBitWidth();
186 shapeK = 128 / operandBitwidth; // 128b wide shapeK
187
188 numElementA = 32 / operandBitwidth; // 32b wide operand A
189 numElementB = 32 / operandBitwidth; // 32b wide operand B
190 } else {
191 return op->emitError()
192 << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
193 "supported by "
194 << op->getName();
195 }
196
197 //
198 // Basic verification
199 //
200
201 if (aShape.size() != 2) {
202 return op->emitError() << "matrixA must be 2 dimensional vector";
203 }
204
205 if (bShape.size() != 2) {
206 return op->emitError() << "matrixB must be 2 dimensional vector";
207 }
208
209 if (cShape.size() != 2) {
210 return op->emitError() << "matrixC must be 2 dimensional vector";
211 }
212
213 auto [m, n, k] = mmaShape;
214
215 // verify warp-wide size for vector a
216 int64_t sparseFactor = sparse ? 2 : 1;
217 if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
218 return op->emitOpError()
219 << "expected " << m * k << " warp-wide matrix A elements";
220
221 // verify warp-wide size for vector b
222 if (bShape[0] * bShape[1] * kWarpSize != k * n)
223 return op->emitOpError()
224 << "expected " << k * n << " warp-wide matrix B elements";
225
226 // verify warp-wide size for vector c
227 if (cShape[0] * cShape[1] * kWarpSize != m * n)
228 return op->emitOpError()
229 << "expected " << m * n << " warp-wide matrix C elements";
230
231 // verify tf32 tensor cores are enabled for only F32 datatype
232 if (tf32Enabled && !(aType.isF32()))
233 return op->emitOpError()
234 << "expected tf32 tensor cores only for F32 operands";
235
236 //
237 // Extended verification
238 //
239
240 // tiles of fundamental tensor core operations
241 int64_t mTile = m / shapeM;
242 int64_t nTile = n / shapeN;
243 int64_t kTile = k / shapeK;
244
245 // verify shape of aVector
246 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
247 (aShape[1] != numElementA))
248 return op->emitOpError() << "expected matrix A to be shaped ("
249 << mTile * kTile << " x " << numElementA << ")";
250
251 // verify shape of bVector
252 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
253 return op->emitOpError() << "expected matrix B to be shaped ("
254 << kTile * nTile << " x " << numElementB << ")";
255
256 // verify shape of cVector
257 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
258 return op->emitOpError() << "expected matrix C to be shaped ("
259 << mTile * nTile << " x " << numElementC << ")";
260
261 return success();
262}
263
264LogicalResult MmaSyncOp::verify() {
265 return verifyMmaSyncOp(op: this->getOperation(), matrixA: getMatrixA(), matrixB: getMatrixB(),
266 matrixC: getMatrixC(), mmaShape: getMmaShapeAsArray(),
267 tf32Enabled: getOperation()->hasAttr(name: getTf32EnabledAttrName()));
268}
269
270//===----------------------------------------------------------------------===//
271// NVGPU_MmaSparseSyncOp
272//===----------------------------------------------------------------------===//
273void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
274 ::mlir::OperationState &odsState, Value matrixA,
275 Value matrixB, Value matrixC, Value sparseMetadata,
276 ArrayRef<int64_t> mmaShape) {
277 build(odsBuilder, odsState, res: matrixC.getType(), matrixA, matrixB, matrixC,
278 sparseMetadata, mmaShape: odsBuilder.getI64ArrayAttr(values: mmaShape), sparsitySelector: 0, tf32Enabled: UnitAttr());
279}
280
281LogicalResult MmaSparseSyncOp::verify() {
282 unsigned sparsitySelector = getSparsitySelector();
283 if (sparsitySelector > 1)
284 return emitOpError() << "sparsity selector should be 0 or 1";
285 return verifyMmaSyncOp(op: this->getOperation(), matrixA: getMatrixA(), matrixB: getMatrixB(),
286 matrixC: getMatrixC(), mmaShape: getMmaShapeAsArray(),
287 tf32Enabled: getOperation()->hasAttr(name: getTf32EnabledAttrName()),
288 sparse: true);
289}
290
291//===----------------------------------------------------------------------===//
292// NVGPU_LdMatrixOp
293//===----------------------------------------------------------------------===//
294LogicalResult LdMatrixOp::verify() {
295
296 // ldmatrix reads data from source in shared memory
297 auto srcMemref = llvm::cast<MemRefType>(Val: getSrcMemref().getType());
298
299 // ldmatrix writes data to result/destination in vector registers
300 auto resVector = llvm::cast<VectorType>(Val: getRes().getType());
301
302 // vector register shape, element type, and bitwidth
303 ArrayRef<int64_t> resShape = resVector.getShape();
304 Type resType = resVector.getElementType();
305 int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
306
307 // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
308 int64_t numElementsPer32b = 32 / elementBitWidth;
309
310 // number of 8-by-8 tiles
311 int64_t numTiles = getNumTiles();
312
313 // transpose elements in vector registers at 16b granularity when true
314 bool isTranspose = getTranspose();
315
316 //
317 // verification
318 //
319
320 if (!NVGPUDialect::hasSharedMemoryAddressSpace(type: srcMemref))
321 return emitError()
322 << "expected nvgpu.ldmatrix srcMemref must have a memory space "
323 "attribute of IntegerAttr("
324 << NVGPUDialect::kSharedMemoryAddressSpace
325 << ") or gpu::AddressSpaceAttr(Workgroup)";
326 if (elementBitWidth > 32)
327 return emitError() << "nvgpu.ldmatrix works for 32b or lower";
328 if (isTranspose && !(elementBitWidth == 16))
329 return emitError()
330 << "nvgpu.ldmatrix transpose works only at 16b granularity";
331 if (resShape.size() != 2) {
332 return emitError() << "results must be 2 dimensional vector";
333 }
334 if (!(resShape[1] == numElementsPer32b))
335 return emitError() << "expected vector register shape[1] = "
336 << numElementsPer32b;
337 if (!(resShape[0] == numTiles))
338 return emitError()
339 << "expected vector register shape[0] and numTiles to match";
340
341 return success();
342}
343
344//===----------------------------------------------------------------------===//
345// NVGPU_TmaAsyncLoadOp
346//===----------------------------------------------------------------------===//
347
348std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
349 Operation *op, nvgpu::TensorMapDescriptorType descType,
350 std::optional<MemRefType> memrefType = std::nullopt) {
351 MemRefType descMemref = descType.getTensor();
352 // Limitation
353 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
354 return op->emitError() << "Interleave options are not supported yet.";
355
356 // Address space check for shared memory check
357 if (!NVGPUDialect::hasSharedMemoryAddressSpace(type: descMemref)) {
358 return op->emitError() << "the tensor map descriptor has incorrect address "
359 "space, it must be shared memory address space.";
360 }
361 // Support only static shape for the time being
362 if (!descMemref.hasStaticShape())
363 return op->emitError() << "the tensor map descriptor must be static shaped";
364
365 for (auto dim : descMemref.getShape()) {
366 if (dim <= 0 || dim > kMaxTMADimension) {
367 return op->emitError() << "the tensor map descriptor must have "
368 "dimensions between 1 and "
369 << kMaxTMADimension << " but it is " << dim;
370 }
371 }
372 if (descMemref.getRank() > 1 &&
373 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
374 unsigned lastDimensionByte =
375 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
376 if (lastDimensionByte != kMaxTMALastdimByte)
377 return op->emitError() << "the tensormap descriptor must have last "
378 "dimension of "
379 << kMaxTMALastdimByte << " bytes but it is "
380 << lastDimensionByte << " bytes";
381 }
382
383 // No verification if memref type is not provided
384 if (!memrefType.has_value())
385 return std::nullopt;
386
387 MemRefType dstMemref = memrefType.value();
388
389 // Check element type
390 if (descMemref.getElementType() != dstMemref.getElementType()) {
391 return op->emitError() << "the element type of tensor map descriptor and "
392 "memref must be same";
393 }
394
395 if (!NVGPUDialect::hasSharedMemoryAddressSpace(type: dstMemref)) {
396 return op->emitError() << "the destination memref has incorrect address "
397 "space, it must be shared memory address space.";
398 }
399 if (!dstMemref.hasStaticShape())
400 return op->emitError() << "the destination memref must be static shaped";
401
402 if (dstMemref.getRank() != descMemref.getRank()) {
403 return op->emitError() << "the shape of tensor map descriptor and "
404 "memref must have same rank";
405 }
406 if (!descMemref.getShape().equals(RHS: dstMemref.getShape())) {
407 return op->emitError() << "memref and tensor map shapes mismatch "
408 << descMemref << " != " << dstMemref;
409 }
410
411 return std::nullopt;
412}
413
414LogicalResult TmaAsyncLoadOp::verify() {
415 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
416 op: *this, descType: getTensorMapDescriptor().getType(), memrefType: getDst().getType());
417 if (error.has_value())
418 return error.value();
419
420 if (getCoordinates().size() > kMaxTMATensorDimension) {
421 return emitError() << "Maximum " << kMaxTMATensorDimension
422 << " coordinates are supported.";
423 }
424 if (getCoordinates().size() !=
425 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
426 return emitError() << "number of coordinates do not match with the rank of "
427 "tensor descriptor map.";
428 }
429
430 return success();
431}
432
433//===----------------------------------------------------------------------===//
434// NVGPU_TmaAsyncStoreOp
435//===----------------------------------------------------------------------===//
436
437LogicalResult TmaAsyncStoreOp::verify() {
438 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
439 op: *this, descType: getTensorMapDescriptor().getType(), memrefType: getSrc().getType());
440 if (error.has_value())
441 return error.value();
442
443 if (getCoordinates().size() > kMaxTMATensorDimension) {
444 return emitError() << "Maximum " << kMaxTMATensorDimension
445 << " coordinates are supported.";
446 }
447 if (getCoordinates().size() !=
448 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
449 return emitError() << "number of coordinates do not match with the rank of "
450 "tensor descriptor map.";
451 }
452
453 return success();
454}
455
456LogicalResult TmaCreateDescriptorOp::verify() {
457 if (getBoxDimensions().size() > kMaxTMATensorDimension) {
458 return emitError() << "Maximum " << kMaxTMATensorDimension
459 << " coordinates are supported.";
460 }
461
462 std::optional<InFlightDiagnostic> error =
463 verifyTmaDescriptorWithMemref(op: *this, descType: getTensorMap().getType());
464 if (error.has_value())
465 return error.value();
466
467 return success();
468}
469
470//===----------------------------------------------------------------------===//
471// NVGPU_WarpgroupGenerateDescriptorOp
472//===----------------------------------------------------------------------===//
473
474LogicalResult WarpgroupGenerateDescriptorOp::verify() {
475 std::optional<InFlightDiagnostic> error =
476 verifyTmaDescriptorWithMemref(op: *this, descType: getTensorMap().getType());
477 if (error.has_value())
478 return error.value();
479
480 if (getTensorMap().getType().getSwizzle() !=
481 TensorMapSwizzleKind::SWIZZLE_128B) {
482 return emitError() << "supports only "
483 << stringifyTensorMapSwizzleKind(
484 TensorMapSwizzleKind::SWIZZLE_128B)
485 << " is supported for the time being";
486 }
487
488 if (getTensorMap().getType().getInterleave() !=
489 TensorMapInterleaveKind::INTERLEAVE_NONE) {
490 return emitError() << "supports only "
491 << stringifyTensorMapInterleaveKind(
492 TensorMapInterleaveKind::INTERLEAVE_NONE)
493 << " is supported for the time being";
494 }
495
496 return success();
497}
498
499//===----------------------------------------------------------------------===//
500// WarpgroupMmaOp
501//===----------------------------------------------------------------------===//
502
503LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
504 // F32 += F16 + F16
505 // F16 += F16 + F16
506 if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16()))
507 return success();
508 // F32 += TF32 + TF32
509 if (typeA.isTF32() && typeD.isF32() && typeB.isTF32())
510 return success();
511 // s32 += i8 + i8
512 if (typeA.isInteger(width: 16) && typeB.isInteger(width: 16) && typeD.isInteger(width: 32))
513 return success();
514 // s32 += i1 + i1
515 if (typeA.isInteger(width: 1) && typeB.isInteger(width: 1) && typeD.isInteger(width: 32))
516 return success();
517 // F32 += BF16 + BF16
518 // F16 += BF16 + BF16
519 if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16()))
520 return success();
521 // F16 += f8 + f8
522 // F32 += f8 + f8
523 if (isa<Float8E5M2Type, Float8E4M3FNType>(Val: typeA) &&
524 isa<Float8E5M2Type, Float8E4M3FNType>(Val: typeB) &&
525 (typeD.isF32() || typeD.isF16()))
526 return success();
527
528 return failure();
529}
530
531LogicalResult isAllowedSizeM(int sizeM) {
532 if (sizeM % kWgmmaSizeM)
533 return failure();
534 return success();
535}
536
537LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
538 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
539 72, 80, 88, 96, 104, 112, 120, 128,
540 136, 144, 152, 160, 168, 176, 184, 192,
541 200, 208, 216, 224, 232, 240, 248, 256};
542 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
543 80, 96, 112, 128, 144, 160,
544 176, 192, 208, 224, 240, 256};
545 if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
546 isa<Float8E5M2Type, Float8E4M3FNType>(Val: typeA))
547 if (llvm::is_contained(Range&: allowedN, Element: sizeN))
548 return success();
549
550 if (typeA.isInteger(width: 8) || typeA.isInteger(width: 1))
551 if (llvm::is_contained(Range&: allowedNshort, Element: sizeN))
552 return success();
553 return failure();
554}
555
556LogicalResult WarpgroupMmaOp::verify() {
557 if (getTransposeA() && !getTransposeB())
558 return emitOpError()
559 << "supports non-transpose A (Row Major) "
560 "and transpose B (Column Major) for the time being ";
561 MemRefType matrixA = getDescriptorA().getType().getTensor();
562 MemRefType matrixB = getDescriptorB().getType().getTensor();
563 VectorType matrixC = getMatrixC().getType().getFragmented();
564 VectorType matrixD = getMatrixD().getType().getFragmented();
565
566 if (matrixC != matrixD)
567 return emitOpError() << "type of matrix C and matrix D must be the same";
568
569 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
570 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
571 return emitOpError()
572 << "has matrices A, B, C and D, they must be 2 dimensional";
573 }
574
575 if (matrixA.getShape()[1] != matrixB.getShape()[0])
576 return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
577 << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
578 << " )";
579 if (matrixA.getShape()[0] != matrixC.getShape()[0])
580 return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
581 << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
582 << " )";
583 if (matrixB.getShape()[1] != matrixC.getShape()[1])
584 return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1]
585 << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
586 << " )";
587
588 if (failed(Result: isAllowedWGMMADataType(typeD: matrixC.getElementType(),
589 typeA: matrixA.getElementType(),
590 typeB: matrixB.getElementType())))
591 return emitOpError() << matrixC.getElementType()
592 << " += " << matrixA.getElementType() << " * "
593 << matrixB.getElementType()
594 << ", it is not supported.";
595 // Check N
596 if (failed(Result: isAllowedSizeN(sizeN: matrixB.getDimSize(idx: 1), typeA: matrixA.getElementType()))) {
597 return emitOpError() << "has input type " << matrixB << " n is set to "
598 << matrixB.getDimSize(idx: 1) << ", it is not supported";
599 }
600
601 // Currently, f16/bf16 supported
602 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
603 !matrixA.getElementType().isBF16()) {
604 return emitOpError() << "hit a limitation: " << matrixC.getElementType()
605 << " += " << matrixA.getElementType() << " * "
606 << matrixB.getElementType()
607 << ", it is not supported yet";
608 }
609
610 return success();
611}
612
613LogicalResult WarpgroupMmaStoreOp::verify() {
614 MemRefType dstMemrefType = getDstMemref().getType();
615 VectorType vtype = getMatrixD().getType().getFragmented();
616
617 // Limitation
618 if (!vtype.getElementType().isF32()) {
619 return emitOpError()
620 << "hit a limitation: only f32 results for the time being";
621 }
622 if (vtype.getDimSize(idx: 0) != dstMemrefType.getDimSize(idx: 0) ||
623 vtype.getDimSize(idx: 1) != dstMemrefType.getDimSize(idx: 1)) {
624 return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(idx: 1)
625 << "] values. However, destination memref["
626 << dstMemrefType.getDimSize(idx: 0) << "]["
627 << dstMemrefType.getDimSize(idx: 1)
628 << "] does not have same size as results";
629 }
630 return success();
631}
632
633//===----------------------------------------------------------------------===//
634// WarpgroupMmaInitAccumulatorOp
635//===----------------------------------------------------------------------===//
636
637LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
638
639 nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
640 int64_t sizeM = accType.getFragmented().getDimSize(idx: 0);
641 int64_t sizeN = accType.getFragmented().getDimSize(idx: 1);
642 Type elemType = accType.getFragmented().getElementType();
643
644 if (failed(Result: isAllowedSizeM(sizeM)) ||
645 failed(Result: isAllowedSizeN(sizeN, typeA: elemType))) {
646 return emitOpError() << "has type " << accType.getFragmented()
647 << ". It does not fit into warp-group "
648 "level (wgmma) matrix multiplication instruction "
649 "(or not supported yet)";
650 }
651 return success();
652}
653
654//===----------------------------------------------------------------------===//
655// RcpOp
656//===----------------------------------------------------------------------===//
657
658LogicalResult RcpOp::verify() {
659 RcpRoundingModeAttr rounding = getRoundingAttr();
660 bool ftz = getFtz();
661 // Currently, only `rcp_approx` and `ftz` is supported.
662 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
663 return emitOpError() << "has a limitation. " << rounding
664 << " or non-ftz is not supported yet.";
665 }
666 return success();
667}
668
669//===----------------------------------------------------------------------===//
670// TableGen'd dialect, type, and op definitions
671//===----------------------------------------------------------------------===//
672
673#define GET_ATTRDEF_CLASSES
674#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
675
676#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
677
678#define GET_OP_CLASSES
679#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
680
681#define GET_TYPEDEF_CLASSES
682#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
683

source code of mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp