1//===-- QuantUtils.h - TOSA numerical support declarations ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Function declarations for TOSA numerical support functions and quantization
10// attribute builders
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_DIALECT_TOSA_UTILS_QUANTUTILS_H
15#define MLIR_DIALECT_TOSA_UTILS_QUANTUTILS_H
16
17#include "mlir/Dialect/Tosa/IR/TosaOps.h"
18
19#include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h"
20#include "mlir/Dialect/Quant/Utils/UniformSupport.h"
21
22namespace mlir {
23namespace tosa {
24
25//===----------------------------------------------------------------------===//
26// Utility functions to support quantization handling in Tosa.
27//===----------------------------------------------------------------------===//
28
29/// From a scale value, computes multiplier and shift values
30/// for 16 or 32-bit scale widths.
31bool computeMultiplierAndShift(double scale, int32_t &multiplier,
32 int32_t &shift, int32_t scaleWidth);
33
34// Return a const value for array of IntType vec
35template <typename IntType>
36Value getConstTensorInt(OpBuilder &builder, Location loc,
37 ArrayRef<IntType> vec) {
38 static_assert(
39 std::is_same<IntType, int8_t>::value ||
40 std::is_same<IntType, int16_t>::value ||
41 std::is_same<IntType, int32_t>::value,
42 "getConstTensorInt only supports int8_t, int16_t, and int32_t types.");
43
44 int64_t count = vec.size();
45 assert(count > 0 && "Vector must not be empty");
46 auto element_type = builder.getIntegerType(sizeof(IntType) * 8);
47 mlir::RankedTensorType const_type =
48 RankedTensorType::get({count}, element_type);
49 mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
50 auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
51 return const_op.getResult();
52}
53
54//// Builds ConvOpQuantizationAttr from input and weight.
55ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
56 Value input, Value weight);
57
58std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
59 Value weight);
60
61//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
62MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
63 Value a, Value b);
64
65//// Builds UnaryOpQuantizationAttr for unary operations from input values.
66UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder,
67 Value input,
68 Type outputRawType);
69
70//// Builds PadOpQuantizationAttr for pad operations from input values.
71PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder,
72 Value input);
73
74//// construct ConvOp output type with correct bitwidth based on input/weight
75/// width.
76Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input,
77 Value weight);
78
79/// Builds Tosa quantization attributes from min/max values.
80Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr,
81 Attribute maxAttr, IntegerAttr quantBits,
82 int filterQuantDim, bool isSigned,
83 BoolAttr narrowRange);
84
85/// Builds Tosa quantization attributes from min/max values.
86TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType,
87 Attribute minAttr, Attribute maxAttr,
88 IntegerAttr quantBits, int filterQuantDim,
89 bool isSigned, BoolAttr narrowRange);
90
91} // namespace tosa
92} // namespace mlir
93
94#endif // MLIR_DIALECT_TOSA_UTILS_QUANTUTILS_H
95

source code of mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h