1//===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements an algorithm for eliminating symbol operations that are
10// known to be dead.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Transforms/Passes.h"
15
16#include "mlir/IR/SymbolTable.h"
17#include "llvm/Support/Debug.h"
18
19namespace mlir {
20#define GEN_PASS_DEF_SYMBOLDCE
21#include "mlir/Transforms/Passes.h.inc"
22} // namespace mlir
23
24using namespace mlir;
25
26#define DEBUG_TYPE "symbol-dce"
27
28namespace {
29struct SymbolDCE : public impl::SymbolDCEBase<SymbolDCE> {
30 void runOnOperation() override;
31
32 /// Compute the liveness of the symbols within the given symbol table.
33 /// `symbolTableIsHidden` is true if this symbol table is known to be
34 /// unaccessible from operations in its parent regions.
35 LogicalResult computeLiveness(Operation *symbolTableOp,
36 SymbolTableCollection &symbolTable,
37 bool symbolTableIsHidden,
38 DenseSet<Operation *> &liveSymbols);
39};
40} // namespace
41
42void SymbolDCE::runOnOperation() {
43 Operation *symbolTableOp = getOperation();
44
45 // SymbolDCE should only be run on operations that define a symbol table.
46 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
47 symbolTableOp->emitOpError()
48 << " was scheduled to run under SymbolDCE, but does not define a "
49 "symbol table";
50 return signalPassFailure();
51 }
52
53 // A flag that signals if the top level symbol table is hidden, i.e. not
54 // accessible from parent scopes.
55 bool symbolTableIsHidden = true;
56 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(Val: symbolTableOp);
57 if (symbolTableOp->getParentOp() && symbol)
58 symbolTableIsHidden = symbol.isPrivate();
59
60 // Compute the set of live symbols within the symbol table.
61 DenseSet<Operation *> liveSymbols;
62 SymbolTableCollection symbolTable;
63 if (failed(Result: computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
64 liveSymbols)))
65 return signalPassFailure();
66
67 // After computing the liveness, delete all of the symbols that were found to
68 // be dead.
69 symbolTableOp->walk(callback: [&](Operation *nestedSymbolTable) {
70 if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
71 return;
72 for (auto &block : nestedSymbolTable->getRegion(index: 0)) {
73 for (Operation &op : llvm::make_early_inc_range(Range&: block)) {
74 if (isa<SymbolOpInterface>(Val: &op) && !liveSymbols.count(V: &op)) {
75 op.erase();
76 ++numDCE;
77 }
78 }
79 }
80 });
81}
82
83/// Compute the liveness of the symbols within the given symbol table.
84/// `symbolTableIsHidden` is true if this symbol table is known to be
85/// unaccessible from operations in its parent regions.
86LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
87 SymbolTableCollection &symbolTable,
88 bool symbolTableIsHidden,
89 DenseSet<Operation *> &liveSymbols) {
90 LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName()
91 << "\n");
92 // A worklist of live operations to propagate uses from.
93 SmallVector<Operation *, 16> worklist;
94
95 // Walk the symbols within the current symbol table, marking the symbols that
96 // are known to be live.
97 for (auto &block : symbolTableOp->getRegion(index: 0)) {
98 // Add all non-symbols or symbols that can't be discarded.
99 for (Operation &op : block) {
100 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(Val: &op);
101 if (!symbol) {
102 worklist.push_back(Elt: &op);
103 continue;
104 }
105 bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
106 symbol.canDiscardOnUseEmpty();
107 if (!isDiscardable && liveSymbols.insert(V: &op).second)
108 worklist.push_back(Elt: &op);
109 }
110 }
111
112 // Process the set of symbols that were known to be live, adding new symbols
113 // that are referenced within. For operations that are not symbol tables, it
114 // considers the liveness with respect to the op itself rather than scope of
115 // nested symbol tables by enqueuing all the top level operations for
116 // consideration.
117 while (!worklist.empty()) {
118 Operation *op = worklist.pop_back_val();
119 LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n");
120
121 // If this is a symbol table, recursively compute its liveness.
122 if (op->hasTrait<OpTrait::SymbolTable>()) {
123 // The internal symbol table is hidden if the parent is, if its not a
124 // symbol, or if it is a private symbol.
125 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(Val: op);
126 bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
127 LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName()
128 << " is hidden: " << symIsHidden << "\n");
129 if (failed(Result: computeLiveness(symbolTableOp: op, symbolTable, symbolTableIsHidden: symIsHidden, liveSymbols)))
130 return failure();
131 } else {
132 LLVM_DEBUG(llvm::dbgs()
133 << "\tnon-symbol table: " << op->getName() << "\n");
134 // If the op is not a symbol table, then, unless op itself is dead which
135 // would be handled by DCE, we need to check all the regions and blocks
136 // within the op to find the uses (e.g., consider visibility within op as
137 // if top level rather than relying on pure symbol table visibility). This
138 // is more conservative than SymbolTable::walkSymbolTables in the case
139 // where there is again SymbolTable information to take advantage of.
140 for (auto &region : op->getRegions())
141 for (auto &block : region.getBlocks())
142 for (Operation &op : block)
143 if (op.getNumRegions())
144 worklist.push_back(Elt: &op);
145 }
146
147 // Get the first parent symbol table op. Note: due to enqueueing of
148 // top-level ops, we may not have a symbol table parent here, but if we do
149 // not, then we also don't have a symbol.
150 Operation *parentOp = op->getParentOp();
151 if (!parentOp->hasTrait<OpTrait::SymbolTable>())
152 continue;
153
154 // Collect the uses held by this operation.
155 std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(from: op);
156 if (!uses) {
157 return op->emitError()
158 << "operation contains potentially unknown symbol table, meaning "
159 << "that we can't reliable compute symbol uses";
160 }
161
162 SmallVector<Operation *, 4> resolvedSymbols;
163 LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n");
164 for (const SymbolTable::SymbolUse &use : *uses) {
165 LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n");
166 // Lookup the symbols referenced by this use.
167 resolvedSymbols.clear();
168 if (failed(Result: symbolTable.lookupSymbolIn(symbolTableOp: parentOp, name: use.getSymbolRef(),
169 symbols&: resolvedSymbols)))
170 // Ignore references to unknown symbols.
171 continue;
172 LLVM_DEBUG({
173 llvm::dbgs() << "\t\tresolved symbols: ";
174 llvm::interleaveComma(resolvedSymbols, llvm::dbgs());
175 llvm::dbgs() << "\n";
176 });
177
178 // Mark each of the resolved symbols as live.
179 for (Operation *resolvedSymbol : resolvedSymbols)
180 if (liveSymbols.insert(V: resolvedSymbol).second)
181 worklist.push_back(Elt: resolvedSymbol);
182 }
183 }
184
185 return success();
186}
187
188std::unique_ptr<Pass> mlir::createSymbolDCEPass() {
189 return std::make_unique<SymbolDCE>();
190}
191

source code of mlir/lib/Transforms/SymbolDCE.cpp