1//===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlowFramework.h"
10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/Pass/Pass.h"
12#include <optional>
13
14using namespace mlir;
15
16namespace {
17/// This analysis state represents an integer that is XOR'd with other states.
18class FooState : public AnalysisState {
19public:
20 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
21
22 using AnalysisState::AnalysisState;
23
24 /// Returns true if the state is uninitialized.
25 bool isUninitialized() const { return !state; }
26
27 /// Print the integer value or "none" if uninitialized.
28 void print(raw_ostream &os) const override {
29 if (state)
30 os << *state;
31 else
32 os << "none";
33 }
34
35 /// Join the state with another. If either is unintialized, take the
36 /// initialized value. Otherwise, XOR the integer values.
37 ChangeResult join(const FooState &rhs) {
38 if (rhs.isUninitialized())
39 return ChangeResult::NoChange;
40 return join(value: *rhs.state);
41 }
42 ChangeResult join(uint64_t value) {
43 if (isUninitialized()) {
44 state = value;
45 return ChangeResult::Change;
46 }
47 uint64_t before = *state;
48 state = before ^ value;
49 return before == *state ? ChangeResult::NoChange : ChangeResult::Change;
50 }
51
52 /// Set the value of the state directly.
53 ChangeResult set(const FooState &rhs) {
54 if (state == rhs.state)
55 return ChangeResult::NoChange;
56 state = rhs.state;
57 return ChangeResult::Change;
58 }
59
60 /// Returns the integer value of the state.
61 uint64_t getValue() const { return *state; }
62
63private:
64 /// An optional integer value.
65 std::optional<uint64_t> state;
66};
67
68/// This analysis computes `FooState` across operations and control-flow edges.
69/// If an op specifies a `foo` integer attribute, the contained value is XOR'd
70/// with the value before the operation.
71class FooAnalysis : public DataFlowAnalysis {
72public:
73 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis)
74
75 using DataFlowAnalysis::DataFlowAnalysis;
76
77 LogicalResult initialize(Operation *top) override;
78 LogicalResult visit(ProgramPoint point) override;
79
80private:
81 void visitBlock(Block *block);
82 void visitOperation(Operation *op);
83};
84
85struct TestFooAnalysisPass
86 : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
87 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
88
89 StringRef getArgument() const override { return "test-foo-analysis"; }
90
91 void runOnOperation() override;
92};
93} // namespace
94
95LogicalResult FooAnalysis::initialize(Operation *top) {
96 if (top->getNumRegions() != 1)
97 return top->emitError(message: "expected a single region top-level op");
98
99 if (top->getRegion(index: 0).getBlocks().empty())
100 return top->emitError(message: "expected at least one block in the region");
101
102 // Initialize the top-level state.
103 (void)getOrCreate<FooState>(point: &top->getRegion(index: 0).front())->join(value: 0);
104
105 // Visit all nested blocks and operations.
106 for (Block &block : top->getRegion(index: 0)) {
107 visitBlock(block: &block);
108 for (Operation &op : block) {
109 if (op.getNumRegions())
110 return op.emitError(message: "unexpected op with regions");
111 visitOperation(op: &op);
112 }
113 }
114 return success();
115}
116
117LogicalResult FooAnalysis::visit(ProgramPoint point) {
118 if (auto *op = llvm::dyn_cast_if_present<Operation *>(Val&: point)) {
119 visitOperation(op);
120 return success();
121 }
122 if (auto *block = llvm::dyn_cast_if_present<Block *>(Val&: point)) {
123 visitBlock(block);
124 return success();
125 }
126 return emitError(loc: point.getLoc(), message: "unknown point kind");
127}
128
129void FooAnalysis::visitBlock(Block *block) {
130 if (block->isEntryBlock()) {
131 // This is the initial state. Let the framework default-initialize it.
132 return;
133 }
134 FooState *state = getOrCreate<FooState>(point: block);
135 ChangeResult result = ChangeResult::NoChange;
136 for (Block *pred : block->getPredecessors()) {
137 // Join the state at the terminators of all predecessors.
138 const FooState *predState =
139 getOrCreateFor<FooState>(dependent: block, point: pred->getTerminator());
140 result |= state->join(rhs: *predState);
141 }
142 propagateIfChanged(state, changed: result);
143}
144
145void FooAnalysis::visitOperation(Operation *op) {
146 FooState *state = getOrCreate<FooState>(point: op);
147 ChangeResult result = ChangeResult::NoChange;
148
149 // Copy the state across the operation.
150 const FooState *prevState;
151 if (Operation *prev = op->getPrevNode())
152 prevState = getOrCreateFor<FooState>(dependent: op, point: prev);
153 else
154 prevState = getOrCreateFor<FooState>(dependent: op, point: op->getBlock());
155 result |= state->set(*prevState);
156
157 // Modify the state with the attribute, if specified.
158 if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
159 uint64_t value = attr.getUInt();
160 result |= state->join(value);
161 }
162 propagateIfChanged(state, changed: result);
163}
164
165void TestFooAnalysisPass::runOnOperation() {
166 func::FuncOp func = getOperation();
167 DataFlowSolver solver;
168 solver.load<FooAnalysis>();
169 if (failed(solver.initializeAndRun(top: func)))
170 return signalPassFailure();
171
172 raw_ostream &os = llvm::errs();
173 os << "function: @" << func.getSymName() << "\n";
174
175 func.walk([&](Operation *op) {
176 auto tag = op->getAttrOfType<StringAttr>("tag");
177 if (!tag)
178 return;
179 const FooState *state = solver.lookupState<FooState>(point: op);
180 assert(state && !state->isUninitialized());
181 os << tag.getValue() << " -> " << state->getValue() << "\n";
182 });
183}
184
185namespace mlir {
186namespace test {
187void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
188} // namespace test
189} // namespace mlir
190

source code of mlir/test/lib/Analysis/TestDataFlowFramework.cpp