1//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Quant/IR/Quant.h"
10#include "mlir/Dialect/Quant/IR/QuantTypes.h"
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/IR/DialectImplementation.h"
13#include "mlir/IR/Location.h"
14#include "mlir/IR/Types.h"
15#include "llvm/ADT/APFloat.h"
16#include "llvm/Support/Format.h"
17#include "llvm/Support/MathExtras.h"
18#include "llvm/Support/SourceMgr.h"
19#include "llvm/Support/raw_ostream.h"
20
21using namespace mlir;
22using namespace quant;
23
24static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
25 auto typeLoc = parser.getCurrentLocation();
26 IntegerType type;
27
28 // Parse storage type (alpha_ident, integer_literal).
29 StringRef identifier;
30 unsigned storageTypeWidth = 0;
31 OptionalParseResult result = parser.parseOptionalType(result&: type);
32 if (result.has_value()) {
33 if (!succeeded(Result: *result))
34 return nullptr;
35 isSigned = !type.isUnsigned();
36 storageTypeWidth = type.getWidth();
37 } else if (succeeded(Result: parser.parseKeyword(keyword: &identifier))) {
38 // Otherwise, this must be an unsigned integer (`u` integer-literal).
39 if (!identifier.consume_front(Prefix: "u")) {
40 parser.emitError(loc: typeLoc, message: "illegal storage type prefix");
41 return nullptr;
42 }
43 if (identifier.getAsInteger(Radix: 10, Result&: storageTypeWidth)) {
44 parser.emitError(loc: typeLoc, message: "expected storage type width");
45 return nullptr;
46 }
47 isSigned = false;
48 type = parser.getBuilder().getIntegerType(storageTypeWidth);
49 } else {
50 return nullptr;
51 }
52
53 if (storageTypeWidth == 0 ||
54 storageTypeWidth > QuantizedType::MaxStorageBits) {
55 parser.emitError(loc: typeLoc, message: "illegal storage type size: ")
56 << storageTypeWidth;
57 return nullptr;
58 }
59
60 return type;
61}
62
63static ParseResult parseStorageRange(DialectAsmParser &parser,
64 IntegerType storageType, bool isSigned,
65 int64_t &storageTypeMin,
66 int64_t &storageTypeMax) {
67 int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
68 isSigned, integralWidth: storageType.getWidth());
69 int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
70 isSigned, integralWidth: storageType.getWidth());
71 if (failed(Result: parser.parseOptionalLess())) {
72 storageTypeMin = defaultIntegerMin;
73 storageTypeMax = defaultIntegerMax;
74 return success();
75 }
76
77 // Explicit storage min and storage max.
78 SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
79 if (parser.parseInteger(result&: storageTypeMin) || parser.parseColon() ||
80 parser.getCurrentLocation(loc: &maxLoc) ||
81 parser.parseInteger(result&: storageTypeMax) || parser.parseGreater())
82 return failure();
83 if (storageTypeMin < defaultIntegerMin) {
84 return parser.emitError(loc: minLoc, message: "illegal storage type minimum: ")
85 << storageTypeMin;
86 }
87 if (storageTypeMax > defaultIntegerMax) {
88 return parser.emitError(loc: maxLoc, message: "illegal storage type maximum: ")
89 << storageTypeMax;
90 }
91 return success();
92}
93
94static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
95 double &min, double &max) {
96 auto typeLoc = parser.getCurrentLocation();
97 FloatType type;
98
99 if (failed(parser.parseType(type))) {
100 parser.emitError(loc: typeLoc, message: "expecting float expressed type");
101 return nullptr;
102 }
103
104 // Calibrated min and max values.
105 if (parser.parseLess() || parser.parseFloat(result&: min) || parser.parseColon() ||
106 parser.parseFloat(result&: max) || parser.parseGreater()) {
107 parser.emitError(loc: typeLoc, message: "calibrated values must be present");
108 return nullptr;
109 }
110 return type;
111}
112
113/// Parses an AnyQuantizedType.
114///
115/// any ::= `any<` storage-spec (expressed-type-spec)?`>`
116/// storage-spec ::= storage-type (`<` storage-range `>`)?
117/// storage-range ::= integer-literal `:` integer-literal
118/// storage-type ::= (`i` | `u`) integer-literal
119/// expressed-type-spec ::= `:` `f` integer-literal
120static Type parseAnyType(DialectAsmParser &parser) {
121 IntegerType storageType;
122 FloatType expressedType;
123 unsigned typeFlags = 0;
124 int64_t storageTypeMin;
125 int64_t storageTypeMax;
126
127 // Type specification.
128 if (parser.parseLess())
129 return nullptr;
130
131 // Storage type.
132 bool isSigned = false;
133 storageType = parseStorageType(parser, isSigned);
134 if (!storageType) {
135 return nullptr;
136 }
137 if (isSigned) {
138 typeFlags |= QuantizationFlags::Signed;
139 }
140
141 // Storage type range.
142 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
143 storageTypeMax)) {
144 return nullptr;
145 }
146
147 // Optional expressed type.
148 if (succeeded(Result: parser.parseOptionalColon())) {
149 if (parser.parseType(expressedType)) {
150 return nullptr;
151 }
152 }
153
154 if (parser.parseGreater()) {
155 return nullptr;
156 }
157
158 return parser.getChecked<AnyQuantizedType>(
159 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
160}
161
162/// Checks if the given scale value is within the valid range of the expressed
163/// type. The `expressedType` argument is the floating-point type used for
164/// expressing the quantized values, and `scale` is the double value to check.
165LogicalResult
166isScaleInExpressedTypeRange(function_ref<InFlightDiagnostic()> emitError,
167 Type expressedType, double scale) {
168 auto floatType = cast<FloatType>(expressedType);
169 double minScale =
170 APFloat::getSmallest(Sem: floatType.getFloatSemantics()).convertToDouble();
171 double maxScale =
172 APFloat::getLargest(Sem: floatType.getFloatSemantics()).convertToDouble();
173 if (scale < minScale || scale > maxScale)
174 return emitError() << "scale " << scale << " out of expressed type range ["
175 << minScale << ", " << maxScale << "]";
176 return success();
177}
178
179/// Parses a quantization parameter, which is either a scale value (float) or a
180/// scale-zero point pair (float:integer). `expressedType`, expressing the type
181/// of scale values, is used to validate the scale. The parsed scale and zero
182/// point (if any) are stored in `scale` and `zeroPoint`.
183static ParseResult parseQuantParams(DialectAsmParser &parser,
184 Type expressedType, double &scale,
185 int64_t &zeroPoint) {
186
187 if (parser.parseFloat(result&: scale)) {
188 return failure();
189 }
190
191 if (failed(Result: isScaleInExpressedTypeRange(
192 emitError: [&]() { return parser.emitError(loc: parser.getCurrentLocation()); },
193 expressedType, scale))) {
194 return failure();
195 }
196
197 zeroPoint = 0;
198 if (failed(Result: parser.parseOptionalColon())) {
199 return success();
200 }
201
202 return parser.parseInteger(result&: zeroPoint);
203}
204
205/// Parses block size information for sub-channel quantization, assuming the
206/// leading '{' has already been parsed. The block size information is provided
207/// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'.
208///
209/// The parsed axis indices are stored in `quantizedDimensions`, and the
210/// corresponding block sizes are stored in `blockSizes`.
211static ParseResult
212parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser,
213 SmallVectorImpl<int32_t> &quantizedDimensions,
214 SmallVectorImpl<int64_t> &blockSizes) {
215 // Empty block-sizes info.
216 if (succeeded(Result: parser.parseOptionalRBrace())) {
217 return success();
218 }
219
220 auto parseBlockSizeElements = [&]() -> ParseResult {
221 quantizedDimensions.resize(N: quantizedDimensions.size() + 1);
222 blockSizes.resize(N: blockSizes.size() + 1);
223 if (parser.parseInteger(result&: quantizedDimensions.back()) ||
224 parser.parseColon() || parser.parseInteger(result&: blockSizes.back()))
225 return failure();
226 return success();
227 };
228
229 if (parser.parseCommaSeparatedList(parseElementFn: parseBlockSizeElements) ||
230 parser.parseRBrace()) {
231 return failure();
232 }
233
234 return success();
235}
236
237/// Parses a bracketed list of quantization parameters, returning the dimensions
238/// of the parsed sub-tensors in `dims`. The dimension of the list is prepended
239/// to the dimensions of the sub-tensors. This function assumes that the initial
240/// left brace has already been parsed. For example:
241///
242/// parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success,
243/// dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4]
244///
245/// parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success,
246/// dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1,
247/// 9]
248///
249/// This function expects all sub-tensors to have the same rank.
250static ParseResult
251parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType,
252 SmallVectorImpl<double> &scales,
253 SmallVectorImpl<int64_t> &zeroPoints,
254 SmallVectorImpl<int64_t> &dims) {
255 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
256 const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
257 if (prevDims == newDims)
258 return success();
259 return parser.emitError(loc: parser.getCurrentLocation())
260 << "tensor literal is invalid; ranks are not consistent "
261 "between elements";
262 };
263
264 bool first = true;
265 SmallVector<int64_t, 4> newDims;
266 unsigned size = 0;
267
268 auto parseOneElement = [&]() -> ParseResult {
269 SmallVector<int64_t, 4> thisDims;
270 if (succeeded(Result: parser.parseOptionalLBrace())) {
271 if (parseQuantParamListUntilRBrace(parser, expressedType, scales,
272 zeroPoints, dims&: thisDims))
273 return failure();
274 } else {
275 zeroPoints.resize(N: zeroPoints.size() + 1);
276 scales.resize(N: scales.size() + 1);
277 if (parseQuantParams(parser, expressedType, scale&: scales.back(),
278 zeroPoint&: zeroPoints.back())) {
279 return failure();
280 }
281 }
282 ++size;
283 if (!first)
284 return checkDims(newDims, thisDims);
285 newDims = thisDims;
286 first = false;
287 return success();
288 };
289
290 if (parser.parseCommaSeparatedList(parseElementFn: parseOneElement) || parser.parseRBrace()) {
291 return failure();
292 }
293
294 // Return the sublists' dimensions with 'size' prepended.
295 dims.clear();
296 dims.push_back(Elt: size);
297 dims.append(in_start: newDims.begin(), in_end: newDims.end());
298
299 return success();
300}
301
302/// Parses a UniformQuantizedType.
303///
304/// uniform_type ::= uniform_per_layer
305/// | uniform_per_axis
306/// | uniform_sub_channel
307/// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
308/// `,` scale-zero `>`
309/// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
310/// axis-spec `,` `{` scale-zero-list `}` `>`
311/// uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec
312/// block-size-info `,` scale-zero-tensor `>`
313/// storage-spec ::= storage-type (`<` storage-range `>`)?
314/// storage-range ::= integer-literal `:` integer-literal
315/// storage-type ::= (`i` | `u`) integer-literal
316/// expressed-type-spec ::= `:` `f` integer-literal
317/// axis-spec ::= `:` integer-literal
318/// scale-zero ::= scale (`:` zero-point)?
319/// scale ::= float-literal
320/// zero-point ::= integer-literal
321/// scale-zero-list ::= scale-zero (`,` scale-zero)*
322/// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}`
323/// axis-block ::= axis-spec `:` block-size-spec
324/// block-size-spec ::= integer-literal
325/// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list
326/// scale-zero-dense-exp ::= `{`
327/// scale-zero-tensor (`,` scale-zero-tensor)*
328/// `}`
329static Type parseUniformType(DialectAsmParser &parser) {
330 IntegerType storageType;
331 FloatType expressedType;
332 unsigned typeFlags = 0;
333 int64_t storageTypeMin;
334 int64_t storageTypeMax;
335 bool isPerAxis = false;
336 bool isSubChannel = false;
337 SmallVector<int32_t, 1> quantizedDimensions;
338 SmallVector<int64_t, 1> blockSizes;
339 SmallVector<double, 1> scales;
340 SmallVector<int64_t, 1> zeroPoints;
341
342 // Type specification.
343 if (parser.parseLess()) {
344 return nullptr;
345 }
346
347 // Storage type.
348 bool isSigned = false;
349 storageType = parseStorageType(parser, isSigned);
350 if (!storageType) {
351 return nullptr;
352 }
353 if (isSigned) {
354 typeFlags |= QuantizationFlags::Signed;
355 }
356
357 // Storage type range.
358 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
359 storageTypeMax)) {
360 return nullptr;
361 }
362
363 // Expressed type.
364 if (parser.parseColon() || parser.parseType(expressedType)) {
365 return nullptr;
366 }
367
368 // Optionally parse quantized dimension for per-axis or sub-channel
369 // quantization.
370 if (succeeded(Result: parser.parseOptionalColon())) {
371 if (succeeded(Result: parser.parseOptionalLBrace())) {
372 isSubChannel = true;
373 if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions,
374 blockSizes)) {
375 return nullptr;
376 }
377 } else {
378 isPerAxis = true;
379 quantizedDimensions.resize(N: 1);
380 if (parser.parseInteger(result&: quantizedDimensions.back())) {
381 return nullptr;
382 }
383 }
384 }
385
386 // Comma leading into range_spec.
387 if (parser.parseComma()) {
388 return nullptr;
389 }
390
391 // Quantization parameter (scales/zeroPoints) specification.
392 bool isPerTensor = !isPerAxis && !isSubChannel;
393 SmallVector<int64_t> dims;
394 if (isPerTensor) {
395 zeroPoints.resize(N: zeroPoints.size() + 1);
396 scales.resize(N: scales.size() + 1);
397 if (parseQuantParams(parser, expressedType, scales.back(),
398 zeroPoints.back())) {
399 return nullptr;
400 }
401
402 } else {
403 if (parser.parseLBrace() ||
404 parseQuantParamListUntilRBrace(parser, expressedType, scales,
405 zeroPoints, dims)) {
406 return nullptr;
407 }
408 }
409
410 if (parser.parseGreater()) {
411 return nullptr;
412 }
413
414 if (isPerAxis) {
415 return parser.getChecked<UniformQuantizedPerAxisType>(
416 typeFlags, storageType, expressedType, scales, zeroPoints,
417 quantizedDimensions[0], storageTypeMin, storageTypeMax);
418 } else if (isSubChannel) {
419 SmallVector<APFloat> apFloatScales =
420 llvm::to_vector(Range: llvm::map_range(C&: scales, F: [&](double scale) -> APFloat {
421 APFloat apFloatScale(scale);
422 bool unused;
423 apFloatScale.convert(ToSemantics: expressedType.getFloatSemantics(),
424 RM: APFloat::rmNearestTiesToEven, losesInfo: &unused);
425 return apFloatScale;
426 }));
427 SmallVector<APInt> apIntZeroPoints = llvm::to_vector(
428 Range: llvm::map_range(C&: zeroPoints, F: [&](int64_t zeroPoint) -> APInt {
429 return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint);
430 }));
431 auto scalesRef = mlir::DenseElementsAttr::get(
432 RankedTensorType::get(dims, expressedType), apFloatScales);
433 auto zeroPointsRef = mlir::DenseElementsAttr::get(
434 RankedTensorType::get(dims, storageType), apIntZeroPoints);
435 return parser.getChecked<UniformQuantizedSubChannelType>(
436 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
437 quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax);
438 }
439
440 return parser.getChecked<UniformQuantizedType>(
441 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
442 storageTypeMin, storageTypeMax);
443}
444
445/// Parses an CalibratedQuantizedType.
446///
447/// calibrated ::= `calibrated<` expressed-spec `>`
448/// expressed-spec ::= expressed-type `<` calibrated-range `>`
449/// expressed-type ::= `f` integer-literal
450/// calibrated-range ::= float-literal `:` float-literal
451static Type parseCalibratedType(DialectAsmParser &parser) {
452 FloatType expressedType;
453 double min;
454 double max;
455
456 // Type specification.
457 if (parser.parseLess())
458 return nullptr;
459
460 // Expressed type.
461 expressedType = parseExpressedTypeAndRange(parser, min, max);
462 if (!expressedType) {
463 return nullptr;
464 }
465
466 if (parser.parseGreater()) {
467 return nullptr;
468 }
469
470 return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
471}
472
473/// Parse a type registered to this dialect.
474Type QuantDialect::parseType(DialectAsmParser &parser) const {
475 // All types start with an identifier that we switch on.
476 StringRef typeNameSpelling;
477 if (failed(parser.parseKeyword(&typeNameSpelling)))
478 return nullptr;
479
480 if (typeNameSpelling == "uniform")
481 return parseUniformType(parser);
482 if (typeNameSpelling == "any")
483 return parseAnyType(parser);
484 if (typeNameSpelling == "calibrated")
485 return parseCalibratedType(parser);
486
487 parser.emitError(parser.getNameLoc(),
488 "unknown quantized type " + typeNameSpelling);
489 return nullptr;
490}
491
492static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
493 // storage type
494 unsigned storageWidth = type.getStorageTypeIntegralWidth();
495 bool isSigned = type.isSigned();
496 if (isSigned) {
497 out << "i" << storageWidth;
498 } else {
499 out << "u" << storageWidth;
500 }
501
502 // storageTypeMin and storageTypeMax if not default.
503 if (type.hasStorageTypeBounds()) {
504 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
505 << ">";
506 }
507}
508
509static void printQuantParams(double scale, int64_t zeroPoint,
510 DialectAsmPrinter &out) {
511 out << scale;
512 if (zeroPoint != 0) {
513 out << ":" << zeroPoint;
514 }
515}
516
517static void
518printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo,
519 DialectAsmPrinter &out) {
520 out << "{";
521 llvm::interleaveComma(
522 c: llvm::seq<size_t>(Begin: 0, End: blockSizeInfo.size()), os&: out, each_fn: [&](size_t index) {
523 out << blockSizeInfo[index].first << ":" << blockSizeInfo[index].second;
524 });
525 out << "}";
526}
527
528/// Helper that prints a AnyQuantizedType.
529static void printAnyQuantizedType(AnyQuantizedType type,
530 DialectAsmPrinter &out) {
531 out << "any<";
532 printStorageType(type, out);
533 if (Type expressedType = type.getExpressedType()) {
534 out << ":" << expressedType;
535 }
536 out << ">";
537}
538
539/// Helper that prints a UniformQuantizedType.
540static void printUniformQuantizedType(UniformQuantizedType type,
541 DialectAsmPrinter &out) {
542 out << "uniform<";
543 printStorageType(type, out);
544 out << ":" << type.getExpressedType() << ", ";
545
546 // scheme specific parameters
547 printQuantParams(scale: type.getScale(), zeroPoint: type.getZeroPoint(), out);
548 out << ">";
549}
550
551/// Helper that prints a UniformQuantizedPerAxisType.
552static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
553 DialectAsmPrinter &out) {
554 out << "uniform<";
555 printStorageType(type, out);
556 out << ":" << type.getExpressedType() << ":";
557 out << type.getQuantizedDimension();
558 out << ", ";
559
560 // scheme specific parameters
561 ArrayRef<double> scales = type.getScales();
562 ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
563 out << "{";
564 llvm::interleave(
565 c: llvm::seq<size_t>(Begin: 0, End: scales.size()), os&: out,
566 each_fn: [&](size_t index) {
567 printQuantParams(scale: scales[index], zeroPoint: zeroPoints[index], out);
568 },
569 separator: ",");
570 out << "}>";
571}
572
573/// Prints quantization parameters as a nested list of `scale`[:`zero_point`]
574/// elements. The nesting corresponds to the `shape` dimensions.
575///
576/// Elements are delimited by commas, and the inner dimensions are enclosed in
577/// braces. `zero_point` is only printed if it is non-zero. For example:
578///
579/// printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0],
580/// zeroPoints=[0, 0, 1, 9],
581/// shape=[2, 2])
582///
583/// would print:
584///
585/// {{1.0, 2.0}, {3.0:1, 4.0:9}}
586void printDenseQuantizationParameters(ArrayRef<APFloat> scales,
587 ArrayRef<APInt> zeroPoints,
588 ArrayRef<int64_t> shape,
589 DialectAsmPrinter &out) {
590 int64_t rank = shape.size();
591 SmallVector<unsigned, 4> counter(rank, 0);
592 unsigned openBrackets = 0;
593
594 auto incrementCounterAndDelimit = [&]() {
595 ++counter[rank - 1];
596 for (unsigned i = rank - 1; i > 0; --i) {
597 if (counter[i] >= shape[i]) {
598 counter[i] = 0;
599 ++counter[i - 1];
600 --openBrackets;
601 out << '}';
602 }
603 }
604 };
605
606 for (unsigned idx = 0, e = scales.size(); idx < e; ++idx) {
607 if (idx != 0)
608 out << ", ";
609 while (openBrackets++ < rank)
610 out << '{';
611 openBrackets = rank;
612 out << scales[idx];
613 if (zeroPoints[idx] != 0) {
614 out << ":" << zeroPoints[idx];
615 }
616 incrementCounterAndDelimit();
617 }
618 while (openBrackets-- > 0)
619 out << '}';
620}
621
622/// Helper that prints a UniformQuantizedSubChannelType.
623static void
624printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type,
625 DialectAsmPrinter &out) {
626 out << "uniform<";
627 printStorageType(type, out);
628 out << ":" << type.getExpressedType() << ":";
629 printBlockSizeInfo(blockSizeInfo: type.getBlockSizeInfo(), out);
630 out << ", ";
631
632 auto scalesItr = type.getScales().getValues<APFloat>();
633 auto zeroPointsItr = type.getZeroPoints().getValues<APInt>();
634 SmallVector<APFloat> scales(scalesItr.begin(), scalesItr.end());
635 SmallVector<APInt> zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end());
636 printDenseQuantizationParameters(scales, zeroPoints,
637 type.getScales().getType().getShape(), out);
638 out << ">";
639}
640
641/// Helper that prints a CalibratedQuantizedType.
642static void printCalibratedQuantizedType(CalibratedQuantizedType type,
643 DialectAsmPrinter &out) {
644 out << "calibrated<" << type.getExpressedType();
645 out << "<" << type.getMin() << ":" << type.getMax() << ">";
646 out << ">";
647}
648
649/// Print a type registered to this dialect.
650void QuantDialect::printType(Type type, DialectAsmPrinter &os) const {
651 if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
652 printAnyQuantizedType(anyType, os);
653 else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
654 printUniformQuantizedType(uniformType, os);
655 else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
656 printUniformQuantizedPerAxisType(perAxisType, os);
657 else if (auto perAxisType =
658 llvm::dyn_cast<UniformQuantizedSubChannelType>(type))
659 printUniformQuantizedSubChannelType(perAxisType, os);
660 else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
661 printCalibratedQuantizedType(calibratedType, os);
662 else
663 llvm_unreachable("Unhandled quantized type");
664}
665

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Quant/IR/TypeParser.cpp