1//===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef TYPE_DETAIL_H_
10#define TYPE_DETAIL_H_
11
12#include "mlir/IR/BuiltinAttributes.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/TypeSupport.h"
15#include "mlir/IR/Types.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/Hashing.h"
18#include "llvm/ADT/bit.h"
19
20namespace mlir {
21namespace quant {
22namespace detail {
23
24struct QuantizedTypeStorage : public mlir::TypeStorage {
25 QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType,
26 int64_t storageTypeMin, int64_t storageTypeMax)
27 : flags(flags), storageType(storageType), expressedType(expressedType),
28 storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
29
30 /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
31 unsigned flags;
32
33 // Integral type for the storage point representation.
34 Type storageType;
35
36 // Floating point type that the quantized type approximates.
37 Type expressedType;
38
39 // The minimum value storageType can take.
40 int64_t storageTypeMin;
41
42 // The maximum value storageType can take.
43 int64_t storageTypeMax;
44};
45
46struct AnyQuantizedTypeStorage : public QuantizedTypeStorage {
47 struct KeyTy {
48 KeyTy(unsigned flags, Type storageType, Type expressedType,
49 int64_t storageTypeMin, int64_t storageTypeMax)
50 : flags(flags), storageType(storageType), expressedType(expressedType),
51 storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
52 unsigned flags;
53 Type storageType;
54 Type expressedType;
55 int64_t storageTypeMin;
56 int64_t storageTypeMax;
57
58 // Check for equality of two structures that share KeyTy data members
59 // (by name).
60 template <typename T, typename U>
61 static bool genericIsEqual(const T &lhs, const U &rhs) {
62 return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
63 lhs.expressedType == rhs.expressedType &&
64 lhs.storageTypeMin == rhs.storageTypeMin &&
65 lhs.storageTypeMax == rhs.storageTypeMax;
66 }
67
68 bool operator==(const KeyTy &other) const {
69 return genericIsEqual(lhs: *this, rhs: other);
70 }
71
72 unsigned getHashValue() const {
73 return llvm::hash_combine(args: flags, args: storageType, args: expressedType,
74 args: storageTypeMin, args: storageTypeMax);
75 }
76 };
77
78 AnyQuantizedTypeStorage(const KeyTy &key)
79 : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
80 key.storageTypeMin, key.storageTypeMax) {}
81
82 bool operator==(const KeyTy &key) const {
83 return KeyTy::genericIsEqual(lhs: *this, rhs: key);
84 }
85
86 /// Construction.
87 static AnyQuantizedTypeStorage *construct(TypeStorageAllocator &allocator,
88 const KeyTy &key) {
89 return new (allocator.allocate<AnyQuantizedTypeStorage>())
90 AnyQuantizedTypeStorage(key);
91 }
92
93 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
94};
95
96struct UniformQuantizedTypeStorage : public QuantizedTypeStorage {
97 struct KeyTy {
98 KeyTy(unsigned flags, Type storageType, Type expressedType, double scale,
99 int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
100 : flags(flags), storageType(storageType), expressedType(expressedType),
101 scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin),
102 storageTypeMax(storageTypeMax) {}
103 /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
104 unsigned flags;
105
106 // Integral type for the storage point representation.
107 Type storageType;
108
109 // Floating point type that the quantized type approximates.
110 Type expressedType;
111
112 double scale;
113 int64_t zeroPoint;
114 int64_t storageTypeMin;
115 int64_t storageTypeMax;
116
117 // Check for equality of two structures that share KeyTy data members
118 // (by name).
119 template <typename T, typename U>
120 static bool genericIsEqual(const T &lhs, const U &rhs) {
121 return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
122 lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale &&
123 lhs.zeroPoint == rhs.zeroPoint &&
124 lhs.storageTypeMin == rhs.storageTypeMin &&
125 lhs.storageTypeMax == rhs.storageTypeMax;
126 }
127
128 bool operator==(const KeyTy &other) const {
129 return genericIsEqual(lhs: *this, rhs: other);
130 }
131
132 unsigned getHashValue() const {
133 int64_t scaleBits = llvm::bit_cast<int64_t>(from: scale);
134 return llvm::hash_combine(args: flags, args: storageType, args: expressedType, args: scaleBits,
135 args: zeroPoint, args: storageTypeMin, args: storageTypeMax);
136 }
137 };
138
139 UniformQuantizedTypeStorage(const KeyTy &key)
140 : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
141 key.storageTypeMin, key.storageTypeMax),
142 scale(key.scale), zeroPoint(key.zeroPoint) {}
143
144 bool operator==(const KeyTy &key) const {
145 return KeyTy::genericIsEqual(lhs: *this, rhs: key);
146 }
147
148 /// Construction.
149 static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator,
150 const KeyTy &key) {
151 return new (allocator.allocate<UniformQuantizedTypeStorage>())
152 UniformQuantizedTypeStorage(key);
153 }
154
155 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
156
157 double scale;
158 int64_t zeroPoint;
159};
160
161struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage {
162 struct KeyTy {
163 KeyTy(unsigned flags, Type storageType, Type expressedType,
164 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
165 int32_t quantizedDimension, int64_t storageTypeMin,
166 int64_t storageTypeMax)
167 : flags(flags), storageType(storageType), expressedType(expressedType),
168 scales(scales), zeroPoints(zeroPoints),
169 quantizedDimension(quantizedDimension),
170 storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
171 /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
172 unsigned flags;
173
174 // Integral type for the storage point representation.
175 Type storageType;
176
177 // Floating point type that the quantized type approximates.
178 Type expressedType;
179
180 ArrayRef<double> scales;
181 ArrayRef<int64_t> zeroPoints;
182 int32_t quantizedDimension;
183 int64_t storageTypeMin;
184 int64_t storageTypeMax;
185
186 ArrayRef<double> getScales() const { return scales; }
187
188 ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; }
189
190 // Check for equality of two structures that share KeyTy data members
191 // (by name).
192 template <typename T, typename U>
193 static bool genericIsEqual(const T &lhs, const U &rhs) {
194 return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
195 lhs.expressedType == rhs.expressedType &&
196 lhs.getScales() == rhs.getScales() &&
197 lhs.getZeroPoints() == rhs.getZeroPoints() &&
198 lhs.quantizedDimension == rhs.quantizedDimension &&
199 lhs.storageTypeMin == rhs.storageTypeMin &&
200 lhs.storageTypeMax == rhs.storageTypeMax;
201 }
202
203 bool operator==(const KeyTy &other) const {
204 return genericIsEqual(lhs: *this, rhs: other);
205 }
206
207 unsigned getHashValue() const {
208 int64_t *scalesCast = llvm::bit_cast<int64_t *>(from: scales.data());
209 ArrayRef<int64_t> scalesBits(scalesCast, scales.size());
210 return llvm::hash_combine(args: flags, args: storageType, args: expressedType,
211 args: llvm::hash_combine_range(R&: scalesBits),
212 args: llvm::hash_combine_range(R: zeroPoints),
213 args: storageTypeMin, args: storageTypeMax);
214 }
215 };
216
217 // We pass scales and zeroPoints in directly rather than relying on KeyTy
218 // because we have to create new reallocated versions in `construct` below.
219 UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef<double> scales,
220 ArrayRef<int64_t> zeroPoints)
221 : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
222 key.storageTypeMin, key.storageTypeMax),
223 scaleElements(scales.data()), zeroPointElements(zeroPoints.data()),
224 quantParamsSize(scales.size()),
225 quantizedDimension(key.quantizedDimension) {}
226
227 bool operator==(const KeyTy &key) const {
228 return KeyTy::genericIsEqual(lhs: *this, rhs: key);
229 }
230
231 /// Construction.
232 static UniformQuantizedPerAxisTypeStorage *
233 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
234 ArrayRef<double> scales = allocator.copyInto(elements: key.scales);
235 ArrayRef<int64_t> zeroPoints = allocator.copyInto(elements: key.zeroPoints);
236 return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>())
237 UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints);
238 }
239
240 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
241
242 ArrayRef<double> getScales() const {
243 return ArrayRef<double>(scaleElements, quantParamsSize);
244 }
245
246 ArrayRef<int64_t> getZeroPoints() const {
247 return ArrayRef<int64_t>(zeroPointElements, quantParamsSize);
248 }
249
250 const double *scaleElements;
251 const int64_t *zeroPointElements;
252 unsigned quantParamsSize;
253 int32_t quantizedDimension;
254};
255
256struct UniformQuantizedSubChannelTypeStorage : public QuantizedTypeStorage {
257 struct KeyTy {
258 KeyTy(unsigned flags, Type storageType, Type expressedType,
259 DenseElementsAttr scales, DenseElementsAttr zeroPoints,
260 ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
261 int64_t storageTypeMin, int64_t storageTypeMax)
262 : flags(flags), storageType(storageType), expressedType(expressedType),
263 scales(scales), zeroPoints(zeroPoints),
264 quantizedDimensions(quantizedDimensions), blockSizes(blockSizes),
265 storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
266 /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
267 unsigned flags;
268
269 // Integral type for the storage point representation.
270 Type storageType;
271
272 // Floating point type that the quantized type approximates.
273 Type expressedType;
274
275 DenseElementsAttr scales;
276 DenseElementsAttr zeroPoints;
277 ArrayRef<int32_t> quantizedDimensions;
278 ArrayRef<int64_t> blockSizes;
279 int64_t storageTypeMin;
280 int64_t storageTypeMax;
281
282 DenseElementsAttr getScales() const { return scales; }
283
284 DenseElementsAttr getZeroPoints() const { return zeroPoints; }
285
286 // Check for equality of two structures that share KeyTy data members
287 // (by name).
288 template <typename T, typename U>
289 static bool genericIsEqual(const T &lhs, const U &rhs) {
290 return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
291 lhs.expressedType == rhs.expressedType &&
292 lhs.scales == rhs.scales && lhs.zeroPoints == rhs.zeroPoints &&
293 lhs.quantizedDimensions == rhs.quantizedDimensions &&
294 lhs.blockSizes == rhs.blockSizes &&
295 lhs.storageTypeMin == rhs.storageTypeMin &&
296 lhs.storageTypeMax == rhs.storageTypeMax;
297 }
298
299 bool operator==(const KeyTy &other) const {
300 return genericIsEqual(lhs: *this, rhs: other);
301 }
302
303 unsigned getHashValue() const {
304 // Hash the scalar attributes.
305 unsigned hash = llvm::hash_combine(args: flags, args: storageType, args: expressedType,
306 args: storageTypeMin, args: storageTypeMax);
307
308 // Hash the scales.
309 for (auto scaleAttr : scales.getValues<APFloat>()) {
310 hash = llvm::hash_combine(
311 hash, llvm::bit_cast<int64_t>(scaleAttr.convertToDouble()));
312 }
313
314 // Hash the zero points. (Assumed to be integers, adjust if needed).
315 for (auto zeroPointAttr : zeroPoints.getValues<APInt>()) {
316 hash = llvm::hash_combine(hash, zeroPointAttr.getSExtValue());
317 }
318
319 // Hash the quantized dimensions and block sizes.
320 hash = llvm::hash_combine(args: hash,
321 args: llvm::hash_combine_range(R: quantizedDimensions),
322 args: llvm::hash_combine_range(R: blockSizes));
323
324 return hash;
325 }
326 };
327
328 // We pass scales and zeroPoints in directly rather than relying on KeyTy
329 // because we have to create new reallocated versions in `construct` below.
330 UniformQuantizedSubChannelTypeStorage(const KeyTy &key,
331 DenseElementsAttr scales,
332 DenseElementsAttr zeroPoints,
333 ArrayRef<int32_t> quantizedDimensions,
334 ArrayRef<int64_t> blockSizes)
335 : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
336 key.storageTypeMin, key.storageTypeMax),
337 scales(scales), zeroPoints(zeroPoints),
338 quantizedDimensions(quantizedDimensions), blockSizes(blockSizes) {}
339
340 bool operator==(const KeyTy &key) const {
341 return KeyTy::genericIsEqual(lhs: *this, rhs: key);
342 }
343
344 /// Construction.
345 static UniformQuantizedSubChannelTypeStorage *
346 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
347 DenseElementsAttr scales = key.scales;
348 DenseElementsAttr zeroPoints = key.zeroPoints;
349 ArrayRef<int32_t> quantizedDimensions =
350 allocator.copyInto(elements: key.quantizedDimensions);
351 ArrayRef<int64_t> blockSizes = allocator.copyInto(elements: key.blockSizes);
352 return new (allocator.allocate<UniformQuantizedSubChannelTypeStorage>())
353 UniformQuantizedSubChannelTypeStorage(key, scales, zeroPoints,
354 quantizedDimensions, blockSizes);
355 }
356
357 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
358
359 DenseElementsAttr getScales() const { return scales; }
360
361 DenseElementsAttr getZeroPoints() const { return zeroPoints; }
362
363 ArrayRef<int32_t> getQuantizedDimensions() const {
364 return quantizedDimensions;
365 }
366
367 ArrayRef<int64_t> getBlockSizes() const { return blockSizes; }
368
369 DenseElementsAttr scales;
370 DenseElementsAttr zeroPoints;
371 ArrayRef<int32_t> quantizedDimensions;
372 ArrayRef<int64_t> blockSizes;
373};
374
375struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage {
376 struct KeyTy {
377 KeyTy(Type expressedType, double min, double max)
378 : expressedType(expressedType), min(min), max(max) {}
379 // Floating point type that the quantized type approximates.
380 Type expressedType;
381
382 double min;
383 double max;
384
385 // Check for equality of two structures that share KeyTy data members
386 // (by name).
387 template <typename T, typename U>
388 static bool genericIsEqual(const T &lhs, const U &rhs) {
389 return lhs.expressedType == rhs.expressedType && lhs.min == rhs.min &&
390 lhs.max == rhs.max;
391 }
392
393 bool operator==(const KeyTy &other) const {
394 return genericIsEqual(lhs: *this, rhs: other);
395 }
396
397 unsigned getHashValue() const {
398 int64_t minBits = llvm::bit_cast<double>(from: min);
399 int64_t maxBits = llvm::bit_cast<double>(from: max);
400 return llvm::hash_combine(args: expressedType, args: minBits, args: maxBits);
401 }
402 };
403
404 CalibratedQuantizedTypeStorage(const KeyTy &key)
405 : QuantizedTypeStorage(0, NoneType(), key.expressedType, 0, 0),
406 min(key.min), max(key.max) {}
407
408 bool operator==(const KeyTy &key) const {
409 return KeyTy::genericIsEqual(lhs: *this, rhs: key);
410 }
411
412 /// Construction.
413 static CalibratedQuantizedTypeStorage *
414 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
415 return new (allocator.allocate<CalibratedQuantizedTypeStorage>())
416 CalibratedQuantizedTypeStorage(key);
417 }
418
419 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
420
421 double min;
422 double max;
423};
424
425} // namespace detail
426} // namespace quant
427} // namespace mlir
428
429#endif // TYPE_DETAIL_H_
430

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Quant/IR/TypeDetail.h