1//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/SymbolTable.h"
10#include "mlir/IR/Builders.h"
11#include "mlir/IR/OpImplementation.h"
12#include "llvm/ADT/SetVector.h"
13#include "llvm/ADT/SmallPtrSet.h"
14#include "llvm/ADT/SmallString.h"
15#include "llvm/ADT/StringSwitch.h"
16#include <optional>
17
18using namespace mlir;
19
20/// Return true if the given operation is unknown and may potentially define a
21/// symbol table.
22static bool isPotentiallyUnknownSymbolTable(Operation *op) {
23 return op->getNumRegions() == 1 && !op->getDialect();
24}
25
26/// Returns the string name of the given symbol, or null if this is not a
27/// symbol.
28static StringAttr getNameIfSymbol(Operation *op) {
29 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
30}
31static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
32 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
33}
34
35/// Computes the nested symbol reference attribute for the symbol 'symbolName'
36/// that are usable within the symbol table operations from 'symbol' as far up
37/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
38/// Returns success if all references up to 'within' could be computed.
39static LogicalResult
40collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
41 Operation *within,
42 SmallVectorImpl<SymbolRefAttr> &results) {
43 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
44 MLIRContext *ctx = symbol->getContext();
45
46 auto leafRef = FlatSymbolRefAttr::get(symbolName);
47 results.push_back(leafRef);
48
49 // Early exit for when 'within' is the parent of 'symbol'.
50 Operation *symbolTableOp = symbol->getParentOp();
51 if (within == symbolTableOp)
52 return success();
53
54 // Collect references until 'symbolTableOp' reaches 'within'.
55 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
56 StringAttr symbolNameId =
57 StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
58 do {
59 // Each parent of 'symbol' should define a symbol table.
60 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
61 return failure();
62 // Each parent of 'symbol' should also be a symbol.
63 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
64 if (!symbolTableName)
65 return failure();
66 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
67
68 symbolTableOp = symbolTableOp->getParentOp();
69 if (symbolTableOp == within)
70 break;
71 nestedRefs.insert(nestedRefs.begin(),
72 FlatSymbolRefAttr::get(symbolTableName));
73 } while (true);
74 return success();
75}
76
77/// Walk all of the operations within the given set of regions, without
78/// traversing into any nested symbol tables. Stops walking if the result of the
79/// callback is anything other than `WalkResult::advance`.
80static std::optional<WalkResult>
81walkSymbolTable(MutableArrayRef<Region> regions,
82 function_ref<std::optional<WalkResult>(Operation *)> callback) {
83 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(Range&: regions));
84 while (!worklist.empty()) {
85 for (Operation &op : worklist.pop_back_val()->getOps()) {
86 std::optional<WalkResult> result = callback(&op);
87 if (result != WalkResult::advance())
88 return result;
89
90 // If this op defines a new symbol table scope, we can't traverse. Any
91 // symbol references nested within 'op' are different semantically.
92 if (!op.hasTrait<OpTrait::SymbolTable>()) {
93 for (Region &region : op.getRegions())
94 worklist.push_back(Elt: &region);
95 }
96 }
97 }
98 return WalkResult::advance();
99}
100
101/// Walk all of the operations nested under, and including, the given operation,
102/// without traversing into any nested symbol tables. Stops walking if the
103/// result of the callback is anything other than `WalkResult::advance`.
104static std::optional<WalkResult>
105walkSymbolTable(Operation *op,
106 function_ref<std::optional<WalkResult>(Operation *)> callback) {
107 std::optional<WalkResult> result = callback(op);
108 if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
109 return result;
110 return walkSymbolTable(regions: op->getRegions(), callback);
111}
112
113//===----------------------------------------------------------------------===//
114// SymbolTable
115//===----------------------------------------------------------------------===//
116
117/// Build a symbol table with the symbols within the given operation.
118SymbolTable::SymbolTable(Operation *symbolTableOp)
119 : symbolTableOp(symbolTableOp) {
120 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
121 "expected operation to have SymbolTable trait");
122 assert(symbolTableOp->getNumRegions() == 1 &&
123 "expected operation to have a single region");
124 assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) &&
125 "expected operation to have a single block");
126
127 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
128 SymbolTable::getSymbolAttrName());
129 for (auto &op : symbolTableOp->getRegion(index: 0).front()) {
130 StringAttr name = getNameIfSymbol(&op, symbolNameId);
131 if (!name)
132 continue;
133
134 auto inserted = symbolTable.insert({name, &op});
135 (void)inserted;
136 assert(inserted.second &&
137 "expected region to contain uniquely named symbol operations");
138 }
139}
140
141/// Look up a symbol with the specified name, returning null if no such name
142/// exists. Names never include the @ on them.
143Operation *SymbolTable::lookup(StringRef name) const {
144 return lookup(StringAttr::get(symbolTableOp->getContext(), name));
145}
146Operation *SymbolTable::lookup(StringAttr name) const {
147 return symbolTable.lookup(Val: name);
148}
149
150void SymbolTable::remove(Operation *op) {
151 StringAttr name = getNameIfSymbol(op);
152 assert(name && "expected valid 'name' attribute");
153 assert(op->getParentOp() == symbolTableOp &&
154 "expected this operation to be inside of the operation with this "
155 "SymbolTable");
156
157 auto it = symbolTable.find(name);
158 if (it != symbolTable.end() && it->second == op)
159 symbolTable.erase(it);
160}
161
162void SymbolTable::erase(Operation *symbol) {
163 remove(op: symbol);
164 symbol->erase();
165}
166
167// TODO: Consider if this should be renamed to something like insertOrUpdate
168/// Insert a new symbol into the table and associated operation if not already
169/// there and rename it as necessary to avoid collisions. Return the name of
170/// the symbol after insertion as attribute.
171StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
172 // The symbol cannot be the child of another op and must be the child of the
173 // symbolTableOp after this.
174 //
175 // TODO: consider if SymbolTable's constructor should behave the same.
176 if (!symbol->getParentOp()) {
177 auto &body = symbolTableOp->getRegion(index: 0).front();
178 if (insertPt == Block::iterator()) {
179 insertPt = Block::iterator(body.end());
180 } else {
181 assert((insertPt == body.end() ||
182 insertPt->getParentOp() == symbolTableOp) &&
183 "expected insertPt to be in the associated module operation");
184 }
185 // Insert before the terminator, if any.
186 if (insertPt == Block::iterator(body.end()) && !body.empty() &&
187 std::prev(x: body.end())->hasTrait<OpTrait::IsTerminator>())
188 insertPt = std::prev(x: body.end());
189
190 body.getOperations().insert(where: insertPt, New: symbol);
191 }
192 assert(symbol->getParentOp() == symbolTableOp &&
193 "symbol is already inserted in another op");
194
195 // Add this symbol to the symbol table, uniquing the name if a conflict is
196 // detected.
197 StringAttr name = getSymbolName(symbol);
198 if (symbolTable.insert({name, symbol}).second)
199 return name;
200 // If the symbol was already in the table, also return.
201 if (symbolTable.lookup(Val: name) == symbol)
202 return name;
203
204 MLIRContext *context = symbol->getContext();
205 SmallString<128> nameBuffer = generateSymbolName<128>(
206 name.getValue(),
207 [&](StringRef candidate) {
208 return !symbolTable
209 .insert({StringAttr::get(context, candidate), symbol})
210 .second;
211 },
212 uniquingCounter);
213 setSymbolName(symbol, name: nameBuffer);
214 return getSymbolName(symbol);
215}
216
217LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) {
218 Operation *op = lookup(from);
219 return rename(op, to);
220}
221
222LogicalResult SymbolTable::rename(Operation *op, StringAttr to) {
223 StringAttr from = getNameIfSymbol(op);
224 (void)from;
225
226 assert(from && "expected valid 'name' attribute");
227 assert(op->getParentOp() == symbolTableOp &&
228 "expected this operation to be inside of the operation with this "
229 "SymbolTable");
230 assert(lookup(from) == op && "current name does not resolve to op");
231 assert(lookup(to) == nullptr && "new name already exists");
232
233 if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp())))
234 return failure();
235
236 // Remove op with old name, change name, add with new name. The order is
237 // important here due to how `remove` and `insert` rely on the op name.
238 remove(op);
239 setSymbolName(op, to);
240 insert(op);
241
242 assert(lookup(to) == op && "new name does not resolve to renamed op");
243 assert(lookup(from) == nullptr && "old name still exists");
244
245 return success();
246}
247
248LogicalResult SymbolTable::rename(StringAttr from, StringRef to) {
249 auto toAttr = StringAttr::get(getOp()->getContext(), to);
250 return rename(from, toAttr);
251}
252
253LogicalResult SymbolTable::rename(Operation *op, StringRef to) {
254 auto toAttr = StringAttr::get(getOp()->getContext(), to);
255 return rename(op, toAttr);
256}
257
258FailureOr<StringAttr>
259SymbolTable::renameToUnique(StringAttr oldName,
260 ArrayRef<SymbolTable *> others) {
261
262 // Determine new name that is unique in all symbol tables.
263 StringAttr newName;
264 {
265 MLIRContext *context = oldName.getContext();
266 SmallString<64> prefix = oldName.getValue();
267 int uniqueId = 0;
268 prefix.push_back(Elt: '_');
269 while (true) {
270 newName = StringAttr::get(context, prefix + Twine(uniqueId++));
271 auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); };
272 if (!lookupNewName(this) && llvm::none_of(Range&: others, P: lookupNewName)) {
273 break;
274 }
275 }
276 }
277
278 // Apply renaming.
279 if (failed(rename(oldName, newName)))
280 return failure();
281 return newName;
282}
283
284FailureOr<StringAttr>
285SymbolTable::renameToUnique(Operation *op, ArrayRef<SymbolTable *> others) {
286 StringAttr from = getNameIfSymbol(op);
287 assert(from && "expected valid 'name' attribute");
288 return renameToUnique(from, others);
289}
290
291/// Returns the name of the given symbol operation.
292StringAttr SymbolTable::getSymbolName(Operation *symbol) {
293 StringAttr name = getNameIfSymbol(symbol);
294 assert(name && "expected valid symbol name");
295 return name;
296}
297
298/// Sets the name of the given symbol operation.
299void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
300 symbol->setAttr(getSymbolAttrName(), name);
301}
302
303/// Returns the visibility of the given symbol operation.
304SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) {
305 // If the attribute doesn't exist, assume public.
306 StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
307 if (!vis)
308 return Visibility::Public;
309
310 // Otherwise, switch on the string value.
311 return StringSwitch<Visibility>(vis.getValue())
312 .Case(S: "private", Value: Visibility::Private)
313 .Case(S: "nested", Value: Visibility::Nested)
314 .Case(S: "public", Value: Visibility::Public);
315}
316/// Sets the visibility of the given symbol operation.
317void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
318 MLIRContext *ctx = symbol->getContext();
319
320 // If the visibility is public, just drop the attribute as this is the
321 // default.
322 if (vis == Visibility::Public) {
323 symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
324 return;
325 }
326
327 // Otherwise, update the attribute.
328 assert((vis == Visibility::Private || vis == Visibility::Nested) &&
329 "unknown symbol visibility kind");
330
331 StringRef visName = vis == Visibility::Private ? "private" : "nested";
332 symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
333}
334
335/// Returns the nearest symbol table from a given operation `from`. Returns
336/// nullptr if no valid parent symbol table could be found.
337Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
338 assert(from && "expected valid operation");
339 if (isPotentiallyUnknownSymbolTable(op: from))
340 return nullptr;
341
342 while (!from->hasTrait<OpTrait::SymbolTable>()) {
343 from = from->getParentOp();
344
345 // Check that this is a valid op and isn't an unknown symbol table.
346 if (!from || isPotentiallyUnknownSymbolTable(op: from))
347 return nullptr;
348 }
349 return from;
350}
351
352/// Walks all symbol table operations nested within, and including, `op`. For
353/// each symbol table operation, the provided callback is invoked with the op
354/// and a boolean signifying if the symbols within that symbol table can be
355/// treated as if all uses are visible. `allSymUsesVisible` identifies whether
356/// all of the symbol uses of symbols within `op` are visible.
357void SymbolTable::walkSymbolTables(
358 Operation *op, bool allSymUsesVisible,
359 function_ref<void(Operation *, bool)> callback) {
360 bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
361 if (isSymbolTable) {
362 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
363 allSymUsesVisible |= !symbol || symbol.isPrivate();
364 } else {
365 // Otherwise if 'op' is not a symbol table, any nested symbols are
366 // guaranteed to be hidden.
367 allSymUsesVisible = true;
368 }
369
370 for (Region &region : op->getRegions())
371 for (Block &block : region)
372 for (Operation &nestedOp : block)
373 walkSymbolTables(op: &nestedOp, allSymUsesVisible, callback);
374
375 // If 'op' had the symbol table trait, visit it after any nested symbol
376 // tables.
377 if (isSymbolTable)
378 callback(op, allSymUsesVisible);
379}
380
381/// Returns the operation registered with the given symbol name with the
382/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
383/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
384/// was found.
385Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
386 StringAttr symbol) {
387 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
388 Region &region = symbolTableOp->getRegion(index: 0);
389 if (region.empty())
390 return nullptr;
391
392 // Look for a symbol with the given name.
393 StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
394 SymbolTable::getSymbolAttrName());
395 for (auto &op : region.front())
396 if (getNameIfSymbol(&op, symbolNameId) == symbol)
397 return &op;
398 return nullptr;
399}
400Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
401 SymbolRefAttr symbol) {
402 SmallVector<Operation *, 4> resolvedSymbols;
403 if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
404 return nullptr;
405 return resolvedSymbols.back();
406}
407
408/// Internal implementation of `lookupSymbolIn` that allows for specialized
409/// implementations of the lookup function.
410static LogicalResult lookupSymbolInImpl(
411 Operation *symbolTableOp, SymbolRefAttr symbol,
412 SmallVectorImpl<Operation *> &symbols,
413 function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
414 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
415
416 // Lookup the root reference for this symbol.
417 symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
418 if (!symbolTableOp)
419 return failure();
420 symbols.push_back(Elt: symbolTableOp);
421
422 // If there are no nested references, just return the root symbol directly.
423 ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
424 if (nestedRefs.empty())
425 return success();
426
427 // Verify that the root is also a symbol table.
428 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
429 return failure();
430
431 // Otherwise, lookup each of the nested non-leaf references and ensure that
432 // each corresponds to a valid symbol table.
433 for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
434 symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
435 if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
436 return failure();
437 symbols.push_back(symbolTableOp);
438 }
439 symbols.push_back(Elt: lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
440 return success(isSuccess: symbols.back());
441}
442
443LogicalResult
444SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
445 SmallVectorImpl<Operation *> &symbols) {
446 auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
447 return lookupSymbolIn(symbolTableOp, symbol);
448 };
449 return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
450}
451
452/// Returns the operation registered with the given symbol name within the
453/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
454/// nullptr if no valid symbol was found.
455Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
456 StringAttr symbol) {
457 Operation *symbolTableOp = getNearestSymbolTable(from);
458 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
459}
460Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
461 SymbolRefAttr symbol) {
462 Operation *symbolTableOp = getNearestSymbolTable(from);
463 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
464}
465
466raw_ostream &mlir::operator<<(raw_ostream &os,
467 SymbolTable::Visibility visibility) {
468 switch (visibility) {
469 case SymbolTable::Visibility::Public:
470 return os << "public";
471 case SymbolTable::Visibility::Private:
472 return os << "private";
473 case SymbolTable::Visibility::Nested:
474 return os << "nested";
475 }
476 llvm_unreachable("Unexpected visibility");
477}
478
479//===----------------------------------------------------------------------===//
480// SymbolTable Trait Types
481//===----------------------------------------------------------------------===//
482
483LogicalResult detail::verifySymbolTable(Operation *op) {
484 if (op->getNumRegions() != 1)
485 return op->emitOpError()
486 << "Operations with a 'SymbolTable' must have exactly one region";
487 if (!llvm::hasSingleElement(C&: op->getRegion(index: 0)))
488 return op->emitOpError()
489 << "Operations with a 'SymbolTable' must have exactly one block";
490
491 // Check that all symbols are uniquely named within child regions.
492 DenseMap<Attribute, Location> nameToOrigLoc;
493 for (auto &block : op->getRegion(index: 0)) {
494 for (auto &op : block) {
495 // Check for a symbol name attribute.
496 auto nameAttr =
497 op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName());
498 if (!nameAttr)
499 continue;
500
501 // Try to insert this symbol into the table.
502 auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
503 if (!it.second)
504 return op.emitError()
505 .append("redefinition of symbol named '", nameAttr.getValue(), "'")
506 .attachNote(it.first->second)
507 .append("see existing symbol definition here");
508 }
509 }
510
511 // Verify any nested symbol user operations.
512 SymbolTableCollection symbolTable;
513 auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
514 if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
515 return WalkResult(user.verifySymbolUses(symbolTable));
516 return WalkResult::advance();
517 };
518
519 std::optional<WalkResult> result =
520 walkSymbolTable(regions: op->getRegions(), callback: verifySymbolUserFn);
521 return success(isSuccess: result && !result->wasInterrupted());
522}
523
524LogicalResult detail::verifySymbol(Operation *op) {
525 // Verify the name attribute.
526 if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
527 return op->emitOpError() << "requires string attribute '"
528 << mlir::SymbolTable::getSymbolAttrName() << "'";
529
530 // Verify the visibility attribute.
531 if (Attribute vis = op->getAttr(name: mlir::SymbolTable::getVisibilityAttrName())) {
532 StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
533 if (!visStrAttr)
534 return op->emitOpError() << "requires visibility attribute '"
535 << mlir::SymbolTable::getVisibilityAttrName()
536 << "' to be a string attribute, but got " << vis;
537
538 if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
539 visStrAttr.getValue()))
540 return op->emitOpError()
541 << "visibility expected to be one of [\"public\", \"private\", "
542 "\"nested\"], but got "
543 << visStrAttr;
544 }
545 return success();
546}
547
548//===----------------------------------------------------------------------===//
549// Symbol Use Lists
550//===----------------------------------------------------------------------===//
551
552/// Walk all of the symbol references within the given operation, invoking the
553/// provided callback for each found use. The callbacks takes the use of the
554/// symbol.
555static WalkResult
556walkSymbolRefs(Operation *op,
557 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
558 return op->getAttrDictionary().walk<WalkOrder::PreOrder>(
559 [&](SymbolRefAttr symbolRef) {
560 if (callback({op, symbolRef}).wasInterrupted())
561 return WalkResult::interrupt();
562
563 // Don't walk nested references.
564 return WalkResult::skip();
565 });
566}
567
568/// Walk all of the uses, for any symbol, that are nested within the given
569/// regions, invoking the provided callback for each. This does not traverse
570/// into any nested symbol tables.
571static std::optional<WalkResult>
572walkSymbolUses(MutableArrayRef<Region> regions,
573 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
574 return walkSymbolTable(regions,
575 callback: [&](Operation *op) -> std::optional<WalkResult> {
576 // Check that this isn't a potentially unknown symbol
577 // table.
578 if (isPotentiallyUnknownSymbolTable(op))
579 return std::nullopt;
580
581 return walkSymbolRefs(op, callback);
582 });
583}
584/// Walk all of the uses, for any symbol, that are nested within the given
585/// operation 'from', invoking the provided callback for each. This does not
586/// traverse into any nested symbol tables.
587static std::optional<WalkResult>
588walkSymbolUses(Operation *from,
589 function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
590 // If this operation has regions, and it, as well as its dialect, isn't
591 // registered then conservatively fail. The operation may define a
592 // symbol table, so we can't opaquely know if we should traverse to find
593 // nested uses.
594 if (isPotentiallyUnknownSymbolTable(op: from))
595 return std::nullopt;
596
597 // Walk the uses on this operation.
598 if (walkSymbolRefs(op: from, callback).wasInterrupted())
599 return WalkResult::interrupt();
600
601 // Only recurse if this operation is not a symbol table. A symbol table
602 // defines a new scope, so we can't walk the attributes from within the symbol
603 // table op.
604 if (!from->hasTrait<OpTrait::SymbolTable>())
605 return walkSymbolUses(regions: from->getRegions(), callback);
606 return WalkResult::advance();
607}
608
609namespace {
610/// This class represents a single symbol scope. A symbol scope represents the
611/// set of operations nested within a symbol table that may reference symbols
612/// within that table. A symbol scope does not contain the symbol table
613/// operation itself, just its contained operations. A scope ends at leaf
614/// operations or another symbol table operation.
615struct SymbolScope {
616 /// Walk the symbol uses within this scope, invoking the given callback.
617 /// This variant is used when the callback type matches that expected by
618 /// 'walkSymbolUses'.
619 template <typename CallbackT,
620 std::enable_if_t<!std::is_same<
621 typename llvm::function_traits<CallbackT>::result_t,
622 void>::value> * = nullptr>
623 std::optional<WalkResult> walk(CallbackT cback) {
624 if (Region *region = llvm::dyn_cast_if_present<Region *>(Val&: limit))
625 return walkSymbolUses(*region, cback);
626 return walkSymbolUses(limit.get<Operation *>(), cback);
627 }
628 /// This variant is used when the callback type matches a stripped down type:
629 /// void(SymbolTable::SymbolUse use)
630 template <typename CallbackT,
631 std::enable_if_t<std::is_same<
632 typename llvm::function_traits<CallbackT>::result_t,
633 void>::value> * = nullptr>
634 std::optional<WalkResult> walk(CallbackT cback) {
635 return walk([=](SymbolTable::SymbolUse use) {
636 return cback(use), WalkResult::advance();
637 });
638 }
639
640 /// Walk all of the operations nested under the current scope without
641 /// traversing into any nested symbol tables.
642 template <typename CallbackT>
643 std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
644 if (Region *region = llvm::dyn_cast_if_present<Region *>(Val&: limit))
645 return ::walkSymbolTable(*region, cback);
646 return ::walkSymbolTable(limit.get<Operation *>(), cback);
647 }
648
649 /// The representation of the symbol within this scope.
650 SymbolRefAttr symbol;
651
652 /// The IR unit representing this scope.
653 llvm::PointerUnion<Operation *, Region *> limit;
654};
655} // namespace
656
657/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
658static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
659 Operation *limit) {
660 StringAttr symName = SymbolTable::getSymbolName(symbol);
661 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
662
663 // Compute the ancestors of 'limit'.
664 SetVector<Operation *, SmallVector<Operation *, 4>,
665 SmallPtrSet<Operation *, 4>>
666 limitAncestors;
667 Operation *limitAncestor = limit;
668 do {
669 // Check to see if 'symbol' is an ancestor of 'limit'.
670 if (limitAncestor == symbol) {
671 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
672 // doesn't support parent references.
673 if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
674 symbol->getParentOp())
675 return {{SymbolRefAttr::get(symName), limit}};
676 return {};
677 }
678
679 limitAncestors.insert(X: limitAncestor);
680 } while ((limitAncestor = limitAncestor->getParentOp()));
681
682 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
683 Operation *commonAncestor = symbol->getParentOp();
684 do {
685 if (limitAncestors.count(key: commonAncestor))
686 break;
687 } while ((commonAncestor = commonAncestor->getParentOp()));
688 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
689
690 // Compute the set of valid nested references for 'symbol' as far up to the
691 // common ancestor as possible.
692 SmallVector<SymbolRefAttr, 2> references;
693 bool collectedAllReferences = succeeded(
694 collectValidReferencesFor(symbol, symName, commonAncestor, references));
695
696 // Handle the case where the common ancestor is 'limit'.
697 if (commonAncestor == limit) {
698 SmallVector<SymbolScope, 2> scopes;
699
700 // Walk each of the ancestors of 'symbol', calling the compute function for
701 // each one.
702 Operation *limitIt = symbol->getParentOp();
703 for (size_t i = 0, e = references.size(); i != e;
704 ++i, limitIt = limitIt->getParentOp()) {
705 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
706 scopes.push_back(Elt: {references[i], &limitIt->getRegion(index: 0)});
707 }
708 return scopes;
709 }
710
711 // Otherwise, we just need the symbol reference for 'symbol' that will be
712 // used within 'limit'. This is the last reference in the list we computed
713 // above if we were able to collect all references.
714 if (!collectedAllReferences)
715 return {};
716 return {{references.back(), limit}};
717}
718static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
719 Region *limit) {
720 auto scopes = collectSymbolScopes(symbol, limit: limit->getParentOp());
721
722 // If we collected some scopes to walk, make sure to constrain the one for
723 // limit to the specific region requested.
724 if (!scopes.empty())
725 scopes.back().limit = limit;
726 return scopes;
727}
728static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
729 Region *limit) {
730 return {{SymbolRefAttr::get(symbol), limit}};
731}
732
733static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
734 Operation *limit) {
735 SmallVector<SymbolScope, 1> scopes;
736 auto symbolRef = SymbolRefAttr::get(symbol);
737 for (auto &region : limit->getRegions())
738 scopes.push_back(Elt: {symbolRef, &region});
739 return scopes;
740}
741
742/// Returns true if the given reference 'SubRef' is a sub reference of the
743/// reference 'ref', i.e. 'ref' is a further qualified reference.
744static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
745 if (ref == subRef)
746 return true;
747
748 // If the references are not pointer equal, check to see if `subRef` is a
749 // prefix of `ref`.
750 if (llvm::isa<FlatSymbolRefAttr>(ref) ||
751 ref.getRootReference() != subRef.getRootReference())
752 return false;
753
754 auto refLeafs = ref.getNestedReferences();
755 auto subRefLeafs = subRef.getNestedReferences();
756 return subRefLeafs.size() < refLeafs.size() &&
757 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
758}
759
760//===----------------------------------------------------------------------===//
761// SymbolTable::getSymbolUses
762
763/// The implementation of SymbolTable::getSymbolUses below.
764template <typename FromT>
765static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
766 std::vector<SymbolTable::SymbolUse> uses;
767 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
768 uses.push_back(x: symbolUse);
769 return WalkResult::advance();
770 };
771 auto result = walkSymbolUses(from, walkFn);
772 return result ? std::optional<SymbolTable::UseRange>(std::move(uses))
773 : std::nullopt;
774}
775
776/// Get an iterator range for all of the uses, for any symbol, that are nested
777/// within the given operation 'from'. This does not traverse into any nested
778/// symbol tables, and will also only return uses on 'from' if it does not
779/// also define a symbol table. This is because we treat the region as the
780/// boundary of the symbol table, and not the op itself. This function returns
781/// std::nullopt if there are any unknown operations that may potentially be
782/// symbol tables.
783auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> {
784 return getSymbolUsesImpl(from);
785}
786auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> {
787 return getSymbolUsesImpl(from: MutableArrayRef<Region>(*from));
788}
789
790//===----------------------------------------------------------------------===//
791// SymbolTable::getSymbolUses
792
793/// The implementation of SymbolTable::getSymbolUses below.
794template <typename SymbolT, typename IRUnitT>
795static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
796 IRUnitT *limit) {
797 std::vector<SymbolTable::SymbolUse> uses;
798 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
799 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
800 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
801 uses.push_back(x: symbolUse);
802 }))
803 return std::nullopt;
804 }
805 return SymbolTable::UseRange(std::move(uses));
806}
807
808/// Get all of the uses of the given symbol that are nested within the given
809/// operation 'from', invoking the provided callback for each. This does not
810/// traverse into any nested symbol tables. This function returns std::nullopt
811/// if there are any unknown operations that may potentially be symbol tables.
812auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
813 -> std::optional<UseRange> {
814 return getSymbolUsesImpl(symbol, from);
815}
816auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
817 -> std::optional<UseRange> {
818 return getSymbolUsesImpl(symbol, limit: from);
819}
820auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
821 -> std::optional<UseRange> {
822 return getSymbolUsesImpl(symbol, from);
823}
824auto SymbolTable::getSymbolUses(Operation *symbol, Region *from)
825 -> std::optional<UseRange> {
826 return getSymbolUsesImpl(symbol, limit: from);
827}
828
829//===----------------------------------------------------------------------===//
830// SymbolTable::symbolKnownUseEmpty
831
832/// The implementation of SymbolTable::symbolKnownUseEmpty below.
833template <typename SymbolT, typename IRUnitT>
834static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
835 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
836 // Walk all of the symbol uses looking for a reference to 'symbol'.
837 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
838 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
839 ? WalkResult::interrupt()
840 : WalkResult::advance();
841 }) != WalkResult::advance())
842 return false;
843 }
844 return true;
845}
846
847/// Return if the given symbol is known to have no uses that are nested within
848/// the given operation 'from'. This does not traverse into any nested symbol
849/// tables. This function will also return false if there are any unknown
850/// operations that may potentially be symbol tables.
851bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
852 return symbolKnownUseEmptyImpl(symbol, from);
853}
854bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
855 return symbolKnownUseEmptyImpl(symbol, limit: from);
856}
857bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
858 return symbolKnownUseEmptyImpl(symbol, from);
859}
860bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
861 return symbolKnownUseEmptyImpl(symbol, limit: from);
862}
863
864//===----------------------------------------------------------------------===//
865// SymbolTable::replaceAllSymbolUses
866
867/// Generates a new symbol reference attribute with a new leaf reference.
868static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
869 FlatSymbolRefAttr newLeafAttr) {
870 if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
871 return newLeafAttr;
872 auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
873 nestedRefs.back() = newLeafAttr;
874 return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
875}
876
877/// The implementation of SymbolTable::replaceAllSymbolUses below.
878template <typename SymbolT, typename IRUnitT>
879static LogicalResult
880replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
881 // Generate a new attribute to replace the given attribute.
882 FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
883 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
884 SymbolRefAttr oldAttr = scope.symbol;
885 SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
886 AttrTypeReplacer replacer;
887 replacer.addReplacement(
888 [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
889 // Regardless of the match, don't walk nested SymbolRefAttrs, we don't
890 // want to accidentally replace an inner reference.
891 if (attr == oldAttr)
892 return {newAttr, WalkResult::skip()};
893 // Handle prefix matches.
894 if (isReferencePrefixOf(oldAttr, attr)) {
895 auto oldNestedRefs = oldAttr.getNestedReferences();
896 auto nestedRefs = attr.getNestedReferences();
897 if (oldNestedRefs.empty())
898 return {SymbolRefAttr::get(newSymbol, nestedRefs),
899 WalkResult::skip()};
900
901 auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
902 newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
903 return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs),
904 WalkResult::skip()};
905 }
906 return {attr, WalkResult::skip()};
907 });
908
909 auto walkFn = [&](Operation *op) -> std::optional<WalkResult> {
910 replacer.replaceElementsIn(op);
911 return WalkResult::advance();
912 };
913 if (!scope.walkSymbolTable(walkFn))
914 return failure();
915 }
916 return success();
917}
918
919/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
920/// provided symbol 'newSymbol' that are nested within the given operation
921/// 'from'. This does not traverse into any nested symbol tables. If there are
922/// any unknown operations that may potentially be symbol tables, no uses are
923/// replaced and failure is returned.
924LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
925 StringAttr newSymbol,
926 Operation *from) {
927 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
928}
929LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
930 StringAttr newSymbol,
931 Operation *from) {
932 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
933}
934LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
935 StringAttr newSymbol,
936 Region *from) {
937 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
938}
939LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
940 StringAttr newSymbol,
941 Region *from) {
942 return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
943}
944
945//===----------------------------------------------------------------------===//
946// SymbolTableCollection
947//===----------------------------------------------------------------------===//
948
949Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
950 StringAttr symbol) {
951 return getSymbolTable(op: symbolTableOp).lookup(symbol);
952}
953Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
954 SymbolRefAttr name) {
955 SmallVector<Operation *, 4> symbols;
956 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
957 return nullptr;
958 return symbols.back();
959}
960/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
961/// a given SymbolRefAttr. Returns failure if any of the nested references could
962/// not be resolved.
963LogicalResult
964SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
965 SymbolRefAttr name,
966 SmallVectorImpl<Operation *> &symbols) {
967 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
968 return lookupSymbolIn(symbolTableOp, symbol);
969 };
970 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
971}
972
973/// Returns the operation registered with the given symbol name within the
974/// closest parent operation of, or including, 'from' with the
975/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
976/// found.
977Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
978 StringAttr symbol) {
979 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
980 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
981}
982Operation *
983SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
984 SymbolRefAttr symbol) {
985 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
986 return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
987}
988
989/// Lookup, or create, a symbol table for an operation.
990SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
991 auto it = symbolTables.try_emplace(Key: op, Args: nullptr);
992 if (it.second)
993 it.first->second = std::make_unique<SymbolTable>(args&: op);
994 return *it.first->second;
995}
996
997//===----------------------------------------------------------------------===//
998// LockedSymbolTableCollection
999//===----------------------------------------------------------------------===//
1000
1001Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1002 StringAttr symbol) {
1003 return getSymbolTable(symbolTableOp).lookup(symbol);
1004}
1005
1006Operation *
1007LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1008 FlatSymbolRefAttr symbol) {
1009 return lookupSymbolIn(symbolTableOp, symbol.getAttr());
1010}
1011
1012Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
1013 SymbolRefAttr name) {
1014 SmallVector<Operation *> symbols;
1015 if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
1016 return nullptr;
1017 return symbols.back();
1018}
1019
1020LogicalResult LockedSymbolTableCollection::lookupSymbolIn(
1021 Operation *symbolTableOp, SymbolRefAttr name,
1022 SmallVectorImpl<Operation *> &symbols) {
1023 auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
1024 return lookupSymbolIn(symbolTableOp, symbol);
1025 };
1026 return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
1027}
1028
1029SymbolTable &
1030LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) {
1031 assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
1032 // Try to find an existing symbol table.
1033 {
1034 llvm::sys::SmartScopedReader<true> lock(mutex);
1035 auto it = collection.symbolTables.find(Val: symbolTableOp);
1036 if (it != collection.symbolTables.end())
1037 return *it->second;
1038 }
1039 // Create a symbol table for the operation. Perform construction outside of
1040 // the critical section.
1041 auto symbolTable = std::make_unique<SymbolTable>(args&: symbolTableOp);
1042 // Insert the constructed symbol table.
1043 llvm::sys::SmartScopedWriter<true> lock(mutex);
1044 return *collection.symbolTables
1045 .insert(KV: {symbolTableOp, std::move(symbolTable)})
1046 .first->second;
1047}
1048
1049//===----------------------------------------------------------------------===//
1050// SymbolUserMap
1051//===----------------------------------------------------------------------===//
1052
1053SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
1054 Operation *symbolTableOp)
1055 : symbolTable(symbolTable) {
1056 // Walk each of the symbol tables looking for discardable callgraph nodes.
1057 SmallVector<Operation *> symbols;
1058 auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
1059 for (Operation &nestedOp : symbolTableOp->getRegion(index: 0).getOps()) {
1060 auto symbolUses = SymbolTable::getSymbolUses(from: &nestedOp);
1061 assert(symbolUses && "expected uses to be valid");
1062
1063 for (const SymbolTable::SymbolUse &use : *symbolUses) {
1064 symbols.clear();
1065 (void)symbolTable.lookupSymbolIn(symbolTableOp, name: use.getSymbolRef(),
1066 symbols);
1067 for (Operation *symbolOp : symbols)
1068 symbolToUsers[symbolOp].insert(X: use.getUser());
1069 }
1070 }
1071 };
1072 // We just set `allSymUsesVisible` to false here because it isn't necessary
1073 // for building the user map.
1074 SymbolTable::walkSymbolTables(op: symbolTableOp, /*allSymUsesVisible=*/false,
1075 callback: walkFn);
1076}
1077
1078void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
1079 StringAttr newSymbolName) {
1080 auto it = symbolToUsers.find(Val: symbol);
1081 if (it == symbolToUsers.end())
1082 return;
1083
1084 // Replace the uses within the users of `symbol`.
1085 for (Operation *user : it->second)
1086 (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
1087
1088 // Move the current users of `symbol` to the new symbol if it is in the
1089 // symbol table.
1090 Operation *newSymbol =
1091 symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
1092 if (newSymbol != symbol) {
1093 // Transfer over the users to the new symbol. The reference to the old one
1094 // is fetched again as the iterator is invalidated during the insertion.
1095 auto newIt = symbolToUsers.try_emplace(Key: newSymbol, Args: SetVector<Operation *>{});
1096 auto oldIt = symbolToUsers.find(Val: symbol);
1097 assert(oldIt != symbolToUsers.end() && "missing old users list");
1098 if (newIt.second)
1099 newIt.first->second = std::move(oldIt->second);
1100 else
1101 newIt.first->second.set_union(oldIt->second);
1102 symbolToUsers.erase(I: oldIt);
1103 }
1104}
1105
1106//===----------------------------------------------------------------------===//
1107// Visibility parsing implementation.
1108//===----------------------------------------------------------------------===//
1109
1110ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,
1111 NamedAttrList &attrs) {
1112 StringRef visibility;
1113 if (parser.parseOptionalKeyword(keyword: &visibility, allowedValues: {"public", "private", "nested"}))
1114 return failure();
1115
1116 StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
1117 attrs.push_back(newAttribute: parser.getBuilder().getNamedAttr(
1118 name: SymbolTable::getVisibilityAttrName(), val: visibilityAttr));
1119 return success();
1120}
1121
1122//===----------------------------------------------------------------------===//
1123// Symbol Interfaces
1124//===----------------------------------------------------------------------===//
1125
1126/// Include the generated symbol interfaces.
1127#include "mlir/IR/SymbolInterfaces.cpp.inc"
1128

source code of mlir/lib/IR/SymbolTable.cpp