1//===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlowFramework.h"
10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/Pass/Pass.h"
12#include <optional>
13
14using namespace mlir;
15
16namespace {
17/// This analysis state represents an integer that is XOR'd with other states.
18class FooState : public AnalysisState {
19public:
20 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
21
22 using AnalysisState::AnalysisState;
23
24 /// Returns true if the state is uninitialized.
25 bool isUninitialized() const { return !state; }
26
27 /// Print the integer value or "none" if uninitialized.
28 void print(raw_ostream &os) const override {
29 if (state)
30 os << *state;
31 else
32 os << "none";
33 }
34
35 /// Join the state with another. If either is unintialized, take the
36 /// initialized value. Otherwise, XOR the integer values.
37 ChangeResult join(const FooState &rhs) {
38 if (rhs.isUninitialized())
39 return ChangeResult::NoChange;
40 return join(value: *rhs.state);
41 }
42 ChangeResult join(uint64_t value) {
43 if (isUninitialized()) {
44 state = value;
45 return ChangeResult::Change;
46 }
47 uint64_t before = *state;
48 state = before ^ value;
49 return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
50 }
51
52 /// Set the value of the state directly.
53 ChangeResult set(const FooState &rhs) {
54 if (state == rhs.state)
55 return ChangeResult::NoChange;
56 state = rhs.state;
57 return ChangeResult::Change;
58 }
59
60 /// Returns the integer value of the state.
61 uint64_t getValue() const { return *state; }
62
63private:
64 /// An optional integer value.
65 std::optional<uint64_t> state;
66};
67
68/// This analysis computes `FooState` across operations and control-flow edges.
69/// If an op specifies a `foo` integer attribute, the contained value is XOR'd
70/// with the value before the operation.
71class FooAnalysis : public DataFlowAnalysis {
72public:
73 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
74
75 using DataFlowAnalysis::DataFlowAnalysis;
76
77 LogicalResult initialize(Operation *top) override;
78 LogicalResult visit(ProgramPoint *point) override;
79
80private:
81 void visitBlock(Block *block);
82 void visitOperation(Operation *op);
83};
84
85struct TestFooAnalysisPass
86 : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
87 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
88
89 StringRef getArgument() const override { return "test-foo-analysis"; }
90
91 void runOnOperation() override;
92};
93} // namespace
94
95LogicalResult FooAnalysis::initialize(Operation *top) {
96 if (top->getNumRegions() != 1)
97 return top->emitError(message: "expected a single region top-level op");
98
99 if (top->getRegion(index: 0).getBlocks().empty())
100 return top->emitError(message: "expected at least one block in the region");
101
102 // Initialize the top-level state.
103 (void)getOrCreate<FooState>(anchor: getProgramPointBefore(block: &top->getRegion(index: 0).front()))
104 ->join(value: 0);
105
106 // Visit all nested blocks and operations.
107 for (Block &block : top->getRegion(index: 0)) {
108 visitBlock(block: &block);
109 for (Operation &op : block) {
110 if (op.getNumRegions())
111 return op.emitError(message: "unexpected op with regions");
112 visitOperation(op: &op);
113 }
114 }
115 return success();
116}
117
118LogicalResult FooAnalysis::visit(ProgramPoint *point) {
119 if (!point->isBlockStart())
120 visitOperation(op: point->getPrevOp());
121 else
122 visitBlock(block: point->getBlock());
123 return success();
124}
125
126void FooAnalysis::visitBlock(Block *block) {
127 if (block->isEntryBlock()) {
128 // This is the initial state. Let the framework default-initialize it.
129 return;
130 }
131 ProgramPoint *point = getProgramPointBefore(block);
132 FooState *state = getOrCreate<FooState>(anchor: point);
133 ChangeResult result = ChangeResult::NoChange;
134 for (Block *pred : block->getPredecessors()) {
135 // Join the state at the terminators of all predecessors.
136 const FooState *predState = getOrCreateFor<FooState>(
137 dependent: point, anchor: getProgramPointAfter(op: pred->getTerminator()));
138 result |= state->join(rhs: *predState);
139 }
140 propagateIfChanged(state, changed: result);
141}
142
143void FooAnalysis::visitOperation(Operation *op) {
144 ProgramPoint *point = getProgramPointAfter(op);
145 FooState *state = getOrCreate<FooState>(anchor: point);
146 ChangeResult result = ChangeResult::NoChange;
147
148 // Copy the state across the operation.
149 const FooState *prevState;
150 prevState = getOrCreateFor<FooState>(dependent: point, anchor: getProgramPointBefore(op));
151 result |= state->set(*prevState);
152
153 // Modify the state with the attribute, if specified.
154 if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
155 uint64_t value = attr.getUInt();
156 result |= state->join(value);
157 }
158 propagateIfChanged(state, changed: result);
159}
160
161void TestFooAnalysisPass::runOnOperation() {
162 func::FuncOp func = getOperation();
163 DataFlowSolver solver;
164 solver.load<FooAnalysis>();
165 if (failed(solver.initializeAndRun(top: func)))
166 return signalPassFailure();
167
168 raw_ostream &os = llvm::errs();
169 os << "function: @" << func.getSymName() << "\n";
170
171 func.walk([&](Operation *op) {
172 auto tag = op->getAttrOfType<StringAttr>("tag");
173 if (!tag)
174 return;
175 const FooState *state =
176 solver.lookupState<FooState>(anchor: solver.getProgramPointAfter(op));
177 assert(state && !state->isUninitialized());
178 os << tag.getValue() << " -> " << state->getValue() << "\n";
179 });
180}
181
182namespace mlir {
183namespace test {
184void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
185} // namespace test
186} // namespace mlir
187

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Analysis/TestDataFlowFramework.cpp