1 | //===- Pass.cpp - C Interface for General Pass Management APIs ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/Pass.h" |
10 | |
11 | #include "mlir/CAPI/IR.h" |
12 | #include "mlir/CAPI/Pass.h" |
13 | #include "mlir/CAPI/Support.h" |
14 | #include "mlir/CAPI/Utils.h" |
15 | #include "mlir/Pass/PassManager.h" |
16 | #include <optional> |
17 | |
18 | using namespace mlir; |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // PassManager/OpPassManager APIs. |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | MlirPassManager mlirPassManagerCreate(MlirContext ctx) { |
25 | return wrap(cpp: new PassManager(unwrap(c: ctx))); |
26 | } |
27 | |
28 | MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, |
29 | MlirStringRef anchorOp) { |
30 | return wrap(cpp: new PassManager(unwrap(c: ctx), unwrap(ref: anchorOp))); |
31 | } |
32 | |
33 | void mlirPassManagerDestroy(MlirPassManager passManager) { |
34 | delete unwrap(c: passManager); |
35 | } |
36 | |
37 | MlirOpPassManager |
38 | mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) { |
39 | return wrap(cpp: static_cast<OpPassManager *>(unwrap(c: passManager))); |
40 | } |
41 | |
42 | MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, |
43 | MlirOperation op) { |
44 | return wrap(res: unwrap(c: passManager)->run(op: unwrap(c: op))); |
45 | } |
46 | |
47 | void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) { |
48 | return unwrap(c: passManager)->enableIRPrinting(); |
49 | } |
50 | |
51 | void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { |
52 | unwrap(c: passManager)->enableVerifier(enabled: enable); |
53 | } |
54 | |
55 | MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, |
56 | MlirStringRef operationName) { |
57 | return wrap(cpp: &unwrap(c: passManager)->nest(nestedName: unwrap(ref: operationName))); |
58 | } |
59 | |
60 | MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, |
61 | MlirStringRef operationName) { |
62 | return wrap(cpp: &unwrap(c: passManager)->nest(nestedName: unwrap(ref: operationName))); |
63 | } |
64 | |
65 | void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) { |
66 | unwrap(c: passManager)->addPass(pass: std::unique_ptr<Pass>(unwrap(c: pass))); |
67 | } |
68 | |
69 | void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, |
70 | MlirPass pass) { |
71 | unwrap(c: passManager)->addPass(pass: std::unique_ptr<Pass>(unwrap(c: pass))); |
72 | } |
73 | |
74 | MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, |
75 | MlirStringRef pipelineElements, |
76 | MlirStringCallback callback, |
77 | void *userData) { |
78 | detail::CallbackOstream stream(callback, userData); |
79 | return wrap(res: parsePassPipeline(pipeline: unwrap(ref: pipelineElements), pm&: *unwrap(c: passManager), |
80 | errorStream&: stream)); |
81 | } |
82 | |
83 | void mlirPrintPassPipeline(MlirOpPassManager passManager, |
84 | MlirStringCallback callback, void *userData) { |
85 | detail::CallbackOstream stream(callback, userData); |
86 | unwrap(c: passManager)->printAsTextualPipeline(os&: stream); |
87 | } |
88 | |
89 | MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, |
90 | MlirStringRef pipeline, |
91 | MlirStringCallback callback, |
92 | void *userData) { |
93 | detail::CallbackOstream stream(callback, userData); |
94 | FailureOr<OpPassManager> pm = parsePassPipeline(pipeline: unwrap(ref: pipeline), errorStream&: stream); |
95 | if (succeeded(result: pm)) |
96 | *unwrap(c: passManager) = std::move(*pm); |
97 | return wrap(res: pm); |
98 | } |
99 | |
100 | //===----------------------------------------------------------------------===// |
101 | // External Pass API. |
102 | //===----------------------------------------------------------------------===// |
103 | |
104 | namespace mlir { |
105 | class ExternalPass; |
106 | } // namespace mlir |
107 | DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass) |
108 | |
109 | namespace mlir { |
110 | /// This pass class wraps external passes defined in other languages using the |
111 | /// MLIR C-interface |
112 | class ExternalPass : public Pass { |
113 | public: |
114 | ExternalPass(TypeID passID, StringRef name, StringRef argument, |
115 | StringRef description, std::optional<StringRef> opName, |
116 | ArrayRef<MlirDialectHandle> dependentDialects, |
117 | MlirExternalPassCallbacks callbacks, void *userData) |
118 | : Pass(passID, opName), id(passID), name(name), argument(argument), |
119 | description(description), dependentDialects(dependentDialects), |
120 | callbacks(callbacks), userData(userData) { |
121 | callbacks.construct(userData); |
122 | } |
123 | |
124 | ~ExternalPass() override { callbacks.destruct(userData); } |
125 | |
126 | StringRef getName() const override { return name; } |
127 | StringRef getArgument() const override { return argument; } |
128 | StringRef getDescription() const override { return description; } |
129 | |
130 | void getDependentDialects(DialectRegistry ®istry) const override { |
131 | MlirDialectRegistry cRegistry = wrap(cpp: ®istry); |
132 | for (MlirDialectHandle dialect : dependentDialects) |
133 | mlirDialectHandleInsertDialect(dialect, cRegistry); |
134 | } |
135 | |
136 | void signalPassFailure() { Pass::signalPassFailure(); } |
137 | |
138 | protected: |
139 | LogicalResult initialize(MLIRContext *ctx) override { |
140 | if (callbacks.initialize) |
141 | return unwrap(res: callbacks.initialize(wrap(cpp: ctx), userData)); |
142 | return success(); |
143 | } |
144 | |
145 | bool canScheduleOn(RegisteredOperationName opName) const override { |
146 | if (std::optional<StringRef> specifiedOpName = getOpName()) |
147 | return opName.getStringRef() == specifiedOpName; |
148 | return true; |
149 | } |
150 | |
151 | void runOnOperation() override { |
152 | callbacks.run(wrap(cpp: getOperation()), wrap(cpp: this), userData); |
153 | } |
154 | |
155 | std::unique_ptr<Pass> clonePass() const override { |
156 | void *clonedUserData = callbacks.clone(userData); |
157 | return std::make_unique<ExternalPass>(args: id, args: name, args: argument, args: description, |
158 | args: getOpName(), args: dependentDialects, |
159 | args: callbacks, args&: clonedUserData); |
160 | } |
161 | |
162 | private: |
163 | TypeID id; |
164 | std::string name; |
165 | std::string argument; |
166 | std::string description; |
167 | std::vector<MlirDialectHandle> dependentDialects; |
168 | MlirExternalPassCallbacks callbacks; |
169 | void *userData; |
170 | }; |
171 | } // namespace mlir |
172 | |
173 | MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, |
174 | MlirStringRef argument, |
175 | MlirStringRef description, MlirStringRef opName, |
176 | intptr_t nDependentDialects, |
177 | MlirDialectHandle *dependentDialects, |
178 | MlirExternalPassCallbacks callbacks, |
179 | void *userData) { |
180 | return wrap(cpp: static_cast<mlir::Pass *>(new mlir::ExternalPass( |
181 | unwrap(c: passID), unwrap(ref: name), unwrap(ref: argument), unwrap(ref: description), |
182 | opName.length > 0 ? std::optional<StringRef>(unwrap(ref: opName)) |
183 | : std::nullopt, |
184 | {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks, |
185 | userData))); |
186 | } |
187 | |
188 | void mlirExternalPassSignalFailure(MlirExternalPass pass) { |
189 | unwrap(c: pass)->signalPassFailure(); |
190 | } |
191 | |