1 | //===- pass.c - Simple test of C APIs -------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
4 | // Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s |
11 | */ |
12 | |
13 | #include "mlir-c/Pass.h" |
14 | #include "mlir-c/Dialect/Func.h" |
15 | #include "mlir-c/IR.h" |
16 | #include "mlir-c/RegisterEverything.h" |
17 | #include "mlir-c/Transforms.h" |
18 | |
19 | #include <assert.h> |
20 | #include <math.h> |
21 | #include <stdio.h> |
22 | #include <stdlib.h> |
23 | #include <string.h> |
24 | |
25 | static void registerAllUpstreamDialects(MlirContext ctx) { |
26 | MlirDialectRegistry registry = mlirDialectRegistryCreate(); |
27 | mlirRegisterAllDialects(registry); |
28 | mlirContextAppendDialectRegistry(ctx, registry); |
29 | mlirDialectRegistryDestroy(registry); |
30 | } |
31 | |
32 | void testRunPassOnModule(void) { |
33 | MlirContext ctx = mlirContextCreate(); |
34 | registerAllUpstreamDialects(ctx); |
35 | |
36 | const char *funcAsm = // |
37 | "func.func @foo(%arg0 : i32) -> i32 { \n" |
38 | " %res = arith.addi %arg0, %arg0 : i32 \n" |
39 | " return %res : i32 \n" |
40 | "} \n" ; |
41 | MlirOperation func = |
42 | mlirOperationCreateParse(context: ctx, sourceStr: mlirStringRefCreateFromCString(str: funcAsm), |
43 | sourceName: mlirStringRefCreateFromCString(str: "funcAsm" )); |
44 | if (mlirOperationIsNull(op: func)) { |
45 | fprintf(stderr, format: "Unexpected failure parsing asm.\n" ); |
46 | exit(EXIT_FAILURE); |
47 | } |
48 | |
49 | // Run the print-op-stats pass on the top-level module: |
50 | // CHECK-LABEL: Operations encountered: |
51 | // CHECK: arith.addi , 1 |
52 | // CHECK: func.func , 1 |
53 | // CHECK: func.return , 1 |
54 | { |
55 | MlirPassManager pm = mlirPassManagerCreate(ctx); |
56 | MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); |
57 | mlirPassManagerAddOwnedPass(passManager: pm, pass: printOpStatPass); |
58 | MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: func); |
59 | if (mlirLogicalResultIsFailure(res: success)) { |
60 | fprintf(stderr, format: "Unexpected failure running pass manager.\n" ); |
61 | exit(EXIT_FAILURE); |
62 | } |
63 | mlirPassManagerDestroy(passManager: pm); |
64 | } |
65 | mlirOperationDestroy(op: func); |
66 | mlirContextDestroy(context: ctx); |
67 | } |
68 | |
69 | void testRunPassOnNestedModule(void) { |
70 | MlirContext ctx = mlirContextCreate(); |
71 | registerAllUpstreamDialects(ctx); |
72 | |
73 | const char *moduleAsm = // |
74 | "module { \n" |
75 | " func.func @foo(%arg0 : i32) -> i32 { \n" |
76 | " %res = arith.addi %arg0, %arg0 : i32 \n" |
77 | " return %res : i32 \n" |
78 | " } \n" |
79 | " module { \n" |
80 | " func.func @bar(%arg0 : f32) -> f32 { \n" |
81 | " %res = arith.addf %arg0, %arg0 : f32 \n" |
82 | " return %res : f32 \n" |
83 | " } \n" |
84 | " } \n" |
85 | "} \n" ; |
86 | MlirOperation module = |
87 | mlirOperationCreateParse(context: ctx, sourceStr: mlirStringRefCreateFromCString(str: moduleAsm), |
88 | sourceName: mlirStringRefCreateFromCString(str: "moduleAsm" )); |
89 | if (mlirOperationIsNull(op: module)) |
90 | exit(status: 1); |
91 | |
92 | // Run the print-op-stats pass on functions under the top-level module: |
93 | // CHECK-LABEL: Operations encountered: |
94 | // CHECK: arith.addi , 1 |
95 | // CHECK: func.func , 1 |
96 | // CHECK: func.return , 1 |
97 | { |
98 | MlirPassManager pm = mlirPassManagerCreate(ctx); |
99 | MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder( |
100 | passManager: pm, operationName: mlirStringRefCreateFromCString(str: "func.func" )); |
101 | MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); |
102 | mlirOpPassManagerAddOwnedPass(passManager: nestedFuncPm, pass: printOpStatPass); |
103 | MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module); |
104 | if (mlirLogicalResultIsFailure(res: success)) |
105 | exit(status: 2); |
106 | mlirPassManagerDestroy(passManager: pm); |
107 | } |
108 | // Run the print-op-stats pass on functions under the nested module: |
109 | // CHECK-LABEL: Operations encountered: |
110 | // CHECK: arith.addf , 1 |
111 | // CHECK: func.func , 1 |
112 | // CHECK: func.return , 1 |
113 | { |
114 | MlirPassManager pm = mlirPassManagerCreate(ctx); |
115 | MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder( |
116 | passManager: pm, operationName: mlirStringRefCreateFromCString(str: "builtin.module" )); |
117 | MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder( |
118 | passManager: nestedModulePm, operationName: mlirStringRefCreateFromCString(str: "func.func" )); |
119 | MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); |
120 | mlirOpPassManagerAddOwnedPass(passManager: nestedFuncPm, pass: printOpStatPass); |
121 | MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module); |
122 | if (mlirLogicalResultIsFailure(res: success)) |
123 | exit(status: 2); |
124 | mlirPassManagerDestroy(passManager: pm); |
125 | } |
126 | |
127 | mlirOperationDestroy(op: module); |
128 | mlirContextDestroy(context: ctx); |
129 | } |
130 | |
131 | static void printToStderr(MlirStringRef str, void *userData) { |
132 | (void)userData; |
133 | fwrite(ptr: str.data, size: 1, n: str.length, stderr); |
134 | } |
135 | |
136 | static void dontPrint(MlirStringRef str, void *userData) { |
137 | (void)str; |
138 | (void)userData; |
139 | } |
140 | |
141 | void testPrintPassPipeline(void) { |
142 | MlirContext ctx = mlirContextCreate(); |
143 | MlirPassManager pm = mlirPassManagerCreateOnOperation( |
144 | ctx, anchorOp: mlirStringRefCreateFromCString(str: "any" )); |
145 | // Populate the pass-manager |
146 | MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder( |
147 | passManager: pm, operationName: mlirStringRefCreateFromCString(str: "builtin.module" )); |
148 | MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder( |
149 | passManager: nestedModulePm, operationName: mlirStringRefCreateFromCString(str: "func.func" )); |
150 | MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); |
151 | mlirOpPassManagerAddOwnedPass(passManager: nestedFuncPm, pass: printOpStatPass); |
152 | |
153 | // Print the top level pass manager |
154 | // CHECK: Top-level: any( |
155 | // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false})) |
156 | // CHECK-SAME: ) |
157 | fprintf(stderr, format: "Top-level: " ); |
158 | mlirPrintPassPipeline(passManager: mlirPassManagerGetAsOpPassManager(passManager: pm), callback: printToStderr, |
159 | NULL); |
160 | fprintf(stderr, format: "\n" ); |
161 | |
162 | // Print the pipeline nested one level down |
163 | // CHECK: Nested Module: builtin.module(func.func(print-op-stats{json=false})) |
164 | fprintf(stderr, format: "Nested Module: " ); |
165 | mlirPrintPassPipeline(passManager: nestedModulePm, callback: printToStderr, NULL); |
166 | fprintf(stderr, format: "\n" ); |
167 | |
168 | // Print the pipeline nested two levels down |
169 | // CHECK: Nested Module>Func: func.func(print-op-stats{json=false}) |
170 | fprintf(stderr, format: "Nested Module>Func: " ); |
171 | mlirPrintPassPipeline(passManager: nestedFuncPm, callback: printToStderr, NULL); |
172 | fprintf(stderr, format: "\n" ); |
173 | |
174 | mlirPassManagerDestroy(passManager: pm); |
175 | mlirContextDestroy(context: ctx); |
176 | } |
177 | |
178 | void testParsePassPipeline(void) { |
179 | MlirContext ctx = mlirContextCreate(); |
180 | MlirPassManager pm = mlirPassManagerCreate(ctx); |
181 | // Try parse a pipeline. |
182 | MlirLogicalResult status = mlirParsePassPipeline( |
183 | passManager: mlirPassManagerGetAsOpPassManager(passManager: pm), |
184 | pipeline: mlirStringRefCreateFromCString( |
185 | str: "builtin.module(func.func(print-op-stats{json=false}))" ), |
186 | callback: printToStderr, NULL); |
187 | // Expect a failure, we haven't registered the print-op-stats pass yet. |
188 | if (mlirLogicalResultIsSuccess(res: status)) { |
189 | fprintf( |
190 | stderr, |
191 | format: "Unexpected success parsing pipeline without registering the pass\n" ); |
192 | exit(EXIT_FAILURE); |
193 | } |
194 | // Try again after registrating the pass. |
195 | mlirRegisterTransformsPrintOpStats(); |
196 | status = mlirParsePassPipeline( |
197 | passManager: mlirPassManagerGetAsOpPassManager(passManager: pm), |
198 | pipeline: mlirStringRefCreateFromCString( |
199 | str: "builtin.module(func.func(print-op-stats{json=false}))" ), |
200 | callback: printToStderr, NULL); |
201 | // Expect a failure, we haven't registered the print-op-stats pass yet. |
202 | if (mlirLogicalResultIsFailure(res: status)) { |
203 | fprintf(stderr, |
204 | format: "Unexpected failure parsing pipeline after registering the pass\n" ); |
205 | exit(EXIT_FAILURE); |
206 | } |
207 | |
208 | // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false})) |
209 | fprintf(stderr, format: "Round-trip: " ); |
210 | mlirPrintPassPipeline(passManager: mlirPassManagerGetAsOpPassManager(passManager: pm), callback: printToStderr, |
211 | NULL); |
212 | fprintf(stderr, format: "\n" ); |
213 | |
214 | // Try appending a pass: |
215 | status = mlirOpPassManagerAddPipeline( |
216 | passManager: mlirPassManagerGetAsOpPassManager(passManager: pm), |
217 | pipelineElements: mlirStringRefCreateFromCString(str: "func.func(print-op-stats{json=false})" ), |
218 | callback: printToStderr, NULL); |
219 | if (mlirLogicalResultIsFailure(res: status)) { |
220 | fprintf(stderr, format: "Unexpected failure appending pipeline\n" ); |
221 | exit(EXIT_FAILURE); |
222 | } |
223 | // CHECK: Appended: builtin.module( |
224 | // CHECK-SAME: func.func(print-op-stats{json=false}), |
225 | // CHECK-SAME: func.func(print-op-stats{json=false}) |
226 | // CHECK-SAME: ) |
227 | fprintf(stderr, format: "Appended: " ); |
228 | mlirPrintPassPipeline(passManager: mlirPassManagerGetAsOpPassManager(passManager: pm), callback: printToStderr, |
229 | NULL); |
230 | fprintf(stderr, format: "\n" ); |
231 | |
232 | mlirPassManagerDestroy(passManager: pm); |
233 | mlirContextDestroy(context: ctx); |
234 | } |
235 | |
236 | void testParseErrorCapture(void) { |
237 | // CHECK-LABEL: testParseErrorCapture: |
238 | fprintf(stderr, format: "\nTEST: testParseErrorCapture:\n" ); |
239 | |
240 | MlirContext ctx = mlirContextCreate(); |
241 | MlirPassManager pm = mlirPassManagerCreate(ctx); |
242 | MlirOpPassManager opm = mlirPassManagerGetAsOpPassManager(passManager: pm); |
243 | MlirStringRef invalidPipeline = mlirStringRefCreateFromCString(str: "invalid" ); |
244 | |
245 | // CHECK: mlirParsePassPipeline: |
246 | // CHECK: expected pass pipeline to be wrapped with the anchor operation type |
247 | fprintf(stderr, format: "mlirParsePassPipeline:\n" ); |
248 | if (mlirLogicalResultIsSuccess( |
249 | res: mlirParsePassPipeline(passManager: opm, pipeline: invalidPipeline, callback: printToStderr, NULL))) |
250 | exit(EXIT_FAILURE); |
251 | fprintf(stderr, format: "\n" ); |
252 | |
253 | // CHECK: mlirOpPassManagerAddPipeline: |
254 | // CHECK: 'invalid' does not refer to a registered pass or pass pipeline |
255 | fprintf(stderr, format: "mlirOpPassManagerAddPipeline:\n" ); |
256 | if (mlirLogicalResultIsSuccess(res: mlirOpPassManagerAddPipeline( |
257 | passManager: opm, pipelineElements: invalidPipeline, callback: printToStderr, NULL))) |
258 | exit(EXIT_FAILURE); |
259 | fprintf(stderr, format: "\n" ); |
260 | |
261 | // Make sure all output is going through the callback. |
262 | // CHECK: dontPrint: <> |
263 | fprintf(stderr, format: "dontPrint: <" ); |
264 | if (mlirLogicalResultIsSuccess( |
265 | res: mlirParsePassPipeline(passManager: opm, pipeline: invalidPipeline, callback: dontPrint, NULL))) |
266 | exit(EXIT_FAILURE); |
267 | if (mlirLogicalResultIsSuccess( |
268 | res: mlirOpPassManagerAddPipeline(passManager: opm, pipelineElements: invalidPipeline, callback: dontPrint, NULL))) |
269 | exit(EXIT_FAILURE); |
270 | fprintf(stderr, format: ">\n" ); |
271 | |
272 | mlirPassManagerDestroy(passManager: pm); |
273 | mlirContextDestroy(context: ctx); |
274 | } |
275 | |
276 | struct TestExternalPassUserData { |
277 | int constructCallCount; |
278 | int destructCallCount; |
279 | int initializeCallCount; |
280 | int cloneCallCount; |
281 | int runCallCount; |
282 | }; |
283 | typedef struct TestExternalPassUserData TestExternalPassUserData; |
284 | |
285 | void testConstructExternalPass(void *userData) { |
286 | ++((TestExternalPassUserData *)userData)->constructCallCount; |
287 | } |
288 | |
289 | void testDestructExternalPass(void *userData) { |
290 | ++((TestExternalPassUserData *)userData)->destructCallCount; |
291 | } |
292 | |
293 | MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) { |
294 | ++((TestExternalPassUserData *)userData)->initializeCallCount; |
295 | return mlirLogicalResultSuccess(); |
296 | } |
297 | |
298 | MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx, |
299 | void *userData) { |
300 | ++((TestExternalPassUserData *)userData)->initializeCallCount; |
301 | return mlirLogicalResultFailure(); |
302 | } |
303 | |
304 | void *testCloneExternalPass(void *userData) { |
305 | ++((TestExternalPassUserData *)userData)->cloneCallCount; |
306 | return userData; |
307 | } |
308 | |
309 | void testRunExternalPass(MlirOperation op, MlirExternalPass pass, |
310 | void *userData) { |
311 | ++((TestExternalPassUserData *)userData)->runCallCount; |
312 | } |
313 | |
314 | void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass, |
315 | void *userData) { |
316 | ++((TestExternalPassUserData *)userData)->runCallCount; |
317 | MlirStringRef opName = mlirIdentifierStr(ident: mlirOperationGetName(op)); |
318 | if (!mlirStringRefEqual(string: opName, |
319 | other: mlirStringRefCreateFromCString(str: "func.func" ))) { |
320 | mlirExternalPassSignalFailure(pass); |
321 | } |
322 | } |
323 | |
324 | void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass, |
325 | void *userData) { |
326 | ++((TestExternalPassUserData *)userData)->runCallCount; |
327 | mlirExternalPassSignalFailure(pass); |
328 | } |
329 | |
330 | MlirExternalPassCallbacks makeTestExternalPassCallbacks( |
331 | MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData), |
332 | void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) { |
333 | return (MlirExternalPassCallbacks){testConstructExternalPass, |
334 | testDestructExternalPass, initializePass, |
335 | testCloneExternalPass, runPass}; |
336 | } |
337 | |
338 | void testExternalPass(void) { |
339 | MlirContext ctx = mlirContextCreate(); |
340 | registerAllUpstreamDialects(ctx); |
341 | |
342 | const char *moduleAsm = // |
343 | "module { \n" |
344 | " func.func @foo(%arg0 : i32) -> i32 { \n" |
345 | " %res = arith.addi %arg0, %arg0 : i32 \n" |
346 | " return %res : i32 \n" |
347 | " } \n" |
348 | "}" ; |
349 | MlirOperation module = |
350 | mlirOperationCreateParse(context: ctx, sourceStr: mlirStringRefCreateFromCString(str: moduleAsm), |
351 | sourceName: mlirStringRefCreateFromCString(str: "moduleAsm" )); |
352 | if (mlirOperationIsNull(op: module)) { |
353 | fprintf(stderr, format: "Unexpected failure parsing module.\n" ); |
354 | exit(EXIT_FAILURE); |
355 | } |
356 | |
357 | MlirStringRef description = mlirStringRefCreateFromCString(str: "" ); |
358 | MlirStringRef emptyOpName = mlirStringRefCreateFromCString(str: "" ); |
359 | |
360 | MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); |
361 | |
362 | // Run a generic pass |
363 | { |
364 | MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(allocator: typeIDAllocator); |
365 | MlirStringRef name = mlirStringRefCreateFromCString(str: "TestExternalPass" ); |
366 | MlirStringRef argument = |
367 | mlirStringRefCreateFromCString(str: "test-external-pass" ); |
368 | TestExternalPassUserData userData = {0}; |
369 | |
370 | MlirPass externalPass = mlirCreateExternalPass( |
371 | passID, name, argument, description, opName: emptyOpName, nDependentDialects: 0, NULL, |
372 | callbacks: makeTestExternalPassCallbacks(NULL, runPass: testRunExternalPass), userData: &userData); |
373 | |
374 | if (userData.constructCallCount != 1) { |
375 | fprintf(stderr, format: "Expected constructCallCount to be 1\n" ); |
376 | exit(EXIT_FAILURE); |
377 | } |
378 | |
379 | MlirPassManager pm = mlirPassManagerCreate(ctx); |
380 | mlirPassManagerAddOwnedPass(passManager: pm, pass: externalPass); |
381 | MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module); |
382 | if (mlirLogicalResultIsFailure(res: success)) { |
383 | fprintf(stderr, format: "Unexpected failure running external pass.\n" ); |
384 | exit(EXIT_FAILURE); |
385 | } |
386 | |
387 | if (userData.runCallCount != 1) { |
388 | fprintf(stderr, format: "Expected runCallCount to be 1\n" ); |
389 | exit(EXIT_FAILURE); |
390 | } |
391 | |
392 | mlirPassManagerDestroy(passManager: pm); |
393 | |
394 | if (userData.destructCallCount != userData.constructCallCount) { |
395 | fprintf(stderr, format: "Expected destructCallCount to be equal to " |
396 | "constructCallCount\n" ); |
397 | exit(EXIT_FAILURE); |
398 | } |
399 | } |
400 | |
401 | // Run a func operation pass |
402 | { |
403 | MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(allocator: typeIDAllocator); |
404 | MlirStringRef name = mlirStringRefCreateFromCString(str: "TestExternalFuncPass" ); |
405 | MlirStringRef argument = |
406 | mlirStringRefCreateFromCString(str: "test-external-func-pass" ); |
407 | TestExternalPassUserData userData = {0}; |
408 | MlirDialectHandle funcHandle = mlirGetDialectHandle__func__(); |
409 | MlirStringRef funcOpName = mlirStringRefCreateFromCString(str: "func.func" ); |
410 | |
411 | MlirPass externalPass = mlirCreateExternalPass( |
412 | passID, name, argument, description, opName: funcOpName, nDependentDialects: 1, dependentDialects: &funcHandle, |
413 | callbacks: makeTestExternalPassCallbacks(NULL, runPass: testRunExternalFuncPass), |
414 | userData: &userData); |
415 | |
416 | if (userData.constructCallCount != 1) { |
417 | fprintf(stderr, format: "Expected constructCallCount to be 1\n" ); |
418 | exit(EXIT_FAILURE); |
419 | } |
420 | |
421 | MlirPassManager pm = mlirPassManagerCreate(ctx); |
422 | MlirOpPassManager nestedFuncPm = |
423 | mlirPassManagerGetNestedUnder(passManager: pm, operationName: funcOpName); |
424 | mlirOpPassManagerAddOwnedPass(passManager: nestedFuncPm, pass: externalPass); |
425 | MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module); |
426 | if (mlirLogicalResultIsFailure(res: success)) { |
427 | fprintf(stderr, format: "Unexpected failure running external operation pass.\n" ); |
428 | exit(EXIT_FAILURE); |
429 | } |
430 | |
431 | // Since this is a nested pass, it can be cloned and run in parallel |
432 | if (userData.cloneCallCount != userData.constructCallCount - 1) { |
433 | fprintf(stderr, format: "Expected constructCallCount to be 1\n" ); |
434 | exit(EXIT_FAILURE); |
435 | } |
436 | |
437 | // The pass should only be run once this there is only one func op |
438 | if (userData.runCallCount != 1) { |
439 | fprintf(stderr, format: "Expected runCallCount to be 1\n" ); |
440 | exit(EXIT_FAILURE); |
441 | } |
442 | |
443 | mlirPassManagerDestroy(passManager: pm); |
444 | |
445 | if (userData.destructCallCount != userData.constructCallCount) { |
446 | fprintf(stderr, format: "Expected destructCallCount to be equal to " |
447 | "constructCallCount\n" ); |
448 | exit(EXIT_FAILURE); |
449 | } |
450 | } |
451 | |
452 | // Run a pass with `initialize` set |
453 | { |
454 | MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(allocator: typeIDAllocator); |
455 | MlirStringRef name = mlirStringRefCreateFromCString(str: "TestExternalPass" ); |
456 | MlirStringRef argument = |
457 | mlirStringRefCreateFromCString(str: "test-external-pass" ); |
458 | TestExternalPassUserData userData = {0}; |
459 | |
460 | MlirPass externalPass = mlirCreateExternalPass( |
461 | passID, name, argument, description, opName: emptyOpName, nDependentDialects: 0, NULL, |
462 | callbacks: makeTestExternalPassCallbacks(initializePass: testInitializeExternalPass, |
463 | runPass: testRunExternalPass), |
464 | userData: &userData); |
465 | |
466 | if (userData.constructCallCount != 1) { |
467 | fprintf(stderr, format: "Expected constructCallCount to be 1\n" ); |
468 | exit(EXIT_FAILURE); |
469 | } |
470 | |
471 | MlirPassManager pm = mlirPassManagerCreate(ctx); |
472 | mlirPassManagerAddOwnedPass(passManager: pm, pass: externalPass); |
473 | MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module); |
474 | if (mlirLogicalResultIsFailure(res: success)) { |
475 | fprintf(stderr, format: "Unexpected failure running external pass.\n" ); |
476 | exit(EXIT_FAILURE); |
477 | } |
478 | |
479 | if (userData.initializeCallCount != 1) { |
480 | fprintf(stderr, format: "Expected initializeCallCount to be 1\n" ); |
481 | exit(EXIT_FAILURE); |
482 | } |
483 | |
484 | if (userData.runCallCount != 1) { |
485 | fprintf(stderr, format: "Expected runCallCount to be 1\n" ); |
486 | exit(EXIT_FAILURE); |
487 | } |
488 | |
489 | mlirPassManagerDestroy(passManager: pm); |
490 | |
491 | if (userData.destructCallCount != userData.constructCallCount) { |
492 | fprintf(stderr, format: "Expected destructCallCount to be equal to " |
493 | "constructCallCount\n" ); |
494 | exit(EXIT_FAILURE); |
495 | } |
496 | } |
497 | |
498 | // Run a pass that fails during `initialize` |
499 | { |
500 | MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(allocator: typeIDAllocator); |
501 | MlirStringRef name = |
502 | mlirStringRefCreateFromCString(str: "TestExternalFailingPass" ); |
503 | MlirStringRef argument = |
504 | mlirStringRefCreateFromCString(str: "test-external-failing-pass" ); |
505 | TestExternalPassUserData userData = {0}; |
506 | |
507 | MlirPass externalPass = mlirCreateExternalPass( |
508 | passID, name, argument, description, opName: emptyOpName, nDependentDialects: 0, NULL, |
509 | callbacks: makeTestExternalPassCallbacks(initializePass: testInitializeFailingExternalPass, |
510 | runPass: testRunExternalPass), |
511 | userData: &userData); |
512 | |
513 | if (userData.constructCallCount != 1) { |
514 | fprintf(stderr, format: "Expected constructCallCount to be 1\n" ); |
515 | exit(EXIT_FAILURE); |
516 | } |
517 | |
518 | MlirPassManager pm = mlirPassManagerCreate(ctx); |
519 | mlirPassManagerAddOwnedPass(passManager: pm, pass: externalPass); |
520 | MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module); |
521 | if (mlirLogicalResultIsSuccess(res: success)) { |
522 | fprintf( |
523 | stderr, |
524 | format: "Expected failure running pass manager on failing external pass.\n" ); |
525 | exit(EXIT_FAILURE); |
526 | } |
527 | |
528 | if (userData.initializeCallCount != 1) { |
529 | fprintf(stderr, format: "Expected initializeCallCount to be 1\n" ); |
530 | exit(EXIT_FAILURE); |
531 | } |
532 | |
533 | if (userData.runCallCount != 0) { |
534 | fprintf(stderr, format: "Expected runCallCount to be 0\n" ); |
535 | exit(EXIT_FAILURE); |
536 | } |
537 | |
538 | mlirPassManagerDestroy(passManager: pm); |
539 | |
540 | if (userData.destructCallCount != userData.constructCallCount) { |
541 | fprintf(stderr, format: "Expected destructCallCount to be equal to " |
542 | "constructCallCount\n" ); |
543 | exit(EXIT_FAILURE); |
544 | } |
545 | } |
546 | |
547 | // Run a pass that fails during `run` |
548 | { |
549 | MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(allocator: typeIDAllocator); |
550 | MlirStringRef name = |
551 | mlirStringRefCreateFromCString(str: "TestExternalFailingPass" ); |
552 | MlirStringRef argument = |
553 | mlirStringRefCreateFromCString(str: "test-external-failing-pass" ); |
554 | TestExternalPassUserData userData = {0}; |
555 | |
556 | MlirPass externalPass = mlirCreateExternalPass( |
557 | passID, name, argument, description, opName: emptyOpName, nDependentDialects: 0, NULL, |
558 | callbacks: makeTestExternalPassCallbacks(NULL, runPass: testRunFailingExternalPass), |
559 | userData: &userData); |
560 | |
561 | if (userData.constructCallCount != 1) { |
562 | fprintf(stderr, format: "Expected constructCallCount to be 1\n" ); |
563 | exit(EXIT_FAILURE); |
564 | } |
565 | |
566 | MlirPassManager pm = mlirPassManagerCreate(ctx); |
567 | mlirPassManagerAddOwnedPass(passManager: pm, pass: externalPass); |
568 | MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module); |
569 | if (mlirLogicalResultIsSuccess(res: success)) { |
570 | fprintf( |
571 | stderr, |
572 | format: "Expected failure running pass manager on failing external pass.\n" ); |
573 | exit(EXIT_FAILURE); |
574 | } |
575 | |
576 | if (userData.runCallCount != 1) { |
577 | fprintf(stderr, format: "Expected runCallCount to be 1\n" ); |
578 | exit(EXIT_FAILURE); |
579 | } |
580 | |
581 | mlirPassManagerDestroy(passManager: pm); |
582 | |
583 | if (userData.destructCallCount != userData.constructCallCount) { |
584 | fprintf(stderr, format: "Expected destructCallCount to be equal to " |
585 | "constructCallCount\n" ); |
586 | exit(EXIT_FAILURE); |
587 | } |
588 | } |
589 | |
590 | mlirTypeIDAllocatorDestroy(allocator: typeIDAllocator); |
591 | mlirOperationDestroy(op: module); |
592 | mlirContextDestroy(context: ctx); |
593 | } |
594 | |
595 | int main(void) { |
596 | testRunPassOnModule(); |
597 | testRunPassOnNestedModule(); |
598 | testPrintPassPipeline(); |
599 | testParsePassPipeline(); |
600 | testParseErrorCapture(); |
601 | testExternalPass(); |
602 | return 0; |
603 | } |
604 | |