1 | //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to |
10 | // run any optimization pass within it and only replaces the output module with |
11 | // the transformed version if it is smaller and interesting. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Pass/PassManager.h" |
16 | #include "mlir/Pass/PassRegistry.h" |
17 | #include "mlir/Reducer/Passes.h" |
18 | #include "mlir/Reducer/Tester.h" |
19 | |
20 | #include "llvm/Support/Debug.h" |
21 | |
22 | namespace mlir { |
23 | #define GEN_PASS_DEF_OPTREDUCTIONPASS |
24 | #include "mlir/Reducer/Passes.h.inc" |
25 | } // namespace mlir |
26 | |
27 | #define DEBUG_TYPE "mlir-reduce" |
28 | |
29 | using namespace mlir; |
30 | |
31 | namespace { |
32 | |
33 | class OptReductionPass : public impl::OptReductionPassBase<OptReductionPass> { |
34 | public: |
35 | using Base::Base; |
36 | |
37 | /// Runs the pass instance in the pass pipeline. |
38 | void runOnOperation() override; |
39 | }; |
40 | |
41 | } // namespace |
42 | |
43 | /// Runs the pass instance in the pass pipeline. |
44 | void OptReductionPass::runOnOperation() { |
45 | LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: " ); |
46 | |
47 | Tester test(testerName, testerArgs); |
48 | |
49 | ModuleOp module = this->getOperation(); |
50 | ModuleOp moduleVariant = module.clone(); |
51 | |
52 | OpPassManager passManager("builtin.module" ); |
53 | if (failed(parsePassPipeline(optPass, passManager))) { |
54 | module.emitError() << "\nfailed to parse pass pipeline" ; |
55 | return signalPassFailure(); |
56 | } |
57 | |
58 | std::pair<Tester::Interestingness, int> original = test.isInteresting(module); |
59 | if (original.first != Tester::Interestingness::True) { |
60 | module.emitError() << "\nthe original input is not interested" ; |
61 | return signalPassFailure(); |
62 | } |
63 | |
64 | // Temporarily push the variant under the main module and execute the pipeline |
65 | // on it. |
66 | module.getBody()->push_back(moduleVariant); |
67 | LogicalResult pipelineResult = runPipeline(passManager, moduleVariant); |
68 | moduleVariant->remove(); |
69 | |
70 | if (failed(Result: pipelineResult)) { |
71 | module.emitError() << "\nfailed to run pass pipeline" ; |
72 | return signalPassFailure(); |
73 | } |
74 | |
75 | std::pair<Tester::Interestingness, int> reduced = |
76 | test.isInteresting(moduleVariant); |
77 | |
78 | if (reduced.first == Tester::Interestingness::True && |
79 | reduced.second < original.second) { |
80 | module.getBody()->clear(); |
81 | module.getBody()->getOperations().splice( |
82 | module.getBody()->begin(), moduleVariant.getBody()->getOperations()); |
83 | LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n" ); |
84 | } else { |
85 | LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n" ); |
86 | } |
87 | |
88 | moduleVariant->destroy(); |
89 | |
90 | LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n" ); |
91 | } |
92 | |