1//===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Test pass for backward dense dataflow analysis.
10//
11//===----------------------------------------------------------------------===//
12
13#include "TestDenseDataFlowAnalysis.h"
14#include "TestDialect.h"
15#include "TestOps.h"
16#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
17#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
18#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
19#include "mlir/Analysis/DataFlowFramework.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/SymbolTable.h"
22#include "mlir/Interfaces/CallInterfaces.h"
23#include "mlir/Interfaces/ControlFlowInterfaces.h"
24#include "mlir/Interfaces/SideEffectInterfaces.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Support/TypeID.h"
27#include "llvm/Support/raw_ostream.h"
28
29using namespace mlir;
30using namespace mlir::dataflow;
31using namespace mlir::dataflow::test;
32
33namespace {
34
35class NextAccess : public AbstractDenseLattice, public AccessLatticeBase {
36public:
37 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess)
38
39 using dataflow::AbstractDenseLattice::AbstractDenseLattice;
40
41 ChangeResult meet(const AbstractDenseLattice &lattice) override {
42 return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>(
43 static_cast<const NextAccess &>(lattice)));
44 }
45
46 void print(raw_ostream &os) const override {
47 return AccessLatticeBase::print(os);
48 }
49};
50
51class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
52public:
53 NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
54 bool assumeFuncReads = false)
55 : DenseBackwardDataFlowAnalysis(solver, symbolTable),
56 assumeFuncReads(assumeFuncReads) {}
57
58 LogicalResult visitOperation(Operation *op, const NextAccess &after,
59 NextAccess *before) override;
60
61 void visitCallControlFlowTransfer(CallOpInterface call,
62 CallControlFlowAction action,
63 const NextAccess &after,
64 NextAccess *before) override;
65
66 void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
67 RegionBranchPoint regionFrom,
68 RegionBranchPoint regionTo,
69 const NextAccess &after,
70 NextAccess *before) override;
71
72 // TODO: this isn't ideal for the analysis. When there is no next access, it
73 // means "we don't know what the next access is" rather than "there is no next
74 // access". But it's unclear how to differentiate the two cases...
75 void setToExitState(NextAccess *lattice) override {
76 propagateIfChanged(state: lattice, changed: lattice->setKnownToUnknown());
77 }
78
79 /// Visit an operation. If this analysis can confirm that lattice content
80 /// of lattice anchors around operation are necessarily identical, join
81 /// them into the same equivalent class.
82 void buildOperationEquivalentLatticeAnchor(Operation *op) override;
83
84 const bool assumeFuncReads;
85};
86} // namespace
87
88LogicalResult NextAccessAnalysis::visitOperation(Operation *op,
89 const NextAccess &after,
90 NextAccess *before) {
91 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
92 // If we can't reason about the memory effects, conservatively assume we can't
93 // say anything about the next access.
94 if (!memory) {
95 setToExitState(before);
96 return success();
97 }
98
99 SmallVector<MemoryEffects::EffectInstance> effects;
100 memory.getEffects(effects);
101
102 // First, check if all underlying values are already known. Otherwise, avoid
103 // propagating and stay in the "undefined" state to avoid incorrectly
104 // propagating values that may be overwritten later on as that could be
105 // problematic for convergence based on monotonicity of lattice updates.
106 SmallVector<Value> underlyingValues;
107 underlyingValues.reserve(N: effects.size());
108 for (const MemoryEffects::EffectInstance &effect : effects) {
109 Value value = effect.getValue();
110
111 // Effects with unspecified value are treated conservatively and we cannot
112 // assume anything about the next access.
113 if (!value) {
114 setToExitState(before);
115 return success();
116 }
117
118 // If cannot find the most underlying value, we cannot assume anything about
119 // the next accesses.
120 std::optional<Value> underlyingValue =
121 UnderlyingValueAnalysis::getMostUnderlyingValue(
122 value, getUnderlyingValueFn: [&](Value value) {
123 return getOrCreateFor<UnderlyingValueLattice>(
124 dependent: getProgramPointBefore(op), anchor: value);
125 });
126
127 // If the underlying value is not known yet, don't propagate.
128 if (!underlyingValue)
129 return success();
130
131 underlyingValues.push_back(Elt: *underlyingValue);
132 }
133
134 // Update the state if all underlying values are known.
135 ChangeResult result = before->meet(lattice: after);
136 for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) {
137 // If the underlying value is known to be unknown, set to fixpoint.
138 if (!value) {
139 setToExitState(before);
140 return success();
141 }
142
143 result |= before->set(value, op);
144 }
145 propagateIfChanged(state: before, changed: result);
146 return success();
147}
148
149void NextAccessAnalysis::buildOperationEquivalentLatticeAnchor(Operation *op) {
150 if (isMemoryEffectFree(op)) {
151 unionLatticeAnchors<NextAccess>(anchor: getProgramPointBefore(op),
152 other: getProgramPointAfter(op));
153 }
154}
155
156void NextAccessAnalysis::visitCallControlFlowTransfer(
157 CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
158 NextAccess *before) {
159 if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) {
160 SmallVector<Value> underlyingValues;
161 underlyingValues.reserve(N: call->getNumOperands());
162 for (Value operand : call.getArgOperands()) {
163 std::optional<Value> underlyingValue =
164 UnderlyingValueAnalysis::getMostUnderlyingValue(
165 operand, [&](Value value) {
166 return getOrCreateFor<UnderlyingValueLattice>(
167 getProgramPointBefore(call.getOperation()), value);
168 });
169 if (!underlyingValue)
170 return;
171 underlyingValues.push_back(*underlyingValue);
172 }
173
174 ChangeResult result = before->meet(lattice: after);
175 for (Value operand : underlyingValues) {
176 result |= before->set(value: operand, op: call);
177 }
178 return propagateIfChanged(state: before, changed: result);
179 }
180 auto testCallAndStore =
181 dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
182 if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
183 testCallAndStore.getStoreBeforeCall()) ||
184 (action == CallControlFlowAction::ExitCallee &&
185 !testCallAndStore.getStoreBeforeCall()))) {
186 (void)visitOperation(op: call, after, before);
187 } else {
188 AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer(
189 call: call, action, after, before);
190 }
191}
192
193void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
194 RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
195 RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
196 auto testStoreWithARegion =
197 dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
198
199 if (testStoreWithARegion &&
200 ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
201 (regionFrom.isParent() &&
202 testStoreWithARegion.getStoreBeforeRegion()))) {
203 (void)visitOperation(op: branch, after: static_cast<const NextAccess &>(after),
204 before: static_cast<NextAccess *>(before));
205 } else {
206 propagateIfChanged(state: before, changed: before->meet(lattice: after));
207 }
208}
209
210namespace {
211struct TestNextAccessPass
212 : public PassWrapper<TestNextAccessPass, OperationPass<>> {
213 TestNextAccessPass() = default;
214 TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) {
215 interprocedural = other.interprocedural;
216 assumeFuncReads = other.assumeFuncReads;
217 }
218
219 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass)
220
221 StringRef getArgument() const override { return "test-next-access"; }
222
223 Option<bool> interprocedural{
224 *this, "interprocedural", llvm::cl::init(Val: true),
225 llvm::cl::desc("perform interprocedural analysis")};
226 Option<bool> assumeFuncReads{
227 *this, "assume-func-reads", llvm::cl::init(Val: false),
228 llvm::cl::desc(
229 "assume external functions have read effect on all arguments")};
230
231 static constexpr llvm::StringLiteral kTagAttrName = "name";
232 static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access";
233 static constexpr llvm::StringLiteral kAtEntryPointAttrName =
234 "next_at_entry_point";
235
236 static Attribute makeNextAccessAttribute(Operation *op,
237 const DataFlowSolver &solver,
238 const NextAccess *nextAccess) {
239 if (!nextAccess)
240 return StringAttr::get(op->getContext(), "not computed");
241
242 // Note that if the underlying value could not be computed or is unknown, we
243 // conservatively treat the result also unknown.
244 SmallVector<Attribute> attrs;
245 for (Value operand : op->getOperands()) {
246 std::optional<Value> underlyingValue =
247 UnderlyingValueAnalysis::getMostUnderlyingValue(
248 value: operand, getUnderlyingValueFn: [&](Value value) {
249 return solver.lookupState<UnderlyingValueLattice>(anchor: value);
250 });
251 if (!underlyingValue) {
252 attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
253 continue;
254 }
255 Value value = *underlyingValue;
256 const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value);
257 if (!nextAcc || !nextAcc->isKnown()) {
258 attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
259 continue;
260 }
261
262 SmallVector<Attribute> innerAttrs;
263 innerAttrs.reserve(N: nextAcc->get().size());
264 for (Operation *nextAccOp : nextAcc->get()) {
265 if (auto nextAccTag =
266 nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) {
267 innerAttrs.push_back(Elt: nextAccTag);
268 continue;
269 }
270 std::string repr;
271 llvm::raw_string_ostream os(repr);
272 nextAccOp->print(os);
273 innerAttrs.push_back(StringAttr::get(op->getContext(), os.str()));
274 }
275 attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs));
276 }
277 return ArrayAttr::get(op->getContext(), attrs);
278 }
279
280 void runOnOperation() override {
281 Operation *op = getOperation();
282 SymbolTableCollection symbolTable;
283
284 auto config = DataFlowConfig().setInterprocedural(interprocedural);
285 DataFlowSolver solver(config);
286 solver.load<DeadCodeAnalysis>();
287 solver.load<NextAccessAnalysis>(args&: symbolTable, args&: assumeFuncReads);
288 solver.load<SparseConstantPropagation>();
289 solver.load<UnderlyingValueAnalysis>();
290 if (failed(Result: solver.initializeAndRun(top: op))) {
291 emitError(loc: op->getLoc(), message: "dataflow solver failed");
292 return signalPassFailure();
293 }
294 op->walk(callback: [&](Operation *op) {
295 auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
296 if (!tag)
297 return;
298
299 const NextAccess *nextAccess =
300 solver.lookupState<NextAccess>(anchor: solver.getProgramPointAfter(op));
301 op->setAttr(name: kNextAccessAttrName,
302 value: makeNextAccessAttribute(op, solver, nextAccess));
303
304 auto iface = dyn_cast<RegionBranchOpInterface>(op);
305 if (!iface)
306 return;
307
308 SmallVector<Attribute> entryPointNextAccess;
309 SmallVector<RegionSuccessor> regionSuccessors;
310 iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
311 for (const RegionSuccessor &successor : regionSuccessors) {
312 if (!successor.getSuccessor() || successor.getSuccessor()->empty())
313 continue;
314 Block &successorBlock = successor.getSuccessor()->front();
315 ProgramPoint *successorPoint =
316 solver.getProgramPointBefore(block: &successorBlock);
317 entryPointNextAccess.push_back(Elt: makeNextAccessAttribute(
318 op, solver, nextAccess: solver.lookupState<NextAccess>(anchor: successorPoint)));
319 }
320 op->setAttr(kAtEntryPointAttrName,
321 ArrayAttr::get(op->getContext(), entryPointNextAccess));
322 });
323 }
324};
325} // namespace
326
327namespace mlir::test {
328void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); }
329} // namespace mlir::test
330

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp