1//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the Integer Dot Product operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
18#include "llvm/Support/FormatVariadic.h"
19
20using namespace mlir::spirv::AttrNames;
21
22namespace mlir::spirv {
23
24//===----------------------------------------------------------------------===//
25// Integer Dot Product ops
26//===----------------------------------------------------------------------===//
27
28template <typename IntegerDotProductOpTy>
29static LogicalResult verifyIntegerDotProduct(Operation *op) {
30 assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
31 "Not an integer dot product op?");
32 assert(op->getNumResults() == 1 && "Expected a single result");
33
34 // ODS enforces that vector 1 and vector 2, and result and the accumulator
35 // have the same types.
36 Type factorTy = op->getOperand(idx: 0).getType();
37 StringAttr packedVectorFormatAttrName =
38 IntegerDotProductOpTy::getFormatAttrName(op->getName());
39 if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
40 auto packedVectorFormat =
41 llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
42 op->getAttr(packedVectorFormatAttrName));
43 if (!packedVectorFormat)
44 return op->emitOpError(message: "requires Packed Vector Format attribute for "
45 "integer vector operands");
46
47 assert(packedVectorFormat.getValue() ==
48 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
49 "Unknown Packed Vector Format");
50 if (intTy.getWidth() != 32)
51 return op->emitOpError(
52 message: llvm::formatv("with specified Packed Vector Format ({0}) requires "
53 "integer vector operands to be 32-bits wide",
54 packedVectorFormat.getValue()));
55 } else {
56 if (op->hasAttr(packedVectorFormatAttrName))
57 return op->emitOpError(message: llvm::formatv(
58 Fmt: "with invalid format attribute for vector operands of type '{0}'",
59 Vals&: factorTy));
60 }
61
62 Type resultTy = op->getResultTypes().front();
63 unsigned factorBitWidth = getBitWidth(type: factorTy);
64 unsigned resultBitWidth = getBitWidth(type: resultTy);
65 if (factorBitWidth > resultBitWidth)
66 return op->emitOpError(
67 message: llvm::formatv(Fmt: "result type has insufficient bit-width ({0} bits) "
68 "for the specified vector operand type ({1} bits)",
69 Vals&: resultBitWidth, Vals&: factorBitWidth));
70
71 return success();
72}
73
74static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
75 return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
76}
77
78static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
79 return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
80}
81
82static SmallVector<ArrayRef<spirv::Extension>, 1>
83getIntegerDotProductExtensions() {
84 // Requires the SPV_KHR_integer_dot_product extension, specified either
85 // explicitly or implied by target env's SPIR-V version >= 1.6.
86 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
87 return {extension};
88}
89
90template <typename IntegerDotProductOpTy>
91static SmallVector<ArrayRef<spirv::Capability>, 1>
92getIntegerDotProductCapabilities(Operation *op) {
93 // Requires the the DotProduct capability and capabilities that depend on
94 // exact op types.
95 static const auto dotProductCap = spirv::Capability::DotProduct;
96 static const auto dotProductInput4x8BitPackedCap =
97 spirv::Capability::DotProductInput4x8BitPacked;
98 static const auto dotProductInput4x8BitCap =
99 spirv::Capability::DotProductInput4x8Bit;
100 static const auto dotProductInputAllCap =
101 spirv::Capability::DotProductInputAll;
102
103 SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
104
105 Type factorTy = op->getOperand(idx: 0).getType();
106 StringAttr packedVectorFormatAttrName =
107 IntegerDotProductOpTy::getFormatAttrName(op->getName());
108 if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
109 auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
110 op->getAttr(packedVectorFormatAttrName));
111 if (formatAttr.getValue() ==
112 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
113 capabilities.push_back(Elt: dotProductInput4x8BitPackedCap);
114
115 return capabilities;
116 }
117
118 auto vecTy = llvm::cast<VectorType>(factorTy);
119 if (vecTy.getElementTypeBitWidth() == 8) {
120 capabilities.push_back(Elt: dotProductInput4x8BitCap);
121 return capabilities;
122 }
123
124 capabilities.push_back(Elt: dotProductInputAllCap);
125 return capabilities;
126}
127
128#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
129 LogicalResult OpName::verify() { \
130 return verifyIntegerDotProduct<OpName>(*this); \
131 } \
132 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
133 return getIntegerDotProductExtensions(); \
134 } \
135 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
136 return getIntegerDotProductCapabilities<OpName>(*this); \
137 } \
138 std::optional<spirv::Version> OpName::getMinVersion() { \
139 return getIntegerDotProductMinVersion(); \
140 } \
141 std::optional<spirv::Version> OpName::getMaxVersion() { \
142 return getIntegerDotProductMaxVersion(); \
143 }
144
145SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
146SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotOp)
147SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotOp)
148SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotAccSatOp)
149SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotAccSatOp)
150SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotAccSatOp)
151
152#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
153
154} // namespace mlir::spirv
155

source code of mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp