1//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Quant/QuantOps.h"
10#include "mlir/Dialect/Quant/QuantTypes.h"
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/IR/DialectImplementation.h"
13#include "mlir/IR/Location.h"
14#include "mlir/IR/Types.h"
15#include "llvm/ADT/APFloat.h"
16#include "llvm/Support/Format.h"
17#include "llvm/Support/MathExtras.h"
18#include "llvm/Support/SourceMgr.h"
19#include "llvm/Support/raw_ostream.h"
20
21using namespace mlir;
22using namespace quant;
23
24static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
25 auto typeLoc = parser.getCurrentLocation();
26 IntegerType type;
27
28 // Parse storage type (alpha_ident, integer_literal).
29 StringRef identifier;
30 unsigned storageTypeWidth = 0;
31 OptionalParseResult result = parser.parseOptionalType(result&: type);
32 if (result.has_value()) {
33 if (!succeeded(result: *result))
34 return nullptr;
35 isSigned = !type.isUnsigned();
36 storageTypeWidth = type.getWidth();
37 } else if (succeeded(result: parser.parseKeyword(keyword: &identifier))) {
38 // Otherwise, this must be an unsigned integer (`u` integer-literal).
39 if (!identifier.consume_front(Prefix: "u")) {
40 parser.emitError(loc: typeLoc, message: "illegal storage type prefix");
41 return nullptr;
42 }
43 if (identifier.getAsInteger(Radix: 10, Result&: storageTypeWidth)) {
44 parser.emitError(loc: typeLoc, message: "expected storage type width");
45 return nullptr;
46 }
47 isSigned = false;
48 type = parser.getBuilder().getIntegerType(storageTypeWidth);
49 } else {
50 return nullptr;
51 }
52
53 if (storageTypeWidth == 0 ||
54 storageTypeWidth > QuantizedType::MaxStorageBits) {
55 parser.emitError(loc: typeLoc, message: "illegal storage type size: ")
56 << storageTypeWidth;
57 return nullptr;
58 }
59
60 return type;
61}
62
63static ParseResult parseStorageRange(DialectAsmParser &parser,
64 IntegerType storageType, bool isSigned,
65 int64_t &storageTypeMin,
66 int64_t &storageTypeMax) {
67 int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
68 isSigned, integralWidth: storageType.getWidth());
69 int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
70 isSigned, integralWidth: storageType.getWidth());
71 if (failed(result: parser.parseOptionalLess())) {
72 storageTypeMin = defaultIntegerMin;
73 storageTypeMax = defaultIntegerMax;
74 return success();
75 }
76
77 // Explicit storage min and storage max.
78 SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
79 if (parser.parseInteger(result&: storageTypeMin) || parser.parseColon() ||
80 parser.getCurrentLocation(loc: &maxLoc) ||
81 parser.parseInteger(result&: storageTypeMax) || parser.parseGreater())
82 return failure();
83 if (storageTypeMin < defaultIntegerMin) {
84 return parser.emitError(loc: minLoc, message: "illegal storage type minimum: ")
85 << storageTypeMin;
86 }
87 if (storageTypeMax > defaultIntegerMax) {
88 return parser.emitError(loc: maxLoc, message: "illegal storage type maximum: ")
89 << storageTypeMax;
90 }
91 return success();
92}
93
94static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
95 double &min, double &max) {
96 auto typeLoc = parser.getCurrentLocation();
97 FloatType type;
98
99 if (failed(result: parser.parseType(result&: type))) {
100 parser.emitError(loc: typeLoc, message: "expecting float expressed type");
101 return nullptr;
102 }
103
104 // Calibrated min and max values.
105 if (parser.parseLess() || parser.parseFloat(result&: min) || parser.parseColon() ||
106 parser.parseFloat(result&: max) || parser.parseGreater()) {
107 parser.emitError(loc: typeLoc, message: "calibrated values must be present");
108 return nullptr;
109 }
110 return type;
111}
112
113/// Parses an AnyQuantizedType.
114///
115/// any ::= `any<` storage-spec (expressed-type-spec)?`>`
116/// storage-spec ::= storage-type (`<` storage-range `>`)?
117/// storage-range ::= integer-literal `:` integer-literal
118/// storage-type ::= (`i` | `u`) integer-literal
119/// expressed-type-spec ::= `:` `f` integer-literal
120static Type parseAnyType(DialectAsmParser &parser) {
121 IntegerType storageType;
122 FloatType expressedType;
123 unsigned typeFlags = 0;
124 int64_t storageTypeMin;
125 int64_t storageTypeMax;
126
127 // Type specification.
128 if (parser.parseLess())
129 return nullptr;
130
131 // Storage type.
132 bool isSigned = false;
133 storageType = parseStorageType(parser, isSigned);
134 if (!storageType) {
135 return nullptr;
136 }
137 if (isSigned) {
138 typeFlags |= QuantizationFlags::Signed;
139 }
140
141 // Storage type range.
142 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
143 storageTypeMax)) {
144 return nullptr;
145 }
146
147 // Optional expressed type.
148 if (succeeded(result: parser.parseOptionalColon())) {
149 if (parser.parseType(result&: expressedType)) {
150 return nullptr;
151 }
152 }
153
154 if (parser.parseGreater()) {
155 return nullptr;
156 }
157
158 return parser.getChecked<AnyQuantizedType>(
159 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
160}
161
162static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
163 int64_t &zeroPoint) {
164 // scale[:zeroPoint]?
165 // scale.
166 if (parser.parseFloat(result&: scale))
167 return failure();
168
169 // zero point.
170 zeroPoint = 0;
171 if (failed(result: parser.parseOptionalColon())) {
172 // Default zero point.
173 return success();
174 }
175
176 return parser.parseInteger(result&: zeroPoint);
177}
178
179/// Parses a UniformQuantizedType.
180///
181/// uniform_type ::= uniform_per_layer
182/// | uniform_per_axis
183/// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
184/// `,` scale-zero `>`
185/// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
186/// axis-spec `,` scale-zero-list `>`
187/// storage-spec ::= storage-type (`<` storage-range `>`)?
188/// storage-range ::= integer-literal `:` integer-literal
189/// storage-type ::= (`i` | `u`) integer-literal
190/// expressed-type-spec ::= `:` `f` integer-literal
191/// axis-spec ::= `:` integer-literal
192/// scale-zero ::= float-literal `:` integer-literal
193/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
194static Type parseUniformType(DialectAsmParser &parser) {
195 IntegerType storageType;
196 FloatType expressedType;
197 unsigned typeFlags = 0;
198 int64_t storageTypeMin;
199 int64_t storageTypeMax;
200 bool isPerAxis = false;
201 int32_t quantizedDimension;
202 SmallVector<double, 1> scales;
203 SmallVector<int64_t, 1> zeroPoints;
204
205 // Type specification.
206 if (parser.parseLess()) {
207 return nullptr;
208 }
209
210 // Storage type.
211 bool isSigned = false;
212 storageType = parseStorageType(parser, isSigned);
213 if (!storageType) {
214 return nullptr;
215 }
216 if (isSigned) {
217 typeFlags |= QuantizationFlags::Signed;
218 }
219
220 // Storage type range.
221 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
222 storageTypeMax)) {
223 return nullptr;
224 }
225
226 // Expressed type.
227 if (parser.parseColon() || parser.parseType(result&: expressedType)) {
228 return nullptr;
229 }
230
231 // Optionally parse quantized dimension for per-axis quantization.
232 if (succeeded(result: parser.parseOptionalColon())) {
233 if (parser.parseInteger(result&: quantizedDimension))
234 return nullptr;
235 isPerAxis = true;
236 }
237
238 // Comma leading into range_spec.
239 if (parser.parseComma()) {
240 return nullptr;
241 }
242
243 // Parameter specification.
244 // For per-axis, ranges are in a {} delimitted list.
245 if (isPerAxis) {
246 if (parser.parseLBrace()) {
247 return nullptr;
248 }
249 }
250
251 // Parse scales/zeroPoints.
252 SMLoc scaleZPLoc = parser.getCurrentLocation();
253 do {
254 scales.resize(N: scales.size() + 1);
255 zeroPoints.resize(N: zeroPoints.size() + 1);
256 if (parseQuantParams(parser, scale&: scales.back(), zeroPoint&: zeroPoints.back())) {
257 return nullptr;
258 }
259 } while (isPerAxis && succeeded(result: parser.parseOptionalComma()));
260
261 if (isPerAxis) {
262 if (parser.parseRBrace()) {
263 return nullptr;
264 }
265 }
266
267 if (parser.parseGreater()) {
268 return nullptr;
269 }
270
271 if (!isPerAxis && scales.size() > 1) {
272 return (parser.emitError(loc: scaleZPLoc,
273 message: "multiple scales/zeroPoints provided, but "
274 "quantizedDimension wasn't specified"),
275 nullptr);
276 }
277
278 if (isPerAxis) {
279 ArrayRef<double> scalesRef(scales.begin(), scales.end());
280 ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
281 return parser.getChecked<UniformQuantizedPerAxisType>(
282 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
283 quantizedDimension, storageTypeMin, storageTypeMax);
284 }
285
286 return parser.getChecked<UniformQuantizedType>(
287 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
288 storageTypeMin, storageTypeMax);
289}
290
291/// Parses an CalibratedQuantizedType.
292///
293/// calibrated ::= `calibrated<` expressed-spec `>`
294/// expressed-spec ::= expressed-type `<` calibrated-range `>`
295/// expressed-type ::= `f` integer-literal
296/// calibrated-range ::= float-literal `:` float-literal
297static Type parseCalibratedType(DialectAsmParser &parser) {
298 FloatType expressedType;
299 double min;
300 double max;
301
302 // Type specification.
303 if (parser.parseLess())
304 return nullptr;
305
306 // Expressed type.
307 expressedType = parseExpressedTypeAndRange(parser, min, max);
308 if (!expressedType) {
309 return nullptr;
310 }
311
312 if (parser.parseGreater()) {
313 return nullptr;
314 }
315
316 return parser.getChecked<CalibratedQuantizedType>(params&: expressedType, params&: min, params&: max);
317}
318
319/// Parse a type registered to this dialect.
320Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
321 // All types start with an identifier that we switch on.
322 StringRef typeNameSpelling;
323 if (failed(result: parser.parseKeyword(keyword: &typeNameSpelling)))
324 return nullptr;
325
326 if (typeNameSpelling == "uniform")
327 return parseUniformType(parser);
328 if (typeNameSpelling == "any")
329 return parseAnyType(parser);
330 if (typeNameSpelling == "calibrated")
331 return parseCalibratedType(parser);
332
333 parser.emitError(loc: parser.getNameLoc(),
334 message: "unknown quantized type " + typeNameSpelling);
335 return nullptr;
336}
337
338static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
339 // storage type
340 unsigned storageWidth = type.getStorageTypeIntegralWidth();
341 bool isSigned = type.isSigned();
342 if (isSigned) {
343 out << "i" << storageWidth;
344 } else {
345 out << "u" << storageWidth;
346 }
347
348 // storageTypeMin and storageTypeMax if not default.
349 int64_t defaultIntegerMin =
350 QuantizedType::getDefaultMinimumForInteger(isSigned, integralWidth: storageWidth);
351 int64_t defaultIntegerMax =
352 QuantizedType::getDefaultMaximumForInteger(isSigned, integralWidth: storageWidth);
353 if (defaultIntegerMin != type.getStorageTypeMin() ||
354 defaultIntegerMax != type.getStorageTypeMax()) {
355 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
356 << ">";
357 }
358}
359
360static void printQuantParams(double scale, int64_t zeroPoint,
361 DialectAsmPrinter &out) {
362 out << scale;
363 if (zeroPoint != 0) {
364 out << ":" << zeroPoint;
365 }
366}
367
368/// Helper that prints a AnyQuantizedType.
369static void printAnyQuantizedType(AnyQuantizedType type,
370 DialectAsmPrinter &out) {
371 out << "any<";
372 printStorageType(type, out);
373 if (Type expressedType = type.getExpressedType()) {
374 out << ":" << expressedType;
375 }
376 out << ">";
377}
378
379/// Helper that prints a UniformQuantizedType.
380static void printUniformQuantizedType(UniformQuantizedType type,
381 DialectAsmPrinter &out) {
382 out << "uniform<";
383 printStorageType(type, out);
384 out << ":" << type.getExpressedType() << ", ";
385
386 // scheme specific parameters
387 printQuantParams(scale: type.getScale(), zeroPoint: type.getZeroPoint(), out);
388 out << ">";
389}
390
391/// Helper that prints a UniformQuantizedPerAxisType.
392static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
393 DialectAsmPrinter &out) {
394 out << "uniform<";
395 printStorageType(type, out);
396 out << ":" << type.getExpressedType() << ":";
397 out << type.getQuantizedDimension();
398 out << ", ";
399
400 // scheme specific parameters
401 ArrayRef<double> scales = type.getScales();
402 ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
403 out << "{";
404 llvm::interleave(
405 c: llvm::seq<size_t>(Begin: 0, End: scales.size()), os&: out,
406 each_fn: [&](size_t index) {
407 printQuantParams(scale: scales[index], zeroPoint: zeroPoints[index], out);
408 },
409 separator: ",");
410 out << "}>";
411}
412
413/// Helper that prints a CalibratedQuantizedType.
414static void printCalibratedQuantizedType(CalibratedQuantizedType type,
415 DialectAsmPrinter &out) {
416 out << "calibrated<" << type.getExpressedType();
417 out << "<" << type.getMin() << ":" << type.getMax() << ">";
418 out << ">";
419}
420
421/// Print a type registered to this dialect.
422void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
423 if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(Val&: type))
424 printAnyQuantizedType(type: anyType, out&: os);
425 else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(Val&: type))
426 printUniformQuantizedType(type: uniformType, out&: os);
427 else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(Val&: type))
428 printUniformQuantizedPerAxisType(type: perAxisType, out&: os);
429 else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(Val&: type))
430 printCalibratedQuantizedType(type: calibratedType, out&: os);
431 else
432 llvm_unreachable("Unhandled quantized type");
433}
434

source code of mlir/lib/Dialect/Quant/IR/TypeParser.cpp