| 1 | //===- QuantUtils.cpp -----------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains TOSA numerical support functions and quantization |
| 10 | // attribute builders. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" |
| 15 | |
| 16 | using namespace mlir; |
| 17 | using namespace mlir::tosa; |
| 18 | |
| 19 | /// From a scale value, generates multiplier and shift values where |
| 20 | /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that |
| 21 | /// multiplier = mantissa*2^shift for 16-bit scaling. |
| 22 | static void computeMultiplierAndShiftTosaScale16(double scale, |
| 23 | int32_t &multiplier, |
| 24 | int32_t &shift) { |
| 25 | |
| 26 | const double mantissa = std::frexp(x: scale, exponent: &shift); |
| 27 | auto shiftedM = std::round(x: mantissa * (int64_t(1) << 15)); |
| 28 | |
| 29 | // Can't be greater than 1.0. |
| 30 | assert(shiftedM <= (int64_t(1) << 15) && |
| 31 | "Shifted mantissa exceeds 16 signed bits" ); |
| 32 | |
| 33 | if (shiftedM == (int64_t(1) << 15)) { |
| 34 | shiftedM /= 2; |
| 35 | shift++; |
| 36 | } |
| 37 | |
| 38 | // TOSA expects right shift to be positive and embed (1 << 15) into right |
| 39 | // shift bits. |
| 40 | shift = (-shift) + 15; |
| 41 | |
| 42 | assert(shiftedM <= std::numeric_limits<int32_t>::max() && |
| 43 | "Shifted mantissa exceeds 32-bit signed output type" ); |
| 44 | |
| 45 | multiplier = static_cast<int32_t>(shiftedM); |
| 46 | |
| 47 | // Shifting tops out at 62 bits. Right shift to make 62 bits the max. |
| 48 | // The limit of 62 on shift allows the shift to be decomposed as |
| 49 | // two right shifts of 31. |
| 50 | if (shift > 62) { |
| 51 | // Shifting the multiplier by more than 31-bits is unnecessary. |
| 52 | multiplier = multiplier >> std::min<int32_t>(a: 31, b: shift - 62); |
| 53 | shift = 62; |
| 54 | } |
| 55 | } |
| 56 | |
| 57 | /// From a scale value, generates multiplier and shift values where |
| 58 | /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that |
| 59 | /// multiplier = mantissa*2^shift for 32-bit scaling. |
| 60 | static void computeMultiplierAndShiftTosaScale32(double scale, |
| 61 | int32_t &multiplier, |
| 62 | int32_t &shift) { |
| 63 | |
| 64 | const double mantissa = std::frexp(x: scale, exponent: &shift); |
| 65 | auto shiftedM = std::round(x: mantissa * (int64_t(1) << 31)); |
| 66 | |
| 67 | // Can't be greater than 1.0. |
| 68 | assert(shiftedM <= (int64_t(1) << 31) && |
| 69 | "Shifted mantissa exceeds 32 signed bits" ); |
| 70 | if (shiftedM == (int64_t(1) << 31)) { |
| 71 | shiftedM /= 2; |
| 72 | shift++; |
| 73 | } |
| 74 | |
| 75 | // TOSA expects right shift to be positive, and embed (1 << 31) into right |
| 76 | // shift bits. |
| 77 | shift = (-shift) + 31; |
| 78 | |
| 79 | assert(shiftedM <= std::numeric_limits<int32_t>::max() && |
| 80 | "Shifted mantissa exceeds 32-bit signed output type" ); |
| 81 | |
| 82 | multiplier = static_cast<int32_t>(shiftedM); |
| 83 | |
| 84 | // Shifting tops out at 62 bits. Right shift to make 62 bits the max. |
| 85 | // The limit of 62 on shift allows the shift to be decomposed as |
| 86 | // two right shifts of 31. |
| 87 | if (shift > 62) { |
| 88 | // Shifting the multiplier by more than 32-bits is unnecessary. |
| 89 | multiplier = multiplier >> std::min<int32_t>(a: 31, b: shift - 62); |
| 90 | shift = 62; |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | /// Generates a quantized multiplier/shift from double. |
| 95 | bool mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier, |
| 96 | int32_t &shift, int32_t scaleWidth) { |
| 97 | |
| 98 | switch (scaleWidth) { |
| 99 | case 16: |
| 100 | computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); |
| 101 | |
| 102 | // In some cases computeMultiplierAndShiftTosaScale16 can return |
| 103 | // a value less then 2, which is not valid in the TOSA spec. |
| 104 | return (!(shift < 2)); |
| 105 | case 32: |
| 106 | computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); |
| 107 | |
| 108 | // In some cases computeMultiplierAndShiftTosaScale32 can return |
| 109 | // a value less then 2, which is not valid in the TOSA spec. |
| 110 | return (!(shift < 2)); |
| 111 | default: |
| 112 | assert(0 && "Unsupported Tosa quantized_scale regime specified!" ); |
| 113 | return false; |
| 114 | } |
| 115 | } |
| 116 | |
| 117 | #define GET_UQTYPE(inputType) \ |
| 118 | (llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType())) |
| 119 | #define GET_QTYPE(inputType) \ |
| 120 | (llvm::dyn_cast<quant::QuantizedType>((inputType).getElementType())) |
| 121 | |
| 122 | static std::optional<std::pair<std::int64_t, std::int64_t>> |
| 123 | getConvZeroPoints(Value input, Value weight) { |
| 124 | |
| 125 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
| 126 | auto weightType = dyn_cast<ShapedType>(weight.getType()); |
| 127 | |
| 128 | if (!inputType || !weightType) |
| 129 | return std::nullopt; |
| 130 | |
| 131 | auto inputQType = GET_UQTYPE(inputType); |
| 132 | auto weightPerTensorQType = GET_UQTYPE(weightType); |
| 133 | auto weightPerAxisQType = |
| 134 | dyn_cast<quant::UniformQuantizedPerAxisType>(weightType.getElementType()); |
| 135 | |
| 136 | // Weights must be either per-tensor quantized or per-axis quantized. |
| 137 | assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && |
| 138 | "Weights must be either per-tensor or per-axis quantized" ); |
| 139 | |
| 140 | // Either all quantized or all not quantized. |
| 141 | assert(!((bool)inputQType ^ |
| 142 | ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) && |
| 143 | "Inputs and weights must be all quantized or all not quantized" ); |
| 144 | |
| 145 | if (inputQType) { |
| 146 | int64_t inputZp = inputQType.getZeroPoint(); |
| 147 | int64_t weightZp = 0; |
| 148 | |
| 149 | if (weightPerTensorQType) { |
| 150 | weightZp = weightPerTensorQType.getZeroPoint(); |
| 151 | } else if (weightPerAxisQType) { |
| 152 | weightZp = weightPerAxisQType.getZeroPoints().front(); |
| 153 | } |
| 154 | |
| 155 | return std::make_pair(x&: inputZp, y&: weightZp); |
| 156 | } |
| 157 | |
| 158 | return std::nullopt; |
| 159 | } |
| 160 | |
| 161 | std::pair<Value, Value> |
| 162 | mlir::tosa::createZPsAsConst(OpBuilder &builder, Value input, Value weight) { |
| 163 | std::int64_t inputZp, weightZp; |
| 164 | |
| 165 | auto inputEType = getElementTypeOrSelf(type: input.getType()); |
| 166 | auto weightEType = getElementTypeOrSelf(type: weight.getType()); |
| 167 | |
| 168 | if (mlir::isa<FloatType>(Val: inputEType) && mlir::isa<FloatType>(Val: weightEType)) { |
| 169 | inputZp = 0; |
| 170 | weightZp = 0; |
| 171 | } else { |
| 172 | auto maybeZps = getConvZeroPoints(input, weight); |
| 173 | if (!maybeZps.has_value()) |
| 174 | return {}; |
| 175 | |
| 176 | inputZp = maybeZps->first; |
| 177 | weightZp = maybeZps->second; |
| 178 | } |
| 179 | |
| 180 | auto maybeInputZpValue = |
| 181 | createZeroPointTensor(builder, loc: input.getLoc(), srcElemType: inputEType, zp: inputZp); |
| 182 | if (!maybeInputZpValue.has_value()) |
| 183 | return {}; |
| 184 | |
| 185 | auto maybeWeightZpValue = |
| 186 | createZeroPointTensor(builder, loc: weight.getLoc(), srcElemType: weightEType, zp: weightZp); |
| 187 | if (!maybeWeightZpValue.has_value()) |
| 188 | return {}; |
| 189 | |
| 190 | return std::make_pair(x&: *maybeInputZpValue, y&: *maybeWeightZpValue); |
| 191 | } |
| 192 | |
| 193 | /// Method to build ConvOpQuantizationAttr, called from |
| 194 | /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: |
| 195 | /// input_zp: input zeropoint |
| 196 | /// weight_zp: weight zeropoint. |
| 197 | ConvOpQuantizationAttr |
| 198 | mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, |
| 199 | Value weight) { |
| 200 | |
| 201 | auto maybeZps = getConvZeroPoints(input, weight); |
| 202 | if (!maybeZps.has_value()) |
| 203 | return nullptr; |
| 204 | |
| 205 | return builder.getAttr<tosa::ConvOpQuantizationAttr>(maybeZps->first, |
| 206 | maybeZps->second); |
| 207 | } |
| 208 | |
| 209 | /// Builds MatMulOpQuantizationAttr, called from |
| 210 | /// MatMulOpQuantInfoBuilder: |
| 211 | /// aZp: input a zeropoint |
| 212 | /// bZp: input b zeropoint. |
| 213 | MatMulOpQuantizationAttr |
| 214 | mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, |
| 215 | Value b) { |
| 216 | |
| 217 | auto aType = dyn_cast<ShapedType>(a.getType()); |
| 218 | auto bType = dyn_cast<ShapedType>(b.getType()); |
| 219 | |
| 220 | if (!aType || !bType) |
| 221 | return nullptr; |
| 222 | |
| 223 | auto aQType = GET_UQTYPE(aType); |
| 224 | auto bQType = GET_UQTYPE(bType); |
| 225 | |
| 226 | // A and B are either all quantized or all not quantized. |
| 227 | assert(!((bool)aQType ^ (bool)bQType) && |
| 228 | "Matmul operands must be all quantized or all not quantized" ); |
| 229 | |
| 230 | if (aQType) { |
| 231 | return builder.getAttr<tosa::MatMulOpQuantizationAttr>( |
| 232 | aQType.getZeroPoint(), bQType.getZeroPoint()); |
| 233 | } |
| 234 | |
| 235 | return nullptr; |
| 236 | } |
| 237 | |
| 238 | /// Builds UnaryOpQuantizationAttr |
| 239 | /// UnaryOpQuantInfoBuilder: |
| 240 | /// inputZp: input zeropoint |
| 241 | /// outputZp: output zeropoint. |
| 242 | UnaryOpQuantizationAttr |
| 243 | mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, |
| 244 | Type outputRawType) { |
| 245 | |
| 246 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
| 247 | auto outputType = dyn_cast<ShapedType>(outputRawType); |
| 248 | |
| 249 | if (!inputType || !outputType) |
| 250 | return nullptr; |
| 251 | |
| 252 | auto inputQType = GET_UQTYPE(inputType); |
| 253 | auto outputQType = GET_UQTYPE(outputType); |
| 254 | |
| 255 | // Either all quantized or all not quantized. |
| 256 | assert(!((bool)inputQType ^ (bool)outputQType) && |
| 257 | "Unary inputs/outputs must be all quantized or all not quantized" ); |
| 258 | |
| 259 | if (inputQType) { |
| 260 | return builder.getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(), |
| 261 | outputQType.getZeroPoint()); |
| 262 | } |
| 263 | |
| 264 | return nullptr; |
| 265 | } |
| 266 | |
| 267 | /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: |
| 268 | /// inputZp: input zeropoint. |
| 269 | PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder, |
| 270 | Value input) { |
| 271 | |
| 272 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
| 273 | |
| 274 | if (!inputType) |
| 275 | return nullptr; |
| 276 | |
| 277 | auto inputQType = GET_UQTYPE(inputType); |
| 278 | |
| 279 | if (inputQType) { |
| 280 | return builder.getAttr<tosa::PadOpQuantizationAttr>( |
| 281 | inputQType.getZeroPoint()); |
| 282 | } |
| 283 | |
| 284 | return nullptr; |
| 285 | } |
| 286 | |
| 287 | /// Builds output type for a quantized ConvOp with the right bitwidth. |
| 288 | /// This is called by the builder when dealing with quantized content. |
| 289 | Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, |
| 290 | Value input, Value weight) { |
| 291 | |
| 292 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
| 293 | auto weightType = dyn_cast<ShapedType>(weight.getType()); |
| 294 | |
| 295 | assert(inputType && weightType && |
| 296 | "Could not extract input or weight tensors from Conv op" ); |
| 297 | |
| 298 | auto inputQType = GET_QTYPE(inputType); |
| 299 | auto weightQType = GET_QTYPE(weightType); |
| 300 | |
| 301 | assert(inputQType && weightQType && |
| 302 | "Could not extract input or weight tensor types from Conv op" ); |
| 303 | |
| 304 | unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); |
| 305 | unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); |
| 306 | |
| 307 | auto outputShapedType = dyn_cast<ShapedType>(outputType); |
| 308 | assert(outputShapedType && |
| 309 | "Could not extract output shape type from Conv op" ); |
| 310 | |
| 311 | IntegerType accElementType; |
| 312 | if (inputBits == 16 && weightBits == 8) |
| 313 | accElementType = builder.getIntegerType(48); |
| 314 | else |
| 315 | accElementType = builder.getI32Type(); |
| 316 | auto accType = outputShapedType.clone(accElementType); |
| 317 | return accType; |
| 318 | } |
| 319 | |
| 320 | /// Builds Tosa quantization attributes from min/max values. |
| 321 | Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType, |
| 322 | Attribute minAttr, Attribute maxAttr, |
| 323 | IntegerAttr quantBits, int filterQuantDim, |
| 324 | bool isSigned, BoolAttr narrowRange) { |
| 325 | |
| 326 | quant::QuantizedType retType; |
| 327 | |
| 328 | auto convfunc = |
| 329 | quant::ExpressedToQuantizedConverter::forInputType(inputType: inputDType); |
| 330 | |
| 331 | auto minElems = dyn_cast<DenseFPElementsAttr>(Val&: minAttr); |
| 332 | auto maxElems = dyn_cast<DenseFPElementsAttr>(Val&: maxAttr); |
| 333 | |
| 334 | SmallVector<double, 2> min, max; |
| 335 | |
| 336 | // At least one is per-axis quantized elementsattr. |
| 337 | if (minElems || maxElems) { |
| 338 | // Must have the same number of elements. |
| 339 | if (minElems.getNumElements() != maxElems.getNumElements()) |
| 340 | return {}; |
| 341 | min.reserve(N: minElems.getNumElements()); |
| 342 | max.reserve(N: maxElems.getNumElements()); |
| 343 | for (auto i : minElems) |
| 344 | min.push_back(FloatAttr::getValueAsDouble(i)); |
| 345 | for (auto i : maxElems) |
| 346 | max.push_back(FloatAttr::getValueAsDouble(i)); |
| 347 | } else { // Just a single FP value. |
| 348 | auto minVal = dyn_cast<FloatAttr>(minAttr); |
| 349 | if (minVal) |
| 350 | min.push_back(Elt: minVal.getValueAsDouble()); |
| 351 | else |
| 352 | return {}; |
| 353 | auto maxVal = dyn_cast<FloatAttr>(maxAttr); |
| 354 | if (maxVal) |
| 355 | max.push_back(Elt: maxVal.getValueAsDouble()); |
| 356 | else |
| 357 | return {}; |
| 358 | } |
| 359 | |
| 360 | if (min.size() == max.size()) { |
| 361 | if (min.size() == 1) { // Per-tensor quantization with one min/max pair. |
| 362 | retType = quant::fakeQuantAttrsToType( |
| 363 | builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], |
| 364 | narrowRange.getValue(), convfunc.expressedType, isSigned); |
| 365 | } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. |
| 366 | auto shape = dyn_cast<ShapedType>(inputDType); |
| 367 | if (!shape) |
| 368 | return {}; |
| 369 | if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { |
| 370 | retType = quant::fakeQuantAttrsToType( |
| 371 | builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], |
| 372 | max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); |
| 373 | } |
| 374 | } else { |
| 375 | return {}; |
| 376 | } |
| 377 | } else { |
| 378 | return {}; |
| 379 | } |
| 380 | |
| 381 | if (!retType) |
| 382 | return {}; |
| 383 | |
| 384 | return convfunc.convert(elementalType: retType); |
| 385 | } |
| 386 | |
| 387 | /// Builds Tosa quantization attributes from min/max values. |
| 388 | TypeAttr |
| 389 | mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, |
| 390 | Attribute minAttr, Attribute maxAttr, |
| 391 | IntegerAttr quantBits, int filterQuantDim, |
| 392 | bool isSigned, BoolAttr narrowRange) { |
| 393 | |
| 394 | return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, |
| 395 | maxAttr, quantBits, filterQuantDim, |
| 396 | isSigned, narrowRange)); |
| 397 | } |
| 398 | |