1 | //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Quant/QuantOps.h" |
10 | #include "mlir/Dialect/Quant/QuantTypes.h" |
11 | #include "mlir/IR/BuiltinTypes.h" |
12 | #include "mlir/IR/DialectImplementation.h" |
13 | #include "mlir/IR/Location.h" |
14 | #include "mlir/IR/Types.h" |
15 | #include "llvm/ADT/APFloat.h" |
16 | #include "llvm/Support/Format.h" |
17 | #include "llvm/Support/MathExtras.h" |
18 | #include "llvm/Support/SourceMgr.h" |
19 | #include "llvm/Support/raw_ostream.h" |
20 | |
21 | using namespace mlir; |
22 | using namespace quant; |
23 | |
24 | static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) { |
25 | auto typeLoc = parser.getCurrentLocation(); |
26 | IntegerType type; |
27 | |
28 | // Parse storage type (alpha_ident, integer_literal). |
29 | StringRef identifier; |
30 | unsigned storageTypeWidth = 0; |
31 | OptionalParseResult result = parser.parseOptionalType(result&: type); |
32 | if (result.has_value()) { |
33 | if (!succeeded(result: *result)) |
34 | return nullptr; |
35 | isSigned = !type.isUnsigned(); |
36 | storageTypeWidth = type.getWidth(); |
37 | } else if (succeeded(result: parser.parseKeyword(keyword: &identifier))) { |
38 | // Otherwise, this must be an unsigned integer (`u` integer-literal). |
39 | if (!identifier.consume_front(Prefix: "u" )) { |
40 | parser.emitError(loc: typeLoc, message: "illegal storage type prefix" ); |
41 | return nullptr; |
42 | } |
43 | if (identifier.getAsInteger(Radix: 10, Result&: storageTypeWidth)) { |
44 | parser.emitError(loc: typeLoc, message: "expected storage type width" ); |
45 | return nullptr; |
46 | } |
47 | isSigned = false; |
48 | type = parser.getBuilder().getIntegerType(storageTypeWidth); |
49 | } else { |
50 | return nullptr; |
51 | } |
52 | |
53 | if (storageTypeWidth == 0 || |
54 | storageTypeWidth > QuantizedType::MaxStorageBits) { |
55 | parser.emitError(loc: typeLoc, message: "illegal storage type size: " ) |
56 | << storageTypeWidth; |
57 | return nullptr; |
58 | } |
59 | |
60 | return type; |
61 | } |
62 | |
63 | static ParseResult parseStorageRange(DialectAsmParser &parser, |
64 | IntegerType storageType, bool isSigned, |
65 | int64_t &storageTypeMin, |
66 | int64_t &storageTypeMax) { |
67 | int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger( |
68 | isSigned, integralWidth: storageType.getWidth()); |
69 | int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger( |
70 | isSigned, integralWidth: storageType.getWidth()); |
71 | if (failed(result: parser.parseOptionalLess())) { |
72 | storageTypeMin = defaultIntegerMin; |
73 | storageTypeMax = defaultIntegerMax; |
74 | return success(); |
75 | } |
76 | |
77 | // Explicit storage min and storage max. |
78 | SMLoc minLoc = parser.getCurrentLocation(), maxLoc; |
79 | if (parser.parseInteger(result&: storageTypeMin) || parser.parseColon() || |
80 | parser.getCurrentLocation(loc: &maxLoc) || |
81 | parser.parseInteger(result&: storageTypeMax) || parser.parseGreater()) |
82 | return failure(); |
83 | if (storageTypeMin < defaultIntegerMin) { |
84 | return parser.emitError(loc: minLoc, message: "illegal storage type minimum: " ) |
85 | << storageTypeMin; |
86 | } |
87 | if (storageTypeMax > defaultIntegerMax) { |
88 | return parser.emitError(loc: maxLoc, message: "illegal storage type maximum: " ) |
89 | << storageTypeMax; |
90 | } |
91 | return success(); |
92 | } |
93 | |
94 | static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, |
95 | double &min, double &max) { |
96 | auto typeLoc = parser.getCurrentLocation(); |
97 | FloatType type; |
98 | |
99 | if (failed(result: parser.parseType(result&: type))) { |
100 | parser.emitError(loc: typeLoc, message: "expecting float expressed type" ); |
101 | return nullptr; |
102 | } |
103 | |
104 | // Calibrated min and max values. |
105 | if (parser.parseLess() || parser.parseFloat(result&: min) || parser.parseColon() || |
106 | parser.parseFloat(result&: max) || parser.parseGreater()) { |
107 | parser.emitError(loc: typeLoc, message: "calibrated values must be present" ); |
108 | return nullptr; |
109 | } |
110 | return type; |
111 | } |
112 | |
113 | /// Parses an AnyQuantizedType. |
114 | /// |
115 | /// any ::= `any<` storage-spec (expressed-type-spec)?`>` |
116 | /// storage-spec ::= storage-type (`<` storage-range `>`)? |
117 | /// storage-range ::= integer-literal `:` integer-literal |
118 | /// storage-type ::= (`i` | `u`) integer-literal |
119 | /// expressed-type-spec ::= `:` `f` integer-literal |
120 | static Type parseAnyType(DialectAsmParser &parser) { |
121 | IntegerType storageType; |
122 | FloatType expressedType; |
123 | unsigned typeFlags = 0; |
124 | int64_t storageTypeMin; |
125 | int64_t storageTypeMax; |
126 | |
127 | // Type specification. |
128 | if (parser.parseLess()) |
129 | return nullptr; |
130 | |
131 | // Storage type. |
132 | bool isSigned = false; |
133 | storageType = parseStorageType(parser, isSigned); |
134 | if (!storageType) { |
135 | return nullptr; |
136 | } |
137 | if (isSigned) { |
138 | typeFlags |= QuantizationFlags::Signed; |
139 | } |
140 | |
141 | // Storage type range. |
142 | if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, |
143 | storageTypeMax)) { |
144 | return nullptr; |
145 | } |
146 | |
147 | // Optional expressed type. |
148 | if (succeeded(result: parser.parseOptionalColon())) { |
149 | if (parser.parseType(result&: expressedType)) { |
150 | return nullptr; |
151 | } |
152 | } |
153 | |
154 | if (parser.parseGreater()) { |
155 | return nullptr; |
156 | } |
157 | |
158 | return parser.getChecked<AnyQuantizedType>( |
159 | typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax); |
160 | } |
161 | |
162 | static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, |
163 | int64_t &zeroPoint) { |
164 | // scale[:zeroPoint]? |
165 | // scale. |
166 | if (parser.parseFloat(result&: scale)) |
167 | return failure(); |
168 | |
169 | // zero point. |
170 | zeroPoint = 0; |
171 | if (failed(result: parser.parseOptionalColon())) { |
172 | // Default zero point. |
173 | return success(); |
174 | } |
175 | |
176 | return parser.parseInteger(result&: zeroPoint); |
177 | } |
178 | |
179 | /// Parses a UniformQuantizedType. |
180 | /// |
181 | /// uniform_type ::= uniform_per_layer |
182 | /// | uniform_per_axis |
183 | /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec |
184 | /// `,` scale-zero `>` |
185 | /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec |
186 | /// axis-spec `,` scale-zero-list `>` |
187 | /// storage-spec ::= storage-type (`<` storage-range `>`)? |
188 | /// storage-range ::= integer-literal `:` integer-literal |
189 | /// storage-type ::= (`i` | `u`) integer-literal |
190 | /// expressed-type-spec ::= `:` `f` integer-literal |
191 | /// axis-spec ::= `:` integer-literal |
192 | /// scale-zero ::= float-literal `:` integer-literal |
193 | /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}` |
194 | static Type parseUniformType(DialectAsmParser &parser) { |
195 | IntegerType storageType; |
196 | FloatType expressedType; |
197 | unsigned typeFlags = 0; |
198 | int64_t storageTypeMin; |
199 | int64_t storageTypeMax; |
200 | bool isPerAxis = false; |
201 | int32_t quantizedDimension; |
202 | SmallVector<double, 1> scales; |
203 | SmallVector<int64_t, 1> zeroPoints; |
204 | |
205 | // Type specification. |
206 | if (parser.parseLess()) { |
207 | return nullptr; |
208 | } |
209 | |
210 | // Storage type. |
211 | bool isSigned = false; |
212 | storageType = parseStorageType(parser, isSigned); |
213 | if (!storageType) { |
214 | return nullptr; |
215 | } |
216 | if (isSigned) { |
217 | typeFlags |= QuantizationFlags::Signed; |
218 | } |
219 | |
220 | // Storage type range. |
221 | if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, |
222 | storageTypeMax)) { |
223 | return nullptr; |
224 | } |
225 | |
226 | // Expressed type. |
227 | if (parser.parseColon() || parser.parseType(result&: expressedType)) { |
228 | return nullptr; |
229 | } |
230 | |
231 | // Optionally parse quantized dimension for per-axis quantization. |
232 | if (succeeded(result: parser.parseOptionalColon())) { |
233 | if (parser.parseInteger(result&: quantizedDimension)) |
234 | return nullptr; |
235 | isPerAxis = true; |
236 | } |
237 | |
238 | // Comma leading into range_spec. |
239 | if (parser.parseComma()) { |
240 | return nullptr; |
241 | } |
242 | |
243 | // Parameter specification. |
244 | // For per-axis, ranges are in a {} delimitted list. |
245 | if (isPerAxis) { |
246 | if (parser.parseLBrace()) { |
247 | return nullptr; |
248 | } |
249 | } |
250 | |
251 | // Parse scales/zeroPoints. |
252 | SMLoc scaleZPLoc = parser.getCurrentLocation(); |
253 | do { |
254 | scales.resize(N: scales.size() + 1); |
255 | zeroPoints.resize(N: zeroPoints.size() + 1); |
256 | if (parseQuantParams(parser, scale&: scales.back(), zeroPoint&: zeroPoints.back())) { |
257 | return nullptr; |
258 | } |
259 | } while (isPerAxis && succeeded(result: parser.parseOptionalComma())); |
260 | |
261 | if (isPerAxis) { |
262 | if (parser.parseRBrace()) { |
263 | return nullptr; |
264 | } |
265 | } |
266 | |
267 | if (parser.parseGreater()) { |
268 | return nullptr; |
269 | } |
270 | |
271 | if (!isPerAxis && scales.size() > 1) { |
272 | return (parser.emitError(loc: scaleZPLoc, |
273 | message: "multiple scales/zeroPoints provided, but " |
274 | "quantizedDimension wasn't specified" ), |
275 | nullptr); |
276 | } |
277 | |
278 | if (isPerAxis) { |
279 | ArrayRef<double> scalesRef(scales.begin(), scales.end()); |
280 | ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); |
281 | return parser.getChecked<UniformQuantizedPerAxisType>( |
282 | typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, |
283 | quantizedDimension, storageTypeMin, storageTypeMax); |
284 | } |
285 | |
286 | return parser.getChecked<UniformQuantizedType>( |
287 | typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(), |
288 | storageTypeMin, storageTypeMax); |
289 | } |
290 | |
291 | /// Parses an CalibratedQuantizedType. |
292 | /// |
293 | /// calibrated ::= `calibrated<` expressed-spec `>` |
294 | /// expressed-spec ::= expressed-type `<` calibrated-range `>` |
295 | /// expressed-type ::= `f` integer-literal |
296 | /// calibrated-range ::= float-literal `:` float-literal |
297 | static Type parseCalibratedType(DialectAsmParser &parser) { |
298 | FloatType expressedType; |
299 | double min; |
300 | double max; |
301 | |
302 | // Type specification. |
303 | if (parser.parseLess()) |
304 | return nullptr; |
305 | |
306 | // Expressed type. |
307 | expressedType = parseExpressedTypeAndRange(parser, min, max); |
308 | if (!expressedType) { |
309 | return nullptr; |
310 | } |
311 | |
312 | if (parser.parseGreater()) { |
313 | return nullptr; |
314 | } |
315 | |
316 | return parser.getChecked<CalibratedQuantizedType>(params&: expressedType, params&: min, params&: max); |
317 | } |
318 | |
319 | /// Parse a type registered to this dialect. |
320 | Type QuantizationDialect::parseType(DialectAsmParser &parser) const { |
321 | // All types start with an identifier that we switch on. |
322 | StringRef typeNameSpelling; |
323 | if (failed(result: parser.parseKeyword(keyword: &typeNameSpelling))) |
324 | return nullptr; |
325 | |
326 | if (typeNameSpelling == "uniform" ) |
327 | return parseUniformType(parser); |
328 | if (typeNameSpelling == "any" ) |
329 | return parseAnyType(parser); |
330 | if (typeNameSpelling == "calibrated" ) |
331 | return parseCalibratedType(parser); |
332 | |
333 | parser.emitError(loc: parser.getNameLoc(), |
334 | message: "unknown quantized type " + typeNameSpelling); |
335 | return nullptr; |
336 | } |
337 | |
338 | static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { |
339 | // storage type |
340 | unsigned storageWidth = type.getStorageTypeIntegralWidth(); |
341 | bool isSigned = type.isSigned(); |
342 | if (isSigned) { |
343 | out << "i" << storageWidth; |
344 | } else { |
345 | out << "u" << storageWidth; |
346 | } |
347 | |
348 | // storageTypeMin and storageTypeMax if not default. |
349 | int64_t defaultIntegerMin = |
350 | QuantizedType::getDefaultMinimumForInteger(isSigned, integralWidth: storageWidth); |
351 | int64_t defaultIntegerMax = |
352 | QuantizedType::getDefaultMaximumForInteger(isSigned, integralWidth: storageWidth); |
353 | if (defaultIntegerMin != type.getStorageTypeMin() || |
354 | defaultIntegerMax != type.getStorageTypeMax()) { |
355 | out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax() |
356 | << ">" ; |
357 | } |
358 | } |
359 | |
360 | static void printQuantParams(double scale, int64_t zeroPoint, |
361 | DialectAsmPrinter &out) { |
362 | out << scale; |
363 | if (zeroPoint != 0) { |
364 | out << ":" << zeroPoint; |
365 | } |
366 | } |
367 | |
368 | /// Helper that prints a AnyQuantizedType. |
369 | static void printAnyQuantizedType(AnyQuantizedType type, |
370 | DialectAsmPrinter &out) { |
371 | out << "any<" ; |
372 | printStorageType(type, out); |
373 | if (Type expressedType = type.getExpressedType()) { |
374 | out << ":" << expressedType; |
375 | } |
376 | out << ">" ; |
377 | } |
378 | |
379 | /// Helper that prints a UniformQuantizedType. |
380 | static void printUniformQuantizedType(UniformQuantizedType type, |
381 | DialectAsmPrinter &out) { |
382 | out << "uniform<" ; |
383 | printStorageType(type, out); |
384 | out << ":" << type.getExpressedType() << ", " ; |
385 | |
386 | // scheme specific parameters |
387 | printQuantParams(scale: type.getScale(), zeroPoint: type.getZeroPoint(), out); |
388 | out << ">" ; |
389 | } |
390 | |
391 | /// Helper that prints a UniformQuantizedPerAxisType. |
392 | static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, |
393 | DialectAsmPrinter &out) { |
394 | out << "uniform<" ; |
395 | printStorageType(type, out); |
396 | out << ":" << type.getExpressedType() << ":" ; |
397 | out << type.getQuantizedDimension(); |
398 | out << ", " ; |
399 | |
400 | // scheme specific parameters |
401 | ArrayRef<double> scales = type.getScales(); |
402 | ArrayRef<int64_t> zeroPoints = type.getZeroPoints(); |
403 | out << "{" ; |
404 | llvm::interleave( |
405 | c: llvm::seq<size_t>(Begin: 0, End: scales.size()), os&: out, |
406 | each_fn: [&](size_t index) { |
407 | printQuantParams(scale: scales[index], zeroPoint: zeroPoints[index], out); |
408 | }, |
409 | separator: "," ); |
410 | out << "}>" ; |
411 | } |
412 | |
413 | /// Helper that prints a CalibratedQuantizedType. |
414 | static void printCalibratedQuantizedType(CalibratedQuantizedType type, |
415 | DialectAsmPrinter &out) { |
416 | out << "calibrated<" << type.getExpressedType(); |
417 | out << "<" << type.getMin() << ":" << type.getMax() << ">" ; |
418 | out << ">" ; |
419 | } |
420 | |
421 | /// Print a type registered to this dialect. |
422 | void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { |
423 | if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(Val&: type)) |
424 | printAnyQuantizedType(type: anyType, out&: os); |
425 | else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(Val&: type)) |
426 | printUniformQuantizedType(type: uniformType, out&: os); |
427 | else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(Val&: type)) |
428 | printUniformQuantizedPerAxisType(type: perAxisType, out&: os); |
429 | else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(Val&: type)) |
430 | printCalibratedQuantizedType(type: calibratedType, out&: os); |
431 | else |
432 | llvm_unreachable("Unhandled quantized type" ); |
433 | } |
434 | |