1//===- TestDenseForwardDataFlowAnalysis.cpp -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation of tests passes exercising dense forward data flow analysis.
10//
11//===----------------------------------------------------------------------===//
12
13#include "TestDenseDataFlowAnalysis.h"
14#include "TestDialect.h"
15#include "TestOps.h"
16#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
17#include "mlir/Analysis/DataFlow/Utils.h"
18#include "mlir/Interfaces/SideEffectInterfaces.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Support/LLVM.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include <optional>
23
24using namespace mlir;
25using namespace mlir::dataflow;
26using namespace mlir::dataflow::test;
27
28namespace {
29
30/// This lattice represents, for a given memory resource, the potential last
31/// operations that modified the resource.
32class LastModification : public AbstractDenseLattice, public AccessLatticeBase {
33public:
34 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
35
36 using AbstractDenseLattice::AbstractDenseLattice;
37
38 /// Join the last modifications.
39 ChangeResult join(const AbstractDenseLattice &lattice) override {
40 return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>(
41 static_cast<const LastModification &>(lattice)));
42 }
43
44 void print(raw_ostream &os) const override {
45 return AccessLatticeBase::print(os);
46 }
47};
48
49class LastModifiedAnalysis
50 : public DenseForwardDataFlowAnalysis<LastModification> {
51public:
52 explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites)
53 : DenseForwardDataFlowAnalysis(solver),
54 assumeFuncWrites(assumeFuncWrites) {}
55
56 /// Visit an operation. If the operation has no memory effects, then the state
57 /// is propagated with no change. If the operation allocates a resource, then
58 /// its reaching definitions is set to empty. If the operation writes to a
59 /// resource, then its reaching definition is set to the written value.
60 LogicalResult visitOperation(Operation *op, const LastModification &before,
61 LastModification *after) override;
62
63 void visitCallControlFlowTransfer(CallOpInterface call,
64 CallControlFlowAction action,
65 const LastModification &before,
66 LastModification *after) override;
67
68 void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
69 std::optional<unsigned> regionFrom,
70 std::optional<unsigned> regionTo,
71 const LastModification &before,
72 LastModification *after) override;
73
74 /// Visit an operation. If this analysis can confirm that lattice content
75 /// of lattice anchors around operation are necessarily identical, join
76 /// them into the same equivalent class.
77 void buildOperationEquivalentLatticeAnchor(Operation *op) override;
78
79 /// At an entry point, the last modifications of all memory resources are
80 /// unknown.
81 void setToEntryState(LastModification *lattice) override {
82 propagateIfChanged(state: lattice, changed: lattice->reset());
83 }
84
85private:
86 const bool assumeFuncWrites;
87};
88} // end anonymous namespace
89
90LogicalResult LastModifiedAnalysis::visitOperation(
91 Operation *op, const LastModification &before, LastModification *after) {
92 auto memory = dyn_cast<MemoryEffectOpInterface>(Val: op);
93 // If we can't reason about the memory effects, then conservatively assume we
94 // can't deduce anything about the last modifications.
95 if (!memory) {
96 setToEntryState(after);
97 return success();
98 }
99
100 SmallVector<MemoryEffects::EffectInstance> effects;
101 memory.getEffects(effects);
102
103 // First, check if all underlying values are already known. Otherwise, avoid
104 // propagating and stay in the "undefined" state to avoid incorrectly
105 // propagating values that may be overwritten later on as that could be
106 // problematic for convergence based on monotonicity of lattice updates.
107 SmallVector<Value> underlyingValues;
108 underlyingValues.reserve(N: effects.size());
109 for (const auto &effect : effects) {
110 Value value = effect.getValue();
111
112 // If we see an effect on anything other than a value, assume we can't
113 // deduce anything about the last modifications.
114 if (!value) {
115 setToEntryState(after);
116 return success();
117 }
118
119 // If we cannot find the underlying value, we shouldn't just propagate the
120 // effects through, return the pessimistic state.
121 std::optional<Value> underlyingValue =
122 UnderlyingValueAnalysis::getMostUnderlyingValue(
123 value, getUnderlyingValueFn: [&](Value value) {
124 return getOrCreateFor<UnderlyingValueLattice>(
125 dependent: getProgramPointAfter(op), anchor: value);
126 });
127
128 // If the underlying value is not yet known, don't propagate yet.
129 if (!underlyingValue)
130 return success();
131
132 underlyingValues.push_back(Elt: *underlyingValue);
133 }
134
135 // Update the state when all underlying values are known.
136 ChangeResult result = after->join(lattice: before);
137 for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) {
138 // If the underlying value is known to be unknown, set to fixpoint state.
139 if (!value) {
140 setToEntryState(after);
141 return success();
142 }
143
144 // Nothing to do for reads.
145 if (isa<MemoryEffects::Read>(Val: effect.getEffect()))
146 continue;
147
148 result |= after->set(value, op);
149 }
150 propagateIfChanged(state: after, changed: result);
151 return success();
152}
153
154void LastModifiedAnalysis::buildOperationEquivalentLatticeAnchor(
155 Operation *op) {
156 if (isMemoryEffectFree(op)) {
157 unionLatticeAnchors<LastModification>(anchor: getProgramPointBefore(op),
158 other: getProgramPointAfter(op));
159 }
160}
161
162void LastModifiedAnalysis::visitCallControlFlowTransfer(
163 CallOpInterface call, CallControlFlowAction action,
164 const LastModification &before, LastModification *after) {
165 if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) {
166 SmallVector<Value> underlyingValues;
167 underlyingValues.reserve(N: call->getNumOperands());
168 for (Value operand : call.getArgOperands()) {
169 std::optional<Value> underlyingValue =
170 UnderlyingValueAnalysis::getMostUnderlyingValue(
171 value: operand, getUnderlyingValueFn: [&](Value value) {
172 return getOrCreateFor<UnderlyingValueLattice>(
173 dependent: getProgramPointAfter(op: call.getOperation()), anchor: value);
174 });
175 if (!underlyingValue)
176 return;
177 underlyingValues.push_back(Elt: *underlyingValue);
178 }
179
180 ChangeResult result = after->join(lattice: before);
181 for (Value operand : underlyingValues)
182 result |= after->set(value: operand, op: call);
183 return propagateIfChanged(state: after, changed: result);
184 }
185 auto testCallAndStore =
186 dyn_cast<::test::TestCallAndStoreOp>(Val: call.getOperation());
187 if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
188 testCallAndStore.getStoreBeforeCall()) ||
189 (action == CallControlFlowAction::ExitCallee &&
190 !testCallAndStore.getStoreBeforeCall()))) {
191 (void)visitOperation(op: call, before, after);
192 return;
193 }
194 AbstractDenseForwardDataFlowAnalysis::visitCallControlFlowTransfer(
195 call, action, before, after);
196}
197
198void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer(
199 RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
200 std::optional<unsigned> regionTo, const LastModification &before,
201 LastModification *after) {
202 auto defaultHandling = [&]() {
203 AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
204 branch, regionFrom, regionTo, before, after);
205 };
206 TypeSwitch<Operation *>(branch.getOperation())
207 .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>(
208 caseFn: [=](auto storeWithRegion) {
209 if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) ||
210 (!regionFrom && storeWithRegion.getStoreBeforeRegion()))
211 (void)visitOperation(op: branch, before, after);
212 defaultHandling();
213 })
214 .Default(defaultFn: [=](auto) { defaultHandling(); });
215}
216
217namespace {
218struct TestLastModifiedPass
219 : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
220 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
221
222 TestLastModifiedPass() = default;
223 TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) {
224 interprocedural = other.interprocedural;
225 assumeFuncWrites = other.assumeFuncWrites;
226 }
227
228 StringRef getArgument() const override { return "test-last-modified"; }
229
230 Option<bool> interprocedural{
231 *this, "interprocedural", llvm::cl::init(Val: true),
232 llvm::cl::desc("perform interprocedural analysis")};
233 Option<bool> assumeFuncWrites{
234 *this, "assume-func-writes", llvm::cl::init(Val: false),
235 llvm::cl::desc(
236 "assume external functions have write effect on all arguments")};
237
238 void runOnOperation() override {
239 Operation *op = getOperation();
240
241 DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
242 loadBaselineAnalyses(solver);
243 solver.load<LastModifiedAnalysis>(args&: assumeFuncWrites);
244 solver.load<UnderlyingValueAnalysis>();
245 if (failed(Result: solver.initializeAndRun(top: op)))
246 return signalPassFailure();
247
248 raw_ostream &os = llvm::errs();
249
250 // Note that if the underlying value could not be computed or is unknown, we
251 // conservatively treat the result also unknown.
252 op->walk(callback: [&](Operation *op) {
253 auto tag = op->getAttrOfType<StringAttr>(name: "tag");
254 if (!tag)
255 return;
256 os << "test_tag: " << tag.getValue() << ":\n";
257 const LastModification *lastMods =
258 solver.lookupState<LastModification>(anchor: solver.getProgramPointAfter(op));
259 assert(lastMods && "expected a dense lattice");
260 for (auto [index, operand] : llvm::enumerate(First: op->getOperands())) {
261 os << " operand #" << index << "\n";
262 std::optional<Value> underlyingValue =
263 UnderlyingValueAnalysis::getMostUnderlyingValue(
264 value: operand, getUnderlyingValueFn: [&](Value value) {
265 return solver.lookupState<UnderlyingValueLattice>(anchor: value);
266 });
267 if (!underlyingValue) {
268 os << " - <unknown>\n";
269 continue;
270 }
271 Value value = *underlyingValue;
272 assert(value && "expected an underlying value");
273 if (const AdjacentAccess *lastMod =
274 lastMods->getAdjacentAccess(value)) {
275 if (!lastMod->isKnown()) {
276 os << " - <unknown>\n";
277 } else {
278 for (Operation *lastModifier : lastMod->get()) {
279 if (auto tagName =
280 lastModifier->getAttrOfType<StringAttr>(name: "tag_name")) {
281 os << " - " << tagName.getValue() << "\n";
282 } else {
283 os << " - " << lastModifier->getName() << "\n";
284 }
285 }
286 }
287 } else {
288 os << " - <unknown>\n";
289 }
290 }
291 });
292 }
293};
294} // end anonymous namespace
295
296namespace mlir {
297namespace test {
298void registerTestLastModifiedPass() {
299 PassRegistration<TestLastModifiedPass>();
300}
301} // end namespace test
302} // end namespace mlir
303

source code of mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp