| 1 | //===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Defines the Cooperative Matrix operations in the SPIR-V dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "SPIRVParsingUtils.h" |
| 14 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| 15 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| 16 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| 17 | #include "llvm/ADT/STLExtras.h" |
| 18 | #include <cstdint> |
| 19 | |
| 20 | using namespace mlir::spirv::AttrNames; |
| 21 | |
| 22 | namespace mlir::spirv { |
| 23 | |
| 24 | static LogicalResult |
| 25 | verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, |
| 26 | spirv::MemoryAccessAttr memoryOperand) { |
| 27 | auto pointerType = cast<PointerType>(Val&: pointer); |
| 28 | Type pointeeType = pointerType.getPointeeType(); |
| 29 | if (!isa<ScalarType, VectorType>(Val: pointeeType)) { |
| 30 | return op->emitOpError( |
| 31 | message: "Pointer must point to a scalar or vector type but provided " ) |
| 32 | << pointeeType; |
| 33 | } |
| 34 | |
| 35 | if (memoryOperand) { |
| 36 | spirv::MemoryAccess operandSet = memoryOperand.getValue(); |
| 37 | |
| 38 | if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) && |
| 39 | spirv::bitEnumContainsAll(operandSet, |
| 40 | spirv::MemoryAccess::MakePointerAvailable)) { |
| 41 | return op->emitOpError( |
| 42 | message: "not compatible with memory operand 'MakePointerAvailable'" ); |
| 43 | } |
| 44 | |
| 45 | if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) && |
| 46 | spirv::bitEnumContainsAll(operandSet, |
| 47 | spirv::MemoryAccess::MakePointerVisible)) { |
| 48 | return op->emitOpError( |
| 49 | message: "not compatible with memory operand 'MakePointerVisible'" ); |
| 50 | } |
| 51 | |
| 52 | // The 'Aligned' memory operand requires an alignment literal to follow, |
| 53 | // which needs to be implemented on the level of op parsing and |
| 54 | // (de-)serialization. |
| 55 | // TODO: Consider adding support for this attribute value. |
| 56 | if (spirv::bitEnumContainsAll(memoryOperand.getValue(), |
| 57 | spirv::MemoryAccess::Aligned)) { |
| 58 | return op->emitOpError(message: "has unhandled memory operand 'Aligned'" ); |
| 59 | } |
| 60 | } |
| 61 | |
| 62 | // TODO: Verify the memory object behind the pointer: |
| 63 | // > If the Shader capability was declared, Pointer must point into an array |
| 64 | // > and any ArrayStride decoration on Pointer is ignored. |
| 65 | |
| 66 | return success(); |
| 67 | } |
| 68 | |
| 69 | //===----------------------------------------------------------------------===// |
| 70 | // spirv.KHR.CooperativeMatrixLoad |
| 71 | //===----------------------------------------------------------------------===// |
| 72 | |
| 73 | LogicalResult KHRCooperativeMatrixLoadOp::verify() { |
| 74 | return verifyCoopMatrixAccess(*this, getPointer().getType(), |
| 75 | getResult().getType(), getMemoryOperandAttr()); |
| 76 | } |
| 77 | |
| 78 | //===----------------------------------------------------------------------===// |
| 79 | // spirv.KHR.CooperativeMatrixStore |
| 80 | //===----------------------------------------------------------------------===// |
| 81 | |
| 82 | LogicalResult KHRCooperativeMatrixStoreOp::verify() { |
| 83 | return verifyCoopMatrixAccess(*this, getPointer().getType(), |
| 84 | getObject().getType(), getMemoryOperandAttr()); |
| 85 | } |
| 86 | |
| 87 | //===----------------------------------------------------------------------===// |
| 88 | // spirv.KHR.CooperativeMatrixMulAdd |
| 89 | //===----------------------------------------------------------------------===// |
| 90 | |
| 91 | LogicalResult KHRCooperativeMatrixMulAddOp::verify() { |
| 92 | auto typeA = cast<spirv::CooperativeMatrixType>(getA().getType()); |
| 93 | auto typeB = cast<spirv::CooperativeMatrixType>(getB().getType()); |
| 94 | auto typeC = cast<spirv::CooperativeMatrixType>(getC().getType()); |
| 95 | |
| 96 | // Check element types. ODS enforces that `type(c) == type(result)`, so no |
| 97 | // need to check it here. |
| 98 | |
| 99 | // Check the 'use' part of the type against the operands and the result. |
| 100 | if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA) |
| 101 | return emitOpError("operand #0 must be of use 'MatrixA'" ); |
| 102 | if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB) |
| 103 | return emitOpError("operand #1 must be of use 'MatrixB'" ); |
| 104 | if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc) |
| 105 | return emitOpError("operand #2 must be of use 'MatrixAcc'" ); |
| 106 | |
| 107 | // Check the 'scope' part of the type. |
| 108 | if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()})) |
| 109 | return emitOpError("matrix scope mismatch" ); |
| 110 | |
| 111 | // Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'. |
| 112 | if (typeA.getRows() != typeC.getRows()) |
| 113 | return emitOpError("matrix size mismatch on dimension 'M'" ); |
| 114 | if (typeB.getColumns() != typeC.getColumns()) |
| 115 | return emitOpError("matrix size mismatch on dimension 'N'" ); |
| 116 | if (typeA.getColumns() != typeB.getRows()) |
| 117 | return emitOpError("matrix size mismatch on dimension 'K'" ); |
| 118 | |
| 119 | // The spec does not restrict the element types: |
| 120 | // > A, B, C, and Result Type need not necessarily have the same component |
| 121 | // > type, this is defined by the client API. |
| 122 | |
| 123 | // Check that if Cooperative Matrix Operands are provided, the element type |
| 124 | // is integer. |
| 125 | if (getMatrixOperands()) { |
| 126 | Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(), |
| 127 | typeC.getElementType()}; |
| 128 | if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) { |
| 129 | return emitOpError("Matrix Operands require all matrix element types to " |
| 130 | "be Integer Types" ); |
| 131 | } |
| 132 | } |
| 133 | |
| 134 | // Any further requirements need to be checked against VCE. |
| 135 | return success(); |
| 136 | } |
| 137 | |
| 138 | } // namespace mlir::spirv |
| 139 | |