1 | //===- TestDenseDataFlowAnalysis.h - Dataflow test utilities ----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
10 | #include "mlir/Analysis/DataFlowFramework.h" |
11 | #include "mlir/IR/Value.h" |
12 | #include "llvm/ADT/DenseMap.h" |
13 | #include "llvm/Support/raw_ostream.h" |
14 | #include <optional> |
15 | |
16 | namespace mlir { |
17 | namespace dataflow { |
18 | namespace test { |
19 | |
20 | /// This lattice represents a single underlying value for an SSA value. |
21 | class UnderlyingValue { |
22 | public: |
23 | /// Create an underlying value state with a known underlying value. |
24 | explicit UnderlyingValue(std::optional<Value> underlyingValue = std::nullopt) |
25 | : underlyingValue(underlyingValue) {} |
26 | |
27 | /// Whether the state is uninitialized. |
28 | bool isUninitialized() const { return !underlyingValue.has_value(); } |
29 | |
30 | /// Returns the underlying value. |
31 | Value getUnderlyingValue() const { |
32 | assert(!isUninitialized()); |
33 | return *underlyingValue; |
34 | } |
35 | |
36 | /// Join two underlying values. If there are conflicting underlying values, |
37 | /// go to the pessimistic value. |
38 | static UnderlyingValue join(const UnderlyingValue &lhs, |
39 | const UnderlyingValue &rhs) { |
40 | if (lhs.isUninitialized()) |
41 | return rhs; |
42 | if (rhs.isUninitialized()) |
43 | return lhs; |
44 | return lhs.underlyingValue == rhs.underlyingValue |
45 | ? lhs |
46 | : UnderlyingValue(Value{}); |
47 | } |
48 | |
49 | /// Compare underlying values. |
50 | bool operator==(const UnderlyingValue &rhs) const { |
51 | return underlyingValue == rhs.underlyingValue; |
52 | } |
53 | |
54 | void print(raw_ostream &os) const { os << underlyingValue; } |
55 | |
56 | private: |
57 | std::optional<Value> underlyingValue; |
58 | }; |
59 | |
60 | class AdjacentAccess { |
61 | public: |
62 | using DeterministicSetVector = |
63 | SetVector<Operation *, SmallVector<Operation *, 2>, |
64 | SmallPtrSet<Operation *, 2>>; |
65 | |
66 | ArrayRef<Operation *> get() const { return accesses.getArrayRef(); } |
67 | bool isKnown() const { return !unknown; } |
68 | |
69 | ChangeResult merge(const AdjacentAccess &other) { |
70 | if (unknown) |
71 | return ChangeResult::NoChange; |
72 | if (other.unknown) { |
73 | unknown = true; |
74 | accesses.clear(); |
75 | return ChangeResult::Change; |
76 | } |
77 | |
78 | size_t sizeBefore = accesses.size(); |
79 | accesses.insert(Start: other.accesses.begin(), End: other.accesses.end()); |
80 | return accesses.size() == sizeBefore ? ChangeResult::NoChange |
81 | : ChangeResult::Change; |
82 | } |
83 | |
84 | ChangeResult set(Operation *op) { |
85 | if (!unknown && accesses.size() == 1 && *accesses.begin() == op) |
86 | return ChangeResult::NoChange; |
87 | |
88 | unknown = false; |
89 | accesses.clear(); |
90 | accesses.insert(X: op); |
91 | return ChangeResult::Change; |
92 | } |
93 | |
94 | ChangeResult setUnknown() { |
95 | if (unknown) |
96 | return ChangeResult::NoChange; |
97 | |
98 | accesses.clear(); |
99 | unknown = true; |
100 | return ChangeResult::Change; |
101 | } |
102 | |
103 | bool operator==(const AdjacentAccess &other) const { |
104 | return unknown == other.unknown && accesses == other.accesses; |
105 | } |
106 | |
107 | bool operator!=(const AdjacentAccess &other) const { |
108 | return !operator==(other); |
109 | } |
110 | |
111 | private: |
112 | bool unknown = false; |
113 | DeterministicSetVector accesses; |
114 | }; |
115 | |
116 | /// This lattice represents, for a given memory resource, the potential last |
117 | /// operations that modified the resource. |
118 | class AccessLatticeBase { |
119 | public: |
120 | /// Clear all modifications. |
121 | ChangeResult reset() { |
122 | if (adjAccesses.empty()) |
123 | return ChangeResult::NoChange; |
124 | adjAccesses.clear(); |
125 | return ChangeResult::Change; |
126 | } |
127 | |
128 | /// Join the last modifications. |
129 | ChangeResult merge(const AccessLatticeBase &rhs) { |
130 | ChangeResult result = ChangeResult::NoChange; |
131 | for (const auto &mod : rhs.adjAccesses) { |
132 | AdjacentAccess &lhsMod = adjAccesses[mod.first]; |
133 | result |= lhsMod.merge(other: mod.second); |
134 | } |
135 | return result; |
136 | } |
137 | |
138 | /// Set the last modification of a value. |
139 | ChangeResult set(Value value, Operation *op) { |
140 | AdjacentAccess &lastMod = adjAccesses[value]; |
141 | return lastMod.set(op); |
142 | } |
143 | |
144 | ChangeResult setKnownToUnknown() { |
145 | ChangeResult result = ChangeResult::NoChange; |
146 | for (auto &[value, adjacent] : adjAccesses) |
147 | result |= adjacent.setUnknown(); |
148 | return result; |
149 | } |
150 | |
151 | /// Get the adjacent accesses to a value. Returns std::nullopt if they |
152 | /// are not known. |
153 | const AdjacentAccess *getAdjacentAccess(Value value) const { |
154 | auto it = adjAccesses.find(Val: value); |
155 | if (it == adjAccesses.end()) |
156 | return nullptr; |
157 | return &it->getSecond(); |
158 | } |
159 | |
160 | void print(raw_ostream &os) const { |
161 | for (const auto &lastMod : adjAccesses) { |
162 | os << lastMod.first << ":\n" ; |
163 | if (!lastMod.second.isKnown()) { |
164 | os << " <unknown>\n" ; |
165 | return; |
166 | } |
167 | for (Operation *op : lastMod.second.get()) |
168 | os << " " << *op << "\n" ; |
169 | } |
170 | } |
171 | |
172 | private: |
173 | /// The potential adjacent accesses to a memory resource. Use a set vector to |
174 | /// keep the results deterministic. |
175 | DenseMap<Value, AdjacentAccess> adjAccesses; |
176 | }; |
177 | |
178 | /// Define the lattice class explicitly to provide a type ID. |
179 | struct UnderlyingValueLattice : public Lattice<UnderlyingValue> { |
180 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice) |
181 | using Lattice::Lattice; |
182 | }; |
183 | |
184 | /// An analysis that uses forwarding of values along control-flow and callgraph |
185 | /// edges to determine single underlying values for block arguments. This |
186 | /// analysis exists so that the test analysis and pass can test the behaviour of |
187 | /// the dense data-flow analysis on the callgraph. |
188 | class UnderlyingValueAnalysis |
189 | : public SparseForwardDataFlowAnalysis<UnderlyingValueLattice> { |
190 | public: |
191 | using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; |
192 | |
193 | /// The underlying value of the results of an operation are not known. |
194 | void visitOperation(Operation *op, |
195 | ArrayRef<const UnderlyingValueLattice *> operands, |
196 | ArrayRef<UnderlyingValueLattice *> results) override { |
197 | setAllToEntryStates(results); |
198 | } |
199 | |
200 | /// At an entry point, the underlying value of a value is itself. |
201 | void setToEntryState(UnderlyingValueLattice *lattice) override { |
202 | propagateIfChanged(lattice, |
203 | lattice->join(rhs: UnderlyingValue{lattice->getPoint()})); |
204 | } |
205 | |
206 | /// Look for the most underlying value of a value. |
207 | static std::optional<Value> |
208 | getMostUnderlyingValue(Value value, |
209 | function_ref<const UnderlyingValueLattice *(Value)> |
210 | getUnderlyingValueFn) { |
211 | const UnderlyingValueLattice *underlying; |
212 | do { |
213 | underlying = getUnderlyingValueFn(value); |
214 | if (!underlying || underlying->getValue().isUninitialized()) |
215 | return std::nullopt; |
216 | Value underlyingValue = underlying->getValue().getUnderlyingValue(); |
217 | if (underlyingValue == value) |
218 | break; |
219 | value = underlyingValue; |
220 | } while (true); |
221 | return value; |
222 | } |
223 | }; |
224 | |
225 | } // namespace test |
226 | } // namespace dataflow |
227 | } // namespace mlir |
228 | |