1 | //===- Pass.cpp - C Interface for General Pass Management APIs ------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/Pass.h" |
10 | |
11 | #include "mlir/CAPI/IR.h" |
12 | #include "mlir/CAPI/Pass.h" |
13 | #include "mlir/CAPI/Support.h" |
14 | #include "mlir/CAPI/Utils.h" |
15 | #include "mlir/Pass/PassManager.h" |
16 | #include <optional> |
17 | |
18 | using namespace mlir; |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // PassManager/OpPassManager APIs. |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | MlirPassManager mlirPassManagerCreate(MlirContext ctx) { |
25 | return wrap(cpp: new PassManager(unwrap(c: ctx))); |
26 | } |
27 | |
28 | MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, |
29 | MlirStringRef anchorOp) { |
30 | return wrap(cpp: new PassManager(unwrap(c: ctx), unwrap(ref: anchorOp))); |
31 | } |
32 | |
33 | void mlirPassManagerDestroy(MlirPassManager passManager) { |
34 | delete unwrap(c: passManager); |
35 | } |
36 | |
37 | MlirOpPassManager |
38 | mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) { |
39 | return wrap(cpp: static_cast<OpPassManager *>(unwrap(c: passManager))); |
40 | } |
41 | |
42 | MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, |
43 | MlirOperation op) { |
44 | return wrap(res: unwrap(c: passManager)->run(op: unwrap(c: op))); |
45 | } |
46 | |
47 | void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, |
48 | bool printBeforeAll, bool printAfterAll, |
49 | bool printModuleScope, |
50 | bool printAfterOnlyOnChange, |
51 | bool printAfterOnlyOnFailure, |
52 | MlirOpPrintingFlags flags, |
53 | MlirStringRef treePrintingPath) { |
54 | auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { |
55 | return printBeforeAll; |
56 | }; |
57 | auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) { |
58 | return printAfterAll; |
59 | }; |
60 | if (unwrap(ref: treePrintingPath).empty()) |
61 | return unwrap(c: passManager) |
62 | ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, |
63 | printModuleScope, printAfterOnlyOnChange, |
64 | printAfterOnlyOnFailure, /*out=*/llvm::errs(), |
65 | opPrintingFlags: *unwrap(c: flags)); |
66 | |
67 | unwrap(c: passManager) |
68 | ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass, |
69 | printModuleScope, printAfterOnlyOnChange, |
70 | printAfterOnlyOnFailure, |
71 | printTreeDir: unwrap(ref: treePrintingPath), opPrintingFlags: *unwrap(c: flags)); |
72 | } |
73 | |
74 | void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { |
75 | unwrap(c: passManager)->enableVerifier(enabled: enable); |
76 | } |
77 | |
78 | MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, |
79 | MlirStringRef operationName) { |
80 | return wrap(cpp: &unwrap(c: passManager)->nest(nestedName: unwrap(ref: operationName))); |
81 | } |
82 | |
83 | MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, |
84 | MlirStringRef operationName) { |
85 | return wrap(cpp: &unwrap(c: passManager)->nest(nestedName: unwrap(ref: operationName))); |
86 | } |
87 | |
88 | void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) { |
89 | unwrap(c: passManager)->addPass(pass: std::unique_ptr<Pass>(unwrap(c: pass))); |
90 | } |
91 | |
92 | void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, |
93 | MlirPass pass) { |
94 | unwrap(c: passManager)->addPass(pass: std::unique_ptr<Pass>(unwrap(c: pass))); |
95 | } |
96 | |
97 | MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, |
98 | MlirStringRef pipelineElements, |
99 | MlirStringCallback callback, |
100 | void *userData) { |
101 | detail::CallbackOstream stream(callback, userData); |
102 | return wrap(res: parsePassPipeline(pipeline: unwrap(ref: pipelineElements), pm&: *unwrap(c: passManager), |
103 | errorStream&: stream)); |
104 | } |
105 | |
106 | void mlirPrintPassPipeline(MlirOpPassManager passManager, |
107 | MlirStringCallback callback, void *userData) { |
108 | detail::CallbackOstream stream(callback, userData); |
109 | unwrap(c: passManager)->printAsTextualPipeline(os&: stream); |
110 | } |
111 | |
112 | MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, |
113 | MlirStringRef pipeline, |
114 | MlirStringCallback callback, |
115 | void *userData) { |
116 | detail::CallbackOstream stream(callback, userData); |
117 | FailureOr<OpPassManager> pm = parsePassPipeline(pipeline: unwrap(ref: pipeline), errorStream&: stream); |
118 | if (succeeded(Result: pm)) |
119 | *unwrap(c: passManager) = std::move(*pm); |
120 | return wrap(res: pm); |
121 | } |
122 | |
123 | //===----------------------------------------------------------------------===// |
124 | // External Pass API. |
125 | //===----------------------------------------------------------------------===// |
126 | |
127 | namespace mlir { |
128 | class ExternalPass; |
129 | } // namespace mlir |
130 | DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass) |
131 | |
132 | namespace mlir { |
133 | /// This pass class wraps external passes defined in other languages using the |
134 | /// MLIR C-interface |
135 | class ExternalPass : public Pass { |
136 | public: |
137 | ExternalPass(TypeID passID, StringRef name, StringRef argument, |
138 | StringRef description, std::optional<StringRef> opName, |
139 | ArrayRef<MlirDialectHandle> dependentDialects, |
140 | MlirExternalPassCallbacks callbacks, void *userData) |
141 | : Pass(passID, opName), id(passID), name(name), argument(argument), |
142 | description(description), dependentDialects(dependentDialects), |
143 | callbacks(callbacks), userData(userData) { |
144 | callbacks.construct(userData); |
145 | } |
146 | |
147 | ~ExternalPass() override { callbacks.destruct(userData); } |
148 | |
149 | StringRef getName() const override { return name; } |
150 | StringRef getArgument() const override { return argument; } |
151 | StringRef getDescription() const override { return description; } |
152 | |
153 | void getDependentDialects(DialectRegistry ®istry) const override { |
154 | MlirDialectRegistry cRegistry = wrap(cpp: ®istry); |
155 | for (MlirDialectHandle dialect : dependentDialects) |
156 | mlirDialectHandleInsertDialect(dialect, cRegistry); |
157 | } |
158 | |
159 | void signalPassFailure() { Pass::signalPassFailure(); } |
160 | |
161 | protected: |
162 | LogicalResult initialize(MLIRContext *ctx) override { |
163 | if (callbacks.initialize) |
164 | return unwrap(res: callbacks.initialize(wrap(cpp: ctx), userData)); |
165 | return success(); |
166 | } |
167 | |
168 | bool canScheduleOn(RegisteredOperationName opName) const override { |
169 | if (std::optional<StringRef> specifiedOpName = getOpName()) |
170 | return opName.getStringRef() == specifiedOpName; |
171 | return true; |
172 | } |
173 | |
174 | void runOnOperation() override { |
175 | callbacks.run(wrap(cpp: getOperation()), wrap(cpp: this), userData); |
176 | } |
177 | |
178 | std::unique_ptr<Pass> clonePass() const override { |
179 | void *clonedUserData = callbacks.clone(userData); |
180 | return std::make_unique<ExternalPass>(args: id, args: name, args: argument, args: description, |
181 | args: getOpName(), args: dependentDialects, |
182 | args: callbacks, args&: clonedUserData); |
183 | } |
184 | |
185 | private: |
186 | TypeID id; |
187 | std::string name; |
188 | std::string argument; |
189 | std::string description; |
190 | std::vector<MlirDialectHandle> dependentDialects; |
191 | MlirExternalPassCallbacks callbacks; |
192 | void *userData; |
193 | }; |
194 | } // namespace mlir |
195 | |
196 | MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, |
197 | MlirStringRef argument, |
198 | MlirStringRef description, MlirStringRef opName, |
199 | intptr_t nDependentDialects, |
200 | MlirDialectHandle *dependentDialects, |
201 | MlirExternalPassCallbacks callbacks, |
202 | void *userData) { |
203 | return wrap(cpp: static_cast<mlir::Pass *>(new mlir::ExternalPass( |
204 | unwrap(c: passID), unwrap(ref: name), unwrap(ref: argument), unwrap(ref: description), |
205 | opName.length > 0 ? std::optional<StringRef>(unwrap(ref: opName)) |
206 | : std::nullopt, |
207 | {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks, |
208 | userData))); |
209 | } |
210 | |
211 | void mlirExternalPassSignalFailure(MlirExternalPass pass) { |
212 | unwrap(c: pass)->signalPassFailure(); |
213 | } |
214 |
Definitions
- mlirPassManagerCreate
- mlirPassManagerCreateOnOperation
- mlirPassManagerDestroy
- mlirPassManagerGetAsOpPassManager
- mlirPassManagerRunOnOp
- mlirPassManagerEnableIRPrinting
- mlirPassManagerEnableVerifier
- mlirPassManagerGetNestedUnder
- mlirOpPassManagerGetNestedUnder
- mlirPassManagerAddOwnedPass
- mlirOpPassManagerAddOwnedPass
- mlirOpPassManagerAddPipeline
- mlirPrintPassPipeline
- mlirParsePassPipeline
- ExternalPass
- ExternalPass
- ~ExternalPass
- getName
- getArgument
- getDescription
- getDependentDialects
- signalPassFailure
- initialize
- canScheduleOn
- runOnOperation
- clonePass
- mlirCreateExternalPass
Learn to use CMake with our Intro Training
Find out more