1//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// SPIRV Cooperative Matrix ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
15#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
22#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
23#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
24#include "mlir/IR/BuiltinAttributes.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "mlir/IR/ValueRange.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/StringSwitch.h"
30
31#include <cassert>
32
33namespace mlir {
34//===----------------------------------------------------------------------===//
35// Patterns and helpers.
36//===----------------------------------------------------------------------===//
37
38/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
39/// when the elementwise op directly supports with cooperative matrix type.
40/// Returns false if cannot.
41///
42/// See SPV_KHR_cooperative_matrix for supported elementwise ops.
43static bool createElementwiseOp(ConversionPatternRewriter &builder,
44 gpu::SubgroupMmaElementwiseOp op, Type coopType,
45 ValueRange operands) {
46 assert((isa<spirv::CooperativeMatrixType>(coopType)));
47
48 switch (op.getOpType()) {
49 case gpu::MMAElementwiseOp::ADDF:
50 builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
51 return true;
52 case gpu::MMAElementwiseOp::ADDI:
53 builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
54 return true;
55 case gpu::MMAElementwiseOp::SUBF:
56 builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
57 return true;
58 case gpu::MMAElementwiseOp::SUBI:
59 builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
60 return true;
61 case gpu::MMAElementwiseOp::DIVF:
62 builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
63 return true;
64 case gpu::MMAElementwiseOp::DIVS:
65 builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
66 return true;
67 case gpu::MMAElementwiseOp::DIVU:
68 builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
69 return true;
70 case gpu::MMAElementwiseOp::NEGATEF:
71 builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
72 return true;
73 case gpu::MMAElementwiseOp::NEGATES:
74 builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
75 return true;
76 case gpu::MMAElementwiseOp::EXTF:
77 builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
78 return true;
79 default:
80 break;
81 }
82 return false;
83}
84
85bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
86 assert(!operands.empty());
87 if (!llvm::all_equal(
88 Range: llvm::map_range(C&: operands, F: [](Value v) { return v.getType(); })))
89 return false;
90
91 return isa<spirv::CooperativeMatrixType>(Val: operands.front().getType());
92}
93
94namespace {
95/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
96/// matrix ops.
97struct WmmaConstantOpToSPIRVLowering final
98 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
99 using OpConversionPattern::OpConversionPattern;
100
101 LogicalResult
102 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103 ConversionPatternRewriter &rewriter) const override {
104 Value cst = llvm::getSingleElement(adaptor.getOperands());
105 auto coopType = getTypeConverter()->convertType(op.getType());
106 if (!coopType)
107 return rewriter.notifyMatchFailure(op, "type conversion failed");
108
109 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
110 return success();
111 }
112};
113
114/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
115/// matrix ops.
116struct WmmaExtractOpToSPIRVLowering final
117 : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
118 using OpConversionPattern::OpConversionPattern;
119
120 LogicalResult
121 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
122 ConversionPatternRewriter &rewriter) const override {
123 Value matrix = adaptor.getMatrix();
124 auto coopType =
125 getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
126 matrix.getType());
127 if (!coopType)
128 return rewriter.notifyMatchFailure(op, "type conversion failed");
129
130 SmallVector<int32_t> intValues;
131 for (Value val : op.getIndices()) {
132 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
133 intValues.push_back(static_cast<int32_t>(constOp.value()));
134 } else {
135 return rewriter.notifyMatchFailure(op, "indices must be constants");
136 }
137 }
138
139 Type elementType = coopType.getElementType();
140 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
141 op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
142 return success();
143 }
144};
145
146/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
147/// matrix ops.
148struct WmmaInsertOpToSPIRVLowering final
149 : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
150 using OpConversionPattern::OpConversionPattern;
151
152 LogicalResult
153 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter) const override {
155 Value value = adaptor.getValue();
156 Value matrix = adaptor.getMatrix();
157 auto coopType = getTypeConverter()->convertType(matrix.getType());
158 if (!coopType)
159 return rewriter.notifyMatchFailure(op, "type conversion failed");
160
161 SmallVector<int32_t> intValues;
162 for (Value val : op.getIndices()) {
163 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
164 intValues.push_back(static_cast<int32_t>(constOp.value()));
165 } else {
166 return rewriter.notifyMatchFailure(op, "indices must be constants");
167 }
168 }
169
170 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
171 op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
172 return success();
173 }
174};
175
176/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
177/// the default case.
178struct WmmaElementwiseOpToSPIRVDefaultLowering final
179 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
180 using OpConversionPattern::OpConversionPattern;
181
182 LogicalResult
183 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter) const override {
185 // All operands should be of cooperative matrix types.
186 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
187 return rewriter.notifyMatchFailure(op,
188 "not all operands are coop matrices");
189 }
190
191 auto coopType = getTypeConverter()->convertType(op.getType());
192 if (!coopType)
193 return rewriter.notifyMatchFailure(op, "type conversion failed");
194
195 return success(
196 createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
197 }
198};
199
200/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
201/// matrix times scalar case.
202struct WmmaElementwiseOpToSPIRVScalarMulLowering final
203 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
204 using OpConversionPattern::OpConversionPattern;
205
206 LogicalResult
207 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
208 ConversionPatternRewriter &rewriter) const override {
209 if (adaptor.getOperands().size() != 2)
210 return failure();
211
212 // All operands should be of cooperative matrix types.
213 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
214 return rewriter.notifyMatchFailure(op,
215 "not all operands are coop matrices");
216 }
217
218 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
219 return failure();
220
221 // Use the original operands to check whether one of the operands is a splat
222 // scalar value.
223 Value lhs = op.getOperands().front();
224 Value rhs = op.getOperands().back();
225 Value splat = nullptr;
226 Value matrix = nullptr;
227 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
228 splat = adaptor.getOperands().front();
229 matrix = adaptor.getOperands().back();
230 } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
231 matrix = adaptor.getOperands().front();
232 splat = adaptor.getOperands().back();
233 }
234 if (!splat || !matrix)
235 return rewriter.notifyMatchFailure(op, "no splat operand");
236
237 // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
238 Value scalar;
239 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
240 if (!cc) {
241 return rewriter.notifyMatchFailure(op,
242 "splat is not a composite construct");
243 }
244
245 scalar = llvm::getSingleElement(cc.getConstituents());
246
247 auto coopType = getTypeConverter()->convertType(op.getType());
248 if (!coopType)
249 return rewriter.notifyMatchFailure(op, "type conversion failed");
250 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
251 op, coopType, ValueRange{matrix, scalar});
252 return success();
253 }
254};
255} // namespace
256
257//===----------------------------------------------------------------------===//
258// SPV_KHR_cooperative_matrix
259//===----------------------------------------------------------------------===//
260
261namespace khr {
262namespace {
263
264/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
265/// dialect.
266struct WmmaLoadOpToSPIRVLowering final
267 : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
268 using OpConversionPattern::OpConversionPattern;
269
270 LogicalResult
271 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
272 ConversionPatternRewriter &rewriter) const override {
273 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
274 Location loc = op->getLoc();
275
276 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
277 MemRefType memrefType = op.getSrcMemref().getType();
278 Value bufferPtr =
279 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getSrcMemref(),
280 indices: adaptor.getIndices(), loc, builder&: rewriter);
281
282 auto coopType =
283 typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
284 if (!coopType)
285 return rewriter.notifyMatchFailure(op, "type conversion failed");
286
287 int64_t stride = op.getLeadDimension().getSExtValue();
288 IntegerType i32Type = rewriter.getI32Type();
289 auto strideValue = rewriter.create<spirv::ConstantOp>(
290 loc, i32Type, IntegerAttr::get(i32Type, stride));
291
292 bool isColMajor = op.getTranspose().value_or(false);
293 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
294 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
295
296 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
297 op, coopType, bufferPtr, strideValue, layout);
298 return success();
299 }
300};
301
302/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
303/// dialect.
304struct WmmaStoreOpToSPIRVLowering final
305 : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
306 using OpConversionPattern::OpConversionPattern;
307
308 LogicalResult
309 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
310 ConversionPatternRewriter &rewriter) const override {
311 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
312 Location loc = op->getLoc();
313
314 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
315 Value bufferPtr =
316 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getDstMemref(),
317 indices: adaptor.getIndices(), loc, builder&: rewriter);
318
319 int64_t stride = op.getLeadDimension().getSExtValue();
320 IntegerType i32Type = rewriter.getI32Type();
321 auto strideValue = rewriter.create<spirv::ConstantOp>(
322 loc, i32Type, IntegerAttr::get(i32Type, stride));
323
324 bool isColMajor = op.getTranspose().value_or(false);
325 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
326 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
327
328 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
329 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
330 return success();
331 }
332};
333
334/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
335/// dialect.
336struct WmmaMmaOpToSPIRVLowering final
337 : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
338 using OpConversionPattern::OpConversionPattern;
339
340 LogicalResult
341 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
342 OpAdaptor adaptor,
343 ConversionPatternRewriter &rewriter) const override {
344 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
345 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
346 adaptor.getOpC());
347 return success();
348 }
349};
350
351} // namespace
352} // namespace khr
353} // namespace mlir
354
355void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
356 const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
357 using namespace mlir;
358 MLIRContext *context = patterns.getContext();
359 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
360 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
361 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
362 WmmaElementwiseOpToSPIRVDefaultLowering>(arg: converter, args&: context);
363 // Give the following patterns higher benefit to prevail over the default one.
364 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(arg: converter, args&: context,
365 /*benefit=*/args: 2);
366}
367
368void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
369 mlir::SPIRVTypeConverter &typeConverter) {
370 typeConverter.addConversion(callback: [](gpu::MMAMatrixType type) {
371 ArrayRef<int64_t> retTypeShape = type.getShape();
372 Type elementType = type.getElementType();
373 auto use =
374 llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
375 .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
376 .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
377 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
378
379 return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
380 retTypeShape[1],
381 spirv::Scope::Subgroup, use);
382 });
383}
384

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp