1//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Quant/IR/QuantTypes.h"
10#include "TypeDetail.h"
11#include "mlir/Dialect/Quant/IR/Quant.h"
12
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/MLIRContext.h"
15
16using namespace mlir;
17using namespace mlir::quant;
18using namespace mlir::quant::detail;
19
20namespace {
21
22// Return the minimum scale representable in a given float type
23double getMinScale(Type expressedType) {
24 auto floatType = cast<FloatType>(Val&: expressedType);
25 return APFloat::getSmallest(Sem: floatType.getFloatSemantics()).convertToDouble();
26}
27
28// Return the maximum scale representable in a given float type
29double getMaxScale(Type expressedType) {
30 auto floatType = cast<FloatType>(Val&: expressedType);
31 return APFloat::getLargest(Sem: floatType.getFloatSemantics()).convertToDouble();
32}
33
34} // namespace
35
36unsigned QuantizedType::getFlags() const {
37 return static_cast<ImplType *>(impl)->flags;
38}
39
40bool QuantizedType::classof(Type type) {
41 return llvm::isa<QuantDialect>(Val: type.getDialect());
42}
43
44LogicalResult
45QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
46 unsigned flags, Type storageType,
47 Type expressedType, int64_t storageTypeMin,
48 int64_t storageTypeMax) {
49 // Verify that the storage type is integral.
50 // This restriction may be lifted at some point in favor of using bf16
51 // or f16 as exact representations on hardware where that is advantageous.
52 auto intStorageType = llvm::dyn_cast<IntegerType>(Val&: storageType);
53 if (!intStorageType)
54 return emitError() << "storage type must be integral";
55 unsigned integralWidth = intStorageType.getWidth();
56
57 // Verify storage width.
58 if (integralWidth == 0 || integralWidth > MaxStorageBits)
59 return emitError() << "illegal storage type size: " << integralWidth;
60
61 // Verify storageTypeMin and storageTypeMax.
62 bool isSigned =
63 (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
64 int64_t defaultIntegerMin =
65 getDefaultMinimumForInteger(isSigned, integralWidth);
66 int64_t defaultIntegerMax =
67 getDefaultMaximumForInteger(isSigned, integralWidth);
68 if (storageTypeMax - storageTypeMin <= 0 ||
69 storageTypeMin < defaultIntegerMin ||
70 storageTypeMax > defaultIntegerMax) {
71 return emitError() << "illegal storage min and storage max: ("
72 << storageTypeMin << ":" << storageTypeMax << ")";
73 }
74 return success();
75}
76
77Type QuantizedType::getStorageType() const {
78 return static_cast<ImplType *>(impl)->storageType;
79}
80
81int64_t QuantizedType::getStorageTypeMin() const {
82 return static_cast<ImplType *>(impl)->storageTypeMin;
83}
84
85int64_t QuantizedType::getStorageTypeMax() const {
86 return static_cast<ImplType *>(impl)->storageTypeMax;
87}
88
89bool QuantizedType::hasStorageTypeBounds() const {
90 unsigned int integralWidth = getStorageTypeIntegralWidth();
91 bool isSignedInteger = isSigned();
92 int64_t defaultIntegerMin =
93 getDefaultMinimumForInteger(isSigned: isSignedInteger, integralWidth);
94 int64_t defaultIntegerMax =
95 getDefaultMaximumForInteger(isSigned: isSignedInteger, integralWidth);
96 return defaultIntegerMin != getStorageTypeMin() ||
97 defaultIntegerMax != getStorageTypeMax();
98}
99
100unsigned QuantizedType::getStorageTypeIntegralWidth() const {
101 // NOTE: If ever supporting non-integral storage types, some other scheme
102 // for determining the width will be needed.
103 return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
104}
105
106Type QuantizedType::getExpressedType() const {
107 return static_cast<ImplType *>(impl)->expressedType;
108}
109
110bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
111 if (llvm::isa<ShapedType>(Val: candidateExpressedType)) {
112 return llvm::cast<ShapedType>(Val&: candidateExpressedType).getElementType() ==
113 getExpressedType();
114 }
115 return candidateExpressedType == getExpressedType();
116}
117
118QuantizedType
119QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
120 if (llvm::isa<ShapedType>(Val: primitiveOrContainerType)) {
121 Type elementType =
122 llvm::cast<ShapedType>(Val&: primitiveOrContainerType).getElementType();
123 return llvm::dyn_cast<QuantizedType>(Val&: elementType);
124 }
125 return llvm::dyn_cast<QuantizedType>(Val&: primitiveOrContainerType);
126}
127
128Type QuantizedType::castFromStorageType(Type candidateType) {
129 if (candidateType == getStorageType()) {
130 // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
131 return *this;
132 }
133 if (llvm::isa<RankedTensorType>(Val: candidateType)) {
134 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
135 return RankedTensorType::get(
136 shape: llvm::cast<RankedTensorType>(Val&: candidateType).getShape(),
137 elementType: getStorageType());
138 }
139 if (llvm::isa<UnrankedTensorType>(Val: candidateType)) {
140 // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
141 return UnrankedTensorType::get(elementType: getStorageType());
142 }
143 if (llvm::isa<VectorType>(Val: candidateType)) {
144 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
145 return VectorType::get(shape: llvm::cast<VectorType>(Val&: candidateType).getShape(),
146 elementType: getStorageType());
147 }
148
149 return nullptr;
150}
151
152Type QuantizedType::castToStorageType(Type quantizedType) {
153 if (llvm::isa<QuantizedType>(Val: quantizedType)) {
154 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
155 return llvm::cast<QuantizedType>(Val&: quantizedType).getStorageType();
156 }
157 if (llvm::isa<ShapedType>(Val: quantizedType)) {
158 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
159 ShapedType sType = llvm::cast<ShapedType>(Val&: quantizedType);
160 if (!llvm::isa<QuantizedType>(Val: sType.getElementType())) {
161 return nullptr;
162 }
163 Type storageType =
164 llvm::cast<QuantizedType>(Val: sType.getElementType()).getStorageType();
165 if (llvm::isa<RankedTensorType>(Val: quantizedType)) {
166 return RankedTensorType::get(shape: sType.getShape(), elementType: storageType);
167 }
168 if (llvm::isa<UnrankedTensorType>(Val: quantizedType)) {
169 return UnrankedTensorType::get(elementType: storageType);
170 }
171 if (llvm::isa<VectorType>(Val: quantizedType)) {
172 return VectorType::get(shape: sType.getShape(), elementType: storageType);
173 }
174 }
175
176 return nullptr;
177}
178
179Type QuantizedType::castFromExpressedType(Type candidateType) {
180 if (candidateType == getExpressedType()) {
181 // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
182 return *this;
183 }
184 if (llvm::isa<ShapedType>(Val: candidateType)) {
185 ShapedType candidateShapedType = llvm::cast<ShapedType>(Val&: candidateType);
186 if (candidateShapedType.getElementType() != getExpressedType()) {
187 return nullptr;
188 }
189
190 if (llvm::isa<RankedTensorType>(Val: candidateType)) {
191 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
192 return RankedTensorType::get(shape: candidateShapedType.getShape(), elementType: *this);
193 }
194 if (llvm::isa<UnrankedTensorType>(Val: candidateType)) {
195 // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
196 return UnrankedTensorType::get(elementType: *this);
197 }
198 if (llvm::isa<VectorType>(Val: candidateType)) {
199 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
200 return VectorType::get(shape: candidateShapedType.getShape(), elementType: *this);
201 }
202 }
203
204 return nullptr;
205}
206
207Type QuantizedType::castToExpressedType(Type quantizedType) {
208 if (llvm::isa<QuantizedType>(Val: quantizedType)) {
209 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
210 return llvm::cast<QuantizedType>(Val&: quantizedType).getExpressedType();
211 }
212 if (llvm::isa<ShapedType>(Val: quantizedType)) {
213 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
214 ShapedType sType = llvm::cast<ShapedType>(Val&: quantizedType);
215 if (!llvm::isa<QuantizedType>(Val: sType.getElementType())) {
216 return nullptr;
217 }
218 Type expressedType =
219 llvm::cast<QuantizedType>(Val: sType.getElementType()).getExpressedType();
220 if (llvm::isa<RankedTensorType>(Val: quantizedType)) {
221 return RankedTensorType::get(shape: sType.getShape(), elementType: expressedType);
222 }
223 if (llvm::isa<UnrankedTensorType>(Val: quantizedType)) {
224 return UnrankedTensorType::get(elementType: expressedType);
225 }
226 if (llvm::isa<VectorType>(Val: quantizedType)) {
227 return VectorType::get(shape: sType.getShape(), elementType: expressedType);
228 }
229 }
230
231 return nullptr;
232}
233
234Type QuantizedType::castExpressedToStorageType(Type candidateType) {
235 Type expressedQuantizedType = castFromExpressedType(candidateType);
236 if (!expressedQuantizedType) {
237 return nullptr;
238 }
239 return QuantizedType::castToStorageType(quantizedType: expressedQuantizedType);
240}
241
242AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
243 Type expressedType,
244 int64_t storageTypeMin,
245 int64_t storageTypeMax) {
246 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
247 args&: storageTypeMin, args&: storageTypeMax);
248}
249
250AnyQuantizedType
251AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
252 unsigned flags, Type storageType,
253 Type expressedType, int64_t storageTypeMin,
254 int64_t storageTypeMax) {
255 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
256 args: storageType, args: expressedType, args: storageTypeMin,
257 args: storageTypeMax);
258}
259
260LogicalResult
261AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
262 unsigned flags, Type storageType,
263 Type expressedType, int64_t storageTypeMin,
264 int64_t storageTypeMax) {
265 if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType,
266 expressedType, storageTypeMin,
267 storageTypeMax))) {
268 return failure();
269 }
270
271 // Verify that the expressed type is floating point.
272 // If this restriction is ever eliminated, the parser/printer must be
273 // extended.
274 if (expressedType && !llvm::isa<FloatType>(Val: expressedType))
275 return emitError() << "expressed type must be floating point";
276
277 return success();
278}
279
280UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
281 Type expressedType, double scale,
282 int64_t zeroPoint,
283 int64_t storageTypeMin,
284 int64_t storageTypeMax) {
285 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
286 args&: scale, args&: zeroPoint, args&: storageTypeMin, args&: storageTypeMax);
287}
288
289UniformQuantizedType UniformQuantizedType::getChecked(
290 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
291 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
292 int64_t storageTypeMin, int64_t storageTypeMax) {
293 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
294 args: storageType, args: expressedType, args: scale, args: zeroPoint,
295 args: storageTypeMin, args: storageTypeMax);
296}
297
298LogicalResult UniformQuantizedType::verifyInvariants(
299 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
300 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
301 int64_t storageTypeMin, int64_t storageTypeMax) {
302 if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType,
303 expressedType, storageTypeMin,
304 storageTypeMax))) {
305 return failure();
306 }
307
308 // Uniform quantization requires fully expressed parameters, including
309 // expressed type.
310 if (!expressedType)
311 return emitError() << "uniform quantization requires expressed type";
312
313 // Verify that the expressed type is floating point.
314 // If this restriction is ever eliminated, the parser/printer must be
315 // extended.
316 if (!llvm::isa<FloatType>(Val: expressedType))
317 return emitError() << "expressed type must be floating point";
318
319 // Verify scale.
320 double minScale = getMinScale(expressedType);
321 double maxScale = getMaxScale(expressedType);
322 if (scale < minScale || scale > maxScale)
323 return emitError() << "scale out of expressed type range [" << minScale
324 << ", " << maxScale << "]";
325
326 return success();
327}
328
329double UniformQuantizedType::getScale() const { return getImpl()->scale; }
330
331int64_t UniformQuantizedType::getZeroPoint() const {
332 return getImpl()->zeroPoint;
333}
334
335UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
336 unsigned flags, Type storageType, Type expressedType,
337 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
338 int32_t quantizedDimension, int64_t storageTypeMin,
339 int64_t storageTypeMax) {
340 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
341 args&: scales, args&: zeroPoints, args&: quantizedDimension, args&: storageTypeMin,
342 args&: storageTypeMax);
343}
344
345UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
346 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
347 Type storageType, Type expressedType, ArrayRef<double> scales,
348 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
349 int64_t storageTypeMin, int64_t storageTypeMax) {
350 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
351 args: storageType, args: expressedType, args: scales, args: zeroPoints,
352 args: quantizedDimension, args: storageTypeMin, args: storageTypeMax);
353}
354
355LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
356 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
357 Type storageType, Type expressedType, ArrayRef<double> scales,
358 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
359 int64_t storageTypeMin, int64_t storageTypeMax) {
360 if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType,
361 expressedType, storageTypeMin,
362 storageTypeMax))) {
363 return failure();
364 }
365
366 // Uniform quantization requires fully expressed parameters, including
367 // expressed type.
368 if (!expressedType)
369 return emitError() << "uniform quantization requires expressed type";
370
371 // Verify that the expressed type is floating point.
372 // If this restriction is ever eliminated, the parser/printer must be
373 // extended.
374 if (!llvm::isa<FloatType>(Val: expressedType))
375 return emitError() << "expressed type must be floating point";
376
377 // Ensure that the number of scales and zeroPoints match.
378 if (scales.size() != zeroPoints.size())
379 return emitError() << "illegal number of scales and zeroPoints: "
380 << scales.size() << ", " << zeroPoints.size();
381
382 // Verify scale.
383 double minScale = getMinScale(expressedType);
384 double maxScale = getMaxScale(expressedType);
385 for (double scale : scales) {
386 if (scale < minScale || scale > maxScale)
387 return emitError() << "scale out of expressed type range [" << minScale
388 << ", " << maxScale << "]";
389 }
390
391 // Verify quantized dimension.
392 if (quantizedDimension < 0)
393 return emitError() << "illegal quantized dimension: " << quantizedDimension;
394
395 return success();
396}
397
398ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
399 return getImpl()->getScales();
400}
401
402ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
403 return getImpl()->getZeroPoints();
404}
405
406int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
407 return getImpl()->quantizedDimension;
408}
409
410UniformQuantizedSubChannelType UniformQuantizedSubChannelType::get(
411 unsigned flags, Type storageType, Type expressedType,
412 DenseElementsAttr scales, DenseElementsAttr zeroPoints,
413 ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
414 int64_t storageTypeMin, int64_t storageTypeMax) {
415 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
416 args&: scales, args&: zeroPoints, args&: quantizedDimensions, args&: blockSizes,
417 args&: storageTypeMin, args&: storageTypeMax);
418}
419
420UniformQuantizedSubChannelType UniformQuantizedSubChannelType::getChecked(
421 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
422 Type storageType, Type expressedType, DenseElementsAttr scales,
423 DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
424 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
425 int64_t storageTypeMax) {
426 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
427 args: storageType, args: expressedType, args: scales, args: zeroPoints,
428 args: quantizedDimensions, args: blockSizes, args: storageTypeMin,
429 args: storageTypeMax);
430}
431
432LogicalResult UniformQuantizedSubChannelType::verifyInvariants(
433 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
434 Type storageType, Type expressedType, DenseElementsAttr scales,
435 DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
436 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
437 int64_t storageTypeMax) {
438 if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType,
439 expressedType, storageTypeMin,
440 storageTypeMax))) {
441 return failure();
442 }
443
444 // Uniform quantization requires fully expressed parameters, including
445 // expressed type.
446 if (!expressedType)
447 return emitError() << "uniform quantization requires expressed type";
448
449 // Verify that the expressed type is floating point.
450 // If this restriction is ever eliminated, the parser/printer must be
451 // extended.
452 if (!llvm::isa<FloatType>(Val: expressedType))
453 return emitError() << "expressed type must be floating point";
454
455 // Verify scale type to match expressedType.
456 if (scales.getType().getElementType() != expressedType) {
457 return emitError() << "type of scale values "
458 << scales.getType().getElementType()
459 << " must match the expressed type " << expressedType;
460 }
461
462 // Verify zero-point type to match storageType.
463 if (zeroPoints.getType().getElementType() != storageType) {
464 return emitError() << "type of zero point values "
465 << zeroPoints.getType().getElementType()
466 << " must match the storage type " << storageType;
467 }
468
469 // Ensure that the shape of scales and zeroPoints match.
470 if (scales.getType().getShape() != zeroPoints.getType().getShape())
471 return emitError() << "shape of scales and zeroPoints ("
472 << scales.getType().getShape() << " vs "
473 << zeroPoints.getType().getShape() << ") does not match";
474
475 // Ensure that the number of quantized-dimensions and block-sizes match.
476 if (quantizedDimensions.size() != blockSizes.size())
477 return emitError() << "number of quantized dimensions and block sizes ("
478 << scales.size() << " vs " << zeroPoints.size()
479 << ") does not match";
480
481 // Verify quantized dimension.
482 for (auto quantizedDimension : quantizedDimensions) {
483 if (quantizedDimension < 0)
484 return emitError() << "illegal quantized dimension: "
485 << quantizedDimension;
486 }
487
488 // Verify block sizes.
489 for (auto blockSize : blockSizes) {
490 if (blockSize <= 0)
491 return emitError() << "illegal block size: " << blockSize;
492 }
493
494 return success();
495}
496
497DenseElementsAttr UniformQuantizedSubChannelType::getScales() const {
498 return getImpl()->getScales();
499}
500
501DenseElementsAttr UniformQuantizedSubChannelType::getZeroPoints() const {
502 return getImpl()->getZeroPoints();
503}
504
505ArrayRef<int32_t>
506UniformQuantizedSubChannelType::getQuantizedDimensions() const {
507 return getImpl()->getQuantizedDimensions();
508}
509
510ArrayRef<int64_t> UniformQuantizedSubChannelType::getBlockSizes() const {
511 return getImpl()->getBlockSizes();
512}
513
514const SmallVector<std::pair<int32_t, int64_t>>
515UniformQuantizedSubChannelType::getBlockSizeInfo() const {
516 SmallVector<std::pair<int32_t, int64_t>> result;
517 result.reserve(N: getQuantizedDimensions().size());
518
519 for (auto [dim, size] :
520 llvm::zip(t: getQuantizedDimensions(), u: getBlockSizes())) {
521 result.push_back(Elt: {dim, size});
522 }
523
524 return result;
525}
526
527CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
528 double min, double max) {
529 return Base::get(ctx: expressedType.getContext(), args&: expressedType, args&: min, args&: max);
530}
531
532CalibratedQuantizedType CalibratedQuantizedType::getChecked(
533 function_ref<InFlightDiagnostic()> emitError, Type expressedType,
534 double min, double max) {
535 return Base::getChecked(emitErrorFn: emitError, ctx: expressedType.getContext(), args: expressedType,
536 args: min, args: max);
537}
538
539LogicalResult CalibratedQuantizedType::verifyInvariants(
540 function_ref<InFlightDiagnostic()> emitError, Type expressedType,
541 double min, double max) {
542 // Verify that the expressed type is floating point.
543 // If this restriction is ever eliminated, the parser/printer must be
544 // extended.
545 if (!llvm::isa<FloatType>(Val: expressedType))
546 return emitError() << "expressed type must be floating point";
547 if (max <= min)
548 return emitError() << "illegal min and max: (" << min << ":" << max << ")";
549
550 return success();
551}
552
553double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
554
555double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
556

source code of mlir/lib/Dialect/Quant/IR/QuantTypes.cpp