1//===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the NVGPU dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/Diagnostics.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/OpImplementation.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/TypeUtilities.h"
25#include "mlir/IR/Verifier.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/StringExtras.h"
28#include "llvm/ADT/TypeSwitch.h"
29
30using namespace mlir;
31using namespace mlir::nvgpu;
32
33#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
34
35void nvgpu::NVGPUDialect::initialize() {
36 addTypes<
37#define GET_TYPEDEF_LIST
38#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
39 >();
40 addAttributes<
41#define GET_ATTRDEF_LIST
42#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
43 >();
44 addOperations<
45#define GET_OP_LIST
46#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
47 >();
48}
49
50bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
51 if (!memorySpace)
52 return false;
53 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
54 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
55 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
56 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
57 return false;
58}
59
60bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
61 Attribute memorySpace = type.getMemorySpace();
62 return isSharedMemoryAddressSpace(memorySpace);
63}
64
65//===----------------------------------------------------------------------===//
66// NVGPU_DeviceAsyncCopyOp
67//===----------------------------------------------------------------------===//
68
69LogicalResult DeviceAsyncCopyOp::verify() {
70 auto srcMemref = llvm::cast<MemRefType>(getSrc().getType());
71 auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
72
73 if (!srcMemref.isLastDimUnitStride())
74 return emitError("source memref most minor dim must have unit stride");
75 if (!dstMemref.isLastDimUnitStride())
76 return emitError("destination memref most minor dim must have unit stride");
77 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
78 return emitError()
79 << "destination memref must have a memory space attribute of "
80 "IntegerAttr("
81 << NVGPUDialect::kSharedMemoryAddressSpace
82 << ") or gpu::AddressSpaceAttr(Workgroup)";
83 if (dstMemref.getElementType() != srcMemref.getElementType())
84 return emitError("source and destination must have the same element type");
85 if (size_t(srcMemref.getRank()) != getSrcIndices().size())
86 return emitOpError() << "expected " << srcMemref.getRank()
87 << " source indices, got " << getSrcIndices().size();
88 if (size_t(dstMemref.getRank()) != getDstIndices().size())
89 return emitOpError() << "expected " << dstMemref.getRank()
90 << " destination indices, got "
91 << getDstIndices().size();
92 int64_t dstElements = getDstElements().getZExtValue();
93 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
94 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
95 unsigned dstWidth = dstMemref.getElementTypeBitWidth();
96 InFlightDiagnostic diag = emitError();
97 diag << "Requested copy elements is " << dstElements << " with width "
98 << dstMemref.getElementTypeBitWidth()
99 << ". But copy elements could be one of ";
100 if ((32 / dstWidth) > 0)
101 diag << (32 / dstWidth) << ", ";
102 if ((64 / dstWidth) > 0)
103 diag << (64 / dstWidth) << ", ";
104 if ((128 / dstWidth) > 0)
105 diag << (128 / dstWidth) << ".";
106 return diag;
107 }
108 if (getBypassL1().has_value()) {
109 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
110 if (getBypassL1().value() && sizeInBytes != 16) {
111 return emitOpError() << "bypassL1 does not satify alignment for "
112 << dstMemref << " with destination element "
113 << dstElements
114 << ". Unset bypassL1, or set "
115 "destination element to "
116 << req;
117 }
118 }
119 return success();
120}
121
122//===----------------------------------------------------------------------===//
123// NVGPU_MmaSyncOp
124//===----------------------------------------------------------------------===//
125void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
126 ::mlir::OperationState &odsState, Value matrixA,
127 Value matrixB, Value matrixC, ArrayAttr mmaShape) {
128 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
129 mmaShape, UnitAttr());
130}
131
132void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
133 ::mlir::OperationState &odsState, Value matrixA,
134 Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
135 bool tf32Enabled) {
136 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
137 odsBuilder.getI64ArrayAttr(mmaShape),
138 tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
139}
140
141/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
142static LogicalResult verifyMmaSyncOp(Operation *op,
143 TypedValue<VectorType> matrixA,
144 TypedValue<VectorType> matrixB,
145 TypedValue<VectorType> matrixC,
146 const std::array<int64_t, 3> &mmaShape,
147 bool tf32Enabled, bool sparse = false) {
148
149 // The verification for mma.sync covering various shapes and data types is
150 // based on the fundamental tensor core shape.
151
152 // "Fundamental" tensor core shapes:
153 // - For F32 (TF32), F16, S8, and S4 data
154 // types the fundamental tensor core operation is of shape 8-by-8-by-128b.
155 // - F64 is an exception and is of shape 8-by-8-by-256b.
156 int64_t shapeM = 8;
157 int64_t shapeN = 8;
158 int64_t shapeK; // set based on data type (128b for all data types except F64)
159
160 // Number of elements A, B, and C per thread per fundamental tensor core tile
161 int64_t numElementA; // set based on data type (32b except F64)
162 int64_t numElementB; // set based on data type (32b except F64)
163 int64_t numElementC{2}; // two accumulator elements per fundamental tile
164
165 // nvgpu.mma.sync vector operands (per thread)
166 auto aVector = matrixA.getType();
167 auto bVector = matrixB.getType();
168 auto cVector = matrixC.getType();
169
170 // vector shapes
171 ArrayRef<int64_t> aShape = aVector.getShape();
172 ArrayRef<int64_t> bShape = bVector.getShape();
173 ArrayRef<int64_t> cShape = cVector.getShape();
174
175 // vector element type
176 Type aType = aVector.getElementType();
177
178 // Certain data types are not allowed in sparse mode.
179 if (sparse && aType.isF64())
180 return op->emitError() << "f64 is not supported for sparse mode";
181
182 if (aType.isF64()) {
183 // exception to 8-by-8-128b fundamental tensor core tile size
184 shapeK = 4;
185 numElementA = 1;
186 numElementB = 1;
187 } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
188 aType.isInteger(width: 8) || aType.isInteger(width: 4)) {
189 // 8-by-8-128b fundamental tensor core tile size
190 int operandBitwidth = aType.getIntOrFloatBitWidth();
191 shapeK = 128 / operandBitwidth; // 128b wide shapeK
192
193 numElementA = 32 / operandBitwidth; // 32b wide operand A
194 numElementB = 32 / operandBitwidth; // 32b wide operand B
195 } else {
196 return op->emitError()
197 << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
198 "supported by "
199 << op->getName();
200 }
201
202 //
203 // Basic verification
204 //
205
206 if (aShape.size() != 2) {
207 return op->emitError() << "matrixA must be 2 dimensional vector";
208 }
209
210 if (bShape.size() != 2) {
211 return op->emitError() << "matrixB must be 2 dimensional vector";
212 }
213
214 if (cShape.size() != 2) {
215 return op->emitError() << "matrixC must be 2 dimensional vector";
216 }
217
218 auto [m, n, k] = mmaShape;
219
220 // verify warp-wide size for vector a
221 int64_t sparseFactor = sparse ? 2 : 1;
222 if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
223 return op->emitOpError()
224 << "expected " << m * k << " warp-wide matrix A elements";
225
226 // verify warp-wide size for vector b
227 if (bShape[0] * bShape[1] * kWarpSize != k * n)
228 return op->emitOpError()
229 << "expected " << k * n << " warp-wide matrix B elements";
230
231 // verify warp-wide size for vector c
232 if (cShape[0] * cShape[1] * kWarpSize != m * n)
233 return op->emitOpError()
234 << "expected " << m * n << " warp-wide matrix C elements";
235
236 // verify tf32 tensor cores are enabled for only F32 datatype
237 if (tf32Enabled && !(aType.isF32()))
238 return op->emitOpError()
239 << "expected tf32 tensor cores only for F32 operands";
240
241 //
242 // Extended verification
243 //
244
245 // tiles of fundamental tensor core operations
246 int64_t mTile = m / shapeM;
247 int64_t nTile = n / shapeN;
248 int64_t kTile = k / shapeK;
249
250 // verify shape of aVector
251 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
252 (aShape[1] != numElementA))
253 return op->emitOpError() << "expected matrix A to be shaped ("
254 << mTile * kTile << " x " << numElementA << ")";
255
256 // verify shape of bVector
257 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
258 return op->emitOpError() << "expected matrix B to be shaped ("
259 << kTile * nTile << " x " << numElementB << ")";
260
261 // verify shape of cVector
262 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
263 return op->emitOpError() << "expected matrix C to be shaped ("
264 << mTile * nTile << " x " << numElementC << ")";
265
266 return success();
267}
268
269LogicalResult MmaSyncOp::verify() {
270 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
271 getMatrixC(), getMmaShapeAsArray(),
272 getOperation()->hasAttr(getTf32EnabledAttrName()));
273}
274
275//===----------------------------------------------------------------------===//
276// NVGPU_MmaSparseSyncOp
277//===----------------------------------------------------------------------===//
278void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
279 ::mlir::OperationState &odsState, Value matrixA,
280 Value matrixB, Value matrixC, Value sparseMetadata,
281 ArrayRef<int64_t> mmaShape) {
282 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
283 sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
284}
285
286LogicalResult MmaSparseSyncOp::verify() {
287 unsigned sparsitySelector = getSparsitySelector();
288 if (sparsitySelector > 1)
289 return emitOpError() << "sparsity selector should be 0 or 1";
290 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
291 getMatrixC(), getMmaShapeAsArray(),
292 getOperation()->hasAttr(getTf32EnabledAttrName()),
293 true);
294}
295
296//===----------------------------------------------------------------------===//
297// NVGPU_LdMatrixOp
298//===----------------------------------------------------------------------===//
299LogicalResult LdMatrixOp::verify() {
300
301 // ldmatrix reads data from source in shared memory
302 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());
303
304 // ldmatrix writes data to result/destination in vector registers
305 auto resVector = llvm::cast<VectorType>(getRes().getType());
306
307 // vector register shape, element type, and bitwidth
308 ArrayRef<int64_t> resShape = resVector.getShape();
309 Type resType = resVector.getElementType();
310 int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
311
312 // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
313 int64_t numElementsPer32b = 32 / elementBitWidth;
314
315 // number of 8-by-8 tiles
316 int64_t numTiles = getNumTiles();
317
318 // transpose elements in vector registers at 16b granularity when true
319 bool isTranspose = getTranspose();
320
321 //
322 // verification
323 //
324
325 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
326 return emitError()
327 << "expected nvgpu.ldmatrix srcMemref must have a memory space "
328 "attribute of IntegerAttr("
329 << NVGPUDialect::kSharedMemoryAddressSpace
330 << ") or gpu::AddressSpaceAttr(Workgroup)";
331 if (elementBitWidth > 32)
332 return emitError() << "nvgpu.ldmatrix works for 32b or lower";
333 if (isTranspose && !(elementBitWidth == 16))
334 return emitError()
335 << "nvgpu.ldmatrix transpose works only at 16b granularity";
336 if (resShape.size() != 2) {
337 return emitError() << "results must be 2 dimensional vector";
338 }
339 if (!(resShape[1] == numElementsPer32b))
340 return emitError() << "expected vector register shape[1] = "
341 << numElementsPer32b;
342 if (!(resShape[0] == numTiles))
343 return emitError()
344 << "expected vector register shape[0] and numTiles to match";
345
346 return success();
347}
348
349//===----------------------------------------------------------------------===//
350// NVGPU_TmaAsyncLoadOp
351//===----------------------------------------------------------------------===//
352
353std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
354 Operation *op, nvgpu::TensorMapDescriptorType descType,
355 std::optional<MemRefType> memrefType = std::nullopt) {
356 MemRefType descMemref = descType.getTensor();
357 // Limitation
358 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
359 return op->emitError() << "Interleave options are not supported yet.";
360
361 // Address space check for shared memory check
362 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
363 return op->emitError() << "the tensor map descriptor has incorrect address "
364 "space, it must be shared memory address space.";
365 }
366 // Support only static shape for the time being
367 if (!descMemref.hasStaticShape())
368 return op->emitError() << "the tensor map descriptor must be static shaped";
369
370 for (auto dim : descMemref.getShape()) {
371 if (dim <= 0 || dim > kMaxTMADimension) {
372 return op->emitError() << "the tensor map descriptor must have "
373 "dimensions between 1 and "
374 << kMaxTMADimension << " but it is " << dim;
375 }
376 }
377 if (descMemref.getRank() > 1 &&
378 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
379 unsigned lastDimensionByte =
380 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
381 if (lastDimensionByte != kMaxTMALastdimByte)
382 return op->emitError() << "the tensormap descriptor must have last "
383 "dimension of "
384 << kMaxTMALastdimByte << " bytes but it is "
385 << lastDimensionByte << " bytes";
386 }
387
388 // No verification if memref type is not provided
389 if (!memrefType.has_value())
390 return std::nullopt;
391
392 MemRefType dstMemref = memrefType.value();
393
394 // Check element type
395 if (descMemref.getElementType() != dstMemref.getElementType()) {
396 return op->emitError() << "the element type of tensor map descriptor and "
397 "memref must be same";
398 }
399
400 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
401 return op->emitError() << "the destination memref has incorrect address "
402 "space, it must be shared memory address space.";
403 }
404 if (!dstMemref.hasStaticShape())
405 return op->emitError() << "the destination memref must be static shaped";
406
407 if (dstMemref.getRank() != descMemref.getRank()) {
408 return op->emitError() << "the shape of tensor map descriptor and "
409 "memref must have same rank";
410 }
411 if (!descMemref.getShape().equals(dstMemref.getShape())) {
412 return op->emitError() << "memref and tensor map shapes mismatch "
413 << descMemref << " != " << dstMemref;
414 }
415
416 return std::nullopt;
417}
418
419LogicalResult TmaAsyncLoadOp::verify() {
420 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
421 *this, getTensorMapDescriptor().getType(), getDst().getType());
422 if (error.has_value())
423 return error.value();
424
425 if (getCoordinates().size() > kMaxTMATensorDimension) {
426 return emitError() << "Maximum " << kMaxTMATensorDimension
427 << " coordinates are supported.";
428 }
429 if (getCoordinates().size() !=
430 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
431 return emitError() << "number of coordinates do not match with the rank of "
432 "tensor descriptor map.";
433 }
434
435 return success();
436}
437
438//===----------------------------------------------------------------------===//
439// NVGPU_TmaAsyncStoreOp
440//===----------------------------------------------------------------------===//
441
442LogicalResult TmaAsyncStoreOp::verify() {
443 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
444 *this, getTensorMapDescriptor().getType(), getSrc().getType());
445 if (error.has_value())
446 return error.value();
447
448 if (getCoordinates().size() > kMaxTMATensorDimension) {
449 return emitError() << "Maximum " << kMaxTMATensorDimension
450 << " coordinates are supported.";
451 }
452 if (getCoordinates().size() !=
453 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
454 return emitError() << "number of coordinates do not match with the rank of "
455 "tensor descriptor map.";
456 }
457
458 return success();
459}
460
461LogicalResult TmaCreateDescriptorOp::verify() {
462 if (getBoxDimensions().size() > kMaxTMATensorDimension) {
463 return emitError() << "Maximum " << kMaxTMATensorDimension
464 << " coordinates are supported.";
465 }
466
467 std::optional<InFlightDiagnostic> error =
468 verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
469 if (error.has_value())
470 return error.value();
471
472 return success();
473}
474
475//===----------------------------------------------------------------------===//
476// NVGPU_WarpgroupGenerateDescriptorOp
477//===----------------------------------------------------------------------===//
478
479LogicalResult WarpgroupGenerateDescriptorOp::verify() {
480 std::optional<InFlightDiagnostic> error =
481 verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
482 if (error.has_value())
483 return error.value();
484
485 if (getTensorMap().getType().getSwizzle() !=
486 TensorMapSwizzleKind::SWIZZLE_128B) {
487 return emitError() << "supports only "
488 << stringifyTensorMapSwizzleKind(
489 TensorMapSwizzleKind::SWIZZLE_128B)
490 << " is supported for the time being";
491 }
492
493 if (getTensorMap().getType().getInterleave() !=
494 TensorMapInterleaveKind::INTERLEAVE_NONE) {
495 return emitError() << "supports only "
496 << stringifyTensorMapInterleaveKind(
497 TensorMapInterleaveKind::INTERLEAVE_NONE)
498 << " is supported for the time being";
499 }
500
501 return success();
502}
503
504//===----------------------------------------------------------------------===//
505// WarpgroupMmaOp
506//===----------------------------------------------------------------------===//
507
508LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
509 // F32 += F16 + F16
510 // F16 += F16 + F16
511 if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16()))
512 return success();
513 // F32 += TF32 + TF32
514 if (typeA.isTF32() && typeD.isF32() && typeB.isTF32())
515 return success();
516 // s32 += i8 + i8
517 if (typeA.isInteger(width: 16) && typeB.isInteger(width: 16) && typeD.isInteger(width: 32))
518 return success();
519 // s32 += i1 + i1
520 if (typeA.isInteger(width: 1) && typeB.isInteger(width: 1) && typeD.isInteger(width: 32))
521 return success();
522 // F32 += BF16 + BF16
523 // F16 += BF16 + BF16
524 if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16()))
525 return success();
526 // F16 += f8 + f8
527 // F32 += f8 + f8
528 if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
529 isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
530 (typeD.isF32() || typeD.isF16()))
531 return success();
532
533 return failure();
534}
535
536LogicalResult isAllowedSizeM(int sizeM) {
537 if (sizeM % kWgmmaSizeM)
538 return failure();
539 return success();
540}
541
542LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
543 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
544 72, 80, 88, 96, 104, 112, 120, 128,
545 136, 144, 152, 160, 168, 176, 184, 192,
546 200, 208, 216, 224, 232, 240, 248, 256};
547 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
548 80, 96, 112, 128, 144, 160,
549 176, 192, 208, 224, 240, 256};
550 if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
551 isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
552 if (llvm::is_contained(Range&: allowedN, Element: sizeN))
553 return success();
554
555 if (typeA.isInteger(width: 8) || typeA.isInteger(width: 1))
556 if (llvm::is_contained(Range&: allowedNshort, Element: sizeN))
557 return success();
558 return failure();
559}
560
561LogicalResult WarpgroupMmaOp::verify() {
562 if (getTransposeA() && !getTransposeB())
563 return emitOpError()
564 << "supports non-transpose A (Row Major) "
565 "and transpose B (Column Major) for the time being ";
566 MemRefType matrixA = getDescriptorA().getType().getTensor();
567 MemRefType matrixB = getDescriptorB().getType().getTensor();
568 VectorType matrixC = getMatrixC().getType().getFragmented();
569 VectorType matrixD = getMatrixD().getType().getFragmented();
570
571 if (matrixC != matrixD)
572 return emitOpError() << "type of matrix C and matrix D must be the same";
573
574 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
575 matrixC.getRank() != 2 || matrixD.getRank() != 2) {
576 return emitOpError()
577 << "has matrices A, B, C and D, they must be 2 dimensional";
578 }
579
580 if (matrixA.getShape()[1] != matrixB.getShape()[0])
581 return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
582 << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
583 << " )";
584 if (matrixA.getShape()[0] != matrixC.getShape()[0])
585 return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
586 << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
587 << " )";
588 if (matrixB.getShape()[1] != matrixC.getShape()[1])
589 return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1]
590 << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
591 << " )";
592
593 if (failed(isAllowedWGMMADataType(matrixC.getElementType(),
594 matrixA.getElementType(),
595 matrixB.getElementType())))
596 return emitOpError() << matrixC.getElementType()
597 << " += " << matrixA.getElementType() << " * "
598 << matrixB.getElementType()
599 << ", it is not supported.";
600 // Check N
601 if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
602 return emitOpError() << "has input type " << matrixB << " n is set to "
603 << matrixB.getDimSize(1) << ", it is not supported";
604 }
605
606 // Currently, f16/bf16 supported
607 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
608 !matrixA.getElementType().isBF16()) {
609 return emitOpError() << "hit a limitation: " << matrixC.getElementType()
610 << " += " << matrixA.getElementType() << " * "
611 << matrixB.getElementType()
612 << ", it is not supported yet";
613 }
614
615 return success();
616}
617
618LogicalResult WarpgroupMmaStoreOp::verify() {
619 MemRefType dstMemrefType = getDstMemref().getType();
620 VectorType vtype = getMatrixD().getType().getFragmented();
621
622 // Limitation
623 if (!vtype.getElementType().isF32()) {
624 return emitOpError()
625 << "hit a limitation: only f32 results for the time being";
626 }
627 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
628 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
629 return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
630 << "] values. However, destination memref["
631 << dstMemrefType.getDimSize(0) << "]["
632 << dstMemrefType.getDimSize(1)
633 << "] does not have same size as results";
634 }
635 return success();
636}
637
638//===----------------------------------------------------------------------===//
639// WarpgroupMmaInitAccumulatorOp
640//===----------------------------------------------------------------------===//
641
642LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
643
644 nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
645 int64_t sizeM = accType.getFragmented().getDimSize(0);
646 int64_t sizeN = accType.getFragmented().getDimSize(1);
647 Type elemType = accType.getFragmented().getElementType();
648
649 if (failed(isAllowedSizeM(sizeM)) ||
650 failed(isAllowedSizeN(sizeN, elemType))) {
651 return emitOpError() << "has type " << accType.getFragmented()
652 << ". It does not fit into warp-group "
653 "level (wgmma) matrix multiplication instruction "
654 "(or not supported yet)";
655 }
656 return success();
657}
658
659//===----------------------------------------------------------------------===//
660// RcpOp
661//===----------------------------------------------------------------------===//
662
663LogicalResult RcpOp::verify() {
664 RcpRoundingModeAttr rounding = getRoundingAttr();
665 bool ftz = getFtz();
666 // Currently, only `rcp_approx` and `ftz` is supported.
667 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
668 return emitOpError() << "has a limitation. " << rounding
669 << " or non-ftz is not supported yet.";
670 }
671 return success();
672}
673
674//===----------------------------------------------------------------------===//
675// TableGen'd dialect, type, and op definitions
676//===----------------------------------------------------------------------===//
677
678#define GET_ATTRDEF_CLASSES
679#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
680
681#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
682
683#define GET_OP_CLASSES
684#include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc"
685
686#define GET_TYPEDEF_CLASSES
687#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
688

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp