1 | //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Quant/IR/Quant.h" |
10 | #include "mlir/Dialect/Quant/IR/QuantTypes.h" |
11 | #include "mlir/IR/BuiltinTypes.h" |
12 | #include "mlir/IR/DialectImplementation.h" |
13 | #include "mlir/IR/Location.h" |
14 | #include "mlir/IR/Types.h" |
15 | #include "llvm/ADT/APFloat.h" |
16 | #include "llvm/Support/Format.h" |
17 | #include "llvm/Support/MathExtras.h" |
18 | #include "llvm/Support/SourceMgr.h" |
19 | #include "llvm/Support/raw_ostream.h" |
20 | |
21 | using namespace mlir; |
22 | using namespace quant; |
23 | |
24 | static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) { |
25 | auto typeLoc = parser.getCurrentLocation(); |
26 | IntegerType type; |
27 | |
28 | // Parse storage type (alpha_ident, integer_literal). |
29 | StringRef identifier; |
30 | unsigned storageTypeWidth = 0; |
31 | OptionalParseResult result = parser.parseOptionalType(result&: type); |
32 | if (result.has_value()) { |
33 | if (!succeeded(Result: *result)) |
34 | return nullptr; |
35 | isSigned = !type.isUnsigned(); |
36 | storageTypeWidth = type.getWidth(); |
37 | } else if (succeeded(Result: parser.parseKeyword(keyword: &identifier))) { |
38 | // Otherwise, this must be an unsigned integer (`u` integer-literal). |
39 | if (!identifier.consume_front(Prefix: "u")) { |
40 | parser.emitError(loc: typeLoc, message: "illegal storage type prefix"); |
41 | return nullptr; |
42 | } |
43 | if (identifier.getAsInteger(Radix: 10, Result&: storageTypeWidth)) { |
44 | parser.emitError(loc: typeLoc, message: "expected storage type width"); |
45 | return nullptr; |
46 | } |
47 | isSigned = false; |
48 | type = parser.getBuilder().getIntegerType(storageTypeWidth); |
49 | } else { |
50 | return nullptr; |
51 | } |
52 | |
53 | if (storageTypeWidth == 0 || |
54 | storageTypeWidth > QuantizedType::MaxStorageBits) { |
55 | parser.emitError(loc: typeLoc, message: "illegal storage type size: ") |
56 | << storageTypeWidth; |
57 | return nullptr; |
58 | } |
59 | |
60 | return type; |
61 | } |
62 | |
63 | static ParseResult parseStorageRange(DialectAsmParser &parser, |
64 | IntegerType storageType, bool isSigned, |
65 | int64_t &storageTypeMin, |
66 | int64_t &storageTypeMax) { |
67 | int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger( |
68 | isSigned, integralWidth: storageType.getWidth()); |
69 | int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger( |
70 | isSigned, integralWidth: storageType.getWidth()); |
71 | if (failed(Result: parser.parseOptionalLess())) { |
72 | storageTypeMin = defaultIntegerMin; |
73 | storageTypeMax = defaultIntegerMax; |
74 | return success(); |
75 | } |
76 | |
77 | // Explicit storage min and storage max. |
78 | SMLoc minLoc = parser.getCurrentLocation(), maxLoc; |
79 | if (parser.parseInteger(result&: storageTypeMin) || parser.parseColon() || |
80 | parser.getCurrentLocation(loc: &maxLoc) || |
81 | parser.parseInteger(result&: storageTypeMax) || parser.parseGreater()) |
82 | return failure(); |
83 | if (storageTypeMin < defaultIntegerMin) { |
84 | return parser.emitError(loc: minLoc, message: "illegal storage type minimum: ") |
85 | << storageTypeMin; |
86 | } |
87 | if (storageTypeMax > defaultIntegerMax) { |
88 | return parser.emitError(loc: maxLoc, message: "illegal storage type maximum: ") |
89 | << storageTypeMax; |
90 | } |
91 | return success(); |
92 | } |
93 | |
94 | static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, |
95 | double &min, double &max) { |
96 | auto typeLoc = parser.getCurrentLocation(); |
97 | FloatType type; |
98 | |
99 | if (failed(parser.parseType(type))) { |
100 | parser.emitError(loc: typeLoc, message: "expecting float expressed type"); |
101 | return nullptr; |
102 | } |
103 | |
104 | // Calibrated min and max values. |
105 | if (parser.parseLess() || parser.parseFloat(result&: min) || parser.parseColon() || |
106 | parser.parseFloat(result&: max) || parser.parseGreater()) { |
107 | parser.emitError(loc: typeLoc, message: "calibrated values must be present"); |
108 | return nullptr; |
109 | } |
110 | return type; |
111 | } |
112 | |
113 | /// Parses an AnyQuantizedType. |
114 | /// |
115 | /// any ::= `any<` storage-spec (expressed-type-spec)?`>` |
116 | /// storage-spec ::= storage-type (`<` storage-range `>`)? |
117 | /// storage-range ::= integer-literal `:` integer-literal |
118 | /// storage-type ::= (`i` | `u`) integer-literal |
119 | /// expressed-type-spec ::= `:` `f` integer-literal |
120 | static Type parseAnyType(DialectAsmParser &parser) { |
121 | IntegerType storageType; |
122 | FloatType expressedType; |
123 | unsigned typeFlags = 0; |
124 | int64_t storageTypeMin; |
125 | int64_t storageTypeMax; |
126 | |
127 | // Type specification. |
128 | if (parser.parseLess()) |
129 | return nullptr; |
130 | |
131 | // Storage type. |
132 | bool isSigned = false; |
133 | storageType = parseStorageType(parser, isSigned); |
134 | if (!storageType) { |
135 | return nullptr; |
136 | } |
137 | if (isSigned) { |
138 | typeFlags |= QuantizationFlags::Signed; |
139 | } |
140 | |
141 | // Storage type range. |
142 | if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, |
143 | storageTypeMax)) { |
144 | return nullptr; |
145 | } |
146 | |
147 | // Optional expressed type. |
148 | if (succeeded(Result: parser.parseOptionalColon())) { |
149 | if (parser.parseType(expressedType)) { |
150 | return nullptr; |
151 | } |
152 | } |
153 | |
154 | if (parser.parseGreater()) { |
155 | return nullptr; |
156 | } |
157 | |
158 | return parser.getChecked<AnyQuantizedType>( |
159 | typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax); |
160 | } |
161 | |
162 | /// Checks if the given scale value is within the valid range of the expressed |
163 | /// type. The `expressedType` argument is the floating-point type used for |
164 | /// expressing the quantized values, and `scale` is the double value to check. |
165 | LogicalResult |
166 | isScaleInExpressedTypeRange(function_ref<InFlightDiagnostic()> emitError, |
167 | Type expressedType, double scale) { |
168 | auto floatType = cast<FloatType>(expressedType); |
169 | double minScale = |
170 | APFloat::getSmallest(Sem: floatType.getFloatSemantics()).convertToDouble(); |
171 | double maxScale = |
172 | APFloat::getLargest(Sem: floatType.getFloatSemantics()).convertToDouble(); |
173 | if (scale < minScale || scale > maxScale) |
174 | return emitError() << "scale "<< scale << " out of expressed type range [" |
175 | << minScale << ", "<< maxScale << "]"; |
176 | return success(); |
177 | } |
178 | |
179 | /// Parses a quantization parameter, which is either a scale value (float) or a |
180 | /// scale-zero point pair (float:integer). `expressedType`, expressing the type |
181 | /// of scale values, is used to validate the scale. The parsed scale and zero |
182 | /// point (if any) are stored in `scale` and `zeroPoint`. |
183 | static ParseResult parseQuantParams(DialectAsmParser &parser, |
184 | Type expressedType, double &scale, |
185 | int64_t &zeroPoint) { |
186 | |
187 | if (parser.parseFloat(result&: scale)) { |
188 | return failure(); |
189 | } |
190 | |
191 | if (failed(Result: isScaleInExpressedTypeRange( |
192 | emitError: [&]() { return parser.emitError(loc: parser.getCurrentLocation()); }, |
193 | expressedType, scale))) { |
194 | return failure(); |
195 | } |
196 | |
197 | zeroPoint = 0; |
198 | if (failed(Result: parser.parseOptionalColon())) { |
199 | return success(); |
200 | } |
201 | |
202 | return parser.parseInteger(result&: zeroPoint); |
203 | } |
204 | |
205 | /// Parses block size information for sub-channel quantization, assuming the |
206 | /// leading '{' has already been parsed. The block size information is provided |
207 | /// as a comma-separated list of "Axis:BlockSize" pairs, terminated by a '}'. |
208 | /// |
209 | /// The parsed axis indices are stored in `quantizedDimensions`, and the |
210 | /// corresponding block sizes are stored in `blockSizes`. |
211 | static ParseResult |
212 | parseBlockSizeInfoUntilRBrace(DialectAsmParser &parser, |
213 | SmallVectorImpl<int32_t> &quantizedDimensions, |
214 | SmallVectorImpl<int64_t> &blockSizes) { |
215 | // Empty block-sizes info. |
216 | if (succeeded(Result: parser.parseOptionalRBrace())) { |
217 | return success(); |
218 | } |
219 | |
220 | auto parseBlockSizeElements = [&]() -> ParseResult { |
221 | quantizedDimensions.resize(N: quantizedDimensions.size() + 1); |
222 | blockSizes.resize(N: blockSizes.size() + 1); |
223 | if (parser.parseInteger(result&: quantizedDimensions.back()) || |
224 | parser.parseColon() || parser.parseInteger(result&: blockSizes.back())) |
225 | return failure(); |
226 | return success(); |
227 | }; |
228 | |
229 | if (parser.parseCommaSeparatedList(parseElementFn: parseBlockSizeElements) || |
230 | parser.parseRBrace()) { |
231 | return failure(); |
232 | } |
233 | |
234 | return success(); |
235 | } |
236 | |
237 | /// Parses a bracketed list of quantization parameters, returning the dimensions |
238 | /// of the parsed sub-tensors in `dims`. The dimension of the list is prepended |
239 | /// to the dimensions of the sub-tensors. This function assumes that the initial |
240 | /// left brace has already been parsed. For example: |
241 | /// |
242 | /// parseQuantParamListUntilRBrace(1.0:1, 2.0:4, 3.0:4}) -> Success, |
243 | /// dims = [3], scales = [1.0, 2.0, 3.0], zeroPoints = [1, 4, 4] |
244 | /// |
245 | /// parseQuantParamListUntilRBrace({1.0, 2.0}, {3.0:1, 4.0:9}}) -> Success, |
246 | /// dims = [2, 2], scales = [1.0, 2.0, 3.0, 4.0], zeroPoints = [0, 0, 1, |
247 | /// 9] |
248 | /// |
249 | /// This function expects all sub-tensors to have the same rank. |
250 | static ParseResult |
251 | parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, |
252 | SmallVectorImpl<double> &scales, |
253 | SmallVectorImpl<int64_t> &zeroPoints, |
254 | SmallVectorImpl<int64_t> &dims) { |
255 | auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims, |
256 | const SmallVectorImpl<int64_t> &newDims) -> ParseResult { |
257 | if (prevDims == newDims) |
258 | return success(); |
259 | return parser.emitError(loc: parser.getCurrentLocation()) |
260 | << "tensor literal is invalid; ranks are not consistent " |
261 | "between elements"; |
262 | }; |
263 | |
264 | bool first = true; |
265 | SmallVector<int64_t, 4> newDims; |
266 | unsigned size = 0; |
267 | |
268 | auto parseOneElement = [&]() -> ParseResult { |
269 | SmallVector<int64_t, 4> thisDims; |
270 | if (succeeded(Result: parser.parseOptionalLBrace())) { |
271 | if (parseQuantParamListUntilRBrace(parser, expressedType, scales, |
272 | zeroPoints, dims&: thisDims)) |
273 | return failure(); |
274 | } else { |
275 | zeroPoints.resize(N: zeroPoints.size() + 1); |
276 | scales.resize(N: scales.size() + 1); |
277 | if (parseQuantParams(parser, expressedType, scale&: scales.back(), |
278 | zeroPoint&: zeroPoints.back())) { |
279 | return failure(); |
280 | } |
281 | } |
282 | ++size; |
283 | if (!first) |
284 | return checkDims(newDims, thisDims); |
285 | newDims = thisDims; |
286 | first = false; |
287 | return success(); |
288 | }; |
289 | |
290 | if (parser.parseCommaSeparatedList(parseElementFn: parseOneElement) || parser.parseRBrace()) { |
291 | return failure(); |
292 | } |
293 | |
294 | // Return the sublists' dimensions with 'size' prepended. |
295 | dims.clear(); |
296 | dims.push_back(Elt: size); |
297 | dims.append(in_start: newDims.begin(), in_end: newDims.end()); |
298 | |
299 | return success(); |
300 | } |
301 | |
302 | /// Parses a UniformQuantizedType. |
303 | /// |
304 | /// uniform_type ::= uniform_per_layer |
305 | /// | uniform_per_axis |
306 | /// | uniform_sub_channel |
307 | /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec |
308 | /// `,` scale-zero `>` |
309 | /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec |
310 | /// axis-spec `,` `{` scale-zero-list `}` `>` |
311 | /// uniform_sub_channel ::= `uniform<` storage-spec expressed-type-spec |
312 | /// block-size-info `,` scale-zero-tensor `>` |
313 | /// storage-spec ::= storage-type (`<` storage-range `>`)? |
314 | /// storage-range ::= integer-literal `:` integer-literal |
315 | /// storage-type ::= (`i` | `u`) integer-literal |
316 | /// expressed-type-spec ::= `:` `f` integer-literal |
317 | /// axis-spec ::= `:` integer-literal |
318 | /// scale-zero ::= scale (`:` zero-point)? |
319 | /// scale ::= float-literal |
320 | /// zero-point ::= integer-literal |
321 | /// scale-zero-list ::= scale-zero (`,` scale-zero)* |
322 | /// block-size-info ::= `{` `}` | `{` axis-block `:` (`,` axis-block)* `}` |
323 | /// axis-block ::= axis-spec `:` block-size-spec |
324 | /// block-size-spec ::= integer-literal |
325 | /// scale-zero-tensor ::= scale-zero-dense-exp | scale-zero-list |
326 | /// scale-zero-dense-exp ::= `{` |
327 | /// scale-zero-tensor (`,` scale-zero-tensor)* |
328 | /// `}` |
329 | static Type parseUniformType(DialectAsmParser &parser) { |
330 | IntegerType storageType; |
331 | FloatType expressedType; |
332 | unsigned typeFlags = 0; |
333 | int64_t storageTypeMin; |
334 | int64_t storageTypeMax; |
335 | bool isPerAxis = false; |
336 | bool isSubChannel = false; |
337 | SmallVector<int32_t, 1> quantizedDimensions; |
338 | SmallVector<int64_t, 1> blockSizes; |
339 | SmallVector<double, 1> scales; |
340 | SmallVector<int64_t, 1> zeroPoints; |
341 | |
342 | // Type specification. |
343 | if (parser.parseLess()) { |
344 | return nullptr; |
345 | } |
346 | |
347 | // Storage type. |
348 | bool isSigned = false; |
349 | storageType = parseStorageType(parser, isSigned); |
350 | if (!storageType) { |
351 | return nullptr; |
352 | } |
353 | if (isSigned) { |
354 | typeFlags |= QuantizationFlags::Signed; |
355 | } |
356 | |
357 | // Storage type range. |
358 | if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, |
359 | storageTypeMax)) { |
360 | return nullptr; |
361 | } |
362 | |
363 | // Expressed type. |
364 | if (parser.parseColon() || parser.parseType(expressedType)) { |
365 | return nullptr; |
366 | } |
367 | |
368 | // Optionally parse quantized dimension for per-axis or sub-channel |
369 | // quantization. |
370 | if (succeeded(Result: parser.parseOptionalColon())) { |
371 | if (succeeded(Result: parser.parseOptionalLBrace())) { |
372 | isSubChannel = true; |
373 | if (parseBlockSizeInfoUntilRBrace(parser, quantizedDimensions, |
374 | blockSizes)) { |
375 | return nullptr; |
376 | } |
377 | } else { |
378 | isPerAxis = true; |
379 | quantizedDimensions.resize(N: 1); |
380 | if (parser.parseInteger(result&: quantizedDimensions.back())) { |
381 | return nullptr; |
382 | } |
383 | } |
384 | } |
385 | |
386 | // Comma leading into range_spec. |
387 | if (parser.parseComma()) { |
388 | return nullptr; |
389 | } |
390 | |
391 | // Quantization parameter (scales/zeroPoints) specification. |
392 | bool isPerTensor = !isPerAxis && !isSubChannel; |
393 | SmallVector<int64_t> dims; |
394 | if (isPerTensor) { |
395 | zeroPoints.resize(N: zeroPoints.size() + 1); |
396 | scales.resize(N: scales.size() + 1); |
397 | if (parseQuantParams(parser, expressedType, scales.back(), |
398 | zeroPoints.back())) { |
399 | return nullptr; |
400 | } |
401 | |
402 | } else { |
403 | if (parser.parseLBrace() || |
404 | parseQuantParamListUntilRBrace(parser, expressedType, scales, |
405 | zeroPoints, dims)) { |
406 | return nullptr; |
407 | } |
408 | } |
409 | |
410 | if (parser.parseGreater()) { |
411 | return nullptr; |
412 | } |
413 | |
414 | if (isPerAxis) { |
415 | return parser.getChecked<UniformQuantizedPerAxisType>( |
416 | typeFlags, storageType, expressedType, scales, zeroPoints, |
417 | quantizedDimensions[0], storageTypeMin, storageTypeMax); |
418 | } else if (isSubChannel) { |
419 | SmallVector<APFloat> apFloatScales = |
420 | llvm::to_vector(Range: llvm::map_range(C&: scales, F: [&](double scale) -> APFloat { |
421 | APFloat apFloatScale(scale); |
422 | bool unused; |
423 | apFloatScale.convert(ToSemantics: expressedType.getFloatSemantics(), |
424 | RM: APFloat::rmNearestTiesToEven, losesInfo: &unused); |
425 | return apFloatScale; |
426 | })); |
427 | SmallVector<APInt> apIntZeroPoints = llvm::to_vector( |
428 | Range: llvm::map_range(C&: zeroPoints, F: [&](int64_t zeroPoint) -> APInt { |
429 | return APInt(storageType.getIntOrFloatBitWidth(), zeroPoint); |
430 | })); |
431 | auto scalesRef = mlir::DenseElementsAttr::get( |
432 | RankedTensorType::get(dims, expressedType), apFloatScales); |
433 | auto zeroPointsRef = mlir::DenseElementsAttr::get( |
434 | RankedTensorType::get(dims, storageType), apIntZeroPoints); |
435 | return parser.getChecked<UniformQuantizedSubChannelType>( |
436 | typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, |
437 | quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax); |
438 | } |
439 | |
440 | return parser.getChecked<UniformQuantizedType>( |
441 | typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(), |
442 | storageTypeMin, storageTypeMax); |
443 | } |
444 | |
445 | /// Parses an CalibratedQuantizedType. |
446 | /// |
447 | /// calibrated ::= `calibrated<` expressed-spec `>` |
448 | /// expressed-spec ::= expressed-type `<` calibrated-range `>` |
449 | /// expressed-type ::= `f` integer-literal |
450 | /// calibrated-range ::= float-literal `:` float-literal |
451 | static Type parseCalibratedType(DialectAsmParser &parser) { |
452 | FloatType expressedType; |
453 | double min; |
454 | double max; |
455 | |
456 | // Type specification. |
457 | if (parser.parseLess()) |
458 | return nullptr; |
459 | |
460 | // Expressed type. |
461 | expressedType = parseExpressedTypeAndRange(parser, min, max); |
462 | if (!expressedType) { |
463 | return nullptr; |
464 | } |
465 | |
466 | if (parser.parseGreater()) { |
467 | return nullptr; |
468 | } |
469 | |
470 | return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max); |
471 | } |
472 | |
473 | /// Parse a type registered to this dialect. |
474 | Type QuantDialect::parseType(DialectAsmParser &parser) const { |
475 | // All types start with an identifier that we switch on. |
476 | StringRef typeNameSpelling; |
477 | if (failed(parser.parseKeyword(&typeNameSpelling))) |
478 | return nullptr; |
479 | |
480 | if (typeNameSpelling == "uniform") |
481 | return parseUniformType(parser); |
482 | if (typeNameSpelling == "any") |
483 | return parseAnyType(parser); |
484 | if (typeNameSpelling == "calibrated") |
485 | return parseCalibratedType(parser); |
486 | |
487 | parser.emitError(parser.getNameLoc(), |
488 | "unknown quantized type "+ typeNameSpelling); |
489 | return nullptr; |
490 | } |
491 | |
492 | static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { |
493 | // storage type |
494 | unsigned storageWidth = type.getStorageTypeIntegralWidth(); |
495 | bool isSigned = type.isSigned(); |
496 | if (isSigned) { |
497 | out << "i"<< storageWidth; |
498 | } else { |
499 | out << "u"<< storageWidth; |
500 | } |
501 | |
502 | // storageTypeMin and storageTypeMax if not default. |
503 | if (type.hasStorageTypeBounds()) { |
504 | out << "<"<< type.getStorageTypeMin() << ":"<< type.getStorageTypeMax() |
505 | << ">"; |
506 | } |
507 | } |
508 | |
509 | static void printQuantParams(double scale, int64_t zeroPoint, |
510 | DialectAsmPrinter &out) { |
511 | out << scale; |
512 | if (zeroPoint != 0) { |
513 | out << ":"<< zeroPoint; |
514 | } |
515 | } |
516 | |
517 | static void |
518 | printBlockSizeInfo(ArrayRef<std::pair<int32_t, int64_t>> blockSizeInfo, |
519 | DialectAsmPrinter &out) { |
520 | out << "{"; |
521 | llvm::interleaveComma( |
522 | c: llvm::seq<size_t>(Begin: 0, End: blockSizeInfo.size()), os&: out, each_fn: [&](size_t index) { |
523 | out << blockSizeInfo[index].first << ":"<< blockSizeInfo[index].second; |
524 | }); |
525 | out << "}"; |
526 | } |
527 | |
528 | /// Helper that prints a AnyQuantizedType. |
529 | static void printAnyQuantizedType(AnyQuantizedType type, |
530 | DialectAsmPrinter &out) { |
531 | out << "any<"; |
532 | printStorageType(type, out); |
533 | if (Type expressedType = type.getExpressedType()) { |
534 | out << ":"<< expressedType; |
535 | } |
536 | out << ">"; |
537 | } |
538 | |
539 | /// Helper that prints a UniformQuantizedType. |
540 | static void printUniformQuantizedType(UniformQuantizedType type, |
541 | DialectAsmPrinter &out) { |
542 | out << "uniform<"; |
543 | printStorageType(type, out); |
544 | out << ":"<< type.getExpressedType() << ", "; |
545 | |
546 | // scheme specific parameters |
547 | printQuantParams(scale: type.getScale(), zeroPoint: type.getZeroPoint(), out); |
548 | out << ">"; |
549 | } |
550 | |
551 | /// Helper that prints a UniformQuantizedPerAxisType. |
552 | static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, |
553 | DialectAsmPrinter &out) { |
554 | out << "uniform<"; |
555 | printStorageType(type, out); |
556 | out << ":"<< type.getExpressedType() << ":"; |
557 | out << type.getQuantizedDimension(); |
558 | out << ", "; |
559 | |
560 | // scheme specific parameters |
561 | ArrayRef<double> scales = type.getScales(); |
562 | ArrayRef<int64_t> zeroPoints = type.getZeroPoints(); |
563 | out << "{"; |
564 | llvm::interleave( |
565 | c: llvm::seq<size_t>(Begin: 0, End: scales.size()), os&: out, |
566 | each_fn: [&](size_t index) { |
567 | printQuantParams(scale: scales[index], zeroPoint: zeroPoints[index], out); |
568 | }, |
569 | separator: ","); |
570 | out << "}>"; |
571 | } |
572 | |
573 | /// Prints quantization parameters as a nested list of `scale`[:`zero_point`] |
574 | /// elements. The nesting corresponds to the `shape` dimensions. |
575 | /// |
576 | /// Elements are delimited by commas, and the inner dimensions are enclosed in |
577 | /// braces. `zero_point` is only printed if it is non-zero. For example: |
578 | /// |
579 | /// printDenseQuantizationParameters(scales=[1.0, 2.0, 3.0, 4.0], |
580 | /// zeroPoints=[0, 0, 1, 9], |
581 | /// shape=[2, 2]) |
582 | /// |
583 | /// would print: |
584 | /// |
585 | /// {{1.0, 2.0}, {3.0:1, 4.0:9}} |
586 | void printDenseQuantizationParameters(ArrayRef<APFloat> scales, |
587 | ArrayRef<APInt> zeroPoints, |
588 | ArrayRef<int64_t> shape, |
589 | DialectAsmPrinter &out) { |
590 | int64_t rank = shape.size(); |
591 | SmallVector<unsigned, 4> counter(rank, 0); |
592 | unsigned openBrackets = 0; |
593 | |
594 | auto incrementCounterAndDelimit = [&]() { |
595 | ++counter[rank - 1]; |
596 | for (unsigned i = rank - 1; i > 0; --i) { |
597 | if (counter[i] >= shape[i]) { |
598 | counter[i] = 0; |
599 | ++counter[i - 1]; |
600 | --openBrackets; |
601 | out << '}'; |
602 | } |
603 | } |
604 | }; |
605 | |
606 | for (unsigned idx = 0, e = scales.size(); idx < e; ++idx) { |
607 | if (idx != 0) |
608 | out << ", "; |
609 | while (openBrackets++ < rank) |
610 | out << '{'; |
611 | openBrackets = rank; |
612 | out << scales[idx]; |
613 | if (zeroPoints[idx] != 0) { |
614 | out << ":"<< zeroPoints[idx]; |
615 | } |
616 | incrementCounterAndDelimit(); |
617 | } |
618 | while (openBrackets-- > 0) |
619 | out << '}'; |
620 | } |
621 | |
622 | /// Helper that prints a UniformQuantizedSubChannelType. |
623 | static void |
624 | printUniformQuantizedSubChannelType(UniformQuantizedSubChannelType type, |
625 | DialectAsmPrinter &out) { |
626 | out << "uniform<"; |
627 | printStorageType(type, out); |
628 | out << ":"<< type.getExpressedType() << ":"; |
629 | printBlockSizeInfo(blockSizeInfo: type.getBlockSizeInfo(), out); |
630 | out << ", "; |
631 | |
632 | auto scalesItr = type.getScales().getValues<APFloat>(); |
633 | auto zeroPointsItr = type.getZeroPoints().getValues<APInt>(); |
634 | SmallVector<APFloat> scales(scalesItr.begin(), scalesItr.end()); |
635 | SmallVector<APInt> zeroPoints(zeroPointsItr.begin(), zeroPointsItr.end()); |
636 | printDenseQuantizationParameters(scales, zeroPoints, |
637 | type.getScales().getType().getShape(), out); |
638 | out << ">"; |
639 | } |
640 | |
641 | /// Helper that prints a CalibratedQuantizedType. |
642 | static void printCalibratedQuantizedType(CalibratedQuantizedType type, |
643 | DialectAsmPrinter &out) { |
644 | out << "calibrated<"<< type.getExpressedType(); |
645 | out << "<"<< type.getMin() << ":"<< type.getMax() << ">"; |
646 | out << ">"; |
647 | } |
648 | |
649 | /// Print a type registered to this dialect. |
650 | void QuantDialect::printType(Type type, DialectAsmPrinter &os) const { |
651 | if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type)) |
652 | printAnyQuantizedType(anyType, os); |
653 | else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type)) |
654 | printUniformQuantizedType(uniformType, os); |
655 | else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type)) |
656 | printUniformQuantizedPerAxisType(perAxisType, os); |
657 | else if (auto perAxisType = |
658 | llvm::dyn_cast<UniformQuantizedSubChannelType>(type)) |
659 | printUniformQuantizedSubChannelType(perAxisType, os); |
660 | else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type)) |
661 | printCalibratedQuantizedType(calibratedType, os); |
662 | else |
663 | llvm_unreachable("Unhandled quantized type"); |
664 | } |
665 |
Definitions
- parseStorageType
- parseStorageRange
- parseExpressedTypeAndRange
- parseAnyType
- isScaleInExpressedTypeRange
- parseQuantParams
- parseBlockSizeInfoUntilRBrace
- parseQuantParamListUntilRBrace
- parseUniformType
- parseCalibratedType
- printStorageType
- printQuantParams
- printBlockSizeInfo
- printAnyQuantizedType
- printUniformQuantizedType
- printUniformQuantizedPerAxisType
- printDenseQuantizationParameters
- printUniformQuantizedSubChannelType
Learn to use CMake with our Intro Training
Find out more