1 | //===- TestDenseForwardDataFlowAnalysis.cpp -------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implementation of tests passes exercising dense forward data flow analysis. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "TestDenseDataFlowAnalysis.h" |
14 | #include "TestDialect.h" |
15 | #include "TestOps.h" |
16 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
17 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
18 | #include "mlir/Analysis/DataFlow/DenseAnalysis.h" |
19 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include "mlir/Support/LLVM.h" |
22 | #include "llvm/ADT/TypeSwitch.h" |
23 | #include <optional> |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::dataflow; |
27 | using namespace mlir::dataflow::test; |
28 | |
29 | namespace { |
30 | |
31 | /// This lattice represents, for a given memory resource, the potential last |
32 | /// operations that modified the resource. |
33 | class LastModification : public AbstractDenseLattice, public AccessLatticeBase { |
34 | public: |
35 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification) |
36 | |
37 | using AbstractDenseLattice::AbstractDenseLattice; |
38 | |
39 | /// Join the last modifications. |
40 | ChangeResult join(const AbstractDenseLattice &lattice) override { |
41 | return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>( |
42 | static_cast<const LastModification &>(lattice))); |
43 | } |
44 | |
45 | void print(raw_ostream &os) const override { |
46 | return AccessLatticeBase::print(os); |
47 | } |
48 | }; |
49 | |
50 | class LastModifiedAnalysis |
51 | : public DenseForwardDataFlowAnalysis<LastModification> { |
52 | public: |
53 | explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites) |
54 | : DenseForwardDataFlowAnalysis(solver), |
55 | assumeFuncWrites(assumeFuncWrites) {} |
56 | |
57 | /// Visit an operation. If the operation has no memory effects, then the state |
58 | /// is propagated with no change. If the operation allocates a resource, then |
59 | /// its reaching definitions is set to empty. If the operation writes to a |
60 | /// resource, then its reaching definition is set to the written value. |
61 | LogicalResult visitOperation(Operation *op, const LastModification &before, |
62 | LastModification *after) override; |
63 | |
64 | void visitCallControlFlowTransfer(CallOpInterface call, |
65 | CallControlFlowAction action, |
66 | const LastModification &before, |
67 | LastModification *after) override; |
68 | |
69 | void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, |
70 | std::optional<unsigned> regionFrom, |
71 | std::optional<unsigned> regionTo, |
72 | const LastModification &before, |
73 | LastModification *after) override; |
74 | |
75 | /// Visit an operation. If this analysis can confirm that lattice content |
76 | /// of lattice anchors around operation are necessarily identical, join |
77 | /// them into the same equivalent class. |
78 | void buildOperationEquivalentLatticeAnchor(Operation *op) override; |
79 | |
80 | /// At an entry point, the last modifications of all memory resources are |
81 | /// unknown. |
82 | void setToEntryState(LastModification *lattice) override { |
83 | propagateIfChanged(state: lattice, changed: lattice->reset()); |
84 | } |
85 | |
86 | private: |
87 | const bool assumeFuncWrites; |
88 | }; |
89 | } // end anonymous namespace |
90 | |
91 | LogicalResult LastModifiedAnalysis::visitOperation( |
92 | Operation *op, const LastModification &before, LastModification *after) { |
93 | auto memory = dyn_cast<MemoryEffectOpInterface>(op); |
94 | // If we can't reason about the memory effects, then conservatively assume we |
95 | // can't deduce anything about the last modifications. |
96 | if (!memory) { |
97 | setToEntryState(after); |
98 | return success(); |
99 | } |
100 | |
101 | SmallVector<MemoryEffects::EffectInstance> effects; |
102 | memory.getEffects(effects); |
103 | |
104 | // First, check if all underlying values are already known. Otherwise, avoid |
105 | // propagating and stay in the "undefined" state to avoid incorrectly |
106 | // propagating values that may be overwritten later on as that could be |
107 | // problematic for convergence based on monotonicity of lattice updates. |
108 | SmallVector<Value> underlyingValues; |
109 | underlyingValues.reserve(N: effects.size()); |
110 | for (const auto &effect : effects) { |
111 | Value value = effect.getValue(); |
112 | |
113 | // If we see an effect on anything other than a value, assume we can't |
114 | // deduce anything about the last modifications. |
115 | if (!value) { |
116 | setToEntryState(after); |
117 | return success(); |
118 | } |
119 | |
120 | // If we cannot find the underlying value, we shouldn't just propagate the |
121 | // effects through, return the pessimistic state. |
122 | std::optional<Value> underlyingValue = |
123 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
124 | value, getUnderlyingValueFn: [&](Value value) { |
125 | return getOrCreateFor<UnderlyingValueLattice>( |
126 | dependent: getProgramPointAfter(op), anchor: value); |
127 | }); |
128 | |
129 | // If the underlying value is not yet known, don't propagate yet. |
130 | if (!underlyingValue) |
131 | return success(); |
132 | |
133 | underlyingValues.push_back(Elt: *underlyingValue); |
134 | } |
135 | |
136 | // Update the state when all underlying values are known. |
137 | ChangeResult result = after->join(lattice: before); |
138 | for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) { |
139 | // If the underlying value is known to be unknown, set to fixpoint state. |
140 | if (!value) { |
141 | setToEntryState(after); |
142 | return success(); |
143 | } |
144 | |
145 | // Nothing to do for reads. |
146 | if (isa<MemoryEffects::Read>(Val: effect.getEffect())) |
147 | continue; |
148 | |
149 | result |= after->set(value, op); |
150 | } |
151 | propagateIfChanged(state: after, changed: result); |
152 | return success(); |
153 | } |
154 | |
155 | void LastModifiedAnalysis::buildOperationEquivalentLatticeAnchor( |
156 | Operation *op) { |
157 | if (isMemoryEffectFree(op)) { |
158 | unionLatticeAnchors<LastModification>(anchor: getProgramPointBefore(op), |
159 | other: getProgramPointAfter(op)); |
160 | } |
161 | } |
162 | |
163 | void LastModifiedAnalysis::visitCallControlFlowTransfer( |
164 | CallOpInterface call, CallControlFlowAction action, |
165 | const LastModification &before, LastModification *after) { |
166 | if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) { |
167 | SmallVector<Value> underlyingValues; |
168 | underlyingValues.reserve(N: call->getNumOperands()); |
169 | for (Value operand : call.getArgOperands()) { |
170 | std::optional<Value> underlyingValue = |
171 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
172 | operand, [&](Value value) { |
173 | return getOrCreateFor<UnderlyingValueLattice>( |
174 | getProgramPointAfter(call.getOperation()), value); |
175 | }); |
176 | if (!underlyingValue) |
177 | return; |
178 | underlyingValues.push_back(*underlyingValue); |
179 | } |
180 | |
181 | ChangeResult result = after->join(lattice: before); |
182 | for (Value operand : underlyingValues) |
183 | result |= after->set(value: operand, op: call); |
184 | return propagateIfChanged(state: after, changed: result); |
185 | } |
186 | auto testCallAndStore = |
187 | dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); |
188 | if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && |
189 | testCallAndStore.getStoreBeforeCall()) || |
190 | (action == CallControlFlowAction::ExitCallee && |
191 | !testCallAndStore.getStoreBeforeCall()))) { |
192 | (void)visitOperation(op: call, before, after); |
193 | return; |
194 | } |
195 | AbstractDenseForwardDataFlowAnalysis::visitCallControlFlowTransfer( |
196 | call: call, action, before, after); |
197 | } |
198 | |
199 | void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer( |
200 | RegionBranchOpInterface branch, std::optional<unsigned> regionFrom, |
201 | std::optional<unsigned> regionTo, const LastModification &before, |
202 | LastModification *after) { |
203 | auto defaultHandling = [&]() { |
204 | AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer( |
205 | branch: branch, regionFrom, regionTo, before, after); |
206 | }; |
207 | TypeSwitch<Operation *>(branch.getOperation()) |
208 | .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>( |
209 | [=](auto storeWithRegion) { |
210 | if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) || |
211 | (!regionFrom && storeWithRegion.getStoreBeforeRegion())) |
212 | (void)visitOperation(branch, before, after); |
213 | defaultHandling(); |
214 | }) |
215 | .Default([=](auto) { defaultHandling(); }); |
216 | } |
217 | |
218 | namespace { |
219 | struct TestLastModifiedPass |
220 | : public PassWrapper<TestLastModifiedPass, OperationPass<>> { |
221 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass) |
222 | |
223 | TestLastModifiedPass() = default; |
224 | TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) { |
225 | interprocedural = other.interprocedural; |
226 | assumeFuncWrites = other.assumeFuncWrites; |
227 | } |
228 | |
229 | StringRef getArgument() const override { return "test-last-modified" ; } |
230 | |
231 | Option<bool> interprocedural{ |
232 | *this, "interprocedural" , llvm::cl::init(Val: true), |
233 | llvm::cl::desc("perform interprocedural analysis" )}; |
234 | Option<bool> assumeFuncWrites{ |
235 | *this, "assume-func-writes" , llvm::cl::init(Val: false), |
236 | llvm::cl::desc( |
237 | "assume external functions have write effect on all arguments" )}; |
238 | |
239 | void runOnOperation() override { |
240 | Operation *op = getOperation(); |
241 | |
242 | DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural)); |
243 | solver.load<DeadCodeAnalysis>(); |
244 | solver.load<SparseConstantPropagation>(); |
245 | solver.load<LastModifiedAnalysis>(args&: assumeFuncWrites); |
246 | solver.load<UnderlyingValueAnalysis>(); |
247 | if (failed(Result: solver.initializeAndRun(top: op))) |
248 | return signalPassFailure(); |
249 | |
250 | raw_ostream &os = llvm::errs(); |
251 | |
252 | // Note that if the underlying value could not be computed or is unknown, we |
253 | // conservatively treat the result also unknown. |
254 | op->walk(callback: [&](Operation *op) { |
255 | auto tag = op->getAttrOfType<StringAttr>("tag" ); |
256 | if (!tag) |
257 | return; |
258 | os << "test_tag: " << tag.getValue() << ":\n" ; |
259 | const LastModification *lastMods = |
260 | solver.lookupState<LastModification>(anchor: solver.getProgramPointAfter(op)); |
261 | assert(lastMods && "expected a dense lattice" ); |
262 | for (auto [index, operand] : llvm::enumerate(First: op->getOperands())) { |
263 | os << " operand #" << index << "\n" ; |
264 | std::optional<Value> underlyingValue = |
265 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
266 | value: operand, getUnderlyingValueFn: [&](Value value) { |
267 | return solver.lookupState<UnderlyingValueLattice>(anchor: value); |
268 | }); |
269 | if (!underlyingValue) { |
270 | os << " - <unknown>\n" ; |
271 | continue; |
272 | } |
273 | Value value = *underlyingValue; |
274 | assert(value && "expected an underlying value" ); |
275 | if (const AdjacentAccess *lastMod = |
276 | lastMods->getAdjacentAccess(value)) { |
277 | if (!lastMod->isKnown()) { |
278 | os << " - <unknown>\n" ; |
279 | } else { |
280 | for (Operation *lastModifier : lastMod->get()) { |
281 | if (auto tagName = |
282 | lastModifier->getAttrOfType<StringAttr>("tag_name" )) { |
283 | os << " - " << tagName.getValue() << "\n" ; |
284 | } else { |
285 | os << " - " << lastModifier->getName() << "\n" ; |
286 | } |
287 | } |
288 | } |
289 | } else { |
290 | os << " - <unknown>\n" ; |
291 | } |
292 | } |
293 | }); |
294 | } |
295 | }; |
296 | } // end anonymous namespace |
297 | |
298 | namespace mlir { |
299 | namespace test { |
300 | void registerTestLastModifiedPass() { |
301 | PassRegistration<TestLastModifiedPass>(); |
302 | } |
303 | } // end namespace test |
304 | } // end namespace mlir |
305 | |