| 1 | //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to |
| 10 | // run any optimization pass within it and only replaces the output module with |
| 11 | // the transformed version if it is smaller and interesting. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Pass/PassManager.h" |
| 16 | #include "mlir/Pass/PassRegistry.h" |
| 17 | #include "mlir/Reducer/Passes.h" |
| 18 | #include "mlir/Reducer/Tester.h" |
| 19 | |
| 20 | #include "llvm/Support/Debug.h" |
| 21 | |
| 22 | namespace mlir { |
| 23 | #define GEN_PASS_DEF_OPTREDUCTIONPASS |
| 24 | #include "mlir/Reducer/Passes.h.inc" |
| 25 | } // namespace mlir |
| 26 | |
| 27 | #define DEBUG_TYPE "mlir-reduce" |
| 28 | |
| 29 | using namespace mlir; |
| 30 | |
| 31 | namespace { |
| 32 | |
| 33 | class OptReductionPass : public impl::OptReductionPassBase<OptReductionPass> { |
| 34 | public: |
| 35 | using Base::Base; |
| 36 | |
| 37 | /// Runs the pass instance in the pass pipeline. |
| 38 | void runOnOperation() override; |
| 39 | }; |
| 40 | |
| 41 | } // namespace |
| 42 | |
| 43 | /// Runs the pass instance in the pass pipeline. |
| 44 | void OptReductionPass::runOnOperation() { |
| 45 | LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: " ); |
| 46 | |
| 47 | Tester test(testerName, testerArgs); |
| 48 | |
| 49 | ModuleOp module = this->getOperation(); |
| 50 | ModuleOp moduleVariant = module.clone(); |
| 51 | |
| 52 | OpPassManager passManager("builtin.module" ); |
| 53 | if (failed(parsePassPipeline(optPass, passManager))) { |
| 54 | module.emitError() << "\nfailed to parse pass pipeline" ; |
| 55 | return signalPassFailure(); |
| 56 | } |
| 57 | |
| 58 | std::pair<Tester::Interestingness, int> original = test.isInteresting(module); |
| 59 | if (original.first != Tester::Interestingness::True) { |
| 60 | module.emitError() << "\nthe original input is not interested" ; |
| 61 | return signalPassFailure(); |
| 62 | } |
| 63 | |
| 64 | // Temporarily push the variant under the main module and execute the pipeline |
| 65 | // on it. |
| 66 | module.getBody()->push_back(moduleVariant); |
| 67 | LogicalResult pipelineResult = runPipeline(passManager, moduleVariant); |
| 68 | moduleVariant->remove(); |
| 69 | |
| 70 | if (failed(Result: pipelineResult)) { |
| 71 | module.emitError() << "\nfailed to run pass pipeline" ; |
| 72 | return signalPassFailure(); |
| 73 | } |
| 74 | |
| 75 | std::pair<Tester::Interestingness, int> reduced = |
| 76 | test.isInteresting(moduleVariant); |
| 77 | |
| 78 | if (reduced.first == Tester::Interestingness::True && |
| 79 | reduced.second < original.second) { |
| 80 | module.getBody()->clear(); |
| 81 | module.getBody()->getOperations().splice( |
| 82 | module.getBody()->begin(), moduleVariant.getBody()->getOperations()); |
| 83 | LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n" ); |
| 84 | } else { |
| 85 | LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n" ); |
| 86 | } |
| 87 | |
| 88 | moduleVariant->destroy(); |
| 89 | |
| 90 | LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n" ); |
| 91 | } |
| 92 | |