1//===- DeadCodeAnalysis.cpp - Dead code analysis --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
11#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
12#include "mlir/Analysis/DataFlowFramework.h"
13#include "mlir/IR/Attributes.h"
14#include "mlir/IR/Block.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Location.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/SymbolTable.h"
19#include "mlir/IR/Value.h"
20#include "mlir/IR/ValueRange.h"
21#include "mlir/Interfaces/CallInterfaces.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.h"
23#include "mlir/Support/LLVM.h"
24#include "llvm/Support/Casting.h"
25#include "llvm/Support/Debug.h"
26#include <cassert>
27#include <optional>
28
29#define DEBUG_TYPE "dead-code-analysis"
30#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
31#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
32
33using namespace mlir;
34using namespace mlir::dataflow;
35
36//===----------------------------------------------------------------------===//
37// Executable
38//===----------------------------------------------------------------------===//
39
40ChangeResult Executable::setToLive() {
41 if (live)
42 return ChangeResult::NoChange;
43 live = true;
44 return ChangeResult::Change;
45}
46
47void Executable::print(raw_ostream &os) const {
48 os << (live ? "live" : "dead");
49}
50
51void Executable::onUpdate(DataFlowSolver *solver) const {
52 AnalysisState::onUpdate(solver);
53
54 if (ProgramPoint *pp = llvm::dyn_cast_if_present<ProgramPoint *>(Val: anchor)) {
55 if (pp->isBlockStart()) {
56 // Re-invoke the analyses on the block itself.
57 for (DataFlowAnalysis *analysis : subscribers)
58 solver->enqueue(item: {pp, analysis});
59 // Re-invoke the analyses on all operations in the block.
60 for (DataFlowAnalysis *analysis : subscribers)
61 for (Operation &op : *pp->getBlock())
62 solver->enqueue(item: {solver->getProgramPointAfter(op: &op), analysis});
63 }
64 } else if (auto *latticeAnchor =
65 llvm::dyn_cast_if_present<GenericLatticeAnchor *>(Val: anchor)) {
66 // Re-invoke the analysis on the successor block.
67 if (auto *edge = dyn_cast<CFGEdge>(Val: latticeAnchor)) {
68 for (DataFlowAnalysis *analysis : subscribers)
69 solver->enqueue(
70 item: {solver->getProgramPointBefore(block: edge->getTo()), analysis});
71 }
72 }
73}
74
75//===----------------------------------------------------------------------===//
76// PredecessorState
77//===----------------------------------------------------------------------===//
78
79void PredecessorState::print(raw_ostream &os) const {
80 if (allPredecessorsKnown())
81 os << "(all) ";
82 os << "predecessors:\n";
83 for (Operation *op : getKnownPredecessors())
84 os << " " << *op << "\n";
85}
86
87ChangeResult PredecessorState::join(Operation *predecessor) {
88 return knownPredecessors.insert(X: predecessor) ? ChangeResult::Change
89 : ChangeResult::NoChange;
90}
91
92ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) {
93 ChangeResult result = join(predecessor);
94 if (!inputs.empty()) {
95 ValueRange &curInputs = successorInputs[predecessor];
96 if (curInputs != inputs) {
97 curInputs = inputs;
98 result |= ChangeResult::Change;
99 }
100 }
101 return result;
102}
103
104//===----------------------------------------------------------------------===//
105// CFGEdge
106//===----------------------------------------------------------------------===//
107
108Location CFGEdge::getLoc() const {
109 return FusedLoc::get(
110 context: getFrom()->getParent()->getContext(),
111 locs: {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()});
112}
113
114void CFGEdge::print(raw_ostream &os) const {
115 getFrom()->print(os);
116 os << "\n -> \n";
117 getTo()->print(os);
118}
119
120//===----------------------------------------------------------------------===//
121// DeadCodeAnalysis
122//===----------------------------------------------------------------------===//
123
124DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
125 : DataFlowAnalysis(solver) {
126 registerAnchorKind<CFGEdge>();
127}
128
129LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
130 LDBG("Initializing DeadCodeAnalysis for top-level op: " << top->getName());
131 // Mark the top-level blocks as executable.
132 for (Region &region : top->getRegions()) {
133 if (region.empty())
134 continue;
135 auto *state =
136 getOrCreate<Executable>(anchor: getProgramPointBefore(block: &region.front()));
137 propagateIfChanged(state, changed: state->setToLive());
138 LDBG("Marked entry block live for region in op: " << top->getName());
139 }
140
141 // Mark as overdefined the predecessors of symbol callables with potentially
142 // unknown predecessors.
143 initializeSymbolCallables(top);
144
145 return initializeRecursively(op: top);
146}
147
148void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
149 LDBG("[init] Entering initializeSymbolCallables for top-level op: "
150 << top->getName());
151 analysisScope = top;
152 auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
153 LDBG("[init] Processing symbol table op: " << symTable->getName());
154 Region &symbolTableRegion = symTable->getRegion(index: 0);
155 Block *symbolTableBlock = &symbolTableRegion.front();
156
157 bool foundSymbolCallable = false;
158 for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
159 LDBG("[init] Found CallableOpInterface: "
160 << callable.getOperation()->getName());
161 Region *callableRegion = callable.getCallableRegion();
162 if (!callableRegion)
163 continue;
164 auto symbol = dyn_cast<SymbolOpInterface>(Val: callable.getOperation());
165 if (!symbol)
166 continue;
167
168 // Public symbol callables or those for which we can't see all uses have
169 // potentially unknown callsites.
170 if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
171 auto *state =
172 getOrCreate<PredecessorState>(anchor: getProgramPointAfter(op: callable));
173 propagateIfChanged(state, changed: state->setHasUnknownPredecessors());
174 LDBG("[init] Marked callable as having unknown predecessors: "
175 << callable.getOperation()->getName());
176 }
177 foundSymbolCallable = true;
178 }
179
180 // Exit early if no eligible symbol callables were found in the table.
181 if (!foundSymbolCallable)
182 return;
183
184 // Walk the symbol table to check for non-call uses of symbols.
185 std::optional<SymbolTable::UseRange> uses =
186 SymbolTable::getSymbolUses(from: &symbolTableRegion);
187 if (!uses) {
188 // If we couldn't gather the symbol uses, conservatively assume that
189 // we can't track information for any nested symbols.
190 LDBG("[init] Could not gather symbol uses, conservatively marking "
191 "all nested callables as having unknown predecessors");
192 return top->walk(callback: [&](CallableOpInterface callable) {
193 auto *state =
194 getOrCreate<PredecessorState>(anchor: getProgramPointAfter(op: callable));
195 propagateIfChanged(state, changed: state->setHasUnknownPredecessors());
196 LDBG("[init] Marked nested callable as "
197 "having unknown predecessors: "
198 << callable.getOperation()->getName());
199 });
200 }
201
202 for (const SymbolTable::SymbolUse &use : *uses) {
203 if (isa<CallOpInterface>(Val: use.getUser()))
204 continue;
205 // If a callable symbol has a non-call use, then we can't be guaranteed to
206 // know all callsites.
207 Operation *symbol = symbolTable.lookupSymbolIn(symbolTableOp: top, name: use.getSymbolRef());
208 if (!symbol)
209 continue;
210 auto *state = getOrCreate<PredecessorState>(anchor: getProgramPointAfter(op: symbol));
211 propagateIfChanged(state, changed: state->setHasUnknownPredecessors());
212 LDBG("[init] Found non-call use for symbol, "
213 "marked as having unknown predecessors: "
214 << symbol->getName());
215 }
216 };
217 SymbolTable::walkSymbolTables(op: top, /*allSymUsesVisible=*/!top->getBlock(),
218 callback: walkFn);
219 LDBG("[init] Finished initializeSymbolCallables for top-level op: "
220 << top->getName());
221}
222
223/// Returns true if the operation is a returning terminator in region
224/// control-flow or the terminator of a callable region.
225static bool isRegionOrCallableReturn(Operation *op) {
226 return op->getBlock() != nullptr && !op->getNumSuccessors() &&
227 isa<RegionBranchOpInterface, CallableOpInterface>(Val: op->getParentOp()) &&
228 op->getBlock()->getTerminator() == op;
229}
230
231LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
232 LDBG("[init] Entering initializeRecursively for op: " << op->getName()
233 << " at " << op);
234 // Initialize the analysis by visiting every op with control-flow semantics.
235 if (op->getNumRegions() || op->getNumSuccessors() ||
236 isRegionOrCallableReturn(op) || isa<CallOpInterface>(Val: op)) {
237 LDBG("[init] Visiting op with control-flow semantics: " << *op);
238 // When the liveness of the parent block changes, make sure to re-invoke the
239 // analysis on the op.
240 if (op->getBlock())
241 getOrCreate<Executable>(anchor: getProgramPointBefore(block: op->getBlock()))
242 ->blockContentSubscribe(analysis: this);
243 // Visit the op.
244 if (failed(Result: visit(point: getProgramPointAfter(op))))
245 return failure();
246 }
247 // Recurse on nested operations.
248 for (Region &region : op->getRegions()) {
249 LDBG("[init] Recursing into region of op: " << op->getName());
250 for (Operation &nestedOp : region.getOps()) {
251 LDBG("[init] Recursing into nested op: " << nestedOp.getName() << " at "
252 << &nestedOp);
253 if (failed(Result: initializeRecursively(op: &nestedOp)))
254 return failure();
255 }
256 }
257 LDBG("[init] Finished initializeRecursively for op: " << op->getName()
258 << " at " << op);
259 return success();
260}
261
262void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
263 LDBG("Marking edge live from block " << from << " to block " << to);
264 auto *state = getOrCreate<Executable>(anchor: getProgramPointBefore(block: to));
265 propagateIfChanged(state, changed: state->setToLive());
266 auto *edgeState =
267 getOrCreate<Executable>(anchor: getLatticeAnchor<CFGEdge>(args&: from, args&: to));
268 propagateIfChanged(state: edgeState, changed: edgeState->setToLive());
269}
270
271void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
272 LDBG("Marking entry blocks live for op: " << op->getName());
273 for (Region &region : op->getRegions()) {
274 if (region.empty())
275 continue;
276 auto *state =
277 getOrCreate<Executable>(anchor: getProgramPointBefore(block: &region.front()));
278 propagateIfChanged(state, changed: state->setToLive());
279 LDBG("Marked entry block live for region in op: " << op->getName());
280 }
281}
282
283LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
284 LDBG("Visiting program point: " << point << " " << *point);
285 if (point->isBlockStart())
286 return success();
287 Operation *op = point->getPrevOp();
288 LDBG("Visiting operation: " << *op);
289
290 // If the parent block is not executable, there is nothing to do.
291 if (op->getBlock() != nullptr &&
292 !getOrCreate<Executable>(anchor: getProgramPointBefore(block: op->getBlock()))
293 ->isLive()) {
294 LDBG("Parent block not live, skipping op: " << *op);
295 return success();
296 }
297
298 // We have a live call op. Add this as a live predecessor of the callee.
299 if (auto call = dyn_cast<CallOpInterface>(Val: op)) {
300 LDBG("Visiting call operation: " << *op);
301 visitCallOperation(call);
302 }
303
304 // Visit the regions.
305 if (op->getNumRegions()) {
306 // Check if we can reason about the region control-flow.
307 if (auto branch = dyn_cast<RegionBranchOpInterface>(Val: op)) {
308 LDBG("Visiting region branch operation: " << *op);
309 visitRegionBranchOperation(branch);
310
311 // Check if this is a callable operation.
312 } else if (auto callable = dyn_cast<CallableOpInterface>(Val: op)) {
313 LDBG("Visiting callable operation: " << *op);
314 const auto *callsites = getOrCreateFor<PredecessorState>(
315 dependent: getProgramPointAfter(op), anchor: getProgramPointAfter(op: callable));
316
317 // If the callsites could not be resolved or are known to be non-empty,
318 // mark the callable as executable.
319 if (!callsites->allPredecessorsKnown() ||
320 !callsites->getKnownPredecessors().empty())
321 markEntryBlocksLive(op: callable);
322
323 // Otherwise, conservatively mark all entry blocks as executable.
324 } else {
325 LDBG("Marking all entry blocks live for op: " << *op);
326 markEntryBlocksLive(op);
327 }
328 }
329
330 if (isRegionOrCallableReturn(op)) {
331 if (auto branch = dyn_cast<RegionBranchOpInterface>(Val: op->getParentOp())) {
332 LDBG("Visiting region terminator: " << *op);
333 // Visit the exiting terminator of a region.
334 visitRegionTerminator(op, branch);
335 } else if (auto callable =
336 dyn_cast<CallableOpInterface>(Val: op->getParentOp())) {
337 LDBG("Visiting callable terminator: " << *op);
338 // Visit the exiting terminator of a callable.
339 visitCallableTerminator(op, callable);
340 }
341 }
342 // Visit the successors.
343 if (op->getNumSuccessors()) {
344 // Check if we can reason about the control-flow.
345 if (auto branch = dyn_cast<BranchOpInterface>(Val: op)) {
346 LDBG("Visiting branch operation: " << *op);
347 visitBranchOperation(branch);
348
349 // Otherwise, conservatively mark all successors as exectuable.
350 } else {
351 LDBG("Marking all successors live for op: " << *op);
352 for (Block *successor : op->getSuccessors())
353 markEdgeLive(from: op->getBlock(), to: successor);
354 }
355 }
356
357 return success();
358}
359
360void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
361 LDBG("visitCallOperation: " << call.getOperation()->getName());
362 Operation *callableOp = call.resolveCallableInTable(symbolTable: &symbolTable);
363
364 // A call to a externally-defined callable has unknown predecessors.
365 const auto isExternalCallable = [this](Operation *op) {
366 // A callable outside the analysis scope is an external callable.
367 if (!analysisScope->isAncestor(other: op))
368 return true;
369 // Otherwise, check if the callable region is defined.
370 if (auto callable = dyn_cast<CallableOpInterface>(Val: op))
371 return !callable.getCallableRegion();
372 return false;
373 };
374
375 // TODO: Add support for non-symbol callables when necessary. If the
376 // callable has non-call uses we would mark as having reached pessimistic
377 // fixpoint, otherwise allow for propagating the return values out.
378 if (isa_and_nonnull<SymbolOpInterface>(Val: callableOp) &&
379 !isExternalCallable(callableOp)) {
380 // Add the live callsite.
381 auto *callsites =
382 getOrCreate<PredecessorState>(anchor: getProgramPointAfter(op: callableOp));
383 propagateIfChanged(state: callsites, changed: callsites->join(predecessor: call));
384 LDBG("Added callsite as predecessor for callable: "
385 << callableOp->getName());
386 } else {
387 // Mark this call op's predecessors as overdefined.
388 auto *predecessors =
389 getOrCreate<PredecessorState>(anchor: getProgramPointAfter(op: call));
390 propagateIfChanged(state: predecessors, changed: predecessors->setHasUnknownPredecessors());
391 LDBG("Marked call op's predecessors as unknown for: "
392 << call.getOperation()->getName());
393 }
394}
395
396/// Get the constant values of the operands of an operation. If any of the
397/// constant value lattices are uninitialized, return std::nullopt to indicate
398/// the analysis should bail out.
399static std::optional<SmallVector<Attribute>> getOperandValuesImpl(
400 Operation *op,
401 function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
402 SmallVector<Attribute> operands;
403 operands.reserve(N: op->getNumOperands());
404 for (Value operand : op->getOperands()) {
405 const Lattice<ConstantValue> *cv = getLattice(operand);
406 // If any of the operands' values are uninitialized, bail out.
407 if (cv->getValue().isUninitialized())
408 return {};
409 operands.push_back(Elt: cv->getValue().getConstantValue());
410 }
411 return operands;
412}
413
414std::optional<SmallVector<Attribute>>
415DeadCodeAnalysis::getOperandValues(Operation *op) {
416 return getOperandValuesImpl(op, getLattice: [&](Value value) {
417 auto *lattice = getOrCreate<Lattice<ConstantValue>>(anchor: value);
418 lattice->useDefSubscribe(analysis: this);
419 return lattice;
420 });
421}
422
423void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
424 LDBG("visitBranchOperation: " << branch.getOperation()->getName());
425 // Try to deduce a single successor for the branch.
426 std::optional<SmallVector<Attribute>> operands = getOperandValues(op: branch);
427 if (!operands)
428 return;
429
430 if (Block *successor = branch.getSuccessorForOperands(operands: *operands)) {
431 markEdgeLive(from: branch->getBlock(), to: successor);
432 LDBG("Branch has single successor: " << successor);
433 } else {
434 // Otherwise, mark all successors as executable and outgoing edges.
435 for (Block *successor : branch->getSuccessors())
436 markEdgeLive(from: branch->getBlock(), to: successor);
437 LDBG("Branch has multiple/all successors live");
438 }
439}
440
441void DeadCodeAnalysis::visitRegionBranchOperation(
442 RegionBranchOpInterface branch) {
443 LDBG("visitRegionBranchOperation: " << branch.getOperation()->getName());
444 // Try to deduce which regions are executable.
445 std::optional<SmallVector<Attribute>> operands = getOperandValues(op: branch);
446 if (!operands)
447 return;
448
449 SmallVector<RegionSuccessor> successors;
450 branch.getEntrySuccessorRegions(operands: *operands, regions&: successors);
451 for (const RegionSuccessor &successor : successors) {
452 // The successor can be either an entry block or the parent operation.
453 ProgramPoint *point =
454 successor.getSuccessor()
455 ? getProgramPointBefore(block: &successor.getSuccessor()->front())
456 : getProgramPointAfter(op: branch);
457 // Mark the entry block as executable.
458 auto *state = getOrCreate<Executable>(anchor: point);
459 propagateIfChanged(state, changed: state->setToLive());
460 LDBG("Marked region successor live: " << point);
461 // Add the parent op as a predecessor.
462 auto *predecessors = getOrCreate<PredecessorState>(anchor: point);
463 propagateIfChanged(
464 state: predecessors,
465 changed: predecessors->join(predecessor: branch, inputs: successor.getSuccessorInputs()));
466 LDBG("Added region branch as predecessor for successor: " << point);
467 }
468}
469
470void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
471 RegionBranchOpInterface branch) {
472 LDBG("visitRegionTerminator: " << *op);
473 std::optional<SmallVector<Attribute>> operands = getOperandValues(op);
474 if (!operands)
475 return;
476
477 SmallVector<RegionSuccessor> successors;
478 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(Val: op))
479 terminator.getSuccessorRegions(operands: *operands, regions&: successors);
480 else
481 branch.getSuccessorRegions(point: op->getParentRegion(), regions&: successors);
482
483 // Mark successor region entry blocks as executable and add this op to the
484 // list of predecessors.
485 for (const RegionSuccessor &successor : successors) {
486 PredecessorState *predecessors;
487 if (Region *region = successor.getSuccessor()) {
488 auto *state =
489 getOrCreate<Executable>(anchor: getProgramPointBefore(block: &region->front()));
490 propagateIfChanged(state, changed: state->setToLive());
491 LDBG("Marked region entry block live for region: " << region);
492 predecessors = getOrCreate<PredecessorState>(
493 anchor: getProgramPointBefore(block: &region->front()));
494 } else {
495 // Add this terminator as a predecessor to the parent op.
496 predecessors =
497 getOrCreate<PredecessorState>(anchor: getProgramPointAfter(op: branch));
498 }
499 propagateIfChanged(state: predecessors,
500 changed: predecessors->join(predecessor: op, inputs: successor.getSuccessorInputs()));
501 LDBG("Added region terminator as predecessor for successor: "
502 << (successor.getSuccessor() ? "region entry" : "parent op"));
503 }
504}
505
506void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
507 CallableOpInterface callable) {
508 LDBG("visitCallableTerminator: " << *op);
509 // Add as predecessors to all callsites this return op.
510 auto *callsites = getOrCreateFor<PredecessorState>(
511 dependent: getProgramPointAfter(op), anchor: getProgramPointAfter(op: callable));
512 bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
513 for (Operation *predecessor : callsites->getKnownPredecessors()) {
514 assert(isa<CallOpInterface>(predecessor));
515 auto *predecessors =
516 getOrCreate<PredecessorState>(anchor: getProgramPointAfter(op: predecessor));
517 if (canResolve) {
518 propagateIfChanged(state: predecessors, changed: predecessors->join(predecessor: op));
519 LDBG("Added callable terminator as predecessor for callsite: "
520 << predecessor->getName());
521 } else {
522 // If the terminator is not a return-like, then conservatively assume we
523 // can't resolve the predecessor.
524 propagateIfChanged(state: predecessors,
525 changed: predecessors->setHasUnknownPredecessors());
526 LDBG("Could not resolve callable terminator for callsite: "
527 << predecessor->getName());
528 }
529 }
530}
531

source code of mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp