1 | //===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Test pass for backward dense dataflow analysis. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "TestDenseDataFlowAnalysis.h" |
14 | #include "TestDialect.h" |
15 | #include "TestOps.h" |
16 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
17 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
18 | #include "mlir/Analysis/DataFlow/DenseAnalysis.h" |
19 | #include "mlir/Analysis/DataFlowFramework.h" |
20 | #include "mlir/IR/Builders.h" |
21 | #include "mlir/IR/SymbolTable.h" |
22 | #include "mlir/Interfaces/CallInterfaces.h" |
23 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
24 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
25 | #include "mlir/Pass/Pass.h" |
26 | #include "mlir/Support/TypeID.h" |
27 | #include "llvm/Support/raw_ostream.h" |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::dataflow; |
31 | using namespace mlir::dataflow::test; |
32 | |
33 | namespace { |
34 | |
35 | class NextAccess : public AbstractDenseLattice, public AccessLatticeBase { |
36 | public: |
37 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess) |
38 | |
39 | using dataflow::AbstractDenseLattice::AbstractDenseLattice; |
40 | |
41 | ChangeResult meet(const AbstractDenseLattice &lattice) override { |
42 | return AccessLatticeBase::merge(rhs: static_cast<AccessLatticeBase>( |
43 | static_cast<const NextAccess &>(lattice))); |
44 | } |
45 | |
46 | void print(raw_ostream &os) const override { |
47 | return AccessLatticeBase::print(os); |
48 | } |
49 | }; |
50 | |
51 | class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> { |
52 | public: |
53 | NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, |
54 | bool assumeFuncReads = false) |
55 | : DenseBackwardDataFlowAnalysis(solver, symbolTable), |
56 | assumeFuncReads(assumeFuncReads) {} |
57 | |
58 | LogicalResult visitOperation(Operation *op, const NextAccess &after, |
59 | NextAccess *before) override; |
60 | |
61 | void visitCallControlFlowTransfer(CallOpInterface call, |
62 | CallControlFlowAction action, |
63 | const NextAccess &after, |
64 | NextAccess *before) override; |
65 | |
66 | void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, |
67 | RegionBranchPoint regionFrom, |
68 | RegionBranchPoint regionTo, |
69 | const NextAccess &after, |
70 | NextAccess *before) override; |
71 | |
72 | // TODO: this isn't ideal for the analysis. When there is no next access, it |
73 | // means "we don't know what the next access is" rather than "there is no next |
74 | // access". But it's unclear how to differentiate the two cases... |
75 | void setToExitState(NextAccess *lattice) override { |
76 | propagateIfChanged(state: lattice, changed: lattice->setKnownToUnknown()); |
77 | } |
78 | |
79 | /// Visit an operation. If this analysis can confirm that lattice content |
80 | /// of lattice anchors around operation are necessarily identical, join |
81 | /// them into the same equivalent class. |
82 | void buildOperationEquivalentLatticeAnchor(Operation *op) override; |
83 | |
84 | const bool assumeFuncReads; |
85 | }; |
86 | } // namespace |
87 | |
88 | LogicalResult NextAccessAnalysis::visitOperation(Operation *op, |
89 | const NextAccess &after, |
90 | NextAccess *before) { |
91 | auto memory = dyn_cast<MemoryEffectOpInterface>(op); |
92 | // If we can't reason about the memory effects, conservatively assume we can't |
93 | // say anything about the next access. |
94 | if (!memory) { |
95 | setToExitState(before); |
96 | return success(); |
97 | } |
98 | |
99 | SmallVector<MemoryEffects::EffectInstance> effects; |
100 | memory.getEffects(effects); |
101 | |
102 | // First, check if all underlying values are already known. Otherwise, avoid |
103 | // propagating and stay in the "undefined" state to avoid incorrectly |
104 | // propagating values that may be overwritten later on as that could be |
105 | // problematic for convergence based on monotonicity of lattice updates. |
106 | SmallVector<Value> underlyingValues; |
107 | underlyingValues.reserve(N: effects.size()); |
108 | for (const MemoryEffects::EffectInstance &effect : effects) { |
109 | Value value = effect.getValue(); |
110 | |
111 | // Effects with unspecified value are treated conservatively and we cannot |
112 | // assume anything about the next access. |
113 | if (!value) { |
114 | setToExitState(before); |
115 | return success(); |
116 | } |
117 | |
118 | // If cannot find the most underlying value, we cannot assume anything about |
119 | // the next accesses. |
120 | std::optional<Value> underlyingValue = |
121 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
122 | value, getUnderlyingValueFn: [&](Value value) { |
123 | return getOrCreateFor<UnderlyingValueLattice>( |
124 | dependent: getProgramPointBefore(op), anchor: value); |
125 | }); |
126 | |
127 | // If the underlying value is not known yet, don't propagate. |
128 | if (!underlyingValue) |
129 | return success(); |
130 | |
131 | underlyingValues.push_back(Elt: *underlyingValue); |
132 | } |
133 | |
134 | // Update the state if all underlying values are known. |
135 | ChangeResult result = before->meet(lattice: after); |
136 | for (const auto &[effect, value] : llvm::zip(t&: effects, u&: underlyingValues)) { |
137 | // If the underlying value is known to be unknown, set to fixpoint. |
138 | if (!value) { |
139 | setToExitState(before); |
140 | return success(); |
141 | } |
142 | |
143 | result |= before->set(value, op); |
144 | } |
145 | propagateIfChanged(state: before, changed: result); |
146 | return success(); |
147 | } |
148 | |
149 | void NextAccessAnalysis::buildOperationEquivalentLatticeAnchor(Operation *op) { |
150 | if (isMemoryEffectFree(op)) { |
151 | unionLatticeAnchors<NextAccess>(anchor: getProgramPointBefore(op), |
152 | other: getProgramPointAfter(op)); |
153 | } |
154 | } |
155 | |
156 | void NextAccessAnalysis::visitCallControlFlowTransfer( |
157 | CallOpInterface call, CallControlFlowAction action, const NextAccess &after, |
158 | NextAccess *before) { |
159 | if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) { |
160 | SmallVector<Value> underlyingValues; |
161 | underlyingValues.reserve(N: call->getNumOperands()); |
162 | for (Value operand : call.getArgOperands()) { |
163 | std::optional<Value> underlyingValue = |
164 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
165 | operand, [&](Value value) { |
166 | return getOrCreateFor<UnderlyingValueLattice>( |
167 | getProgramPointBefore(call.getOperation()), value); |
168 | }); |
169 | if (!underlyingValue) |
170 | return; |
171 | underlyingValues.push_back(*underlyingValue); |
172 | } |
173 | |
174 | ChangeResult result = before->meet(lattice: after); |
175 | for (Value operand : underlyingValues) { |
176 | result |= before->set(value: operand, op: call); |
177 | } |
178 | return propagateIfChanged(state: before, changed: result); |
179 | } |
180 | auto testCallAndStore = |
181 | dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); |
182 | if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && |
183 | testCallAndStore.getStoreBeforeCall()) || |
184 | (action == CallControlFlowAction::ExitCallee && |
185 | !testCallAndStore.getStoreBeforeCall()))) { |
186 | (void)visitOperation(op: call, after, before); |
187 | } else { |
188 | AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer( |
189 | call: call, action, after, before); |
190 | } |
191 | } |
192 | |
193 | void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( |
194 | RegionBranchOpInterface branch, RegionBranchPoint regionFrom, |
195 | RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) { |
196 | auto testStoreWithARegion = |
197 | dyn_cast<::test::TestStoreWithARegion>(branch.getOperation()); |
198 | |
199 | if (testStoreWithARegion && |
200 | ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) || |
201 | (regionFrom.isParent() && |
202 | testStoreWithARegion.getStoreBeforeRegion()))) { |
203 | (void)visitOperation(op: branch, after: static_cast<const NextAccess &>(after), |
204 | before: static_cast<NextAccess *>(before)); |
205 | } else { |
206 | propagateIfChanged(state: before, changed: before->meet(lattice: after)); |
207 | } |
208 | } |
209 | |
210 | namespace { |
211 | struct TestNextAccessPass |
212 | : public PassWrapper<TestNextAccessPass, OperationPass<>> { |
213 | TestNextAccessPass() = default; |
214 | TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) { |
215 | interprocedural = other.interprocedural; |
216 | assumeFuncReads = other.assumeFuncReads; |
217 | } |
218 | |
219 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass) |
220 | |
221 | StringRef getArgument() const override { return "test-next-access"; } |
222 | |
223 | Option<bool> interprocedural{ |
224 | *this, "interprocedural", llvm::cl::init(Val: true), |
225 | llvm::cl::desc("perform interprocedural analysis")}; |
226 | Option<bool> assumeFuncReads{ |
227 | *this, "assume-func-reads", llvm::cl::init(Val: false), |
228 | llvm::cl::desc( |
229 | "assume external functions have read effect on all arguments")}; |
230 | |
231 | static constexpr llvm::StringLiteral kTagAttrName = "name"; |
232 | static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access"; |
233 | static constexpr llvm::StringLiteral kAtEntryPointAttrName = |
234 | "next_at_entry_point"; |
235 | |
236 | static Attribute makeNextAccessAttribute(Operation *op, |
237 | const DataFlowSolver &solver, |
238 | const NextAccess *nextAccess) { |
239 | if (!nextAccess) |
240 | return StringAttr::get(op->getContext(), "not computed"); |
241 | |
242 | // Note that if the underlying value could not be computed or is unknown, we |
243 | // conservatively treat the result also unknown. |
244 | SmallVector<Attribute> attrs; |
245 | for (Value operand : op->getOperands()) { |
246 | std::optional<Value> underlyingValue = |
247 | UnderlyingValueAnalysis::getMostUnderlyingValue( |
248 | value: operand, getUnderlyingValueFn: [&](Value value) { |
249 | return solver.lookupState<UnderlyingValueLattice>(anchor: value); |
250 | }); |
251 | if (!underlyingValue) { |
252 | attrs.push_back(StringAttr::get(op->getContext(), "unknown")); |
253 | continue; |
254 | } |
255 | Value value = *underlyingValue; |
256 | const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value); |
257 | if (!nextAcc || !nextAcc->isKnown()) { |
258 | attrs.push_back(StringAttr::get(op->getContext(), "unknown")); |
259 | continue; |
260 | } |
261 | |
262 | SmallVector<Attribute> innerAttrs; |
263 | innerAttrs.reserve(N: nextAcc->get().size()); |
264 | for (Operation *nextAccOp : nextAcc->get()) { |
265 | if (auto nextAccTag = |
266 | nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) { |
267 | innerAttrs.push_back(Elt: nextAccTag); |
268 | continue; |
269 | } |
270 | std::string repr; |
271 | llvm::raw_string_ostream os(repr); |
272 | nextAccOp->print(os); |
273 | innerAttrs.push_back(StringAttr::get(op->getContext(), os.str())); |
274 | } |
275 | attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs)); |
276 | } |
277 | return ArrayAttr::get(op->getContext(), attrs); |
278 | } |
279 | |
280 | void runOnOperation() override { |
281 | Operation *op = getOperation(); |
282 | SymbolTableCollection symbolTable; |
283 | |
284 | auto config = DataFlowConfig().setInterprocedural(interprocedural); |
285 | DataFlowSolver solver(config); |
286 | solver.load<DeadCodeAnalysis>(); |
287 | solver.load<NextAccessAnalysis>(args&: symbolTable, args&: assumeFuncReads); |
288 | solver.load<SparseConstantPropagation>(); |
289 | solver.load<UnderlyingValueAnalysis>(); |
290 | if (failed(Result: solver.initializeAndRun(top: op))) { |
291 | emitError(loc: op->getLoc(), message: "dataflow solver failed"); |
292 | return signalPassFailure(); |
293 | } |
294 | op->walk(callback: [&](Operation *op) { |
295 | auto tag = op->getAttrOfType<StringAttr>(kTagAttrName); |
296 | if (!tag) |
297 | return; |
298 | |
299 | const NextAccess *nextAccess = |
300 | solver.lookupState<NextAccess>(anchor: solver.getProgramPointAfter(op)); |
301 | op->setAttr(name: kNextAccessAttrName, |
302 | value: makeNextAccessAttribute(op, solver, nextAccess)); |
303 | |
304 | auto iface = dyn_cast<RegionBranchOpInterface>(op); |
305 | if (!iface) |
306 | return; |
307 | |
308 | SmallVector<Attribute> entryPointNextAccess; |
309 | SmallVector<RegionSuccessor> regionSuccessors; |
310 | iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors); |
311 | for (const RegionSuccessor &successor : regionSuccessors) { |
312 | if (!successor.getSuccessor() || successor.getSuccessor()->empty()) |
313 | continue; |
314 | Block &successorBlock = successor.getSuccessor()->front(); |
315 | ProgramPoint *successorPoint = |
316 | solver.getProgramPointBefore(block: &successorBlock); |
317 | entryPointNextAccess.push_back(Elt: makeNextAccessAttribute( |
318 | op, solver, nextAccess: solver.lookupState<NextAccess>(anchor: successorPoint))); |
319 | } |
320 | op->setAttr(kAtEntryPointAttrName, |
321 | ArrayAttr::get(op->getContext(), entryPointNextAccess)); |
322 | }); |
323 | } |
324 | }; |
325 | } // namespace |
326 | |
327 | namespace mlir::test { |
328 | void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); } |
329 | } // namespace mlir::test |
330 |
Definitions
- NextAccess
- meet
- NextAccessAnalysis
- NextAccessAnalysis
- setToExitState
- visitOperation
- buildOperationEquivalentLatticeAnchor
- visitCallControlFlowTransfer
- visitRegionBranchControlFlowTransfer
- TestNextAccessPass
- TestNextAccessPass
- TestNextAccessPass
- getArgument
- kTagAttrName
- kNextAccessAttrName
- kAtEntryPointAttrName
- makeNextAccessAttribute
- runOnOperation
Improve your Profiling and Debugging skills
Find out more