1 | //===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Defines the Integer Dot Product operations in the SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
14 | |
15 | #include "SPIRVOpUtils.h" |
16 | #include "SPIRVParsingUtils.h" |
17 | |
18 | #include "llvm/Support/FormatVariadic.h" |
19 | |
20 | using namespace mlir::spirv::AttrNames; |
21 | |
22 | namespace mlir::spirv { |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // Integer Dot Product ops |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | template <typename IntegerDotProductOpTy> |
29 | static LogicalResult verifyIntegerDotProduct(Operation *op) { |
30 | assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) && |
31 | "Not an integer dot product op?" ); |
32 | assert(op->getNumResults() == 1 && "Expected a single result" ); |
33 | |
34 | // ODS enforces that vector 1 and vector 2, and result and the accumulator |
35 | // have the same types. |
36 | Type factorTy = op->getOperand(idx: 0).getType(); |
37 | StringAttr packedVectorFormatAttrName = |
38 | IntegerDotProductOpTy::getFormatAttrName(op->getName()); |
39 | if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) { |
40 | auto packedVectorFormat = |
41 | llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>( |
42 | op->getAttr(packedVectorFormatAttrName)); |
43 | if (!packedVectorFormat) |
44 | return op->emitOpError(message: "requires Packed Vector Format attribute for " |
45 | "integer vector operands" ); |
46 | |
47 | assert(packedVectorFormat.getValue() == |
48 | spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && |
49 | "Unknown Packed Vector Format" ); |
50 | if (intTy.getWidth() != 32) |
51 | return op->emitOpError( |
52 | message: llvm::formatv("with specified Packed Vector Format ({0}) requires " |
53 | "integer vector operands to be 32-bits wide" , |
54 | packedVectorFormat.getValue())); |
55 | } else { |
56 | if (op->hasAttr(packedVectorFormatAttrName)) |
57 | return op->emitOpError(message: llvm::formatv( |
58 | Fmt: "with invalid format attribute for vector operands of type '{0}'" , |
59 | Vals&: factorTy)); |
60 | } |
61 | |
62 | Type resultTy = op->getResultTypes().front(); |
63 | unsigned factorBitWidth = getBitWidth(type: factorTy); |
64 | unsigned resultBitWidth = getBitWidth(type: resultTy); |
65 | if (factorBitWidth > resultBitWidth) |
66 | return op->emitOpError( |
67 | message: llvm::formatv(Fmt: "result type has insufficient bit-width ({0} bits) " |
68 | "for the specified vector operand type ({1} bits)" , |
69 | Vals&: resultBitWidth, Vals&: factorBitWidth)); |
70 | |
71 | return success(); |
72 | } |
73 | |
74 | static std::optional<spirv::Version> getIntegerDotProductMinVersion() { |
75 | return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0. |
76 | } |
77 | |
78 | static std::optional<spirv::Version> getIntegerDotProductMaxVersion() { |
79 | return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. |
80 | } |
81 | |
82 | static SmallVector<ArrayRef<spirv::Extension>, 1> |
83 | getIntegerDotProductExtensions() { |
84 | // Requires the SPV_KHR_integer_dot_product extension, specified either |
85 | // explicitly or implied by target env's SPIR-V version >= 1.6. |
86 | static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product; |
87 | return {extension}; |
88 | } |
89 | |
90 | template <typename IntegerDotProductOpTy> |
91 | static SmallVector<ArrayRef<spirv::Capability>, 1> |
92 | getIntegerDotProductCapabilities(Operation *op) { |
93 | // Requires the the DotProduct capability and capabilities that depend on |
94 | // exact op types. |
95 | static const auto dotProductCap = spirv::Capability::DotProduct; |
96 | static const auto dotProductInput4x8BitPackedCap = |
97 | spirv::Capability::DotProductInput4x8BitPacked; |
98 | static const auto dotProductInput4x8BitCap = |
99 | spirv::Capability::DotProductInput4x8Bit; |
100 | static const auto dotProductInputAllCap = |
101 | spirv::Capability::DotProductInputAll; |
102 | |
103 | SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap}; |
104 | |
105 | Type factorTy = op->getOperand(idx: 0).getType(); |
106 | StringAttr packedVectorFormatAttrName = |
107 | IntegerDotProductOpTy::getFormatAttrName(op->getName()); |
108 | if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) { |
109 | auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>( |
110 | op->getAttr(packedVectorFormatAttrName)); |
111 | if (formatAttr.getValue() == |
112 | spirv::PackedVectorFormat::PackedVectorFormat4x8Bit) |
113 | capabilities.push_back(Elt: dotProductInput4x8BitPackedCap); |
114 | |
115 | return capabilities; |
116 | } |
117 | |
118 | auto vecTy = llvm::cast<VectorType>(factorTy); |
119 | if (vecTy.getElementTypeBitWidth() == 8) { |
120 | capabilities.push_back(Elt: dotProductInput4x8BitCap); |
121 | return capabilities; |
122 | } |
123 | |
124 | capabilities.push_back(Elt: dotProductInputAllCap); |
125 | return capabilities; |
126 | } |
127 | |
128 | #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \ |
129 | LogicalResult OpName::verify() { \ |
130 | return verifyIntegerDotProduct<OpName>(*this); \ |
131 | } \ |
132 | SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \ |
133 | return getIntegerDotProductExtensions(); \ |
134 | } \ |
135 | SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \ |
136 | return getIntegerDotProductCapabilities<OpName>(*this); \ |
137 | } \ |
138 | std::optional<spirv::Version> OpName::getMinVersion() { \ |
139 | return getIntegerDotProductMinVersion(); \ |
140 | } \ |
141 | std::optional<spirv::Version> OpName::getMaxVersion() { \ |
142 | return getIntegerDotProductMaxVersion(); \ |
143 | } |
144 | |
145 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp) |
146 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotOp) |
147 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotOp) |
148 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotAccSatOp) |
149 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotAccSatOp) |
150 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotAccSatOp) |
151 | |
152 | #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP |
153 | |
154 | } // namespace mlir::spirv |
155 | |