1 | //===- DeadCodeAnalysis.cpp - Dead code analysis --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
10 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
11 | #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
12 | #include "mlir/Analysis/DataFlowFramework.h" |
13 | #include "mlir/IR/Attributes.h" |
14 | #include "mlir/IR/Block.h" |
15 | #include "mlir/IR/Diagnostics.h" |
16 | #include "mlir/IR/Location.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/IR/SymbolTable.h" |
19 | #include "mlir/IR/Value.h" |
20 | #include "mlir/IR/ValueRange.h" |
21 | #include "mlir/Interfaces/CallInterfaces.h" |
22 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
23 | #include "mlir/Support/LLVM.h" |
24 | #include "mlir/Support/LogicalResult.h" |
25 | #include "llvm/Support/Casting.h" |
26 | #include <cassert> |
27 | #include <optional> |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::dataflow; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // Executable |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | ChangeResult Executable::setToLive() { |
37 | if (live) |
38 | return ChangeResult::NoChange; |
39 | live = true; |
40 | return ChangeResult::Change; |
41 | } |
42 | |
43 | void Executable::print(raw_ostream &os) const { |
44 | os << (live ? "live" : "dead" ); |
45 | } |
46 | |
47 | void Executable::onUpdate(DataFlowSolver *solver) const { |
48 | AnalysisState::onUpdate(solver); |
49 | |
50 | if (auto *block = llvm::dyn_cast_if_present<Block *>(Val: point)) { |
51 | // Re-invoke the analyses on the block itself. |
52 | for (DataFlowAnalysis *analysis : subscribers) |
53 | solver->enqueue(item: {block, analysis}); |
54 | // Re-invoke the analyses on all operations in the block. |
55 | for (DataFlowAnalysis *analysis : subscribers) |
56 | for (Operation &op : *block) |
57 | solver->enqueue(item: {&op, analysis}); |
58 | } else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(Val: point)) { |
59 | // Re-invoke the analysis on the successor block. |
60 | if (auto *edge = dyn_cast<CFGEdge>(Val: programPoint)) { |
61 | for (DataFlowAnalysis *analysis : subscribers) |
62 | solver->enqueue(item: {edge->getTo(), analysis}); |
63 | } |
64 | } |
65 | } |
66 | |
67 | //===----------------------------------------------------------------------===// |
68 | // PredecessorState |
69 | //===----------------------------------------------------------------------===// |
70 | |
71 | void PredecessorState::print(raw_ostream &os) const { |
72 | if (allPredecessorsKnown()) |
73 | os << "(all) " ; |
74 | os << "predecessors:\n" ; |
75 | for (Operation *op : getKnownPredecessors()) |
76 | os << " " << *op << "\n" ; |
77 | } |
78 | |
79 | ChangeResult PredecessorState::join(Operation *predecessor) { |
80 | return knownPredecessors.insert(X: predecessor) ? ChangeResult::Change |
81 | : ChangeResult::NoChange; |
82 | } |
83 | |
84 | ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) { |
85 | ChangeResult result = join(predecessor); |
86 | if (!inputs.empty()) { |
87 | ValueRange &curInputs = successorInputs[predecessor]; |
88 | if (curInputs != inputs) { |
89 | curInputs = inputs; |
90 | result |= ChangeResult::Change; |
91 | } |
92 | } |
93 | return result; |
94 | } |
95 | |
96 | //===----------------------------------------------------------------------===// |
97 | // CFGEdge |
98 | //===----------------------------------------------------------------------===// |
99 | |
100 | Location CFGEdge::getLoc() const { |
101 | return FusedLoc::get( |
102 | getFrom()->getParent()->getContext(), |
103 | {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()}); |
104 | } |
105 | |
106 | void CFGEdge::print(raw_ostream &os) const { |
107 | getFrom()->print(os); |
108 | os << "\n -> \n" ; |
109 | getTo()->print(os); |
110 | } |
111 | |
112 | //===----------------------------------------------------------------------===// |
113 | // DeadCodeAnalysis |
114 | //===----------------------------------------------------------------------===// |
115 | |
116 | DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver) |
117 | : DataFlowAnalysis(solver) { |
118 | registerPointKind<CFGEdge>(); |
119 | } |
120 | |
121 | LogicalResult DeadCodeAnalysis::initialize(Operation *top) { |
122 | // Mark the top-level blocks as executable. |
123 | for (Region ®ion : top->getRegions()) { |
124 | if (region.empty()) |
125 | continue; |
126 | auto *state = getOrCreate<Executable>(point: ®ion.front()); |
127 | propagateIfChanged(state, changed: state->setToLive()); |
128 | } |
129 | |
130 | // Mark as overdefined the predecessors of symbol callables with potentially |
131 | // unknown predecessors. |
132 | initializeSymbolCallables(top); |
133 | |
134 | return initializeRecursively(op: top); |
135 | } |
136 | |
137 | void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { |
138 | analysisScope = top; |
139 | auto walkFn = [&](Operation *symTable, bool allUsesVisible) { |
140 | Region &symbolTableRegion = symTable->getRegion(index: 0); |
141 | Block *symbolTableBlock = &symbolTableRegion.front(); |
142 | |
143 | bool foundSymbolCallable = false; |
144 | for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) { |
145 | Region *callableRegion = callable.getCallableRegion(); |
146 | if (!callableRegion) |
147 | continue; |
148 | auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation()); |
149 | if (!symbol) |
150 | continue; |
151 | |
152 | // Public symbol callables or those for which we can't see all uses have |
153 | // potentially unknown callsites. |
154 | if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { |
155 | auto *state = getOrCreate<PredecessorState>(callable); |
156 | propagateIfChanged(state, state->setHasUnknownPredecessors()); |
157 | } |
158 | foundSymbolCallable = true; |
159 | } |
160 | |
161 | // Exit early if no eligible symbol callables were found in the table. |
162 | if (!foundSymbolCallable) |
163 | return; |
164 | |
165 | // Walk the symbol table to check for non-call uses of symbols. |
166 | std::optional<SymbolTable::UseRange> uses = |
167 | SymbolTable::getSymbolUses(from: &symbolTableRegion); |
168 | if (!uses) { |
169 | // If we couldn't gather the symbol uses, conservatively assume that |
170 | // we can't track information for any nested symbols. |
171 | return top->walk([&](CallableOpInterface callable) { |
172 | auto *state = getOrCreate<PredecessorState>(callable); |
173 | propagateIfChanged(state, state->setHasUnknownPredecessors()); |
174 | }); |
175 | } |
176 | |
177 | for (const SymbolTable::SymbolUse &use : *uses) { |
178 | if (isa<CallOpInterface>(Val: use.getUser())) |
179 | continue; |
180 | // If a callable symbol has a non-call use, then we can't be guaranteed to |
181 | // know all callsites. |
182 | Operation *symbol = symbolTable.lookupSymbolIn(symbolTableOp: top, name: use.getSymbolRef()); |
183 | auto *state = getOrCreate<PredecessorState>(point: symbol); |
184 | propagateIfChanged(state, changed: state->setHasUnknownPredecessors()); |
185 | } |
186 | }; |
187 | SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(), |
188 | walkFn); |
189 | } |
190 | |
191 | /// Returns true if the operation is a returning terminator in region |
192 | /// control-flow or the terminator of a callable region. |
193 | static bool isRegionOrCallableReturn(Operation *op) { |
194 | return !op->getNumSuccessors() && |
195 | isa<RegionBranchOpInterface, CallableOpInterface>(Val: op->getParentOp()) && |
196 | op->getBlock()->getTerminator() == op; |
197 | } |
198 | |
199 | LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { |
200 | // Initialize the analysis by visiting every op with control-flow semantics. |
201 | if (op->getNumRegions() || op->getNumSuccessors() || |
202 | isRegionOrCallableReturn(op) || isa<CallOpInterface>(Val: op)) { |
203 | // When the liveness of the parent block changes, make sure to re-invoke the |
204 | // analysis on the op. |
205 | if (op->getBlock()) |
206 | getOrCreate<Executable>(point: op->getBlock())->blockContentSubscribe(analysis: this); |
207 | // Visit the op. |
208 | if (failed(result: visit(point: op))) |
209 | return failure(); |
210 | } |
211 | // Recurse on nested operations. |
212 | for (Region ®ion : op->getRegions()) |
213 | for (Operation &op : region.getOps()) |
214 | if (failed(result: initializeRecursively(op: &op))) |
215 | return failure(); |
216 | return success(); |
217 | } |
218 | |
219 | void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { |
220 | auto *state = getOrCreate<Executable>(point: to); |
221 | propagateIfChanged(state, changed: state->setToLive()); |
222 | auto *edgeState = getOrCreate<Executable>(point: getProgramPoint<CFGEdge>(args&: from, args&: to)); |
223 | propagateIfChanged(state: edgeState, changed: edgeState->setToLive()); |
224 | } |
225 | |
226 | void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) { |
227 | for (Region ®ion : op->getRegions()) { |
228 | if (region.empty()) |
229 | continue; |
230 | auto *state = getOrCreate<Executable>(point: ®ion.front()); |
231 | propagateIfChanged(state, changed: state->setToLive()); |
232 | } |
233 | } |
234 | |
235 | LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) { |
236 | if (point.is<Block *>()) |
237 | return success(); |
238 | auto *op = llvm::dyn_cast_if_present<Operation *>(Val&: point); |
239 | if (!op) |
240 | return emitError(loc: point.getLoc(), message: "unknown program point kind" ); |
241 | |
242 | // If the parent block is not executable, there is nothing to do. |
243 | if (!getOrCreate<Executable>(point: op->getBlock())->isLive()) |
244 | return success(); |
245 | |
246 | // We have a live call op. Add this as a live predecessor of the callee. |
247 | if (auto call = dyn_cast<CallOpInterface>(op)) |
248 | visitCallOperation(call: call); |
249 | |
250 | // Visit the regions. |
251 | if (op->getNumRegions()) { |
252 | // Check if we can reason about the region control-flow. |
253 | if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { |
254 | visitRegionBranchOperation(branch: branch); |
255 | |
256 | // Check if this is a callable operation. |
257 | } else if (auto callable = dyn_cast<CallableOpInterface>(op)) { |
258 | const auto *callsites = getOrCreateFor<PredecessorState>(op, callable); |
259 | |
260 | // If the callsites could not be resolved or are known to be non-empty, |
261 | // mark the callable as executable. |
262 | if (!callsites->allPredecessorsKnown() || |
263 | !callsites->getKnownPredecessors().empty()) |
264 | markEntryBlocksLive(op: callable); |
265 | |
266 | // Otherwise, conservatively mark all entry blocks as executable. |
267 | } else { |
268 | markEntryBlocksLive(op); |
269 | } |
270 | } |
271 | |
272 | if (isRegionOrCallableReturn(op)) { |
273 | if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) { |
274 | // Visit the exiting terminator of a region. |
275 | visitRegionTerminator(op, branch: branch); |
276 | } else if (auto callable = |
277 | dyn_cast<CallableOpInterface>(op->getParentOp())) { |
278 | // Visit the exiting terminator of a callable. |
279 | visitCallableTerminator(op, callable: callable); |
280 | } |
281 | } |
282 | // Visit the successors. |
283 | if (op->getNumSuccessors()) { |
284 | // Check if we can reason about the control-flow. |
285 | if (auto branch = dyn_cast<BranchOpInterface>(op)) { |
286 | visitBranchOperation(branch: branch); |
287 | |
288 | // Otherwise, conservatively mark all successors as exectuable. |
289 | } else { |
290 | for (Block *successor : op->getSuccessors()) |
291 | markEdgeLive(from: op->getBlock(), to: successor); |
292 | } |
293 | } |
294 | |
295 | return success(); |
296 | } |
297 | |
298 | void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { |
299 | Operation *callableOp = call.resolveCallable(&symbolTable); |
300 | |
301 | // A call to a externally-defined callable has unknown predecessors. |
302 | const auto isExternalCallable = [this](Operation *op) { |
303 | // A callable outside the analysis scope is an external callable. |
304 | if (!analysisScope->isAncestor(other: op)) |
305 | return true; |
306 | // Otherwise, check if the callable region is defined. |
307 | if (auto callable = dyn_cast<CallableOpInterface>(op)) |
308 | return !callable.getCallableRegion(); |
309 | return false; |
310 | }; |
311 | |
312 | // TODO: Add support for non-symbol callables when necessary. If the |
313 | // callable has non-call uses we would mark as having reached pessimistic |
314 | // fixpoint, otherwise allow for propagating the return values out. |
315 | if (isa_and_nonnull<SymbolOpInterface>(callableOp) && |
316 | !isExternalCallable(callableOp)) { |
317 | // Add the live callsite. |
318 | auto *callsites = getOrCreate<PredecessorState>(point: callableOp); |
319 | propagateIfChanged(state: callsites, changed: callsites->join(call)); |
320 | } else { |
321 | // Mark this call op's predecessors as overdefined. |
322 | auto *predecessors = getOrCreate<PredecessorState>(call); |
323 | propagateIfChanged(state: predecessors, changed: predecessors->setHasUnknownPredecessors()); |
324 | } |
325 | } |
326 | |
327 | /// Get the constant values of the operands of an operation. If any of the |
328 | /// constant value lattices are uninitialized, return std::nullopt to indicate |
329 | /// the analysis should bail out. |
330 | static std::optional<SmallVector<Attribute>> getOperandValuesImpl( |
331 | Operation *op, |
332 | function_ref<const Lattice<ConstantValue> *(Value)> getLattice) { |
333 | SmallVector<Attribute> operands; |
334 | operands.reserve(N: op->getNumOperands()); |
335 | for (Value operand : op->getOperands()) { |
336 | const Lattice<ConstantValue> *cv = getLattice(operand); |
337 | // If any of the operands' values are uninitialized, bail out. |
338 | if (cv->getValue().isUninitialized()) |
339 | return {}; |
340 | operands.push_back(Elt: cv->getValue().getConstantValue()); |
341 | } |
342 | return operands; |
343 | } |
344 | |
345 | std::optional<SmallVector<Attribute>> |
346 | DeadCodeAnalysis::getOperandValues(Operation *op) { |
347 | return getOperandValuesImpl(op, getLattice: [&](Value value) { |
348 | auto *lattice = getOrCreate<Lattice<ConstantValue>>(point: value); |
349 | lattice->useDefSubscribe(analysis: this); |
350 | return lattice; |
351 | }); |
352 | } |
353 | |
354 | void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { |
355 | // Try to deduce a single successor for the branch. |
356 | std::optional<SmallVector<Attribute>> operands = getOperandValues(op: branch); |
357 | if (!operands) |
358 | return; |
359 | |
360 | if (Block *successor = branch.getSuccessorForOperands(*operands)) { |
361 | markEdgeLive(from: branch->getBlock(), to: successor); |
362 | } else { |
363 | // Otherwise, mark all successors as executable and outgoing edges. |
364 | for (Block *successor : branch->getSuccessors()) |
365 | markEdgeLive(branch->getBlock(), successor); |
366 | } |
367 | } |
368 | |
369 | void DeadCodeAnalysis::visitRegionBranchOperation( |
370 | RegionBranchOpInterface branch) { |
371 | // Try to deduce which regions are executable. |
372 | std::optional<SmallVector<Attribute>> operands = getOperandValues(op: branch); |
373 | if (!operands) |
374 | return; |
375 | |
376 | SmallVector<RegionSuccessor> successors; |
377 | branch.getEntrySuccessorRegions(*operands, successors); |
378 | for (const RegionSuccessor &successor : successors) { |
379 | // The successor can be either an entry block or the parent operation. |
380 | ProgramPoint point = successor.getSuccessor() |
381 | ? &successor.getSuccessor()->front() |
382 | : ProgramPoint(branch); |
383 | // Mark the entry block as executable. |
384 | auto *state = getOrCreate<Executable>(point); |
385 | propagateIfChanged(state: state, changed: state->setToLive()); |
386 | // Add the parent op as a predecessor. |
387 | auto *predecessors = getOrCreate<PredecessorState>(point); |
388 | propagateIfChanged( |
389 | state: predecessors, |
390 | changed: predecessors->join(branch, successor.getSuccessorInputs())); |
391 | } |
392 | } |
393 | |
394 | void DeadCodeAnalysis::visitRegionTerminator(Operation *op, |
395 | RegionBranchOpInterface branch) { |
396 | std::optional<SmallVector<Attribute>> operands = getOperandValues(op); |
397 | if (!operands) |
398 | return; |
399 | |
400 | SmallVector<RegionSuccessor> successors; |
401 | if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) |
402 | terminator.getSuccessorRegions(*operands, successors); |
403 | else |
404 | branch.getSuccessorRegions(op->getParentRegion(), successors); |
405 | |
406 | // Mark successor region entry blocks as executable and add this op to the |
407 | // list of predecessors. |
408 | for (const RegionSuccessor &successor : successors) { |
409 | PredecessorState *predecessors; |
410 | if (Region *region = successor.getSuccessor()) { |
411 | auto *state = getOrCreate<Executable>(point: ®ion->front()); |
412 | propagateIfChanged(state, changed: state->setToLive()); |
413 | predecessors = getOrCreate<PredecessorState>(point: ®ion->front()); |
414 | } else { |
415 | // Add this terminator as a predecessor to the parent op. |
416 | predecessors = getOrCreate<PredecessorState>(branch); |
417 | } |
418 | propagateIfChanged(state: predecessors, |
419 | changed: predecessors->join(predecessor: op, inputs: successor.getSuccessorInputs())); |
420 | } |
421 | } |
422 | |
423 | void DeadCodeAnalysis::visitCallableTerminator(Operation *op, |
424 | CallableOpInterface callable) { |
425 | // Add as predecessors to all callsites this return op. |
426 | auto *callsites = getOrCreateFor<PredecessorState>(op, callable); |
427 | bool canResolve = op->hasTrait<OpTrait::ReturnLike>(); |
428 | for (Operation *predecessor : callsites->getKnownPredecessors()) { |
429 | assert(isa<CallOpInterface>(predecessor)); |
430 | auto *predecessors = getOrCreate<PredecessorState>(predecessor); |
431 | if (canResolve) { |
432 | propagateIfChanged(predecessors, predecessors->join(op)); |
433 | } else { |
434 | // If the terminator is not a return-like, then conservatively assume we |
435 | // can't resolve the predecessor. |
436 | propagateIfChanged(predecessors, |
437 | predecessors->setHasUnknownPredecessors()); |
438 | } |
439 | } |
440 | } |
441 | |