1 | //===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef TYPE_DETAIL_H_ |
10 | #define TYPE_DETAIL_H_ |
11 | |
12 | #include "mlir/IR/BuiltinAttributes.h" |
13 | #include "mlir/IR/BuiltinTypes.h" |
14 | #include "mlir/IR/TypeSupport.h" |
15 | #include "mlir/IR/Types.h" |
16 | #include "llvm/ADT/DenseMap.h" |
17 | #include "llvm/ADT/Hashing.h" |
18 | #include "llvm/ADT/bit.h" |
19 | |
20 | namespace mlir { |
21 | namespace quant { |
22 | namespace detail { |
23 | |
24 | struct QuantizedTypeStorage : public mlir::TypeStorage { |
25 | QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType, |
26 | int64_t storageTypeMin, int64_t storageTypeMax) |
27 | : flags(flags), storageType(storageType), expressedType(expressedType), |
28 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
29 | |
30 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
31 | unsigned flags; |
32 | |
33 | // Integral type for the storage point representation. |
34 | Type storageType; |
35 | |
36 | // Floating point type that the quantized type approximates. |
37 | Type expressedType; |
38 | |
39 | // The minimum value storageType can take. |
40 | int64_t storageTypeMin; |
41 | |
42 | // The maximum value storageType can take. |
43 | int64_t storageTypeMax; |
44 | }; |
45 | |
46 | struct AnyQuantizedTypeStorage : public QuantizedTypeStorage { |
47 | struct KeyTy { |
48 | KeyTy(unsigned flags, Type storageType, Type expressedType, |
49 | int64_t storageTypeMin, int64_t storageTypeMax) |
50 | : flags(flags), storageType(storageType), expressedType(expressedType), |
51 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
52 | unsigned flags; |
53 | Type storageType; |
54 | Type expressedType; |
55 | int64_t storageTypeMin; |
56 | int64_t storageTypeMax; |
57 | |
58 | // Check for equality of two structures that share KeyTy data members |
59 | // (by name). |
60 | template <typename T, typename U> |
61 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
62 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
63 | lhs.expressedType == rhs.expressedType && |
64 | lhs.storageTypeMin == rhs.storageTypeMin && |
65 | lhs.storageTypeMax == rhs.storageTypeMax; |
66 | } |
67 | |
68 | bool operator==(const KeyTy &other) const { |
69 | return genericIsEqual(lhs: *this, rhs: other); |
70 | } |
71 | |
72 | unsigned getHashValue() const { |
73 | return llvm::hash_combine(args: flags, args: storageType, args: expressedType, |
74 | args: storageTypeMin, args: storageTypeMax); |
75 | } |
76 | }; |
77 | |
78 | AnyQuantizedTypeStorage(const KeyTy &key) |
79 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
80 | key.storageTypeMin, key.storageTypeMax) {} |
81 | |
82 | bool operator==(const KeyTy &key) const { |
83 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
84 | } |
85 | |
86 | /// Construction. |
87 | static AnyQuantizedTypeStorage *construct(TypeStorageAllocator &allocator, |
88 | const KeyTy &key) { |
89 | return new (allocator.allocate<AnyQuantizedTypeStorage>()) |
90 | AnyQuantizedTypeStorage(key); |
91 | } |
92 | |
93 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
94 | }; |
95 | |
96 | struct UniformQuantizedTypeStorage : public QuantizedTypeStorage { |
97 | struct KeyTy { |
98 | KeyTy(unsigned flags, Type storageType, Type expressedType, double scale, |
99 | int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) |
100 | : flags(flags), storageType(storageType), expressedType(expressedType), |
101 | scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin), |
102 | storageTypeMax(storageTypeMax) {} |
103 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
104 | unsigned flags; |
105 | |
106 | // Integral type for the storage point representation. |
107 | Type storageType; |
108 | |
109 | // Floating point type that the quantized type approximates. |
110 | Type expressedType; |
111 | |
112 | double scale; |
113 | int64_t zeroPoint; |
114 | int64_t storageTypeMin; |
115 | int64_t storageTypeMax; |
116 | |
117 | // Check for equality of two structures that share KeyTy data members |
118 | // (by name). |
119 | template <typename T, typename U> |
120 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
121 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
122 | lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale && |
123 | lhs.zeroPoint == rhs.zeroPoint && |
124 | lhs.storageTypeMin == rhs.storageTypeMin && |
125 | lhs.storageTypeMax == rhs.storageTypeMax; |
126 | } |
127 | |
128 | bool operator==(const KeyTy &other) const { |
129 | return genericIsEqual(lhs: *this, rhs: other); |
130 | } |
131 | |
132 | unsigned getHashValue() const { |
133 | int64_t scaleBits = llvm::bit_cast<int64_t>(from: scale); |
134 | return llvm::hash_combine(args: flags, args: storageType, args: expressedType, args: scaleBits, |
135 | args: zeroPoint, args: storageTypeMin, args: storageTypeMax); |
136 | } |
137 | }; |
138 | |
139 | UniformQuantizedTypeStorage(const KeyTy &key) |
140 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
141 | key.storageTypeMin, key.storageTypeMax), |
142 | scale(key.scale), zeroPoint(key.zeroPoint) {} |
143 | |
144 | bool operator==(const KeyTy &key) const { |
145 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
146 | } |
147 | |
148 | /// Construction. |
149 | static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator, |
150 | const KeyTy &key) { |
151 | return new (allocator.allocate<UniformQuantizedTypeStorage>()) |
152 | UniformQuantizedTypeStorage(key); |
153 | } |
154 | |
155 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
156 | |
157 | double scale; |
158 | int64_t zeroPoint; |
159 | }; |
160 | |
161 | struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage { |
162 | struct KeyTy { |
163 | KeyTy(unsigned flags, Type storageType, Type expressedType, |
164 | ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, |
165 | int32_t quantizedDimension, int64_t storageTypeMin, |
166 | int64_t storageTypeMax) |
167 | : flags(flags), storageType(storageType), expressedType(expressedType), |
168 | scales(scales), zeroPoints(zeroPoints), |
169 | quantizedDimension(quantizedDimension), |
170 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
171 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
172 | unsigned flags; |
173 | |
174 | // Integral type for the storage point representation. |
175 | Type storageType; |
176 | |
177 | // Floating point type that the quantized type approximates. |
178 | Type expressedType; |
179 | |
180 | ArrayRef<double> scales; |
181 | ArrayRef<int64_t> zeroPoints; |
182 | int32_t quantizedDimension; |
183 | int64_t storageTypeMin; |
184 | int64_t storageTypeMax; |
185 | |
186 | ArrayRef<double> getScales() const { return scales; } |
187 | |
188 | ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; } |
189 | |
190 | // Check for equality of two structures that share KeyTy data members |
191 | // (by name). |
192 | template <typename T, typename U> |
193 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
194 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
195 | lhs.expressedType == rhs.expressedType && |
196 | lhs.getScales() == rhs.getScales() && |
197 | lhs.getZeroPoints() == rhs.getZeroPoints() && |
198 | lhs.quantizedDimension == rhs.quantizedDimension && |
199 | lhs.storageTypeMin == rhs.storageTypeMin && |
200 | lhs.storageTypeMax == rhs.storageTypeMax; |
201 | } |
202 | |
203 | bool operator==(const KeyTy &other) const { |
204 | return genericIsEqual(lhs: *this, rhs: other); |
205 | } |
206 | |
207 | unsigned getHashValue() const { |
208 | int64_t *scalesCast = llvm::bit_cast<int64_t *>(from: scales.data()); |
209 | ArrayRef<int64_t> scalesBits(scalesCast, scales.size()); |
210 | return llvm::hash_combine(args: flags, args: storageType, args: expressedType, |
211 | args: llvm::hash_combine_range(R&: scalesBits), |
212 | args: llvm::hash_combine_range(R: zeroPoints), |
213 | args: storageTypeMin, args: storageTypeMax); |
214 | } |
215 | }; |
216 | |
217 | // We pass scales and zeroPoints in directly rather than relying on KeyTy |
218 | // because we have to create new reallocated versions in `construct` below. |
219 | UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef<double> scales, |
220 | ArrayRef<int64_t> zeroPoints) |
221 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
222 | key.storageTypeMin, key.storageTypeMax), |
223 | scaleElements(scales.data()), zeroPointElements(zeroPoints.data()), |
224 | quantParamsSize(scales.size()), |
225 | quantizedDimension(key.quantizedDimension) {} |
226 | |
227 | bool operator==(const KeyTy &key) const { |
228 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
229 | } |
230 | |
231 | /// Construction. |
232 | static UniformQuantizedPerAxisTypeStorage * |
233 | construct(TypeStorageAllocator &allocator, const KeyTy &key) { |
234 | ArrayRef<double> scales = allocator.copyInto(elements: key.scales); |
235 | ArrayRef<int64_t> zeroPoints = allocator.copyInto(elements: key.zeroPoints); |
236 | return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>()) |
237 | UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints); |
238 | } |
239 | |
240 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
241 | |
242 | ArrayRef<double> getScales() const { |
243 | return ArrayRef<double>(scaleElements, quantParamsSize); |
244 | } |
245 | |
246 | ArrayRef<int64_t> getZeroPoints() const { |
247 | return ArrayRef<int64_t>(zeroPointElements, quantParamsSize); |
248 | } |
249 | |
250 | const double *scaleElements; |
251 | const int64_t *zeroPointElements; |
252 | unsigned quantParamsSize; |
253 | int32_t quantizedDimension; |
254 | }; |
255 | |
256 | struct UniformQuantizedSubChannelTypeStorage : public QuantizedTypeStorage { |
257 | struct KeyTy { |
258 | KeyTy(unsigned flags, Type storageType, Type expressedType, |
259 | DenseElementsAttr scales, DenseElementsAttr zeroPoints, |
260 | ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes, |
261 | int64_t storageTypeMin, int64_t storageTypeMax) |
262 | : flags(flags), storageType(storageType), expressedType(expressedType), |
263 | scales(scales), zeroPoints(zeroPoints), |
264 | quantizedDimensions(quantizedDimensions), blockSizes(blockSizes), |
265 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
266 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
267 | unsigned flags; |
268 | |
269 | // Integral type for the storage point representation. |
270 | Type storageType; |
271 | |
272 | // Floating point type that the quantized type approximates. |
273 | Type expressedType; |
274 | |
275 | DenseElementsAttr scales; |
276 | DenseElementsAttr zeroPoints; |
277 | ArrayRef<int32_t> quantizedDimensions; |
278 | ArrayRef<int64_t> blockSizes; |
279 | int64_t storageTypeMin; |
280 | int64_t storageTypeMax; |
281 | |
282 | DenseElementsAttr getScales() const { return scales; } |
283 | |
284 | DenseElementsAttr getZeroPoints() const { return zeroPoints; } |
285 | |
286 | // Check for equality of two structures that share KeyTy data members |
287 | // (by name). |
288 | template <typename T, typename U> |
289 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
290 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
291 | lhs.expressedType == rhs.expressedType && |
292 | lhs.scales == rhs.scales && lhs.zeroPoints == rhs.zeroPoints && |
293 | lhs.quantizedDimensions == rhs.quantizedDimensions && |
294 | lhs.blockSizes == rhs.blockSizes && |
295 | lhs.storageTypeMin == rhs.storageTypeMin && |
296 | lhs.storageTypeMax == rhs.storageTypeMax; |
297 | } |
298 | |
299 | bool operator==(const KeyTy &other) const { |
300 | return genericIsEqual(lhs: *this, rhs: other); |
301 | } |
302 | |
303 | unsigned getHashValue() const { |
304 | // Hash the scalar attributes. |
305 | unsigned hash = llvm::hash_combine(args: flags, args: storageType, args: expressedType, |
306 | args: storageTypeMin, args: storageTypeMax); |
307 | |
308 | // Hash the scales. |
309 | for (auto scaleAttr : scales.getValues<APFloat>()) { |
310 | hash = llvm::hash_combine( |
311 | hash, llvm::bit_cast<int64_t>(scaleAttr.convertToDouble())); |
312 | } |
313 | |
314 | // Hash the zero points. (Assumed to be integers, adjust if needed). |
315 | for (auto zeroPointAttr : zeroPoints.getValues<APInt>()) { |
316 | hash = llvm::hash_combine(hash, zeroPointAttr.getSExtValue()); |
317 | } |
318 | |
319 | // Hash the quantized dimensions and block sizes. |
320 | hash = llvm::hash_combine(args: hash, |
321 | args: llvm::hash_combine_range(R: quantizedDimensions), |
322 | args: llvm::hash_combine_range(R: blockSizes)); |
323 | |
324 | return hash; |
325 | } |
326 | }; |
327 | |
328 | // We pass scales and zeroPoints in directly rather than relying on KeyTy |
329 | // because we have to create new reallocated versions in `construct` below. |
330 | UniformQuantizedSubChannelTypeStorage(const KeyTy &key, |
331 | DenseElementsAttr scales, |
332 | DenseElementsAttr zeroPoints, |
333 | ArrayRef<int32_t> quantizedDimensions, |
334 | ArrayRef<int64_t> blockSizes) |
335 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
336 | key.storageTypeMin, key.storageTypeMax), |
337 | scales(scales), zeroPoints(zeroPoints), |
338 | quantizedDimensions(quantizedDimensions), blockSizes(blockSizes) {} |
339 | |
340 | bool operator==(const KeyTy &key) const { |
341 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
342 | } |
343 | |
344 | /// Construction. |
345 | static UniformQuantizedSubChannelTypeStorage * |
346 | construct(TypeStorageAllocator &allocator, const KeyTy &key) { |
347 | DenseElementsAttr scales = key.scales; |
348 | DenseElementsAttr zeroPoints = key.zeroPoints; |
349 | ArrayRef<int32_t> quantizedDimensions = |
350 | allocator.copyInto(elements: key.quantizedDimensions); |
351 | ArrayRef<int64_t> blockSizes = allocator.copyInto(elements: key.blockSizes); |
352 | return new (allocator.allocate<UniformQuantizedSubChannelTypeStorage>()) |
353 | UniformQuantizedSubChannelTypeStorage(key, scales, zeroPoints, |
354 | quantizedDimensions, blockSizes); |
355 | } |
356 | |
357 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
358 | |
359 | DenseElementsAttr getScales() const { return scales; } |
360 | |
361 | DenseElementsAttr getZeroPoints() const { return zeroPoints; } |
362 | |
363 | ArrayRef<int32_t> getQuantizedDimensions() const { |
364 | return quantizedDimensions; |
365 | } |
366 | |
367 | ArrayRef<int64_t> getBlockSizes() const { return blockSizes; } |
368 | |
369 | DenseElementsAttr scales; |
370 | DenseElementsAttr zeroPoints; |
371 | ArrayRef<int32_t> quantizedDimensions; |
372 | ArrayRef<int64_t> blockSizes; |
373 | }; |
374 | |
375 | struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage { |
376 | struct KeyTy { |
377 | KeyTy(Type expressedType, double min, double max) |
378 | : expressedType(expressedType), min(min), max(max) {} |
379 | // Floating point type that the quantized type approximates. |
380 | Type expressedType; |
381 | |
382 | double min; |
383 | double max; |
384 | |
385 | // Check for equality of two structures that share KeyTy data members |
386 | // (by name). |
387 | template <typename T, typename U> |
388 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
389 | return lhs.expressedType == rhs.expressedType && lhs.min == rhs.min && |
390 | lhs.max == rhs.max; |
391 | } |
392 | |
393 | bool operator==(const KeyTy &other) const { |
394 | return genericIsEqual(lhs: *this, rhs: other); |
395 | } |
396 | |
397 | unsigned getHashValue() const { |
398 | int64_t minBits = llvm::bit_cast<double>(from: min); |
399 | int64_t maxBits = llvm::bit_cast<double>(from: max); |
400 | return llvm::hash_combine(args: expressedType, args: minBits, args: maxBits); |
401 | } |
402 | }; |
403 | |
404 | CalibratedQuantizedTypeStorage(const KeyTy &key) |
405 | : QuantizedTypeStorage(0, NoneType(), key.expressedType, 0, 0), |
406 | min(key.min), max(key.max) {} |
407 | |
408 | bool operator==(const KeyTy &key) const { |
409 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
410 | } |
411 | |
412 | /// Construction. |
413 | static CalibratedQuantizedTypeStorage * |
414 | construct(TypeStorageAllocator &allocator, const KeyTy &key) { |
415 | return new (allocator.allocate<CalibratedQuantizedTypeStorage>()) |
416 | CalibratedQuantizedTypeStorage(key); |
417 | } |
418 | |
419 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
420 | |
421 | double min; |
422 | double max; |
423 | }; |
424 | |
425 | } // namespace detail |
426 | } // namespace quant |
427 | } // namespace mlir |
428 | |
429 | #endif // TYPE_DETAIL_H_ |
430 |
Definitions
- QuantizedTypeStorage
- QuantizedTypeStorage
- AnyQuantizedTypeStorage
- KeyTy
- KeyTy
- genericIsEqual
- operator==
- getHashValue
- AnyQuantizedTypeStorage
- operator==
- construct
- hashKey
- UniformQuantizedTypeStorage
- KeyTy
- KeyTy
- genericIsEqual
- operator==
- getHashValue
- UniformQuantizedTypeStorage
- operator==
- construct
- hashKey
- UniformQuantizedPerAxisTypeStorage
- KeyTy
- KeyTy
- getScales
- getZeroPoints
- genericIsEqual
- operator==
- getHashValue
- UniformQuantizedPerAxisTypeStorage
- operator==
- construct
- hashKey
- getScales
- getZeroPoints
- UniformQuantizedSubChannelTypeStorage
- KeyTy
- KeyTy
- getScales
- getZeroPoints
- genericIsEqual
- operator==
- getHashValue
- UniformQuantizedSubChannelTypeStorage
- operator==
- construct
- hashKey
- getScales
- getZeroPoints
- getQuantizedDimensions
- getBlockSizes
- CalibratedQuantizedTypeStorage
- KeyTy
- KeyTy
- genericIsEqual
- operator==
- getHashValue
- CalibratedQuantizedTypeStorage
- operator==
- construct
Improve your Profiling and Debugging skills
Find out more