1 | //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to |
10 | // run any optimization pass within it and only replaces the output module with |
11 | // the transformed version if it is smaller and interesting. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Pass/PassManager.h" |
16 | #include "mlir/Pass/PassRegistry.h" |
17 | #include "mlir/Reducer/Passes.h" |
18 | #include "mlir/Reducer/Tester.h" |
19 | #include "llvm/Support/Debug.h" |
20 | |
21 | namespace mlir { |
22 | #define GEN_PASS_DEF_OPTREDUCTION |
23 | #include "mlir/Reducer/Passes.h.inc" |
24 | } // namespace mlir |
25 | |
26 | #define DEBUG_TYPE "mlir-reduce" |
27 | |
28 | using namespace mlir; |
29 | |
30 | namespace { |
31 | |
32 | class OptReductionPass : public impl::OptReductionBase<OptReductionPass> { |
33 | public: |
34 | /// Runs the pass instance in the pass pipeline. |
35 | void runOnOperation() override; |
36 | }; |
37 | |
38 | } // namespace |
39 | |
40 | /// Runs the pass instance in the pass pipeline. |
41 | void OptReductionPass::runOnOperation() { |
42 | LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: " ); |
43 | |
44 | Tester test(testerName, testerArgs); |
45 | |
46 | ModuleOp module = this->getOperation(); |
47 | ModuleOp moduleVariant = module.clone(); |
48 | |
49 | OpPassManager passManager("builtin.module" ); |
50 | if (failed(parsePassPipeline(optPass, passManager))) { |
51 | module.emitError() << "\nfailed to parse pass pipeline" ; |
52 | return signalPassFailure(); |
53 | } |
54 | |
55 | std::pair<Tester::Interestingness, int> original = test.isInteresting(module); |
56 | if (original.first != Tester::Interestingness::True) { |
57 | module.emitError() << "\nthe original input is not interested" ; |
58 | return signalPassFailure(); |
59 | } |
60 | |
61 | // Temporarily push the variant under the main module and execute the pipeline |
62 | // on it. |
63 | module.getBody()->push_back(moduleVariant); |
64 | LogicalResult pipelineResult = runPipeline(passManager, moduleVariant); |
65 | moduleVariant->remove(); |
66 | |
67 | if (failed(result: pipelineResult)) { |
68 | module.emitError() << "\nfailed to run pass pipeline" ; |
69 | return signalPassFailure(); |
70 | } |
71 | |
72 | std::pair<Tester::Interestingness, int> reduced = |
73 | test.isInteresting(moduleVariant); |
74 | |
75 | if (reduced.first == Tester::Interestingness::True && |
76 | reduced.second < original.second) { |
77 | module.getBody()->clear(); |
78 | module.getBody()->getOperations().splice( |
79 | module.getBody()->begin(), moduleVariant.getBody()->getOperations()); |
80 | LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n" ); |
81 | } else { |
82 | LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n" ); |
83 | } |
84 | |
85 | moduleVariant->destroy(); |
86 | |
87 | LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n" ); |
88 | } |
89 | |
90 | std::unique_ptr<Pass> mlir::createOptReductionPass() { |
91 | return std::make_unique<OptReductionPass>(); |
92 | } |
93 | |