1//===- TestDenseForwardDataFlowAnalysis.cpp -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation of tests passes exercising dense forward data flow analysis.
10//
11//===----------------------------------------------------------------------===//
12
13#include "TestDenseDataFlowAnalysis.h"
14#include "TestDialect.h"
15#include "TestOps.h"
16#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
18#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
19#include "mlir/Interfaces/SideEffectInterfaces.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Support/LLVM.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include <optional>
24
25using namespace mlir;
26using namespace mlir::dataflow;
27using namespace mlir::dataflow::test;
28
29namespace {
30
31/// This lattice represents, for a given memory resource, the potential last
32/// operations that modified the resource.
33class LastModification : public AbstractDenseLattice, public AccessLatticeBase {
34public:
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
36
37 using AbstractDenseLattice::AbstractDenseLattice;
38
39 /// Join the last modifications.
40 ChangeResult join(const AbstractDenseLattice &lattice) override {
41 return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>(
42 static_cast<const LastModification &>(lattice)));
43 }
44
45 void print(raw_ostream &os) const override {
46 return AccessLatticeBase::print(os);
47 }
48};
49
50class LastModifiedAnalysis
51 : public DenseForwardDataFlowAnalysis<LastModification> {
52public:
53 explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites)
54 : DenseForwardDataFlowAnalysis(solver),
55 assumeFuncWrites(assumeFuncWrites) {}
56
57 /// Visit an operation. If the operation has no memory effects, then the state
58 /// is propagated with no change. If the operation allocates a resource, then
59 /// its reaching definitions is set to empty. If the operation writes to a
60 /// resource, then its reaching definition is set to the written value.
61 LogicalResult visitOperation(Operation *op, const LastModification &before,
62 LastModification *after) override;
63
64 void visitCallControlFlowTransfer(CallOpInterface call,
65 CallControlFlowAction action,
66 const LastModification &before,
67 LastModification *after) override;
68
69 void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
70 std::optional<unsigned> regionFrom,
71 std::optional<unsigned> regionTo,
72 const LastModification &before,
73 LastModification *after) override;
74
75 /// Visit an operation. If this analysis can confirm that lattice content
76 /// of lattice anchors around operation are necessarily identical, join
77 /// them into the same equivalent class.
78 void buildOperationEquivalentLatticeAnchor(Operation *op) override;
79
80 /// At an entry point, the last modifications of all memory resources are
81 /// unknown.
82 void setToEntryState(LastModification *lattice) override {
83 propagateIfChanged(state: lattice, changed: lattice->reset());
84 }
85
86private:
87 const bool assumeFuncWrites;
88};
89} // end anonymous namespace
90
91LogicalResult LastModifiedAnalysis::visitOperation(
92 Operation *op, const LastModification &before, LastModification *after) {
93 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
94 // If we can't reason about the memory effects, then conservatively assume we
95 // can't deduce anything about the last modifications.
96 if (!memory) {
97 setToEntryState(after);
98 return success();
99 }
100
101 SmallVector<MemoryEffects::EffectInstance> effects;
102 memory.getEffects(effects);
103
104 // First, check if all underlying values are already known. Otherwise, avoid
105 // propagating and stay in the "undefined" state to avoid incorrectly
106 // propagating values that may be overwritten later on as that could be
107 // problematic for convergence based on monotonicity of lattice updates.
108 SmallVector<Value> underlyingValues;
109 underlyingValues.reserve(N: effects.size());
110 for (const auto &effect : effects) {
111 Value value = effect.getValue();
112
113 // If we see an effect on anything other than a value, assume we can't
114 // deduce anything about the last modifications.
115 if (!value) {
116 setToEntryState(after);
117 return success();
118 }
119
120 // If we cannot find the underlying value, we shouldn't just propagate the
121 // effects through, return the pessimistic state.
122 std::optional<Value> underlyingValue =
123 UnderlyingValueAnalysis::getMostUnderlyingValue(
124 value, getUnderlyingValueFn: [&](Value value) {
125 return getOrCreateFor<UnderlyingValueLattice>(
126 dependent: getProgramPointAfter(op), anchor: value);
127 });
128
129 // If the underlying value is not yet known, don't propagate yet.
130 if (!underlyingValue)
131 return success();
132
133 underlyingValues.push_back(Elt: *underlyingValue);
134 }
135
136 // Update the state when all underlying values are known.
137 ChangeResult result = after->join(lattice: before);
138 for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) {
139 // If the underlying value is known to be unknown, set to fixpoint state.
140 if (!value) {
141 setToEntryState(after);
142 return success();
143 }
144
145 // Nothing to do for reads.
146 if (isa<MemoryEffects::Read>(Val: effect.getEffect()))
147 continue;
148
149 result |= after->set(value, op);
150 }
151 propagateIfChanged(state: after, changed: result);
152 return success();
153}
154
155void LastModifiedAnalysis::buildOperationEquivalentLatticeAnchor(
156 Operation *op) {
157 if (isMemoryEffectFree(op)) {
158 unionLatticeAnchors<LastModification>(anchor: getProgramPointBefore(op),
159 other: getProgramPointAfter(op));
160 }
161}
162
163void LastModifiedAnalysis::visitCallControlFlowTransfer(
164 CallOpInterface call, CallControlFlowAction action,
165 const LastModification &before, LastModification *after) {
166 if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) {
167 SmallVector<Value> underlyingValues;
168 underlyingValues.reserve(N: call->getNumOperands());
169 for (Value operand : call.getArgOperands()) {
170 std::optional<Value> underlyingValue =
171 UnderlyingValueAnalysis::getMostUnderlyingValue(
172 operand, [&](Value value) {
173 return getOrCreateFor<UnderlyingValueLattice>(
174 getProgramPointAfter(call.getOperation()), value);
175 });
176 if (!underlyingValue)
177 return;
178 underlyingValues.push_back(*underlyingValue);
179 }
180
181 ChangeResult result = after->join(lattice: before);
182 for (Value operand : underlyingValues)
183 result |= after->set(value: operand, op: call);
184 return propagateIfChanged(state: after, changed: result);
185 }
186 auto testCallAndStore =
187 dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
188 if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
189 testCallAndStore.getStoreBeforeCall()) ||
190 (action == CallControlFlowAction::ExitCallee &&
191 !testCallAndStore.getStoreBeforeCall()))) {
192 (void)visitOperation(op: call, before, after);
193 return;
194 }
195 AbstractDenseForwardDataFlowAnalysis::visitCallControlFlowTransfer(
196 call: call, action, before, after);
197}
198
199void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer(
200 RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
201 std::optional<unsigned> regionTo, const LastModification &before,
202 LastModification *after) {
203 auto defaultHandling = [&]() {
204 AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
205 branch: branch, regionFrom, regionTo, before, after);
206 };
207 TypeSwitch<Operation *>(branch.getOperation())
208 .Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>(
209 [=](auto storeWithRegion) {
210 if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) ||
211 (!regionFrom && storeWithRegion.getStoreBeforeRegion()))
212 (void)visitOperation(branch, before, after);
213 defaultHandling();
214 })
215 .Default([=](auto) { defaultHandling(); });
216}
217
218namespace {
219struct TestLastModifiedPass
220 : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
221 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
222
223 TestLastModifiedPass() = default;
224 TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) {
225 interprocedural = other.interprocedural;
226 assumeFuncWrites = other.assumeFuncWrites;
227 }
228
229 StringRef getArgument() const override { return "test-last-modified"; }
230
231 Option<bool> interprocedural{
232 *this, "interprocedural", llvm::cl::init(Val: true),
233 llvm::cl::desc("perform interprocedural analysis")};
234 Option<bool> assumeFuncWrites{
235 *this, "assume-func-writes", llvm::cl::init(Val: false),
236 llvm::cl::desc(
237 "assume external functions have write effect on all arguments")};
238
239 void runOnOperation() override {
240 Operation *op = getOperation();
241
242 DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
243 solver.load<DeadCodeAnalysis>();
244 solver.load<SparseConstantPropagation>();
245 solver.load<LastModifiedAnalysis>(args&: assumeFuncWrites);
246 solver.load<UnderlyingValueAnalysis>();
247 if (failed(Result: solver.initializeAndRun(top: op)))
248 return signalPassFailure();
249
250 raw_ostream &os = llvm::errs();
251
252 // Note that if the underlying value could not be computed or is unknown, we
253 // conservatively treat the result also unknown.
254 op->walk(callback: [&](Operation *op) {
255 auto tag = op->getAttrOfType<StringAttr>("tag");
256 if (!tag)
257 return;
258 os << "test_tag: " << tag.getValue() << ":\n";
259 const LastModification *lastMods =
260 solver.lookupState<LastModification>(anchor: solver.getProgramPointAfter(op));
261 assert(lastMods && "expected a dense lattice");
262 for (auto [index, operand] : llvm::enumerate(First: op->getOperands())) {
263 os << " operand #" << index << "\n";
264 std::optional<Value> underlyingValue =
265 UnderlyingValueAnalysis::getMostUnderlyingValue(
266 value: operand, getUnderlyingValueFn: [&](Value value) {
267 return solver.lookupState<UnderlyingValueLattice>(anchor: value);
268 });
269 if (!underlyingValue) {
270 os << " - <unknown>\n";
271 continue;
272 }
273 Value value = *underlyingValue;
274 assert(value && "expected an underlying value");
275 if (const AdjacentAccess *lastMod =
276 lastMods->getAdjacentAccess(value)) {
277 if (!lastMod->isKnown()) {
278 os << " - <unknown>\n";
279 } else {
280 for (Operation *lastModifier : lastMod->get()) {
281 if (auto tagName =
282 lastModifier->getAttrOfType<StringAttr>("tag_name")) {
283 os << " - " << tagName.getValue() << "\n";
284 } else {
285 os << " - " << lastModifier->getName() << "\n";
286 }
287 }
288 }
289 } else {
290 os << " - <unknown>\n";
291 }
292 }
293 });
294 }
295};
296} // end anonymous namespace
297
298namespace mlir {
299namespace test {
300void registerTestLastModifiedPass() {
301 PassRegistration<TestLastModifiedPass>();
302}
303} // end namespace test
304} // end namespace mlir
305

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp