1 | //===- JointMatrixOps.cpp - MLIR SPIR-V Intel Joint Matrix Ops -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Defines the Intel Joint Matrix operations in the SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
14 | |
15 | namespace mlir { |
16 | //===----------------------------------------------------------------------===// |
17 | // spirv.INTEL.JointMatrixLoad |
18 | //===----------------------------------------------------------------------===// |
19 | |
20 | static LogicalResult |
21 | verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) { |
22 | Type pointeeType = llvm::cast<spirv::PointerType>(Val&: pointer).getPointeeType(); |
23 | if (!llvm::isa<spirv::ScalarType>(Val: pointeeType) && |
24 | !llvm::isa<VectorType>(Val: pointeeType)) |
25 | return op->emitError( |
26 | message: "Pointer must point to a scalar or vector type but provided " ) |
27 | << pointeeType; |
28 | spirv::StorageClass storage = |
29 | llvm::cast<spirv::PointerType>(pointer).getStorageClass(); |
30 | if (storage != spirv::StorageClass::Workgroup && |
31 | storage != spirv::StorageClass::CrossWorkgroup && |
32 | storage != spirv::StorageClass::UniformConstant && |
33 | storage != spirv::StorageClass::Generic) |
34 | return op->emitError(message: "Pointer storage class must be Workgroup or " |
35 | "CrossWorkgroup but provided " ) |
36 | << stringifyStorageClass(storage); |
37 | return success(); |
38 | } |
39 | |
40 | LogicalResult spirv::INTELJointMatrixLoadOp::verify() { |
41 | return verifyPointerAndJointMatrixType(*this, getPointer().getType(), |
42 | getResult().getType()); |
43 | } |
44 | |
45 | //===----------------------------------------------------------------------===// |
46 | // spirv.INTEL.JointMatrixStore |
47 | //===----------------------------------------------------------------------===// |
48 | |
49 | LogicalResult spirv::INTELJointMatrixStoreOp::verify() { |
50 | return verifyPointerAndJointMatrixType(*this, getPointer().getType(), |
51 | getObject().getType()); |
52 | } |
53 | |
54 | //===----------------------------------------------------------------------===// |
55 | // spirv.INTEL.JointMatrixMad |
56 | //===----------------------------------------------------------------------===// |
57 | |
58 | static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) { |
59 | if (op.getC().getType() != op.getResult().getType()) |
60 | return op.emitOpError("result and third operand must have the same type" ); |
61 | auto typeA = llvm::cast<spirv::JointMatrixINTELType>(op.getA().getType()); |
62 | auto typeB = llvm::cast<spirv::JointMatrixINTELType>(op.getB().getType()); |
63 | auto typeC = llvm::cast<spirv::JointMatrixINTELType>(op.getC().getType()); |
64 | auto typeR = |
65 | llvm::cast<spirv::JointMatrixINTELType>(op.getResult().getType()); |
66 | if (typeA.getRows() != typeR.getRows() || |
67 | typeA.getColumns() != typeB.getRows() || |
68 | typeB.getColumns() != typeR.getColumns()) |
69 | return op.emitOpError("matrix size must match" ); |
70 | if (typeR.getScope() != typeA.getScope() || |
71 | typeR.getScope() != typeB.getScope() || |
72 | typeR.getScope() != typeC.getScope()) |
73 | return op.emitOpError("matrix scope must match" ); |
74 | if (typeA.getElementType() != typeB.getElementType() || |
75 | typeR.getElementType() != typeC.getElementType()) |
76 | return op.emitOpError("matrix element type must match" ); |
77 | return success(); |
78 | } |
79 | |
80 | LogicalResult spirv::INTELJointMatrixMadOp::verify() { |
81 | return verifyJointMatrixMad(*this); |
82 | } |
83 | |
84 | } // namespace mlir |
85 | |