1//===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cstdint>
10#include <vector>
11
12#include "mlir-c/BuiltinAttributes.h"
13#include "mlir-c/Dialect/Quant.h"
14#include "mlir-c/IR.h"
15#include "mlir/Bindings/Python/Nanobind.h"
16#include "mlir/Bindings/Python/NanobindAdaptors.h"
17
18namespace nb = nanobind;
19using namespace llvm;
20using namespace mlir;
21using namespace mlir::python::nanobind_adaptors;
22
23static void populateDialectQuantSubmodule(const nb::module_ &m) {
24 //===-------------------------------------------------------------------===//
25 // QuantizedType
26 //===-------------------------------------------------------------------===//
27
28 auto quantizedType =
29 mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
30 quantizedType.def_staticmethod(
31 "default_minimum_for_integer",
32 [](bool isSigned, unsigned integralWidth) {
33 return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned,
34 integralWidth);
35 },
36 "Default minimum value for the integer with the specified signedness and "
37 "bit width.",
38 nb::arg("is_signed"), nb::arg("integral_width"));
39 quantizedType.def_staticmethod(
40 "default_maximum_for_integer",
41 [](bool isSigned, unsigned integralWidth) {
42 return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned,
43 integralWidth);
44 },
45 "Default maximum value for the integer with the specified signedness and "
46 "bit width.",
47 nb::arg("is_signed"), nb::arg("integral_width"));
48 quantizedType.def_property_readonly(
49 "expressed_type",
50 [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); },
51 "Type expressed by this quantized type.");
52 quantizedType.def_property_readonly(
53 "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); },
54 "Flags of this quantized type (named accessors should be preferred to "
55 "this)");
56 quantizedType.def_property_readonly(
57 "is_signed",
58 [](MlirType type) { return mlirQuantizedTypeIsSigned(type); },
59 "Signedness of this quantized type.");
60 quantizedType.def_property_readonly(
61 "storage_type",
62 [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); },
63 "Storage type backing this quantized type.");
64 quantizedType.def_property_readonly(
65 "storage_type_min",
66 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); },
67 "The minimum value held by the storage type of this quantized type.");
68 quantizedType.def_property_readonly(
69 "storage_type_max",
70 [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); },
71 "The maximum value held by the storage type of this quantized type.");
72 quantizedType.def_property_readonly(
73 "storage_type_integral_width",
74 [](MlirType type) {
75 return mlirQuantizedTypeGetStorageTypeIntegralWidth(type);
76 },
77 "The bitwidth of the storage type of this quantized type.");
78 quantizedType.def(
79 "is_compatible_expressed_type",
80 [](MlirType type, MlirType candidate) {
81 return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate);
82 },
83 "Checks whether the candidate type can be expressed by this quantized "
84 "type.",
85 nb::arg("candidate"));
86 quantizedType.def_property_readonly(
87 "quantized_element_type",
88 [](MlirType type) {
89 return mlirQuantizedTypeGetQuantizedElementType(type);
90 },
91 "Element type of this quantized type expressed as quantized type.");
92 quantizedType.def(
93 "cast_from_storage_type",
94 [](MlirType type, MlirType candidate) {
95 MlirType castResult =
96 mlirQuantizedTypeCastFromStorageType(type, candidate);
97 if (!mlirTypeIsNull(castResult))
98 return castResult;
99 throw nb::type_error("Invalid cast.");
100 },
101 "Casts from a type based on the storage type of this quantized type to a "
102 "corresponding type based on the quantized type. Raises TypeError if the "
103 "cast is not valid.",
104 nb::arg("candidate"));
105 quantizedType.def_staticmethod(
106 "cast_to_storage_type",
107 [](MlirType type) {
108 MlirType castResult = mlirQuantizedTypeCastToStorageType(type);
109 if (!mlirTypeIsNull(castResult))
110 return castResult;
111 throw nb::type_error("Invalid cast.");
112 },
113 "Casts from a type based on a quantized type to a corresponding type "
114 "based on the storage type of this quantized type. Raises TypeError if "
115 "the cast is not valid.",
116 nb::arg("type"));
117 quantizedType.def(
118 "cast_from_expressed_type",
119 [](MlirType type, MlirType candidate) {
120 MlirType castResult =
121 mlirQuantizedTypeCastFromExpressedType(type, candidate);
122 if (!mlirTypeIsNull(castResult))
123 return castResult;
124 throw nb::type_error("Invalid cast.");
125 },
126 "Casts from a type based on the expressed type of this quantized type to "
127 "a corresponding type based on the quantized type. Raises TypeError if "
128 "the cast is not valid.",
129 nb::arg("candidate"));
130 quantizedType.def_staticmethod(
131 "cast_to_expressed_type",
132 [](MlirType type) {
133 MlirType castResult = mlirQuantizedTypeCastToExpressedType(type);
134 if (!mlirTypeIsNull(castResult))
135 return castResult;
136 throw nb::type_error("Invalid cast.");
137 },
138 "Casts from a type based on a quantized type to a corresponding type "
139 "based on the expressed type of this quantized type. Raises TypeError if "
140 "the cast is not valid.",
141 nb::arg("type"));
142 quantizedType.def(
143 "cast_expressed_to_storage_type",
144 [](MlirType type, MlirType candidate) {
145 MlirType castResult =
146 mlirQuantizedTypeCastExpressedToStorageType(type, candidate);
147 if (!mlirTypeIsNull(castResult))
148 return castResult;
149 throw nb::type_error("Invalid cast.");
150 },
151 "Casts from a type based on the expressed type of this quantized type to "
152 "a corresponding type based on the storage type. Raises TypeError if the "
153 "cast is not valid.",
154 nb::arg("candidate"));
155
156 quantizedType.get_class().attr("FLAG_SIGNED") =
157 mlirQuantizedTypeGetSignedFlag();
158
159 //===-------------------------------------------------------------------===//
160 // AnyQuantizedType
161 //===-------------------------------------------------------------------===//
162
163 auto anyQuantizedType =
164 mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType,
165 quantizedType.get_class());
166 anyQuantizedType.def_classmethod(
167 "get",
168 [](nb::object cls, unsigned flags, MlirType storageType,
169 MlirType expressedType, int64_t storageTypeMin,
170 int64_t storageTypeMax) {
171 return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType,
172 storageTypeMin, storageTypeMax));
173 },
174 "Gets an instance of AnyQuantizedType in the same context as the "
175 "provided storage type.",
176 nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
177 nb::arg("expressed_type"), nb::arg("storage_type_min"),
178 nb::arg("storage_type_max"));
179
180 //===-------------------------------------------------------------------===//
181 // UniformQuantizedType
182 //===-------------------------------------------------------------------===//
183
184 auto uniformQuantizedType = mlir_type_subclass(
185 m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType,
186 quantizedType.get_class());
187 uniformQuantizedType.def_classmethod(
188 "get",
189 [](nb::object cls, unsigned flags, MlirType storageType,
190 MlirType expressedType, double scale, int64_t zeroPoint,
191 int64_t storageTypeMin, int64_t storageTypeMax) {
192 return cls(mlirUniformQuantizedTypeGet(flags, storageType,
193 expressedType, scale, zeroPoint,
194 storageTypeMin, storageTypeMax));
195 },
196 "Gets an instance of UniformQuantizedType in the same context as the "
197 "provided storage type.",
198 nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
199 nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"),
200 nb::arg("storage_type_min"), nb::arg("storage_type_max"));
201 uniformQuantizedType.def_property_readonly(
202 "scale",
203 [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); },
204 "The scale designates the difference between the real values "
205 "corresponding to consecutive quantized values differing by 1.");
206 uniformQuantizedType.def_property_readonly(
207 "zero_point",
208 [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); },
209 "The storage value corresponding to the real value 0 in the affine "
210 "equation.");
211 uniformQuantizedType.def_property_readonly(
212 "is_fixed_point",
213 [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); },
214 "Fixed point values are real numbers divided by a scale.");
215
216 //===-------------------------------------------------------------------===//
217 // UniformQuantizedPerAxisType
218 //===-------------------------------------------------------------------===//
219 auto uniformQuantizedPerAxisType = mlir_type_subclass(
220 m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType,
221 quantizedType.get_class());
222 uniformQuantizedPerAxisType.def_classmethod(
223 "get",
224 [](nb::object cls, unsigned flags, MlirType storageType,
225 MlirType expressedType, std::vector<double> scales,
226 std::vector<int64_t> zeroPoints, int32_t quantizedDimension,
227 int64_t storageTypeMin, int64_t storageTypeMax) {
228 if (scales.size() != zeroPoints.size())
229 throw nb::value_error(
230 "Mismatching number of scales and zero points.");
231 auto nDims = static_cast<intptr_t>(scales.size());
232 return cls(mlirUniformQuantizedPerAxisTypeGet(
233 flags, storageType, expressedType, nDims, scales.data(),
234 zeroPoints.data(), quantizedDimension, storageTypeMin,
235 storageTypeMax));
236 },
237 "Gets an instance of UniformQuantizedPerAxisType in the same context as "
238 "the provided storage type.",
239 nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
240 nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
241 nb::arg("quantized_dimension"), nb::arg("storage_type_min"),
242 nb::arg("storage_type_max"));
243 uniformQuantizedPerAxisType.def_property_readonly(
244 "scales",
245 [](MlirType type) {
246 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
247 std::vector<double> scales;
248 scales.reserve(n: nDim);
249 for (intptr_t i = 0; i < nDim; ++i) {
250 double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i);
251 scales.push_back(x: scale);
252 }
253 return scales;
254 },
255 "The scales designate the difference between the real values "
256 "corresponding to consecutive quantized values differing by 1. The ith "
257 "scale corresponds to the ith slice in the quantized_dimension.");
258 uniformQuantizedPerAxisType.def_property_readonly(
259 "zero_points",
260 [](MlirType type) {
261 intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type);
262 std::vector<int64_t> zeroPoints;
263 zeroPoints.reserve(n: nDim);
264 for (intptr_t i = 0; i < nDim; ++i) {
265 int64_t zeroPoint =
266 mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i);
267 zeroPoints.push_back(x: zeroPoint);
268 }
269 return zeroPoints;
270 },
271 "the storage values corresponding to the real value 0 in the affine "
272 "equation. The ith zero point corresponds to the ith slice in the "
273 "quantized_dimension.");
274 uniformQuantizedPerAxisType.def_property_readonly(
275 "quantized_dimension",
276 [](MlirType type) {
277 return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type);
278 },
279 "Specifies the dimension of the shape that the scales and zero points "
280 "correspond to.");
281 uniformQuantizedPerAxisType.def_property_readonly(
282 "is_fixed_point",
283 [](MlirType type) {
284 return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type);
285 },
286 "Fixed point values are real numbers divided by a scale.");
287
288 //===-------------------------------------------------------------------===//
289 // UniformQuantizedSubChannelType
290 //===-------------------------------------------------------------------===//
291 auto uniformQuantizedSubChannelType = mlir_type_subclass(
292 m, "UniformQuantizedSubChannelType",
293 mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class());
294 uniformQuantizedSubChannelType.def_classmethod(
295 "get",
296 [](nb::object cls, unsigned flags, MlirType storageType,
297 MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints,
298 std::vector<int32_t> quantizedDimensions,
299 std::vector<int64_t> blockSizes, int64_t storageTypeMin,
300 int64_t storageTypeMax) {
301 return cls(mlirUniformQuantizedSubChannelTypeGet(
302 flags, storageType, expressedType, scales, zeroPoints,
303 static_cast<intptr_t>(blockSizes.size()),
304 quantizedDimensions.data(), blockSizes.data(), storageTypeMin,
305 storageTypeMax));
306 },
307 "Gets an instance of UniformQuantizedSubChannel in the same context as "
308 "the provided storage type.",
309 nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"),
310 nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"),
311 nb::arg("quantized_dimensions"), nb::arg("block_sizes"),
312 nb::arg("storage_type_min"), nb::arg("storage_type_max"));
313 uniformQuantizedSubChannelType.def_property_readonly(
314 "quantized_dimensions",
315 [](MlirType type) {
316 intptr_t nDim =
317 mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
318 std::vector<int32_t> quantizedDimensions;
319 quantizedDimensions.reserve(n: nDim);
320 for (intptr_t i = 0; i < nDim; ++i) {
321 quantizedDimensions.push_back(
322 mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i));
323 }
324 return quantizedDimensions;
325 },
326 "Gets the quantized dimensions. Each element in the returned list "
327 "represents an axis of the quantized data tensor that has a specified "
328 "block size. The order of elements corresponds to the order of block "
329 "sizes returned by 'block_sizes' method. It means that the data tensor "
330 "is quantized along the i-th dimension in the returned list using the "
331 "i-th block size from block_sizes method.");
332 uniformQuantizedSubChannelType.def_property_readonly(
333 "block_sizes",
334 [](MlirType type) {
335 intptr_t nDim =
336 mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type);
337 std::vector<int64_t> blockSizes;
338 blockSizes.reserve(n: nDim);
339 for (intptr_t i = 0; i < nDim; ++i) {
340 blockSizes.push_back(
341 mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i));
342 }
343 return blockSizes;
344 },
345 "Gets the block sizes for the quantized dimensions. The i-th element in "
346 "the returned list corresponds to the block size for the i-th dimension "
347 "in the list returned by quantized_dimensions method.");
348 uniformQuantizedSubChannelType.def_property_readonly(
349 "scales",
350 [](MlirType type) -> MlirAttribute {
351 return mlirUniformQuantizedSubChannelTypeGetScales(type);
352 },
353 "The scales of the quantized type.");
354 uniformQuantizedSubChannelType.def_property_readonly(
355 "zero_points",
356 [](MlirType type) -> MlirAttribute {
357 return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type);
358 },
359 "The zero points of the quantized type.");
360
361 //===-------------------------------------------------------------------===//
362 // CalibratedQuantizedType
363 //===-------------------------------------------------------------------===//
364
365 auto calibratedQuantizedType = mlir_type_subclass(
366 m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType,
367 quantizedType.get_class());
368 calibratedQuantizedType.def_classmethod(
369 "get",
370 [](nb::object cls, MlirType expressedType, double min, double max) {
371 return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max));
372 },
373 "Gets an instance of CalibratedQuantizedType in the same context as the "
374 "provided expressed type.",
375 nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"),
376 nb::arg("max"));
377 calibratedQuantizedType.def_property_readonly("min", [](MlirType type) {
378 return mlirCalibratedQuantizedTypeGetMin(type);
379 });
380 calibratedQuantizedType.def_property_readonly("max", [](MlirType type) {
381 return mlirCalibratedQuantizedTypeGetMax(type);
382 });
383}
384
385NB_MODULE(_mlirDialectsQuant, m) {
386 m.doc() = "MLIR Quantization dialect";
387
388 populateDialectQuantSubmodule(m);
389}
390

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Bindings/Python/DialectQuant.cpp