1//===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements an algorithm for eliminating symbol operations that are
10// known to be dead.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Transforms/Passes.h"
15
16#include "mlir/IR/SymbolTable.h"
17
18namespace mlir {
19#define GEN_PASS_DEF_SYMBOLDCE
20#include "mlir/Transforms/Passes.h.inc"
21} // namespace mlir
22
23using namespace mlir;
24
25namespace {
26struct SymbolDCE : public impl::SymbolDCEBase<SymbolDCE> {
27 void runOnOperation() override;
28
29 /// Compute the liveness of the symbols within the given symbol table.
30 /// `symbolTableIsHidden` is true if this symbol table is known to be
31 /// unaccessible from operations in its parent regions.
32 LogicalResult computeLiveness(Operation *symbolTableOp,
33 SymbolTableCollection &symbolTable,
34 bool symbolTableIsHidden,
35 DenseSet<Operation *> &liveSymbols);
36};
37} // namespace
38
39void SymbolDCE::runOnOperation() {
40 Operation *symbolTableOp = getOperation();
41
42 // SymbolDCE should only be run on operations that define a symbol table.
43 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
44 symbolTableOp->emitOpError()
45 << " was scheduled to run under SymbolDCE, but does not define a "
46 "symbol table";
47 return signalPassFailure();
48 }
49
50 // A flag that signals if the top level symbol table is hidden, i.e. not
51 // accessible from parent scopes.
52 bool symbolTableIsHidden = true;
53 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
54 if (symbolTableOp->getParentOp() && symbol)
55 symbolTableIsHidden = symbol.isPrivate();
56
57 // Compute the set of live symbols within the symbol table.
58 DenseSet<Operation *> liveSymbols;
59 SymbolTableCollection symbolTable;
60 if (failed(result: computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
61 liveSymbols)))
62 return signalPassFailure();
63
64 // After computing the liveness, delete all of the symbols that were found to
65 // be dead.
66 symbolTableOp->walk(callback: [&](Operation *nestedSymbolTable) {
67 if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
68 return;
69 for (auto &block : nestedSymbolTable->getRegion(index: 0)) {
70 for (Operation &op : llvm::make_early_inc_range(Range&: block)) {
71 if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op)) {
72 op.erase();
73 ++numDCE;
74 }
75 }
76 }
77 });
78}
79
80/// Compute the liveness of the symbols within the given symbol table.
81/// `symbolTableIsHidden` is true if this symbol table is known to be
82/// unaccessible from operations in its parent regions.
83LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
84 SymbolTableCollection &symbolTable,
85 bool symbolTableIsHidden,
86 DenseSet<Operation *> &liveSymbols) {
87 // A worklist of live operations to propagate uses from.
88 SmallVector<Operation *, 16> worklist;
89
90 // Walk the symbols within the current symbol table, marking the symbols that
91 // are known to be live.
92 for (auto &block : symbolTableOp->getRegion(index: 0)) {
93 // Add all non-symbols or symbols that can't be discarded.
94 for (Operation &op : block) {
95 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
96 if (!symbol) {
97 worklist.push_back(Elt: &op);
98 continue;
99 }
100 bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
101 symbol.canDiscardOnUseEmpty();
102 if (!isDiscardable && liveSymbols.insert(V: &op).second)
103 worklist.push_back(Elt: &op);
104 }
105 }
106
107 // Process the set of symbols that were known to be live, adding new symbols
108 // that are referenced within.
109 while (!worklist.empty()) {
110 Operation *op = worklist.pop_back_val();
111
112 // If this is a symbol table, recursively compute its liveness.
113 if (op->hasTrait<OpTrait::SymbolTable>()) {
114 // The internal symbol table is hidden if the parent is, if its not a
115 // symbol, or if it is a private symbol.
116 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
117 bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
118 if (failed(result: computeLiveness(symbolTableOp: op, symbolTable, symbolTableIsHidden: symIsHidden, liveSymbols)))
119 return failure();
120 }
121
122 // Collect the uses held by this operation.
123 std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(from: op);
124 if (!uses) {
125 return op->emitError()
126 << "operation contains potentially unknown symbol table, "
127 "meaning that we can't reliable compute symbol uses";
128 }
129
130 SmallVector<Operation *, 4> resolvedSymbols;
131 for (const SymbolTable::SymbolUse &use : *uses) {
132 // Lookup the symbols referenced by this use.
133 resolvedSymbols.clear();
134 if (failed(result: symbolTable.lookupSymbolIn(
135 symbolTableOp: op->getParentOp(), name: use.getSymbolRef(), symbols&: resolvedSymbols)))
136 // Ignore references to unknown symbols.
137 continue;
138
139 // Mark each of the resolved symbols as live.
140 for (Operation *resolvedSymbol : resolvedSymbols)
141 if (liveSymbols.insert(V: resolvedSymbol).second)
142 worklist.push_back(Elt: resolvedSymbol);
143 }
144 }
145
146 return success();
147}
148
149std::unique_ptr<Pass> mlir::createSymbolDCEPass() {
150 return std::make_unique<SymbolDCE>();
151}
152

source code of mlir/lib/Transforms/SymbolDCE.cpp