| 1 | //===- Pass.cpp - C Interface for General Pass Management APIs ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir-c/Pass.h" |
| 10 | |
| 11 | #include "mlir/CAPI/IR.h" |
| 12 | #include "mlir/CAPI/Pass.h" |
| 13 | #include "mlir/CAPI/Support.h" |
| 14 | #include "mlir/CAPI/Utils.h" |
| 15 | #include "mlir/Pass/PassManager.h" |
| 16 | #include <optional> |
| 17 | |
| 18 | using namespace mlir; |
| 19 | |
| 20 | //===----------------------------------------------------------------------===// |
| 21 | // PassManager/OpPassManager APIs. |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | |
| 24 | MlirPassManager mlirPassManagerCreate(MlirContext ctx) { |
| 25 | return wrap(cpp: new PassManager(unwrap(c: ctx))); |
| 26 | } |
| 27 | |
| 28 | MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, |
| 29 | MlirStringRef anchorOp) { |
| 30 | return wrap(cpp: new PassManager(unwrap(c: ctx), unwrap(ref: anchorOp))); |
| 31 | } |
| 32 | |
| 33 | void mlirPassManagerDestroy(MlirPassManager passManager) { |
| 34 | delete unwrap(c: passManager); |
| 35 | } |
| 36 | |
| 37 | MlirOpPassManager |
| 38 | mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) { |
| 39 | return wrap(cpp: static_cast<OpPassManager *>(unwrap(c: passManager))); |
| 40 | } |
| 41 | |
| 42 | MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, |
| 43 | MlirOperation op) { |
| 44 | return wrap(res: unwrap(c: passManager)->run(op: unwrap(c: op))); |
| 45 | } |
| 46 | |
| 47 | void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, |
| 48 | bool printBeforeAll, bool printAfterAll, |
| 49 | bool printModuleScope, |
| 50 | bool printAfterOnlyOnChange, |
| 51 | bool printAfterOnlyOnFailure, |
| 52 | MlirOpPrintingFlags flags, |
| 53 | MlirStringRef treePrintingPath) { |
| 54 | auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { |
| 55 | return printBeforeAll; |
| 56 | }; |
| 57 | auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) { |
| 58 | return printAfterAll; |
| 59 | }; |
| 60 | if (unwrap(ref: treePrintingPath).empty()) |
| 61 | return unwrap(c: passManager) |
| 62 | ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, |
| 63 | printModuleScope, printAfterOnlyOnChange, |
| 64 | printAfterOnlyOnFailure, /*out=*/llvm::errs(), |
| 65 | opPrintingFlags: *unwrap(c: flags)); |
| 66 | |
| 67 | unwrap(c: passManager) |
| 68 | ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass, |
| 69 | printModuleScope, printAfterOnlyOnChange, |
| 70 | printAfterOnlyOnFailure, |
| 71 | printTreeDir: unwrap(ref: treePrintingPath), opPrintingFlags: *unwrap(c: flags)); |
| 72 | } |
| 73 | |
| 74 | void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { |
| 75 | unwrap(c: passManager)->enableVerifier(enabled: enable); |
| 76 | } |
| 77 | |
| 78 | MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, |
| 79 | MlirStringRef operationName) { |
| 80 | return wrap(cpp: &unwrap(c: passManager)->nest(nestedName: unwrap(ref: operationName))); |
| 81 | } |
| 82 | |
| 83 | MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, |
| 84 | MlirStringRef operationName) { |
| 85 | return wrap(cpp: &unwrap(c: passManager)->nest(nestedName: unwrap(ref: operationName))); |
| 86 | } |
| 87 | |
| 88 | void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) { |
| 89 | unwrap(c: passManager)->addPass(pass: std::unique_ptr<Pass>(unwrap(c: pass))); |
| 90 | } |
| 91 | |
| 92 | void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, |
| 93 | MlirPass pass) { |
| 94 | unwrap(c: passManager)->addPass(pass: std::unique_ptr<Pass>(unwrap(c: pass))); |
| 95 | } |
| 96 | |
| 97 | MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, |
| 98 | MlirStringRef pipelineElements, |
| 99 | MlirStringCallback callback, |
| 100 | void *userData) { |
| 101 | detail::CallbackOstream stream(callback, userData); |
| 102 | return wrap(res: parsePassPipeline(pipeline: unwrap(ref: pipelineElements), pm&: *unwrap(c: passManager), |
| 103 | errorStream&: stream)); |
| 104 | } |
| 105 | |
| 106 | void mlirPrintPassPipeline(MlirOpPassManager passManager, |
| 107 | MlirStringCallback callback, void *userData) { |
| 108 | detail::CallbackOstream stream(callback, userData); |
| 109 | unwrap(c: passManager)->printAsTextualPipeline(os&: stream); |
| 110 | } |
| 111 | |
| 112 | MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, |
| 113 | MlirStringRef pipeline, |
| 114 | MlirStringCallback callback, |
| 115 | void *userData) { |
| 116 | detail::CallbackOstream stream(callback, userData); |
| 117 | FailureOr<OpPassManager> pm = parsePassPipeline(pipeline: unwrap(ref: pipeline), errorStream&: stream); |
| 118 | if (succeeded(Result: pm)) |
| 119 | *unwrap(c: passManager) = std::move(*pm); |
| 120 | return wrap(res: pm); |
| 121 | } |
| 122 | |
| 123 | //===----------------------------------------------------------------------===// |
| 124 | // External Pass API. |
| 125 | //===----------------------------------------------------------------------===// |
| 126 | |
| 127 | namespace mlir { |
| 128 | class ExternalPass; |
| 129 | } // namespace mlir |
| 130 | DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass) |
| 131 | |
| 132 | namespace mlir { |
| 133 | /// This pass class wraps external passes defined in other languages using the |
| 134 | /// MLIR C-interface |
| 135 | class ExternalPass : public Pass { |
| 136 | public: |
| 137 | ExternalPass(TypeID passID, StringRef name, StringRef argument, |
| 138 | StringRef description, std::optional<StringRef> opName, |
| 139 | ArrayRef<MlirDialectHandle> dependentDialects, |
| 140 | MlirExternalPassCallbacks callbacks, void *userData) |
| 141 | : Pass(passID, opName), id(passID), name(name), argument(argument), |
| 142 | description(description), dependentDialects(dependentDialects), |
| 143 | callbacks(callbacks), userData(userData) { |
| 144 | callbacks.construct(userData); |
| 145 | } |
| 146 | |
| 147 | ~ExternalPass() override { callbacks.destruct(userData); } |
| 148 | |
| 149 | StringRef getName() const override { return name; } |
| 150 | StringRef getArgument() const override { return argument; } |
| 151 | StringRef getDescription() const override { return description; } |
| 152 | |
| 153 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 154 | MlirDialectRegistry cRegistry = wrap(cpp: ®istry); |
| 155 | for (MlirDialectHandle dialect : dependentDialects) |
| 156 | mlirDialectHandleInsertDialect(dialect, cRegistry); |
| 157 | } |
| 158 | |
| 159 | void signalPassFailure() { Pass::signalPassFailure(); } |
| 160 | |
| 161 | protected: |
| 162 | LogicalResult initialize(MLIRContext *ctx) override { |
| 163 | if (callbacks.initialize) |
| 164 | return unwrap(res: callbacks.initialize(wrap(cpp: ctx), userData)); |
| 165 | return success(); |
| 166 | } |
| 167 | |
| 168 | bool canScheduleOn(RegisteredOperationName opName) const override { |
| 169 | if (std::optional<StringRef> specifiedOpName = getOpName()) |
| 170 | return opName.getStringRef() == specifiedOpName; |
| 171 | return true; |
| 172 | } |
| 173 | |
| 174 | void runOnOperation() override { |
| 175 | callbacks.run(wrap(cpp: getOperation()), wrap(cpp: this), userData); |
| 176 | } |
| 177 | |
| 178 | std::unique_ptr<Pass> clonePass() const override { |
| 179 | void *clonedUserData = callbacks.clone(userData); |
| 180 | return std::make_unique<ExternalPass>(args: id, args: name, args: argument, args: description, |
| 181 | args: getOpName(), args: dependentDialects, |
| 182 | args: callbacks, args&: clonedUserData); |
| 183 | } |
| 184 | |
| 185 | private: |
| 186 | TypeID id; |
| 187 | std::string name; |
| 188 | std::string argument; |
| 189 | std::string description; |
| 190 | std::vector<MlirDialectHandle> dependentDialects; |
| 191 | MlirExternalPassCallbacks callbacks; |
| 192 | void *userData; |
| 193 | }; |
| 194 | } // namespace mlir |
| 195 | |
| 196 | MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, |
| 197 | MlirStringRef argument, |
| 198 | MlirStringRef description, MlirStringRef opName, |
| 199 | intptr_t nDependentDialects, |
| 200 | MlirDialectHandle *dependentDialects, |
| 201 | MlirExternalPassCallbacks callbacks, |
| 202 | void *userData) { |
| 203 | return wrap(cpp: static_cast<mlir::Pass *>(new mlir::ExternalPass( |
| 204 | unwrap(c: passID), unwrap(ref: name), unwrap(ref: argument), unwrap(ref: description), |
| 205 | opName.length > 0 ? std::optional<StringRef>(unwrap(ref: opName)) |
| 206 | : std::nullopt, |
| 207 | {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks, |
| 208 | userData))); |
| 209 | } |
| 210 | |
| 211 | void mlirExternalPassSignalFailure(MlirExternalPass pass) { |
| 212 | unwrap(c: pass)->signalPassFailure(); |
| 213 | } |
| 214 | |