1 | //===- Quant.cpp - C Interface for Quant dialect --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/Dialect/Quant.h" |
10 | #include "mlir/CAPI/Registration.h" |
11 | #include "mlir/Dialect/Quant/QuantOps.h" |
12 | #include "mlir/Dialect/Quant/QuantTypes.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) |
17 | |
18 | //===---------------------------------------------------------------------===// |
19 | // QuantizedType |
20 | //===---------------------------------------------------------------------===// |
21 | |
22 | bool mlirTypeIsAQuantizedType(MlirType type) { |
23 | return isa<quant::QuantizedType>(Val: unwrap(c: type)); |
24 | } |
25 | |
26 | unsigned mlirQuantizedTypeGetSignedFlag() { |
27 | return quant::QuantizationFlags::Signed; |
28 | } |
29 | |
30 | int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned, |
31 | unsigned integralWidth) { |
32 | return quant::QuantizedType::getDefaultMinimumForInteger(isSigned, |
33 | integralWidth); |
34 | } |
35 | |
36 | int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned, |
37 | unsigned integralWidth) { |
38 | return quant::QuantizedType::getDefaultMaximumForInteger(isSigned, |
39 | integralWidth); |
40 | } |
41 | |
42 | MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { |
43 | return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type)).getExpressedType()); |
44 | } |
45 | |
46 | unsigned mlirQuantizedTypeGetFlags(MlirType type) { |
47 | return cast<quant::QuantizedType>(Val: unwrap(c: type)).getFlags(); |
48 | } |
49 | |
50 | bool mlirQuantizedTypeIsSigned(MlirType type) { |
51 | return cast<quant::QuantizedType>(Val: unwrap(c: type)).isSigned(); |
52 | } |
53 | |
54 | MlirType mlirQuantizedTypeGetStorageType(MlirType type) { |
55 | return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageType()); |
56 | } |
57 | |
58 | int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { |
59 | return cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageTypeMin(); |
60 | } |
61 | |
62 | int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { |
63 | return cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageTypeMax(); |
64 | } |
65 | |
66 | unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { |
67 | return cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageTypeIntegralWidth(); |
68 | } |
69 | |
70 | bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, |
71 | MlirType candidate) { |
72 | return cast<quant::QuantizedType>(Val: unwrap(c: type)) |
73 | .isCompatibleExpressedType(candidateExpressedType: unwrap(c: candidate)); |
74 | } |
75 | |
76 | MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { |
77 | return wrap(cpp: quant::QuantizedType::getQuantizedElementType(primitiveOrContainerType: unwrap(c: type))); |
78 | } |
79 | |
80 | MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, |
81 | MlirType candidate) { |
82 | return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type)) |
83 | .castFromStorageType(candidateType: unwrap(c: candidate))); |
84 | } |
85 | |
86 | MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { |
87 | return wrap(cpp: quant::QuantizedType::castToStorageType( |
88 | quantizedType: cast<quant::QuantizedType>(Val: unwrap(c: type)))); |
89 | } |
90 | |
91 | MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, |
92 | MlirType candidate) { |
93 | return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type)) |
94 | .castFromExpressedType(candidateType: unwrap(c: candidate))); |
95 | } |
96 | |
97 | MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { |
98 | return wrap(cpp: quant::QuantizedType::castToExpressedType(quantizedType: unwrap(c: type))); |
99 | } |
100 | |
101 | MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, |
102 | MlirType candidate) { |
103 | return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type)) |
104 | .castExpressedToStorageType(candidateType: unwrap(c: candidate))); |
105 | } |
106 | |
107 | //===---------------------------------------------------------------------===// |
108 | // AnyQuantizedType |
109 | //===---------------------------------------------------------------------===// |
110 | |
111 | bool mlirTypeIsAAnyQuantizedType(MlirType type) { |
112 | return isa<quant::AnyQuantizedType>(Val: unwrap(c: type)); |
113 | } |
114 | |
115 | MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, |
116 | MlirType expressedType, int64_t storageTypeMin, |
117 | int64_t storageTypeMax) { |
118 | return wrap(cpp: quant::AnyQuantizedType::get(flags, storageType: unwrap(c: storageType), |
119 | expressedType: unwrap(c: expressedType), |
120 | storageTypeMin, storageTypeMax)); |
121 | } |
122 | |
123 | //===---------------------------------------------------------------------===// |
124 | // UniformQuantizedType |
125 | //===---------------------------------------------------------------------===// |
126 | |
127 | bool mlirTypeIsAUniformQuantizedType(MlirType type) { |
128 | return isa<quant::UniformQuantizedType>(Val: unwrap(c: type)); |
129 | } |
130 | |
131 | MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, |
132 | MlirType expressedType, double scale, |
133 | int64_t zeroPoint, int64_t storageTypeMin, |
134 | int64_t storageTypeMax) { |
135 | return wrap(cpp: quant::UniformQuantizedType::get( |
136 | flags, storageType: unwrap(c: storageType), expressedType: unwrap(c: expressedType), scale, zeroPoint, |
137 | storageTypeMin, storageTypeMax)); |
138 | } |
139 | |
140 | double mlirUniformQuantizedTypeGetScale(MlirType type) { |
141 | return cast<quant::UniformQuantizedType>(Val: unwrap(c: type)).getScale(); |
142 | } |
143 | |
144 | int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { |
145 | return cast<quant::UniformQuantizedType>(Val: unwrap(c: type)).getZeroPoint(); |
146 | } |
147 | |
148 | bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { |
149 | return cast<quant::UniformQuantizedType>(Val: unwrap(c: type)).isFixedPoint(); |
150 | } |
151 | |
152 | //===---------------------------------------------------------------------===// |
153 | // UniformQuantizedPerAxisType |
154 | //===---------------------------------------------------------------------===// |
155 | |
156 | bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { |
157 | return isa<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type)); |
158 | } |
159 | |
160 | MlirType mlirUniformQuantizedPerAxisTypeGet( |
161 | unsigned flags, MlirType storageType, MlirType expressedType, |
162 | intptr_t nDims, double *scales, int64_t *zeroPoints, |
163 | int32_t quantizedDimension, int64_t storageTypeMin, |
164 | int64_t storageTypeMax) { |
165 | return wrap(cpp: quant::UniformQuantizedPerAxisType::get( |
166 | flags, storageType: unwrap(c: storageType), expressedType: unwrap(c: expressedType), |
167 | scales: llvm::ArrayRef(scales, nDims), zeroPoints: llvm::ArrayRef(zeroPoints, nDims), |
168 | quantizedDimension, storageTypeMin, storageTypeMax)); |
169 | } |
170 | |
171 | intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { |
172 | return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type)) |
173 | .getScales() |
174 | .size(); |
175 | } |
176 | |
177 | double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { |
178 | return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type)) |
179 | .getScales()[pos]; |
180 | } |
181 | |
182 | int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, |
183 | intptr_t pos) { |
184 | return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type)) |
185 | .getZeroPoints()[pos]; |
186 | } |
187 | |
188 | int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { |
189 | return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type)) |
190 | .getQuantizedDimension(); |
191 | } |
192 | |
193 | bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { |
194 | return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type)).isFixedPoint(); |
195 | } |
196 | |
197 | //===---------------------------------------------------------------------===// |
198 | // CalibratedQuantizedType |
199 | //===---------------------------------------------------------------------===// |
200 | |
201 | bool mlirTypeIsACalibratedQuantizedType(MlirType type) { |
202 | return isa<quant::CalibratedQuantizedType>(Val: unwrap(c: type)); |
203 | } |
204 | |
205 | MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, |
206 | double max) { |
207 | return wrap( |
208 | cpp: quant::CalibratedQuantizedType::get(expressedType: unwrap(c: expressedType), min, max)); |
209 | } |
210 | |
211 | double mlirCalibratedQuantizedTypeGetMin(MlirType type) { |
212 | return cast<quant::CalibratedQuantizedType>(Val: unwrap(c: type)).getMin(); |
213 | } |
214 | |
215 | double mlirCalibratedQuantizedTypeGetMax(MlirType type) { |
216 | return cast<quant::CalibratedQuantizedType>(Val: unwrap(c: type)).getMax(); |
217 | } |
218 | |