1 | //===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/DataFlowFramework.h" |
10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
11 | #include "mlir/Pass/Pass.h" |
12 | #include <optional> |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace { |
17 | /// This analysis state represents an integer that is XOR'd with other states. |
18 | class FooState : public AnalysisState { |
19 | public: |
20 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState) |
21 | |
22 | using AnalysisState::AnalysisState; |
23 | |
24 | /// Returns true if the state is uninitialized. |
25 | bool isUninitialized() const { return !state; } |
26 | |
27 | /// Print the integer value or "none" if uninitialized. |
28 | void print(raw_ostream &os) const override { |
29 | if (state) |
30 | os << *state; |
31 | else |
32 | os << "none" ; |
33 | } |
34 | |
35 | /// Join the state with another. If either is unintialized, take the |
36 | /// initialized value. Otherwise, XOR the integer values. |
37 | ChangeResult join(const FooState &rhs) { |
38 | if (rhs.isUninitialized()) |
39 | return ChangeResult::NoChange; |
40 | return join(value: *rhs.state); |
41 | } |
42 | ChangeResult join(uint64_t value) { |
43 | if (isUninitialized()) { |
44 | state = value; |
45 | return ChangeResult::Change; |
46 | } |
47 | uint64_t before = *state; |
48 | state = before ^ value; |
49 | return before == *state ? ChangeResult::NoChange : ChangeResult::Change; |
50 | } |
51 | |
52 | /// Set the value of the state directly. |
53 | ChangeResult set(const FooState &rhs) { |
54 | if (state == rhs.state) |
55 | return ChangeResult::NoChange; |
56 | state = rhs.state; |
57 | return ChangeResult::Change; |
58 | } |
59 | |
60 | /// Returns the integer value of the state. |
61 | uint64_t getValue() const { return *state; } |
62 | |
63 | private: |
64 | /// An optional integer value. |
65 | std::optional<uint64_t> state; |
66 | }; |
67 | |
68 | /// This analysis computes `FooState` across operations and control-flow edges. |
69 | /// If an op specifies a `foo` integer attribute, the contained value is XOR'd |
70 | /// with the value before the operation. |
71 | class FooAnalysis : public DataFlowAnalysis { |
72 | public: |
73 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis) |
74 | |
75 | using DataFlowAnalysis::DataFlowAnalysis; |
76 | |
77 | LogicalResult initialize(Operation *top) override; |
78 | LogicalResult visit(ProgramPoint point) override; |
79 | |
80 | private: |
81 | void visitBlock(Block *block); |
82 | void visitOperation(Operation *op); |
83 | }; |
84 | |
85 | struct TestFooAnalysisPass |
86 | : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> { |
87 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass) |
88 | |
89 | StringRef getArgument() const override { return "test-foo-analysis" ; } |
90 | |
91 | void runOnOperation() override; |
92 | }; |
93 | } // namespace |
94 | |
95 | LogicalResult FooAnalysis::initialize(Operation *top) { |
96 | if (top->getNumRegions() != 1) |
97 | return top->emitError(message: "expected a single region top-level op" ); |
98 | |
99 | if (top->getRegion(index: 0).getBlocks().empty()) |
100 | return top->emitError(message: "expected at least one block in the region" ); |
101 | |
102 | // Initialize the top-level state. |
103 | (void)getOrCreate<FooState>(point: &top->getRegion(index: 0).front())->join(value: 0); |
104 | |
105 | // Visit all nested blocks and operations. |
106 | for (Block &block : top->getRegion(index: 0)) { |
107 | visitBlock(block: &block); |
108 | for (Operation &op : block) { |
109 | if (op.getNumRegions()) |
110 | return op.emitError(message: "unexpected op with regions" ); |
111 | visitOperation(op: &op); |
112 | } |
113 | } |
114 | return success(); |
115 | } |
116 | |
117 | LogicalResult FooAnalysis::visit(ProgramPoint point) { |
118 | if (auto *op = llvm::dyn_cast_if_present<Operation *>(Val&: point)) { |
119 | visitOperation(op); |
120 | return success(); |
121 | } |
122 | if (auto *block = llvm::dyn_cast_if_present<Block *>(Val&: point)) { |
123 | visitBlock(block); |
124 | return success(); |
125 | } |
126 | return emitError(loc: point.getLoc(), message: "unknown point kind" ); |
127 | } |
128 | |
129 | void FooAnalysis::visitBlock(Block *block) { |
130 | if (block->isEntryBlock()) { |
131 | // This is the initial state. Let the framework default-initialize it. |
132 | return; |
133 | } |
134 | FooState *state = getOrCreate<FooState>(point: block); |
135 | ChangeResult result = ChangeResult::NoChange; |
136 | for (Block *pred : block->getPredecessors()) { |
137 | // Join the state at the terminators of all predecessors. |
138 | const FooState *predState = |
139 | getOrCreateFor<FooState>(dependent: block, point: pred->getTerminator()); |
140 | result |= state->join(rhs: *predState); |
141 | } |
142 | propagateIfChanged(state, changed: result); |
143 | } |
144 | |
145 | void FooAnalysis::visitOperation(Operation *op) { |
146 | FooState *state = getOrCreate<FooState>(point: op); |
147 | ChangeResult result = ChangeResult::NoChange; |
148 | |
149 | // Copy the state across the operation. |
150 | const FooState *prevState; |
151 | if (Operation *prev = op->getPrevNode()) |
152 | prevState = getOrCreateFor<FooState>(dependent: op, point: prev); |
153 | else |
154 | prevState = getOrCreateFor<FooState>(dependent: op, point: op->getBlock()); |
155 | result |= state->set(*prevState); |
156 | |
157 | // Modify the state with the attribute, if specified. |
158 | if (auto attr = op->getAttrOfType<IntegerAttr>("foo" )) { |
159 | uint64_t value = attr.getUInt(); |
160 | result |= state->join(value); |
161 | } |
162 | propagateIfChanged(state, changed: result); |
163 | } |
164 | |
165 | void TestFooAnalysisPass::runOnOperation() { |
166 | func::FuncOp func = getOperation(); |
167 | DataFlowSolver solver; |
168 | solver.load<FooAnalysis>(); |
169 | if (failed(solver.initializeAndRun(top: func))) |
170 | return signalPassFailure(); |
171 | |
172 | raw_ostream &os = llvm::errs(); |
173 | os << "function: @" << func.getSymName() << "\n" ; |
174 | |
175 | func.walk([&](Operation *op) { |
176 | auto tag = op->getAttrOfType<StringAttr>("tag" ); |
177 | if (!tag) |
178 | return; |
179 | const FooState *state = solver.lookupState<FooState>(point: op); |
180 | assert(state && !state->isUninitialized()); |
181 | os << tag.getValue() << " -> " << state->getValue() << "\n" ; |
182 | }); |
183 | } |
184 | |
185 | namespace mlir { |
186 | namespace test { |
187 | void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); } |
188 | } // namespace test |
189 | } // namespace mlir |
190 | |