1 | //===- QuantUtils.cpp -----------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains TOSA numerical support functions and quantization |
10 | // attribute builders. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::tosa; |
18 | |
19 | /// From a scale value, generates multiplier and shift values where |
20 | /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that |
21 | /// multiplier = mantissa*2^shift for 16-bit scaling. |
22 | static void computeMultiplierAndShiftTosaScale16(double scale, |
23 | int32_t &multiplier, |
24 | int32_t &shift) { |
25 | |
26 | const double mantissa = std::frexp(x: scale, exponent: &shift); |
27 | auto shiftedM = std::round(x: mantissa * (int64_t(1) << 15)); |
28 | |
29 | // Can't be greater than 1.0. |
30 | assert(shiftedM <= (int64_t(1) << 15) && |
31 | "Shifted mantissa exceeds 16 signed bits" ); |
32 | |
33 | if (shiftedM == (int64_t(1) << 15)) { |
34 | shiftedM /= 2; |
35 | shift++; |
36 | } |
37 | |
38 | // TOSA expects right shift to be positive and embed (1 << 15) into right |
39 | // shift bits. |
40 | shift = (-shift) + 15; |
41 | |
42 | assert(shiftedM <= std::numeric_limits<int32_t>::max() && |
43 | "Shifted mantissa exceeds 32-bit signed output type" ); |
44 | |
45 | multiplier = static_cast<int32_t>(shiftedM); |
46 | |
47 | // Shifting tops out at 62 bits. Right shift to make 62 bits the max. |
48 | // The limit of 62 on shift allows the shift to be decomposed as |
49 | // two right shifts of 31. |
50 | if (shift > 62) { |
51 | // Shifting the multiplier by more than 31-bits is unnecessary. |
52 | multiplier = multiplier >> std::min<int32_t>(a: 31, b: shift - 62); |
53 | shift = 62; |
54 | } |
55 | } |
56 | |
57 | /// From a scale value, generates multiplier and shift values where |
58 | /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that |
59 | /// multiplier = mantissa*2^shift for 32-bit scaling. |
60 | static void computeMultiplierAndShiftTosaScale32(double scale, |
61 | int32_t &multiplier, |
62 | int32_t &shift) { |
63 | |
64 | const double mantissa = std::frexp(x: scale, exponent: &shift); |
65 | auto shiftedM = std::round(x: mantissa * (int64_t(1) << 31)); |
66 | |
67 | // Can't be greater than 1.0. |
68 | assert(shiftedM <= (int64_t(1) << 31) && |
69 | "Shifted mantissa exceeds 32 signed bits" ); |
70 | if (shiftedM == (int64_t(1) << 31)) { |
71 | shiftedM /= 2; |
72 | shift++; |
73 | } |
74 | |
75 | // TOSA expects right shift to be positive, and embed (1 << 31) into right |
76 | // shift bits. |
77 | shift = (-shift) + 31; |
78 | |
79 | assert(shiftedM <= std::numeric_limits<int32_t>::max() && |
80 | "Shifted mantissa exceeds 32-bit signed output type" ); |
81 | |
82 | multiplier = static_cast<int32_t>(shiftedM); |
83 | |
84 | // Shifting tops out at 62 bits. Right shift to make 62 bits the max. |
85 | // The limit of 62 on shift allows the shift to be decomposed as |
86 | // two right shifts of 31. |
87 | if (shift > 62) { |
88 | // Shifting the multiplier by more than 32-bits is unnecessary. |
89 | multiplier = multiplier >> std::min<int32_t>(a: 31, b: shift - 62); |
90 | shift = 62; |
91 | } |
92 | } |
93 | |
94 | /// Generates a quantized multiplier/shift from double. |
95 | bool mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier, |
96 | int32_t &shift, int32_t scaleWidth) { |
97 | |
98 | switch (scaleWidth) { |
99 | case 16: |
100 | computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); |
101 | |
102 | // In some cases computeMultiplierAndShiftTosaScale16 can return |
103 | // a value less then 2, which is not valid in the TOSA spec. |
104 | return (!(shift < 2)); |
105 | case 32: |
106 | computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); |
107 | |
108 | // In some cases computeMultiplierAndShiftTosaScale32 can return |
109 | // a value less then 2, which is not valid in the TOSA spec. |
110 | return (!(shift < 2)); |
111 | default: |
112 | assert(0 && "Unsupported Tosa quantized_scale regime specified!" ); |
113 | return false; |
114 | } |
115 | } |
116 | |
117 | #define GET_UQTYPE(inputType) \ |
118 | (llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType())) |
119 | #define GET_QTYPE(inputType) \ |
120 | (llvm::dyn_cast<quant::QuantizedType>((inputType).getElementType())) |
121 | |
122 | static std::optional<std::pair<std::int64_t, std::int64_t>> |
123 | getConvZeroPoints(Value input, Value weight) { |
124 | |
125 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
126 | auto weightType = dyn_cast<ShapedType>(weight.getType()); |
127 | |
128 | if (!inputType || !weightType) |
129 | return std::nullopt; |
130 | |
131 | auto inputQType = GET_UQTYPE(inputType); |
132 | auto weightPerTensorQType = GET_UQTYPE(weightType); |
133 | auto weightPerAxisQType = |
134 | dyn_cast<quant::UniformQuantizedPerAxisType>(weightType.getElementType()); |
135 | |
136 | // Weights must be either per-tensor quantized or per-axis quantized. |
137 | assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && |
138 | "Weights must be either per-tensor or per-axis quantized" ); |
139 | |
140 | // Either all quantized or all not quantized. |
141 | assert(!((bool)inputQType ^ |
142 | ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) && |
143 | "Inputs and weights must be all quantized or all not quantized" ); |
144 | |
145 | if (inputQType) { |
146 | int64_t inputZp = inputQType.getZeroPoint(); |
147 | int64_t weightZp = 0; |
148 | |
149 | if (weightPerTensorQType) { |
150 | weightZp = weightPerTensorQType.getZeroPoint(); |
151 | } else if (weightPerAxisQType) { |
152 | weightZp = weightPerAxisQType.getZeroPoints().front(); |
153 | } |
154 | |
155 | return std::make_pair(x&: inputZp, y&: weightZp); |
156 | } |
157 | |
158 | return std::nullopt; |
159 | } |
160 | |
161 | std::pair<Value, Value> |
162 | mlir::tosa::createZPsAsConst(OpBuilder &builder, Value input, Value weight) { |
163 | std::int64_t inputZp, weightZp; |
164 | |
165 | auto inputEType = getElementTypeOrSelf(type: input.getType()); |
166 | auto weightEType = getElementTypeOrSelf(type: weight.getType()); |
167 | |
168 | if (mlir::isa<FloatType>(Val: inputEType) && mlir::isa<FloatType>(Val: weightEType)) { |
169 | inputZp = 0; |
170 | weightZp = 0; |
171 | } else { |
172 | auto maybeZps = getConvZeroPoints(input, weight); |
173 | if (!maybeZps.has_value()) |
174 | return {}; |
175 | |
176 | inputZp = maybeZps->first; |
177 | weightZp = maybeZps->second; |
178 | } |
179 | |
180 | auto maybeInputZpValue = |
181 | createZeroPointTensor(builder, loc: input.getLoc(), srcElemType: inputEType, zp: inputZp); |
182 | if (!maybeInputZpValue.has_value()) |
183 | return {}; |
184 | |
185 | auto maybeWeightZpValue = |
186 | createZeroPointTensor(builder, loc: weight.getLoc(), srcElemType: weightEType, zp: weightZp); |
187 | if (!maybeWeightZpValue.has_value()) |
188 | return {}; |
189 | |
190 | return std::make_pair(x&: *maybeInputZpValue, y&: *maybeWeightZpValue); |
191 | } |
192 | |
193 | /// Method to build ConvOpQuantizationAttr, called from |
194 | /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: |
195 | /// input_zp: input zeropoint |
196 | /// weight_zp: weight zeropoint. |
197 | ConvOpQuantizationAttr |
198 | mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, |
199 | Value weight) { |
200 | |
201 | auto maybeZps = getConvZeroPoints(input, weight); |
202 | if (!maybeZps.has_value()) |
203 | return nullptr; |
204 | |
205 | return builder.getAttr<tosa::ConvOpQuantizationAttr>(maybeZps->first, |
206 | maybeZps->second); |
207 | } |
208 | |
209 | /// Builds MatMulOpQuantizationAttr, called from |
210 | /// MatMulOpQuantInfoBuilder: |
211 | /// aZp: input a zeropoint |
212 | /// bZp: input b zeropoint. |
213 | MatMulOpQuantizationAttr |
214 | mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, |
215 | Value b) { |
216 | |
217 | auto aType = dyn_cast<ShapedType>(a.getType()); |
218 | auto bType = dyn_cast<ShapedType>(b.getType()); |
219 | |
220 | if (!aType || !bType) |
221 | return nullptr; |
222 | |
223 | auto aQType = GET_UQTYPE(aType); |
224 | auto bQType = GET_UQTYPE(bType); |
225 | |
226 | // A and B are either all quantized or all not quantized. |
227 | assert(!((bool)aQType ^ (bool)bQType) && |
228 | "Matmul operands must be all quantized or all not quantized" ); |
229 | |
230 | if (aQType) { |
231 | return builder.getAttr<tosa::MatMulOpQuantizationAttr>( |
232 | aQType.getZeroPoint(), bQType.getZeroPoint()); |
233 | } |
234 | |
235 | return nullptr; |
236 | } |
237 | |
238 | /// Builds UnaryOpQuantizationAttr |
239 | /// UnaryOpQuantInfoBuilder: |
240 | /// inputZp: input zeropoint |
241 | /// outputZp: output zeropoint. |
242 | UnaryOpQuantizationAttr |
243 | mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, |
244 | Type outputRawType) { |
245 | |
246 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
247 | auto outputType = dyn_cast<ShapedType>(outputRawType); |
248 | |
249 | if (!inputType || !outputType) |
250 | return nullptr; |
251 | |
252 | auto inputQType = GET_UQTYPE(inputType); |
253 | auto outputQType = GET_UQTYPE(outputType); |
254 | |
255 | // Either all quantized or all not quantized. |
256 | assert(!((bool)inputQType ^ (bool)outputQType) && |
257 | "Unary inputs/outputs must be all quantized or all not quantized" ); |
258 | |
259 | if (inputQType) { |
260 | return builder.getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(), |
261 | outputQType.getZeroPoint()); |
262 | } |
263 | |
264 | return nullptr; |
265 | } |
266 | |
267 | /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: |
268 | /// inputZp: input zeropoint. |
269 | PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder, |
270 | Value input) { |
271 | |
272 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
273 | |
274 | if (!inputType) |
275 | return nullptr; |
276 | |
277 | auto inputQType = GET_UQTYPE(inputType); |
278 | |
279 | if (inputQType) { |
280 | return builder.getAttr<tosa::PadOpQuantizationAttr>( |
281 | inputQType.getZeroPoint()); |
282 | } |
283 | |
284 | return nullptr; |
285 | } |
286 | |
287 | /// Builds output type for a quantized ConvOp with the right bitwidth. |
288 | /// This is called by the builder when dealing with quantized content. |
289 | Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, |
290 | Value input, Value weight) { |
291 | |
292 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
293 | auto weightType = dyn_cast<ShapedType>(weight.getType()); |
294 | |
295 | assert(inputType && weightType && |
296 | "Could not extract input or weight tensors from Conv op" ); |
297 | |
298 | auto inputQType = GET_QTYPE(inputType); |
299 | auto weightQType = GET_QTYPE(weightType); |
300 | |
301 | assert(inputQType && weightQType && |
302 | "Could not extract input or weight tensor types from Conv op" ); |
303 | |
304 | unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); |
305 | unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); |
306 | |
307 | auto outputShapedType = dyn_cast<ShapedType>(outputType); |
308 | assert(outputShapedType && |
309 | "Could not extract output shape type from Conv op" ); |
310 | |
311 | IntegerType accElementType; |
312 | if (inputBits == 16 && weightBits == 8) |
313 | accElementType = builder.getIntegerType(48); |
314 | else |
315 | accElementType = builder.getI32Type(); |
316 | auto accType = outputShapedType.clone(accElementType); |
317 | return accType; |
318 | } |
319 | |
320 | /// Builds Tosa quantization attributes from min/max values. |
321 | Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType, |
322 | Attribute minAttr, Attribute maxAttr, |
323 | IntegerAttr quantBits, int filterQuantDim, |
324 | bool isSigned, BoolAttr narrowRange) { |
325 | |
326 | quant::QuantizedType retType; |
327 | |
328 | auto convfunc = |
329 | quant::ExpressedToQuantizedConverter::forInputType(inputType: inputDType); |
330 | |
331 | auto minElems = dyn_cast<DenseFPElementsAttr>(Val&: minAttr); |
332 | auto maxElems = dyn_cast<DenseFPElementsAttr>(Val&: maxAttr); |
333 | |
334 | SmallVector<double, 2> min, max; |
335 | |
336 | // At least one is per-axis quantized elementsattr. |
337 | if (minElems || maxElems) { |
338 | // Must have the same number of elements. |
339 | if (minElems.getNumElements() != maxElems.getNumElements()) |
340 | return {}; |
341 | min.reserve(N: minElems.getNumElements()); |
342 | max.reserve(N: maxElems.getNumElements()); |
343 | for (auto i : minElems) |
344 | min.push_back(FloatAttr::getValueAsDouble(i)); |
345 | for (auto i : maxElems) |
346 | max.push_back(FloatAttr::getValueAsDouble(i)); |
347 | } else { // Just a single FP value. |
348 | auto minVal = dyn_cast<FloatAttr>(minAttr); |
349 | if (minVal) |
350 | min.push_back(Elt: minVal.getValueAsDouble()); |
351 | else |
352 | return {}; |
353 | auto maxVal = dyn_cast<FloatAttr>(maxAttr); |
354 | if (maxVal) |
355 | max.push_back(Elt: maxVal.getValueAsDouble()); |
356 | else |
357 | return {}; |
358 | } |
359 | |
360 | if (min.size() == max.size()) { |
361 | if (min.size() == 1) { // Per-tensor quantization with one min/max pair. |
362 | retType = quant::fakeQuantAttrsToType( |
363 | builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], |
364 | narrowRange.getValue(), convfunc.expressedType, isSigned); |
365 | } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. |
366 | auto shape = dyn_cast<ShapedType>(inputDType); |
367 | if (!shape) |
368 | return {}; |
369 | if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { |
370 | retType = quant::fakeQuantAttrsToType( |
371 | builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], |
372 | max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); |
373 | } |
374 | } else { |
375 | return {}; |
376 | } |
377 | } else { |
378 | return {}; |
379 | } |
380 | |
381 | if (!retType) |
382 | return {}; |
383 | |
384 | return convfunc.convert(elementalType: retType); |
385 | } |
386 | |
387 | /// Builds Tosa quantization attributes from min/max values. |
388 | TypeAttr |
389 | mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, |
390 | Attribute minAttr, Attribute maxAttr, |
391 | IntegerAttr quantBits, int filterQuantDim, |
392 | bool isSigned, BoolAttr narrowRange) { |
393 | |
394 | return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, |
395 | maxAttr, quantBits, filterQuantDim, |
396 | isSigned, narrowRange)); |
397 | } |
398 | |