1//===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
9
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
12#include "mlir/Dialect/Vector/IR/VectorOps.h"
13
14using namespace mlir;
15using namespace mlir::nvgpu;
16
17/// There are always 4 threads per [128|256|512] bit row.
18static constexpr int64_t kThreadsPerRow = 4;
19static constexpr int64_t kNumRowsPerTile = 8;
20
21static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
22 return operandType == MatMulOperandRole::C;
23}
24
25/// Returns the number of registers which compose a matrix fragment held by a
26/// single thread.
27static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
28 int64_t lineSize = inferTileWidthInBits(type);
29 auto shape = type.vectorType.getShape();
30 return (shape[0] / kNumRowsPerTile) *
31 (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
32 lineSize;
33}
34
35/// Returns the number of 8 x [128|256|512] bit tiles that compose the given
36/// operand shape.
37static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
38 Type elementType,
39 int64_t lineSizeBits) {
40 // For each 8x128bit square, a thread is responsible for one 32bit register.
41 return {operandShape[0] / kNumRowsPerTile,
42 (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
43 lineSizeBits};
44}
45
46/// Returns the first user of the `op` that is vector.contract. If no
47/// vector.contract user exists, return failure.
48FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
49 for (Operation *user : op->getUsers()) {
50 if (auto contractOp = dyn_cast<vector::ContractionOp>(Val: user))
51 return contractOp;
52 }
53 return failure();
54}
55
56FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
57 WarpMatrixInfo info;
58
59 // Determine the vector type at warp-level.
60 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(Val: op)) {
61 info.vectorType = writeOp.getVectorType();
62 } else if (isa<vector::TransferReadOp, vector::ContractionOp,
63 vector::ExtractStridedSliceOp, arith::ConstantOp>(Val: op)) {
64 info.vectorType = cast<VectorType>(Val: op->getResult(idx: 0).getType());
65 } else {
66 return op->emitError()
67 << "unhandled operation type in nvgpu.mma.sync conversion path";
68 }
69
70 // Determine the operand role. We assume it is an accumulator/result unless it
71 // is directly consumed by a `vector.contract` op.
72 info.operandRole = MatMulOperandRole::C;
73 FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
74 if (failed(Result: contractOp))
75 return info;
76
77 if ((*contractOp).getLhs() == op->getResult(idx: 0))
78 info.operandRole = MatMulOperandRole::A;
79 else if ((*contractOp).getRhs() == op->getResult(idx: 0))
80 info.operandRole = MatMulOperandRole::B;
81
82 return info;
83}
84
85int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) {
86 bool isAcc = isAccumulatorOrResult(operandType: type.operandRole);
87 Type elType = type.vectorType.getElementType();
88 if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
89 return 256;
90 }
91 if (elType.getIntOrFloatBitWidth() == 64) {
92 return isAcc ? 512 : 256;
93 }
94 return 128;
95}
96
97FailureOr<FragmentElementInfo>
98nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
99 MLIRContext *ctx = type.vectorType.getContext();
100 const bool isAccum = isAccumulatorOrResult(operandType: type.operandRole);
101
102 Type elType = type.vectorType.getElementType();
103 if (elType.isF16()) {
104 return FragmentElementInfo{.registerLLVMType: VectorType::get(shape: 2, elementType: Float16Type::get(context: ctx)), .elementsPerRegister: 2, .registerWidthBits: 32,
105 .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)};
106 }
107
108 // f64 operand
109 Type f64Ty = Float64Type::get(context: ctx);
110 if (elType.isF64()) {
111 return isAccum
112 ? FragmentElementInfo{.registerLLVMType: VectorType::get(shape: 2, elementType: f64Ty), .elementsPerRegister: 2, .registerWidthBits: 128,
113 .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)}
114 : FragmentElementInfo{.registerLLVMType: f64Ty, .elementsPerRegister: 1, .registerWidthBits: 64,
115 .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)};
116 }
117
118 // int8 operand
119 if (elType.isInteger(width: 8)) {
120 return FragmentElementInfo{.registerLLVMType: VectorType::get(shape: 4, elementType: IntegerType::get(context: ctx, width: 8)), .elementsPerRegister: 4,
121 .registerWidthBits: 32, .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)};
122 }
123
124 // int4 operand
125 if (elType.isInteger(width: 4)) {
126 return FragmentElementInfo{.registerLLVMType: VectorType::get(shape: 8, elementType: IntegerType::get(context: ctx, width: 4)), .elementsPerRegister: 8,
127 .registerWidthBits: 32, .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)};
128 }
129
130 // Integer 32bit acc operands
131 if (elType.isInteger(width: 32)) {
132 return FragmentElementInfo{.registerLLVMType: VectorType::get(shape: 2, elementType: IntegerType::get(context: ctx, width: 32)), .elementsPerRegister: 2,
133 .registerWidthBits: 64, .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)};
134 }
135
136 // Floating point 32bit operands
137 if (elType.isF32()) {
138 Type f32Ty = Float32Type::get(context: ctx);
139 return isAccum
140 ? FragmentElementInfo{.registerLLVMType: VectorType::get(shape: 2, elementType: f32Ty), .elementsPerRegister: 2, .registerWidthBits: 64,
141 .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)}
142 : FragmentElementInfo{.registerLLVMType: f32Ty, .elementsPerRegister: 1, .registerWidthBits: 32,
143 .numRegistersPerFragment: inferNumRegistersPerMatrixFragment(type)};
144 }
145 return failure();
146}
147
148static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
149 Type elementType,
150 ArrayRef<int64_t> operandShape,
151 bool isAccumulator,
152 int64_t elementsPerRegister,
153 AffineExpr logicalValueId) {
154 const int64_t elementsPerLine =
155 lineSize / elementType.getIntOrFloatBitWidth();
156 const std::array<int64_t, 2> num8x128bTiles =
157 getTileShape(operandShape, elementType, lineSizeBits: lineSize);
158 AffineExpr registerIdx = logicalValueId.floorDiv(v: elementsPerRegister);
159 return AffineMap::get(
160 dimCount: 2, symbolCount: 0,
161 results: {(registerIdx % num8x128bTiles[0]) * 8,
162 (registerIdx.floorDiv(v: num8x128bTiles[0])) * elementsPerLine},
163 context: elementType.getContext());
164}
165
166FailureOr<AffineMap>
167nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
168 const WarpMatrixInfo &fragmentType) {
169 Type elementType = fragmentType.vectorType.getElementType();
170 ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
171 FailureOr<nvgpu::FragmentElementInfo> regInfo =
172 getMmaSyncRegisterType(type: fragmentType);
173 if (failed(Result: regInfo))
174 return failure();
175
176 const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
177 const int64_t elementsPerRegister =
178 regInfo->registerWidthBits / elementBitWidth;
179 const int64_t lineSize = inferTileWidthInBits(type: fragmentType);
180
181 AffineExpr laneId, logicalValueIdDim;
182 bindDims(ctx: builder.getContext(), exprs&: laneId, exprs&: logicalValueIdDim);
183
184 // Determine what register logicalValueId corresponds to. Use that as a
185 // linear index into the coordinate mapping `index -> (tile row, tile col)`.
186 AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
187 lineSize, elementType, operandShape,
188 isAccumulator: isAccumulatorOrResult(operandType: fragmentType.operandRole), elementsPerRegister,
189 logicalValueId: logicalValueIdDim);
190
191 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
192 return AffineMap::get(dimCount: 2, symbolCount: 0, results: dimExprs, context: builder.getContext());
193 };
194
195 auto tileRow = registerIndexToTileCoord.getResult(idx: 0);
196 auto tileCol = registerIndexToTileCoord.getResult(idx: 1);
197 return makeMap({tileRow + laneId.floorDiv(v: kThreadsPerRow),
198 tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
199 (logicalValueIdDim % elementsPerRegister)});
200}
201
202FailureOr<nvgpu::LdMatrixParams>
203nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
204 LdMatrixParams params;
205 Type elType = type.vectorType.getElementType();
206 params.fragmentType = type.vectorType;
207 if (type.operandRole == MatMulOperandRole::A ||
208 type.operandRole == MatMulOperandRole::C) {
209 params.targetLayout = NVVM::MMALayout::row;
210 } else {
211 params.targetLayout = NVVM::MMALayout::col;
212 }
213 ArrayRef<int64_t> shape = type.vectorType.getShape();
214 params.contiguousDimType = transpose ? vector::IteratorType::parallel
215 : vector::IteratorType::reduction;
216
217 if (params.contiguousDimType == vector::IteratorType::reduction) {
218 params.numTiles = (shape[0] / kNumRowsPerTile) *
219 ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
220 } else {
221 params.numTiles = (shape[1] / kNumRowsPerTile) *
222 ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
223 }
224
225 if (params.numTiles == 0)
226 return failure();
227
228 return params;
229}
230
231FailureOr<AffineMap>
232nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
233 const LdMatrixParams &params) {
234 // One thread per 128b row.
235 const int bitsPerElement = static_cast<int>(
236 params.fragmentType.getElementType().getIntOrFloatBitWidth());
237 const int kElementsPer128b = (128 / bitsPerElement);
238 ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
239 AffineExpr d0 = getAffineDimExpr(position: 0, context: builder.getContext());
240
241 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
242 return AffineMap::get(dimCount: 1, symbolCount: 0, results: dimExprs, context: builder.getContext());
243 };
244
245 // Index `idx` in vectorType `operandShape` maps to the strided dimension of
246 // the `srcMemref` memory of the LdMatrixOp.
247 int idx =
248 (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
249
250 // Affine expr in strided and contiguous dimension encodes the coordinate
251 // mapping for the element a thread points to for warp-wide LdMatrixOp.
252 AffineExpr strided = d0 % (operandShape[idx]);
253 AffineExpr contiguous = d0.floorDiv(v: operandShape[idx]) * (kElementsPer128b);
254
255 // This case corresponds to row-major matrixA or col-major matrixB or
256 // row-major matrixC. This is when the memory layout in `srcMemref`
257 // match mma.sync hardware vector register operand layout.
258 if (params.contiguousDimType == vector::IteratorType::reduction)
259 return makeMap({strided, contiguous});
260
261 // This case corresponds to col-major matrixA or row-major matrixB or
262 // col-major matrixC. This is when the memory layout in `srcMemref` does not
263 // match mma.sync hardware vector register operand layout.
264 if (params.contiguousDimType == vector::IteratorType::parallel)
265 return makeMap({contiguous, strided});
266
267 return failure();
268}
269
270bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
271 if (op.getMask() || op.hasOutOfBoundsDim())
272 return false;
273 VectorType type = op.getType();
274 // The result type should be 2D. Note that it is possible to expand support so
275 // that we are robust to extra unit dimensions that failed to fold, but that
276 // would significantly increase downstream code complexity in the conversion
277 // step. For now, we rely on other patterns to ensure canonical 2D form is
278 // used when targeting the `nvgpu.mma.sync` lowering path.
279 if (!type.hasStaticShape() || type.getRank() != 2)
280 return false;
281
282 // Currently we can't support reads on tensor types because we need stride
283 // information to ensure correctness of downstream assumptions. It is possible
284 // to enable this if caller can assert that tensor will be lowered in a
285 // particular manner.
286 auto sourceType = dyn_cast<MemRefType>(Val: op.getBase().getType());
287 if (!sourceType)
288 return false;
289
290 // Check that the last dimension of the read is contiguous. Note that it is
291 // possible to expand support for this by scalarizing all the loads during
292 // conversion.
293 auto [strides, offset] = sourceType.getStridesAndOffset();
294 return strides.back() == 1;
295}
296
297bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
298 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
299 return false;
300 VectorType type = op.getVectorType();
301 if (!type.hasStaticShape() || type.getRank() != 2)
302 return false;
303 // TODO: Currently we rely on lowering to a `vector.store` operation. We could
304 // support the transposed write case by lowering to scalarized `memref.store`
305 // operations.
306 if (!op.getPermutationMap().isMinorIdentity())
307 return false;
308 // Currently we can't support reads on tensor types because we need stride
309 // information to ensure correctness of downstream assumptions.
310 auto sourceType = dyn_cast<MemRefType>(Val: op.getBase().getType());
311 if (!sourceType)
312 return false;
313
314 // Check that the last dimension of the target memref is contiguous. Note that
315 // it is possible to expand support for this by scalarizing all the stores
316 // during conversion.
317 auto [strides, offset] = sourceType.getStridesAndOffset();
318 return strides.back() == 1;
319}
320

source code of mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp