1//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11#include "mlir/IR/Matchers.h"
12#include "mlir/Pass/Pass.h"
13
14using namespace mlir;
15using namespace mlir::dataflow;
16
17/// Print the liveness of every block, control-flow edge, and the predecessors
18/// of all regions, callables, and calls.
19static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
20 raw_ostream &os) {
21 op->walk(callback: [&](Operation *op) {
22 auto tag = op->getAttrOfType<StringAttr>("tag");
23 if (!tag)
24 return;
25 os << tag.getValue() << ":\n";
26 for (Region &region : op->getRegions()) {
27 os << " region #" << region.getRegionNumber() << "\n";
28 for (Block &block : region) {
29 os << " ";
30 block.printAsOperand(os);
31 os << " = ";
32 auto *live = solver.lookupState<Executable>(point: &block);
33 if (live)
34 os << *live;
35 else
36 os << "dead";
37 os << "\n";
38 for (Block *pred : block.getPredecessors()) {
39 os << " from ";
40 pred->printAsOperand(os);
41 os << " = ";
42 auto *live = solver.lookupState<Executable>(
43 point: solver.getProgramPoint<CFGEdge>(args&: pred, args: &block));
44 if (live)
45 os << *live;
46 else
47 os << "dead";
48 os << "\n";
49 }
50 }
51 if (!region.empty()) {
52 auto *preds = solver.lookupState<PredecessorState>(point: &region.front());
53 if (preds)
54 os << "region_preds: " << *preds << "\n";
55 }
56 }
57 auto *preds = solver.lookupState<PredecessorState>(point: op);
58 if (preds)
59 os << "op_preds: " << *preds << "\n";
60 });
61}
62
63namespace {
64/// This is a simple analysis that implements a transfer function for constant
65/// operations.
66struct ConstantAnalysis : public DataFlowAnalysis {
67 using DataFlowAnalysis::DataFlowAnalysis;
68
69 LogicalResult initialize(Operation *top) override {
70 WalkResult result = top->walk(callback: [&](Operation *op) {
71 if (failed(result: visit(point: op)))
72 return WalkResult::interrupt();
73 return WalkResult::advance();
74 });
75 return success(isSuccess: !result.wasInterrupted());
76 }
77
78 LogicalResult visit(ProgramPoint point) override {
79 Operation *op = point.get<Operation *>();
80 Attribute value;
81 if (matchPattern(op, pattern: m_Constant(bind_value: &value))) {
82 auto *constant = getOrCreate<Lattice<ConstantValue>>(point: op->getResult(idx: 0));
83 propagateIfChanged(
84 state: constant, changed: constant->join(rhs: ConstantValue(value, op->getDialect())));
85 return success();
86 }
87 setAllToUnknownConstants(op->getResults());
88 for (Region &region : op->getRegions())
89 setAllToUnknownConstants(region.getArguments());
90 return success();
91 }
92
93 /// Set all given values as not constants.
94 void setAllToUnknownConstants(ValueRange values) {
95 for (Value value : values) {
96 auto *constant = getOrCreate<Lattice<ConstantValue>>(point: value);
97 propagateIfChanged(state: constant,
98 changed: constant->join(rhs: ConstantValue::getUnknownConstant()));
99 }
100 }
101};
102
103/// This is a simple pass that runs dead code analysis with a constant value
104/// provider that only understands constant operations.
105struct TestDeadCodeAnalysisPass
106 : public PassWrapper<TestDeadCodeAnalysisPass, OperationPass<>> {
107 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass)
108
109 StringRef getArgument() const override { return "test-dead-code-analysis"; }
110
111 void runOnOperation() override {
112 Operation *op = getOperation();
113
114 DataFlowSolver solver;
115 solver.load<DeadCodeAnalysis>();
116 solver.load<ConstantAnalysis>();
117 if (failed(result: solver.initializeAndRun(top: op)))
118 return signalPassFailure();
119 printAnalysisResults(solver, op, os&: llvm::errs());
120 }
121};
122} // end anonymous namespace
123
124namespace mlir {
125namespace test {
126void registerTestDeadCodeAnalysisPass() {
127 PassRegistration<TestDeadCodeAnalysisPass>();
128}
129} // end namespace test
130} // end namespace mlir
131

source code of mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp