1//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/SymbolTable.h"
10#include "mlir/IR/Builders.h"
11#include "mlir/IR/OpImplementation.h"
12#include "llvm/ADT/SetVector.h"
13#include "llvm/ADT/SmallPtrSet.h"
14#include "llvm/ADT/SmallString.h"
15#include "llvm/ADT/StringSwitch.h"
16#include <optional>
17
18using namespace mlir;
19
20/// Return true if the given operation is unknown and may potentially define a
21/// symbol table.
22static bool isPotentiallyUnknownSymbolTable(Operation *op) {
23 return op->getNumRegions() == 1 && !op->getDialect();
24}
25
26/// Returns the string name of the given symbol, or null if this is not a
27/// symbol.
28static StringAttr getNameIfSymbol(Operation *op) {
29 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
30}
31static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
32 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
33}
34
35/// Computes the nested symbol reference attribute for the symbol 'symbolName'
36/// that are usable within the symbol table operations from 'symbol' as far up
37/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
38/// Returns success if all references up to 'within' could be computed.
39static LogicalResult
40collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
41 Operation *within,
42 SmallVectorImpl<SymbolRefAttr> &results) {
43 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
44 MLIRContext *ctx = symbol->getContext();
45
46 auto leafRef = FlatSymbolRefAttr::get(symbolName);
47 results.push_back(leafRef);
48
49 // Early exit for when 'within' is the parent of 'symbol'.
50 Operation *symbolTableOp = symbol->getParentOp();
51 if (within == symbolTableOp)
52 return success();
53
54 // Collect references until 'symbolTableOp' reaches 'within'.
55 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
56 StringAttr symbolNameId =
57 StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
58 do {
59 // Each parent of 'symbol' should define a symbol table.
60 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
61 return failure();
62 // Each parent of 'symbol' should also be a symbol.
63 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
64 if (!symbolTableName)
65 return failure();
66 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
67
68 symbolTableOp = symbolTableOp->getParentOp();
69 if (symbolTableOp == within)
70 break;
71 nestedRefs.insert(nestedRefs.begin(),
72 FlatSymbolRefAttr::get(symbolTableName));
73 } while (true);
74 return success();
75}
76
77/// Walk all of the operations within the given set of regions, without
78/// traversing into any nested symbol tables. Stops walking if the result of the
79/// callback is anything other than `WalkResult::advance`.
80static std::optional<WalkResult>
81walkSymbolTable(MutableArrayRef<Region> regions,
82 function_ref<std::optional<WalkResult>(Operation *)> callback) {
83 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(Range&: regions));
84 while (!worklist.empty()) {
85 for (Operation &op : worklist.pop_back_val()->getOps()) {
86 std::optional<WalkResult> result = callback(&op);
87 if (result != WalkResult::advance())
88 return result;
89
90 // If this op defines a new symbol table scope, we can't traverse. Any
91 // symbol references nested within 'op' are different semantically.
92 if (!op.hasTrait<OpTrait::SymbolTable>()) {
93 for (Region &region : op.getRegions())
94 worklist.push_back(Elt: &region);
95 }
96 }
97 }
98 return WalkResult::advance();
99}
100
101/// Walk all of the operations nested under, and including, the given operation,
102/// without traversing into any nested symbol tables. Stops walking if the
103/// result of the callback is anything other than `WalkResult::advance`.
104static std::optional<WalkResult>
105walkSymbolTable(Operation *op,
106 function_ref<std::optional<WalkResult>(Operation *)> callback) {
107 std::optional<WalkResult> result = callback(op);
108 if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
109 return result;
110 return walkSymbolTable(regions: op->getRegions(), callback);
111}
112
113//===----------------------------------------------------------------------===//
114// SymbolTable
115//===----------------------------------------------------------------------===//
116
117/// Build a symbol table with the symbols within the given operation.
118SymbolTable::SymbolTable(Operation *symbolTableOp)
119 : symbolTableOp(symbolTableOp) {
120 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
121 "expected operation to have SymbolTable trait");
122 assert(symbolTableOp->getNumRegions() == 1 &&
123 "expected operation to have a single region");
124 assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) &&
125 "expected operation to have a single block");
126
127 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
128 SymbolTable::getSymbolAttrName());
129 for (auto &op : symbolTableOp->getRegion(index: 0).front()) {
130 StringAttr name = getNameIfSymbol(&op, symbolNameId);
131 if (!name)
132 continue;
133
134 auto inserted = symbolTable.insert({name, &op});
135 (void)inserted;
136 assert(inserted.second &&
137 "expected region to contain uniquely named symbol operations");
138 }
139}
140
141/// Look up a symbol with the specified name, returning null if no such name
142/// exists. Names never include the @ on them.
143Operation *SymbolTable::lookup(StringRef name) const {
144 return lookup(StringAttr::get(symbolTableOp->getContext(), name));
145}
146Operation *SymbolTable::lookup(StringAttr name) const {
147 return symbolTable.lookup(Val: name);
148}
149
150void SymbolTable::remove(Operation *op) {
151 StringAttr name = getNameIfSymbol(op);
152 assert(name && "expected valid 'name' attribute");
153 assert(op->getParentOp() == symbolTableOp &&
154 "expected this operation to be inside of the operation with this "
155 "SymbolTable");
156
157 auto it = symbolTable.find(name);
158 if (it != symbolTable.end() && it->second == op)
159 symbolTable.erase(it);
160}
161
162void SymbolTable::erase(Operation *symbol) {
163 remove(op: symbol);
164 symbol->erase();
165}
166
167// TODO: Consider if this should be renamed to something like insertOrUpdate
168/// Insert a new symbol into the table and associated operation if not already
169/// there and rename it as necessary to avoid collisions. Return the name of
170/// the symbol after insertion as attribute.
171StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
172 // The symbol cannot be the child of another op and must be the child of the
173 // symbolTableOp after this.
174 //
175 // TODO: consider if SymbolTable's constructor should behave the same.
176 if (!symbol->getParentOp()) {
177 auto &body = symbolTableOp->getRegion(index: 0).front();
178 if (insertPt == Block::iterator()) {
179 insertPt = Block::iterator(body.end());
180 } else {
181 assert((insertPt == body.end() ||
182 insertPt->getParentOp() == symbolTableOp) &&
183 "expected insertPt to be in the associated module operation");
184 }
185 // Insert before the terminator, if any.
186 if (insertPt == Block::iterator(body.end()) && !body.empty() &&
187 std::prev(x: body.end())->hasTrait<OpTrait::IsTerminator>())
188 insertPt = std::prev(x: body.end());
189
190 body.getOperations().insert(where: insertPt, New: symbol);
191 }
192 assert(symbol->getParentOp() == symbolTableOp &&
193 "symbol is already inserted in another op");
194
195 // Add this symbol to the symbol table, uniquing the name if a conflict is
196 // detected.
197 StringAttr name = getSymbolName(symbol);
198 if (symbolTable.insert({name, symbol}).second)
199 return name;
200 // If the symbol was already in the table, also return.
201 if (symbolTable.lookup(Val: name) == symbol)
202 return name;
203
204 MLIRContext *context = symbol->getContext();
205 SmallString<128> nameBuffer = generateSymbolName<128>(
206 name.getValue(),
207 [&](StringRef candidate) {
208 return !symbolTable
209 .insert({StringAttr::get(context, candidate), symbol})
210 .second;
211 },
212 uniquingCounter);
213 setSymbolName(symbol, name: nameBuffer);
214 return getSymbolName(symbol);
215}
216
217LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
218 Operation *op = lookup(from);
219 return rename(op, to);
220}
221
222LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
223 StringAttr from = getNameIfSymbol(op);
224 (void)from;
225
226 assert(from && "expected valid 'name' attribute");
227 assert(op->getParentOp() == symbolTableOp &&
228 "expected this operation to be inside of the operation with this "
229 "SymbolTable");
230 assert(lookup(from) == op && "current name does not resolve to op");
231 assert(lookup(to) == nullptr && "new name already exists");
232
233 if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp())))
234 return failure();
235
236 // Remove op with old name, change name, add with new name. The order is
237 // important here due to how `remove` and `insert` rely on the op name.
238 remove(op);
239 setSymbolName(op, to);
240 insert(op);
241
242 assert(lookup(to) == op && "new name does not resolve to renamed op");
243 assert(lookup(from) == nullptr && "old name still exists");
244
245 return success();
246}
247
248LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
249 auto toAttr = StringAttr::get(getOp()->getContext(), to);
250 return rename(from, toAttr);
251}
252
253LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
254 auto toAttr = StringAttr::get(getOp()->getContext(), to);
255 return rename(op, toAttr);
256}
257
258FailureOr<StringAttr>
259SymbolTable::renameToUnique(StringAttr oldName,
260 ArrayRef<SymbolTable *> others) {
261
262 // Determine new name that is unique in all symbol tables.
263 StringAttr newName;
264 {
265 MLIRContext *context = oldName.getContext();
266 SmallString<64> prefix = oldName.getValue();
267 int uniqueId = 0;
268 prefix.push_back(Elt: '_');
269 while (true) {
270 newName = StringAttr::get(context, prefix + Twine(uniqueId++));
271 auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
272 if (!lookupNewName(this) && llvm::none_of(Range&: others, P: lookupNewName)) {
273 break;
274 }
275 }
276 }
277
278 // Apply renaming.
279 if (failed(rename(oldName, newName)))
280 return failure();
281 return newName;
282}
283
284FailureOr<StringAttr>
285SymbolTable::renameToUnique(Operation *op, ArrayRef<SymbolTable *> others) {
286 StringAttr from = getNameIfSymbol(op);
287 assert(from && "expected valid 'name' attribute");
288 return renameToUnique(from, others);
289}
290
291/// Returns the name of the given symbol operation.
292StringAttr SymbolTable::getSymbolName(Operation *symbol) {
293 StringAttr name = getNameIfSymbol(symbol);
294 assert(name && "expected valid symbol name");
295 return name;
296}
297
298/// Sets the name of the given symbol operation.
299void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
300 symbol->setAttr(getSymbolAttrName(), name);
301}
302
303/// Returns the visibility of the given symbol operation.
304SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) {
305 // If the attribute doesn't exist, assume public.
306 StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
307 if (!vis)
308 return Visibility::Public;
309
310 // Otherwise, switch on the string value.
311 return StringSwitch<Visibility>(vis.getValue())
312 .Case(S: "private", Value: Visibility::Private)
313 .Case(S: "nested", Value: Visibility::Nested)
314 .Case(S: "public", Value: Visibility::Public);
315}
316/// Sets the visibility of the given symbol operation.
317void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
318 MLIRContext *ctx = symbol->getContext();
319
320 // If the visibility is public, just drop the attribute as this is the
321 // default.
322 if (vis == Visibility::Public) {
323 symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
324 return;
325 }
326
327 // Otherwise, update the attribute.
328 assert((vis == Visibility::Private || vis == Visibility::Nested) &&
329 "unknown symbol visibility kind");
330
331 StringRef visName = vis == Visibility::Private ? "private" : "nested";
332 symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
333}
334
335/// Returns the nearest symbol table from a given operation `from`. Returns
336/// nullptr if no valid parent symbol table could be found.
337Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
338 assert(from && "expected valid operation");
339 if (isPotentiallyUnknownSymbolTable(op: from))
340 return nullptr;
341
342 while (!from->hasTrait<OpTrait::SymbolTable>()) {
343 from = from->getParentOp();
344
345 // Check that this is a valid op and isn't an unknown symbol table.
346 if (!from || isPotentiallyUnknownSymbolTable(op: from))
347 return nullptr;
348 }
349 return from;
350}
351
352/// Walks all symbol table operations nested within, and including, `op`. For
353/// each symbol table operation, the provided callback is invoked with the op
354/// and a boolean signifying if the symbols within that symbol table can be
355/// treated as if all uses are visible. `allSymUsesVisible` identifies whether
356/// all of the symbol uses of symbols within `op` are visible.
357void SymbolTable::walkSymbolTables(
358 Operation *op, bool allSymUsesVisible,
359 function_ref<void(Operation *, bool)> callback) {
360 bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
361 if (isSymbolTable) {
362 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
363 allSymUsesVisible |= !symbol || symbol.isPrivate();
364 } else {
365 // Otherwise if 'op' is not a symbol table, any nested symbols are
366 // guaranteed to be hidden.
367 allSymUsesVisible = true;
368 }
369
370 for (Region &region : op->getRegions())
371 for (Block &block : region)
372 for (Operation &nestedOp : block)
373 walkSymbolTables(op: &nestedOp, allSymUsesVisible, callback);
374
375 // If 'op' had the symbol table trait, visit it after any nested symbol
376 // tables.
377 if (isSymbolTable)
378 callback(op, allSymUsesVisible);
379}
380
381/// Returns the operation registered with the given symbol name with the
382/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
383/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
384/// was found.
385Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
386 StringAttr symbol) {
387 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
388 Region &region = symbolTableOp->getRegion(index: 0);
389 if (region.empty())
390 return nullptr;
391
392 // Look for a symbol with the given name.
393 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
394 SymbolTable::getSymbolAttrName());
395 for (auto &op : region.front())
396 if (getNameIfSymbol(&op, symbolNameId) == symbol)
397 return &op;
398 return nullptr;
399}
400Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
401 SymbolRefAttr symbol) {
402 SmallVector<Operation *, 4> resolvedSymbols;
403 if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
404 return nullptr;
405 return resolvedSymbols.back();
406}
407
408/// Internal implementation of `lookupSymbolIn` that allows for specialized
409/// implementations of the lookup function.
410static LogicalResult lookupSymbolInImpl(
411 Operation *symbolTableOp, SymbolRefAttr symbol,
412 SmallVectorImpl<Operation *> &symbols,
413 function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
414 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
415
416 // Lookup the root reference for this symbol.
417 symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
418 if (!symbolTableOp)
419 return failure();
420 symbols.push_back(Elt: symbolTableOp);
421
422 // If there are no nested references, just return the root symbol directly.
423 ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
424 if (nestedRefs.empty())
425 return success();
426
427 // Verify that the root is also a symbol table.
428 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
429 return failure();
430
431 // Otherwise, lookup each of the nested non-leaf references and ensure that
432 // each corresponds to a valid symbol table.
433 for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
434 symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
435 if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
436 return failure();
437 symbols.push_back(symbolTableOp);
438 }
439 symbols.push_back(Elt: lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
440 return success(IsSuccess: symbols.back());
441}
442
443LogicalResult
444SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
445 SmallVectorImpl<Operation *> &symbols) {
446 auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
447 return lookupSymbolIn(symbolTableOp, symbol);
448 };
449 return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
450}
451
452/// Returns the operation registered with the given symbol name within the
453/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
454/// nullptr if no valid symbol was found.
455Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
456 StringAttr symbol) {
457 Operation *symbolTableOp = getNearestSymbolTable(from);
458 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
459}
460Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
461 SymbolRefAttr symbol) {
462 Operation *symbolTableOp = getNearestSymbolTable(from);
463 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
464}
465
466raw_ostream &mlir::operator<<(raw_ostream &os,
467 SymbolTable::Visibility visibility) {
468 switch (visibility) {
469 case SymbolTable::Visibility::Public:
470 return os << "public";
471 case SymbolTable::Visibility::Private:
472 return os << "private";
473 case SymbolTable::Visibility::Nested:
474 return os << "nested";
475 }
476 llvm_unreachable("Unexpected visibility");
477}
478
479//===----------------------------------------------------------------------===//
480// SymbolTable Trait Types
481//===----------------------------------------------------------------------===//
482
483LogicalResult detail::verifySymbolTable(Operation *op) {
484 if (op->getNumRegions() != 1)
485 return op->emitOpError()
486 << "Operations with a 'SymbolTable' must have exactly one region";
487 if (!llvm::hasSingleElement(C&: op->getRegion(index: 0)))
488 return op->emitOpError()
489 << "Operations with a 'SymbolTable' must have exactly one block";
490
491 // Check that all symbols are uniquely named within child regions.
492 DenseMap<Attribute, Location> nameToOrigLoc;
493 for (auto &block : op->getRegion(index: 0)) {
494 for (auto &op : block) {
495 // Check for a symbol name attribute.
496 auto nameAttr =
497 op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName());
498 if (!nameAttr)
499 continue;
500
501 // Try to insert this symbol into the table.
502 auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
503 if (!it.second)
504 return op.emitError()
505 .append("redefinition of symbol named '", nameAttr.getValue(), "'")
506 .attachNote(it.first->second)
507 .append("see existing symbol definition here");
508 }
509 }
510
511 // Verify any nested symbol user operations.
512 SymbolTableCollection symbolTable;
513 auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
514 if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
515 return WalkResult(user.verifySymbolUses(symbolTable));
516 return WalkResult::advance();
517 };
518
519 std::optional<WalkResult> result =
520 walkSymbolTable(regions: op->getRegions(), callback: verifySymbolUserFn);
521 return success(IsSuccess: result && !result->wasInterrupted());
522}
523
524LogicalResult detail::verifySymbol(Operation *op) {
525 // Verify the name attribute.
526 if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
527 return op->emitOpError() << "requires string attribute '"
528 << mlir::SymbolTable::getSymbolAttrName() << "'";
529
530 // Verify the visibility attribute.
531 if (Attribute vis = op->getAttr(name: mlir::SymbolTable::getVisibilityAttrName())) {
532 StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
533 if (!visStrAttr)
534 return op->emitOpError() << "requires visibility attribute '"
535 << mlir::SymbolTable::getVisibilityAttrName()
536 << "' to be a string attribute, but got " << vis;
537
538 if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
539 visStrAttr.getValue()))
540 return op->emitOpError()
541 << "visibility expected to be one of [\"public\", \"private\", "
542 "\"nested\"], but got "
543 << visStrAttr;
544 }
545 return success();
546}
547
548//===----------------------------------------------------------------------===//
549// Symbol Use Lists
550//===----------------------------------------------------------------------===//
551
552/// Walk all of the symbol references within the given operation, invoking the
553/// provided callback for each found use. The callbacks takes the use of the
554/// symbol.
555static WalkResult
556walkSymbolRefs(Operation *op,
557 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
558 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
559 [&](SymbolRefAttr symbolRef) {
560 if (callback({op, symbolRef}).wasInterrupted())
561 return WalkResult::interrupt();
562
563 // Don't walk nested references.
564 return WalkResult::skip();
565 });
566}
567
568/// Walk all of the uses, for any symbol, that are nested within the given
569/// regions, invoking the provided callback for each. This does not traverse
570/// into any nested symbol tables.
571static std::optional<WalkResult>
572walkSymbolUses(MutableArrayRef<Region> regions,
573 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
574 return walkSymbolTable(regions,
575 callback: [&](Operation *op) -> std::optional<WalkResult> {
576 // Check that this isn't a potentially unknown symbol
577 // table.
578 if (isPotentiallyUnknownSymbolTable(op))
579 return std::nullopt;
580
581 return walkSymbolRefs(op, callback);
582 });
583}
584/// Walk all of the uses, for any symbol, that are nested within the given
585/// operation 'from', invoking the provided callback for each. This does not
586/// traverse into any nested symbol tables.
587static std::optional<WalkResult>
588walkSymbolUses(Operation *from,
589 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
590 // If this operation has regions, and it, as well as its dialect, isn't
591 // registered then conservatively fail. The operation may define a
592 // symbol table, so we can't opaquely know if we should traverse to find
593 // nested uses.
594 if (isPotentiallyUnknownSymbolTable(op: from))
595 return std::nullopt;
596
597 // Walk the uses on this operation.
598 if (walkSymbolRefs(op: from, callback).wasInterrupted())
599 return WalkResult::interrupt();
600
601 // Only recurse if this operation is not a symbol table. A symbol table
602 // defines a new scope, so we can't walk the attributes from within the symbol
603 // table op.
604 if (!from->hasTrait<OpTrait::SymbolTable>())
605 return walkSymbolUses(regions: from->getRegions(), callback);
606 return WalkResult::advance();
607}
608
609namespace {
610/// This class represents a single symbol scope. A symbol scope represents the
611/// set of operations nested within a symbol table that may reference symbols
612/// within that table. A symbol scope does not contain the symbol table
613/// operation itself, just its contained operations. A scope ends at leaf
614/// operations or another symbol table operation.
615struct SymbolScope {
616 /// Walk the symbol uses within this scope, invoking the given callback.
617 /// This variant is used when the callback type matches that expected by
618 /// 'walkSymbolUses'.
619 template <typename CallbackT,
620 std::enable_if_t<!std::is_same<
621 typename llvm::function_traits<CallbackT>::result_t,
622 void>::value> * = nullptr>
623 std::optional<WalkResult> walk(CallbackT cback) {
624 if (Region *region = llvm::dyn_cast_if_present<Region *>(Val&: limit))
625 return walkSymbolUses(*region, cback);
626 return walkSymbolUses(cast<Operation *>(Val&: limit), cback);
627 }
628 /// This variant is used when the callback type matches a stripped down type:
629 /// void(SymbolTable::SymbolUse use)
630 template <typename CallbackT,
631 std::enable_if_t<std::is_same<
632 typename llvm::function_traits<CallbackT>::result_t,
633 void>::value> * = nullptr>
634 std::optional<WalkResult> walk(CallbackT cback) {
635 return walk([=](SymbolTable::SymbolUse use) {
636 return cback(use), WalkResult::advance();
637 });
638 }
639
640 /// Walk all of the operations nested under the current scope without
641 /// traversing into any nested symbol tables.
642 template <typename CallbackT>
643 std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
644 if (Region *region = llvm::dyn_cast_if_present<Region *>(Val&: limit))
645 return ::walkSymbolTable(*region, cback);
646 return ::walkSymbolTable(cast<Operation *>(Val&: limit), cback);
647 }
648
649 /// The representation of the symbol within this scope.
650 SymbolRefAttr symbol;
651
652 /// The IR unit representing this scope.
653 llvm::PointerUnion<Operation *, Region *> limit;
654};
655} // namespace
656
657/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
658static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
659 Operation *limit) {
660 StringAttr symName = SymbolTable::getSymbolName(symbol);
661 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
662
663 // Compute the ancestors of 'limit'.
664 SetVector<Operation *, SmallVector<Operation *, 4>,
665 SmallPtrSet<Operation *, 4>>
666 limitAncestors;
667 Operation *limitAncestor = limit;
668 do {
669 // Check to see if 'symbol' is an ancestor of 'limit'.
670 if (limitAncestor == symbol) {
671 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
672 // doesn't support parent references.
673 if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
674 symbol->getParentOp())
675 return {{SymbolRefAttr::get(symName), limit}};
676 return {};
677 }
678
679 limitAncestors.insert(X: limitAncestor);
680 } while ((limitAncestor = limitAncestor->getParentOp()));
681
682 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
683 Operation *commonAncestor = symbol->getParentOp();
684 do {
685 if (limitAncestors.count(key: commonAncestor))
686 break;
687 } while ((commonAncestor = commonAncestor->getParentOp()));
688 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
689
690 // Compute the set of valid nested references for 'symbol' as far up to the
691 // common ancestor as possible.
692 SmallVector<SymbolRefAttr, 2> references;
693 bool collectedAllReferences = succeeded(
694 collectValidReferencesFor(symbol, symName, commonAncestor, references));
695
696 // Handle the case where the common ancestor is 'limit'.
697 if (commonAncestor == limit) {
698 SmallVector<SymbolScope, 2> scopes;
699
700 // Walk each of the ancestors of 'symbol', calling the compute function for
701 // each one.
702 Operation *limitIt = symbol->getParentOp();
703 for (size_t i = 0, e = references.size(); i != e;
704 ++i, limitIt = limitIt->getParentOp()) {
705 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
706 scopes.push_back(Elt: {references[i], &limitIt->getRegion(index: 0)});
707 }
708 return scopes;
709 }
710
711 // Otherwise, we just need the symbol reference for 'symbol' that will be
712 // used within 'limit'. This is the last reference in the list we computed
713 // above if we were able to collect all references.
714 if (!collectedAllReferences)
715 return {};
716 return {{references.back(), limit}};
717}
718static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
719 Region *limit) {
720 auto scopes = collectSymbolScopes(symbol, limit: limit->getParentOp());
721
722 // If we collected some scopes to walk, make sure to constrain the one for
723 // limit to the specific region requested.
724 if (!scopes.empty())
725 scopes.back().limit = limit;
726 return scopes;
727}
728static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
729 Region *limit) {
730 return {{SymbolRefAttr::get(symbol), limit}};
731}
732
733static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
734 Operation *limit) {
735 SmallVector<SymbolScope, 1> scopes;
736 auto symbolRef = SymbolRefAttr::get(symbol);
737 for (auto &region : limit->getRegions())
738 scopes.push_back(Elt: {symbolRef, &region});
739 return scopes;
740}
741
742/// Returns true if the given reference 'SubRef' is a sub reference of the
743/// reference 'ref', i.e. 'ref' is a further qualified reference.
744static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
745 if (ref == subRef)
746 return true;
747
748 // If the references are not pointer equal, check to see if `subRef` is a
749 // prefix of `ref`.
750 if (llvm::isa<FlatSymbolRefAttr>(ref) ||
751 ref.getRootReference() != subRef.getRootReference())
752 return false;
753
754 auto refLeafs = ref.getNestedReferences();
755 auto subRefLeafs = subRef.getNestedReferences();
756 return subRefLeafs.size() < refLeafs.size() &&
757 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
758}
759
760//===----------------------------------------------------------------------===//
761// SymbolTable::getSymbolUses
762//===----------------------------------------------------------------------===//
763
764/// The implementation of SymbolTable::getSymbolUses below.
765template <typename FromT>
766static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
767 std::vector<SymbolTable::SymbolUse> uses;
768 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
769 uses.push_back(x: symbolUse);
770 return WalkResult::advance();
771 };
772 auto result = walkSymbolUses(from, walkFn);
773 return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
774 : std::nullopt;
775}
776
777/// Get an iterator range for all of the uses, for any symbol, that are nested
778/// within the given operation 'from'. This does not traverse into any nested
779/// symbol tables, and will also only return uses on 'from' if it does not
780/// also define a symbol table. This is because we treat the region as the
781/// boundary of the symbol table, and not the op itself. This function returns
782/// std::nullopt if there are any unknown operations that may potentially be
783/// symbol tables.
784auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
785 return getSymbolUsesImpl(from);
786}
787auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
788 return getSymbolUsesImpl(from: MutableArrayRef<Region>(*from));
789}
790
791//===----------------------------------------------------------------------===//
792// SymbolTable::getSymbolUses
793//===----------------------------------------------------------------------===//
794
795/// The implementation of SymbolTable::getSymbolUses below.
796template <typename SymbolT, typename IRUnitT>
797static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
798 IRUnitT *limit) {
799 std::vector<SymbolTable::SymbolUse> uses;
800 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
801 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
802 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
803 uses.push_back(x: symbolUse);
804 }))
805 return std::nullopt;
806 }
807 return SymbolTable::UseRange(std::move(uses));
808}
809
810/// Get all of the uses of the given symbol that are nested within the given
811/// operation 'from'. This does not traverse into any nested symbol tables.
812/// This function returns std::nullopt if there are any unknown operations that
813/// may potentially be symbol tables.
814auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
815 -> std::optional<UseRange> {
816 return getSymbolUsesImpl(symbol, from);
817}
818auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
819 -> std::optional<UseRange> {
820 return getSymbolUsesImpl(symbol, limit: from);
821}
822auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
823 -> std::optional<UseRange> {
824 return getSymbolUsesImpl(symbol, from);
825}
826auto SymbolTable::getSymbolUses(Operation *symbol, Region *from)
827 -> std::optional<UseRange> {
828 return getSymbolUsesImpl(symbol, limit: from);
829}
830
831//===----------------------------------------------------------------------===//
832// SymbolTable::symbolKnownUseEmpty
833//===----------------------------------------------------------------------===//
834
835/// The implementation of SymbolTable::symbolKnownUseEmpty below.
836template <typename SymbolT, typename IRUnitT>
837static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
838 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
839 // Walk all of the symbol uses looking for a reference to 'symbol'.
840 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
841 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
842 ? WalkResult::interrupt()
843 : WalkResult::advance();
844 }) != WalkResult::advance())
845 return false;
846 }
847 return true;
848}
849
850/// Return if the given symbol is known to have no uses that are nested within
851/// the given operation 'from'. This does not traverse into any nested symbol
852/// tables. This function will also return false if there are any unknown
853/// operations that may potentially be symbol tables.
854bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
855 return symbolKnownUseEmptyImpl(symbol, from);
856}
857bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
858 return symbolKnownUseEmptyImpl(symbol, limit: from);
859}
860bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
861 return symbolKnownUseEmptyImpl(symbol, from);
862}
863bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
864 return symbolKnownUseEmptyImpl(symbol, limit: from);
865}
866
867//===----------------------------------------------------------------------===//
868// SymbolTable::replaceAllSymbolUses
869//===----------------------------------------------------------------------===//
870
871/// Generates a new symbol reference attribute with a new leaf reference.
872static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
873 FlatSymbolRefAttr newLeafAttr) {
874 if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
875 return newLeafAttr;
876 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
877 nestedRefs.back() = newLeafAttr;
878 return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
879}
880
881/// The implementation of SymbolTable::replaceAllSymbolUses below.
882template <typename SymbolT, typename IRUnitT>
883static LogicalResult
884replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
885 // Generate a new attribute to replace the given attribute.
886 FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
887 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
888 SymbolRefAttr oldAttr = scope.symbol;
889 SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
890 AttrTypeReplacer replacer;
891 replacer.addReplacement(
892 [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
893 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
894 // want to accidentally replace an inner reference.
895 if (attr == oldAttr)
896 return {newAttr, WalkResult::skip()};
897 // Handle prefix matches.
898 if (isReferencePrefixOf(oldAttr, attr)) {
899 auto oldNestedRefs = oldAttr.getNestedReferences();
900 auto nestedRefs = attr.getNestedReferences();
901 if (oldNestedRefs.empty())
902 return {SymbolRefAttr::get(newSymbol, nestedRefs),
903 WalkResult::skip()};
904
905 auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
906 newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
907 return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
908 WalkResult::skip()};
909 }
910 return {attr, WalkResult::skip()};
911 });
912
913 auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
914 replacer.replaceElementsIn(op);
915 return WalkResult::advance();
916 };
917 if (!scope.walkSymbolTable(walkFn))
918 return failure();
919 }
920 return success();
921}
922
923/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
924/// provided symbol 'newSymbol' that are nested within the given operation
925/// 'from'. This does not traverse into any nested symbol tables. If there are
926/// any unknown operations that may potentially be symbol tables, no uses are
927/// replaced and failure is returned.
928LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
929 StringAttr newSymbol,
930 Operation *from) {
931 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
932}
933LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
934 StringAttr newSymbol,
935 Operation *from) {
936 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
937}
938LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
939 StringAttr newSymbol,
940 Region *from) {
941 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
942}
943LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
944 StringAttr newSymbol,
945 Region *from) {
946 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
947}
948
949//===----------------------------------------------------------------------===//
950// SymbolTableCollection
951//===----------------------------------------------------------------------===//
952
953Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
954 StringAttr symbol) {
955 return getSymbolTable(op: symbolTableOp).lookup(symbol);
956}
957Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
958 SymbolRefAttr name) {
959 SmallVector<Operation *, 4> symbols;
960 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
961 return nullptr;
962 return symbols.back();
963}
964/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
965/// a given SymbolRefAttr. Returns failure if any of the nested references could
966/// not be resolved.
967LogicalResult
968SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
969 SymbolRefAttr name,
970 SmallVectorImpl<Operation *> &symbols) {
971 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
972 return lookupSymbolIn(symbolTableOp, symbol);
973 };
974 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
975}
976
977/// Returns the operation registered with the given symbol name within the
978/// closest parent operation of, or including, 'from' with the
979/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
980/// found.
981Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
982 StringAttr symbol) {
983 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
984 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
985}
986Operation *
987SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
988 SymbolRefAttr symbol) {
989 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
990 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
991}
992
993/// Lookup, or create, a symbol table for an operation.
994SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
995 auto it = symbolTables.try_emplace(Key: op, Args: nullptr);
996 if (it.second)
997 it.first->second = std::make_unique<SymbolTable>(args&: op);
998 return *it.first->second;
999}
1000
1001void SymbolTableCollection::invalidateSymbolTable(Operation *op) {
1002 symbolTables.erase(Val: op);
1003}
1004
1005//===----------------------------------------------------------------------===//
1006// LockedSymbolTableCollection
1007//===----------------------------------------------------------------------===//
1008
1009Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1010 StringAttr symbol) {
1011 return getSymbolTable(symbolTableOp).lookup(symbol);
1012}
1013
1014Operation *
1015LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1016 FlatSymbolRefAttr symbol) {
1017 return lookupSymbolIn(symbolTableOp, symbol.getAttr());
1018}
1019
1020Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1021 SymbolRefAttr name) {
1022 SmallVector<Operation *> symbols;
1023 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
1024 return nullptr;
1025 return symbols.back();
1026}
1027
1028LogicalResult LockedSymbolTableCollection::lookupSymbolIn(
1029 Operation *symbolTableOp, SymbolRefAttr name,
1030 SmallVectorImpl<Operation *> &symbols) {
1031 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
1032 return lookupSymbolIn(symbolTableOp, symbol);
1033 };
1034 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
1035}
1036
1037SymbolTable &
1038LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
1039 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
1040 // Try to find an existing symbol table.
1041 {
1042 llvm::sys::SmartScopedReader<true> lock(mutex);
1043 auto it = collection.symbolTables.find(Val: symbolTableOp);
1044 if (it != collection.symbolTables.end())
1045 return *it->second;
1046 }
1047 // Create a symbol table for the operation. Perform construction outside of
1048 // the critical section.
1049 auto symbolTable = std::make_unique<SymbolTable>(args&: symbolTableOp);
1050 // Insert the constructed symbol table.
1051 llvm::sys::SmartScopedWriter<true> lock(mutex);
1052 return *collection.symbolTables
1053 .insert(KV: {symbolTableOp, std::move(symbolTable)})
1054 .first->second;
1055}
1056
1057//===----------------------------------------------------------------------===//
1058// SymbolUserMap
1059//===----------------------------------------------------------------------===//
1060
1061SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
1062 Operation *symbolTableOp)
1063 : symbolTable(symbolTable) {
1064 // Walk each of the symbol tables looking for discardable callgraph nodes.
1065 SmallVector<Operation *> symbols;
1066 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1067 for (Operation &nestedOp : symbolTableOp->getRegion(index: 0).getOps()) {
1068 auto symbolUses = SymbolTable::getSymbolUses(from: &nestedOp);
1069 assert(symbolUses && "expected uses to be valid");
1070
1071 for (const SymbolTable::SymbolUse &use : *symbolUses) {
1072 symbols.clear();
1073 (void)symbolTable.lookupSymbolIn(symbolTableOp, name: use.getSymbolRef(),
1074 symbols);
1075 for (Operation *symbolOp : symbols)
1076 symbolToUsers[symbolOp].insert(X: use.getUser());
1077 }
1078 }
1079 };
1080 // We just set `allSymUsesVisible` to false here because it isn't necessary
1081 // for building the user map.
1082 SymbolTable::walkSymbolTables(op: symbolTableOp, /*allSymUsesVisible=*/false,
1083 callback: walkFn);
1084}
1085
1086void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
1087 StringAttr newSymbolName) {
1088 auto it = symbolToUsers.find(Val: symbol);
1089 if (it == symbolToUsers.end())
1090 return;
1091
1092 // Replace the uses within the users of `symbol`.
1093 for (Operation *user : it->second)
1094 (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1095
1096 // Move the current users of `symbol` to the new symbol if it is in the
1097 // symbol table.
1098 Operation *newSymbol =
1099 symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1100 if (newSymbol != symbol) {
1101 // Transfer over the users to the new symbol. The reference to the old one
1102 // is fetched again as the iterator is invalidated during the insertion.
1103 auto newIt = symbolToUsers.try_emplace(Key: newSymbol);
1104 auto oldIt = symbolToUsers.find(Val: symbol);
1105 assert(oldIt != symbolToUsers.end() && "missing old users list");
1106 if (newIt.second)
1107 newIt.first->second = std::move(oldIt->second);
1108 else
1109 newIt.first->second.set_union(oldIt->second);
1110 symbolToUsers.erase(I: oldIt);
1111 }
1112}
1113
1114//===----------------------------------------------------------------------===//
1115// Visibility parsing implementation.
1116//===----------------------------------------------------------------------===//
1117
1118ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,
1119 NamedAttrList &attrs) {
1120 StringRef visibility;
1121 if (parser.parseOptionalKeyword(keyword: &visibility, allowedValues: {"public", "private", "nested"}))
1122 return failure();
1123
1124 StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1125 attrs.push_back(newAttribute: parser.getBuilder().getNamedAttr(
1126 name: SymbolTable::getVisibilityAttrName(), val: visibilityAttr));
1127 return success();
1128}
1129
1130//===----------------------------------------------------------------------===//
1131// Symbol Interfaces
1132//===----------------------------------------------------------------------===//
1133
1134/// Include the generated symbol interfaces.
1135#include "mlir/IR/SymbolInterfaces.cpp.inc"
1136

source code of mlir/lib/IR/SymbolTable.cpp