1 | //===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Test pass for backward dense dataflow analysis. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "TestDenseDataFlowAnalysis.h" |
14 | #include "TestDialect.h" |
15 | #include "TestOps.h" |
16 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
17 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
18 | #include "mlir/Analysis/DataFlow/DenseAnalysis.h" |
19 | #include "mlir/Analysis/DataFlowFramework.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/SymbolTable.h" |
22 | #include "mlir/Interfaces/CallInterfaces.h" |
23 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
24 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
25 | #include "mlir/Pass/Pass.h" |
26 | #include "mlir/Support/TypeID.h" |
27 | #include "llvm/Support/raw_ostream.h" |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::dataflow; |
31 | using namespace mlir::dataflow::test; |
32 | |
33 | namespace { |
34 | |
35 | class NextAccess : public AbstractDenseLattice, public AccessLatticeBase { |
36 | public: |
37 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess) |
38 | |
39 | using dataflow::AbstractDenseLattice::AbstractDenseLattice; |
40 | |
41 | ChangeResult meet(const AbstractDenseLattice &lattice) override { |
42 | return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>( |
43 | static_cast<const NextAccess &>(lattice))); |
44 | } |
45 | |
46 | void print(raw_ostream &os) const override { |
47 | return AccessLatticeBase::print(os); |
48 | } |
49 | }; |
50 | |
51 | class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> { |
52 | public: |
53 | NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, |
54 | bool assumeFuncReads = false) |
55 | : DenseBackwardDataFlowAnalysis(solver, symbolTable), |
56 | assumeFuncReads(assumeFuncReads) {} |
57 | |
58 | void visitOperation(Operation *op, const NextAccess &after, |
59 | NextAccess *before) override; |
60 | |
61 | void visitCallControlFlowTransfer(CallOpInterface call, |
62 | CallControlFlowAction action, |
63 | const NextAccess &after, |
64 | NextAccess *before) override; |
65 | |
66 | void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, |
67 | RegionBranchPoint regionFrom, |
68 | RegionBranchPoint regionTo, |
69 | const NextAccess &after, |
70 | NextAccess *before) override; |
71 | |
72 | // TODO: this isn't ideal for the analysis. When there is no next access, it |
73 | // means "we don't know what the next access is" rather than "there is no next |
74 | // access". But it's unclear how to differentiate the two cases... |
75 | void setToExitState(NextAccess *lattice) override { |
76 | propagateIfChanged(state: lattice, changed: lattice->setKnownToUnknown()); |
77 | } |
78 | |
79 | const bool assumeFuncReads; |
80 | }; |
81 | } // namespace |
82 | |
83 | void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after, |
84 | NextAccess *before) { |
85 | auto memory = dyn_cast<MemoryEffectOpInterface>(op); |
86 | // If we can't reason about the memory effects, conservatively assume we can't |
87 | // say anything about the next access. |
88 | if (!memory) |
89 | return setToExitState(before); |
90 | |
91 | SmallVector<MemoryEffects::EffectInstance> effects; |
92 | memory.getEffects(effects); |
93 | |
94 | // First, check if all underlying values are already known. Otherwise, avoid |
95 | // propagating and stay in the "undefined" state to avoid incorrectly |
96 | // propagating values that may be overwritten later on as that could be |
97 | // problematic for convergence based on monotonicity of lattice updates. |
98 | SmallVector<Value> underlyingValues; |
99 | underlyingValues.reserve(N: effects.size()); |
100 | for (const MemoryEffects::EffectInstance &effect : effects) { |
101 | Value value = effect.getValue(); |
102 | |
103 | // Effects with unspecified value are treated conservatively and we cannot |
104 | // assume anything about the next access. |
105 | if (!value) |
106 | return setToExitState(before); |
107 | |
108 | // If cannot find the most underlying value, we cannot assume anything about |
109 | // the next accesses. |
110 | std::optional<Value> underlyingValue = |
111 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
112 | value, getUnderlyingValueFn: [&](Value value) { |
113 | return getOrCreateFor<UnderlyingValueLattice>(dependent: op, point: value); |
114 | }); |
115 | |
116 | // If the underlying value is not known yet, don't propagate. |
117 | if (!underlyingValue) |
118 | return; |
119 | |
120 | underlyingValues.push_back(Elt: *underlyingValue); |
121 | } |
122 | |
123 | // Update the state if all underlying values are known. |
124 | ChangeResult result = before->meet(lattice: after); |
125 | for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) { |
126 | // If the underlying value is known to be unknown, set to fixpoint. |
127 | if (!value) |
128 | return setToExitState(before); |
129 | |
130 | result |= before->set(value, op); |
131 | } |
132 | propagateIfChanged(state: before, changed: result); |
133 | } |
134 | |
135 | void NextAccessAnalysis::visitCallControlFlowTransfer( |
136 | CallOpInterface call, CallControlFlowAction action, const NextAccess &after, |
137 | NextAccess *before) { |
138 | if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) { |
139 | SmallVector<Value> underlyingValues; |
140 | underlyingValues.reserve(N: call->getNumOperands()); |
141 | for (Value operand : call.getArgOperands()) { |
142 | std::optional<Value> underlyingValue = |
143 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
144 | operand, [&](Value value) { |
145 | return getOrCreateFor<UnderlyingValueLattice>( |
146 | call.getOperation(), value); |
147 | }); |
148 | if (!underlyingValue) |
149 | return; |
150 | underlyingValues.push_back(*underlyingValue); |
151 | } |
152 | |
153 | ChangeResult result = before->meet(lattice: after); |
154 | for (Value operand : underlyingValues) { |
155 | result |= before->set(value: operand, op: call); |
156 | } |
157 | return propagateIfChanged(state: before, changed: result); |
158 | } |
159 | auto testCallAndStore = |
160 | dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); |
161 | if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && |
162 | testCallAndStore.getStoreBeforeCall()) || |
163 | (action == CallControlFlowAction::ExitCallee && |
164 | !testCallAndStore.getStoreBeforeCall()))) { |
165 | visitOperation(op: call, after, before); |
166 | } else { |
167 | AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer( |
168 | call: call, action, after, before); |
169 | } |
170 | } |
171 | |
172 | void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( |
173 | RegionBranchOpInterface branch, RegionBranchPoint regionFrom, |
174 | RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) { |
175 | auto testStoreWithARegion = |
176 | dyn_cast<::test::TestStoreWithARegion>(branch.getOperation()); |
177 | |
178 | if (testStoreWithARegion && |
179 | ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) || |
180 | (regionFrom.isParent() && |
181 | testStoreWithARegion.getStoreBeforeRegion()))) { |
182 | visitOperation(op: branch, after: static_cast<const NextAccess &>(after), |
183 | before: static_cast<NextAccess *>(before)); |
184 | } else { |
185 | propagateIfChanged(state: before, changed: before->meet(lattice: after)); |
186 | } |
187 | } |
188 | |
189 | namespace { |
190 | struct TestNextAccessPass |
191 | : public PassWrapper<TestNextAccessPass, OperationPass<>> { |
192 | TestNextAccessPass() = default; |
193 | TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) { |
194 | interprocedural = other.interprocedural; |
195 | assumeFuncReads = other.assumeFuncReads; |
196 | } |
197 | |
198 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass) |
199 | |
200 | StringRef getArgument() const override { return "test-next-access" ; } |
201 | |
202 | Option<bool> interprocedural{ |
203 | *this, "interprocedural" , llvm::cl::init(Val: true), |
204 | llvm::cl::desc("perform interprocedural analysis" )}; |
205 | Option<bool> assumeFuncReads{ |
206 | *this, "assume-func-reads" , llvm::cl::init(Val: false), |
207 | llvm::cl::desc( |
208 | "assume external functions have read effect on all arguments" )}; |
209 | |
210 | static constexpr llvm::StringLiteral kTagAttrName = "name" ; |
211 | static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access" ; |
212 | static constexpr llvm::StringLiteral kAtEntryPointAttrName = |
213 | "next_at_entry_point" ; |
214 | |
215 | static Attribute makeNextAccessAttribute(Operation *op, |
216 | const DataFlowSolver &solver, |
217 | const NextAccess *nextAccess) { |
218 | if (!nextAccess) |
219 | return StringAttr::get(op->getContext(), "not computed" ); |
220 | |
221 | // Note that if the underlying value could not be computed or is unknown, we |
222 | // conservatively treat the result also unknown. |
223 | SmallVector<Attribute> attrs; |
224 | for (Value operand : op->getOperands()) { |
225 | std::optional<Value> underlyingValue = |
226 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
227 | value: operand, getUnderlyingValueFn: [&](Value value) { |
228 | return solver.lookupState<UnderlyingValueLattice>(point: value); |
229 | }); |
230 | if (!underlyingValue) { |
231 | attrs.push_back(StringAttr::get(op->getContext(), "unknown" )); |
232 | continue; |
233 | } |
234 | Value value = *underlyingValue; |
235 | const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value); |
236 | if (!nextAcc || !nextAcc->isKnown()) { |
237 | attrs.push_back(StringAttr::get(op->getContext(), "unknown" )); |
238 | continue; |
239 | } |
240 | |
241 | SmallVector<Attribute> innerAttrs; |
242 | innerAttrs.reserve(N: nextAcc->get().size()); |
243 | for (Operation *nextAccOp : nextAcc->get()) { |
244 | if (auto nextAccTag = |
245 | nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) { |
246 | innerAttrs.push_back(Elt: nextAccTag); |
247 | continue; |
248 | } |
249 | std::string repr; |
250 | llvm::raw_string_ostream os(repr); |
251 | nextAccOp->print(os); |
252 | innerAttrs.push_back(StringAttr::get(op->getContext(), os.str())); |
253 | } |
254 | attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs)); |
255 | } |
256 | return ArrayAttr::get(op->getContext(), attrs); |
257 | } |
258 | |
259 | void runOnOperation() override { |
260 | Operation *op = getOperation(); |
261 | SymbolTableCollection symbolTable; |
262 | |
263 | auto config = DataFlowConfig().setInterprocedural(interprocedural); |
264 | DataFlowSolver solver(config); |
265 | solver.load<DeadCodeAnalysis>(); |
266 | solver.load<NextAccessAnalysis>(args&: symbolTable, args&: assumeFuncReads); |
267 | solver.load<SparseConstantPropagation>(); |
268 | solver.load<UnderlyingValueAnalysis>(); |
269 | if (failed(result: solver.initializeAndRun(top: op))) { |
270 | emitError(loc: op->getLoc(), message: "dataflow solver failed" ); |
271 | return signalPassFailure(); |
272 | } |
273 | op->walk(callback: [&](Operation *op) { |
274 | auto tag = op->getAttrOfType<StringAttr>(kTagAttrName); |
275 | if (!tag) |
276 | return; |
277 | |
278 | const NextAccess *nextAccess = solver.lookupState<NextAccess>( |
279 | op->getNextNode() == nullptr ? ProgramPoint(op->getBlock()) |
280 | : op->getNextNode()); |
281 | op->setAttr(name: kNextAccessAttrName, |
282 | value: makeNextAccessAttribute(op, solver, nextAccess)); |
283 | |
284 | auto iface = dyn_cast<RegionBranchOpInterface>(op); |
285 | if (!iface) |
286 | return; |
287 | |
288 | SmallVector<Attribute> entryPointNextAccess; |
289 | SmallVector<RegionSuccessor> regionSuccessors; |
290 | iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors); |
291 | for (const RegionSuccessor &successor : regionSuccessors) { |
292 | if (!successor.getSuccessor() || successor.getSuccessor()->empty()) |
293 | continue; |
294 | Block &successorBlock = successor.getSuccessor()->front(); |
295 | ProgramPoint successorPoint = successorBlock.empty() |
296 | ? ProgramPoint(&successorBlock) |
297 | : &successorBlock.front(); |
298 | entryPointNextAccess.push_back(Elt: makeNextAccessAttribute( |
299 | op, solver, nextAccess: solver.lookupState<NextAccess>(point: successorPoint))); |
300 | } |
301 | op->setAttr(kAtEntryPointAttrName, |
302 | ArrayAttr::get(op->getContext(), entryPointNextAccess)); |
303 | }); |
304 | } |
305 | }; |
306 | } // namespace |
307 | |
308 | namespace mlir::test { |
309 | void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); } |
310 | } // namespace mlir::test |
311 | |