1//===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Test pass for backward dense dataflow analysis.
10//
11//===----------------------------------------------------------------------===//
12
13#include "TestDenseDataFlowAnalysis.h"
14#include "TestDialect.h"
15#include "TestOps.h"
16#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
18#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
19#include "mlir/Analysis/DataFlowFramework.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/SymbolTable.h"
22#include "mlir/Interfaces/CallInterfaces.h"
23#include "mlir/Interfaces/ControlFlowInterfaces.h"
24#include "mlir/Interfaces/SideEffectInterfaces.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Support/TypeID.h"
27#include "llvm/Support/raw_ostream.h"
28
29using namespace mlir;
30using namespace mlir::dataflow;
31using namespace mlir::dataflow::test;
32
33namespace {
34
35class NextAccess : public AbstractDenseLattice, public AccessLatticeBase {
36public:
37 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess)
38
39 using dataflow::AbstractDenseLattice::AbstractDenseLattice;
40
41 ChangeResult meet(const AbstractDenseLattice &lattice) override {
42 return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>(
43 static_cast<const NextAccess &>(lattice)));
44 }
45
46 void print(raw_ostream &os) const override {
47 return AccessLatticeBase::print(os);
48 }
49};
50
51class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
52public:
53 NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
54 bool assumeFuncReads = false)
55 : DenseBackwardDataFlowAnalysis(solver, symbolTable),
56 assumeFuncReads(assumeFuncReads) {}
57
58 void visitOperation(Operation *op, const NextAccess &after,
59 NextAccess *before) override;
60
61 void visitCallControlFlowTransfer(CallOpInterface call,
62 CallControlFlowAction action,
63 const NextAccess &after,
64 NextAccess *before) override;
65
66 void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
67 RegionBranchPoint regionFrom,
68 RegionBranchPoint regionTo,
69 const NextAccess &after,
70 NextAccess *before) override;
71
72 // TODO: this isn't ideal for the analysis. When there is no next access, it
73 // means "we don't know what the next access is" rather than "there is no next
74 // access". But it's unclear how to differentiate the two cases...
75 void setToExitState(NextAccess *lattice) override {
76 propagateIfChanged(state: lattice, changed: lattice->setKnownToUnknown());
77 }
78
79 const bool assumeFuncReads;
80};
81} // namespace
82
83void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after,
84 NextAccess *before) {
85 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
86 // If we can't reason about the memory effects, conservatively assume we can't
87 // say anything about the next access.
88 if (!memory)
89 return setToExitState(before);
90
91 SmallVector<MemoryEffects::EffectInstance> effects;
92 memory.getEffects(effects);
93
94 // First, check if all underlying values are already known. Otherwise, avoid
95 // propagating and stay in the "undefined" state to avoid incorrectly
96 // propagating values that may be overwritten later on as that could be
97 // problematic for convergence based on monotonicity of lattice updates.
98 SmallVector<Value> underlyingValues;
99 underlyingValues.reserve(N: effects.size());
100 for (const MemoryEffects::EffectInstance &effect : effects) {
101 Value value = effect.getValue();
102
103 // Effects with unspecified value are treated conservatively and we cannot
104 // assume anything about the next access.
105 if (!value)
106 return setToExitState(before);
107
108 // If cannot find the most underlying value, we cannot assume anything about
109 // the next accesses.
110 std::optional<Value> underlyingValue =
111 UnderlyingValueAnalysis::getMostUnderlyingValue(
112 value, getUnderlyingValueFn: [&](Value value) {
113 return getOrCreateFor<UnderlyingValueLattice>(dependent: op, point: value);
114 });
115
116 // If the underlying value is not known yet, don't propagate.
117 if (!underlyingValue)
118 return;
119
120 underlyingValues.push_back(Elt: *underlyingValue);
121 }
122
123 // Update the state if all underlying values are known.
124 ChangeResult result = before->meet(lattice: after);
125 for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) {
126 // If the underlying value is known to be unknown, set to fixpoint.
127 if (!value)
128 return setToExitState(before);
129
130 result |= before->set(value, op);
131 }
132 propagateIfChanged(state: before, changed: result);
133}
134
135void NextAccessAnalysis::visitCallControlFlowTransfer(
136 CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
137 NextAccess *before) {
138 if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) {
139 SmallVector<Value> underlyingValues;
140 underlyingValues.reserve(N: call->getNumOperands());
141 for (Value operand : call.getArgOperands()) {
142 std::optional<Value> underlyingValue =
143 UnderlyingValueAnalysis::getMostUnderlyingValue(
144 operand, [&](Value value) {
145 return getOrCreateFor<UnderlyingValueLattice>(
146 call.getOperation(), value);
147 });
148 if (!underlyingValue)
149 return;
150 underlyingValues.push_back(*underlyingValue);
151 }
152
153 ChangeResult result = before->meet(lattice: after);
154 for (Value operand : underlyingValues) {
155 result |= before->set(value: operand, op: call);
156 }
157 return propagateIfChanged(state: before, changed: result);
158 }
159 auto testCallAndStore =
160 dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
161 if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
162 testCallAndStore.getStoreBeforeCall()) ||
163 (action == CallControlFlowAction::ExitCallee &&
164 !testCallAndStore.getStoreBeforeCall()))) {
165 visitOperation(op: call, after, before);
166 } else {
167 AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer(
168 call: call, action, after, before);
169 }
170}
171
172void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
173 RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
174 RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
175 auto testStoreWithARegion =
176 dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
177
178 if (testStoreWithARegion &&
179 ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
180 (regionFrom.isParent() &&
181 testStoreWithARegion.getStoreBeforeRegion()))) {
182 visitOperation(op: branch, after: static_cast<const NextAccess &>(after),
183 before: static_cast<NextAccess *>(before));
184 } else {
185 propagateIfChanged(state: before, changed: before->meet(lattice: after));
186 }
187}
188
189namespace {
190struct TestNextAccessPass
191 : public PassWrapper<TestNextAccessPass, OperationPass<>> {
192 TestNextAccessPass() = default;
193 TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) {
194 interprocedural = other.interprocedural;
195 assumeFuncReads = other.assumeFuncReads;
196 }
197
198 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass)
199
200 StringRef getArgument() const override { return "test-next-access"; }
201
202 Option<bool> interprocedural{
203 *this, "interprocedural", llvm::cl::init(Val: true),
204 llvm::cl::desc("perform interprocedural analysis")};
205 Option<bool> assumeFuncReads{
206 *this, "assume-func-reads", llvm::cl::init(Val: false),
207 llvm::cl::desc(
208 "assume external functions have read effect on all arguments")};
209
210 static constexpr llvm::StringLiteral kTagAttrName = "name";
211 static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access";
212 static constexpr llvm::StringLiteral kAtEntryPointAttrName =
213 "next_at_entry_point";
214
215 static Attribute makeNextAccessAttribute(Operation *op,
216 const DataFlowSolver &solver,
217 const NextAccess *nextAccess) {
218 if (!nextAccess)
219 return StringAttr::get(op->getContext(), "not computed");
220
221 // Note that if the underlying value could not be computed or is unknown, we
222 // conservatively treat the result also unknown.
223 SmallVector<Attribute> attrs;
224 for (Value operand : op->getOperands()) {
225 std::optional<Value> underlyingValue =
226 UnderlyingValueAnalysis::getMostUnderlyingValue(
227 value: operand, getUnderlyingValueFn: [&](Value value) {
228 return solver.lookupState<UnderlyingValueLattice>(point: value);
229 });
230 if (!underlyingValue) {
231 attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
232 continue;
233 }
234 Value value = *underlyingValue;
235 const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value);
236 if (!nextAcc || !nextAcc->isKnown()) {
237 attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
238 continue;
239 }
240
241 SmallVector<Attribute> innerAttrs;
242 innerAttrs.reserve(N: nextAcc->get().size());
243 for (Operation *nextAccOp : nextAcc->get()) {
244 if (auto nextAccTag =
245 nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) {
246 innerAttrs.push_back(Elt: nextAccTag);
247 continue;
248 }
249 std::string repr;
250 llvm::raw_string_ostream os(repr);
251 nextAccOp->print(os);
252 innerAttrs.push_back(StringAttr::get(op->getContext(), os.str()));
253 }
254 attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs));
255 }
256 return ArrayAttr::get(op->getContext(), attrs);
257 }
258
259 void runOnOperation() override {
260 Operation *op = getOperation();
261 SymbolTableCollection symbolTable;
262
263 auto config = DataFlowConfig().setInterprocedural(interprocedural);
264 DataFlowSolver solver(config);
265 solver.load<DeadCodeAnalysis>();
266 solver.load<NextAccessAnalysis>(args&: symbolTable, args&: assumeFuncReads);
267 solver.load<SparseConstantPropagation>();
268 solver.load<UnderlyingValueAnalysis>();
269 if (failed(result: solver.initializeAndRun(top: op))) {
270 emitError(loc: op->getLoc(), message: "dataflow solver failed");
271 return signalPassFailure();
272 }
273 op->walk(callback: [&](Operation *op) {
274 auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
275 if (!tag)
276 return;
277
278 const NextAccess *nextAccess = solver.lookupState<NextAccess>(
279 op->getNextNode() == nullptr ? ProgramPoint(op->getBlock())
280 : op->getNextNode());
281 op->setAttr(name: kNextAccessAttrName,
282 value: makeNextAccessAttribute(op, solver, nextAccess));
283
284 auto iface = dyn_cast<RegionBranchOpInterface>(op);
285 if (!iface)
286 return;
287
288 SmallVector<Attribute> entryPointNextAccess;
289 SmallVector<RegionSuccessor> regionSuccessors;
290 iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
291 for (const RegionSuccessor &successor : regionSuccessors) {
292 if (!successor.getSuccessor() || successor.getSuccessor()->empty())
293 continue;
294 Block &successorBlock = successor.getSuccessor()->front();
295 ProgramPoint successorPoint = successorBlock.empty()
296 ? ProgramPoint(&successorBlock)
297 : &successorBlock.front();
298 entryPointNextAccess.push_back(Elt: makeNextAccessAttribute(
299 op, solver, nextAccess: solver.lookupState<NextAccess>(point: successorPoint)));
300 }
301 op->setAttr(kAtEntryPointAttrName,
302 ArrayAttr::get(op->getContext(), entryPointNextAccess));
303 });
304 }
305};
306} // namespace
307
308namespace mlir::test {
309void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); }
310} // namespace mlir::test
311

source code of mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp