1//===- Quant.cpp - C Interface for Quant dialect --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/Dialect/Quant.h"
10#include "mlir-c/BuiltinAttributes.h"
11#include "mlir/CAPI/Registration.h"
12#include "mlir/Dialect/Quant/IR/Quant.h"
13#include "mlir/Dialect/Quant/IR/QuantTypes.h"
14
15using namespace mlir;
16
17MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect)
18
19//===---------------------------------------------------------------------===//
20// QuantizedType
21//===---------------------------------------------------------------------===//
22
23bool mlirTypeIsAQuantizedType(MlirType type) {
24 return isa<quant::QuantizedType>(Val: unwrap(c: type));
25}
26
27unsigned mlirQuantizedTypeGetSignedFlag() {
28 return quant::QuantizationFlags::Signed;
29}
30
31int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,
32 unsigned integralWidth) {
33 return quant::QuantizedType::getDefaultMinimumForInteger(isSigned,
34 integralWidth);
35}
36
37int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,
38 unsigned integralWidth) {
39 return quant::QuantizedType::getDefaultMaximumForInteger(isSigned,
40 integralWidth);
41}
42
43MlirType mlirQuantizedTypeGetExpressedType(MlirType type) {
44 return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type)).getExpressedType());
45}
46
47unsigned mlirQuantizedTypeGetFlags(MlirType type) {
48 return cast<quant::QuantizedType>(Val: unwrap(c: type)).getFlags();
49}
50
51bool mlirQuantizedTypeIsSigned(MlirType type) {
52 return cast<quant::QuantizedType>(Val: unwrap(c: type)).isSigned();
53}
54
55MlirType mlirQuantizedTypeGetStorageType(MlirType type) {
56 return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageType());
57}
58
59int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) {
60 return cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageTypeMin();
61}
62
63int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) {
64 return cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageTypeMax();
65}
66
67unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) {
68 return cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageTypeIntegralWidth();
69}
70
71bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,
72 MlirType candidate) {
73 return cast<quant::QuantizedType>(Val: unwrap(c: type))
74 .isCompatibleExpressedType(candidateExpressedType: unwrap(c: candidate));
75}
76
77MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) {
78 return wrap(cpp: quant::QuantizedType::getQuantizedElementType(primitiveOrContainerType: unwrap(c: type)));
79}
80
81MlirType mlirQuantizedTypeCastFromStorageType(MlirType type,
82 MlirType candidate) {
83 return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type))
84 .castFromStorageType(candidateType: unwrap(c: candidate)));
85}
86
87MlirType mlirQuantizedTypeCastToStorageType(MlirType type) {
88 return wrap(cpp: quant::QuantizedType::castToStorageType(
89 quantizedType: cast<quant::QuantizedType>(Val: unwrap(c: type))));
90}
91
92MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type,
93 MlirType candidate) {
94 return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type))
95 .castFromExpressedType(candidateType: unwrap(c: candidate)));
96}
97
98MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) {
99 return wrap(cpp: quant::QuantizedType::castToExpressedType(quantizedType: unwrap(c: type)));
100}
101
102MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type,
103 MlirType candidate) {
104 return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type))
105 .castExpressedToStorageType(candidateType: unwrap(c: candidate)));
106}
107
108//===---------------------------------------------------------------------===//
109// AnyQuantizedType
110//===---------------------------------------------------------------------===//
111
112bool mlirTypeIsAAnyQuantizedType(MlirType type) {
113 return isa<quant::AnyQuantizedType>(Val: unwrap(c: type));
114}
115
116MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
117 MlirType expressedType, int64_t storageTypeMin,
118 int64_t storageTypeMax) {
119 return wrap(cpp: quant::AnyQuantizedType::get(flags, storageType: unwrap(c: storageType),
120 expressedType: unwrap(c: expressedType),
121 storageTypeMin, storageTypeMax));
122}
123
124//===---------------------------------------------------------------------===//
125// UniformQuantizedType
126//===---------------------------------------------------------------------===//
127
128bool mlirTypeIsAUniformQuantizedType(MlirType type) {
129 return isa<quant::UniformQuantizedType>(Val: unwrap(c: type));
130}
131
132MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
133 MlirType expressedType, double scale,
134 int64_t zeroPoint, int64_t storageTypeMin,
135 int64_t storageTypeMax) {
136 return wrap(cpp: quant::UniformQuantizedType::get(
137 flags, storageType: unwrap(c: storageType), expressedType: unwrap(c: expressedType), scale, zeroPoint,
138 storageTypeMin, storageTypeMax));
139}
140
141double mlirUniformQuantizedTypeGetScale(MlirType type) {
142 return cast<quant::UniformQuantizedType>(Val: unwrap(c: type)).getScale();
143}
144
145int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) {
146 return cast<quant::UniformQuantizedType>(Val: unwrap(c: type)).getZeroPoint();
147}
148
149bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) {
150 return cast<quant::UniformQuantizedType>(Val: unwrap(c: type)).isFixedPoint();
151}
152
153//===---------------------------------------------------------------------===//
154// UniformQuantizedPerAxisType
155//===---------------------------------------------------------------------===//
156
157bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
158 return isa<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type));
159}
160
161MlirType mlirUniformQuantizedPerAxisTypeGet(
162 unsigned flags, MlirType storageType, MlirType expressedType,
163 intptr_t nDims, double *scales, int64_t *zeroPoints,
164 int32_t quantizedDimension, int64_t storageTypeMin,
165 int64_t storageTypeMax) {
166 return wrap(cpp: quant::UniformQuantizedPerAxisType::get(
167 flags, storageType: unwrap(c: storageType), expressedType: unwrap(c: expressedType),
168 scales: llvm::ArrayRef(scales, nDims), zeroPoints: llvm::ArrayRef(zeroPoints, nDims),
169 quantizedDimension, storageTypeMin, storageTypeMax));
170}
171
172intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) {
173 return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type))
174 .getScales()
175 .size();
176}
177
178double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) {
179 return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type))
180 .getScales()[pos];
181}
182
183int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,
184 intptr_t pos) {
185 return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type))
186 .getZeroPoints()[pos];
187}
188
189int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) {
190 return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type))
191 .getQuantizedDimension();
192}
193
194bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
195 return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type)).isFixedPoint();
196}
197
198//===---------------------------------------------------------------------===//
199// UniformQuantizedSubChannelType
200//===---------------------------------------------------------------------===//
201
202bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) {
203 return isa<quant::UniformQuantizedSubChannelType>(Val: unwrap(c: type));
204}
205
206MlirType mlirUniformQuantizedSubChannelTypeGet(
207 unsigned flags, MlirType storageType, MlirType expressedType,
208 MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims,
209 int32_t *quantizedDimensions, int64_t *blockSizes, int64_t storageTypeMin,
210 int64_t storageTypeMax) {
211 auto scales = dyn_cast<mlir::DenseElementsAttr>(Val: unwrap(c: scalesAttr));
212 auto zeroPoints = dyn_cast<mlir::DenseElementsAttr>(Val: unwrap(c: zeroPointsAttr));
213
214 if (!scales || !zeroPoints) {
215 return {};
216 }
217
218 return wrap(cpp: quant::UniformQuantizedSubChannelType::get(
219 flags, storageType: unwrap(c: storageType), expressedType: unwrap(c: expressedType), scales, zeroPoints,
220 quantizedDimensions: llvm::ArrayRef<int32_t>(quantizedDimensions, nDims),
221 blockSizes: llvm::ArrayRef<int64_t>(blockSizes, nDims), storageTypeMin,
222 storageTypeMax));
223}
224
225intptr_t mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type) {
226 return cast<quant::UniformQuantizedSubChannelType>(Val: unwrap(c: type))
227 .getBlockSizes()
228 .size();
229}
230
231int32_t mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
232 intptr_t pos) {
233 return cast<quant::UniformQuantizedSubChannelType>(Val: unwrap(c: type))
234 .getQuantizedDimensions()[pos];
235}
236
237int64_t mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type,
238 intptr_t pos) {
239 return cast<quant::UniformQuantizedSubChannelType>(Val: unwrap(c: type))
240 .getBlockSizes()[pos];
241}
242
243MlirAttribute mlirUniformQuantizedSubChannelTypeGetScales(MlirType type) {
244 return wrap(
245 cpp: cast<quant::UniformQuantizedSubChannelType>(Val: unwrap(c: type)).getScales());
246}
247
248MlirAttribute mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type) {
249 return wrap(cpp: cast<quant::UniformQuantizedSubChannelType>(Val: unwrap(c: type))
250 .getZeroPoints());
251}
252
253//===---------------------------------------------------------------------===//
254// CalibratedQuantizedType
255//===---------------------------------------------------------------------===//
256
257bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
258 return isa<quant::CalibratedQuantizedType>(Val: unwrap(c: type));
259}
260
261MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
262 double max) {
263 return wrap(
264 cpp: quant::CalibratedQuantizedType::get(expressedType: unwrap(c: expressedType), min, max));
265}
266
267double mlirCalibratedQuantizedTypeGetMin(MlirType type) {
268 return cast<quant::CalibratedQuantizedType>(Val: unwrap(c: type)).getMin();
269}
270
271double mlirCalibratedQuantizedTypeGetMax(MlirType type) {
272 return cast<quant::CalibratedQuantizedType>(Val: unwrap(c: type)).getMax();
273}
274

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/CAPI/Dialect/Quant.cpp