1 | //===- TestDenseForwardDataFlowAnalysis.cpp -------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Implementation of tests passes exercising dense forward data flow analysis. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "TestDenseDataFlowAnalysis.h" |
14 | #include "TestDialect.h" |
15 | #include "TestOps.h" |
16 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
17 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
18 | #include "mlir/Analysis/DataFlow/DenseAnalysis.h" |
19 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include "mlir/Support/LLVM.h" |
22 | #include "llvm/ADT/TypeSwitch.h" |
23 | #include <optional> |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::dataflow; |
27 | using namespace mlir::dataflow::test; |
28 | |
29 | namespace { |
30 | |
31 | /// This lattice represents, for a given memory resource, the potential last |
32 | /// operations that modified the resource. |
33 | class LastModification : public AbstractDenseLattice, public AccessLatticeBase { |
34 | public: |
35 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification) |
36 | |
37 | using AbstractDenseLattice::AbstractDenseLattice; |
38 | |
39 | /// Join the last modifications. |
40 | ChangeResult join(const AbstractDenseLattice &lattice) override { |
41 | return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>( |
42 | static_cast<const LastModification &>(lattice))); |
43 | } |
44 | |
45 | void print(raw_ostream &os) const override { |
46 | return AccessLatticeBase::print(os); |
47 | } |
48 | }; |
49 | |
50 | class LastModifiedAnalysis |
51 | : public DenseForwardDataFlowAnalysis<LastModification> { |
52 | public: |
53 | explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites) |
54 | : DenseForwardDataFlowAnalysis(solver), |
55 | assumeFuncWrites(assumeFuncWrites) {} |
56 | |
57 | /// Visit an operation. If the operation has no memory effects, then the state |
58 | /// is propagated with no change. If the operation allocates a resource, then |
59 | /// its reaching definitions is set to empty. If the operation writes to a |
60 | /// resource, then its reaching definition is set to the written value. |
61 | void visitOperation(Operation *op, const LastModification &before, |
62 | LastModification *after) override; |
63 | |
64 | void visitCallControlFlowTransfer(CallOpInterface call, |
65 | CallControlFlowAction action, |
66 | const LastModification &before, |
67 | LastModification *after) override; |
68 | |
69 | void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, |
70 | std::optional<unsigned> regionFrom, |
71 | std::optional<unsigned> regionTo, |
72 | const LastModification &before, |
73 | LastModification *after) override; |
74 | |
75 | /// At an entry point, the last modifications of all memory resources are |
76 | /// unknown. |
77 | void setToEntryState(LastModification *lattice) override { |
78 | propagateIfChanged(state: lattice, changed: lattice->reset()); |
79 | } |
80 | |
81 | private: |
82 | const bool assumeFuncWrites; |
83 | }; |
84 | } // end anonymous namespace |
85 | |
86 | void LastModifiedAnalysis::visitOperation(Operation *op, |
87 | const LastModification &before, |
88 | LastModification *after) { |
89 | auto memory = dyn_cast<MemoryEffectOpInterface>(op); |
90 | // If we can't reason about the memory effects, then conservatively assume we |
91 | // can't deduce anything about the last modifications. |
92 | if (!memory) |
93 | return setToEntryState(after); |
94 | |
95 | SmallVector<MemoryEffects::EffectInstance> effects; |
96 | memory.getEffects(effects); |
97 | |
98 | // First, check if all underlying values are already known. Otherwise, avoid |
99 | // propagating and stay in the "undefined" state to avoid incorrectly |
100 | // propagating values that may be overwritten later on as that could be |
101 | // problematic for convergence based on monotonicity of lattice updates. |
102 | SmallVector<Value> underlyingValues; |
103 | underlyingValues.reserve(N: effects.size()); |
104 | for (const auto &effect : effects) { |
105 | Value value = effect.getValue(); |
106 | |
107 | // If we see an effect on anything other than a value, assume we can't |
108 | // deduce anything about the last modifications. |
109 | if (!value) |
110 | return setToEntryState(after); |
111 | |
112 | // If we cannot find the underlying value, we shouldn't just propagate the |
113 | // effects through, return the pessimistic state. |
114 | std::optional<Value> underlyingValue = |
115 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
116 | value, getUnderlyingValueFn: [&](Value value) { |
117 | return getOrCreateFor<UnderlyingValueLattice>(dependent: op, point: value); |
118 | }); |
119 | |
120 | // If the underlying value is not yet known, don't propagate yet. |
121 | if (!underlyingValue) |
122 | return; |
123 | |
124 | underlyingValues.push_back(Elt: *underlyingValue); |
125 | } |
126 | |
127 | // Update the state when all underlying values are known. |
128 | ChangeResult result = after->join(lattice: before); |
129 | for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) { |
130 | // If the underlying value is known to be unknown, set to fixpoint state. |
131 | if (!value) |
132 | return setToEntryState(after); |
133 | |
134 | // Nothing to do for reads. |
135 | if (isa<MemoryEffects::Read>(Val: effect.getEffect())) |
136 | continue; |
137 | |
138 | result |= after->set(value, op); |
139 | } |
140 | propagateIfChanged(state: after, changed: result); |
141 | } |
142 | |
143 | void LastModifiedAnalysis::visitCallControlFlowTransfer( |
144 | CallOpInterface call, CallControlFlowAction action, |
145 | const LastModification &before, LastModification *after) { |
146 | if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) { |
147 | SmallVector<Value> underlyingValues; |
148 | underlyingValues.reserve(N: call->getNumOperands()); |
149 | for (Value operand : call.getArgOperands()) { |
150 | std::optional<Value> underlyingValue = |
151 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
152 | operand, [&](Value value) { |
153 | return getOrCreateFor<UnderlyingValueLattice>( |
154 | call.getOperation(), value); |
155 | }); |
156 | if (!underlyingValue) |
157 | return; |
158 | underlyingValues.push_back(*underlyingValue); |
159 | } |
160 | |
161 | ChangeResult result = after->join(lattice: before); |
162 | for (Value operand : underlyingValues) |
163 | result |= after->set(value: operand, op: call); |
164 | return propagateIfChanged(state: after, changed: result); |
165 | } |
166 | auto testCallAndStore = |
167 | dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); |
168 | if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && |
169 | testCallAndStore.getStoreBeforeCall()) || |
170 | (action == CallControlFlowAction::ExitCallee && |
171 | !testCallAndStore.getStoreBeforeCall()))) { |
172 | return visitOperation(op: call, before, after); |
173 | } |
174 | AbstractDenseForwardDataFlowAnalysis::visitCallControlFlowTransfer( |
175 | call: call, action, before, after); |
176 | } |
177 | |
178 | void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer( |
179 | RegionBranchOpInterface branch, std::optional<unsigned> regionFrom, |
180 | std::optional<unsigned> regionTo, const LastModification &before, |
181 | LastModification *after) { |
182 | auto defaultHandling = [&]() { |
183 | AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer( |
184 | branch: branch, regionFrom, regionTo, before, after); |
185 | }; |
186 | TypeSwitch<Operation *>(branch.getOperation()) |
187 | .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>( |
188 | [=](auto storeWithRegion) { |
189 | if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) || |
190 | (!regionFrom && storeWithRegion.getStoreBeforeRegion())) |
191 | visitOperation(branch, before, after); |
192 | defaultHandling(); |
193 | }) |
194 | .Default([=](auto) { defaultHandling(); }); |
195 | } |
196 | |
197 | namespace { |
198 | struct TestLastModifiedPass |
199 | : public PassWrapper<TestLastModifiedPass, OperationPass<>> { |
200 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass) |
201 | |
202 | TestLastModifiedPass() = default; |
203 | TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) { |
204 | interprocedural = other.interprocedural; |
205 | assumeFuncWrites = other.assumeFuncWrites; |
206 | } |
207 | |
208 | StringRef getArgument() const override { return "test-last-modified" ; } |
209 | |
210 | Option<bool> interprocedural{ |
211 | *this, "interprocedural" , llvm::cl::init(Val: true), |
212 | llvm::cl::desc("perform interprocedural analysis" )}; |
213 | Option<bool> assumeFuncWrites{ |
214 | *this, "assume-func-writes" , llvm::cl::init(Val: false), |
215 | llvm::cl::desc( |
216 | "assume external functions have write effect on all arguments" )}; |
217 | |
218 | void runOnOperation() override { |
219 | Operation *op = getOperation(); |
220 | |
221 | DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural)); |
222 | solver.load<DeadCodeAnalysis>(); |
223 | solver.load<SparseConstantPropagation>(); |
224 | solver.load<LastModifiedAnalysis>(args&: assumeFuncWrites); |
225 | solver.load<UnderlyingValueAnalysis>(); |
226 | if (failed(result: solver.initializeAndRun(top: op))) |
227 | return signalPassFailure(); |
228 | |
229 | raw_ostream &os = llvm::errs(); |
230 | |
231 | // Note that if the underlying value could not be computed or is unknown, we |
232 | // conservatively treat the result also unknown. |
233 | op->walk(callback: [&](Operation *op) { |
234 | auto tag = op->getAttrOfType<StringAttr>("tag" ); |
235 | if (!tag) |
236 | return; |
237 | os << "test_tag: " << tag.getValue() << ":\n" ; |
238 | const LastModification *lastMods = |
239 | solver.lookupState<LastModification>(point: op); |
240 | assert(lastMods && "expected a dense lattice" ); |
241 | for (auto [index, operand] : llvm::enumerate(First: op->getOperands())) { |
242 | os << " operand #" << index << "\n" ; |
243 | std::optional<Value> underlyingValue = |
244 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
245 | value: operand, getUnderlyingValueFn: [&](Value value) { |
246 | return solver.lookupState<UnderlyingValueLattice>(point: value); |
247 | }); |
248 | if (!underlyingValue) { |
249 | os << " - <unknown>\n" ; |
250 | continue; |
251 | } |
252 | Value value = *underlyingValue; |
253 | assert(value && "expected an underlying value" ); |
254 | if (const AdjacentAccess *lastMod = |
255 | lastMods->getAdjacentAccess(value)) { |
256 | if (!lastMod->isKnown()) { |
257 | os << " - <unknown>\n" ; |
258 | } else { |
259 | for (Operation *lastModifier : lastMod->get()) { |
260 | if (auto tagName = |
261 | lastModifier->getAttrOfType<StringAttr>("tag_name" )) { |
262 | os << " - " << tagName.getValue() << "\n" ; |
263 | } else { |
264 | os << " - " << lastModifier->getName() << "\n" ; |
265 | } |
266 | } |
267 | } |
268 | } else { |
269 | os << " - <unknown>\n" ; |
270 | } |
271 | } |
272 | }); |
273 | } |
274 | }; |
275 | } // end anonymous namespace |
276 | |
277 | namespace mlir { |
278 | namespace test { |
279 | void registerTestLastModifiedPass() { |
280 | PassRegistration<TestLastModifiedPass>(); |
281 | } |
282 | } // end namespace test |
283 | } // end namespace mlir |
284 | |