1 | //===- DataFlowFramework.cpp - A generic framework for data-flow analysis -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/DataFlowFramework.h" |
10 | #include "mlir/IR/Location.h" |
11 | #include "mlir/IR/Operation.h" |
12 | #include "mlir/IR/Value.h" |
13 | #include "mlir/Support/LogicalResult.h" |
14 | #include "llvm/ADT/iterator.h" |
15 | #include "llvm/Config/abi-breaking.h" |
16 | #include "llvm/Support/Casting.h" |
17 | #include "llvm/Support/Debug.h" |
18 | #include "llvm/Support/raw_ostream.h" |
19 | |
20 | #define DEBUG_TYPE "dataflow" |
21 | #if LLVM_ENABLE_ABI_BREAKING_CHECKS |
22 | #define DATAFLOW_DEBUG(X) LLVM_DEBUG(X) |
23 | #else |
24 | #define DATAFLOW_DEBUG(X) |
25 | #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS |
26 | |
27 | using namespace mlir; |
28 | |
29 | //===----------------------------------------------------------------------===// |
30 | // GenericProgramPoint |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | GenericProgramPoint::~GenericProgramPoint() = default; |
34 | |
35 | //===----------------------------------------------------------------------===// |
36 | // AnalysisState |
37 | //===----------------------------------------------------------------------===// |
38 | |
39 | AnalysisState::~AnalysisState() = default; |
40 | |
41 | void AnalysisState::addDependency(ProgramPoint dependent, |
42 | DataFlowAnalysis *analysis) { |
43 | auto inserted = dependents.insert(X: {dependent, analysis}); |
44 | (void)inserted; |
45 | DATAFLOW_DEBUG({ |
46 | if (inserted) { |
47 | llvm::dbgs() << "Creating dependency between " << debugName << " of " |
48 | << point << "\nand " << debugName << " on " << dependent |
49 | << "\n" ; |
50 | } |
51 | }); |
52 | } |
53 | |
54 | void AnalysisState::dump() const { print(os&: llvm::errs()); } |
55 | |
56 | //===----------------------------------------------------------------------===// |
57 | // ProgramPoint |
58 | //===----------------------------------------------------------------------===// |
59 | |
60 | void ProgramPoint::print(raw_ostream &os) const { |
61 | if (isNull()) { |
62 | os << "<NULL POINT>" ; |
63 | return; |
64 | } |
65 | if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(Val: *this)) |
66 | return programPoint->print(os); |
67 | if (auto *op = llvm::dyn_cast<Operation *>(Val: *this)) |
68 | return op->print(os, flags: OpPrintingFlags().skipRegions()); |
69 | if (auto value = llvm::dyn_cast<Value>(Val: *this)) |
70 | return value.print(os, flags: OpPrintingFlags().skipRegions()); |
71 | return get<Block *>()->print(os); |
72 | } |
73 | |
74 | Location ProgramPoint::getLoc() const { |
75 | if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(Val: *this)) |
76 | return programPoint->getLoc(); |
77 | if (auto *op = llvm::dyn_cast<Operation *>(Val: *this)) |
78 | return op->getLoc(); |
79 | if (auto value = llvm::dyn_cast<Value>(Val: *this)) |
80 | return value.getLoc(); |
81 | return get<Block *>()->getParent()->getLoc(); |
82 | } |
83 | |
84 | //===----------------------------------------------------------------------===// |
85 | // DataFlowSolver |
86 | //===----------------------------------------------------------------------===// |
87 | |
88 | LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { |
89 | // Initialize the analyses. |
90 | for (DataFlowAnalysis &analysis : llvm::make_pointee_range(Range&: childAnalyses)) { |
91 | DATAFLOW_DEBUG(llvm::dbgs() |
92 | << "Priming analysis: " << analysis.debugName << "\n" ); |
93 | if (failed(result: analysis.initialize(top))) |
94 | return failure(); |
95 | } |
96 | |
97 | // Run the analysis until fixpoint. |
98 | do { |
99 | // Exhaust the worklist. |
100 | while (!worklist.empty()) { |
101 | auto [point, analysis] = worklist.front(); |
102 | worklist.pop(); |
103 | |
104 | DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName |
105 | << "' on: " << point << "\n" ); |
106 | if (failed(result: analysis->visit(point))) |
107 | return failure(); |
108 | } |
109 | |
110 | // Iterate until all states are in some initialized state and the worklist |
111 | // is exhausted. |
112 | } while (!worklist.empty()); |
113 | |
114 | return success(); |
115 | } |
116 | |
117 | void DataFlowSolver::propagateIfChanged(AnalysisState *state, |
118 | ChangeResult changed) { |
119 | if (changed == ChangeResult::Change) { |
120 | DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName |
121 | << " of " << state->point << "\n" |
122 | << "Value: " << *state << "\n" ); |
123 | state->onUpdate(solver: this); |
124 | } |
125 | } |
126 | |
127 | //===----------------------------------------------------------------------===// |
128 | // DataFlowAnalysis |
129 | //===----------------------------------------------------------------------===// |
130 | |
131 | DataFlowAnalysis::~DataFlowAnalysis() = default; |
132 | |
133 | DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {} |
134 | |
135 | void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) { |
136 | state->addDependency(dependent: point, analysis: this); |
137 | } |
138 | |
139 | void DataFlowAnalysis::propagateIfChanged(AnalysisState *state, |
140 | ChangeResult changed) { |
141 | solver.propagateIfChanged(state, changed); |
142 | } |
143 | |