1//===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// NVVM Dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
15#include "mlir/Conversion/LLVMCommon/Pattern.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19#include "mlir/IR/TypeUtilities.h"
20
21using namespace mlir;
22
23namespace {
24
25/// Checks if all the operands of the op being lowered are of LLVM Types. The
26/// types are expected to be converted by the `LLVMTypeConverter` before the op
27/// is actually lowered. If the type of an operands is not already converted it
28/// hints a missing typeConversion and failure is returned in that case.
29static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
30 ConversionPatternRewriter &rewriter) {
31 if (!llvm::all_of(Range&: operands, P: [](Value value) {
32 return LLVM::isCompatibleType(type: value.getType());
33 })) {
34 return rewriter.notifyMatchFailure(
35 arg&: op, msg: "cannot convert if operands aren't of LLVM type.");
36 }
37
38 return success();
39}
40
41/// Error string to emit when an unimplemented WMMA variant is encountered.
42static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
43
44static NVVM::MMAFrag convertOperand(StringRef operandName) {
45 if (operandName == "AOp")
46 return NVVM::MMAFrag::a;
47 if (operandName == "BOp")
48 return NVVM::MMAFrag::b;
49 if (operandName == "COp")
50 return NVVM::MMAFrag::c;
51 llvm_unreachable("Unknown operand name");
52}
53
54static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
55 if (type.getElementType().isF16())
56 return NVVM::MMATypes::f16;
57 if (type.getElementType().isF32())
58 return type.getOperand() == "COp" ? NVVM::MMATypes::f32
59 : NVVM::MMATypes::tf32;
60
61 if (type.getElementType().isSignedInteger(width: 8))
62 return NVVM::MMATypes::s8;
63 if (type.getElementType().isUnsignedInteger(width: 8))
64 return NVVM::MMATypes::u8;
65 // Accumulator type is signless and implies signed.
66 if (type.getElementType().isInteger(width: 32))
67 return NVVM::MMATypes::s32;
68 llvm_unreachable("Unsupported type");
69}
70
71/// This class implements the conversion of GPU MMA loadOp to wmma.load op
72/// in the NVVM dialect. The conversion not only emits the NVVM op but also
73/// emits code that is necessary to store the data in the destination memref
74/// after it has been loaded.
75struct WmmaLoadOpToNVVMLowering
76 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
77 using ConvertOpToLLVMPattern<
78 gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
79
80 LogicalResult
81 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
82 OpAdaptor adaptor,
83 ConversionPatternRewriter &rewriter) const override {
84 Operation *op = subgroupMmaLoadMatrixOp.getOperation();
85 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)))
86 return failure();
87
88 // Get the shape of the MMAMatrix type being returned. The shape will
89 // choose which intrinsic this op will be lowered to.
90 NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
91 ? NVVM::MMALayout::col
92 : NVVM::MMALayout::row;
93 gpu::MMAMatrixType retType =
94 cast<gpu::MMAMatrixType>(Val: subgroupMmaLoadMatrixOp.getRes().getType());
95 ArrayRef<int64_t> retTypeShape = retType.getShape();
96 int64_t m = 0;
97 int64_t n = 0;
98 int64_t k = 0;
99 NVVM::MMATypes eltype = getElementType(type: retType);
100 // NVVM intrinsics require to give mxnxk dimensions, infer the missing
101 // dimension based on the valid intrinsics available.
102 if (retType.getOperand() == "AOp") {
103 m = retTypeShape[0];
104 k = retTypeShape[1];
105 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltypeEnum: eltype);
106 } else if (retType.getOperand() == "BOp") {
107 k = retTypeShape[0];
108 n = retTypeShape[1];
109 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltypeEnum: eltype);
110 } else if (retType.getOperand() == "COp") {
111 m = retTypeShape[0];
112 n = retTypeShape[1];
113 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltypeEnum: eltype);
114 }
115 NVVM::MMAFrag frag = convertOperand(operandName: retType.getOperand());
116 // Check that there is an exisiting instruction for the combination we need.
117 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layoutEnum: layout, eltypeEnum: eltype, fragEnum: frag) == 0)
118 return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr);
119
120 Type resType = convertMMAToLLVMType(type: retType);
121 Location loc = op->getLoc();
122
123 // Create nvvm.mma_load op according to the operand types.
124 Value dataPtr = getStridedElementPtr(
125 rewriter, loc,
126 type: cast<MemRefType>(Val: subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
127 memRefDesc: adaptor.getSrcMemref(), indices: adaptor.getIndices());
128
129 Value leadingDim = rewriter.create<LLVM::ConstantOp>(
130 location: loc, args: rewriter.getI32Type(),
131 args: subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
132 rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
133 op, args&: resType, args&: dataPtr, args&: leadingDim, args&: m, args&: n, args&: k, args&: layout, args&: eltype, args&: frag);
134 return success();
135 }
136};
137
138/// This class implements the conversion of GPU MMA storeOp to wmma.store op
139/// in the NVVM dialect. The conversion not only emits the NVVM op but also
140/// emits code that is necessary to unpack the data in the source and
141/// convert the data in the format that is needed by the NVVM op.
142struct WmmaStoreOpToNVVMLowering
143 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
144 using ConvertOpToLLVMPattern<
145 gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
146
147 LogicalResult
148 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
149 OpAdaptor adaptor,
150 ConversionPatternRewriter &rewriter) const override {
151 Operation *op = subgroupMmaStoreMatrixOp.getOperation();
152 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)))
153 return failure();
154
155 Location loc = op->getLoc();
156
157 SmallVector<Value, 4> storeOpOperands;
158 // Get the shape of the MMAMatrix type being stored. The shape will
159 // choose which intrinsic this op will be lowered to.
160 gpu::MMAMatrixType srcType =
161 cast<gpu::MMAMatrixType>(Val: subgroupMmaStoreMatrixOp.getSrc().getType());
162 ArrayRef<int64_t> srcTypeShape = srcType.getShape();
163 NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
164 ? NVVM::MMALayout::col
165 : NVVM::MMALayout::row;
166 NVVM::MMATypes eltype = getElementType(type: srcType);
167 int64_t m = srcTypeShape[0];
168 int64_t n = srcTypeShape[1];
169 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltypeEnum: eltype);
170 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layoutEnum: layout, eltypeEnum: eltype) == 0)
171 return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr);
172
173 auto matrixType = cast<LLVM::LLVMStructType>(Val: adaptor.getSrc().getType());
174 for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
175 Value toUse =
176 rewriter.create<LLVM::ExtractValueOp>(location: loc, args: adaptor.getSrc(), args&: i);
177 storeOpOperands.push_back(Elt: toUse);
178 }
179
180 Value dataPtr = getStridedElementPtr(
181 rewriter, loc,
182 type: cast<MemRefType>(Val: subgroupMmaStoreMatrixOp.getDstMemref().getType()),
183 memRefDesc: adaptor.getDstMemref(), indices: adaptor.getIndices());
184 Value leadingDim = rewriter.create<LLVM::ConstantOp>(
185 location: loc, args: rewriter.getI32Type(),
186 args: subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
187 rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
188 op, args&: dataPtr, args&: m, args&: n, args&: k, args&: layout, args&: eltype, args&: storeOpOperands, args&: leadingDim);
189 return success();
190 }
191};
192
193/// This class implements the conversion of GPU MMA computeOp to wmma.mma op
194/// in the NVVM dialect.
195struct WmmaMmaOpToNVVMLowering
196 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
197 using ConvertOpToLLVMPattern<
198 gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
199
200 LogicalResult
201 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
202 OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter) const override {
204 Operation *op = subgroupMmaComputeOp.getOperation();
205 if (failed(Result: areAllLLVMTypes(op, operands: adaptor.getOperands(), rewriter)))
206 return failure();
207
208 Location loc = op->getLoc();
209
210 // The wmma.mma intrinsic in llvm requires the operands as individual
211 // values. So individual elements from the memrefs need to be extracted and
212 // then passed on to the intrinsic call. Emit llvm ops to extract individual
213 // values form lowered memrefs.
214 SmallVector<Value> unpackedOps;
215
216 auto unpackOp = [&](Value operand) {
217 auto structType = cast<LLVM::LLVMStructType>(Val: operand.getType());
218 for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
219 Value toUse = rewriter.create<LLVM::ExtractValueOp>(location: loc, args&: operand, args&: i);
220 unpackedOps.push_back(Elt: toUse);
221 }
222 };
223
224 // Get the shapes of the MMAMatrix type being used. The shapes will
225 // choose which intrinsic this op will be lowered to.
226 gpu::MMAMatrixType aType =
227 cast<gpu::MMAMatrixType>(Val: subgroupMmaComputeOp.getOpA().getType());
228 ArrayRef<int64_t> aTypeShape = aType.getShape();
229 gpu::MMAMatrixType cType =
230 cast<gpu::MMAMatrixType>(Val: subgroupMmaComputeOp.getOpC().getType());
231 ArrayRef<int64_t> cTypeShape = cType.getShape();
232 int64_t m = cTypeShape[0];
233 int64_t n = cTypeShape[1];
234 int64_t k = aTypeShape[1];
235 NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
236 ? NVVM::MMALayout::col
237 : NVVM::MMALayout::row;
238 NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
239 ? NVVM::MMALayout::col
240 : NVVM::MMALayout::row;
241 NVVM::MMATypes sourceType = getElementType(type: aType);
242 NVVM::MMATypes destType = getElementType(type: cType);
243 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, layoutAEnum: aLayout, layoutBEnum: bLayout, eltypeAEnum: sourceType,
244 eltypeBEnum: destType) == 0)
245 return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr);
246
247 NVVM::MMATypes bElementType = getElementType(
248 type: cast<gpu::MMAMatrixType>(Val: subgroupMmaComputeOp.getOpB().getType()));
249 if (bElementType != sourceType)
250 return rewriter.notifyMatchFailure(
251 arg&: op, msg: "WMMA compute op input matrix element types must match.");
252
253 unpackOp(adaptor.getOpA());
254 unpackOp(adaptor.getOpB());
255 unpackOp(adaptor.getOpC());
256
257 rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
258 op, args: adaptor.getOpC().getType(), args&: m, args&: n, args&: k, args&: aLayout, args&: bLayout, args&: sourceType,
259 args&: destType, args&: unpackedOps);
260 return success();
261 }
262};
263
264/// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
265struct WmmaConstantOpToNVVMLowering
266 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
267 using ConvertOpToLLVMPattern<
268 gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
269
270 LogicalResult
271 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
272 OpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override {
274 if (failed(Result: areAllLLVMTypes(op: subgroupMmaConstantOp.getOperation(),
275 operands: adaptor.getOperands(), rewriter)))
276 return failure();
277 Location loc = subgroupMmaConstantOp.getLoc();
278 Value cst = adaptor.getOperands()[0];
279 LLVM::LLVMStructType type = convertMMAToLLVMType(
280 type: cast<gpu::MMAMatrixType>(Val: subgroupMmaConstantOp.getType()));
281 // If the element type is a vector create a vector from the operand.
282 if (auto vecType = dyn_cast<VectorType>(Val: type.getBody()[0])) {
283 Value vecCst = rewriter.create<LLVM::PoisonOp>(location: loc, args&: vecType);
284 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
285 Value idx = rewriter.create<LLVM::ConstantOp>(
286 location: loc, args: rewriter.getI32Type(), args&: vecEl);
287 vecCst = rewriter.create<LLVM::InsertElementOp>(location: loc, args&: vecType, args&: vecCst,
288 args&: cst, args&: idx);
289 }
290 cst = vecCst;
291 }
292 Value matrixStruct = rewriter.create<LLVM::PoisonOp>(location: loc, args&: type);
293 for (size_t i : llvm::seq(Begin: size_t(0), End: type.getBody().size())) {
294 matrixStruct =
295 rewriter.create<LLVM::InsertValueOp>(location: loc, args&: matrixStruct, args&: cst, args&: i);
296 }
297 rewriter.replaceOp(op: subgroupMmaConstantOp, newValues: matrixStruct);
298 return success();
299 }
300};
301
302static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
303 Value rhs, bool isMin) {
304 auto floatType = cast<FloatType>(Val: getElementTypeOrSelf(type: lhs.getType()));
305 Type i1Type = builder.getI1Type();
306 if (auto vecType = dyn_cast<VectorType>(Val: lhs.getType()))
307 i1Type = VectorType::get(shape: vecType.getShape(), elementType: i1Type);
308 Value cmp = builder.create<LLVM::FCmpOp>(
309 location: loc, args&: i1Type, args: isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
310 args&: lhs, args&: rhs);
311 Value sel = builder.create<LLVM::SelectOp>(location: loc, args&: cmp, args&: lhs, args&: rhs);
312 Value isNan = builder.create<LLVM::FCmpOp>(
313 location: loc, args&: i1Type, args: LLVM::FCmpPredicate::uno, args&: lhs, args&: rhs);
314 Value nan = builder.create<LLVM::ConstantOp>(
315 location: loc, args: lhs.getType(),
316 args: builder.getFloatAttr(type: floatType,
317 value: APFloat::getQNaN(Sem: floatType.getFloatSemantics())));
318 return builder.create<LLVM::SelectOp>(location: loc, args&: isNan, args&: nan, args&: sel);
319}
320
321static Value createScalarOp(OpBuilder &builder, Location loc,
322 gpu::MMAElementwiseOp op,
323 ArrayRef<Value> operands) {
324 switch (op) {
325 case gpu::MMAElementwiseOp::ADDF:
326 return builder.create<LLVM::FAddOp>(location: loc, args: operands[0].getType(), args&: operands);
327 case gpu::MMAElementwiseOp::MULF:
328 return builder.create<LLVM::FMulOp>(location: loc, args: operands[0].getType(), args&: operands);
329 case gpu::MMAElementwiseOp::DIVF:
330 return builder.create<LLVM::FDivOp>(location: loc, args: operands[0].getType(), args&: operands);
331 case gpu::MMAElementwiseOp::MAXF:
332 return createMinMaxF(builder, loc, lhs: operands[0], rhs: operands[1],
333 /*isMin=*/false);
334 case gpu::MMAElementwiseOp::MINF:
335 return createMinMaxF(builder, loc, lhs: operands[0], rhs: operands[1],
336 /*isMin=*/true);
337 default:
338 llvm_unreachable("unknown op");
339 }
340}
341
342/// Convert GPU MMA elementwise ops to extract + op + insert.
343struct WmmaElementwiseOpToNVVMLowering
344 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
345 using ConvertOpToLLVMPattern<
346 gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
347
348 LogicalResult
349 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
350 OpAdaptor adaptor,
351 ConversionPatternRewriter &rewriter) const override {
352 if (failed(Result: areAllLLVMTypes(op: subgroupMmaElementwiseOp.getOperation(),
353 operands: adaptor.getOperands(), rewriter)))
354 return failure();
355 Location loc = subgroupMmaElementwiseOp.getLoc();
356 size_t numOperands = adaptor.getOperands().size();
357 LLVM::LLVMStructType destType = convertMMAToLLVMType(
358 type: cast<gpu::MMAMatrixType>(Val: subgroupMmaElementwiseOp.getType()));
359 Value matrixStruct = rewriter.create<LLVM::PoisonOp>(location: loc, args&: destType);
360 for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
361 SmallVector<Value> extractedOperands;
362 for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
363 extractedOperands.push_back(Elt: rewriter.create<LLVM::ExtractValueOp>(
364 location: loc, args: adaptor.getOperands()[opIdx], args&: i));
365 }
366 Value element =
367 createScalarOp(builder&: rewriter, loc, op: subgroupMmaElementwiseOp.getOpType(),
368 operands: extractedOperands);
369 matrixStruct =
370 rewriter.create<LLVM::InsertValueOp>(location: loc, args&: matrixStruct, args&: element, args&: i);
371 }
372 rewriter.replaceOp(op: subgroupMmaElementwiseOp, newValues: matrixStruct);
373 return success();
374 }
375};
376
377} // namespace
378
379/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
380LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
381 NVVM::MMAFrag frag = convertOperand(operandName: type.getOperand());
382 NVVM::MMATypes eltType = getElementType(type);
383 auto nRow = type.getShape()[0];
384 auto nCol = type.getShape()[1];
385 std::pair<Type, unsigned> typeInfo =
386 NVVM::inferMMAType(type: eltType, frag, nRow, nCol, context: type.getContext());
387 return LLVM::LLVMStructType::getLiteral(
388 context: type.getContext(), types: SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
389}
390
391void mlir::populateGpuWMMAToNVVMConversionPatterns(
392 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
393 PatternBenefit benefit) {
394 patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
395 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
396 WmmaElementwiseOpToNVVMLowering>(arg: converter, args&: benefit);
397}
398

source code of mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp