1//===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the Cooperative Matrix operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVParsingUtils.h"
14#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "llvm/ADT/STLExtras.h"
18
19using namespace mlir::spirv::AttrNames;
20
21namespace mlir::spirv {
22
23static LogicalResult
24verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
25 spirv::MemoryAccessAttr memoryOperand,
26 IntegerAttr alignment) {
27 auto pointerType = cast<PointerType>(Val&: pointer);
28 Type pointeeType = pointerType.getPointeeType();
29 if (!isa<ScalarType, VectorType>(Val: pointeeType)) {
30 return op->emitOpError(
31 message: "Pointer must point to a scalar or vector type but provided ")
32 << pointeeType;
33 }
34
35 if (memoryOperand) {
36 spirv::MemoryAccess operandSet = memoryOperand.getValue();
37
38 if (isa<spirv::KHRCooperativeMatrixLoadOp>(Val: op) &&
39 spirv::bitEnumContainsAll(bits: operandSet,
40 bit: spirv::MemoryAccess::MakePointerAvailable)) {
41 return op->emitOpError(
42 message: "not compatible with memory operand 'MakePointerAvailable'");
43 }
44
45 if (isa<spirv::KHRCooperativeMatrixStoreOp>(Val: op) &&
46 spirv::bitEnumContainsAll(bits: operandSet,
47 bit: spirv::MemoryAccess::MakePointerVisible)) {
48 return op->emitOpError(
49 message: "not compatible with memory operand 'MakePointerVisible'");
50 }
51
52 // TODO: Need to check that NonPrivatePointer is set for MakePointer*. See
53 // #145485.
54
55 if (spirv::bitEnumContainsAll(bits: operandSet, bit: spirv::MemoryAccess::Aligned) &&
56 !alignment) {
57 return op->emitOpError(message: "missing value for the 'Aligned' memory operand");
58 }
59
60 if (!spirv::bitEnumContainsAll(bits: operandSet, bit: spirv::MemoryAccess::Aligned) &&
61 alignment) {
62 return op->emitOpError(
63 message: "found alignment attribute for non-'Aligned' memory operand");
64 }
65 }
66
67 // TODO: Verify the memory object behind the pointer:
68 // > If the Shader capability was declared, Pointer must point into an array
69 // > and any ArrayStride decoration on Pointer is ignored.
70
71 return success();
72}
73
74//===----------------------------------------------------------------------===//
75// spirv.KHR.CooperativeMatrixLoad
76//===----------------------------------------------------------------------===//
77
78LogicalResult KHRCooperativeMatrixLoadOp::verify() {
79 return verifyCoopMatrixAccess(op: *this, pointer: getPointer().getType(),
80 coopMatrix: getResult().getType(), memoryOperand: getMemoryOperandAttr(),
81 alignment: getAlignmentAttr());
82}
83
84//===----------------------------------------------------------------------===//
85// spirv.KHR.CooperativeMatrixStore
86//===----------------------------------------------------------------------===//
87
88LogicalResult KHRCooperativeMatrixStoreOp::verify() {
89 return verifyCoopMatrixAccess(op: *this, pointer: getPointer().getType(),
90 coopMatrix: getObject().getType(), memoryOperand: getMemoryOperandAttr(),
91 alignment: getAlignmentAttr());
92}
93
94//===----------------------------------------------------------------------===//
95// spirv.KHR.CooperativeMatrixMulAdd
96//===----------------------------------------------------------------------===//
97
98LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
99 auto typeA = cast<spirv::CooperativeMatrixType>(Val: getA().getType());
100 auto typeB = cast<spirv::CooperativeMatrixType>(Val: getB().getType());
101 auto typeC = cast<spirv::CooperativeMatrixType>(Val: getC().getType());
102
103 // Check element types. ODS enforces that `type(c) == type(result)`, so no
104 // need to check it here.
105
106 // Check the 'use' part of the type against the operands and the result.
107 if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
108 return emitOpError(message: "operand #0 must be of use 'MatrixA'");
109 if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB)
110 return emitOpError(message: "operand #1 must be of use 'MatrixB'");
111 if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc)
112 return emitOpError(message: "operand #2 must be of use 'MatrixAcc'");
113
114 // Check the 'scope' part of the type.
115 if (!llvm::all_equal(Values: {typeA.getScope(), typeB.getScope(), typeC.getScope()}))
116 return emitOpError(message: "matrix scope mismatch");
117
118 // Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'.
119 if (typeA.getRows() != typeC.getRows())
120 return emitOpError(message: "matrix size mismatch on dimension 'M'");
121 if (typeB.getColumns() != typeC.getColumns())
122 return emitOpError(message: "matrix size mismatch on dimension 'N'");
123 if (typeA.getColumns() != typeB.getRows())
124 return emitOpError(message: "matrix size mismatch on dimension 'K'");
125
126 // The spec does not restrict the element types:
127 // > A, B, C, and Result Type need not necessarily have the same component
128 // > type, this is defined by the client API.
129
130 // Check that if Cooperative Matrix Operands are provided, the element type
131 // is integer.
132 if (getMatrixOperands()) {
133 Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
134 typeC.getElementType()};
135 if (!llvm::all_of(Range&: elementTypes, P: llvm::IsaPred<IntegerType>)) {
136 return emitOpError(message: "Matrix Operands require all matrix element types to "
137 "be Integer Types");
138 }
139 }
140
141 // Any further requirements need to be checked against VCE.
142 return success();
143}
144
145} // namespace mlir::spirv
146

source code of mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp