1 | //===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/DataFlowFramework.h" |
10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
11 | #include "mlir/Pass/Pass.h" |
12 | #include <optional> |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace { |
17 | /// This analysis state represents an integer that is XOR'd with other states. |
18 | class FooState : public AnalysisState { |
19 | public: |
20 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState) |
21 | |
22 | using AnalysisState::AnalysisState; |
23 | |
24 | /// Returns true if the state is uninitialized. |
25 | bool isUninitialized() const { return !state; } |
26 | |
27 | /// Print the integer value or "none" if uninitialized. |
28 | void print(raw_ostream &os) const override { |
29 | if (state) |
30 | os << *state; |
31 | else |
32 | os << "none" ; |
33 | } |
34 | |
35 | /// Join the state with another. If either is unintialized, take the |
36 | /// initialized value. Otherwise, XOR the integer values. |
37 | ChangeResult join(const FooState &rhs) { |
38 | if (rhs.isUninitialized()) |
39 | return ChangeResult::NoChange; |
40 | return join(value: *rhs.state); |
41 | } |
42 | ChangeResult join(uint64_t value) { |
43 | if (isUninitialized()) { |
44 | state = value; |
45 | return ChangeResult::Change; |
46 | } |
47 | uint64_t before = *state; |
48 | state = before ^ value; |
49 | return before == *state ? ChangeResult::NoChange : ChangeResult::Change; |
50 | } |
51 | |
52 | /// Set the value of the state directly. |
53 | ChangeResult set(const FooState &rhs) { |
54 | if (state == rhs.state) |
55 | return ChangeResult::NoChange; |
56 | state = rhs.state; |
57 | return ChangeResult::Change; |
58 | } |
59 | |
60 | /// Returns the integer value of the state. |
61 | uint64_t getValue() const { return *state; } |
62 | |
63 | private: |
64 | /// An optional integer value. |
65 | std::optional<uint64_t> state; |
66 | }; |
67 | |
68 | /// This analysis computes `FooState` across operations and control-flow edges. |
69 | /// If an op specifies a `foo` integer attribute, the contained value is XOR'd |
70 | /// with the value before the operation. |
71 | class FooAnalysis : public DataFlowAnalysis { |
72 | public: |
73 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis) |
74 | |
75 | using DataFlowAnalysis::DataFlowAnalysis; |
76 | |
77 | LogicalResult initialize(Operation *top) override; |
78 | LogicalResult visit(ProgramPoint *point) override; |
79 | |
80 | private: |
81 | void visitBlock(Block *block); |
82 | void visitOperation(Operation *op); |
83 | }; |
84 | |
85 | struct TestFooAnalysisPass |
86 | : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> { |
87 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass) |
88 | |
89 | StringRef getArgument() const override { return "test-foo-analysis" ; } |
90 | |
91 | void runOnOperation() override; |
92 | }; |
93 | } // namespace |
94 | |
95 | LogicalResult FooAnalysis::initialize(Operation *top) { |
96 | if (top->getNumRegions() != 1) |
97 | return top->emitError(message: "expected a single region top-level op" ); |
98 | |
99 | if (top->getRegion(index: 0).getBlocks().empty()) |
100 | return top->emitError(message: "expected at least one block in the region" ); |
101 | |
102 | // Initialize the top-level state. |
103 | (void)getOrCreate<FooState>(anchor: getProgramPointBefore(block: &top->getRegion(index: 0).front())) |
104 | ->join(value: 0); |
105 | |
106 | // Visit all nested blocks and operations. |
107 | for (Block &block : top->getRegion(index: 0)) { |
108 | visitBlock(block: &block); |
109 | for (Operation &op : block) { |
110 | if (op.getNumRegions()) |
111 | return op.emitError(message: "unexpected op with regions" ); |
112 | visitOperation(op: &op); |
113 | } |
114 | } |
115 | return success(); |
116 | } |
117 | |
118 | LogicalResult FooAnalysis::visit(ProgramPoint *point) { |
119 | if (!point->isBlockStart()) |
120 | visitOperation(op: point->getPrevOp()); |
121 | else |
122 | visitBlock(block: point->getBlock()); |
123 | return success(); |
124 | } |
125 | |
126 | void FooAnalysis::visitBlock(Block *block) { |
127 | if (block->isEntryBlock()) { |
128 | // This is the initial state. Let the framework default-initialize it. |
129 | return; |
130 | } |
131 | ProgramPoint *point = getProgramPointBefore(block); |
132 | FooState *state = getOrCreate<FooState>(anchor: point); |
133 | ChangeResult result = ChangeResult::NoChange; |
134 | for (Block *pred : block->getPredecessors()) { |
135 | // Join the state at the terminators of all predecessors. |
136 | const FooState *predState = getOrCreateFor<FooState>( |
137 | dependent: point, anchor: getProgramPointAfter(op: pred->getTerminator())); |
138 | result |= state->join(rhs: *predState); |
139 | } |
140 | propagateIfChanged(state, changed: result); |
141 | } |
142 | |
143 | void FooAnalysis::visitOperation(Operation *op) { |
144 | ProgramPoint *point = getProgramPointAfter(op); |
145 | FooState *state = getOrCreate<FooState>(anchor: point); |
146 | ChangeResult result = ChangeResult::NoChange; |
147 | |
148 | // Copy the state across the operation. |
149 | const FooState *prevState; |
150 | prevState = getOrCreateFor<FooState>(dependent: point, anchor: getProgramPointBefore(op)); |
151 | result |= state->set(*prevState); |
152 | |
153 | // Modify the state with the attribute, if specified. |
154 | if (auto attr = op->getAttrOfType<IntegerAttr>("foo" )) { |
155 | uint64_t value = attr.getUInt(); |
156 | result |= state->join(value); |
157 | } |
158 | propagateIfChanged(state, changed: result); |
159 | } |
160 | |
161 | void TestFooAnalysisPass::runOnOperation() { |
162 | func::FuncOp func = getOperation(); |
163 | DataFlowSolver solver; |
164 | solver.load<FooAnalysis>(); |
165 | if (failed(solver.initializeAndRun(top: func))) |
166 | return signalPassFailure(); |
167 | |
168 | raw_ostream &os = llvm::errs(); |
169 | os << "function: @" << func.getSymName() << "\n" ; |
170 | |
171 | func.walk([&](Operation *op) { |
172 | auto tag = op->getAttrOfType<StringAttr>("tag" ); |
173 | if (!tag) |
174 | return; |
175 | const FooState *state = |
176 | solver.lookupState<FooState>(anchor: solver.getProgramPointAfter(op)); |
177 | assert(state && !state->isUninitialized()); |
178 | os << tag.getValue() << " -> " << state->getValue() << "\n" ; |
179 | }); |
180 | } |
181 | |
182 | namespace mlir { |
183 | namespace test { |
184 | void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); } |
185 | } // namespace test |
186 | } // namespace mlir |
187 | |