1 | //===- ir.c - Simple test of C APIs ---------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
4 | // Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s |
11 | */ |
12 | |
13 | #include "mlir-c/IR.h" |
14 | #include "mlir-c/AffineExpr.h" |
15 | #include "mlir-c/AffineMap.h" |
16 | #include "mlir-c/BuiltinAttributes.h" |
17 | #include "mlir-c/BuiltinTypes.h" |
18 | #include "mlir-c/Diagnostics.h" |
19 | #include "mlir-c/Dialect/Func.h" |
20 | #include "mlir-c/IntegerSet.h" |
21 | #include "mlir-c/RegisterEverything.h" |
22 | #include "mlir-c/Support.h" |
23 | |
24 | #include <assert.h> |
25 | #include <inttypes.h> |
26 | #include <math.h> |
27 | #include <stdio.h> |
28 | #include <stdlib.h> |
29 | #include <string.h> |
30 | |
31 | static void registerAllUpstreamDialects(MlirContext ctx) { |
32 | MlirDialectRegistry registry = mlirDialectRegistryCreate(); |
33 | mlirRegisterAllDialects(registry); |
34 | mlirContextAppendDialectRegistry(ctx, registry); |
35 | mlirDialectRegistryDestroy(registry); |
36 | } |
37 | |
38 | struct ResourceDeleteUserData { |
39 | const char *name; |
40 | }; |
41 | static struct ResourceDeleteUserData resourceI64BlobUserData = { |
42 | "resource_i64_blob" }; |
43 | static void reportResourceDelete(void *userData, const void *data, size_t size, |
44 | size_t align) { |
45 | fprintf(stderr, format: "reportResourceDelete: %s\n" , |
46 | ((struct ResourceDeleteUserData *)userData)->name); |
47 | } |
48 | |
49 | void populateLoopBody(MlirContext ctx, MlirBlock loopBody, |
50 | MlirLocation location, MlirBlock funcBody) { |
51 | MlirValue iv = mlirBlockGetArgument(block: loopBody, pos: 0); |
52 | MlirValue funcArg0 = mlirBlockGetArgument(block: funcBody, pos: 0); |
53 | MlirValue funcArg1 = mlirBlockGetArgument(block: funcBody, pos: 1); |
54 | MlirType f32Type = |
55 | mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(str: "f32" )); |
56 | |
57 | MlirOperationState loadLHSState = mlirOperationStateGet( |
58 | name: mlirStringRefCreateFromCString(str: "memref.load" ), loc: location); |
59 | MlirValue loadLHSOperands[] = {funcArg0, iv}; |
60 | mlirOperationStateAddOperands(state: &loadLHSState, n: 2, operands: loadLHSOperands); |
61 | mlirOperationStateAddResults(state: &loadLHSState, n: 1, results: &f32Type); |
62 | MlirOperation loadLHS = mlirOperationCreate(state: &loadLHSState); |
63 | mlirBlockAppendOwnedOperation(block: loopBody, operation: loadLHS); |
64 | |
65 | MlirOperationState loadRHSState = mlirOperationStateGet( |
66 | name: mlirStringRefCreateFromCString(str: "memref.load" ), loc: location); |
67 | MlirValue loadRHSOperands[] = {funcArg1, iv}; |
68 | mlirOperationStateAddOperands(state: &loadRHSState, n: 2, operands: loadRHSOperands); |
69 | mlirOperationStateAddResults(state: &loadRHSState, n: 1, results: &f32Type); |
70 | MlirOperation loadRHS = mlirOperationCreate(state: &loadRHSState); |
71 | mlirBlockAppendOwnedOperation(block: loopBody, operation: loadRHS); |
72 | |
73 | MlirOperationState addState = mlirOperationStateGet( |
74 | name: mlirStringRefCreateFromCString(str: "arith.addf" ), loc: location); |
75 | MlirValue addOperands[] = {mlirOperationGetResult(op: loadLHS, pos: 0), |
76 | mlirOperationGetResult(op: loadRHS, pos: 0)}; |
77 | mlirOperationStateAddOperands(state: &addState, n: 2, operands: addOperands); |
78 | mlirOperationStateAddResults(state: &addState, n: 1, results: &f32Type); |
79 | MlirOperation add = mlirOperationCreate(state: &addState); |
80 | mlirBlockAppendOwnedOperation(block: loopBody, operation: add); |
81 | |
82 | MlirOperationState storeState = mlirOperationStateGet( |
83 | name: mlirStringRefCreateFromCString(str: "memref.store" ), loc: location); |
84 | MlirValue storeOperands[] = {mlirOperationGetResult(op: add, pos: 0), funcArg0, iv}; |
85 | mlirOperationStateAddOperands(state: &storeState, n: 3, operands: storeOperands); |
86 | MlirOperation store = mlirOperationCreate(state: &storeState); |
87 | mlirBlockAppendOwnedOperation(block: loopBody, operation: store); |
88 | |
89 | MlirOperationState yieldState = mlirOperationStateGet( |
90 | name: mlirStringRefCreateFromCString(str: "scf.yield" ), loc: location); |
91 | MlirOperation yield = mlirOperationCreate(state: &yieldState); |
92 | mlirBlockAppendOwnedOperation(block: loopBody, operation: yield); |
93 | } |
94 | |
95 | MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) { |
96 | MlirModule moduleOp = mlirModuleCreateEmpty(location); |
97 | MlirBlock moduleBody = mlirModuleGetBody(module: moduleOp); |
98 | |
99 | MlirType memrefType = |
100 | mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(str: "memref<?xf32>" )); |
101 | MlirType funcBodyArgTypes[] = {memrefType, memrefType}; |
102 | MlirLocation funcBodyArgLocs[] = {location, location}; |
103 | MlirRegion funcBodyRegion = mlirRegionCreate(); |
104 | MlirBlock funcBody = |
105 | mlirBlockCreate(nArgs: sizeof(funcBodyArgTypes) / sizeof(MlirType), |
106 | args: funcBodyArgTypes, locs: funcBodyArgLocs); |
107 | mlirRegionAppendOwnedBlock(region: funcBodyRegion, block: funcBody); |
108 | |
109 | MlirAttribute funcTypeAttr = mlirAttributeParseGet( |
110 | context: ctx, |
111 | attr: mlirStringRefCreateFromCString(str: "(memref<?xf32>, memref<?xf32>) -> ()" )); |
112 | MlirAttribute funcNameAttr = |
113 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "\"add\"" )); |
114 | MlirNamedAttribute funcAttrs[] = { |
115 | mlirNamedAttributeGet( |
116 | name: mlirIdentifierGet(context: ctx, |
117 | str: mlirStringRefCreateFromCString(str: "function_type" )), |
118 | attr: funcTypeAttr), |
119 | mlirNamedAttributeGet( |
120 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "sym_name" )), |
121 | attr: funcNameAttr)}; |
122 | MlirOperationState funcState = mlirOperationStateGet( |
123 | name: mlirStringRefCreateFromCString(str: "func.func" ), loc: location); |
124 | mlirOperationStateAddAttributes(state: &funcState, n: 2, attributes: funcAttrs); |
125 | mlirOperationStateAddOwnedRegions(state: &funcState, n: 1, regions: &funcBodyRegion); |
126 | MlirOperation func = mlirOperationCreate(state: &funcState); |
127 | mlirBlockInsertOwnedOperation(block: moduleBody, pos: 0, operation: func); |
128 | |
129 | MlirType indexType = |
130 | mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(str: "index" )); |
131 | MlirAttribute indexZeroLiteral = |
132 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "0 : index" )); |
133 | MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( |
134 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value" )), |
135 | attr: indexZeroLiteral); |
136 | MlirOperationState constZeroState = mlirOperationStateGet( |
137 | name: mlirStringRefCreateFromCString(str: "arith.constant" ), loc: location); |
138 | mlirOperationStateAddResults(state: &constZeroState, n: 1, results: &indexType); |
139 | mlirOperationStateAddAttributes(state: &constZeroState, n: 1, attributes: &indexZeroValueAttr); |
140 | MlirOperation constZero = mlirOperationCreate(state: &constZeroState); |
141 | mlirBlockAppendOwnedOperation(block: funcBody, operation: constZero); |
142 | |
143 | MlirValue funcArg0 = mlirBlockGetArgument(block: funcBody, pos: 0); |
144 | MlirValue constZeroValue = mlirOperationGetResult(op: constZero, pos: 0); |
145 | MlirValue dimOperands[] = {funcArg0, constZeroValue}; |
146 | MlirOperationState dimState = mlirOperationStateGet( |
147 | name: mlirStringRefCreateFromCString(str: "memref.dim" ), loc: location); |
148 | mlirOperationStateAddOperands(state: &dimState, n: 2, operands: dimOperands); |
149 | mlirOperationStateAddResults(state: &dimState, n: 1, results: &indexType); |
150 | MlirOperation dim = mlirOperationCreate(state: &dimState); |
151 | mlirBlockAppendOwnedOperation(block: funcBody, operation: dim); |
152 | |
153 | MlirRegion loopBodyRegion = mlirRegionCreate(); |
154 | MlirBlock loopBody = mlirBlockCreate(nArgs: 0, NULL, NULL); |
155 | mlirBlockAddArgument(block: loopBody, type: indexType, loc: location); |
156 | mlirRegionAppendOwnedBlock(region: loopBodyRegion, block: loopBody); |
157 | |
158 | MlirAttribute indexOneLiteral = |
159 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "1 : index" )); |
160 | MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet( |
161 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value" )), |
162 | attr: indexOneLiteral); |
163 | MlirOperationState constOneState = mlirOperationStateGet( |
164 | name: mlirStringRefCreateFromCString(str: "arith.constant" ), loc: location); |
165 | mlirOperationStateAddResults(state: &constOneState, n: 1, results: &indexType); |
166 | mlirOperationStateAddAttributes(state: &constOneState, n: 1, attributes: &indexOneValueAttr); |
167 | MlirOperation constOne = mlirOperationCreate(state: &constOneState); |
168 | mlirBlockAppendOwnedOperation(block: funcBody, operation: constOne); |
169 | |
170 | MlirValue dimValue = mlirOperationGetResult(op: dim, pos: 0); |
171 | MlirValue constOneValue = mlirOperationGetResult(op: constOne, pos: 0); |
172 | MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue}; |
173 | MlirOperationState loopState = mlirOperationStateGet( |
174 | name: mlirStringRefCreateFromCString(str: "scf.for" ), loc: location); |
175 | mlirOperationStateAddOperands(state: &loopState, n: 3, operands: loopOperands); |
176 | mlirOperationStateAddOwnedRegions(state: &loopState, n: 1, regions: &loopBodyRegion); |
177 | MlirOperation loop = mlirOperationCreate(state: &loopState); |
178 | mlirBlockAppendOwnedOperation(block: funcBody, operation: loop); |
179 | |
180 | populateLoopBody(ctx, loopBody, location, funcBody); |
181 | |
182 | MlirOperationState retState = mlirOperationStateGet( |
183 | name: mlirStringRefCreateFromCString(str: "func.return" ), loc: location); |
184 | MlirOperation ret = mlirOperationCreate(state: &retState); |
185 | mlirBlockAppendOwnedOperation(block: funcBody, operation: ret); |
186 | |
187 | MlirOperation module = mlirModuleGetOperation(module: moduleOp); |
188 | mlirOperationDump(op: module); |
189 | // clang-format off |
190 | // CHECK: module { |
191 | // CHECK: func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) { |
192 | // CHECK: %[[C0:.*]] = arith.constant 0 : index |
193 | // CHECK: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xf32> |
194 | // CHECK: %[[C1:.*]] = arith.constant 1 : index |
195 | // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] { |
196 | // CHECK: %[[LHS:.*]] = memref.load %[[ARG0]][%[[I]]] : memref<?xf32> |
197 | // CHECK: %[[RHS:.*]] = memref.load %[[ARG1]][%[[I]]] : memref<?xf32> |
198 | // CHECK: %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 |
199 | // CHECK: memref.store %[[SUM]], %[[ARG0]][%[[I]]] : memref<?xf32> |
200 | // CHECK: } |
201 | // CHECK: return |
202 | // CHECK: } |
203 | // CHECK: } |
204 | // clang-format on |
205 | |
206 | return moduleOp; |
207 | } |
208 | |
209 | struct OpListNode { |
210 | MlirOperation op; |
211 | struct OpListNode *next; |
212 | }; |
213 | typedef struct OpListNode OpListNode; |
214 | |
215 | struct ModuleStats { |
216 | unsigned numOperations; |
217 | unsigned numAttributes; |
218 | unsigned numBlocks; |
219 | unsigned numRegions; |
220 | unsigned numValues; |
221 | unsigned numBlockArguments; |
222 | unsigned numOpResults; |
223 | }; |
224 | typedef struct ModuleStats ModuleStats; |
225 | |
226 | int collectStatsSingle(OpListNode *head, ModuleStats *stats) { |
227 | MlirOperation operation = head->op; |
228 | stats->numOperations += 1; |
229 | stats->numValues += mlirOperationGetNumResults(op: operation); |
230 | stats->numAttributes += mlirOperationGetNumAttributes(op: operation); |
231 | |
232 | unsigned numRegions = mlirOperationGetNumRegions(op: operation); |
233 | |
234 | stats->numRegions += numRegions; |
235 | |
236 | intptr_t numResults = mlirOperationGetNumResults(op: operation); |
237 | for (intptr_t i = 0; i < numResults; ++i) { |
238 | MlirValue result = mlirOperationGetResult(op: operation, pos: i); |
239 | if (!mlirValueIsAOpResult(value: result)) |
240 | return 1; |
241 | if (mlirValueIsABlockArgument(value: result)) |
242 | return 2; |
243 | if (!mlirOperationEqual(op: operation, other: mlirOpResultGetOwner(value: result))) |
244 | return 3; |
245 | if (i != mlirOpResultGetResultNumber(value: result)) |
246 | return 4; |
247 | ++stats->numOpResults; |
248 | } |
249 | |
250 | MlirRegion region = mlirOperationGetFirstRegion(op: operation); |
251 | while (!mlirRegionIsNull(region)) { |
252 | for (MlirBlock block = mlirRegionGetFirstBlock(region); |
253 | !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) { |
254 | ++stats->numBlocks; |
255 | intptr_t numArgs = mlirBlockGetNumArguments(block); |
256 | stats->numValues += numArgs; |
257 | for (intptr_t j = 0; j < numArgs; ++j) { |
258 | MlirValue arg = mlirBlockGetArgument(block, pos: j); |
259 | if (!mlirValueIsABlockArgument(value: arg)) |
260 | return 5; |
261 | if (mlirValueIsAOpResult(value: arg)) |
262 | return 6; |
263 | if (!mlirBlockEqual(block, other: mlirBlockArgumentGetOwner(value: arg))) |
264 | return 7; |
265 | if (j != mlirBlockArgumentGetArgNumber(value: arg)) |
266 | return 8; |
267 | ++stats->numBlockArguments; |
268 | } |
269 | |
270 | for (MlirOperation child = mlirBlockGetFirstOperation(block); |
271 | !mlirOperationIsNull(op: child); |
272 | child = mlirOperationGetNextInBlock(op: child)) { |
273 | OpListNode *node = malloc(size: sizeof(OpListNode)); |
274 | node->op = child; |
275 | node->next = head->next; |
276 | head->next = node; |
277 | } |
278 | } |
279 | region = mlirRegionGetNextInOperation(region); |
280 | } |
281 | return 0; |
282 | } |
283 | |
284 | int collectStats(MlirOperation operation) { |
285 | OpListNode *head = malloc(size: sizeof(OpListNode)); |
286 | head->op = operation; |
287 | head->next = NULL; |
288 | |
289 | ModuleStats stats; |
290 | stats.numOperations = 0; |
291 | stats.numAttributes = 0; |
292 | stats.numBlocks = 0; |
293 | stats.numRegions = 0; |
294 | stats.numValues = 0; |
295 | stats.numBlockArguments = 0; |
296 | stats.numOpResults = 0; |
297 | |
298 | do { |
299 | int retval = collectStatsSingle(head, stats: &stats); |
300 | if (retval) { |
301 | free(ptr: head); |
302 | return retval; |
303 | } |
304 | OpListNode *next = head->next; |
305 | free(ptr: head); |
306 | head = next; |
307 | } while (head); |
308 | |
309 | if (stats.numValues != stats.numBlockArguments + stats.numOpResults) |
310 | return 100; |
311 | |
312 | fprintf(stderr, format: "@stats\n" ); |
313 | fprintf(stderr, format: "Number of operations: %u\n" , stats.numOperations); |
314 | fprintf(stderr, format: "Number of attributes: %u\n" , stats.numAttributes); |
315 | fprintf(stderr, format: "Number of blocks: %u\n" , stats.numBlocks); |
316 | fprintf(stderr, format: "Number of regions: %u\n" , stats.numRegions); |
317 | fprintf(stderr, format: "Number of values: %u\n" , stats.numValues); |
318 | fprintf(stderr, format: "Number of block arguments: %u\n" , stats.numBlockArguments); |
319 | fprintf(stderr, format: "Number of op results: %u\n" , stats.numOpResults); |
320 | // clang-format off |
321 | // CHECK-LABEL: @stats |
322 | // CHECK: Number of operations: 12 |
323 | // CHECK: Number of attributes: 5 |
324 | // CHECK: Number of blocks: 3 |
325 | // CHECK: Number of regions: 3 |
326 | // CHECK: Number of values: 9 |
327 | // CHECK: Number of block arguments: 3 |
328 | // CHECK: Number of op results: 6 |
329 | // clang-format on |
330 | return 0; |
331 | } |
332 | |
333 | static void printToStderr(MlirStringRef str, void *userData) { |
334 | (void)userData; |
335 | fwrite(ptr: str.data, size: 1, n: str.length, stderr); |
336 | } |
337 | |
338 | static void printFirstOfEach(MlirContext ctx, MlirOperation operation) { |
339 | // Assuming we are given a module, go to the first operation of the first |
340 | // function. |
341 | MlirRegion region = mlirOperationGetRegion(op: operation, pos: 0); |
342 | MlirBlock block = mlirRegionGetFirstBlock(region); |
343 | operation = mlirBlockGetFirstOperation(block); |
344 | region = mlirOperationGetRegion(op: operation, pos: 0); |
345 | MlirOperation parentOperation = operation; |
346 | block = mlirRegionGetFirstBlock(region); |
347 | operation = mlirBlockGetFirstOperation(block); |
348 | assert(mlirModuleIsNull(mlirModuleFromOperation(operation))); |
349 | |
350 | // Verify that parent operation and block report correctly. |
351 | // CHECK: Parent operation eq: 1 |
352 | fprintf(stderr, format: "Parent operation eq: %d\n" , |
353 | mlirOperationEqual(op: mlirOperationGetParentOperation(op: operation), |
354 | other: parentOperation)); |
355 | // CHECK: Block eq: 1 |
356 | fprintf(stderr, format: "Block eq: %d\n" , |
357 | mlirBlockEqual(block: mlirOperationGetBlock(op: operation), other: block)); |
358 | // CHECK: Block parent operation eq: 1 |
359 | fprintf( |
360 | stderr, format: "Block parent operation eq: %d\n" , |
361 | mlirOperationEqual(op: mlirBlockGetParentOperation(block), other: parentOperation)); |
362 | // CHECK: Block parent region eq: 1 |
363 | fprintf(stderr, format: "Block parent region eq: %d\n" , |
364 | mlirRegionEqual(region: mlirBlockGetParentRegion(block), other: region)); |
365 | |
366 | // In the module we created, the first operation of the first function is |
367 | // an "memref.dim", which has an attribute and a single result that we can |
368 | // use to test the printing mechanism. |
369 | mlirBlockPrint(block, callback: printToStderr, NULL); |
370 | fprintf(stderr, format: "\n" ); |
371 | fprintf(stderr, format: "First operation: " ); |
372 | mlirOperationPrint(op: operation, callback: printToStderr, NULL); |
373 | fprintf(stderr, format: "\n" ); |
374 | // clang-format off |
375 | // CHECK: %[[C0:.*]] = arith.constant 0 : index |
376 | // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32> |
377 | // CHECK: %[[C1:.*]] = arith.constant 1 : index |
378 | // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] { |
379 | // CHECK: %[[LHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32> |
380 | // CHECK: %[[RHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32> |
381 | // CHECK: %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 |
382 | // CHECK: memref.store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32> |
383 | // CHECK: } |
384 | // CHECK: return |
385 | // CHECK: First operation: {{.*}} = arith.constant 0 : index |
386 | // clang-format on |
387 | |
388 | // Get the operation name and print it. |
389 | MlirIdentifier ident = mlirOperationGetName(op: operation); |
390 | MlirStringRef identStr = mlirIdentifierStr(ident); |
391 | fprintf(stderr, format: "Operation name: '" ); |
392 | for (size_t i = 0; i < identStr.length; ++i) |
393 | fputc(c: identStr.data[i], stderr); |
394 | fprintf(stderr, format: "'\n" ); |
395 | // CHECK: Operation name: 'arith.constant' |
396 | |
397 | // Get the identifier again and verify equal. |
398 | MlirIdentifier identAgain = mlirIdentifierGet(context: ctx, str: identStr); |
399 | fprintf(stderr, format: "Identifier equal: %d\n" , |
400 | mlirIdentifierEqual(ident, other: identAgain)); |
401 | // CHECK: Identifier equal: 1 |
402 | |
403 | // Get the block terminator and print it. |
404 | MlirOperation terminator = mlirBlockGetTerminator(block); |
405 | fprintf(stderr, format: "Terminator: " ); |
406 | mlirOperationPrint(op: terminator, callback: printToStderr, NULL); |
407 | fprintf(stderr, format: "\n" ); |
408 | // CHECK: Terminator: func.return |
409 | |
410 | // Get the attribute by name. |
411 | bool hasValueAttr = mlirOperationHasInherentAttributeByName( |
412 | op: operation, name: mlirStringRefCreateFromCString(str: "value" )); |
413 | if (hasValueAttr) |
414 | // CHECK: Has attr "value" |
415 | fprintf(stderr, format: "Has attr \"value\"" ); |
416 | |
417 | MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName( |
418 | op: operation, name: mlirStringRefCreateFromCString(str: "value" )); |
419 | fprintf(stderr, format: "Get attr \"value\": " ); |
420 | mlirAttributePrint(attr: valueAttr0, callback: printToStderr, NULL); |
421 | fprintf(stderr, format: "\n" ); |
422 | // CHECK: Get attr "value": 0 : index |
423 | |
424 | // Get a non-existing attribute and assert that it is null (sanity). |
425 | fprintf(stderr, format: "does_not_exist is null: %d\n" , |
426 | mlirAttributeIsNull(attr: mlirOperationGetDiscardableAttributeByName( |
427 | op: operation, name: mlirStringRefCreateFromCString(str: "does_not_exist" )))); |
428 | // CHECK: does_not_exist is null: 1 |
429 | |
430 | // Get result 0 and its type. |
431 | MlirValue value = mlirOperationGetResult(op: operation, pos: 0); |
432 | fprintf(stderr, format: "Result 0: " ); |
433 | mlirValuePrint(value, callback: printToStderr, NULL); |
434 | fprintf(stderr, format: "\n" ); |
435 | fprintf(stderr, format: "Value is null: %d\n" , mlirValueIsNull(value)); |
436 | // CHECK: Result 0: {{.*}} = arith.constant 0 : index |
437 | // CHECK: Value is null: 0 |
438 | |
439 | MlirType type = mlirValueGetType(value); |
440 | fprintf(stderr, format: "Result 0 type: " ); |
441 | mlirTypePrint(type, callback: printToStderr, NULL); |
442 | fprintf(stderr, format: "\n" ); |
443 | // CHECK: Result 0 type: index |
444 | |
445 | // Set a discardable attribute. |
446 | mlirOperationSetDiscardableAttributeByName( |
447 | op: operation, name: mlirStringRefCreateFromCString(str: "custom_attr" ), |
448 | attr: mlirBoolAttrGet(ctx, value: 1)); |
449 | fprintf(stderr, format: "Op with set attr: " ); |
450 | mlirOperationPrint(op: operation, callback: printToStderr, NULL); |
451 | fprintf(stderr, format: "\n" ); |
452 | // CHECK: Op with set attr: {{.*}} {custom_attr = true} |
453 | |
454 | // Remove the attribute. |
455 | fprintf(stderr, format: "Remove attr: %d\n" , |
456 | mlirOperationRemoveDiscardableAttributeByName( |
457 | op: operation, name: mlirStringRefCreateFromCString(str: "custom_attr" ))); |
458 | fprintf(stderr, format: "Remove attr again: %d\n" , |
459 | mlirOperationRemoveDiscardableAttributeByName( |
460 | op: operation, name: mlirStringRefCreateFromCString(str: "custom_attr" ))); |
461 | fprintf(stderr, format: "Removed attr is null: %d\n" , |
462 | mlirAttributeIsNull(attr: mlirOperationGetDiscardableAttributeByName( |
463 | op: operation, name: mlirStringRefCreateFromCString(str: "custom_attr" )))); |
464 | // CHECK: Remove attr: 1 |
465 | // CHECK: Remove attr again: 0 |
466 | // CHECK: Removed attr is null: 1 |
467 | |
468 | // Add a large attribute to verify printing flags. |
469 | int64_t eltsShape[] = {4}; |
470 | int32_t eltsData[] = {1, 2, 3, 4}; |
471 | mlirOperationSetDiscardableAttributeByName( |
472 | op: operation, name: mlirStringRefCreateFromCString(str: "elts" ), |
473 | attr: mlirDenseElementsAttrInt32Get( |
474 | shapedType: mlirRankedTensorTypeGet(rank: 1, shape: eltsShape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 32), |
475 | encoding: mlirAttributeGetNull()), |
476 | numElements: 4, elements: eltsData)); |
477 | MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); |
478 | mlirOpPrintingFlagsElideLargeElementsAttrs(flags, largeElementLimit: 2); |
479 | mlirOpPrintingFlagsPrintGenericOpForm(flags); |
480 | mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/1, /*prettyForm=*/0); |
481 | mlirOpPrintingFlagsUseLocalScope(flags); |
482 | fprintf(stderr, format: "Op print with all flags: " ); |
483 | mlirOperationPrintWithFlags(op: operation, flags, callback: printToStderr, NULL); |
484 | fprintf(stderr, format: "\n" ); |
485 | fprintf(stderr, format: "Op print with state: " ); |
486 | MlirAsmState state = mlirAsmStateCreateForOperation(op: parentOperation, flags); |
487 | mlirOperationPrintWithState(op: operation, state, callback: printToStderr, NULL); |
488 | fprintf(stderr, format: "\n" ); |
489 | // clang-format off |
490 | // CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown) |
491 | // clang-format on |
492 | |
493 | fprintf(stderr, format: "With state: |" ); |
494 | mlirValuePrintAsOperand(value, state, callback: printToStderr, NULL); |
495 | // CHECK: With state: |%0| |
496 | fprintf(stderr, format: "|\n" ); |
497 | mlirAsmStateDestroy(state); |
498 | |
499 | mlirOpPrintingFlagsDestroy(flags); |
500 | } |
501 | |
502 | static int constructAndTraverseIr(MlirContext ctx) { |
503 | MlirLocation location = mlirLocationUnknownGet(context: ctx); |
504 | |
505 | MlirModule moduleOp = makeAndDumpAdd(ctx, location); |
506 | MlirOperation module = mlirModuleGetOperation(module: moduleOp); |
507 | assert(!mlirModuleIsNull(mlirModuleFromOperation(module))); |
508 | |
509 | int errcode = collectStats(operation: module); |
510 | if (errcode) |
511 | return errcode; |
512 | |
513 | printFirstOfEach(ctx, operation: module); |
514 | |
515 | mlirModuleDestroy(module: moduleOp); |
516 | return 0; |
517 | } |
518 | |
519 | /// Creates an operation with a region containing multiple blocks with |
520 | /// operations and dumps it. The blocks and operations are inserted using |
521 | /// block/operation-relative API and their final order is checked. |
522 | static void buildWithInsertionsAndPrint(MlirContext ctx) { |
523 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
524 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
525 | |
526 | MlirRegion owningRegion = mlirRegionCreate(); |
527 | MlirBlock nullBlock = mlirRegionGetFirstBlock(region: owningRegion); |
528 | MlirOperationState state = mlirOperationStateGet( |
529 | name: mlirStringRefCreateFromCString(str: "insertion.order.test" ), loc); |
530 | mlirOperationStateAddOwnedRegions(state: &state, n: 1, regions: &owningRegion); |
531 | MlirOperation op = mlirOperationCreate(state: &state); |
532 | MlirRegion region = mlirOperationGetRegion(op, pos: 0); |
533 | |
534 | // Use integer types of different bitwidth as block arguments in order to |
535 | // differentiate blocks. |
536 | MlirType i1 = mlirIntegerTypeGet(ctx, bitwidth: 1); |
537 | MlirType i2 = mlirIntegerTypeGet(ctx, bitwidth: 2); |
538 | MlirType i3 = mlirIntegerTypeGet(ctx, bitwidth: 3); |
539 | MlirType i4 = mlirIntegerTypeGet(ctx, bitwidth: 4); |
540 | MlirType i5 = mlirIntegerTypeGet(ctx, bitwidth: 5); |
541 | MlirBlock block1 = mlirBlockCreate(nArgs: 1, args: &i1, locs: &loc); |
542 | MlirBlock block2 = mlirBlockCreate(nArgs: 1, args: &i2, locs: &loc); |
543 | MlirBlock block3 = mlirBlockCreate(nArgs: 1, args: &i3, locs: &loc); |
544 | MlirBlock block4 = mlirBlockCreate(nArgs: 1, args: &i4, locs: &loc); |
545 | MlirBlock block5 = mlirBlockCreate(nArgs: 1, args: &i5, locs: &loc); |
546 | // Insert blocks so as to obtain the 1-2-3-4 order, |
547 | mlirRegionInsertOwnedBlockBefore(region, reference: nullBlock, block: block3); |
548 | mlirRegionInsertOwnedBlockBefore(region, reference: block3, block: block2); |
549 | mlirRegionInsertOwnedBlockAfter(region, reference: nullBlock, block: block1); |
550 | mlirRegionInsertOwnedBlockAfter(region, reference: block3, block: block4); |
551 | mlirRegionInsertOwnedBlockBefore(region, reference: block3, block: block5); |
552 | |
553 | MlirOperationState op1State = |
554 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op1" ), loc); |
555 | MlirOperationState op2State = |
556 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op2" ), loc); |
557 | MlirOperationState op3State = |
558 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op3" ), loc); |
559 | MlirOperationState op4State = |
560 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op4" ), loc); |
561 | MlirOperationState op5State = |
562 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op5" ), loc); |
563 | MlirOperationState op6State = |
564 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op6" ), loc); |
565 | MlirOperationState op7State = |
566 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op7" ), loc); |
567 | MlirOperationState op8State = |
568 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op8" ), loc); |
569 | MlirOperation op1 = mlirOperationCreate(state: &op1State); |
570 | MlirOperation op2 = mlirOperationCreate(state: &op2State); |
571 | MlirOperation op3 = mlirOperationCreate(state: &op3State); |
572 | MlirOperation op4 = mlirOperationCreate(state: &op4State); |
573 | MlirOperation op5 = mlirOperationCreate(state: &op5State); |
574 | MlirOperation op6 = mlirOperationCreate(state: &op6State); |
575 | MlirOperation op7 = mlirOperationCreate(state: &op7State); |
576 | MlirOperation op8 = mlirOperationCreate(state: &op8State); |
577 | |
578 | // Insert operations in the first block so as to obtain the 1-2-3-4 order. |
579 | MlirOperation nullOperation = mlirBlockGetFirstOperation(block: block1); |
580 | assert(mlirOperationIsNull(nullOperation)); |
581 | mlirBlockInsertOwnedOperationBefore(block: block1, reference: nullOperation, operation: op3); |
582 | mlirBlockInsertOwnedOperationBefore(block: block1, reference: op3, operation: op2); |
583 | mlirBlockInsertOwnedOperationAfter(block: block1, reference: nullOperation, operation: op1); |
584 | mlirBlockInsertOwnedOperationAfter(block: block1, reference: op3, operation: op4); |
585 | |
586 | // Append operations to the rest of blocks to make them non-empty and thus |
587 | // printable. |
588 | mlirBlockAppendOwnedOperation(block: block2, operation: op5); |
589 | mlirBlockAppendOwnedOperation(block: block3, operation: op6); |
590 | mlirBlockAppendOwnedOperation(block: block4, operation: op7); |
591 | mlirBlockAppendOwnedOperation(block: block5, operation: op8); |
592 | |
593 | // Remove block5. |
594 | mlirBlockDetach(block: block5); |
595 | mlirBlockDestroy(block: block5); |
596 | |
597 | mlirOperationDump(op); |
598 | mlirOperationDestroy(op); |
599 | mlirContextSetAllowUnregisteredDialects(context: ctx, false); |
600 | // clang-format off |
601 | // CHECK-LABEL: "insertion.order.test" |
602 | // CHECK: ^{{.*}}(%{{.*}}: i1 |
603 | // CHECK: "dummy.op1" |
604 | // CHECK-NEXT: "dummy.op2" |
605 | // CHECK-NEXT: "dummy.op3" |
606 | // CHECK-NEXT: "dummy.op4" |
607 | // CHECK: ^{{.*}}(%{{.*}}: i2 |
608 | // CHECK: "dummy.op5" |
609 | // CHECK-NOT: ^{{.*}}(%{{.*}}: i5 |
610 | // CHECK-NOT: "dummy.op8" |
611 | // CHECK: ^{{.*}}(%{{.*}}: i3 |
612 | // CHECK: "dummy.op6" |
613 | // CHECK: ^{{.*}}(%{{.*}}: i4 |
614 | // CHECK: "dummy.op7" |
615 | // clang-format on |
616 | } |
617 | |
618 | /// Creates operations with type inference and tests various failure modes. |
619 | static int createOperationWithTypeInference(MlirContext ctx) { |
620 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
621 | MlirAttribute iAttr = mlirIntegerAttrGet(type: mlirIntegerTypeGet(ctx, bitwidth: 32), value: 4); |
622 | |
623 | // The shape.const_size op implements result type inference and is only used |
624 | // for that reason. |
625 | MlirOperationState state = mlirOperationStateGet( |
626 | name: mlirStringRefCreateFromCString(str: "shape.const_size" ), loc); |
627 | MlirNamedAttribute valueAttr = mlirNamedAttributeGet( |
628 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value" )), attr: iAttr); |
629 | mlirOperationStateAddAttributes(state: &state, n: 1, attributes: &valueAttr); |
630 | mlirOperationStateEnableResultTypeInference(state: &state); |
631 | |
632 | // Expect result type inference to succeed. |
633 | MlirOperation op = mlirOperationCreate(state: &state); |
634 | if (mlirOperationIsNull(op)) { |
635 | fprintf(stderr, format: "ERROR: Result type inference unexpectedly failed" ); |
636 | return 1; |
637 | } |
638 | |
639 | // CHECK: RESULT_TYPE_INFERENCE: !shape.size |
640 | fprintf(stderr, format: "RESULT_TYPE_INFERENCE: " ); |
641 | mlirTypeDump(type: mlirValueGetType(value: mlirOperationGetResult(op, pos: 0))); |
642 | fprintf(stderr, format: "\n" ); |
643 | mlirOperationDestroy(op); |
644 | return 0; |
645 | } |
646 | |
647 | /// Dumps instances of all builtin types to check that C API works correctly. |
648 | /// Additionally, performs simple identity checks that a builtin type |
649 | /// constructed with C API can be inspected and has the expected type. The |
650 | /// latter achieves full coverage of C API for builtin types. Returns 0 on |
651 | /// success and a non-zero error code on failure. |
652 | static int printBuiltinTypes(MlirContext ctx) { |
653 | // Integer types. |
654 | MlirType i32 = mlirIntegerTypeGet(ctx, bitwidth: 32); |
655 | MlirType si32 = mlirIntegerTypeSignedGet(ctx, bitwidth: 32); |
656 | MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, bitwidth: 32); |
657 | if (!mlirTypeIsAInteger(type: i32) || mlirTypeIsAF32(type: i32)) |
658 | return 1; |
659 | if (!mlirTypeIsAInteger(type: si32) || !mlirIntegerTypeIsSigned(type: si32)) |
660 | return 2; |
661 | if (!mlirTypeIsAInteger(type: ui32) || !mlirIntegerTypeIsUnsigned(type: ui32)) |
662 | return 3; |
663 | if (mlirTypeEqual(t1: i32, t2: ui32) || mlirTypeEqual(t1: i32, t2: si32)) |
664 | return 4; |
665 | if (mlirIntegerTypeGetWidth(type: i32) != mlirIntegerTypeGetWidth(type: si32)) |
666 | return 5; |
667 | fprintf(stderr, format: "@types\n" ); |
668 | mlirTypeDump(type: i32); |
669 | fprintf(stderr, format: "\n" ); |
670 | mlirTypeDump(type: si32); |
671 | fprintf(stderr, format: "\n" ); |
672 | mlirTypeDump(type: ui32); |
673 | fprintf(stderr, format: "\n" ); |
674 | // CHECK-LABEL: @types |
675 | // CHECK: i32 |
676 | // CHECK: si32 |
677 | // CHECK: ui32 |
678 | |
679 | // Index type. |
680 | MlirType index = mlirIndexTypeGet(ctx); |
681 | if (!mlirTypeIsAIndex(type: index)) |
682 | return 6; |
683 | mlirTypeDump(type: index); |
684 | fprintf(stderr, format: "\n" ); |
685 | // CHECK: index |
686 | |
687 | // Floating-point types. |
688 | MlirType bf16 = mlirBF16TypeGet(ctx); |
689 | MlirType f16 = mlirF16TypeGet(ctx); |
690 | MlirType f32 = mlirF32TypeGet(ctx); |
691 | MlirType f64 = mlirF64TypeGet(ctx); |
692 | if (!mlirTypeIsABF16(type: bf16)) |
693 | return 7; |
694 | if (!mlirTypeIsAF16(type: f16)) |
695 | return 9; |
696 | if (!mlirTypeIsAF32(type: f32)) |
697 | return 10; |
698 | if (!mlirTypeIsAF64(type: f64)) |
699 | return 11; |
700 | mlirTypeDump(type: bf16); |
701 | fprintf(stderr, format: "\n" ); |
702 | mlirTypeDump(type: f16); |
703 | fprintf(stderr, format: "\n" ); |
704 | mlirTypeDump(type: f32); |
705 | fprintf(stderr, format: "\n" ); |
706 | mlirTypeDump(type: f64); |
707 | fprintf(stderr, format: "\n" ); |
708 | // CHECK: bf16 |
709 | // CHECK: f16 |
710 | // CHECK: f32 |
711 | // CHECK: f64 |
712 | |
713 | // None type. |
714 | MlirType none = mlirNoneTypeGet(ctx); |
715 | if (!mlirTypeIsANone(type: none)) |
716 | return 12; |
717 | mlirTypeDump(type: none); |
718 | fprintf(stderr, format: "\n" ); |
719 | // CHECK: none |
720 | |
721 | // Complex type. |
722 | MlirType cplx = mlirComplexTypeGet(elementType: f32); |
723 | if (!mlirTypeIsAComplex(type: cplx) || |
724 | !mlirTypeEqual(t1: mlirComplexTypeGetElementType(type: cplx), t2: f32)) |
725 | return 13; |
726 | mlirTypeDump(type: cplx); |
727 | fprintf(stderr, format: "\n" ); |
728 | // CHECK: complex<f32> |
729 | |
730 | // Vector (and Shaped) type. ShapedType is a common base class for vectors, |
731 | // memrefs and tensors, one cannot create instances of this class so it is |
732 | // tested on an instance of vector type. |
733 | int64_t shape[] = {2, 3}; |
734 | MlirType vector = |
735 | mlirVectorTypeGet(rank: sizeof(shape) / sizeof(int64_t), shape, elementType: f32); |
736 | if (!mlirTypeIsAVector(type: vector) || !mlirTypeIsAShaped(type: vector)) |
737 | return 14; |
738 | if (!mlirTypeEqual(t1: mlirShapedTypeGetElementType(type: vector), t2: f32) || |
739 | !mlirShapedTypeHasRank(type: vector) || mlirShapedTypeGetRank(type: vector) != 2 || |
740 | mlirShapedTypeGetDimSize(type: vector, dim: 0) != 2 || |
741 | mlirShapedTypeIsDynamicDim(type: vector, dim: 0) || |
742 | mlirShapedTypeGetDimSize(type: vector, dim: 1) != 3 || |
743 | !mlirShapedTypeHasStaticShape(type: vector)) |
744 | return 15; |
745 | mlirTypeDump(type: vector); |
746 | fprintf(stderr, format: "\n" ); |
747 | // CHECK: vector<2x3xf32> |
748 | |
749 | // Scalable vector type. |
750 | bool scalable[] = {false, true}; |
751 | MlirType scalableVector = mlirVectorTypeGetScalable( |
752 | rank: sizeof(shape) / sizeof(int64_t), shape, scalable, elementType: f32); |
753 | if (!mlirTypeIsAVector(type: scalableVector)) |
754 | return 16; |
755 | if (!mlirVectorTypeIsScalable(type: scalableVector) || |
756 | mlirVectorTypeIsDimScalable(type: scalableVector, dim: 0) || |
757 | !mlirVectorTypeIsDimScalable(type: scalableVector, dim: 1)) |
758 | return 17; |
759 | mlirTypeDump(type: scalableVector); |
760 | fprintf(stderr, format: "\n" ); |
761 | // CHECK: vector<2x[3]xf32> |
762 | |
763 | // Ranked tensor type. |
764 | MlirType rankedTensor = mlirRankedTensorTypeGet( |
765 | rank: sizeof(shape) / sizeof(int64_t), shape, elementType: f32, encoding: mlirAttributeGetNull()); |
766 | if (!mlirTypeIsATensor(type: rankedTensor) || |
767 | !mlirTypeIsARankedTensor(type: rankedTensor) || |
768 | !mlirAttributeIsNull(attr: mlirRankedTensorTypeGetEncoding(type: rankedTensor))) |
769 | return 18; |
770 | mlirTypeDump(type: rankedTensor); |
771 | fprintf(stderr, format: "\n" ); |
772 | // CHECK: tensor<2x3xf32> |
773 | |
774 | // Unranked tensor type. |
775 | MlirType unrankedTensor = mlirUnrankedTensorTypeGet(elementType: f32); |
776 | if (!mlirTypeIsATensor(type: unrankedTensor) || |
777 | !mlirTypeIsAUnrankedTensor(type: unrankedTensor) || |
778 | mlirShapedTypeHasRank(type: unrankedTensor)) |
779 | return 19; |
780 | mlirTypeDump(type: unrankedTensor); |
781 | fprintf(stderr, format: "\n" ); |
782 | // CHECK: tensor<*xf32> |
783 | |
784 | // MemRef type. |
785 | MlirAttribute memSpace2 = mlirIntegerAttrGet(type: mlirIntegerTypeGet(ctx, bitwidth: 64), value: 2); |
786 | MlirType memRef = mlirMemRefTypeContiguousGet( |
787 | elementType: f32, rank: sizeof(shape) / sizeof(int64_t), shape, memorySpace: memSpace2); |
788 | if (!mlirTypeIsAMemRef(type: memRef) || |
789 | !mlirAttributeEqual(a1: mlirMemRefTypeGetMemorySpace(type: memRef), a2: memSpace2)) |
790 | return 20; |
791 | mlirTypeDump(type: memRef); |
792 | fprintf(stderr, format: "\n" ); |
793 | // CHECK: memref<2x3xf32, 2> |
794 | |
795 | // Unranked MemRef type. |
796 | MlirAttribute memSpace4 = mlirIntegerAttrGet(type: mlirIntegerTypeGet(ctx, bitwidth: 64), value: 4); |
797 | MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(elementType: f32, memorySpace: memSpace4); |
798 | if (!mlirTypeIsAUnrankedMemRef(type: unrankedMemRef) || |
799 | mlirTypeIsAMemRef(type: unrankedMemRef) || |
800 | !mlirAttributeEqual(a1: mlirUnrankedMemrefGetMemorySpace(type: unrankedMemRef), |
801 | a2: memSpace4)) |
802 | return 21; |
803 | mlirTypeDump(type: unrankedMemRef); |
804 | fprintf(stderr, format: "\n" ); |
805 | // CHECK: memref<*xf32, 4> |
806 | |
807 | // Tuple type. |
808 | MlirType types[] = {unrankedMemRef, f32}; |
809 | MlirType tuple = mlirTupleTypeGet(ctx, numElements: 2, elements: types); |
810 | if (!mlirTypeIsATuple(type: tuple) || mlirTupleTypeGetNumTypes(type: tuple) != 2 || |
811 | !mlirTypeEqual(t1: mlirTupleTypeGetType(type: tuple, pos: 0), t2: unrankedMemRef) || |
812 | !mlirTypeEqual(t1: mlirTupleTypeGetType(type: tuple, pos: 1), t2: f32)) |
813 | return 22; |
814 | mlirTypeDump(type: tuple); |
815 | fprintf(stderr, format: "\n" ); |
816 | // CHECK: tuple<memref<*xf32, 4>, f32> |
817 | |
818 | // Function type. |
819 | MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, bitwidth: 1)}; |
820 | MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, bitwidth: 16), |
821 | mlirIntegerTypeGet(ctx, bitwidth: 32), |
822 | mlirIntegerTypeGet(ctx, bitwidth: 64)}; |
823 | MlirType funcType = mlirFunctionTypeGet(ctx, numInputs: 2, inputs: funcInputs, numResults: 3, results: funcResults); |
824 | if (mlirFunctionTypeGetNumInputs(type: funcType) != 2) |
825 | return 23; |
826 | if (mlirFunctionTypeGetNumResults(type: funcType) != 3) |
827 | return 24; |
828 | if (!mlirTypeEqual(t1: funcInputs[0], t2: mlirFunctionTypeGetInput(type: funcType, pos: 0)) || |
829 | !mlirTypeEqual(t1: funcInputs[1], t2: mlirFunctionTypeGetInput(type: funcType, pos: 1))) |
830 | return 25; |
831 | if (!mlirTypeEqual(t1: funcResults[0], t2: mlirFunctionTypeGetResult(type: funcType, pos: 0)) || |
832 | !mlirTypeEqual(t1: funcResults[1], t2: mlirFunctionTypeGetResult(type: funcType, pos: 1)) || |
833 | !mlirTypeEqual(t1: funcResults[2], t2: mlirFunctionTypeGetResult(type: funcType, pos: 2))) |
834 | return 26; |
835 | mlirTypeDump(type: funcType); |
836 | fprintf(stderr, format: "\n" ); |
837 | // CHECK: (index, i1) -> (i16, i32, i64) |
838 | |
839 | // Opaque type. |
840 | MlirStringRef namespace = mlirStringRefCreate(str: "dialect" , length: 7); |
841 | MlirStringRef data = mlirStringRefCreate(str: "type" , length: 4); |
842 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
843 | MlirType opaque = mlirOpaqueTypeGet(ctx, dialectNamespace: namespace, typeData: data); |
844 | mlirContextSetAllowUnregisteredDialects(context: ctx, false); |
845 | if (!mlirTypeIsAOpaque(type: opaque) || |
846 | !mlirStringRefEqual(string: mlirOpaqueTypeGetDialectNamespace(type: opaque), |
847 | other: namespace) || |
848 | !mlirStringRefEqual(string: mlirOpaqueTypeGetData(type: opaque), other: data)) |
849 | return 27; |
850 | mlirTypeDump(type: opaque); |
851 | fprintf(stderr, format: "\n" ); |
852 | // CHECK: !dialect.type |
853 | |
854 | return 0; |
855 | } |
856 | |
857 | void callbackSetFixedLengthString(const char *data, intptr_t len, |
858 | void *userData) { |
859 | strncpy(dest: userData, src: data, n: len); |
860 | } |
861 | |
862 | bool stringIsEqual(const char *lhs, MlirStringRef rhs) { |
863 | if (strlen(s: lhs) != rhs.length) { |
864 | return false; |
865 | } |
866 | return !strncmp(s1: lhs, s2: rhs.data, n: rhs.length); |
867 | } |
868 | |
869 | int printBuiltinAttributes(MlirContext ctx) { |
870 | MlirAttribute floating = |
871 | mlirFloatAttrDoubleGet(ctx, type: mlirF64TypeGet(ctx), value: 2.0); |
872 | if (!mlirAttributeIsAFloat(attr: floating) || |
873 | fabs(x: mlirFloatAttrGetValueDouble(attr: floating) - 2.0) > 1E-6) |
874 | return 1; |
875 | fprintf(stderr, format: "@attrs\n" ); |
876 | mlirAttributeDump(attr: floating); |
877 | // CHECK-LABEL: @attrs |
878 | // CHECK: 2.000000e+00 : f64 |
879 | |
880 | // Exercise mlirAttributeGetType() just for the first one. |
881 | MlirType floatingType = mlirAttributeGetType(attribute: floating); |
882 | mlirTypeDump(type: floatingType); |
883 | // CHECK: f64 |
884 | |
885 | MlirAttribute integer = mlirIntegerAttrGet(type: mlirIntegerTypeGet(ctx, bitwidth: 32), value: 42); |
886 | MlirAttribute signedInteger = |
887 | mlirIntegerAttrGet(type: mlirIntegerTypeSignedGet(ctx, bitwidth: 8), value: -1); |
888 | MlirAttribute unsignedInteger = |
889 | mlirIntegerAttrGet(type: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 8), value: 255); |
890 | if (!mlirAttributeIsAInteger(attr: integer) || |
891 | mlirIntegerAttrGetValueInt(attr: integer) != 42 || |
892 | mlirIntegerAttrGetValueSInt(attr: signedInteger) != -1 || |
893 | mlirIntegerAttrGetValueUInt(attr: unsignedInteger) != 255) |
894 | return 2; |
895 | mlirAttributeDump(attr: integer); |
896 | mlirAttributeDump(attr: signedInteger); |
897 | mlirAttributeDump(attr: unsignedInteger); |
898 | // CHECK: 42 : i32 |
899 | // CHECK: -1 : si8 |
900 | // CHECK: 255 : ui8 |
901 | |
902 | MlirAttribute boolean = mlirBoolAttrGet(ctx, value: 1); |
903 | if (!mlirAttributeIsABool(attr: boolean) || !mlirBoolAttrGetValue(attr: boolean)) |
904 | return 3; |
905 | mlirAttributeDump(attr: boolean); |
906 | // CHECK: true |
907 | |
908 | const char data[] = "abcdefghijklmnopqestuvwxyz" ; |
909 | MlirAttribute opaque = |
910 | mlirOpaqueAttrGet(ctx, dialectNamespace: mlirStringRefCreateFromCString(str: "func" ), dataLength: 3, data, |
911 | type: mlirNoneTypeGet(ctx)); |
912 | if (!mlirAttributeIsAOpaque(attr: opaque) || |
913 | !stringIsEqual(lhs: "func" , rhs: mlirOpaqueAttrGetDialectNamespace(attr: opaque))) |
914 | return 4; |
915 | |
916 | MlirStringRef opaqueData = mlirOpaqueAttrGetData(attr: opaque); |
917 | if (opaqueData.length != 3 || |
918 | strncmp(s1: data, s2: opaqueData.data, n: opaqueData.length)) |
919 | return 5; |
920 | mlirAttributeDump(attr: opaque); |
921 | // CHECK: #func.abc |
922 | |
923 | MlirAttribute string = |
924 | mlirStringAttrGet(ctx, str: mlirStringRefCreate(str: data + 3, length: 2)); |
925 | if (!mlirAttributeIsAString(attr: string)) |
926 | return 6; |
927 | |
928 | MlirStringRef stringValue = mlirStringAttrGetValue(attr: string); |
929 | if (stringValue.length != 2 || |
930 | strncmp(s1: data + 3, s2: stringValue.data, n: stringValue.length)) |
931 | return 7; |
932 | mlirAttributeDump(attr: string); |
933 | // CHECK: "de" |
934 | |
935 | MlirAttribute flatSymbolRef = |
936 | mlirFlatSymbolRefAttrGet(ctx, symbol: mlirStringRefCreate(str: data + 5, length: 3)); |
937 | if (!mlirAttributeIsAFlatSymbolRef(attr: flatSymbolRef)) |
938 | return 8; |
939 | |
940 | MlirStringRef flatSymbolRefValue = |
941 | mlirFlatSymbolRefAttrGetValue(attr: flatSymbolRef); |
942 | if (flatSymbolRefValue.length != 3 || |
943 | strncmp(s1: data + 5, s2: flatSymbolRefValue.data, n: flatSymbolRefValue.length)) |
944 | return 9; |
945 | mlirAttributeDump(attr: flatSymbolRef); |
946 | // CHECK: @fgh |
947 | |
948 | MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef}; |
949 | MlirAttribute symbolRef = |
950 | mlirSymbolRefAttrGet(ctx, symbol: mlirStringRefCreate(str: data + 8, length: 2), numReferences: 2, references: symbols); |
951 | if (!mlirAttributeIsASymbolRef(attr: symbolRef) || |
952 | mlirSymbolRefAttrGetNumNestedReferences(attr: symbolRef) != 2 || |
953 | !mlirAttributeEqual(a1: mlirSymbolRefAttrGetNestedReference(attr: symbolRef, pos: 0), |
954 | a2: flatSymbolRef) || |
955 | !mlirAttributeEqual(a1: mlirSymbolRefAttrGetNestedReference(attr: symbolRef, pos: 1), |
956 | a2: flatSymbolRef)) |
957 | return 10; |
958 | |
959 | MlirStringRef symbolRefLeaf = mlirSymbolRefAttrGetLeafReference(attr: symbolRef); |
960 | MlirStringRef symbolRefRoot = mlirSymbolRefAttrGetRootReference(attr: symbolRef); |
961 | if (symbolRefLeaf.length != 3 || |
962 | strncmp(s1: data + 5, s2: symbolRefLeaf.data, n: symbolRefLeaf.length) || |
963 | symbolRefRoot.length != 2 || |
964 | strncmp(s1: data + 8, s2: symbolRefRoot.data, n: symbolRefRoot.length)) |
965 | return 11; |
966 | mlirAttributeDump(attr: symbolRef); |
967 | // CHECK: @ij::@fgh::@fgh |
968 | |
969 | MlirAttribute type = mlirTypeAttrGet(type: mlirF32TypeGet(ctx)); |
970 | if (!mlirAttributeIsAType(attr: type) || |
971 | !mlirTypeEqual(t1: mlirF32TypeGet(ctx), t2: mlirTypeAttrGetValue(attr: type))) |
972 | return 12; |
973 | mlirAttributeDump(attr: type); |
974 | // CHECK: f32 |
975 | |
976 | MlirAttribute unit = mlirUnitAttrGet(ctx); |
977 | if (!mlirAttributeIsAUnit(attr: unit)) |
978 | return 13; |
979 | mlirAttributeDump(attr: unit); |
980 | // CHECK: unit |
981 | |
982 | int64_t shape[] = {1, 2}; |
983 | |
984 | int bools[] = {0, 1}; |
985 | uint8_t uints8[] = {0u, 1u}; |
986 | int8_t ints8[] = {0, 1}; |
987 | uint16_t uints16[] = {0u, 1u}; |
988 | int16_t ints16[] = {0, 1}; |
989 | uint32_t uints32[] = {0u, 1u}; |
990 | int32_t ints32[] = {0, 1}; |
991 | uint64_t uints64[] = {0u, 1u}; |
992 | int64_t ints64[] = {0, 1}; |
993 | float floats[] = {0.0f, 1.0f}; |
994 | double doubles[] = {0.0, 1.0}; |
995 | uint16_t bf16s[] = {0x0, 0x3f80}; |
996 | uint16_t f16s[] = {0x0, 0x3c00}; |
997 | MlirAttribute encoding = mlirAttributeGetNull(); |
998 | MlirAttribute boolElements = mlirDenseElementsAttrBoolGet( |
999 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 1), encoding), |
1000 | numElements: 2, elements: bools); |
1001 | MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get( |
1002 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 8), |
1003 | encoding), |
1004 | numElements: 2, elements: uints8); |
1005 | MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get( |
1006 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 8), encoding), |
1007 | numElements: 2, elements: ints8); |
1008 | MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get( |
1009 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 16), |
1010 | encoding), |
1011 | numElements: 2, elements: uints16); |
1012 | MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get( |
1013 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 16), encoding), |
1014 | numElements: 2, elements: ints16); |
1015 | MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get( |
1016 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 32), |
1017 | encoding), |
1018 | numElements: 2, elements: uints32); |
1019 | MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get( |
1020 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 32), encoding), |
1021 | numElements: 2, elements: ints32); |
1022 | MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get( |
1023 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 64), |
1024 | encoding), |
1025 | numElements: 2, elements: uints64); |
1026 | MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get( |
1027 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 64), encoding), |
1028 | numElements: 2, elements: ints64); |
1029 | MlirAttribute floatElements = mlirDenseElementsAttrFloatGet( |
1030 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF32TypeGet(ctx), encoding), numElements: 2, |
1031 | elements: floats); |
1032 | MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet( |
1033 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF64TypeGet(ctx), encoding), numElements: 2, |
1034 | elements: doubles); |
1035 | MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get( |
1036 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirBF16TypeGet(ctx), encoding), numElements: 2, |
1037 | elements: bf16s); |
1038 | MlirAttribute f16Elements = mlirDenseElementsAttrFloat16Get( |
1039 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF16TypeGet(ctx), encoding), numElements: 2, |
1040 | elements: f16s); |
1041 | |
1042 | if (!mlirAttributeIsADenseElements(attr: boolElements) || |
1043 | !mlirAttributeIsADenseElements(attr: uint8Elements) || |
1044 | !mlirAttributeIsADenseElements(attr: int8Elements) || |
1045 | !mlirAttributeIsADenseElements(attr: uint32Elements) || |
1046 | !mlirAttributeIsADenseElements(attr: int32Elements) || |
1047 | !mlirAttributeIsADenseElements(attr: uint64Elements) || |
1048 | !mlirAttributeIsADenseElements(attr: int64Elements) || |
1049 | !mlirAttributeIsADenseElements(attr: floatElements) || |
1050 | !mlirAttributeIsADenseElements(attr: doubleElements) || |
1051 | !mlirAttributeIsADenseElements(attr: bf16Elements) || |
1052 | !mlirAttributeIsADenseElements(attr: f16Elements)) |
1053 | return 14; |
1054 | |
1055 | if (mlirDenseElementsAttrGetBoolValue(attr: boolElements, pos: 1) != 1 || |
1056 | mlirDenseElementsAttrGetUInt8Value(attr: uint8Elements, pos: 1) != 1 || |
1057 | mlirDenseElementsAttrGetInt8Value(attr: int8Elements, pos: 1) != 1 || |
1058 | mlirDenseElementsAttrGetUInt16Value(attr: uint16Elements, pos: 1) != 1 || |
1059 | mlirDenseElementsAttrGetInt16Value(attr: int16Elements, pos: 1) != 1 || |
1060 | mlirDenseElementsAttrGetUInt32Value(attr: uint32Elements, pos: 1) != 1 || |
1061 | mlirDenseElementsAttrGetInt32Value(attr: int32Elements, pos: 1) != 1 || |
1062 | mlirDenseElementsAttrGetUInt64Value(attr: uint64Elements, pos: 1) != 1 || |
1063 | mlirDenseElementsAttrGetInt64Value(attr: int64Elements, pos: 1) != 1 || |
1064 | fabsf(x: mlirDenseElementsAttrGetFloatValue(attr: floatElements, pos: 1) - 1.0f) > |
1065 | 1E-6f || |
1066 | fabs(x: mlirDenseElementsAttrGetDoubleValue(attr: doubleElements, pos: 1) - 1.0) > 1E-6) |
1067 | return 15; |
1068 | |
1069 | mlirAttributeDump(attr: boolElements); |
1070 | mlirAttributeDump(attr: uint8Elements); |
1071 | mlirAttributeDump(attr: int8Elements); |
1072 | mlirAttributeDump(attr: uint32Elements); |
1073 | mlirAttributeDump(attr: int32Elements); |
1074 | mlirAttributeDump(attr: uint64Elements); |
1075 | mlirAttributeDump(attr: int64Elements); |
1076 | mlirAttributeDump(attr: floatElements); |
1077 | mlirAttributeDump(attr: doubleElements); |
1078 | mlirAttributeDump(attr: bf16Elements); |
1079 | mlirAttributeDump(attr: f16Elements); |
1080 | // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1> |
1081 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8> |
1082 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8> |
1083 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32> |
1084 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32> |
1085 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64> |
1086 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64> |
1087 | // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32> |
1088 | // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64> |
1089 | // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16> |
1090 | // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf16> |
1091 | |
1092 | MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet( |
1093 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 1), encoding), |
1094 | element: 1); |
1095 | MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet( |
1096 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 8), |
1097 | encoding), |
1098 | element: 1); |
1099 | MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet( |
1100 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 8), encoding), |
1101 | element: 1); |
1102 | MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet( |
1103 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 32), |
1104 | encoding), |
1105 | element: 1); |
1106 | MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet( |
1107 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 32), encoding), |
1108 | element: 1); |
1109 | MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet( |
1110 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 64), |
1111 | encoding), |
1112 | element: 1); |
1113 | MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet( |
1114 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 64), encoding), |
1115 | element: 1); |
1116 | MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet( |
1117 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF32TypeGet(ctx), encoding), element: 1.0f); |
1118 | MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet( |
1119 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF64TypeGet(ctx), encoding), element: 1.0); |
1120 | |
1121 | if (!mlirAttributeIsADenseElements(attr: splatBool) || |
1122 | !mlirDenseElementsAttrIsSplat(attr: splatBool) || |
1123 | !mlirAttributeIsADenseElements(attr: splatUInt8) || |
1124 | !mlirDenseElementsAttrIsSplat(attr: splatUInt8) || |
1125 | !mlirAttributeIsADenseElements(attr: splatInt8) || |
1126 | !mlirDenseElementsAttrIsSplat(attr: splatInt8) || |
1127 | !mlirAttributeIsADenseElements(attr: splatUInt32) || |
1128 | !mlirDenseElementsAttrIsSplat(attr: splatUInt32) || |
1129 | !mlirAttributeIsADenseElements(attr: splatInt32) || |
1130 | !mlirDenseElementsAttrIsSplat(attr: splatInt32) || |
1131 | !mlirAttributeIsADenseElements(attr: splatUInt64) || |
1132 | !mlirDenseElementsAttrIsSplat(attr: splatUInt64) || |
1133 | !mlirAttributeIsADenseElements(attr: splatInt64) || |
1134 | !mlirDenseElementsAttrIsSplat(attr: splatInt64) || |
1135 | !mlirAttributeIsADenseElements(attr: splatFloat) || |
1136 | !mlirDenseElementsAttrIsSplat(attr: splatFloat) || |
1137 | !mlirAttributeIsADenseElements(attr: splatDouble) || |
1138 | !mlirDenseElementsAttrIsSplat(attr: splatDouble)) |
1139 | return 16; |
1140 | |
1141 | if (mlirDenseElementsAttrGetBoolSplatValue(attr: splatBool) != 1 || |
1142 | mlirDenseElementsAttrGetUInt8SplatValue(attr: splatUInt8) != 1 || |
1143 | mlirDenseElementsAttrGetInt8SplatValue(attr: splatInt8) != 1 || |
1144 | mlirDenseElementsAttrGetUInt32SplatValue(attr: splatUInt32) != 1 || |
1145 | mlirDenseElementsAttrGetInt32SplatValue(attr: splatInt32) != 1 || |
1146 | mlirDenseElementsAttrGetUInt64SplatValue(attr: splatUInt64) != 1 || |
1147 | mlirDenseElementsAttrGetInt64SplatValue(attr: splatInt64) != 1 || |
1148 | fabsf(x: mlirDenseElementsAttrGetFloatSplatValue(attr: splatFloat) - 1.0f) > |
1149 | 1E-6f || |
1150 | fabs(x: mlirDenseElementsAttrGetDoubleSplatValue(attr: splatDouble) - 1.0) > 1E-6) |
1151 | return 17; |
1152 | |
1153 | const uint8_t *uint8RawData = |
1154 | (const uint8_t *)mlirDenseElementsAttrGetRawData(attr: uint8Elements); |
1155 | const int8_t *int8RawData = |
1156 | (const int8_t *)mlirDenseElementsAttrGetRawData(attr: int8Elements); |
1157 | const uint32_t *uint32RawData = |
1158 | (const uint32_t *)mlirDenseElementsAttrGetRawData(attr: uint32Elements); |
1159 | const int32_t *int32RawData = |
1160 | (const int32_t *)mlirDenseElementsAttrGetRawData(attr: int32Elements); |
1161 | const uint64_t *uint64RawData = |
1162 | (const uint64_t *)mlirDenseElementsAttrGetRawData(attr: uint64Elements); |
1163 | const int64_t *int64RawData = |
1164 | (const int64_t *)mlirDenseElementsAttrGetRawData(attr: int64Elements); |
1165 | const float *floatRawData = |
1166 | (const float *)mlirDenseElementsAttrGetRawData(attr: floatElements); |
1167 | const double *doubleRawData = |
1168 | (const double *)mlirDenseElementsAttrGetRawData(attr: doubleElements); |
1169 | const uint16_t *bf16RawData = |
1170 | (const uint16_t *)mlirDenseElementsAttrGetRawData(attr: bf16Elements); |
1171 | const uint16_t *f16RawData = |
1172 | (const uint16_t *)mlirDenseElementsAttrGetRawData(attr: f16Elements); |
1173 | if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 || |
1174 | int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u || |
1175 | int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u || |
1176 | uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 || |
1177 | floatRawData[0] != 0.0f || floatRawData[1] != 1.0f || |
1178 | doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0 || |
1179 | bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80 || f16RawData[0] != 0 || |
1180 | f16RawData[1] != 0x3c00) |
1181 | return 18; |
1182 | |
1183 | mlirAttributeDump(attr: splatBool); |
1184 | mlirAttributeDump(attr: splatUInt8); |
1185 | mlirAttributeDump(attr: splatInt8); |
1186 | mlirAttributeDump(attr: splatUInt32); |
1187 | mlirAttributeDump(attr: splatInt32); |
1188 | mlirAttributeDump(attr: splatUInt64); |
1189 | mlirAttributeDump(attr: splatInt64); |
1190 | mlirAttributeDump(attr: splatFloat); |
1191 | mlirAttributeDump(attr: splatDouble); |
1192 | // CHECK: dense<true> : tensor<1x2xi1> |
1193 | // CHECK: dense<1> : tensor<1x2xui8> |
1194 | // CHECK: dense<1> : tensor<1x2xi8> |
1195 | // CHECK: dense<1> : tensor<1x2xui32> |
1196 | // CHECK: dense<1> : tensor<1x2xi32> |
1197 | // CHECK: dense<1> : tensor<1x2xui64> |
1198 | // CHECK: dense<1> : tensor<1x2xi64> |
1199 | // CHECK: dense<1.000000e+00> : tensor<1x2xf32> |
1200 | // CHECK: dense<1.000000e+00> : tensor<1x2xf64> |
1201 | |
1202 | mlirAttributeDump(attr: mlirElementsAttrGetValue(attr: floatElements, rank: 2, idxs: uints64)); |
1203 | mlirAttributeDump(attr: mlirElementsAttrGetValue(attr: doubleElements, rank: 2, idxs: uints64)); |
1204 | mlirAttributeDump(attr: mlirElementsAttrGetValue(attr: bf16Elements, rank: 2, idxs: uints64)); |
1205 | mlirAttributeDump(attr: mlirElementsAttrGetValue(attr: f16Elements, rank: 2, idxs: uints64)); |
1206 | // CHECK: 1.000000e+00 : f32 |
1207 | // CHECK: 1.000000e+00 : f64 |
1208 | // CHECK: 1.000000e+00 : bf16 |
1209 | // CHECK: 1.000000e+00 : f16 |
1210 | |
1211 | int64_t indices[] = {0, 1}; |
1212 | int64_t one = 1; |
1213 | MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get( |
1214 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 64), encoding), |
1215 | numElements: 2, elements: indices); |
1216 | MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet( |
1217 | shapedType: mlirRankedTensorTypeGet(rank: 1, shape: &one, elementType: mlirF32TypeGet(ctx), encoding), numElements: 1, |
1218 | elements: floats); |
1219 | MlirAttribute sparseAttr = mlirSparseElementsAttribute( |
1220 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF32TypeGet(ctx), encoding), |
1221 | denseIndices: indicesAttr, denseValues: valuesAttr); |
1222 | mlirAttributeDump(attr: sparseAttr); |
1223 | // CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32> |
1224 | |
1225 | MlirAttribute boolArray = mlirDenseBoolArrayGet(ctx, size: 2, values: bools); |
1226 | MlirAttribute int8Array = mlirDenseI8ArrayGet(ctx, size: 2, values: ints8); |
1227 | MlirAttribute int16Array = mlirDenseI16ArrayGet(ctx, size: 2, values: ints16); |
1228 | MlirAttribute int32Array = mlirDenseI32ArrayGet(ctx, size: 2, values: ints32); |
1229 | MlirAttribute int64Array = mlirDenseI64ArrayGet(ctx, size: 2, values: ints64); |
1230 | MlirAttribute floatArray = mlirDenseF32ArrayGet(ctx, size: 2, values: floats); |
1231 | MlirAttribute doubleArray = mlirDenseF64ArrayGet(ctx, size: 2, values: doubles); |
1232 | if (!mlirAttributeIsADenseBoolArray(attr: boolArray) || |
1233 | !mlirAttributeIsADenseI8Array(attr: int8Array) || |
1234 | !mlirAttributeIsADenseI16Array(attr: int16Array) || |
1235 | !mlirAttributeIsADenseI32Array(attr: int32Array) || |
1236 | !mlirAttributeIsADenseI64Array(attr: int64Array) || |
1237 | !mlirAttributeIsADenseF32Array(attr: floatArray) || |
1238 | !mlirAttributeIsADenseF64Array(attr: doubleArray)) |
1239 | return 19; |
1240 | |
1241 | if (mlirDenseArrayGetNumElements(attr: boolArray) != 2 || |
1242 | mlirDenseArrayGetNumElements(attr: int8Array) != 2 || |
1243 | mlirDenseArrayGetNumElements(attr: int16Array) != 2 || |
1244 | mlirDenseArrayGetNumElements(attr: int32Array) != 2 || |
1245 | mlirDenseArrayGetNumElements(attr: int64Array) != 2 || |
1246 | mlirDenseArrayGetNumElements(attr: floatArray) != 2 || |
1247 | mlirDenseArrayGetNumElements(attr: doubleArray) != 2) |
1248 | return 20; |
1249 | |
1250 | if (mlirDenseBoolArrayGetElement(attr: boolArray, pos: 1) != 1 || |
1251 | mlirDenseI8ArrayGetElement(attr: int8Array, pos: 1) != 1 || |
1252 | mlirDenseI16ArrayGetElement(attr: int16Array, pos: 1) != 1 || |
1253 | mlirDenseI32ArrayGetElement(attr: int32Array, pos: 1) != 1 || |
1254 | mlirDenseI64ArrayGetElement(attr: int64Array, pos: 1) != 1 || |
1255 | fabsf(x: mlirDenseF32ArrayGetElement(attr: floatArray, pos: 1) - 1.0f) > 1E-6f || |
1256 | fabs(x: mlirDenseF64ArrayGetElement(attr: doubleArray, pos: 1) - 1.0) > 1E-6) |
1257 | return 21; |
1258 | |
1259 | int64_t layoutStrides[3] = {5, 7, 13}; |
1260 | MlirAttribute stridedLayoutAttr = |
1261 | mlirStridedLayoutAttrGet(ctx, offset: 42, numStrides: 3, strides: &layoutStrides[0]); |
1262 | |
1263 | // CHECK: strided<[5, 7, 13], offset: 42> |
1264 | mlirAttributeDump(attr: stridedLayoutAttr); |
1265 | |
1266 | if (mlirStridedLayoutAttrGetOffset(attr: stridedLayoutAttr) != 42 || |
1267 | mlirStridedLayoutAttrGetNumStrides(attr: stridedLayoutAttr) != 3 || |
1268 | mlirStridedLayoutAttrGetStride(attr: stridedLayoutAttr, pos: 0) != 5 || |
1269 | mlirStridedLayoutAttrGetStride(attr: stridedLayoutAttr, pos: 1) != 7 || |
1270 | mlirStridedLayoutAttrGetStride(attr: stridedLayoutAttr, pos: 2) != 13) |
1271 | return 22; |
1272 | |
1273 | MlirAttribute uint8Blob = mlirUnmanagedDenseUInt8ResourceElementsAttrGet( |
1274 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 8), |
1275 | encoding), |
1276 | name: mlirStringRefCreateFromCString(str: "resource_ui8" ), numElements: 2, elements: uints8); |
1277 | MlirAttribute uint16Blob = mlirUnmanagedDenseUInt16ResourceElementsAttrGet( |
1278 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 16), |
1279 | encoding), |
1280 | name: mlirStringRefCreateFromCString(str: "resource_ui16" ), numElements: 2, elements: uints16); |
1281 | MlirAttribute uint32Blob = mlirUnmanagedDenseUInt32ResourceElementsAttrGet( |
1282 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 32), |
1283 | encoding), |
1284 | name: mlirStringRefCreateFromCString(str: "resource_ui32" ), numElements: 2, elements: uints32); |
1285 | MlirAttribute uint64Blob = mlirUnmanagedDenseUInt64ResourceElementsAttrGet( |
1286 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 64), |
1287 | encoding), |
1288 | name: mlirStringRefCreateFromCString(str: "resource_ui64" ), numElements: 2, elements: uints64); |
1289 | MlirAttribute int8Blob = mlirUnmanagedDenseInt8ResourceElementsAttrGet( |
1290 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 8), encoding), |
1291 | name: mlirStringRefCreateFromCString(str: "resource_i8" ), numElements: 2, elements: ints8); |
1292 | MlirAttribute int16Blob = mlirUnmanagedDenseInt16ResourceElementsAttrGet( |
1293 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 16), encoding), |
1294 | name: mlirStringRefCreateFromCString(str: "resource_i16" ), numElements: 2, elements: ints16); |
1295 | MlirAttribute int32Blob = mlirUnmanagedDenseInt32ResourceElementsAttrGet( |
1296 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 32), encoding), |
1297 | name: mlirStringRefCreateFromCString(str: "resource_i32" ), numElements: 2, elements: ints32); |
1298 | MlirAttribute int64Blob = mlirUnmanagedDenseInt64ResourceElementsAttrGet( |
1299 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 64), encoding), |
1300 | name: mlirStringRefCreateFromCString(str: "resource_i64" ), numElements: 2, elements: ints64); |
1301 | MlirAttribute floatsBlob = mlirUnmanagedDenseFloatResourceElementsAttrGet( |
1302 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF32TypeGet(ctx), encoding), |
1303 | name: mlirStringRefCreateFromCString(str: "resource_f32" ), numElements: 2, elements: floats); |
1304 | MlirAttribute doublesBlob = mlirUnmanagedDenseDoubleResourceElementsAttrGet( |
1305 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF64TypeGet(ctx), encoding), |
1306 | name: mlirStringRefCreateFromCString(str: "resource_f64" ), numElements: 2, elements: doubles); |
1307 | MlirAttribute blobBlob = mlirUnmanagedDenseResourceElementsAttrGet( |
1308 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 64), encoding), |
1309 | name: mlirStringRefCreateFromCString(str: "resource_i64_blob" ), /*data=*/uints64, |
1310 | /*dataLength=*/sizeof(uints64), |
1311 | /*dataAlignment=*/_Alignof(uint64_t), |
1312 | /*dataIsMutable=*/false, |
1313 | /*deleter=*/reportResourceDelete, |
1314 | /*userData=*/(void *)&resourceI64BlobUserData); |
1315 | |
1316 | mlirAttributeDump(attr: uint8Blob); |
1317 | mlirAttributeDump(attr: uint16Blob); |
1318 | mlirAttributeDump(attr: uint32Blob); |
1319 | mlirAttributeDump(attr: uint64Blob); |
1320 | mlirAttributeDump(attr: int8Blob); |
1321 | mlirAttributeDump(attr: int16Blob); |
1322 | mlirAttributeDump(attr: int32Blob); |
1323 | mlirAttributeDump(attr: int64Blob); |
1324 | mlirAttributeDump(attr: floatsBlob); |
1325 | mlirAttributeDump(attr: doublesBlob); |
1326 | mlirAttributeDump(attr: blobBlob); |
1327 | // CHECK: dense_resource<resource_ui8> : tensor<1x2xui8> |
1328 | // CHECK: dense_resource<resource_ui16> : tensor<1x2xui16> |
1329 | // CHECK: dense_resource<resource_ui32> : tensor<1x2xui32> |
1330 | // CHECK: dense_resource<resource_ui64> : tensor<1x2xui64> |
1331 | // CHECK: dense_resource<resource_i8> : tensor<1x2xi8> |
1332 | // CHECK: dense_resource<resource_i16> : tensor<1x2xi16> |
1333 | // CHECK: dense_resource<resource_i32> : tensor<1x2xi32> |
1334 | // CHECK: dense_resource<resource_i64> : tensor<1x2xi64> |
1335 | // CHECK: dense_resource<resource_f32> : tensor<1x2xf32> |
1336 | // CHECK: dense_resource<resource_f64> : tensor<1x2xf64> |
1337 | // CHECK: dense_resource<resource_i64_blob> : tensor<1x2xi64> |
1338 | |
1339 | if (mlirDenseUInt8ResourceElementsAttrGetValue(attr: uint8Blob, pos: 1) != 1 || |
1340 | mlirDenseUInt16ResourceElementsAttrGetValue(attr: uint16Blob, pos: 1) != 1 || |
1341 | mlirDenseUInt32ResourceElementsAttrGetValue(attr: uint32Blob, pos: 1) != 1 || |
1342 | mlirDenseUInt64ResourceElementsAttrGetValue(attr: uint64Blob, pos: 1) != 1 || |
1343 | mlirDenseInt8ResourceElementsAttrGetValue(attr: int8Blob, pos: 1) != 1 || |
1344 | mlirDenseInt16ResourceElementsAttrGetValue(attr: int16Blob, pos: 1) != 1 || |
1345 | mlirDenseInt32ResourceElementsAttrGetValue(attr: int32Blob, pos: 1) != 1 || |
1346 | mlirDenseInt64ResourceElementsAttrGetValue(attr: int64Blob, pos: 1) != 1 || |
1347 | fabsf(x: mlirDenseF32ArrayGetElement(attr: floatArray, pos: 1) - 1.0f) > 1E-6f || |
1348 | fabsf(x: mlirDenseFloatResourceElementsAttrGetValue(attr: floatsBlob, pos: 1) - 1.0f) > |
1349 | 1e-6 || |
1350 | fabs(x: mlirDenseDoubleResourceElementsAttrGetValue(attr: doublesBlob, pos: 1) - 1.0f) > |
1351 | 1e-6 || |
1352 | mlirDenseUInt64ResourceElementsAttrGetValue(attr: blobBlob, pos: 1) != 1) |
1353 | return 23; |
1354 | |
1355 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
1356 | MlirAttribute locAttr = mlirLocationGetAttribute(location: loc); |
1357 | if (!mlirAttributeIsALocation(attr: locAttr)) |
1358 | return 24; |
1359 | |
1360 | return 0; |
1361 | } |
1362 | |
1363 | int printAffineMap(MlirContext ctx) { |
1364 | MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx); |
1365 | MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, dimCount: 3, symbolCount: 2); |
1366 | MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, val: 2); |
1367 | MlirAffineMap multiDimIdentityAffineMap = |
1368 | mlirAffineMapMultiDimIdentityGet(ctx, numDims: 3); |
1369 | MlirAffineMap minorIdentityAffineMap = |
1370 | mlirAffineMapMinorIdentityGet(ctx, dims: 3, results: 2); |
1371 | unsigned permutation[] = {1, 2, 0}; |
1372 | MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet( |
1373 | ctx, size: sizeof(permutation) / sizeof(unsigned), permutation); |
1374 | |
1375 | fprintf(stderr, format: "@affineMap\n" ); |
1376 | mlirAffineMapDump(affineMap: emptyAffineMap); |
1377 | mlirAffineMapDump(affineMap); |
1378 | mlirAffineMapDump(affineMap: constAffineMap); |
1379 | mlirAffineMapDump(affineMap: multiDimIdentityAffineMap); |
1380 | mlirAffineMapDump(affineMap: minorIdentityAffineMap); |
1381 | mlirAffineMapDump(affineMap: permutationAffineMap); |
1382 | // CHECK-LABEL: @affineMap |
1383 | // CHECK: () -> () |
1384 | // CHECK: (d0, d1, d2)[s0, s1] -> () |
1385 | // CHECK: () -> (2) |
1386 | // CHECK: (d0, d1, d2) -> (d0, d1, d2) |
1387 | // CHECK: (d0, d1, d2) -> (d1, d2) |
1388 | // CHECK: (d0, d1, d2) -> (d1, d2, d0) |
1389 | |
1390 | if (!mlirAffineMapIsIdentity(affineMap: emptyAffineMap) || |
1391 | mlirAffineMapIsIdentity(affineMap) || |
1392 | mlirAffineMapIsIdentity(affineMap: constAffineMap) || |
1393 | !mlirAffineMapIsIdentity(affineMap: multiDimIdentityAffineMap) || |
1394 | mlirAffineMapIsIdentity(affineMap: minorIdentityAffineMap) || |
1395 | mlirAffineMapIsIdentity(affineMap: permutationAffineMap)) |
1396 | return 1; |
1397 | |
1398 | if (!mlirAffineMapIsMinorIdentity(affineMap: emptyAffineMap) || |
1399 | mlirAffineMapIsMinorIdentity(affineMap) || |
1400 | !mlirAffineMapIsMinorIdentity(affineMap: multiDimIdentityAffineMap) || |
1401 | !mlirAffineMapIsMinorIdentity(affineMap: minorIdentityAffineMap) || |
1402 | mlirAffineMapIsMinorIdentity(affineMap: permutationAffineMap)) |
1403 | return 2; |
1404 | |
1405 | if (!mlirAffineMapIsEmpty(affineMap: emptyAffineMap) || |
1406 | mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(affineMap: constAffineMap) || |
1407 | mlirAffineMapIsEmpty(affineMap: multiDimIdentityAffineMap) || |
1408 | mlirAffineMapIsEmpty(affineMap: minorIdentityAffineMap) || |
1409 | mlirAffineMapIsEmpty(affineMap: permutationAffineMap)) |
1410 | return 3; |
1411 | |
1412 | if (mlirAffineMapIsSingleConstant(affineMap: emptyAffineMap) || |
1413 | mlirAffineMapIsSingleConstant(affineMap) || |
1414 | !mlirAffineMapIsSingleConstant(affineMap: constAffineMap) || |
1415 | mlirAffineMapIsSingleConstant(affineMap: multiDimIdentityAffineMap) || |
1416 | mlirAffineMapIsSingleConstant(affineMap: minorIdentityAffineMap) || |
1417 | mlirAffineMapIsSingleConstant(affineMap: permutationAffineMap)) |
1418 | return 4; |
1419 | |
1420 | if (mlirAffineMapGetSingleConstantResult(affineMap: constAffineMap) != 2) |
1421 | return 5; |
1422 | |
1423 | if (mlirAffineMapGetNumDims(affineMap: emptyAffineMap) != 0 || |
1424 | mlirAffineMapGetNumDims(affineMap) != 3 || |
1425 | mlirAffineMapGetNumDims(affineMap: constAffineMap) != 0 || |
1426 | mlirAffineMapGetNumDims(affineMap: multiDimIdentityAffineMap) != 3 || |
1427 | mlirAffineMapGetNumDims(affineMap: minorIdentityAffineMap) != 3 || |
1428 | mlirAffineMapGetNumDims(affineMap: permutationAffineMap) != 3) |
1429 | return 6; |
1430 | |
1431 | if (mlirAffineMapGetNumSymbols(affineMap: emptyAffineMap) != 0 || |
1432 | mlirAffineMapGetNumSymbols(affineMap) != 2 || |
1433 | mlirAffineMapGetNumSymbols(affineMap: constAffineMap) != 0 || |
1434 | mlirAffineMapGetNumSymbols(affineMap: multiDimIdentityAffineMap) != 0 || |
1435 | mlirAffineMapGetNumSymbols(affineMap: minorIdentityAffineMap) != 0 || |
1436 | mlirAffineMapGetNumSymbols(affineMap: permutationAffineMap) != 0) |
1437 | return 7; |
1438 | |
1439 | if (mlirAffineMapGetNumResults(affineMap: emptyAffineMap) != 0 || |
1440 | mlirAffineMapGetNumResults(affineMap) != 0 || |
1441 | mlirAffineMapGetNumResults(affineMap: constAffineMap) != 1 || |
1442 | mlirAffineMapGetNumResults(affineMap: multiDimIdentityAffineMap) != 3 || |
1443 | mlirAffineMapGetNumResults(affineMap: minorIdentityAffineMap) != 2 || |
1444 | mlirAffineMapGetNumResults(affineMap: permutationAffineMap) != 3) |
1445 | return 8; |
1446 | |
1447 | if (mlirAffineMapGetNumInputs(affineMap: emptyAffineMap) != 0 || |
1448 | mlirAffineMapGetNumInputs(affineMap) != 5 || |
1449 | mlirAffineMapGetNumInputs(affineMap: constAffineMap) != 0 || |
1450 | mlirAffineMapGetNumInputs(affineMap: multiDimIdentityAffineMap) != 3 || |
1451 | mlirAffineMapGetNumInputs(affineMap: minorIdentityAffineMap) != 3 || |
1452 | mlirAffineMapGetNumInputs(affineMap: permutationAffineMap) != 3) |
1453 | return 9; |
1454 | |
1455 | if (!mlirAffineMapIsProjectedPermutation(affineMap: emptyAffineMap) || |
1456 | !mlirAffineMapIsPermutation(affineMap: emptyAffineMap) || |
1457 | mlirAffineMapIsProjectedPermutation(affineMap) || |
1458 | mlirAffineMapIsPermutation(affineMap) || |
1459 | mlirAffineMapIsProjectedPermutation(affineMap: constAffineMap) || |
1460 | mlirAffineMapIsPermutation(affineMap: constAffineMap) || |
1461 | !mlirAffineMapIsProjectedPermutation(affineMap: multiDimIdentityAffineMap) || |
1462 | !mlirAffineMapIsPermutation(affineMap: multiDimIdentityAffineMap) || |
1463 | !mlirAffineMapIsProjectedPermutation(affineMap: minorIdentityAffineMap) || |
1464 | mlirAffineMapIsPermutation(affineMap: minorIdentityAffineMap) || |
1465 | !mlirAffineMapIsProjectedPermutation(affineMap: permutationAffineMap) || |
1466 | !mlirAffineMapIsPermutation(affineMap: permutationAffineMap)) |
1467 | return 10; |
1468 | |
1469 | intptr_t sub[] = {1}; |
1470 | |
1471 | MlirAffineMap subMap = mlirAffineMapGetSubMap( |
1472 | affineMap: multiDimIdentityAffineMap, size: sizeof(sub) / sizeof(intptr_t), resultPos: sub); |
1473 | MlirAffineMap majorSubMap = |
1474 | mlirAffineMapGetMajorSubMap(affineMap: multiDimIdentityAffineMap, numResults: 1); |
1475 | MlirAffineMap minorSubMap = |
1476 | mlirAffineMapGetMinorSubMap(affineMap: multiDimIdentityAffineMap, numResults: 1); |
1477 | |
1478 | mlirAffineMapDump(affineMap: subMap); |
1479 | mlirAffineMapDump(affineMap: majorSubMap); |
1480 | mlirAffineMapDump(affineMap: minorSubMap); |
1481 | // CHECK: (d0, d1, d2) -> (d1) |
1482 | // CHECK: (d0, d1, d2) -> (d0) |
1483 | // CHECK: (d0, d1, d2) -> (d2) |
1484 | |
1485 | // CHECK: distinct[0]<"foo"> |
1486 | mlirAttributeDump(attr: mlirDisctinctAttrCreate( |
1487 | referencedAttr: mlirStringAttrGet(ctx, str: mlirStringRefCreateFromCString(str: "foo" )))); |
1488 | |
1489 | return 0; |
1490 | } |
1491 | |
1492 | int printAffineExpr(MlirContext ctx) { |
1493 | MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, position: 5); |
1494 | MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, position: 5); |
1495 | MlirAffineExpr affineConstantExpr = mlirAffineConstantExprGet(ctx, constant: 5); |
1496 | MlirAffineExpr affineAddExpr = |
1497 | mlirAffineAddExprGet(lhs: affineDimExpr, rhs: affineSymbolExpr); |
1498 | MlirAffineExpr affineMulExpr = |
1499 | mlirAffineMulExprGet(lhs: affineDimExpr, rhs: affineSymbolExpr); |
1500 | MlirAffineExpr affineModExpr = |
1501 | mlirAffineModExprGet(lhs: affineDimExpr, rhs: affineSymbolExpr); |
1502 | MlirAffineExpr affineFloorDivExpr = |
1503 | mlirAffineFloorDivExprGet(lhs: affineDimExpr, rhs: affineSymbolExpr); |
1504 | MlirAffineExpr affineCeilDivExpr = |
1505 | mlirAffineCeilDivExprGet(lhs: affineDimExpr, rhs: affineSymbolExpr); |
1506 | |
1507 | // Tests mlirAffineExprDump. |
1508 | fprintf(stderr, format: "@affineExpr\n" ); |
1509 | mlirAffineExprDump(affineExpr: affineDimExpr); |
1510 | mlirAffineExprDump(affineExpr: affineSymbolExpr); |
1511 | mlirAffineExprDump(affineExpr: affineConstantExpr); |
1512 | mlirAffineExprDump(affineExpr: affineAddExpr); |
1513 | mlirAffineExprDump(affineExpr: affineMulExpr); |
1514 | mlirAffineExprDump(affineExpr: affineModExpr); |
1515 | mlirAffineExprDump(affineExpr: affineFloorDivExpr); |
1516 | mlirAffineExprDump(affineExpr: affineCeilDivExpr); |
1517 | // CHECK-LABEL: @affineExpr |
1518 | // CHECK: d5 |
1519 | // CHECK: s5 |
1520 | // CHECK: 5 |
1521 | // CHECK: d5 + s5 |
1522 | // CHECK: d5 * s5 |
1523 | // CHECK: d5 mod s5 |
1524 | // CHECK: d5 floordiv s5 |
1525 | // CHECK: d5 ceildiv s5 |
1526 | |
1527 | // Tests methods of affine binary operation expression, takes add expression |
1528 | // as an example. |
1529 | mlirAffineExprDump(affineExpr: mlirAffineBinaryOpExprGetLHS(affineExpr: affineAddExpr)); |
1530 | mlirAffineExprDump(affineExpr: mlirAffineBinaryOpExprGetRHS(affineExpr: affineAddExpr)); |
1531 | // CHECK: d5 |
1532 | // CHECK: s5 |
1533 | |
1534 | // Tests methods of affine dimension expression. |
1535 | if (mlirAffineDimExprGetPosition(affineExpr: affineDimExpr) != 5) |
1536 | return 1; |
1537 | |
1538 | // Tests methods of affine symbol expression. |
1539 | if (mlirAffineSymbolExprGetPosition(affineExpr: affineSymbolExpr) != 5) |
1540 | return 2; |
1541 | |
1542 | // Tests methods of affine constant expression. |
1543 | if (mlirAffineConstantExprGetValue(affineExpr: affineConstantExpr) != 5) |
1544 | return 3; |
1545 | |
1546 | // Tests methods of affine expression. |
1547 | if (mlirAffineExprIsSymbolicOrConstant(affineExpr: affineDimExpr) || |
1548 | !mlirAffineExprIsSymbolicOrConstant(affineExpr: affineSymbolExpr) || |
1549 | !mlirAffineExprIsSymbolicOrConstant(affineExpr: affineConstantExpr) || |
1550 | mlirAffineExprIsSymbolicOrConstant(affineExpr: affineAddExpr) || |
1551 | mlirAffineExprIsSymbolicOrConstant(affineExpr: affineMulExpr) || |
1552 | mlirAffineExprIsSymbolicOrConstant(affineExpr: affineModExpr) || |
1553 | mlirAffineExprIsSymbolicOrConstant(affineExpr: affineFloorDivExpr) || |
1554 | mlirAffineExprIsSymbolicOrConstant(affineExpr: affineCeilDivExpr)) |
1555 | return 4; |
1556 | |
1557 | if (!mlirAffineExprIsPureAffine(affineExpr: affineDimExpr) || |
1558 | !mlirAffineExprIsPureAffine(affineExpr: affineSymbolExpr) || |
1559 | !mlirAffineExprIsPureAffine(affineExpr: affineConstantExpr) || |
1560 | !mlirAffineExprIsPureAffine(affineExpr: affineAddExpr) || |
1561 | mlirAffineExprIsPureAffine(affineExpr: affineMulExpr) || |
1562 | mlirAffineExprIsPureAffine(affineExpr: affineModExpr) || |
1563 | mlirAffineExprIsPureAffine(affineExpr: affineFloorDivExpr) || |
1564 | mlirAffineExprIsPureAffine(affineExpr: affineCeilDivExpr)) |
1565 | return 5; |
1566 | |
1567 | if (mlirAffineExprGetLargestKnownDivisor(affineExpr: affineDimExpr) != 1 || |
1568 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineSymbolExpr) != 1 || |
1569 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineConstantExpr) != 5 || |
1570 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineAddExpr) != 1 || |
1571 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineMulExpr) != 1 || |
1572 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineModExpr) != 1 || |
1573 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineFloorDivExpr) != 1 || |
1574 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineCeilDivExpr) != 1) |
1575 | return 6; |
1576 | |
1577 | if (!mlirAffineExprIsMultipleOf(affineExpr: affineDimExpr, factor: 1) || |
1578 | !mlirAffineExprIsMultipleOf(affineExpr: affineSymbolExpr, factor: 1) || |
1579 | !mlirAffineExprIsMultipleOf(affineExpr: affineConstantExpr, factor: 5) || |
1580 | !mlirAffineExprIsMultipleOf(affineExpr: affineAddExpr, factor: 1) || |
1581 | !mlirAffineExprIsMultipleOf(affineExpr: affineMulExpr, factor: 1) || |
1582 | !mlirAffineExprIsMultipleOf(affineExpr: affineModExpr, factor: 1) || |
1583 | !mlirAffineExprIsMultipleOf(affineExpr: affineFloorDivExpr, factor: 1) || |
1584 | !mlirAffineExprIsMultipleOf(affineExpr: affineCeilDivExpr, factor: 1)) |
1585 | return 7; |
1586 | |
1587 | if (!mlirAffineExprIsFunctionOfDim(affineExpr: affineDimExpr, position: 5) || |
1588 | mlirAffineExprIsFunctionOfDim(affineExpr: affineSymbolExpr, position: 5) || |
1589 | mlirAffineExprIsFunctionOfDim(affineExpr: affineConstantExpr, position: 5) || |
1590 | !mlirAffineExprIsFunctionOfDim(affineExpr: affineAddExpr, position: 5) || |
1591 | !mlirAffineExprIsFunctionOfDim(affineExpr: affineMulExpr, position: 5) || |
1592 | !mlirAffineExprIsFunctionOfDim(affineExpr: affineModExpr, position: 5) || |
1593 | !mlirAffineExprIsFunctionOfDim(affineExpr: affineFloorDivExpr, position: 5) || |
1594 | !mlirAffineExprIsFunctionOfDim(affineExpr: affineCeilDivExpr, position: 5)) |
1595 | return 8; |
1596 | |
1597 | // Tests 'IsA' methods of affine binary operation expression. |
1598 | if (!mlirAffineExprIsAAdd(affineExpr: affineAddExpr)) |
1599 | return 9; |
1600 | |
1601 | if (!mlirAffineExprIsAMul(affineExpr: affineMulExpr)) |
1602 | return 10; |
1603 | |
1604 | if (!mlirAffineExprIsAMod(affineExpr: affineModExpr)) |
1605 | return 11; |
1606 | |
1607 | if (!mlirAffineExprIsAFloorDiv(affineExpr: affineFloorDivExpr)) |
1608 | return 12; |
1609 | |
1610 | if (!mlirAffineExprIsACeilDiv(affineExpr: affineCeilDivExpr)) |
1611 | return 13; |
1612 | |
1613 | if (!mlirAffineExprIsABinary(affineExpr: affineAddExpr)) |
1614 | return 14; |
1615 | |
1616 | // Test other 'IsA' method on affine expressions. |
1617 | if (!mlirAffineExprIsAConstant(affineExpr: affineConstantExpr)) |
1618 | return 15; |
1619 | |
1620 | if (!mlirAffineExprIsADim(affineExpr: affineDimExpr)) |
1621 | return 16; |
1622 | |
1623 | if (!mlirAffineExprIsASymbol(affineExpr: affineSymbolExpr)) |
1624 | return 17; |
1625 | |
1626 | // Test equality and nullity. |
1627 | MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, position: 5); |
1628 | if (!mlirAffineExprEqual(lhs: affineDimExpr, rhs: otherDimExpr)) |
1629 | return 18; |
1630 | |
1631 | if (mlirAffineExprIsNull(affineExpr: affineDimExpr)) |
1632 | return 19; |
1633 | |
1634 | return 0; |
1635 | } |
1636 | |
1637 | int affineMapFromExprs(MlirContext ctx) { |
1638 | MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, position: 0); |
1639 | MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, position: 1); |
1640 | MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr}; |
1641 | MlirAffineMap map = mlirAffineMapGet(ctx, dimCount: 3, symbolCount: 3, nAffineExprs: 2, affineExprs: exprs); |
1642 | |
1643 | // CHECK-LABEL: @affineMapFromExprs |
1644 | fprintf(stderr, format: "@affineMapFromExprs" ); |
1645 | // CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1) |
1646 | mlirAffineMapDump(affineMap: map); |
1647 | |
1648 | if (mlirAffineMapGetNumResults(affineMap: map) != 2) |
1649 | return 1; |
1650 | |
1651 | if (!mlirAffineExprEqual(lhs: mlirAffineMapGetResult(affineMap: map, pos: 0), rhs: affineDimExpr)) |
1652 | return 2; |
1653 | |
1654 | if (!mlirAffineExprEqual(lhs: mlirAffineMapGetResult(affineMap: map, pos: 1), rhs: affineSymbolExpr)) |
1655 | return 3; |
1656 | |
1657 | MlirAffineExpr affineDim2Expr = mlirAffineDimExprGet(ctx, position: 1); |
1658 | MlirAffineExpr composed = mlirAffineExprCompose(affineExpr: affineDim2Expr, affineMap: map); |
1659 | // CHECK: s1 |
1660 | mlirAffineExprDump(affineExpr: composed); |
1661 | if (!mlirAffineExprEqual(lhs: composed, rhs: affineSymbolExpr)) |
1662 | return 4; |
1663 | |
1664 | return 0; |
1665 | } |
1666 | |
1667 | int printIntegerSet(MlirContext ctx) { |
1668 | MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(context: ctx, numDims: 2, numSymbols: 1); |
1669 | |
1670 | // CHECK-LABEL: @printIntegerSet |
1671 | fprintf(stderr, format: "@printIntegerSet" ); |
1672 | |
1673 | // CHECK: (d0, d1)[s0] : (1 == 0) |
1674 | mlirIntegerSetDump(set: emptySet); |
1675 | |
1676 | if (!mlirIntegerSetIsCanonicalEmpty(set: emptySet)) |
1677 | return 1; |
1678 | |
1679 | MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(context: ctx, numDims: 2, numSymbols: 1); |
1680 | if (!mlirIntegerSetEqual(s1: emptySet, s2: anotherEmptySet)) |
1681 | return 2; |
1682 | |
1683 | // Construct a set constrained by: |
1684 | // d0 - s0 == 0, |
1685 | // d1 - 42 >= 0. |
1686 | MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, constant: -1); |
1687 | MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, constant: -42); |
1688 | MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, position: 0); |
1689 | MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, position: 1); |
1690 | MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, position: 0); |
1691 | MlirAffineExpr negS0 = mlirAffineMulExprGet(lhs: negOne, rhs: s0); |
1692 | MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(lhs: d0, rhs: negS0); |
1693 | MlirAffineExpr d1minus42 = mlirAffineAddExprGet(lhs: d1, rhs: negFortyTwo); |
1694 | MlirAffineExpr constraints[] = {d0minusS0, d1minus42}; |
1695 | bool flags[] = {true, false}; |
1696 | |
1697 | MlirIntegerSet set = mlirIntegerSetGet(context: ctx, numDims: 2, numSymbols: 1, numConstraints: 2, constraints, eqFlags: flags); |
1698 | // CHECK: (d0, d1)[s0] : ( |
1699 | // CHECK-DAG: d0 - s0 == 0 |
1700 | // CHECK-DAG: d1 - 42 >= 0 |
1701 | mlirIntegerSetDump(set); |
1702 | |
1703 | // Transform d1 into s0. |
1704 | MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, position: 1); |
1705 | MlirAffineExpr repl[] = {d0, s1}; |
1706 | MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, dimReplacements: repl, symbolReplacements: &s0, numResultDims: 1, numResultSymbols: 2); |
1707 | // CHECK: (d0)[s0, s1] : ( |
1708 | // CHECK-DAG: d0 - s0 == 0 |
1709 | // CHECK-DAG: s1 - 42 >= 0 |
1710 | mlirIntegerSetDump(set: replaced); |
1711 | |
1712 | if (mlirIntegerSetGetNumDims(set) != 2) |
1713 | return 3; |
1714 | if (mlirIntegerSetGetNumDims(set: replaced) != 1) |
1715 | return 4; |
1716 | |
1717 | if (mlirIntegerSetGetNumSymbols(set) != 1) |
1718 | return 5; |
1719 | if (mlirIntegerSetGetNumSymbols(set: replaced) != 2) |
1720 | return 6; |
1721 | |
1722 | if (mlirIntegerSetGetNumInputs(set) != 3) |
1723 | return 7; |
1724 | |
1725 | if (mlirIntegerSetGetNumConstraints(set) != 2) |
1726 | return 8; |
1727 | |
1728 | if (mlirIntegerSetGetNumEqualities(set) != 1) |
1729 | return 9; |
1730 | |
1731 | if (mlirIntegerSetGetNumInequalities(set) != 1) |
1732 | return 10; |
1733 | |
1734 | MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, pos: 0); |
1735 | MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, pos: 1); |
1736 | bool isEq1 = mlirIntegerSetIsConstraintEq(set, pos: 0); |
1737 | bool isEq2 = mlirIntegerSetIsConstraintEq(set, pos: 1); |
1738 | if (!mlirAffineExprEqual(lhs: cstr1, rhs: isEq1 ? d0minusS0 : d1minus42)) |
1739 | return 11; |
1740 | if (!mlirAffineExprEqual(lhs: cstr2, rhs: isEq2 ? d0minusS0 : d1minus42)) |
1741 | return 12; |
1742 | |
1743 | return 0; |
1744 | } |
1745 | |
1746 | int registerOnlyStd(void) { |
1747 | MlirContext ctx = mlirContextCreate(); |
1748 | // The built-in dialect is always loaded. |
1749 | if (mlirContextGetNumLoadedDialects(context: ctx) != 1) |
1750 | return 1; |
1751 | |
1752 | MlirDialectHandle stdHandle = mlirGetDialectHandle__func__(); |
1753 | |
1754 | MlirDialect std = mlirContextGetOrLoadDialect( |
1755 | context: ctx, name: mlirDialectHandleGetNamespace(stdHandle)); |
1756 | if (!mlirDialectIsNull(dialect: std)) |
1757 | return 2; |
1758 | |
1759 | mlirDialectHandleRegisterDialect(stdHandle, ctx); |
1760 | |
1761 | std = mlirContextGetOrLoadDialect(context: ctx, |
1762 | name: mlirDialectHandleGetNamespace(stdHandle)); |
1763 | if (mlirDialectIsNull(dialect: std)) |
1764 | return 3; |
1765 | |
1766 | MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx); |
1767 | if (!mlirDialectEqual(dialect1: std, dialect2: alsoStd)) |
1768 | return 4; |
1769 | |
1770 | MlirStringRef stdNs = mlirDialectGetNamespace(dialect: std); |
1771 | MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle); |
1772 | if (stdNs.length != alsoStdNs.length || |
1773 | strncmp(s1: stdNs.data, s2: alsoStdNs.data, n: stdNs.length)) |
1774 | return 5; |
1775 | |
1776 | fprintf(stderr, format: "@registration\n" ); |
1777 | // CHECK-LABEL: @registration |
1778 | |
1779 | // CHECK: func.call is_registered: 1 |
1780 | fprintf(stderr, format: "func.call is_registered: %d\n" , |
1781 | mlirContextIsRegisteredOperation( |
1782 | context: ctx, name: mlirStringRefCreateFromCString(str: "func.call" ))); |
1783 | |
1784 | // CHECK: func.not_existing_op is_registered: 0 |
1785 | fprintf(stderr, format: "func.not_existing_op is_registered: %d\n" , |
1786 | mlirContextIsRegisteredOperation( |
1787 | context: ctx, name: mlirStringRefCreateFromCString(str: "func.not_existing_op" ))); |
1788 | |
1789 | // CHECK: not_existing_dialect.not_existing_op is_registered: 0 |
1790 | fprintf(stderr, format: "not_existing_dialect.not_existing_op is_registered: %d\n" , |
1791 | mlirContextIsRegisteredOperation( |
1792 | context: ctx, name: mlirStringRefCreateFromCString( |
1793 | str: "not_existing_dialect.not_existing_op" ))); |
1794 | |
1795 | mlirContextDestroy(context: ctx); |
1796 | return 0; |
1797 | } |
1798 | |
1799 | /// Tests backreference APIs |
1800 | static int testBackreferences(void) { |
1801 | fprintf(stderr, format: "@test_backreferences\n" ); |
1802 | |
1803 | MlirContext ctx = mlirContextCreate(); |
1804 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
1805 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
1806 | |
1807 | MlirOperationState opState = |
1808 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "invalid.op" ), loc); |
1809 | MlirRegion region = mlirRegionCreate(); |
1810 | MlirBlock block = mlirBlockCreate(nArgs: 0, NULL, NULL); |
1811 | mlirRegionAppendOwnedBlock(region, block); |
1812 | mlirOperationStateAddOwnedRegions(state: &opState, n: 1, regions: ®ion); |
1813 | MlirOperation op = mlirOperationCreate(state: &opState); |
1814 | MlirIdentifier ident = |
1815 | mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "identifier" )); |
1816 | |
1817 | if (!mlirContextEqual(ctx1: ctx, ctx2: mlirOperationGetContext(op))) { |
1818 | fprintf(stderr, format: "ERROR: Getting context from operation failed\n" ); |
1819 | return 1; |
1820 | } |
1821 | if (!mlirOperationEqual(op, other: mlirBlockGetParentOperation(block))) { |
1822 | fprintf(stderr, format: "ERROR: Getting parent operation from block failed\n" ); |
1823 | return 2; |
1824 | } |
1825 | if (!mlirContextEqual(ctx1: ctx, ctx2: mlirIdentifierGetContext(ident))) { |
1826 | fprintf(stderr, format: "ERROR: Getting context from identifier failed\n" ); |
1827 | return 3; |
1828 | } |
1829 | |
1830 | mlirOperationDestroy(op); |
1831 | mlirContextDestroy(context: ctx); |
1832 | |
1833 | // CHECK-LABEL: @test_backreferences |
1834 | return 0; |
1835 | } |
1836 | |
1837 | /// Tests operand APIs. |
1838 | int testOperands(void) { |
1839 | fprintf(stderr, format: "@testOperands\n" ); |
1840 | // CHECK-LABEL: @testOperands |
1841 | |
1842 | MlirContext ctx = mlirContextCreate(); |
1843 | registerAllUpstreamDialects(ctx); |
1844 | |
1845 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "arith" )); |
1846 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "test" )); |
1847 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
1848 | MlirType indexType = mlirIndexTypeGet(ctx); |
1849 | |
1850 | // Create some constants to use as operands. |
1851 | MlirAttribute indexZeroLiteral = |
1852 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "0 : index" )); |
1853 | MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( |
1854 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value" )), |
1855 | attr: indexZeroLiteral); |
1856 | MlirOperationState constZeroState = mlirOperationStateGet( |
1857 | name: mlirStringRefCreateFromCString(str: "arith.constant" ), loc); |
1858 | mlirOperationStateAddResults(state: &constZeroState, n: 1, results: &indexType); |
1859 | mlirOperationStateAddAttributes(state: &constZeroState, n: 1, attributes: &indexZeroValueAttr); |
1860 | MlirOperation constZero = mlirOperationCreate(state: &constZeroState); |
1861 | MlirValue constZeroValue = mlirOperationGetResult(op: constZero, pos: 0); |
1862 | |
1863 | MlirAttribute indexOneLiteral = |
1864 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "1 : index" )); |
1865 | MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet( |
1866 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value" )), |
1867 | attr: indexOneLiteral); |
1868 | MlirOperationState constOneState = mlirOperationStateGet( |
1869 | name: mlirStringRefCreateFromCString(str: "arith.constant" ), loc); |
1870 | mlirOperationStateAddResults(state: &constOneState, n: 1, results: &indexType); |
1871 | mlirOperationStateAddAttributes(state: &constOneState, n: 1, attributes: &indexOneValueAttr); |
1872 | MlirOperation constOne = mlirOperationCreate(state: &constOneState); |
1873 | MlirValue constOneValue = mlirOperationGetResult(op: constOne, pos: 0); |
1874 | |
1875 | // Create the operation under test. |
1876 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
1877 | MlirOperationState opState = |
1878 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op" ), loc); |
1879 | MlirValue initialOperands[] = {constZeroValue}; |
1880 | mlirOperationStateAddOperands(state: &opState, n: 1, operands: initialOperands); |
1881 | MlirOperation op = mlirOperationCreate(state: &opState); |
1882 | |
1883 | // Test operand APIs. |
1884 | intptr_t numOperands = mlirOperationGetNumOperands(op); |
1885 | fprintf(stderr, format: "Num Operands: %" PRIdPTR "\n" , numOperands); |
1886 | // CHECK: Num Operands: 1 |
1887 | |
1888 | MlirValue opOperand1 = mlirOperationGetOperand(op, pos: 0); |
1889 | fprintf(stderr, format: "Original operand: " ); |
1890 | mlirValuePrint(value: opOperand1, callback: printToStderr, NULL); |
1891 | // CHECK: Original operand: {{.+}} arith.constant 0 : index |
1892 | |
1893 | mlirOperationSetOperand(op, pos: 0, newValue: constOneValue); |
1894 | MlirValue opOperand2 = mlirOperationGetOperand(op, pos: 0); |
1895 | fprintf(stderr, format: "Updated operand: " ); |
1896 | mlirValuePrint(value: opOperand2, callback: printToStderr, NULL); |
1897 | // CHECK: Updated operand: {{.+}} arith.constant 1 : index |
1898 | |
1899 | // Test op operand APIs. |
1900 | MlirOpOperand use1 = mlirValueGetFirstUse(value: opOperand1); |
1901 | if (!mlirOpOperandIsNull(opOperand: use1)) { |
1902 | fprintf(stderr, format: "ERROR: Use should be null\n" ); |
1903 | return 1; |
1904 | } |
1905 | |
1906 | MlirOpOperand use2 = mlirValueGetFirstUse(value: opOperand2); |
1907 | if (mlirOpOperandIsNull(opOperand: use2)) { |
1908 | fprintf(stderr, format: "ERROR: Use should not be null\n" ); |
1909 | return 2; |
1910 | } |
1911 | |
1912 | fprintf(stderr, format: "Use owner: " ); |
1913 | mlirOperationPrint(op: mlirOpOperandGetOwner(opOperand: use2), callback: printToStderr, NULL); |
1914 | fprintf(stderr, format: "\n" ); |
1915 | // CHECK: Use owner: "dummy.op" |
1916 | |
1917 | fprintf(stderr, format: "Use operandNumber: %d\n" , |
1918 | mlirOpOperandGetOperandNumber(opOperand: use2)); |
1919 | // CHECK: Use operandNumber: 0 |
1920 | |
1921 | use2 = mlirOpOperandGetNextUse(opOperand: use2); |
1922 | if (!mlirOpOperandIsNull(opOperand: use2)) { |
1923 | fprintf(stderr, format: "ERROR: Next use should be null\n" ); |
1924 | return 3; |
1925 | } |
1926 | |
1927 | MlirOperationState op2State = |
1928 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op2" ), loc); |
1929 | MlirValue initialOperands2[] = {constOneValue}; |
1930 | mlirOperationStateAddOperands(state: &op2State, n: 1, operands: initialOperands2); |
1931 | MlirOperation op2 = mlirOperationCreate(state: &op2State); |
1932 | |
1933 | MlirOpOperand use3 = mlirValueGetFirstUse(value: constOneValue); |
1934 | fprintf(stderr, format: "First use owner: " ); |
1935 | mlirOperationPrint(op: mlirOpOperandGetOwner(opOperand: use3), callback: printToStderr, NULL); |
1936 | fprintf(stderr, format: "\n" ); |
1937 | // CHECK: First use owner: "dummy.op2" |
1938 | |
1939 | use3 = mlirOpOperandGetNextUse(opOperand: mlirValueGetFirstUse(value: constOneValue)); |
1940 | fprintf(stderr, format: "Second use owner: " ); |
1941 | mlirOperationPrint(op: mlirOpOperandGetOwner(opOperand: use3), callback: printToStderr, NULL); |
1942 | fprintf(stderr, format: "\n" ); |
1943 | // CHECK: Second use owner: "dummy.op" |
1944 | |
1945 | MlirAttribute indexTwoLiteral = |
1946 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "2 : index" )); |
1947 | MlirNamedAttribute indexTwoValueAttr = mlirNamedAttributeGet( |
1948 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value" )), |
1949 | attr: indexTwoLiteral); |
1950 | MlirOperationState constTwoState = mlirOperationStateGet( |
1951 | name: mlirStringRefCreateFromCString(str: "arith.constant" ), loc); |
1952 | mlirOperationStateAddResults(state: &constTwoState, n: 1, results: &indexType); |
1953 | mlirOperationStateAddAttributes(state: &constTwoState, n: 1, attributes: &indexTwoValueAttr); |
1954 | MlirOperation constTwo = mlirOperationCreate(state: &constTwoState); |
1955 | MlirValue constTwoValue = mlirOperationGetResult(op: constTwo, pos: 0); |
1956 | |
1957 | mlirValueReplaceAllUsesOfWith(of: constOneValue, with: constTwoValue); |
1958 | |
1959 | use3 = mlirValueGetFirstUse(value: constOneValue); |
1960 | if (!mlirOpOperandIsNull(opOperand: use3)) { |
1961 | fprintf(stderr, format: "ERROR: Use should be null\n" ); |
1962 | return 4; |
1963 | } |
1964 | |
1965 | MlirOpOperand use4 = mlirValueGetFirstUse(value: constTwoValue); |
1966 | fprintf(stderr, format: "First replacement use owner: " ); |
1967 | mlirOperationPrint(op: mlirOpOperandGetOwner(opOperand: use4), callback: printToStderr, NULL); |
1968 | fprintf(stderr, format: "\n" ); |
1969 | // CHECK: First replacement use owner: "dummy.op" |
1970 | |
1971 | use4 = mlirOpOperandGetNextUse(opOperand: mlirValueGetFirstUse(value: constTwoValue)); |
1972 | fprintf(stderr, format: "Second replacement use owner: " ); |
1973 | mlirOperationPrint(op: mlirOpOperandGetOwner(opOperand: use4), callback: printToStderr, NULL); |
1974 | fprintf(stderr, format: "\n" ); |
1975 | // CHECK: Second replacement use owner: "dummy.op2" |
1976 | |
1977 | MlirOpOperand use5 = mlirValueGetFirstUse(value: constTwoValue); |
1978 | MlirOpOperand use6 = mlirOpOperandGetNextUse(opOperand: use5); |
1979 | if (!mlirValueEqual(value1: mlirOpOperandGetValue(opOperand: use5), |
1980 | value2: mlirOpOperandGetValue(opOperand: use6))) { |
1981 | fprintf(stderr, |
1982 | format: "ERROR: First and second operand should share the same value\n" ); |
1983 | return 5; |
1984 | } |
1985 | |
1986 | mlirOperationDestroy(op); |
1987 | mlirOperationDestroy(op: op2); |
1988 | mlirOperationDestroy(op: constZero); |
1989 | mlirOperationDestroy(op: constOne); |
1990 | mlirOperationDestroy(op: constTwo); |
1991 | mlirContextDestroy(context: ctx); |
1992 | |
1993 | return 0; |
1994 | } |
1995 | |
1996 | /// Tests clone APIs. |
1997 | int testClone(void) { |
1998 | fprintf(stderr, format: "@testClone\n" ); |
1999 | // CHECK-LABEL: @testClone |
2000 | |
2001 | MlirContext ctx = mlirContextCreate(); |
2002 | registerAllUpstreamDialects(ctx); |
2003 | |
2004 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "func" )); |
2005 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "arith" )); |
2006 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
2007 | MlirType indexType = mlirIndexTypeGet(ctx); |
2008 | MlirStringRef valueStringRef = mlirStringRefCreateFromCString(str: "value" ); |
2009 | |
2010 | MlirAttribute indexZeroLiteral = |
2011 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "0 : index" )); |
2012 | MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( |
2013 | name: mlirIdentifierGet(context: ctx, str: valueStringRef), attr: indexZeroLiteral); |
2014 | MlirOperationState constZeroState = mlirOperationStateGet( |
2015 | name: mlirStringRefCreateFromCString(str: "arith.constant" ), loc); |
2016 | mlirOperationStateAddResults(state: &constZeroState, n: 1, results: &indexType); |
2017 | mlirOperationStateAddAttributes(state: &constZeroState, n: 1, attributes: &indexZeroValueAttr); |
2018 | MlirOperation constZero = mlirOperationCreate(state: &constZeroState); |
2019 | |
2020 | MlirAttribute indexOneLiteral = |
2021 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "1 : index" )); |
2022 | MlirOperation constOne = mlirOperationClone(op: constZero); |
2023 | mlirOperationSetAttributeByName(op: constOne, name: valueStringRef, attr: indexOneLiteral); |
2024 | |
2025 | mlirOperationPrint(op: constZero, callback: printToStderr, NULL); |
2026 | mlirOperationPrint(op: constOne, callback: printToStderr, NULL); |
2027 | // CHECK: arith.constant 0 : index |
2028 | // CHECK: arith.constant 1 : index |
2029 | |
2030 | mlirOperationDestroy(op: constZero); |
2031 | mlirOperationDestroy(op: constOne); |
2032 | mlirContextDestroy(context: ctx); |
2033 | return 0; |
2034 | } |
2035 | |
2036 | // Wraps a diagnostic into additional text we can match against. |
2037 | MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) { |
2038 | fprintf(stderr, format: "processing diagnostic (userData: %" PRIdPTR ") <<\n" , |
2039 | (intptr_t)userData); |
2040 | mlirDiagnosticPrint(diagnostic, callback: printToStderr, NULL); |
2041 | fprintf(stderr, format: "\n" ); |
2042 | MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); |
2043 | mlirLocationPrint(location: loc, callback: printToStderr, NULL); |
2044 | assert(mlirDiagnosticGetNumNotes(diagnostic) == 0); |
2045 | fprintf(stderr, format: "\n>> end of diagnostic (userData: %" PRIdPTR ")\n" , |
2046 | (intptr_t)userData); |
2047 | return mlirLogicalResultSuccess(); |
2048 | } |
2049 | |
2050 | // Logs when the delete user data callback is called |
2051 | static void deleteUserData(void *userData) { |
2052 | fprintf(stderr, format: "deleting user data (userData: %" PRIdPTR ")\n" , |
2053 | (intptr_t)userData); |
2054 | } |
2055 | |
2056 | int testTypeID(MlirContext ctx) { |
2057 | fprintf(stderr, format: "@testTypeID\n" ); |
2058 | |
2059 | // Test getting and comparing type and attribute type ids. |
2060 | MlirType i32 = mlirIntegerTypeGet(ctx, bitwidth: 32); |
2061 | MlirTypeID i32ID = mlirTypeGetTypeID(type: i32); |
2062 | MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, bitwidth: 32); |
2063 | MlirTypeID ui32ID = mlirTypeGetTypeID(type: ui32); |
2064 | MlirType f32 = mlirF32TypeGet(ctx); |
2065 | MlirTypeID f32ID = mlirTypeGetTypeID(type: f32); |
2066 | MlirAttribute i32Attr = mlirIntegerAttrGet(type: i32, value: 1); |
2067 | MlirTypeID i32AttrID = mlirAttributeGetTypeID(attribute: i32Attr); |
2068 | |
2069 | if (mlirTypeIDIsNull(typeID: i32ID) || mlirTypeIDIsNull(typeID: ui32ID) || |
2070 | mlirTypeIDIsNull(typeID: f32ID) || mlirTypeIDIsNull(typeID: i32AttrID)) { |
2071 | fprintf(stderr, format: "ERROR: Expected type ids to be present\n" ); |
2072 | return 1; |
2073 | } |
2074 | |
2075 | if (!mlirTypeIDEqual(typeID1: i32ID, typeID2: ui32ID) || |
2076 | mlirTypeIDHashValue(typeID: i32ID) != mlirTypeIDHashValue(typeID: ui32ID)) { |
2077 | fprintf( |
2078 | stderr, |
2079 | format: "ERROR: Expected different integer types to have the same type id\n" ); |
2080 | return 2; |
2081 | } |
2082 | |
2083 | if (mlirTypeIDEqual(typeID1: i32ID, typeID2: f32ID)) { |
2084 | fprintf(stderr, |
2085 | format: "ERROR: Expected integer type id to not equal float type id\n" ); |
2086 | return 3; |
2087 | } |
2088 | |
2089 | if (mlirTypeIDEqual(typeID1: i32ID, typeID2: i32AttrID)) { |
2090 | fprintf(stderr, format: "ERROR: Expected integer type id to not equal integer " |
2091 | "attribute type id\n" ); |
2092 | return 4; |
2093 | } |
2094 | |
2095 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
2096 | MlirType indexType = mlirIndexTypeGet(ctx); |
2097 | MlirStringRef valueStringRef = mlirStringRefCreateFromCString(str: "value" ); |
2098 | |
2099 | // Create a registered operation, which should have a type id. |
2100 | MlirAttribute indexZeroLiteral = |
2101 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "0 : index" )); |
2102 | MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( |
2103 | name: mlirIdentifierGet(context: ctx, str: valueStringRef), attr: indexZeroLiteral); |
2104 | MlirOperationState constZeroState = mlirOperationStateGet( |
2105 | name: mlirStringRefCreateFromCString(str: "arith.constant" ), loc); |
2106 | mlirOperationStateAddResults(state: &constZeroState, n: 1, results: &indexType); |
2107 | mlirOperationStateAddAttributes(state: &constZeroState, n: 1, attributes: &indexZeroValueAttr); |
2108 | MlirOperation constZero = mlirOperationCreate(state: &constZeroState); |
2109 | |
2110 | if (!mlirOperationVerify(op: constZero)) { |
2111 | fprintf(stderr, format: "ERROR: Expected operation to verify correctly\n" ); |
2112 | return 5; |
2113 | } |
2114 | |
2115 | if (mlirOperationIsNull(op: constZero)) { |
2116 | fprintf(stderr, format: "ERROR: Expected registered operation to be present\n" ); |
2117 | return 6; |
2118 | } |
2119 | |
2120 | MlirTypeID registeredOpID = mlirOperationGetTypeID(op: constZero); |
2121 | |
2122 | if (mlirTypeIDIsNull(typeID: registeredOpID)) { |
2123 | fprintf(stderr, |
2124 | format: "ERROR: Expected registered operation type id to be present\n" ); |
2125 | return 7; |
2126 | } |
2127 | |
2128 | // Create an unregistered operation, which should not have a type id. |
2129 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
2130 | MlirOperationState opState = |
2131 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op" ), loc); |
2132 | MlirOperation unregisteredOp = mlirOperationCreate(state: &opState); |
2133 | if (mlirOperationIsNull(op: unregisteredOp)) { |
2134 | fprintf(stderr, format: "ERROR: Expected unregistered operation to be present\n" ); |
2135 | return 8; |
2136 | } |
2137 | |
2138 | MlirTypeID unregisteredOpID = mlirOperationGetTypeID(op: unregisteredOp); |
2139 | |
2140 | if (!mlirTypeIDIsNull(typeID: unregisteredOpID)) { |
2141 | fprintf(stderr, |
2142 | format: "ERROR: Expected unregistered operation type id to be null\n" ); |
2143 | return 9; |
2144 | } |
2145 | |
2146 | mlirOperationDestroy(op: constZero); |
2147 | mlirOperationDestroy(op: unregisteredOp); |
2148 | |
2149 | return 0; |
2150 | } |
2151 | |
2152 | int testSymbolTable(MlirContext ctx) { |
2153 | fprintf(stderr, format: "@testSymbolTable\n" ); |
2154 | |
2155 | const char *moduleString = "func.func private @foo()" |
2156 | "func.func private @bar()" ; |
2157 | const char *otherModuleString = "func.func private @qux()" |
2158 | "func.func private @foo()" ; |
2159 | |
2160 | MlirModule module = |
2161 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
2162 | MlirModule otherModule = mlirModuleCreateParse( |
2163 | context: ctx, module: mlirStringRefCreateFromCString(str: otherModuleString)); |
2164 | |
2165 | MlirSymbolTable symbolTable = |
2166 | mlirSymbolTableCreate(operation: mlirModuleGetOperation(module)); |
2167 | |
2168 | MlirOperation funcFoo = |
2169 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "foo" )); |
2170 | if (mlirOperationIsNull(op: funcFoo)) |
2171 | return 1; |
2172 | |
2173 | MlirOperation funcBar = |
2174 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "bar" )); |
2175 | if (mlirOperationEqual(op: funcFoo, other: funcBar)) |
2176 | return 2; |
2177 | |
2178 | MlirOperation missing = |
2179 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "qux" )); |
2180 | if (!mlirOperationIsNull(op: missing)) |
2181 | return 3; |
2182 | |
2183 | MlirBlock moduleBody = mlirModuleGetBody(module); |
2184 | MlirBlock otherModuleBody = mlirModuleGetBody(module: otherModule); |
2185 | MlirOperation operation = mlirBlockGetFirstOperation(block: otherModuleBody); |
2186 | mlirOperationRemoveFromParent(op: operation); |
2187 | mlirBlockAppendOwnedOperation(block: moduleBody, operation); |
2188 | |
2189 | // At this moment, the operation is still missing from the symbol table. |
2190 | MlirOperation stillMissing = |
2191 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "qux" )); |
2192 | if (!mlirOperationIsNull(op: stillMissing)) |
2193 | return 4; |
2194 | |
2195 | // After it is added to the symbol table, and not only the operation with |
2196 | // which the table is associated, it can be looked up. |
2197 | mlirSymbolTableInsert(symbolTable, operation); |
2198 | MlirOperation funcQux = |
2199 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "qux" )); |
2200 | if (!mlirOperationEqual(op: operation, other: funcQux)) |
2201 | return 5; |
2202 | |
2203 | // Erasing from the symbol table also removes the operation. |
2204 | mlirSymbolTableErase(symbolTable, operation: funcBar); |
2205 | MlirOperation nowMissing = |
2206 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "bar" )); |
2207 | if (!mlirOperationIsNull(op: nowMissing)) |
2208 | return 6; |
2209 | |
2210 | // Adding a symbol with the same name to the table should rename. |
2211 | MlirOperation duplicateNameOp = mlirBlockGetFirstOperation(block: otherModuleBody); |
2212 | mlirOperationRemoveFromParent(op: duplicateNameOp); |
2213 | mlirBlockAppendOwnedOperation(block: moduleBody, operation: duplicateNameOp); |
2214 | MlirAttribute newName = mlirSymbolTableInsert(symbolTable, operation: duplicateNameOp); |
2215 | MlirStringRef newNameStr = mlirStringAttrGetValue(attr: newName); |
2216 | if (mlirStringRefEqual(string: newNameStr, other: mlirStringRefCreateFromCString(str: "foo" ))) |
2217 | return 7; |
2218 | MlirAttribute updatedName = mlirOperationGetAttributeByName( |
2219 | op: duplicateNameOp, name: mlirSymbolTableGetSymbolAttributeName()); |
2220 | if (!mlirAttributeEqual(a1: updatedName, a2: newName)) |
2221 | return 8; |
2222 | |
2223 | mlirOperationDump(op: mlirModuleGetOperation(module)); |
2224 | mlirOperationDump(op: mlirModuleGetOperation(module: otherModule)); |
2225 | // clang-format off |
2226 | // CHECK-LABEL: @testSymbolTable |
2227 | // CHECK: module |
2228 | // CHECK: func private @foo |
2229 | // CHECK: func private @qux |
2230 | // CHECK: func private @foo{{.+}} |
2231 | // CHECK: module |
2232 | // CHECK-NOT: @qux |
2233 | // CHECK-NOT: @foo |
2234 | // clang-format on |
2235 | |
2236 | mlirSymbolTableDestroy(symbolTable); |
2237 | mlirModuleDestroy(module); |
2238 | mlirModuleDestroy(module: otherModule); |
2239 | |
2240 | return 0; |
2241 | } |
2242 | |
2243 | typedef struct { |
2244 | const char *x; |
2245 | } callBackData; |
2246 | |
2247 | MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) { |
2248 | fprintf(stderr, format: "%s: %s\n" , ((callBackData *)(rootOpVoid))->x, |
2249 | mlirIdentifierStr(ident: mlirOperationGetName(op)).data); |
2250 | return MlirWalkResultAdvance; |
2251 | } |
2252 | |
2253 | MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) { |
2254 | fprintf(stderr, format: "%s: %s\n" , ((callBackData *)(rootOpVoid))->x, |
2255 | mlirIdentifierStr(ident: mlirOperationGetName(op)).data); |
2256 | if (strcmp(s1: mlirIdentifierStr(ident: mlirOperationGetName(op)).data, s2: "func.func" ) == |
2257 | 0) |
2258 | return MlirWalkResultSkip; |
2259 | if (strcmp(s1: mlirIdentifierStr(ident: mlirOperationGetName(op)).data, s2: "arith.addi" ) == |
2260 | 0) |
2261 | return MlirWalkResultInterrupt; |
2262 | return MlirWalkResultAdvance; |
2263 | } |
2264 | |
2265 | int testOperationWalk(MlirContext ctx) { |
2266 | // CHECK-LABEL: @testOperationWalk |
2267 | fprintf(stderr, format: "@testOperationWalk\n" ); |
2268 | |
2269 | const char *moduleString = "module {\n" |
2270 | " func.func @foo() {\n" |
2271 | " %1 = arith.constant 10: i32\n" |
2272 | " arith.addi %1, %1: i32\n" |
2273 | " return\n" |
2274 | " }\n" |
2275 | " func.func @bar() {\n" |
2276 | " return\n" |
2277 | " }\n" |
2278 | "}" ; |
2279 | MlirModule module = |
2280 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
2281 | |
2282 | callBackData data; |
2283 | data.x = "i love you" ; |
2284 | |
2285 | // CHECK-NEXT: i love you: arith.constant |
2286 | // CHECK-NEXT: i love you: arith.addi |
2287 | // CHECK-NEXT: i love you: func.return |
2288 | // CHECK-NEXT: i love you: func.func |
2289 | // CHECK-NEXT: i love you: func.return |
2290 | // CHECK-NEXT: i love you: func.func |
2291 | // CHECK-NEXT: i love you: builtin.module |
2292 | mlirOperationWalk(op: mlirModuleGetOperation(module), callback: walkCallBack, |
2293 | userData: (void *)(&data), walkOrder: MlirWalkPostOrder); |
2294 | |
2295 | data.x = "i don't love you" ; |
2296 | // CHECK-NEXT: i don't love you: builtin.module |
2297 | // CHECK-NEXT: i don't love you: func.func |
2298 | // CHECK-NEXT: i don't love you: arith.constant |
2299 | // CHECK-NEXT: i don't love you: arith.addi |
2300 | // CHECK-NEXT: i don't love you: func.return |
2301 | // CHECK-NEXT: i don't love you: func.func |
2302 | // CHECK-NEXT: i don't love you: func.return |
2303 | mlirOperationWalk(op: mlirModuleGetOperation(module), callback: walkCallBack, |
2304 | userData: (void *)(&data), walkOrder: MlirWalkPreOrder); |
2305 | |
2306 | data.x = "interrupt" ; |
2307 | // Interrupted at `arith.addi` |
2308 | // CHECK-NEXT: interrupt: arith.constant |
2309 | // CHECK-NEXT: interrupt: arith.addi |
2310 | mlirOperationWalk(op: mlirModuleGetOperation(module), callback: walkCallBackTestWalkResult, |
2311 | userData: (void *)(&data), walkOrder: MlirWalkPostOrder); |
2312 | |
2313 | data.x = "skip" ; |
2314 | // Skip at `func.func` |
2315 | // CHECK-NEXT: skip: builtin.module |
2316 | // CHECK-NEXT: skip: func.func |
2317 | // CHECK-NEXT: skip: func.func |
2318 | mlirOperationWalk(op: mlirModuleGetOperation(module), callback: walkCallBackTestWalkResult, |
2319 | userData: (void *)(&data), walkOrder: MlirWalkPreOrder); |
2320 | |
2321 | mlirModuleDestroy(module); |
2322 | return 0; |
2323 | } |
2324 | |
2325 | int testDialectRegistry(void) { |
2326 | fprintf(stderr, format: "@testDialectRegistry\n" ); |
2327 | |
2328 | MlirDialectRegistry registry = mlirDialectRegistryCreate(); |
2329 | if (mlirDialectRegistryIsNull(registry)) { |
2330 | fprintf(stderr, format: "ERROR: Expected registry to be present\n" ); |
2331 | return 1; |
2332 | } |
2333 | |
2334 | MlirDialectHandle stdHandle = mlirGetDialectHandle__func__(); |
2335 | mlirDialectHandleInsertDialect(stdHandle, registry); |
2336 | |
2337 | MlirContext ctx = mlirContextCreate(); |
2338 | if (mlirContextGetNumRegisteredDialects(context: ctx) != 0) { |
2339 | fprintf(stderr, |
2340 | format: "ERROR: Expected no dialects to be registered to new context\n" ); |
2341 | } |
2342 | |
2343 | mlirContextAppendDialectRegistry(ctx, registry); |
2344 | if (mlirContextGetNumRegisteredDialects(context: ctx) != 1) { |
2345 | fprintf(stderr, format: "ERROR: Expected the dialect in the registry to be " |
2346 | "registered to the context\n" ); |
2347 | } |
2348 | |
2349 | mlirContextDestroy(context: ctx); |
2350 | mlirDialectRegistryDestroy(registry); |
2351 | |
2352 | return 0; |
2353 | } |
2354 | |
2355 | void testExplicitThreadPools(void) { |
2356 | MlirLlvmThreadPool threadPool = mlirLlvmThreadPoolCreate(); |
2357 | MlirDialectRegistry registry = mlirDialectRegistryCreate(); |
2358 | mlirRegisterAllDialects(registry); |
2359 | MlirContext context = |
2360 | mlirContextCreateWithRegistry(registry, /*threadingEnabled=*/false); |
2361 | mlirContextSetThreadPool(context, threadPool); |
2362 | mlirContextDestroy(context); |
2363 | mlirDialectRegistryDestroy(registry); |
2364 | mlirLlvmThreadPoolDestroy(pool: threadPool); |
2365 | } |
2366 | |
2367 | void testDiagnostics(void) { |
2368 | MlirContext ctx = mlirContextCreate(); |
2369 | MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( |
2370 | context: ctx, handler: errorHandler, userData: (void *)42, deleteUserData); |
2371 | fprintf(stderr, format: "@test_diagnostics\n" ); |
2372 | MlirLocation unknownLoc = mlirLocationUnknownGet(context: ctx); |
2373 | mlirEmitError(location: unknownLoc, message: "test diagnostics" ); |
2374 | MlirAttribute unknownAttr = mlirLocationGetAttribute(location: unknownLoc); |
2375 | MlirLocation unknownClone = mlirLocationFromAttribute(attribute: unknownAttr); |
2376 | mlirEmitError(location: unknownClone, message: "test clone" ); |
2377 | MlirLocation fileLineColLoc = mlirLocationFileLineColGet( |
2378 | context: ctx, filename: mlirStringRefCreateFromCString(str: "file.c" ), line: 1, col: 2); |
2379 | mlirEmitError(location: fileLineColLoc, message: "test diagnostics" ); |
2380 | MlirLocation callSiteLoc = mlirLocationCallSiteGet( |
2381 | callee: mlirLocationFileLineColGet( |
2382 | context: ctx, filename: mlirStringRefCreateFromCString(str: "other-file.c" ), line: 2, col: 3), |
2383 | caller: fileLineColLoc); |
2384 | mlirEmitError(location: callSiteLoc, message: "test diagnostics" ); |
2385 | MlirLocation null = {0}; |
2386 | MlirLocation nameLoc = |
2387 | mlirLocationNameGet(context: ctx, name: mlirStringRefCreateFromCString(str: "named" ), childLoc: null); |
2388 | mlirEmitError(location: nameLoc, message: "test diagnostics" ); |
2389 | MlirLocation locs[2] = {nameLoc, callSiteLoc}; |
2390 | MlirAttribute nullAttr = {0}; |
2391 | MlirLocation fusedLoc = mlirLocationFusedGet(ctx, nLocations: 2, locations: locs, metadata: nullAttr); |
2392 | mlirEmitError(location: fusedLoc, message: "test diagnostics" ); |
2393 | mlirContextDetachDiagnosticHandler(context: ctx, id); |
2394 | mlirEmitError(location: unknownLoc, message: "more test diagnostics" ); |
2395 | // CHECK-LABEL: @test_diagnostics |
2396 | // CHECK: processing diagnostic (userData: 42) << |
2397 | // CHECK: test diagnostics |
2398 | // CHECK: loc(unknown) |
2399 | // CHECK: processing diagnostic (userData: 42) << |
2400 | // CHECK: test clone |
2401 | // CHECK: loc(unknown) |
2402 | // CHECK: >> end of diagnostic (userData: 42) |
2403 | // CHECK: processing diagnostic (userData: 42) << |
2404 | // CHECK: test diagnostics |
2405 | // CHECK: loc("file.c":1:2) |
2406 | // CHECK: >> end of diagnostic (userData: 42) |
2407 | // CHECK: processing diagnostic (userData: 42) << |
2408 | // CHECK: test diagnostics |
2409 | // CHECK: loc(callsite("other-file.c":2:3 at "file.c":1:2)) |
2410 | // CHECK: >> end of diagnostic (userData: 42) |
2411 | // CHECK: processing diagnostic (userData: 42) << |
2412 | // CHECK: test diagnostics |
2413 | // CHECK: loc("named") |
2414 | // CHECK: >> end of diagnostic (userData: 42) |
2415 | // CHECK: processing diagnostic (userData: 42) << |
2416 | // CHECK: test diagnostics |
2417 | // CHECK: loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)]) |
2418 | // CHECK: deleting user data (userData: 42) |
2419 | // CHECK-NOT: processing diagnostic |
2420 | // CHECK: more test diagnostics |
2421 | mlirContextDestroy(context: ctx); |
2422 | } |
2423 | |
2424 | int main(void) { |
2425 | MlirContext ctx = mlirContextCreate(); |
2426 | registerAllUpstreamDialects(ctx); |
2427 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "func" )); |
2428 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "memref" )); |
2429 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "shape" )); |
2430 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "scf" )); |
2431 | |
2432 | if (constructAndTraverseIr(ctx)) |
2433 | return 1; |
2434 | buildWithInsertionsAndPrint(ctx); |
2435 | if (createOperationWithTypeInference(ctx)) |
2436 | return 2; |
2437 | |
2438 | if (printBuiltinTypes(ctx)) |
2439 | return 3; |
2440 | if (printBuiltinAttributes(ctx)) |
2441 | return 4; |
2442 | if (printAffineMap(ctx)) |
2443 | return 5; |
2444 | if (printAffineExpr(ctx)) |
2445 | return 6; |
2446 | if (affineMapFromExprs(ctx)) |
2447 | return 7; |
2448 | if (printIntegerSet(ctx)) |
2449 | return 8; |
2450 | if (registerOnlyStd()) |
2451 | return 9; |
2452 | if (testBackreferences()) |
2453 | return 10; |
2454 | if (testOperands()) |
2455 | return 11; |
2456 | if (testClone()) |
2457 | return 12; |
2458 | if (testTypeID(ctx)) |
2459 | return 13; |
2460 | if (testSymbolTable(ctx)) |
2461 | return 14; |
2462 | if (testDialectRegistry()) |
2463 | return 15; |
2464 | if (testOperationWalk(ctx)) |
2465 | return 16; |
2466 | |
2467 | testExplicitThreadPools(); |
2468 | testDiagnostics(); |
2469 | |
2470 | // CHECK: DESTROY MAIN CONTEXT |
2471 | // CHECK: reportResourceDelete: resource_i64_blob |
2472 | fprintf(stderr, format: "DESTROY MAIN CONTEXT\n" ); |
2473 | mlirContextDestroy(context: ctx); |
2474 | |
2475 | return 0; |
2476 | } |
2477 | |