| 1 | //===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef TYPE_DETAIL_H_ |
| 10 | #define TYPE_DETAIL_H_ |
| 11 | |
| 12 | #include "mlir/IR/BuiltinAttributes.h" |
| 13 | #include "mlir/IR/BuiltinTypes.h" |
| 14 | #include "mlir/IR/TypeSupport.h" |
| 15 | #include "mlir/IR/Types.h" |
| 16 | #include "llvm/ADT/DenseMap.h" |
| 17 | #include "llvm/ADT/Hashing.h" |
| 18 | #include "llvm/ADT/bit.h" |
| 19 | |
| 20 | namespace mlir { |
| 21 | namespace quant { |
| 22 | namespace detail { |
| 23 | |
| 24 | struct QuantizedTypeStorage : public mlir::TypeStorage { |
| 25 | QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType, |
| 26 | int64_t storageTypeMin, int64_t storageTypeMax) |
| 27 | : flags(flags), storageType(storageType), expressedType(expressedType), |
| 28 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
| 29 | |
| 30 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
| 31 | unsigned flags; |
| 32 | |
| 33 | // Integral type for the storage point representation. |
| 34 | Type storageType; |
| 35 | |
| 36 | // Floating point type that the quantized type approximates. |
| 37 | Type expressedType; |
| 38 | |
| 39 | // The minimum value storageType can take. |
| 40 | int64_t storageTypeMin; |
| 41 | |
| 42 | // The maximum value storageType can take. |
| 43 | int64_t storageTypeMax; |
| 44 | }; |
| 45 | |
| 46 | struct AnyQuantizedTypeStorage : public QuantizedTypeStorage { |
| 47 | struct KeyTy { |
| 48 | KeyTy(unsigned flags, Type storageType, Type expressedType, |
| 49 | int64_t storageTypeMin, int64_t storageTypeMax) |
| 50 | : flags(flags), storageType(storageType), expressedType(expressedType), |
| 51 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
| 52 | unsigned flags; |
| 53 | Type storageType; |
| 54 | Type expressedType; |
| 55 | int64_t storageTypeMin; |
| 56 | int64_t storageTypeMax; |
| 57 | |
| 58 | // Check for equality of two structures that share KeyTy data members |
| 59 | // (by name). |
| 60 | template <typename T, typename U> |
| 61 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
| 62 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
| 63 | lhs.expressedType == rhs.expressedType && |
| 64 | lhs.storageTypeMin == rhs.storageTypeMin && |
| 65 | lhs.storageTypeMax == rhs.storageTypeMax; |
| 66 | } |
| 67 | |
| 68 | bool operator==(const KeyTy &other) const { |
| 69 | return genericIsEqual(lhs: *this, rhs: other); |
| 70 | } |
| 71 | |
| 72 | unsigned getHashValue() const { |
| 73 | return llvm::hash_combine(args: flags, args: storageType, args: expressedType, |
| 74 | args: storageTypeMin, args: storageTypeMax); |
| 75 | } |
| 76 | }; |
| 77 | |
| 78 | AnyQuantizedTypeStorage(const KeyTy &key) |
| 79 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
| 80 | key.storageTypeMin, key.storageTypeMax) {} |
| 81 | |
| 82 | bool operator==(const KeyTy &key) const { |
| 83 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
| 84 | } |
| 85 | |
| 86 | /// Construction. |
| 87 | static AnyQuantizedTypeStorage *construct(TypeStorageAllocator &allocator, |
| 88 | const KeyTy &key) { |
| 89 | return new (allocator.allocate<AnyQuantizedTypeStorage>()) |
| 90 | AnyQuantizedTypeStorage(key); |
| 91 | } |
| 92 | |
| 93 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
| 94 | }; |
| 95 | |
| 96 | struct UniformQuantizedTypeStorage : public QuantizedTypeStorage { |
| 97 | struct KeyTy { |
| 98 | KeyTy(unsigned flags, Type storageType, Type expressedType, double scale, |
| 99 | int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) |
| 100 | : flags(flags), storageType(storageType), expressedType(expressedType), |
| 101 | scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin), |
| 102 | storageTypeMax(storageTypeMax) {} |
| 103 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
| 104 | unsigned flags; |
| 105 | |
| 106 | // Integral type for the storage point representation. |
| 107 | Type storageType; |
| 108 | |
| 109 | // Floating point type that the quantized type approximates. |
| 110 | Type expressedType; |
| 111 | |
| 112 | double scale; |
| 113 | int64_t zeroPoint; |
| 114 | int64_t storageTypeMin; |
| 115 | int64_t storageTypeMax; |
| 116 | |
| 117 | // Check for equality of two structures that share KeyTy data members |
| 118 | // (by name). |
| 119 | template <typename T, typename U> |
| 120 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
| 121 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
| 122 | lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale && |
| 123 | lhs.zeroPoint == rhs.zeroPoint && |
| 124 | lhs.storageTypeMin == rhs.storageTypeMin && |
| 125 | lhs.storageTypeMax == rhs.storageTypeMax; |
| 126 | } |
| 127 | |
| 128 | bool operator==(const KeyTy &other) const { |
| 129 | return genericIsEqual(lhs: *this, rhs: other); |
| 130 | } |
| 131 | |
| 132 | unsigned getHashValue() const { |
| 133 | int64_t scaleBits = llvm::bit_cast<int64_t>(from: scale); |
| 134 | return llvm::hash_combine(args: flags, args: storageType, args: expressedType, args: scaleBits, |
| 135 | args: zeroPoint, args: storageTypeMin, args: storageTypeMax); |
| 136 | } |
| 137 | }; |
| 138 | |
| 139 | UniformQuantizedTypeStorage(const KeyTy &key) |
| 140 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
| 141 | key.storageTypeMin, key.storageTypeMax), |
| 142 | scale(key.scale), zeroPoint(key.zeroPoint) {} |
| 143 | |
| 144 | bool operator==(const KeyTy &key) const { |
| 145 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
| 146 | } |
| 147 | |
| 148 | /// Construction. |
| 149 | static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator, |
| 150 | const KeyTy &key) { |
| 151 | return new (allocator.allocate<UniformQuantizedTypeStorage>()) |
| 152 | UniformQuantizedTypeStorage(key); |
| 153 | } |
| 154 | |
| 155 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
| 156 | |
| 157 | double scale; |
| 158 | int64_t zeroPoint; |
| 159 | }; |
| 160 | |
| 161 | struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage { |
| 162 | struct KeyTy { |
| 163 | KeyTy(unsigned flags, Type storageType, Type expressedType, |
| 164 | ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, |
| 165 | int32_t quantizedDimension, int64_t storageTypeMin, |
| 166 | int64_t storageTypeMax) |
| 167 | : flags(flags), storageType(storageType), expressedType(expressedType), |
| 168 | scales(scales), zeroPoints(zeroPoints), |
| 169 | quantizedDimension(quantizedDimension), |
| 170 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
| 171 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
| 172 | unsigned flags; |
| 173 | |
| 174 | // Integral type for the storage point representation. |
| 175 | Type storageType; |
| 176 | |
| 177 | // Floating point type that the quantized type approximates. |
| 178 | Type expressedType; |
| 179 | |
| 180 | ArrayRef<double> scales; |
| 181 | ArrayRef<int64_t> zeroPoints; |
| 182 | int32_t quantizedDimension; |
| 183 | int64_t storageTypeMin; |
| 184 | int64_t storageTypeMax; |
| 185 | |
| 186 | ArrayRef<double> getScales() const { return scales; } |
| 187 | |
| 188 | ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; } |
| 189 | |
| 190 | // Check for equality of two structures that share KeyTy data members |
| 191 | // (by name). |
| 192 | template <typename T, typename U> |
| 193 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
| 194 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
| 195 | lhs.expressedType == rhs.expressedType && |
| 196 | lhs.getScales() == rhs.getScales() && |
| 197 | lhs.getZeroPoints() == rhs.getZeroPoints() && |
| 198 | lhs.quantizedDimension == rhs.quantizedDimension && |
| 199 | lhs.storageTypeMin == rhs.storageTypeMin && |
| 200 | lhs.storageTypeMax == rhs.storageTypeMax; |
| 201 | } |
| 202 | |
| 203 | bool operator==(const KeyTy &other) const { |
| 204 | return genericIsEqual(lhs: *this, rhs: other); |
| 205 | } |
| 206 | |
| 207 | unsigned getHashValue() const { |
| 208 | int64_t *scalesCast = llvm::bit_cast<int64_t *>(from: scales.data()); |
| 209 | ArrayRef<int64_t> scalesBits(scalesCast, scales.size()); |
| 210 | return llvm::hash_combine(args: flags, args: storageType, args: expressedType, |
| 211 | args: llvm::hash_combine_range(R&: scalesBits), |
| 212 | args: llvm::hash_combine_range(R: zeroPoints), |
| 213 | args: storageTypeMin, args: storageTypeMax); |
| 214 | } |
| 215 | }; |
| 216 | |
| 217 | // We pass scales and zeroPoints in directly rather than relying on KeyTy |
| 218 | // because we have to create new reallocated versions in `construct` below. |
| 219 | UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef<double> scales, |
| 220 | ArrayRef<int64_t> zeroPoints) |
| 221 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
| 222 | key.storageTypeMin, key.storageTypeMax), |
| 223 | scaleElements(scales.data()), zeroPointElements(zeroPoints.data()), |
| 224 | quantParamsSize(scales.size()), |
| 225 | quantizedDimension(key.quantizedDimension) {} |
| 226 | |
| 227 | bool operator==(const KeyTy &key) const { |
| 228 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
| 229 | } |
| 230 | |
| 231 | /// Construction. |
| 232 | static UniformQuantizedPerAxisTypeStorage * |
| 233 | construct(TypeStorageAllocator &allocator, const KeyTy &key) { |
| 234 | ArrayRef<double> scales = allocator.copyInto(elements: key.scales); |
| 235 | ArrayRef<int64_t> zeroPoints = allocator.copyInto(elements: key.zeroPoints); |
| 236 | return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>()) |
| 237 | UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints); |
| 238 | } |
| 239 | |
| 240 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
| 241 | |
| 242 | ArrayRef<double> getScales() const { |
| 243 | return ArrayRef<double>(scaleElements, quantParamsSize); |
| 244 | } |
| 245 | |
| 246 | ArrayRef<int64_t> getZeroPoints() const { |
| 247 | return ArrayRef<int64_t>(zeroPointElements, quantParamsSize); |
| 248 | } |
| 249 | |
| 250 | const double *scaleElements; |
| 251 | const int64_t *zeroPointElements; |
| 252 | unsigned quantParamsSize; |
| 253 | int32_t quantizedDimension; |
| 254 | }; |
| 255 | |
| 256 | struct UniformQuantizedSubChannelTypeStorage : public QuantizedTypeStorage { |
| 257 | struct KeyTy { |
| 258 | KeyTy(unsigned flags, Type storageType, Type expressedType, |
| 259 | DenseElementsAttr scales, DenseElementsAttr zeroPoints, |
| 260 | ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes, |
| 261 | int64_t storageTypeMin, int64_t storageTypeMax) |
| 262 | : flags(flags), storageType(storageType), expressedType(expressedType), |
| 263 | scales(scales), zeroPoints(zeroPoints), |
| 264 | quantizedDimensions(quantizedDimensions), blockSizes(blockSizes), |
| 265 | storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} |
| 266 | /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue. |
| 267 | unsigned flags; |
| 268 | |
| 269 | // Integral type for the storage point representation. |
| 270 | Type storageType; |
| 271 | |
| 272 | // Floating point type that the quantized type approximates. |
| 273 | Type expressedType; |
| 274 | |
| 275 | DenseElementsAttr scales; |
| 276 | DenseElementsAttr zeroPoints; |
| 277 | ArrayRef<int32_t> quantizedDimensions; |
| 278 | ArrayRef<int64_t> blockSizes; |
| 279 | int64_t storageTypeMin; |
| 280 | int64_t storageTypeMax; |
| 281 | |
| 282 | DenseElementsAttr getScales() const { return scales; } |
| 283 | |
| 284 | DenseElementsAttr getZeroPoints() const { return zeroPoints; } |
| 285 | |
| 286 | // Check for equality of two structures that share KeyTy data members |
| 287 | // (by name). |
| 288 | template <typename T, typename U> |
| 289 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
| 290 | return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType && |
| 291 | lhs.expressedType == rhs.expressedType && |
| 292 | lhs.scales == rhs.scales && lhs.zeroPoints == rhs.zeroPoints && |
| 293 | lhs.quantizedDimensions == rhs.quantizedDimensions && |
| 294 | lhs.blockSizes == rhs.blockSizes && |
| 295 | lhs.storageTypeMin == rhs.storageTypeMin && |
| 296 | lhs.storageTypeMax == rhs.storageTypeMax; |
| 297 | } |
| 298 | |
| 299 | bool operator==(const KeyTy &other) const { |
| 300 | return genericIsEqual(lhs: *this, rhs: other); |
| 301 | } |
| 302 | |
| 303 | unsigned getHashValue() const { |
| 304 | // Hash the scalar attributes. |
| 305 | unsigned hash = llvm::hash_combine(args: flags, args: storageType, args: expressedType, |
| 306 | args: storageTypeMin, args: storageTypeMax); |
| 307 | |
| 308 | // Hash the scales. |
| 309 | for (auto scaleAttr : scales.getValues<APFloat>()) { |
| 310 | hash = llvm::hash_combine( |
| 311 | hash, llvm::bit_cast<int64_t>(scaleAttr.convertToDouble())); |
| 312 | } |
| 313 | |
| 314 | // Hash the zero points. (Assumed to be integers, adjust if needed). |
| 315 | for (auto zeroPointAttr : zeroPoints.getValues<APInt>()) { |
| 316 | hash = llvm::hash_combine(hash, zeroPointAttr.getSExtValue()); |
| 317 | } |
| 318 | |
| 319 | // Hash the quantized dimensions and block sizes. |
| 320 | hash = llvm::hash_combine(args: hash, |
| 321 | args: llvm::hash_combine_range(R: quantizedDimensions), |
| 322 | args: llvm::hash_combine_range(R: blockSizes)); |
| 323 | |
| 324 | return hash; |
| 325 | } |
| 326 | }; |
| 327 | |
| 328 | // We pass scales and zeroPoints in directly rather than relying on KeyTy |
| 329 | // because we have to create new reallocated versions in `construct` below. |
| 330 | UniformQuantizedSubChannelTypeStorage(const KeyTy &key, |
| 331 | DenseElementsAttr scales, |
| 332 | DenseElementsAttr zeroPoints, |
| 333 | ArrayRef<int32_t> quantizedDimensions, |
| 334 | ArrayRef<int64_t> blockSizes) |
| 335 | : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, |
| 336 | key.storageTypeMin, key.storageTypeMax), |
| 337 | scales(scales), zeroPoints(zeroPoints), |
| 338 | quantizedDimensions(quantizedDimensions), blockSizes(blockSizes) {} |
| 339 | |
| 340 | bool operator==(const KeyTy &key) const { |
| 341 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
| 342 | } |
| 343 | |
| 344 | /// Construction. |
| 345 | static UniformQuantizedSubChannelTypeStorage * |
| 346 | construct(TypeStorageAllocator &allocator, const KeyTy &key) { |
| 347 | DenseElementsAttr scales = key.scales; |
| 348 | DenseElementsAttr zeroPoints = key.zeroPoints; |
| 349 | ArrayRef<int32_t> quantizedDimensions = |
| 350 | allocator.copyInto(elements: key.quantizedDimensions); |
| 351 | ArrayRef<int64_t> blockSizes = allocator.copyInto(elements: key.blockSizes); |
| 352 | return new (allocator.allocate<UniformQuantizedSubChannelTypeStorage>()) |
| 353 | UniformQuantizedSubChannelTypeStorage(key, scales, zeroPoints, |
| 354 | quantizedDimensions, blockSizes); |
| 355 | } |
| 356 | |
| 357 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
| 358 | |
| 359 | DenseElementsAttr getScales() const { return scales; } |
| 360 | |
| 361 | DenseElementsAttr getZeroPoints() const { return zeroPoints; } |
| 362 | |
| 363 | ArrayRef<int32_t> getQuantizedDimensions() const { |
| 364 | return quantizedDimensions; |
| 365 | } |
| 366 | |
| 367 | ArrayRef<int64_t> getBlockSizes() const { return blockSizes; } |
| 368 | |
| 369 | DenseElementsAttr scales; |
| 370 | DenseElementsAttr zeroPoints; |
| 371 | ArrayRef<int32_t> quantizedDimensions; |
| 372 | ArrayRef<int64_t> blockSizes; |
| 373 | }; |
| 374 | |
| 375 | struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage { |
| 376 | struct KeyTy { |
| 377 | KeyTy(Type expressedType, double min, double max) |
| 378 | : expressedType(expressedType), min(min), max(max) {} |
| 379 | // Floating point type that the quantized type approximates. |
| 380 | Type expressedType; |
| 381 | |
| 382 | double min; |
| 383 | double max; |
| 384 | |
| 385 | // Check for equality of two structures that share KeyTy data members |
| 386 | // (by name). |
| 387 | template <typename T, typename U> |
| 388 | static bool genericIsEqual(const T &lhs, const U &rhs) { |
| 389 | return lhs.expressedType == rhs.expressedType && lhs.min == rhs.min && |
| 390 | lhs.max == rhs.max; |
| 391 | } |
| 392 | |
| 393 | bool operator==(const KeyTy &other) const { |
| 394 | return genericIsEqual(lhs: *this, rhs: other); |
| 395 | } |
| 396 | |
| 397 | unsigned getHashValue() const { |
| 398 | int64_t minBits = llvm::bit_cast<double>(from: min); |
| 399 | int64_t maxBits = llvm::bit_cast<double>(from: max); |
| 400 | return llvm::hash_combine(args: expressedType, args: minBits, args: maxBits); |
| 401 | } |
| 402 | }; |
| 403 | |
| 404 | CalibratedQuantizedTypeStorage(const KeyTy &key) |
| 405 | : QuantizedTypeStorage(0, NoneType(), key.expressedType, 0, 0), |
| 406 | min(key.min), max(key.max) {} |
| 407 | |
| 408 | bool operator==(const KeyTy &key) const { |
| 409 | return KeyTy::genericIsEqual(lhs: *this, rhs: key); |
| 410 | } |
| 411 | |
| 412 | /// Construction. |
| 413 | static CalibratedQuantizedTypeStorage * |
| 414 | construct(TypeStorageAllocator &allocator, const KeyTy &key) { |
| 415 | return new (allocator.allocate<CalibratedQuantizedTypeStorage>()) |
| 416 | CalibratedQuantizedTypeStorage(key); |
| 417 | } |
| 418 | |
| 419 | static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } |
| 420 | |
| 421 | double min; |
| 422 | double max; |
| 423 | }; |
| 424 | |
| 425 | } // namespace detail |
| 426 | } // namespace quant |
| 427 | } // namespace mlir |
| 428 | |
| 429 | #endif // TYPE_DETAIL_H_ |
| 430 | |