1//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
10// SPIRV Cooperative Matrix ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
15#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
20#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/TypeUtilities.h"
24#include "mlir/IR/ValueRange.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/StringSwitch.h"
27
28#include <cassert>
29
30namespace mlir {
31//===----------------------------------------------------------------------===//
32// Patterns and helpers.
33//===----------------------------------------------------------------------===//
34
35/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
36/// when the elementwise op directly supports with cooperative matrix type.
37/// Returns false if cannot.
38///
39/// See SPV_KHR_cooperative_matrix for supported elementwise ops.
40static bool createElementwiseOp(ConversionPatternRewriter &builder,
41 gpu::SubgroupMmaElementwiseOp op, Type coopType,
42 ValueRange operands) {
43 assert((isa<spirv::CooperativeMatrixType>(coopType)));
44
45 switch (op.getOpType()) {
46 case gpu::MMAElementwiseOp::ADDF:
47 builder.replaceOpWithNewOp<spirv::FAddOp>(op, args&: coopType, args&: operands);
48 return true;
49 case gpu::MMAElementwiseOp::ADDI:
50 builder.replaceOpWithNewOp<spirv::IAddOp>(op, args&: coopType, args&: operands);
51 return true;
52 case gpu::MMAElementwiseOp::SUBF:
53 builder.replaceOpWithNewOp<spirv::FSubOp>(op, args&: coopType, args&: operands);
54 return true;
55 case gpu::MMAElementwiseOp::SUBI:
56 builder.replaceOpWithNewOp<spirv::ISubOp>(op, args&: coopType, args&: operands);
57 return true;
58 case gpu::MMAElementwiseOp::DIVF:
59 builder.replaceOpWithNewOp<spirv::FDivOp>(op, args&: coopType, args&: operands);
60 return true;
61 case gpu::MMAElementwiseOp::DIVS:
62 builder.replaceOpWithNewOp<spirv::SDivOp>(op, args&: coopType, args&: operands);
63 return true;
64 case gpu::MMAElementwiseOp::DIVU:
65 builder.replaceOpWithNewOp<spirv::UDivOp>(op, args&: coopType, args&: operands);
66 return true;
67 case gpu::MMAElementwiseOp::NEGATEF:
68 builder.replaceOpWithNewOp<spirv::FNegateOp>(op, args&: coopType, args&: operands);
69 return true;
70 case gpu::MMAElementwiseOp::NEGATES:
71 builder.replaceOpWithNewOp<spirv::SNegateOp>(op, args&: coopType, args&: operands);
72 return true;
73 case gpu::MMAElementwiseOp::EXTF:
74 builder.replaceOpWithNewOp<spirv::FConvertOp>(op, args&: coopType, args&: operands);
75 return true;
76 default:
77 break;
78 }
79 return false;
80}
81
82bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
83 assert(!operands.empty());
84 if (!llvm::all_equal(
85 Range: llvm::map_range(C&: operands, F: [](Value v) { return v.getType(); })))
86 return false;
87
88 return isa<spirv::CooperativeMatrixType>(Val: operands.front().getType());
89}
90
91namespace {
92/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
93/// matrix ops.
94struct WmmaConstantOpToSPIRVLowering final
95 : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
96 using OpConversionPattern::OpConversionPattern;
97
98 LogicalResult
99 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter) const override {
101 Value cst = llvm::getSingleElement(C: adaptor.getOperands());
102 auto coopType = getTypeConverter()->convertType(t: op.getType());
103 if (!coopType)
104 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
105
106 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, args&: coopType, args&: cst);
107 return success();
108 }
109};
110
111/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
112/// matrix ops.
113struct WmmaExtractOpToSPIRVLowering final
114 : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
115 using OpConversionPattern::OpConversionPattern;
116
117 LogicalResult
118 matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
119 ConversionPatternRewriter &rewriter) const override {
120 Value matrix = adaptor.getMatrix();
121 auto coopType =
122 getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
123 t: matrix.getType());
124 if (!coopType)
125 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
126
127 SmallVector<int32_t> intValues;
128 for (Value val : op.getIndices()) {
129 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
130 intValues.push_back(Elt: static_cast<int32_t>(constOp.value()));
131 } else {
132 return rewriter.notifyMatchFailure(arg&: op, msg: "indices must be constants");
133 }
134 }
135
136 Type elementType = coopType.getElementType();
137 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
138 op, args&: elementType, args&: matrix, args: rewriter.getI32ArrayAttr(values: intValues));
139 return success();
140 }
141};
142
143/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
144/// matrix ops.
145struct WmmaInsertOpToSPIRVLowering final
146 : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
147 using OpConversionPattern::OpConversionPattern;
148
149 LogicalResult
150 matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
151 ConversionPatternRewriter &rewriter) const override {
152 Value value = adaptor.getValue();
153 Value matrix = adaptor.getMatrix();
154 auto coopType = getTypeConverter()->convertType(t: matrix.getType());
155 if (!coopType)
156 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
157
158 SmallVector<int32_t> intValues;
159 for (Value val : op.getIndices()) {
160 if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
161 intValues.push_back(Elt: static_cast<int32_t>(constOp.value()));
162 } else {
163 return rewriter.notifyMatchFailure(arg&: op, msg: "indices must be constants");
164 }
165 }
166
167 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
168 op, args&: coopType, args&: value, args&: matrix, args: rewriter.getI32ArrayAttr(values: intValues));
169 return success();
170 }
171};
172
173/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
174/// the default case.
175struct WmmaElementwiseOpToSPIRVDefaultLowering final
176 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
177 using OpConversionPattern::OpConversionPattern;
178
179 LogicalResult
180 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter) const override {
182 // All operands should be of cooperative matrix types.
183 if (!allOperandsHaveSameCoopMatrixType(operands: adaptor.getOperands())) {
184 return rewriter.notifyMatchFailure(arg&: op,
185 msg: "not all operands are coop matrices");
186 }
187
188 auto coopType = getTypeConverter()->convertType(t: op.getType());
189 if (!coopType)
190 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
191
192 return success(
193 IsSuccess: createElementwiseOp(builder&: rewriter, op, coopType, operands: adaptor.getOperands()));
194 }
195};
196
197/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
198/// matrix times scalar case.
199struct WmmaElementwiseOpToSPIRVScalarMulLowering final
200 : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
201 using OpConversionPattern::OpConversionPattern;
202
203 LogicalResult
204 matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override {
206 if (adaptor.getOperands().size() != 2)
207 return failure();
208
209 // All operands should be of cooperative matrix types.
210 if (!allOperandsHaveSameCoopMatrixType(operands: adaptor.getOperands())) {
211 return rewriter.notifyMatchFailure(arg&: op,
212 msg: "not all operands are coop matrices");
213 }
214
215 if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
216 return failure();
217
218 // Use the original operands to check whether one of the operands is a splat
219 // scalar value.
220 Value lhs = op.getOperands().front();
221 Value rhs = op.getOperands().back();
222 Value splat = nullptr;
223 Value matrix = nullptr;
224 if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
225 splat = adaptor.getOperands().front();
226 matrix = adaptor.getOperands().back();
227 } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
228 matrix = adaptor.getOperands().front();
229 splat = adaptor.getOperands().back();
230 }
231 if (!splat || !matrix)
232 return rewriter.notifyMatchFailure(arg&: op, msg: "no splat operand");
233
234 // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
235 Value scalar;
236 auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
237 if (!cc) {
238 return rewriter.notifyMatchFailure(arg&: op,
239 msg: "splat is not a composite construct");
240 }
241
242 scalar = llvm::getSingleElement(C: cc.getConstituents());
243
244 auto coopType = getTypeConverter()->convertType(t: op.getType());
245 if (!coopType)
246 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
247 rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
248 op, args&: coopType, args: ValueRange{matrix, scalar});
249 return success();
250 }
251};
252} // namespace
253
254//===----------------------------------------------------------------------===//
255// SPV_KHR_cooperative_matrix
256//===----------------------------------------------------------------------===//
257
258namespace khr {
259namespace {
260
261/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
262/// dialect.
263struct WmmaLoadOpToSPIRVLowering final
264 : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
265 using OpConversionPattern::OpConversionPattern;
266
267 LogicalResult
268 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
269 ConversionPatternRewriter &rewriter) const override {
270 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
271 Location loc = op->getLoc();
272
273 auto retType = cast<gpu::MMAMatrixType>(Val: op.getRes().getType());
274 MemRefType memrefType = op.getSrcMemref().getType();
275 Value bufferPtr =
276 spirv::getElementPtr(typeConverter, baseType: memrefType, basePtr: adaptor.getSrcMemref(),
277 indices: adaptor.getIndices(), loc, builder&: rewriter);
278
279 auto coopType =
280 typeConverter.convertType<spirv::CooperativeMatrixType>(t: retType);
281 if (!coopType)
282 return rewriter.notifyMatchFailure(arg&: op, msg: "type conversion failed");
283
284 int64_t stride = op.getLeadDimension().getSExtValue();
285 IntegerType i32Type = rewriter.getI32Type();
286 auto strideValue = rewriter.create<spirv::ConstantOp>(
287 location: loc, args&: i32Type, args: IntegerAttr::get(type: i32Type, value: stride));
288
289 bool isColMajor = op.getTranspose().value_or(u: false);
290 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
291 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
292
293 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
294 op, args&: coopType, args&: bufferPtr, args&: strideValue, args&: layout);
295 return success();
296 }
297};
298
299/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
300/// dialect.
301struct WmmaStoreOpToSPIRVLowering final
302 : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
303 using OpConversionPattern::OpConversionPattern;
304
305 LogicalResult
306 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
307 ConversionPatternRewriter &rewriter) const override {
308 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
309 Location loc = op->getLoc();
310
311 auto memrefType = cast<MemRefType>(Val: op.getDstMemref().getType());
312 Value bufferPtr =
313 spirv::getElementPtr(typeConverter, baseType: memrefType, basePtr: adaptor.getDstMemref(),
314 indices: adaptor.getIndices(), loc, builder&: rewriter);
315
316 int64_t stride = op.getLeadDimension().getSExtValue();
317 IntegerType i32Type = rewriter.getI32Type();
318 auto strideValue = rewriter.create<spirv::ConstantOp>(
319 location: loc, args&: i32Type, args: IntegerAttr::get(type: i32Type, value: stride));
320
321 bool isColMajor = op.getTranspose().value_or(u: false);
322 auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
323 : spirv::CooperativeMatrixLayoutKHR::RowMajor;
324
325 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
326 op, args&: bufferPtr, args: adaptor.getSrc(), args&: strideValue, args&: layout);
327 return success();
328 }
329};
330
331/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
332/// dialect.
333struct WmmaMmaOpToSPIRVLowering final
334 : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
335 using OpConversionPattern::OpConversionPattern;
336
337 LogicalResult
338 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
339 OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const override {
341 rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
342 op: subgroupMmaComputeOp, args: adaptor.getOpA(), args: adaptor.getOpB(),
343 args: adaptor.getOpC());
344 return success();
345 }
346};
347
348} // namespace
349} // namespace khr
350} // namespace mlir
351
352void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
353 const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
354 using namespace mlir;
355 MLIRContext *context = patterns.getContext();
356 patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
357 khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
358 WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
359 WmmaElementwiseOpToSPIRVDefaultLowering>(arg: converter, args&: context);
360 // Give the following patterns higher benefit to prevail over the default one.
361 patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(arg: converter, args&: context,
362 /*benefit=*/args: 2);
363}
364
365void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
366 mlir::SPIRVTypeConverter &typeConverter) {
367 typeConverter.addConversion(callback: [](gpu::MMAMatrixType type) {
368 ArrayRef<int64_t> retTypeShape = type.getShape();
369 Type elementType = type.getElementType();
370 auto use =
371 llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
372 .Case(S: "AOp", Value: spirv::CooperativeMatrixUseKHR::MatrixA)
373 .Case(S: "BOp", Value: spirv::CooperativeMatrixUseKHR::MatrixB)
374 .Default(Value: spirv::CooperativeMatrixUseKHR::MatrixAcc);
375
376 return spirv::CooperativeMatrixType::get(elementType, rows: retTypeShape[0],
377 columns: retTypeShape[1],
378 scope: spirv::Scope::Subgroup, use);
379 });
380}
381

source code of mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp