1//===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
9
10#include "mlir/Dialect/Affine/IR/AffineOps.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
13#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15
16using namespace mlir;
17using namespace mlir::nvgpu;
18
19/// There are always 4 threads per [128|256|512] bit row.
20static constexpr int64_t kThreadsPerRow = 4;
21static constexpr int64_t kNumRowsPerTile = 8;
22
23static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
24 return operandType == MatMulOperandRole::C;
25}
26
27/// Returns the number of registers which compose a matrix fragment held by a
28/// single thread.
29static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
30 int64_t lineSize = inferTileWidthInBits(type);
31 auto shape = type.vectorType.getShape();
32 return (shape[0] / kNumRowsPerTile) *
33 (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
34 lineSize;
35}
36
37/// Returns the number of 8 x [128|256|512] bit tiles that compose the given
38/// operand shape.
39static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
40 Type elementType,
41 int64_t lineSizeBits) {
42 // For each 8x128bit square, a thread is responsible for one 32bit register.
43 return {operandShape[0] / kNumRowsPerTile,
44 (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
45 lineSizeBits};
46}
47
48/// Returns the first user of the `op` that is vector.contract. If no
49/// vector.contract user exists, return failure.
50FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
51 for (Operation *user : op->getUsers()) {
52 if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
53 return contractOp;
54 }
55 return failure();
56}
57
58FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
59 WarpMatrixInfo info;
60
61 // Determine the vector type at warp-level.
62 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
63 info.vectorType = writeOp.getVectorType();
64 } else if (isa<vector::TransferReadOp, vector::ContractionOp,
65 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
66 info.vectorType = cast<VectorType>(op->getResult(idx: 0).getType());
67 } else {
68 return op->emitError()
69 << "unhandled operation type in nvgpu.mma.sync conversion path";
70 }
71
72 // Determine the operand role. We assume it is an accumulator/result unless it
73 // is directly consumed by a `vector.contract` op.
74 info.operandRole = MatMulOperandRole::C;
75 FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
76 if (failed(result: contractOp))
77 return info;
78
79 if ((*contractOp).getLhs() == op->getResult(idx: 0))
80 info.operandRole = MatMulOperandRole::A;
81 else if ((*contractOp).getRhs() == op->getResult(idx: 0))
82 info.operandRole = MatMulOperandRole::B;
83
84 return info;
85}
86
87int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) {
88 bool isAcc = isAccumulatorOrResult(operandType: type.operandRole);
89 Type elType = type.vectorType.getElementType();
90 if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
91 return 256;
92 }
93 if (elType.getIntOrFloatBitWidth() == 64) {
94 return isAcc ? 512 : 256;
95 }
96 return 128;
97}
98
99FailureOr<FragmentElementInfo>
100nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
101 MLIRContext *ctx = type.vectorType.getContext();
102 const bool isAccum = isAccumulatorOrResult(operandType: type.operandRole);
103
104 Type elType = type.vectorType.getElementType();
105 if (elType.isF16()) {
106 return FragmentElementInfo{
107 LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
108 inferNumRegistersPerMatrixFragment(type)};
109 }
110
111 // f64 operand
112 Type f64Ty = Float64Type::get(ctx);
113 if (elType.isF64()) {
114 return isAccum
115 ? FragmentElementInfo{.registerLLVMType: LLVM::getFixedVectorType(elementType: f64Ty, numElements: 2), .elementsPerRegister: 2, .registerWidthBits: 128,
116 .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)}
117 : FragmentElementInfo{.registerLLVMType: f64Ty, .elementsPerRegister: 1, .registerWidthBits: 64,
118 .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)};
119 }
120
121 // int8 operand
122 if (elType.isInteger(width: 8)) {
123 return FragmentElementInfo{
124 LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
125 inferNumRegistersPerMatrixFragment(type)};
126 }
127
128 // int4 operand
129 if (elType.isInteger(width: 4)) {
130 return FragmentElementInfo{
131 LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
132 inferNumRegistersPerMatrixFragment(type)};
133 }
134
135 // Integer 32bit acc operands
136 if (elType.isInteger(width: 32)) {
137 return FragmentElementInfo{
138 LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
139 inferNumRegistersPerMatrixFragment(type)};
140 }
141
142 // Floating point 32bit operands
143 if (elType.isF32()) {
144 Type f32Ty = Float32Type::get(ctx);
145 return isAccum
146 ? FragmentElementInfo{.registerLLVMType: LLVM::getFixedVectorType(elementType: f32Ty, numElements: 2), .elementsPerRegister: 2, .registerWidthBits: 64,
147 .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)}
148 : FragmentElementInfo{.registerLLVMType: f32Ty, .elementsPerRegister: 1, .registerWidthBits: 32,
149 .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)};
150 }
151 return failure();
152}
153
154static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
155 Type elementType,
156 ArrayRef<int64_t> operandShape,
157 bool isAccumulator,
158 int64_t elementsPerRegister,
159 AffineExpr logicalValueId) {
160 const int64_t elementsPerLine =
161 lineSize / elementType.getIntOrFloatBitWidth();
162 const std::array<int64_t, 2> num8x128bTiles =
163 getTileShape(operandShape, elementType, lineSizeBits: lineSize);
164 AffineExpr registerIdx = logicalValueId.floorDiv(v: elementsPerRegister);
165 return AffineMap::get(
166 dimCount: 2, symbolCount: 0,
167 results: {(registerIdx % num8x128bTiles[0]) * 8,
168 (registerIdx.floorDiv(v: num8x128bTiles[0])) * elementsPerLine},
169 context: elementType.getContext());
170}
171
172FailureOr<AffineMap>
173nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
174 const WarpMatrixInfo &fragmentType) {
175 Type elementType = fragmentType.vectorType.getElementType();
176 ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
177 FailureOr<nvgpu::FragmentElementInfo> regInfo =
178 getMmaSyncRegisterType(type: fragmentType);
179 if (failed(result: regInfo))
180 return failure();
181
182 const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
183 const int64_t elementsPerRegister =
184 regInfo->registerWidthBits / elementBitWidth;
185 const int64_t lineSize = inferTileWidthInBits(type: fragmentType);
186
187 AffineExpr laneId, logicalValueIdDim;
188 bindDims(ctx: builder.getContext(), exprs&: laneId, exprs&: logicalValueIdDim);
189
190 // Determine what register logicalValueId corresponds to. Use that as a
191 // linear index into the coordinate mapping `index -> (tile row, tile col)`.
192 AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
193 lineSize, elementType, operandShape,
194 isAccumulator: isAccumulatorOrResult(operandType: fragmentType.operandRole), elementsPerRegister,
195 logicalValueId: logicalValueIdDim);
196
197 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
198 return AffineMap::get(dimCount: 2, symbolCount: 0, results: dimExprs, context: builder.getContext());
199 };
200
201 auto tileRow = registerIndexToTileCoord.getResult(idx: 0);
202 auto tileCol = registerIndexToTileCoord.getResult(idx: 1);
203 return makeMap({tileRow + laneId.floorDiv(v: kThreadsPerRow),
204 tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
205 (logicalValueIdDim % elementsPerRegister)});
206}
207
208FailureOr<nvgpu::LdMatrixParams>
209nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
210 LdMatrixParams params;
211 Type elType = type.vectorType.getElementType();
212 params.fragmentType = type.vectorType;
213 if (type.operandRole == MatMulOperandRole::A ||
214 type.operandRole == MatMulOperandRole::C) {
215 params.targetLayout = NVVM::MMALayout::row;
216 } else {
217 params.targetLayout = NVVM::MMALayout::col;
218 }
219 ArrayRef<int64_t> shape = type.vectorType.getShape();
220 params.contiguousDimType = transpose ? vector::IteratorType::parallel
221 : vector::IteratorType::reduction;
222
223 if (params.contiguousDimType == vector::IteratorType::reduction) {
224 params.numTiles = (shape[0] / kNumRowsPerTile) *
225 ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
226 } else {
227 params.numTiles = (shape[1] / kNumRowsPerTile) *
228 ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
229 }
230
231 if (params.numTiles == 0)
232 return failure();
233
234 return params;
235}
236
237FailureOr<AffineMap>
238nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
239 const LdMatrixParams &params) {
240 // One thread per 128b row.
241 const int bitsPerElement = static_cast<int>(
242 params.fragmentType.getElementType().getIntOrFloatBitWidth());
243 const int kElementsPer128b = (128 / bitsPerElement);
244 ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
245 AffineExpr d0 = getAffineDimExpr(position: 0, context: builder.getContext());
246
247 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
248 return AffineMap::get(dimCount: 1, symbolCount: 0, results: dimExprs, context: builder.getContext());
249 };
250
251 // Index `idx` in vectorType `operandShape` maps to the strided dimension of
252 // the `srcMemref` memory of the LdMatrixOp.
253 int idx =
254 (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
255
256 // Affine expr in strided and contiguous dimension encodes the coordinate
257 // mapping for the element a thread points to for warp-wide LdMatrixOp.
258 AffineExpr strided = d0 % (operandShape[idx]);
259 AffineExpr contiguous = d0.floorDiv(v: operandShape[idx]) * (kElementsPer128b);
260
261 // This case corresponds to row-major matrixA or col-major matrixB or
262 // row-major matrixC. This is when the memory layout in `srcMemref`
263 // match mma.sync hardware vector register operand layout.
264 if (params.contiguousDimType == vector::IteratorType::reduction)
265 return makeMap({strided, contiguous});
266
267 // This case corresponds to col-major matrixA or row-major matrixB or
268 // col-major matrixC. This is when the memory layout in `srcMemref` does not
269 // match mma.sync hardware vector register operand layout.
270 if (params.contiguousDimType == vector::IteratorType::parallel)
271 return makeMap({contiguous, strided});
272
273 return failure();
274}
275
276bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
277 if (op.getMask() || op.hasOutOfBoundsDim())
278 return false;
279 VectorType type = op.getType();
280 // The result type should be 2D. Note that it is possible to expand support so
281 // that we are robust to extra unit dimensions that failed to fold, but that
282 // would significantly increase downstream code complexity in the conversion
283 // step. For now, we rely on other patterns to ensure canonical 2D form is
284 // used when targeting the `nvgpu.mma.sync` lowering path.
285 if (!type.hasStaticShape() || type.getRank() != 2)
286 return false;
287
288 // Currently we can't support reads on tensor types because we need stride
289 // information to ensure correctness of downstream assumptions. It is possible
290 // to enable this if caller can assert that tensor will be lowered in a
291 // particular manner.
292 auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
293 if (!sourceType)
294 return false;
295
296 // Check that the last dimension of the read is contiguous. Note that it is
297 // possible to expand support for this by scalarizing all the loads during
298 // conversion.
299 auto [strides, offset] = mlir::getStridesAndOffset(t: sourceType);
300 return strides.back() == 1;
301}
302
303bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
304 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
305 return false;
306 VectorType type = op.getVectorType();
307 if (!type.hasStaticShape() || type.getRank() != 2)
308 return false;
309 // TODO: Currently we rely on lowering to a `vector.store` operation. We could
310 // support the transposed write case by lowering to scalarized `memref.store`
311 // operations.
312 if (!op.getPermutationMap().isMinorIdentity())
313 return false;
314 // Currently we can't support reads on tensor types because we need stride
315 // information to ensure correctness of downstream assumptions.
316 auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
317 if (!sourceType)
318 return false;
319
320 // Check that the last dimension of the target memref is contiguous. Note that
321 // it is possible to expand support for this by scalarizing all the stores
322 // during conversion.
323 auto [strides, offset] = mlir::getStridesAndOffset(t: sourceType);
324 return strides.back() == 1;
325}
326

source code of mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp