1 | //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Quant/QuantTypes.h" |
10 | #include "TypeDetail.h" |
11 | #include "mlir/Dialect/Quant/QuantOps.h" |
12 | |
13 | #include "mlir/IR/BuiltinTypes.h" |
14 | #include "mlir/IR/MLIRContext.h" |
15 | #include "llvm/ADT/StringRef.h" |
16 | #include "llvm/ADT/Twine.h" |
17 | #include "llvm/Support/MathExtras.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace mlir::quant; |
21 | using namespace mlir::quant::detail; |
22 | |
23 | unsigned QuantizedType::getFlags() const { |
24 | return static_cast<ImplType *>(impl)->flags; |
25 | } |
26 | |
27 | bool QuantizedType::classof(Type type) { |
28 | return llvm::isa<QuantizationDialect>(type.getDialect()); |
29 | } |
30 | |
31 | LogicalResult |
32 | QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError, |
33 | unsigned flags, Type storageType, Type expressedType, |
34 | int64_t storageTypeMin, int64_t storageTypeMax) { |
35 | // Verify that the storage type is integral. |
36 | // This restriction may be lifted at some point in favor of using bf16 |
37 | // or f16 as exact representations on hardware where that is advantageous. |
38 | auto intStorageType = llvm::dyn_cast<IntegerType>(storageType); |
39 | if (!intStorageType) |
40 | return emitError() << "storage type must be integral" ; |
41 | unsigned integralWidth = intStorageType.getWidth(); |
42 | |
43 | // Verify storage width. |
44 | if (integralWidth == 0 || integralWidth > MaxStorageBits) |
45 | return emitError() << "illegal storage type size: " << integralWidth; |
46 | |
47 | // Verify storageTypeMin and storageTypeMax. |
48 | bool isSigned = |
49 | (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed; |
50 | int64_t defaultIntegerMin = |
51 | getDefaultMinimumForInteger(isSigned, integralWidth); |
52 | int64_t defaultIntegerMax = |
53 | getDefaultMaximumForInteger(isSigned, integralWidth); |
54 | if (storageTypeMax - storageTypeMin <= 0 || |
55 | storageTypeMin < defaultIntegerMin || |
56 | storageTypeMax > defaultIntegerMax) { |
57 | return emitError() << "illegal storage min and storage max: (" |
58 | << storageTypeMin << ":" << storageTypeMax << ")" ; |
59 | } |
60 | return success(); |
61 | } |
62 | |
63 | Type QuantizedType::getStorageType() const { |
64 | return static_cast<ImplType *>(impl)->storageType; |
65 | } |
66 | |
67 | int64_t QuantizedType::getStorageTypeMin() const { |
68 | return static_cast<ImplType *>(impl)->storageTypeMin; |
69 | } |
70 | |
71 | int64_t QuantizedType::getStorageTypeMax() const { |
72 | return static_cast<ImplType *>(impl)->storageTypeMax; |
73 | } |
74 | |
75 | unsigned QuantizedType::getStorageTypeIntegralWidth() const { |
76 | // NOTE: If ever supporting non-integral storage types, some other scheme |
77 | // for determining the width will be needed. |
78 | return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth(); |
79 | } |
80 | |
81 | Type QuantizedType::getExpressedType() const { |
82 | return static_cast<ImplType *>(impl)->expressedType; |
83 | } |
84 | |
85 | bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { |
86 | if (llvm::isa<ShapedType>(candidateExpressedType)) { |
87 | return llvm::cast<ShapedType>(candidateExpressedType).getElementType() == |
88 | getExpressedType(); |
89 | } |
90 | return candidateExpressedType == getExpressedType(); |
91 | } |
92 | |
93 | QuantizedType |
94 | QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) { |
95 | if (llvm::isa<ShapedType>(primitiveOrContainerType)) { |
96 | Type elementType = |
97 | llvm::cast<ShapedType>(primitiveOrContainerType).getElementType(); |
98 | return llvm::dyn_cast<QuantizedType>(Val&: elementType); |
99 | } |
100 | return llvm::dyn_cast<QuantizedType>(Val&: primitiveOrContainerType); |
101 | } |
102 | |
103 | Type QuantizedType::castFromStorageType(Type candidateType) { |
104 | if (candidateType == getStorageType()) { |
105 | // i.e. i32 -> quant<"uniform[i8:f32]{1.0}"> |
106 | return *this; |
107 | } |
108 | if (llvm::isa<RankedTensorType>(Val: candidateType)) { |
109 | // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
110 | return RankedTensorType::get( |
111 | llvm::cast<RankedTensorType>(candidateType).getShape(), |
112 | getStorageType()); |
113 | } |
114 | if (llvm::isa<UnrankedTensorType>(Val: candidateType)) { |
115 | // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">> |
116 | return UnrankedTensorType::get(getStorageType()); |
117 | } |
118 | if (llvm::isa<VectorType>(Val: candidateType)) { |
119 | // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
120 | return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(), |
121 | getStorageType()); |
122 | } |
123 | |
124 | return nullptr; |
125 | } |
126 | |
127 | Type QuantizedType::castToStorageType(Type quantizedType) { |
128 | if (llvm::isa<QuantizedType>(Val: quantizedType)) { |
129 | // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 |
130 | return llvm::cast<QuantizedType>(Val&: quantizedType).getStorageType(); |
131 | } |
132 | if (llvm::isa<ShapedType>(quantizedType)) { |
133 | // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
134 | ShapedType sType = llvm::cast<ShapedType>(quantizedType); |
135 | if (!llvm::isa<QuantizedType>(sType.getElementType())) { |
136 | return nullptr; |
137 | } |
138 | Type storageType = |
139 | llvm::cast<QuantizedType>(sType.getElementType()).getStorageType(); |
140 | if (llvm::isa<RankedTensorType>(Val: quantizedType)) { |
141 | return RankedTensorType::get(sType.getShape(), storageType); |
142 | } |
143 | if (llvm::isa<UnrankedTensorType>(Val: quantizedType)) { |
144 | return UnrankedTensorType::get(storageType); |
145 | } |
146 | if (llvm::isa<VectorType>(Val: quantizedType)) { |
147 | return VectorType::get(sType.getShape(), storageType); |
148 | } |
149 | } |
150 | |
151 | return nullptr; |
152 | } |
153 | |
154 | Type QuantizedType::castFromExpressedType(Type candidateType) { |
155 | if (candidateType == getExpressedType()) { |
156 | // i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> |
157 | return *this; |
158 | } |
159 | if (llvm::isa<ShapedType>(candidateType)) { |
160 | ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType); |
161 | if (candidateShapedType.getElementType() != getExpressedType()) { |
162 | return nullptr; |
163 | } |
164 | |
165 | if (llvm::isa<RankedTensorType>(Val: candidateType)) { |
166 | // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
167 | return RankedTensorType::get(candidateShapedType.getShape(), *this); |
168 | } |
169 | if (llvm::isa<UnrankedTensorType>(Val: candidateType)) { |
170 | // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">> |
171 | return UnrankedTensorType::get(*this); |
172 | } |
173 | if (llvm::isa<VectorType>(Val: candidateType)) { |
174 | // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
175 | return VectorType::get(candidateShapedType.getShape(), *this); |
176 | } |
177 | } |
178 | |
179 | return nullptr; |
180 | } |
181 | |
182 | Type QuantizedType::castToExpressedType(Type quantizedType) { |
183 | if (llvm::isa<QuantizedType>(Val: quantizedType)) { |
184 | // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 |
185 | return llvm::cast<QuantizedType>(Val&: quantizedType).getExpressedType(); |
186 | } |
187 | if (llvm::isa<ShapedType>(quantizedType)) { |
188 | // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
189 | ShapedType sType = llvm::cast<ShapedType>(quantizedType); |
190 | if (!llvm::isa<QuantizedType>(sType.getElementType())) { |
191 | return nullptr; |
192 | } |
193 | Type expressedType = |
194 | llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType(); |
195 | if (llvm::isa<RankedTensorType>(Val: quantizedType)) { |
196 | return RankedTensorType::get(sType.getShape(), expressedType); |
197 | } |
198 | if (llvm::isa<UnrankedTensorType>(Val: quantizedType)) { |
199 | return UnrankedTensorType::get(expressedType); |
200 | } |
201 | if (llvm::isa<VectorType>(Val: quantizedType)) { |
202 | return VectorType::get(sType.getShape(), expressedType); |
203 | } |
204 | } |
205 | |
206 | return nullptr; |
207 | } |
208 | |
209 | Type QuantizedType::castExpressedToStorageType(Type candidateType) { |
210 | Type expressedQuantizedType = castFromExpressedType(candidateType); |
211 | if (!expressedQuantizedType) { |
212 | return nullptr; |
213 | } |
214 | return QuantizedType::castToStorageType(quantizedType: expressedQuantizedType); |
215 | } |
216 | |
217 | AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType, |
218 | Type expressedType, |
219 | int64_t storageTypeMin, |
220 | int64_t storageTypeMax) { |
221 | return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType, |
222 | args&: storageTypeMin, args&: storageTypeMax); |
223 | } |
224 | |
225 | AnyQuantizedType |
226 | AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
227 | unsigned flags, Type storageType, |
228 | Type expressedType, int64_t storageTypeMin, |
229 | int64_t storageTypeMax) { |
230 | return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags, |
231 | args: storageType, args: expressedType, args: storageTypeMin, |
232 | args: storageTypeMax); |
233 | } |
234 | |
235 | LogicalResult |
236 | AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError, |
237 | unsigned flags, Type storageType, Type expressedType, |
238 | int64_t storageTypeMin, int64_t storageTypeMax) { |
239 | if (failed(result: QuantizedType::verify(emitError, flags, storageType, expressedType, |
240 | storageTypeMin, storageTypeMax))) { |
241 | return failure(); |
242 | } |
243 | |
244 | // Verify that the expressed type is floating point. |
245 | // If this restriction is ever eliminated, the parser/printer must be |
246 | // extended. |
247 | if (expressedType && !llvm::isa<FloatType>(Val: expressedType)) |
248 | return emitError() << "expressed type must be floating point" ; |
249 | |
250 | return success(); |
251 | } |
252 | |
253 | UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType, |
254 | Type expressedType, double scale, |
255 | int64_t zeroPoint, |
256 | int64_t storageTypeMin, |
257 | int64_t storageTypeMax) { |
258 | return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType, |
259 | args&: scale, args&: zeroPoint, args&: storageTypeMin, args&: storageTypeMax); |
260 | } |
261 | |
262 | UniformQuantizedType UniformQuantizedType::getChecked( |
263 | function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
264 | Type storageType, Type expressedType, double scale, int64_t zeroPoint, |
265 | int64_t storageTypeMin, int64_t storageTypeMax) { |
266 | return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags, |
267 | args: storageType, args: expressedType, args: scale, args: zeroPoint, |
268 | args: storageTypeMin, args: storageTypeMax); |
269 | } |
270 | |
271 | LogicalResult UniformQuantizedType::verify( |
272 | function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
273 | Type storageType, Type expressedType, double scale, int64_t zeroPoint, |
274 | int64_t storageTypeMin, int64_t storageTypeMax) { |
275 | if (failed(result: QuantizedType::verify(emitError, flags, storageType, expressedType, |
276 | storageTypeMin, storageTypeMax))) { |
277 | return failure(); |
278 | } |
279 | |
280 | // Uniform quantization requires fully expressed parameters, including |
281 | // expressed type. |
282 | if (!expressedType) |
283 | return emitError() << "uniform quantization requires expressed type" ; |
284 | |
285 | // Verify that the expressed type is floating point. |
286 | // If this restriction is ever eliminated, the parser/printer must be |
287 | // extended. |
288 | if (!llvm::isa<FloatType>(Val: expressedType)) |
289 | return emitError() << "expressed type must be floating point" ; |
290 | |
291 | // Verify scale. |
292 | if (scale <= 0.0 || std::isinf(x: scale) || std::isnan(x: scale)) |
293 | return emitError() << "illegal scale: " << scale; |
294 | |
295 | return success(); |
296 | } |
297 | |
298 | double UniformQuantizedType::getScale() const { return getImpl()->scale; } |
299 | |
300 | int64_t UniformQuantizedType::getZeroPoint() const { |
301 | return getImpl()->zeroPoint; |
302 | } |
303 | |
304 | UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get( |
305 | unsigned flags, Type storageType, Type expressedType, |
306 | ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, |
307 | int32_t quantizedDimension, int64_t storageTypeMin, |
308 | int64_t storageTypeMax) { |
309 | return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType, |
310 | args&: scales, args&: zeroPoints, args&: quantizedDimension, args&: storageTypeMin, |
311 | args&: storageTypeMax); |
312 | } |
313 | |
314 | UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked( |
315 | function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
316 | Type storageType, Type expressedType, ArrayRef<double> scales, |
317 | ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, |
318 | int64_t storageTypeMin, int64_t storageTypeMax) { |
319 | return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags, |
320 | args: storageType, args: expressedType, args: scales, args: zeroPoints, |
321 | args: quantizedDimension, args: storageTypeMin, args: storageTypeMax); |
322 | } |
323 | |
324 | LogicalResult UniformQuantizedPerAxisType::verify( |
325 | function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
326 | Type storageType, Type expressedType, ArrayRef<double> scales, |
327 | ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, |
328 | int64_t storageTypeMin, int64_t storageTypeMax) { |
329 | if (failed(result: QuantizedType::verify(emitError, flags, storageType, expressedType, |
330 | storageTypeMin, storageTypeMax))) { |
331 | return failure(); |
332 | } |
333 | |
334 | // Uniform quantization requires fully expressed parameters, including |
335 | // expressed type. |
336 | if (!expressedType) |
337 | return emitError() << "uniform quantization requires expressed type" ; |
338 | |
339 | // Verify that the expressed type is floating point. |
340 | // If this restriction is ever eliminated, the parser/printer must be |
341 | // extended. |
342 | if (!llvm::isa<FloatType>(Val: expressedType)) |
343 | return emitError() << "expressed type must be floating point" ; |
344 | |
345 | // Ensure that the number of scales and zeroPoints match. |
346 | if (scales.size() != zeroPoints.size()) |
347 | return emitError() << "illegal number of scales and zeroPoints: " |
348 | << scales.size() << ", " << zeroPoints.size(); |
349 | |
350 | // Verify scale. |
351 | for (double scale : scales) { |
352 | if (scale <= 0.0 || std::isinf(x: scale) || std::isnan(x: scale)) |
353 | return emitError() << "illegal scale: " << scale; |
354 | } |
355 | |
356 | return success(); |
357 | } |
358 | |
359 | ArrayRef<double> UniformQuantizedPerAxisType::getScales() const { |
360 | return getImpl()->getScales(); |
361 | } |
362 | |
363 | ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const { |
364 | return getImpl()->getZeroPoints(); |
365 | } |
366 | |
367 | int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const { |
368 | return getImpl()->quantizedDimension; |
369 | } |
370 | |
371 | CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType, |
372 | double min, double max) { |
373 | return Base::get(ctx: expressedType.getContext(), args&: expressedType, args&: min, args&: max); |
374 | } |
375 | |
376 | CalibratedQuantizedType CalibratedQuantizedType::getChecked( |
377 | function_ref<InFlightDiagnostic()> emitError, Type expressedType, |
378 | double min, double max) { |
379 | return Base::getChecked(emitErrorFn: emitError, ctx: expressedType.getContext(), args: expressedType, |
380 | args: min, args: max); |
381 | } |
382 | |
383 | LogicalResult |
384 | CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError, |
385 | Type expressedType, double min, double max) { |
386 | // Verify that the expressed type is floating point. |
387 | // If this restriction is ever eliminated, the parser/printer must be |
388 | // extended. |
389 | if (!llvm::isa<FloatType>(Val: expressedType)) |
390 | return emitError() << "expressed type must be floating point" ; |
391 | if (max <= min) |
392 | return emitError() << "illegal min and max: (" << min << ":" << max << ")" ; |
393 | |
394 | return success(); |
395 | } |
396 | |
397 | double CalibratedQuantizedType::getMin() const { return getImpl()->min; } |
398 | |
399 | double CalibratedQuantizedType::getMax() const { return getImpl()->max; } |
400 | |