1//===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// NVVM Dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
15#include "mlir/Conversion/LLVMCommon/Pattern.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19#include "mlir/IR/TypeUtilities.h"
20
21using namespace mlir;
22
23namespace {
24
25/// Checks if all the operands of the op being lowered are of LLVM Types. The
26/// types are expected to be converted by the `LLVMTypeConverter` before the op
27/// is actually lowered. If the type of an operands is not already converted it
28/// hints a missing typeConversion and failure is returned in that case.
29static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
30 ConversionPatternRewriter &rewriter) {
31 if (!llvm::all_of(Range&: operands, P: [](Value value) {
32 return LLVM::isCompatibleType(type: value.getType());
33 })) {
34 return rewriter.notifyMatchFailure(
35 arg&: op, msg: "cannot convert if operands aren't of LLVM type.");
36 }
37
38 return success();
39}
40
41/// Error string to emit when an unimplemented WMMA variant is encountered.
42static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
43
44static NVVM::MMAFrag convertOperand(StringRef operandName) {
45 if (operandName.equals("AOp"))
46 return NVVM::MMAFrag::a;
47 if (operandName.equals("BOp"))
48 return NVVM::MMAFrag::b;
49 if (operandName.equals("COp"))
50 return NVVM::MMAFrag::c;
51 llvm_unreachable("Unknown operand name");
52}
53
54static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
55 if (type.getElementType().isF16())
56 return NVVM::MMATypes::f16;
57 if (type.getElementType().isF32())
58 return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
59 : NVVM::MMATypes::tf32;
60
61 if (type.getElementType().isSignedInteger(8))
62 return NVVM::MMATypes::s8;
63 if (type.getElementType().isUnsignedInteger(8))
64 return NVVM::MMATypes::u8;
65 // Accumulator type is signless and implies signed.
66 if (type.getElementType().isInteger(32))
67 return NVVM::MMATypes::s32;
68 llvm_unreachable("Unsupported type");
69}
70
71/// This class implements the conversion of GPU MMA loadOp to wmma.load op
72/// in the NVVM dialect. The conversion not only emits the NVVM op but also
73/// emits code that is necessary to store the data in the destination memref
74/// after it has been loaded.
75struct WmmaLoadOpToNVVMLowering
76 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
77 using ConvertOpToLLVMPattern<
78 gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
79
80 LogicalResult
81 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
82 OpAdaptor adaptor,
83 ConversionPatternRewriter &rewriter) const override {
84 Operation *op = subgroupMmaLoadMatrixOp.getOperation();
85 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
86 return failure();
87
88 // Get the shape of the MMAMatrix type being returned. The shape will
89 // choose which intrinsic this op will be lowered to.
90 NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
91 ? NVVM::MMALayout::col
92 : NVVM::MMALayout::row;
93 gpu::MMAMatrixType retType =
94 cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
95 ArrayRef<int64_t> retTypeShape = retType.getShape();
96 int64_t m = 0;
97 int64_t n = 0;
98 int64_t k = 0;
99 NVVM::MMATypes eltype = getElementType(retType);
100 // NVVM intrinsics require to give mxnxk dimensions, infer the missing
101 // dimension based on the valid intrinsics available.
102 if (retType.getOperand().equals(RHS: "AOp")) {
103 m = retTypeShape[0];
104 k = retTypeShape[1];
105 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
106 } else if (retType.getOperand().equals(RHS: "BOp")) {
107 k = retTypeShape[0];
108 n = retTypeShape[1];
109 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
110 } else if (retType.getOperand().equals(RHS: "COp")) {
111 m = retTypeShape[0];
112 n = retTypeShape[1];
113 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
114 }
115 NVVM::MMAFrag frag = convertOperand(retType.getOperand());
116 // Check that there is an exisiting instruction for the combination we need.
117 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
118 return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr);
119
120 Type resType = convertMMAToLLVMType(type: retType);
121 Location loc = op->getLoc();
122
123 // Create nvvm.mma_load op according to the operand types.
124 Value dataPtr = getStridedElementPtr(
125 loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
126 adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
127
128 Value leadingDim = rewriter.create<LLVM::ConstantOp>(
129 loc, rewriter.getI32Type(),
130 subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
131 rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
132 op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
133 return success();
134 }
135};
136
137/// This class implements the conversion of GPU MMA storeOp to wmma.store op
138/// in the NVVM dialect. The conversion not only emits the NVVM op but also
139/// emits code that is necessary to unpack the data in the source and
140/// convert the data in the format that is needed by the NVVM op.
141struct WmmaStoreOpToNVVMLowering
142 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
143 using ConvertOpToLLVMPattern<
144 gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
145
146 LogicalResult
147 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
148 OpAdaptor adaptor,
149 ConversionPatternRewriter &rewriter) const override {
150 Operation *op = subgroupMmaStoreMatrixOp.getOperation();
151 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
152 return failure();
153
154 Location loc = op->getLoc();
155
156 SmallVector<Value, 4> storeOpOperands;
157 // Get the shape of the MMAMatrix type being stored. The shape will
158 // choose which intrinsic this op will be lowered to.
159 gpu::MMAMatrixType srcType =
160 cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
161 ArrayRef<int64_t> srcTypeShape = srcType.getShape();
162 NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
163 ? NVVM::MMALayout::col
164 : NVVM::MMALayout::row;
165 NVVM::MMATypes eltype = getElementType(srcType);
166 int64_t m = srcTypeShape[0];
167 int64_t n = srcTypeShape[1];
168 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
169 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
170 return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr);
171
172 auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
173 for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
174 Value toUse =
175 rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
176 storeOpOperands.push_back(Elt: toUse);
177 }
178
179 Value dataPtr = getStridedElementPtr(
180 loc,
181 cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
182 adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
183 Value leadingDim = rewriter.create<LLVM::ConstantOp>(
184 loc, rewriter.getI32Type(),
185 subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
186 rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
187 op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
188 return success();
189 }
190};
191
192/// This class implements the conversion of GPU MMA computeOp to wmma.mma op
193/// in the NVVM dialect.
194struct WmmaMmaOpToNVVMLowering
195 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
196 using ConvertOpToLLVMPattern<
197 gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
198
199 LogicalResult
200 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
201 OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter) const override {
203 Operation *op = subgroupMmaComputeOp.getOperation();
204 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
205 return failure();
206
207 Location loc = op->getLoc();
208
209 // The wmma.mma intrinsic in llvm requires the operands as individual
210 // values. So individual elements from the memrefs need to be extracted and
211 // then passed on to the intrinsic call. Emit llvm ops to extract individual
212 // values form lowered memrefs.
213 SmallVector<Value> unpackedOps;
214
215 auto unpackOp = [&](Value operand) {
216 auto structType = cast<LLVM::LLVMStructType>(Val: operand.getType());
217 for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
218 Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
219 unpackedOps.push_back(Elt: toUse);
220 }
221 };
222
223 // Get the shapes of the MMAMatrix type being used. The shapes will
224 // choose which intrinsic this op will be lowered to.
225 gpu::MMAMatrixType aType =
226 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
227 ArrayRef<int64_t> aTypeShape = aType.getShape();
228 gpu::MMAMatrixType cType =
229 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
230 ArrayRef<int64_t> cTypeShape = cType.getShape();
231 int64_t m = cTypeShape[0];
232 int64_t n = cTypeShape[1];
233 int64_t k = aTypeShape[1];
234 NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
235 ? NVVM::MMALayout::col
236 : NVVM::MMALayout::row;
237 NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
238 ? NVVM::MMALayout::col
239 : NVVM::MMALayout::row;
240 NVVM::MMATypes sourceType = getElementType(aType);
241 NVVM::MMATypes destType = getElementType(cType);
242 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
243 destType) == 0)
244 return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr);
245
246 NVVM::MMATypes bElementType = getElementType(
247 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
248 if (bElementType != sourceType)
249 return rewriter.notifyMatchFailure(
250 arg&: op, msg: "WMMA compute op input matrix element types must match.");
251
252 unpackOp(adaptor.getOpA());
253 unpackOp(adaptor.getOpB());
254 unpackOp(adaptor.getOpC());
255
256 rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
257 op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
258 destType, unpackedOps);
259 return success();
260 }
261};
262
263/// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
264struct WmmaConstantOpToNVVMLowering
265 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
266 using ConvertOpToLLVMPattern<
267 gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
268
269 LogicalResult
270 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
271 OpAdaptor adaptor,
272 ConversionPatternRewriter &rewriter) const override {
273 if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(),
274 adaptor.getOperands(), rewriter)))
275 return failure();
276 Location loc = subgroupMmaConstantOp.getLoc();
277 Value cst = adaptor.getOperands()[0];
278 LLVM::LLVMStructType type = convertMMAToLLVMType(
279 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
280 // If the element type is a vector create a vector from the operand.
281 if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
282 Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
283 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
284 Value idx = rewriter.create<LLVM::ConstantOp>(
285 loc, rewriter.getI32Type(), vecEl);
286 vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
287 cst, idx);
288 }
289 cst = vecCst;
290 }
291 Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type);
292 for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
293 matrixStruct =
294 rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
295 }
296 rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
297 return success();
298 }
299};
300
301static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
302 Value rhs, bool isMin) {
303 auto floatType = cast<FloatType>(Val: getElementTypeOrSelf(type: lhs.getType()));
304 Type i1Type = builder.getI1Type();
305 if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
306 i1Type = VectorType::get(vecType.getShape(), i1Type);
307 Value cmp = builder.create<LLVM::FCmpOp>(
308 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
309 lhs, rhs);
310 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
311 Value isNan = builder.create<LLVM::FCmpOp>(
312 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
313 Value nan = builder.create<LLVM::ConstantOp>(
314 loc, lhs.getType(),
315 builder.getFloatAttr(floatType,
316 APFloat::getQNaN(floatType.getFloatSemantics())));
317 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
318}
319
320static Value createScalarOp(OpBuilder &builder, Location loc,
321 gpu::MMAElementwiseOp op,
322 ArrayRef<Value> operands) {
323 switch (op) {
324 case gpu::MMAElementwiseOp::ADDF:
325 return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
326 case gpu::MMAElementwiseOp::MULF:
327 return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
328 case gpu::MMAElementwiseOp::DIVF:
329 return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
330 case gpu::MMAElementwiseOp::MAXF:
331 return createMinMaxF(builder, loc, lhs: operands[0], rhs: operands[1],
332 /*isMin=*/false);
333 case gpu::MMAElementwiseOp::MINF:
334 return createMinMaxF(builder, loc, lhs: operands[0], rhs: operands[1],
335 /*isMin=*/true);
336 default:
337 llvm_unreachable("unknown op");
338 }
339}
340
341/// Convert GPU MMA elementwise ops to extract + op + insert.
342struct WmmaElementwiseOpToNVVMLowering
343 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
344 using ConvertOpToLLVMPattern<
345 gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
346
347 LogicalResult
348 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
349 OpAdaptor adaptor,
350 ConversionPatternRewriter &rewriter) const override {
351 if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
352 adaptor.getOperands(), rewriter)))
353 return failure();
354 Location loc = subgroupMmaElementwiseOp.getLoc();
355 size_t numOperands = adaptor.getOperands().size();
356 LLVM::LLVMStructType destType = convertMMAToLLVMType(
357 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
358 Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
359 for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
360 SmallVector<Value> extractedOperands;
361 for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
362 extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
363 loc, adaptor.getOperands()[opIdx], i));
364 }
365 Value element =
366 createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
367 extractedOperands);
368 matrixStruct =
369 rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
370 }
371 rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
372 return success();
373 }
374};
375
376} // namespace
377
378/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
379LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
380 NVVM::MMAFrag frag = convertOperand(type.getOperand());
381 NVVM::MMATypes eltType = getElementType(type);
382 auto nRow = type.getShape()[0];
383 auto nCol = type.getShape()[1];
384 std::pair<Type, unsigned> typeInfo =
385 NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
386 return LLVM::LLVMStructType::getLiteral(
387 context: type.getContext(), types: SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
388}
389
390void mlir::populateGpuWMMAToNVVMConversionPatterns(
391 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
392 patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
393 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
394 WmmaElementwiseOpToNVVMLowering>(arg&: converter);
395}
396

source code of mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp