1//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// SPIRV Cooperative Matrix ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
15#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
22#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
23#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
24#include "mlir/IR/BuiltinAttributes.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "mlir/IR/ValueRange.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/StringSwitch.h"
30
31#include <cassert>
32
33namespace mlir {
34//===----------------------------------------------------------------------===//
35// Patterns and helpers.
36//===----------------------------------------------------------------------===//
37
38/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
39/// when the elementwise op directly supports with cooperative matrix type.
40/// Returns false if cannot.
41///
42/// See SPV_KHR_cooperative_matrix for supported elementwise ops.
43static bool createElementwiseOp(ConversionPatternRewriter &builder,
44 gpu::SubgroupMmaElementwiseOp op, Type coopType,
45 ValueRange operands) {
46 assert((isa<spirv::CooperativeMatrixType>(coopType)));
47
48 switch (op.getOpType()) {
49 case gpu::MMAElementwiseOp::ADDF:
50 builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
51 return true;
52 case gpu::MMAElementwiseOp::ADDI:
53 builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
54 return true;
55 case gpu::MMAElementwiseOp::SUBF:
56 builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
57 return true;
58 case gpu::MMAElementwiseOp::SUBI:
59 builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
60 return true;
61 case gpu::MMAElementwiseOp::DIVF:
62 builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
63 return true;
64 case gpu::MMAElementwiseOp::DIVS:
65 builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
66 return true;
67 case gpu::MMAElementwiseOp::DIVU:
68 builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
69 return true;
70 case gpu::MMAElementwiseOp::NEGATEF:
71 builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
72 return true;
73 case gpu::MMAElementwiseOp::NEGATES:
74 builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
75 return true;
76 case gpu::MMAElementwiseOp::EXTF:
77 builder.replaceOpWithNewOp<spirv::FConvertOp>(op, coopType, operands);
78 return true;
79 default:
80 break;
81 }
82 return false;
83}
84
85bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
86 assert(!operands.empty());
87 if (!llvm::all_equal(
88 Range: llvm::map_range(C&: operands, F: [](Value v) { return v.getType(); })))
89 return false;
90
91 return isa<spirv::CooperativeMatrixType>(Val: operands.front().getType());
92}
93
94namespace {
95/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
96/// matrix ops.
97struct WmmaConstantOpToSPIRVLowering final
98 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
99 using OpConversionPattern::OpConversionPattern;
100
101 LogicalResult
102 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103 ConversionPatternRewriter &rewriter) const override {
104 assert(adaptor.getOperands().size() == 1);
105 Value cst = adaptor.getOperands().front();
106 auto coopType = getTypeConverter()->convertType(op.getType());
107 if (!coopType)
108 return rewriter.notifyMatchFailure(op, "type conversion failed");
109
110 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
111 return success();
112 }
113};
114
115/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
116/// the default case.
117struct WmmaElementwiseOpToSPIRVDefaultLowering final
118 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
119 using OpConversionPattern::OpConversionPattern;
120
121 LogicalResult
122 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter) const override {
124 // All operands should be of cooperative matrix types.
125 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
126 return rewriter.notifyMatchFailure(op,
127 "not all operands are coop matrices");
128 }
129
130 auto coopType = getTypeConverter()->convertType(op.getType());
131 if (!coopType)
132 return rewriter.notifyMatchFailure(op, "type conversion failed");
133
134 return success(
135 createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
136 }
137};
138
139/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
140/// matrix times scalar case.
141struct WmmaElementwiseOpToSPIRVScalarMulLowering final
142 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
143 using OpConversionPattern::OpConversionPattern;
144
145 LogicalResult
146 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter) const override {
148 if (adaptor.getOperands().size() != 2)
149 return failure();
150
151 // All operands should be of cooperative matrix types.
152 if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
153 return rewriter.notifyMatchFailure(op,
154 "not all operands are coop matrices");
155 }
156
157 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
158 return failure();
159
160 // Use the original operands to check whether one of the operands is a splat
161 // scalar value.
162 Value lhs = op.getOperands().front();
163 Value rhs = op.getOperands().back();
164 Value splat = nullptr;
165 Value matrix = nullptr;
166 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
167 splat = adaptor.getOperands().front();
168 matrix = adaptor.getOperands().back();
169 } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
170 matrix = adaptor.getOperands().front();
171 splat = adaptor.getOperands().back();
172 }
173 if (!splat || !matrix)
174 return rewriter.notifyMatchFailure(op, "no splat operand");
175
176 // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
177 Value scalar;
178 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
179 if (!cc) {
180 return rewriter.notifyMatchFailure(op,
181 "splat is not a composite construct");
182 }
183
184 assert(cc.getConstituents().size() == 1);
185 scalar = cc.getConstituents().front();
186
187 auto coopType = getTypeConverter()->convertType(op.getType());
188 if (!coopType)
189 return rewriter.notifyMatchFailure(op, "type conversion failed");
190 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
191 op, coopType, ValueRange{matrix, scalar});
192 return success();
193 }
194};
195} // namespace
196
197//===----------------------------------------------------------------------===//
198// SPV_KHR_cooperative_matrix
199//===----------------------------------------------------------------------===//
200
201namespace khr {
202namespace {
203
204/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
205/// dialect.
206struct WmmaLoadOpToSPIRVLowering final
207 : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
208 using OpConversionPattern::OpConversionPattern;
209
210 LogicalResult
211 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
212 ConversionPatternRewriter &rewriter) const override {
213 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
214 Location loc = op->getLoc();
215
216 auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
217 MemRefType memrefType = op.getSrcMemref().getType();
218 Value bufferPtr =
219 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getSrcMemref(),
220 indices: adaptor.getIndices(), loc, builder&: rewriter);
221
222 auto coopType =
223 typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
224 if (!coopType)
225 return rewriter.notifyMatchFailure(op, "type conversion failed");
226
227 int64_t stride = op.getLeadDimension().getSExtValue();
228 IntegerType i32Type = rewriter.getI32Type();
229 auto strideValue = rewriter.create<spirv::ConstantOp>(
230 loc, i32Type, IntegerAttr::get(i32Type, stride));
231
232 bool isColMajor = op.getTranspose().value_or(false);
233 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
234 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
235
236 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
237 op, coopType, bufferPtr, strideValue, layout);
238 return success();
239 }
240};
241
242/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
243/// dialect.
244struct WmmaStoreOpToSPIRVLowering final
245 : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
246 using OpConversionPattern::OpConversionPattern;
247
248 LogicalResult
249 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
250 ConversionPatternRewriter &rewriter) const override {
251 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
252 Location loc = op->getLoc();
253
254 auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
255 Value bufferPtr =
256 spirv::getElementPtr(typeConverter: typeConverter, baseType: memrefType, basePtr: adaptor.getDstMemref(),
257 indices: adaptor.getIndices(), loc, builder&: rewriter);
258
259 int64_t stride = op.getLeadDimension().getSExtValue();
260 IntegerType i32Type = rewriter.getI32Type();
261 auto strideValue = rewriter.create<spirv::ConstantOp>(
262 loc, i32Type, IntegerAttr::get(i32Type, stride));
263
264 bool isColMajor = op.getTranspose().value_or(false);
265 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
266 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
267
268 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
269 op, bufferPtr, adaptor.getSrc(), strideValue, layout);
270 return success();
271 }
272};
273
274/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
275/// dialect.
276struct WmmaMmaOpToSPIRVLowering final
277 : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
278 using OpConversionPattern::OpConversionPattern;
279
280 LogicalResult
281 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
282 OpAdaptor adaptor,
283 ConversionPatternRewriter &rewriter) const override {
284 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
285 subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
286 adaptor.getOpC());
287 return success();
288 }
289};
290
291} // namespace
292} // namespace khr
293} // namespace mlir
294
295void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
296 SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
297 using namespace mlir;
298 MLIRContext *context = patterns.getContext();
299 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
300 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
301 WmmaElementwiseOpToSPIRVDefaultLowering>(arg&: converter, args&: context);
302 // Give the following patterns higher benefit to prevail over the default one.
303 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(arg&: converter, args&: context,
304 /*benefit=*/args: 2);
305}
306
307void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
308 mlir::SPIRVTypeConverter &typeConverter) {
309 typeConverter.addConversion(callback: [](gpu::MMAMatrixType type) {
310 ArrayRef<int64_t> retTypeShape = type.getShape();
311 Type elementType = type.getElementType();
312 auto use =
313 llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
314 .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
315 .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
316 .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
317
318 return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
319 retTypeShape[1],
320 spirv::Scope::Subgroup, use);
321 });
322}
323

source code of mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp