1//===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Quant/FakeQuantSupport.h"
10#include "mlir/Dialect/Quant/QuantTypes.h"
11
12using namespace mlir;
13using namespace mlir::quant;
14
15static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
16 bool isSigned, MLIRContext *ctx,
17 Type &storageType, int64_t &qmin,
18 int64_t &qmax) {
19 // Hard-coded type mapping from TFLite.
20 if (numBits <= 8) {
21 storageType = IntegerType::get(ctx, 8);
22 if (isSigned) {
23 qmin = -128;
24 qmax = 127;
25 } else {
26 qmin = 0;
27 qmax = 255;
28 }
29 } else if (numBits <= 16) {
30 storageType = IntegerType::get(ctx, 16);
31 if (isSigned) {
32 qmin = -32768;
33 qmax = 32767;
34 } else {
35 qmin = 0;
36 qmax = 65535;
37 }
38 } else if (numBits <= 32) {
39 storageType = IntegerType::get(ctx, 32);
40 if (isSigned) {
41 qmin = std::numeric_limits<int32_t>::min();
42 qmax = std::numeric_limits<int32_t>::max();
43 } else {
44 qmin = std::numeric_limits<uint32_t>::min();
45 qmax = std::numeric_limits<uint32_t>::max();
46 }
47 } else {
48 return true;
49 }
50
51 // Handle narrowRange.
52 if (narrowRange) {
53 qmin += 1;
54 }
55 return false;
56}
57
58// This is a specific implementation of nudging:
59// If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
60// to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
61// point is derived from the shifted range, and the scale isn't changed. As
62// a consequence some values, which are supposed in the original [rmin, rmax]
63// range will be outside the shifted range and be clamped during quantization.
64// TODO: we should nudge the scale as well, but that requires the
65// fake quant op used in the training to use the nudged scale as well.
66static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
67 double rmax, double &scale,
68 int64_t &nudgedZeroPoint) {
69 // Determine the scale.
70 const double qminDouble = qmin;
71 const double qmaxDouble = qmax;
72 scale = (rmax - rmin) / (qmaxDouble - qminDouble);
73
74 // Zero point computation.
75 // In float, solve the affine equation for any known pair
76 // (real value, corresponding quantized value), of which, two such pairs
77 // are known: (rmin, qmin), (rmax, qmax).
78 // The arithmetic error on the zero point computed from either pair will be
79 // roughly machine_epsilon * (sum of absolute values of terms).
80 // Use the variant that adds the smaller error.
81 const double zeroPointFromMin = qminDouble - rmin / scale;
82 const double zeroPointFromMinError =
83 std::abs(x: qminDouble) + std::abs(x: rmin / scale);
84 const double zeroPointFromMax = qmaxDouble - rmax / scale;
85 const double zeroPointFromMaxError =
86 std::abs(x: qmaxDouble) + std::abs(x: rmax / scale);
87
88 const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
89 ? zeroPointFromMin
90 : zeroPointFromMax;
91
92 // Now nudge the zero point to be an integer.
93 nudgedZeroPoint = 0;
94 if (zeroPointDouble < qminDouble) {
95 nudgedZeroPoint = qmin;
96 } else if (zeroPointDouble > qmaxDouble) {
97 nudgedZeroPoint = qmax;
98 } else {
99 nudgedZeroPoint = round(x: zeroPointDouble);
100 }
101
102 // By construction, the nudged zero point should always be in range.
103 assert(nudgedZeroPoint >= qmin);
104 assert(nudgedZeroPoint <= qmax);
105}
106
107UniformQuantizedType
108mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
109 double rmax, bool narrowRange,
110 Type expressedType, bool isSigned) {
111 MLIRContext *ctx = expressedType.getContext();
112 unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
113 Type storageType;
114 int64_t qmin;
115 int64_t qmax;
116 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
117 qmin, qmax)) {
118 return (emitError(loc, message: "unsupported FakeQuant number of bits: ") << numBits,
119 nullptr);
120 }
121
122 // Special case where min/max is close enough. The tensor contents are all
123 // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
124 // points and dequantized to 0.0.
125 if (std::fabs(x: rmax - rmin) < std::numeric_limits<double>::epsilon()) {
126 return UniformQuantizedType::getChecked(
127 loc, args&: flags, args&: storageType, args&: expressedType, args: 1.0, args&: qmin, args&: qmin, args&: qmax);
128 }
129
130 double scale;
131 int64_t nudgedZeroPoint;
132 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
133
134 return UniformQuantizedType::getChecked(loc, args&: flags, args&: storageType,
135 args&: expressedType, args&: scale, args&: nudgedZeroPoint,
136 args&: qmin, args&: qmax);
137}
138
139UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
140 Location loc, unsigned numBits, int32_t quantizedDimension,
141 ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
142 Type expressedType, bool isSigned) {
143 size_t axisSize = rmins.size();
144 if (axisSize != rmaxs.size()) {
145 return (emitError(loc, message: "mismatched per-axis min and max size: ")
146 << axisSize << " vs. " << rmaxs.size(),
147 nullptr);
148 }
149
150 MLIRContext *ctx = expressedType.getContext();
151 Type storageType;
152 int64_t qmin;
153 int64_t qmax;
154 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
155 qmin, qmax)) {
156 return (emitError(loc, message: "unsupported FakeQuant number of bits: ") << numBits,
157 nullptr);
158 }
159
160 SmallVector<double, 4> scales;
161 SmallVector<int64_t, 4> zeroPoints;
162 scales.reserve(N: axisSize);
163 zeroPoints.reserve(N: axisSize);
164 for (size_t axis = 0; axis != axisSize; ++axis) {
165 double rmin = rmins[axis];
166 double rmax = rmaxs[axis];
167 if (std::fabs(x: rmax - rmin) < std::numeric_limits<double>::epsilon()) {
168 scales.push_back(Elt: 1.0);
169 zeroPoints.push_back(Elt: qmin);
170 continue;
171 }
172
173 double scale;
174 int64_t nudgedZeroPoint;
175 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
176 scales.push_back(Elt: scale);
177 zeroPoints.push_back(Elt: nudgedZeroPoint);
178 }
179
180 unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
181 return UniformQuantizedPerAxisType::getChecked(
182 loc, args&: flags, args&: storageType, args&: expressedType, args&: scales, args&: zeroPoints,
183 args&: quantizedDimension, args&: qmin, args&: qmax);
184}
185

source code of mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp