1//===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Test pass for backward dense dataflow analysis.
10//
11//===----------------------------------------------------------------------===//
12
13#include "TestDenseDataFlowAnalysis.h"
14#include "TestDialect.h"
15#include "TestOps.h"
16#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
17#include "mlir/Analysis/DataFlow/Utils.h"
18#include "mlir/Analysis/DataFlowFramework.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/SymbolTable.h"
21#include "mlir/Interfaces/CallInterfaces.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.h"
23#include "mlir/Interfaces/SideEffectInterfaces.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Support/TypeID.h"
26#include "llvm/Support/raw_ostream.h"
27
28using namespace mlir;
29using namespace mlir::dataflow;
30using namespace mlir::dataflow::test;
31
32namespace {
33
34class NextAccess : public AbstractDenseLattice, public AccessLatticeBase {
35public:
36 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess)
37
38 using dataflow::AbstractDenseLattice::AbstractDenseLattice;
39
40 ChangeResult meet(const AbstractDenseLattice &lattice) override {
41 return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>(
42 static_cast<const NextAccess &>(lattice)));
43 }
44
45 void print(raw_ostream &os) const override {
46 return AccessLatticeBase::print(os);
47 }
48};
49
50class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
51public:
52 NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
53 bool assumeFuncReads = false)
54 : DenseBackwardDataFlowAnalysis(solver, symbolTable),
55 assumeFuncReads(assumeFuncReads) {}
56
57 LogicalResult visitOperation(Operation *op, const NextAccess &after,
58 NextAccess *before) override;
59
60 void visitCallControlFlowTransfer(CallOpInterface call,
61 CallControlFlowAction action,
62 const NextAccess &after,
63 NextAccess *before) override;
64
65 void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
66 RegionBranchPoint regionFrom,
67 RegionBranchPoint regionTo,
68 const NextAccess &after,
69 NextAccess *before) override;
70
71 // TODO: this isn't ideal for the analysis. When there is no next access, it
72 // means "we don't know what the next access is" rather than "there is no next
73 // access". But it's unclear how to differentiate the two cases...
74 void setToExitState(NextAccess *lattice) override {
75 propagateIfChanged(state: lattice, changed: lattice->setKnownToUnknown());
76 }
77
78 /// Visit an operation. If this analysis can confirm that lattice content
79 /// of lattice anchors around operation are necessarily identical, join
80 /// them into the same equivalent class.
81 void buildOperationEquivalentLatticeAnchor(Operation *op) override;
82
83 const bool assumeFuncReads;
84};
85} // namespace
86
87LogicalResult NextAccessAnalysis::visitOperation(Operation *op,
88 const NextAccess &after,
89 NextAccess *before) {
90 auto memory = dyn_cast<MemoryEffectOpInterface>(Val: op);
91 // If we can't reason about the memory effects, conservatively assume we can't
92 // say anything about the next access.
93 if (!memory) {
94 setToExitState(before);
95 return success();
96 }
97
98 SmallVector<MemoryEffects::EffectInstance> effects;
99 memory.getEffects(effects);
100
101 // First, check if all underlying values are already known. Otherwise, avoid
102 // propagating and stay in the "undefined" state to avoid incorrectly
103 // propagating values that may be overwritten later on as that could be
104 // problematic for convergence based on monotonicity of lattice updates.
105 SmallVector<Value> underlyingValues;
106 underlyingValues.reserve(N: effects.size());
107 for (const MemoryEffects::EffectInstance &effect : effects) {
108 Value value = effect.getValue();
109
110 // Effects with unspecified value are treated conservatively and we cannot
111 // assume anything about the next access.
112 if (!value) {
113 setToExitState(before);
114 return success();
115 }
116
117 // If cannot find the most underlying value, we cannot assume anything about
118 // the next accesses.
119 std::optional<Value> underlyingValue =
120 UnderlyingValueAnalysis::getMostUnderlyingValue(
121 value, getUnderlyingValueFn: [&](Value value) {
122 return getOrCreateFor<UnderlyingValueLattice>(
123 dependent: getProgramPointBefore(op), anchor: value);
124 });
125
126 // If the underlying value is not known yet, don't propagate.
127 if (!underlyingValue)
128 return success();
129
130 underlyingValues.push_back(Elt: *underlyingValue);
131 }
132
133 // Update the state if all underlying values are known.
134 ChangeResult result = before->meet(lattice: after);
135 for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) {
136 // If the underlying value is known to be unknown, set to fixpoint.
137 if (!value) {
138 setToExitState(before);
139 return success();
140 }
141
142 result |= before->set(value, op);
143 }
144 propagateIfChanged(state: before, changed: result);
145 return success();
146}
147
148void NextAccessAnalysis::buildOperationEquivalentLatticeAnchor(Operation *op) {
149 if (isMemoryEffectFree(op)) {
150 unionLatticeAnchors<NextAccess>(anchor: getProgramPointBefore(op),
151 other: getProgramPointAfter(op));
152 }
153}
154
155void NextAccessAnalysis::visitCallControlFlowTransfer(
156 CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
157 NextAccess *before) {
158 if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) {
159 SmallVector<Value> underlyingValues;
160 underlyingValues.reserve(N: call->getNumOperands());
161 for (Value operand : call.getArgOperands()) {
162 std::optional<Value> underlyingValue =
163 UnderlyingValueAnalysis::getMostUnderlyingValue(
164 value: operand, getUnderlyingValueFn: [&](Value value) {
165 return getOrCreateFor<UnderlyingValueLattice>(
166 dependent: getProgramPointBefore(op: call.getOperation()), anchor: value);
167 });
168 if (!underlyingValue)
169 return;
170 underlyingValues.push_back(Elt: *underlyingValue);
171 }
172
173 ChangeResult result = before->meet(lattice: after);
174 for (Value operand : underlyingValues) {
175 result |= before->set(value: operand, op: call);
176 }
177 return propagateIfChanged(state: before, changed: result);
178 }
179 auto testCallAndStore =
180 dyn_cast<::test::TestCallAndStoreOp>(Val: call.getOperation());
181 if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
182 testCallAndStore.getStoreBeforeCall()) ||
183 (action == CallControlFlowAction::ExitCallee &&
184 !testCallAndStore.getStoreBeforeCall()))) {
185 (void)visitOperation(op: call, after, before);
186 } else {
187 AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer(
188 call, action, after, before);
189 }
190}
191
192void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
193 RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
194 RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
195 auto testStoreWithARegion =
196 dyn_cast<::test::TestStoreWithARegion>(Val: branch.getOperation());
197
198 if (testStoreWithARegion &&
199 ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
200 (regionFrom.isParent() &&
201 testStoreWithARegion.getStoreBeforeRegion()))) {
202 (void)visitOperation(op: branch, after: static_cast<const NextAccess &>(after),
203 before: static_cast<NextAccess *>(before));
204 } else {
205 propagateIfChanged(state: before, changed: before->meet(lattice: after));
206 }
207}
208
209namespace {
210struct TestNextAccessPass
211 : public PassWrapper<TestNextAccessPass, OperationPass<>> {
212 TestNextAccessPass() = default;
213 TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) {
214 interprocedural = other.interprocedural;
215 assumeFuncReads = other.assumeFuncReads;
216 }
217
218 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass)
219
220 StringRef getArgument() const override { return "test-next-access"; }
221
222 Option<bool> interprocedural{
223 *this, "interprocedural", llvm::cl::init(Val: true),
224 llvm::cl::desc("perform interprocedural analysis")};
225 Option<bool> assumeFuncReads{
226 *this, "assume-func-reads", llvm::cl::init(Val: false),
227 llvm::cl::desc(
228 "assume external functions have read effect on all arguments")};
229
230 static constexpr llvm::StringLiteral kTagAttrName = "name";
231 static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access";
232 static constexpr llvm::StringLiteral kAtEntryPointAttrName =
233 "next_at_entry_point";
234
235 static Attribute makeNextAccessAttribute(Operation *op,
236 const DataFlowSolver &solver,
237 const NextAccess *nextAccess) {
238 if (!nextAccess)
239 return StringAttr::get(context: op->getContext(), bytes: "not computed");
240
241 // Note that if the underlying value could not be computed or is unknown, we
242 // conservatively treat the result also unknown.
243 SmallVector<Attribute> attrs;
244 for (Value operand : op->getOperands()) {
245 std::optional<Value> underlyingValue =
246 UnderlyingValueAnalysis::getMostUnderlyingValue(
247 value: operand, getUnderlyingValueFn: [&](Value value) {
248 return solver.lookupState<UnderlyingValueLattice>(anchor: value);
249 });
250 if (!underlyingValue) {
251 attrs.push_back(Elt: StringAttr::get(context: op->getContext(), bytes: "unknown"));
252 continue;
253 }
254 Value value = *underlyingValue;
255 const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value);
256 if (!nextAcc || !nextAcc->isKnown()) {
257 attrs.push_back(Elt: StringAttr::get(context: op->getContext(), bytes: "unknown"));
258 continue;
259 }
260
261 SmallVector<Attribute> innerAttrs;
262 innerAttrs.reserve(N: nextAcc->get().size());
263 for (Operation *nextAccOp : nextAcc->get()) {
264 if (auto nextAccTag =
265 nextAccOp->getAttrOfType<StringAttr>(name: kTagAttrName)) {
266 innerAttrs.push_back(Elt: nextAccTag);
267 continue;
268 }
269 std::string repr;
270 llvm::raw_string_ostream os(repr);
271 nextAccOp->print(os);
272 innerAttrs.push_back(Elt: StringAttr::get(context: op->getContext(), bytes: os.str()));
273 }
274 attrs.push_back(Elt: ArrayAttr::get(context: op->getContext(), value: innerAttrs));
275 }
276 return ArrayAttr::get(context: op->getContext(), value: attrs);
277 }
278
279 void runOnOperation() override {
280 Operation *op = getOperation();
281 SymbolTableCollection symbolTable;
282
283 auto config = DataFlowConfig().setInterprocedural(interprocedural);
284 DataFlowSolver solver(config);
285 loadBaselineAnalyses(solver);
286 solver.load<NextAccessAnalysis>(args&: symbolTable, args&: assumeFuncReads);
287 solver.load<UnderlyingValueAnalysis>();
288 if (failed(Result: solver.initializeAndRun(top: op))) {
289 emitError(loc: op->getLoc(), message: "dataflow solver failed");
290 return signalPassFailure();
291 }
292 op->walk(callback: [&](Operation *op) {
293 auto tag = op->getAttrOfType<StringAttr>(name: kTagAttrName);
294 if (!tag)
295 return;
296
297 const NextAccess *nextAccess =
298 solver.lookupState<NextAccess>(anchor: solver.getProgramPointAfter(op));
299 op->setAttr(name: kNextAccessAttrName,
300 value: makeNextAccessAttribute(op, solver, nextAccess));
301
302 auto iface = dyn_cast<RegionBranchOpInterface>(Val: op);
303 if (!iface)
304 return;
305
306 SmallVector<Attribute> entryPointNextAccess;
307 SmallVector<RegionSuccessor> regionSuccessors;
308 iface.getSuccessorRegions(point: RegionBranchPoint::parent(), regions&: regionSuccessors);
309 for (const RegionSuccessor &successor : regionSuccessors) {
310 if (!successor.getSuccessor() || successor.getSuccessor()->empty())
311 continue;
312 Block &successorBlock = successor.getSuccessor()->front();
313 ProgramPoint *successorPoint =
314 solver.getProgramPointBefore(block: &successorBlock);
315 entryPointNextAccess.push_back(Elt: makeNextAccessAttribute(
316 op, solver, nextAccess: solver.lookupState<NextAccess>(anchor: successorPoint)));
317 }
318 op->setAttr(name: kAtEntryPointAttrName,
319 value: ArrayAttr::get(context: op->getContext(), value: entryPointNextAccess));
320 });
321 }
322};
323} // namespace
324
325namespace mlir::test {
326void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); }
327} // namespace mlir::test
328

source code of mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp