1 | //===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Defines the Cooperative Matrix operations in the SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "SPIRVParsingUtils.h" |
14 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
17 | #include "llvm/ADT/STLExtras.h" |
18 | #include <cstdint> |
19 | |
20 | using namespace mlir::spirv::AttrNames; |
21 | |
22 | namespace mlir::spirv { |
23 | |
24 | static LogicalResult |
25 | verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, |
26 | spirv::MemoryAccessAttr memoryOperand) { |
27 | auto pointerType = cast<PointerType>(Val&: pointer); |
28 | Type pointeeType = pointerType.getPointeeType(); |
29 | if (!isa<ScalarType, VectorType>(Val: pointeeType)) { |
30 | return op->emitOpError( |
31 | message: "Pointer must point to a scalar or vector type but provided " ) |
32 | << pointeeType; |
33 | } |
34 | |
35 | if (memoryOperand) { |
36 | spirv::MemoryAccess operandSet = memoryOperand.getValue(); |
37 | |
38 | if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) && |
39 | spirv::bitEnumContainsAll(operandSet, |
40 | spirv::MemoryAccess::MakePointerAvailable)) { |
41 | return op->emitOpError( |
42 | message: "not compatible with memory operand 'MakePointerAvailable'" ); |
43 | } |
44 | |
45 | if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) && |
46 | spirv::bitEnumContainsAll(operandSet, |
47 | spirv::MemoryAccess::MakePointerVisible)) { |
48 | return op->emitOpError( |
49 | message: "not compatible with memory operand 'MakePointerVisible'" ); |
50 | } |
51 | |
52 | // The 'Aligned' memory operand requires an alignment literal to follow, |
53 | // which needs to be implemented on the level of op parsing and |
54 | // (de-)serialization. |
55 | // TODO: Consider adding support for this attribute value. |
56 | if (spirv::bitEnumContainsAll(memoryOperand.getValue(), |
57 | spirv::MemoryAccess::Aligned)) { |
58 | return op->emitOpError(message: "has unhandled memory operand 'Aligned'" ); |
59 | } |
60 | } |
61 | |
62 | // TODO: Verify the memory object behind the pointer: |
63 | // > If the Shader capability was declared, Pointer must point into an array |
64 | // > and any ArrayStride decoration on Pointer is ignored. |
65 | |
66 | return success(); |
67 | } |
68 | |
69 | //===----------------------------------------------------------------------===// |
70 | // spirv.KHR.CooperativeMatrixLoad |
71 | //===----------------------------------------------------------------------===// |
72 | |
73 | LogicalResult KHRCooperativeMatrixLoadOp::verify() { |
74 | return verifyCoopMatrixAccess(*this, getPointer().getType(), |
75 | getResult().getType(), getMemoryOperandAttr()); |
76 | } |
77 | |
78 | //===----------------------------------------------------------------------===// |
79 | // spirv.KHR.CooperativeMatrixStore |
80 | //===----------------------------------------------------------------------===// |
81 | |
82 | LogicalResult KHRCooperativeMatrixStoreOp::verify() { |
83 | return verifyCoopMatrixAccess(*this, getPointer().getType(), |
84 | getObject().getType(), getMemoryOperandAttr()); |
85 | } |
86 | |
87 | //===----------------------------------------------------------------------===// |
88 | // spirv.KHR.CooperativeMatrixMulAdd |
89 | //===----------------------------------------------------------------------===// |
90 | |
91 | LogicalResult KHRCooperativeMatrixMulAddOp::verify() { |
92 | auto typeA = cast<spirv::CooperativeMatrixType>(getA().getType()); |
93 | auto typeB = cast<spirv::CooperativeMatrixType>(getB().getType()); |
94 | auto typeC = cast<spirv::CooperativeMatrixType>(getC().getType()); |
95 | |
96 | // Check element types. ODS enforces that `type(c) == type(result)`, so no |
97 | // need to check it here. |
98 | |
99 | // Check the 'use' part of the type against the operands and the result. |
100 | if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA) |
101 | return emitOpError("operand #0 must be of use 'MatrixA'" ); |
102 | if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB) |
103 | return emitOpError("operand #1 must be of use 'MatrixB'" ); |
104 | if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc) |
105 | return emitOpError("operand #2 must be of use 'MatrixAcc'" ); |
106 | |
107 | // Check the 'scope' part of the type. |
108 | if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()})) |
109 | return emitOpError("matrix scope mismatch" ); |
110 | |
111 | // Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'. |
112 | if (typeA.getRows() != typeC.getRows()) |
113 | return emitOpError("matrix size mismatch on dimension 'M'" ); |
114 | if (typeB.getColumns() != typeC.getColumns()) |
115 | return emitOpError("matrix size mismatch on dimension 'N'" ); |
116 | if (typeA.getColumns() != typeB.getRows()) |
117 | return emitOpError("matrix size mismatch on dimension 'K'" ); |
118 | |
119 | // The spec does not restrict the element types: |
120 | // > A, B, C, and Result Type need not necessarily have the same component |
121 | // > type, this is defined by the client API. |
122 | |
123 | // Check that if Cooperative Matrix Operands are provided, the element type |
124 | // is integer. |
125 | if (getMatrixOperands()) { |
126 | Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(), |
127 | typeC.getElementType()}; |
128 | if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) { |
129 | return emitOpError("Matrix Operands require all matrix element types to " |
130 | "be Integer Types" ); |
131 | } |
132 | } |
133 | |
134 | // Any further requirements need to be checked against VCE. |
135 | return success(); |
136 | } |
137 | |
138 | } // namespace mlir::spirv |
139 | |