1 | //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_IR_SYMBOLTABLE_H |
10 | #define MLIR_IR_SYMBOLTABLE_H |
11 | |
12 | #include "mlir/IR/Attributes.h" |
13 | #include "mlir/IR/OpDefinition.h" |
14 | #include "llvm/ADT/SetVector.h" |
15 | #include "llvm/ADT/StringMap.h" |
16 | #include "llvm/Support/RWMutex.h" |
17 | |
18 | namespace mlir { |
19 | |
20 | /// This class allows for representing and managing the symbol table used by |
21 | /// operations with the 'SymbolTable' trait. Inserting into and erasing from |
22 | /// this SymbolTable will also insert and erase from the Operation given to it |
23 | /// at construction. |
24 | class SymbolTable { |
25 | public: |
26 | /// Build a symbol table with the symbols within the given operation. |
27 | SymbolTable(Operation *symbolTableOp); |
28 | |
29 | /// Look up a symbol with the specified name, returning null if no such |
30 | /// name exists. Names never include the @ on them. |
31 | Operation *lookup(StringRef name) const; |
32 | template <typename T> |
33 | T lookup(StringRef name) const { |
34 | return dyn_cast_or_null<T>(lookup(name)); |
35 | } |
36 | |
37 | /// Look up a symbol with the specified name, returning null if no such |
38 | /// name exists. Names never include the @ on them. |
39 | Operation *lookup(StringAttr name) const; |
40 | template <typename T> |
41 | T lookup(StringAttr name) const { |
42 | return dyn_cast_or_null<T>(lookup(name)); |
43 | } |
44 | |
45 | /// Remove the given symbol from the table, without deleting it. |
46 | void remove(Operation *op); |
47 | |
48 | /// Erase the given symbol from the table and delete the operation. |
49 | void erase(Operation *symbol); |
50 | |
51 | /// Insert a new symbol into the table, and rename it as necessary to avoid |
52 | /// collisions. Also insert at the specified location in the body of the |
53 | /// associated operation if it is not already there. It is asserted that the |
54 | /// symbol is not inside another operation. Return the name of the symbol |
55 | /// after insertion as attribute. |
56 | StringAttr insert(Operation *symbol, Block::iterator insertPt = {}); |
57 | |
58 | /// Renames the given op or the op refered to by the given name to the given |
59 | /// new name and updates the symbol table and all usages of the symbol |
60 | /// accordingly. Fails if the updating of the usages fails. |
61 | LogicalResult rename(StringAttr from, StringAttr to); |
62 | LogicalResult rename(Operation *op, StringAttr to); |
63 | LogicalResult rename(StringAttr from, StringRef to); |
64 | LogicalResult rename(Operation *op, StringRef to); |
65 | |
66 | /// Renames the given op or the op refered to by the given name to the a name |
67 | /// that is unique within this and the provided other symbol tables and |
68 | /// updates the symbol table and all usages of the symbol accordingly. Returns |
69 | /// the new name or failure if the renaming fails. |
70 | FailureOr<StringAttr> renameToUnique(StringAttr from, |
71 | ArrayRef<SymbolTable *> others); |
72 | FailureOr<StringAttr> renameToUnique(Operation *op, |
73 | ArrayRef<SymbolTable *> others); |
74 | |
75 | /// Return the name of the attribute used for symbol names. |
76 | static StringRef getSymbolAttrName() { return "sym_name" ; } |
77 | |
78 | /// Returns the associated operation. |
79 | Operation *getOp() const { return symbolTableOp; } |
80 | |
81 | /// Return the name of the attribute used for symbol visibility. |
82 | static StringRef getVisibilityAttrName() { return "sym_visibility" ; } |
83 | |
84 | //===--------------------------------------------------------------------===// |
85 | // Symbol Utilities |
86 | //===--------------------------------------------------------------------===// |
87 | |
88 | /// An enumeration detailing the different visibility types that a symbol may |
89 | /// have. |
90 | enum class Visibility { |
91 | /// The symbol is public and may be referenced anywhere internal or external |
92 | /// to the visible references in the IR. |
93 | Public, |
94 | |
95 | /// The symbol is private and may only be referenced by SymbolRefAttrs local |
96 | /// to the operations within the current symbol table. |
97 | Private, |
98 | |
99 | /// The symbol is visible to the current IR, which may include operations in |
100 | /// symbol tables above the one that owns the current symbol. `Nested` |
101 | /// visibility allows for referencing a symbol outside of its current symbol |
102 | /// table, while retaining the ability to observe all uses. |
103 | Nested, |
104 | }; |
105 | |
106 | /// Generate a unique symbol name. Iteratively increase uniquingCounter |
107 | /// and use it as a suffix for symbol names until uniqueChecker does not |
108 | /// detect any conflict. |
109 | template <unsigned N, typename UniqueChecker> |
110 | static SmallString<N> generateSymbolName(StringRef name, |
111 | UniqueChecker uniqueChecker, |
112 | unsigned &uniquingCounter) { |
113 | SmallString<N> nameBuffer(name); |
114 | unsigned originalLength = nameBuffer.size(); |
115 | do { |
116 | nameBuffer.resize(originalLength); |
117 | nameBuffer += '_'; |
118 | nameBuffer += std::to_string(val: uniquingCounter++); |
119 | } while (uniqueChecker(nameBuffer)); |
120 | |
121 | return nameBuffer; |
122 | } |
123 | |
124 | /// Returns the name of the given symbol operation, aborting if no symbol is |
125 | /// present. |
126 | static StringAttr getSymbolName(Operation *symbol); |
127 | |
128 | /// Sets the name of the given symbol operation. |
129 | static void setSymbolName(Operation *symbol, StringAttr name); |
130 | static void setSymbolName(Operation *symbol, StringRef name) { |
131 | setSymbolName(symbol, StringAttr::get(symbol->getContext(), name)); |
132 | } |
133 | |
134 | /// Returns the visibility of the given symbol operation. |
135 | static Visibility getSymbolVisibility(Operation *symbol); |
136 | /// Sets the visibility of the given symbol operation. |
137 | static void setSymbolVisibility(Operation *symbol, Visibility vis); |
138 | |
139 | /// Returns the nearest symbol table from a given operation `from`. Returns |
140 | /// nullptr if no valid parent symbol table could be found. |
141 | static Operation *getNearestSymbolTable(Operation *from); |
142 | |
143 | /// Walks all symbol table operations nested within, and including, `op`. For |
144 | /// each symbol table operation, the provided callback is invoked with the op |
145 | /// and a boolean signifying if the symbols within that symbol table can be |
146 | /// treated as if all uses within the IR are visible to the caller. |
147 | /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols |
148 | /// within `op` are visible. |
149 | static void walkSymbolTables(Operation *op, bool allSymUsesVisible, |
150 | function_ref<void(Operation *, bool)> callback); |
151 | |
152 | /// Returns the operation registered with the given symbol name with the |
153 | /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation |
154 | /// with the 'OpTrait::SymbolTable' trait. |
155 | static Operation *lookupSymbolIn(Operation *op, StringAttr symbol); |
156 | static Operation *lookupSymbolIn(Operation *op, StringRef symbol) { |
157 | return lookupSymbolIn(op, StringAttr::get(op->getContext(), symbol)); |
158 | } |
159 | static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); |
160 | /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced |
161 | /// by a given SymbolRefAttr. Returns failure if any of the nested references |
162 | /// could not be resolved. |
163 | static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol, |
164 | SmallVectorImpl<Operation *> &symbols); |
165 | |
166 | /// Returns the operation registered with the given symbol name within the |
167 | /// closest parent operation of, or including, 'from' with the |
168 | /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was |
169 | /// found. |
170 | static Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol); |
171 | static Operation *lookupNearestSymbolFrom(Operation *from, |
172 | SymbolRefAttr symbol); |
173 | template <typename T> |
174 | static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) { |
175 | return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); |
176 | } |
177 | template <typename T> |
178 | static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { |
179 | return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); |
180 | } |
181 | |
182 | /// This class represents a specific symbol use. |
183 | class SymbolUse { |
184 | public: |
185 | SymbolUse(Operation *op, SymbolRefAttr symbolRef) |
186 | : owner(op), symbolRef(symbolRef) {} |
187 | |
188 | /// Return the operation user of this symbol reference. |
189 | Operation *getUser() const { return owner; } |
190 | |
191 | /// Return the symbol reference that this use represents. |
192 | SymbolRefAttr getSymbolRef() const { return symbolRef; } |
193 | |
194 | private: |
195 | /// The operation that this access is held by. |
196 | Operation *owner; |
197 | |
198 | /// The symbol reference that this use represents. |
199 | SymbolRefAttr symbolRef; |
200 | }; |
201 | |
202 | /// This class implements a range of SymbolRef uses. |
203 | class UseRange { |
204 | public: |
205 | UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {} |
206 | |
207 | using iterator = std::vector<SymbolUse>::const_iterator; |
208 | iterator begin() const { return uses.begin(); } |
209 | iterator end() const { return uses.end(); } |
210 | bool empty() const { return uses.empty(); } |
211 | |
212 | private: |
213 | std::vector<SymbolUse> uses; |
214 | }; |
215 | |
216 | /// Get an iterator range for all of the uses, for any symbol, that are nested |
217 | /// within the given operation 'from'. This does not traverse into any nested |
218 | /// symbol tables. This function returns std::nullopt if there are any unknown |
219 | /// operations that may potentially be symbol tables. |
220 | static std::optional<UseRange> getSymbolUses(Operation *from); |
221 | static std::optional<UseRange> getSymbolUses(Region *from); |
222 | |
223 | /// Get all of the uses of the given symbol that are nested within the given |
224 | /// operation 'from'. This does not traverse into any nested symbol tables. |
225 | /// This function returns std::nullopt if there are any unknown operations |
226 | /// that may potentially be symbol tables. |
227 | static std::optional<UseRange> getSymbolUses(StringAttr symbol, |
228 | Operation *from); |
229 | static std::optional<UseRange> getSymbolUses(Operation *symbol, |
230 | Operation *from); |
231 | static std::optional<UseRange> getSymbolUses(StringAttr symbol, Region *from); |
232 | static std::optional<UseRange> getSymbolUses(Operation *symbol, Region *from); |
233 | |
234 | /// Return if the given symbol is known to have no uses that are nested |
235 | /// within the given operation 'from'. This does not traverse into any nested |
236 | /// symbol tables. This function will also return false if there are any |
237 | /// unknown operations that may potentially be symbol tables. This doesn't |
238 | /// necessarily mean that there are no uses, we just can't conservatively |
239 | /// prove it. |
240 | static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from); |
241 | static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); |
242 | static bool symbolKnownUseEmpty(StringAttr symbol, Region *from); |
243 | static bool symbolKnownUseEmpty(Operation *symbol, Region *from); |
244 | |
245 | /// Attempt to replace all uses of the given symbol 'oldSymbol' with the |
246 | /// provided symbol 'newSymbol' that are nested within the given operation |
247 | /// 'from'. This does not traverse into any nested symbol tables. If there are |
248 | /// any unknown operations that may potentially be symbol tables, no uses are |
249 | /// replaced and failure is returned. |
250 | static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, |
251 | StringAttr newSymbol, |
252 | Operation *from); |
253 | static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, |
254 | StringAttr newSymbolName, |
255 | Operation *from); |
256 | static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, |
257 | StringAttr newSymbol, Region *from); |
258 | static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, |
259 | StringAttr newSymbolName, |
260 | Region *from); |
261 | |
262 | private: |
263 | Operation *symbolTableOp; |
264 | |
265 | /// This is a mapping from a name to the symbol with that name. They key is |
266 | /// always known to be a StringAttr. |
267 | DenseMap<Attribute, Operation *> symbolTable; |
268 | |
269 | /// This is used when name conflicts are detected. |
270 | unsigned uniquingCounter = 0; |
271 | }; |
272 | |
273 | raw_ostream &operator<<(raw_ostream &os, SymbolTable::Visibility visibility); |
274 | |
275 | //===----------------------------------------------------------------------===// |
276 | // SymbolTableCollection |
277 | //===----------------------------------------------------------------------===// |
278 | |
279 | /// This class represents a collection of `SymbolTable`s. This simplifies |
280 | /// certain algorithms that run recursively on nested symbol tables. Symbol |
281 | /// tables are constructed lazily to reduce the upfront cost of constructing |
282 | /// unnecessary tables. |
283 | class SymbolTableCollection { |
284 | public: |
285 | /// Look up a symbol with the specified name within the specified symbol table |
286 | /// operation, returning null if no such name exists. |
287 | Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol); |
288 | Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name); |
289 | template <typename T, typename NameT> |
290 | T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) { |
291 | return dyn_cast_or_null<T>( |
292 | lookupSymbolIn(symbolTableOp, std::forward<NameT>(name))); |
293 | } |
294 | /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced |
295 | /// by a given SymbolRefAttr when resolved within the provided symbol table |
296 | /// operation. Returns failure if any of the nested references could not be |
297 | /// resolved. |
298 | LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, |
299 | SmallVectorImpl<Operation *> &symbols); |
300 | |
301 | /// Returns the operation registered with the given symbol name within the |
302 | /// closest parent operation of, or including, 'from' with the |
303 | /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was |
304 | /// found. |
305 | Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol); |
306 | Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol); |
307 | template <typename T> |
308 | T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) { |
309 | return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); |
310 | } |
311 | template <typename T> |
312 | T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { |
313 | return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); |
314 | } |
315 | |
316 | /// Lookup, or create, a symbol table for an operation. |
317 | SymbolTable &getSymbolTable(Operation *op); |
318 | |
319 | private: |
320 | friend class LockedSymbolTableCollection; |
321 | |
322 | /// The constructed symbol tables nested within this table. |
323 | DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables; |
324 | }; |
325 | |
326 | //===----------------------------------------------------------------------===// |
327 | // LockedSymbolTableCollection |
328 | //===----------------------------------------------------------------------===// |
329 | |
330 | /// This class implements a lock-based shared wrapper around a symbol table |
331 | /// collection that allows shared access to the collection of symbol tables. |
332 | /// This class does not protect shared access to individual symbol tables. |
333 | /// `SymbolTableCollection` lazily instantiates `SymbolTable` instances for |
334 | /// symbol table operations, making read operations not thread-safe. This class |
335 | /// provides a thread-safe `lookupSymbolIn` implementation by synchronizing the |
336 | /// lazy `SymbolTable` lookup. |
337 | class LockedSymbolTableCollection : public SymbolTableCollection { |
338 | public: |
339 | explicit LockedSymbolTableCollection(SymbolTableCollection &collection) |
340 | : collection(collection) {} |
341 | |
342 | /// Look up a symbol with the specified name within the specified symbol table |
343 | /// operation, returning null if no such name exists. |
344 | Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol); |
345 | /// Look up a symbol with the specified name within the specified symbol table |
346 | /// operation, returning null if no such name exists. |
347 | Operation *lookupSymbolIn(Operation *symbolTableOp, FlatSymbolRefAttr symbol); |
348 | /// Look up a potentially nested symbol within the specified symbol table |
349 | /// operation, returning null if no such symbol exists. |
350 | Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name); |
351 | |
352 | /// Lookup a symbol of a particular kind within the specified symbol table, |
353 | /// returning null if the symbol was not found. |
354 | template <typename T, typename NameT> |
355 | T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) { |
356 | return dyn_cast_or_null<T>( |
357 | lookupSymbolIn(symbolTableOp, std::forward<NameT>(name))); |
358 | } |
359 | |
360 | /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced |
361 | /// by a given SymbolRefAttr when resolved within the provided symbol table |
362 | /// operation. Returns failure if any of the nested references could not be |
363 | /// resolved. |
364 | LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, |
365 | SmallVectorImpl<Operation *> &symbols); |
366 | |
367 | private: |
368 | /// Get the symbol table for the symbol table operation, constructing if it |
369 | /// does not exist. This function provides thread safety over `collection` |
370 | /// by locking when performing the lookup and when inserting |
371 | /// lazily-constructed symbol tables. |
372 | SymbolTable &getSymbolTable(Operation *symbolTableOp); |
373 | |
374 | /// The symbol tables to manage. |
375 | SymbolTableCollection &collection; |
376 | /// The mutex protecting access to the symbol table collection. |
377 | llvm::sys::SmartRWMutex<true> mutex; |
378 | }; |
379 | |
380 | //===----------------------------------------------------------------------===// |
381 | // SymbolUserMap |
382 | //===----------------------------------------------------------------------===// |
383 | |
384 | /// This class represents a map of symbols to users, and provides efficient |
385 | /// implementations of symbol queries related to users; such as collecting the |
386 | /// users of a symbol, replacing all uses, etc. |
387 | class SymbolUserMap { |
388 | public: |
389 | /// Build a user map for all of the symbols defined in regions nested under |
390 | /// 'symbolTableOp'. A reference to the provided symbol table collection is |
391 | /// kept by the user map to ensure efficient lookups, thus the lifetime should |
392 | /// extend beyond that of this map. |
393 | SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp); |
394 | |
395 | /// Return the users of the provided symbol operation. |
396 | ArrayRef<Operation *> getUsers(Operation *symbol) const { |
397 | auto it = symbolToUsers.find(Val: symbol); |
398 | return it != symbolToUsers.end() ? it->second.getArrayRef() : std::nullopt; |
399 | } |
400 | |
401 | /// Return true if the given symbol has no uses. |
402 | bool useEmpty(Operation *symbol) const { |
403 | return !symbolToUsers.count(Val: symbol); |
404 | } |
405 | |
406 | /// Replace all of the uses of the given symbol with `newSymbolName`. |
407 | void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName); |
408 | |
409 | private: |
410 | /// A reference to the symbol table used to construct this map. |
411 | SymbolTableCollection &symbolTable; |
412 | |
413 | /// A map of symbol operations to symbol users. |
414 | DenseMap<Operation *, SetVector<Operation *>> symbolToUsers; |
415 | }; |
416 | |
417 | //===----------------------------------------------------------------------===// |
418 | // SymbolTable Trait Types |
419 | //===----------------------------------------------------------------------===// |
420 | |
421 | namespace detail { |
422 | LogicalResult verifySymbolTable(Operation *op); |
423 | LogicalResult verifySymbol(Operation *op); |
424 | } // namespace detail |
425 | |
426 | namespace OpTrait { |
427 | /// A trait used to provide symbol table functionalities to a region operation. |
428 | /// This operation must hold exactly 1 region. Once attached, all operations |
429 | /// that are directly within the region, i.e not including those within child |
430 | /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will |
431 | /// be verified to ensure that the names are uniqued. These operations must also |
432 | /// adhere to the constraints defined by the `Symbol` trait, even if they do not |
433 | /// inherit from it. |
434 | template <typename ConcreteType> |
435 | class SymbolTable : public TraitBase<ConcreteType, SymbolTable> { |
436 | public: |
437 | static LogicalResult verifyRegionTrait(Operation *op) { |
438 | return ::mlir::detail::verifySymbolTable(op); |
439 | } |
440 | |
441 | /// Look up a symbol with the specified name, returning null if no such |
442 | /// name exists. Symbol names never include the @ on them. Note: This |
443 | /// performs a linear scan of held symbols. |
444 | Operation *lookupSymbol(StringAttr name) { |
445 | return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); |
446 | } |
447 | template <typename T> |
448 | T lookupSymbol(StringAttr name) { |
449 | return dyn_cast_or_null<T>(lookupSymbol(name)); |
450 | } |
451 | Operation *lookupSymbol(SymbolRefAttr symbol) { |
452 | return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol); |
453 | } |
454 | template <typename T> |
455 | T lookupSymbol(SymbolRefAttr symbol) { |
456 | return dyn_cast_or_null<T>(lookupSymbol(symbol)); |
457 | } |
458 | |
459 | Operation *lookupSymbol(StringRef name) { |
460 | return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); |
461 | } |
462 | template <typename T> |
463 | T lookupSymbol(StringRef name) { |
464 | return dyn_cast_or_null<T>(lookupSymbol(name)); |
465 | } |
466 | }; |
467 | |
468 | } // namespace OpTrait |
469 | |
470 | //===----------------------------------------------------------------------===// |
471 | // Visibility parsing implementation. |
472 | //===----------------------------------------------------------------------===// |
473 | |
474 | namespace impl { |
475 | /// Parse an optional visibility attribute keyword (i.e., public, private, or |
476 | /// nested) without quotes in a string attribute named 'attrName'. |
477 | ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, |
478 | NamedAttrList &attrs); |
479 | } // namespace impl |
480 | |
481 | } // namespace mlir |
482 | |
483 | /// Include the generated symbol interfaces. |
484 | #include "mlir/IR/SymbolInterfaces.h.inc" |
485 | |
486 | #endif // MLIR_IR_SYMBOLTABLE_H |
487 | |