1 | //===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef TYPE_DETAIL_H_ |
10 | #define TYPE_DETAIL_H_ |
11 | |
12 | #include "mlir/IR/BuiltinTypes.h" |
13 | #include "mlir/IR/TypeSupport.h" |
14 | #include "mlir/IR/Types.h" |
15 | #include "llvm/ADT/DenseMap.h" |
16 | #include "llvm/ADT/Hashing.h" |
17 | #include "llvm/ADT/bit.h" |
18 | |
19 | namespace mlir { |
20 | namespace quant { |
21 | namespace detail { |
22 | |
23 | struct QuantizedTypeStorage : public mlir::TypeStorage { |
24 | QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType, |
25 | int64_t storageTypeMin, int64_t storageTypeMax) |
26 | : flags(flags), storageType(storageType), expressedType(expressedType), |
27 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
28 | |
29 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
30 | unsigned flags; |
31 | |
32 | // Integral type for the storage point representation. |
33 | Type storageType; |
34 | |
35 | // Floating point type that the quantized type approximates. |
36 | Type expressedType; |
37 | |
38 | // The minimum value storageType can take. |
39 | int64_t storageTypeMin; |
40 | |
41 | // The maximum value storageType can take. |
42 | int64_t storageTypeMax; |
43 | }; |
44 | |
45 | struct AnyQuantizedTypeStorage : public QuantizedTypeStorage { |
46 | struct KeyTy { |
47 | KeyTy(unsigned flags, Type storageType, Type expressedType, |
48 | int64_t storageTypeMin, int64_t storageTypeMax) |
49 | : flags(flags), storageType(storageType), expressedType(expressedType), |
50 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
51 | unsigned flags; |
52 | Type storageType; |
53 | Type expressedType; |
54 | int64_t storageTypeMin; |
55 | int64_t storageTypeMax; |
56 | |
57 | // Check for equality of two structures that share KeyTy data members |
58 | // (by name). |
59 | template <typename T, typename U> |
60 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
61 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
62 | lhs.expressedType == rhs.expressedType && |
63 | lhs.storageTypeMin == rhs.storageTypeMin && |
64 | lhs.storageTypeMax == rhs.storageTypeMax; |
65 | } |
66 | |
67 | bool operator==(const KeyTy &other) const { |
68 | return genericIsEqual(lhs: *this, rhs: other); |
69 | } |
70 | |
71 | unsigned getHashValue() const { |
72 | return llvm::hash_combine(args: flags, args: storageType, args: expressedType, |
73 | args: storageTypeMin, args: storageTypeMax); |
74 | } |
75 | }; |
76 | |
77 | AnyQuantizedTypeStorage(const KeyTy &key) |
78 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
79 | key.storageTypeMin, key.storageTypeMax) {} |
80 | |
81 | bool operator==(const KeyTy &key) const { |
82 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
83 | } |
84 | |
85 | /// Construction. |
86 | static AnyQuantizedTypeStorage *construct(TypeStorageAllocator &allocator, |
87 | const KeyTy &key) { |
88 | return new (allocator.allocate<AnyQuantizedTypeStorage>()) |
89 | AnyQuantizedTypeStorage(key); |
90 | } |
91 | |
92 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
93 | }; |
94 | |
95 | struct UniformQuantizedTypeStorage : public QuantizedTypeStorage { |
96 | struct KeyTy { |
97 | KeyTy(unsigned flags, Type storageType, Type expressedType, double scale, |
98 | int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) |
99 | : flags(flags), storageType(storageType), expressedType(expressedType), |
100 | scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin), |
101 | storageTypeMax(storageTypeMax) {} |
102 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
103 | unsigned flags; |
104 | |
105 | // Integral type for the storage point representation. |
106 | Type storageType; |
107 | |
108 | // Floating point type that the quantized type approximates. |
109 | Type expressedType; |
110 | |
111 | double scale; |
112 | int64_t zeroPoint; |
113 | int64_t storageTypeMin; |
114 | int64_t storageTypeMax; |
115 | |
116 | // Check for equality of two structures that share KeyTy data members |
117 | // (by name). |
118 | template <typename T, typename U> |
119 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
120 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
121 | lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale && |
122 | lhs.zeroPoint == rhs.zeroPoint && |
123 | lhs.storageTypeMin == rhs.storageTypeMin && |
124 | lhs.storageTypeMax == rhs.storageTypeMax; |
125 | } |
126 | |
127 | bool operator==(const KeyTy &other) const { |
128 | return genericIsEqual(lhs: *this, rhs: other); |
129 | } |
130 | |
131 | unsigned getHashValue() const { |
132 | int64_t scaleBits = llvm::bit_cast<int64_t>(from: scale); |
133 | return llvm::hash_combine(args: flags, args: storageType, args: expressedType, args: scaleBits, |
134 | args: zeroPoint, args: storageTypeMin, args: storageTypeMax); |
135 | } |
136 | }; |
137 | |
138 | UniformQuantizedTypeStorage(const KeyTy &key) |
139 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
140 | key.storageTypeMin, key.storageTypeMax), |
141 | scale(key.scale), zeroPoint(key.zeroPoint) {} |
142 | |
143 | bool operator==(const KeyTy &key) const { |
144 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
145 | } |
146 | |
147 | /// Construction. |
148 | static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator, |
149 | const KeyTy &key) { |
150 | return new (allocator.allocate<UniformQuantizedTypeStorage>()) |
151 | UniformQuantizedTypeStorage(key); |
152 | } |
153 | |
154 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
155 | |
156 | double scale; |
157 | int64_t zeroPoint; |
158 | }; |
159 | |
160 | struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage { |
161 | struct KeyTy { |
162 | KeyTy(unsigned flags, Type storageType, Type expressedType, |
163 | ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, |
164 | int32_t quantizedDimension, int64_t storageTypeMin, |
165 | int64_t storageTypeMax) |
166 | : flags(flags), storageType(storageType), expressedType(expressedType), |
167 | scales(scales), zeroPoints(zeroPoints), |
168 | quantizedDimension(quantizedDimension), |
169 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
170 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
171 | unsigned flags; |
172 | |
173 | // Integral type for the storage point representation. |
174 | Type storageType; |
175 | |
176 | // Floating point type that the quantized type approximates. |
177 | Type expressedType; |
178 | |
179 | ArrayRef<double> scales; |
180 | ArrayRef<int64_t> zeroPoints; |
181 | int32_t quantizedDimension; |
182 | int64_t storageTypeMin; |
183 | int64_t storageTypeMax; |
184 | |
185 | ArrayRef<double> getScales() const { return scales; } |
186 | |
187 | ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; } |
188 | |
189 | // Check for equality of two structures that share KeyTy data members |
190 | // (by name). |
191 | template <typename T, typename U> |
192 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
193 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
194 | lhs.expressedType == rhs.expressedType && |
195 | lhs.getScales() == rhs.getScales() && |
196 | lhs.getZeroPoints() == rhs.getZeroPoints() && |
197 | lhs.quantizedDimension == rhs.quantizedDimension && |
198 | lhs.storageTypeMin == rhs.storageTypeMin && |
199 | lhs.storageTypeMax == rhs.storageTypeMax; |
200 | } |
201 | |
202 | bool operator==(const KeyTy &other) const { |
203 | return genericIsEqual(lhs: *this, rhs: other); |
204 | } |
205 | |
206 | unsigned getHashValue() const { |
207 | int64_t *scalesCast = llvm::bit_cast<int64_t *>(from: scales.data()); |
208 | ArrayRef<int64_t> scalesBits(scalesCast, scales.size()); |
209 | return llvm::hash_combine( |
210 | args: flags, args: storageType, args: expressedType, |
211 | args: llvm::hash_combine_range(first: scalesBits.begin(), last: scalesBits.end()), |
212 | args: llvm::hash_combine_range(first: zeroPoints.begin(), last: zeroPoints.end()), |
213 | args: storageTypeMin, args: storageTypeMax); |
214 | } |
215 | }; |
216 | |
217 | // We pass scales and zeroPoints in directly rather than relying on KeyTy |
218 | // because we have to create new reallocated versions in `construct` below. |
219 | UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef<double> scales, |
220 | ArrayRef<int64_t> zeroPoints) |
221 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
222 | key.storageTypeMin, key.storageTypeMax), |
223 | scaleElements(scales.data()), zeroPointElements(zeroPoints.data()), |
224 | quantParamsSize(scales.size()), |
225 | quantizedDimension(key.quantizedDimension) {} |
226 | |
227 | bool operator==(const KeyTy &key) const { |
228 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
229 | } |
230 | |
231 | /// Construction. |
232 | static UniformQuantizedPerAxisTypeStorage * |
233 | construct(TypeStorageAllocator &allocator, const KeyTy &key) { |
234 | ArrayRef<double> scales = allocator.copyInto(elements: key.scales); |
235 | ArrayRef<int64_t> zeroPoints = allocator.copyInto(elements: key.zeroPoints); |
236 | return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>()) |
237 | UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints); |
238 | } |
239 | |
240 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
241 | |
242 | ArrayRef<double> getScales() const { |
243 | return ArrayRef<double>(scaleElements, quantParamsSize); |
244 | } |
245 | |
246 | ArrayRef<int64_t> getZeroPoints() const { |
247 | return ArrayRef<int64_t>(zeroPointElements, quantParamsSize); |
248 | } |
249 | |
250 | const double *scaleElements; |
251 | const int64_t *zeroPointElements; |
252 | unsigned quantParamsSize; |
253 | int32_t quantizedDimension; |
254 | }; |
255 | |
256 | struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage { |
257 | struct KeyTy { |
258 | KeyTy(Type expressedType, double min, double max) |
259 | : expressedType(expressedType), min(min), max(max) {} |
260 | // Floating point type that the quantized type approximates. |
261 | Type expressedType; |
262 | |
263 | double min; |
264 | double max; |
265 | |
266 | // Check for equality of two structures that share KeyTy data members |
267 | // (by name). |
268 | template <typename T, typename U> |
269 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
270 | return lhs.expressedType == rhs.expressedType && lhs.min == rhs.min && |
271 | lhs.max == rhs.max; |
272 | } |
273 | |
274 | bool operator==(const KeyTy &other) const { |
275 | return genericIsEqual(lhs: *this, rhs: other); |
276 | } |
277 | |
278 | unsigned getHashValue() const { |
279 | int64_t minBits = llvm::bit_cast<double>(from: min); |
280 | int64_t maxBits = llvm::bit_cast<double>(from: max); |
281 | return llvm::hash_combine(args: expressedType, args: minBits, args: maxBits); |
282 | } |
283 | }; |
284 | |
285 | CalibratedQuantizedTypeStorage(const KeyTy &key) |
286 | : QuantizedTypeStorage(0, NoneType(), key.expressedType, 0, 0), |
287 | min(key.min), max(key.max) {} |
288 | |
289 | bool operator==(const KeyTy &key) const { |
290 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
291 | } |
292 | |
293 | /// Construction. |
294 | static CalibratedQuantizedTypeStorage * |
295 | construct(TypeStorageAllocator &allocator, const KeyTy &key) { |
296 | return new (allocator.allocate<CalibratedQuantizedTypeStorage>()) |
297 | CalibratedQuantizedTypeStorage(key); |
298 | } |
299 | |
300 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
301 | |
302 | double min; |
303 | double max; |
304 | }; |
305 | |
306 | } // namespace detail |
307 | } // namespace quant |
308 | } // namespace mlir |
309 | |
310 | #endif // TYPE_DETAIL_H_ |
311 | |