| 1 | //===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Defines the Cooperative Matrix operations in the SPIR-V dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "SPIRVParsingUtils.h" |
| 14 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| 15 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 16 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 17 | #include "llvm/ADT/STLExtras.h" |
| 18 | |
| 19 | using namespace mlir::spirv::AttrNames; |
| 20 | |
| 21 | namespace mlir::spirv { |
| 22 | |
| 23 | static LogicalResult |
| 24 | verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, |
| 25 | spirv::MemoryAccessAttr memoryOperand, |
| 26 | IntegerAttr alignment) { |
| 27 | auto pointerType = cast<PointerType>(Val&: pointer); |
| 28 | Type pointeeType = pointerType.getPointeeType(); |
| 29 | if (!isa<ScalarType, VectorType>(Val: pointeeType)) { |
| 30 | return op->emitOpError( |
| 31 | message: "Pointer must point to a scalar or vector type but provided " ) |
| 32 | << pointeeType; |
| 33 | } |
| 34 | |
| 35 | if (memoryOperand) { |
| 36 | spirv::MemoryAccess operandSet = memoryOperand.getValue(); |
| 37 | |
| 38 | if (isa<spirv::KHRCooperativeMatrixLoadOp>(Val: op) && |
| 39 | spirv::bitEnumContainsAll(bits: operandSet, |
| 40 | bit: spirv::MemoryAccess::MakePointerAvailable)) { |
| 41 | return op->emitOpError( |
| 42 | message: "not compatible with memory operand 'MakePointerAvailable'" ); |
| 43 | } |
| 44 | |
| 45 | if (isa<spirv::KHRCooperativeMatrixStoreOp>(Val: op) && |
| 46 | spirv::bitEnumContainsAll(bits: operandSet, |
| 47 | bit: spirv::MemoryAccess::MakePointerVisible)) { |
| 48 | return op->emitOpError( |
| 49 | message: "not compatible with memory operand 'MakePointerVisible'" ); |
| 50 | } |
| 51 | |
| 52 | // TODO: Need to check that NonPrivatePointer is set for MakePointer*. See |
| 53 | // #145485. |
| 54 | |
| 55 | if (spirv::bitEnumContainsAll(bits: operandSet, bit: spirv::MemoryAccess::Aligned) && |
| 56 | !alignment) { |
| 57 | return op->emitOpError(message: "missing value for the 'Aligned' memory operand" ); |
| 58 | } |
| 59 | |
| 60 | if (!spirv::bitEnumContainsAll(bits: operandSet, bit: spirv::MemoryAccess::Aligned) && |
| 61 | alignment) { |
| 62 | return op->emitOpError( |
| 63 | message: "found alignment attribute for non-'Aligned' memory operand" ); |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | // TODO: Verify the memory object behind the pointer: |
| 68 | // > If the Shader capability was declared, Pointer must point into an array |
| 69 | // > and any ArrayStride decoration on Pointer is ignored. |
| 70 | |
| 71 | return success(); |
| 72 | } |
| 73 | |
| 74 | //===----------------------------------------------------------------------===// |
| 75 | // spirv.KHR.CooperativeMatrixLoad |
| 76 | //===----------------------------------------------------------------------===// |
| 77 | |
| 78 | LogicalResult KHRCooperativeMatrixLoadOp::verify() { |
| 79 | return verifyCoopMatrixAccess(op: *this, pointer: getPointer().getType(), |
| 80 | coopMatrix: getResult().getType(), memoryOperand: getMemoryOperandAttr(), |
| 81 | alignment: getAlignmentAttr()); |
| 82 | } |
| 83 | |
| 84 | //===----------------------------------------------------------------------===// |
| 85 | // spirv.KHR.CooperativeMatrixStore |
| 86 | //===----------------------------------------------------------------------===// |
| 87 | |
| 88 | LogicalResult KHRCooperativeMatrixStoreOp::verify() { |
| 89 | return verifyCoopMatrixAccess(op: *this, pointer: getPointer().getType(), |
| 90 | coopMatrix: getObject().getType(), memoryOperand: getMemoryOperandAttr(), |
| 91 | alignment: getAlignmentAttr()); |
| 92 | } |
| 93 | |
| 94 | //===----------------------------------------------------------------------===// |
| 95 | // spirv.KHR.CooperativeMatrixMulAdd |
| 96 | //===----------------------------------------------------------------------===// |
| 97 | |
| 98 | LogicalResult KHRCooperativeMatrixMulAddOp::verify() { |
| 99 | auto typeA = cast<spirv::CooperativeMatrixType>(Val: getA().getType()); |
| 100 | auto typeB = cast<spirv::CooperativeMatrixType>(Val: getB().getType()); |
| 101 | auto typeC = cast<spirv::CooperativeMatrixType>(Val: getC().getType()); |
| 102 | |
| 103 | // Check element types. ODS enforces that `type(c) == type(result)`, so no |
| 104 | // need to check it here. |
| 105 | |
| 106 | // Check the 'use' part of the type against the operands and the result. |
| 107 | if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA) |
| 108 | return emitOpError(message: "operand #0 must be of use 'MatrixA'" ); |
| 109 | if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB) |
| 110 | return emitOpError(message: "operand #1 must be of use 'MatrixB'" ); |
| 111 | if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc) |
| 112 | return emitOpError(message: "operand #2 must be of use 'MatrixAcc'" ); |
| 113 | |
| 114 | // Check the 'scope' part of the type. |
| 115 | if (!llvm::all_equal(Values: {typeA.getScope(), typeB.getScope(), typeC.getScope()})) |
| 116 | return emitOpError(message: "matrix scope mismatch" ); |
| 117 | |
| 118 | // Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'. |
| 119 | if (typeA.getRows() != typeC.getRows()) |
| 120 | return emitOpError(message: "matrix size mismatch on dimension 'M'" ); |
| 121 | if (typeB.getColumns() != typeC.getColumns()) |
| 122 | return emitOpError(message: "matrix size mismatch on dimension 'N'" ); |
| 123 | if (typeA.getColumns() != typeB.getRows()) |
| 124 | return emitOpError(message: "matrix size mismatch on dimension 'K'" ); |
| 125 | |
| 126 | // The spec does not restrict the element types: |
| 127 | // > A, B, C, and Result Type need not necessarily have the same component |
| 128 | // > type, this is defined by the client API. |
| 129 | |
| 130 | // Check that if Cooperative Matrix Operands are provided, the element type |
| 131 | // is integer. |
| 132 | if (getMatrixOperands()) { |
| 133 | Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(), |
| 134 | typeC.getElementType()}; |
| 135 | if (!llvm::all_of(Range&: elementTypes, P: llvm::IsaPred<IntegerType>)) { |
| 136 | return emitOpError(message: "Matrix Operands require all matrix element types to " |
| 137 | "be Integer Types" ); |
| 138 | } |
| 139 | } |
| 140 | |
| 141 | // Any further requirements need to be checked against VCE. |
| 142 | return success(); |
| 143 | } |
| 144 | |
| 145 | } // namespace mlir::spirv |
| 146 | |