1//===- QuantUtils.cpp -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains TOSA numerical support functions and quantization
10// attribute builders.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
15
16using namespace mlir;
17using namespace mlir::tosa;
18
19/// From a scale value, generates multiplier and shift values where
20/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
21/// multiplier = mantissa*2^shift for 16-bit scaling.
22static void computeMultiplierAndShiftTosaScale16(double scale,
23 int32_t &multiplier,
24 int32_t &shift) {
25
26 const double mantissa = std::frexp(x: scale, exponent: &shift);
27 auto shiftedM = std::round(x: mantissa * (int64_t(1) << 15));
28
29 // Can't be greater than 1.0.
30 assert(shiftedM <= (int64_t(1) << 15) &&
31 "Shifted mantissa exceeds 16 signed bits");
32
33 if (shiftedM == (int64_t(1) << 15)) {
34 shiftedM /= 2;
35 shift++;
36 }
37
38 // TOSA expects right shift to be positive and embed (1 << 15) into right
39 // shift bits.
40 shift = (-shift) + 15;
41
42 assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
43 "Shifted mantissa exceeds 32-bit signed output type");
44
45 multiplier = static_cast<int32_t>(shiftedM);
46
47 // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
48 // The limit of 62 on shift allows the shift to be decomposed as
49 // two right shifts of 31.
50 if (shift > 62) {
51 // Shifting the multiplier by more than 31-bits is unnecessary.
52 multiplier = multiplier >> std::min<int32_t>(a: 31, b: shift - 62);
53 shift = 62;
54 }
55}
56
57/// From a scale value, generates multiplier and shift values where
58/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
59/// multiplier = mantissa*2^shift for 32-bit scaling.
60static void computeMultiplierAndShiftTosaScale32(double scale,
61 int32_t &multiplier,
62 int32_t &shift) {
63
64 const double mantissa = std::frexp(x: scale, exponent: &shift);
65 auto shiftedM = std::round(x: mantissa * (int64_t(1) << 31));
66
67 // Can't be greater than 1.0.
68 assert(shiftedM <= (int64_t(1) << 31) &&
69 "Shifted mantissa exceeds 32 signed bits");
70 if (shiftedM == (int64_t(1) << 31)) {
71 shiftedM /= 2;
72 shift++;
73 }
74
75 // TOSA expects right shift to be positive, and embed (1 << 31) into right
76 // shift bits.
77 shift = (-shift) + 31;
78
79 assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
80 "Shifted mantissa exceeds 32-bit signed output type");
81
82 multiplier = static_cast<int32_t>(shiftedM);
83
84 // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
85 // The limit of 62 on shift allows the shift to be decomposed as
86 // two right shifts of 31.
87 if (shift > 62) {
88 // Shifting the multiplier by more than 32-bits is unnecessary.
89 multiplier = multiplier >> std::min<int32_t>(a: 31, b: shift - 62);
90 shift = 62;
91 }
92}
93
94/// Generates a quantized multiplier/shift from double.
95void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
96 int32_t &shift, int32_t scaleWidth) {
97
98 switch (scaleWidth) {
99 case 16:
100 computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
101 return;
102 case 32:
103 computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
104 return;
105 default:
106 assert(0 && "Unsupported Tosa quantized_scale regime specified!");
107 }
108}
109
110#define GET_UQTYPE(inputType) \
111 (llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType()))
112#define GET_QTYPE(inputType) \
113 (llvm::dyn_cast<quant::QuantizedType>((inputType).getElementType()))
114
115/// Method to build ConvOpQuantizationAttr, called from
116/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
117/// input_zp: input zeropoint
118/// weight_zp: weight zeropoint.
119ConvOpQuantizationAttr
120mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
121 Value weight) {
122
123 auto inputType = dyn_cast<ShapedType>(input.getType());
124 auto weightType = dyn_cast<ShapedType>(weight.getType());
125
126 if (!inputType || !weightType)
127 return nullptr;
128
129 auto inputQType = GET_UQTYPE(inputType);
130 auto weightPerTensorQType = GET_UQTYPE(weightType);
131 auto weightPerAxisQType =
132 dyn_cast<quant::UniformQuantizedPerAxisType>(weightType.getElementType());
133
134 // Weights must be either per-tensor quantized or per-axis quantized.
135 assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
136 "Weights must be either per-tensor or per-axis quantized");
137
138 // Either all quantized or all not quantized.
139 assert(!((bool)inputQType ^
140 ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
141 "Inputs and weights must be all quantized or all not quantized");
142
143 if (inputQType) {
144 int64_t inputZp = inputQType.getZeroPoint();
145 int64_t weightZp = 0;
146
147 if (weightPerTensorQType) {
148 weightZp = weightPerTensorQType.getZeroPoint();
149 } else if (weightPerAxisQType) {
150 weightZp = weightPerAxisQType.getZeroPoints().front();
151 }
152
153 return builder.getAttr<tosa::ConvOpQuantizationAttr>(inputZp, weightZp);
154 }
155
156 return nullptr;
157}
158
159/// Builds MatMulOpQuantizationAttr, called from
160/// MatMulOpQuantInfoBuilder:
161/// aZp: input a zeropoint
162/// bZp: input b zeropoint.
163MatMulOpQuantizationAttr
164mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
165 Value b) {
166
167 auto aType = dyn_cast<ShapedType>(a.getType());
168 auto bType = dyn_cast<ShapedType>(b.getType());
169
170 if (!aType || !bType)
171 return nullptr;
172
173 auto aQType = GET_UQTYPE(aType);
174 auto bQType = GET_UQTYPE(bType);
175
176 // A and B are either all quantized or all not quantized.
177 assert(!((bool)aQType ^ (bool)bQType) &&
178 "Matmul operands must be all quantized or all not quantized");
179
180 if (aQType) {
181 return builder.getAttr<tosa::MatMulOpQuantizationAttr>(
182 aQType.getZeroPoint(), bQType.getZeroPoint());
183 }
184
185 return nullptr;
186}
187
188/// Builds UnaryOpQuantizationAttr
189/// UnaryOpQuantInfoBuilder:
190/// inputZp: input zeropoint
191/// outputZp: output zeropoint.
192UnaryOpQuantizationAttr
193mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
194 Type outputRawType) {
195
196 auto inputType = dyn_cast<ShapedType>(input.getType());
197 auto outputType = dyn_cast<ShapedType>(outputRawType);
198
199 if (!inputType || !outputType)
200 return nullptr;
201
202 auto inputQType = GET_UQTYPE(inputType);
203 auto outputQType = GET_UQTYPE(outputType);
204
205 // Either all quantized or all not quantized.
206 assert(!((bool)inputQType ^ (bool)outputQType) &&
207 "Unary inputs/outputs must be all quantized or all not quantized");
208
209 if (inputQType) {
210 return builder.getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(),
211 outputQType.getZeroPoint());
212 }
213
214 return nullptr;
215}
216
217/// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
218/// inputZp: input zeropoint.
219PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
220 Value input) {
221
222 auto inputType = dyn_cast<ShapedType>(input.getType());
223
224 if (!inputType)
225 return nullptr;
226
227 auto inputQType = GET_UQTYPE(inputType);
228
229 if (inputQType) {
230 return builder.getAttr<tosa::PadOpQuantizationAttr>(
231 inputQType.getZeroPoint());
232 }
233
234 return nullptr;
235}
236
237/// Builds output type for a quantized ConvOp with the right bitwidth.
238/// This is called by the builder when dealing with quantized content.
239Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
240 Value input, Value weight) {
241
242 auto inputType = dyn_cast<ShapedType>(input.getType());
243 auto weightType = dyn_cast<ShapedType>(weight.getType());
244
245 assert(inputType && weightType &&
246 "Could not extract input or weight tensors from Conv op");
247
248 auto inputQType = GET_QTYPE(inputType);
249 auto weightQType = GET_QTYPE(weightType);
250
251 assert(inputQType && weightQType &&
252 "Could not extract input or weight tensor types from Conv op");
253
254 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
255 unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
256
257 auto outputShapedType = dyn_cast<ShapedType>(outputType);
258 assert(outputShapedType &&
259 "Could not extract output shape type from Conv op");
260
261 IntegerType accElementType;
262 if (inputBits == 16 && weightBits == 8)
263 accElementType = builder.getIntegerType(48);
264 else
265 accElementType = builder.getI32Type();
266 auto accType = outputShapedType.clone(accElementType);
267 return accType;
268}
269
270/// Builds Tosa quantization attributes from min/max values.
271Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
272 Attribute minAttr, Attribute maxAttr,
273 IntegerAttr quantBits, int filterQuantDim,
274 bool isSigned, BoolAttr narrowRange) {
275
276 quant::QuantizedType retType;
277
278 auto convfunc =
279 quant::ExpressedToQuantizedConverter::forInputType(inputType: inputDType);
280
281 auto minElems = dyn_cast<DenseFPElementsAttr>(Val&: minAttr);
282 auto maxElems = dyn_cast<DenseFPElementsAttr>(Val&: maxAttr);
283
284 SmallVector<double, 2> min, max;
285
286 // At least one is per-axis quantized elementsattr.
287 if (minElems || maxElems) {
288 // Must have the same number of elements.
289 if (minElems.getNumElements() != maxElems.getNumElements())
290 return {};
291 min.reserve(N: minElems.getNumElements());
292 max.reserve(N: maxElems.getNumElements());
293 for (auto i : minElems)
294 min.push_back(FloatAttr::getValueAsDouble(i));
295 for (auto i : maxElems)
296 max.push_back(FloatAttr::getValueAsDouble(i));
297 } else { // Just a single FP value.
298 auto minVal = dyn_cast<FloatAttr>(minAttr);
299 if (minVal)
300 min.push_back(Elt: minVal.getValueAsDouble());
301 else
302 return {};
303 auto maxVal = dyn_cast<FloatAttr>(maxAttr);
304 if (maxVal)
305 max.push_back(Elt: maxVal.getValueAsDouble());
306 else
307 return {};
308 }
309
310 if (min.size() == max.size()) {
311 if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
312 retType = quant::fakeQuantAttrsToType(
313 builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
314 narrowRange.getValue(), convfunc.expressedType, isSigned);
315 } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
316 auto shape = dyn_cast<ShapedType>(inputDType);
317 if (!shape)
318 return {};
319 if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
320 retType = quant::fakeQuantAttrsToType(
321 builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
322 max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
323 }
324 } else {
325 return {};
326 }
327 } else {
328 return {};
329 }
330
331 if (!retType)
332 return {};
333
334 return convfunc.convert(elementalType: retType);
335}
336
337/// Builds Tosa quantization attributes from min/max values.
338TypeAttr
339mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
340 Attribute minAttr, Attribute maxAttr,
341 IntegerAttr quantBits, int filterQuantDim,
342 bool isSigned, BoolAttr narrowRange) {
343
344 return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
345 maxAttr, quantBits, filterQuantDim,
346 isSigned, narrowRange));
347}
348

source code of mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp