1 | //===- QuantUtils.cpp -----------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains TOSA numerical support functions and quantization |
10 | // attribute builders. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::tosa; |
18 | |
19 | /// From a scale value, generates multiplier and shift values where |
20 | /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that |
21 | /// multiplier = mantissa*2^shift for 16-bit scaling. |
22 | static void computeMultiplierAndShiftTosaScale16(double scale, |
23 | int32_t &multiplier, |
24 | int32_t &shift) { |
25 | |
26 | const double mantissa = std::frexp(x: scale, exponent: &shift); |
27 | auto shiftedM = std::round(x: mantissa * (int64_t(1) << 15)); |
28 | |
29 | // Can't be greater than 1.0. |
30 | assert(shiftedM <= (int64_t(1) << 15) && |
31 | "Shifted mantissa exceeds 16 signed bits" ); |
32 | |
33 | if (shiftedM == (int64_t(1) << 15)) { |
34 | shiftedM /= 2; |
35 | shift++; |
36 | } |
37 | |
38 | // TOSA expects right shift to be positive and embed (1 << 15) into right |
39 | // shift bits. |
40 | shift = (-shift) + 15; |
41 | |
42 | assert(shiftedM <= std::numeric_limits<int32_t>::max() && |
43 | "Shifted mantissa exceeds 32-bit signed output type" ); |
44 | |
45 | multiplier = static_cast<int32_t>(shiftedM); |
46 | |
47 | // Shifting tops out at 62 bits. Right shift to make 62 bits the max. |
48 | // The limit of 62 on shift allows the shift to be decomposed as |
49 | // two right shifts of 31. |
50 | if (shift > 62) { |
51 | // Shifting the multiplier by more than 31-bits is unnecessary. |
52 | multiplier = multiplier >> std::min<int32_t>(a: 31, b: shift - 62); |
53 | shift = 62; |
54 | } |
55 | } |
56 | |
57 | /// From a scale value, generates multiplier and shift values where |
58 | /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that |
59 | /// multiplier = mantissa*2^shift for 32-bit scaling. |
60 | static void computeMultiplierAndShiftTosaScale32(double scale, |
61 | int32_t &multiplier, |
62 | int32_t &shift) { |
63 | |
64 | const double mantissa = std::frexp(x: scale, exponent: &shift); |
65 | auto shiftedM = std::round(x: mantissa * (int64_t(1) << 31)); |
66 | |
67 | // Can't be greater than 1.0. |
68 | assert(shiftedM <= (int64_t(1) << 31) && |
69 | "Shifted mantissa exceeds 32 signed bits" ); |
70 | if (shiftedM == (int64_t(1) << 31)) { |
71 | shiftedM /= 2; |
72 | shift++; |
73 | } |
74 | |
75 | // TOSA expects right shift to be positive, and embed (1 << 31) into right |
76 | // shift bits. |
77 | shift = (-shift) + 31; |
78 | |
79 | assert(shiftedM <= std::numeric_limits<int32_t>::max() && |
80 | "Shifted mantissa exceeds 32-bit signed output type" ); |
81 | |
82 | multiplier = static_cast<int32_t>(shiftedM); |
83 | |
84 | // Shifting tops out at 62 bits. Right shift to make 62 bits the max. |
85 | // The limit of 62 on shift allows the shift to be decomposed as |
86 | // two right shifts of 31. |
87 | if (shift > 62) { |
88 | // Shifting the multiplier by more than 32-bits is unnecessary. |
89 | multiplier = multiplier >> std::min<int32_t>(a: 31, b: shift - 62); |
90 | shift = 62; |
91 | } |
92 | } |
93 | |
94 | /// Generates a quantized multiplier/shift from double. |
95 | void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier, |
96 | int32_t &shift, int32_t scaleWidth) { |
97 | |
98 | switch (scaleWidth) { |
99 | case 16: |
100 | computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); |
101 | return; |
102 | case 32: |
103 | computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); |
104 | return; |
105 | default: |
106 | assert(0 && "Unsupported Tosa quantized_scale regime specified!" ); |
107 | } |
108 | } |
109 | |
110 | #define GET_UQTYPE(inputType) \ |
111 | (llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType())) |
112 | #define GET_QTYPE(inputType) \ |
113 | (llvm::dyn_cast<quant::QuantizedType>((inputType).getElementType())) |
114 | |
115 | /// Method to build ConvOpQuantizationAttr, called from |
116 | /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: |
117 | /// input_zp: input zeropoint |
118 | /// weight_zp: weight zeropoint. |
119 | ConvOpQuantizationAttr |
120 | mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, |
121 | Value weight) { |
122 | |
123 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
124 | auto weightType = dyn_cast<ShapedType>(weight.getType()); |
125 | |
126 | if (!inputType || !weightType) |
127 | return nullptr; |
128 | |
129 | auto inputQType = GET_UQTYPE(inputType); |
130 | auto weightPerTensorQType = GET_UQTYPE(weightType); |
131 | auto weightPerAxisQType = |
132 | dyn_cast<quant::UniformQuantizedPerAxisType>(weightType.getElementType()); |
133 | |
134 | // Weights must be either per-tensor quantized or per-axis quantized. |
135 | assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && |
136 | "Weights must be either per-tensor or per-axis quantized" ); |
137 | |
138 | // Either all quantized or all not quantized. |
139 | assert(!((bool)inputQType ^ |
140 | ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) && |
141 | "Inputs and weights must be all quantized or all not quantized" ); |
142 | |
143 | if (inputQType) { |
144 | int64_t inputZp = inputQType.getZeroPoint(); |
145 | int64_t weightZp = 0; |
146 | |
147 | if (weightPerTensorQType) { |
148 | weightZp = weightPerTensorQType.getZeroPoint(); |
149 | } else if (weightPerAxisQType) { |
150 | weightZp = weightPerAxisQType.getZeroPoints().front(); |
151 | } |
152 | |
153 | return builder.getAttr<tosa::ConvOpQuantizationAttr>(inputZp, weightZp); |
154 | } |
155 | |
156 | return nullptr; |
157 | } |
158 | |
159 | /// Builds MatMulOpQuantizationAttr, called from |
160 | /// MatMulOpQuantInfoBuilder: |
161 | /// aZp: input a zeropoint |
162 | /// bZp: input b zeropoint. |
163 | MatMulOpQuantizationAttr |
164 | mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, |
165 | Value b) { |
166 | |
167 | auto aType = dyn_cast<ShapedType>(a.getType()); |
168 | auto bType = dyn_cast<ShapedType>(b.getType()); |
169 | |
170 | if (!aType || !bType) |
171 | return nullptr; |
172 | |
173 | auto aQType = GET_UQTYPE(aType); |
174 | auto bQType = GET_UQTYPE(bType); |
175 | |
176 | // A and B are either all quantized or all not quantized. |
177 | assert(!((bool)aQType ^ (bool)bQType) && |
178 | "Matmul operands must be all quantized or all not quantized" ); |
179 | |
180 | if (aQType) { |
181 | return builder.getAttr<tosa::MatMulOpQuantizationAttr>( |
182 | aQType.getZeroPoint(), bQType.getZeroPoint()); |
183 | } |
184 | |
185 | return nullptr; |
186 | } |
187 | |
188 | /// Builds UnaryOpQuantizationAttr |
189 | /// UnaryOpQuantInfoBuilder: |
190 | /// inputZp: input zeropoint |
191 | /// outputZp: output zeropoint. |
192 | UnaryOpQuantizationAttr |
193 | mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, |
194 | Type outputRawType) { |
195 | |
196 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
197 | auto outputType = dyn_cast<ShapedType>(outputRawType); |
198 | |
199 | if (!inputType || !outputType) |
200 | return nullptr; |
201 | |
202 | auto inputQType = GET_UQTYPE(inputType); |
203 | auto outputQType = GET_UQTYPE(outputType); |
204 | |
205 | // Either all quantized or all not quantized. |
206 | assert(!((bool)inputQType ^ (bool)outputQType) && |
207 | "Unary inputs/outputs must be all quantized or all not quantized" ); |
208 | |
209 | if (inputQType) { |
210 | return builder.getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(), |
211 | outputQType.getZeroPoint()); |
212 | } |
213 | |
214 | return nullptr; |
215 | } |
216 | |
217 | /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: |
218 | /// inputZp: input zeropoint. |
219 | PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder, |
220 | Value input) { |
221 | |
222 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
223 | |
224 | if (!inputType) |
225 | return nullptr; |
226 | |
227 | auto inputQType = GET_UQTYPE(inputType); |
228 | |
229 | if (inputQType) { |
230 | return builder.getAttr<tosa::PadOpQuantizationAttr>( |
231 | inputQType.getZeroPoint()); |
232 | } |
233 | |
234 | return nullptr; |
235 | } |
236 | |
237 | /// Builds output type for a quantized ConvOp with the right bitwidth. |
238 | /// This is called by the builder when dealing with quantized content. |
239 | Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, |
240 | Value input, Value weight) { |
241 | |
242 | auto inputType = dyn_cast<ShapedType>(input.getType()); |
243 | auto weightType = dyn_cast<ShapedType>(weight.getType()); |
244 | |
245 | assert(inputType && weightType && |
246 | "Could not extract input or weight tensors from Conv op" ); |
247 | |
248 | auto inputQType = GET_QTYPE(inputType); |
249 | auto weightQType = GET_QTYPE(weightType); |
250 | |
251 | assert(inputQType && weightQType && |
252 | "Could not extract input or weight tensor types from Conv op" ); |
253 | |
254 | unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); |
255 | unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); |
256 | |
257 | auto outputShapedType = dyn_cast<ShapedType>(outputType); |
258 | assert(outputShapedType && |
259 | "Could not extract output shape type from Conv op" ); |
260 | |
261 | IntegerType accElementType; |
262 | if (inputBits == 16 && weightBits == 8) |
263 | accElementType = builder.getIntegerType(48); |
264 | else |
265 | accElementType = builder.getI32Type(); |
266 | auto accType = outputShapedType.clone(accElementType); |
267 | return accType; |
268 | } |
269 | |
270 | /// Builds Tosa quantization attributes from min/max values. |
271 | Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType, |
272 | Attribute minAttr, Attribute maxAttr, |
273 | IntegerAttr quantBits, int filterQuantDim, |
274 | bool isSigned, BoolAttr narrowRange) { |
275 | |
276 | quant::QuantizedType retType; |
277 | |
278 | auto convfunc = |
279 | quant::ExpressedToQuantizedConverter::forInputType(inputType: inputDType); |
280 | |
281 | auto minElems = dyn_cast<DenseFPElementsAttr>(Val&: minAttr); |
282 | auto maxElems = dyn_cast<DenseFPElementsAttr>(Val&: maxAttr); |
283 | |
284 | SmallVector<double, 2> min, max; |
285 | |
286 | // At least one is per-axis quantized elementsattr. |
287 | if (minElems || maxElems) { |
288 | // Must have the same number of elements. |
289 | if (minElems.getNumElements() != maxElems.getNumElements()) |
290 | return {}; |
291 | min.reserve(N: minElems.getNumElements()); |
292 | max.reserve(N: maxElems.getNumElements()); |
293 | for (auto i : minElems) |
294 | min.push_back(FloatAttr::getValueAsDouble(i)); |
295 | for (auto i : maxElems) |
296 | max.push_back(FloatAttr::getValueAsDouble(i)); |
297 | } else { // Just a single FP value. |
298 | auto minVal = dyn_cast<FloatAttr>(minAttr); |
299 | if (minVal) |
300 | min.push_back(Elt: minVal.getValueAsDouble()); |
301 | else |
302 | return {}; |
303 | auto maxVal = dyn_cast<FloatAttr>(maxAttr); |
304 | if (maxVal) |
305 | max.push_back(Elt: maxVal.getValueAsDouble()); |
306 | else |
307 | return {}; |
308 | } |
309 | |
310 | if (min.size() == max.size()) { |
311 | if (min.size() == 1) { // Per-tensor quantization with one min/max pair. |
312 | retType = quant::fakeQuantAttrsToType( |
313 | builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], |
314 | narrowRange.getValue(), convfunc.expressedType, isSigned); |
315 | } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. |
316 | auto shape = dyn_cast<ShapedType>(inputDType); |
317 | if (!shape) |
318 | return {}; |
319 | if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { |
320 | retType = quant::fakeQuantAttrsToType( |
321 | builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], |
322 | max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); |
323 | } |
324 | } else { |
325 | return {}; |
326 | } |
327 | } else { |
328 | return {}; |
329 | } |
330 | |
331 | if (!retType) |
332 | return {}; |
333 | |
334 | return convfunc.convert(elementalType: retType); |
335 | } |
336 | |
337 | /// Builds Tosa quantization attributes from min/max values. |
338 | TypeAttr |
339 | mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, |
340 | Attribute minAttr, Attribute maxAttr, |
341 | IntegerAttr quantBits, int filterQuantDim, |
342 | bool isSigned, BoolAttr narrowRange) { |
343 | |
344 | return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, |
345 | maxAttr, quantBits, filterQuantDim, |
346 | isSigned, narrowRange)); |
347 | } |
348 | |