1 | //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the NVGPU dialect and its operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/BuiltinAttributes.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/Diagnostics.h" |
20 | #include "mlir/IR/DialectImplementation.h" |
21 | #include "mlir/IR/Matchers.h" |
22 | #include "mlir/IR/OpImplementation.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/IR/TypeUtilities.h" |
25 | #include "mlir/IR/Verifier.h" |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/ADT/StringExtras.h" |
28 | #include "llvm/ADT/TypeSwitch.h" |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::nvgpu; |
32 | |
33 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" |
34 | |
35 | void nvgpu::NVGPUDialect::initialize() { |
36 | addTypes< |
37 | #define GET_TYPEDEF_LIST |
38 | #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc" |
39 | >(); |
40 | addAttributes< |
41 | #define GET_ATTRDEF_LIST |
42 | #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" |
43 | >(); |
44 | addOperations< |
45 | #define GET_OP_LIST |
46 | #include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc" |
47 | >(); |
48 | } |
49 | |
50 | bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { |
51 | if (!memorySpace) |
52 | return false; |
53 | if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace)) |
54 | return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace; |
55 | if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) |
56 | return gpuAttr.getValue() == gpu::AddressSpace::Workgroup; |
57 | return false; |
58 | } |
59 | |
60 | bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { |
61 | Attribute memorySpace = type.getMemorySpace(); |
62 | return isSharedMemoryAddressSpace(memorySpace); |
63 | } |
64 | |
65 | //===----------------------------------------------------------------------===// |
66 | // NVGPU_DeviceAsyncCopyOp |
67 | //===----------------------------------------------------------------------===// |
68 | |
69 | LogicalResult DeviceAsyncCopyOp::verify() { |
70 | auto srcMemref = llvm::cast<MemRefType>(getSrc().getType()); |
71 | auto dstMemref = llvm::cast<MemRefType>(getDst().getType()); |
72 | |
73 | if (!srcMemref.isLastDimUnitStride()) |
74 | return emitError("source memref most minor dim must have unit stride" ); |
75 | if (!dstMemref.isLastDimUnitStride()) |
76 | return emitError("destination memref most minor dim must have unit stride" ); |
77 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) |
78 | return emitError() |
79 | << "destination memref must have a memory space attribute of " |
80 | "IntegerAttr(" |
81 | << NVGPUDialect::kSharedMemoryAddressSpace |
82 | << ") or gpu::AddressSpaceAttr(Workgroup)" ; |
83 | if (dstMemref.getElementType() != srcMemref.getElementType()) |
84 | return emitError("source and destination must have the same element type" ); |
85 | if (size_t(srcMemref.getRank()) != getSrcIndices().size()) |
86 | return emitOpError() << "expected " << srcMemref.getRank() |
87 | << " source indices, got " << getSrcIndices().size(); |
88 | if (size_t(dstMemref.getRank()) != getDstIndices().size()) |
89 | return emitOpError() << "expected " << dstMemref.getRank() |
90 | << " destination indices, got " |
91 | << getDstIndices().size(); |
92 | int64_t dstElements = getDstElements().getZExtValue(); |
93 | int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8; |
94 | if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) { |
95 | unsigned dstWidth = dstMemref.getElementTypeBitWidth(); |
96 | InFlightDiagnostic diag = emitError(); |
97 | diag << "Requested copy elements is " << dstElements << " with width " |
98 | << dstMemref.getElementTypeBitWidth() |
99 | << ". But copy elements could be one of " ; |
100 | if ((32 / dstWidth) > 0) |
101 | diag << (32 / dstWidth) << ", " ; |
102 | if ((64 / dstWidth) > 0) |
103 | diag << (64 / dstWidth) << ", " ; |
104 | if ((128 / dstWidth) > 0) |
105 | diag << (128 / dstWidth) << "." ; |
106 | return diag; |
107 | } |
108 | if (getBypassL1().has_value()) { |
109 | int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth(); |
110 | if (getBypassL1().value() && sizeInBytes != 16) { |
111 | return emitOpError() << "bypassL1 does not satify alignment for " |
112 | << dstMemref << " with destination element " |
113 | << dstElements |
114 | << ". Unset bypassL1, or set " |
115 | "destination element to " |
116 | << req; |
117 | } |
118 | } |
119 | return success(); |
120 | } |
121 | |
122 | //===----------------------------------------------------------------------===// |
123 | // NVGPU_MmaSyncOp |
124 | //===----------------------------------------------------------------------===// |
125 | void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder, |
126 | ::mlir::OperationState &odsState, Value matrixA, |
127 | Value matrixB, Value matrixC, ArrayAttr mmaShape) { |
128 | build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, |
129 | mmaShape, UnitAttr()); |
130 | } |
131 | |
132 | void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder, |
133 | ::mlir::OperationState &odsState, Value matrixA, |
134 | Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape, |
135 | bool tf32Enabled) { |
136 | build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, |
137 | odsBuilder.getI64ArrayAttr(mmaShape), |
138 | tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr()); |
139 | } |
140 | |
141 | /// Performs verification for MmaSyncOp and MmaSparseSyncOp. |
142 | static LogicalResult verifyMmaSyncOp(Operation *op, |
143 | TypedValue<VectorType> matrixA, |
144 | TypedValue<VectorType> matrixB, |
145 | TypedValue<VectorType> matrixC, |
146 | const std::array<int64_t, 3> &mmaShape, |
147 | bool tf32Enabled, bool sparse = false) { |
148 | |
149 | // The verification for mma.sync covering various shapes and data types is |
150 | // based on the fundamental tensor core shape. |
151 | |
152 | // "Fundamental" tensor core shapes: |
153 | // - For F32 (TF32), F16, S8, and S4 data |
154 | // types the fundamental tensor core operation is of shape 8-by-8-by-128b. |
155 | // - F64 is an exception and is of shape 8-by-8-by-256b. |
156 | int64_t shapeM = 8; |
157 | int64_t shapeN = 8; |
158 | int64_t shapeK; // set based on data type (128b for all data types except F64) |
159 | |
160 | // Number of elements A, B, and C per thread per fundamental tensor core tile |
161 | int64_t numElementA; // set based on data type (32b except F64) |
162 | int64_t numElementB; // set based on data type (32b except F64) |
163 | int64_t numElementC{2}; // two accumulator elements per fundamental tile |
164 | |
165 | // nvgpu.mma.sync vector operands (per thread) |
166 | auto aVector = matrixA.getType(); |
167 | auto bVector = matrixB.getType(); |
168 | auto cVector = matrixC.getType(); |
169 | |
170 | // vector shapes |
171 | ArrayRef<int64_t> aShape = aVector.getShape(); |
172 | ArrayRef<int64_t> bShape = bVector.getShape(); |
173 | ArrayRef<int64_t> cShape = cVector.getShape(); |
174 | |
175 | // vector element type |
176 | Type aType = aVector.getElementType(); |
177 | |
178 | // Certain data types are not allowed in sparse mode. |
179 | if (sparse && aType.isF64()) |
180 | return op->emitError() << "f64 is not supported for sparse mode" ; |
181 | |
182 | if (aType.isF64()) { |
183 | // exception to 8-by-8-128b fundamental tensor core tile size |
184 | shapeK = 4; |
185 | numElementA = 1; |
186 | numElementB = 1; |
187 | } else if (aType.isF32() || aType.isBF16() || aType.isF16() || |
188 | aType.isInteger(width: 8) || aType.isInteger(width: 4)) { |
189 | // 8-by-8-128b fundamental tensor core tile size |
190 | int operandBitwidth = aType.getIntOrFloatBitWidth(); |
191 | shapeK = 128 / operandBitwidth; // 128b wide shapeK |
192 | |
193 | numElementA = 32 / operandBitwidth; // 32b wide operand A |
194 | numElementB = 32 / operandBitwidth; // 32b wide operand B |
195 | } else { |
196 | return op->emitError() |
197 | << "expected input data type (i4,i8,f16,bf16,tf32,f64) " |
198 | "supported by " |
199 | << op->getName(); |
200 | } |
201 | |
202 | // |
203 | // Basic verification |
204 | // |
205 | |
206 | if (aShape.size() != 2) { |
207 | return op->emitError() << "matrixA must be 2 dimensional vector" ; |
208 | } |
209 | |
210 | if (bShape.size() != 2) { |
211 | return op->emitError() << "matrixB must be 2 dimensional vector" ; |
212 | } |
213 | |
214 | if (cShape.size() != 2) { |
215 | return op->emitError() << "matrixC must be 2 dimensional vector" ; |
216 | } |
217 | |
218 | auto [m, n, k] = mmaShape; |
219 | |
220 | // verify warp-wide size for vector a |
221 | int64_t sparseFactor = sparse ? 2 : 1; |
222 | if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor) |
223 | return op->emitOpError() |
224 | << "expected " << m * k << " warp-wide matrix A elements" ; |
225 | |
226 | // verify warp-wide size for vector b |
227 | if (bShape[0] * bShape[1] * kWarpSize != k * n) |
228 | return op->emitOpError() |
229 | << "expected " << k * n << " warp-wide matrix B elements" ; |
230 | |
231 | // verify warp-wide size for vector c |
232 | if (cShape[0] * cShape[1] * kWarpSize != m * n) |
233 | return op->emitOpError() |
234 | << "expected " << m * n << " warp-wide matrix C elements" ; |
235 | |
236 | // verify tf32 tensor cores are enabled for only F32 datatype |
237 | if (tf32Enabled && !(aType.isF32())) |
238 | return op->emitOpError() |
239 | << "expected tf32 tensor cores only for F32 operands" ; |
240 | |
241 | // |
242 | // Extended verification |
243 | // |
244 | |
245 | // tiles of fundamental tensor core operations |
246 | int64_t mTile = m / shapeM; |
247 | int64_t nTile = n / shapeN; |
248 | int64_t kTile = k / shapeK; |
249 | |
250 | // verify shape of aVector |
251 | if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) || |
252 | (aShape[1] != numElementA)) |
253 | return op->emitOpError() << "expected matrix A to be shaped (" |
254 | << mTile * kTile << " x " << numElementA << ")" ; |
255 | |
256 | // verify shape of bVector |
257 | if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB)) |
258 | return op->emitOpError() << "expected matrix B to be shaped (" |
259 | << kTile * nTile << " x " << numElementB << ")" ; |
260 | |
261 | // verify shape of cVector |
262 | if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC)) |
263 | return op->emitOpError() << "expected matrix C to be shaped (" |
264 | << mTile * nTile << " x " << numElementC << ")" ; |
265 | |
266 | return success(); |
267 | } |
268 | |
269 | LogicalResult MmaSyncOp::verify() { |
270 | return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(), |
271 | getMatrixC(), getMmaShapeAsArray(), |
272 | getOperation()->hasAttr(getTf32EnabledAttrName())); |
273 | } |
274 | |
275 | //===----------------------------------------------------------------------===// |
276 | // NVGPU_MmaSparseSyncOp |
277 | //===----------------------------------------------------------------------===// |
278 | void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder, |
279 | ::mlir::OperationState &odsState, Value matrixA, |
280 | Value matrixB, Value matrixC, Value sparseMetadata, |
281 | ArrayRef<int64_t> mmaShape) { |
282 | build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, |
283 | sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr()); |
284 | } |
285 | |
286 | LogicalResult MmaSparseSyncOp::verify() { |
287 | unsigned sparsitySelector = getSparsitySelector(); |
288 | if (sparsitySelector > 1) |
289 | return emitOpError() << "sparsity selector should be 0 or 1" ; |
290 | return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(), |
291 | getMatrixC(), getMmaShapeAsArray(), |
292 | getOperation()->hasAttr(getTf32EnabledAttrName()), |
293 | true); |
294 | } |
295 | |
296 | //===----------------------------------------------------------------------===// |
297 | // NVGPU_LdMatrixOp |
298 | //===----------------------------------------------------------------------===// |
299 | LogicalResult LdMatrixOp::verify() { |
300 | |
301 | // ldmatrix reads data from source in shared memory |
302 | auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType()); |
303 | |
304 | // ldmatrix writes data to result/destination in vector registers |
305 | auto resVector = llvm::cast<VectorType>(getRes().getType()); |
306 | |
307 | // vector register shape, element type, and bitwidth |
308 | ArrayRef<int64_t> resShape = resVector.getShape(); |
309 | Type resType = resVector.getElementType(); |
310 | int64_t elementBitWidth = resType.getIntOrFloatBitWidth(); |
311 | |
312 | // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread |
313 | int64_t numElementsPer32b = 32 / elementBitWidth; |
314 | |
315 | // number of 8-by-8 tiles |
316 | int64_t numTiles = getNumTiles(); |
317 | |
318 | // transpose elements in vector registers at 16b granularity when true |
319 | bool isTranspose = getTranspose(); |
320 | |
321 | // |
322 | // verification |
323 | // |
324 | |
325 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref)) |
326 | return emitError() |
327 | << "expected nvgpu.ldmatrix srcMemref must have a memory space " |
328 | "attribute of IntegerAttr(" |
329 | << NVGPUDialect::kSharedMemoryAddressSpace |
330 | << ") or gpu::AddressSpaceAttr(Workgroup)" ; |
331 | if (elementBitWidth > 32) |
332 | return emitError() << "nvgpu.ldmatrix works for 32b or lower" ; |
333 | if (isTranspose && !(elementBitWidth == 16)) |
334 | return emitError() |
335 | << "nvgpu.ldmatrix transpose works only at 16b granularity" ; |
336 | if (resShape.size() != 2) { |
337 | return emitError() << "results must be 2 dimensional vector" ; |
338 | } |
339 | if (!(resShape[1] == numElementsPer32b)) |
340 | return emitError() << "expected vector register shape[1] = " |
341 | << numElementsPer32b; |
342 | if (!(resShape[0] == numTiles)) |
343 | return emitError() |
344 | << "expected vector register shape[0] and numTiles to match" ; |
345 | |
346 | return success(); |
347 | } |
348 | |
349 | //===----------------------------------------------------------------------===// |
350 | // NVGPU_TmaAsyncLoadOp |
351 | //===----------------------------------------------------------------------===// |
352 | |
353 | std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref( |
354 | Operation *op, nvgpu::TensorMapDescriptorType descType, |
355 | std::optional<MemRefType> memrefType = std::nullopt) { |
356 | MemRefType descMemref = descType.getTensor(); |
357 | // Limitation |
358 | if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE) |
359 | return op->emitError() << "Interleave options are not supported yet." ; |
360 | |
361 | // Address space check for shared memory check |
362 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) { |
363 | return op->emitError() << "the tensor map descriptor has incorrect address " |
364 | "space, it must be shared memory address space." ; |
365 | } |
366 | // Support only static shape for the time being |
367 | if (!descMemref.hasStaticShape()) |
368 | return op->emitError() << "the tensor map descriptor must be static shaped" ; |
369 | |
370 | for (auto dim : descMemref.getShape()) { |
371 | if (dim <= 0 || dim > kMaxTMADimension) { |
372 | return op->emitError() << "the tensor map descriptor must have " |
373 | "dimensions between 1 and " |
374 | << kMaxTMADimension << " but it is " << dim; |
375 | } |
376 | } |
377 | if (descMemref.getRank() > 1 && |
378 | descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) { |
379 | unsigned lastDimensionByte = |
380 | descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8; |
381 | if (lastDimensionByte != kMaxTMALastdimByte) |
382 | return op->emitError() << "the tensormap descriptor must have last " |
383 | "dimension of " |
384 | << kMaxTMALastdimByte << " bytes but it is " |
385 | << lastDimensionByte << " bytes" ; |
386 | } |
387 | |
388 | // No verification if memref type is not provided |
389 | if (!memrefType.has_value()) |
390 | return std::nullopt; |
391 | |
392 | MemRefType dstMemref = memrefType.value(); |
393 | |
394 | // Check element type |
395 | if (descMemref.getElementType() != dstMemref.getElementType()) { |
396 | return op->emitError() << "the element type of tensor map descriptor and " |
397 | "memref must be same" ; |
398 | } |
399 | |
400 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) { |
401 | return op->emitError() << "the destination memref has incorrect address " |
402 | "space, it must be shared memory address space." ; |
403 | } |
404 | if (!dstMemref.hasStaticShape()) |
405 | return op->emitError() << "the destination memref must be static shaped" ; |
406 | |
407 | if (dstMemref.getRank() != descMemref.getRank()) { |
408 | return op->emitError() << "the shape of tensor map descriptor and " |
409 | "memref must have same rank" ; |
410 | } |
411 | if (!descMemref.getShape().equals(dstMemref.getShape())) { |
412 | return op->emitError() << "memref and tensor map shapes mismatch " |
413 | << descMemref << " != " << dstMemref; |
414 | } |
415 | |
416 | return std::nullopt; |
417 | } |
418 | |
419 | LogicalResult TmaAsyncLoadOp::verify() { |
420 | std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref( |
421 | *this, getTensorMapDescriptor().getType(), getDst().getType()); |
422 | if (error.has_value()) |
423 | return error.value(); |
424 | |
425 | if (getCoordinates().size() > kMaxTMATensorDimension) { |
426 | return emitError() << "Maximum " << kMaxTMATensorDimension |
427 | << " coordinates are supported." ; |
428 | } |
429 | if (getCoordinates().size() != |
430 | size_t(getTensorMapDescriptor().getType().getTensor().getRank())) { |
431 | return emitError() << "number of coordinates do not match with the rank of " |
432 | "tensor descriptor map." ; |
433 | } |
434 | |
435 | return success(); |
436 | } |
437 | |
438 | //===----------------------------------------------------------------------===// |
439 | // NVGPU_TmaAsyncStoreOp |
440 | //===----------------------------------------------------------------------===// |
441 | |
442 | LogicalResult TmaAsyncStoreOp::verify() { |
443 | std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref( |
444 | *this, getTensorMapDescriptor().getType(), getSrc().getType()); |
445 | if (error.has_value()) |
446 | return error.value(); |
447 | |
448 | if (getCoordinates().size() > kMaxTMATensorDimension) { |
449 | return emitError() << "Maximum " << kMaxTMATensorDimension |
450 | << " coordinates are supported." ; |
451 | } |
452 | if (getCoordinates().size() != |
453 | size_t(getTensorMapDescriptor().getType().getTensor().getRank())) { |
454 | return emitError() << "number of coordinates do not match with the rank of " |
455 | "tensor descriptor map." ; |
456 | } |
457 | |
458 | return success(); |
459 | } |
460 | |
461 | LogicalResult TmaCreateDescriptorOp::verify() { |
462 | if (getBoxDimensions().size() > kMaxTMATensorDimension) { |
463 | return emitError() << "Maximum " << kMaxTMATensorDimension |
464 | << " coordinates are supported." ; |
465 | } |
466 | |
467 | std::optional<InFlightDiagnostic> error = |
468 | verifyTmaDescriptorWithMemref(*this, getTensorMap().getType()); |
469 | if (error.has_value()) |
470 | return error.value(); |
471 | |
472 | return success(); |
473 | } |
474 | |
475 | //===----------------------------------------------------------------------===// |
476 | // NVGPU_WarpgroupGenerateDescriptorOp |
477 | //===----------------------------------------------------------------------===// |
478 | |
479 | LogicalResult WarpgroupGenerateDescriptorOp::verify() { |
480 | std::optional<InFlightDiagnostic> error = |
481 | verifyTmaDescriptorWithMemref(*this, getTensorMap().getType()); |
482 | if (error.has_value()) |
483 | return error.value(); |
484 | |
485 | if (getTensorMap().getType().getSwizzle() != |
486 | TensorMapSwizzleKind::SWIZZLE_128B) { |
487 | return emitError() << "supports only " |
488 | << stringifyTensorMapSwizzleKind( |
489 | TensorMapSwizzleKind::SWIZZLE_128B) |
490 | << " is supported for the time being" ; |
491 | } |
492 | |
493 | if (getTensorMap().getType().getInterleave() != |
494 | TensorMapInterleaveKind::INTERLEAVE_NONE) { |
495 | return emitError() << "supports only " |
496 | << stringifyTensorMapInterleaveKind( |
497 | TensorMapInterleaveKind::INTERLEAVE_NONE) |
498 | << " is supported for the time being" ; |
499 | } |
500 | |
501 | return success(); |
502 | } |
503 | |
504 | //===----------------------------------------------------------------------===// |
505 | // WarpgroupMmaOp |
506 | //===----------------------------------------------------------------------===// |
507 | |
508 | LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) { |
509 | // F32 += F16 + F16 |
510 | // F16 += F16 + F16 |
511 | if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16())) |
512 | return success(); |
513 | // F32 += TF32 + TF32 |
514 | if (typeA.isTF32() && typeD.isF32() && typeB.isTF32()) |
515 | return success(); |
516 | // s32 += i8 + i8 |
517 | if (typeA.isInteger(width: 16) && typeB.isInteger(width: 16) && typeD.isInteger(width: 32)) |
518 | return success(); |
519 | // s32 += i1 + i1 |
520 | if (typeA.isInteger(width: 1) && typeB.isInteger(width: 1) && typeD.isInteger(width: 32)) |
521 | return success(); |
522 | // F32 += BF16 + BF16 |
523 | // F16 += BF16 + BF16 |
524 | if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16())) |
525 | return success(); |
526 | // F16 += f8 + f8 |
527 | // F32 += f8 + f8 |
528 | if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) && |
529 | isa<Float8E5M2Type, Float8E4M3FNType>(typeB) && |
530 | (typeD.isF32() || typeD.isF16())) |
531 | return success(); |
532 | |
533 | return failure(); |
534 | } |
535 | |
536 | LogicalResult isAllowedSizeM(int sizeM) { |
537 | if (sizeM % kWgmmaSizeM) |
538 | return failure(); |
539 | return success(); |
540 | } |
541 | |
542 | LogicalResult isAllowedSizeN(int sizeN, Type typeA) { |
543 | SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64, |
544 | 72, 80, 88, 96, 104, 112, 120, 128, |
545 | 136, 144, 152, 160, 168, 176, 184, 192, |
546 | 200, 208, 216, 224, 232, 240, 248, 256}; |
547 | SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64, |
548 | 80, 96, 112, 128, 144, 160, |
549 | 176, 192, 208, 224, 240, 256}; |
550 | if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() || |
551 | isa<Float8E5M2Type, Float8E4M3FNType>(typeA)) |
552 | if (llvm::is_contained(Range&: allowedN, Element: sizeN)) |
553 | return success(); |
554 | |
555 | if (typeA.isInteger(width: 8) || typeA.isInteger(width: 1)) |
556 | if (llvm::is_contained(Range&: allowedNshort, Element: sizeN)) |
557 | return success(); |
558 | return failure(); |
559 | } |
560 | |
561 | LogicalResult WarpgroupMmaOp::verify() { |
562 | if (getTransposeA() && !getTransposeB()) |
563 | return emitOpError() |
564 | << "supports non-transpose A (Row Major) " |
565 | "and transpose B (Column Major) for the time being " ; |
566 | MemRefType matrixA = getDescriptorA().getType().getTensor(); |
567 | MemRefType matrixB = getDescriptorB().getType().getTensor(); |
568 | VectorType matrixC = getMatrixC().getType().getFragmented(); |
569 | VectorType matrixD = getMatrixD().getType().getFragmented(); |
570 | |
571 | if (matrixC != matrixD) |
572 | return emitOpError() << "type of matrix C and matrix D must be the same" ; |
573 | |
574 | if (matrixA.getRank() != 2 || matrixB.getRank() != 2 || |
575 | matrixC.getRank() != 2 || matrixD.getRank() != 2) { |
576 | return emitOpError() |
577 | << "has matrices A, B, C and D, they must be 2 dimensional" ; |
578 | } |
579 | |
580 | if (matrixA.getShape()[1] != matrixB.getShape()[0]) |
581 | return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1] |
582 | << ")!= 1st dim matrix-B (" << matrixB.getShape()[0] |
583 | << " )" ; |
584 | if (matrixA.getShape()[0] != matrixC.getShape()[0]) |
585 | return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0] |
586 | << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0] |
587 | << " )" ; |
588 | if (matrixB.getShape()[1] != matrixC.getShape()[1]) |
589 | return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1] |
590 | << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1] |
591 | << " )" ; |
592 | |
593 | if (failed(isAllowedWGMMADataType(matrixC.getElementType(), |
594 | matrixA.getElementType(), |
595 | matrixB.getElementType()))) |
596 | return emitOpError() << matrixC.getElementType() |
597 | << " += " << matrixA.getElementType() << " * " |
598 | << matrixB.getElementType() |
599 | << ", it is not supported." ; |
600 | // Check N |
601 | if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) { |
602 | return emitOpError() << "has input type " << matrixB << " n is set to " |
603 | << matrixB.getDimSize(1) << ", it is not supported" ; |
604 | } |
605 | |
606 | // Currently, f16/bf16 supported |
607 | if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() && |
608 | !matrixA.getElementType().isBF16()) { |
609 | return emitOpError() << "hit a limitation: " << matrixC.getElementType() |
610 | << " += " << matrixA.getElementType() << " * " |
611 | << matrixB.getElementType() |
612 | << ", it is not supported yet" ; |
613 | } |
614 | |
615 | return success(); |
616 | } |
617 | |
618 | LogicalResult WarpgroupMmaStoreOp::verify() { |
619 | MemRefType dstMemrefType = getDstMemref().getType(); |
620 | VectorType vtype = getMatrixD().getType().getFragmented(); |
621 | |
622 | // Limitation |
623 | if (!vtype.getElementType().isF32()) { |
624 | return emitOpError() |
625 | << "hit a limitation: only f32 results for the time being" ; |
626 | } |
627 | if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) || |
628 | vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) { |
629 | return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1) |
630 | << "] values. However, destination memref[" |
631 | << dstMemrefType.getDimSize(0) << "][" |
632 | << dstMemrefType.getDimSize(1) |
633 | << "] does not have same size as results" ; |
634 | } |
635 | return success(); |
636 | } |
637 | |
638 | //===----------------------------------------------------------------------===// |
639 | // WarpgroupMmaInitAccumulatorOp |
640 | //===----------------------------------------------------------------------===// |
641 | |
642 | LogicalResult WarpgroupMmaInitAccumulatorOp::verify() { |
643 | |
644 | nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType(); |
645 | int64_t sizeM = accType.getFragmented().getDimSize(0); |
646 | int64_t sizeN = accType.getFragmented().getDimSize(1); |
647 | Type elemType = accType.getFragmented().getElementType(); |
648 | |
649 | if (failed(isAllowedSizeM(sizeM)) || |
650 | failed(isAllowedSizeN(sizeN, elemType))) { |
651 | return emitOpError() << "has type " << accType.getFragmented() |
652 | << ". It does not fit into warp-group " |
653 | "level (wgmma) matrix multiplication instruction " |
654 | "(or not supported yet)" ; |
655 | } |
656 | return success(); |
657 | } |
658 | |
659 | //===----------------------------------------------------------------------===// |
660 | // RcpOp |
661 | //===----------------------------------------------------------------------===// |
662 | |
663 | LogicalResult RcpOp::verify() { |
664 | RcpRoundingModeAttr rounding = getRoundingAttr(); |
665 | bool ftz = getFtz(); |
666 | // Currently, only `rcp_approx` and `ftz` is supported. |
667 | if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) { |
668 | return emitOpError() << "has a limitation. " << rounding |
669 | << " or non-ftz is not supported yet." ; |
670 | } |
671 | return success(); |
672 | } |
673 | |
674 | //===----------------------------------------------------------------------===// |
675 | // TableGen'd dialect, type, and op definitions |
676 | //===----------------------------------------------------------------------===// |
677 | |
678 | #define GET_ATTRDEF_CLASSES |
679 | #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" |
680 | |
681 | #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc" |
682 | |
683 | #define GET_OP_CLASSES |
684 | #include "mlir/Dialect/NVGPU/IR/NVGPUOps.cpp.inc" |
685 | |
686 | #define GET_TYPEDEF_CLASSES |
687 | #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc" |
688 | |