1//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the Dot Product operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
18#include "llvm/Support/FormatVariadic.h"
19
20using namespace mlir::spirv::AttrNames;
21
22namespace mlir::spirv {
23
24//===----------------------------------------------------------------------===//
25// Dot Product ops
26//===----------------------------------------------------------------------===//
27
28static std::optional<spirv::Version> getDotProductMinVersion() {
29 return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
30}
31
32static std::optional<spirv::Version> getDotProductMaxVersion() {
33 return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
34}
35
36SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
37 if (isa<BFloat16Type>(Val: getType())) {
38 static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
39 return {extension};
40 }
41
42 return {};
43}
44
45SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
46 if (isa<BFloat16Type>(Val: getType())) {
47 static const auto capability = spirv::Capability::BFloat16DotProductKHR;
48 return {capability};
49 }
50
51 return {};
52}
53
54std::optional<spirv::Version> DotOp::getMinVersion() {
55 return getDotProductMinVersion();
56}
57
58std::optional<spirv::Version> DotOp::getMaxVersion() {
59 return getDotProductMaxVersion();
60}
61
62//===----------------------------------------------------------------------===//
63// Integer Dot Product ops
64//===----------------------------------------------------------------------===//
65
66template <typename IntegerDotProductOpTy>
67static LogicalResult verifyIntegerDotProduct(Operation *op) {
68 assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
69 "Not an integer dot product op?");
70 assert(op->getNumResults() == 1 && "Expected a single result");
71
72 // ODS enforces that vector 1 and vector 2, and result and the accumulator
73 // have the same types.
74 Type factorTy = op->getOperand(idx: 0).getType();
75 StringAttr packedVectorFormatAttrName =
76 IntegerDotProductOpTy::getFormatAttrName(op->getName());
77 if (auto intTy = llvm::dyn_cast<IntegerType>(Val&: factorTy)) {
78 auto packedVectorFormat =
79 llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
80 Val: op->getAttr(name: packedVectorFormatAttrName));
81 if (!packedVectorFormat)
82 return op->emitOpError(message: "requires Packed Vector Format attribute for "
83 "integer vector operands");
84
85 assert(packedVectorFormat.getValue() ==
86 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
87 "Unknown Packed Vector Format");
88 if (intTy.getWidth() != 32)
89 return op->emitOpError(
90 message: llvm::formatv(Fmt: "with specified Packed Vector Format ({0}) requires "
91 "integer vector operands to be 32-bits wide",
92 Vals: packedVectorFormat.getValue()));
93 } else {
94 if (op->hasAttr(name: packedVectorFormatAttrName))
95 return op->emitOpError(message: llvm::formatv(
96 Fmt: "with invalid format attribute for vector operands of type '{0}'",
97 Vals&: factorTy));
98 }
99
100 Type resultTy = op->getResultTypes().front();
101 unsigned factorBitWidth = getBitWidth(type: factorTy);
102 unsigned resultBitWidth = getBitWidth(type: resultTy);
103 if (factorBitWidth > resultBitWidth)
104 return op->emitOpError(
105 message: llvm::formatv(Fmt: "result type has insufficient bit-width ({0} bits) "
106 "for the specified vector operand type ({1} bits)",
107 Vals&: resultBitWidth, Vals&: factorBitWidth));
108
109 return success();
110}
111
112static SmallVector<ArrayRef<spirv::Extension>, 1>
113getIntegerDotProductExtensions() {
114 // Requires the SPV_KHR_integer_dot_product extension, specified either
115 // explicitly or implied by target env's SPIR-V version >= 1.6.
116 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
117 return {extension};
118}
119
120template <typename IntegerDotProductOpTy>
121static SmallVector<ArrayRef<spirv::Capability>, 1>
122getIntegerDotProductCapabilities(Operation *op) {
123 // Requires the the DotProduct capability and capabilities that depend on
124 // exact op types.
125 static const auto dotProductCap = spirv::Capability::DotProduct;
126 static const auto dotProductInput4x8BitPackedCap =
127 spirv::Capability::DotProductInput4x8BitPacked;
128 static const auto dotProductInput4x8BitCap =
129 spirv::Capability::DotProductInput4x8Bit;
130 static const auto dotProductInputAllCap =
131 spirv::Capability::DotProductInputAll;
132
133 SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
134
135 Type factorTy = op->getOperand(idx: 0).getType();
136 StringAttr packedVectorFormatAttrName =
137 IntegerDotProductOpTy::getFormatAttrName(op->getName());
138 if (auto intTy = llvm::dyn_cast<IntegerType>(Val&: factorTy)) {
139 auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
140 Val: op->getAttr(name: packedVectorFormatAttrName));
141 if (formatAttr.getValue() ==
142 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
143 capabilities.push_back(Elt: dotProductInput4x8BitPackedCap);
144
145 return capabilities;
146 }
147
148 auto vecTy = llvm::cast<VectorType>(Val&: factorTy);
149 if (vecTy.getElementTypeBitWidth() == 8) {
150 capabilities.push_back(Elt: dotProductInput4x8BitCap);
151 return capabilities;
152 }
153
154 capabilities.push_back(Elt: dotProductInputAllCap);
155 return capabilities;
156}
157
158#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
159 LogicalResult OpName::verify() { \
160 return verifyIntegerDotProduct<OpName>(*this); \
161 } \
162 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
163 return getIntegerDotProductExtensions(); \
164 } \
165 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
166 return getIntegerDotProductCapabilities<OpName>(*this); \
167 } \
168 std::optional<spirv::Version> OpName::getMinVersion() { \
169 return getDotProductMinVersion(); \
170 } \
171 std::optional<spirv::Version> OpName::getMaxVersion() { \
172 return getDotProductMaxVersion(); \
173 }
174
175SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
176SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotOp)
177SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotOp)
178SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotAccSatOp)
179SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotAccSatOp)
180SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotAccSatOp)
181
182#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
183
184} // namespace mlir::spirv
185

source code of mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp