1 | //===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains definitions of patterns to lower GPU Subgroup MMA ops to |
10 | // NVVM Dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
15 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
18 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
19 | #include "mlir/IR/TypeUtilities.h" |
20 | |
21 | using namespace mlir; |
22 | |
23 | namespace { |
24 | |
25 | /// Checks if all the operands of the op being lowered are of LLVM Types. The |
26 | /// types are expected to be converted by the `LLVMTypeConverter` before the op |
27 | /// is actually lowered. If the type of an operands is not already converted it |
28 | /// hints a missing typeConversion and failure is returned in that case. |
29 | static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, |
30 | ConversionPatternRewriter &rewriter) { |
31 | if (!llvm::all_of(Range&: operands, P: [](Value value) { |
32 | return LLVM::isCompatibleType(type: value.getType()); |
33 | })) { |
34 | return rewriter.notifyMatchFailure( |
35 | arg&: op, msg: "cannot convert if operands aren't of LLVM type."); |
36 | } |
37 | |
38 | return success(); |
39 | } |
40 | |
41 | /// Error string to emit when an unimplemented WMMA variant is encountered. |
42 | static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant."; |
43 | |
44 | static NVVM::MMAFrag convertOperand(StringRef operandName) { |
45 | if (operandName == "AOp") |
46 | return NVVM::MMAFrag::a; |
47 | if (operandName == "BOp") |
48 | return NVVM::MMAFrag::b; |
49 | if (operandName == "COp") |
50 | return NVVM::MMAFrag::c; |
51 | llvm_unreachable("Unknown operand name"); |
52 | } |
53 | |
54 | static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { |
55 | if (type.getElementType().isF16()) |
56 | return NVVM::MMATypes::f16; |
57 | if (type.getElementType().isF32()) |
58 | return type.getOperand() == "COp"? NVVM::MMATypes::f32 |
59 | : NVVM::MMATypes::tf32; |
60 | |
61 | if (type.getElementType().isSignedInteger(8)) |
62 | return NVVM::MMATypes::s8; |
63 | if (type.getElementType().isUnsignedInteger(8)) |
64 | return NVVM::MMATypes::u8; |
65 | // Accumulator type is signless and implies signed. |
66 | if (type.getElementType().isInteger(32)) |
67 | return NVVM::MMATypes::s32; |
68 | llvm_unreachable("Unsupported type"); |
69 | } |
70 | |
71 | /// This class implements the conversion of GPU MMA loadOp to wmma.load op |
72 | /// in the NVVM dialect. The conversion not only emits the NVVM op but also |
73 | /// emits code that is necessary to store the data in the destination memref |
74 | /// after it has been loaded. |
75 | struct WmmaLoadOpToNVVMLowering |
76 | : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> { |
77 | using ConvertOpToLLVMPattern< |
78 | gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern; |
79 | |
80 | LogicalResult |
81 | matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, |
82 | OpAdaptor adaptor, |
83 | ConversionPatternRewriter &rewriter) const override { |
84 | Operation *op = subgroupMmaLoadMatrixOp.getOperation(); |
85 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) |
86 | return failure(); |
87 | |
88 | // Get the shape of the MMAMatrix type being returned. The shape will |
89 | // choose which intrinsic this op will be lowered to. |
90 | NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose() |
91 | ? NVVM::MMALayout::col |
92 | : NVVM::MMALayout::row; |
93 | gpu::MMAMatrixType retType = |
94 | cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType()); |
95 | ArrayRef<int64_t> retTypeShape = retType.getShape(); |
96 | int64_t m = 0; |
97 | int64_t n = 0; |
98 | int64_t k = 0; |
99 | NVVM::MMATypes eltype = getElementType(retType); |
100 | // NVVM intrinsics require to give mxnxk dimensions, infer the missing |
101 | // dimension based on the valid intrinsics available. |
102 | if (retType.getOperand() == "AOp") { |
103 | m = retTypeShape[0]; |
104 | k = retTypeShape[1]; |
105 | n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype); |
106 | } else if (retType.getOperand() == "BOp") { |
107 | k = retTypeShape[0]; |
108 | n = retTypeShape[1]; |
109 | m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype); |
110 | } else if (retType.getOperand() == "COp") { |
111 | m = retTypeShape[0]; |
112 | n = retTypeShape[1]; |
113 | k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype); |
114 | } |
115 | NVVM::MMAFrag frag = convertOperand(retType.getOperand()); |
116 | // Check that there is an exisiting instruction for the combination we need. |
117 | if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) |
118 | return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr); |
119 | |
120 | Type resType = convertMMAToLLVMType(retType); |
121 | Location loc = op->getLoc(); |
122 | |
123 | // Create nvvm.mma_load op according to the operand types. |
124 | Value dataPtr = getStridedElementPtr( |
125 | rewriter, loc, |
126 | cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()), |
127 | adaptor.getSrcMemref(), adaptor.getIndices()); |
128 | |
129 | Value leadingDim = rewriter.create<LLVM::ConstantOp>( |
130 | loc, rewriter.getI32Type(), |
131 | subgroupMmaLoadMatrixOp.getLeadDimensionAttr()); |
132 | rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>( |
133 | op, resType, dataPtr, leadingDim, m, n, k, layout, eltype, frag); |
134 | return success(); |
135 | } |
136 | }; |
137 | |
138 | /// This class implements the conversion of GPU MMA storeOp to wmma.store op |
139 | /// in the NVVM dialect. The conversion not only emits the NVVM op but also |
140 | /// emits code that is necessary to unpack the data in the source and |
141 | /// convert the data in the format that is needed by the NVVM op. |
142 | struct WmmaStoreOpToNVVMLowering |
143 | : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> { |
144 | using ConvertOpToLLVMPattern< |
145 | gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern; |
146 | |
147 | LogicalResult |
148 | matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, |
149 | OpAdaptor adaptor, |
150 | ConversionPatternRewriter &rewriter) const override { |
151 | Operation *op = subgroupMmaStoreMatrixOp.getOperation(); |
152 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) |
153 | return failure(); |
154 | |
155 | Location loc = op->getLoc(); |
156 | |
157 | SmallVector<Value, 4> storeOpOperands; |
158 | // Get the shape of the MMAMatrix type being stored. The shape will |
159 | // choose which intrinsic this op will be lowered to. |
160 | gpu::MMAMatrixType srcType = |
161 | cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType()); |
162 | ArrayRef<int64_t> srcTypeShape = srcType.getShape(); |
163 | NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose() |
164 | ? NVVM::MMALayout::col |
165 | : NVVM::MMALayout::row; |
166 | NVVM::MMATypes eltype = getElementType(srcType); |
167 | int64_t m = srcTypeShape[0]; |
168 | int64_t n = srcTypeShape[1]; |
169 | int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype); |
170 | if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0) |
171 | return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr); |
172 | |
173 | auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType()); |
174 | for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { |
175 | Value toUse = |
176 | rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i); |
177 | storeOpOperands.push_back(Elt: toUse); |
178 | } |
179 | |
180 | Value dataPtr = getStridedElementPtr( |
181 | rewriter, loc, |
182 | cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()), |
183 | adaptor.getDstMemref(), adaptor.getIndices()); |
184 | Value leadingDim = rewriter.create<LLVM::ConstantOp>( |
185 | loc, rewriter.getI32Type(), |
186 | subgroupMmaStoreMatrixOp.getLeadDimensionAttr()); |
187 | rewriter.replaceOpWithNewOp<NVVM::WMMAStoreOp>( |
188 | op, dataPtr, m, n, k, layout, eltype, storeOpOperands, leadingDim); |
189 | return success(); |
190 | } |
191 | }; |
192 | |
193 | /// This class implements the conversion of GPU MMA computeOp to wmma.mma op |
194 | /// in the NVVM dialect. |
195 | struct WmmaMmaOpToNVVMLowering |
196 | : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> { |
197 | using ConvertOpToLLVMPattern< |
198 | gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern; |
199 | |
200 | LogicalResult |
201 | matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, |
202 | OpAdaptor adaptor, |
203 | ConversionPatternRewriter &rewriter) const override { |
204 | Operation *op = subgroupMmaComputeOp.getOperation(); |
205 | if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) |
206 | return failure(); |
207 | |
208 | Location loc = op->getLoc(); |
209 | |
210 | // The wmma.mma intrinsic in llvm requires the operands as individual |
211 | // values. So individual elements from the memrefs need to be extracted and |
212 | // then passed on to the intrinsic call. Emit llvm ops to extract individual |
213 | // values form lowered memrefs. |
214 | SmallVector<Value> unpackedOps; |
215 | |
216 | auto unpackOp = [&](Value operand) { |
217 | auto structType = cast<LLVM::LLVMStructType>(operand.getType()); |
218 | for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { |
219 | Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i); |
220 | unpackedOps.push_back(Elt: toUse); |
221 | } |
222 | }; |
223 | |
224 | // Get the shapes of the MMAMatrix type being used. The shapes will |
225 | // choose which intrinsic this op will be lowered to. |
226 | gpu::MMAMatrixType aType = |
227 | cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType()); |
228 | ArrayRef<int64_t> aTypeShape = aType.getShape(); |
229 | gpu::MMAMatrixType cType = |
230 | cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType()); |
231 | ArrayRef<int64_t> cTypeShape = cType.getShape(); |
232 | int64_t m = cTypeShape[0]; |
233 | int64_t n = cTypeShape[1]; |
234 | int64_t k = aTypeShape[1]; |
235 | NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose() |
236 | ? NVVM::MMALayout::col |
237 | : NVVM::MMALayout::row; |
238 | NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose() |
239 | ? NVVM::MMALayout::col |
240 | : NVVM::MMALayout::row; |
241 | NVVM::MMATypes sourceType = getElementType(aType); |
242 | NVVM::MMATypes destType = getElementType(cType); |
243 | if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType, |
244 | destType) == 0) |
245 | return rewriter.notifyMatchFailure(arg&: op, msg: kInvalidCaseStr); |
246 | |
247 | NVVM::MMATypes bElementType = getElementType( |
248 | cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType())); |
249 | if (bElementType != sourceType) |
250 | return rewriter.notifyMatchFailure( |
251 | arg&: op, msg: "WMMA compute op input matrix element types must match."); |
252 | |
253 | unpackOp(adaptor.getOpA()); |
254 | unpackOp(adaptor.getOpB()); |
255 | unpackOp(adaptor.getOpC()); |
256 | |
257 | rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>( |
258 | op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType, |
259 | destType, unpackedOps); |
260 | return success(); |
261 | } |
262 | }; |
263 | |
264 | /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp. |
265 | struct WmmaConstantOpToNVVMLowering |
266 | : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> { |
267 | using ConvertOpToLLVMPattern< |
268 | gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern; |
269 | |
270 | LogicalResult |
271 | matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp, |
272 | OpAdaptor adaptor, |
273 | ConversionPatternRewriter &rewriter) const override { |
274 | if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), |
275 | adaptor.getOperands(), rewriter))) |
276 | return failure(); |
277 | Location loc = subgroupMmaConstantOp.getLoc(); |
278 | Value cst = adaptor.getOperands()[0]; |
279 | LLVM::LLVMStructType type = convertMMAToLLVMType( |
280 | cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType())); |
281 | // If the element type is a vector create a vector from the operand. |
282 | if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) { |
283 | Value vecCst = rewriter.create<LLVM::PoisonOp>(loc, vecType); |
284 | for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { |
285 | Value idx = rewriter.create<LLVM::ConstantOp>( |
286 | loc, rewriter.getI32Type(), vecEl); |
287 | vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst, |
288 | cst, idx); |
289 | } |
290 | cst = vecCst; |
291 | } |
292 | Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, type); |
293 | for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { |
294 | matrixStruct = |
295 | rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, cst, i); |
296 | } |
297 | rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct); |
298 | return success(); |
299 | } |
300 | }; |
301 | |
302 | static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, |
303 | Value rhs, bool isMin) { |
304 | auto floatType = cast<FloatType>(getElementTypeOrSelf(type: lhs.getType())); |
305 | Type i1Type = builder.getI1Type(); |
306 | if (auto vecType = dyn_cast<VectorType>(lhs.getType())) |
307 | i1Type = VectorType::get(vecType.getShape(), i1Type); |
308 | Value cmp = builder.create<LLVM::FCmpOp>( |
309 | loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, |
310 | lhs, rhs); |
311 | Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs); |
312 | Value isNan = builder.create<LLVM::FCmpOp>( |
313 | loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); |
314 | Value nan = builder.create<LLVM::ConstantOp>( |
315 | loc, lhs.getType(), |
316 | builder.getFloatAttr(floatType, |
317 | APFloat::getQNaN(floatType.getFloatSemantics()))); |
318 | return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel); |
319 | } |
320 | |
321 | static Value createScalarOp(OpBuilder &builder, Location loc, |
322 | gpu::MMAElementwiseOp op, |
323 | ArrayRef<Value> operands) { |
324 | switch (op) { |
325 | case gpu::MMAElementwiseOp::ADDF: |
326 | return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands); |
327 | case gpu::MMAElementwiseOp::MULF: |
328 | return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands); |
329 | case gpu::MMAElementwiseOp::DIVF: |
330 | return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands); |
331 | case gpu::MMAElementwiseOp::MAXF: |
332 | return createMinMaxF(builder, loc, lhs: operands[0], rhs: operands[1], |
333 | /*isMin=*/false); |
334 | case gpu::MMAElementwiseOp::MINF: |
335 | return createMinMaxF(builder, loc, lhs: operands[0], rhs: operands[1], |
336 | /*isMin=*/true); |
337 | default: |
338 | llvm_unreachable("unknown op"); |
339 | } |
340 | } |
341 | |
342 | /// Convert GPU MMA elementwise ops to extract + op + insert. |
343 | struct WmmaElementwiseOpToNVVMLowering |
344 | : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> { |
345 | using ConvertOpToLLVMPattern< |
346 | gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern; |
347 | |
348 | LogicalResult |
349 | matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp, |
350 | OpAdaptor adaptor, |
351 | ConversionPatternRewriter &rewriter) const override { |
352 | if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(), |
353 | adaptor.getOperands(), rewriter))) |
354 | return failure(); |
355 | Location loc = subgroupMmaElementwiseOp.getLoc(); |
356 | size_t numOperands = adaptor.getOperands().size(); |
357 | LLVM::LLVMStructType destType = convertMMAToLLVMType( |
358 | cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType())); |
359 | Value matrixStruct = rewriter.create<LLVM::PoisonOp>(loc, destType); |
360 | for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { |
361 | SmallVector<Value> extractedOperands; |
362 | for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { |
363 | extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( |
364 | loc, adaptor.getOperands()[opIdx], i)); |
365 | } |
366 | Value element = |
367 | createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(), |
368 | extractedOperands); |
369 | matrixStruct = |
370 | rewriter.create<LLVM::InsertValueOp>(loc, matrixStruct, element, i); |
371 | } |
372 | rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct); |
373 | return success(); |
374 | } |
375 | }; |
376 | |
377 | } // namespace |
378 | |
379 | /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. |
380 | LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { |
381 | NVVM::MMAFrag frag = convertOperand(type.getOperand()); |
382 | NVVM::MMATypes eltType = getElementType(type); |
383 | auto nRow = type.getShape()[0]; |
384 | auto nCol = type.getShape()[1]; |
385 | std::pair<Type, unsigned> typeInfo = |
386 | NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext()); |
387 | return LLVM::LLVMStructType::getLiteral( |
388 | type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); |
389 | } |
390 | |
391 | void mlir::populateGpuWMMAToNVVMConversionPatterns( |
392 | const LLVMTypeConverter &converter, RewritePatternSet &patterns, |
393 | PatternBenefit benefit) { |
394 | patterns.add<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering, |
395 | WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering, |
396 | WmmaElementwiseOpToNVVMLowering>(arg: converter, args&: benefit); |
397 | } |
398 |
Definitions
- areAllLLVMTypes
- kInvalidCaseStr
- convertOperand
- getElementType
- WmmaLoadOpToNVVMLowering
- matchAndRewrite
- WmmaStoreOpToNVVMLowering
- matchAndRewrite
- WmmaMmaOpToNVVMLowering
- matchAndRewrite
- WmmaConstantOpToNVVMLowering
- matchAndRewrite
- createMinMaxF
- createScalarOp
- WmmaElementwiseOpToNVVMLowering
- matchAndRewrite
- convertMMAToLLVMType
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more