1//===- QuantUtils.cpp -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains TOSA numerical support functions and quantization
10// attribute builders.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
15
16using namespace mlir;
17using namespace mlir::tosa;
18
19/// From a scale value, generates multiplier and shift values where
20/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
21/// multiplier = mantissa*2^shift for 16-bit scaling.
22static void computeMultiplierAndShiftTosaScale16(double scale,
23 int32_t &multiplier,
24 int32_t &shift) {
25
26 const double mantissa = std::frexp(x: scale, exponent: &shift);
27 auto shiftedM = std::round(x: mantissa * (int64_t(1) << 15));
28
29 // Can't be greater than 1.0.
30 assert(shiftedM <= (int64_t(1) << 15) &&
31 "Shifted mantissa exceeds 16 signed bits");
32
33 if (shiftedM == (int64_t(1) << 15)) {
34 shiftedM /= 2;
35 shift++;
36 }
37
38 // TOSA expects right shift to be positive and embed (1 << 15) into right
39 // shift bits.
40 shift = (-shift) + 15;
41
42 assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
43 "Shifted mantissa exceeds 32-bit signed output type");
44
45 multiplier = static_cast<int32_t>(shiftedM);
46
47 // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
48 // The limit of 62 on shift allows the shift to be decomposed as
49 // two right shifts of 31.
50 if (shift > 62) {
51 // Shifting the multiplier by more than 31-bits is unnecessary.
52 multiplier = multiplier >> std::min<int32_t>(a: 31, b: shift - 62);
53 shift = 62;
54 }
55}
56
57/// From a scale value, generates multiplier and shift values where
58/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
59/// multiplier = mantissa*2^shift for 32-bit scaling.
60static void computeMultiplierAndShiftTosaScale32(double scale,
61 int32_t &multiplier,
62 int32_t &shift) {
63
64 const double mantissa = std::frexp(x: scale, exponent: &shift);
65 auto shiftedM = std::round(x: mantissa * (int64_t(1) << 31));
66
67 // Can't be greater than 1.0.
68 assert(shiftedM <= (int64_t(1) << 31) &&
69 "Shifted mantissa exceeds 32 signed bits");
70 if (shiftedM == (int64_t(1) << 31)) {
71 shiftedM /= 2;
72 shift++;
73 }
74
75 // TOSA expects right shift to be positive, and embed (1 << 31) into right
76 // shift bits.
77 shift = (-shift) + 31;
78
79 assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
80 "Shifted mantissa exceeds 32-bit signed output type");
81
82 multiplier = static_cast<int32_t>(shiftedM);
83
84 // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
85 // The limit of 62 on shift allows the shift to be decomposed as
86 // two right shifts of 31.
87 if (shift > 62) {
88 // Shifting the multiplier by more than 32-bits is unnecessary.
89 multiplier = multiplier >> std::min<int32_t>(a: 31, b: shift - 62);
90 shift = 62;
91 }
92}
93
94/// Generates a quantized multiplier/shift from double.
95bool mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
96 int32_t &shift, int32_t scaleWidth) {
97
98 switch (scaleWidth) {
99 case 16:
100 computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
101
102 // In some cases computeMultiplierAndShiftTosaScale16 can return
103 // a value less then 2, which is not valid in the TOSA spec.
104 return (!(shift < 2));
105 case 32:
106 computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
107
108 // In some cases computeMultiplierAndShiftTosaScale32 can return
109 // a value less then 2, which is not valid in the TOSA spec.
110 return (!(shift < 2));
111 default:
112 assert(0 && "Unsupported Tosa quantized_scale regime specified!");
113 return false;
114 }
115}
116
117#define GET_UQTYPE(inputType) \
118 (llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType()))
119#define GET_QTYPE(inputType) \
120 (llvm::dyn_cast<quant::QuantizedType>((inputType).getElementType()))
121
122static std::optional<std::pair<std::int64_t, std::int64_t>>
123getConvZeroPoints(Value input, Value weight) {
124
125 auto inputType = dyn_cast<ShapedType>(input.getType());
126 auto weightType = dyn_cast<ShapedType>(weight.getType());
127
128 if (!inputType || !weightType)
129 return std::nullopt;
130
131 auto inputQType = GET_UQTYPE(inputType);
132 auto weightPerTensorQType = GET_UQTYPE(weightType);
133 auto weightPerAxisQType =
134 dyn_cast<quant::UniformQuantizedPerAxisType>(weightType.getElementType());
135
136 // Weights must be either per-tensor quantized or per-axis quantized.
137 assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
138 "Weights must be either per-tensor or per-axis quantized");
139
140 // Either all quantized or all not quantized.
141 assert(!((bool)inputQType ^
142 ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
143 "Inputs and weights must be all quantized or all not quantized");
144
145 if (inputQType) {
146 int64_t inputZp = inputQType.getZeroPoint();
147 int64_t weightZp = 0;
148
149 if (weightPerTensorQType) {
150 weightZp = weightPerTensorQType.getZeroPoint();
151 } else if (weightPerAxisQType) {
152 weightZp = weightPerAxisQType.getZeroPoints().front();
153 }
154
155 return std::make_pair(x&: inputZp, y&: weightZp);
156 }
157
158 return std::nullopt;
159}
160
161std::pair<Value, Value>
162mlir::tosa::createZPsAsConst(OpBuilder &builder, Value input, Value weight) {
163 std::int64_t inputZp, weightZp;
164
165 auto inputEType = getElementTypeOrSelf(type: input.getType());
166 auto weightEType = getElementTypeOrSelf(type: weight.getType());
167
168 if (mlir::isa<FloatType>(Val: inputEType) && mlir::isa<FloatType>(Val: weightEType)) {
169 inputZp = 0;
170 weightZp = 0;
171 } else {
172 auto maybeZps = getConvZeroPoints(input, weight);
173 if (!maybeZps.has_value())
174 return {};
175
176 inputZp = maybeZps->first;
177 weightZp = maybeZps->second;
178 }
179
180 auto maybeInputZpValue =
181 createZeroPointTensor(builder, loc: input.getLoc(), srcElemType: inputEType, zp: inputZp);
182 if (!maybeInputZpValue.has_value())
183 return {};
184
185 auto maybeWeightZpValue =
186 createZeroPointTensor(builder, loc: weight.getLoc(), srcElemType: weightEType, zp: weightZp);
187 if (!maybeWeightZpValue.has_value())
188 return {};
189
190 return std::make_pair(x&: *maybeInputZpValue, y&: *maybeWeightZpValue);
191}
192
193/// Method to build ConvOpQuantizationAttr, called from
194/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
195/// input_zp: input zeropoint
196/// weight_zp: weight zeropoint.
197ConvOpQuantizationAttr
198mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
199 Value weight) {
200
201 auto maybeZps = getConvZeroPoints(input, weight);
202 if (!maybeZps.has_value())
203 return nullptr;
204
205 return builder.getAttr<tosa::ConvOpQuantizationAttr>(maybeZps->first,
206 maybeZps->second);
207}
208
209/// Builds MatMulOpQuantizationAttr, called from
210/// MatMulOpQuantInfoBuilder:
211/// aZp: input a zeropoint
212/// bZp: input b zeropoint.
213MatMulOpQuantizationAttr
214mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
215 Value b) {
216
217 auto aType = dyn_cast<ShapedType>(a.getType());
218 auto bType = dyn_cast<ShapedType>(b.getType());
219
220 if (!aType || !bType)
221 return nullptr;
222
223 auto aQType = GET_UQTYPE(aType);
224 auto bQType = GET_UQTYPE(bType);
225
226 // A and B are either all quantized or all not quantized.
227 assert(!((bool)aQType ^ (bool)bQType) &&
228 "Matmul operands must be all quantized or all not quantized");
229
230 if (aQType) {
231 return builder.getAttr<tosa::MatMulOpQuantizationAttr>(
232 aQType.getZeroPoint(), bQType.getZeroPoint());
233 }
234
235 return nullptr;
236}
237
238/// Builds UnaryOpQuantizationAttr
239/// UnaryOpQuantInfoBuilder:
240/// inputZp: input zeropoint
241/// outputZp: output zeropoint.
242UnaryOpQuantizationAttr
243mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
244 Type outputRawType) {
245
246 auto inputType = dyn_cast<ShapedType>(input.getType());
247 auto outputType = dyn_cast<ShapedType>(outputRawType);
248
249 if (!inputType || !outputType)
250 return nullptr;
251
252 auto inputQType = GET_UQTYPE(inputType);
253 auto outputQType = GET_UQTYPE(outputType);
254
255 // Either all quantized or all not quantized.
256 assert(!((bool)inputQType ^ (bool)outputQType) &&
257 "Unary inputs/outputs must be all quantized or all not quantized");
258
259 if (inputQType) {
260 return builder.getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(),
261 outputQType.getZeroPoint());
262 }
263
264 return nullptr;
265}
266
267/// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
268/// inputZp: input zeropoint.
269PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
270 Value input) {
271
272 auto inputType = dyn_cast<ShapedType>(input.getType());
273
274 if (!inputType)
275 return nullptr;
276
277 auto inputQType = GET_UQTYPE(inputType);
278
279 if (inputQType) {
280 return builder.getAttr<tosa::PadOpQuantizationAttr>(
281 inputQType.getZeroPoint());
282 }
283
284 return nullptr;
285}
286
287/// Builds output type for a quantized ConvOp with the right bitwidth.
288/// This is called by the builder when dealing with quantized content.
289Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
290 Value input, Value weight) {
291
292 auto inputType = dyn_cast<ShapedType>(input.getType());
293 auto weightType = dyn_cast<ShapedType>(weight.getType());
294
295 assert(inputType && weightType &&
296 "Could not extract input or weight tensors from Conv op");
297
298 auto inputQType = GET_QTYPE(inputType);
299 auto weightQType = GET_QTYPE(weightType);
300
301 assert(inputQType && weightQType &&
302 "Could not extract input or weight tensor types from Conv op");
303
304 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
305 unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
306
307 auto outputShapedType = dyn_cast<ShapedType>(outputType);
308 assert(outputShapedType &&
309 "Could not extract output shape type from Conv op");
310
311 IntegerType accElementType;
312 if (inputBits == 16 && weightBits == 8)
313 accElementType = builder.getIntegerType(48);
314 else
315 accElementType = builder.getI32Type();
316 auto accType = outputShapedType.clone(accElementType);
317 return accType;
318}
319
320/// Builds Tosa quantization attributes from min/max values.
321Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
322 Attribute minAttr, Attribute maxAttr,
323 IntegerAttr quantBits, int filterQuantDim,
324 bool isSigned, BoolAttr narrowRange) {
325
326 quant::QuantizedType retType;
327
328 auto convfunc =
329 quant::ExpressedToQuantizedConverter::forInputType(inputType: inputDType);
330
331 auto minElems = dyn_cast<DenseFPElementsAttr>(Val&: minAttr);
332 auto maxElems = dyn_cast<DenseFPElementsAttr>(Val&: maxAttr);
333
334 SmallVector<double, 2> min, max;
335
336 // At least one is per-axis quantized elementsattr.
337 if (minElems || maxElems) {
338 // Must have the same number of elements.
339 if (minElems.getNumElements() != maxElems.getNumElements())
340 return {};
341 min.reserve(N: minElems.getNumElements());
342 max.reserve(N: maxElems.getNumElements());
343 for (auto i : minElems)
344 min.push_back(FloatAttr::getValueAsDouble(i));
345 for (auto i : maxElems)
346 max.push_back(FloatAttr::getValueAsDouble(i));
347 } else { // Just a single FP value.
348 auto minVal = dyn_cast<FloatAttr>(minAttr);
349 if (minVal)
350 min.push_back(Elt: minVal.getValueAsDouble());
351 else
352 return {};
353 auto maxVal = dyn_cast<FloatAttr>(maxAttr);
354 if (maxVal)
355 max.push_back(Elt: maxVal.getValueAsDouble());
356 else
357 return {};
358 }
359
360 if (min.size() == max.size()) {
361 if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
362 retType = quant::fakeQuantAttrsToType(
363 builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
364 narrowRange.getValue(), convfunc.expressedType, isSigned);
365 } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
366 auto shape = dyn_cast<ShapedType>(inputDType);
367 if (!shape)
368 return {};
369 if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
370 retType = quant::fakeQuantAttrsToType(
371 builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
372 max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
373 }
374 } else {
375 return {};
376 }
377 } else {
378 return {};
379 }
380
381 if (!retType)
382 return {};
383
384 return convfunc.convert(elementalType: retType);
385}
386
387/// Builds Tosa quantization attributes from min/max values.
388TypeAttr
389mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
390 Attribute minAttr, Attribute maxAttr,
391 IntegerAttr quantBits, int filterQuantDim,
392 bool isSigned, BoolAttr narrowRange) {
393
394 return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
395 maxAttr, quantBits, filterQuantDim,
396 isSigned, narrowRange));
397}
398

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp