1//===- TestDenseForwardDataFlowAnalysis.cpp -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation of tests passes exercising dense forward data flow analysis.
10//
11//===----------------------------------------------------------------------===//
12
13#include "TestDenseDataFlowAnalysis.h"
14#include "TestDialect.h"
15#include "TestOps.h"
16#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
18#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Support/LLVM.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include <optional>
24
25using namespace mlir;
26using namespace mlir::dataflow;
27using namespace mlir::dataflow::test;
28
29namespace {
30
31/// This lattice represents, for a given memory resource, the potential last
32/// operations that modified the resource.
33class LastModification : public AbstractDenseLattice, public AccessLatticeBase {
34public:
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
36
37 using AbstractDenseLattice::AbstractDenseLattice;
38
39 /// Join the last modifications.
40 ChangeResult join(const AbstractDenseLattice &lattice) override {
41 return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>(
42 static_cast<const LastModification &>(lattice)));
43 }
44
45 void print(raw_ostream &os) const override {
46 return AccessLatticeBase::print(os);
47 }
48};
49
50class LastModifiedAnalysis
51 : public DenseForwardDataFlowAnalysis<LastModification> {
52public:
53 explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites)
54 : DenseForwardDataFlowAnalysis(solver),
55 assumeFuncWrites(assumeFuncWrites) {}
56
57 /// Visit an operation. If the operation has no memory effects, then the state
58 /// is propagated with no change. If the operation allocates a resource, then
59 /// its reaching definitions is set to empty. If the operation writes to a
60 /// resource, then its reaching definition is set to the written value.
61 void visitOperation(Operation *op, const LastModification &before,
62 LastModification *after) override;
63
64 void visitCallControlFlowTransfer(CallOpInterface call,
65 CallControlFlowAction action,
66 const LastModification &before,
67 LastModification *after) override;
68
69 void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
70 std::optional<unsigned> regionFrom,
71 std::optional<unsigned> regionTo,
72 const LastModification &before,
73 LastModification *after) override;
74
75 /// At an entry point, the last modifications of all memory resources are
76 /// unknown.
77 void setToEntryState(LastModification *lattice) override {
78 propagateIfChanged(state: lattice, changed: lattice->reset());
79 }
80
81private:
82 const bool assumeFuncWrites;
83};
84} // end anonymous namespace
85
86void LastModifiedAnalysis::visitOperation(Operation *op,
87 const LastModification &before,
88 LastModification *after) {
89 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
90 // If we can't reason about the memory effects, then conservatively assume we
91 // can't deduce anything about the last modifications.
92 if (!memory)
93 return setToEntryState(after);
94
95 SmallVector<MemoryEffects::EffectInstance> effects;
96 memory.getEffects(effects);
97
98 // First, check if all underlying values are already known. Otherwise, avoid
99 // propagating and stay in the "undefined" state to avoid incorrectly
100 // propagating values that may be overwritten later on as that could be
101 // problematic for convergence based on monotonicity of lattice updates.
102 SmallVector<Value> underlyingValues;
103 underlyingValues.reserve(N: effects.size());
104 for (const auto &effect : effects) {
105 Value value = effect.getValue();
106
107 // If we see an effect on anything other than a value, assume we can't
108 // deduce anything about the last modifications.
109 if (!value)
110 return setToEntryState(after);
111
112 // If we cannot find the underlying value, we shouldn't just propagate the
113 // effects through, return the pessimistic state.
114 std::optional<Value> underlyingValue =
115 UnderlyingValueAnalysis::getMostUnderlyingValue(
116 value, getUnderlyingValueFn: [&](Value value) {
117 return getOrCreateFor<UnderlyingValueLattice>(dependent: op, point: value);
118 });
119
120 // If the underlying value is not yet known, don't propagate yet.
121 if (!underlyingValue)
122 return;
123
124 underlyingValues.push_back(Elt: *underlyingValue);
125 }
126
127 // Update the state when all underlying values are known.
128 ChangeResult result = after->join(lattice: before);
129 for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) {
130 // If the underlying value is known to be unknown, set to fixpoint state.
131 if (!value)
132 return setToEntryState(after);
133
134 // Nothing to do for reads.
135 if (isa<MemoryEffects::Read>(Val: effect.getEffect()))
136 continue;
137
138 result |= after->set(value, op);
139 }
140 propagateIfChanged(state: after, changed: result);
141}
142
143void LastModifiedAnalysis::visitCallControlFlowTransfer(
144 CallOpInterface call, CallControlFlowAction action,
145 const LastModification &before, LastModification *after) {
146 if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) {
147 SmallVector<Value> underlyingValues;
148 underlyingValues.reserve(N: call->getNumOperands());
149 for (Value operand : call.getArgOperands()) {
150 std::optional<Value> underlyingValue =
151 UnderlyingValueAnalysis::getMostUnderlyingValue(
152 operand, [&](Value value) {
153 return getOrCreateFor<UnderlyingValueLattice>(
154 call.getOperation(), value);
155 });
156 if (!underlyingValue)
157 return;
158 underlyingValues.push_back(*underlyingValue);
159 }
160
161 ChangeResult result = after->join(lattice: before);
162 for (Value operand : underlyingValues)
163 result |= after->set(value: operand, op: call);
164 return propagateIfChanged(state: after, changed: result);
165 }
166 auto testCallAndStore =
167 dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
168 if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
169 testCallAndStore.getStoreBeforeCall()) ||
170 (action == CallControlFlowAction::ExitCallee &&
171 !testCallAndStore.getStoreBeforeCall()))) {
172 return visitOperation(op: call, before, after);
173 }
174 AbstractDenseForwardDataFlowAnalysis::visitCallControlFlowTransfer(
175 call: call, action, before, after);
176}
177
178void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer(
179 RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
180 std::optional<unsigned> regionTo, const LastModification &before,
181 LastModification *after) {
182 auto defaultHandling = [&]() {
183 AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
184 branch: branch, regionFrom, regionTo, before, after);
185 };
186 TypeSwitch<Operation *>(branch.getOperation())
187 .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>(
188 [=](auto storeWithRegion) {
189 if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) ||
190 (!regionFrom && storeWithRegion.getStoreBeforeRegion()))
191 visitOperation(branch, before, after);
192 defaultHandling();
193 })
194 .Default([=](auto) { defaultHandling(); });
195}
196
197namespace {
198struct TestLastModifiedPass
199 : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
200 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
201
202 TestLastModifiedPass() = default;
203 TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) {
204 interprocedural = other.interprocedural;
205 assumeFuncWrites = other.assumeFuncWrites;
206 }
207
208 StringRef getArgument() const override { return "test-last-modified"; }
209
210 Option<bool> interprocedural{
211 *this, "interprocedural", llvm::cl::init(Val: true),
212 llvm::cl::desc("perform interprocedural analysis")};
213 Option<bool> assumeFuncWrites{
214 *this, "assume-func-writes", llvm::cl::init(Val: false),
215 llvm::cl::desc(
216 "assume external functions have write effect on all arguments")};
217
218 void runOnOperation() override {
219 Operation *op = getOperation();
220
221 DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
222 solver.load<DeadCodeAnalysis>();
223 solver.load<SparseConstantPropagation>();
224 solver.load<LastModifiedAnalysis>(args&: assumeFuncWrites);
225 solver.load<UnderlyingValueAnalysis>();
226 if (failed(result: solver.initializeAndRun(top: op)))
227 return signalPassFailure();
228
229 raw_ostream &os = llvm::errs();
230
231 // Note that if the underlying value could not be computed or is unknown, we
232 // conservatively treat the result also unknown.
233 op->walk(callback: [&](Operation *op) {
234 auto tag = op->getAttrOfType<StringAttr>("tag");
235 if (!tag)
236 return;
237 os << "test_tag: " << tag.getValue() << ":\n";
238 const LastModification *lastMods =
239 solver.lookupState<LastModification>(point: op);
240 assert(lastMods && "expected a dense lattice");
241 for (auto [index, operand] : llvm::enumerate(First: op->getOperands())) {
242 os << " operand #" << index << "\n";
243 std::optional<Value> underlyingValue =
244 UnderlyingValueAnalysis::getMostUnderlyingValue(
245 value: operand, getUnderlyingValueFn: [&](Value value) {
246 return solver.lookupState<UnderlyingValueLattice>(point: value);
247 });
248 if (!underlyingValue) {
249 os << " - <unknown>\n";
250 continue;
251 }
252 Value value = *underlyingValue;
253 assert(value && "expected an underlying value");
254 if (const AdjacentAccess *lastMod =
255 lastMods->getAdjacentAccess(value)) {
256 if (!lastMod->isKnown()) {
257 os << " - <unknown>\n";
258 } else {
259 for (Operation *lastModifier : lastMod->get()) {
260 if (auto tagName =
261 lastModifier->getAttrOfType<StringAttr>("tag_name")) {
262 os << " - " << tagName.getValue() << "\n";
263 } else {
264 os << " - " << lastModifier->getName() << "\n";
265 }
266 }
267 }
268 } else {
269 os << " - <unknown>\n";
270 }
271 }
272 });
273 }
274};
275} // end anonymous namespace
276
277namespace mlir {
278namespace test {
279void registerTestLastModifiedPass() {
280 PassRegistration<TestLastModifiedPass>();
281}
282} // end namespace test
283} // end namespace mlir
284

source code of mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp