1//===- quant.c - Test of Quant dialect C API ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10// RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
11
12#include "mlir-c/Dialect/Quant.h"
13#include "mlir-c/BuiltinTypes.h"
14#include "mlir-c/IR.h"
15
16#include <assert.h>
17#include <inttypes.h>
18#include <stdio.h>
19#include <stdlib.h>
20
21// CHECK-LABEL: testTypeHierarchy
22static void testTypeHierarchy(MlirContext ctx) {
23 fprintf(stderr, format: "testTypeHierarchy\n");
24
25 MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8);
26 MlirType any = mlirTypeParseGet(
27 context: ctx, type: mlirStringRefCreateFromCString(str: "!quant.any<i8<-8:7>:f32>"));
28 MlirType uniform =
29 mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(
30 str: "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
31 MlirType perAxis = mlirTypeParseGet(
32 context: ctx, type: mlirStringRefCreateFromCString(
33 str: "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
34 MlirType calibrated = mlirTypeParseGet(
35 context: ctx,
36 type: mlirStringRefCreateFromCString(str: "!quant.calibrated<f32<-0.998:1.2321>>"));
37
38 // The parser itself is checked in C++ dialect tests.
39 assert(!mlirTypeIsNull(any) && "couldn't parse AnyQuantizedType");
40 assert(!mlirTypeIsNull(uniform) && "couldn't parse UniformQuantizedType");
41 assert(!mlirTypeIsNull(perAxis) &&
42 "couldn't parse UniformQuantizedPerAxisType");
43 assert(!mlirTypeIsNull(calibrated) &&
44 "couldn't parse CalibratedQuantizedType");
45
46 // CHECK: i8 isa QuantizedType: 0
47 fprintf(stderr, format: "i8 isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(type: i8));
48 // CHECK: any isa QuantizedType: 1
49 fprintf(stderr, format: "any isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(type: any));
50 // CHECK: uniform isa QuantizedType: 1
51 fprintf(stderr, format: "uniform isa QuantizedType: %d\n",
52 mlirTypeIsAQuantizedType(type: uniform));
53 // CHECK: perAxis isa QuantizedType: 1
54 fprintf(stderr, format: "perAxis isa QuantizedType: %d\n",
55 mlirTypeIsAQuantizedType(type: perAxis));
56 // CHECK: calibrated isa QuantizedType: 1
57 fprintf(stderr, format: "calibrated isa QuantizedType: %d\n",
58 mlirTypeIsAQuantizedType(type: calibrated));
59
60 // CHECK: any isa AnyQuantizedType: 1
61 fprintf(stderr, format: "any isa AnyQuantizedType: %d\n",
62 mlirTypeIsAAnyQuantizedType(type: any));
63 // CHECK: uniform isa UniformQuantizedType: 1
64 fprintf(stderr, format: "uniform isa UniformQuantizedType: %d\n",
65 mlirTypeIsAUniformQuantizedType(type: uniform));
66 // CHECK: perAxis isa UniformQuantizedPerAxisType: 1
67 fprintf(stderr, format: "perAxis isa UniformQuantizedPerAxisType: %d\n",
68 mlirTypeIsAUniformQuantizedPerAxisType(type: perAxis));
69 // CHECK: calibrated isa CalibratedQuantizedType: 1
70 fprintf(stderr, format: "calibrated isa CalibratedQuantizedType: %d\n",
71 mlirTypeIsACalibratedQuantizedType(type: calibrated));
72
73 // CHECK: perAxis isa UniformQuantizedType: 0
74 fprintf(stderr, format: "perAxis isa UniformQuantizedType: %d\n",
75 mlirTypeIsAUniformQuantizedType(type: perAxis));
76 // CHECK: uniform isa CalibratedQuantizedType: 0
77 fprintf(stderr, format: "uniform isa CalibratedQuantizedType: %d\n",
78 mlirTypeIsACalibratedQuantizedType(type: uniform));
79 fprintf(stderr, format: "\n");
80}
81
82// CHECK-LABEL: testAnyQuantizedType
83void testAnyQuantizedType(MlirContext ctx) {
84 fprintf(stderr, format: "testAnyQuantizedType\n");
85
86 MlirType anyParsed = mlirTypeParseGet(
87 context: ctx, type: mlirStringRefCreateFromCString(str: "!quant.any<i8<-8:7>:f32>"));
88
89 MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8);
90 MlirType f32 = mlirF32TypeGet(ctx);
91 MlirType any =
92 mlirAnyQuantizedTypeGet(flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32, storageTypeMin: -8, storageTypeMax: 7);
93
94 // CHECK: flags: 1
95 fprintf(stderr, format: "flags: %u\n", mlirQuantizedTypeGetFlags(type: any));
96 // CHECK: signed: 1
97 fprintf(stderr, format: "signed: %u\n", mlirQuantizedTypeIsSigned(type: any));
98 // CHECK: storage type: i8
99 fprintf(stderr, format: "storage type: ");
100 mlirTypeDump(type: mlirQuantizedTypeGetStorageType(type: any));
101 fprintf(stderr, format: "\n");
102 // CHECK: expressed type: f32
103 fprintf(stderr, format: "expressed type: ");
104 mlirTypeDump(type: mlirQuantizedTypeGetExpressedType(type: any));
105 fprintf(stderr, format: "\n");
106 // CHECK: storage min: -8
107 fprintf(stderr, format: "storage min: %" PRId64 "\n",
108 mlirQuantizedTypeGetStorageTypeMin(type: any));
109 // CHECK: storage max: 7
110 fprintf(stderr, format: "storage max: %" PRId64 "\n",
111 mlirQuantizedTypeGetStorageTypeMax(type: any));
112 // CHECK: storage width: 8
113 fprintf(stderr, format: "storage width: %u\n",
114 mlirQuantizedTypeGetStorageTypeIntegralWidth(type: any));
115 // CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
116 fprintf(stderr, format: "quantized element type: ");
117 mlirTypeDump(type: mlirQuantizedTypeGetQuantizedElementType(type: any));
118 fprintf(stderr, format: "\n");
119
120 // CHECK: equal: 1
121 fprintf(stderr, format: "equal: %d\n", mlirTypeEqual(t1: anyParsed, t2: any));
122 // CHECK: !quant.any<i8<-8:7>:f32>
123 mlirTypeDump(type: any);
124 fprintf(stderr, format: "\n\n");
125}
126
127// CHECK-LABEL: testUniformType
128void testUniformType(MlirContext ctx) {
129 fprintf(stderr, format: "testUniformType\n");
130
131 MlirType uniformParsed =
132 mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(
133 str: "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
134
135 MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8);
136 MlirType f32 = mlirF32TypeGet(ctx);
137 MlirType uniform = mlirUniformQuantizedTypeGet(
138 flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32, scale: 0.99872, zeroPoint: 127, storageTypeMin: -8, storageTypeMax: 7);
139
140 // CHECK: scale: 0.998720
141 fprintf(stderr, format: "scale: %lf\n", mlirUniformQuantizedTypeGetScale(type: uniform));
142 // CHECK: zero point: 127
143 fprintf(stderr, format: "zero point: %" PRId64 "\n",
144 mlirUniformQuantizedTypeGetZeroPoint(type: uniform));
145 // CHECK: fixed point: 0
146 fprintf(stderr, format: "fixed point: %d\n",
147 mlirUniformQuantizedTypeIsFixedPoint(type: uniform));
148
149 // CHECK: equal: 1
150 fprintf(stderr, format: "equal: %d\n", mlirTypeEqual(t1: uniform, t2: uniformParsed));
151 // CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
152 mlirTypeDump(type: uniform);
153 fprintf(stderr, format: "\n\n");
154}
155
156// CHECK-LABEL: testUniformPerAxisType
157void testUniformPerAxisType(MlirContext ctx) {
158 fprintf(stderr, format: "testUniformPerAxisType\n");
159
160 MlirType perAxisParsed = mlirTypeParseGet(
161 context: ctx, type: mlirStringRefCreateFromCString(
162 str: "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
163
164 MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8);
165 MlirType f32 = mlirF32TypeGet(ctx);
166 double scales[] = {200.0, 0.99872};
167 int64_t zeroPoints[] = {0, 120};
168 MlirType perAxis = mlirUniformQuantizedPerAxisTypeGet(
169 flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32,
170 /*nDims=*/2, scales, zeroPoints,
171 /*quantizedDimension=*/1,
172 storageTypeMin: mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
173 /*integralWidth=*/8),
174 storageTypeMax: mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
175 /*integralWidth=*/8));
176
177 // CHECK: num dims: 2
178 fprintf(stderr, format: "num dims: %" PRIdPTR "\n",
179 mlirUniformQuantizedPerAxisTypeGetNumDims(type: perAxis));
180 // CHECK: scale 0: 200.000000
181 fprintf(stderr, format: "scale 0: %lf\n",
182 mlirUniformQuantizedPerAxisTypeGetScale(type: perAxis, pos: 0));
183 // CHECK: scale 1: 0.998720
184 fprintf(stderr, format: "scale 1: %lf\n",
185 mlirUniformQuantizedPerAxisTypeGetScale(type: perAxis, pos: 1));
186 // CHECK: zero point 0: 0
187 fprintf(stderr, format: "zero point 0: %" PRId64 "\n",
188 mlirUniformQuantizedPerAxisTypeGetZeroPoint(type: perAxis, pos: 0));
189 // CHECK: zero point 1: 120
190 fprintf(stderr, format: "zero point 1: %" PRId64 "\n",
191 mlirUniformQuantizedPerAxisTypeGetZeroPoint(type: perAxis, pos: 1));
192 // CHECK: quantized dim: 1
193 fprintf(stderr, format: "quantized dim: %" PRId32 "\n",
194 mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type: perAxis));
195 // CHECK: fixed point: 0
196 fprintf(stderr, format: "fixed point: %d\n",
197 mlirUniformQuantizedPerAxisTypeIsFixedPoint(type: perAxis));
198
199 // CHECK: equal: 1
200 fprintf(stderr, format: "equal: %d\n", mlirTypeEqual(t1: perAxis, t2: perAxisParsed));
201 // CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
202 mlirTypeDump(type: perAxis);
203 fprintf(stderr, format: "\n\n");
204}
205
206// CHECK-LABEL: testCalibratedType
207void testCalibratedType(MlirContext ctx) {
208 fprintf(stderr, format: "testCalibratedType\n");
209
210 MlirType calibratedParsed = mlirTypeParseGet(
211 context: ctx,
212 type: mlirStringRefCreateFromCString(str: "!quant.calibrated<f32<-0.998:1.2321>>"));
213
214 MlirType f32 = mlirF32TypeGet(ctx);
215 MlirType calibrated = mlirCalibratedQuantizedTypeGet(expressedType: f32, min: -0.998, max: 1.2321);
216
217 // CHECK: min: -0.998000
218 fprintf(stderr, format: "min: %lf\n", mlirCalibratedQuantizedTypeGetMin(type: calibrated));
219 // CHECK: max: 1.232100
220 fprintf(stderr, format: "max: %lf\n", mlirCalibratedQuantizedTypeGetMax(type: calibrated));
221
222 // CHECK: equal: 1
223 fprintf(stderr, format: "equal: %d\n", mlirTypeEqual(t1: calibrated, t2: calibratedParsed));
224 // CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
225 mlirTypeDump(type: calibrated);
226 fprintf(stderr, format: "\n\n");
227}
228
229int main(void) {
230 MlirContext ctx = mlirContextCreate();
231 mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx);
232 testTypeHierarchy(ctx);
233 testAnyQuantizedType(ctx);
234 testUniformType(ctx);
235 testUniformPerAxisType(ctx);
236 testCalibratedType(ctx);
237 mlirContextDestroy(context: ctx);
238 return EXIT_SUCCESS;
239}
240

source code of mlir/test/CAPI/quant.c