1 | //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the NVGPU dialect and its operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/BuiltinAttributes.h" |
18 | #include "mlir/IR/BuiltinTypes.h" |
19 | #include "mlir/IR/Diagnostics.h" |
20 | #include "mlir/IR/DialectImplementation.h" |
21 | #include "mlir/IR/Matchers.h" |
22 | #include "mlir/IR/OpImplementation.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/IR/TypeUtilities.h" |
25 | #include "mlir/IR/Verifier.h" |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/ADT/StringExtras.h" |
28 | #include "llvm/ADT/TypeSwitch.h" |
29 | |
30 | using namespace mlir; |
31 | using namespace mlir::nvgpu; |
32 | |
33 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" |
34 | |
35 | void nvgpu::NVGPUDialect::initialize() { |
36 | addTypes< |
37 | #define GET_TYPEDEF_LIST |
38 | #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc" |
39 | >(); |
40 | addAttributes< |
41 | #define GET_ATTRDEF_LIST |
42 | #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" |
43 | >(); |
44 | addOperations< |
45 | #define GET_OP_LIST |
46 | #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" |
47 | >(); |
48 | } |
49 | |
50 | bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { |
51 | if (!memorySpace) |
52 | return false; |
53 | if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace)) |
54 | return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace; |
55 | if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) |
56 | return gpuAttr.getValue() == gpu::AddressSpace::Workgroup; |
57 | return false; |
58 | } |
59 | |
60 | bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { |
61 | Attribute memorySpace = type.getMemorySpace(); |
62 | return isSharedMemoryAddressSpace(memorySpace); |
63 | } |
64 | |
65 | //===----------------------------------------------------------------------===// |
66 | // NVGPU_DeviceAsyncCopyOp |
67 | //===----------------------------------------------------------------------===// |
68 | |
69 | LogicalResult DeviceAsyncCopyOp::verify() { |
70 | auto srcMemref = llvm::cast<MemRefType>(getSrc().getType()); |
71 | auto dstMemref = llvm::cast<MemRefType>(getDst().getType()); |
72 | |
73 | if (!isLastMemrefDimUnitStride(srcMemref)) |
74 | return emitError("source memref most minor dim must have unit stride" ); |
75 | if (!isLastMemrefDimUnitStride(dstMemref)) |
76 | return emitError("destination memref most minor dim must have unit stride" ); |
77 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) |
78 | return emitError() |
79 | << "destination memref must have a memory space attribute of " |
80 | "IntegerAttr(" |
81 | << NVGPUDialect::kSharedMemoryAddressSpace |
82 | << ") or gpu::AddressSpaceAttr(Workgroup)" ; |
83 | if (dstMemref.getElementType() != srcMemref.getElementType()) |
84 | return emitError("source and destination must have the same element type" ); |
85 | if (size_t(srcMemref.getRank()) != getSrcIndices().size()) |
86 | return emitOpError() << "expected " << srcMemref.getRank() |
87 | << " source indices, got " << getSrcIndices().size(); |
88 | if (size_t(dstMemref.getRank()) != getDstIndices().size()) |
89 | return emitOpError() << "expected " << dstMemref.getRank() |
90 | << " destination indices, got " |
91 | << getDstIndices().size(); |
92 | int64_t dstElements = getDstElements().getZExtValue(); |
93 | int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8; |
94 | if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) { |
95 | unsigned dstWidth = dstMemref.getElementTypeBitWidth(); |
96 | InFlightDiagnostic diag = emitError(); |
97 | diag << "Requested copy elements is " << dstElements << " with width " |
98 | << dstMemref.getElementTypeBitWidth() |
99 | << ". But copy elements could be one of " ; |
100 | if ((32 / dstWidth) > 0) |
101 | diag << (32 / dstWidth) << ", " ; |
102 | if ((64 / dstWidth) > 0) |
103 | diag << (64 / dstWidth) << ", " ; |
104 | if ((128 / dstWidth) > 0) |
105 | diag << (128 / dstWidth) << "." ; |
106 | return diag; |
107 | } |
108 | if (getBypassL1().has_value()) { |
109 | int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth(); |
110 | if (getBypassL1().value() && sizeInBytes != 16) { |
111 | return emitOpError() << "bypassL1 does not satify alignment for " |
112 | << dstMemref << " with destination element " |
113 | << dstElements |
114 | << ". Unset bypassL1, or set " |
115 | "destination element to " |
116 | << req; |
117 | } |
118 | } |
119 | return success(); |
120 | } |
121 | |
122 | //===----------------------------------------------------------------------===// |
123 | // NVGPU_MmaSyncOp |
124 | //===----------------------------------------------------------------------===// |
125 | void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder, |
126 | ::mlir::OperationState &odsState, Value matrixA, |
127 | Value matrixB, Value matrixC, ArrayAttr mmaShape) { |
128 | build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, |
129 | mmaShape, UnitAttr()); |
130 | } |
131 | |
132 | void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder, |
133 | ::mlir::OperationState &odsState, Value matrixA, |
134 | Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape, |
135 | bool tf32Enabled) { |
136 | build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, |
137 | odsBuilder.getI64ArrayAttr(mmaShape), |
138 | tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr()); |
139 | } |
140 | |
141 | /// Performs verification for MmaSyncOp and MmaSparseSyncOp. |
142 | static LogicalResult verifyMmaSyncOp(Operation *op, |
143 | TypedValue<VectorType> matrixA, |
144 | TypedValue<VectorType> matrixB, |
145 | TypedValue<VectorType> matrixC, |
146 | const std::array<int64_t, 3> &mmaShape, |
147 | bool tf32Enabled, bool sparse = false) { |
148 | |
149 | // The verification for mma.sync covering various shapes and data types is |
150 | // based on the fundamental tensor core shape. |
151 | |
152 | // "Fundamental" tensor core shapes: |
153 | // - For F32 (TF32), F16, S8, and S4 data |
154 | // types the fundamental tensor core operation is of shape 8-by-8-by-128b. |
155 | // - F64 is an exception and is of shape 8-by-8-by-256b. |
156 | int64_t shapeM = 8; |
157 | int64_t shapeN = 8; |
158 | int64_t shapeK; // set based on data type (128b for all data types except F64) |
159 | |
160 | // Number of elements A, B, and C per thread per fundamental tensor core tile |
161 | int64_t numElementA; // set based on data type (32b except F64) |
162 | int64_t numElementB; // set based on data type (32b except F64) |
163 | int64_t numElementC{2}; // two accumulator elements per fundamental tile |
164 | |
165 | // nvgpu.mma.sync vector operands (per thread) |
166 | auto aVector = matrixA.getType(); |
167 | auto bVector = matrixB.getType(); |
168 | auto cVector = matrixC.getType(); |
169 | |
170 | // vector shapes |
171 | ArrayRef<int64_t> aShape = aVector.getShape(); |
172 | ArrayRef<int64_t> bShape = bVector.getShape(); |
173 | ArrayRef<int64_t> cShape = cVector.getShape(); |
174 | |
175 | // vector element type |
176 | Type aType = aVector.getElementType(); |
177 | |
178 | // Certain data types are not allowed in sparse mode. |
179 | if (sparse && aType.isF64()) |
180 | return op->emitError() << "f64 is not supported for sparse mode" ; |
181 | |
182 | if (aType.isF64()) { |
183 | // exception to 8-by-8-128b fundamental tensor core tile size |
184 | shapeK = 4; |
185 | numElementA = 1; |
186 | numElementB = 1; |
187 | } else if (aType.isF32() || aType.isBF16() || aType.isF16() || |
188 | aType.isInteger(width: 8) || aType.isInteger(width: 4)) { |
189 | // 8-by-8-128b fundamental tensor core tile size |
190 | int operandBitwidth = aType.getIntOrFloatBitWidth(); |
191 | shapeK = 128 / operandBitwidth; // 128b wide shapeK |
192 | |
193 | numElementA = 32 / operandBitwidth; // 32b wide operand A |
194 | numElementB = 32 / operandBitwidth; // 32b wide operand B |
195 | } else { |
196 | return op->emitError() |
197 | << "expected input data type (i4,i8,f16,bf16,tf32,f64) " |
198 | "supported by " |
199 | << op->getName(); |
200 | } |
201 | |
202 | // |
203 | // Basic verification |
204 | // |
205 | |
206 | auto [m, n, k] = mmaShape; |
207 | |
208 | // verify warp-wide size for vector a |
209 | int64_t sparseFactor = sparse ? 2 : 1; |
210 | if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor) |
211 | return op->emitOpError() |
212 | << "expected " << m * k << " warp-wide matrix A elements" ; |
213 | |
214 | // verify warp-wide size for vector b |
215 | if (bShape[0] * bShape[1] * kWarpSize != k * n) |
216 | return op->emitOpError() |
217 | << "expected " << k * n << " warp-wide matrix B elements" ; |
218 | |
219 | // verify warp-wide size for vector c |
220 | if (cShape[0] * cShape[1] * kWarpSize != m * n) |
221 | return op->emitOpError() |
222 | << "expected " << m * n << " warp-wide matrix C elements" ; |
223 | |
224 | // verify tf32 tensor cores are enabled for only F32 datatype |
225 | if (tf32Enabled && !(aType.isF32())) |
226 | return op->emitOpError() |
227 | << "expected tf32 tensor cores only for F32 operands" ; |
228 | |
229 | // |
230 | // Extended verification |
231 | // |
232 | |
233 | // tiles of fundamental tensor core operations |
234 | int64_t mTile = m / shapeM; |
235 | int64_t nTile = n / shapeN; |
236 | int64_t kTile = k / shapeK; |
237 | |
238 | // verify shape of aVector |
239 | if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) || |
240 | (aShape[1] != numElementA)) |
241 | return op->emitOpError() << "expected matrix A to be shaped (" |
242 | << mTile * kTile << " x " << numElementA << ")" ; |
243 | |
244 | // verify shape of bVector |
245 | if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB)) |
246 | return op->emitOpError() << "expected matrix B to be shaped (" |
247 | << kTile * nTile << " x " << numElementB << ")" ; |
248 | |
249 | // verify shape of cVector |
250 | if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC)) |
251 | return op->emitOpError() << "expected matrix C to be shaped (" |
252 | << mTile * nTile << " x " << numElementC << ")" ; |
253 | |
254 | return success(); |
255 | } |
256 | |
257 | LogicalResult MmaSyncOp::verify() { |
258 | return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(), |
259 | getMatrixC(), getMmaShapeAsArray(), |
260 | getOperation()->hasAttr(getTf32EnabledAttrName())); |
261 | } |
262 | |
263 | //===----------------------------------------------------------------------===// |
264 | // NVGPU_MmaSparseSyncOp |
265 | //===----------------------------------------------------------------------===// |
266 | void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder, |
267 | ::mlir::OperationState &odsState, Value matrixA, |
268 | Value matrixB, Value matrixC, Value sparseMetadata, |
269 | ArrayRef<int64_t> mmaShape) { |
270 | build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, |
271 | sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr()); |
272 | } |
273 | |
274 | LogicalResult MmaSparseSyncOp::verify() { |
275 | unsigned sparsitySelector = getSparsitySelector(); |
276 | if (sparsitySelector > 1) |
277 | return emitOpError() << "sparsity selector should be 0 or 1" ; |
278 | return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(), |
279 | getMatrixC(), getMmaShapeAsArray(), |
280 | getOperation()->hasAttr(getTf32EnabledAttrName()), |
281 | true); |
282 | } |
283 | |
284 | //===----------------------------------------------------------------------===// |
285 | // NVGPU_LdMatrixOp |
286 | //===----------------------------------------------------------------------===// |
287 | LogicalResult LdMatrixOp::verify() { |
288 | |
289 | // ldmatrix reads data from source in shared memory |
290 | auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType()); |
291 | |
292 | // ldmatrix writes data to result/destination in vector registers |
293 | auto resVector = llvm::cast<VectorType>(getRes().getType()); |
294 | |
295 | // vector register shape, element type, and bitwidth |
296 | ArrayRef<int64_t> resShape = resVector.getShape(); |
297 | Type resType = resVector.getElementType(); |
298 | int64_t elementBitWidth = resType.getIntOrFloatBitWidth(); |
299 | |
300 | // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread |
301 | int64_t numElementsPer32b = 32 / elementBitWidth; |
302 | |
303 | // number of 8-by-8 tiles |
304 | int64_t numTiles = getNumTiles(); |
305 | |
306 | // transpose elements in vector registers at 16b granularity when true |
307 | bool isTranspose = getTranspose(); |
308 | |
309 | // |
310 | // verification |
311 | // |
312 | |
313 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref)) |
314 | return emitError() |
315 | << "expected nvgpu.ldmatrix srcMemref must have a memory space " |
316 | "attribute of IntegerAttr(" |
317 | << NVGPUDialect::kSharedMemoryAddressSpace |
318 | << ") or gpu::AddressSpaceAttr(Workgroup)" ; |
319 | if (elementBitWidth > 32) |
320 | return emitError() << "nvgpu.ldmatrix works for 32b or lower" ; |
321 | if (isTranspose && !(elementBitWidth == 16)) |
322 | return emitError() |
323 | << "nvgpu.ldmatrix transpose works only at 16b granularity" ; |
324 | if (resShape.size() != 2) { |
325 | return emitError() << "results must be 2 dimensional vector" ; |
326 | } |
327 | if (!(resShape[1] == numElementsPer32b)) |
328 | return emitError() << "expected vector register shape[1] = " |
329 | << numElementsPer32b; |
330 | if (!(resShape[0] == numTiles)) |
331 | return emitError() |
332 | << "expected vector register shape[0] and numTiles to match" ; |
333 | |
334 | return success(); |
335 | } |
336 | |
337 | //===----------------------------------------------------------------------===// |
338 | // NVGPU_TmaAsyncLoadOp |
339 | //===----------------------------------------------------------------------===// |
340 | |
341 | std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref( |
342 | Operation *op, nvgpu::TensorMapDescriptorType descType, |
343 | std::optional<MemRefType> memrefType = std::nullopt) { |
344 | MemRefType descMemref = descType.getTensor(); |
345 | // Limitation |
346 | if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE) |
347 | return op->emitError() << "Interleave options are not supported yet." ; |
348 | |
349 | // Address space check for shared memory check |
350 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) { |
351 | return op->emitError() << "the tensor map descriptor has incorrect address " |
352 | "space, it must be shared memory address space." ; |
353 | } |
354 | // Support only static shape for the time being |
355 | if (!descMemref.hasStaticShape()) |
356 | return op->emitError() << "the tensor map descriptor must be static shaped" ; |
357 | |
358 | for (auto dim : descMemref.getShape()) { |
359 | if (dim <= 0 || dim > kMaxTMADimension) { |
360 | return op->emitError() << "the tensor map descriptor must have " |
361 | "dimensions between 1 and " |
362 | << kMaxTMADimension << " but it is " << dim; |
363 | } |
364 | } |
365 | if (descMemref.getRank() > 1 && |
366 | descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) { |
367 | unsigned lastDimensionByte = |
368 | descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8; |
369 | if (lastDimensionByte != kMaxTMALastdimByte) |
370 | return op->emitError() << "the tensormap descriptor must have last " |
371 | "dimension of " |
372 | << kMaxTMALastdimByte << " bytes but it is " |
373 | << lastDimensionByte << " bytes" ; |
374 | } |
375 | |
376 | // No verification if memref type is not provided |
377 | if (!memrefType.has_value()) |
378 | return std::nullopt; |
379 | |
380 | MemRefType dstMemref = memrefType.value(); |
381 | |
382 | // Check element type |
383 | if (descMemref.getElementType() != dstMemref.getElementType()) { |
384 | return op->emitError() << "the element type of tensor map descriptor and " |
385 | "memref must be same" ; |
386 | } |
387 | |
388 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) { |
389 | return op->emitError() << "the destination memref has incorrect address " |
390 | "space, it must be shared memory address space." ; |
391 | } |
392 | if (!dstMemref.hasStaticShape()) |
393 | return op->emitError() << "the destination memref must be static shaped" ; |
394 | |
395 | if (dstMemref.getRank() != descMemref.getRank()) { |
396 | return op->emitError() << "the shape of tensor map descriptor and " |
397 | "memref must have same rank" ; |
398 | } |
399 | if (!descMemref.getShape().equals(dstMemref.getShape())) { |
400 | return op->emitError() << "memref and tensor map shapes mismatch " |
401 | << descMemref << " != " << dstMemref; |
402 | } |
403 | |
404 | return std::nullopt; |
405 | } |
406 | |
407 | LogicalResult TmaAsyncLoadOp::verify() { |
408 | std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref( |
409 | *this, getTensorMapDescriptor().getType(), getDst().getType()); |
410 | if (error.has_value()) |
411 | return error.value(); |
412 | |
413 | if (getCoordinates().size() > kMaxTMATensorDimension) { |
414 | return emitError() << "Maximum " << kMaxTMATensorDimension |
415 | << " coordinates are supported." ; |
416 | } |
417 | if (getCoordinates().size() != |
418 | size_t(getTensorMapDescriptor().getType().getTensor().getRank())) { |
419 | return emitError() << "number of coordinates do not match with the rank of " |
420 | "tensor descriptor map." ; |
421 | } |
422 | |
423 | return success(); |
424 | } |
425 | |
426 | //===----------------------------------------------------------------------===// |
427 | // NVGPU_TmaAsyncStoreOp |
428 | //===----------------------------------------------------------------------===// |
429 | |
430 | LogicalResult TmaAsyncStoreOp::verify() { |
431 | std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref( |
432 | *this, getTensorMapDescriptor().getType(), getSrc().getType()); |
433 | if (error.has_value()) |
434 | return error.value(); |
435 | |
436 | if (getCoordinates().size() > kMaxTMATensorDimension) { |
437 | return emitError() << "Maximum " << kMaxTMATensorDimension |
438 | << " coordinates are supported." ; |
439 | } |
440 | if (getCoordinates().size() != |
441 | size_t(getTensorMapDescriptor().getType().getTensor().getRank())) { |
442 | return emitError() << "number of coordinates do not match with the rank of " |
443 | "tensor descriptor map." ; |
444 | } |
445 | |
446 | return success(); |
447 | } |
448 | |
449 | LogicalResult TmaCreateDescriptorOp::verify() { |
450 | if (getBoxDimensions().size() > kMaxTMATensorDimension) { |
451 | return emitError() << "Maximum " << kMaxTMATensorDimension |
452 | << " coordinates are supported." ; |
453 | } |
454 | |
455 | std::optional<InFlightDiagnostic> error = |
456 | verifyTmaDescriptorWithMemref(*this, getTensorMap().getType()); |
457 | if (error.has_value()) |
458 | return error.value(); |
459 | |
460 | return success(); |
461 | } |
462 | |
463 | //===----------------------------------------------------------------------===// |
464 | // NVGPU_WarpgroupGenerateDescriptorOp |
465 | //===----------------------------------------------------------------------===// |
466 | |
467 | LogicalResult WarpgroupGenerateDescriptorOp::verify() { |
468 | std::optional<InFlightDiagnostic> error = |
469 | verifyTmaDescriptorWithMemref(*this, getTensorMap().getType()); |
470 | if (error.has_value()) |
471 | return error.value(); |
472 | |
473 | if (getTensorMap().getType().getSwizzle() != |
474 | TensorMapSwizzleKind::SWIZZLE_128B) { |
475 | return emitError() << "supports only " |
476 | << stringifyTensorMapSwizzleKind( |
477 | TensorMapSwizzleKind::SWIZZLE_128B) |
478 | << " is supported for the time being" ; |
479 | } |
480 | |
481 | if (getTensorMap().getType().getInterleave() != |
482 | TensorMapInterleaveKind::INTERLEAVE_NONE) { |
483 | return emitError() << "supports only " |
484 | << stringifyTensorMapInterleaveKind( |
485 | TensorMapInterleaveKind::INTERLEAVE_NONE) |
486 | << " is supported for the time being" ; |
487 | } |
488 | |
489 | return success(); |
490 | } |
491 | |
492 | //===----------------------------------------------------------------------===// |
493 | // WarpgroupMmaOp |
494 | //===----------------------------------------------------------------------===// |
495 | |
496 | LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) { |
497 | // F32 += F16 + F16 |
498 | // F16 += F16 + F16 |
499 | if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16())) |
500 | return success(); |
501 | // F32 += TF32 + TF32 |
502 | if (typeA.isTF32() && typeD.isF32() && typeB.isTF32()) |
503 | return success(); |
504 | // s32 += i8 + i8 |
505 | if (typeA.isInteger(width: 16) && typeB.isInteger(width: 16) && typeD.isInteger(width: 32)) |
506 | return success(); |
507 | // s32 += i1 + i1 |
508 | if (typeA.isInteger(width: 1) && typeB.isInteger(width: 1) && typeD.isInteger(width: 32)) |
509 | return success(); |
510 | // F32 += BF16 + BF16 |
511 | // F16 += BF16 + BF16 |
512 | if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16())) |
513 | return success(); |
514 | // F16 += f8 + f8 |
515 | // F32 += f8 + f8 |
516 | if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) && |
517 | (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) && |
518 | (typeD.isF32() || typeD.isF16())) |
519 | return success(); |
520 | |
521 | return failure(); |
522 | } |
523 | |
524 | LogicalResult isAllowedSizeM(int sizeM) { |
525 | if (sizeM % kWgmmaSizeM) |
526 | return failure(); |
527 | return success(); |
528 | } |
529 | |
530 | LogicalResult isAllowedSizeN(int sizeN, Type typeA) { |
531 | SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64, |
532 | 72, 80, 88, 96, 104, 112, 120, 128, |
533 | 136, 144, 152, 160, 168, 176, 184, 192, |
534 | 200, 208, 216, 224, 232, 240, 248, 256}; |
535 | SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64, |
536 | 80, 96, 112, 128, 144, 160, |
537 | 176, 192, 208, 224, 240, 256}; |
538 | if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() || |
539 | typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2()) |
540 | if (llvm::is_contained(Range&: allowedN, Element: sizeN)) |
541 | return success(); |
542 | |
543 | if (typeA.isInteger(width: 8) || typeA.isInteger(width: 1)) |
544 | if (llvm::is_contained(Range&: allowedNshort, Element: sizeN)) |
545 | return success(); |
546 | return failure(); |
547 | } |
548 | |
549 | LogicalResult WarpgroupMmaOp::verify() { |
550 | if (getTransposeA() && !getTransposeB()) |
551 | return emitOpError() |
552 | << "supports non-transpose A (Row Major) " |
553 | "and transpose B (Column Major) for the time being " ; |
554 | MemRefType matrixA = getDescriptorA().getType().getTensor(); |
555 | MemRefType matrixB = getDescriptorB().getType().getTensor(); |
556 | VectorType matrixC = getMatrixC().getType().getFragmented(); |
557 | VectorType matrixD = getMatrixD().getType().getFragmented(); |
558 | |
559 | if (matrixC != matrixD) |
560 | return emitOpError() << "type of matrix C and matrix D must be the same" ; |
561 | |
562 | if (matrixA.getRank() != 2 || matrixB.getRank() != 2 || |
563 | matrixC.getRank() != 2 || matrixD.getRank() != 2) { |
564 | return emitOpError() |
565 | << "has matrices A, B, C and D, they must be 2 dimensional" ; |
566 | } |
567 | |
568 | if (matrixA.getShape()[1] != matrixB.getShape()[0]) |
569 | return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1] |
570 | << ")!= 1st dim matrix-B (" << matrixB.getShape()[0] |
571 | << " )" ; |
572 | if (matrixA.getShape()[0] != matrixC.getShape()[0]) |
573 | return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0] |
574 | << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0] |
575 | << " )" ; |
576 | if (matrixB.getShape()[1] != matrixC.getShape()[1]) |
577 | return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1] |
578 | << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1] |
579 | << " )" ; |
580 | |
581 | if (failed(isAllowedWGMMADataType(matrixC.getElementType(), |
582 | matrixA.getElementType(), |
583 | matrixB.getElementType()))) |
584 | return emitOpError() << matrixC.getElementType() |
585 | << " += " << matrixA.getElementType() << " * " |
586 | << matrixB.getElementType() |
587 | << ", it is not supported." ; |
588 | // Check N |
589 | if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) { |
590 | return emitOpError() << "has input type " << matrixB << " n is set to " |
591 | << matrixB.getDimSize(1) << ", it is not supported" ; |
592 | } |
593 | |
594 | // Currently, f16/bf16 supported |
595 | if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() && |
596 | !matrixA.getElementType().isBF16()) { |
597 | return emitOpError() << "hit a limitation: " << matrixC.getElementType() |
598 | << " += " << matrixA.getElementType() << " * " |
599 | << matrixB.getElementType() |
600 | << ", it is not supported yet" ; |
601 | } |
602 | |
603 | return success(); |
604 | } |
605 | |
606 | LogicalResult WarpgroupMmaStoreOp::verify() { |
607 | MemRefType dstMemrefType = getDstMemref().getType(); |
608 | VectorType vtype = getMatrixD().getType().getFragmented(); |
609 | |
610 | // Limitation |
611 | if (!vtype.getElementType().isF32()) { |
612 | return emitOpError() |
613 | << "hit a limitation: only f32 results for the time being" ; |
614 | } |
615 | if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) || |
616 | vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) { |
617 | return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1) |
618 | << "] values. However, destination memref[" |
619 | << dstMemrefType.getDimSize(0) << "][" |
620 | << dstMemrefType.getDimSize(1) |
621 | << "] does not have same size as results" ; |
622 | } |
623 | return success(); |
624 | } |
625 | |
626 | //===----------------------------------------------------------------------===// |
627 | // WarpgroupMmaInitAccumulatorOp |
628 | //===----------------------------------------------------------------------===// |
629 | |
630 | LogicalResult WarpgroupMmaInitAccumulatorOp::verify() { |
631 | |
632 | nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType(); |
633 | int64_t sizeM = accType.getFragmented().getDimSize(0); |
634 | int64_t sizeN = accType.getFragmented().getDimSize(1); |
635 | Type elemType = accType.getFragmented().getElementType(); |
636 | |
637 | if (failed(isAllowedSizeM(sizeM)) || |
638 | failed(isAllowedSizeN(sizeN, elemType))) { |
639 | return emitOpError() << "has type " << accType.getFragmented() |
640 | << ". It does not fit into warp-group " |
641 | "level (wgmma) matrix multiplication instruction " |
642 | "(or not supported yet)" ; |
643 | } |
644 | return success(); |
645 | } |
646 | |
647 | //===----------------------------------------------------------------------===// |
648 | // TableGen'd dialect, type, and op definitions |
649 | //===----------------------------------------------------------------------===// |
650 | |
651 | #define GET_ATTRDEF_CLASSES |
652 | #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" |
653 | |
654 | #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc" |
655 | |
656 | #define GET_OP_CLASSES |
657 | #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" |
658 | |
659 | #define GET_TYPEDEF_CLASSES |
660 | #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc" |
661 | |