1 | //===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains definitions of patterns to lower GPU Subgroup MMA ops to |
10 | // SPIRV Cooperative Matrix ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" |
15 | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
19 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
20 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
21 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
22 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
23 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
24 | #include "mlir/IR/BuiltinAttributes.h" |
25 | #include "mlir/IR/BuiltinTypes.h" |
26 | #include "mlir/IR/TypeUtilities.h" |
27 | #include "mlir/IR/ValueRange.h" |
28 | #include "llvm/ADT/STLExtras.h" |
29 | #include "llvm/ADT/StringSwitch.h" |
30 | |
31 | #include <cassert> |
32 | |
33 | namespace mlir { |
34 | //===----------------------------------------------------------------------===// |
35 | // Patterns and helpers. |
36 | //===----------------------------------------------------------------------===// |
37 | |
38 | /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op |
39 | /// when the elementwise op directly supports with cooperative matrix type. |
40 | /// Returns false if cannot. |
41 | /// |
42 | /// See SPV_KHR_cooperative_matrix for supported elementwise ops. |
43 | static bool createElementwiseOp(ConversionPatternRewriter &builder, |
44 | gpu::SubgroupMmaElementwiseOp op, Type coopType, |
45 | ValueRange operands) { |
46 | assert((isa<spirv::CooperativeMatrixType>(coopType))); |
47 | |
48 | switch (op.getOpType()) { |
49 | case gpu::MMAElementwiseOp::ADDF: |
50 | builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands); |
51 | return true; |
52 | case gpu::MMAElementwiseOp::ADDI: |
53 | builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands); |
54 | return true; |
55 | case gpu::MMAElementwiseOp::SUBF: |
56 | builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands); |
57 | return true; |
58 | case gpu::MMAElementwiseOp::SUBI: |
59 | builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands); |
60 | return true; |
61 | case gpu::MMAElementwiseOp::DIVF: |
62 | builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands); |
63 | return true; |
64 | case gpu::MMAElementwiseOp::DIVS: |
65 | builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands); |
66 | return true; |
67 | case gpu::MMAElementwiseOp::DIVU: |
68 | builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands); |
69 | return true; |
70 | case gpu::MMAElementwiseOp::NEGATEF: |
71 | builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands); |
72 | return true; |
73 | case gpu::MMAElementwiseOp::NEGATES: |
74 | builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands); |
75 | return true; |
76 | case gpu::MMAElementwiseOp::EXTF: |
77 | builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands); |
78 | return true; |
79 | default: |
80 | break; |
81 | } |
82 | return false; |
83 | } |
84 | |
85 | bool allOperandsHaveSameCoopMatrixType(ValueRange operands) { |
86 | assert(!operands.empty()); |
87 | if (!llvm::all_equal( |
88 | Range: llvm::map_range(C&: operands, F: [](Value v) { return v.getType(); }))) |
89 | return false; |
90 | |
91 | return isa<spirv::CooperativeMatrixType>(Val: operands.front().getType()); |
92 | } |
93 | |
94 | namespace { |
95 | /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative |
96 | /// matrix ops. |
97 | struct WmmaConstantOpToSPIRVLowering final |
98 | : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { |
99 | using OpConversionPattern::OpConversionPattern; |
100 | |
101 | LogicalResult |
102 | matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor, |
103 | ConversionPatternRewriter &rewriter) const override { |
104 | assert(adaptor.getOperands().size() == 1); |
105 | Value cst = adaptor.getOperands().front(); |
106 | auto coopType = getTypeConverter()->convertType(op.getType()); |
107 | if (!coopType) |
108 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
109 | |
110 | rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst); |
111 | return success(); |
112 | } |
113 | }; |
114 | |
115 | /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
116 | /// the default case. |
117 | struct WmmaElementwiseOpToSPIRVDefaultLowering final |
118 | : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
119 | using OpConversionPattern::OpConversionPattern; |
120 | |
121 | LogicalResult |
122 | matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, |
123 | ConversionPatternRewriter &rewriter) const override { |
124 | // All operands should be of cooperative matrix types. |
125 | if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { |
126 | return rewriter.notifyMatchFailure(op, |
127 | "not all operands are coop matrices" ); |
128 | } |
129 | |
130 | auto coopType = getTypeConverter()->convertType(op.getType()); |
131 | if (!coopType) |
132 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
133 | |
134 | return success( |
135 | createElementwiseOp(rewriter, op, coopType, adaptor.getOperands())); |
136 | } |
137 | }; |
138 | |
139 | /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
140 | /// matrix times scalar case. |
141 | struct WmmaElementwiseOpToSPIRVScalarMulLowering final |
142 | : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
143 | using OpConversionPattern::OpConversionPattern; |
144 | |
145 | LogicalResult |
146 | matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, |
147 | ConversionPatternRewriter &rewriter) const override { |
148 | if (adaptor.getOperands().size() != 2) |
149 | return failure(); |
150 | |
151 | // All operands should be of cooperative matrix types. |
152 | if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { |
153 | return rewriter.notifyMatchFailure(op, |
154 | "not all operands are coop matrices" ); |
155 | } |
156 | |
157 | if (op.getOpType() != gpu::MMAElementwiseOp::MULF) |
158 | return failure(); |
159 | |
160 | // Use the original operands to check whether one of the operands is a splat |
161 | // scalar value. |
162 | Value lhs = op.getOperands().front(); |
163 | Value rhs = op.getOperands().back(); |
164 | Value splat = nullptr; |
165 | Value matrix = nullptr; |
166 | if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
167 | splat = adaptor.getOperands().front(); |
168 | matrix = adaptor.getOperands().back(); |
169 | } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
170 | matrix = adaptor.getOperands().front(); |
171 | splat = adaptor.getOperands().back(); |
172 | } |
173 | if (!splat || !matrix) |
174 | return rewriter.notifyMatchFailure(op, "no splat operand" ); |
175 | |
176 | // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops. |
177 | Value scalar; |
178 | auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>(); |
179 | if (!cc) { |
180 | return rewriter.notifyMatchFailure(op, |
181 | "splat is not a composite construct" ); |
182 | } |
183 | |
184 | assert(cc.getConstituents().size() == 1); |
185 | scalar = cc.getConstituents().front(); |
186 | |
187 | auto coopType = getTypeConverter()->convertType(op.getType()); |
188 | if (!coopType) |
189 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
190 | rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>( |
191 | op, coopType, ValueRange{matrix, scalar}); |
192 | return success(); |
193 | } |
194 | }; |
195 | } // namespace |
196 | |
197 | //===----------------------------------------------------------------------===// |
198 | // SPV_KHR_cooperative_matrix |
199 | //===----------------------------------------------------------------------===// |
200 | |
201 | namespace khr { |
202 | namespace { |
203 | |
204 | /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV |
205 | /// dialect. |
206 | struct WmmaLoadOpToSPIRVLowering final |
207 | : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> { |
208 | using OpConversionPattern::OpConversionPattern; |
209 | |
210 | LogicalResult |
211 | matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor, |
212 | ConversionPatternRewriter &rewriter) const override { |
213 | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
214 | Location loc = op->getLoc(); |
215 | |
216 | auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType()); |
217 | MemRefType memrefType = op.getSrcMemref().getType(); |
218 | Value bufferPtr = |
219 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getSrcMemref(), |
220 | indices: adaptor.getIndices(), loc, builder&: rewriter); |
221 | |
222 | auto coopType = |
223 | typeConverter.convertType<spirv::CooperativeMatrixType>(retType); |
224 | if (!coopType) |
225 | return rewriter.notifyMatchFailure(op, "type conversion failed" ); |
226 | |
227 | int64_t stride = op.getLeadDimension().getSExtValue(); |
228 | IntegerType i32Type = rewriter.getI32Type(); |
229 | auto strideValue = rewriter.create<spirv::ConstantOp>( |
230 | loc, i32Type, IntegerAttr::get(i32Type, stride)); |
231 | |
232 | bool isColMajor = op.getTranspose().value_or(false); |
233 | auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor |
234 | : spirv::CooperativeMatrixLayoutKHR::RowMajor; |
235 | |
236 | rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>( |
237 | op, coopType, bufferPtr, strideValue, layout); |
238 | return success(); |
239 | } |
240 | }; |
241 | |
242 | /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV |
243 | /// dialect. |
244 | struct WmmaStoreOpToSPIRVLowering final |
245 | : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> { |
246 | using OpConversionPattern::OpConversionPattern; |
247 | |
248 | LogicalResult |
249 | matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor, |
250 | ConversionPatternRewriter &rewriter) const override { |
251 | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
252 | Location loc = op->getLoc(); |
253 | |
254 | auto memrefType = cast<MemRefType>(op.getDstMemref().getType()); |
255 | Value bufferPtr = |
256 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getDstMemref(), |
257 | indices: adaptor.getIndices(), loc, builder&: rewriter); |
258 | |
259 | int64_t stride = op.getLeadDimension().getSExtValue(); |
260 | IntegerType i32Type = rewriter.getI32Type(); |
261 | auto strideValue = rewriter.create<spirv::ConstantOp>( |
262 | loc, i32Type, IntegerAttr::get(i32Type, stride)); |
263 | |
264 | bool isColMajor = op.getTranspose().value_or(false); |
265 | auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor |
266 | : spirv::CooperativeMatrixLayoutKHR::RowMajor; |
267 | |
268 | rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>( |
269 | op, bufferPtr, adaptor.getSrc(), strideValue, layout); |
270 | return success(); |
271 | } |
272 | }; |
273 | |
274 | /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV |
275 | /// dialect. |
276 | struct WmmaMmaOpToSPIRVLowering final |
277 | : OpConversionPattern<gpu::SubgroupMmaComputeOp> { |
278 | using OpConversionPattern::OpConversionPattern; |
279 | |
280 | LogicalResult |
281 | matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, |
282 | OpAdaptor adaptor, |
283 | ConversionPatternRewriter &rewriter) const override { |
284 | rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>( |
285 | subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(), |
286 | adaptor.getOpC()); |
287 | return success(); |
288 | } |
289 | }; |
290 | |
291 | } // namespace |
292 | } // namespace khr |
293 | } // namespace mlir |
294 | |
295 | void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( |
296 | SPIRVTypeConverter &converter, RewritePatternSet &patterns) { |
297 | using namespace mlir; |
298 | MLIRContext *context = patterns.getContext(); |
299 | patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering, |
300 | khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering, |
301 | WmmaElementwiseOpToSPIRVDefaultLowering>(arg&: converter, args&: context); |
302 | // Give the following patterns higher benefit to prevail over the default one. |
303 | patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(arg&: converter, args&: context, |
304 | /*benefit=*/args: 2); |
305 | } |
306 | |
307 | void mlir::populateMMAToSPIRVCoopMatrixTypeConversion( |
308 | mlir::SPIRVTypeConverter &typeConverter) { |
309 | typeConverter.addConversion(callback: [](gpu::MMAMatrixType type) { |
310 | ArrayRef<int64_t> retTypeShape = type.getShape(); |
311 | Type elementType = type.getElementType(); |
312 | auto use = |
313 | llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand()) |
314 | .Case("AOp" , spirv::CooperativeMatrixUseKHR::MatrixA) |
315 | .Case("BOp" , spirv::CooperativeMatrixUseKHR::MatrixB) |
316 | .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc); |
317 | |
318 | return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0], |
319 | retTypeShape[1], |
320 | spirv::Scope::Subgroup, use); |
321 | }); |
322 | } |
323 | |