| 1 | //===- TestPassManager.cpp - Test pass manager functionality --------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "TestDialect.h" |
| 10 | #include "TestOps.h" |
| 11 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 12 | #include "mlir/IR/BuiltinOps.h" |
| 13 | #include "mlir/Pass/Pass.h" |
| 14 | #include "mlir/Pass/PassManager.h" |
| 15 | |
| 16 | using namespace mlir; |
| 17 | |
| 18 | namespace { |
| 19 | struct TestModulePass |
| 20 | : public PassWrapper<TestModulePass, OperationPass<ModuleOp>> { |
| 21 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestModulePass) |
| 22 | |
| 23 | void runOnOperation() final {} |
| 24 | StringRef getArgument() const final { return "test-module-pass" ; } |
| 25 | StringRef getDescription() const final { |
| 26 | return "Test a module pass in the pass manager" ; |
| 27 | } |
| 28 | }; |
| 29 | struct TestFunctionPass |
| 30 | : public PassWrapper<TestFunctionPass, OperationPass<func::FuncOp>> { |
| 31 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFunctionPass) |
| 32 | |
| 33 | void runOnOperation() final {} |
| 34 | StringRef getArgument() const final { return "test-function-pass" ; } |
| 35 | StringRef getDescription() const final { |
| 36 | return "Test a function pass in the pass manager" ; |
| 37 | } |
| 38 | }; |
| 39 | struct TestInterfacePass |
| 40 | : public PassWrapper<TestInterfacePass, |
| 41 | InterfacePass<FunctionOpInterface>> { |
| 42 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestInterfacePass) |
| 43 | |
| 44 | void runOnOperation() final { |
| 45 | getOperation()->emitRemark() << "Executing interface pass on operation" ; |
| 46 | } |
| 47 | StringRef getArgument() const final { return "test-interface-pass" ; } |
| 48 | StringRef getDescription() const final { |
| 49 | return "Test an interface pass (running on FunctionOpInterface) in the " |
| 50 | "pass manager" ; |
| 51 | } |
| 52 | }; |
| 53 | struct TestOptionsPass |
| 54 | : public PassWrapper<TestOptionsPass, OperationPass<func::FuncOp>> { |
| 55 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsPass) |
| 56 | |
| 57 | enum Enum { Zero, One, Two }; |
| 58 | |
| 59 | struct Options : public PassPipelineOptions<Options> { |
| 60 | ListOption<int> listOption{*this, "list" , |
| 61 | llvm::cl::desc("Example list option" )}; |
| 62 | ListOption<std::string> stringListOption{ |
| 63 | *this, "string-list" , llvm::cl::desc("Example string list option" )}; |
| 64 | Option<std::string> stringOption{*this, "string" , |
| 65 | llvm::cl::desc("Example string option" )}; |
| 66 | Option<Enum> enumOption{ |
| 67 | *this, "enum" , llvm::cl::desc("Example enum option" ), |
| 68 | llvm::cl::values(clEnumValN(0, "zero" , "Example zero value" ), |
| 69 | clEnumValN(1, "one" , "Example one value" ), |
| 70 | clEnumValN(2, "two" , "Example two value" ))}; |
| 71 | |
| 72 | Options() = default; |
| 73 | Options(const Options &rhs) { *this = rhs; } |
| 74 | Options &operator=(const Options &rhs) { |
| 75 | copyOptionValuesFrom(other: rhs); |
| 76 | return *this; |
| 77 | } |
| 78 | }; |
| 79 | TestOptionsPass() = default; |
| 80 | TestOptionsPass(const TestOptionsPass &) : PassWrapper() {} |
| 81 | TestOptionsPass(const Options &options) { |
| 82 | listOption = options.listOption; |
| 83 | stringOption = options.stringOption; |
| 84 | stringListOption = options.stringListOption; |
| 85 | enumOption = options.enumOption; |
| 86 | } |
| 87 | |
| 88 | void runOnOperation() final {} |
| 89 | StringRef getArgument() const final { return "test-options-pass" ; } |
| 90 | StringRef getDescription() const final { |
| 91 | return "Test options parsing capabilities" ; |
| 92 | } |
| 93 | |
| 94 | ListOption<int> listOption{*this, "list" , |
| 95 | llvm::cl::desc("Example list option" )}; |
| 96 | ListOption<std::string> stringListOption{ |
| 97 | *this, "string-list" , llvm::cl::desc("Example string list option" )}; |
| 98 | Option<std::string> stringOption{*this, "string" , |
| 99 | llvm::cl::desc("Example string option" )}; |
| 100 | Option<Enum> enumOption{ |
| 101 | *this, "enum" , llvm::cl::desc("Example enum option" ), |
| 102 | llvm::cl::values(clEnumValN(0, "zero" , "Example zero value" ), |
| 103 | clEnumValN(1, "one" , "Example one value" ), |
| 104 | clEnumValN(2, "two" , "Example two value" ))}; |
| 105 | }; |
| 106 | |
| 107 | struct TestOptionsSuperPass |
| 108 | : public PassWrapper<TestOptionsSuperPass, OperationPass<func::FuncOp>> { |
| 109 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsSuperPass) |
| 110 | |
| 111 | struct Options : public PassPipelineOptions<Options> { |
| 112 | ListOption<TestOptionsPass::Options> listOption{ |
| 113 | *this, "super-list" , |
| 114 | llvm::cl::desc("Example list of PassPipelineOptions option" )}; |
| 115 | |
| 116 | Options() = default; |
| 117 | }; |
| 118 | |
| 119 | TestOptionsSuperPass() = default; |
| 120 | TestOptionsSuperPass(const TestOptionsSuperPass &) : PassWrapper() {} |
| 121 | TestOptionsSuperPass(const Options &options) { |
| 122 | listOption = options.listOption; |
| 123 | } |
| 124 | |
| 125 | void runOnOperation() final {} |
| 126 | StringRef getArgument() const final { return "test-options-super-pass" ; } |
| 127 | StringRef getDescription() const final { |
| 128 | return "Test options of options parsing capabilities" ; |
| 129 | } |
| 130 | |
| 131 | ListOption<TestOptionsPass::Options> listOption{ |
| 132 | *this, "list" , |
| 133 | llvm::cl::desc("Example list of PassPipelineOptions option" )}; |
| 134 | }; |
| 135 | |
| 136 | /// A test pass that always aborts to enable testing the crash recovery |
| 137 | /// mechanism of the pass manager. |
| 138 | struct TestCrashRecoveryPass |
| 139 | : public PassWrapper<TestCrashRecoveryPass, OperationPass<>> { |
| 140 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCrashRecoveryPass) |
| 141 | |
| 142 | void runOnOperation() final { abort(); } |
| 143 | StringRef getArgument() const final { return "test-pass-crash" ; } |
| 144 | StringRef getDescription() const final { |
| 145 | return "Test a pass in the pass manager that always crashes" ; |
| 146 | } |
| 147 | }; |
| 148 | |
| 149 | /// A test pass that always fails to enable testing the failure recovery |
| 150 | /// mechanisms of the pass manager. |
| 151 | struct TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> { |
| 152 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFailurePass) |
| 153 | |
| 154 | TestFailurePass() = default; |
| 155 | TestFailurePass(const TestFailurePass &other) : PassWrapper(other) {} |
| 156 | |
| 157 | void runOnOperation() final { |
| 158 | signalPassFailure(); |
| 159 | if (genDiagnostics) |
| 160 | mlir::emitError(loc: getOperation()->getLoc(), message: "illegal operation" ); |
| 161 | } |
| 162 | StringRef getArgument() const final { return "test-pass-failure" ; } |
| 163 | StringRef getDescription() const final { |
| 164 | return "Test a pass in the pass manager that always fails" ; |
| 165 | } |
| 166 | |
| 167 | Option<bool> genDiagnostics{*this, "gen-diagnostics" , |
| 168 | llvm::cl::desc("Generate a diagnostic message" )}; |
| 169 | }; |
| 170 | |
| 171 | /// A test pass that creates an invalid operation in a function body. |
| 172 | struct TestInvalidIRPass |
| 173 | : public PassWrapper<TestInvalidIRPass, |
| 174 | InterfacePass<FunctionOpInterface>> { |
| 175 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestInvalidIRPass) |
| 176 | |
| 177 | TestInvalidIRPass() = default; |
| 178 | TestInvalidIRPass(const TestInvalidIRPass &other) : PassWrapper(other) {} |
| 179 | |
| 180 | StringRef getArgument() const final { return "test-pass-create-invalid-ir" ; } |
| 181 | StringRef getDescription() const final { |
| 182 | return "Test pass that adds an invalid operation in a function body" ; |
| 183 | } |
| 184 | void getDependentDialects(DialectRegistry ®istry) const final { |
| 185 | registry.insert<test::TestDialect>(); |
| 186 | } |
| 187 | void runOnOperation() final { |
| 188 | if (signalFailure) |
| 189 | signalPassFailure(); |
| 190 | if (!emitInvalidIR) |
| 191 | return; |
| 192 | OpBuilder b(getOperation().getFunctionBody()); |
| 193 | OperationState state(b.getUnknownLoc(), "test.any_attr_of_i32_str" ); |
| 194 | b.create(state); |
| 195 | } |
| 196 | Option<bool> signalFailure{*this, "signal-pass-failure" , |
| 197 | llvm::cl::desc("Trigger a pass failure" )}; |
| 198 | Option<bool> emitInvalidIR{*this, "emit-invalid-ir" , llvm::cl::init(Val: true), |
| 199 | llvm::cl::desc("Emit invalid IR" )}; |
| 200 | }; |
| 201 | |
| 202 | /// A test pass that always fails to enable testing the failure recovery |
| 203 | /// mechanisms of the pass manager. |
| 204 | struct TestInvalidParentPass |
| 205 | : public PassWrapper<TestInvalidParentPass, |
| 206 | InterfacePass<FunctionOpInterface>> { |
| 207 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestInvalidParentPass) |
| 208 | |
| 209 | StringRef getArgument() const final { return "test-pass-invalid-parent" ; } |
| 210 | StringRef getDescription() const final { |
| 211 | return "Test a pass in the pass manager that makes the parent operation " |
| 212 | "invalid" ; |
| 213 | } |
| 214 | void getDependentDialects(DialectRegistry ®istry) const final { |
| 215 | registry.insert<test::TestDialect>(); |
| 216 | } |
| 217 | void runOnOperation() final { |
| 218 | FunctionOpInterface op = getOperation(); |
| 219 | OpBuilder b(op.getFunctionBody()); |
| 220 | b.create<test::TestCallOp>(op.getLoc(), TypeRange(), "some_unknown_func" , |
| 221 | ValueRange()); |
| 222 | } |
| 223 | }; |
| 224 | |
| 225 | /// A test pass that contains a statistic. |
| 226 | struct TestStatisticPass |
| 227 | : public PassWrapper<TestStatisticPass, OperationPass<>> { |
| 228 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStatisticPass) |
| 229 | |
| 230 | TestStatisticPass() = default; |
| 231 | TestStatisticPass(const TestStatisticPass &) : PassWrapper() {} |
| 232 | StringRef getArgument() const final { return "test-stats-pass" ; } |
| 233 | StringRef getDescription() const final { return "Test pass statistics" ; } |
| 234 | |
| 235 | // Use a couple of statistics to verify their ordering |
| 236 | // in the print out. The statistics are registered in the order |
| 237 | // of construction, so put "num-ops2" before "num-ops" and |
| 238 | // make sure that the order is reversed. |
| 239 | Statistic opCountDuplicate{this, "num-ops2" , |
| 240 | "Number of operations counted one more time" }; |
| 241 | Statistic opCount{this, "num-ops" , "Number of operations counted" }; |
| 242 | |
| 243 | void runOnOperation() final { |
| 244 | getOperation()->walk(callback: [&](Operation *) { ++opCount; }); |
| 245 | getOperation()->walk(callback: [&](Operation *) { ++opCountDuplicate; }); |
| 246 | } |
| 247 | }; |
| 248 | } // namespace |
| 249 | |
| 250 | static void testNestedPipeline(OpPassManager &pm) { |
| 251 | // Nest a module pipeline that contains: |
| 252 | /// A module pass. |
| 253 | auto &modulePM = pm.nest<ModuleOp>(); |
| 254 | modulePM.addPass(pass: std::make_unique<TestModulePass>()); |
| 255 | /// A nested function pass. |
| 256 | auto &nestedFunctionPM = modulePM.nest<func::FuncOp>(); |
| 257 | nestedFunctionPM.addPass(std::make_unique<TestFunctionPass>()); |
| 258 | |
| 259 | // Nest a function pipeline that contains a single pass. |
| 260 | auto &functionPM = pm.nest<func::FuncOp>(); |
| 261 | functionPM.addPass(std::make_unique<TestFunctionPass>()); |
| 262 | } |
| 263 | |
| 264 | static void testNestedPipelineTextual(OpPassManager &pm) { |
| 265 | (void)parsePassPipeline(pipeline: "test-pm-nested-pipeline" , pm); |
| 266 | } |
| 267 | |
| 268 | namespace mlir { |
| 269 | void registerPassManagerTestPass() { |
| 270 | PassRegistration<TestOptionsPass>(); |
| 271 | PassRegistration<TestOptionsSuperPass>(); |
| 272 | |
| 273 | PassRegistration<TestModulePass>(); |
| 274 | |
| 275 | PassRegistration<TestFunctionPass>(); |
| 276 | |
| 277 | PassRegistration<TestInterfacePass>(); |
| 278 | |
| 279 | PassRegistration<TestCrashRecoveryPass>(); |
| 280 | PassRegistration<TestFailurePass>(); |
| 281 | PassRegistration<TestInvalidIRPass>(); |
| 282 | PassRegistration<TestInvalidParentPass>(); |
| 283 | |
| 284 | PassRegistration<TestStatisticPass>(); |
| 285 | |
| 286 | PassPipelineRegistration<>("test-pm-nested-pipeline" , |
| 287 | "Test a nested pipeline in the pass manager" , |
| 288 | testNestedPipeline); |
| 289 | PassPipelineRegistration<>("test-textual-pm-nested-pipeline" , |
| 290 | "Test a nested pipeline in the pass manager" , |
| 291 | testNestedPipelineTextual); |
| 292 | |
| 293 | PassPipelineRegistration<TestOptionsPass::Options> |
| 294 | registerOptionsPassPipeline( |
| 295 | "test-options-pass-pipeline" , |
| 296 | "Parses options using pass pipeline registration" , |
| 297 | [](OpPassManager &pm, const TestOptionsPass::Options &options) { |
| 298 | pm.addPass(std::make_unique<TestOptionsPass>(args: options)); |
| 299 | }); |
| 300 | |
| 301 | PassPipelineRegistration<TestOptionsSuperPass::Options> |
| 302 | registerOptionsSuperPassPipeline( |
| 303 | "test-options-super-pass-pipeline" , |
| 304 | "Parses options of PassPipelineOptions using pass pipeline " |
| 305 | "registration" , |
| 306 | [](OpPassManager &pm, const TestOptionsSuperPass::Options &options) { |
| 307 | pm.addPass(std::make_unique<TestOptionsSuperPass>(args: options)); |
| 308 | }); |
| 309 | } |
| 310 | } // namespace mlir |
| 311 | |