1//===- quant.c - Test of Quant dialect C API ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10// RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
11
12#include "mlir-c/Dialect/Quant.h"
13#include "mlir-c/BuiltinAttributes.h"
14#include "mlir-c/BuiltinTypes.h"
15#include "mlir-c/IR.h"
16
17#include <assert.h>
18#include <inttypes.h>
19#include <stdio.h>
20#include <stdlib.h>
21
22// CHECK-LABEL: testTypeHierarchy
23static void testTypeHierarchy(MlirContext ctx) {
24 fprintf(stderr, format: "testTypeHierarchy\n");
25
26 MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8);
27 MlirType any = mlirTypeParseGet(
28 context: ctx, type: mlirStringRefCreateFromCString(str: "!quant.any<i8<-8:7>:f32>"));
29 MlirType uniform =
30 mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(
31 str: "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
32 MlirType perAxis = mlirTypeParseGet(
33 context: ctx, type: mlirStringRefCreateFromCString(
34 str: "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
35 MlirType calibrated = mlirTypeParseGet(
36 context: ctx,
37 type: mlirStringRefCreateFromCString(str: "!quant.calibrated<f32<-0.998:1.2321>>"));
38
39 // The parser itself is checked in C++ dialect tests.
40 assert(!mlirTypeIsNull(any) && "couldn't parse AnyQuantizedType");
41 assert(!mlirTypeIsNull(uniform) && "couldn't parse UniformQuantizedType");
42 assert(!mlirTypeIsNull(perAxis) &&
43 "couldn't parse UniformQuantizedPerAxisType");
44 assert(!mlirTypeIsNull(calibrated) &&
45 "couldn't parse CalibratedQuantizedType");
46
47 // CHECK: i8 isa QuantizedType: 0
48 fprintf(stderr, format: "i8 isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(type: i8));
49 // CHECK: any isa QuantizedType: 1
50 fprintf(stderr, format: "any isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(type: any));
51 // CHECK: uniform isa QuantizedType: 1
52 fprintf(stderr, format: "uniform isa QuantizedType: %d\n",
53 mlirTypeIsAQuantizedType(type: uniform));
54 // CHECK: perAxis isa QuantizedType: 1
55 fprintf(stderr, format: "perAxis isa QuantizedType: %d\n",
56 mlirTypeIsAQuantizedType(type: perAxis));
57 // CHECK: calibrated isa QuantizedType: 1
58 fprintf(stderr, format: "calibrated isa QuantizedType: %d\n",
59 mlirTypeIsAQuantizedType(type: calibrated));
60
61 // CHECK: any isa AnyQuantizedType: 1
62 fprintf(stderr, format: "any isa AnyQuantizedType: %d\n",
63 mlirTypeIsAAnyQuantizedType(type: any));
64 // CHECK: uniform isa UniformQuantizedType: 1
65 fprintf(stderr, format: "uniform isa UniformQuantizedType: %d\n",
66 mlirTypeIsAUniformQuantizedType(type: uniform));
67 // CHECK: perAxis isa UniformQuantizedPerAxisType: 1
68 fprintf(stderr, format: "perAxis isa UniformQuantizedPerAxisType: %d\n",
69 mlirTypeIsAUniformQuantizedPerAxisType(type: perAxis));
70 // CHECK: calibrated isa CalibratedQuantizedType: 1
71 fprintf(stderr, format: "calibrated isa CalibratedQuantizedType: %d\n",
72 mlirTypeIsACalibratedQuantizedType(type: calibrated));
73
74 // CHECK: perAxis isa UniformQuantizedType: 0
75 fprintf(stderr, format: "perAxis isa UniformQuantizedType: %d\n",
76 mlirTypeIsAUniformQuantizedType(type: perAxis));
77 // CHECK: uniform isa CalibratedQuantizedType: 0
78 fprintf(stderr, format: "uniform isa CalibratedQuantizedType: %d\n",
79 mlirTypeIsACalibratedQuantizedType(type: uniform));
80 fprintf(stderr, format: "\n");
81}
82
83// CHECK-LABEL: testAnyQuantizedType
84void testAnyQuantizedType(MlirContext ctx) {
85 fprintf(stderr, format: "testAnyQuantizedType\n");
86
87 MlirType anyParsed = mlirTypeParseGet(
88 context: ctx, type: mlirStringRefCreateFromCString(str: "!quant.any<i8<-8:7>:f32>"));
89
90 MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8);
91 MlirType f32 = mlirF32TypeGet(ctx);
92 MlirType any =
93 mlirAnyQuantizedTypeGet(flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32, storageTypeMin: -8, storageTypeMax: 7);
94
95 // CHECK: flags: 1
96 fprintf(stderr, format: "flags: %u\n", mlirQuantizedTypeGetFlags(type: any));
97 // CHECK: signed: 1
98 fprintf(stderr, format: "signed: %u\n", mlirQuantizedTypeIsSigned(type: any));
99 // CHECK: storage type: i8
100 fprintf(stderr, format: "storage type: ");
101 mlirTypeDump(type: mlirQuantizedTypeGetStorageType(type: any));
102 fprintf(stderr, format: "\n");
103 // CHECK: expressed type: f32
104 fprintf(stderr, format: "expressed type: ");
105 mlirTypeDump(type: mlirQuantizedTypeGetExpressedType(type: any));
106 fprintf(stderr, format: "\n");
107 // CHECK: storage min: -8
108 fprintf(stderr, format: "storage min: %" PRId64 "\n",
109 mlirQuantizedTypeGetStorageTypeMin(type: any));
110 // CHECK: storage max: 7
111 fprintf(stderr, format: "storage max: %" PRId64 "\n",
112 mlirQuantizedTypeGetStorageTypeMax(type: any));
113 // CHECK: storage width: 8
114 fprintf(stderr, format: "storage width: %u\n",
115 mlirQuantizedTypeGetStorageTypeIntegralWidth(type: any));
116 // CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
117 fprintf(stderr, format: "quantized element type: ");
118 mlirTypeDump(type: mlirQuantizedTypeGetQuantizedElementType(type: any));
119 fprintf(stderr, format: "\n");
120
121 // CHECK: equal: 1
122 fprintf(stderr, format: "equal: %d\n", mlirTypeEqual(t1: anyParsed, t2: any));
123 // CHECK: !quant.any<i8<-8:7>:f32>
124 mlirTypeDump(type: any);
125 fprintf(stderr, format: "\n\n");
126}
127
128// CHECK-LABEL: testUniformType
129void testUniformType(MlirContext ctx) {
130 fprintf(stderr, format: "testUniformType\n");
131
132 MlirType uniformParsed =
133 mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(
134 str: "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
135
136 MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8);
137 MlirType f32 = mlirF32TypeGet(ctx);
138 MlirType uniform = mlirUniformQuantizedTypeGet(
139 flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32, scale: 0.99872, zeroPoint: 127, storageTypeMin: -8, storageTypeMax: 7);
140
141 // CHECK: scale: 0.998720
142 fprintf(stderr, format: "scale: %lf\n", mlirUniformQuantizedTypeGetScale(type: uniform));
143 // CHECK: zero point: 127
144 fprintf(stderr, format: "zero point: %" PRId64 "\n",
145 mlirUniformQuantizedTypeGetZeroPoint(type: uniform));
146 // CHECK: fixed point: 0
147 fprintf(stderr, format: "fixed point: %d\n",
148 mlirUniformQuantizedTypeIsFixedPoint(type: uniform));
149
150 // CHECK: equal: 1
151 fprintf(stderr, format: "equal: %d\n", mlirTypeEqual(t1: uniform, t2: uniformParsed));
152 // CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
153 mlirTypeDump(type: uniform);
154 fprintf(stderr, format: "\n\n");
155}
156
157// CHECK-LABEL: testUniformPerAxisType
158void testUniformPerAxisType(MlirContext ctx) {
159 fprintf(stderr, format: "testUniformPerAxisType\n");
160
161 MlirType perAxisParsed = mlirTypeParseGet(
162 context: ctx, type: mlirStringRefCreateFromCString(
163 str: "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
164
165 MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8);
166 MlirType f32 = mlirF32TypeGet(ctx);
167 double scales[] = {200.0, 0.99872};
168 int64_t zeroPoints[] = {0, 120};
169 MlirType perAxis = mlirUniformQuantizedPerAxisTypeGet(
170 flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32,
171 /*nDims=*/2, scales, zeroPoints,
172 /*quantizedDimension=*/1,
173 storageTypeMin: mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
174 /*integralWidth=*/8),
175 storageTypeMax: mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
176 /*integralWidth=*/8));
177
178 // CHECK: num dims: 2
179 fprintf(stderr, format: "num dims: %" PRIdPTR "\n",
180 mlirUniformQuantizedPerAxisTypeGetNumDims(type: perAxis));
181 // CHECK: scale 0: 200.000000
182 fprintf(stderr, format: "scale 0: %lf\n",
183 mlirUniformQuantizedPerAxisTypeGetScale(type: perAxis, pos: 0));
184 // CHECK: scale 1: 0.998720
185 fprintf(stderr, format: "scale 1: %lf\n",
186 mlirUniformQuantizedPerAxisTypeGetScale(type: perAxis, pos: 1));
187 // CHECK: zero point 0: 0
188 fprintf(stderr, format: "zero point 0: %" PRId64 "\n",
189 mlirUniformQuantizedPerAxisTypeGetZeroPoint(type: perAxis, pos: 0));
190 // CHECK: zero point 1: 120
191 fprintf(stderr, format: "zero point 1: %" PRId64 "\n",
192 mlirUniformQuantizedPerAxisTypeGetZeroPoint(type: perAxis, pos: 1));
193 // CHECK: quantized dim: 1
194 fprintf(stderr, format: "quantized dim: %" PRId32 "\n",
195 mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type: perAxis));
196 // CHECK: fixed point: 0
197 fprintf(stderr, format: "fixed point: %d\n",
198 mlirUniformQuantizedPerAxisTypeIsFixedPoint(type: perAxis));
199
200 // CHECK: equal: 1
201 fprintf(stderr, format: "equal: %d\n", mlirTypeEqual(t1: perAxis, t2: perAxisParsed));
202 // CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
203 mlirTypeDump(type: perAxis);
204 fprintf(stderr, format: "\n\n");
205}
206
207// CHECK-LABEL: testUniformSubChannelType
208void testUniformSubChannelType(MlirContext ctx) {
209 fprintf(stderr, format: "testUniformSubChannelType\n");
210
211 MlirType subChannelParsed =
212 mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(
213 str: "!quant.uniform<i8:f32:{0:1, 1:2}, "
214 "{{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>"));
215
216 MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8);
217 MlirType f32 = mlirF32TypeGet(ctx);
218
219 // block-size information
220 int32_t quantizedDimensions[] = {0, 1};
221 int64_t blockSizes[] = {1, 2};
222 int64_t numBlockSizes = 2;
223
224 // quantization parameters
225 int64_t quantParamShape[] = {2, 2};
226 int64_t quantParamRank = 2;
227 int64_t numQuantizationParams = 4;
228 MlirAttribute scales[] = {mlirFloatAttrDoubleGet(ctx, type: f32, value: 2.0),
229 mlirFloatAttrDoubleGet(ctx, type: f32, value: 3.0),
230 mlirFloatAttrDoubleGet(ctx, type: f32, value: 4.0),
231 mlirFloatAttrDoubleGet(ctx, type: f32, value: 5.0)};
232 MlirAttribute zeroPoints[] = {
233 mlirIntegerAttrGet(type: i8, value: 10), mlirIntegerAttrGet(type: i8, value: 20),
234 mlirIntegerAttrGet(type: i8, value: 30), mlirIntegerAttrGet(type: i8, value: 40)};
235
236 MlirType scalesType =
237 mlirRankedTensorTypeGet(rank: quantParamRank, shape: quantParamShape, elementType: f32,
238 /*encoding=*/mlirAttributeGetNull());
239 MlirType zeroPointsType = mlirRankedTensorTypeGet(
240 rank: quantParamRank, shape: quantParamShape, elementType: i8, /*encoding=*/mlirAttributeGetNull());
241 MlirAttribute denseScalesAttr =
242 mlirDenseElementsAttrGet(shapedType: scalesType, numElements: numQuantizationParams, elements: scales);
243 MlirAttribute denseZeroPointsAttr = mlirDenseElementsAttrGet(
244 shapedType: zeroPointsType, numElements: numQuantizationParams, elements: zeroPoints);
245
246 MlirType subChannel = mlirUniformQuantizedSubChannelTypeGet(
247 flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32, scalesAttr: denseScalesAttr,
248 zeroPointsAttr: denseZeroPointsAttr, blockSizeInfoLength: numBlockSizes, quantizedDimensions, blockSizes,
249 storageTypeMin: mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
250 /*integralWidth=*/8),
251 storageTypeMax: mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
252 /*integralWidth=*/8));
253
254 MlirAttribute arrayScalesAttr =
255 mlirArrayAttrGet(ctx, numElements: numQuantizationParams, elements: scales);
256 MlirAttribute arrayZeroPointsAttr =
257 mlirArrayAttrGet(ctx, numElements: numQuantizationParams, elements: zeroPoints);
258 MlirType illegalSubChannel = mlirUniformQuantizedSubChannelTypeGet(
259 flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32, scalesAttr: arrayScalesAttr,
260 zeroPointsAttr: arrayZeroPointsAttr, blockSizeInfoLength: numBlockSizes, quantizedDimensions, blockSizes,
261 storageTypeMin: mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
262 /*integralWidth=*/8),
263 storageTypeMax: mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
264 /*integralWidth=*/8));
265
266 // CHECK: is null sub-channel type: 1
267 fprintf(stderr, format: "is null sub-channel type: %d\n",
268 mlirTypeIsNull(type: illegalSubChannel));
269
270 // CHECK: num dims: 2
271 fprintf(stderr, format: "num dims: %" PRIdPTR "\n",
272 mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type: subChannel));
273
274 // CHECK: axis-block-size-pair[0]: 0:1
275 fprintf(
276 stderr, format: "axis-block-size-pair[0]: %" PRId32 ":%" PRId64 "\n",
277 mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type: subChannel, pos: 0),
278 mlirUniformQuantizedSubChannelTypeGetBlockSize(type: subChannel, pos: 0));
279
280 // CHECK: axis-block-size-pair[1]: 1:2
281 fprintf(
282 stderr, format: "axis-block-size-pair[1]: %" PRId32 ":%" PRId64 "\n",
283 mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type: subChannel, pos: 1),
284 mlirUniformQuantizedSubChannelTypeGetBlockSize(type: subChannel, pos: 1));
285
286 denseScalesAttr = mlirUniformQuantizedSubChannelTypeGetScales(type: subChannel);
287 denseZeroPointsAttr =
288 mlirUniformQuantizedSubChannelTypeGetZeroPoints(type: subChannel);
289 scalesType = mlirAttributeGetType(attribute: denseScalesAttr);
290 zeroPointsType = mlirAttributeGetType(attribute: denseZeroPointsAttr);
291
292 // CHECK: tensor<2x2xf32>
293 mlirTypeDump(type: scalesType);
294 // CHECK: tensor<2x2xi8>
295 mlirTypeDump(type: zeroPointsType);
296
297 // CHECK: number of quantization parameters: 4
298 fprintf(stderr, format: "number of quantization parameters: %" PRId64 "\n",
299 mlirElementsAttrGetNumElements(attr: denseScalesAttr));
300
301 // CHECK: quantization-parameter[0]: 2.000000:10
302 fprintf(stderr, format: "quantization-parameter[0]: %lf:%" PRId8 "\n",
303 mlirDenseElementsAttrGetFloatValue(attr: denseScalesAttr, pos: 0),
304 mlirDenseElementsAttrGetInt8Value(attr: denseZeroPointsAttr, pos: 0));
305
306 // CHECK: quantization-parameter[1]: 3.000000:20
307 fprintf(stderr, format: "quantization-parameter[1]: %lf:%" PRId8 "\n",
308 mlirDenseElementsAttrGetFloatValue(attr: denseScalesAttr, pos: 1),
309 mlirDenseElementsAttrGetInt8Value(attr: denseZeroPointsAttr, pos: 1));
310
311 // CHECK: quantization-parameter[2]: 4.000000:30
312 fprintf(stderr, format: "quantization-parameter[2]: %lf:%" PRId8 "\n",
313 mlirDenseElementsAttrGetFloatValue(attr: denseScalesAttr, pos: 2),
314 mlirDenseElementsAttrGetInt8Value(attr: denseZeroPointsAttr, pos: 2));
315
316 // CHECK: quantization-parameter[3]: 5.000000:40
317 fprintf(stderr, format: "quantization-parameter[3]: %lf:%" PRId8 "\n",
318 mlirDenseElementsAttrGetFloatValue(attr: denseScalesAttr, pos: 3),
319 mlirDenseElementsAttrGetInt8Value(attr: denseZeroPointsAttr, pos: 3));
320
321 // CHECK: equal: 1
322 fprintf(stderr, format: "equal: %d\n", mlirTypeEqual(t1: subChannel, t2: subChannelParsed));
323
324 // CHECK: !quant.uniform<i8:f32:{0:1, 1:2},
325 // {{.*}}2.000000e+00:10, 3.000000e+00:20},
326 // {4.000000e+00:30, 5.000000e+00:40{{.*}}}}>
327 mlirTypeDump(type: subChannel);
328 fprintf(stderr, format: "\n\n");
329}
330
331// CHECK-LABEL: testCalibratedType
332void testCalibratedType(MlirContext ctx) {
333 fprintf(stderr, format: "testCalibratedType\n");
334
335 MlirType calibratedParsed = mlirTypeParseGet(
336 context: ctx,
337 type: mlirStringRefCreateFromCString(str: "!quant.calibrated<f32<-0.998:1.2321>>"));
338
339 MlirType f32 = mlirF32TypeGet(ctx);
340 MlirType calibrated = mlirCalibratedQuantizedTypeGet(expressedType: f32, min: -0.998, max: 1.2321);
341
342 // CHECK: min: -0.998000
343 fprintf(stderr, format: "min: %lf\n", mlirCalibratedQuantizedTypeGetMin(type: calibrated));
344 // CHECK: max: 1.232100
345 fprintf(stderr, format: "max: %lf\n", mlirCalibratedQuantizedTypeGetMax(type: calibrated));
346
347 // CHECK: equal: 1
348 fprintf(stderr, format: "equal: %d\n", mlirTypeEqual(t1: calibrated, t2: calibratedParsed));
349 // CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
350 mlirTypeDump(type: calibrated);
351 fprintf(stderr, format: "\n\n");
352}
353
354int main(void) {
355 MlirContext ctx = mlirContextCreate();
356 mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx);
357 testTypeHierarchy(ctx);
358 testAnyQuantizedType(ctx);
359 testUniformType(ctx);
360 testUniformPerAxisType(ctx);
361 testUniformSubChannelType(ctx);
362 testCalibratedType(ctx);
363 mlirContextDestroy(context: ctx);
364 return EXIT_SUCCESS;
365}
366

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/test/CAPI/quant.c