1//===- pass.c - Simple test of C APIs -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4// Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10/* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
11 */
12
13#include "mlir-c/Pass.h"
14#include "mlir-c/Dialect/Func.h"
15#include "mlir-c/IR.h"
16#include "mlir-c/RegisterEverything.h"
17#include "mlir-c/Transforms.h"
18
19#include <assert.h>
20#include <math.h>
21#include <stdio.h>
22#include <stdlib.h>
23#include <string.h>
24
25static void registerAllUpstreamDialects(MlirContext ctx) {
26 MlirDialectRegistry registry = mlirDialectRegistryCreate();
27 mlirRegisterAllDialects(registry);
28 mlirContextAppendDialectRegistry(ctx, registry);
29 mlirDialectRegistryDestroy(registry);
30}
31
32void testRunPassOnModule(void) {
33 MlirContext ctx = mlirContextCreate();
34 registerAllUpstreamDialects(ctx);
35
36 const char *funcAsm = //
37 "func.func @foo(%arg0 : i32) -> i32 { \n"
38 " %res = arith.addi %arg0, %arg0 : i32 \n"
39 " return %res : i32 \n"
40 "} \n";
41 MlirOperation func =
42 mlirOperationCreateParse(context: ctx, sourceStr: mlirStringRefCreateFromCString(str: funcAsm),
43 sourceName: mlirStringRefCreateFromCString(str: "funcAsm"));
44 if (mlirOperationIsNull(op: func)) {
45 fprintf(stderr, format: "Unexpected failure parsing asm.\n");
46 exit(EXIT_FAILURE);
47 }
48
49 // Run the print-op-stats pass on the top-level module:
50 // CHECK-LABEL: Operations encountered:
51 // CHECK: arith.addi , 1
52 // CHECK: func.func , 1
53 // CHECK: func.return , 1
54 {
55 MlirPassManager pm = mlirPassManagerCreate(ctx);
56 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
57 mlirPassManagerAddOwnedPass(passManager: pm, pass: printOpStatPass);
58 MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: func);
59 if (mlirLogicalResultIsFailure(res: success)) {
60 fprintf(stderr, format: "Unexpected failure running pass manager.\n");
61 exit(EXIT_FAILURE);
62 }
63 mlirPassManagerDestroy(passManager: pm);
64 }
65 mlirOperationDestroy(op: func);
66 mlirContextDestroy(context: ctx);
67}
68
69void testRunPassOnNestedModule(void) {
70 MlirContext ctx = mlirContextCreate();
71 registerAllUpstreamDialects(ctx);
72
73 const char *moduleAsm = //
74 "module { \n"
75 " func.func @foo(%arg0 : i32) -> i32 { \n"
76 " %res = arith.addi %arg0, %arg0 : i32 \n"
77 " return %res : i32 \n"
78 " } \n"
79 " module { \n"
80 " func.func @bar(%arg0 : f32) -> f32 { \n"
81 " %res = arith.addf %arg0, %arg0 : f32 \n"
82 " return %res : f32 \n"
83 " } \n"
84 " } \n"
85 "} \n";
86 MlirOperation module =
87 mlirOperationCreateParse(context: ctx, sourceStr: mlirStringRefCreateFromCString(str: moduleAsm),
88 sourceName: mlirStringRefCreateFromCString(str: "moduleAsm"));
89 if (mlirOperationIsNull(op: module))
90 exit(status: 1);
91
92 // Run the print-op-stats pass on functions under the top-level module:
93 // CHECK-LABEL: Operations encountered:
94 // CHECK: arith.addi , 1
95 // CHECK: func.func , 1
96 // CHECK: func.return , 1
97 {
98 MlirPassManager pm = mlirPassManagerCreate(ctx);
99 MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
100 passManager: pm, operationName: mlirStringRefCreateFromCString(str: "func.func"));
101 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
102 mlirOpPassManagerAddOwnedPass(passManager: nestedFuncPm, pass: printOpStatPass);
103 MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module);
104 if (mlirLogicalResultIsFailure(res: success))
105 exit(status: 2);
106 mlirPassManagerDestroy(passManager: pm);
107 }
108 // Run the print-op-stats pass on functions under the nested module:
109 // CHECK-LABEL: Operations encountered:
110 // CHECK: arith.addf , 1
111 // CHECK: func.func , 1
112 // CHECK: func.return , 1
113 {
114 MlirPassManager pm = mlirPassManagerCreate(ctx);
115 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
116 passManager: pm, operationName: mlirStringRefCreateFromCString(str: "builtin.module"));
117 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
118 passManager: nestedModulePm, operationName: mlirStringRefCreateFromCString(str: "func.func"));
119 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
120 mlirOpPassManagerAddOwnedPass(passManager: nestedFuncPm, pass: printOpStatPass);
121 MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module);
122 if (mlirLogicalResultIsFailure(res: success))
123 exit(status: 2);
124 mlirPassManagerDestroy(passManager: pm);
125 }
126
127 mlirOperationDestroy(op: module);
128 mlirContextDestroy(context: ctx);
129}
130
131static void printToStderr(MlirStringRef str, void *userData) {
132 (void)userData;
133 fwrite(ptr: str.data, size: 1, n: str.length, stderr);
134}
135
136static void dontPrint(MlirStringRef str, void *userData) {
137 (void)str;
138 (void)userData;
139}
140
141void testPrintPassPipeline(void) {
142 MlirContext ctx = mlirContextCreate();
143 MlirPassManager pm = mlirPassManagerCreateOnOperation(
144 ctx, anchorOp: mlirStringRefCreateFromCString(str: "any"));
145 // Populate the pass-manager
146 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
147 passManager: pm, operationName: mlirStringRefCreateFromCString(str: "builtin.module"));
148 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
149 passManager: nestedModulePm, operationName: mlirStringRefCreateFromCString(str: "func.func"));
150 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
151 mlirOpPassManagerAddOwnedPass(passManager: nestedFuncPm, pass: printOpStatPass);
152
153 // Print the top level pass manager
154 // CHECK: Top-level: any(
155 // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false}))
156 // CHECK-SAME: )
157 fprintf(stderr, format: "Top-level: ");
158 mlirPrintPassPipeline(passManager: mlirPassManagerGetAsOpPassManager(passManager: pm), callback: printToStderr,
159 NULL);
160 fprintf(stderr, format: "\n");
161
162 // Print the pipeline nested one level down
163 // CHECK: Nested Module: builtin.module(func.func(print-op-stats{json=false}))
164 fprintf(stderr, format: "Nested Module: ");
165 mlirPrintPassPipeline(passManager: nestedModulePm, callback: printToStderr, NULL);
166 fprintf(stderr, format: "\n");
167
168 // Print the pipeline nested two levels down
169 // CHECK: Nested Module>Func: func.func(print-op-stats{json=false})
170 fprintf(stderr, format: "Nested Module>Func: ");
171 mlirPrintPassPipeline(passManager: nestedFuncPm, callback: printToStderr, NULL);
172 fprintf(stderr, format: "\n");
173
174 mlirPassManagerDestroy(passManager: pm);
175 mlirContextDestroy(context: ctx);
176}
177
178void testParsePassPipeline(void) {
179 MlirContext ctx = mlirContextCreate();
180 MlirPassManager pm = mlirPassManagerCreate(ctx);
181 // Try parse a pipeline.
182 MlirLogicalResult status = mlirParsePassPipeline(
183 passManager: mlirPassManagerGetAsOpPassManager(passManager: pm),
184 pipeline: mlirStringRefCreateFromCString(
185 str: "builtin.module(func.func(print-op-stats{json=false}))"),
186 callback: printToStderr, NULL);
187 // Expect a failure, we haven't registered the print-op-stats pass yet.
188 if (mlirLogicalResultIsSuccess(res: status)) {
189 fprintf(
190 stderr,
191 format: "Unexpected success parsing pipeline without registering the pass\n");
192 exit(EXIT_FAILURE);
193 }
194 // Try again after registrating the pass.
195 mlirRegisterTransformsPrintOpStats();
196 status = mlirParsePassPipeline(
197 passManager: mlirPassManagerGetAsOpPassManager(passManager: pm),
198 pipeline: mlirStringRefCreateFromCString(
199 str: "builtin.module(func.func(print-op-stats{json=false}))"),
200 callback: printToStderr, NULL);
201 // Expect a failure, we haven't registered the print-op-stats pass yet.
202 if (mlirLogicalResultIsFailure(res: status)) {
203 fprintf(stderr,
204 format: "Unexpected failure parsing pipeline after registering the pass\n");
205 exit(EXIT_FAILURE);
206 }
207
208 // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}))
209 fprintf(stderr, format: "Round-trip: ");
210 mlirPrintPassPipeline(passManager: mlirPassManagerGetAsOpPassManager(passManager: pm), callback: printToStderr,
211 NULL);
212 fprintf(stderr, format: "\n");
213
214 // Try appending a pass:
215 status = mlirOpPassManagerAddPipeline(
216 passManager: mlirPassManagerGetAsOpPassManager(passManager: pm),
217 pipelineElements: mlirStringRefCreateFromCString(str: "func.func(print-op-stats{json=false})"),
218 callback: printToStderr, NULL);
219 if (mlirLogicalResultIsFailure(res: status)) {
220 fprintf(stderr, format: "Unexpected failure appending pipeline\n");
221 exit(EXIT_FAILURE);
222 }
223 // CHECK: Appended: builtin.module(
224 // CHECK-SAME: func.func(print-op-stats{json=false}),
225 // CHECK-SAME: func.func(print-op-stats{json=false})
226 // CHECK-SAME: )
227 fprintf(stderr, format: "Appended: ");
228 mlirPrintPassPipeline(passManager: mlirPassManagerGetAsOpPassManager(passManager: pm), callback: printToStderr,
229 NULL);
230 fprintf(stderr, format: "\n");
231
232 mlirPassManagerDestroy(passManager: pm);
233 mlirContextDestroy(context: ctx);
234}
235
236void testParseErrorCapture(void) {
237 // CHECK-LABEL: testParseErrorCapture:
238 fprintf(stderr, format: "\nTEST: testParseErrorCapture:\n");
239
240 MlirContext ctx = mlirContextCreate();
241 MlirPassManager pm = mlirPassManagerCreate(ctx);
242 MlirOpPassManager opm = mlirPassManagerGetAsOpPassManager(passManager: pm);
243 MlirStringRef invalidPipeline = mlirStringRefCreateFromCString(str: "invalid");
244
245 // CHECK: mlirParsePassPipeline:
246 // CHECK: expected pass pipeline to be wrapped with the anchor operation type
247 fprintf(stderr, format: "mlirParsePassPipeline:\n");
248 if (mlirLogicalResultIsSuccess(
249 res: mlirParsePassPipeline(passManager: opm, pipeline: invalidPipeline, callback: printToStderr, NULL)))
250 exit(EXIT_FAILURE);
251 fprintf(stderr, format: "\n");
252
253 // CHECK: mlirOpPassManagerAddPipeline:
254 // CHECK: 'invalid' does not refer to a registered pass or pass pipeline
255 fprintf(stderr, format: "mlirOpPassManagerAddPipeline:\n");
256 if (mlirLogicalResultIsSuccess(res: mlirOpPassManagerAddPipeline(
257 passManager: opm, pipelineElements: invalidPipeline, callback: printToStderr, NULL)))
258 exit(EXIT_FAILURE);
259 fprintf(stderr, format: "\n");
260
261 // Make sure all output is going through the callback.
262 // CHECK: dontPrint: <>
263 fprintf(stderr, format: "dontPrint: <");
264 if (mlirLogicalResultIsSuccess(
265 res: mlirParsePassPipeline(passManager: opm, pipeline: invalidPipeline, callback: dontPrint, NULL)))
266 exit(EXIT_FAILURE);
267 if (mlirLogicalResultIsSuccess(
268 res: mlirOpPassManagerAddPipeline(passManager: opm, pipelineElements: invalidPipeline, callback: dontPrint, NULL)))
269 exit(EXIT_FAILURE);
270 fprintf(stderr, format: ">\n");
271
272 mlirPassManagerDestroy(passManager: pm);
273 mlirContextDestroy(context: ctx);
274}
275
276struct TestExternalPassUserData {
277 int constructCallCount;
278 int destructCallCount;
279 int initializeCallCount;
280 int cloneCallCount;
281 int runCallCount;
282};
283typedef struct TestExternalPassUserData TestExternalPassUserData;
284
285void testConstructExternalPass(void *userData) {
286 ++((TestExternalPassUserData *)userData)->constructCallCount;
287}
288
289void testDestructExternalPass(void *userData) {
290 ++((TestExternalPassUserData *)userData)->destructCallCount;
291}
292
293MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
294 ++((TestExternalPassUserData *)userData)->initializeCallCount;
295 return mlirLogicalResultSuccess();
296}
297
298MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
299 void *userData) {
300 ++((TestExternalPassUserData *)userData)->initializeCallCount;
301 return mlirLogicalResultFailure();
302}
303
304void *testCloneExternalPass(void *userData) {
305 ++((TestExternalPassUserData *)userData)->cloneCallCount;
306 return userData;
307}
308
309void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
310 void *userData) {
311 ++((TestExternalPassUserData *)userData)->runCallCount;
312}
313
314void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
315 void *userData) {
316 ++((TestExternalPassUserData *)userData)->runCallCount;
317 MlirStringRef opName = mlirIdentifierStr(ident: mlirOperationGetName(op));
318 if (!mlirStringRefEqual(string: opName,
319 other: mlirStringRefCreateFromCString(str: "func.func"))) {
320 mlirExternalPassSignalFailure(pass);
321 }
322}
323
324void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
325 void *userData) {
326 ++((TestExternalPassUserData *)userData)->runCallCount;
327 mlirExternalPassSignalFailure(pass);
328}
329
330MlirExternalPassCallbacks makeTestExternalPassCallbacks(
331 MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
332 void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
333 return (MlirExternalPassCallbacks){testConstructExternalPass,
334 testDestructExternalPass, initializePass,
335 testCloneExternalPass, runPass};
336}
337
338void testExternalPass(void) {
339 MlirContext ctx = mlirContextCreate();
340 registerAllUpstreamDialects(ctx);
341
342 const char *moduleAsm = //
343 "module { \n"
344 " func.func @foo(%arg0 : i32) -> i32 { \n"
345 " %res = arith.addi %arg0, %arg0 : i32 \n"
346 " return %res : i32 \n"
347 " } \n"
348 "}";
349 MlirOperation module =
350 mlirOperationCreateParse(context: ctx, sourceStr: mlirStringRefCreateFromCString(str: moduleAsm),
351 sourceName: mlirStringRefCreateFromCString(str: "moduleAsm"));
352 if (mlirOperationIsNull(op: module)) {
353 fprintf(stderr, format: "Unexpected failure parsing module.\n");
354 exit(EXIT_FAILURE);
355 }
356
357 MlirStringRef description = mlirStringRefCreateFromCString(str: "");
358 MlirStringRef emptyOpName = mlirStringRefCreateFromCString(str: "");
359
360 MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
361
362 // Run a generic pass
363 {
364 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(allocator: typeIDAllocator);
365 MlirStringRef name = mlirStringRefCreateFromCString(str: "TestExternalPass");
366 MlirStringRef argument =
367 mlirStringRefCreateFromCString(str: "test-external-pass");
368 TestExternalPassUserData userData = {0};
369
370 MlirPass externalPass = mlirCreateExternalPass(
371 passID, name, argument, description, opName: emptyOpName, nDependentDialects: 0, NULL,
372 callbacks: makeTestExternalPassCallbacks(NULL, runPass: testRunExternalPass), userData: &userData);
373
374 if (userData.constructCallCount != 1) {
375 fprintf(stderr, format: "Expected constructCallCount to be 1\n");
376 exit(EXIT_FAILURE);
377 }
378
379 MlirPassManager pm = mlirPassManagerCreate(ctx);
380 mlirPassManagerAddOwnedPass(passManager: pm, pass: externalPass);
381 MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module);
382 if (mlirLogicalResultIsFailure(res: success)) {
383 fprintf(stderr, format: "Unexpected failure running external pass.\n");
384 exit(EXIT_FAILURE);
385 }
386
387 if (userData.runCallCount != 1) {
388 fprintf(stderr, format: "Expected runCallCount to be 1\n");
389 exit(EXIT_FAILURE);
390 }
391
392 mlirPassManagerDestroy(passManager: pm);
393
394 if (userData.destructCallCount != userData.constructCallCount) {
395 fprintf(stderr, format: "Expected destructCallCount to be equal to "
396 "constructCallCount\n");
397 exit(EXIT_FAILURE);
398 }
399 }
400
401 // Run a func operation pass
402 {
403 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(allocator: typeIDAllocator);
404 MlirStringRef name = mlirStringRefCreateFromCString(str: "TestExternalFuncPass");
405 MlirStringRef argument =
406 mlirStringRefCreateFromCString(str: "test-external-func-pass");
407 TestExternalPassUserData userData = {0};
408 MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
409 MlirStringRef funcOpName = mlirStringRefCreateFromCString(str: "func.func");
410
411 MlirPass externalPass = mlirCreateExternalPass(
412 passID, name, argument, description, opName: funcOpName, nDependentDialects: 1, dependentDialects: &funcHandle,
413 callbacks: makeTestExternalPassCallbacks(NULL, runPass: testRunExternalFuncPass),
414 userData: &userData);
415
416 if (userData.constructCallCount != 1) {
417 fprintf(stderr, format: "Expected constructCallCount to be 1\n");
418 exit(EXIT_FAILURE);
419 }
420
421 MlirPassManager pm = mlirPassManagerCreate(ctx);
422 MlirOpPassManager nestedFuncPm =
423 mlirPassManagerGetNestedUnder(passManager: pm, operationName: funcOpName);
424 mlirOpPassManagerAddOwnedPass(passManager: nestedFuncPm, pass: externalPass);
425 MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module);
426 if (mlirLogicalResultIsFailure(res: success)) {
427 fprintf(stderr, format: "Unexpected failure running external operation pass.\n");
428 exit(EXIT_FAILURE);
429 }
430
431 // Since this is a nested pass, it can be cloned and run in parallel
432 if (userData.cloneCallCount != userData.constructCallCount - 1) {
433 fprintf(stderr, format: "Expected constructCallCount to be 1\n");
434 exit(EXIT_FAILURE);
435 }
436
437 // The pass should only be run once this there is only one func op
438 if (userData.runCallCount != 1) {
439 fprintf(stderr, format: "Expected runCallCount to be 1\n");
440 exit(EXIT_FAILURE);
441 }
442
443 mlirPassManagerDestroy(passManager: pm);
444
445 if (userData.destructCallCount != userData.constructCallCount) {
446 fprintf(stderr, format: "Expected destructCallCount to be equal to "
447 "constructCallCount\n");
448 exit(EXIT_FAILURE);
449 }
450 }
451
452 // Run a pass with `initialize` set
453 {
454 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(allocator: typeIDAllocator);
455 MlirStringRef name = mlirStringRefCreateFromCString(str: "TestExternalPass");
456 MlirStringRef argument =
457 mlirStringRefCreateFromCString(str: "test-external-pass");
458 TestExternalPassUserData userData = {0};
459
460 MlirPass externalPass = mlirCreateExternalPass(
461 passID, name, argument, description, opName: emptyOpName, nDependentDialects: 0, NULL,
462 callbacks: makeTestExternalPassCallbacks(initializePass: testInitializeExternalPass,
463 runPass: testRunExternalPass),
464 userData: &userData);
465
466 if (userData.constructCallCount != 1) {
467 fprintf(stderr, format: "Expected constructCallCount to be 1\n");
468 exit(EXIT_FAILURE);
469 }
470
471 MlirPassManager pm = mlirPassManagerCreate(ctx);
472 mlirPassManagerAddOwnedPass(passManager: pm, pass: externalPass);
473 MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module);
474 if (mlirLogicalResultIsFailure(res: success)) {
475 fprintf(stderr, format: "Unexpected failure running external pass.\n");
476 exit(EXIT_FAILURE);
477 }
478
479 if (userData.initializeCallCount != 1) {
480 fprintf(stderr, format: "Expected initializeCallCount to be 1\n");
481 exit(EXIT_FAILURE);
482 }
483
484 if (userData.runCallCount != 1) {
485 fprintf(stderr, format: "Expected runCallCount to be 1\n");
486 exit(EXIT_FAILURE);
487 }
488
489 mlirPassManagerDestroy(passManager: pm);
490
491 if (userData.destructCallCount != userData.constructCallCount) {
492 fprintf(stderr, format: "Expected destructCallCount to be equal to "
493 "constructCallCount\n");
494 exit(EXIT_FAILURE);
495 }
496 }
497
498 // Run a pass that fails during `initialize`
499 {
500 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(allocator: typeIDAllocator);
501 MlirStringRef name =
502 mlirStringRefCreateFromCString(str: "TestExternalFailingPass");
503 MlirStringRef argument =
504 mlirStringRefCreateFromCString(str: "test-external-failing-pass");
505 TestExternalPassUserData userData = {0};
506
507 MlirPass externalPass = mlirCreateExternalPass(
508 passID, name, argument, description, opName: emptyOpName, nDependentDialects: 0, NULL,
509 callbacks: makeTestExternalPassCallbacks(initializePass: testInitializeFailingExternalPass,
510 runPass: testRunExternalPass),
511 userData: &userData);
512
513 if (userData.constructCallCount != 1) {
514 fprintf(stderr, format: "Expected constructCallCount to be 1\n");
515 exit(EXIT_FAILURE);
516 }
517
518 MlirPassManager pm = mlirPassManagerCreate(ctx);
519 mlirPassManagerAddOwnedPass(passManager: pm, pass: externalPass);
520 MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module);
521 if (mlirLogicalResultIsSuccess(res: success)) {
522 fprintf(
523 stderr,
524 format: "Expected failure running pass manager on failing external pass.\n");
525 exit(EXIT_FAILURE);
526 }
527
528 if (userData.initializeCallCount != 1) {
529 fprintf(stderr, format: "Expected initializeCallCount to be 1\n");
530 exit(EXIT_FAILURE);
531 }
532
533 if (userData.runCallCount != 0) {
534 fprintf(stderr, format: "Expected runCallCount to be 0\n");
535 exit(EXIT_FAILURE);
536 }
537
538 mlirPassManagerDestroy(passManager: pm);
539
540 if (userData.destructCallCount != userData.constructCallCount) {
541 fprintf(stderr, format: "Expected destructCallCount to be equal to "
542 "constructCallCount\n");
543 exit(EXIT_FAILURE);
544 }
545 }
546
547 // Run a pass that fails during `run`
548 {
549 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(allocator: typeIDAllocator);
550 MlirStringRef name =
551 mlirStringRefCreateFromCString(str: "TestExternalFailingPass");
552 MlirStringRef argument =
553 mlirStringRefCreateFromCString(str: "test-external-failing-pass");
554 TestExternalPassUserData userData = {0};
555
556 MlirPass externalPass = mlirCreateExternalPass(
557 passID, name, argument, description, opName: emptyOpName, nDependentDialects: 0, NULL,
558 callbacks: makeTestExternalPassCallbacks(NULL, runPass: testRunFailingExternalPass),
559 userData: &userData);
560
561 if (userData.constructCallCount != 1) {
562 fprintf(stderr, format: "Expected constructCallCount to be 1\n");
563 exit(EXIT_FAILURE);
564 }
565
566 MlirPassManager pm = mlirPassManagerCreate(ctx);
567 mlirPassManagerAddOwnedPass(passManager: pm, pass: externalPass);
568 MlirLogicalResult success = mlirPassManagerRunOnOp(passManager: pm, op: module);
569 if (mlirLogicalResultIsSuccess(res: success)) {
570 fprintf(
571 stderr,
572 format: "Expected failure running pass manager on failing external pass.\n");
573 exit(EXIT_FAILURE);
574 }
575
576 if (userData.runCallCount != 1) {
577 fprintf(stderr, format: "Expected runCallCount to be 1\n");
578 exit(EXIT_FAILURE);
579 }
580
581 mlirPassManagerDestroy(passManager: pm);
582
583 if (userData.destructCallCount != userData.constructCallCount) {
584 fprintf(stderr, format: "Expected destructCallCount to be equal to "
585 "constructCallCount\n");
586 exit(EXIT_FAILURE);
587 }
588 }
589
590 mlirTypeIDAllocatorDestroy(allocator: typeIDAllocator);
591 mlirOperationDestroy(op: module);
592 mlirContextDestroy(context: ctx);
593}
594
595int main(void) {
596 testRunPassOnModule();
597 testRunPassOnNestedModule();
598 testPrintPassPipeline();
599 testParsePassPipeline();
600 testParseErrorCapture();
601 testExternalPass();
602 return 0;
603}
604

source code of mlir/test/CAPI/pass.c