1//===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef TYPE_DETAIL_H_
10#define TYPE_DETAIL_H_
11
12#include "mlir/IR/BuiltinTypes.h"
13#include "mlir/IR/TypeSupport.h"
14#include "mlir/IR/Types.h"
15#include "llvm/ADT/DenseMap.h"
16#include "llvm/ADT/Hashing.h"
17#include "llvm/ADT/bit.h"
18
19namespace mlir {
20namespace quant {
21namespace detail {
22
23struct QuantizedTypeStorage : public mlir::TypeStorage {
24 QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType,
25 int64_t storageTypeMin, int64_t storageTypeMax)
26 : flags(flags), storageType(storageType), expressedType(expressedType),
27 storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
28
29 /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
30 unsigned flags;
31
32 // Integral type for the storage point representation.
33 Type storageType;
34
35 // Floating point type that the quantized type approximates.
36 Type expressedType;
37
38 // The minimum value storageType can take.
39 int64_t storageTypeMin;
40
41 // The maximum value storageType can take.
42 int64_t storageTypeMax;
43};
44
45struct AnyQuantizedTypeStorage : public QuantizedTypeStorage {
46 struct KeyTy {
47 KeyTy(unsigned flags, Type storageType, Type expressedType,
48 int64_t storageTypeMin, int64_t storageTypeMax)
49 : flags(flags), storageType(storageType), expressedType(expressedType),
50 storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
51 unsigned flags;
52 Type storageType;
53 Type expressedType;
54 int64_t storageTypeMin;
55 int64_t storageTypeMax;
56
57 // Check for equality of two structures that share KeyTy data members
58 // (by name).
59 template <typename T, typename U>
60 static bool genericIsEqual(const T &lhs, const U &rhs) {
61 return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
62 lhs.expressedType == rhs.expressedType &&
63 lhs.storageTypeMin == rhs.storageTypeMin &&
64 lhs.storageTypeMax == rhs.storageTypeMax;
65 }
66
67 bool operator==(const KeyTy &other) const {
68 return genericIsEqual(lhs: *this, rhs: other);
69 }
70
71 unsigned getHashValue() const {
72 return llvm::hash_combine(args: flags, args: storageType, args: expressedType,
73 args: storageTypeMin, args: storageTypeMax);
74 }
75 };
76
77 AnyQuantizedTypeStorage(const KeyTy &key)
78 : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
79 key.storageTypeMin, key.storageTypeMax) {}
80
81 bool operator==(const KeyTy &key) const {
82 return KeyTy::genericIsEqual(lhs: *this, rhs: key);
83 }
84
85 /// Construction.
86 static AnyQuantizedTypeStorage *construct(TypeStorageAllocator &allocator,
87 const KeyTy &key) {
88 return new (allocator.allocate<AnyQuantizedTypeStorage>())
89 AnyQuantizedTypeStorage(key);
90 }
91
92 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
93};
94
95struct UniformQuantizedTypeStorage : public QuantizedTypeStorage {
96 struct KeyTy {
97 KeyTy(unsigned flags, Type storageType, Type expressedType, double scale,
98 int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
99 : flags(flags), storageType(storageType), expressedType(expressedType),
100 scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin),
101 storageTypeMax(storageTypeMax) {}
102 /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
103 unsigned flags;
104
105 // Integral type for the storage point representation.
106 Type storageType;
107
108 // Floating point type that the quantized type approximates.
109 Type expressedType;
110
111 double scale;
112 int64_t zeroPoint;
113 int64_t storageTypeMin;
114 int64_t storageTypeMax;
115
116 // Check for equality of two structures that share KeyTy data members
117 // (by name).
118 template <typename T, typename U>
119 static bool genericIsEqual(const T &lhs, const U &rhs) {
120 return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
121 lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale &&
122 lhs.zeroPoint == rhs.zeroPoint &&
123 lhs.storageTypeMin == rhs.storageTypeMin &&
124 lhs.storageTypeMax == rhs.storageTypeMax;
125 }
126
127 bool operator==(const KeyTy &other) const {
128 return genericIsEqual(lhs: *this, rhs: other);
129 }
130
131 unsigned getHashValue() const {
132 int64_t scaleBits = llvm::bit_cast<int64_t>(from: scale);
133 return llvm::hash_combine(args: flags, args: storageType, args: expressedType, args: scaleBits,
134 args: zeroPoint, args: storageTypeMin, args: storageTypeMax);
135 }
136 };
137
138 UniformQuantizedTypeStorage(const KeyTy &key)
139 : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
140 key.storageTypeMin, key.storageTypeMax),
141 scale(key.scale), zeroPoint(key.zeroPoint) {}
142
143 bool operator==(const KeyTy &key) const {
144 return KeyTy::genericIsEqual(lhs: *this, rhs: key);
145 }
146
147 /// Construction.
148 static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator,
149 const KeyTy &key) {
150 return new (allocator.allocate<UniformQuantizedTypeStorage>())
151 UniformQuantizedTypeStorage(key);
152 }
153
154 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
155
156 double scale;
157 int64_t zeroPoint;
158};
159
160struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage {
161 struct KeyTy {
162 KeyTy(unsigned flags, Type storageType, Type expressedType,
163 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
164 int32_t quantizedDimension, int64_t storageTypeMin,
165 int64_t storageTypeMax)
166 : flags(flags), storageType(storageType), expressedType(expressedType),
167 scales(scales), zeroPoints(zeroPoints),
168 quantizedDimension(quantizedDimension),
169 storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
170 /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
171 unsigned flags;
172
173 // Integral type for the storage point representation.
174 Type storageType;
175
176 // Floating point type that the quantized type approximates.
177 Type expressedType;
178
179 ArrayRef<double> scales;
180 ArrayRef<int64_t> zeroPoints;
181 int32_t quantizedDimension;
182 int64_t storageTypeMin;
183 int64_t storageTypeMax;
184
185 ArrayRef<double> getScales() const { return scales; }
186
187 ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; }
188
189 // Check for equality of two structures that share KeyTy data members
190 // (by name).
191 template <typename T, typename U>
192 static bool genericIsEqual(const T &lhs, const U &rhs) {
193 return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
194 lhs.expressedType == rhs.expressedType &&
195 lhs.getScales() == rhs.getScales() &&
196 lhs.getZeroPoints() == rhs.getZeroPoints() &&
197 lhs.quantizedDimension == rhs.quantizedDimension &&
198 lhs.storageTypeMin == rhs.storageTypeMin &&
199 lhs.storageTypeMax == rhs.storageTypeMax;
200 }
201
202 bool operator==(const KeyTy &other) const {
203 return genericIsEqual(lhs: *this, rhs: other);
204 }
205
206 unsigned getHashValue() const {
207 int64_t *scalesCast = llvm::bit_cast<int64_t *>(from: scales.data());
208 ArrayRef<int64_t> scalesBits(scalesCast, scales.size());
209 return llvm::hash_combine(
210 args: flags, args: storageType, args: expressedType,
211 args: llvm::hash_combine_range(first: scalesBits.begin(), last: scalesBits.end()),
212 args: llvm::hash_combine_range(first: zeroPoints.begin(), last: zeroPoints.end()),
213 args: storageTypeMin, args: storageTypeMax);
214 }
215 };
216
217 // We pass scales and zeroPoints in directly rather than relying on KeyTy
218 // because we have to create new reallocated versions in `construct` below.
219 UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef<double> scales,
220 ArrayRef<int64_t> zeroPoints)
221 : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
222 key.storageTypeMin, key.storageTypeMax),
223 scaleElements(scales.data()), zeroPointElements(zeroPoints.data()),
224 quantParamsSize(scales.size()),
225 quantizedDimension(key.quantizedDimension) {}
226
227 bool operator==(const KeyTy &key) const {
228 return KeyTy::genericIsEqual(lhs: *this, rhs: key);
229 }
230
231 /// Construction.
232 static UniformQuantizedPerAxisTypeStorage *
233 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
234 ArrayRef<double> scales = allocator.copyInto(elements: key.scales);
235 ArrayRef<int64_t> zeroPoints = allocator.copyInto(elements: key.zeroPoints);
236 return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>())
237 UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints);
238 }
239
240 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
241
242 ArrayRef<double> getScales() const {
243 return ArrayRef<double>(scaleElements, quantParamsSize);
244 }
245
246 ArrayRef<int64_t> getZeroPoints() const {
247 return ArrayRef<int64_t>(zeroPointElements, quantParamsSize);
248 }
249
250 const double *scaleElements;
251 const int64_t *zeroPointElements;
252 unsigned quantParamsSize;
253 int32_t quantizedDimension;
254};
255
256struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage {
257 struct KeyTy {
258 KeyTy(Type expressedType, double min, double max)
259 : expressedType(expressedType), min(min), max(max) {}
260 // Floating point type that the quantized type approximates.
261 Type expressedType;
262
263 double min;
264 double max;
265
266 // Check for equality of two structures that share KeyTy data members
267 // (by name).
268 template <typename T, typename U>
269 static bool genericIsEqual(const T &lhs, const U &rhs) {
270 return lhs.expressedType == rhs.expressedType && lhs.min == rhs.min &&
271 lhs.max == rhs.max;
272 }
273
274 bool operator==(const KeyTy &other) const {
275 return genericIsEqual(lhs: *this, rhs: other);
276 }
277
278 unsigned getHashValue() const {
279 int64_t minBits = llvm::bit_cast<double>(from: min);
280 int64_t maxBits = llvm::bit_cast<double>(from: max);
281 return llvm::hash_combine(args: expressedType, args: minBits, args: maxBits);
282 }
283 };
284
285 CalibratedQuantizedTypeStorage(const KeyTy &key)
286 : QuantizedTypeStorage(0, NoneType(), key.expressedType, 0, 0),
287 min(key.min), max(key.max) {}
288
289 bool operator==(const KeyTy &key) const {
290 return KeyTy::genericIsEqual(lhs: *this, rhs: key);
291 }
292
293 /// Construction.
294 static CalibratedQuantizedTypeStorage *
295 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
296 return new (allocator.allocate<CalibratedQuantizedTypeStorage>())
297 CalibratedQuantizedTypeStorage(key);
298 }
299
300 static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
301
302 double min;
303 double max;
304};
305
306} // namespace detail
307} // namespace quant
308} // namespace mlir
309
310#endif // TYPE_DETAIL_H_
311

source code of mlir/lib/Dialect/Quant/IR/TypeDetail.h