1 | //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Quant/IR/QuantTypes.h" |
10 | #include "TypeDetail.h" |
11 | #include "mlir/Dialect/Quant/IR/Quant.h" |
12 | |
13 | #include "mlir/IR/BuiltinTypes.h" |
14 | #include "mlir/IR/MLIRContext.h" |
15 | #include "llvm/ADT/StringRef.h" |
16 | #include "llvm/ADT/Twine.h" |
17 | #include "llvm/Support/MathExtras.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace mlir::quant; |
21 | using namespace mlir::quant::detail; |
22 | |
23 | namespace { |
24 | |
25 | // Return the minimum scale representable in a given float type |
26 | double getMinScale(Type expressedType) { |
27 | auto floatType = cast<FloatType>(expressedType); |
28 | return APFloat::getSmallest(Sem: floatType.getFloatSemantics()).convertToDouble(); |
29 | } |
30 | |
31 | // Return the maximum scale representable in a given float type |
32 | double getMaxScale(Type expressedType) { |
33 | auto floatType = cast<FloatType>(expressedType); |
34 | return APFloat::getLargest(Sem: floatType.getFloatSemantics()).convertToDouble(); |
35 | } |
36 | |
37 | } // namespace |
38 | |
39 | unsigned QuantizedType::getFlags() const { |
40 | return static_cast<ImplType *>(impl)->flags; |
41 | } |
42 | |
43 | bool QuantizedType::classof(Type type) { |
44 | return llvm::isa<QuantDialect>(type.getDialect()); |
45 | } |
46 | |
47 | LogicalResult |
48 | QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, |
49 | unsigned flags, Type storageType, |
50 | Type expressedType, int64_t storageTypeMin, |
51 | int64_t storageTypeMax) { |
52 | // Verify that the storage type is integral. |
53 | // This restriction may be lifted at some point in favor of using bf16 |
54 | // or f16 as exact representations on hardware where that is advantageous. |
55 | auto intStorageType = llvm::dyn_cast<IntegerType>(storageType); |
56 | if (!intStorageType) |
57 | return emitError() << "storage type must be integral"; |
58 | unsigned integralWidth = intStorageType.getWidth(); |
59 | |
60 | // Verify storage width. |
61 | if (integralWidth == 0 || integralWidth > MaxStorageBits) |
62 | return emitError() << "illegal storage type size: "<< integralWidth; |
63 | |
64 | // Verify storageTypeMin and storageTypeMax. |
65 | bool isSigned = |
66 | (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed; |
67 | int64_t defaultIntegerMin = |
68 | getDefaultMinimumForInteger(isSigned, integralWidth); |
69 | int64_t defaultIntegerMax = |
70 | getDefaultMaximumForInteger(isSigned, integralWidth); |
71 | if (storageTypeMax - storageTypeMin <= 0 || |
72 | storageTypeMin < defaultIntegerMin || |
73 | storageTypeMax > defaultIntegerMax) { |
74 | return emitError() << "illegal storage min and storage max: (" |
75 | << storageTypeMin << ":"<< storageTypeMax << ")"; |
76 | } |
77 | return success(); |
78 | } |
79 | |
80 | Type QuantizedType::getStorageType() const { |
81 | return static_cast<ImplType *>(impl)->storageType; |
82 | } |
83 | |
84 | int64_t QuantizedType::getStorageTypeMin() const { |
85 | return static_cast<ImplType *>(impl)->storageTypeMin; |
86 | } |
87 | |
88 | int64_t QuantizedType::getStorageTypeMax() const { |
89 | return static_cast<ImplType *>(impl)->storageTypeMax; |
90 | } |
91 | |
92 | bool QuantizedType::hasStorageTypeBounds() const { |
93 | unsigned int integralWidth = getStorageTypeIntegralWidth(); |
94 | bool isSignedInteger = isSigned(); |
95 | int64_t defaultIntegerMin = |
96 | getDefaultMinimumForInteger(isSigned: isSignedInteger, integralWidth); |
97 | int64_t defaultIntegerMax = |
98 | getDefaultMaximumForInteger(isSigned: isSignedInteger, integralWidth); |
99 | return defaultIntegerMin != getStorageTypeMin() || |
100 | defaultIntegerMax != getStorageTypeMax(); |
101 | } |
102 | |
103 | unsigned QuantizedType::getStorageTypeIntegralWidth() const { |
104 | // NOTE: If ever supporting non-integral storage types, some other scheme |
105 | // for determining the width will be needed. |
106 | return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth(); |
107 | } |
108 | |
109 | Type QuantizedType::getExpressedType() const { |
110 | return static_cast<ImplType *>(impl)->expressedType; |
111 | } |
112 | |
113 | bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { |
114 | if (llvm::isa<ShapedType>(candidateExpressedType)) { |
115 | return llvm::cast<ShapedType>(candidateExpressedType).getElementType() == |
116 | getExpressedType(); |
117 | } |
118 | return candidateExpressedType == getExpressedType(); |
119 | } |
120 | |
121 | QuantizedType |
122 | QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) { |
123 | if (llvm::isa<ShapedType>(primitiveOrContainerType)) { |
124 | Type elementType = |
125 | llvm::cast<ShapedType>(primitiveOrContainerType).getElementType(); |
126 | return llvm::dyn_cast<QuantizedType>(Val&: elementType); |
127 | } |
128 | return llvm::dyn_cast<QuantizedType>(Val&: primitiveOrContainerType); |
129 | } |
130 | |
131 | Type QuantizedType::castFromStorageType(Type candidateType) { |
132 | if (candidateType == getStorageType()) { |
133 | // i.e. i32 -> quant<"uniform[i8:f32]{1.0}"> |
134 | return *this; |
135 | } |
136 | if (llvm::isa<RankedTensorType>(Val: candidateType)) { |
137 | // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
138 | return RankedTensorType::get( |
139 | llvm::cast<RankedTensorType>(candidateType).getShape(), |
140 | getStorageType()); |
141 | } |
142 | if (llvm::isa<UnrankedTensorType>(Val: candidateType)) { |
143 | // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">> |
144 | return UnrankedTensorType::get(getStorageType()); |
145 | } |
146 | if (llvm::isa<VectorType>(Val: candidateType)) { |
147 | // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
148 | return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(), |
149 | getStorageType()); |
150 | } |
151 | |
152 | return nullptr; |
153 | } |
154 | |
155 | Type QuantizedType::castToStorageType(Type quantizedType) { |
156 | if (llvm::isa<QuantizedType>(Val: quantizedType)) { |
157 | // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 |
158 | return llvm::cast<QuantizedType>(Val&: quantizedType).getStorageType(); |
159 | } |
160 | if (llvm::isa<ShapedType>(quantizedType)) { |
161 | // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
162 | ShapedType sType = llvm::cast<ShapedType>(quantizedType); |
163 | if (!llvm::isa<QuantizedType>(sType.getElementType())) { |
164 | return nullptr; |
165 | } |
166 | Type storageType = |
167 | llvm::cast<QuantizedType>(sType.getElementType()).getStorageType(); |
168 | if (llvm::isa<RankedTensorType>(Val: quantizedType)) { |
169 | return RankedTensorType::get(sType.getShape(), storageType); |
170 | } |
171 | if (llvm::isa<UnrankedTensorType>(Val: quantizedType)) { |
172 | return UnrankedTensorType::get(storageType); |
173 | } |
174 | if (llvm::isa<VectorType>(Val: quantizedType)) { |
175 | return VectorType::get(sType.getShape(), storageType); |
176 | } |
177 | } |
178 | |
179 | return nullptr; |
180 | } |
181 | |
182 | Type QuantizedType::castFromExpressedType(Type candidateType) { |
183 | if (candidateType == getExpressedType()) { |
184 | // i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> |
185 | return *this; |
186 | } |
187 | if (llvm::isa<ShapedType>(candidateType)) { |
188 | ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType); |
189 | if (candidateShapedType.getElementType() != getExpressedType()) { |
190 | return nullptr; |
191 | } |
192 | |
193 | if (llvm::isa<RankedTensorType>(Val: candidateType)) { |
194 | // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
195 | return RankedTensorType::get(candidateShapedType.getShape(), *this); |
196 | } |
197 | if (llvm::isa<UnrankedTensorType>(Val: candidateType)) { |
198 | // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">> |
199 | return UnrankedTensorType::get(*this); |
200 | } |
201 | if (llvm::isa<VectorType>(Val: candidateType)) { |
202 | // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
203 | return VectorType::get(candidateShapedType.getShape(), *this); |
204 | } |
205 | } |
206 | |
207 | return nullptr; |
208 | } |
209 | |
210 | Type QuantizedType::castToExpressedType(Type quantizedType) { |
211 | if (llvm::isa<QuantizedType>(Val: quantizedType)) { |
212 | // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 |
213 | return llvm::cast<QuantizedType>(Val&: quantizedType).getExpressedType(); |
214 | } |
215 | if (llvm::isa<ShapedType>(quantizedType)) { |
216 | // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
217 | ShapedType sType = llvm::cast<ShapedType>(quantizedType); |
218 | if (!llvm::isa<QuantizedType>(sType.getElementType())) { |
219 | return nullptr; |
220 | } |
221 | Type expressedType = |
222 | llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType(); |
223 | if (llvm::isa<RankedTensorType>(Val: quantizedType)) { |
224 | return RankedTensorType::get(sType.getShape(), expressedType); |
225 | } |
226 | if (llvm::isa<UnrankedTensorType>(Val: quantizedType)) { |
227 | return UnrankedTensorType::get(expressedType); |
228 | } |
229 | if (llvm::isa<VectorType>(Val: quantizedType)) { |
230 | return VectorType::get(sType.getShape(), expressedType); |
231 | } |
232 | } |
233 | |
234 | return nullptr; |
235 | } |
236 | |
237 | Type QuantizedType::castExpressedToStorageType(Type candidateType) { |
238 | Type expressedQuantizedType = castFromExpressedType(candidateType); |
239 | if (!expressedQuantizedType) { |
240 | return nullptr; |
241 | } |
242 | return QuantizedType::castToStorageType(quantizedType: expressedQuantizedType); |
243 | } |
244 | |
245 | AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType, |
246 | Type expressedType, |
247 | int64_t storageTypeMin, |
248 | int64_t storageTypeMax) { |
249 | return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType, |
250 | args&: storageTypeMin, args&: storageTypeMax); |
251 | } |
252 | |
253 | AnyQuantizedType |
254 | AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
255 | unsigned flags, Type storageType, |
256 | Type expressedType, int64_t storageTypeMin, |
257 | int64_t storageTypeMax) { |
258 | return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags, |
259 | args: storageType, args: expressedType, args: storageTypeMin, |
260 | args: storageTypeMax); |
261 | } |
262 | |
263 | LogicalResult |
264 | AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, |
265 | unsigned flags, Type storageType, |
266 | Type expressedType, int64_t storageTypeMin, |
267 | int64_t storageTypeMax) { |
268 | if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType, |
269 | expressedType, storageTypeMin, |
270 | storageTypeMax))) { |
271 | return failure(); |
272 | } |
273 | |
274 | // Verify that the expressed type is floating point. |
275 | // If this restriction is ever eliminated, the parser/printer must be |
276 | // extended. |
277 | if (expressedType && !llvm::isa<FloatType>(Val: expressedType)) |
278 | return emitError() << "expressed type must be floating point"; |
279 | |
280 | return success(); |
281 | } |
282 | |
283 | UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType, |
284 | Type expressedType, double scale, |
285 | int64_t zeroPoint, |
286 | int64_t storageTypeMin, |
287 | int64_t storageTypeMax) { |
288 | return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType, |
289 | args&: scale, args&: zeroPoint, args&: storageTypeMin, args&: storageTypeMax); |
290 | } |
291 | |
292 | UniformQuantizedType UniformQuantizedType::getChecked( |
293 | function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
294 | Type storageType, Type expressedType, double scale, int64_t zeroPoint, |
295 | int64_t storageTypeMin, int64_t storageTypeMax) { |
296 | return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags, |
297 | args: storageType, args: expressedType, args: scale, args: zeroPoint, |
298 | args: storageTypeMin, args: storageTypeMax); |
299 | } |
300 | |
301 | LogicalResult UniformQuantizedType::verifyInvariants( |
302 | function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
303 | Type storageType, Type expressedType, double scale, int64_t zeroPoint, |
304 | int64_t storageTypeMin, int64_t storageTypeMax) { |
305 | if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType, |
306 | expressedType, storageTypeMin, |
307 | storageTypeMax))) { |
308 | return failure(); |
309 | } |
310 | |
311 | // Uniform quantization requires fully expressed parameters, including |
312 | // expressed type. |
313 | if (!expressedType) |
314 | return emitError() << "uniform quantization requires expressed type"; |
315 | |
316 | // Verify that the expressed type is floating point. |
317 | // If this restriction is ever eliminated, the parser/printer must be |
318 | // extended. |
319 | if (!llvm::isa<FloatType>(Val: expressedType)) |
320 | return emitError() << "expressed type must be floating point"; |
321 | |
322 | // Verify scale. |
323 | double minScale = getMinScale(expressedType); |
324 | double maxScale = getMaxScale(expressedType); |
325 | if (scale < minScale || scale > maxScale) |
326 | return emitError() << "scale out of expressed type range ["<< minScale |
327 | << ", "<< maxScale << "]"; |
328 | |
329 | return success(); |
330 | } |
331 | |
332 | double UniformQuantizedType::getScale() const { return getImpl()->scale; } |
333 | |
334 | int64_t UniformQuantizedType::getZeroPoint() const { |
335 | return getImpl()->zeroPoint; |
336 | } |
337 | |
338 | UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get( |
339 | unsigned flags, Type storageType, Type expressedType, |
340 | ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, |
341 | int32_t quantizedDimension, int64_t storageTypeMin, |
342 | int64_t storageTypeMax) { |
343 | return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType, |
344 | args&: scales, args&: zeroPoints, args&: quantizedDimension, args&: storageTypeMin, |
345 | args&: storageTypeMax); |
346 | } |
347 | |
348 | UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked( |
349 | function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
350 | Type storageType, Type expressedType, ArrayRef<double> scales, |
351 | ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, |
352 | int64_t storageTypeMin, int64_t storageTypeMax) { |
353 | return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags, |
354 | args: storageType, args: expressedType, args: scales, args: zeroPoints, |
355 | args: quantizedDimension, args: storageTypeMin, args: storageTypeMax); |
356 | } |
357 | |
358 | LogicalResult UniformQuantizedPerAxisType::verifyInvariants( |
359 | function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
360 | Type storageType, Type expressedType, ArrayRef<double> scales, |
361 | ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, |
362 | int64_t storageTypeMin, int64_t storageTypeMax) { |
363 | if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType, |
364 | expressedType, storageTypeMin, |
365 | storageTypeMax))) { |
366 | return failure(); |
367 | } |
368 | |
369 | // Uniform quantization requires fully expressed parameters, including |
370 | // expressed type. |
371 | if (!expressedType) |
372 | return emitError() << "uniform quantization requires expressed type"; |
373 | |
374 | // Verify that the expressed type is floating point. |
375 | // If this restriction is ever eliminated, the parser/printer must be |
376 | // extended. |
377 | if (!llvm::isa<FloatType>(Val: expressedType)) |
378 | return emitError() << "expressed type must be floating point"; |
379 | |
380 | // Ensure that the number of scales and zeroPoints match. |
381 | if (scales.size() != zeroPoints.size()) |
382 | return emitError() << "illegal number of scales and zeroPoints: " |
383 | << scales.size() << ", "<< zeroPoints.size(); |
384 | |
385 | // Verify scale. |
386 | double minScale = getMinScale(expressedType); |
387 | double maxScale = getMaxScale(expressedType); |
388 | for (double scale : scales) { |
389 | if (scale < minScale || scale > maxScale) |
390 | return emitError() << "scale out of expressed type range ["<< minScale |
391 | << ", "<< maxScale << "]"; |
392 | } |
393 | |
394 | // Verify quantized dimension. |
395 | if (quantizedDimension < 0) |
396 | return emitError() << "illegal quantized dimension: "<< quantizedDimension; |
397 | |
398 | return success(); |
399 | } |
400 | |
401 | ArrayRef<double> UniformQuantizedPerAxisType::getScales() const { |
402 | return getImpl()->getScales(); |
403 | } |
404 | |
405 | ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const { |
406 | return getImpl()->getZeroPoints(); |
407 | } |
408 | |
409 | int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const { |
410 | return getImpl()->quantizedDimension; |
411 | } |
412 | |
413 | UniformQuantizedSubChannelType UniformQuantizedSubChannelType::get( |
414 | unsigned flags, Type storageType, Type expressedType, |
415 | DenseElementsAttr scales, DenseElementsAttr zeroPoints, |
416 | ArrayRef<int32_t> quantizedDimensions, ArrayRef<int64_t> blockSizes, |
417 | int64_t storageTypeMin, int64_t storageTypeMax) { |
418 | return Base::get(ctx: storageType.getContext(), args&: flags, args&: storageType, args&: expressedType, |
419 | args&: scales, args&: zeroPoints, args&: quantizedDimensions, args&: blockSizes, |
420 | args&: storageTypeMin, args&: storageTypeMax); |
421 | } |
422 | |
423 | UniformQuantizedSubChannelType UniformQuantizedSubChannelType::getChecked( |
424 | function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
425 | Type storageType, Type expressedType, DenseElementsAttr scales, |
426 | DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions, |
427 | ArrayRef<int64_t> blockSizes, int64_t storageTypeMin, |
428 | int64_t storageTypeMax) { |
429 | return Base::getChecked(emitErrorFn: emitError, ctx: storageType.getContext(), args: flags, |
430 | args: storageType, args: expressedType, args: scales, args: zeroPoints, |
431 | args: quantizedDimensions, args: blockSizes, args: storageTypeMin, |
432 | args: storageTypeMax); |
433 | } |
434 | |
435 | LogicalResult UniformQuantizedSubChannelType::verifyInvariants( |
436 | function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
437 | Type storageType, Type expressedType, DenseElementsAttr scales, |
438 | DenseElementsAttr zeroPoints, ArrayRef<int32_t> quantizedDimensions, |
439 | ArrayRef<int64_t> blockSizes, int64_t storageTypeMin, |
440 | int64_t storageTypeMax) { |
441 | if (failed(Result: QuantizedType::verifyInvariants(emitError, flags, storageType, |
442 | expressedType, storageTypeMin, |
443 | storageTypeMax))) { |
444 | return failure(); |
445 | } |
446 | |
447 | // Uniform quantization requires fully expressed parameters, including |
448 | // expressed type. |
449 | if (!expressedType) |
450 | return emitError() << "uniform quantization requires expressed type"; |
451 | |
452 | // Verify that the expressed type is floating point. |
453 | // If this restriction is ever eliminated, the parser/printer must be |
454 | // extended. |
455 | if (!llvm::isa<FloatType>(Val: expressedType)) |
456 | return emitError() << "expressed type must be floating point"; |
457 | |
458 | // Verify scale type to match expressedType. |
459 | if (scales.getType().getElementType() != expressedType) { |
460 | return emitError() << "type of scale values " |
461 | << scales.getType().getElementType() |
462 | << " must match the expressed type "<< expressedType; |
463 | } |
464 | |
465 | // Verify zero-point type to match storageType. |
466 | if (zeroPoints.getType().getElementType() != storageType) { |
467 | return emitError() << "type of zero point values " |
468 | << zeroPoints.getType().getElementType() |
469 | << " must match the storage type "<< storageType; |
470 | } |
471 | |
472 | // Ensure that the shape of scales and zeroPoints match. |
473 | if (scales.getType().getShape() != zeroPoints.getType().getShape()) |
474 | return emitError() << "shape of scales and zeroPoints (" |
475 | << scales.getType().getShape() << " vs " |
476 | << zeroPoints.getType().getShape() << ") does not match"; |
477 | |
478 | // Ensure that the number of quantized-dimensions and block-sizes match. |
479 | if (quantizedDimensions.size() != blockSizes.size()) |
480 | return emitError() << "number of quantized dimensions and block sizes (" |
481 | << scales.size() << " vs "<< zeroPoints.size() |
482 | << ") does not match"; |
483 | |
484 | // Verify quantized dimension. |
485 | for (auto quantizedDimension : quantizedDimensions) { |
486 | if (quantizedDimension < 0) |
487 | return emitError() << "illegal quantized dimension: " |
488 | << quantizedDimension; |
489 | } |
490 | |
491 | // Verify block sizes. |
492 | for (auto blockSize : blockSizes) { |
493 | if (blockSize <= 0) |
494 | return emitError() << "illegal block size: "<< blockSize; |
495 | } |
496 | |
497 | return success(); |
498 | } |
499 | |
500 | DenseElementsAttr UniformQuantizedSubChannelType::getScales() const { |
501 | return getImpl()->getScales(); |
502 | } |
503 | |
504 | DenseElementsAttr UniformQuantizedSubChannelType::getZeroPoints() const { |
505 | return getImpl()->getZeroPoints(); |
506 | } |
507 | |
508 | ArrayRef<int32_t> |
509 | UniformQuantizedSubChannelType::getQuantizedDimensions() const { |
510 | return getImpl()->getQuantizedDimensions(); |
511 | } |
512 | |
513 | ArrayRef<int64_t> UniformQuantizedSubChannelType::getBlockSizes() const { |
514 | return getImpl()->getBlockSizes(); |
515 | } |
516 | |
517 | const SmallVector<std::pair<int32_t, int64_t>> |
518 | UniformQuantizedSubChannelType::getBlockSizeInfo() const { |
519 | SmallVector<std::pair<int32_t, int64_t>> result; |
520 | result.reserve(N: getQuantizedDimensions().size()); |
521 | |
522 | for (auto [dim, size] : |
523 | llvm::zip(t: getQuantizedDimensions(), u: getBlockSizes())) { |
524 | result.push_back(Elt: {dim, size}); |
525 | } |
526 | |
527 | return result; |
528 | } |
529 | |
530 | CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType, |
531 | double min, double max) { |
532 | return Base::get(ctx: expressedType.getContext(), args&: expressedType, args&: min, args&: max); |
533 | } |
534 | |
535 | CalibratedQuantizedType CalibratedQuantizedType::getChecked( |
536 | function_ref<InFlightDiagnostic()> emitError, Type expressedType, |
537 | double min, double max) { |
538 | return Base::getChecked(emitErrorFn: emitError, ctx: expressedType.getContext(), args: expressedType, |
539 | args: min, args: max); |
540 | } |
541 | |
542 | LogicalResult CalibratedQuantizedType::verifyInvariants( |
543 | function_ref<InFlightDiagnostic()> emitError, Type expressedType, |
544 | double min, double max) { |
545 | // Verify that the expressed type is floating point. |
546 | // If this restriction is ever eliminated, the parser/printer must be |
547 | // extended. |
548 | if (!llvm::isa<FloatType>(Val: expressedType)) |
549 | return emitError() << "expressed type must be floating point"; |
550 | if (max <= min) |
551 | return emitError() << "illegal min and max: ("<< min << ":"<< max << ")"; |
552 | |
553 | return success(); |
554 | } |
555 | |
556 | double CalibratedQuantizedType::getMin() const { return getImpl()->min; } |
557 | |
558 | double CalibratedQuantizedType::getMax() const { return getImpl()->max; } |
559 |
Definitions
- getMinScale
- getMaxScale
- getFlags
- classof
- verifyInvariants
- getStorageType
- getStorageTypeMin
- getStorageTypeMax
- hasStorageTypeBounds
- getStorageTypeIntegralWidth
- getExpressedType
- isCompatibleExpressedType
- getQuantizedElementType
- castFromStorageType
- castToStorageType
- castFromExpressedType
- castToExpressedType
- castExpressedToStorageType
- get
- getChecked
- verifyInvariants
- get
- getChecked
- verifyInvariants
- getScale
- getZeroPoint
- get
- getChecked
- verifyInvariants
- getScales
- getZeroPoints
- getQuantizedDimension
- get
- getChecked
- verifyInvariants
- getScales
- getZeroPoints
- getQuantizedDimensions
- getBlockSizes
- getBlockSizeInfo
- get
- getChecked
- verifyInvariants
- getMin
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more