1 | //===- ir.c - Simple test of C APIs ---------------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
4 | // Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s |
11 | */ |
12 | |
13 | #include "mlir-c/IR.h" |
14 | #include "mlir-c/AffineExpr.h" |
15 | #include "mlir-c/AffineMap.h" |
16 | #include "mlir-c/BuiltinAttributes.h" |
17 | #include "mlir-c/BuiltinTypes.h" |
18 | #include "mlir-c/Diagnostics.h" |
19 | #include "mlir-c/Dialect/Func.h" |
20 | #include "mlir-c/IntegerSet.h" |
21 | #include "mlir-c/RegisterEverything.h" |
22 | #include "mlir-c/Support.h" |
23 | |
24 | #include <assert.h> |
25 | #include <inttypes.h> |
26 | #include <math.h> |
27 | #include <stdio.h> |
28 | #include <stdlib.h> |
29 | #include <string.h> |
30 | |
31 | static void registerAllUpstreamDialects(MlirContext ctx) { |
32 | MlirDialectRegistry registry = mlirDialectRegistryCreate(); |
33 | mlirRegisterAllDialects(registry); |
34 | mlirContextAppendDialectRegistry(ctx, registry); |
35 | mlirDialectRegistryDestroy(registry); |
36 | } |
37 | |
38 | struct ResourceDeleteUserData { |
39 | const char *name; |
40 | }; |
41 | static struct ResourceDeleteUserData resourceI64BlobUserData = { |
42 | "resource_i64_blob"}; |
43 | static void reportResourceDelete(void *userData, const void *data, size_t size, |
44 | size_t align) { |
45 | fprintf(stderr, format: "reportResourceDelete: %s\n", |
46 | ((struct ResourceDeleteUserData *)userData)->name); |
47 | } |
48 | |
49 | void populateLoopBody(MlirContext ctx, MlirBlock loopBody, |
50 | MlirLocation location, MlirBlock funcBody) { |
51 | MlirValue iv = mlirBlockGetArgument(block: loopBody, pos: 0); |
52 | MlirValue funcArg0 = mlirBlockGetArgument(block: funcBody, pos: 0); |
53 | MlirValue funcArg1 = mlirBlockGetArgument(block: funcBody, pos: 1); |
54 | MlirType f32Type = |
55 | mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(str: "f32")); |
56 | |
57 | MlirOperationState loadLHSState = mlirOperationStateGet( |
58 | name: mlirStringRefCreateFromCString(str: "memref.load"), loc: location); |
59 | MlirValue loadLHSOperands[] = {funcArg0, iv}; |
60 | mlirOperationStateAddOperands(state: &loadLHSState, n: 2, operands: loadLHSOperands); |
61 | mlirOperationStateAddResults(state: &loadLHSState, n: 1, results: &f32Type); |
62 | MlirOperation loadLHS = mlirOperationCreate(state: &loadLHSState); |
63 | mlirBlockAppendOwnedOperation(block: loopBody, operation: loadLHS); |
64 | |
65 | MlirOperationState loadRHSState = mlirOperationStateGet( |
66 | name: mlirStringRefCreateFromCString(str: "memref.load"), loc: location); |
67 | MlirValue loadRHSOperands[] = {funcArg1, iv}; |
68 | mlirOperationStateAddOperands(state: &loadRHSState, n: 2, operands: loadRHSOperands); |
69 | mlirOperationStateAddResults(state: &loadRHSState, n: 1, results: &f32Type); |
70 | MlirOperation loadRHS = mlirOperationCreate(state: &loadRHSState); |
71 | mlirBlockAppendOwnedOperation(block: loopBody, operation: loadRHS); |
72 | |
73 | MlirOperationState addState = mlirOperationStateGet( |
74 | name: mlirStringRefCreateFromCString(str: "arith.addf"), loc: location); |
75 | MlirValue addOperands[] = {mlirOperationGetResult(op: loadLHS, pos: 0), |
76 | mlirOperationGetResult(op: loadRHS, pos: 0)}; |
77 | mlirOperationStateAddOperands(state: &addState, n: 2, operands: addOperands); |
78 | mlirOperationStateAddResults(state: &addState, n: 1, results: &f32Type); |
79 | MlirOperation add = mlirOperationCreate(state: &addState); |
80 | mlirBlockAppendOwnedOperation(block: loopBody, operation: add); |
81 | |
82 | MlirOperationState storeState = mlirOperationStateGet( |
83 | name: mlirStringRefCreateFromCString(str: "memref.store"), loc: location); |
84 | MlirValue storeOperands[] = {mlirOperationGetResult(op: add, pos: 0), funcArg0, iv}; |
85 | mlirOperationStateAddOperands(state: &storeState, n: 3, operands: storeOperands); |
86 | MlirOperation store = mlirOperationCreate(state: &storeState); |
87 | mlirBlockAppendOwnedOperation(block: loopBody, operation: store); |
88 | |
89 | MlirOperationState yieldState = mlirOperationStateGet( |
90 | name: mlirStringRefCreateFromCString(str: "scf.yield"), loc: location); |
91 | MlirOperation yield = mlirOperationCreate(state: &yieldState); |
92 | mlirBlockAppendOwnedOperation(block: loopBody, operation: yield); |
93 | } |
94 | |
95 | MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) { |
96 | MlirModule moduleOp = mlirModuleCreateEmpty(location); |
97 | MlirBlock moduleBody = mlirModuleGetBody(module: moduleOp); |
98 | |
99 | MlirType memrefType = |
100 | mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(str: "memref<?xf32>")); |
101 | MlirType funcBodyArgTypes[] = {memrefType, memrefType}; |
102 | MlirLocation funcBodyArgLocs[] = {location, location}; |
103 | MlirRegion funcBodyRegion = mlirRegionCreate(); |
104 | MlirBlock funcBody = |
105 | mlirBlockCreate(nArgs: sizeof(funcBodyArgTypes) / sizeof(MlirType), |
106 | args: funcBodyArgTypes, locs: funcBodyArgLocs); |
107 | mlirRegionAppendOwnedBlock(region: funcBodyRegion, block: funcBody); |
108 | |
109 | MlirAttribute funcTypeAttr = mlirAttributeParseGet( |
110 | context: ctx, |
111 | attr: mlirStringRefCreateFromCString(str: "(memref<?xf32>, memref<?xf32>) -> ()")); |
112 | MlirAttribute funcNameAttr = |
113 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "\"add\"")); |
114 | MlirNamedAttribute funcAttrs[] = { |
115 | mlirNamedAttributeGet( |
116 | name: mlirIdentifierGet(context: ctx, |
117 | str: mlirStringRefCreateFromCString(str: "function_type")), |
118 | attr: funcTypeAttr), |
119 | mlirNamedAttributeGet( |
120 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "sym_name")), |
121 | attr: funcNameAttr)}; |
122 | MlirOperationState funcState = mlirOperationStateGet( |
123 | name: mlirStringRefCreateFromCString(str: "func.func"), loc: location); |
124 | mlirOperationStateAddAttributes(state: &funcState, n: 2, attributes: funcAttrs); |
125 | mlirOperationStateAddOwnedRegions(state: &funcState, n: 1, regions: &funcBodyRegion); |
126 | MlirOperation func = mlirOperationCreate(state: &funcState); |
127 | mlirBlockInsertOwnedOperation(block: moduleBody, pos: 0, operation: func); |
128 | |
129 | MlirType indexType = |
130 | mlirTypeParseGet(context: ctx, type: mlirStringRefCreateFromCString(str: "index")); |
131 | MlirAttribute indexZeroLiteral = |
132 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "0 : index")); |
133 | MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( |
134 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value")), |
135 | attr: indexZeroLiteral); |
136 | MlirOperationState constZeroState = mlirOperationStateGet( |
137 | name: mlirStringRefCreateFromCString(str: "arith.constant"), loc: location); |
138 | mlirOperationStateAddResults(state: &constZeroState, n: 1, results: &indexType); |
139 | mlirOperationStateAddAttributes(state: &constZeroState, n: 1, attributes: &indexZeroValueAttr); |
140 | MlirOperation constZero = mlirOperationCreate(state: &constZeroState); |
141 | mlirBlockAppendOwnedOperation(block: funcBody, operation: constZero); |
142 | |
143 | MlirValue funcArg0 = mlirBlockGetArgument(block: funcBody, pos: 0); |
144 | MlirValue constZeroValue = mlirOperationGetResult(op: constZero, pos: 0); |
145 | MlirValue dimOperands[] = {funcArg0, constZeroValue}; |
146 | MlirOperationState dimState = mlirOperationStateGet( |
147 | name: mlirStringRefCreateFromCString(str: "memref.dim"), loc: location); |
148 | mlirOperationStateAddOperands(state: &dimState, n: 2, operands: dimOperands); |
149 | mlirOperationStateAddResults(state: &dimState, n: 1, results: &indexType); |
150 | MlirOperation dim = mlirOperationCreate(state: &dimState); |
151 | mlirBlockAppendOwnedOperation(block: funcBody, operation: dim); |
152 | |
153 | MlirRegion loopBodyRegion = mlirRegionCreate(); |
154 | MlirBlock loopBody = mlirBlockCreate(nArgs: 0, NULL, NULL); |
155 | mlirBlockAddArgument(block: loopBody, type: indexType, loc: location); |
156 | mlirRegionAppendOwnedBlock(region: loopBodyRegion, block: loopBody); |
157 | |
158 | MlirAttribute indexOneLiteral = |
159 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "1 : index")); |
160 | MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet( |
161 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value")), |
162 | attr: indexOneLiteral); |
163 | MlirOperationState constOneState = mlirOperationStateGet( |
164 | name: mlirStringRefCreateFromCString(str: "arith.constant"), loc: location); |
165 | mlirOperationStateAddResults(state: &constOneState, n: 1, results: &indexType); |
166 | mlirOperationStateAddAttributes(state: &constOneState, n: 1, attributes: &indexOneValueAttr); |
167 | MlirOperation constOne = mlirOperationCreate(state: &constOneState); |
168 | mlirBlockAppendOwnedOperation(block: funcBody, operation: constOne); |
169 | |
170 | MlirValue dimValue = mlirOperationGetResult(op: dim, pos: 0); |
171 | MlirValue constOneValue = mlirOperationGetResult(op: constOne, pos: 0); |
172 | MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue}; |
173 | MlirOperationState loopState = mlirOperationStateGet( |
174 | name: mlirStringRefCreateFromCString(str: "scf.for"), loc: location); |
175 | mlirOperationStateAddOperands(state: &loopState, n: 3, operands: loopOperands); |
176 | mlirOperationStateAddOwnedRegions(state: &loopState, n: 1, regions: &loopBodyRegion); |
177 | MlirOperation loop = mlirOperationCreate(state: &loopState); |
178 | mlirBlockAppendOwnedOperation(block: funcBody, operation: loop); |
179 | |
180 | populateLoopBody(ctx, loopBody, location, funcBody); |
181 | |
182 | MlirOperationState retState = mlirOperationStateGet( |
183 | name: mlirStringRefCreateFromCString(str: "func.return"), loc: location); |
184 | MlirOperation ret = mlirOperationCreate(state: &retState); |
185 | mlirBlockAppendOwnedOperation(block: funcBody, operation: ret); |
186 | |
187 | MlirOperation module = mlirModuleGetOperation(module: moduleOp); |
188 | mlirOperationDump(op: module); |
189 | // clang-format off |
190 | // CHECK: module { |
191 | // CHECK: func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) { |
192 | // CHECK: %[[C0:.*]] = arith.constant 0 : index |
193 | // CHECK: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xf32> |
194 | // CHECK: %[[C1:.*]] = arith.constant 1 : index |
195 | // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] { |
196 | // CHECK: %[[LHS:.*]] = memref.load %[[ARG0]][%[[I]]] : memref<?xf32> |
197 | // CHECK: %[[RHS:.*]] = memref.load %[[ARG1]][%[[I]]] : memref<?xf32> |
198 | // CHECK: %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 |
199 | // CHECK: memref.store %[[SUM]], %[[ARG0]][%[[I]]] : memref<?xf32> |
200 | // CHECK: } |
201 | // CHECK: return |
202 | // CHECK: } |
203 | // CHECK: } |
204 | // clang-format on |
205 | |
206 | return moduleOp; |
207 | } |
208 | |
209 | struct OpListNode { |
210 | MlirOperation op; |
211 | struct OpListNode *next; |
212 | }; |
213 | typedef struct OpListNode OpListNode; |
214 | |
215 | struct ModuleStats { |
216 | unsigned numOperations; |
217 | unsigned numAttributes; |
218 | unsigned numBlocks; |
219 | unsigned numRegions; |
220 | unsigned numValues; |
221 | unsigned numBlockArguments; |
222 | unsigned numOpResults; |
223 | }; |
224 | typedef struct ModuleStats ModuleStats; |
225 | |
226 | int collectStatsSingle(OpListNode *head, ModuleStats *stats) { |
227 | MlirOperation operation = head->op; |
228 | stats->numOperations += 1; |
229 | stats->numValues += mlirOperationGetNumResults(op: operation); |
230 | stats->numAttributes += mlirOperationGetNumAttributes(op: operation); |
231 | |
232 | unsigned numRegions = mlirOperationGetNumRegions(op: operation); |
233 | |
234 | stats->numRegions += numRegions; |
235 | |
236 | intptr_t numResults = mlirOperationGetNumResults(op: operation); |
237 | for (intptr_t i = 0; i < numResults; ++i) { |
238 | MlirValue result = mlirOperationGetResult(op: operation, pos: i); |
239 | if (!mlirValueIsAOpResult(value: result)) |
240 | return 1; |
241 | if (mlirValueIsABlockArgument(value: result)) |
242 | return 2; |
243 | if (!mlirOperationEqual(op: operation, other: mlirOpResultGetOwner(value: result))) |
244 | return 3; |
245 | if (i != mlirOpResultGetResultNumber(value: result)) |
246 | return 4; |
247 | ++stats->numOpResults; |
248 | } |
249 | |
250 | MlirRegion region = mlirOperationGetFirstRegion(op: operation); |
251 | while (!mlirRegionIsNull(region)) { |
252 | for (MlirBlock block = mlirRegionGetFirstBlock(region); |
253 | !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) { |
254 | ++stats->numBlocks; |
255 | intptr_t numArgs = mlirBlockGetNumArguments(block); |
256 | stats->numValues += numArgs; |
257 | for (intptr_t j = 0; j < numArgs; ++j) { |
258 | MlirValue arg = mlirBlockGetArgument(block, pos: j); |
259 | if (!mlirValueIsABlockArgument(value: arg)) |
260 | return 5; |
261 | if (mlirValueIsAOpResult(value: arg)) |
262 | return 6; |
263 | if (!mlirBlockEqual(block, other: mlirBlockArgumentGetOwner(value: arg))) |
264 | return 7; |
265 | if (j != mlirBlockArgumentGetArgNumber(value: arg)) |
266 | return 8; |
267 | ++stats->numBlockArguments; |
268 | } |
269 | |
270 | for (MlirOperation child = mlirBlockGetFirstOperation(block); |
271 | !mlirOperationIsNull(op: child); |
272 | child = mlirOperationGetNextInBlock(op: child)) { |
273 | OpListNode *node = malloc(size: sizeof(OpListNode)); |
274 | node->op = child; |
275 | node->next = head->next; |
276 | head->next = node; |
277 | } |
278 | } |
279 | region = mlirRegionGetNextInOperation(region); |
280 | } |
281 | return 0; |
282 | } |
283 | |
284 | int collectStats(MlirOperation operation) { |
285 | OpListNode *head = malloc(size: sizeof(OpListNode)); |
286 | head->op = operation; |
287 | head->next = NULL; |
288 | |
289 | ModuleStats stats; |
290 | stats.numOperations = 0; |
291 | stats.numAttributes = 0; |
292 | stats.numBlocks = 0; |
293 | stats.numRegions = 0; |
294 | stats.numValues = 0; |
295 | stats.numBlockArguments = 0; |
296 | stats.numOpResults = 0; |
297 | |
298 | do { |
299 | int retval = collectStatsSingle(head, stats: &stats); |
300 | if (retval) { |
301 | free(ptr: head); |
302 | return retval; |
303 | } |
304 | OpListNode *next = head->next; |
305 | free(ptr: head); |
306 | head = next; |
307 | } while (head); |
308 | |
309 | if (stats.numValues != stats.numBlockArguments + stats.numOpResults) |
310 | return 100; |
311 | |
312 | fprintf(stderr, format: "@stats\n"); |
313 | fprintf(stderr, format: "Number of operations: %u\n", stats.numOperations); |
314 | fprintf(stderr, format: "Number of attributes: %u\n", stats.numAttributes); |
315 | fprintf(stderr, format: "Number of blocks: %u\n", stats.numBlocks); |
316 | fprintf(stderr, format: "Number of regions: %u\n", stats.numRegions); |
317 | fprintf(stderr, format: "Number of values: %u\n", stats.numValues); |
318 | fprintf(stderr, format: "Number of block arguments: %u\n", stats.numBlockArguments); |
319 | fprintf(stderr, format: "Number of op results: %u\n", stats.numOpResults); |
320 | // clang-format off |
321 | // CHECK-LABEL: @stats |
322 | // CHECK: Number of operations: 12 |
323 | // CHECK: Number of attributes: 5 |
324 | // CHECK: Number of blocks: 3 |
325 | // CHECK: Number of regions: 3 |
326 | // CHECK: Number of values: 9 |
327 | // CHECK: Number of block arguments: 3 |
328 | // CHECK: Number of op results: 6 |
329 | // clang-format on |
330 | return 0; |
331 | } |
332 | |
333 | static void printToStderr(MlirStringRef str, void *userData) { |
334 | (void)userData; |
335 | fwrite(ptr: str.data, size: 1, n: str.length, stderr); |
336 | } |
337 | |
338 | static void printFirstOfEach(MlirContext ctx, MlirOperation operation) { |
339 | // Assuming we are given a module, go to the first operation of the first |
340 | // function. |
341 | MlirRegion region = mlirOperationGetRegion(op: operation, pos: 0); |
342 | MlirBlock block = mlirRegionGetFirstBlock(region); |
343 | MlirOperation function = mlirBlockGetFirstOperation(block); |
344 | region = mlirOperationGetRegion(op: function, pos: 0); |
345 | MlirOperation parentOperation = function; |
346 | block = mlirRegionGetFirstBlock(region); |
347 | operation = mlirBlockGetFirstOperation(block); |
348 | assert(mlirModuleIsNull(mlirModuleFromOperation(operation))); |
349 | |
350 | // Verify that parent operation and block report correctly. |
351 | // CHECK: Parent operation eq: 1 |
352 | fprintf(stderr, format: "Parent operation eq: %d\n", |
353 | mlirOperationEqual(op: mlirOperationGetParentOperation(op: operation), |
354 | other: parentOperation)); |
355 | // CHECK: Block eq: 1 |
356 | fprintf(stderr, format: "Block eq: %d\n", |
357 | mlirBlockEqual(block: mlirOperationGetBlock(op: operation), other: block)); |
358 | // CHECK: Block parent operation eq: 1 |
359 | fprintf( |
360 | stderr, format: "Block parent operation eq: %d\n", |
361 | mlirOperationEqual(op: mlirBlockGetParentOperation(block), other: parentOperation)); |
362 | // CHECK: Block parent region eq: 1 |
363 | fprintf(stderr, format: "Block parent region eq: %d\n", |
364 | mlirRegionEqual(region: mlirBlockGetParentRegion(block), other: region)); |
365 | |
366 | // In the module we created, the first operation of the first function is |
367 | // an "memref.dim", which has an attribute and a single result that we can |
368 | // use to test the printing mechanism. |
369 | mlirBlockPrint(block, callback: printToStderr, NULL); |
370 | fprintf(stderr, format: "\n"); |
371 | fprintf(stderr, format: "First operation: "); |
372 | mlirOperationPrint(op: operation, callback: printToStderr, NULL); |
373 | fprintf(stderr, format: "\n"); |
374 | // clang-format off |
375 | // CHECK: %[[C0:.*]] = arith.constant 0 : index |
376 | // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32> |
377 | // CHECK: %[[C1:.*]] = arith.constant 1 : index |
378 | // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] { |
379 | // CHECK: %[[LHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32> |
380 | // CHECK: %[[RHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32> |
381 | // CHECK: %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 |
382 | // CHECK: memref.store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32> |
383 | // CHECK: } |
384 | // CHECK: return |
385 | // CHECK: First operation: {{.*}} = arith.constant 0 : index |
386 | // clang-format on |
387 | |
388 | // Get the operation name and print it. |
389 | MlirIdentifier ident = mlirOperationGetName(op: operation); |
390 | MlirStringRef identStr = mlirIdentifierStr(ident); |
391 | fprintf(stderr, format: "Operation name: '"); |
392 | for (size_t i = 0; i < identStr.length; ++i) |
393 | fputc(c: identStr.data[i], stderr); |
394 | fprintf(stderr, format: "'\n"); |
395 | // CHECK: Operation name: 'arith.constant' |
396 | |
397 | // Get the identifier again and verify equal. |
398 | MlirIdentifier identAgain = mlirIdentifierGet(context: ctx, str: identStr); |
399 | fprintf(stderr, format: "Identifier equal: %d\n", |
400 | mlirIdentifierEqual(ident, other: identAgain)); |
401 | // CHECK: Identifier equal: 1 |
402 | |
403 | // Get the block terminator and print it. |
404 | MlirOperation terminator = mlirBlockGetTerminator(block); |
405 | fprintf(stderr, format: "Terminator: "); |
406 | mlirOperationPrint(op: terminator, callback: printToStderr, NULL); |
407 | fprintf(stderr, format: "\n"); |
408 | // CHECK: Terminator: func.return |
409 | |
410 | // Get the attribute by name. |
411 | bool hasValueAttr = mlirOperationHasInherentAttributeByName( |
412 | op: operation, name: mlirStringRefCreateFromCString(str: "value")); |
413 | if (hasValueAttr) |
414 | // CHECK: Has attr "value" |
415 | fprintf(stderr, format: "Has attr \"value\""); |
416 | |
417 | MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName( |
418 | op: operation, name: mlirStringRefCreateFromCString(str: "value")); |
419 | fprintf(stderr, format: "Get attr \"value\": "); |
420 | mlirAttributePrint(attr: valueAttr0, callback: printToStderr, NULL); |
421 | fprintf(stderr, format: "\n"); |
422 | // CHECK: Get attr "value": 0 : index |
423 | |
424 | // Get a non-existing attribute and assert that it is null (sanity). |
425 | fprintf(stderr, format: "does_not_exist is null: %d\n", |
426 | mlirAttributeIsNull(attr: mlirOperationGetDiscardableAttributeByName( |
427 | op: operation, name: mlirStringRefCreateFromCString(str: "does_not_exist")))); |
428 | // CHECK: does_not_exist is null: 1 |
429 | |
430 | // Get result 0 and its type. |
431 | MlirValue value = mlirOperationGetResult(op: operation, pos: 0); |
432 | fprintf(stderr, format: "Result 0: "); |
433 | mlirValuePrint(value, callback: printToStderr, NULL); |
434 | fprintf(stderr, format: "\n"); |
435 | fprintf(stderr, format: "Value is null: %d\n", mlirValueIsNull(value)); |
436 | // CHECK: Result 0: {{.*}} = arith.constant 0 : index |
437 | // CHECK: Value is null: 0 |
438 | |
439 | MlirType type = mlirValueGetType(value); |
440 | fprintf(stderr, format: "Result 0 type: "); |
441 | mlirTypePrint(type, callback: printToStderr, NULL); |
442 | fprintf(stderr, format: "\n"); |
443 | // CHECK: Result 0 type: index |
444 | |
445 | // Set a discardable attribute. |
446 | mlirOperationSetDiscardableAttributeByName( |
447 | op: operation, name: mlirStringRefCreateFromCString(str: "custom_attr"), |
448 | attr: mlirBoolAttrGet(ctx, value: 1)); |
449 | fprintf(stderr, format: "Op with set attr: "); |
450 | mlirOperationPrint(op: operation, callback: printToStderr, NULL); |
451 | fprintf(stderr, format: "\n"); |
452 | // CHECK: Op with set attr: {{.*}} {custom_attr = true} |
453 | |
454 | // Remove the attribute. |
455 | fprintf(stderr, format: "Remove attr: %d\n", |
456 | mlirOperationRemoveDiscardableAttributeByName( |
457 | op: operation, name: mlirStringRefCreateFromCString(str: "custom_attr"))); |
458 | fprintf(stderr, format: "Remove attr again: %d\n", |
459 | mlirOperationRemoveDiscardableAttributeByName( |
460 | op: operation, name: mlirStringRefCreateFromCString(str: "custom_attr"))); |
461 | fprintf(stderr, format: "Removed attr is null: %d\n", |
462 | mlirAttributeIsNull(attr: mlirOperationGetDiscardableAttributeByName( |
463 | op: operation, name: mlirStringRefCreateFromCString(str: "custom_attr")))); |
464 | // CHECK: Remove attr: 1 |
465 | // CHECK: Remove attr again: 0 |
466 | // CHECK: Removed attr is null: 1 |
467 | |
468 | // Add a large attribute to verify printing flags. |
469 | int64_t eltsShape[] = {4}; |
470 | int32_t eltsData[] = {1, 2, 3, 4}; |
471 | mlirOperationSetDiscardableAttributeByName( |
472 | op: operation, name: mlirStringRefCreateFromCString(str: "elts"), |
473 | attr: mlirDenseElementsAttrInt32Get( |
474 | shapedType: mlirRankedTensorTypeGet(rank: 1, shape: eltsShape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 32), |
475 | encoding: mlirAttributeGetNull()), |
476 | numElements: 4, elements: eltsData)); |
477 | MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); |
478 | mlirOpPrintingFlagsElideLargeElementsAttrs(flags, largeElementLimit: 2); |
479 | mlirOpPrintingFlagsPrintGenericOpForm(flags); |
480 | mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/1, /*prettyForm=*/0); |
481 | mlirOpPrintingFlagsUseLocalScope(flags); |
482 | fprintf(stderr, format: "Op print with all flags: "); |
483 | mlirOperationPrintWithFlags(op: operation, flags, callback: printToStderr, NULL); |
484 | fprintf(stderr, format: "\n"); |
485 | fprintf(stderr, format: "Op print with state: "); |
486 | MlirAsmState state = mlirAsmStateCreateForOperation(op: parentOperation, flags); |
487 | mlirOperationPrintWithState(op: operation, state, callback: printToStderr, NULL); |
488 | fprintf(stderr, format: "\n"); |
489 | // clang-format off |
490 | // CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown) |
491 | // clang-format on |
492 | |
493 | mlirOpPrintingFlagsDestroy(flags); |
494 | flags = mlirOpPrintingFlagsCreate(); |
495 | mlirOpPrintingFlagsSkipRegions(flags); |
496 | fprintf(stderr, format: "Op print with skip regions flag: "); |
497 | mlirOperationPrintWithFlags(op: function, flags, callback: printToStderr, NULL); |
498 | fprintf(stderr, format: "\n"); |
499 | // clang-format off |
500 | // CHECK: Op print with skip regions flag: func.func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) |
501 | // CHECK-NOT: constant |
502 | // CHECK-NOT: return |
503 | // clang-format on |
504 | |
505 | fprintf(stderr, format: "With state: |"); |
506 | mlirValuePrintAsOperand(value, state, callback: printToStderr, NULL); |
507 | // CHECK: With state: |%0| |
508 | fprintf(stderr, format: "|\n"); |
509 | mlirAsmStateDestroy(state); |
510 | |
511 | mlirOpPrintingFlagsDestroy(flags); |
512 | } |
513 | |
514 | static int constructAndTraverseIr(MlirContext ctx) { |
515 | MlirLocation location = mlirLocationUnknownGet(context: ctx); |
516 | |
517 | MlirModule moduleOp = makeAndDumpAdd(ctx, location); |
518 | MlirOperation module = mlirModuleGetOperation(module: moduleOp); |
519 | assert(!mlirModuleIsNull(mlirModuleFromOperation(module))); |
520 | |
521 | int errcode = collectStats(operation: module); |
522 | if (errcode) |
523 | return errcode; |
524 | |
525 | printFirstOfEach(ctx, operation: module); |
526 | |
527 | mlirModuleDestroy(module: moduleOp); |
528 | return 0; |
529 | } |
530 | |
531 | /// Creates an operation with a region containing multiple blocks with |
532 | /// operations and dumps it. The blocks and operations are inserted using |
533 | /// block/operation-relative API and their final order is checked. |
534 | static void buildWithInsertionsAndPrint(MlirContext ctx) { |
535 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
536 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
537 | |
538 | MlirRegion owningRegion = mlirRegionCreate(); |
539 | MlirBlock nullBlock = mlirRegionGetFirstBlock(region: owningRegion); |
540 | MlirOperationState state = mlirOperationStateGet( |
541 | name: mlirStringRefCreateFromCString(str: "insertion.order.test"), loc); |
542 | mlirOperationStateAddOwnedRegions(state: &state, n: 1, regions: &owningRegion); |
543 | MlirOperation op = mlirOperationCreate(state: &state); |
544 | MlirRegion region = mlirOperationGetRegion(op, pos: 0); |
545 | |
546 | // Use integer types of different bitwidth as block arguments in order to |
547 | // differentiate blocks. |
548 | MlirType i1 = mlirIntegerTypeGet(ctx, bitwidth: 1); |
549 | MlirType i2 = mlirIntegerTypeGet(ctx, bitwidth: 2); |
550 | MlirType i3 = mlirIntegerTypeGet(ctx, bitwidth: 3); |
551 | MlirType i4 = mlirIntegerTypeGet(ctx, bitwidth: 4); |
552 | MlirType i5 = mlirIntegerTypeGet(ctx, bitwidth: 5); |
553 | MlirBlock block1 = mlirBlockCreate(nArgs: 1, args: &i1, locs: &loc); |
554 | MlirBlock block2 = mlirBlockCreate(nArgs: 1, args: &i2, locs: &loc); |
555 | MlirBlock block3 = mlirBlockCreate(nArgs: 1, args: &i3, locs: &loc); |
556 | MlirBlock block4 = mlirBlockCreate(nArgs: 1, args: &i4, locs: &loc); |
557 | MlirBlock block5 = mlirBlockCreate(nArgs: 1, args: &i5, locs: &loc); |
558 | // Insert blocks so as to obtain the 1-2-3-4 order, |
559 | mlirRegionInsertOwnedBlockBefore(region, reference: nullBlock, block: block3); |
560 | mlirRegionInsertOwnedBlockBefore(region, reference: block3, block: block2); |
561 | mlirRegionInsertOwnedBlockAfter(region, reference: nullBlock, block: block1); |
562 | mlirRegionInsertOwnedBlockAfter(region, reference: block3, block: block4); |
563 | mlirRegionInsertOwnedBlockBefore(region, reference: block3, block: block5); |
564 | |
565 | MlirOperationState op1State = |
566 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op1"), loc); |
567 | MlirOperationState op2State = |
568 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op2"), loc); |
569 | MlirOperationState op3State = |
570 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op3"), loc); |
571 | MlirOperationState op4State = |
572 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op4"), loc); |
573 | MlirOperationState op5State = |
574 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op5"), loc); |
575 | MlirOperationState op6State = |
576 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op6"), loc); |
577 | MlirOperationState op7State = |
578 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op7"), loc); |
579 | MlirOperationState op8State = |
580 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op8"), loc); |
581 | MlirOperation op1 = mlirOperationCreate(state: &op1State); |
582 | MlirOperation op2 = mlirOperationCreate(state: &op2State); |
583 | MlirOperation op3 = mlirOperationCreate(state: &op3State); |
584 | MlirOperation op4 = mlirOperationCreate(state: &op4State); |
585 | MlirOperation op5 = mlirOperationCreate(state: &op5State); |
586 | MlirOperation op6 = mlirOperationCreate(state: &op6State); |
587 | MlirOperation op7 = mlirOperationCreate(state: &op7State); |
588 | MlirOperation op8 = mlirOperationCreate(state: &op8State); |
589 | |
590 | // Insert operations in the first block so as to obtain the 1-2-3-4 order. |
591 | MlirOperation nullOperation = mlirBlockGetFirstOperation(block: block1); |
592 | assert(mlirOperationIsNull(nullOperation)); |
593 | mlirBlockInsertOwnedOperationBefore(block: block1, reference: nullOperation, operation: op3); |
594 | mlirBlockInsertOwnedOperationBefore(block: block1, reference: op3, operation: op2); |
595 | mlirBlockInsertOwnedOperationAfter(block: block1, reference: nullOperation, operation: op1); |
596 | mlirBlockInsertOwnedOperationAfter(block: block1, reference: op3, operation: op4); |
597 | |
598 | // Append operations to the rest of blocks to make them non-empty and thus |
599 | // printable. |
600 | mlirBlockAppendOwnedOperation(block: block2, operation: op5); |
601 | mlirBlockAppendOwnedOperation(block: block3, operation: op6); |
602 | mlirBlockAppendOwnedOperation(block: block4, operation: op7); |
603 | mlirBlockAppendOwnedOperation(block: block5, operation: op8); |
604 | |
605 | // Remove block5. |
606 | mlirBlockDetach(block: block5); |
607 | mlirBlockDestroy(block: block5); |
608 | |
609 | mlirOperationDump(op); |
610 | mlirOperationDestroy(op); |
611 | mlirContextSetAllowUnregisteredDialects(context: ctx, false); |
612 | // clang-format off |
613 | // CHECK-LABEL: "insertion.order.test" |
614 | // CHECK: ^{{.*}}(%{{.*}}: i1 |
615 | // CHECK: "dummy.op1" |
616 | // CHECK-NEXT: "dummy.op2" |
617 | // CHECK-NEXT: "dummy.op3" |
618 | // CHECK-NEXT: "dummy.op4" |
619 | // CHECK: ^{{.*}}(%{{.*}}: i2 |
620 | // CHECK: "dummy.op5" |
621 | // CHECK-NOT: ^{{.*}}(%{{.*}}: i5 |
622 | // CHECK-NOT: "dummy.op8" |
623 | // CHECK: ^{{.*}}(%{{.*}}: i3 |
624 | // CHECK: "dummy.op6" |
625 | // CHECK: ^{{.*}}(%{{.*}}: i4 |
626 | // CHECK: "dummy.op7" |
627 | // clang-format on |
628 | } |
629 | |
630 | /// Creates operations with type inference and tests various failure modes. |
631 | static int createOperationWithTypeInference(MlirContext ctx) { |
632 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
633 | MlirAttribute iAttr = mlirIntegerAttrGet(type: mlirIntegerTypeGet(ctx, bitwidth: 32), value: 4); |
634 | |
635 | // The shape.const_size op implements result type inference and is only used |
636 | // for that reason. |
637 | MlirOperationState state = mlirOperationStateGet( |
638 | name: mlirStringRefCreateFromCString(str: "shape.const_size"), loc); |
639 | MlirNamedAttribute valueAttr = mlirNamedAttributeGet( |
640 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value")), attr: iAttr); |
641 | mlirOperationStateAddAttributes(state: &state, n: 1, attributes: &valueAttr); |
642 | mlirOperationStateEnableResultTypeInference(state: &state); |
643 | |
644 | // Expect result type inference to succeed. |
645 | MlirOperation op = mlirOperationCreate(state: &state); |
646 | if (mlirOperationIsNull(op)) { |
647 | fprintf(stderr, format: "ERROR: Result type inference unexpectedly failed"); |
648 | return 1; |
649 | } |
650 | |
651 | // CHECK: RESULT_TYPE_INFERENCE: !shape.size |
652 | fprintf(stderr, format: "RESULT_TYPE_INFERENCE: "); |
653 | mlirTypeDump(type: mlirValueGetType(value: mlirOperationGetResult(op, pos: 0))); |
654 | fprintf(stderr, format: "\n"); |
655 | mlirOperationDestroy(op); |
656 | return 0; |
657 | } |
658 | |
659 | /// Dumps instances of all builtin types to check that C API works correctly. |
660 | /// Additionally, performs simple identity checks that a builtin type |
661 | /// constructed with C API can be inspected and has the expected type. The |
662 | /// latter achieves full coverage of C API for builtin types. Returns 0 on |
663 | /// success and a non-zero error code on failure. |
664 | static int printBuiltinTypes(MlirContext ctx) { |
665 | // Integer types. |
666 | MlirType i32 = mlirIntegerTypeGet(ctx, bitwidth: 32); |
667 | MlirType si32 = mlirIntegerTypeSignedGet(ctx, bitwidth: 32); |
668 | MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, bitwidth: 32); |
669 | if (!mlirTypeIsAInteger(type: i32) || mlirTypeIsAF32(type: i32)) |
670 | return 1; |
671 | if (!mlirTypeIsAInteger(type: si32) || !mlirIntegerTypeIsSigned(type: si32)) |
672 | return 2; |
673 | if (!mlirTypeIsAInteger(type: ui32) || !mlirIntegerTypeIsUnsigned(type: ui32)) |
674 | return 3; |
675 | if (mlirTypeEqual(t1: i32, t2: ui32) || mlirTypeEqual(t1: i32, t2: si32)) |
676 | return 4; |
677 | if (mlirIntegerTypeGetWidth(type: i32) != mlirIntegerTypeGetWidth(type: si32)) |
678 | return 5; |
679 | fprintf(stderr, format: "@types\n"); |
680 | mlirTypeDump(type: i32); |
681 | fprintf(stderr, format: "\n"); |
682 | mlirTypeDump(type: si32); |
683 | fprintf(stderr, format: "\n"); |
684 | mlirTypeDump(type: ui32); |
685 | fprintf(stderr, format: "\n"); |
686 | // CHECK-LABEL: @types |
687 | // CHECK: i32 |
688 | // CHECK: si32 |
689 | // CHECK: ui32 |
690 | |
691 | // Index type. |
692 | MlirType index = mlirIndexTypeGet(ctx); |
693 | if (!mlirTypeIsAIndex(type: index)) |
694 | return 6; |
695 | mlirTypeDump(type: index); |
696 | fprintf(stderr, format: "\n"); |
697 | // CHECK: index |
698 | |
699 | // Floating-point types. |
700 | MlirType bf16 = mlirBF16TypeGet(ctx); |
701 | MlirType f16 = mlirF16TypeGet(ctx); |
702 | MlirType f32 = mlirF32TypeGet(ctx); |
703 | MlirType f64 = mlirF64TypeGet(ctx); |
704 | if (!mlirTypeIsABF16(type: bf16)) |
705 | return 7; |
706 | if (!mlirTypeIsAF16(type: f16)) |
707 | return 9; |
708 | if (!mlirTypeIsAF32(type: f32)) |
709 | return 10; |
710 | if (!mlirTypeIsAF64(type: f64)) |
711 | return 11; |
712 | mlirTypeDump(type: bf16); |
713 | fprintf(stderr, format: "\n"); |
714 | mlirTypeDump(type: f16); |
715 | fprintf(stderr, format: "\n"); |
716 | mlirTypeDump(type: f32); |
717 | fprintf(stderr, format: "\n"); |
718 | mlirTypeDump(type: f64); |
719 | fprintf(stderr, format: "\n"); |
720 | // CHECK: bf16 |
721 | // CHECK: f16 |
722 | // CHECK: f32 |
723 | // CHECK: f64 |
724 | |
725 | // None type. |
726 | MlirType none = mlirNoneTypeGet(ctx); |
727 | if (!mlirTypeIsANone(type: none)) |
728 | return 12; |
729 | mlirTypeDump(type: none); |
730 | fprintf(stderr, format: "\n"); |
731 | // CHECK: none |
732 | |
733 | // Complex type. |
734 | MlirType cplx = mlirComplexTypeGet(elementType: f32); |
735 | if (!mlirTypeIsAComplex(type: cplx) || |
736 | !mlirTypeEqual(t1: mlirComplexTypeGetElementType(type: cplx), t2: f32)) |
737 | return 13; |
738 | mlirTypeDump(type: cplx); |
739 | fprintf(stderr, format: "\n"); |
740 | // CHECK: complex<f32> |
741 | |
742 | // Vector (and Shaped) type. ShapedType is a common base class for vectors, |
743 | // memrefs and tensors, one cannot create instances of this class so it is |
744 | // tested on an instance of vector type. |
745 | int64_t shape[] = {2, 3}; |
746 | MlirType vector = |
747 | mlirVectorTypeGet(rank: sizeof(shape) / sizeof(int64_t), shape, elementType: f32); |
748 | if (!mlirTypeIsAVector(type: vector) || !mlirTypeIsAShaped(type: vector)) |
749 | return 14; |
750 | if (!mlirTypeEqual(t1: mlirShapedTypeGetElementType(type: vector), t2: f32) || |
751 | !mlirShapedTypeHasRank(type: vector) || mlirShapedTypeGetRank(type: vector) != 2 || |
752 | mlirShapedTypeGetDimSize(type: vector, dim: 0) != 2 || |
753 | mlirShapedTypeIsDynamicDim(type: vector, dim: 0) || |
754 | mlirShapedTypeGetDimSize(type: vector, dim: 1) != 3 || |
755 | !mlirShapedTypeHasStaticShape(type: vector)) |
756 | return 15; |
757 | mlirTypeDump(type: vector); |
758 | fprintf(stderr, format: "\n"); |
759 | // CHECK: vector<2x3xf32> |
760 | |
761 | // Scalable vector type. |
762 | bool scalable[] = {false, true}; |
763 | MlirType scalableVector = mlirVectorTypeGetScalable( |
764 | rank: sizeof(shape) / sizeof(int64_t), shape, scalable, elementType: f32); |
765 | if (!mlirTypeIsAVector(type: scalableVector)) |
766 | return 16; |
767 | if (!mlirVectorTypeIsScalable(type: scalableVector) || |
768 | mlirVectorTypeIsDimScalable(type: scalableVector, dim: 0) || |
769 | !mlirVectorTypeIsDimScalable(type: scalableVector, dim: 1)) |
770 | return 17; |
771 | mlirTypeDump(type: scalableVector); |
772 | fprintf(stderr, format: "\n"); |
773 | // CHECK: vector<2x[3]xf32> |
774 | |
775 | // Ranked tensor type. |
776 | MlirType rankedTensor = mlirRankedTensorTypeGet( |
777 | rank: sizeof(shape) / sizeof(int64_t), shape, elementType: f32, encoding: mlirAttributeGetNull()); |
778 | if (!mlirTypeIsATensor(type: rankedTensor) || |
779 | !mlirTypeIsARankedTensor(type: rankedTensor) || |
780 | !mlirAttributeIsNull(attr: mlirRankedTensorTypeGetEncoding(type: rankedTensor))) |
781 | return 18; |
782 | mlirTypeDump(type: rankedTensor); |
783 | fprintf(stderr, format: "\n"); |
784 | // CHECK: tensor<2x3xf32> |
785 | |
786 | // Unranked tensor type. |
787 | MlirType unrankedTensor = mlirUnrankedTensorTypeGet(elementType: f32); |
788 | if (!mlirTypeIsATensor(type: unrankedTensor) || |
789 | !mlirTypeIsAUnrankedTensor(type: unrankedTensor) || |
790 | mlirShapedTypeHasRank(type: unrankedTensor)) |
791 | return 19; |
792 | mlirTypeDump(type: unrankedTensor); |
793 | fprintf(stderr, format: "\n"); |
794 | // CHECK: tensor<*xf32> |
795 | |
796 | // MemRef type. |
797 | MlirAttribute memSpace2 = mlirIntegerAttrGet(type: mlirIntegerTypeGet(ctx, bitwidth: 64), value: 2); |
798 | MlirType memRef = mlirMemRefTypeContiguousGet( |
799 | elementType: f32, rank: sizeof(shape) / sizeof(int64_t), shape, memorySpace: memSpace2); |
800 | if (!mlirTypeIsAMemRef(type: memRef) || |
801 | !mlirAttributeEqual(a1: mlirMemRefTypeGetMemorySpace(type: memRef), a2: memSpace2)) |
802 | return 20; |
803 | mlirTypeDump(type: memRef); |
804 | fprintf(stderr, format: "\n"); |
805 | // CHECK: memref<2x3xf32, 2> |
806 | |
807 | // Unranked MemRef type. |
808 | MlirAttribute memSpace4 = mlirIntegerAttrGet(type: mlirIntegerTypeGet(ctx, bitwidth: 64), value: 4); |
809 | MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(elementType: f32, memorySpace: memSpace4); |
810 | if (!mlirTypeIsAUnrankedMemRef(type: unrankedMemRef) || |
811 | mlirTypeIsAMemRef(type: unrankedMemRef) || |
812 | !mlirAttributeEqual(a1: mlirUnrankedMemrefGetMemorySpace(type: unrankedMemRef), |
813 | a2: memSpace4)) |
814 | return 21; |
815 | mlirTypeDump(type: unrankedMemRef); |
816 | fprintf(stderr, format: "\n"); |
817 | // CHECK: memref<*xf32, 4> |
818 | |
819 | // Tuple type. |
820 | MlirType types[] = {unrankedMemRef, f32}; |
821 | MlirType tuple = mlirTupleTypeGet(ctx, numElements: 2, elements: types); |
822 | if (!mlirTypeIsATuple(type: tuple) || mlirTupleTypeGetNumTypes(type: tuple) != 2 || |
823 | !mlirTypeEqual(t1: mlirTupleTypeGetType(type: tuple, pos: 0), t2: unrankedMemRef) || |
824 | !mlirTypeEqual(t1: mlirTupleTypeGetType(type: tuple, pos: 1), t2: f32)) |
825 | return 22; |
826 | mlirTypeDump(type: tuple); |
827 | fprintf(stderr, format: "\n"); |
828 | // CHECK: tuple<memref<*xf32, 4>, f32> |
829 | |
830 | // Function type. |
831 | MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, bitwidth: 1)}; |
832 | MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, bitwidth: 16), |
833 | mlirIntegerTypeGet(ctx, bitwidth: 32), |
834 | mlirIntegerTypeGet(ctx, bitwidth: 64)}; |
835 | MlirType funcType = mlirFunctionTypeGet(ctx, numInputs: 2, inputs: funcInputs, numResults: 3, results: funcResults); |
836 | if (mlirFunctionTypeGetNumInputs(type: funcType) != 2) |
837 | return 23; |
838 | if (mlirFunctionTypeGetNumResults(type: funcType) != 3) |
839 | return 24; |
840 | if (!mlirTypeEqual(t1: funcInputs[0], t2: mlirFunctionTypeGetInput(type: funcType, pos: 0)) || |
841 | !mlirTypeEqual(t1: funcInputs[1], t2: mlirFunctionTypeGetInput(type: funcType, pos: 1))) |
842 | return 25; |
843 | if (!mlirTypeEqual(t1: funcResults[0], t2: mlirFunctionTypeGetResult(type: funcType, pos: 0)) || |
844 | !mlirTypeEqual(t1: funcResults[1], t2: mlirFunctionTypeGetResult(type: funcType, pos: 1)) || |
845 | !mlirTypeEqual(t1: funcResults[2], t2: mlirFunctionTypeGetResult(type: funcType, pos: 2))) |
846 | return 26; |
847 | mlirTypeDump(type: funcType); |
848 | fprintf(stderr, format: "\n"); |
849 | // CHECK: (index, i1) -> (i16, i32, i64) |
850 | |
851 | // Opaque type. |
852 | MlirStringRef namespace = mlirStringRefCreate(str: "dialect", length: 7); |
853 | MlirStringRef data = mlirStringRefCreate(str: "type", length: 4); |
854 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
855 | MlirType opaque = mlirOpaqueTypeGet(ctx, dialectNamespace: namespace, typeData: data); |
856 | mlirContextSetAllowUnregisteredDialects(context: ctx, false); |
857 | if (!mlirTypeIsAOpaque(type: opaque) || |
858 | !mlirStringRefEqual(string: mlirOpaqueTypeGetDialectNamespace(type: opaque), |
859 | other: namespace) || |
860 | !mlirStringRefEqual(string: mlirOpaqueTypeGetData(type: opaque), other: data)) |
861 | return 27; |
862 | mlirTypeDump(type: opaque); |
863 | fprintf(stderr, format: "\n"); |
864 | // CHECK: !dialect.type |
865 | |
866 | return 0; |
867 | } |
868 | |
869 | void callbackSetFixedLengthString(const char *data, intptr_t len, |
870 | void *userData) { |
871 | strncpy(dest: userData, src: data, n: len); |
872 | } |
873 | |
874 | bool stringIsEqual(const char *lhs, MlirStringRef rhs) { |
875 | if (strlen(s: lhs) != rhs.length) { |
876 | return false; |
877 | } |
878 | return !strncmp(s1: lhs, s2: rhs.data, n: rhs.length); |
879 | } |
880 | |
881 | int printBuiltinAttributes(MlirContext ctx) { |
882 | MlirAttribute floating = |
883 | mlirFloatAttrDoubleGet(ctx, type: mlirF64TypeGet(ctx), value: 2.0); |
884 | if (!mlirAttributeIsAFloat(attr: floating) || |
885 | fabs(x: mlirFloatAttrGetValueDouble(attr: floating) - 2.0) > 1E-6) |
886 | return 1; |
887 | fprintf(stderr, format: "@attrs\n"); |
888 | mlirAttributeDump(attr: floating); |
889 | // CHECK-LABEL: @attrs |
890 | // CHECK: 2.000000e+00 : f64 |
891 | |
892 | // Exercise mlirAttributeGetType() just for the first one. |
893 | MlirType floatingType = mlirAttributeGetType(attribute: floating); |
894 | mlirTypeDump(type: floatingType); |
895 | // CHECK: f64 |
896 | |
897 | MlirAttribute integer = mlirIntegerAttrGet(type: mlirIntegerTypeGet(ctx, bitwidth: 32), value: 42); |
898 | MlirAttribute signedInteger = |
899 | mlirIntegerAttrGet(type: mlirIntegerTypeSignedGet(ctx, bitwidth: 8), value: -1); |
900 | MlirAttribute unsignedInteger = |
901 | mlirIntegerAttrGet(type: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 8), value: 255); |
902 | if (!mlirAttributeIsAInteger(attr: integer) || |
903 | mlirIntegerAttrGetValueInt(attr: integer) != 42 || |
904 | mlirIntegerAttrGetValueSInt(attr: signedInteger) != -1 || |
905 | mlirIntegerAttrGetValueUInt(attr: unsignedInteger) != 255) |
906 | return 2; |
907 | mlirAttributeDump(attr: integer); |
908 | mlirAttributeDump(attr: signedInteger); |
909 | mlirAttributeDump(attr: unsignedInteger); |
910 | // CHECK: 42 : i32 |
911 | // CHECK: -1 : si8 |
912 | // CHECK: 255 : ui8 |
913 | |
914 | MlirAttribute boolean = mlirBoolAttrGet(ctx, value: 1); |
915 | if (!mlirAttributeIsABool(attr: boolean) || !mlirBoolAttrGetValue(attr: boolean)) |
916 | return 3; |
917 | mlirAttributeDump(attr: boolean); |
918 | // CHECK: true |
919 | |
920 | const char data[] = "abcdefghijklmnopqestuvwxyz"; |
921 | MlirAttribute opaque = |
922 | mlirOpaqueAttrGet(ctx, dialectNamespace: mlirStringRefCreateFromCString(str: "func"), dataLength: 3, data, |
923 | type: mlirNoneTypeGet(ctx)); |
924 | if (!mlirAttributeIsAOpaque(attr: opaque) || |
925 | !stringIsEqual(lhs: "func", rhs: mlirOpaqueAttrGetDialectNamespace(attr: opaque))) |
926 | return 4; |
927 | |
928 | MlirStringRef opaqueData = mlirOpaqueAttrGetData(attr: opaque); |
929 | if (opaqueData.length != 3 || |
930 | strncmp(s1: data, s2: opaqueData.data, n: opaqueData.length)) |
931 | return 5; |
932 | mlirAttributeDump(attr: opaque); |
933 | // CHECK: #func.abc |
934 | |
935 | MlirAttribute string = |
936 | mlirStringAttrGet(ctx, str: mlirStringRefCreate(str: data + 3, length: 2)); |
937 | if (!mlirAttributeIsAString(attr: string)) |
938 | return 6; |
939 | |
940 | MlirStringRef stringValue = mlirStringAttrGetValue(attr: string); |
941 | if (stringValue.length != 2 || |
942 | strncmp(s1: data + 3, s2: stringValue.data, n: stringValue.length)) |
943 | return 7; |
944 | mlirAttributeDump(attr: string); |
945 | // CHECK: "de" |
946 | |
947 | MlirAttribute flatSymbolRef = |
948 | mlirFlatSymbolRefAttrGet(ctx, symbol: mlirStringRefCreate(str: data + 5, length: 3)); |
949 | if (!mlirAttributeIsAFlatSymbolRef(attr: flatSymbolRef)) |
950 | return 8; |
951 | |
952 | MlirStringRef flatSymbolRefValue = |
953 | mlirFlatSymbolRefAttrGetValue(attr: flatSymbolRef); |
954 | if (flatSymbolRefValue.length != 3 || |
955 | strncmp(s1: data + 5, s2: flatSymbolRefValue.data, n: flatSymbolRefValue.length)) |
956 | return 9; |
957 | mlirAttributeDump(attr: flatSymbolRef); |
958 | // CHECK: @fgh |
959 | |
960 | MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef}; |
961 | MlirAttribute symbolRef = |
962 | mlirSymbolRefAttrGet(ctx, symbol: mlirStringRefCreate(str: data + 8, length: 2), numReferences: 2, references: symbols); |
963 | if (!mlirAttributeIsASymbolRef(attr: symbolRef) || |
964 | mlirSymbolRefAttrGetNumNestedReferences(attr: symbolRef) != 2 || |
965 | !mlirAttributeEqual(a1: mlirSymbolRefAttrGetNestedReference(attr: symbolRef, pos: 0), |
966 | a2: flatSymbolRef) || |
967 | !mlirAttributeEqual(a1: mlirSymbolRefAttrGetNestedReference(attr: symbolRef, pos: 1), |
968 | a2: flatSymbolRef)) |
969 | return 10; |
970 | |
971 | MlirStringRef symbolRefLeaf = mlirSymbolRefAttrGetLeafReference(attr: symbolRef); |
972 | MlirStringRef symbolRefRoot = mlirSymbolRefAttrGetRootReference(attr: symbolRef); |
973 | if (symbolRefLeaf.length != 3 || |
974 | strncmp(s1: data + 5, s2: symbolRefLeaf.data, n: symbolRefLeaf.length) || |
975 | symbolRefRoot.length != 2 || |
976 | strncmp(s1: data + 8, s2: symbolRefRoot.data, n: symbolRefRoot.length)) |
977 | return 11; |
978 | mlirAttributeDump(attr: symbolRef); |
979 | // CHECK: @ij::@fgh::@fgh |
980 | |
981 | MlirAttribute type = mlirTypeAttrGet(type: mlirF32TypeGet(ctx)); |
982 | if (!mlirAttributeIsAType(attr: type) || |
983 | !mlirTypeEqual(t1: mlirF32TypeGet(ctx), t2: mlirTypeAttrGetValue(attr: type))) |
984 | return 12; |
985 | mlirAttributeDump(attr: type); |
986 | // CHECK: f32 |
987 | |
988 | MlirAttribute unit = mlirUnitAttrGet(ctx); |
989 | if (!mlirAttributeIsAUnit(attr: unit)) |
990 | return 13; |
991 | mlirAttributeDump(attr: unit); |
992 | // CHECK: unit |
993 | |
994 | int64_t shape[] = {1, 2}; |
995 | |
996 | int bools[] = {0, 1}; |
997 | uint8_t uints8[] = {0u, 1u}; |
998 | int8_t ints8[] = {0, 1}; |
999 | uint16_t uints16[] = {0u, 1u}; |
1000 | int16_t ints16[] = {0, 1}; |
1001 | uint32_t uints32[] = {0u, 1u}; |
1002 | int32_t ints32[] = {0, 1}; |
1003 | uint64_t uints64[] = {0u, 1u}; |
1004 | int64_t ints64[] = {0, 1}; |
1005 | float floats[] = {0.0f, 1.0f}; |
1006 | double doubles[] = {0.0, 1.0}; |
1007 | uint16_t bf16s[] = {0x0, 0x3f80}; |
1008 | uint16_t f16s[] = {0x0, 0x3c00}; |
1009 | MlirAttribute encoding = mlirAttributeGetNull(); |
1010 | MlirAttribute boolElements = mlirDenseElementsAttrBoolGet( |
1011 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 1), encoding), |
1012 | numElements: 2, elements: bools); |
1013 | MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get( |
1014 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 8), |
1015 | encoding), |
1016 | numElements: 2, elements: uints8); |
1017 | MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get( |
1018 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 8), encoding), |
1019 | numElements: 2, elements: ints8); |
1020 | MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get( |
1021 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 16), |
1022 | encoding), |
1023 | numElements: 2, elements: uints16); |
1024 | MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get( |
1025 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 16), encoding), |
1026 | numElements: 2, elements: ints16); |
1027 | MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get( |
1028 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 32), |
1029 | encoding), |
1030 | numElements: 2, elements: uints32); |
1031 | MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get( |
1032 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 32), encoding), |
1033 | numElements: 2, elements: ints32); |
1034 | MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get( |
1035 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 64), |
1036 | encoding), |
1037 | numElements: 2, elements: uints64); |
1038 | MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get( |
1039 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 64), encoding), |
1040 | numElements: 2, elements: ints64); |
1041 | MlirAttribute floatElements = mlirDenseElementsAttrFloatGet( |
1042 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF32TypeGet(ctx), encoding), numElements: 2, |
1043 | elements: floats); |
1044 | MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet( |
1045 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF64TypeGet(ctx), encoding), numElements: 2, |
1046 | elements: doubles); |
1047 | MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get( |
1048 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirBF16TypeGet(ctx), encoding), numElements: 2, |
1049 | elements: bf16s); |
1050 | MlirAttribute f16Elements = mlirDenseElementsAttrFloat16Get( |
1051 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF16TypeGet(ctx), encoding), numElements: 2, |
1052 | elements: f16s); |
1053 | |
1054 | if (!mlirAttributeIsADenseElements(attr: boolElements) || |
1055 | !mlirAttributeIsADenseElements(attr: uint8Elements) || |
1056 | !mlirAttributeIsADenseElements(attr: int8Elements) || |
1057 | !mlirAttributeIsADenseElements(attr: uint32Elements) || |
1058 | !mlirAttributeIsADenseElements(attr: int32Elements) || |
1059 | !mlirAttributeIsADenseElements(attr: uint64Elements) || |
1060 | !mlirAttributeIsADenseElements(attr: int64Elements) || |
1061 | !mlirAttributeIsADenseElements(attr: floatElements) || |
1062 | !mlirAttributeIsADenseElements(attr: doubleElements) || |
1063 | !mlirAttributeIsADenseElements(attr: bf16Elements) || |
1064 | !mlirAttributeIsADenseElements(attr: f16Elements)) |
1065 | return 14; |
1066 | |
1067 | if (mlirDenseElementsAttrGetBoolValue(attr: boolElements, pos: 1) != 1 || |
1068 | mlirDenseElementsAttrGetUInt8Value(attr: uint8Elements, pos: 1) != 1 || |
1069 | mlirDenseElementsAttrGetInt8Value(attr: int8Elements, pos: 1) != 1 || |
1070 | mlirDenseElementsAttrGetUInt16Value(attr: uint16Elements, pos: 1) != 1 || |
1071 | mlirDenseElementsAttrGetInt16Value(attr: int16Elements, pos: 1) != 1 || |
1072 | mlirDenseElementsAttrGetUInt32Value(attr: uint32Elements, pos: 1) != 1 || |
1073 | mlirDenseElementsAttrGetInt32Value(attr: int32Elements, pos: 1) != 1 || |
1074 | mlirDenseElementsAttrGetUInt64Value(attr: uint64Elements, pos: 1) != 1 || |
1075 | mlirDenseElementsAttrGetInt64Value(attr: int64Elements, pos: 1) != 1 || |
1076 | fabsf(x: mlirDenseElementsAttrGetFloatValue(attr: floatElements, pos: 1) - 1.0f) > |
1077 | 1E-6f || |
1078 | fabs(x: mlirDenseElementsAttrGetDoubleValue(attr: doubleElements, pos: 1) - 1.0) > 1E-6) |
1079 | return 15; |
1080 | |
1081 | mlirAttributeDump(attr: boolElements); |
1082 | mlirAttributeDump(attr: uint8Elements); |
1083 | mlirAttributeDump(attr: int8Elements); |
1084 | mlirAttributeDump(attr: uint32Elements); |
1085 | mlirAttributeDump(attr: int32Elements); |
1086 | mlirAttributeDump(attr: uint64Elements); |
1087 | mlirAttributeDump(attr: int64Elements); |
1088 | mlirAttributeDump(attr: floatElements); |
1089 | mlirAttributeDump(attr: doubleElements); |
1090 | mlirAttributeDump(attr: bf16Elements); |
1091 | mlirAttributeDump(attr: f16Elements); |
1092 | // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1> |
1093 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8> |
1094 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8> |
1095 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32> |
1096 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32> |
1097 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64> |
1098 | // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64> |
1099 | // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32> |
1100 | // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64> |
1101 | // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16> |
1102 | // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf16> |
1103 | |
1104 | MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet( |
1105 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 1), encoding), |
1106 | element: 1); |
1107 | MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet( |
1108 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 8), |
1109 | encoding), |
1110 | element: 1); |
1111 | MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet( |
1112 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 8), encoding), |
1113 | element: 1); |
1114 | MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet( |
1115 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 32), |
1116 | encoding), |
1117 | element: 1); |
1118 | MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet( |
1119 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 32), encoding), |
1120 | element: 1); |
1121 | MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet( |
1122 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 64), |
1123 | encoding), |
1124 | element: 1); |
1125 | MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet( |
1126 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 64), encoding), |
1127 | element: 1); |
1128 | MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet( |
1129 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF32TypeGet(ctx), encoding), element: 1.0f); |
1130 | MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet( |
1131 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF64TypeGet(ctx), encoding), element: 1.0); |
1132 | |
1133 | if (!mlirAttributeIsADenseElements(attr: splatBool) || |
1134 | !mlirDenseElementsAttrIsSplat(attr: splatBool) || |
1135 | !mlirAttributeIsADenseElements(attr: splatUInt8) || |
1136 | !mlirDenseElementsAttrIsSplat(attr: splatUInt8) || |
1137 | !mlirAttributeIsADenseElements(attr: splatInt8) || |
1138 | !mlirDenseElementsAttrIsSplat(attr: splatInt8) || |
1139 | !mlirAttributeIsADenseElements(attr: splatUInt32) || |
1140 | !mlirDenseElementsAttrIsSplat(attr: splatUInt32) || |
1141 | !mlirAttributeIsADenseElements(attr: splatInt32) || |
1142 | !mlirDenseElementsAttrIsSplat(attr: splatInt32) || |
1143 | !mlirAttributeIsADenseElements(attr: splatUInt64) || |
1144 | !mlirDenseElementsAttrIsSplat(attr: splatUInt64) || |
1145 | !mlirAttributeIsADenseElements(attr: splatInt64) || |
1146 | !mlirDenseElementsAttrIsSplat(attr: splatInt64) || |
1147 | !mlirAttributeIsADenseElements(attr: splatFloat) || |
1148 | !mlirDenseElementsAttrIsSplat(attr: splatFloat) || |
1149 | !mlirAttributeIsADenseElements(attr: splatDouble) || |
1150 | !mlirDenseElementsAttrIsSplat(attr: splatDouble)) |
1151 | return 16; |
1152 | |
1153 | if (mlirDenseElementsAttrGetBoolSplatValue(attr: splatBool) != 1 || |
1154 | mlirDenseElementsAttrGetUInt8SplatValue(attr: splatUInt8) != 1 || |
1155 | mlirDenseElementsAttrGetInt8SplatValue(attr: splatInt8) != 1 || |
1156 | mlirDenseElementsAttrGetUInt32SplatValue(attr: splatUInt32) != 1 || |
1157 | mlirDenseElementsAttrGetInt32SplatValue(attr: splatInt32) != 1 || |
1158 | mlirDenseElementsAttrGetUInt64SplatValue(attr: splatUInt64) != 1 || |
1159 | mlirDenseElementsAttrGetInt64SplatValue(attr: splatInt64) != 1 || |
1160 | fabsf(x: mlirDenseElementsAttrGetFloatSplatValue(attr: splatFloat) - 1.0f) > |
1161 | 1E-6f || |
1162 | fabs(x: mlirDenseElementsAttrGetDoubleSplatValue(attr: splatDouble) - 1.0) > 1E-6) |
1163 | return 17; |
1164 | |
1165 | const uint8_t *uint8RawData = |
1166 | (const uint8_t *)mlirDenseElementsAttrGetRawData(attr: uint8Elements); |
1167 | const int8_t *int8RawData = |
1168 | (const int8_t *)mlirDenseElementsAttrGetRawData(attr: int8Elements); |
1169 | const uint32_t *uint32RawData = |
1170 | (const uint32_t *)mlirDenseElementsAttrGetRawData(attr: uint32Elements); |
1171 | const int32_t *int32RawData = |
1172 | (const int32_t *)mlirDenseElementsAttrGetRawData(attr: int32Elements); |
1173 | const uint64_t *uint64RawData = |
1174 | (const uint64_t *)mlirDenseElementsAttrGetRawData(attr: uint64Elements); |
1175 | const int64_t *int64RawData = |
1176 | (const int64_t *)mlirDenseElementsAttrGetRawData(attr: int64Elements); |
1177 | const float *floatRawData = |
1178 | (const float *)mlirDenseElementsAttrGetRawData(attr: floatElements); |
1179 | const double *doubleRawData = |
1180 | (const double *)mlirDenseElementsAttrGetRawData(attr: doubleElements); |
1181 | const uint16_t *bf16RawData = |
1182 | (const uint16_t *)mlirDenseElementsAttrGetRawData(attr: bf16Elements); |
1183 | const uint16_t *f16RawData = |
1184 | (const uint16_t *)mlirDenseElementsAttrGetRawData(attr: f16Elements); |
1185 | if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 || |
1186 | int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u || |
1187 | int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u || |
1188 | uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 || |
1189 | floatRawData[0] != 0.0f || floatRawData[1] != 1.0f || |
1190 | doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0 || |
1191 | bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80 || f16RawData[0] != 0 || |
1192 | f16RawData[1] != 0x3c00) |
1193 | return 18; |
1194 | |
1195 | mlirAttributeDump(attr: splatBool); |
1196 | mlirAttributeDump(attr: splatUInt8); |
1197 | mlirAttributeDump(attr: splatInt8); |
1198 | mlirAttributeDump(attr: splatUInt32); |
1199 | mlirAttributeDump(attr: splatInt32); |
1200 | mlirAttributeDump(attr: splatUInt64); |
1201 | mlirAttributeDump(attr: splatInt64); |
1202 | mlirAttributeDump(attr: splatFloat); |
1203 | mlirAttributeDump(attr: splatDouble); |
1204 | // CHECK: dense<true> : tensor<1x2xi1> |
1205 | // CHECK: dense<1> : tensor<1x2xui8> |
1206 | // CHECK: dense<1> : tensor<1x2xi8> |
1207 | // CHECK: dense<1> : tensor<1x2xui32> |
1208 | // CHECK: dense<1> : tensor<1x2xi32> |
1209 | // CHECK: dense<1> : tensor<1x2xui64> |
1210 | // CHECK: dense<1> : tensor<1x2xi64> |
1211 | // CHECK: dense<1.000000e+00> : tensor<1x2xf32> |
1212 | // CHECK: dense<1.000000e+00> : tensor<1x2xf64> |
1213 | |
1214 | mlirAttributeDump(attr: mlirElementsAttrGetValue(attr: floatElements, rank: 2, idxs: uints64)); |
1215 | mlirAttributeDump(attr: mlirElementsAttrGetValue(attr: doubleElements, rank: 2, idxs: uints64)); |
1216 | mlirAttributeDump(attr: mlirElementsAttrGetValue(attr: bf16Elements, rank: 2, idxs: uints64)); |
1217 | mlirAttributeDump(attr: mlirElementsAttrGetValue(attr: f16Elements, rank: 2, idxs: uints64)); |
1218 | // CHECK: 1.000000e+00 : f32 |
1219 | // CHECK: 1.000000e+00 : f64 |
1220 | // CHECK: 1.000000e+00 : bf16 |
1221 | // CHECK: 1.000000e+00 : f16 |
1222 | |
1223 | int64_t indices[] = {0, 1}; |
1224 | int64_t one = 1; |
1225 | MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get( |
1226 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 64), encoding), |
1227 | numElements: 2, elements: indices); |
1228 | MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet( |
1229 | shapedType: mlirRankedTensorTypeGet(rank: 1, shape: &one, elementType: mlirF32TypeGet(ctx), encoding), numElements: 1, |
1230 | elements: floats); |
1231 | MlirAttribute sparseAttr = mlirSparseElementsAttribute( |
1232 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF32TypeGet(ctx), encoding), |
1233 | denseIndices: indicesAttr, denseValues: valuesAttr); |
1234 | mlirAttributeDump(attr: sparseAttr); |
1235 | // CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32> |
1236 | |
1237 | MlirAttribute boolArray = mlirDenseBoolArrayGet(ctx, size: 2, values: bools); |
1238 | MlirAttribute int8Array = mlirDenseI8ArrayGet(ctx, size: 2, values: ints8); |
1239 | MlirAttribute int16Array = mlirDenseI16ArrayGet(ctx, size: 2, values: ints16); |
1240 | MlirAttribute int32Array = mlirDenseI32ArrayGet(ctx, size: 2, values: ints32); |
1241 | MlirAttribute int64Array = mlirDenseI64ArrayGet(ctx, size: 2, values: ints64); |
1242 | MlirAttribute floatArray = mlirDenseF32ArrayGet(ctx, size: 2, values: floats); |
1243 | MlirAttribute doubleArray = mlirDenseF64ArrayGet(ctx, size: 2, values: doubles); |
1244 | if (!mlirAttributeIsADenseBoolArray(attr: boolArray) || |
1245 | !mlirAttributeIsADenseI8Array(attr: int8Array) || |
1246 | !mlirAttributeIsADenseI16Array(attr: int16Array) || |
1247 | !mlirAttributeIsADenseI32Array(attr: int32Array) || |
1248 | !mlirAttributeIsADenseI64Array(attr: int64Array) || |
1249 | !mlirAttributeIsADenseF32Array(attr: floatArray) || |
1250 | !mlirAttributeIsADenseF64Array(attr: doubleArray)) |
1251 | return 19; |
1252 | |
1253 | if (mlirDenseArrayGetNumElements(attr: boolArray) != 2 || |
1254 | mlirDenseArrayGetNumElements(attr: int8Array) != 2 || |
1255 | mlirDenseArrayGetNumElements(attr: int16Array) != 2 || |
1256 | mlirDenseArrayGetNumElements(attr: int32Array) != 2 || |
1257 | mlirDenseArrayGetNumElements(attr: int64Array) != 2 || |
1258 | mlirDenseArrayGetNumElements(attr: floatArray) != 2 || |
1259 | mlirDenseArrayGetNumElements(attr: doubleArray) != 2) |
1260 | return 20; |
1261 | |
1262 | if (mlirDenseBoolArrayGetElement(attr: boolArray, pos: 1) != 1 || |
1263 | mlirDenseI8ArrayGetElement(attr: int8Array, pos: 1) != 1 || |
1264 | mlirDenseI16ArrayGetElement(attr: int16Array, pos: 1) != 1 || |
1265 | mlirDenseI32ArrayGetElement(attr: int32Array, pos: 1) != 1 || |
1266 | mlirDenseI64ArrayGetElement(attr: int64Array, pos: 1) != 1 || |
1267 | fabsf(x: mlirDenseF32ArrayGetElement(attr: floatArray, pos: 1) - 1.0f) > 1E-6f || |
1268 | fabs(x: mlirDenseF64ArrayGetElement(attr: doubleArray, pos: 1) - 1.0) > 1E-6) |
1269 | return 21; |
1270 | |
1271 | int64_t layoutStrides[3] = {5, 7, 13}; |
1272 | MlirAttribute stridedLayoutAttr = |
1273 | mlirStridedLayoutAttrGet(ctx, offset: 42, numStrides: 3, strides: &layoutStrides[0]); |
1274 | |
1275 | // CHECK: strided<[5, 7, 13], offset: 42> |
1276 | mlirAttributeDump(attr: stridedLayoutAttr); |
1277 | |
1278 | if (mlirStridedLayoutAttrGetOffset(attr: stridedLayoutAttr) != 42 || |
1279 | mlirStridedLayoutAttrGetNumStrides(attr: stridedLayoutAttr) != 3 || |
1280 | mlirStridedLayoutAttrGetStride(attr: stridedLayoutAttr, pos: 0) != 5 || |
1281 | mlirStridedLayoutAttrGetStride(attr: stridedLayoutAttr, pos: 1) != 7 || |
1282 | mlirStridedLayoutAttrGetStride(attr: stridedLayoutAttr, pos: 2) != 13) |
1283 | return 22; |
1284 | |
1285 | MlirAttribute uint8Blob = mlirUnmanagedDenseUInt8ResourceElementsAttrGet( |
1286 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 8), |
1287 | encoding), |
1288 | name: mlirStringRefCreateFromCString(str: "resource_ui8"), numElements: 2, elements: uints8); |
1289 | MlirAttribute uint16Blob = mlirUnmanagedDenseUInt16ResourceElementsAttrGet( |
1290 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 16), |
1291 | encoding), |
1292 | name: mlirStringRefCreateFromCString(str: "resource_ui16"), numElements: 2, elements: uints16); |
1293 | MlirAttribute uint32Blob = mlirUnmanagedDenseUInt32ResourceElementsAttrGet( |
1294 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 32), |
1295 | encoding), |
1296 | name: mlirStringRefCreateFromCString(str: "resource_ui32"), numElements: 2, elements: uints32); |
1297 | MlirAttribute uint64Blob = mlirUnmanagedDenseUInt64ResourceElementsAttrGet( |
1298 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeUnsignedGet(ctx, bitwidth: 64), |
1299 | encoding), |
1300 | name: mlirStringRefCreateFromCString(str: "resource_ui64"), numElements: 2, elements: uints64); |
1301 | MlirAttribute int8Blob = mlirUnmanagedDenseInt8ResourceElementsAttrGet( |
1302 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 8), encoding), |
1303 | name: mlirStringRefCreateFromCString(str: "resource_i8"), numElements: 2, elements: ints8); |
1304 | MlirAttribute int16Blob = mlirUnmanagedDenseInt16ResourceElementsAttrGet( |
1305 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 16), encoding), |
1306 | name: mlirStringRefCreateFromCString(str: "resource_i16"), numElements: 2, elements: ints16); |
1307 | MlirAttribute int32Blob = mlirUnmanagedDenseInt32ResourceElementsAttrGet( |
1308 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 32), encoding), |
1309 | name: mlirStringRefCreateFromCString(str: "resource_i32"), numElements: 2, elements: ints32); |
1310 | MlirAttribute int64Blob = mlirUnmanagedDenseInt64ResourceElementsAttrGet( |
1311 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 64), encoding), |
1312 | name: mlirStringRefCreateFromCString(str: "resource_i64"), numElements: 2, elements: ints64); |
1313 | MlirAttribute floatsBlob = mlirUnmanagedDenseFloatResourceElementsAttrGet( |
1314 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF32TypeGet(ctx), encoding), |
1315 | name: mlirStringRefCreateFromCString(str: "resource_f32"), numElements: 2, elements: floats); |
1316 | MlirAttribute doublesBlob = mlirUnmanagedDenseDoubleResourceElementsAttrGet( |
1317 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirF64TypeGet(ctx), encoding), |
1318 | name: mlirStringRefCreateFromCString(str: "resource_f64"), numElements: 2, elements: doubles); |
1319 | MlirAttribute blobBlob = mlirUnmanagedDenseResourceElementsAttrGet( |
1320 | shapedType: mlirRankedTensorTypeGet(rank: 2, shape, elementType: mlirIntegerTypeGet(ctx, bitwidth: 64), encoding), |
1321 | name: mlirStringRefCreateFromCString(str: "resource_i64_blob"), /*data=*/uints64, |
1322 | /*dataLength=*/sizeof(uints64), |
1323 | /*dataAlignment=*/_Alignof(uint64_t), |
1324 | /*dataIsMutable=*/false, |
1325 | /*deleter=*/reportResourceDelete, |
1326 | /*userData=*/(void *)&resourceI64BlobUserData); |
1327 | |
1328 | mlirAttributeDump(attr: uint8Blob); |
1329 | mlirAttributeDump(attr: uint16Blob); |
1330 | mlirAttributeDump(attr: uint32Blob); |
1331 | mlirAttributeDump(attr: uint64Blob); |
1332 | mlirAttributeDump(attr: int8Blob); |
1333 | mlirAttributeDump(attr: int16Blob); |
1334 | mlirAttributeDump(attr: int32Blob); |
1335 | mlirAttributeDump(attr: int64Blob); |
1336 | mlirAttributeDump(attr: floatsBlob); |
1337 | mlirAttributeDump(attr: doublesBlob); |
1338 | mlirAttributeDump(attr: blobBlob); |
1339 | // CHECK: dense_resource<resource_ui8> : tensor<1x2xui8> |
1340 | // CHECK: dense_resource<resource_ui16> : tensor<1x2xui16> |
1341 | // CHECK: dense_resource<resource_ui32> : tensor<1x2xui32> |
1342 | // CHECK: dense_resource<resource_ui64> : tensor<1x2xui64> |
1343 | // CHECK: dense_resource<resource_i8> : tensor<1x2xi8> |
1344 | // CHECK: dense_resource<resource_i16> : tensor<1x2xi16> |
1345 | // CHECK: dense_resource<resource_i32> : tensor<1x2xi32> |
1346 | // CHECK: dense_resource<resource_i64> : tensor<1x2xi64> |
1347 | // CHECK: dense_resource<resource_f32> : tensor<1x2xf32> |
1348 | // CHECK: dense_resource<resource_f64> : tensor<1x2xf64> |
1349 | // CHECK: dense_resource<resource_i64_blob> : tensor<1x2xi64> |
1350 | |
1351 | if (mlirDenseUInt8ResourceElementsAttrGetValue(attr: uint8Blob, pos: 1) != 1 || |
1352 | mlirDenseUInt16ResourceElementsAttrGetValue(attr: uint16Blob, pos: 1) != 1 || |
1353 | mlirDenseUInt32ResourceElementsAttrGetValue(attr: uint32Blob, pos: 1) != 1 || |
1354 | mlirDenseUInt64ResourceElementsAttrGetValue(attr: uint64Blob, pos: 1) != 1 || |
1355 | mlirDenseInt8ResourceElementsAttrGetValue(attr: int8Blob, pos: 1) != 1 || |
1356 | mlirDenseInt16ResourceElementsAttrGetValue(attr: int16Blob, pos: 1) != 1 || |
1357 | mlirDenseInt32ResourceElementsAttrGetValue(attr: int32Blob, pos: 1) != 1 || |
1358 | mlirDenseInt64ResourceElementsAttrGetValue(attr: int64Blob, pos: 1) != 1 || |
1359 | fabsf(x: mlirDenseF32ArrayGetElement(attr: floatArray, pos: 1) - 1.0f) > 1E-6f || |
1360 | fabsf(x: mlirDenseFloatResourceElementsAttrGetValue(attr: floatsBlob, pos: 1) - 1.0f) > |
1361 | 1e-6 || |
1362 | fabs(x: mlirDenseDoubleResourceElementsAttrGetValue(attr: doublesBlob, pos: 1) - 1.0f) > |
1363 | 1e-6 || |
1364 | mlirDenseUInt64ResourceElementsAttrGetValue(attr: blobBlob, pos: 1) != 1) |
1365 | return 23; |
1366 | |
1367 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
1368 | MlirAttribute locAttr = mlirLocationGetAttribute(location: loc); |
1369 | if (!mlirAttributeIsALocation(attr: locAttr)) |
1370 | return 24; |
1371 | |
1372 | return 0; |
1373 | } |
1374 | |
1375 | int printAffineMap(MlirContext ctx) { |
1376 | MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx); |
1377 | MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, dimCount: 3, symbolCount: 2); |
1378 | MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, val: 2); |
1379 | MlirAffineMap multiDimIdentityAffineMap = |
1380 | mlirAffineMapMultiDimIdentityGet(ctx, numDims: 3); |
1381 | MlirAffineMap minorIdentityAffineMap = |
1382 | mlirAffineMapMinorIdentityGet(ctx, dims: 3, results: 2); |
1383 | unsigned permutation[] = {1, 2, 0}; |
1384 | MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet( |
1385 | ctx, size: sizeof(permutation) / sizeof(unsigned), permutation); |
1386 | |
1387 | fprintf(stderr, format: "@affineMap\n"); |
1388 | mlirAffineMapDump(affineMap: emptyAffineMap); |
1389 | mlirAffineMapDump(affineMap); |
1390 | mlirAffineMapDump(affineMap: constAffineMap); |
1391 | mlirAffineMapDump(affineMap: multiDimIdentityAffineMap); |
1392 | mlirAffineMapDump(affineMap: minorIdentityAffineMap); |
1393 | mlirAffineMapDump(affineMap: permutationAffineMap); |
1394 | // CHECK-LABEL: @affineMap |
1395 | // CHECK: () -> () |
1396 | // CHECK: (d0, d1, d2)[s0, s1] -> () |
1397 | // CHECK: () -> (2) |
1398 | // CHECK: (d0, d1, d2) -> (d0, d1, d2) |
1399 | // CHECK: (d0, d1, d2) -> (d1, d2) |
1400 | // CHECK: (d0, d1, d2) -> (d1, d2, d0) |
1401 | |
1402 | if (!mlirAffineMapIsIdentity(affineMap: emptyAffineMap) || |
1403 | mlirAffineMapIsIdentity(affineMap) || |
1404 | mlirAffineMapIsIdentity(affineMap: constAffineMap) || |
1405 | !mlirAffineMapIsIdentity(affineMap: multiDimIdentityAffineMap) || |
1406 | mlirAffineMapIsIdentity(affineMap: minorIdentityAffineMap) || |
1407 | mlirAffineMapIsIdentity(affineMap: permutationAffineMap)) |
1408 | return 1; |
1409 | |
1410 | if (!mlirAffineMapIsMinorIdentity(affineMap: emptyAffineMap) || |
1411 | mlirAffineMapIsMinorIdentity(affineMap) || |
1412 | !mlirAffineMapIsMinorIdentity(affineMap: multiDimIdentityAffineMap) || |
1413 | !mlirAffineMapIsMinorIdentity(affineMap: minorIdentityAffineMap) || |
1414 | mlirAffineMapIsMinorIdentity(affineMap: permutationAffineMap)) |
1415 | return 2; |
1416 | |
1417 | if (!mlirAffineMapIsEmpty(affineMap: emptyAffineMap) || |
1418 | mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(affineMap: constAffineMap) || |
1419 | mlirAffineMapIsEmpty(affineMap: multiDimIdentityAffineMap) || |
1420 | mlirAffineMapIsEmpty(affineMap: minorIdentityAffineMap) || |
1421 | mlirAffineMapIsEmpty(affineMap: permutationAffineMap)) |
1422 | return 3; |
1423 | |
1424 | if (mlirAffineMapIsSingleConstant(affineMap: emptyAffineMap) || |
1425 | mlirAffineMapIsSingleConstant(affineMap) || |
1426 | !mlirAffineMapIsSingleConstant(affineMap: constAffineMap) || |
1427 | mlirAffineMapIsSingleConstant(affineMap: multiDimIdentityAffineMap) || |
1428 | mlirAffineMapIsSingleConstant(affineMap: minorIdentityAffineMap) || |
1429 | mlirAffineMapIsSingleConstant(affineMap: permutationAffineMap)) |
1430 | return 4; |
1431 | |
1432 | if (mlirAffineMapGetSingleConstantResult(affineMap: constAffineMap) != 2) |
1433 | return 5; |
1434 | |
1435 | if (mlirAffineMapGetNumDims(affineMap: emptyAffineMap) != 0 || |
1436 | mlirAffineMapGetNumDims(affineMap) != 3 || |
1437 | mlirAffineMapGetNumDims(affineMap: constAffineMap) != 0 || |
1438 | mlirAffineMapGetNumDims(affineMap: multiDimIdentityAffineMap) != 3 || |
1439 | mlirAffineMapGetNumDims(affineMap: minorIdentityAffineMap) != 3 || |
1440 | mlirAffineMapGetNumDims(affineMap: permutationAffineMap) != 3) |
1441 | return 6; |
1442 | |
1443 | if (mlirAffineMapGetNumSymbols(affineMap: emptyAffineMap) != 0 || |
1444 | mlirAffineMapGetNumSymbols(affineMap) != 2 || |
1445 | mlirAffineMapGetNumSymbols(affineMap: constAffineMap) != 0 || |
1446 | mlirAffineMapGetNumSymbols(affineMap: multiDimIdentityAffineMap) != 0 || |
1447 | mlirAffineMapGetNumSymbols(affineMap: minorIdentityAffineMap) != 0 || |
1448 | mlirAffineMapGetNumSymbols(affineMap: permutationAffineMap) != 0) |
1449 | return 7; |
1450 | |
1451 | if (mlirAffineMapGetNumResults(affineMap: emptyAffineMap) != 0 || |
1452 | mlirAffineMapGetNumResults(affineMap) != 0 || |
1453 | mlirAffineMapGetNumResults(affineMap: constAffineMap) != 1 || |
1454 | mlirAffineMapGetNumResults(affineMap: multiDimIdentityAffineMap) != 3 || |
1455 | mlirAffineMapGetNumResults(affineMap: minorIdentityAffineMap) != 2 || |
1456 | mlirAffineMapGetNumResults(affineMap: permutationAffineMap) != 3) |
1457 | return 8; |
1458 | |
1459 | if (mlirAffineMapGetNumInputs(affineMap: emptyAffineMap) != 0 || |
1460 | mlirAffineMapGetNumInputs(affineMap) != 5 || |
1461 | mlirAffineMapGetNumInputs(affineMap: constAffineMap) != 0 || |
1462 | mlirAffineMapGetNumInputs(affineMap: multiDimIdentityAffineMap) != 3 || |
1463 | mlirAffineMapGetNumInputs(affineMap: minorIdentityAffineMap) != 3 || |
1464 | mlirAffineMapGetNumInputs(affineMap: permutationAffineMap) != 3) |
1465 | return 9; |
1466 | |
1467 | if (!mlirAffineMapIsProjectedPermutation(affineMap: emptyAffineMap) || |
1468 | !mlirAffineMapIsPermutation(affineMap: emptyAffineMap) || |
1469 | mlirAffineMapIsProjectedPermutation(affineMap) || |
1470 | mlirAffineMapIsPermutation(affineMap) || |
1471 | mlirAffineMapIsProjectedPermutation(affineMap: constAffineMap) || |
1472 | mlirAffineMapIsPermutation(affineMap: constAffineMap) || |
1473 | !mlirAffineMapIsProjectedPermutation(affineMap: multiDimIdentityAffineMap) || |
1474 | !mlirAffineMapIsPermutation(affineMap: multiDimIdentityAffineMap) || |
1475 | !mlirAffineMapIsProjectedPermutation(affineMap: minorIdentityAffineMap) || |
1476 | mlirAffineMapIsPermutation(affineMap: minorIdentityAffineMap) || |
1477 | !mlirAffineMapIsProjectedPermutation(affineMap: permutationAffineMap) || |
1478 | !mlirAffineMapIsPermutation(affineMap: permutationAffineMap)) |
1479 | return 10; |
1480 | |
1481 | intptr_t sub[] = {1}; |
1482 | |
1483 | MlirAffineMap subMap = mlirAffineMapGetSubMap( |
1484 | affineMap: multiDimIdentityAffineMap, size: sizeof(sub) / sizeof(intptr_t), resultPos: sub); |
1485 | MlirAffineMap majorSubMap = |
1486 | mlirAffineMapGetMajorSubMap(affineMap: multiDimIdentityAffineMap, numResults: 1); |
1487 | MlirAffineMap minorSubMap = |
1488 | mlirAffineMapGetMinorSubMap(affineMap: multiDimIdentityAffineMap, numResults: 1); |
1489 | |
1490 | mlirAffineMapDump(affineMap: subMap); |
1491 | mlirAffineMapDump(affineMap: majorSubMap); |
1492 | mlirAffineMapDump(affineMap: minorSubMap); |
1493 | // CHECK: (d0, d1, d2) -> (d1) |
1494 | // CHECK: (d0, d1, d2) -> (d0) |
1495 | // CHECK: (d0, d1, d2) -> (d2) |
1496 | |
1497 | // CHECK: distinct[0]<"foo"> |
1498 | mlirAttributeDump(attr: mlirDisctinctAttrCreate( |
1499 | referencedAttr: mlirStringAttrGet(ctx, str: mlirStringRefCreateFromCString(str: "foo")))); |
1500 | |
1501 | return 0; |
1502 | } |
1503 | |
1504 | int printAffineExpr(MlirContext ctx) { |
1505 | MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, position: 5); |
1506 | MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, position: 5); |
1507 | MlirAffineExpr affineConstantExpr = mlirAffineConstantExprGet(ctx, constant: 5); |
1508 | MlirAffineExpr affineAddExpr = |
1509 | mlirAffineAddExprGet(lhs: affineDimExpr, rhs: affineSymbolExpr); |
1510 | MlirAffineExpr affineMulExpr = |
1511 | mlirAffineMulExprGet(lhs: affineDimExpr, rhs: affineSymbolExpr); |
1512 | MlirAffineExpr affineModExpr = |
1513 | mlirAffineModExprGet(lhs: affineDimExpr, rhs: affineSymbolExpr); |
1514 | MlirAffineExpr affineFloorDivExpr = |
1515 | mlirAffineFloorDivExprGet(lhs: affineDimExpr, rhs: affineSymbolExpr); |
1516 | MlirAffineExpr affineCeilDivExpr = |
1517 | mlirAffineCeilDivExprGet(lhs: affineDimExpr, rhs: affineSymbolExpr); |
1518 | |
1519 | // Tests mlirAffineExprDump. |
1520 | fprintf(stderr, format: "@affineExpr\n"); |
1521 | mlirAffineExprDump(affineExpr: affineDimExpr); |
1522 | mlirAffineExprDump(affineExpr: affineSymbolExpr); |
1523 | mlirAffineExprDump(affineExpr: affineConstantExpr); |
1524 | mlirAffineExprDump(affineExpr: affineAddExpr); |
1525 | mlirAffineExprDump(affineExpr: affineMulExpr); |
1526 | mlirAffineExprDump(affineExpr: affineModExpr); |
1527 | mlirAffineExprDump(affineExpr: affineFloorDivExpr); |
1528 | mlirAffineExprDump(affineExpr: affineCeilDivExpr); |
1529 | // CHECK-LABEL: @affineExpr |
1530 | // CHECK: d5 |
1531 | // CHECK: s5 |
1532 | // CHECK: 5 |
1533 | // CHECK: d5 + s5 |
1534 | // CHECK: d5 * s5 |
1535 | // CHECK: d5 mod s5 |
1536 | // CHECK: d5 floordiv s5 |
1537 | // CHECK: d5 ceildiv s5 |
1538 | |
1539 | // Tests methods of affine binary operation expression, takes add expression |
1540 | // as an example. |
1541 | mlirAffineExprDump(affineExpr: mlirAffineBinaryOpExprGetLHS(affineExpr: affineAddExpr)); |
1542 | mlirAffineExprDump(affineExpr: mlirAffineBinaryOpExprGetRHS(affineExpr: affineAddExpr)); |
1543 | // CHECK: d5 |
1544 | // CHECK: s5 |
1545 | |
1546 | // Tests methods of affine dimension expression. |
1547 | if (mlirAffineDimExprGetPosition(affineExpr: affineDimExpr) != 5) |
1548 | return 1; |
1549 | |
1550 | // Tests methods of affine symbol expression. |
1551 | if (mlirAffineSymbolExprGetPosition(affineExpr: affineSymbolExpr) != 5) |
1552 | return 2; |
1553 | |
1554 | // Tests methods of affine constant expression. |
1555 | if (mlirAffineConstantExprGetValue(affineExpr: affineConstantExpr) != 5) |
1556 | return 3; |
1557 | |
1558 | // Tests methods of affine expression. |
1559 | if (mlirAffineExprIsSymbolicOrConstant(affineExpr: affineDimExpr) || |
1560 | !mlirAffineExprIsSymbolicOrConstant(affineExpr: affineSymbolExpr) || |
1561 | !mlirAffineExprIsSymbolicOrConstant(affineExpr: affineConstantExpr) || |
1562 | mlirAffineExprIsSymbolicOrConstant(affineExpr: affineAddExpr) || |
1563 | mlirAffineExprIsSymbolicOrConstant(affineExpr: affineMulExpr) || |
1564 | mlirAffineExprIsSymbolicOrConstant(affineExpr: affineModExpr) || |
1565 | mlirAffineExprIsSymbolicOrConstant(affineExpr: affineFloorDivExpr) || |
1566 | mlirAffineExprIsSymbolicOrConstant(affineExpr: affineCeilDivExpr)) |
1567 | return 4; |
1568 | |
1569 | if (!mlirAffineExprIsPureAffine(affineExpr: affineDimExpr) || |
1570 | !mlirAffineExprIsPureAffine(affineExpr: affineSymbolExpr) || |
1571 | !mlirAffineExprIsPureAffine(affineExpr: affineConstantExpr) || |
1572 | !mlirAffineExprIsPureAffine(affineExpr: affineAddExpr) || |
1573 | mlirAffineExprIsPureAffine(affineExpr: affineMulExpr) || |
1574 | mlirAffineExprIsPureAffine(affineExpr: affineModExpr) || |
1575 | mlirAffineExprIsPureAffine(affineExpr: affineFloorDivExpr) || |
1576 | mlirAffineExprIsPureAffine(affineExpr: affineCeilDivExpr)) |
1577 | return 5; |
1578 | |
1579 | if (mlirAffineExprGetLargestKnownDivisor(affineExpr: affineDimExpr) != 1 || |
1580 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineSymbolExpr) != 1 || |
1581 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineConstantExpr) != 5 || |
1582 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineAddExpr) != 1 || |
1583 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineMulExpr) != 1 || |
1584 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineModExpr) != 1 || |
1585 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineFloorDivExpr) != 1 || |
1586 | mlirAffineExprGetLargestKnownDivisor(affineExpr: affineCeilDivExpr) != 1) |
1587 | return 6; |
1588 | |
1589 | if (!mlirAffineExprIsMultipleOf(affineExpr: affineDimExpr, factor: 1) || |
1590 | !mlirAffineExprIsMultipleOf(affineExpr: affineSymbolExpr, factor: 1) || |
1591 | !mlirAffineExprIsMultipleOf(affineExpr: affineConstantExpr, factor: 5) || |
1592 | !mlirAffineExprIsMultipleOf(affineExpr: affineAddExpr, factor: 1) || |
1593 | !mlirAffineExprIsMultipleOf(affineExpr: affineMulExpr, factor: 1) || |
1594 | !mlirAffineExprIsMultipleOf(affineExpr: affineModExpr, factor: 1) || |
1595 | !mlirAffineExprIsMultipleOf(affineExpr: affineFloorDivExpr, factor: 1) || |
1596 | !mlirAffineExprIsMultipleOf(affineExpr: affineCeilDivExpr, factor: 1)) |
1597 | return 7; |
1598 | |
1599 | if (!mlirAffineExprIsFunctionOfDim(affineExpr: affineDimExpr, position: 5) || |
1600 | mlirAffineExprIsFunctionOfDim(affineExpr: affineSymbolExpr, position: 5) || |
1601 | mlirAffineExprIsFunctionOfDim(affineExpr: affineConstantExpr, position: 5) || |
1602 | !mlirAffineExprIsFunctionOfDim(affineExpr: affineAddExpr, position: 5) || |
1603 | !mlirAffineExprIsFunctionOfDim(affineExpr: affineMulExpr, position: 5) || |
1604 | !mlirAffineExprIsFunctionOfDim(affineExpr: affineModExpr, position: 5) || |
1605 | !mlirAffineExprIsFunctionOfDim(affineExpr: affineFloorDivExpr, position: 5) || |
1606 | !mlirAffineExprIsFunctionOfDim(affineExpr: affineCeilDivExpr, position: 5)) |
1607 | return 8; |
1608 | |
1609 | // Tests 'IsA' methods of affine binary operation expression. |
1610 | if (!mlirAffineExprIsAAdd(affineExpr: affineAddExpr)) |
1611 | return 9; |
1612 | |
1613 | if (!mlirAffineExprIsAMul(affineExpr: affineMulExpr)) |
1614 | return 10; |
1615 | |
1616 | if (!mlirAffineExprIsAMod(affineExpr: affineModExpr)) |
1617 | return 11; |
1618 | |
1619 | if (!mlirAffineExprIsAFloorDiv(affineExpr: affineFloorDivExpr)) |
1620 | return 12; |
1621 | |
1622 | if (!mlirAffineExprIsACeilDiv(affineExpr: affineCeilDivExpr)) |
1623 | return 13; |
1624 | |
1625 | if (!mlirAffineExprIsABinary(affineExpr: affineAddExpr)) |
1626 | return 14; |
1627 | |
1628 | // Test other 'IsA' method on affine expressions. |
1629 | if (!mlirAffineExprIsAConstant(affineExpr: affineConstantExpr)) |
1630 | return 15; |
1631 | |
1632 | if (!mlirAffineExprIsADim(affineExpr: affineDimExpr)) |
1633 | return 16; |
1634 | |
1635 | if (!mlirAffineExprIsASymbol(affineExpr: affineSymbolExpr)) |
1636 | return 17; |
1637 | |
1638 | // Test equality and nullity. |
1639 | MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, position: 5); |
1640 | if (!mlirAffineExprEqual(lhs: affineDimExpr, rhs: otherDimExpr)) |
1641 | return 18; |
1642 | |
1643 | if (mlirAffineExprIsNull(affineExpr: affineDimExpr)) |
1644 | return 19; |
1645 | |
1646 | return 0; |
1647 | } |
1648 | |
1649 | int affineMapFromExprs(MlirContext ctx) { |
1650 | MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, position: 0); |
1651 | MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, position: 1); |
1652 | MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr}; |
1653 | MlirAffineMap map = mlirAffineMapGet(ctx, dimCount: 3, symbolCount: 3, nAffineExprs: 2, affineExprs: exprs); |
1654 | |
1655 | // CHECK-LABEL: @affineMapFromExprs |
1656 | fprintf(stderr, format: "@affineMapFromExprs"); |
1657 | // CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1) |
1658 | mlirAffineMapDump(affineMap: map); |
1659 | |
1660 | if (mlirAffineMapGetNumResults(affineMap: map) != 2) |
1661 | return 1; |
1662 | |
1663 | if (!mlirAffineExprEqual(lhs: mlirAffineMapGetResult(affineMap: map, pos: 0), rhs: affineDimExpr)) |
1664 | return 2; |
1665 | |
1666 | if (!mlirAffineExprEqual(lhs: mlirAffineMapGetResult(affineMap: map, pos: 1), rhs: affineSymbolExpr)) |
1667 | return 3; |
1668 | |
1669 | MlirAffineExpr affineDim2Expr = mlirAffineDimExprGet(ctx, position: 1); |
1670 | MlirAffineExpr composed = mlirAffineExprCompose(affineExpr: affineDim2Expr, affineMap: map); |
1671 | // CHECK: s1 |
1672 | mlirAffineExprDump(affineExpr: composed); |
1673 | if (!mlirAffineExprEqual(lhs: composed, rhs: affineSymbolExpr)) |
1674 | return 4; |
1675 | |
1676 | return 0; |
1677 | } |
1678 | |
1679 | int printIntegerSet(MlirContext ctx) { |
1680 | MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(context: ctx, numDims: 2, numSymbols: 1); |
1681 | |
1682 | // CHECK-LABEL: @printIntegerSet |
1683 | fprintf(stderr, format: "@printIntegerSet"); |
1684 | |
1685 | // CHECK: (d0, d1)[s0] : (1 == 0) |
1686 | mlirIntegerSetDump(set: emptySet); |
1687 | |
1688 | if (!mlirIntegerSetIsCanonicalEmpty(set: emptySet)) |
1689 | return 1; |
1690 | |
1691 | MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(context: ctx, numDims: 2, numSymbols: 1); |
1692 | if (!mlirIntegerSetEqual(s1: emptySet, s2: anotherEmptySet)) |
1693 | return 2; |
1694 | |
1695 | // Construct a set constrained by: |
1696 | // d0 - s0 == 0, |
1697 | // d1 - 42 >= 0. |
1698 | MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, constant: -1); |
1699 | MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, constant: -42); |
1700 | MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, position: 0); |
1701 | MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, position: 1); |
1702 | MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, position: 0); |
1703 | MlirAffineExpr negS0 = mlirAffineMulExprGet(lhs: negOne, rhs: s0); |
1704 | MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(lhs: d0, rhs: negS0); |
1705 | MlirAffineExpr d1minus42 = mlirAffineAddExprGet(lhs: d1, rhs: negFortyTwo); |
1706 | MlirAffineExpr constraints[] = {d0minusS0, d1minus42}; |
1707 | bool flags[] = {true, false}; |
1708 | |
1709 | MlirIntegerSet set = mlirIntegerSetGet(context: ctx, numDims: 2, numSymbols: 1, numConstraints: 2, constraints, eqFlags: flags); |
1710 | // CHECK: (d0, d1)[s0] : ( |
1711 | // CHECK-DAG: d0 - s0 == 0 |
1712 | // CHECK-DAG: d1 - 42 >= 0 |
1713 | mlirIntegerSetDump(set); |
1714 | |
1715 | // Transform d1 into s0. |
1716 | MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, position: 1); |
1717 | MlirAffineExpr repl[] = {d0, s1}; |
1718 | MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, dimReplacements: repl, symbolReplacements: &s0, numResultDims: 1, numResultSymbols: 2); |
1719 | // CHECK: (d0)[s0, s1] : ( |
1720 | // CHECK-DAG: d0 - s0 == 0 |
1721 | // CHECK-DAG: s1 - 42 >= 0 |
1722 | mlirIntegerSetDump(set: replaced); |
1723 | |
1724 | if (mlirIntegerSetGetNumDims(set) != 2) |
1725 | return 3; |
1726 | if (mlirIntegerSetGetNumDims(set: replaced) != 1) |
1727 | return 4; |
1728 | |
1729 | if (mlirIntegerSetGetNumSymbols(set) != 1) |
1730 | return 5; |
1731 | if (mlirIntegerSetGetNumSymbols(set: replaced) != 2) |
1732 | return 6; |
1733 | |
1734 | if (mlirIntegerSetGetNumInputs(set) != 3) |
1735 | return 7; |
1736 | |
1737 | if (mlirIntegerSetGetNumConstraints(set) != 2) |
1738 | return 8; |
1739 | |
1740 | if (mlirIntegerSetGetNumEqualities(set) != 1) |
1741 | return 9; |
1742 | |
1743 | if (mlirIntegerSetGetNumInequalities(set) != 1) |
1744 | return 10; |
1745 | |
1746 | MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, pos: 0); |
1747 | MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, pos: 1); |
1748 | bool isEq1 = mlirIntegerSetIsConstraintEq(set, pos: 0); |
1749 | bool isEq2 = mlirIntegerSetIsConstraintEq(set, pos: 1); |
1750 | if (!mlirAffineExprEqual(lhs: cstr1, rhs: isEq1 ? d0minusS0 : d1minus42)) |
1751 | return 11; |
1752 | if (!mlirAffineExprEqual(lhs: cstr2, rhs: isEq2 ? d0minusS0 : d1minus42)) |
1753 | return 12; |
1754 | |
1755 | return 0; |
1756 | } |
1757 | |
1758 | int registerOnlyStd(void) { |
1759 | MlirContext ctx = mlirContextCreate(); |
1760 | // The built-in dialect is always loaded. |
1761 | if (mlirContextGetNumLoadedDialects(context: ctx) != 1) |
1762 | return 1; |
1763 | |
1764 | MlirDialectHandle stdHandle = mlirGetDialectHandle__func__(); |
1765 | |
1766 | MlirDialect std = mlirContextGetOrLoadDialect( |
1767 | context: ctx, name: mlirDialectHandleGetNamespace(stdHandle)); |
1768 | if (!mlirDialectIsNull(dialect: std)) |
1769 | return 2; |
1770 | |
1771 | mlirDialectHandleRegisterDialect(stdHandle, ctx); |
1772 | |
1773 | std = mlirContextGetOrLoadDialect(context: ctx, |
1774 | name: mlirDialectHandleGetNamespace(stdHandle)); |
1775 | if (mlirDialectIsNull(dialect: std)) |
1776 | return 3; |
1777 | |
1778 | MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx); |
1779 | if (!mlirDialectEqual(dialect1: std, dialect2: alsoStd)) |
1780 | return 4; |
1781 | |
1782 | MlirStringRef stdNs = mlirDialectGetNamespace(dialect: std); |
1783 | MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle); |
1784 | if (stdNs.length != alsoStdNs.length || |
1785 | strncmp(s1: stdNs.data, s2: alsoStdNs.data, n: stdNs.length)) |
1786 | return 5; |
1787 | |
1788 | fprintf(stderr, format: "@registration\n"); |
1789 | // CHECK-LABEL: @registration |
1790 | |
1791 | // CHECK: func.call is_registered: 1 |
1792 | fprintf(stderr, format: "func.call is_registered: %d\n", |
1793 | mlirContextIsRegisteredOperation( |
1794 | context: ctx, name: mlirStringRefCreateFromCString(str: "func.call"))); |
1795 | |
1796 | // CHECK: func.not_existing_op is_registered: 0 |
1797 | fprintf(stderr, format: "func.not_existing_op is_registered: %d\n", |
1798 | mlirContextIsRegisteredOperation( |
1799 | context: ctx, name: mlirStringRefCreateFromCString(str: "func.not_existing_op"))); |
1800 | |
1801 | // CHECK: not_existing_dialect.not_existing_op is_registered: 0 |
1802 | fprintf(stderr, format: "not_existing_dialect.not_existing_op is_registered: %d\n", |
1803 | mlirContextIsRegisteredOperation( |
1804 | context: ctx, name: mlirStringRefCreateFromCString( |
1805 | str: "not_existing_dialect.not_existing_op"))); |
1806 | |
1807 | mlirContextDestroy(context: ctx); |
1808 | return 0; |
1809 | } |
1810 | |
1811 | /// Tests backreference APIs |
1812 | static int testBackreferences(void) { |
1813 | fprintf(stderr, format: "@test_backreferences\n"); |
1814 | |
1815 | MlirContext ctx = mlirContextCreate(); |
1816 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
1817 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
1818 | |
1819 | MlirOperationState opState = |
1820 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "invalid.op"), loc); |
1821 | MlirRegion region = mlirRegionCreate(); |
1822 | MlirBlock block = mlirBlockCreate(nArgs: 0, NULL, NULL); |
1823 | mlirRegionAppendOwnedBlock(region, block); |
1824 | mlirOperationStateAddOwnedRegions(state: &opState, n: 1, regions: ®ion); |
1825 | MlirOperation op = mlirOperationCreate(state: &opState); |
1826 | MlirIdentifier ident = |
1827 | mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "identifier")); |
1828 | |
1829 | if (!mlirContextEqual(ctx1: ctx, ctx2: mlirOperationGetContext(op))) { |
1830 | fprintf(stderr, format: "ERROR: Getting context from operation failed\n"); |
1831 | return 1; |
1832 | } |
1833 | if (!mlirOperationEqual(op, other: mlirBlockGetParentOperation(block))) { |
1834 | fprintf(stderr, format: "ERROR: Getting parent operation from block failed\n"); |
1835 | return 2; |
1836 | } |
1837 | if (!mlirContextEqual(ctx1: ctx, ctx2: mlirIdentifierGetContext(ident))) { |
1838 | fprintf(stderr, format: "ERROR: Getting context from identifier failed\n"); |
1839 | return 3; |
1840 | } |
1841 | |
1842 | mlirOperationDestroy(op); |
1843 | mlirContextDestroy(context: ctx); |
1844 | |
1845 | // CHECK-LABEL: @test_backreferences |
1846 | return 0; |
1847 | } |
1848 | |
1849 | /// Tests operand APIs. |
1850 | int testOperands(void) { |
1851 | fprintf(stderr, format: "@testOperands\n"); |
1852 | // CHECK-LABEL: @testOperands |
1853 | |
1854 | MlirContext ctx = mlirContextCreate(); |
1855 | registerAllUpstreamDialects(ctx); |
1856 | |
1857 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "arith")); |
1858 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "test")); |
1859 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
1860 | MlirType indexType = mlirIndexTypeGet(ctx); |
1861 | |
1862 | // Create some constants to use as operands. |
1863 | MlirAttribute indexZeroLiteral = |
1864 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "0 : index")); |
1865 | MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( |
1866 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value")), |
1867 | attr: indexZeroLiteral); |
1868 | MlirOperationState constZeroState = mlirOperationStateGet( |
1869 | name: mlirStringRefCreateFromCString(str: "arith.constant"), loc); |
1870 | mlirOperationStateAddResults(state: &constZeroState, n: 1, results: &indexType); |
1871 | mlirOperationStateAddAttributes(state: &constZeroState, n: 1, attributes: &indexZeroValueAttr); |
1872 | MlirOperation constZero = mlirOperationCreate(state: &constZeroState); |
1873 | MlirValue constZeroValue = mlirOperationGetResult(op: constZero, pos: 0); |
1874 | |
1875 | MlirAttribute indexOneLiteral = |
1876 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "1 : index")); |
1877 | MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet( |
1878 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value")), |
1879 | attr: indexOneLiteral); |
1880 | MlirOperationState constOneState = mlirOperationStateGet( |
1881 | name: mlirStringRefCreateFromCString(str: "arith.constant"), loc); |
1882 | mlirOperationStateAddResults(state: &constOneState, n: 1, results: &indexType); |
1883 | mlirOperationStateAddAttributes(state: &constOneState, n: 1, attributes: &indexOneValueAttr); |
1884 | MlirOperation constOne = mlirOperationCreate(state: &constOneState); |
1885 | MlirValue constOneValue = mlirOperationGetResult(op: constOne, pos: 0); |
1886 | |
1887 | // Create the operation under test. |
1888 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
1889 | MlirOperationState opState = |
1890 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op"), loc); |
1891 | MlirValue initialOperands[] = {constZeroValue}; |
1892 | mlirOperationStateAddOperands(state: &opState, n: 1, operands: initialOperands); |
1893 | MlirOperation op = mlirOperationCreate(state: &opState); |
1894 | |
1895 | // Test operand APIs. |
1896 | intptr_t numOperands = mlirOperationGetNumOperands(op); |
1897 | fprintf(stderr, format: "Num Operands: %"PRIdPTR "\n", numOperands); |
1898 | // CHECK: Num Operands: 1 |
1899 | |
1900 | MlirValue opOperand1 = mlirOperationGetOperand(op, pos: 0); |
1901 | fprintf(stderr, format: "Original operand: "); |
1902 | mlirValuePrint(value: opOperand1, callback: printToStderr, NULL); |
1903 | // CHECK: Original operand: {{.+}} arith.constant 0 : index |
1904 | |
1905 | mlirOperationSetOperand(op, pos: 0, newValue: constOneValue); |
1906 | MlirValue opOperand2 = mlirOperationGetOperand(op, pos: 0); |
1907 | fprintf(stderr, format: "Updated operand: "); |
1908 | mlirValuePrint(value: opOperand2, callback: printToStderr, NULL); |
1909 | // CHECK: Updated operand: {{.+}} arith.constant 1 : index |
1910 | |
1911 | // Test op operand APIs. |
1912 | MlirOpOperand use1 = mlirValueGetFirstUse(value: opOperand1); |
1913 | if (!mlirOpOperandIsNull(opOperand: use1)) { |
1914 | fprintf(stderr, format: "ERROR: Use should be null\n"); |
1915 | return 1; |
1916 | } |
1917 | |
1918 | MlirOpOperand use2 = mlirValueGetFirstUse(value: opOperand2); |
1919 | if (mlirOpOperandIsNull(opOperand: use2)) { |
1920 | fprintf(stderr, format: "ERROR: Use should not be null\n"); |
1921 | return 2; |
1922 | } |
1923 | |
1924 | fprintf(stderr, format: "Use owner: "); |
1925 | mlirOperationPrint(op: mlirOpOperandGetOwner(opOperand: use2), callback: printToStderr, NULL); |
1926 | fprintf(stderr, format: "\n"); |
1927 | // CHECK: Use owner: "dummy.op" |
1928 | |
1929 | fprintf(stderr, format: "Use operandNumber: %d\n", |
1930 | mlirOpOperandGetOperandNumber(opOperand: use2)); |
1931 | // CHECK: Use operandNumber: 0 |
1932 | |
1933 | use2 = mlirOpOperandGetNextUse(opOperand: use2); |
1934 | if (!mlirOpOperandIsNull(opOperand: use2)) { |
1935 | fprintf(stderr, format: "ERROR: Next use should be null\n"); |
1936 | return 3; |
1937 | } |
1938 | |
1939 | MlirOperationState op2State = |
1940 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op2"), loc); |
1941 | MlirValue initialOperands2[] = {constOneValue}; |
1942 | mlirOperationStateAddOperands(state: &op2State, n: 1, operands: initialOperands2); |
1943 | MlirOperation op2 = mlirOperationCreate(state: &op2State); |
1944 | |
1945 | MlirOpOperand use3 = mlirValueGetFirstUse(value: constOneValue); |
1946 | fprintf(stderr, format: "First use owner: "); |
1947 | mlirOperationPrint(op: mlirOpOperandGetOwner(opOperand: use3), callback: printToStderr, NULL); |
1948 | fprintf(stderr, format: "\n"); |
1949 | // CHECK: First use owner: "dummy.op2" |
1950 | |
1951 | use3 = mlirOpOperandGetNextUse(opOperand: mlirValueGetFirstUse(value: constOneValue)); |
1952 | fprintf(stderr, format: "Second use owner: "); |
1953 | mlirOperationPrint(op: mlirOpOperandGetOwner(opOperand: use3), callback: printToStderr, NULL); |
1954 | fprintf(stderr, format: "\n"); |
1955 | // CHECK: Second use owner: "dummy.op" |
1956 | |
1957 | MlirAttribute indexTwoLiteral = |
1958 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "2 : index")); |
1959 | MlirNamedAttribute indexTwoValueAttr = mlirNamedAttributeGet( |
1960 | name: mlirIdentifierGet(context: ctx, str: mlirStringRefCreateFromCString(str: "value")), |
1961 | attr: indexTwoLiteral); |
1962 | MlirOperationState constTwoState = mlirOperationStateGet( |
1963 | name: mlirStringRefCreateFromCString(str: "arith.constant"), loc); |
1964 | mlirOperationStateAddResults(state: &constTwoState, n: 1, results: &indexType); |
1965 | mlirOperationStateAddAttributes(state: &constTwoState, n: 1, attributes: &indexTwoValueAttr); |
1966 | MlirOperation constTwo = mlirOperationCreate(state: &constTwoState); |
1967 | MlirValue constTwoValue = mlirOperationGetResult(op: constTwo, pos: 0); |
1968 | |
1969 | mlirValueReplaceAllUsesOfWith(of: constOneValue, with: constTwoValue); |
1970 | |
1971 | use3 = mlirValueGetFirstUse(value: constOneValue); |
1972 | if (!mlirOpOperandIsNull(opOperand: use3)) { |
1973 | fprintf(stderr, format: "ERROR: Use should be null\n"); |
1974 | return 4; |
1975 | } |
1976 | |
1977 | MlirOpOperand use4 = mlirValueGetFirstUse(value: constTwoValue); |
1978 | fprintf(stderr, format: "First replacement use owner: "); |
1979 | mlirOperationPrint(op: mlirOpOperandGetOwner(opOperand: use4), callback: printToStderr, NULL); |
1980 | fprintf(stderr, format: "\n"); |
1981 | // CHECK: First replacement use owner: "dummy.op" |
1982 | |
1983 | use4 = mlirOpOperandGetNextUse(opOperand: mlirValueGetFirstUse(value: constTwoValue)); |
1984 | fprintf(stderr, format: "Second replacement use owner: "); |
1985 | mlirOperationPrint(op: mlirOpOperandGetOwner(opOperand: use4), callback: printToStderr, NULL); |
1986 | fprintf(stderr, format: "\n"); |
1987 | // CHECK: Second replacement use owner: "dummy.op2" |
1988 | |
1989 | MlirOpOperand use5 = mlirValueGetFirstUse(value: constTwoValue); |
1990 | MlirOpOperand use6 = mlirOpOperandGetNextUse(opOperand: use5); |
1991 | if (!mlirValueEqual(value1: mlirOpOperandGetValue(opOperand: use5), |
1992 | value2: mlirOpOperandGetValue(opOperand: use6))) { |
1993 | fprintf(stderr, |
1994 | format: "ERROR: First and second operand should share the same value\n"); |
1995 | return 5; |
1996 | } |
1997 | |
1998 | mlirOperationDestroy(op); |
1999 | mlirOperationDestroy(op: op2); |
2000 | mlirOperationDestroy(op: constZero); |
2001 | mlirOperationDestroy(op: constOne); |
2002 | mlirOperationDestroy(op: constTwo); |
2003 | mlirContextDestroy(context: ctx); |
2004 | |
2005 | return 0; |
2006 | } |
2007 | |
2008 | /// Tests clone APIs. |
2009 | int testClone(void) { |
2010 | fprintf(stderr, format: "@testClone\n"); |
2011 | // CHECK-LABEL: @testClone |
2012 | |
2013 | MlirContext ctx = mlirContextCreate(); |
2014 | registerAllUpstreamDialects(ctx); |
2015 | |
2016 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "func")); |
2017 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "arith")); |
2018 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
2019 | MlirType indexType = mlirIndexTypeGet(ctx); |
2020 | MlirStringRef valueStringRef = mlirStringRefCreateFromCString(str: "value"); |
2021 | |
2022 | MlirAttribute indexZeroLiteral = |
2023 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "0 : index")); |
2024 | MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( |
2025 | name: mlirIdentifierGet(context: ctx, str: valueStringRef), attr: indexZeroLiteral); |
2026 | MlirOperationState constZeroState = mlirOperationStateGet( |
2027 | name: mlirStringRefCreateFromCString(str: "arith.constant"), loc); |
2028 | mlirOperationStateAddResults(state: &constZeroState, n: 1, results: &indexType); |
2029 | mlirOperationStateAddAttributes(state: &constZeroState, n: 1, attributes: &indexZeroValueAttr); |
2030 | MlirOperation constZero = mlirOperationCreate(state: &constZeroState); |
2031 | |
2032 | MlirAttribute indexOneLiteral = |
2033 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "1 : index")); |
2034 | MlirOperation constOne = mlirOperationClone(op: constZero); |
2035 | mlirOperationSetAttributeByName(op: constOne, name: valueStringRef, attr: indexOneLiteral); |
2036 | |
2037 | mlirOperationPrint(op: constZero, callback: printToStderr, NULL); |
2038 | mlirOperationPrint(op: constOne, callback: printToStderr, NULL); |
2039 | // CHECK: arith.constant 0 : index |
2040 | // CHECK: arith.constant 1 : index |
2041 | |
2042 | mlirOperationDestroy(op: constZero); |
2043 | mlirOperationDestroy(op: constOne); |
2044 | mlirContextDestroy(context: ctx); |
2045 | return 0; |
2046 | } |
2047 | |
2048 | // Wraps a diagnostic into additional text we can match against. |
2049 | MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) { |
2050 | fprintf(stderr, format: "processing diagnostic (userData: %"PRIdPTR ") <<\n", |
2051 | (intptr_t)userData); |
2052 | mlirDiagnosticPrint(diagnostic, callback: printToStderr, NULL); |
2053 | fprintf(stderr, format: "\n"); |
2054 | MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); |
2055 | mlirLocationPrint(location: loc, callback: printToStderr, NULL); |
2056 | assert(mlirDiagnosticGetNumNotes(diagnostic) == 0); |
2057 | fprintf(stderr, format: "\n>> end of diagnostic (userData: %"PRIdPTR ")\n", |
2058 | (intptr_t)userData); |
2059 | return mlirLogicalResultSuccess(); |
2060 | } |
2061 | |
2062 | // Logs when the delete user data callback is called |
2063 | static void deleteUserData(void *userData) { |
2064 | fprintf(stderr, format: "deleting user data (userData: %"PRIdPTR ")\n", |
2065 | (intptr_t)userData); |
2066 | } |
2067 | |
2068 | int testTypeID(MlirContext ctx) { |
2069 | fprintf(stderr, format: "@testTypeID\n"); |
2070 | |
2071 | // Test getting and comparing type and attribute type ids. |
2072 | MlirType i32 = mlirIntegerTypeGet(ctx, bitwidth: 32); |
2073 | MlirTypeID i32ID = mlirTypeGetTypeID(type: i32); |
2074 | MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, bitwidth: 32); |
2075 | MlirTypeID ui32ID = mlirTypeGetTypeID(type: ui32); |
2076 | MlirType f32 = mlirF32TypeGet(ctx); |
2077 | MlirTypeID f32ID = mlirTypeGetTypeID(type: f32); |
2078 | MlirAttribute i32Attr = mlirIntegerAttrGet(type: i32, value: 1); |
2079 | MlirTypeID i32AttrID = mlirAttributeGetTypeID(attribute: i32Attr); |
2080 | |
2081 | if (mlirTypeIDIsNull(typeID: i32ID) || mlirTypeIDIsNull(typeID: ui32ID) || |
2082 | mlirTypeIDIsNull(typeID: f32ID) || mlirTypeIDIsNull(typeID: i32AttrID)) { |
2083 | fprintf(stderr, format: "ERROR: Expected type ids to be present\n"); |
2084 | return 1; |
2085 | } |
2086 | |
2087 | if (!mlirTypeIDEqual(typeID1: i32ID, typeID2: ui32ID) || |
2088 | mlirTypeIDHashValue(typeID: i32ID) != mlirTypeIDHashValue(typeID: ui32ID)) { |
2089 | fprintf( |
2090 | stderr, |
2091 | format: "ERROR: Expected different integer types to have the same type id\n"); |
2092 | return 2; |
2093 | } |
2094 | |
2095 | if (mlirTypeIDEqual(typeID1: i32ID, typeID2: f32ID)) { |
2096 | fprintf(stderr, |
2097 | format: "ERROR: Expected integer type id to not equal float type id\n"); |
2098 | return 3; |
2099 | } |
2100 | |
2101 | if (mlirTypeIDEqual(typeID1: i32ID, typeID2: i32AttrID)) { |
2102 | fprintf(stderr, format: "ERROR: Expected integer type id to not equal integer " |
2103 | "attribute type id\n"); |
2104 | return 4; |
2105 | } |
2106 | |
2107 | MlirLocation loc = mlirLocationUnknownGet(context: ctx); |
2108 | MlirType indexType = mlirIndexTypeGet(ctx); |
2109 | MlirStringRef valueStringRef = mlirStringRefCreateFromCString(str: "value"); |
2110 | |
2111 | // Create a registered operation, which should have a type id. |
2112 | MlirAttribute indexZeroLiteral = |
2113 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: "0 : index")); |
2114 | MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( |
2115 | name: mlirIdentifierGet(context: ctx, str: valueStringRef), attr: indexZeroLiteral); |
2116 | MlirOperationState constZeroState = mlirOperationStateGet( |
2117 | name: mlirStringRefCreateFromCString(str: "arith.constant"), loc); |
2118 | mlirOperationStateAddResults(state: &constZeroState, n: 1, results: &indexType); |
2119 | mlirOperationStateAddAttributes(state: &constZeroState, n: 1, attributes: &indexZeroValueAttr); |
2120 | MlirOperation constZero = mlirOperationCreate(state: &constZeroState); |
2121 | |
2122 | if (!mlirOperationVerify(op: constZero)) { |
2123 | fprintf(stderr, format: "ERROR: Expected operation to verify correctly\n"); |
2124 | return 5; |
2125 | } |
2126 | |
2127 | if (mlirOperationIsNull(op: constZero)) { |
2128 | fprintf(stderr, format: "ERROR: Expected registered operation to be present\n"); |
2129 | return 6; |
2130 | } |
2131 | |
2132 | MlirTypeID registeredOpID = mlirOperationGetTypeID(op: constZero); |
2133 | |
2134 | if (mlirTypeIDIsNull(typeID: registeredOpID)) { |
2135 | fprintf(stderr, |
2136 | format: "ERROR: Expected registered operation type id to be present\n"); |
2137 | return 7; |
2138 | } |
2139 | |
2140 | // Create an unregistered operation, which should not have a type id. |
2141 | mlirContextSetAllowUnregisteredDialects(context: ctx, true); |
2142 | MlirOperationState opState = |
2143 | mlirOperationStateGet(name: mlirStringRefCreateFromCString(str: "dummy.op"), loc); |
2144 | MlirOperation unregisteredOp = mlirOperationCreate(state: &opState); |
2145 | if (mlirOperationIsNull(op: unregisteredOp)) { |
2146 | fprintf(stderr, format: "ERROR: Expected unregistered operation to be present\n"); |
2147 | return 8; |
2148 | } |
2149 | |
2150 | MlirTypeID unregisteredOpID = mlirOperationGetTypeID(op: unregisteredOp); |
2151 | |
2152 | if (!mlirTypeIDIsNull(typeID: unregisteredOpID)) { |
2153 | fprintf(stderr, |
2154 | format: "ERROR: Expected unregistered operation type id to be null\n"); |
2155 | return 9; |
2156 | } |
2157 | |
2158 | mlirOperationDestroy(op: constZero); |
2159 | mlirOperationDestroy(op: unregisteredOp); |
2160 | |
2161 | return 0; |
2162 | } |
2163 | |
2164 | int testSymbolTable(MlirContext ctx) { |
2165 | fprintf(stderr, format: "@testSymbolTable\n"); |
2166 | |
2167 | const char *moduleString = "func.func private @foo()" |
2168 | "func.func private @bar()"; |
2169 | const char *otherModuleString = "func.func private @qux()" |
2170 | "func.func private @foo()"; |
2171 | |
2172 | MlirModule module = |
2173 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
2174 | MlirModule otherModule = mlirModuleCreateParse( |
2175 | context: ctx, module: mlirStringRefCreateFromCString(str: otherModuleString)); |
2176 | |
2177 | MlirSymbolTable symbolTable = |
2178 | mlirSymbolTableCreate(operation: mlirModuleGetOperation(module)); |
2179 | |
2180 | MlirOperation funcFoo = |
2181 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "foo")); |
2182 | if (mlirOperationIsNull(op: funcFoo)) |
2183 | return 1; |
2184 | |
2185 | MlirOperation funcBar = |
2186 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "bar")); |
2187 | if (mlirOperationEqual(op: funcFoo, other: funcBar)) |
2188 | return 2; |
2189 | |
2190 | MlirOperation missing = |
2191 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "qux")); |
2192 | if (!mlirOperationIsNull(op: missing)) |
2193 | return 3; |
2194 | |
2195 | MlirBlock moduleBody = mlirModuleGetBody(module); |
2196 | MlirBlock otherModuleBody = mlirModuleGetBody(module: otherModule); |
2197 | MlirOperation operation = mlirBlockGetFirstOperation(block: otherModuleBody); |
2198 | mlirOperationRemoveFromParent(op: operation); |
2199 | mlirBlockAppendOwnedOperation(block: moduleBody, operation); |
2200 | |
2201 | // At this moment, the operation is still missing from the symbol table. |
2202 | MlirOperation stillMissing = |
2203 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "qux")); |
2204 | if (!mlirOperationIsNull(op: stillMissing)) |
2205 | return 4; |
2206 | |
2207 | // After it is added to the symbol table, and not only the operation with |
2208 | // which the table is associated, it can be looked up. |
2209 | mlirSymbolTableInsert(symbolTable, operation); |
2210 | MlirOperation funcQux = |
2211 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "qux")); |
2212 | if (!mlirOperationEqual(op: operation, other: funcQux)) |
2213 | return 5; |
2214 | |
2215 | // Erasing from the symbol table also removes the operation. |
2216 | mlirSymbolTableErase(symbolTable, operation: funcBar); |
2217 | MlirOperation nowMissing = |
2218 | mlirSymbolTableLookup(symbolTable, name: mlirStringRefCreateFromCString(str: "bar")); |
2219 | if (!mlirOperationIsNull(op: nowMissing)) |
2220 | return 6; |
2221 | |
2222 | // Adding a symbol with the same name to the table should rename. |
2223 | MlirOperation duplicateNameOp = mlirBlockGetFirstOperation(block: otherModuleBody); |
2224 | mlirOperationRemoveFromParent(op: duplicateNameOp); |
2225 | mlirBlockAppendOwnedOperation(block: moduleBody, operation: duplicateNameOp); |
2226 | MlirAttribute newName = mlirSymbolTableInsert(symbolTable, operation: duplicateNameOp); |
2227 | MlirStringRef newNameStr = mlirStringAttrGetValue(attr: newName); |
2228 | if (mlirStringRefEqual(string: newNameStr, other: mlirStringRefCreateFromCString(str: "foo"))) |
2229 | return 7; |
2230 | MlirAttribute updatedName = mlirOperationGetAttributeByName( |
2231 | op: duplicateNameOp, name: mlirSymbolTableGetSymbolAttributeName()); |
2232 | if (!mlirAttributeEqual(a1: updatedName, a2: newName)) |
2233 | return 8; |
2234 | |
2235 | mlirOperationDump(op: mlirModuleGetOperation(module)); |
2236 | mlirOperationDump(op: mlirModuleGetOperation(module: otherModule)); |
2237 | // clang-format off |
2238 | // CHECK-LABEL: @testSymbolTable |
2239 | // CHECK: module |
2240 | // CHECK: func private @foo |
2241 | // CHECK: func private @qux |
2242 | // CHECK: func private @foo{{.+}} |
2243 | // CHECK: module |
2244 | // CHECK-NOT: @qux |
2245 | // CHECK-NOT: @foo |
2246 | // clang-format on |
2247 | |
2248 | mlirSymbolTableDestroy(symbolTable); |
2249 | mlirModuleDestroy(module); |
2250 | mlirModuleDestroy(module: otherModule); |
2251 | |
2252 | return 0; |
2253 | } |
2254 | |
2255 | typedef struct { |
2256 | const char *x; |
2257 | } callBackData; |
2258 | |
2259 | MlirWalkResult walkCallBack(MlirOperation op, void *rootOpVoid) { |
2260 | fprintf(stderr, format: "%s: %s\n", ((callBackData *)(rootOpVoid))->x, |
2261 | mlirIdentifierStr(ident: mlirOperationGetName(op)).data); |
2262 | return MlirWalkResultAdvance; |
2263 | } |
2264 | |
2265 | MlirWalkResult walkCallBackTestWalkResult(MlirOperation op, void *rootOpVoid) { |
2266 | fprintf(stderr, format: "%s: %s\n", ((callBackData *)(rootOpVoid))->x, |
2267 | mlirIdentifierStr(ident: mlirOperationGetName(op)).data); |
2268 | if (strcmp(s1: mlirIdentifierStr(ident: mlirOperationGetName(op)).data, s2: "func.func") == |
2269 | 0) |
2270 | return MlirWalkResultSkip; |
2271 | if (strcmp(s1: mlirIdentifierStr(ident: mlirOperationGetName(op)).data, s2: "arith.addi") == |
2272 | 0) |
2273 | return MlirWalkResultInterrupt; |
2274 | return MlirWalkResultAdvance; |
2275 | } |
2276 | |
2277 | int testOperationWalk(MlirContext ctx) { |
2278 | // CHECK-LABEL: @testOperationWalk |
2279 | fprintf(stderr, format: "@testOperationWalk\n"); |
2280 | |
2281 | const char *moduleString = "module {\n" |
2282 | " func.func @foo() {\n" |
2283 | " %1 = arith.constant 10: i32\n" |
2284 | " arith.addi %1, %1: i32\n" |
2285 | " return\n" |
2286 | " }\n" |
2287 | " func.func @bar() {\n" |
2288 | " return\n" |
2289 | " }\n" |
2290 | "}"; |
2291 | MlirModule module = |
2292 | mlirModuleCreateParse(context: ctx, module: mlirStringRefCreateFromCString(str: moduleString)); |
2293 | |
2294 | callBackData data; |
2295 | data.x = "i love you"; |
2296 | |
2297 | // CHECK-NEXT: i love you: arith.constant |
2298 | // CHECK-NEXT: i love you: arith.addi |
2299 | // CHECK-NEXT: i love you: func.return |
2300 | // CHECK-NEXT: i love you: func.func |
2301 | // CHECK-NEXT: i love you: func.return |
2302 | // CHECK-NEXT: i love you: func.func |
2303 | // CHECK-NEXT: i love you: builtin.module |
2304 | mlirOperationWalk(op: mlirModuleGetOperation(module), callback: walkCallBack, |
2305 | userData: (void *)(&data), walkOrder: MlirWalkPostOrder); |
2306 | |
2307 | data.x = "i don't love you"; |
2308 | // CHECK-NEXT: i don't love you: builtin.module |
2309 | // CHECK-NEXT: i don't love you: func.func |
2310 | // CHECK-NEXT: i don't love you: arith.constant |
2311 | // CHECK-NEXT: i don't love you: arith.addi |
2312 | // CHECK-NEXT: i don't love you: func.return |
2313 | // CHECK-NEXT: i don't love you: func.func |
2314 | // CHECK-NEXT: i don't love you: func.return |
2315 | mlirOperationWalk(op: mlirModuleGetOperation(module), callback: walkCallBack, |
2316 | userData: (void *)(&data), walkOrder: MlirWalkPreOrder); |
2317 | |
2318 | data.x = "interrupt"; |
2319 | // Interrupted at `arith.addi` |
2320 | // CHECK-NEXT: interrupt: arith.constant |
2321 | // CHECK-NEXT: interrupt: arith.addi |
2322 | mlirOperationWalk(op: mlirModuleGetOperation(module), callback: walkCallBackTestWalkResult, |
2323 | userData: (void *)(&data), walkOrder: MlirWalkPostOrder); |
2324 | |
2325 | data.x = "skip"; |
2326 | // Skip at `func.func` |
2327 | // CHECK-NEXT: skip: builtin.module |
2328 | // CHECK-NEXT: skip: func.func |
2329 | // CHECK-NEXT: skip: func.func |
2330 | mlirOperationWalk(op: mlirModuleGetOperation(module), callback: walkCallBackTestWalkResult, |
2331 | userData: (void *)(&data), walkOrder: MlirWalkPreOrder); |
2332 | |
2333 | mlirModuleDestroy(module); |
2334 | return 0; |
2335 | } |
2336 | |
2337 | int testDialectRegistry(void) { |
2338 | fprintf(stderr, format: "@testDialectRegistry\n"); |
2339 | |
2340 | MlirDialectRegistry registry = mlirDialectRegistryCreate(); |
2341 | if (mlirDialectRegistryIsNull(registry)) { |
2342 | fprintf(stderr, format: "ERROR: Expected registry to be present\n"); |
2343 | return 1; |
2344 | } |
2345 | |
2346 | MlirDialectHandle stdHandle = mlirGetDialectHandle__func__(); |
2347 | mlirDialectHandleInsertDialect(stdHandle, registry); |
2348 | |
2349 | MlirContext ctx = mlirContextCreate(); |
2350 | if (mlirContextGetNumRegisteredDialects(context: ctx) != 0) { |
2351 | fprintf(stderr, |
2352 | format: "ERROR: Expected no dialects to be registered to new context\n"); |
2353 | } |
2354 | |
2355 | mlirContextAppendDialectRegistry(ctx, registry); |
2356 | if (mlirContextGetNumRegisteredDialects(context: ctx) != 1) { |
2357 | fprintf(stderr, format: "ERROR: Expected the dialect in the registry to be " |
2358 | "registered to the context\n"); |
2359 | } |
2360 | |
2361 | mlirContextDestroy(context: ctx); |
2362 | mlirDialectRegistryDestroy(registry); |
2363 | |
2364 | return 0; |
2365 | } |
2366 | |
2367 | void testExplicitThreadPools(void) { |
2368 | MlirLlvmThreadPool threadPool = mlirLlvmThreadPoolCreate(); |
2369 | MlirDialectRegistry registry = mlirDialectRegistryCreate(); |
2370 | mlirRegisterAllDialects(registry); |
2371 | MlirContext context = |
2372 | mlirContextCreateWithRegistry(registry, /*threadingEnabled=*/false); |
2373 | mlirContextSetThreadPool(context, threadPool); |
2374 | mlirContextDestroy(context); |
2375 | mlirDialectRegistryDestroy(registry); |
2376 | mlirLlvmThreadPoolDestroy(pool: threadPool); |
2377 | } |
2378 | |
2379 | void testDiagnostics(void) { |
2380 | MlirContext ctx = mlirContextCreate(); |
2381 | MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( |
2382 | context: ctx, handler: errorHandler, userData: (void *)42, deleteUserData); |
2383 | fprintf(stderr, format: "@test_diagnostics\n"); |
2384 | MlirLocation unknownLoc = mlirLocationUnknownGet(context: ctx); |
2385 | mlirEmitError(location: unknownLoc, message: "test diagnostics"); |
2386 | MlirAttribute unknownAttr = mlirLocationGetAttribute(location: unknownLoc); |
2387 | MlirLocation unknownClone = mlirLocationFromAttribute(attribute: unknownAttr); |
2388 | mlirEmitError(location: unknownClone, message: "test clone"); |
2389 | MlirLocation fileLineColLoc = mlirLocationFileLineColGet( |
2390 | context: ctx, filename: mlirStringRefCreateFromCString(str: "file.c"), line: 1, col: 2); |
2391 | mlirEmitError(location: fileLineColLoc, message: "test diagnostics"); |
2392 | MlirLocation fileLineColRange = mlirLocationFileLineColRangeGet( |
2393 | context: ctx, filename: mlirStringRefCreateFromCString(str: "other-file.c"), start_line: 1, start_col: 2, end_line: 3, end_col: 4); |
2394 | mlirEmitError(location: fileLineColRange, message: "test diagnostics"); |
2395 | MlirLocation callSiteLoc = mlirLocationCallSiteGet( |
2396 | callee: mlirLocationFileLineColGet( |
2397 | context: ctx, filename: mlirStringRefCreateFromCString(str: "other-file.c"), line: 2, col: 3), |
2398 | caller: fileLineColLoc); |
2399 | mlirEmitError(location: callSiteLoc, message: "test diagnostics"); |
2400 | MlirLocation null = {0}; |
2401 | MlirLocation nameLoc = |
2402 | mlirLocationNameGet(context: ctx, name: mlirStringRefCreateFromCString(str: "named"), childLoc: null); |
2403 | mlirEmitError(location: nameLoc, message: "test diagnostics"); |
2404 | MlirLocation locs[2] = {nameLoc, callSiteLoc}; |
2405 | MlirAttribute nullAttr = {0}; |
2406 | MlirLocation fusedLoc = mlirLocationFusedGet(ctx, nLocations: 2, locations: locs, metadata: nullAttr); |
2407 | mlirEmitError(location: fusedLoc, message: "test diagnostics"); |
2408 | mlirContextDetachDiagnosticHandler(context: ctx, id); |
2409 | mlirEmitError(location: unknownLoc, message: "more test diagnostics"); |
2410 | // CHECK-LABEL: @test_diagnostics |
2411 | // CHECK: processing diagnostic (userData: 42) << |
2412 | // CHECK: test diagnostics |
2413 | // CHECK: loc(unknown) |
2414 | // CHECK: processing diagnostic (userData: 42) << |
2415 | // CHECK: test clone |
2416 | // CHECK: loc(unknown) |
2417 | // CHECK: >> end of diagnostic (userData: 42) |
2418 | // CHECK: processing diagnostic (userData: 42) << |
2419 | // CHECK: test diagnostics |
2420 | // CHECK: loc("file.c":1:2) |
2421 | // CHECK: >> end of diagnostic (userData: 42) |
2422 | // CHECK: processing diagnostic (userData: 42) << |
2423 | // CHECK: test diagnostics |
2424 | // CHECK: loc("other-file.c":1:2 to 3:4) |
2425 | // CHECK: >> end of diagnostic (userData: 42) |
2426 | // CHECK: processing diagnostic (userData: 42) << |
2427 | // CHECK: test diagnostics |
2428 | // CHECK: loc(callsite("other-file.c":2:3 at "file.c":1:2)) |
2429 | // CHECK: >> end of diagnostic (userData: 42) |
2430 | // CHECK: processing diagnostic (userData: 42) << |
2431 | // CHECK: test diagnostics |
2432 | // CHECK: loc("named") |
2433 | // CHECK: >> end of diagnostic (userData: 42) |
2434 | // CHECK: processing diagnostic (userData: 42) << |
2435 | // CHECK: test diagnostics |
2436 | // CHECK: loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)]) |
2437 | // CHECK: deleting user data (userData: 42) |
2438 | // CHECK-NOT: processing diagnostic |
2439 | // CHECK: more test diagnostics |
2440 | mlirContextDestroy(context: ctx); |
2441 | } |
2442 | |
2443 | int main(void) { |
2444 | MlirContext ctx = mlirContextCreate(); |
2445 | registerAllUpstreamDialects(ctx); |
2446 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "func")); |
2447 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "memref")); |
2448 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "shape")); |
2449 | mlirContextGetOrLoadDialect(context: ctx, name: mlirStringRefCreateFromCString(str: "scf")); |
2450 | |
2451 | if (constructAndTraverseIr(ctx)) |
2452 | return 1; |
2453 | buildWithInsertionsAndPrint(ctx); |
2454 | if (createOperationWithTypeInference(ctx)) |
2455 | return 2; |
2456 | |
2457 | if (printBuiltinTypes(ctx)) |
2458 | return 3; |
2459 | if (printBuiltinAttributes(ctx)) |
2460 | return 4; |
2461 | if (printAffineMap(ctx)) |
2462 | return 5; |
2463 | if (printAffineExpr(ctx)) |
2464 | return 6; |
2465 | if (affineMapFromExprs(ctx)) |
2466 | return 7; |
2467 | if (printIntegerSet(ctx)) |
2468 | return 8; |
2469 | if (registerOnlyStd()) |
2470 | return 9; |
2471 | if (testBackreferences()) |
2472 | return 10; |
2473 | if (testOperands()) |
2474 | return 11; |
2475 | if (testClone()) |
2476 | return 12; |
2477 | if (testTypeID(ctx)) |
2478 | return 13; |
2479 | if (testSymbolTable(ctx)) |
2480 | return 14; |
2481 | if (testDialectRegistry()) |
2482 | return 15; |
2483 | if (testOperationWalk(ctx)) |
2484 | return 16; |
2485 | |
2486 | testExplicitThreadPools(); |
2487 | testDiagnostics(); |
2488 | |
2489 | // CHECK: DESTROY MAIN CONTEXT |
2490 | // CHECK: reportResourceDelete: resource_i64_blob |
2491 | fprintf(stderr, format: "DESTROY MAIN CONTEXT\n"); |
2492 | mlirContextDestroy(context: ctx); |
2493 | |
2494 | return 0; |
2495 | } |
2496 |
Definitions
- registerAllUpstreamDialects
- ResourceDeleteUserData
- resourceI64BlobUserData
- reportResourceDelete
- populateLoopBody
- makeAndDumpAdd
- OpListNode
- ModuleStats
- collectStatsSingle
- collectStats
- printToStderr
- printFirstOfEach
- constructAndTraverseIr
- buildWithInsertionsAndPrint
- createOperationWithTypeInference
- printBuiltinTypes
- callbackSetFixedLengthString
- stringIsEqual
- printBuiltinAttributes
- printAffineMap
- printAffineExpr
- affineMapFromExprs
- printIntegerSet
- registerOnlyStd
- testBackreferences
- testOperands
- testClone
- errorHandler
- deleteUserData
- testTypeID
- testSymbolTable
- walkCallBack
- walkCallBackTestWalkResult
- testOperationWalk
- testDialectRegistry
- testExplicitThreadPools
- testDiagnostics
Improve your Profiling and Debugging skills
Find out more