1//===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
10// run any optimization pass within it and only replaces the output module with
11// the transformed version if it is smaller and interesting.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Pass/PassManager.h"
16#include "mlir/Pass/PassRegistry.h"
17#include "mlir/Reducer/Passes.h"
18#include "mlir/Reducer/Tester.h"
19#include "llvm/Support/Debug.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_OPTREDUCTION
23#include "mlir/Reducer/Passes.h.inc"
24} // namespace mlir
25
26#define DEBUG_TYPE "mlir-reduce"
27
28using namespace mlir;
29
30namespace {
31
32class OptReductionPass : public impl::OptReductionBase<OptReductionPass> {
33public:
34 /// Runs the pass instance in the pass pipeline.
35 void runOnOperation() override;
36};
37
38} // namespace
39
40/// Runs the pass instance in the pass pipeline.
41void OptReductionPass::runOnOperation() {
42 LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");
43
44 Tester test(testerName, testerArgs);
45
46 ModuleOp module = this->getOperation();
47 ModuleOp moduleVariant = module.clone();
48
49 OpPassManager passManager("builtin.module");
50 if (failed(parsePassPipeline(optPass, passManager))) {
51 module.emitError() << "\nfailed to parse pass pipeline";
52 return signalPassFailure();
53 }
54
55 std::pair<Tester::Interestingness, int> original = test.isInteresting(module);
56 if (original.first != Tester::Interestingness::True) {
57 module.emitError() << "\nthe original input is not interested";
58 return signalPassFailure();
59 }
60
61 // Temporarily push the variant under the main module and execute the pipeline
62 // on it.
63 module.getBody()->push_back(moduleVariant);
64 LogicalResult pipelineResult = runPipeline(passManager, moduleVariant);
65 moduleVariant->remove();
66
67 if (failed(result: pipelineResult)) {
68 module.emitError() << "\nfailed to run pass pipeline";
69 return signalPassFailure();
70 }
71
72 std::pair<Tester::Interestingness, int> reduced =
73 test.isInteresting(moduleVariant);
74
75 if (reduced.first == Tester::Interestingness::True &&
76 reduced.second < original.second) {
77 module.getBody()->clear();
78 module.getBody()->getOperations().splice(
79 module.getBody()->begin(), moduleVariant.getBody()->getOperations());
80 LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n");
81 } else {
82 LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n");
83 }
84
85 moduleVariant->destroy();
86
87 LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n");
88}
89
90std::unique_ptr<Pass> mlir::createOptReductionPass() {
91 return std::make_unique<OptReductionPass>();
92}
93

source code of mlir/lib/Reducer/OptReductionPass.cpp