| 1 | //===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include <cstdint> |
| 10 | #include <vector> |
| 11 | |
| 12 | #include "mlir-c/BuiltinAttributes.h" |
| 13 | #include "mlir-c/Dialect/Quant.h" |
| 14 | #include "mlir-c/IR.h" |
| 15 | #include "mlir/Bindings/Python/Nanobind.h" |
| 16 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
| 17 | |
| 18 | namespace nb = nanobind; |
| 19 | using namespace llvm; |
| 20 | using namespace mlir; |
| 21 | using namespace mlir::python::nanobind_adaptors; |
| 22 | |
| 23 | static void populateDialectQuantSubmodule(const nb::module_ &m) { |
| 24 | //===-------------------------------------------------------------------===// |
| 25 | // QuantizedType |
| 26 | //===-------------------------------------------------------------------===// |
| 27 | |
| 28 | auto quantizedType = |
| 29 | mlir_type_subclass(m, "QuantizedType" , mlirTypeIsAQuantizedType); |
| 30 | quantizedType.def_staticmethod( |
| 31 | "default_minimum_for_integer" , |
| 32 | [](bool isSigned, unsigned integralWidth) { |
| 33 | return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, |
| 34 | integralWidth); |
| 35 | }, |
| 36 | "Default minimum value for the integer with the specified signedness and " |
| 37 | "bit width." , |
| 38 | nb::arg("is_signed" ), nb::arg("integral_width" )); |
| 39 | quantizedType.def_staticmethod( |
| 40 | "default_maximum_for_integer" , |
| 41 | [](bool isSigned, unsigned integralWidth) { |
| 42 | return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, |
| 43 | integralWidth); |
| 44 | }, |
| 45 | "Default maximum value for the integer with the specified signedness and " |
| 46 | "bit width." , |
| 47 | nb::arg("is_signed" ), nb::arg("integral_width" )); |
| 48 | quantizedType.def_property_readonly( |
| 49 | "expressed_type" , |
| 50 | [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, |
| 51 | "Type expressed by this quantized type." ); |
| 52 | quantizedType.def_property_readonly( |
| 53 | "flags" , [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, |
| 54 | "Flags of this quantized type (named accessors should be preferred to " |
| 55 | "this)" ); |
| 56 | quantizedType.def_property_readonly( |
| 57 | "is_signed" , |
| 58 | [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, |
| 59 | "Signedness of this quantized type." ); |
| 60 | quantizedType.def_property_readonly( |
| 61 | "storage_type" , |
| 62 | [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, |
| 63 | "Storage type backing this quantized type." ); |
| 64 | quantizedType.def_property_readonly( |
| 65 | "storage_type_min" , |
| 66 | [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, |
| 67 | "The minimum value held by the storage type of this quantized type." ); |
| 68 | quantizedType.def_property_readonly( |
| 69 | "storage_type_max" , |
| 70 | [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, |
| 71 | "The maximum value held by the storage type of this quantized type." ); |
| 72 | quantizedType.def_property_readonly( |
| 73 | "storage_type_integral_width" , |
| 74 | [](MlirType type) { |
| 75 | return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); |
| 76 | }, |
| 77 | "The bitwidth of the storage type of this quantized type." ); |
| 78 | quantizedType.def( |
| 79 | "is_compatible_expressed_type" , |
| 80 | [](MlirType type, MlirType candidate) { |
| 81 | return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); |
| 82 | }, |
| 83 | "Checks whether the candidate type can be expressed by this quantized " |
| 84 | "type." , |
| 85 | nb::arg("candidate" )); |
| 86 | quantizedType.def_property_readonly( |
| 87 | "quantized_element_type" , |
| 88 | [](MlirType type) { |
| 89 | return mlirQuantizedTypeGetQuantizedElementType(type); |
| 90 | }, |
| 91 | "Element type of this quantized type expressed as quantized type." ); |
| 92 | quantizedType.def( |
| 93 | "cast_from_storage_type" , |
| 94 | [](MlirType type, MlirType candidate) { |
| 95 | MlirType castResult = |
| 96 | mlirQuantizedTypeCastFromStorageType(type, candidate); |
| 97 | if (!mlirTypeIsNull(castResult)) |
| 98 | return castResult; |
| 99 | throw nb::type_error("Invalid cast." ); |
| 100 | }, |
| 101 | "Casts from a type based on the storage type of this quantized type to a " |
| 102 | "corresponding type based on the quantized type. Raises TypeError if the " |
| 103 | "cast is not valid." , |
| 104 | nb::arg("candidate" )); |
| 105 | quantizedType.def_staticmethod( |
| 106 | "cast_to_storage_type" , |
| 107 | [](MlirType type) { |
| 108 | MlirType castResult = mlirQuantizedTypeCastToStorageType(type); |
| 109 | if (!mlirTypeIsNull(castResult)) |
| 110 | return castResult; |
| 111 | throw nb::type_error("Invalid cast." ); |
| 112 | }, |
| 113 | "Casts from a type based on a quantized type to a corresponding type " |
| 114 | "based on the storage type of this quantized type. Raises TypeError if " |
| 115 | "the cast is not valid." , |
| 116 | nb::arg("type" )); |
| 117 | quantizedType.def( |
| 118 | "cast_from_expressed_type" , |
| 119 | [](MlirType type, MlirType candidate) { |
| 120 | MlirType castResult = |
| 121 | mlirQuantizedTypeCastFromExpressedType(type, candidate); |
| 122 | if (!mlirTypeIsNull(castResult)) |
| 123 | return castResult; |
| 124 | throw nb::type_error("Invalid cast." ); |
| 125 | }, |
| 126 | "Casts from a type based on the expressed type of this quantized type to " |
| 127 | "a corresponding type based on the quantized type. Raises TypeError if " |
| 128 | "the cast is not valid." , |
| 129 | nb::arg("candidate" )); |
| 130 | quantizedType.def_staticmethod( |
| 131 | "cast_to_expressed_type" , |
| 132 | [](MlirType type) { |
| 133 | MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); |
| 134 | if (!mlirTypeIsNull(castResult)) |
| 135 | return castResult; |
| 136 | throw nb::type_error("Invalid cast." ); |
| 137 | }, |
| 138 | "Casts from a type based on a quantized type to a corresponding type " |
| 139 | "based on the expressed type of this quantized type. Raises TypeError if " |
| 140 | "the cast is not valid." , |
| 141 | nb::arg("type" )); |
| 142 | quantizedType.def( |
| 143 | "cast_expressed_to_storage_type" , |
| 144 | [](MlirType type, MlirType candidate) { |
| 145 | MlirType castResult = |
| 146 | mlirQuantizedTypeCastExpressedToStorageType(type, candidate); |
| 147 | if (!mlirTypeIsNull(castResult)) |
| 148 | return castResult; |
| 149 | throw nb::type_error("Invalid cast." ); |
| 150 | }, |
| 151 | "Casts from a type based on the expressed type of this quantized type to " |
| 152 | "a corresponding type based on the storage type. Raises TypeError if the " |
| 153 | "cast is not valid." , |
| 154 | nb::arg("candidate" )); |
| 155 | |
| 156 | quantizedType.get_class().attr("FLAG_SIGNED" ) = |
| 157 | mlirQuantizedTypeGetSignedFlag(); |
| 158 | |
| 159 | //===-------------------------------------------------------------------===// |
| 160 | // AnyQuantizedType |
| 161 | //===-------------------------------------------------------------------===// |
| 162 | |
| 163 | auto anyQuantizedType = |
| 164 | mlir_type_subclass(m, "AnyQuantizedType" , mlirTypeIsAAnyQuantizedType, |
| 165 | quantizedType.get_class()); |
| 166 | anyQuantizedType.def_classmethod( |
| 167 | "get" , |
| 168 | [](nb::object cls, unsigned flags, MlirType storageType, |
| 169 | MlirType expressedType, int64_t storageTypeMin, |
| 170 | int64_t storageTypeMax) { |
| 171 | return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, |
| 172 | storageTypeMin, storageTypeMax)); |
| 173 | }, |
| 174 | "Gets an instance of AnyQuantizedType in the same context as the " |
| 175 | "provided storage type." , |
| 176 | nb::arg("cls" ), nb::arg("flags" ), nb::arg("storage_type" ), |
| 177 | nb::arg("expressed_type" ), nb::arg("storage_type_min" ), |
| 178 | nb::arg("storage_type_max" )); |
| 179 | |
| 180 | //===-------------------------------------------------------------------===// |
| 181 | // UniformQuantizedType |
| 182 | //===-------------------------------------------------------------------===// |
| 183 | |
| 184 | auto uniformQuantizedType = mlir_type_subclass( |
| 185 | m, "UniformQuantizedType" , mlirTypeIsAUniformQuantizedType, |
| 186 | quantizedType.get_class()); |
| 187 | uniformQuantizedType.def_classmethod( |
| 188 | "get" , |
| 189 | [](nb::object cls, unsigned flags, MlirType storageType, |
| 190 | MlirType expressedType, double scale, int64_t zeroPoint, |
| 191 | int64_t storageTypeMin, int64_t storageTypeMax) { |
| 192 | return cls(mlirUniformQuantizedTypeGet(flags, storageType, |
| 193 | expressedType, scale, zeroPoint, |
| 194 | storageTypeMin, storageTypeMax)); |
| 195 | }, |
| 196 | "Gets an instance of UniformQuantizedType in the same context as the " |
| 197 | "provided storage type." , |
| 198 | nb::arg("cls" ), nb::arg("flags" ), nb::arg("storage_type" ), |
| 199 | nb::arg("expressed_type" ), nb::arg("scale" ), nb::arg("zero_point" ), |
| 200 | nb::arg("storage_type_min" ), nb::arg("storage_type_max" )); |
| 201 | uniformQuantizedType.def_property_readonly( |
| 202 | "scale" , |
| 203 | [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, |
| 204 | "The scale designates the difference between the real values " |
| 205 | "corresponding to consecutive quantized values differing by 1." ); |
| 206 | uniformQuantizedType.def_property_readonly( |
| 207 | "zero_point" , |
| 208 | [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, |
| 209 | "The storage value corresponding to the real value 0 in the affine " |
| 210 | "equation." ); |
| 211 | uniformQuantizedType.def_property_readonly( |
| 212 | "is_fixed_point" , |
| 213 | [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, |
| 214 | "Fixed point values are real numbers divided by a scale." ); |
| 215 | |
| 216 | //===-------------------------------------------------------------------===// |
| 217 | // UniformQuantizedPerAxisType |
| 218 | //===-------------------------------------------------------------------===// |
| 219 | auto uniformQuantizedPerAxisType = mlir_type_subclass( |
| 220 | m, "UniformQuantizedPerAxisType" , mlirTypeIsAUniformQuantizedPerAxisType, |
| 221 | quantizedType.get_class()); |
| 222 | uniformQuantizedPerAxisType.def_classmethod( |
| 223 | "get" , |
| 224 | [](nb::object cls, unsigned flags, MlirType storageType, |
| 225 | MlirType expressedType, std::vector<double> scales, |
| 226 | std::vector<int64_t> zeroPoints, int32_t quantizedDimension, |
| 227 | int64_t storageTypeMin, int64_t storageTypeMax) { |
| 228 | if (scales.size() != zeroPoints.size()) |
| 229 | throw nb::value_error( |
| 230 | "Mismatching number of scales and zero points." ); |
| 231 | auto nDims = static_cast<intptr_t>(scales.size()); |
| 232 | return cls(mlirUniformQuantizedPerAxisTypeGet( |
| 233 | flags, storageType, expressedType, nDims, scales.data(), |
| 234 | zeroPoints.data(), quantizedDimension, storageTypeMin, |
| 235 | storageTypeMax)); |
| 236 | }, |
| 237 | "Gets an instance of UniformQuantizedPerAxisType in the same context as " |
| 238 | "the provided storage type." , |
| 239 | nb::arg("cls" ), nb::arg("flags" ), nb::arg("storage_type" ), |
| 240 | nb::arg("expressed_type" ), nb::arg("scales" ), nb::arg("zero_points" ), |
| 241 | nb::arg("quantized_dimension" ), nb::arg("storage_type_min" ), |
| 242 | nb::arg("storage_type_max" )); |
| 243 | uniformQuantizedPerAxisType.def_property_readonly( |
| 244 | "scales" , |
| 245 | [](MlirType type) { |
| 246 | intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); |
| 247 | std::vector<double> scales; |
| 248 | scales.reserve(n: nDim); |
| 249 | for (intptr_t i = 0; i < nDim; ++i) { |
| 250 | double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); |
| 251 | scales.push_back(x: scale); |
| 252 | } |
| 253 | return scales; |
| 254 | }, |
| 255 | "The scales designate the difference between the real values " |
| 256 | "corresponding to consecutive quantized values differing by 1. The ith " |
| 257 | "scale corresponds to the ith slice in the quantized_dimension." ); |
| 258 | uniformQuantizedPerAxisType.def_property_readonly( |
| 259 | "zero_points" , |
| 260 | [](MlirType type) { |
| 261 | intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); |
| 262 | std::vector<int64_t> zeroPoints; |
| 263 | zeroPoints.reserve(n: nDim); |
| 264 | for (intptr_t i = 0; i < nDim; ++i) { |
| 265 | int64_t zeroPoint = |
| 266 | mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); |
| 267 | zeroPoints.push_back(x: zeroPoint); |
| 268 | } |
| 269 | return zeroPoints; |
| 270 | }, |
| 271 | "the storage values corresponding to the real value 0 in the affine " |
| 272 | "equation. The ith zero point corresponds to the ith slice in the " |
| 273 | "quantized_dimension." ); |
| 274 | uniformQuantizedPerAxisType.def_property_readonly( |
| 275 | "quantized_dimension" , |
| 276 | [](MlirType type) { |
| 277 | return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); |
| 278 | }, |
| 279 | "Specifies the dimension of the shape that the scales and zero points " |
| 280 | "correspond to." ); |
| 281 | uniformQuantizedPerAxisType.def_property_readonly( |
| 282 | "is_fixed_point" , |
| 283 | [](MlirType type) { |
| 284 | return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); |
| 285 | }, |
| 286 | "Fixed point values are real numbers divided by a scale." ); |
| 287 | |
| 288 | //===-------------------------------------------------------------------===// |
| 289 | // UniformQuantizedSubChannelType |
| 290 | //===-------------------------------------------------------------------===// |
| 291 | auto uniformQuantizedSubChannelType = mlir_type_subclass( |
| 292 | m, "UniformQuantizedSubChannelType" , |
| 293 | mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class()); |
| 294 | uniformQuantizedSubChannelType.def_classmethod( |
| 295 | "get" , |
| 296 | [](nb::object cls, unsigned flags, MlirType storageType, |
| 297 | MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints, |
| 298 | std::vector<int32_t> quantizedDimensions, |
| 299 | std::vector<int64_t> blockSizes, int64_t storageTypeMin, |
| 300 | int64_t storageTypeMax) { |
| 301 | return cls(mlirUniformQuantizedSubChannelTypeGet( |
| 302 | flags, storageType, expressedType, scales, zeroPoints, |
| 303 | static_cast<intptr_t>(blockSizes.size()), |
| 304 | quantizedDimensions.data(), blockSizes.data(), storageTypeMin, |
| 305 | storageTypeMax)); |
| 306 | }, |
| 307 | "Gets an instance of UniformQuantizedSubChannel in the same context as " |
| 308 | "the provided storage type." , |
| 309 | nb::arg("cls" ), nb::arg("flags" ), nb::arg("storage_type" ), |
| 310 | nb::arg("expressed_type" ), nb::arg("scales" ), nb::arg("zero_points" ), |
| 311 | nb::arg("quantized_dimensions" ), nb::arg("block_sizes" ), |
| 312 | nb::arg("storage_type_min" ), nb::arg("storage_type_max" )); |
| 313 | uniformQuantizedSubChannelType.def_property_readonly( |
| 314 | "quantized_dimensions" , |
| 315 | [](MlirType type) { |
| 316 | intptr_t nDim = |
| 317 | mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); |
| 318 | std::vector<int32_t> quantizedDimensions; |
| 319 | quantizedDimensions.reserve(n: nDim); |
| 320 | for (intptr_t i = 0; i < nDim; ++i) { |
| 321 | quantizedDimensions.push_back( |
| 322 | mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i)); |
| 323 | } |
| 324 | return quantizedDimensions; |
| 325 | }, |
| 326 | "Gets the quantized dimensions. Each element in the returned list " |
| 327 | "represents an axis of the quantized data tensor that has a specified " |
| 328 | "block size. The order of elements corresponds to the order of block " |
| 329 | "sizes returned by 'block_sizes' method. It means that the data tensor " |
| 330 | "is quantized along the i-th dimension in the returned list using the " |
| 331 | "i-th block size from block_sizes method." ); |
| 332 | uniformQuantizedSubChannelType.def_property_readonly( |
| 333 | "block_sizes" , |
| 334 | [](MlirType type) { |
| 335 | intptr_t nDim = |
| 336 | mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); |
| 337 | std::vector<int64_t> blockSizes; |
| 338 | blockSizes.reserve(n: nDim); |
| 339 | for (intptr_t i = 0; i < nDim; ++i) { |
| 340 | blockSizes.push_back( |
| 341 | mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i)); |
| 342 | } |
| 343 | return blockSizes; |
| 344 | }, |
| 345 | "Gets the block sizes for the quantized dimensions. The i-th element in " |
| 346 | "the returned list corresponds to the block size for the i-th dimension " |
| 347 | "in the list returned by quantized_dimensions method." ); |
| 348 | uniformQuantizedSubChannelType.def_property_readonly( |
| 349 | "scales" , |
| 350 | [](MlirType type) -> MlirAttribute { |
| 351 | return mlirUniformQuantizedSubChannelTypeGetScales(type); |
| 352 | }, |
| 353 | "The scales of the quantized type." ); |
| 354 | uniformQuantizedSubChannelType.def_property_readonly( |
| 355 | "zero_points" , |
| 356 | [](MlirType type) -> MlirAttribute { |
| 357 | return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type); |
| 358 | }, |
| 359 | "The zero points of the quantized type." ); |
| 360 | |
| 361 | //===-------------------------------------------------------------------===// |
| 362 | // CalibratedQuantizedType |
| 363 | //===-------------------------------------------------------------------===// |
| 364 | |
| 365 | auto calibratedQuantizedType = mlir_type_subclass( |
| 366 | m, "CalibratedQuantizedType" , mlirTypeIsACalibratedQuantizedType, |
| 367 | quantizedType.get_class()); |
| 368 | calibratedQuantizedType.def_classmethod( |
| 369 | "get" , |
| 370 | [](nb::object cls, MlirType expressedType, double min, double max) { |
| 371 | return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); |
| 372 | }, |
| 373 | "Gets an instance of CalibratedQuantizedType in the same context as the " |
| 374 | "provided expressed type." , |
| 375 | nb::arg("cls" ), nb::arg("expressed_type" ), nb::arg("min" ), |
| 376 | nb::arg("max" )); |
| 377 | calibratedQuantizedType.def_property_readonly("min" , [](MlirType type) { |
| 378 | return mlirCalibratedQuantizedTypeGetMin(type); |
| 379 | }); |
| 380 | calibratedQuantizedType.def_property_readonly("max" , [](MlirType type) { |
| 381 | return mlirCalibratedQuantizedTypeGetMax(type); |
| 382 | }); |
| 383 | } |
| 384 | |
| 385 | NB_MODULE(_mlirDialectsQuant, m) { |
| 386 | m.doc() = "MLIR Quantization dialect" ; |
| 387 | |
| 388 | populateDialectQuantSubmodule(m); |
| 389 | } |
| 390 | |