1//===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
9
10#include "mlir/Dialect/Affine/IR/AffineOps.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
13#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15
16using namespace mlir;
17using namespace mlir::nvgpu;
18
19/// There are always 4 threads per [128|256|512] bit row.
20static constexpr int64_t kThreadsPerRow = 4;
21static constexpr int64_t kNumRowsPerTile = 8;
22
23static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
24 return operandType == MatMulOperandRole::C;
25}
26
27/// Returns the number of registers which compose a matrix fragment held by a
28/// single thread.
29static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
30 int64_t lineSize = inferTileWidthInBits(type);
31 auto shape = type.vectorType.getShape();
32 return (shape[0] / kNumRowsPerTile) *
33 (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
34 lineSize;
35}
36
37/// Returns the number of 8 x [128|256|512] bit tiles that compose the given
38/// operand shape.
39static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
40 Type elementType,
41 int64_t lineSizeBits) {
42 // For each 8x128bit square, a thread is responsible for one 32bit register.
43 return {operandShape[0] / kNumRowsPerTile,
44 (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
45 lineSizeBits};
46}
47
48/// Returns the first user of the `op` that is vector.contract. If no
49/// vector.contract user exists, return failure.
50FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
51 for (Operation *user : op->getUsers()) {
52 if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
53 return contractOp;
54 }
55 return failure();
56}
57
58FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
59 WarpMatrixInfo info;
60
61 // Determine the vector type at warp-level.
62 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
63 info.vectorType = writeOp.getVectorType();
64 } else if (isa<vector::TransferReadOp, vector::ContractionOp,
65 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
66 info.vectorType = cast<VectorType>(op->getResult(idx: 0).getType());
67 } else {
68 return op->emitError()
69 << "unhandled operation type in nvgpu.mma.sync conversion path";
70 }
71
72 // Determine the operand role. We assume it is an accumulator/result unless it
73 // is directly consumed by a `vector.contract` op.
74 info.operandRole = MatMulOperandRole::C;
75 FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
76 if (failed(Result: contractOp))
77 return info;
78
79 if ((*contractOp).getLhs() == op->getResult(idx: 0))
80 info.operandRole = MatMulOperandRole::A;
81 else if ((*contractOp).getRhs() == op->getResult(idx: 0))
82 info.operandRole = MatMulOperandRole::B;
83
84 return info;
85}
86
87int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) {
88 bool isAcc = isAccumulatorOrResult(operandType: type.operandRole);
89 Type elType = type.vectorType.getElementType();
90 if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
91 return 256;
92 }
93 if (elType.getIntOrFloatBitWidth() == 64) {
94 return isAcc ? 512 : 256;
95 }
96 return 128;
97}
98
99FailureOr<FragmentElementInfo>
100nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
101 MLIRContext *ctx = type.vectorType.getContext();
102 const bool isAccum = isAccumulatorOrResult(operandType: type.operandRole);
103
104 Type elType = type.vectorType.getElementType();
105 if (elType.isF16()) {
106 return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
107 inferNumRegistersPerMatrixFragment(type)};
108 }
109
110 // f64 operand
111 Type f64Ty = Float64Type::get(ctx);
112 if (elType.isF64()) {
113 return isAccum
114 ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
115 inferNumRegistersPerMatrixFragment(type)}
116 : FragmentElementInfo{f64Ty, 1, 64,
117 inferNumRegistersPerMatrixFragment(type)};
118 }
119
120 // int8 operand
121 if (elType.isInteger(width: 8)) {
122 return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
123 32, inferNumRegistersPerMatrixFragment(type)};
124 }
125
126 // int4 operand
127 if (elType.isInteger(width: 4)) {
128 return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
129 32, inferNumRegistersPerMatrixFragment(type)};
130 }
131
132 // Integer 32bit acc operands
133 if (elType.isInteger(width: 32)) {
134 return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
135 64, inferNumRegistersPerMatrixFragment(type)};
136 }
137
138 // Floating point 32bit operands
139 if (elType.isF32()) {
140 Type f32Ty = Float32Type::get(ctx);
141 return isAccum
142 ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
143 inferNumRegistersPerMatrixFragment(type)}
144 : FragmentElementInfo{f32Ty, 1, 32,
145 inferNumRegistersPerMatrixFragment(type)};
146 }
147 return failure();
148}
149
150static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
151 Type elementType,
152 ArrayRef<int64_t> operandShape,
153 bool isAccumulator,
154 int64_t elementsPerRegister,
155 AffineExpr logicalValueId) {
156 const int64_t elementsPerLine =
157 lineSize / elementType.getIntOrFloatBitWidth();
158 const std::array<int64_t, 2> num8x128bTiles =
159 getTileShape(operandShape, elementType, lineSizeBits: lineSize);
160 AffineExpr registerIdx = logicalValueId.floorDiv(v: elementsPerRegister);
161 return AffineMap::get(
162 dimCount: 2, symbolCount: 0,
163 results: {(registerIdx % num8x128bTiles[0]) * 8,
164 (registerIdx.floorDiv(v: num8x128bTiles[0])) * elementsPerLine},
165 context: elementType.getContext());
166}
167
168FailureOr<AffineMap>
169nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
170 const WarpMatrixInfo &fragmentType) {
171 Type elementType = fragmentType.vectorType.getElementType();
172 ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
173 FailureOr<nvgpu::FragmentElementInfo> regInfo =
174 getMmaSyncRegisterType(type: fragmentType);
175 if (failed(Result: regInfo))
176 return failure();
177
178 const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
179 const int64_t elementsPerRegister =
180 regInfo->registerWidthBits / elementBitWidth;
181 const int64_t lineSize = inferTileWidthInBits(type: fragmentType);
182
183 AffineExpr laneId, logicalValueIdDim;
184 bindDims(ctx: builder.getContext(), exprs&: laneId, exprs&: logicalValueIdDim);
185
186 // Determine what register logicalValueId corresponds to. Use that as a
187 // linear index into the coordinate mapping `index -> (tile row, tile col)`.
188 AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
189 lineSize, elementType, operandShape,
190 isAccumulator: isAccumulatorOrResult(operandType: fragmentType.operandRole), elementsPerRegister,
191 logicalValueId: logicalValueIdDim);
192
193 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
194 return AffineMap::get(dimCount: 2, symbolCount: 0, results: dimExprs, context: builder.getContext());
195 };
196
197 auto tileRow = registerIndexToTileCoord.getResult(idx: 0);
198 auto tileCol = registerIndexToTileCoord.getResult(idx: 1);
199 return makeMap({tileRow + laneId.floorDiv(v: kThreadsPerRow),
200 tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
201 (logicalValueIdDim % elementsPerRegister)});
202}
203
204FailureOr<nvgpu::LdMatrixParams>
205nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
206 LdMatrixParams params;
207 Type elType = type.vectorType.getElementType();
208 params.fragmentType = type.vectorType;
209 if (type.operandRole == MatMulOperandRole::A ||
210 type.operandRole == MatMulOperandRole::C) {
211 params.targetLayout = NVVM::MMALayout::row;
212 } else {
213 params.targetLayout = NVVM::MMALayout::col;
214 }
215 ArrayRef<int64_t> shape = type.vectorType.getShape();
216 params.contiguousDimType = transpose ? vector::IteratorType::parallel
217 : vector::IteratorType::reduction;
218
219 if (params.contiguousDimType == vector::IteratorType::reduction) {
220 params.numTiles = (shape[0] / kNumRowsPerTile) *
221 ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
222 } else {
223 params.numTiles = (shape[1] / kNumRowsPerTile) *
224 ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
225 }
226
227 if (params.numTiles == 0)
228 return failure();
229
230 return params;
231}
232
233FailureOr<AffineMap>
234nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
235 const LdMatrixParams &params) {
236 // One thread per 128b row.
237 const int bitsPerElement = static_cast<int>(
238 params.fragmentType.getElementType().getIntOrFloatBitWidth());
239 const int kElementsPer128b = (128 / bitsPerElement);
240 ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
241 AffineExpr d0 = getAffineDimExpr(position: 0, context: builder.getContext());
242
243 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
244 return AffineMap::get(dimCount: 1, symbolCount: 0, results: dimExprs, context: builder.getContext());
245 };
246
247 // Index `idx` in vectorType `operandShape` maps to the strided dimension of
248 // the `srcMemref` memory of the LdMatrixOp.
249 int idx =
250 (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
251
252 // Affine expr in strided and contiguous dimension encodes the coordinate
253 // mapping for the element a thread points to for warp-wide LdMatrixOp.
254 AffineExpr strided = d0 % (operandShape[idx]);
255 AffineExpr contiguous = d0.floorDiv(v: operandShape[idx]) * (kElementsPer128b);
256
257 // This case corresponds to row-major matrixA or col-major matrixB or
258 // row-major matrixC. This is when the memory layout in `srcMemref`
259 // match mma.sync hardware vector register operand layout.
260 if (params.contiguousDimType == vector::IteratorType::reduction)
261 return makeMap({strided, contiguous});
262
263 // This case corresponds to col-major matrixA or row-major matrixB or
264 // col-major matrixC. This is when the memory layout in `srcMemref` does not
265 // match mma.sync hardware vector register operand layout.
266 if (params.contiguousDimType == vector::IteratorType::parallel)
267 return makeMap({contiguous, strided});
268
269 return failure();
270}
271
272bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
273 if (op.getMask() || op.hasOutOfBoundsDim())
274 return false;
275 VectorType type = op.getType();
276 // The result type should be 2D. Note that it is possible to expand support so
277 // that we are robust to extra unit dimensions that failed to fold, but that
278 // would significantly increase downstream code complexity in the conversion
279 // step. For now, we rely on other patterns to ensure canonical 2D form is
280 // used when targeting the `nvgpu.mma.sync` lowering path.
281 if (!type.hasStaticShape() || type.getRank() != 2)
282 return false;
283
284 // Currently we can't support reads on tensor types because we need stride
285 // information to ensure correctness of downstream assumptions. It is possible
286 // to enable this if caller can assert that tensor will be lowered in a
287 // particular manner.
288 auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
289 if (!sourceType)
290 return false;
291
292 // Check that the last dimension of the read is contiguous. Note that it is
293 // possible to expand support for this by scalarizing all the loads during
294 // conversion.
295 auto [strides, offset] = sourceType.getStridesAndOffset();
296 return strides.back() == 1;
297}
298
299bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
300 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
301 return false;
302 VectorType type = op.getVectorType();
303 if (!type.hasStaticShape() || type.getRank() != 2)
304 return false;
305 // TODO: Currently we rely on lowering to a `vector.store` operation. We could
306 // support the transposed write case by lowering to scalarized `memref.store`
307 // operations.
308 if (!op.getPermutationMap().isMinorIdentity())
309 return false;
310 // Currently we can't support reads on tensor types because we need stride
311 // information to ensure correctness of downstream assumptions.
312 auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
313 if (!sourceType)
314 return false;
315
316 // Check that the last dimension of the target memref is contiguous. Note that
317 // it is possible to expand support for this by scalarizing all the stores
318 // during conversion.
319 auto [strides, offset] = sourceType.getStridesAndOffset();
320 return strides.back() == 1;
321}
322

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp