1//===- Quant.cpp - C Interface for Quant dialect --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/Dialect/Quant.h"
10#include "mlir/CAPI/Registration.h"
11#include "mlir/Dialect/Quant/QuantOps.h"
12#include "mlir/Dialect/Quant/QuantTypes.h"
13
14using namespace mlir;
15
16MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect)
17
18//===---------------------------------------------------------------------===//
19// QuantizedType
20//===---------------------------------------------------------------------===//
21
22bool mlirTypeIsAQuantizedType(MlirType type) {
23 return isa<quant::QuantizedType>(Val: unwrap(c: type));
24}
25
26unsigned mlirQuantizedTypeGetSignedFlag() {
27 return quant::QuantizationFlags::Signed;
28}
29
30int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned,
31 unsigned integralWidth) {
32 return quant::QuantizedType::getDefaultMinimumForInteger(isSigned,
33 integralWidth);
34}
35
36int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,
37 unsigned integralWidth) {
38 return quant::QuantizedType::getDefaultMaximumForInteger(isSigned,
39 integralWidth);
40}
41
42MlirType mlirQuantizedTypeGetExpressedType(MlirType type) {
43 return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type)).getExpressedType());
44}
45
46unsigned mlirQuantizedTypeGetFlags(MlirType type) {
47 return cast<quant::QuantizedType>(Val: unwrap(c: type)).getFlags();
48}
49
50bool mlirQuantizedTypeIsSigned(MlirType type) {
51 return cast<quant::QuantizedType>(Val: unwrap(c: type)).isSigned();
52}
53
54MlirType mlirQuantizedTypeGetStorageType(MlirType type) {
55 return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageType());
56}
57
58int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) {
59 return cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageTypeMin();
60}
61
62int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) {
63 return cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageTypeMax();
64}
65
66unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) {
67 return cast<quant::QuantizedType>(Val: unwrap(c: type)).getStorageTypeIntegralWidth();
68}
69
70bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,
71 MlirType candidate) {
72 return cast<quant::QuantizedType>(Val: unwrap(c: type))
73 .isCompatibleExpressedType(candidateExpressedType: unwrap(c: candidate));
74}
75
76MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) {
77 return wrap(cpp: quant::QuantizedType::getQuantizedElementType(primitiveOrContainerType: unwrap(c: type)));
78}
79
80MlirType mlirQuantizedTypeCastFromStorageType(MlirType type,
81 MlirType candidate) {
82 return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type))
83 .castFromStorageType(candidateType: unwrap(c: candidate)));
84}
85
86MlirType mlirQuantizedTypeCastToStorageType(MlirType type) {
87 return wrap(cpp: quant::QuantizedType::castToStorageType(
88 quantizedType: cast<quant::QuantizedType>(Val: unwrap(c: type))));
89}
90
91MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type,
92 MlirType candidate) {
93 return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type))
94 .castFromExpressedType(candidateType: unwrap(c: candidate)));
95}
96
97MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) {
98 return wrap(cpp: quant::QuantizedType::castToExpressedType(quantizedType: unwrap(c: type)));
99}
100
101MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type,
102 MlirType candidate) {
103 return wrap(cpp: cast<quant::QuantizedType>(Val: unwrap(c: type))
104 .castExpressedToStorageType(candidateType: unwrap(c: candidate)));
105}
106
107//===---------------------------------------------------------------------===//
108// AnyQuantizedType
109//===---------------------------------------------------------------------===//
110
111bool mlirTypeIsAAnyQuantizedType(MlirType type) {
112 return isa<quant::AnyQuantizedType>(Val: unwrap(c: type));
113}
114
115MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
116 MlirType expressedType, int64_t storageTypeMin,
117 int64_t storageTypeMax) {
118 return wrap(cpp: quant::AnyQuantizedType::get(flags, storageType: unwrap(c: storageType),
119 expressedType: unwrap(c: expressedType),
120 storageTypeMin, storageTypeMax));
121}
122
123//===---------------------------------------------------------------------===//
124// UniformQuantizedType
125//===---------------------------------------------------------------------===//
126
127bool mlirTypeIsAUniformQuantizedType(MlirType type) {
128 return isa<quant::UniformQuantizedType>(Val: unwrap(c: type));
129}
130
131MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
132 MlirType expressedType, double scale,
133 int64_t zeroPoint, int64_t storageTypeMin,
134 int64_t storageTypeMax) {
135 return wrap(cpp: quant::UniformQuantizedType::get(
136 flags, storageType: unwrap(c: storageType), expressedType: unwrap(c: expressedType), scale, zeroPoint,
137 storageTypeMin, storageTypeMax));
138}
139
140double mlirUniformQuantizedTypeGetScale(MlirType type) {
141 return cast<quant::UniformQuantizedType>(Val: unwrap(c: type)).getScale();
142}
143
144int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) {
145 return cast<quant::UniformQuantizedType>(Val: unwrap(c: type)).getZeroPoint();
146}
147
148bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) {
149 return cast<quant::UniformQuantizedType>(Val: unwrap(c: type)).isFixedPoint();
150}
151
152//===---------------------------------------------------------------------===//
153// UniformQuantizedPerAxisType
154//===---------------------------------------------------------------------===//
155
156bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
157 return isa<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type));
158}
159
160MlirType mlirUniformQuantizedPerAxisTypeGet(
161 unsigned flags, MlirType storageType, MlirType expressedType,
162 intptr_t nDims, double *scales, int64_t *zeroPoints,
163 int32_t quantizedDimension, int64_t storageTypeMin,
164 int64_t storageTypeMax) {
165 return wrap(cpp: quant::UniformQuantizedPerAxisType::get(
166 flags, storageType: unwrap(c: storageType), expressedType: unwrap(c: expressedType),
167 scales: llvm::ArrayRef(scales, nDims), zeroPoints: llvm::ArrayRef(zeroPoints, nDims),
168 quantizedDimension, storageTypeMin, storageTypeMax));
169}
170
171intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) {
172 return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type))
173 .getScales()
174 .size();
175}
176
177double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) {
178 return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type))
179 .getScales()[pos];
180}
181
182int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,
183 intptr_t pos) {
184 return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type))
185 .getZeroPoints()[pos];
186}
187
188int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) {
189 return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type))
190 .getQuantizedDimension();
191}
192
193bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
194 return cast<quant::UniformQuantizedPerAxisType>(Val: unwrap(c: type)).isFixedPoint();
195}
196
197//===---------------------------------------------------------------------===//
198// CalibratedQuantizedType
199//===---------------------------------------------------------------------===//
200
201bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
202 return isa<quant::CalibratedQuantizedType>(Val: unwrap(c: type));
203}
204
205MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
206 double max) {
207 return wrap(
208 cpp: quant::CalibratedQuantizedType::get(expressedType: unwrap(c: expressedType), min, max));
209}
210
211double mlirCalibratedQuantizedTypeGetMin(MlirType type) {
212 return cast<quant::CalibratedQuantizedType>(Val: unwrap(c: type)).getMin();
213}
214
215double mlirCalibratedQuantizedTypeGetMax(MlirType type) {
216 return cast<quant::CalibratedQuantizedType>(Val: unwrap(c: type)).getMax();
217}
218

source code of mlir/lib/CAPI/Dialect/Quant.cpp