1//===- JointMatrixOps.cpp - MLIR SPIR-V Intel Joint Matrix Ops -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the Intel Joint Matrix operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15namespace mlir {
16//===----------------------------------------------------------------------===//
17// spirv.INTEL.JointMatrixLoad
18//===----------------------------------------------------------------------===//
19
20static LogicalResult
21verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
22 Type pointeeType = llvm::cast<spirv::PointerType>(Val&: pointer).getPointeeType();
23 if (!llvm::isa<spirv::ScalarType>(Val: pointeeType) &&
24 !llvm::isa<VectorType>(Val: pointeeType))
25 return op->emitError(
26 message: "Pointer must point to a scalar or vector type but provided ")
27 << pointeeType;
28 spirv::StorageClass storage =
29 llvm::cast<spirv::PointerType>(pointer).getStorageClass();
30 if (storage != spirv::StorageClass::Workgroup &&
31 storage != spirv::StorageClass::CrossWorkgroup &&
32 storage != spirv::StorageClass::UniformConstant &&
33 storage != spirv::StorageClass::Generic)
34 return op->emitError(message: "Pointer storage class must be Workgroup or "
35 "CrossWorkgroup but provided ")
36 << stringifyStorageClass(storage);
37 return success();
38}
39
40LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
41 return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
42 getResult().getType());
43}
44
45//===----------------------------------------------------------------------===//
46// spirv.INTEL.JointMatrixStore
47//===----------------------------------------------------------------------===//
48
49LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
50 return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
51 getObject().getType());
52}
53
54//===----------------------------------------------------------------------===//
55// spirv.INTEL.JointMatrixMad
56//===----------------------------------------------------------------------===//
57
58static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
59 if (op.getC().getType() != op.getResult().getType())
60 return op.emitOpError("result and third operand must have the same type");
61 auto typeA = llvm::cast<spirv::JointMatrixINTELType>(op.getA().getType());
62 auto typeB = llvm::cast<spirv::JointMatrixINTELType>(op.getB().getType());
63 auto typeC = llvm::cast<spirv::JointMatrixINTELType>(op.getC().getType());
64 auto typeR =
65 llvm::cast<spirv::JointMatrixINTELType>(op.getResult().getType());
66 if (typeA.getRows() != typeR.getRows() ||
67 typeA.getColumns() != typeB.getRows() ||
68 typeB.getColumns() != typeR.getColumns())
69 return op.emitOpError("matrix size must match");
70 if (typeR.getScope() != typeA.getScope() ||
71 typeR.getScope() != typeB.getScope() ||
72 typeR.getScope() != typeC.getScope())
73 return op.emitOpError("matrix scope must match");
74 if (typeA.getElementType() != typeB.getElementType() ||
75 typeR.getElementType() != typeC.getElementType())
76 return op.emitOpError("matrix element type must match");
77 return success();
78}
79
80LogicalResult spirv::INTELJointMatrixMadOp::verify() {
81 return verifyJointMatrixMad(*this);
82}
83
84} // namespace mlir
85

source code of mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp