1//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Quant/QuantTypes.h"
10#include "TypeDetail.h"
11#include "mlir/Dialect/Quant/QuantOps.h"
12
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/MLIRContext.h"
15#include "llvm/ADT/StringRef.h"
16#include "llvm/ADT/Twine.h"
17#include "llvm/Support/MathExtras.h"
18
19using namespace mlir;
20using namespace mlir::quant;
21using namespace mlir::quant::detail;
22
23unsigned QuantizedType::getFlags() const {
24 return static_cast<ImplType *>(impl)->flags;
25}
26
27bool QuantizedType::classof(Type type) {
28 return llvm::isa<QuantizationDialect>(type.getDialect());
29}
30
31LogicalResult
32QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
33 unsigned flags, Type storageType, Type expressedType,
34 int64_t storageTypeMin, int64_t storageTypeMax) {
35 // Verify that the storage type is integral.
36 // This restriction may be lifted at some point in favor of using bf16
37 // or f16 as exact representations on hardware where that is advantageous.
38 auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
39 if (!intStorageType)
40 return emitError() << "storage type must be integral";
41 unsigned integralWidth = intStorageType.getWidth();
42
43 // Verify storage width.
44 if (integralWidth == 0 || integralWidth > MaxStorageBits)
45 return emitError() << "illegal storage type size: " << integralWidth;
46
47 // Verify storageTypeMin and storageTypeMax.
48 bool isSigned =
49 (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
50 int64_t defaultIntegerMin =
51 getDefaultMinimumForInteger(isSigned, integralWidth);
52 int64_t defaultIntegerMax =
53 getDefaultMaximumForInteger(isSigned, integralWidth);
54 if (storageTypeMax - storageTypeMin <= 0 ||
55 storageTypeMin < defaultIntegerMin ||
56 storageTypeMax > defaultIntegerMax) {
57 return emitError() << "illegal storage min and storage max: ("
58 << storageTypeMin << ":" << storageTypeMax << ")";
59 }
60 return success();
61}
62
63Type QuantizedType::getStorageType() const {
64 return static_cast<ImplType *>(impl)->storageType;
65}
66
67int64_t QuantizedType::getStorageTypeMin() const {
68 return static_cast<ImplType *>(impl)->storageTypeMin;
69}
70
71int64_t QuantizedType::getStorageTypeMax() const {
72 return static_cast<ImplType *>(impl)->storageTypeMax;
73}
74
75unsigned QuantizedType::getStorageTypeIntegralWidth() const {
76 // NOTE: If ever supporting non-integral storage types, some other scheme
77 // for determining the width will be needed.
78 return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
79}
80
81Type QuantizedType::getExpressedType() const {
82 return static_cast<ImplType *>(impl)->expressedType;
83}
84
85bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
86 if (llvm::isa<ShapedType>(candidateExpressedType)) {
87 return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
88 getExpressedType();
89 }
90 return candidateExpressedType == getExpressedType();
91}
92
93QuantizedType
94QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
95 if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
96 Type elementType =
97 llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
98 return llvm::dyn_cast<QuantizedType>(Val&: elementType);
99 }
100 return llvm::dyn_cast<QuantizedType>(Val&: primitiveOrContainerType);
101}
102
103Type QuantizedType::castFromStorageType(Type candidateType) {
104 if (candidateType == getStorageType()) {
105 // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
106 return *this;
107 }
108 if (llvm::isa<RankedTensorType>(Val: candidateType)) {
109 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
110 return RankedTensorType::get(
111 llvm::cast<RankedTensorType>(candidateType).getShape(),
112 getStorageType());
113 }
114 if (llvm::isa<UnrankedTensorType>(Val: candidateType)) {
115 // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
116 return UnrankedTensorType::get(getStorageType());
117 }
118 if (llvm::isa<VectorType>(Val: candidateType)) {
119 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
120 return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
121 getStorageType());
122 }
123
124 return nullptr;
125}
126
127Type QuantizedType::castToStorageType(Type quantizedType) {
128 if (llvm::isa<QuantizedType>(Val: quantizedType)) {
129 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
130 return llvm::cast<QuantizedType>(Val&: quantizedType).getStorageType();
131 }
132 if (llvm::isa<ShapedType>(quantizedType)) {
133 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
134 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
135 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
136 return nullptr;
137 }
138 Type storageType =
139 llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
140 if (llvm::isa<RankedTensorType>(Val: quantizedType)) {
141 return RankedTensorType::get(sType.getShape(), storageType);
142 }
143 if (llvm::isa<UnrankedTensorType>(Val: quantizedType)) {
144 return UnrankedTensorType::get(storageType);
145 }
146 if (llvm::isa<VectorType>(Val: quantizedType)) {
147 return VectorType::get(sType.getShape(), storageType);
148 }
149 }
150
151 return nullptr;
152}
153
154Type QuantizedType::castFromExpressedType(Type candidateType) {
155 if (candidateType == getExpressedType()) {
156 // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
157 return *this;
158 }
159 if (llvm::isa<ShapedType>(candidateType)) {
160 ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
161 if (candidateShapedType.getElementType() != getExpressedType()) {
162 return nullptr;
163 }
164
165 if (llvm::isa<RankedTensorType>(Val: candidateType)) {
166 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
167 return RankedTensorType::get(candidateShapedType.getShape(), *this);
168 }
169 if (llvm::isa<UnrankedTensorType>(Val: candidateType)) {
170 // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
171 return UnrankedTensorType::get(*this);
172 }
173 if (llvm::isa<VectorType>(Val: candidateType)) {
174 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
175 return VectorType::get(candidateShapedType.getShape(), *this);
176 }
177 }
178
179 return nullptr;
180}
181
182Type QuantizedType::castToExpressedType(Type quantizedType) {
183 if (llvm::isa<QuantizedType>(Val: quantizedType)) {
184 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
185 return llvm::cast<QuantizedType>(Val&: quantizedType).getExpressedType();
186 }
187 if (llvm::isa<ShapedType>(quantizedType)) {
188 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
189 ShapedType sType = llvm::cast<ShapedType>(quantizedType);
190 if (!llvm::isa<QuantizedType>(sType.getElementType())) {
191 return nullptr;
192 }
193 Type expressedType =
194 llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
195 if (llvm::isa<RankedTensorType>(Val: quantizedType)) {
196 return RankedTensorType::get(sType.getShape(), expressedType);
197 }
198 if (llvm::isa<UnrankedTensorType>(Val: quantizedType)) {
199 return UnrankedTensorType::get(expressedType);
200 }
201 if (llvm::isa<VectorType>(Val: quantizedType)) {
202 return VectorType::get(sType.getShape(), expressedType);
203 }
204 }
205
206 return nullptr;
207}
208
209Type QuantizedType::castExpressedToStorageType(Type candidateType) {
210 Type expressedQuantizedType = castFromExpressedType(candidateType);
211 if (!expressedQuantizedType) {
212 return nullptr;
213 }
214 return QuantizedType::castToStorageType(quantizedType: expressedQuantizedType);
215}
216
217AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
218 Type expressedType,
219 int64_t storageTypeMin,
220 int64_t storageTypeMax) {
221 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
222 args&: storageTypeMin, args&: storageTypeMax);
223}
224
225AnyQuantizedType
226AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
227 unsigned flags, Type storageType,
228 Type expressedType, int64_t storageTypeMin,
229 int64_t storageTypeMax) {
230 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
231 args: storageType, args: expressedType, args: storageTypeMin,
232 args: storageTypeMax);
233}
234
235LogicalResult
236AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
237 unsigned flags, Type storageType, Type expressedType,
238 int64_t storageTypeMin, int64_t storageTypeMax) {
239 if (failed(result: QuantizedType::verify(emitError, flags, storageType, expressedType,
240 storageTypeMin, storageTypeMax))) {
241 return failure();
242 }
243
244 // Verify that the expressed type is floating point.
245 // If this restriction is ever eliminated, the parser/printer must be
246 // extended.
247 if (expressedType && !llvm::isa<FloatType>(Val: expressedType))
248 return emitError() << "expressed type must be floating point";
249
250 return success();
251}
252
253UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
254 Type expressedType, double scale,
255 int64_t zeroPoint,
256 int64_t storageTypeMin,
257 int64_t storageTypeMax) {
258 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
259 args&: scale, args&: zeroPoint, args&: storageTypeMin, args&: storageTypeMax);
260}
261
262UniformQuantizedType UniformQuantizedType::getChecked(
263 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
264 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
265 int64_t storageTypeMin, int64_t storageTypeMax) {
266 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
267 args: storageType, args: expressedType, args: scale, args: zeroPoint,
268 args: storageTypeMin, args: storageTypeMax);
269}
270
271LogicalResult UniformQuantizedType::verify(
272 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
273 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
274 int64_t storageTypeMin, int64_t storageTypeMax) {
275 if (failed(result: QuantizedType::verify(emitError, flags, storageType, expressedType,
276 storageTypeMin, storageTypeMax))) {
277 return failure();
278 }
279
280 // Uniform quantization requires fully expressed parameters, including
281 // expressed type.
282 if (!expressedType)
283 return emitError() << "uniform quantization requires expressed type";
284
285 // Verify that the expressed type is floating point.
286 // If this restriction is ever eliminated, the parser/printer must be
287 // extended.
288 if (!llvm::isa<FloatType>(Val: expressedType))
289 return emitError() << "expressed type must be floating point";
290
291 // Verify scale.
292 if (scale <= 0.0 || std::isinf(x: scale) || std::isnan(x: scale))
293 return emitError() << "illegal scale: " << scale;
294
295 return success();
296}
297
298double UniformQuantizedType::getScale() const { return getImpl()->scale; }
299
300int64_t UniformQuantizedType::getZeroPoint() const {
301 return getImpl()->zeroPoint;
302}
303
304UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
305 unsigned flags, Type storageType, Type expressedType,
306 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
307 int32_t quantizedDimension, int64_t storageTypeMin,
308 int64_t storageTypeMax) {
309 return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType,
310 args&: scales, args&: zeroPoints, args&: quantizedDimension, args&: storageTypeMin,
311 args&: storageTypeMax);
312}
313
314UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
315 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
316 Type storageType, Type expressedType, ArrayRef<double> scales,
317 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
318 int64_t storageTypeMin, int64_t storageTypeMax) {
319 return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags,
320 args: storageType, args: expressedType, args: scales, args: zeroPoints,
321 args: quantizedDimension, args: storageTypeMin, args: storageTypeMax);
322}
323
324LogicalResult UniformQuantizedPerAxisType::verify(
325 function_ref<InFlightDiagnostic()> emitError, unsigned flags,
326 Type storageType, Type expressedType, ArrayRef<double> scales,
327 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
328 int64_t storageTypeMin, int64_t storageTypeMax) {
329 if (failed(result: QuantizedType::verify(emitError, flags, storageType, expressedType,
330 storageTypeMin, storageTypeMax))) {
331 return failure();
332 }
333
334 // Uniform quantization requires fully expressed parameters, including
335 // expressed type.
336 if (!expressedType)
337 return emitError() << "uniform quantization requires expressed type";
338
339 // Verify that the expressed type is floating point.
340 // If this restriction is ever eliminated, the parser/printer must be
341 // extended.
342 if (!llvm::isa<FloatType>(Val: expressedType))
343 return emitError() << "expressed type must be floating point";
344
345 // Ensure that the number of scales and zeroPoints match.
346 if (scales.size() != zeroPoints.size())
347 return emitError() << "illegal number of scales and zeroPoints: "
348 << scales.size() << ", " << zeroPoints.size();
349
350 // Verify scale.
351 for (double scale : scales) {
352 if (scale <= 0.0 || std::isinf(x: scale) || std::isnan(x: scale))
353 return emitError() << "illegal scale: " << scale;
354 }
355
356 return success();
357}
358
359ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
360 return getImpl()->getScales();
361}
362
363ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
364 return getImpl()->getZeroPoints();
365}
366
367int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
368 return getImpl()->quantizedDimension;
369}
370
371CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
372 double min, double max) {
373 return Base::get(ctx: expressedType.getContext(), args&: expressedType, args&: min, args&: max);
374}
375
376CalibratedQuantizedType CalibratedQuantizedType::getChecked(
377 function_ref<InFlightDiagnostic()> emitError, Type expressedType,
378 double min, double max) {
379 return Base::getChecked(emitErrorFn: emitError, ctx: expressedType.getContext(), args: expressedType,
380 args: min, args: max);
381}
382
383LogicalResult
384CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
385 Type expressedType, double min, double max) {
386 // Verify that the expressed type is floating point.
387 // If this restriction is ever eliminated, the parser/printer must be
388 // extended.
389 if (!llvm::isa<FloatType>(Val: expressedType))
390 return emitError() << "expressed type must be floating point";
391 if (max <= min)
392 return emitError() << "illegal min and max: (" << min << ":" << max << ")";
393
394 return success();
395}
396
397double CalibratedQuantizedType::getMin() const { return getImpl()->min; }
398
399double CalibratedQuantizedType::getMax() const { return getImpl()->max; }
400

source code of mlir/lib/Dialect/Quant/IR/QuantTypes.cpp