1 | //===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains definitions of patterns to lower GPU Subgroup MMA ops to |
10 | // NVVM Dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
15 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
18 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
19 | #include "mlir/IR/TypeUtilities.h" |
20 | |
21 | using namespace mlir; |
22 | |
23 | namespace { |
24 | |
25 | /// Checks if all the operands of the op being lowered are of LLVM Types. The |
26 | /// types are expected to be converted by the `LLVMTypeConverter` before the op |
27 | /// is actually lowered. If the type of an operands is not already converted it |
28 | /// hints a missing typeConversion and failure is returned in that case. |
29 | static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, |
30 | ConversionPatternRewriter &rewriter) { |
31 | if (!llvm::all_of(Range&: operands, P: [](Value value) { |
32 | return LLVM::isCompatibleType(type: value.getType()); |
33 | })) { |
34 | return rewriter.notifyMatchFailure( |
35 | arg&: op, msg: "cannot convert if operands aren't of LLVM type." ); |
36 | } |
37 | |
38 | return success(); |
39 | } |
40 | |
41 | /// Error string to emit when an unimplemented WMMA variant is encountered. |
42 | static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant." ; |
43 | |
44 | static NVVM::MMAFrag convertOperand(StringRef operandName) { |
45 | if (operandName.equals("AOp" )) |
46 | return NVVM::MMAFrag::a; |
47 | if (operandName.equals("BOp" )) |
48 | return NVVM::MMAFrag::b; |
49 | if (operandName.equals("COp" )) |
50 | return NVVM::MMAFrag::c; |
51 | llvm_unreachable("Unknown operand name" ); |
52 | } |
53 | |
54 | static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { |
55 | if (type.getElementType().isF16()) |
56 | return NVVM::MMATypes::f16; |
57 | if (type.getElementType().isF32()) |
58 | return type.getOperand().equals("COp" ) ? NVVM::MMATypes::f32 |
59 | : NVVM::MMATypes::tf32; |
60 | |
61 | if (type.getElementType().isSignedInteger(8)) |
62 | return NVVM::MMATypes::s8; |
63 | if (type.getElementType().isUnsignedInteger(8)) |
64 | return NVVM::MMATypes::u8; |
65 | // Accumulator type is signless and implies signed. |
66 | if (type.getElementType().isInteger(32)) |
67 | return NVVM::MMATypes::s32; |
68 | llvm_unreachable("Unsupported type" ); |
69 | } |
70 | |
71 | /// This class implements the conversion of GPU MMA loadOp to wmma.load op |
72 | /// in the NVVM dialect. The conversion not only emits the NVVM op but also |
73 | /// emits code that is necessary to store the data in the destination memref |
74 | /// after it has been loaded. |
75 | struct WmmaLoadOpToNVVMLowering |
76 | : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> { |
77 | using ConvertOpToLLVMPattern< |
78 | gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern; |
79 | |
80 | LogicalResult |
81 | matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, |
82 | OpAdaptor adaptor, |
83 | ConversionPatternRewriter &rewriter) const override { |
84 | Operation *op = subgroupMmaLoadMatrixOp.getOperation(); |
85 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) |
86 | return failure(); |
87 | |
88 | // Get the shape of the MMAMatrix type being returned. The shape will |
89 | // choose which intrinsic this op will be lowered to. |
90 | NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose() |
91 | ? NVVM::MMALayout::col |
92 | : NVVM::MMALayout::row; |
93 | gpu::MMAMatrixType retType = |
94 | cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType()); |
95 | ArrayRef<int64_t> retTypeShape = retType.getShape(); |
96 | int64_t m = 0; |
97 | int64_t n = 0; |
98 | int64_t k = 0; |
99 | NVVM::MMATypes eltype = getElementType(retType); |
100 | // NVVM intrinsics require to give mxnxk dimensions, infer the missing |
101 | // dimension based on the valid intrinsics available. |
102 | if (retType.getOperand().equals(RHS: "AOp" )) { |
103 | m = retTypeShape[0]; |
104 | k = retTypeShape[1]; |
105 | n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype); |
106 | } else if (retType.getOperand().equals(RHS: "BOp" )) { |
107 | k = retTypeShape[0]; |
108 | n = retTypeShape[1]; |
109 | m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype); |
110 | } else if (retType.getOperand().equals(RHS: "COp" )) { |
111 | m = retTypeShape[0]; |
112 | n = retTypeShape[1]; |
113 | k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype); |
114 | } |
115 | NVVM::MMAFrag frag = convertOperand(retType.getOperand()); |
116 | // Check that there is an exisiting instruction for the combination we need. |
117 | if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) |
118 | return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr); |
119 | |
120 | Type resType = convertMMAToLLVMType(type: retType); |
121 | Location loc = op->getLoc(); |
122 | |
123 | // Create nvvm.mma_load op according to the operand types. |
124 | Value dataPtr = getStridedElementPtr( |
125 | loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()), |
126 | adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); |
127 | |
128 | Value leadingDim = rewriter.create<LLVM::ConstantOp>( |
129 | loc, rewriter.getI32Type(), |
130 | subgroupMmaLoadMatrixOp.getLeadDimensionAttr()); |
131 | rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>( |
132 | op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag); |
133 | return success(); |
134 | } |
135 | }; |
136 | |
137 | /// This class implements the conversion of GPU MMA storeOp to wmma.store op |
138 | /// in the NVVM dialect. The conversion not only emits the NVVM op but also |
139 | /// emits code that is necessary to unpack the data in the source and |
140 | /// convert the data in the format that is needed by the NVVM op. |
141 | struct WmmaStoreOpToNVVMLowering |
142 | : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> { |
143 | using ConvertOpToLLVMPattern< |
144 | gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern; |
145 | |
146 | LogicalResult |
147 | matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, |
148 | OpAdaptor adaptor, |
149 | ConversionPatternRewriter &rewriter) const override { |
150 | Operation *op = subgroupMmaStoreMatrixOp.getOperation(); |
151 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) |
152 | return failure(); |
153 | |
154 | Location loc = op->getLoc(); |
155 | |
156 | SmallVector<Value, 4> storeOpOperands; |
157 | // Get the shape of the MMAMatrix type being stored. The shape will |
158 | // choose which intrinsic this op will be lowered to. |
159 | gpu::MMAMatrixType srcType = |
160 | cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType()); |
161 | ArrayRef<int64_t> srcTypeShape = srcType.getShape(); |
162 | NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose() |
163 | ? NVVM::MMALayout::col |
164 | : NVVM::MMALayout::row; |
165 | NVVM::MMATypes eltype = getElementType(srcType); |
166 | int64_t m = srcTypeShape[0]; |
167 | int64_t n = srcTypeShape[1]; |
168 | int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype); |
169 | if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0) |
170 | return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr); |
171 | |
172 | auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType()); |
173 | for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { |
174 | Value toUse = |
175 | rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i); |
176 | storeOpOperands.push_back(Elt: toUse); |
177 | } |
178 | |
179 | Value dataPtr = getStridedElementPtr( |
180 | loc, |
181 | cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()), |
182 | adaptor.getDstMemref(), adaptor.getIndices(), rewriter); |
183 | Value leadingDim = rewriter.create<LLVM::ConstantOp>( |
184 | loc, rewriter.getI32Type(), |
185 | subgroupMmaStoreMatrixOp.getLeadDimensionAttr()); |
186 | rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>( |
187 | op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim); |
188 | return success(); |
189 | } |
190 | }; |
191 | |
192 | /// This class implements the conversion of GPU MMA computeOp to wmma.mma op |
193 | /// in the NVVM dialect. |
194 | struct WmmaMmaOpToNVVMLowering |
195 | : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> { |
196 | using ConvertOpToLLVMPattern< |
197 | gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern; |
198 | |
199 | LogicalResult |
200 | matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, |
201 | OpAdaptor adaptor, |
202 | ConversionPatternRewriter &rewriter) const override { |
203 | Operation *op = subgroupMmaComputeOp.getOperation(); |
204 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) |
205 | return failure(); |
206 | |
207 | Location loc = op->getLoc(); |
208 | |
209 | // The wmma.mma intrinsic in llvm requires the operands as individual |
210 | // values. So individual elements from the memrefs need to be extracted and |
211 | // then passed on to the intrinsic call. Emit llvm ops to extract individual |
212 | // values form lowered memrefs. |
213 | SmallVector<Value> unpackedOps; |
214 | |
215 | auto unpackOp = [&](Value operand) { |
216 | auto structType = cast<LLVM::LLVMStructType>(Val: operand.getType()); |
217 | for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { |
218 | Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i); |
219 | unpackedOps.push_back(Elt: toUse); |
220 | } |
221 | }; |
222 | |
223 | // Get the shapes of the MMAMatrix type being used. The shapes will |
224 | // choose which intrinsic this op will be lowered to. |
225 | gpu::MMAMatrixType aType = |
226 | cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType()); |
227 | ArrayRef<int64_t> aTypeShape = aType.getShape(); |
228 | gpu::MMAMatrixType cType = |
229 | cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType()); |
230 | ArrayRef<int64_t> cTypeShape = cType.getShape(); |
231 | int64_t m = cTypeShape[0]; |
232 | int64_t n = cTypeShape[1]; |
233 | int64_t k = aTypeShape[1]; |
234 | NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose() |
235 | ? NVVM::MMALayout::col |
236 | : NVVM::MMALayout::row; |
237 | NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose() |
238 | ? NVVM::MMALayout::col |
239 | : NVVM::MMALayout::row; |
240 | NVVM::MMATypes sourceType = getElementType(aType); |
241 | NVVM::MMATypes destType = getElementType(cType); |
242 | if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType, |
243 | destType) == 0) |
244 | return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr); |
245 | |
246 | NVVM::MMATypes bElementType = getElementType( |
247 | cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType())); |
248 | if (bElementType != sourceType) |
249 | return rewriter.notifyMatchFailure( |
250 | arg&: op, msg: "WMMA compute op input matrix element types must match." ); |
251 | |
252 | unpackOp(adaptor.getOpA()); |
253 | unpackOp(adaptor.getOpB()); |
254 | unpackOp(adaptor.getOpC()); |
255 | |
256 | rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>( |
257 | op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType, |
258 | destType, unpackedOps); |
259 | return success(); |
260 | } |
261 | }; |
262 | |
263 | /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp. |
264 | struct WmmaConstantOpToNVVMLowering |
265 | : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> { |
266 | using ConvertOpToLLVMPattern< |
267 | gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern; |
268 | |
269 | LogicalResult |
270 | matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp, |
271 | OpAdaptor adaptor, |
272 | ConversionPatternRewriter &rewriter) const override { |
273 | if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), |
274 | adaptor.getOperands(), rewriter))) |
275 | return failure(); |
276 | Location loc = subgroupMmaConstantOp.getLoc(); |
277 | Value cst = adaptor.getOperands()[0]; |
278 | LLVM::LLVMStructType type = convertMMAToLLVMType( |
279 | cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType())); |
280 | // If the element type is a vector create a vector from the operand. |
281 | if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) { |
282 | Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType); |
283 | for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { |
284 | Value idx = rewriter.create<LLVM::ConstantOp>( |
285 | loc, rewriter.getI32Type(), vecEl); |
286 | vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst, |
287 | cst, idx); |
288 | } |
289 | cst = vecCst; |
290 | } |
291 | Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type); |
292 | for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { |
293 | matrixStruct = |
294 | rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i); |
295 | } |
296 | rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct); |
297 | return success(); |
298 | } |
299 | }; |
300 | |
301 | static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, |
302 | Value rhs, bool isMin) { |
303 | auto floatType = cast<FloatType>(Val: getElementTypeOrSelf(type: lhs.getType())); |
304 | Type i1Type = builder.getI1Type(); |
305 | if (auto vecType = dyn_cast<VectorType>(lhs.getType())) |
306 | i1Type = VectorType::get(vecType.getShape(), i1Type); |
307 | Value cmp = builder.create<LLVM::FCmpOp>( |
308 | loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, |
309 | lhs, rhs); |
310 | Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); |
311 | Value isNan = builder.create<LLVM::FCmpOp>( |
312 | loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); |
313 | Value nan = builder.create<LLVM::ConstantOp>( |
314 | loc, lhs.getType(), |
315 | builder.getFloatAttr(floatType, |
316 | APFloat::getQNaN(floatType.getFloatSemantics()))); |
317 | return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); |
318 | } |
319 | |
320 | static Value createScalarOp(OpBuilder &builder, Location loc, |
321 | gpu::MMAElementwiseOp op, |
322 | ArrayRef<Value> operands) { |
323 | switch (op) { |
324 | case gpu::MMAElementwiseOp::ADDF: |
325 | return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands); |
326 | case gpu::MMAElementwiseOp::MULF: |
327 | return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands); |
328 | case gpu::MMAElementwiseOp::DIVF: |
329 | return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands); |
330 | case gpu::MMAElementwiseOp::MAXF: |
331 | return createMinMaxF(builder, loc, lhs: operands[0], rhs: operands[1], |
332 | /*isMin=*/false); |
333 | case gpu::MMAElementwiseOp::MINF: |
334 | return createMinMaxF(builder, loc, lhs: operands[0], rhs: operands[1], |
335 | /*isMin=*/true); |
336 | default: |
337 | llvm_unreachable("unknown op" ); |
338 | } |
339 | } |
340 | |
341 | /// Convert GPU MMA elementwise ops to extract + op + insert. |
342 | struct WmmaElementwiseOpToNVVMLowering |
343 | : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> { |
344 | using ConvertOpToLLVMPattern< |
345 | gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern; |
346 | |
347 | LogicalResult |
348 | matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp, |
349 | OpAdaptor adaptor, |
350 | ConversionPatternRewriter &rewriter) const override { |
351 | if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(), |
352 | adaptor.getOperands(), rewriter))) |
353 | return failure(); |
354 | Location loc = subgroupMmaElementwiseOp.getLoc(); |
355 | size_t numOperands = adaptor.getOperands().size(); |
356 | LLVM::LLVMStructType destType = convertMMAToLLVMType( |
357 | cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType())); |
358 | Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType); |
359 | for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { |
360 | SmallVector<Value> extractedOperands; |
361 | for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { |
362 | extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( |
363 | loc, adaptor.getOperands()[opIdx], i)); |
364 | } |
365 | Value element = |
366 | createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(), |
367 | extractedOperands); |
368 | matrixStruct = |
369 | rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i); |
370 | } |
371 | rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct); |
372 | return success(); |
373 | } |
374 | }; |
375 | |
376 | } // namespace |
377 | |
378 | /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. |
379 | LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { |
380 | NVVM::MMAFrag frag = convertOperand(type.getOperand()); |
381 | NVVM::MMATypes eltType = getElementType(type); |
382 | auto nRow = type.getShape()[0]; |
383 | auto nCol = type.getShape()[1]; |
384 | std::pair<Type, unsigned> typeInfo = |
385 | NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext()); |
386 | return LLVM::LLVMStructType::getLiteral( |
387 | context: type.getContext(), types: SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); |
388 | } |
389 | |
390 | void mlir::populateGpuWMMAToNVVMConversionPatterns( |
391 | LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
392 | patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering, |
393 | WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering, |
394 | WmmaElementwiseOpToNVVMLowering>(arg&: converter); |
395 | } |
396 | |