1 | //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <cstdint> |
10 | #include <vector> |
11 | |
12 | #include "mlir-c/BuiltinAttributes.h" |
13 | #include "mlir-c/Dialect/Quant.h" |
14 | #include "mlir-c/IR.h" |
15 | #include "mlir/Bindings/Python/Nanobind.h" |
16 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
17 | |
18 | namespace nb = nanobind; |
19 | using namespace llvm; |
20 | using namespace mlir; |
21 | using namespace mlir::python::nanobind_adaptors; |
22 | |
23 | static void populateDialectQuantSubmodule(const nb::module_ &m) { |
24 | //===-------------------------------------------------------------------===// |
25 | // QuantizedType |
26 | //===-------------------------------------------------------------------===// |
27 | |
28 | auto quantizedType = |
29 | mlir_type_subclass(m, "QuantizedType" , mlirTypeIsAQuantizedType); |
30 | quantizedType.def_staticmethod( |
31 | "default_minimum_for_integer" , |
32 | [](bool isSigned, unsigned integralWidth) { |
33 | return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, |
34 | integralWidth); |
35 | }, |
36 | "Default minimum value for the integer with the specified signedness and " |
37 | "bit width." , |
38 | nb::arg("is_signed" ), nb::arg("integral_width" )); |
39 | quantizedType.def_staticmethod( |
40 | "default_maximum_for_integer" , |
41 | [](bool isSigned, unsigned integralWidth) { |
42 | return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, |
43 | integralWidth); |
44 | }, |
45 | "Default maximum value for the integer with the specified signedness and " |
46 | "bit width." , |
47 | nb::arg("is_signed" ), nb::arg("integral_width" )); |
48 | quantizedType.def_property_readonly( |
49 | "expressed_type" , |
50 | [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, |
51 | "Type expressed by this quantized type." ); |
52 | quantizedType.def_property_readonly( |
53 | "flags" , [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, |
54 | "Flags of this quantized type (named accessors should be preferred to " |
55 | "this)" ); |
56 | quantizedType.def_property_readonly( |
57 | "is_signed" , |
58 | [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, |
59 | "Signedness of this quantized type." ); |
60 | quantizedType.def_property_readonly( |
61 | "storage_type" , |
62 | [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, |
63 | "Storage type backing this quantized type." ); |
64 | quantizedType.def_property_readonly( |
65 | "storage_type_min" , |
66 | [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, |
67 | "The minimum value held by the storage type of this quantized type." ); |
68 | quantizedType.def_property_readonly( |
69 | "storage_type_max" , |
70 | [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, |
71 | "The maximum value held by the storage type of this quantized type." ); |
72 | quantizedType.def_property_readonly( |
73 | "storage_type_integral_width" , |
74 | [](MlirType type) { |
75 | return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); |
76 | }, |
77 | "The bitwidth of the storage type of this quantized type." ); |
78 | quantizedType.def( |
79 | "is_compatible_expressed_type" , |
80 | [](MlirType type, MlirType candidate) { |
81 | return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); |
82 | }, |
83 | "Checks whether the candidate type can be expressed by this quantized " |
84 | "type." , |
85 | nb::arg("candidate" )); |
86 | quantizedType.def_property_readonly( |
87 | "quantized_element_type" , |
88 | [](MlirType type) { |
89 | return mlirQuantizedTypeGetQuantizedElementType(type); |
90 | }, |
91 | "Element type of this quantized type expressed as quantized type." ); |
92 | quantizedType.def( |
93 | "cast_from_storage_type" , |
94 | [](MlirType type, MlirType candidate) { |
95 | MlirType castResult = |
96 | mlirQuantizedTypeCastFromStorageType(type, candidate); |
97 | if (!mlirTypeIsNull(castResult)) |
98 | return castResult; |
99 | throw nb::type_error("Invalid cast." ); |
100 | }, |
101 | "Casts from a type based on the storage type of this quantized type to a " |
102 | "corresponding type based on the quantized type. Raises TypeError if the " |
103 | "cast is not valid." , |
104 | nb::arg("candidate" )); |
105 | quantizedType.def_staticmethod( |
106 | "cast_to_storage_type" , |
107 | [](MlirType type) { |
108 | MlirType castResult = mlirQuantizedTypeCastToStorageType(type); |
109 | if (!mlirTypeIsNull(castResult)) |
110 | return castResult; |
111 | throw nb::type_error("Invalid cast." ); |
112 | }, |
113 | "Casts from a type based on a quantized type to a corresponding type " |
114 | "based on the storage type of this quantized type. Raises TypeError if " |
115 | "the cast is not valid." , |
116 | nb::arg("type" )); |
117 | quantizedType.def( |
118 | "cast_from_expressed_type" , |
119 | [](MlirType type, MlirType candidate) { |
120 | MlirType castResult = |
121 | mlirQuantizedTypeCastFromExpressedType(type, candidate); |
122 | if (!mlirTypeIsNull(castResult)) |
123 | return castResult; |
124 | throw nb::type_error("Invalid cast." ); |
125 | }, |
126 | "Casts from a type based on the expressed type of this quantized type to " |
127 | "a corresponding type based on the quantized type. Raises TypeError if " |
128 | "the cast is not valid." , |
129 | nb::arg("candidate" )); |
130 | quantizedType.def_staticmethod( |
131 | "cast_to_expressed_type" , |
132 | [](MlirType type) { |
133 | MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); |
134 | if (!mlirTypeIsNull(castResult)) |
135 | return castResult; |
136 | throw nb::type_error("Invalid cast." ); |
137 | }, |
138 | "Casts from a type based on a quantized type to a corresponding type " |
139 | "based on the expressed type of this quantized type. Raises TypeError if " |
140 | "the cast is not valid." , |
141 | nb::arg("type" )); |
142 | quantizedType.def( |
143 | "cast_expressed_to_storage_type" , |
144 | [](MlirType type, MlirType candidate) { |
145 | MlirType castResult = |
146 | mlirQuantizedTypeCastExpressedToStorageType(type, candidate); |
147 | if (!mlirTypeIsNull(castResult)) |
148 | return castResult; |
149 | throw nb::type_error("Invalid cast." ); |
150 | }, |
151 | "Casts from a type based on the expressed type of this quantized type to " |
152 | "a corresponding type based on the storage type. Raises TypeError if the " |
153 | "cast is not valid." , |
154 | nb::arg("candidate" )); |
155 | |
156 | quantizedType.get_class().attr("FLAG_SIGNED" ) = |
157 | mlirQuantizedTypeGetSignedFlag(); |
158 | |
159 | //===-------------------------------------------------------------------===// |
160 | // AnyQuantizedType |
161 | //===-------------------------------------------------------------------===// |
162 | |
163 | auto anyQuantizedType = |
164 | mlir_type_subclass(m, "AnyQuantizedType" , mlirTypeIsAAnyQuantizedType, |
165 | quantizedType.get_class()); |
166 | anyQuantizedType.def_classmethod( |
167 | "get" , |
168 | [](nb::object cls, unsigned flags, MlirType storageType, |
169 | MlirType expressedType, int64_t storageTypeMin, |
170 | int64_t storageTypeMax) { |
171 | return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, |
172 | storageTypeMin, storageTypeMax)); |
173 | }, |
174 | "Gets an instance of AnyQuantizedType in the same context as the " |
175 | "provided storage type." , |
176 | nb::arg("cls" ), nb::arg("flags" ), nb::arg("storage_type" ), |
177 | nb::arg("expressed_type" ), nb::arg("storage_type_min" ), |
178 | nb::arg("storage_type_max" )); |
179 | |
180 | //===-------------------------------------------------------------------===// |
181 | // UniformQuantizedType |
182 | //===-------------------------------------------------------------------===// |
183 | |
184 | auto uniformQuantizedType = mlir_type_subclass( |
185 | m, "UniformQuantizedType" , mlirTypeIsAUniformQuantizedType, |
186 | quantizedType.get_class()); |
187 | uniformQuantizedType.def_classmethod( |
188 | "get" , |
189 | [](nb::object cls, unsigned flags, MlirType storageType, |
190 | MlirType expressedType, double scale, int64_t zeroPoint, |
191 | int64_t storageTypeMin, int64_t storageTypeMax) { |
192 | return cls(mlirUniformQuantizedTypeGet(flags, storageType, |
193 | expressedType, scale, zeroPoint, |
194 | storageTypeMin, storageTypeMax)); |
195 | }, |
196 | "Gets an instance of UniformQuantizedType in the same context as the " |
197 | "provided storage type." , |
198 | nb::arg("cls" ), nb::arg("flags" ), nb::arg("storage_type" ), |
199 | nb::arg("expressed_type" ), nb::arg("scale" ), nb::arg("zero_point" ), |
200 | nb::arg("storage_type_min" ), nb::arg("storage_type_max" )); |
201 | uniformQuantizedType.def_property_readonly( |
202 | "scale" , |
203 | [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, |
204 | "The scale designates the difference between the real values " |
205 | "corresponding to consecutive quantized values differing by 1." ); |
206 | uniformQuantizedType.def_property_readonly( |
207 | "zero_point" , |
208 | [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, |
209 | "The storage value corresponding to the real value 0 in the affine " |
210 | "equation." ); |
211 | uniformQuantizedType.def_property_readonly( |
212 | "is_fixed_point" , |
213 | [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, |
214 | "Fixed point values are real numbers divided by a scale." ); |
215 | |
216 | //===-------------------------------------------------------------------===// |
217 | // UniformQuantizedPerAxisType |
218 | //===-------------------------------------------------------------------===// |
219 | auto uniformQuantizedPerAxisType = mlir_type_subclass( |
220 | m, "UniformQuantizedPerAxisType" , mlirTypeIsAUniformQuantizedPerAxisType, |
221 | quantizedType.get_class()); |
222 | uniformQuantizedPerAxisType.def_classmethod( |
223 | "get" , |
224 | [](nb::object cls, unsigned flags, MlirType storageType, |
225 | MlirType expressedType, std::vector<double> scales, |
226 | std::vector<int64_t> zeroPoints, int32_t quantizedDimension, |
227 | int64_t storageTypeMin, int64_t storageTypeMax) { |
228 | if (scales.size() != zeroPoints.size()) |
229 | throw nb::value_error( |
230 | "Mismatching number of scales and zero points." ); |
231 | auto nDims = static_cast<intptr_t>(scales.size()); |
232 | return cls(mlirUniformQuantizedPerAxisTypeGet( |
233 | flags, storageType, expressedType, nDims, scales.data(), |
234 | zeroPoints.data(), quantizedDimension, storageTypeMin, |
235 | storageTypeMax)); |
236 | }, |
237 | "Gets an instance of UniformQuantizedPerAxisType in the same context as " |
238 | "the provided storage type." , |
239 | nb::arg("cls" ), nb::arg("flags" ), nb::arg("storage_type" ), |
240 | nb::arg("expressed_type" ), nb::arg("scales" ), nb::arg("zero_points" ), |
241 | nb::arg("quantized_dimension" ), nb::arg("storage_type_min" ), |
242 | nb::arg("storage_type_max" )); |
243 | uniformQuantizedPerAxisType.def_property_readonly( |
244 | "scales" , |
245 | [](MlirType type) { |
246 | intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); |
247 | std::vector<double> scales; |
248 | scales.reserve(n: nDim); |
249 | for (intptr_t i = 0; i < nDim; ++i) { |
250 | double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); |
251 | scales.push_back(x: scale); |
252 | } |
253 | return scales; |
254 | }, |
255 | "The scales designate the difference between the real values " |
256 | "corresponding to consecutive quantized values differing by 1. The ith " |
257 | "scale corresponds to the ith slice in the quantized_dimension." ); |
258 | uniformQuantizedPerAxisType.def_property_readonly( |
259 | "zero_points" , |
260 | [](MlirType type) { |
261 | intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); |
262 | std::vector<int64_t> zeroPoints; |
263 | zeroPoints.reserve(n: nDim); |
264 | for (intptr_t i = 0; i < nDim; ++i) { |
265 | int64_t zeroPoint = |
266 | mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); |
267 | zeroPoints.push_back(x: zeroPoint); |
268 | } |
269 | return zeroPoints; |
270 | }, |
271 | "the storage values corresponding to the real value 0 in the affine " |
272 | "equation. The ith zero point corresponds to the ith slice in the " |
273 | "quantized_dimension." ); |
274 | uniformQuantizedPerAxisType.def_property_readonly( |
275 | "quantized_dimension" , |
276 | [](MlirType type) { |
277 | return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); |
278 | }, |
279 | "Specifies the dimension of the shape that the scales and zero points " |
280 | "correspond to." ); |
281 | uniformQuantizedPerAxisType.def_property_readonly( |
282 | "is_fixed_point" , |
283 | [](MlirType type) { |
284 | return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); |
285 | }, |
286 | "Fixed point values are real numbers divided by a scale." ); |
287 | |
288 | //===-------------------------------------------------------------------===// |
289 | // UniformQuantizedSubChannelType |
290 | //===-------------------------------------------------------------------===// |
291 | auto uniformQuantizedSubChannelType = mlir_type_subclass( |
292 | m, "UniformQuantizedSubChannelType" , |
293 | mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class()); |
294 | uniformQuantizedSubChannelType.def_classmethod( |
295 | "get" , |
296 | [](nb::object cls, unsigned flags, MlirType storageType, |
297 | MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints, |
298 | std::vector<int32_t> quantizedDimensions, |
299 | std::vector<int64_t> blockSizes, int64_t storageTypeMin, |
300 | int64_t storageTypeMax) { |
301 | return cls(mlirUniformQuantizedSubChannelTypeGet( |
302 | flags, storageType, expressedType, scales, zeroPoints, |
303 | static_cast<intptr_t>(blockSizes.size()), |
304 | quantizedDimensions.data(), blockSizes.data(), storageTypeMin, |
305 | storageTypeMax)); |
306 | }, |
307 | "Gets an instance of UniformQuantizedSubChannel in the same context as " |
308 | "the provided storage type." , |
309 | nb::arg("cls" ), nb::arg("flags" ), nb::arg("storage_type" ), |
310 | nb::arg("expressed_type" ), nb::arg("scales" ), nb::arg("zero_points" ), |
311 | nb::arg("quantized_dimensions" ), nb::arg("block_sizes" ), |
312 | nb::arg("storage_type_min" ), nb::arg("storage_type_max" )); |
313 | uniformQuantizedSubChannelType.def_property_readonly( |
314 | "quantized_dimensions" , |
315 | [](MlirType type) { |
316 | intptr_t nDim = |
317 | mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); |
318 | std::vector<int32_t> quantizedDimensions; |
319 | quantizedDimensions.reserve(n: nDim); |
320 | for (intptr_t i = 0; i < nDim; ++i) { |
321 | quantizedDimensions.push_back( |
322 | mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i)); |
323 | } |
324 | return quantizedDimensions; |
325 | }, |
326 | "Gets the quantized dimensions. Each element in the returned list " |
327 | "represents an axis of the quantized data tensor that has a specified " |
328 | "block size. The order of elements corresponds to the order of block " |
329 | "sizes returned by 'block_sizes' method. It means that the data tensor " |
330 | "is quantized along the i-th dimension in the returned list using the " |
331 | "i-th block size from block_sizes method." ); |
332 | uniformQuantizedSubChannelType.def_property_readonly( |
333 | "block_sizes" , |
334 | [](MlirType type) { |
335 | intptr_t nDim = |
336 | mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); |
337 | std::vector<int64_t> blockSizes; |
338 | blockSizes.reserve(n: nDim); |
339 | for (intptr_t i = 0; i < nDim; ++i) { |
340 | blockSizes.push_back( |
341 | mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i)); |
342 | } |
343 | return blockSizes; |
344 | }, |
345 | "Gets the block sizes for the quantized dimensions. The i-th element in " |
346 | "the returned list corresponds to the block size for the i-th dimension " |
347 | "in the list returned by quantized_dimensions method." ); |
348 | uniformQuantizedSubChannelType.def_property_readonly( |
349 | "scales" , |
350 | [](MlirType type) -> MlirAttribute { |
351 | return mlirUniformQuantizedSubChannelTypeGetScales(type); |
352 | }, |
353 | "The scales of the quantized type." ); |
354 | uniformQuantizedSubChannelType.def_property_readonly( |
355 | "zero_points" , |
356 | [](MlirType type) -> MlirAttribute { |
357 | return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type); |
358 | }, |
359 | "The zero points of the quantized type." ); |
360 | |
361 | //===-------------------------------------------------------------------===// |
362 | // CalibratedQuantizedType |
363 | //===-------------------------------------------------------------------===// |
364 | |
365 | auto calibratedQuantizedType = mlir_type_subclass( |
366 | m, "CalibratedQuantizedType" , mlirTypeIsACalibratedQuantizedType, |
367 | quantizedType.get_class()); |
368 | calibratedQuantizedType.def_classmethod( |
369 | "get" , |
370 | [](nb::object cls, MlirType expressedType, double min, double max) { |
371 | return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); |
372 | }, |
373 | "Gets an instance of CalibratedQuantizedType in the same context as the " |
374 | "provided expressed type." , |
375 | nb::arg("cls" ), nb::arg("expressed_type" ), nb::arg("min" ), |
376 | nb::arg("max" )); |
377 | calibratedQuantizedType.def_property_readonly("min" , [](MlirType type) { |
378 | return mlirCalibratedQuantizedTypeGetMin(type); |
379 | }); |
380 | calibratedQuantizedType.def_property_readonly("max" , [](MlirType type) { |
381 | return mlirCalibratedQuantizedTypeGetMax(type); |
382 | }); |
383 | } |
384 | |
385 | NB_MODULE(_mlirDialectsQuant, m) { |
386 | m.doc() = "MLIR Quantization dialect" ; |
387 | |
388 | populateDialectQuantSubmodule(m); |
389 | } |
390 | |