1//===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// NVVM Dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
15#include "mlir/Conversion/LLVMCommon/Pattern.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19#include "mlir/IR/TypeUtilities.h"
20
21using namespace mlir;
22
23namespace {
24
25/// Checks if all the operands of the op being lowered are of LLVM Types. The
26/// types are expected to be converted by the `LLVMTypeConverter` before the op
27/// is actually lowered. If the type of an operands is not already converted it
28/// hints a missing typeConversion and failure is returned in that case.
29static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
30 ConversionPatternRewriter &rewriter) {
31 if (!llvm::all_of(Range&: operands, P: [](Value value) {
32 return LLVM::isCompatibleType(type: value.getType());
33 })) {
34 return rewriter.notifyMatchFailure(
35 arg&: op, msg: "cannot convert if operands aren't of LLVM type.");
36 }
37
38 return success();
39}
40
41/// Error string to emit when an unimplemented WMMA variant is encountered.
42static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
43
44static NVVM::MMAFrag convertOperand(StringRef operandName) {
45 if (operandName == "AOp")
46 return NVVM::MMAFrag::a;
47 if (operandName == "BOp")
48 return NVVM::MMAFrag::b;
49 if (operandName == "COp")
50 return NVVM::MMAFrag::c;
51 llvm_unreachable("Unknown operand name");
52}
53
54static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
55 if (type.getElementType().isF16())
56 return NVVM::MMATypes::f16;
57 if (type.getElementType().isF32())
58 return type.getOperand() == "COp" ? NVVM::MMATypes::f32
59 : NVVM::MMATypes::tf32;
60
61 if (type.getElementType().isSignedInteger(8))
62 return NVVM::MMATypes::s8;
63 if (type.getElementType().isUnsignedInteger(8))
64 return NVVM::MMATypes::u8;
65 // Accumulator type is signless and implies signed.
66 if (type.getElementType().isInteger(32))
67 return NVVM::MMATypes::s32;
68 llvm_unreachable("Unsupported type");
69}
70
71/// This class implements the conversion of GPU MMA loadOp to wmma.load op
72/// in the NVVM dialect. The conversion not only emits the NVVM op but also
73/// emits code that is necessary to store the data in the destination memref
74/// after it has been loaded.
75struct WmmaLoadOpToNVVMLowering
76 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
77 using ConvertOpToLLVMPattern<
78 gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
79
80 LogicalResult
81 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
82 OpAdaptor adaptor,
83 ConversionPatternRewriter &rewriter) const override {
84 Operation *op = subgroupMmaLoadMatrixOp.getOperation();
85 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
86 return failure();
87
88 // Get the shape of the MMAMatrix type being returned. The shape will
89 // choose which intrinsic this op will be lowered to.
90 NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
91 ? NVVM::MMALayout::col
92 : NVVM::MMALayout::row;
93 gpu::MMAMatrixType retType =
94 cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
95 ArrayRef<int64_t> retTypeShape = retType.getShape();
96 int64_t m = 0;
97 int64_t n = 0;
98 int64_t k = 0;
99 NVVM::MMATypes eltype = getElementType(retType);
100 // NVVM intrinsics require to give mxnxk dimensions, infer the missing
101 // dimension based on the valid intrinsics available.
102 if (retType.getOperand() == "AOp") {
103 m = retTypeShape[0];
104 k = retTypeShape[1];
105 n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
106 } else if (retType.getOperand() == "BOp") {
107 k = retTypeShape[0];
108 n = retTypeShape[1];
109 m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
110 } else if (retType.getOperand() == "COp") {
111 m = retTypeShape[0];
112 n = retTypeShape[1];
113 k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
114 }
115 NVVM::MMAFrag frag = convertOperand(retType.getOperand());
116 // Check that there is an exisiting instruction for the combination we need.
117 if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
118 return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr);
119
120 Type resType = convertMMAToLLVMType(retType);
121 Location loc = op->getLoc();
122
123 // Create nvvm.mma_load op according to the operand types.
124 Value dataPtr = getStridedElementPtr(
125 rewriter, loc,
126 cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
127 adaptor.getSrcMemref(), adaptor.getIndices());
128
129 Value leadingDim = rewriter.create<LLVM::ConstantOp>(
130 loc, rewriter.getI32Type(),
131 subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
132 rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
133 op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag);
134 return success();
135 }
136};
137
138/// This class implements the conversion of GPU MMA storeOp to wmma.store op
139/// in the NVVM dialect. The conversion not only emits the NVVM op but also
140/// emits code that is necessary to unpack the data in the source and
141/// convert the data in the format that is needed by the NVVM op.
142struct WmmaStoreOpToNVVMLowering
143 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
144 using ConvertOpToLLVMPattern<
145 gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
146
147 LogicalResult
148 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
149 OpAdaptor adaptor,
150 ConversionPatternRewriter &rewriter) const override {
151 Operation *op = subgroupMmaStoreMatrixOp.getOperation();
152 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
153 return failure();
154
155 Location loc = op->getLoc();
156
157 SmallVector<Value, 4> storeOpOperands;
158 // Get the shape of the MMAMatrix type being stored. The shape will
159 // choose which intrinsic this op will be lowered to.
160 gpu::MMAMatrixType srcType =
161 cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
162 ArrayRef<int64_t> srcTypeShape = srcType.getShape();
163 NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
164 ? NVVM::MMALayout::col
165 : NVVM::MMALayout::row;
166 NVVM::MMATypes eltype = getElementType(srcType);
167 int64_t m = srcTypeShape[0];
168 int64_t n = srcTypeShape[1];
169 int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
170 if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
171 return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr);
172
173 auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
174 for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
175 Value toUse =
176 rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
177 storeOpOperands.push_back(Elt: toUse);
178 }
179
180 Value dataPtr = getStridedElementPtr(
181 rewriter, loc,
182 cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
183 adaptor.getDstMemref(), adaptor.getIndices());
184 Value leadingDim = rewriter.create<LLVM::ConstantOp>(
185 loc, rewriter.getI32Type(),
186 subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
187 rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>(
188 op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim);
189 return success();
190 }
191};
192
193/// This class implements the conversion of GPU MMA computeOp to wmma.mma op
194/// in the NVVM dialect.
195struct WmmaMmaOpToNVVMLowering
196 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
197 using ConvertOpToLLVMPattern<
198 gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
199
200 LogicalResult
201 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
202 OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter) const override {
204 Operation *op = subgroupMmaComputeOp.getOperation();
205 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
206 return failure();
207
208 Location loc = op->getLoc();
209
210 // The wmma.mma intrinsic in llvm requires the operands as individual
211 // values. So individual elements from the memrefs need to be extracted and
212 // then passed on to the intrinsic call. Emit llvm ops to extract individual
213 // values form lowered memrefs.
214 SmallVector<Value> unpackedOps;
215
216 auto unpackOp = [&](Value operand) {
217 auto structType = cast<LLVM::LLVMStructType>(operand.getType());
218 for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
219 Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
220 unpackedOps.push_back(Elt: toUse);
221 }
222 };
223
224 // Get the shapes of the MMAMatrix type being used. The shapes will
225 // choose which intrinsic this op will be lowered to.
226 gpu::MMAMatrixType aType =
227 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
228 ArrayRef<int64_t> aTypeShape = aType.getShape();
229 gpu::MMAMatrixType cType =
230 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
231 ArrayRef<int64_t> cTypeShape = cType.getShape();
232 int64_t m = cTypeShape[0];
233 int64_t n = cTypeShape[1];
234 int64_t k = aTypeShape[1];
235 NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
236 ? NVVM::MMALayout::col
237 : NVVM::MMALayout::row;
238 NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
239 ? NVVM::MMALayout::col
240 : NVVM::MMALayout::row;
241 NVVM::MMATypes sourceType = getElementType(aType);
242 NVVM::MMATypes destType = getElementType(cType);
243 if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
244 destType) == 0)
245 return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr);
246
247 NVVM::MMATypes bElementType = getElementType(
248 cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
249 if (bElementType != sourceType)
250 return rewriter.notifyMatchFailure(
251 arg&: op, msg: "WMMA compute op input matrix element types must match.");
252
253 unpackOp(adaptor.getOpA());
254 unpackOp(adaptor.getOpB());
255 unpackOp(adaptor.getOpC());
256
257 rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
258 op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
259 destType, unpackedOps);
260 return success();
261 }
262};
263
264/// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
265struct WmmaConstantOpToNVVMLowering
266 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
267 using ConvertOpToLLVMPattern<
268 gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
269
270 LogicalResult
271 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
272 OpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override {
274 if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(),
275 adaptor.getOperands(), rewriter)))
276 return failure();
277 Location loc = subgroupMmaConstantOp.getLoc();
278 Value cst = adaptor.getOperands()[0];
279 LLVM::LLVMStructType type = convertMMAToLLVMType(
280 cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
281 // If the element type is a vector create a vector from the operand.
282 if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
283 Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType);
284 for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
285 Value idx = rewriter.create<LLVM::ConstantOp>(
286 loc, rewriter.getI32Type(), vecEl);
287 vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
288 cst, idx);
289 }
290 cst = vecCst;
291 }
292 Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type);
293 for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
294 matrixStruct =
295 rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i);
296 }
297 rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
298 return success();
299 }
300};
301
302static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
303 Value rhs, bool isMin) {
304 auto floatType = cast<FloatType>(getElementTypeOrSelf(type: lhs.getType()));
305 Type i1Type = builder.getI1Type();
306 if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
307 i1Type = VectorType::get(vecType.getShape(), i1Type);
308 Value cmp = builder.create<LLVM::FCmpOp>(
309 loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
310 lhs, rhs);
311 Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
312 Value isNan = builder.create<LLVM::FCmpOp>(
313 loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
314 Value nan = builder.create<LLVM::ConstantOp>(
315 loc, lhs.getType(),
316 builder.getFloatAttr(floatType,
317 APFloat::getQNaN(floatType.getFloatSemantics())));
318 return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
319}
320
321static Value createScalarOp(OpBuilder &builder, Location loc,
322 gpu::MMAElementwiseOp op,
323 ArrayRef<Value> operands) {
324 switch (op) {
325 case gpu::MMAElementwiseOp::ADDF:
326 return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
327 case gpu::MMAElementwiseOp::MULF:
328 return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
329 case gpu::MMAElementwiseOp::DIVF:
330 return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
331 case gpu::MMAElementwiseOp::MAXF:
332 return createMinMaxF(builder, loc, lhs: operands[0], rhs: operands[1],
333 /*isMin=*/false);
334 case gpu::MMAElementwiseOp::MINF:
335 return createMinMaxF(builder, loc, lhs: operands[0], rhs: operands[1],
336 /*isMin=*/true);
337 default:
338 llvm_unreachable("unknown op");
339 }
340}
341
342/// Convert GPU MMA elementwise ops to extract + op + insert.
343struct WmmaElementwiseOpToNVVMLowering
344 : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
345 using ConvertOpToLLVMPattern<
346 gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
347
348 LogicalResult
349 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
350 OpAdaptor adaptor,
351 ConversionPatternRewriter &rewriter) const override {
352 if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
353 adaptor.getOperands(), rewriter)))
354 return failure();
355 Location loc = subgroupMmaElementwiseOp.getLoc();
356 size_t numOperands = adaptor.getOperands().size();
357 LLVM::LLVMStructType destType = convertMMAToLLVMType(
358 cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
359 Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType);
360 for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
361 SmallVector<Value> extractedOperands;
362 for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
363 extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
364 loc, adaptor.getOperands()[opIdx], i));
365 }
366 Value element =
367 createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
368 extractedOperands);
369 matrixStruct =
370 rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i);
371 }
372 rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
373 return success();
374 }
375};
376
377} // namespace
378
379/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
380LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
381 NVVM::MMAFrag frag = convertOperand(type.getOperand());
382 NVVM::MMATypes eltType = getElementType(type);
383 auto nRow = type.getShape()[0];
384 auto nCol = type.getShape()[1];
385 std::pair<Type, unsigned> typeInfo =
386 NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
387 return LLVM::LLVMStructType::getLiteral(
388 type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
389}
390
391void mlir::populateGpuWMMAToNVVMConversionPatterns(
392 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
393 PatternBenefit benefit) {
394 patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
395 WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
396 WmmaElementwiseOpToNVVMLowering>(arg: converter, args&: benefit);
397}
398

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp