1 | //===- quant.c - Test of Quant dialect C API ------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
4 | // Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | // RUN: mlir-capi-quant-test 2>&1 | FileCheck %s |
11 | |
12 | #include "mlir-c/Dialect/Quant.h" |
13 | #include "mlir-c/BuiltinTypes.h" |
14 | #include "mlir-c/IR.h" |
15 | |
16 | #include <assert.h> |
17 | #include <inttypes.h> |
18 | #include <stdio.h> |
19 | #include <stdlib.h> |
20 | |
21 | // CHECK-LABEL: testTypeHierarchy |
22 | static void testTypeHierarchy(MlirContext ctx) { |
23 | fprintf(stderr, format: "testTypeHierarchy\n" ); |
24 | |
25 | MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8); |
26 | MlirType any = mlirTypeParseGet( |
27 | context: ctx, type: mlirStringRefCreateFromCString(str: "!quant.any<i8<-8:7>:f32>" )); |
28 | MlirType uniform = |
29 | mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString( |
30 | str: "!quant.uniform<i8<-8:7>:f32, 0.99872:127>" )); |
31 | MlirType perAxis = mlirTypeParseGet( |
32 | context: ctx, type: mlirStringRefCreateFromCString( |
33 | str: "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>" )); |
34 | MlirType calibrated = mlirTypeParseGet( |
35 | context: ctx, |
36 | type: mlirStringRefCreateFromCString(str: "!quant.calibrated<f32<-0.998:1.2321>>" )); |
37 | |
38 | // The parser itself is checked in C++ dialect tests. |
39 | assert(!mlirTypeIsNull(any) && "couldn't parse AnyQuantizedType" ); |
40 | assert(!mlirTypeIsNull(uniform) && "couldn't parse UniformQuantizedType" ); |
41 | assert(!mlirTypeIsNull(perAxis) && |
42 | "couldn't parse UniformQuantizedPerAxisType" ); |
43 | assert(!mlirTypeIsNull(calibrated) && |
44 | "couldn't parse CalibratedQuantizedType" ); |
45 | |
46 | // CHECK: i8 isa QuantizedType: 0 |
47 | fprintf(stderr, format: "i8 isa QuantizedType: %d\n" , mlirTypeIsAQuantizedType(type: i8)); |
48 | // CHECK: any isa QuantizedType: 1 |
49 | fprintf(stderr, format: "any isa QuantizedType: %d\n" , mlirTypeIsAQuantizedType(type: any)); |
50 | // CHECK: uniform isa QuantizedType: 1 |
51 | fprintf(stderr, format: "uniform isa QuantizedType: %d\n" , |
52 | mlirTypeIsAQuantizedType(type: uniform)); |
53 | // CHECK: perAxis isa QuantizedType: 1 |
54 | fprintf(stderr, format: "perAxis isa QuantizedType: %d\n" , |
55 | mlirTypeIsAQuantizedType(type: perAxis)); |
56 | // CHECK: calibrated isa QuantizedType: 1 |
57 | fprintf(stderr, format: "calibrated isa QuantizedType: %d\n" , |
58 | mlirTypeIsAQuantizedType(type: calibrated)); |
59 | |
60 | // CHECK: any isa AnyQuantizedType: 1 |
61 | fprintf(stderr, format: "any isa AnyQuantizedType: %d\n" , |
62 | mlirTypeIsAAnyQuantizedType(type: any)); |
63 | // CHECK: uniform isa UniformQuantizedType: 1 |
64 | fprintf(stderr, format: "uniform isa UniformQuantizedType: %d\n" , |
65 | mlirTypeIsAUniformQuantizedType(type: uniform)); |
66 | // CHECK: perAxis isa UniformQuantizedPerAxisType: 1 |
67 | fprintf(stderr, format: "perAxis isa UniformQuantizedPerAxisType: %d\n" , |
68 | mlirTypeIsAUniformQuantizedPerAxisType(type: perAxis)); |
69 | // CHECK: calibrated isa CalibratedQuantizedType: 1 |
70 | fprintf(stderr, format: "calibrated isa CalibratedQuantizedType: %d\n" , |
71 | mlirTypeIsACalibratedQuantizedType(type: calibrated)); |
72 | |
73 | // CHECK: perAxis isa UniformQuantizedType: 0 |
74 | fprintf(stderr, format: "perAxis isa UniformQuantizedType: %d\n" , |
75 | mlirTypeIsAUniformQuantizedType(type: perAxis)); |
76 | // CHECK: uniform isa CalibratedQuantizedType: 0 |
77 | fprintf(stderr, format: "uniform isa CalibratedQuantizedType: %d\n" , |
78 | mlirTypeIsACalibratedQuantizedType(type: uniform)); |
79 | fprintf(stderr, format: "\n" ); |
80 | } |
81 | |
82 | // CHECK-LABEL: testAnyQuantizedType |
83 | void testAnyQuantizedType(MlirContext ctx) { |
84 | fprintf(stderr, format: "testAnyQuantizedType\n" ); |
85 | |
86 | MlirType anyParsed = mlirTypeParseGet( |
87 | context: ctx, type: mlirStringRefCreateFromCString(str: "!quant.any<i8<-8:7>:f32>" )); |
88 | |
89 | MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8); |
90 | MlirType f32 = mlirF32TypeGet(ctx); |
91 | MlirType any = |
92 | mlirAnyQuantizedTypeGet(flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32, storageTypeMin: -8, storageTypeMax: 7); |
93 | |
94 | // CHECK: flags: 1 |
95 | fprintf(stderr, format: "flags: %u\n" , mlirQuantizedTypeGetFlags(type: any)); |
96 | // CHECK: signed: 1 |
97 | fprintf(stderr, format: "signed: %u\n" , mlirQuantizedTypeIsSigned(type: any)); |
98 | // CHECK: storage type: i8 |
99 | fprintf(stderr, format: "storage type: " ); |
100 | mlirTypeDump(type: mlirQuantizedTypeGetStorageType(type: any)); |
101 | fprintf(stderr, format: "\n" ); |
102 | // CHECK: expressed type: f32 |
103 | fprintf(stderr, format: "expressed type: " ); |
104 | mlirTypeDump(type: mlirQuantizedTypeGetExpressedType(type: any)); |
105 | fprintf(stderr, format: "\n" ); |
106 | // CHECK: storage min: -8 |
107 | fprintf(stderr, format: "storage min: %" PRId64 "\n" , |
108 | mlirQuantizedTypeGetStorageTypeMin(type: any)); |
109 | // CHECK: storage max: 7 |
110 | fprintf(stderr, format: "storage max: %" PRId64 "\n" , |
111 | mlirQuantizedTypeGetStorageTypeMax(type: any)); |
112 | // CHECK: storage width: 8 |
113 | fprintf(stderr, format: "storage width: %u\n" , |
114 | mlirQuantizedTypeGetStorageTypeIntegralWidth(type: any)); |
115 | // CHECK: quantized element type: !quant.any<i8<-8:7>:f32> |
116 | fprintf(stderr, format: "quantized element type: " ); |
117 | mlirTypeDump(type: mlirQuantizedTypeGetQuantizedElementType(type: any)); |
118 | fprintf(stderr, format: "\n" ); |
119 | |
120 | // CHECK: equal: 1 |
121 | fprintf(stderr, format: "equal: %d\n" , mlirTypeEqual(t1: anyParsed, t2: any)); |
122 | // CHECK: !quant.any<i8<-8:7>:f32> |
123 | mlirTypeDump(type: any); |
124 | fprintf(stderr, format: "\n\n" ); |
125 | } |
126 | |
127 | // CHECK-LABEL: testUniformType |
128 | void testUniformType(MlirContext ctx) { |
129 | fprintf(stderr, format: "testUniformType\n" ); |
130 | |
131 | MlirType uniformParsed = |
132 | mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString( |
133 | str: "!quant.uniform<i8<-8:7>:f32, 0.99872:127>" )); |
134 | |
135 | MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8); |
136 | MlirType f32 = mlirF32TypeGet(ctx); |
137 | MlirType uniform = mlirUniformQuantizedTypeGet( |
138 | flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32, scale: 0.99872, zeroPoint: 127, storageTypeMin: -8, storageTypeMax: 7); |
139 | |
140 | // CHECK: scale: 0.998720 |
141 | fprintf(stderr, format: "scale: %lf\n" , mlirUniformQuantizedTypeGetScale(type: uniform)); |
142 | // CHECK: zero point: 127 |
143 | fprintf(stderr, format: "zero point: %" PRId64 "\n" , |
144 | mlirUniformQuantizedTypeGetZeroPoint(type: uniform)); |
145 | // CHECK: fixed point: 0 |
146 | fprintf(stderr, format: "fixed point: %d\n" , |
147 | mlirUniformQuantizedTypeIsFixedPoint(type: uniform)); |
148 | |
149 | // CHECK: equal: 1 |
150 | fprintf(stderr, format: "equal: %d\n" , mlirTypeEqual(t1: uniform, t2: uniformParsed)); |
151 | // CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127> |
152 | mlirTypeDump(type: uniform); |
153 | fprintf(stderr, format: "\n\n" ); |
154 | } |
155 | |
156 | // CHECK-LABEL: testUniformPerAxisType |
157 | void testUniformPerAxisType(MlirContext ctx) { |
158 | fprintf(stderr, format: "testUniformPerAxisType\n" ); |
159 | |
160 | MlirType perAxisParsed = mlirTypeParseGet( |
161 | context: ctx, type: mlirStringRefCreateFromCString( |
162 | str: "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>" )); |
163 | |
164 | MlirType i8 = mlirIntegerTypeGet(ctx, bitwidth: 8); |
165 | MlirType f32 = mlirF32TypeGet(ctx); |
166 | double scales[] = {200.0, 0.99872}; |
167 | int64_t zeroPoints[] = {0, 120}; |
168 | MlirType perAxis = mlirUniformQuantizedPerAxisTypeGet( |
169 | flags: mlirQuantizedTypeGetSignedFlag(), storageType: i8, expressedType: f32, |
170 | /*nDims=*/2, scales, zeroPoints, |
171 | /*quantizedDimension=*/1, |
172 | storageTypeMin: mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true, |
173 | /*integralWidth=*/8), |
174 | storageTypeMax: mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true, |
175 | /*integralWidth=*/8)); |
176 | |
177 | // CHECK: num dims: 2 |
178 | fprintf(stderr, format: "num dims: %" PRIdPTR "\n" , |
179 | mlirUniformQuantizedPerAxisTypeGetNumDims(type: perAxis)); |
180 | // CHECK: scale 0: 200.000000 |
181 | fprintf(stderr, format: "scale 0: %lf\n" , |
182 | mlirUniformQuantizedPerAxisTypeGetScale(type: perAxis, pos: 0)); |
183 | // CHECK: scale 1: 0.998720 |
184 | fprintf(stderr, format: "scale 1: %lf\n" , |
185 | mlirUniformQuantizedPerAxisTypeGetScale(type: perAxis, pos: 1)); |
186 | // CHECK: zero point 0: 0 |
187 | fprintf(stderr, format: "zero point 0: %" PRId64 "\n" , |
188 | mlirUniformQuantizedPerAxisTypeGetZeroPoint(type: perAxis, pos: 0)); |
189 | // CHECK: zero point 1: 120 |
190 | fprintf(stderr, format: "zero point 1: %" PRId64 "\n" , |
191 | mlirUniformQuantizedPerAxisTypeGetZeroPoint(type: perAxis, pos: 1)); |
192 | // CHECK: quantized dim: 1 |
193 | fprintf(stderr, format: "quantized dim: %" PRId32 "\n" , |
194 | mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type: perAxis)); |
195 | // CHECK: fixed point: 0 |
196 | fprintf(stderr, format: "fixed point: %d\n" , |
197 | mlirUniformQuantizedPerAxisTypeIsFixedPoint(type: perAxis)); |
198 | |
199 | // CHECK: equal: 1 |
200 | fprintf(stderr, format: "equal: %d\n" , mlirTypeEqual(t1: perAxis, t2: perAxisParsed)); |
201 | // CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}> |
202 | mlirTypeDump(type: perAxis); |
203 | fprintf(stderr, format: "\n\n" ); |
204 | } |
205 | |
206 | // CHECK-LABEL: testCalibratedType |
207 | void testCalibratedType(MlirContext ctx) { |
208 | fprintf(stderr, format: "testCalibratedType\n" ); |
209 | |
210 | MlirType calibratedParsed = mlirTypeParseGet( |
211 | context: ctx, |
212 | type: mlirStringRefCreateFromCString(str: "!quant.calibrated<f32<-0.998:1.2321>>" )); |
213 | |
214 | MlirType f32 = mlirF32TypeGet(ctx); |
215 | MlirType calibrated = mlirCalibratedQuantizedTypeGet(expressedType: f32, min: -0.998, max: 1.2321); |
216 | |
217 | // CHECK: min: -0.998000 |
218 | fprintf(stderr, format: "min: %lf\n" , mlirCalibratedQuantizedTypeGetMin(type: calibrated)); |
219 | // CHECK: max: 1.232100 |
220 | fprintf(stderr, format: "max: %lf\n" , mlirCalibratedQuantizedTypeGetMax(type: calibrated)); |
221 | |
222 | // CHECK: equal: 1 |
223 | fprintf(stderr, format: "equal: %d\n" , mlirTypeEqual(t1: calibrated, t2: calibratedParsed)); |
224 | // CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>> |
225 | mlirTypeDump(type: calibrated); |
226 | fprintf(stderr, format: "\n\n" ); |
227 | } |
228 | |
229 | int main(void) { |
230 | MlirContext ctx = mlirContextCreate(); |
231 | mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx); |
232 | testTypeHierarchy(ctx); |
233 | testAnyQuantizedType(ctx); |
234 | testUniformType(ctx); |
235 | testUniformPerAxisType(ctx); |
236 | testCalibratedType(ctx); |
237 | mlirContextDestroy(context: ctx); |
238 | return EXIT_SUCCESS; |
239 | } |
240 | |