1 | //===-- QuantUtils.h - TOSA numerical support declarations ------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Function declarations for TOSA numerical support functions and quantization |
10 | // attribute builders |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_DIALECT_TOSA_UTILS_QUANTUTILS_H |
15 | #define MLIR_DIALECT_TOSA_UTILS_QUANTUTILS_H |
16 | |
17 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
18 | |
19 | #include "mlir/Dialect/Quant/Utils/FakeQuantSupport.h" |
20 | #include "mlir/Dialect/Quant/Utils/UniformSupport.h" |
21 | |
22 | namespace mlir { |
23 | namespace tosa { |
24 | |
25 | //===----------------------------------------------------------------------===// |
26 | // Utility functions to support quantization handling in Tosa. |
27 | //===----------------------------------------------------------------------===// |
28 | |
29 | /// From a scale value, computes multiplier and shift values |
30 | /// for 16 or 32-bit scale widths. |
31 | bool computeMultiplierAndShift(double scale, int32_t &multiplier, |
32 | int32_t &shift, int32_t scaleWidth); |
33 | |
34 | // Return a const value for array of IntType vec |
35 | template <typename IntType> |
36 | Value getConstTensorInt(OpBuilder &builder, Location loc, |
37 | ArrayRef<IntType> vec) { |
38 | static_assert( |
39 | std::is_same<IntType, int8_t>::value || |
40 | std::is_same<IntType, int16_t>::value || |
41 | std::is_same<IntType, int32_t>::value, |
42 | "getConstTensorInt only supports int8_t, int16_t, and int32_t types." ); |
43 | |
44 | int64_t count = vec.size(); |
45 | assert(count > 0 && "Vector must not be empty" ); |
46 | auto element_type = builder.getIntegerType(sizeof(IntType) * 8); |
47 | mlir::RankedTensorType const_type = |
48 | RankedTensorType::get({count}, element_type); |
49 | mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec); |
50 | auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr); |
51 | return const_op.getResult(); |
52 | } |
53 | |
54 | //// Builds ConvOpQuantizationAttr from input and weight. |
55 | ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, |
56 | Value input, Value weight); |
57 | |
58 | std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input, |
59 | Value weight); |
60 | |
61 | //// Builds MatMulOpQuantizationAttr for MatMul operations from A and B. |
62 | MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, |
63 | Value a, Value b); |
64 | |
65 | //// Builds UnaryOpQuantizationAttr for unary operations from input values. |
66 | UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, |
67 | Value input, |
68 | Type outputRawType); |
69 | |
70 | //// Builds PadOpQuantizationAttr for pad operations from input values. |
71 | PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, |
72 | Value input); |
73 | |
74 | //// construct ConvOp output type with correct bitwidth based on input/weight |
75 | /// width. |
76 | Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, |
77 | Value weight); |
78 | |
79 | /// Builds Tosa quantization attributes from min/max values. |
80 | Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, |
81 | Attribute maxAttr, IntegerAttr quantBits, |
82 | int filterQuantDim, bool isSigned, |
83 | BoolAttr narrowRange); |
84 | |
85 | /// Builds Tosa quantization attributes from min/max values. |
86 | TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType, |
87 | Attribute minAttr, Attribute maxAttr, |
88 | IntegerAttr quantBits, int filterQuantDim, |
89 | bool isSigned, BoolAttr narrowRange); |
90 | |
91 | } // namespace tosa |
92 | } // namespace mlir |
93 | |
94 | #endif // MLIR_DIALECT_TOSA_UTILS_QUANTUTILS_H |
95 | |