1//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Quant/IR/QuantTypes.h"
10#include "TypeDetail.h"
11#include "mlir/Dialect/Quant/IR/Quant.h"
12
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/MLIRContext.h"
15#include "llvm/ADT/StringRef.h"
16#include "llvm/ADT/Twine.h"
17#include "llvm/Support/MathExtras.h"
18
19using namespace mlir;
20using namespace mlir::quant;
21using namespace mlir::quant::detail;
22
23namespace {
24
25// Return the minimum scale representable in a given float type
26double getMinScale(Type expressedType) {
27 auto floatType = cast<FloatType>(expressedType);
28 return APFloat::getSmallest(Sem: floatType.getFloatSemantics()).convertToDouble();
29}
30
31// Return the maximum scale representable in a given float type
32double getMaxScale(Type expressedType) {
33 auto floatType = cast<FloatType>(expressedType);
34 return APFloat::getLargest(Sem: floatType.getFloatSemantics()).convertToDouble();
35}
36
37} // namespace
38
39unsigned QuantizedType::getFlags() const {
40 return static_cast<ImplType *>(impl)->flags;
41}
42
43bool QuantizedType::classof(Type type) {
44 return llvm::isa<QuantDialect>(type.getDialect());
45}
46
47LogicalResult
48QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
49 unsigned flags, Type storageType,
50 Type expressedType, int64_t storageTypeMin,
51 int64_t storageTypeMax) {
52 // Verify that the storage type is integral.
53 // This restriction may be lifted at some point in favor of using bf16
54 // or f16 as exact representations on hardware where that is advantageous.
55 auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
56 if (!intStorageType)
57 return emitError() << "storage type must be integral";
58 unsigned integralWidth = intStorageType.getWidth();
59
60 // Verify storage width.
61 if (integralWidth == 0 || integralWidth > MaxStorageBits)
62 return emitError() << "illegal storage type size: " << integralWidth;
63
64 // Verify storageTypeMin and storageTypeMax.
65 bool isSigned =
66 (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
67 int64_t defaultIntegerMin =
68 getDefaultMinimumForInteger(isSigned, integralWidth);
69 int64_t defaultIntegerMax =
70 getDefaultMaximumForInteger(isSigned, integralWidth);
71 if (storageTypeMax - storageTypeMin <= 0 ||
72 storageTypeMin < defaultIntegerMin ||
73 storageTypeMax > defaultIntegerMax) {
74 return emitError() << "illegal storage min and storage max: ("
75 << storageTypeMin << ":" << storageTypeMax << ")";
76 }
77 return success();
78}
79
80Type QuantizedType::getStorageType() const {
81 return static_cast<ImplType *>(impl)->storageType;
82}
83
84int64_t QuantizedType::getStorageTypeMin() const {
85 return static_cast<ImplType *>(impl)->storageTypeMin;
86}
87
88int64_t QuantizedType::getStorageTypeMax() const {
89 return static_cast<ImplType *>(impl)->storageTypeMax;
90}
91
92bool QuantizedType::hasStorageTypeBounds() const {
93 unsigned int integralWidth = getStorageTypeIntegralWidth();
94 bool isSignedInteger = isSigned();
95 int64_t defaultIntegerMin =
96 getDefaultMinimumForInteger(isSigned: isSignedInteger, integralWidth);
97 int64_t defaultIntegerMax =
98 getDefaultMaximumForInteger(isSigned: isSignedInteger, integralWidth);
99 return defaultIntegerMin != getStorageTypeMin() ||
100 defaultIntegerMax != getStorageTypeMax();
101}
102
103unsigned QuantizedType::getStorageTypeIntegralWidth() const {
104 // NOTE: If ever supporting non-integral storage types, some other scheme
105 // for determining the width will be needed.
106 return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
107}
108
109Type QuantizedType::getExpressedType() const {
110 return static_cast<ImplType *>(impl)->expressedType;
111}
112
113bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
114 if (llvm::isa<ShapedType>(candidateExpressedType)) {
115 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
116 getExpressedType();
117 }
118 return candidateExpressedType == getExpressedType();
119}
120
121QuantizedType
122QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
123 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
124 Type elementType =
125 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
126 return llvm::dyn_cast<QuantizedType>(Val&: elementType);
127 }
128 return llvm::dyn_cast<QuantizedType>(Val&: primitiveOrContainerType);
129}
130
131Type QuantizedType::castFromStorageType(Type candidateType) {
132 if (candidateType == getStorageType()) {
133 // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
134 return *this;
135 }
136 if (llvm::isa<RankedTensorType>(Val: candidateType)) {
137 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
138 return RankedTensorType::get(
139 llvm::cast<RankedTensorType>(candidateType).getShape(),
140 getStorageType());
141 }
142 if (llvm::isa<UnrankedTensorType>(Val: candidateType)) {
143 // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
144 return UnrankedTensorType::get(getStorageType());
145 }
146 if (llvm::isa<VectorType>(Val: candidateType)) {
147 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
148 return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
149 getStorageType());
150 }
151
152 return nullptr;
153}
154
155Type QuantizedType::castToStorageType(Type quantizedType) {
156 if (llvm::isa<QuantizedType>(Val: quantizedType)) {
157 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
158 return llvm::cast<QuantizedType>(Val&: quantizedType).getStorageType();
159 }
160 if (llvm::isa<ShapedType>(quantizedType)) {
161 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
162 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
163 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
164 return nullptr;
165 }
166 Type storageType =
167 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
168 if (llvm::isa<RankedTensorType>(Val: quantizedType)) {
169 return RankedTensorType::get(sType.getShape(), storageType);
170 }
171 if (llvm::isa<UnrankedTensorType>(Val: quantizedType)) {
172 return UnrankedTensorType::get(storageType);
173 }
174 if (llvm::isa<VectorType>(Val: quantizedType)) {
175 return VectorType::get(sType.getShape(), storageType);
176 }
177 }
178
179 return nullptr;
180}
181
182Type QuantizedType::castFromExpressedType(Type candidateType) {
183 if (candidateType == getExpressedType()) {
184 // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
185 return *this;
186 }
187 if (llvm::isa<ShapedType>(candidateType)) {
188 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
189 if (candidateShapedType.getElementType() != getExpressedType()) {
190 return nullptr;
191 }
192
193 if (llvm::isa<RankedTensorType>(Val: candidateType)) {
194 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
195 return RankedTensorType::get(candidateShapedType.getShape(), *this);
196 }
197 if (llvm::isa<UnrankedTensorType>(Val: candidateType)) {
198 // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
199 return UnrankedTensorType::get(*this);
200 }
201 if (llvm::isa<VectorType>(Val: candidateType)) {
202 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
203 return VectorType::get(candidateShapedType.getShape(), *this);
204 }
205 }
206
207 return nullptr;
208}
209
210Type QuantizedType::castToExpressedType(Type quantizedType) {
211 if (llvm::isa<QuantizedType>(Val: quantizedType)) {
212 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
213 return llvm::cast<QuantizedType>(Val&: quantizedType).getExpressedType();
214 }
215 if (llvm::isa<ShapedType>(quantizedType)) {
216 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
217 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
218 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
219 return nullptr;
220 }
221 Type expressedType =
222 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
223 if (llvm::isa<RankedTensorType>(Val: quantizedType)) {
224 return RankedTensorType::get(sType.getShape(), expressedType);
225 }
226 if (llvm::isa<UnrankedTensorType>(Val: quantizedType)) {
227 return UnrankedTensorType::get(expressedType);
228 }
229 if (llvm::isa<VectorType>(Val: quantizedType)) {
230 return VectorType::get(sType.getShape(), expressedType);
231 }
232 }
233
234 return nullptr;
235}
236
237Type QuantizedType::castExpressedToStorageType(Type candidateType) {
238 Type expressedQuantizedType = castFromExpressedType(candidateType);
239 if (!expressedQuantizedType) {
240 return nullptr;
241 }
242 return QuantizedType::castToStorageType(quantizedType: expressedQuantizedType);
243}
244
245AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
246 Type expressedType,
247 int64_t storageTypeMin,
248 int64_t storageTypeMax) {
249 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
250 args&: storageTypeMin, args&: storageTypeMax);
251}
252
253AnyQuantizedType
254AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
255 unsigned flags, Type storageType,
256 Type expressedType, int64_t storageTypeMin,
257 int64_t storageTypeMax) {
258 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
259 args: storageType, args: expressedType, args: storageTypeMin,
260 args: storageTypeMax);
261}
262
263LogicalResult
264AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
265 unsigned flags, Type storageType,
266 Type expressedType, int64_t storageTypeMin,
267 int64_t storageTypeMax) {
268 if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType,
269 expressedType, storageTypeMin,
270 storageTypeMax))) {
271 return failure();
272 }
273
274 // Verify that the expressed type is floating point.
275 // If this restriction is ever eliminated, the parser/printer must be
276 // extended.
277 if (expressedType && !llvm::isa<FloatType>(Val: expressedType))
278 return emitError() << "expressed type must be floating point";
279
280 return success();
281}
282
283UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
284 Type expressedType, double scale,
285 int64_t zeroPoint,
286 int64_t storageTypeMin,
287 int64_t storageTypeMax) {
288 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
289 args&: scale, args&: zeroPoint, args&: storageTypeMin, args&: storageTypeMax);
290}
291
292UniformQuantizedType UniformQuantizedType::getChecked(
293 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
294 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
295 int64_t storageTypeMin, int64_t storageTypeMax) {
296 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
297 args: storageType, args: expressedType, args: scale, args: zeroPoint,
298 args: storageTypeMin, args: storageTypeMax);
299}
300
301LogicalResult UniformQuantizedType::verifyInvariants(
302 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
303 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
304 int64_t storageTypeMin, int64_t storageTypeMax) {
305 if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType,
306 expressedType, storageTypeMin,
307 storageTypeMax))) {
308 return failure();
309 }
310
311 // Uniform quantization requires fully expressed parameters, including
312 // expressed type.
313 if (!expressedType)
314 return emitError() << "uniform quantization requires expressed type";
315
316 // Verify that the expressed type is floating point.
317 // If this restriction is ever eliminated, the parser/printer must be
318 // extended.
319 if (!llvm::isa<FloatType>(Val: expressedType))
320 return emitError() << "expressed type must be floating point";
321
322 // Verify scale.
323 double minScale = getMinScale(expressedType);
324 double maxScale = getMaxScale(expressedType);
325 if (scale < minScale || scale > maxScale)
326 return emitError() << "scale out of expressed type range [" << minScale
327 << ", " << maxScale << "]";
328
329 return success();
330}
331
332double UniformQuantizedType::getScale() const { return getImpl()->scale; }
333
334int64_t UniformQuantizedType::getZeroPoint() const {
335 return getImpl()->zeroPoint;
336}
337
338UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
339 unsigned flags, Type storageType, Type expressedType,
340 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
341 int32_t quantizedDimension, int64_t storageTypeMin,
342 int64_t storageTypeMax) {
343 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
344 args&: scales, args&: zeroPoints, args&: quantizedDimension, args&: storageTypeMin,
345 args&: storageTypeMax);
346}
347
348UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
349 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
350 Type storageType, Type expressedType, ArrayRef<double> scales,
351 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
352 int64_t storageTypeMin, int64_t storageTypeMax) {
353 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
354 args: storageType, args: expressedType, args: scales, args: zeroPoints,
355 args: quantizedDimension, args: storageTypeMin, args: storageTypeMax);
356}
357
358LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
359 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
360 Type storageType, Type expressedType, ArrayRef<double> scales,
361 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
362 int64_t storageTypeMin, int64_t storageTypeMax) {
363 if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType,
364 expressedType, storageTypeMin,
365 storageTypeMax))) {
366 return failure();
367 }
368
369 // Uniform quantization requires fully expressed parameters, including
370 // expressed type.
371 if (!expressedType)
372 return emitError() << "uniform quantization requires expressed type";
373
374 // Verify that the expressed type is floating point.
375 // If this restriction is ever eliminated, the parser/printer must be
376 // extended.
377 if (!llvm::isa<FloatType>(Val: expressedType))
378 return emitError() << "expressed type must be floating point";
379
380 // Ensure that the number of scales and zeroPoints match.
381 if (scales.size() != zeroPoints.size())
382 return emitError() << "illegal number of scales and zeroPoints: "
383 << scales.size() << ", " << zeroPoints.size();
384
385 // Verify scale.
386 double minScale = getMinScale(expressedType);
387 double maxScale = getMaxScale(expressedType);
388 for (double scale : scales) {
389 if (scale < minScale || scale > maxScale)
390 return emitError() << "scale out of expressed type range [" << minScale
391 << ", " << maxScale << "]";
392 }
393
394 // Verify quantized dimension.
395 if (quantizedDimension < 0)
396 return emitError() << "illegal quantized dimension: " << quantizedDimension;
397
398 return success();
399}
400
401ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
402 return getImpl()->getScales();
403}
404
405ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
406 return getImpl()->getZeroPoints();
407}
408
409int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
410 return getImpl()->quantizedDimension;
411}
412
413UniformQuantizedSubChannelType UniformQuantizedSubChannelType::get(
414 unsigned flags, Type storageType, Type expressedType,
415 DenseElementsAttr scales, DenseElementsAttr zeroPoints,
416 ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes,
417 int64_t storageTypeMin, int64_t storageTypeMax) {
418 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
419 args&: scales, args&: zeroPoints, args&: quantizedDimensions, args&: blockSizes,
420 args&: storageTypeMin, args&: storageTypeMax);
421}
422
423UniformQuantizedSubChannelType UniformQuantizedSubChannelType::getChecked(
424 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
425 Type storageType, Type expressedType, DenseElementsAttr scales,
426 DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
427 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
428 int64_t storageTypeMax) {
429 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
430 args: storageType, args: expressedType, args: scales, args: zeroPoints,
431 args: quantizedDimensions, args: blockSizes, args: storageTypeMin,
432 args: storageTypeMax);
433}
434
435LogicalResult UniformQuantizedSubChannelType::verifyInvariants(
436 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
437 Type storageType, Type expressedType, DenseElementsAttr scales,
438 DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions,
439 ArrayRef<int64_t> blockSizes, int64_t storageTypeMin,
440 int64_t storageTypeMax) {
441 if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType,
442 expressedType, storageTypeMin,
443 storageTypeMax))) {
444 return failure();
445 }
446
447 // Uniform quantization requires fully expressed parameters, including
448 // expressed type.
449 if (!expressedType)
450 return emitError() << "uniform quantization requires expressed type";
451
452 // Verify that the expressed type is floating point.
453 // If this restriction is ever eliminated, the parser/printer must be
454 // extended.
455 if (!llvm::isa<FloatType>(Val: expressedType))
456 return emitError() << "expressed type must be floating point";
457
458 // Verify scale type to match expressedType.
459 if (scales.getType().getElementType() != expressedType) {
460 return emitError() << "type of scale values "
461 << scales.getType().getElementType()
462 << " must match the expressed type " << expressedType;
463 }
464
465 // Verify zero-point type to match storageType.
466 if (zeroPoints.getType().getElementType() != storageType) {
467 return emitError() << "type of zero point values "
468 << zeroPoints.getType().getElementType()
469 << " must match the storage type " << storageType;
470 }
471
472 // Ensure that the shape of scales and zeroPoints match.
473 if (scales.getType().getShape() != zeroPoints.getType().getShape())
474 return emitError() << "shape of scales and zeroPoints ("
475 << scales.getType().getShape() << " vs "
476 << zeroPoints.getType().getShape() << ") does not match";
477
478 // Ensure that the number of quantized-dimensions and block-sizes match.
479 if (quantizedDimensions.size() != blockSizes.size())
480 return emitError() << "number of quantized dimensions and block sizes ("
481 << scales.size() << " vs " << zeroPoints.size()
482 << ") does not match";
483
484 // Verify quantized dimension.
485 for (auto quantizedDimension : quantizedDimensions) {
486 if (quantizedDimension < 0)
487 return emitError() << "illegal quantized dimension: "
488 << quantizedDimension;
489 }
490
491 // Verify block sizes.
492 for (auto blockSize : blockSizes) {
493 if (blockSize <= 0)
494 return emitError() << "illegal block size: " << blockSize;
495 }
496
497 return success();
498}
499
500DenseElementsAttr UniformQuantizedSubChannelType::getScales() const {
501 return getImpl()->getScales();
502}
503
504DenseElementsAttr UniformQuantizedSubChannelType::getZeroPoints() const {
505 return getImpl()->getZeroPoints();
506}
507
508ArrayRef<int32_t>
509UniformQuantizedSubChannelType::getQuantizedDimensions() const {
510 return getImpl()->getQuantizedDimensions();
511}
512
513ArrayRef<int64_t> UniformQuantizedSubChannelType::getBlockSizes() const {
514 return getImpl()->getBlockSizes();
515}
516
517const SmallVector<std::pair<int32_t, int64_t>>
518UniformQuantizedSubChannelType::getBlockSizeInfo() const {
519 SmallVector<std::pair<int32_t, int64_t>> result;
520 result.reserve(N: getQuantizedDimensions().size());
521
522 for (auto [dim, size] :
523 llvm::zip(t: getQuantizedDimensions(), u: getBlockSizes())) {
524 result.push_back(Elt: {dim, size});
525 }
526
527 return result;
528}
529
530CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
531 double min, double max) {
532 return Base::get(ctx: expressedType.getContext(), args&: expressedType, args&: min, args&: max);
533}
534
535CalibratedQuantizedType CalibratedQuantizedType::getChecked(
536 function_ref<InFlightDiagnostic()> emitError, Type expressedType,
537 double min, double max) {
538 return Base::getChecked(emitErrorFn: emitError, ctx: expressedType.getContext(), args: expressedType,
539 args: min, args: max);
540}
541
542LogicalResult CalibratedQuantizedType::verifyInvariants(
543 function_ref<InFlightDiagnostic()> emitError, Type expressedType,
544 double min, double max) {
545 // Verify that the expressed type is floating point.
546 // If this restriction is ever eliminated, the parser/printer must be
547 // extended.
548 if (!llvm::isa<FloatType>(Val: expressedType))
549 return emitError() << "expressed type must be floating point";
550 if (max <= min)
551 return emitError() << "illegal min and max: (" << min << ":" << max << ")";
552
553 return success();
554}
555
556double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
557
558double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
559

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Quant/IR/QuantTypes.cpp