1 | //===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains definitions of patterns to lower GPU Subgroup MMA ops to |
10 | // SPIRV Cooperative Matrix ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" |
15 | #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
19 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
20 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
21 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
22 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
23 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
24 | #include "mlir/IR/BuiltinAttributes.h" |
25 | #include "mlir/IR/BuiltinTypes.h" |
26 | #include "mlir/IR/TypeUtilities.h" |
27 | #include "mlir/IR/ValueRange.h" |
28 | #include "llvm/ADT/STLExtras.h" |
29 | #include "llvm/ADT/StringSwitch.h" |
30 | |
31 | #include <cassert> |
32 | |
33 | namespace mlir { |
34 | //===----------------------------------------------------------------------===// |
35 | // Patterns and helpers. |
36 | //===----------------------------------------------------------------------===// |
37 | |
38 | /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op |
39 | /// when the elementwise op directly supports with cooperative matrix type. |
40 | /// Returns false if cannot. |
41 | /// |
42 | /// See SPV_KHR_cooperative_matrix for supported elementwise ops. |
43 | static bool createElementwiseOp(ConversionPatternRewriter &builder, |
44 | gpu::SubgroupMmaElementwiseOp op, Type coopType, |
45 | ValueRange operands) { |
46 | assert((isa<spirv::CooperativeMatrixType>(coopType))); |
47 | |
48 | switch (op.getOpType()) { |
49 | case gpu::MMAElementwiseOp::ADDF: |
50 | builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands); |
51 | return true; |
52 | case gpu::MMAElementwiseOp::ADDI: |
53 | builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands); |
54 | return true; |
55 | case gpu::MMAElementwiseOp::SUBF: |
56 | builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands); |
57 | return true; |
58 | case gpu::MMAElementwiseOp::SUBI: |
59 | builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands); |
60 | return true; |
61 | case gpu::MMAElementwiseOp::DIVF: |
62 | builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands); |
63 | return true; |
64 | case gpu::MMAElementwiseOp::DIVS: |
65 | builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands); |
66 | return true; |
67 | case gpu::MMAElementwiseOp::DIVU: |
68 | builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands); |
69 | return true; |
70 | case gpu::MMAElementwiseOp::NEGATEF: |
71 | builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands); |
72 | return true; |
73 | case gpu::MMAElementwiseOp::NEGATES: |
74 | builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands); |
75 | return true; |
76 | case gpu::MMAElementwiseOp::EXTF: |
77 | builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands); |
78 | return true; |
79 | default: |
80 | break; |
81 | } |
82 | return false; |
83 | } |
84 | |
85 | bool allOperandsHaveSameCoopMatrixType(ValueRange operands) { |
86 | assert(!operands.empty()); |
87 | if (!llvm::all_equal( |
88 | Range: llvm::map_range(C&: operands, F: [](Value v) { return v.getType(); }))) |
89 | return false; |
90 | |
91 | return isa<spirv::CooperativeMatrixType>(Val: operands.front().getType()); |
92 | } |
93 | |
94 | namespace { |
95 | /// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative |
96 | /// matrix ops. |
97 | struct WmmaConstantOpToSPIRVLowering final |
98 | : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { |
99 | using OpConversionPattern::OpConversionPattern; |
100 | |
101 | LogicalResult |
102 | matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor, |
103 | ConversionPatternRewriter &rewriter) const override { |
104 | Value cst = llvm::getSingleElement(adaptor.getOperands()); |
105 | auto coopType = getTypeConverter()->convertType(op.getType()); |
106 | if (!coopType) |
107 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
108 | |
109 | rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst); |
110 | return success(); |
111 | } |
112 | }; |
113 | |
114 | /// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative |
115 | /// matrix ops. |
116 | struct WmmaExtractOpToSPIRVLowering final |
117 | : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> { |
118 | using OpConversionPattern::OpConversionPattern; |
119 | |
120 | LogicalResult |
121 | matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor, |
122 | ConversionPatternRewriter &rewriter) const override { |
123 | Value matrix = adaptor.getMatrix(); |
124 | auto coopType = |
125 | getTypeConverter()->convertType<spirv::CooperativeMatrixType>( |
126 | matrix.getType()); |
127 | if (!coopType) |
128 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
129 | |
130 | SmallVector<int32_t> intValues; |
131 | for (Value val : op.getIndices()) { |
132 | if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) { |
133 | intValues.push_back(static_cast<int32_t>(constOp.value())); |
134 | } else { |
135 | return rewriter.notifyMatchFailure(op, "indices must be constants"); |
136 | } |
137 | } |
138 | |
139 | Type elementType = coopType.getElementType(); |
140 | rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( |
141 | op, elementType, matrix, rewriter.getI32ArrayAttr(intValues)); |
142 | return success(); |
143 | } |
144 | }; |
145 | |
146 | /// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative |
147 | /// matrix ops. |
148 | struct WmmaInsertOpToSPIRVLowering final |
149 | : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> { |
150 | using OpConversionPattern::OpConversionPattern; |
151 | |
152 | LogicalResult |
153 | matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor, |
154 | ConversionPatternRewriter &rewriter) const override { |
155 | Value value = adaptor.getValue(); |
156 | Value matrix = adaptor.getMatrix(); |
157 | auto coopType = getTypeConverter()->convertType(matrix.getType()); |
158 | if (!coopType) |
159 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
160 | |
161 | SmallVector<int32_t> intValues; |
162 | for (Value val : op.getIndices()) { |
163 | if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) { |
164 | intValues.push_back(static_cast<int32_t>(constOp.value())); |
165 | } else { |
166 | return rewriter.notifyMatchFailure(op, "indices must be constants"); |
167 | } |
168 | } |
169 | |
170 | rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( |
171 | op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues)); |
172 | return success(); |
173 | } |
174 | }; |
175 | |
176 | /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
177 | /// the default case. |
178 | struct WmmaElementwiseOpToSPIRVDefaultLowering final |
179 | : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
180 | using OpConversionPattern::OpConversionPattern; |
181 | |
182 | LogicalResult |
183 | matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, |
184 | ConversionPatternRewriter &rewriter) const override { |
185 | // All operands should be of cooperative matrix types. |
186 | if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { |
187 | return rewriter.notifyMatchFailure(op, |
188 | "not all operands are coop matrices"); |
189 | } |
190 | |
191 | auto coopType = getTypeConverter()->convertType(op.getType()); |
192 | if (!coopType) |
193 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
194 | |
195 | return success( |
196 | createElementwiseOp(rewriter, op, coopType, adaptor.getOperands())); |
197 | } |
198 | }; |
199 | |
200 | /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for |
201 | /// matrix times scalar case. |
202 | struct WmmaElementwiseOpToSPIRVScalarMulLowering final |
203 | : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { |
204 | using OpConversionPattern::OpConversionPattern; |
205 | |
206 | LogicalResult |
207 | matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, |
208 | ConversionPatternRewriter &rewriter) const override { |
209 | if (adaptor.getOperands().size() != 2) |
210 | return failure(); |
211 | |
212 | // All operands should be of cooperative matrix types. |
213 | if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { |
214 | return rewriter.notifyMatchFailure(op, |
215 | "not all operands are coop matrices"); |
216 | } |
217 | |
218 | if (op.getOpType() != gpu::MMAElementwiseOp::MULF) |
219 | return failure(); |
220 | |
221 | // Use the original operands to check whether one of the operands is a splat |
222 | // scalar value. |
223 | Value lhs = op.getOperands().front(); |
224 | Value rhs = op.getOperands().back(); |
225 | Value splat = nullptr; |
226 | Value matrix = nullptr; |
227 | if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
228 | splat = adaptor.getOperands().front(); |
229 | matrix = adaptor.getOperands().back(); |
230 | } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) { |
231 | matrix = adaptor.getOperands().front(); |
232 | splat = adaptor.getOperands().back(); |
233 | } |
234 | if (!splat || !matrix) |
235 | return rewriter.notifyMatchFailure(op, "no splat operand"); |
236 | |
237 | // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops. |
238 | Value scalar; |
239 | auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>(); |
240 | if (!cc) { |
241 | return rewriter.notifyMatchFailure(op, |
242 | "splat is not a composite construct"); |
243 | } |
244 | |
245 | scalar = llvm::getSingleElement(cc.getConstituents()); |
246 | |
247 | auto coopType = getTypeConverter()->convertType(op.getType()); |
248 | if (!coopType) |
249 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
250 | rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>( |
251 | op, coopType, ValueRange{matrix, scalar}); |
252 | return success(); |
253 | } |
254 | }; |
255 | } // namespace |
256 | |
257 | //===----------------------------------------------------------------------===// |
258 | // SPV_KHR_cooperative_matrix |
259 | //===----------------------------------------------------------------------===// |
260 | |
261 | namespace khr { |
262 | namespace { |
263 | |
264 | /// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV |
265 | /// dialect. |
266 | struct WmmaLoadOpToSPIRVLowering final |
267 | : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> { |
268 | using OpConversionPattern::OpConversionPattern; |
269 | |
270 | LogicalResult |
271 | matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor, |
272 | ConversionPatternRewriter &rewriter) const override { |
273 | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
274 | Location loc = op->getLoc(); |
275 | |
276 | auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType()); |
277 | MemRefType memrefType = op.getSrcMemref().getType(); |
278 | Value bufferPtr = |
279 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getSrcMemref(), |
280 | indices: adaptor.getIndices(), loc, builder&: rewriter); |
281 | |
282 | auto coopType = |
283 | typeConverter.convertType<spirv::CooperativeMatrixType>(retType); |
284 | if (!coopType) |
285 | return rewriter.notifyMatchFailure(op, "type conversion failed"); |
286 | |
287 | int64_t stride = op.getLeadDimension().getSExtValue(); |
288 | IntegerType i32Type = rewriter.getI32Type(); |
289 | auto strideValue = rewriter.create<spirv::ConstantOp>( |
290 | loc, i32Type, IntegerAttr::get(i32Type, stride)); |
291 | |
292 | bool isColMajor = op.getTranspose().value_or(false); |
293 | auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor |
294 | : spirv::CooperativeMatrixLayoutKHR::RowMajor; |
295 | |
296 | rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>( |
297 | op, coopType, bufferPtr, strideValue, layout); |
298 | return success(); |
299 | } |
300 | }; |
301 | |
302 | /// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV |
303 | /// dialect. |
304 | struct WmmaStoreOpToSPIRVLowering final |
305 | : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> { |
306 | using OpConversionPattern::OpConversionPattern; |
307 | |
308 | LogicalResult |
309 | matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor, |
310 | ConversionPatternRewriter &rewriter) const override { |
311 | const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
312 | Location loc = op->getLoc(); |
313 | |
314 | auto memrefType = cast<MemRefType>(op.getDstMemref().getType()); |
315 | Value bufferPtr = |
316 | spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getDstMemref(), |
317 | indices: adaptor.getIndices(), loc, builder&: rewriter); |
318 | |
319 | int64_t stride = op.getLeadDimension().getSExtValue(); |
320 | IntegerType i32Type = rewriter.getI32Type(); |
321 | auto strideValue = rewriter.create<spirv::ConstantOp>( |
322 | loc, i32Type, IntegerAttr::get(i32Type, stride)); |
323 | |
324 | bool isColMajor = op.getTranspose().value_or(false); |
325 | auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor |
326 | : spirv::CooperativeMatrixLayoutKHR::RowMajor; |
327 | |
328 | rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>( |
329 | op, bufferPtr, adaptor.getSrc(), strideValue, layout); |
330 | return success(); |
331 | } |
332 | }; |
333 | |
334 | /// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV |
335 | /// dialect. |
336 | struct WmmaMmaOpToSPIRVLowering final |
337 | : OpConversionPattern<gpu::SubgroupMmaComputeOp> { |
338 | using OpConversionPattern::OpConversionPattern; |
339 | |
340 | LogicalResult |
341 | matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, |
342 | OpAdaptor adaptor, |
343 | ConversionPatternRewriter &rewriter) const override { |
344 | rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>( |
345 | subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(), |
346 | adaptor.getOpC()); |
347 | return success(); |
348 | } |
349 | }; |
350 | |
351 | } // namespace |
352 | } // namespace khr |
353 | } // namespace mlir |
354 | |
355 | void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( |
356 | const SPIRVTypeConverter &converter, RewritePatternSet &patterns) { |
357 | using namespace mlir; |
358 | MLIRContext *context = patterns.getContext(); |
359 | patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering, |
360 | khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering, |
361 | WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering, |
362 | WmmaElementwiseOpToSPIRVDefaultLowering>(arg: converter, args&: context); |
363 | // Give the following patterns higher benefit to prevail over the default one. |
364 | patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(arg: converter, args&: context, |
365 | /*benefit=*/args: 2); |
366 | } |
367 | |
368 | void mlir::populateMMAToSPIRVCoopMatrixTypeConversion( |
369 | mlir::SPIRVTypeConverter &typeConverter) { |
370 | typeConverter.addConversion(callback: [](gpu::MMAMatrixType type) { |
371 | ArrayRef<int64_t> retTypeShape = type.getShape(); |
372 | Type elementType = type.getElementType(); |
373 | auto use = |
374 | llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand()) |
375 | .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA) |
376 | .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB) |
377 | .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc); |
378 | |
379 | return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0], |
380 | retTypeShape[1], |
381 | spirv::Scope::Subgroup, use); |
382 | }); |
383 | } |
384 |
Definitions
- createElementwiseOp
- allOperandsHaveSameCoopMatrixType
- WmmaConstantOpToSPIRVLowering
- matchAndRewrite
- WmmaExtractOpToSPIRVLowering
- matchAndRewrite
- WmmaInsertOpToSPIRVLowering
- matchAndRewrite
- WmmaElementwiseOpToSPIRVDefaultLowering
- matchAndRewrite
- WmmaElementwiseOpToSPIRVScalarMulLowering
- matchAndRewrite
- WmmaLoadOpToSPIRVLowering
- matchAndRewrite
- WmmaStoreOpToSPIRVLowering
- matchAndRewrite
- WmmaMmaOpToSPIRVLowering
- matchAndRewrite
- populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns
Learn to use CMake with our Intro Training
Find out more