1//===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_SYMBOLTABLE_H
10#define MLIR_IR_SYMBOLTABLE_H
11
12#include "mlir/IR/Attributes.h"
13#include "mlir/IR/OpDefinition.h"
14#include "llvm/ADT/SetVector.h"
15#include "llvm/ADT/StringMap.h"
16#include "llvm/Support/RWMutex.h"
17
18namespace mlir {
19
20/// This class allows for representing and managing the symbol table used by
21/// operations with the 'SymbolTable' trait. Inserting into and erasing from
22/// this SymbolTable will also insert and erase from the Operation given to it
23/// at construction.
24class SymbolTable {
25public:
26 /// Build a symbol table with the symbols within the given operation.
27 SymbolTable(Operation *symbolTableOp);
28
29 /// Look up a symbol with the specified name, returning null if no such
30 /// name exists. Names never include the @ on them.
31 Operation *lookup(StringRef name) const;
32 template <typename T>
33 T lookup(StringRef name) const {
34 return dyn_cast_or_null<T>(lookup(name));
35 }
36
37 /// Look up a symbol with the specified name, returning null if no such
38 /// name exists. Names never include the @ on them.
39 Operation *lookup(StringAttr name) const;
40 template <typename T>
41 T lookup(StringAttr name) const {
42 return dyn_cast_or_null<T>(lookup(name));
43 }
44
45 /// Remove the given symbol from the table, without deleting it.
46 void remove(Operation *op);
47
48 /// Erase the given symbol from the table and delete the operation.
49 void erase(Operation *symbol);
50
51 /// Insert a new symbol into the table, and rename it as necessary to avoid
52 /// collisions. Also insert at the specified location in the body of the
53 /// associated operation if it is not already there. It is asserted that the
54 /// symbol is not inside another operation. Return the name of the symbol
55 /// after insertion as attribute.
56 StringAttr insert(Operation *symbol, Block::iterator insertPt = {});
57
58 /// Renames the given op or the op refered to by the given name to the given
59 /// new name and updates the symbol table and all usages of the symbol
60 /// accordingly. Fails if the updating of the usages fails.
61 LogicalResult rename(StringAttr from, StringAttr to);
62 LogicalResult rename(Operation *op, StringAttr to);
63 LogicalResult rename(StringAttr from, StringRef to);
64 LogicalResult rename(Operation *op, StringRef to);
65
66 /// Renames the given op or the op refered to by the given name to the a name
67 /// that is unique within this and the provided other symbol tables and
68 /// updates the symbol table and all usages of the symbol accordingly. Returns
69 /// the new name or failure if the renaming fails.
70 FailureOr<StringAttr> renameToUnique(StringAttr from,
71 ArrayRef<SymbolTable *> others);
72 FailureOr<StringAttr> renameToUnique(Operation *op,
73 ArrayRef<SymbolTable *> others);
74
75 /// Return the name of the attribute used for symbol names.
76 static StringRef getSymbolAttrName() { return "sym_name"; }
77
78 /// Returns the associated operation.
79 Operation *getOp() const { return symbolTableOp; }
80
81 /// Return the name of the attribute used for symbol visibility.
82 static StringRef getVisibilityAttrName() { return "sym_visibility"; }
83
84 //===--------------------------------------------------------------------===//
85 // Symbol Utilities
86 //===--------------------------------------------------------------------===//
87
88 /// An enumeration detailing the different visibility types that a symbol may
89 /// have.
90 enum class Visibility {
91 /// The symbol is public and may be referenced anywhere internal or external
92 /// to the visible references in the IR.
93 Public,
94
95 /// The symbol is private and may only be referenced by SymbolRefAttrs local
96 /// to the operations within the current symbol table.
97 Private,
98
99 /// The symbol is visible to the current IR, which may include operations in
100 /// symbol tables above the one that owns the current symbol. `Nested`
101 /// visibility allows for referencing a symbol outside of its current symbol
102 /// table, while retaining the ability to observe all uses.
103 Nested,
104 };
105
106 /// Generate a unique symbol name. Iteratively increase uniquingCounter
107 /// and use it as a suffix for symbol names until uniqueChecker does not
108 /// detect any conflict.
109 template <unsigned N, typename UniqueChecker>
110 static SmallString<N> generateSymbolName(StringRef name,
111 UniqueChecker uniqueChecker,
112 unsigned &uniquingCounter) {
113 SmallString<N> nameBuffer(name);
114 unsigned originalLength = nameBuffer.size();
115 do {
116 nameBuffer.resize(originalLength);
117 nameBuffer += '_';
118 nameBuffer += std::to_string(val: uniquingCounter++);
119 } while (uniqueChecker(nameBuffer));
120
121 return nameBuffer;
122 }
123
124 /// Returns the name of the given symbol operation, aborting if no symbol is
125 /// present.
126 static StringAttr getSymbolName(Operation *symbol);
127
128 /// Sets the name of the given symbol operation.
129 static void setSymbolName(Operation *symbol, StringAttr name);
130 static void setSymbolName(Operation *symbol, StringRef name) {
131 setSymbolName(symbol, StringAttr::get(symbol->getContext(), name));
132 }
133
134 /// Returns the visibility of the given symbol operation.
135 static Visibility getSymbolVisibility(Operation *symbol);
136 /// Sets the visibility of the given symbol operation.
137 static void setSymbolVisibility(Operation *symbol, Visibility vis);
138
139 /// Returns the nearest symbol table from a given operation `from`. Returns
140 /// nullptr if no valid parent symbol table could be found.
141 static Operation *getNearestSymbolTable(Operation *from);
142
143 /// Walks all symbol table operations nested within, and including, `op`. For
144 /// each symbol table operation, the provided callback is invoked with the op
145 /// and a boolean signifying if the symbols within that symbol table can be
146 /// treated as if all uses within the IR are visible to the caller.
147 /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
148 /// within `op` are visible.
149 static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
150 function_ref<void(Operation *, bool)> callback);
151
152 /// Returns the operation registered with the given symbol name with the
153 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
154 /// with the 'OpTrait::SymbolTable' trait.
155 static Operation *lookupSymbolIn(Operation *op, StringAttr symbol);
156 static Operation *lookupSymbolIn(Operation *op, StringRef symbol) {
157 return lookupSymbolIn(op, StringAttr::get(op->getContext(), symbol));
158 }
159 static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
160 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
161 /// by a given SymbolRefAttr. Returns failure if any of the nested references
162 /// could not be resolved.
163 static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol,
164 SmallVectorImpl<Operation *> &symbols);
165
166 /// Returns the operation registered with the given symbol name within the
167 /// closest parent operation of, or including, 'from' with the
168 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
169 /// found.
170 static Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
171 static Operation *lookupNearestSymbolFrom(Operation *from,
172 SymbolRefAttr symbol);
173 template <typename T>
174 static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
175 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
176 }
177 template <typename T>
178 static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
179 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
180 }
181
182 /// This class represents a specific symbol use.
183 class SymbolUse {
184 public:
185 SymbolUse(Operation *op, SymbolRefAttr symbolRef)
186 : owner(op), symbolRef(symbolRef) {}
187
188 /// Return the operation user of this symbol reference.
189 Operation *getUser() const { return owner; }
190
191 /// Return the symbol reference that this use represents.
192 SymbolRefAttr getSymbolRef() const { return symbolRef; }
193
194 private:
195 /// The operation that this access is held by.
196 Operation *owner;
197
198 /// The symbol reference that this use represents.
199 SymbolRefAttr symbolRef;
200 };
201
202 /// This class implements a range of SymbolRef uses.
203 class UseRange {
204 public:
205 UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {}
206
207 using iterator = std::vector<SymbolUse>::const_iterator;
208 iterator begin() const { return uses.begin(); }
209 iterator end() const { return uses.end(); }
210 bool empty() const { return uses.empty(); }
211
212 private:
213 std::vector<SymbolUse> uses;
214 };
215
216 /// Get an iterator range for all of the uses, for any symbol, that are nested
217 /// within the given operation 'from'. This does not traverse into any nested
218 /// symbol tables. This function returns std::nullopt if there are any unknown
219 /// operations that may potentially be symbol tables.
220 static std::optional<UseRange> getSymbolUses(Operation *from);
221 static std::optional<UseRange> getSymbolUses(Region *from);
222
223 /// Get all of the uses of the given symbol that are nested within the given
224 /// operation 'from'. This does not traverse into any nested symbol tables.
225 /// This function returns std::nullopt if there are any unknown operations
226 /// that may potentially be symbol tables.
227 static std::optional<UseRange> getSymbolUses(StringAttr symbol,
228 Operation *from);
229 static std::optional<UseRange> getSymbolUses(Operation *symbol,
230 Operation *from);
231 static std::optional<UseRange> getSymbolUses(StringAttr symbol, Region *from);
232 static std::optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
233
234 /// Return if the given symbol is known to have no uses that are nested
235 /// within the given operation 'from'. This does not traverse into any nested
236 /// symbol tables. This function will also return false if there are any
237 /// unknown operations that may potentially be symbol tables. This doesn't
238 /// necessarily mean that there are no uses, we just can't conservatively
239 /// prove it.
240 static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from);
241 static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
242 static bool symbolKnownUseEmpty(StringAttr symbol, Region *from);
243 static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
244
245 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
246 /// provided symbol 'newSymbol' that are nested within the given operation
247 /// 'from'. This does not traverse into any nested symbol tables. If there are
248 /// any unknown operations that may potentially be symbol tables, no uses are
249 /// replaced and failure is returned.
250 static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
251 StringAttr newSymbol,
252 Operation *from);
253 static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
254 StringAttr newSymbolName,
255 Operation *from);
256 static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
257 StringAttr newSymbol, Region *from);
258 static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
259 StringAttr newSymbolName,
260 Region *from);
261
262private:
263 Operation *symbolTableOp;
264
265 /// This is a mapping from a name to the symbol with that name. They key is
266 /// always known to be a StringAttr.
267 DenseMap<Attribute, Operation *> symbolTable;
268
269 /// This is used when name conflicts are detected.
270 unsigned uniquingCounter = 0;
271};
272
273raw_ostream &operator<<(raw_ostream &os, SymbolTable::Visibility visibility);
274
275//===----------------------------------------------------------------------===//
276// SymbolTableCollection
277//===----------------------------------------------------------------------===//
278
279/// This class represents a collection of `SymbolTable`s. This simplifies
280/// certain algorithms that run recursively on nested symbol tables. Symbol
281/// tables are constructed lazily to reduce the upfront cost of constructing
282/// unnecessary tables.
283class SymbolTableCollection {
284public:
285 /// Look up a symbol with the specified name within the specified symbol table
286 /// operation, returning null if no such name exists.
287 Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
288 Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
289 template <typename T, typename NameT>
290 T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) {
291 return dyn_cast_or_null<T>(
292 lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
293 }
294 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
295 /// by a given SymbolRefAttr when resolved within the provided symbol table
296 /// operation. Returns failure if any of the nested references could not be
297 /// resolved.
298 LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
299 SmallVectorImpl<Operation *> &symbols);
300
301 /// Returns the operation registered with the given symbol name within the
302 /// closest parent operation of, or including, 'from' with the
303 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
304 /// found.
305 Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
306 Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
307 template <typename T>
308 T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
309 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
310 }
311 template <typename T>
312 T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
313 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
314 }
315
316 /// Lookup, or create, a symbol table for an operation.
317 SymbolTable &getSymbolTable(Operation *op);
318
319private:
320 friend class LockedSymbolTableCollection;
321
322 /// The constructed symbol tables nested within this table.
323 DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
324};
325
326//===----------------------------------------------------------------------===//
327// LockedSymbolTableCollection
328//===----------------------------------------------------------------------===//
329
330/// This class implements a lock-based shared wrapper around a symbol table
331/// collection that allows shared access to the collection of symbol tables.
332/// This class does not protect shared access to individual symbol tables.
333/// `SymbolTableCollection` lazily instantiates `SymbolTable` instances for
334/// symbol table operations, making read operations not thread-safe. This class
335/// provides a thread-safe `lookupSymbolIn` implementation by synchronizing the
336/// lazy `SymbolTable` lookup.
337class LockedSymbolTableCollection : public SymbolTableCollection {
338public:
339 explicit LockedSymbolTableCollection(SymbolTableCollection &collection)
340 : collection(collection) {}
341
342 /// Look up a symbol with the specified name within the specified symbol table
343 /// operation, returning null if no such name exists.
344 Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
345 /// Look up a symbol with the specified name within the specified symbol table
346 /// operation, returning null if no such name exists.
347 Operation *lookupSymbolIn(Operation *symbolTableOp, FlatSymbolRefAttr symbol);
348 /// Look up a potentially nested symbol within the specified symbol table
349 /// operation, returning null if no such symbol exists.
350 Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
351
352 /// Lookup a symbol of a particular kind within the specified symbol table,
353 /// returning null if the symbol was not found.
354 template <typename T, typename NameT>
355 T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) {
356 return dyn_cast_or_null<T>(
357 lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
358 }
359
360 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
361 /// by a given SymbolRefAttr when resolved within the provided symbol table
362 /// operation. Returns failure if any of the nested references could not be
363 /// resolved.
364 LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
365 SmallVectorImpl<Operation *> &symbols);
366
367private:
368 /// Get the symbol table for the symbol table operation, constructing if it
369 /// does not exist. This function provides thread safety over `collection`
370 /// by locking when performing the lookup and when inserting
371 /// lazily-constructed symbol tables.
372 SymbolTable &getSymbolTable(Operation *symbolTableOp);
373
374 /// The symbol tables to manage.
375 SymbolTableCollection &collection;
376 /// The mutex protecting access to the symbol table collection.
377 llvm::sys::SmartRWMutex<true> mutex;
378};
379
380//===----------------------------------------------------------------------===//
381// SymbolUserMap
382//===----------------------------------------------------------------------===//
383
384/// This class represents a map of symbols to users, and provides efficient
385/// implementations of symbol queries related to users; such as collecting the
386/// users of a symbol, replacing all uses, etc.
387class SymbolUserMap {
388public:
389 /// Build a user map for all of the symbols defined in regions nested under
390 /// 'symbolTableOp'. A reference to the provided symbol table collection is
391 /// kept by the user map to ensure efficient lookups, thus the lifetime should
392 /// extend beyond that of this map.
393 SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp);
394
395 /// Return the users of the provided symbol operation.
396 ArrayRef<Operation *> getUsers(Operation *symbol) const {
397 auto it = symbolToUsers.find(Val: symbol);
398 return it != symbolToUsers.end() ? it->second.getArrayRef() : std::nullopt;
399 }
400
401 /// Return true if the given symbol has no uses.
402 bool useEmpty(Operation *symbol) const {
403 return !symbolToUsers.count(Val: symbol);
404 }
405
406 /// Replace all of the uses of the given symbol with `newSymbolName`.
407 void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName);
408
409private:
410 /// A reference to the symbol table used to construct this map.
411 SymbolTableCollection &symbolTable;
412
413 /// A map of symbol operations to symbol users.
414 DenseMap<Operation *, SetVector<Operation *>> symbolToUsers;
415};
416
417//===----------------------------------------------------------------------===//
418// SymbolTable Trait Types
419//===----------------------------------------------------------------------===//
420
421namespace detail {
422LogicalResult verifySymbolTable(Operation *op);
423LogicalResult verifySymbol(Operation *op);
424} // namespace detail
425
426namespace OpTrait {
427/// A trait used to provide symbol table functionalities to a region operation.
428/// This operation must hold exactly 1 region. Once attached, all operations
429/// that are directly within the region, i.e not including those within child
430/// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will
431/// be verified to ensure that the names are uniqued. These operations must also
432/// adhere to the constraints defined by the `Symbol` trait, even if they do not
433/// inherit from it.
434template <typename ConcreteType>
435class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
436public:
437 static LogicalResult verifyRegionTrait(Operation *op) {
438 return ::mlir::detail::verifySymbolTable(op);
439 }
440
441 /// Look up a symbol with the specified name, returning null if no such
442 /// name exists. Symbol names never include the @ on them. Note: This
443 /// performs a linear scan of held symbols.
444 Operation *lookupSymbol(StringAttr name) {
445 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
446 }
447 template <typename T>
448 T lookupSymbol(StringAttr name) {
449 return dyn_cast_or_null<T>(lookupSymbol(name));
450 }
451 Operation *lookupSymbol(SymbolRefAttr symbol) {
452 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol);
453 }
454 template <typename T>
455 T lookupSymbol(SymbolRefAttr symbol) {
456 return dyn_cast_or_null<T>(lookupSymbol(symbol));
457 }
458
459 Operation *lookupSymbol(StringRef name) {
460 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
461 }
462 template <typename T>
463 T lookupSymbol(StringRef name) {
464 return dyn_cast_or_null<T>(lookupSymbol(name));
465 }
466};
467
468} // namespace OpTrait
469
470//===----------------------------------------------------------------------===//
471// Visibility parsing implementation.
472//===----------------------------------------------------------------------===//
473
474namespace impl {
475/// Parse an optional visibility attribute keyword (i.e., public, private, or
476/// nested) without quotes in a string attribute named 'attrName'.
477ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser,
478 NamedAttrList &attrs);
479} // namespace impl
480
481} // namespace mlir
482
483/// Include the generated symbol interfaces.
484#include "mlir/IR/SymbolInterfaces.h.inc"
485
486#endif // MLIR_IR_SYMBOLTABLE_H
487

source code of mlir/include/mlir/IR/SymbolTable.h