1//===- DeadCodeAnalysis.cpp - Dead code analysis --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
11#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
12#include "mlir/Analysis/DataFlowFramework.h"
13#include "mlir/IR/Attributes.h"
14#include "mlir/IR/Block.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Location.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/SymbolTable.h"
19#include "mlir/IR/Value.h"
20#include "mlir/IR/ValueRange.h"
21#include "mlir/Interfaces/CallInterfaces.h"
22#include "mlir/Interfaces/ControlFlowInterfaces.h"
23#include "mlir/Support/LLVM.h"
24#include "mlir/Support/LogicalResult.h"
25#include "llvm/Support/Casting.h"
26#include <cassert>
27#include <optional>
28
29using namespace mlir;
30using namespace mlir::dataflow;
31
32//===----------------------------------------------------------------------===//
33// Executable
34//===----------------------------------------------------------------------===//
35
36ChangeResult Executable::setToLive() {
37 if (live)
38 return ChangeResult::NoChange;
39 live = true;
40 return ChangeResult::Change;
41}
42
43void Executable::print(raw_ostream &os) const {
44 os << (live ? "live" : "dead");
45}
46
47void Executable::onUpdate(DataFlowSolver *solver) const {
48 AnalysisState::onUpdate(solver);
49
50 if (auto *block = llvm::dyn_cast_if_present<Block *>(Val: point)) {
51 // Re-invoke the analyses on the block itself.
52 for (DataFlowAnalysis *analysis : subscribers)
53 solver->enqueue(item: {block, analysis});
54 // Re-invoke the analyses on all operations in the block.
55 for (DataFlowAnalysis *analysis : subscribers)
56 for (Operation &op : *block)
57 solver->enqueue(item: {&op, analysis});
58 } else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(Val: point)) {
59 // Re-invoke the analysis on the successor block.
60 if (auto *edge = dyn_cast<CFGEdge>(Val: programPoint)) {
61 for (DataFlowAnalysis *analysis : subscribers)
62 solver->enqueue(item: {edge->getTo(), analysis});
63 }
64 }
65}
66
67//===----------------------------------------------------------------------===//
68// PredecessorState
69//===----------------------------------------------------------------------===//
70
71void PredecessorState::print(raw_ostream &os) const {
72 if (allPredecessorsKnown())
73 os << "(all) ";
74 os << "predecessors:\n";
75 for (Operation *op : getKnownPredecessors())
76 os << " " << *op << "\n";
77}
78
79ChangeResult PredecessorState::join(Operation *predecessor) {
80 return knownPredecessors.insert(X: predecessor) ? ChangeResult::Change
81 : ChangeResult::NoChange;
82}
83
84ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) {
85 ChangeResult result = join(predecessor);
86 if (!inputs.empty()) {
87 ValueRange &curInputs = successorInputs[predecessor];
88 if (curInputs != inputs) {
89 curInputs = inputs;
90 result |= ChangeResult::Change;
91 }
92 }
93 return result;
94}
95
96//===----------------------------------------------------------------------===//
97// CFGEdge
98//===----------------------------------------------------------------------===//
99
100Location CFGEdge::getLoc() const {
101 return FusedLoc::get(
102 getFrom()->getParent()->getContext(),
103 {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()});
104}
105
106void CFGEdge::print(raw_ostream &os) const {
107 getFrom()->print(os);
108 os << "\n -> \n";
109 getTo()->print(os);
110}
111
112//===----------------------------------------------------------------------===//
113// DeadCodeAnalysis
114//===----------------------------------------------------------------------===//
115
116DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
117 : DataFlowAnalysis(solver) {
118 registerPointKind<CFGEdge>();
119}
120
121LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
122 // Mark the top-level blocks as executable.
123 for (Region &region : top->getRegions()) {
124 if (region.empty())
125 continue;
126 auto *state = getOrCreate<Executable>(point: &region.front());
127 propagateIfChanged(state, changed: state->setToLive());
128 }
129
130 // Mark as overdefined the predecessors of symbol callables with potentially
131 // unknown predecessors.
132 initializeSymbolCallables(top);
133
134 return initializeRecursively(op: top);
135}
136
137void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
138 analysisScope = top;
139 auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
140 Region &symbolTableRegion = symTable->getRegion(index: 0);
141 Block *symbolTableBlock = &symbolTableRegion.front();
142
143 bool foundSymbolCallable = false;
144 for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
145 Region *callableRegion = callable.getCallableRegion();
146 if (!callableRegion)
147 continue;
148 auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
149 if (!symbol)
150 continue;
151
152 // Public symbol callables or those for which we can't see all uses have
153 // potentially unknown callsites.
154 if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
155 auto *state = getOrCreate<PredecessorState>(callable);
156 propagateIfChanged(state, state->setHasUnknownPredecessors());
157 }
158 foundSymbolCallable = true;
159 }
160
161 // Exit early if no eligible symbol callables were found in the table.
162 if (!foundSymbolCallable)
163 return;
164
165 // Walk the symbol table to check for non-call uses of symbols.
166 std::optional<SymbolTable::UseRange> uses =
167 SymbolTable::getSymbolUses(from: &symbolTableRegion);
168 if (!uses) {
169 // If we couldn't gather the symbol uses, conservatively assume that
170 // we can't track information for any nested symbols.
171 return top->walk([&](CallableOpInterface callable) {
172 auto *state = getOrCreate<PredecessorState>(callable);
173 propagateIfChanged(state, state->setHasUnknownPredecessors());
174 });
175 }
176
177 for (const SymbolTable::SymbolUse &use : *uses) {
178 if (isa<CallOpInterface>(Val: use.getUser()))
179 continue;
180 // If a callable symbol has a non-call use, then we can't be guaranteed to
181 // know all callsites.
182 Operation *symbol = symbolTable.lookupSymbolIn(symbolTableOp: top, name: use.getSymbolRef());
183 auto *state = getOrCreate<PredecessorState>(point: symbol);
184 propagateIfChanged(state, changed: state->setHasUnknownPredecessors());
185 }
186 };
187 SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
188 walkFn);
189}
190
191/// Returns true if the operation is a returning terminator in region
192/// control-flow or the terminator of a callable region.
193static bool isRegionOrCallableReturn(Operation *op) {
194 return !op->getNumSuccessors() &&
195 isa<RegionBranchOpInterface, CallableOpInterface>(Val: op->getParentOp()) &&
196 op->getBlock()->getTerminator() == op;
197}
198
199LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
200 // Initialize the analysis by visiting every op with control-flow semantics.
201 if (op->getNumRegions() || op->getNumSuccessors() ||
202 isRegionOrCallableReturn(op) || isa<CallOpInterface>(Val: op)) {
203 // When the liveness of the parent block changes, make sure to re-invoke the
204 // analysis on the op.
205 if (op->getBlock())
206 getOrCreate<Executable>(point: op->getBlock())->blockContentSubscribe(analysis: this);
207 // Visit the op.
208 if (failed(result: visit(point: op)))
209 return failure();
210 }
211 // Recurse on nested operations.
212 for (Region &region : op->getRegions())
213 for (Operation &op : region.getOps())
214 if (failed(result: initializeRecursively(op: &op)))
215 return failure();
216 return success();
217}
218
219void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
220 auto *state = getOrCreate<Executable>(point: to);
221 propagateIfChanged(state, changed: state->setToLive());
222 auto *edgeState = getOrCreate<Executable>(point: getProgramPoint<CFGEdge>(args&: from, args&: to));
223 propagateIfChanged(state: edgeState, changed: edgeState->setToLive());
224}
225
226void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
227 for (Region &region : op->getRegions()) {
228 if (region.empty())
229 continue;
230 auto *state = getOrCreate<Executable>(point: &region.front());
231 propagateIfChanged(state, changed: state->setToLive());
232 }
233}
234
235LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
236 if (point.is<Block *>())
237 return success();
238 auto *op = llvm::dyn_cast_if_present<Operation *>(Val&: point);
239 if (!op)
240 return emitError(loc: point.getLoc(), message: "unknown program point kind");
241
242 // If the parent block is not executable, there is nothing to do.
243 if (!getOrCreate<Executable>(point: op->getBlock())->isLive())
244 return success();
245
246 // We have a live call op. Add this as a live predecessor of the callee.
247 if (auto call = dyn_cast<CallOpInterface>(op))
248 visitCallOperation(call: call);
249
250 // Visit the regions.
251 if (op->getNumRegions()) {
252 // Check if we can reason about the region control-flow.
253 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
254 visitRegionBranchOperation(branch: branch);
255
256 // Check if this is a callable operation.
257 } else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
258 const auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
259
260 // If the callsites could not be resolved or are known to be non-empty,
261 // mark the callable as executable.
262 if (!callsites->allPredecessorsKnown() ||
263 !callsites->getKnownPredecessors().empty())
264 markEntryBlocksLive(op: callable);
265
266 // Otherwise, conservatively mark all entry blocks as executable.
267 } else {
268 markEntryBlocksLive(op);
269 }
270 }
271
272 if (isRegionOrCallableReturn(op)) {
273 if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
274 // Visit the exiting terminator of a region.
275 visitRegionTerminator(op, branch: branch);
276 } else if (auto callable =
277 dyn_cast<CallableOpInterface>(op->getParentOp())) {
278 // Visit the exiting terminator of a callable.
279 visitCallableTerminator(op, callable: callable);
280 }
281 }
282 // Visit the successors.
283 if (op->getNumSuccessors()) {
284 // Check if we can reason about the control-flow.
285 if (auto branch = dyn_cast<BranchOpInterface>(op)) {
286 visitBranchOperation(branch: branch);
287
288 // Otherwise, conservatively mark all successors as exectuable.
289 } else {
290 for (Block *successor : op->getSuccessors())
291 markEdgeLive(from: op->getBlock(), to: successor);
292 }
293 }
294
295 return success();
296}
297
298void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
299 Operation *callableOp = call.resolveCallable(&symbolTable);
300
301 // A call to a externally-defined callable has unknown predecessors.
302 const auto isExternalCallable = [this](Operation *op) {
303 // A callable outside the analysis scope is an external callable.
304 if (!analysisScope->isAncestor(other: op))
305 return true;
306 // Otherwise, check if the callable region is defined.
307 if (auto callable = dyn_cast<CallableOpInterface>(op))
308 return !callable.getCallableRegion();
309 return false;
310 };
311
312 // TODO: Add support for non-symbol callables when necessary. If the
313 // callable has non-call uses we would mark as having reached pessimistic
314 // fixpoint, otherwise allow for propagating the return values out.
315 if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
316 !isExternalCallable(callableOp)) {
317 // Add the live callsite.
318 auto *callsites = getOrCreate<PredecessorState>(point: callableOp);
319 propagateIfChanged(state: callsites, changed: callsites->join(call));
320 } else {
321 // Mark this call op's predecessors as overdefined.
322 auto *predecessors = getOrCreate<PredecessorState>(call);
323 propagateIfChanged(state: predecessors, changed: predecessors->setHasUnknownPredecessors());
324 }
325}
326
327/// Get the constant values of the operands of an operation. If any of the
328/// constant value lattices are uninitialized, return std::nullopt to indicate
329/// the analysis should bail out.
330static std::optional<SmallVector<Attribute>> getOperandValuesImpl(
331 Operation *op,
332 function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
333 SmallVector<Attribute> operands;
334 operands.reserve(N: op->getNumOperands());
335 for (Value operand : op->getOperands()) {
336 const Lattice<ConstantValue> *cv = getLattice(operand);
337 // If any of the operands' values are uninitialized, bail out.
338 if (cv->getValue().isUninitialized())
339 return {};
340 operands.push_back(Elt: cv->getValue().getConstantValue());
341 }
342 return operands;
343}
344
345std::optional<SmallVector<Attribute>>
346DeadCodeAnalysis::getOperandValues(Operation *op) {
347 return getOperandValuesImpl(op, getLattice: [&](Value value) {
348 auto *lattice = getOrCreate<Lattice<ConstantValue>>(point: value);
349 lattice->useDefSubscribe(analysis: this);
350 return lattice;
351 });
352}
353
354void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
355 // Try to deduce a single successor for the branch.
356 std::optional<SmallVector<Attribute>> operands = getOperandValues(op: branch);
357 if (!operands)
358 return;
359
360 if (Block *successor = branch.getSuccessorForOperands(*operands)) {
361 markEdgeLive(from: branch->getBlock(), to: successor);
362 } else {
363 // Otherwise, mark all successors as executable and outgoing edges.
364 for (Block *successor : branch->getSuccessors())
365 markEdgeLive(branch->getBlock(), successor);
366 }
367}
368
369void DeadCodeAnalysis::visitRegionBranchOperation(
370 RegionBranchOpInterface branch) {
371 // Try to deduce which regions are executable.
372 std::optional<SmallVector<Attribute>> operands = getOperandValues(op: branch);
373 if (!operands)
374 return;
375
376 SmallVector<RegionSuccessor> successors;
377 branch.getEntrySuccessorRegions(*operands, successors);
378 for (const RegionSuccessor &successor : successors) {
379 // The successor can be either an entry block or the parent operation.
380 ProgramPoint point = successor.getSuccessor()
381 ? &successor.getSuccessor()->front()
382 : ProgramPoint(branch);
383 // Mark the entry block as executable.
384 auto *state = getOrCreate<Executable>(point);
385 propagateIfChanged(state: state, changed: state->setToLive());
386 // Add the parent op as a predecessor.
387 auto *predecessors = getOrCreate<PredecessorState>(point);
388 propagateIfChanged(
389 state: predecessors,
390 changed: predecessors->join(branch, successor.getSuccessorInputs()));
391 }
392}
393
394void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
395 RegionBranchOpInterface branch) {
396 std::optional<SmallVector<Attribute>> operands = getOperandValues(op);
397 if (!operands)
398 return;
399
400 SmallVector<RegionSuccessor> successors;
401 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op))
402 terminator.getSuccessorRegions(*operands, successors);
403 else
404 branch.getSuccessorRegions(op->getParentRegion(), successors);
405
406 // Mark successor region entry blocks as executable and add this op to the
407 // list of predecessors.
408 for (const RegionSuccessor &successor : successors) {
409 PredecessorState *predecessors;
410 if (Region *region = successor.getSuccessor()) {
411 auto *state = getOrCreate<Executable>(point: &region->front());
412 propagateIfChanged(state, changed: state->setToLive());
413 predecessors = getOrCreate<PredecessorState>(point: &region->front());
414 } else {
415 // Add this terminator as a predecessor to the parent op.
416 predecessors = getOrCreate<PredecessorState>(branch);
417 }
418 propagateIfChanged(state: predecessors,
419 changed: predecessors->join(predecessor: op, inputs: successor.getSuccessorInputs()));
420 }
421}
422
423void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
424 CallableOpInterface callable) {
425 // Add as predecessors to all callsites this return op.
426 auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
427 bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
428 for (Operation *predecessor : callsites->getKnownPredecessors()) {
429 assert(isa<CallOpInterface>(predecessor));
430 auto *predecessors = getOrCreate<PredecessorState>(predecessor);
431 if (canResolve) {
432 propagateIfChanged(predecessors, predecessors->join(op));
433 } else {
434 // If the terminator is not a return-like, then conservatively assume we
435 // can't resolve the predecessor.
436 propagateIfChanged(predecessors,
437 predecessors->setHasUnknownPredecessors());
438 }
439 }
440}
441

source code of mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp