1//===- TestDenseDataFlowAnalysis.h - Dataflow test utilities ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10#include "mlir/Analysis/DataFlowFramework.h"
11#include "mlir/IR/Value.h"
12#include "llvm/ADT/DenseMap.h"
13#include "llvm/Support/raw_ostream.h"
14#include <optional>
15
16namespace mlir {
17namespace dataflow {
18namespace test {
19
20/// This lattice represents a single underlying value for an SSA value.
21class UnderlyingValue {
22public:
23 /// Create an underlying value state with a known underlying value.
24 explicit UnderlyingValue(std::optional<Value> underlyingValue = std::nullopt)
25 : underlyingValue(underlyingValue) {}
26
27 /// Whether the state is uninitialized.
28 bool isUninitialized() const { return !underlyingValue.has_value(); }
29
30 /// Returns the underlying value.
31 Value getUnderlyingValue() const {
32 assert(!isUninitialized());
33 return *underlyingValue;
34 }
35
36 /// Join two underlying values. If there are conflicting underlying values,
37 /// go to the pessimistic value.
38 static UnderlyingValue join(const UnderlyingValue &lhs,
39 const UnderlyingValue &rhs) {
40 if (lhs.isUninitialized())
41 return rhs;
42 if (rhs.isUninitialized())
43 return lhs;
44 return lhs.underlyingValue == rhs.underlyingValue
45 ? lhs
46 : UnderlyingValue(Value{});
47 }
48
49 /// Compare underlying values.
50 bool operator==(const UnderlyingValue &rhs) const {
51 return underlyingValue == rhs.underlyingValue;
52 }
53
54 void print(raw_ostream &os) const { os << underlyingValue; }
55
56private:
57 std::optional<Value> underlyingValue;
58};
59
60class AdjacentAccess {
61public:
62 using DeterministicSetVector =
63 SetVector<Operation *, SmallVector<Operation *, 2>,
64 SmallPtrSet<Operation *, 2>>;
65
66 ArrayRef<Operation *> get() const { return accesses.getArrayRef(); }
67 bool isKnown() const { return !unknown; }
68
69 ChangeResult merge(const AdjacentAccess &other) {
70 if (unknown)
71 return ChangeResult::NoChange;
72 if (other.unknown) {
73 unknown = true;
74 accesses.clear();
75 return ChangeResult::Change;
76 }
77
78 size_t sizeBefore = accesses.size();
79 accesses.insert(Start: other.accesses.begin(), End: other.accesses.end());
80 return accesses.size() == sizeBefore ? ChangeResult::NoChange
81 : ChangeResult::Change;
82 }
83
84 ChangeResult set(Operation *op) {
85 if (!unknown && accesses.size() == 1 && *accesses.begin() == op)
86 return ChangeResult::NoChange;
87
88 unknown = false;
89 accesses.clear();
90 accesses.insert(X: op);
91 return ChangeResult::Change;
92 }
93
94 ChangeResult setUnknown() {
95 if (unknown)
96 return ChangeResult::NoChange;
97
98 accesses.clear();
99 unknown = true;
100 return ChangeResult::Change;
101 }
102
103 bool operator==(const AdjacentAccess &other) const {
104 return unknown == other.unknown && accesses == other.accesses;
105 }
106
107 bool operator!=(const AdjacentAccess &other) const {
108 return !operator==(other);
109 }
110
111private:
112 bool unknown = false;
113 DeterministicSetVector accesses;
114};
115
116/// This lattice represents, for a given memory resource, the potential last
117/// operations that modified the resource.
118class AccessLatticeBase {
119public:
120 /// Clear all modifications.
121 ChangeResult reset() {
122 if (adjAccesses.empty())
123 return ChangeResult::NoChange;
124 adjAccesses.clear();
125 return ChangeResult::Change;
126 }
127
128 /// Join the last modifications.
129 ChangeResult merge(const AccessLatticeBase &rhs) {
130 ChangeResult result = ChangeResult::NoChange;
131 for (const auto &mod : rhs.adjAccesses) {
132 AdjacentAccess &lhsMod = adjAccesses[mod.first];
133 result |= lhsMod.merge(other: mod.second);
134 }
135 return result;
136 }
137
138 /// Set the last modification of a value.
139 ChangeResult set(Value value, Operation *op) {
140 AdjacentAccess &lastMod = adjAccesses[value];
141 return lastMod.set(op);
142 }
143
144 ChangeResult setKnownToUnknown() {
145 ChangeResult result = ChangeResult::NoChange;
146 for (auto &[value, adjacent] : adjAccesses)
147 result |= adjacent.setUnknown();
148 return result;
149 }
150
151 /// Get the adjacent accesses to a value. Returns std::nullopt if they
152 /// are not known.
153 const AdjacentAccess *getAdjacentAccess(Value value) const {
154 auto it = adjAccesses.find(Val: value);
155 if (it == adjAccesses.end())
156 return nullptr;
157 return &it->getSecond();
158 }
159
160 void print(raw_ostream &os) const {
161 for (const auto &lastMod : adjAccesses) {
162 os << lastMod.first << ":\n";
163 if (!lastMod.second.isKnown()) {
164 os << " <unknown>\n";
165 return;
166 }
167 for (Operation *op : lastMod.second.get())
168 os << " " << *op << "\n";
169 }
170 }
171
172private:
173 /// The potential adjacent accesses to a memory resource. Use a set vector to
174 /// keep the results deterministic.
175 DenseMap<Value, AdjacentAccess> adjAccesses;
176};
177
178/// Define the lattice class explicitly to provide a type ID.
179struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
180 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
181 using Lattice::Lattice;
182};
183
184/// An analysis that uses forwarding of values along control-flow and callgraph
185/// edges to determine single underlying values for block arguments. This
186/// analysis exists so that the test analysis and pass can test the behaviour of
187/// the dense data-flow analysis on the callgraph.
188class UnderlyingValueAnalysis
189 : public SparseForwardDataFlowAnalysis<UnderlyingValueLattice> {
190public:
191 using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
192
193 /// The underlying value of the results of an operation are not known.
194 void visitOperation(Operation *op,
195 ArrayRef<const UnderlyingValueLattice *> operands,
196 ArrayRef<UnderlyingValueLattice *> results) override {
197 setAllToEntryStates(results);
198 }
199
200 /// At an entry point, the underlying value of a value is itself.
201 void setToEntryState(UnderlyingValueLattice *lattice) override {
202 propagateIfChanged(lattice,
203 lattice->join(rhs: UnderlyingValue{lattice->getPoint()}));
204 }
205
206 /// Look for the most underlying value of a value.
207 static std::optional<Value>
208 getMostUnderlyingValue(Value value,
209 function_ref<const UnderlyingValueLattice *(Value)>
210 getUnderlyingValueFn) {
211 const UnderlyingValueLattice *underlying;
212 do {
213 underlying = getUnderlyingValueFn(value);
214 if (!underlying || underlying->getValue().isUninitialized())
215 return std::nullopt;
216 Value underlyingValue = underlying->getValue().getUnderlyingValue();
217 if (underlyingValue == value)
218 break;
219 value = underlyingValue;
220 } while (true);
221 return value;
222 }
223};
224
225} // namespace test
226} // namespace dataflow
227} // namespace mlir
228

source code of mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h