1 | //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/Dialect/Quant.h" |
10 | #include "mlir-c/IR.h" |
11 | #include "mlir/Bindings/Python/PybindAdaptors.h" |
12 | #include <cstdint> |
13 | #include <pybind11/cast.h> |
14 | #include <pybind11/detail/common.h> |
15 | #include <pybind11/pybind11.h> |
16 | #include <vector> |
17 | |
18 | namespace py = pybind11; |
19 | using namespace llvm; |
20 | using namespace mlir; |
21 | using namespace mlir::python::adaptors; |
22 | |
23 | static void populateDialectQuantSubmodule(const py::module &m) { |
24 | //===-------------------------------------------------------------------===// |
25 | // QuantizedType |
26 | //===-------------------------------------------------------------------===// |
27 | |
28 | auto quantizedType = |
29 | mlir_type_subclass(m, "QuantizedType" , mlirTypeIsAQuantizedType); |
30 | quantizedType.def_staticmethod( |
31 | "default_minimum_for_integer" , |
32 | [](bool isSigned, unsigned integralWidth) { |
33 | return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, |
34 | integralWidth); |
35 | }, |
36 | "Default minimum value for the integer with the specified signedness and " |
37 | "bit width." , |
38 | py::arg("is_signed" ), py::arg("integral_width" )); |
39 | quantizedType.def_staticmethod( |
40 | "default_maximum_for_integer" , |
41 | [](bool isSigned, unsigned integralWidth) { |
42 | return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, |
43 | integralWidth); |
44 | }, |
45 | "Default maximum value for the integer with the specified signedness and " |
46 | "bit width." , |
47 | py::arg("is_signed" ), py::arg("integral_width" )); |
48 | quantizedType.def_property_readonly( |
49 | "expressed_type" , |
50 | [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, |
51 | "Type expressed by this quantized type." ); |
52 | quantizedType.def_property_readonly( |
53 | "flags" , [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, |
54 | "Flags of this quantized type (named accessors should be preferred to " |
55 | "this)" ); |
56 | quantizedType.def_property_readonly( |
57 | "is_signed" , |
58 | [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, |
59 | "Signedness of this quantized type." ); |
60 | quantizedType.def_property_readonly( |
61 | "storage_type" , |
62 | [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, |
63 | "Storage type backing this quantized type." ); |
64 | quantizedType.def_property_readonly( |
65 | "storage_type_min" , |
66 | [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, |
67 | "The minimum value held by the storage type of this quantized type." ); |
68 | quantizedType.def_property_readonly( |
69 | "storage_type_max" , |
70 | [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, |
71 | "The maximum value held by the storage type of this quantized type." ); |
72 | quantizedType.def_property_readonly( |
73 | "storage_type_integral_width" , |
74 | [](MlirType type) { |
75 | return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); |
76 | }, |
77 | "The bitwidth of the storage type of this quantized type." ); |
78 | quantizedType.def( |
79 | "is_compatible_expressed_type" , |
80 | [](MlirType type, MlirType candidate) { |
81 | return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); |
82 | }, |
83 | "Checks whether the candidate type can be expressed by this quantized " |
84 | "type." , |
85 | py::arg("candidate" )); |
86 | quantizedType.def_property_readonly( |
87 | "quantized_element_type" , |
88 | [](MlirType type) { |
89 | return mlirQuantizedTypeGetQuantizedElementType(type); |
90 | }, |
91 | "Element type of this quantized type expressed as quantized type." ); |
92 | quantizedType.def( |
93 | "cast_from_storage_type" , |
94 | [](MlirType type, MlirType candidate) { |
95 | MlirType castResult = |
96 | mlirQuantizedTypeCastFromStorageType(type, candidate); |
97 | if (!mlirTypeIsNull(castResult)) |
98 | return castResult; |
99 | throw py::type_error("Invalid cast." ); |
100 | }, |
101 | "Casts from a type based on the storage type of this quantized type to a " |
102 | "corresponding type based on the quantized type. Raises TypeError if the " |
103 | "cast is not valid." , |
104 | py::arg("candidate" )); |
105 | quantizedType.def_staticmethod( |
106 | "cast_to_storage_type" , |
107 | [](MlirType type) { |
108 | MlirType castResult = mlirQuantizedTypeCastToStorageType(type); |
109 | if (!mlirTypeIsNull(castResult)) |
110 | return castResult; |
111 | throw py::type_error("Invalid cast." ); |
112 | }, |
113 | "Casts from a type based on a quantized type to a corresponding type " |
114 | "based on the storage type of this quantized type. Raises TypeError if " |
115 | "the cast is not valid." , |
116 | py::arg("type" )); |
117 | quantizedType.def( |
118 | "cast_from_expressed_type" , |
119 | [](MlirType type, MlirType candidate) { |
120 | MlirType castResult = |
121 | mlirQuantizedTypeCastFromExpressedType(type, candidate); |
122 | if (!mlirTypeIsNull(castResult)) |
123 | return castResult; |
124 | throw py::type_error("Invalid cast." ); |
125 | }, |
126 | "Casts from a type based on the expressed type of this quantized type to " |
127 | "a corresponding type based on the quantized type. Raises TypeError if " |
128 | "the cast is not valid." , |
129 | py::arg("candidate" )); |
130 | quantizedType.def_staticmethod( |
131 | "cast_to_expressed_type" , |
132 | [](MlirType type) { |
133 | MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); |
134 | if (!mlirTypeIsNull(castResult)) |
135 | return castResult; |
136 | throw py::type_error("Invalid cast." ); |
137 | }, |
138 | "Casts from a type based on a quantized type to a corresponding type " |
139 | "based on the expressed type of this quantized type. Raises TypeError if " |
140 | "the cast is not valid." , |
141 | py::arg("type" )); |
142 | quantizedType.def( |
143 | "cast_expressed_to_storage_type" , |
144 | [](MlirType type, MlirType candidate) { |
145 | MlirType castResult = |
146 | mlirQuantizedTypeCastExpressedToStorageType(type, candidate); |
147 | if (!mlirTypeIsNull(castResult)) |
148 | return castResult; |
149 | throw py::type_error("Invalid cast." ); |
150 | }, |
151 | "Casts from a type based on the expressed type of this quantized type to " |
152 | "a corresponding type based on the storage type. Raises TypeError if the " |
153 | "cast is not valid." , |
154 | py::arg("candidate" )); |
155 | |
156 | quantizedType.get_class().attr("FLAG_SIGNED" ) = |
157 | mlirQuantizedTypeGetSignedFlag(); |
158 | |
159 | //===-------------------------------------------------------------------===// |
160 | // AnyQuantizedType |
161 | //===-------------------------------------------------------------------===// |
162 | |
163 | auto anyQuantizedType = |
164 | mlir_type_subclass(m, "AnyQuantizedType" , mlirTypeIsAAnyQuantizedType, |
165 | quantizedType.get_class()); |
166 | anyQuantizedType.def_classmethod( |
167 | "get" , |
168 | [](py::object cls, unsigned flags, MlirType storageType, |
169 | MlirType expressedType, int64_t storageTypeMin, |
170 | int64_t storageTypeMax) { |
171 | return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, |
172 | storageTypeMin, storageTypeMax)); |
173 | }, |
174 | "Gets an instance of AnyQuantizedType in the same context as the " |
175 | "provided storage type." , |
176 | py::arg("cls" ), py::arg("flags" ), py::arg("storage_type" ), |
177 | py::arg("expressed_type" ), py::arg("storage_type_min" ), |
178 | py::arg("storage_type_max" )); |
179 | |
180 | //===-------------------------------------------------------------------===// |
181 | // UniformQuantizedType |
182 | //===-------------------------------------------------------------------===// |
183 | |
184 | auto uniformQuantizedType = mlir_type_subclass( |
185 | m, "UniformQuantizedType" , mlirTypeIsAUniformQuantizedType, |
186 | quantizedType.get_class()); |
187 | uniformQuantizedType.def_classmethod( |
188 | "get" , |
189 | [](py::object cls, unsigned flags, MlirType storageType, |
190 | MlirType expressedType, double scale, int64_t zeroPoint, |
191 | int64_t storageTypeMin, int64_t storageTypeMax) { |
192 | return cls(mlirUniformQuantizedTypeGet(flags, storageType, |
193 | expressedType, scale, zeroPoint, |
194 | storageTypeMin, storageTypeMax)); |
195 | }, |
196 | "Gets an instance of UniformQuantizedType in the same context as the " |
197 | "provided storage type." , |
198 | py::arg("cls" ), py::arg("flags" ), py::arg("storage_type" ), |
199 | py::arg("expressed_type" ), py::arg("scale" ), py::arg("zero_point" ), |
200 | py::arg("storage_type_min" ), py::arg("storage_type_max" )); |
201 | uniformQuantizedType.def_property_readonly( |
202 | "scale" , |
203 | [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, |
204 | "The scale designates the difference between the real values " |
205 | "corresponding to consecutive quantized values differing by 1." ); |
206 | uniformQuantizedType.def_property_readonly( |
207 | "zero_point" , |
208 | [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, |
209 | "The storage value corresponding to the real value 0 in the affine " |
210 | "equation." ); |
211 | uniformQuantizedType.def_property_readonly( |
212 | "is_fixed_point" , |
213 | [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, |
214 | "Fixed point values are real numbers divided by a scale." ); |
215 | |
216 | //===-------------------------------------------------------------------===// |
217 | // UniformQuantizedPerAxisType |
218 | //===-------------------------------------------------------------------===// |
219 | auto uniformQuantizedPerAxisType = mlir_type_subclass( |
220 | m, "UniformQuantizedPerAxisType" , mlirTypeIsAUniformQuantizedPerAxisType, |
221 | quantizedType.get_class()); |
222 | uniformQuantizedPerAxisType.def_classmethod( |
223 | "get" , |
224 | [](py::object cls, unsigned flags, MlirType storageType, |
225 | MlirType expressedType, std::vector<double> scales, |
226 | std::vector<int64_t> zeroPoints, int32_t quantizedDimension, |
227 | int64_t storageTypeMin, int64_t storageTypeMax) { |
228 | if (scales.size() != zeroPoints.size()) |
229 | throw py::value_error( |
230 | "Mismatching number of scales and zero points." ); |
231 | auto nDims = static_cast<intptr_t>(scales.size()); |
232 | return cls(mlirUniformQuantizedPerAxisTypeGet( |
233 | flags, storageType, expressedType, nDims, scales.data(), |
234 | zeroPoints.data(), quantizedDimension, storageTypeMin, |
235 | storageTypeMax)); |
236 | }, |
237 | "Gets an instance of UniformQuantizedPerAxisType in the same context as " |
238 | "the provided storage type." , |
239 | py::arg("cls" ), py::arg("flags" ), py::arg("storage_type" ), |
240 | py::arg("expressed_type" ), py::arg("scales" ), py::arg("zero_points" ), |
241 | py::arg("quantized_dimension" ), py::arg("storage_type_min" ), |
242 | py::arg("storage_type_max" )); |
243 | uniformQuantizedPerAxisType.def_property_readonly( |
244 | "scales" , |
245 | [](MlirType type) { |
246 | intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); |
247 | std::vector<double> scales; |
248 | scales.reserve(n: nDim); |
249 | for (intptr_t i = 0; i < nDim; ++i) { |
250 | double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); |
251 | scales.push_back(x: scale); |
252 | } |
253 | }, |
254 | "The scales designate the difference between the real values " |
255 | "corresponding to consecutive quantized values differing by 1. The ith " |
256 | "scale corresponds to the ith slice in the quantized_dimension." ); |
257 | uniformQuantizedPerAxisType.def_property_readonly( |
258 | "zero_points" , |
259 | [](MlirType type) { |
260 | intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); |
261 | std::vector<int64_t> zeroPoints; |
262 | zeroPoints.reserve(n: nDim); |
263 | for (intptr_t i = 0; i < nDim; ++i) { |
264 | int64_t zeroPoint = |
265 | mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); |
266 | zeroPoints.push_back(x: zeroPoint); |
267 | } |
268 | }, |
269 | "the storage values corresponding to the real value 0 in the affine " |
270 | "equation. The ith zero point corresponds to the ith slice in the " |
271 | "quantized_dimension." ); |
272 | uniformQuantizedPerAxisType.def_property_readonly( |
273 | "quantized_dimension" , |
274 | [](MlirType type) { |
275 | return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); |
276 | }, |
277 | "Specifies the dimension of the shape that the scales and zero points " |
278 | "correspond to." ); |
279 | uniformQuantizedPerAxisType.def_property_readonly( |
280 | "is_fixed_point" , |
281 | [](MlirType type) { |
282 | return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); |
283 | }, |
284 | "Fixed point values are real numbers divided by a scale." ); |
285 | |
286 | //===-------------------------------------------------------------------===// |
287 | // CalibratedQuantizedType |
288 | //===-------------------------------------------------------------------===// |
289 | |
290 | auto calibratedQuantizedType = mlir_type_subclass( |
291 | m, "CalibratedQuantizedType" , mlirTypeIsACalibratedQuantizedType, |
292 | quantizedType.get_class()); |
293 | calibratedQuantizedType.def_classmethod( |
294 | "get" , |
295 | [](py::object cls, MlirType expressedType, double min, double max) { |
296 | return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); |
297 | }, |
298 | "Gets an instance of CalibratedQuantizedType in the same context as the " |
299 | "provided expressed type." , |
300 | py::arg("cls" ), py::arg("expressed_type" ), py::arg("min" ), |
301 | py::arg("max" )); |
302 | calibratedQuantizedType.def_property_readonly("min" , [](MlirType type) { |
303 | return mlirCalibratedQuantizedTypeGetMin(type); |
304 | }); |
305 | calibratedQuantizedType.def_property_readonly("max" , [](MlirType type) { |
306 | return mlirCalibratedQuantizedTypeGetMax(type); |
307 | }); |
308 | } |
309 | |
310 | PYBIND11_MODULE(_mlirDialectsQuant, m) { |
311 | m.doc() = "MLIR Quantization dialect" ; |
312 | |
313 | populateDialectQuantSubmodule(m); |
314 | } |
315 | |