1//===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/Pass.h"
10
11#include "mlir/CAPI/IR.h"
12#include "mlir/CAPI/Pass.h"
13#include "mlir/CAPI/Support.h"
14#include "mlir/CAPI/Utils.h"
15#include "mlir/Pass/PassManager.h"
16#include <optional>
17
18using namespace mlir;
19
20//===----------------------------------------------------------------------===//
21// PassManager/OpPassManager APIs.
22//===----------------------------------------------------------------------===//
23
24MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
25 return wrap(cpp: new PassManager(unwrap(c: ctx)));
26}
27
28MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
29 MlirStringRef anchorOp) {
30 return wrap(cpp: new PassManager(unwrap(c: ctx), unwrap(ref: anchorOp)));
31}
32
33void mlirPassManagerDestroy(MlirPassManager passManager) {
34 delete unwrap(c: passManager);
35}
36
37MlirOpPassManager
38mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
39 return wrap(cpp: static_cast<OpPassManager *>(unwrap(c: passManager)));
40}
41
42MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
43 MlirOperation op) {
44 return wrap(res: unwrap(c: passManager)->run(op: unwrap(c: op)));
45}
46
47void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
48 bool printBeforeAll, bool printAfterAll,
49 bool printModuleScope,
50 bool printAfterOnlyOnChange,
51 bool printAfterOnlyOnFailure,
52 MlirOpPrintingFlags flags,
53 MlirStringRef treePrintingPath) {
54 auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
55 return printBeforeAll;
56 };
57 auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
58 return printAfterAll;
59 };
60 if (unwrap(ref: treePrintingPath).empty())
61 return unwrap(c: passManager)
62 ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
63 printModuleScope, printAfterOnlyOnChange,
64 printAfterOnlyOnFailure, /*out=*/llvm::errs(),
65 opPrintingFlags: *unwrap(c: flags));
66
67 unwrap(c: passManager)
68 ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
69 printModuleScope, printAfterOnlyOnChange,
70 printAfterOnlyOnFailure,
71 printTreeDir: unwrap(ref: treePrintingPath), opPrintingFlags: *unwrap(c: flags));
72}
73
74void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
75 unwrap(c: passManager)->enableVerifier(enabled: enable);
76}
77
78MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
79 MlirStringRef operationName) {
80 return wrap(cpp: &unwrap(c: passManager)->nest(nestedName: unwrap(ref: operationName)));
81}
82
83MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
84 MlirStringRef operationName) {
85 return wrap(cpp: &unwrap(c: passManager)->nest(nestedName: unwrap(ref: operationName)));
86}
87
88void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
89 unwrap(c: passManager)->addPass(pass: std::unique_ptr<Pass>(unwrap(c: pass)));
90}
91
92void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
93 MlirPass pass) {
94 unwrap(c: passManager)->addPass(pass: std::unique_ptr<Pass>(unwrap(c: pass)));
95}
96
97MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
98 MlirStringRef pipelineElements,
99 MlirStringCallback callback,
100 void *userData) {
101 detail::CallbackOstream stream(callback, userData);
102 return wrap(res: parsePassPipeline(pipeline: unwrap(ref: pipelineElements), pm&: *unwrap(c: passManager),
103 errorStream&: stream));
104}
105
106void mlirPrintPassPipeline(MlirOpPassManager passManager,
107 MlirStringCallback callback, void *userData) {
108 detail::CallbackOstream stream(callback, userData);
109 unwrap(c: passManager)->printAsTextualPipeline(os&: stream);
110}
111
112MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
113 MlirStringRef pipeline,
114 MlirStringCallback callback,
115 void *userData) {
116 detail::CallbackOstream stream(callback, userData);
117 FailureOr<OpPassManager> pm = parsePassPipeline(pipeline: unwrap(ref: pipeline), errorStream&: stream);
118 if (succeeded(Result: pm))
119 *unwrap(c: passManager) = std::move(*pm);
120 return wrap(res: pm);
121}
122
123//===----------------------------------------------------------------------===//
124// External Pass API.
125//===----------------------------------------------------------------------===//
126
127namespace mlir {
128class ExternalPass;
129} // namespace mlir
130DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
131
132namespace mlir {
133/// This pass class wraps external passes defined in other languages using the
134/// MLIR C-interface
135class ExternalPass : public Pass {
136public:
137 ExternalPass(TypeID passID, StringRef name, StringRef argument,
138 StringRef description, std::optional<StringRef> opName,
139 ArrayRef<MlirDialectHandle> dependentDialects,
140 MlirExternalPassCallbacks callbacks, void *userData)
141 : Pass(passID, opName), id(passID), name(name), argument(argument),
142 description(description), dependentDialects(dependentDialects),
143 callbacks(callbacks), userData(userData) {
144 callbacks.construct(userData);
145 }
146
147 ~ExternalPass() override { callbacks.destruct(userData); }
148
149 StringRef getName() const override { return name; }
150 StringRef getArgument() const override { return argument; }
151 StringRef getDescription() const override { return description; }
152
153 void getDependentDialects(DialectRegistry &registry) const override {
154 MlirDialectRegistry cRegistry = wrap(cpp: &registry);
155 for (MlirDialectHandle dialect : dependentDialects)
156 mlirDialectHandleInsertDialect(dialect, cRegistry);
157 }
158
159 void signalPassFailure() { Pass::signalPassFailure(); }
160
161protected:
162 LogicalResult initialize(MLIRContext *ctx) override {
163 if (callbacks.initialize)
164 return unwrap(res: callbacks.initialize(wrap(cpp: ctx), userData));
165 return success();
166 }
167
168 bool canScheduleOn(RegisteredOperationName opName) const override {
169 if (std::optional<StringRef> specifiedOpName = getOpName())
170 return opName.getStringRef() == specifiedOpName;
171 return true;
172 }
173
174 void runOnOperation() override {
175 callbacks.run(wrap(cpp: getOperation()), wrap(cpp: this), userData);
176 }
177
178 std::unique_ptr<Pass> clonePass() const override {
179 void *clonedUserData = callbacks.clone(userData);
180 return std::make_unique<ExternalPass>(args: id, args: name, args: argument, args: description,
181 args: getOpName(), args: dependentDialects,
182 args: callbacks, args&: clonedUserData);
183 }
184
185private:
186 TypeID id;
187 std::string name;
188 std::string argument;
189 std::string description;
190 std::vector<MlirDialectHandle> dependentDialects;
191 MlirExternalPassCallbacks callbacks;
192 void *userData;
193};
194} // namespace mlir
195
196MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
197 MlirStringRef argument,
198 MlirStringRef description, MlirStringRef opName,
199 intptr_t nDependentDialects,
200 MlirDialectHandle *dependentDialects,
201 MlirExternalPassCallbacks callbacks,
202 void *userData) {
203 return wrap(cpp: static_cast<mlir::Pass *>(new mlir::ExternalPass(
204 unwrap(c: passID), unwrap(ref: name), unwrap(ref: argument), unwrap(ref: description),
205 opName.length > 0 ? std::optional<StringRef>(unwrap(ref: opName))
206 : std::nullopt,
207 {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
208 userData)));
209}
210
211void mlirExternalPassSignalFailure(MlirExternalPass pass) {
212 unwrap(c: pass)->signalPassFailure();
213}
214

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/CAPI/IR/Pass.cpp