1 | //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/SymbolTable.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/IR/OpImplementation.h" |
12 | #include "llvm/ADT/SetVector.h" |
13 | #include "llvm/ADT/SmallPtrSet.h" |
14 | #include "llvm/ADT/SmallString.h" |
15 | #include "llvm/ADT/StringSwitch.h" |
16 | #include <optional> |
17 | |
18 | using namespace mlir; |
19 | |
20 | /// Return true if the given operation is unknown and may potentially define a |
21 | /// symbol table. |
22 | static bool isPotentiallyUnknownSymbolTable(Operation *op) { |
23 | return op->getNumRegions() == 1 && !op->getDialect(); |
24 | } |
25 | |
26 | /// Returns the string name of the given symbol, or null if this is not a |
27 | /// symbol. |
28 | static StringAttr getNameIfSymbol(Operation *op) { |
29 | return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); |
30 | } |
31 | static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) { |
32 | return op->getAttrOfType<StringAttr>(symbolAttrNameId); |
33 | } |
34 | |
35 | /// Computes the nested symbol reference attribute for the symbol 'symbolName' |
36 | /// that are usable within the symbol table operations from 'symbol' as far up |
37 | /// to the given operation 'within', where 'within' is an ancestor of 'symbol'. |
38 | /// Returns success if all references up to 'within' could be computed. |
39 | static LogicalResult |
40 | collectValidReferencesFor(Operation *symbol, StringAttr symbolName, |
41 | Operation *within, |
42 | SmallVectorImpl<SymbolRefAttr> &results) { |
43 | assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor" ); |
44 | MLIRContext *ctx = symbol->getContext(); |
45 | |
46 | auto leafRef = FlatSymbolRefAttr::get(symbolName); |
47 | results.push_back(leafRef); |
48 | |
49 | // Early exit for when 'within' is the parent of 'symbol'. |
50 | Operation *symbolTableOp = symbol->getParentOp(); |
51 | if (within == symbolTableOp) |
52 | return success(); |
53 | |
54 | // Collect references until 'symbolTableOp' reaches 'within'. |
55 | SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef); |
56 | StringAttr symbolNameId = |
57 | StringAttr::get(ctx, SymbolTable::getSymbolAttrName()); |
58 | do { |
59 | // Each parent of 'symbol' should define a symbol table. |
60 | if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) |
61 | return failure(); |
62 | // Each parent of 'symbol' should also be a symbol. |
63 | StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId); |
64 | if (!symbolTableName) |
65 | return failure(); |
66 | results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs)); |
67 | |
68 | symbolTableOp = symbolTableOp->getParentOp(); |
69 | if (symbolTableOp == within) |
70 | break; |
71 | nestedRefs.insert(nestedRefs.begin(), |
72 | FlatSymbolRefAttr::get(symbolTableName)); |
73 | } while (true); |
74 | return success(); |
75 | } |
76 | |
77 | /// Walk all of the operations within the given set of regions, without |
78 | /// traversing into any nested symbol tables. Stops walking if the result of the |
79 | /// callback is anything other than `WalkResult::advance`. |
80 | static std::optional<WalkResult> |
81 | walkSymbolTable(MutableArrayRef<Region> regions, |
82 | function_ref<std::optional<WalkResult>(Operation *)> callback) { |
83 | SmallVector<Region *, 1> worklist(llvm::make_pointer_range(Range&: regions)); |
84 | while (!worklist.empty()) { |
85 | for (Operation &op : worklist.pop_back_val()->getOps()) { |
86 | std::optional<WalkResult> result = callback(&op); |
87 | if (result != WalkResult::advance()) |
88 | return result; |
89 | |
90 | // If this op defines a new symbol table scope, we can't traverse. Any |
91 | // symbol references nested within 'op' are different semantically. |
92 | if (!op.hasTrait<OpTrait::SymbolTable>()) { |
93 | for (Region ®ion : op.getRegions()) |
94 | worklist.push_back(Elt: ®ion); |
95 | } |
96 | } |
97 | } |
98 | return WalkResult::advance(); |
99 | } |
100 | |
101 | /// Walk all of the operations nested under, and including, the given operation, |
102 | /// without traversing into any nested symbol tables. Stops walking if the |
103 | /// result of the callback is anything other than `WalkResult::advance`. |
104 | static std::optional<WalkResult> |
105 | walkSymbolTable(Operation *op, |
106 | function_ref<std::optional<WalkResult>(Operation *)> callback) { |
107 | std::optional<WalkResult> result = callback(op); |
108 | if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>()) |
109 | return result; |
110 | return walkSymbolTable(regions: op->getRegions(), callback); |
111 | } |
112 | |
113 | //===----------------------------------------------------------------------===// |
114 | // SymbolTable |
115 | //===----------------------------------------------------------------------===// |
116 | |
117 | /// Build a symbol table with the symbols within the given operation. |
118 | SymbolTable::SymbolTable(Operation *symbolTableOp) |
119 | : symbolTableOp(symbolTableOp) { |
120 | assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() && |
121 | "expected operation to have SymbolTable trait" ); |
122 | assert(symbolTableOp->getNumRegions() == 1 && |
123 | "expected operation to have a single region" ); |
124 | assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) && |
125 | "expected operation to have a single block" ); |
126 | |
127 | StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(), |
128 | SymbolTable::getSymbolAttrName()); |
129 | for (auto &op : symbolTableOp->getRegion(index: 0).front()) { |
130 | StringAttr name = getNameIfSymbol(&op, symbolNameId); |
131 | if (!name) |
132 | continue; |
133 | |
134 | auto inserted = symbolTable.insert({name, &op}); |
135 | (void)inserted; |
136 | assert(inserted.second && |
137 | "expected region to contain uniquely named symbol operations" ); |
138 | } |
139 | } |
140 | |
141 | /// Look up a symbol with the specified name, returning null if no such name |
142 | /// exists. Names never include the @ on them. |
143 | Operation *SymbolTable::lookup(StringRef name) const { |
144 | return lookup(StringAttr::get(symbolTableOp->getContext(), name)); |
145 | } |
146 | Operation *SymbolTable::lookup(StringAttr name) const { |
147 | return symbolTable.lookup(Val: name); |
148 | } |
149 | |
150 | void SymbolTable::remove(Operation *op) { |
151 | StringAttr name = getNameIfSymbol(op); |
152 | assert(name && "expected valid 'name' attribute" ); |
153 | assert(op->getParentOp() == symbolTableOp && |
154 | "expected this operation to be inside of the operation with this " |
155 | "SymbolTable" ); |
156 | |
157 | auto it = symbolTable.find(name); |
158 | if (it != symbolTable.end() && it->second == op) |
159 | symbolTable.erase(it); |
160 | } |
161 | |
162 | void SymbolTable::erase(Operation *symbol) { |
163 | remove(op: symbol); |
164 | symbol->erase(); |
165 | } |
166 | |
167 | // TODO: Consider if this should be renamed to something like insertOrUpdate |
168 | /// Insert a new symbol into the table and associated operation if not already |
169 | /// there and rename it as necessary to avoid collisions. Return the name of |
170 | /// the symbol after insertion as attribute. |
171 | StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { |
172 | // The symbol cannot be the child of another op and must be the child of the |
173 | // symbolTableOp after this. |
174 | // |
175 | // TODO: consider if SymbolTable's constructor should behave the same. |
176 | if (!symbol->getParentOp()) { |
177 | auto &body = symbolTableOp->getRegion(index: 0).front(); |
178 | if (insertPt == Block::iterator()) { |
179 | insertPt = Block::iterator(body.end()); |
180 | } else { |
181 | assert((insertPt == body.end() || |
182 | insertPt->getParentOp() == symbolTableOp) && |
183 | "expected insertPt to be in the associated module operation" ); |
184 | } |
185 | // Insert before the terminator, if any. |
186 | if (insertPt == Block::iterator(body.end()) && !body.empty() && |
187 | std::prev(x: body.end())->hasTrait<OpTrait::IsTerminator>()) |
188 | insertPt = std::prev(x: body.end()); |
189 | |
190 | body.getOperations().insert(where: insertPt, New: symbol); |
191 | } |
192 | assert(symbol->getParentOp() == symbolTableOp && |
193 | "symbol is already inserted in another op" ); |
194 | |
195 | // Add this symbol to the symbol table, uniquing the name if a conflict is |
196 | // detected. |
197 | StringAttr name = getSymbolName(symbol); |
198 | if (symbolTable.insert({name, symbol}).second) |
199 | return name; |
200 | // If the symbol was already in the table, also return. |
201 | if (symbolTable.lookup(Val: name) == symbol) |
202 | return name; |
203 | |
204 | MLIRContext *context = symbol->getContext(); |
205 | SmallString<128> nameBuffer = generateSymbolName<128>( |
206 | name.getValue(), |
207 | [&](StringRef candidate) { |
208 | return !symbolTable |
209 | .insert({StringAttr::get(context, candidate), symbol}) |
210 | .second; |
211 | }, |
212 | uniquingCounter); |
213 | setSymbolName(symbol, name: nameBuffer); |
214 | return getSymbolName(symbol); |
215 | } |
216 | |
217 | LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) { |
218 | Operation *op = lookup(from); |
219 | return rename(op, to); |
220 | } |
221 | |
222 | LogicalResult SymbolTable::rename(Operation *op, StringAttr to) { |
223 | StringAttr from = getNameIfSymbol(op); |
224 | (void)from; |
225 | |
226 | assert(from && "expected valid 'name' attribute" ); |
227 | assert(op->getParentOp() == symbolTableOp && |
228 | "expected this operation to be inside of the operation with this " |
229 | "SymbolTable" ); |
230 | assert(lookup(from) == op && "current name does not resolve to op" ); |
231 | assert(lookup(to) == nullptr && "new name already exists" ); |
232 | |
233 | if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp()))) |
234 | return failure(); |
235 | |
236 | // Remove op with old name, change name, add with new name. The order is |
237 | // important here due to how `remove` and `insert` rely on the op name. |
238 | remove(op); |
239 | setSymbolName(op, to); |
240 | insert(op); |
241 | |
242 | assert(lookup(to) == op && "new name does not resolve to renamed op" ); |
243 | assert(lookup(from) == nullptr && "old name still exists" ); |
244 | |
245 | return success(); |
246 | } |
247 | |
248 | LogicalResult SymbolTable::rename(StringAttr from, StringRef to) { |
249 | auto toAttr = StringAttr::get(getOp()->getContext(), to); |
250 | return rename(from, toAttr); |
251 | } |
252 | |
253 | LogicalResult SymbolTable::rename(Operation *op, StringRef to) { |
254 | auto toAttr = StringAttr::get(getOp()->getContext(), to); |
255 | return rename(op, toAttr); |
256 | } |
257 | |
258 | FailureOr<StringAttr> |
259 | SymbolTable::renameToUnique(StringAttr oldName, |
260 | ArrayRef<SymbolTable *> others) { |
261 | |
262 | // Determine new name that is unique in all symbol tables. |
263 | StringAttr newName; |
264 | { |
265 | MLIRContext *context = oldName.getContext(); |
266 | SmallString<64> prefix = oldName.getValue(); |
267 | int uniqueId = 0; |
268 | prefix.push_back(Elt: '_'); |
269 | while (true) { |
270 | newName = StringAttr::get(context, prefix + Twine(uniqueId++)); |
271 | auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); }; |
272 | if (!lookupNewName(this) && llvm::none_of(Range&: others, P: lookupNewName)) { |
273 | break; |
274 | } |
275 | } |
276 | } |
277 | |
278 | // Apply renaming. |
279 | if (failed(rename(oldName, newName))) |
280 | return failure(); |
281 | return newName; |
282 | } |
283 | |
284 | FailureOr<StringAttr> |
285 | SymbolTable::renameToUnique(Operation *op, ArrayRef<SymbolTable *> others) { |
286 | StringAttr from = getNameIfSymbol(op); |
287 | assert(from && "expected valid 'name' attribute" ); |
288 | return renameToUnique(from, others); |
289 | } |
290 | |
291 | /// Returns the name of the given symbol operation. |
292 | StringAttr SymbolTable::getSymbolName(Operation *symbol) { |
293 | StringAttr name = getNameIfSymbol(symbol); |
294 | assert(name && "expected valid symbol name" ); |
295 | return name; |
296 | } |
297 | |
298 | /// Sets the name of the given symbol operation. |
299 | void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) { |
300 | symbol->setAttr(getSymbolAttrName(), name); |
301 | } |
302 | |
303 | /// Returns the visibility of the given symbol operation. |
304 | SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) { |
305 | // If the attribute doesn't exist, assume public. |
306 | StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName()); |
307 | if (!vis) |
308 | return Visibility::Public; |
309 | |
310 | // Otherwise, switch on the string value. |
311 | return StringSwitch<Visibility>(vis.getValue()) |
312 | .Case(S: "private" , Value: Visibility::Private) |
313 | .Case(S: "nested" , Value: Visibility::Nested) |
314 | .Case(S: "public" , Value: Visibility::Public); |
315 | } |
316 | /// Sets the visibility of the given symbol operation. |
317 | void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) { |
318 | MLIRContext *ctx = symbol->getContext(); |
319 | |
320 | // If the visibility is public, just drop the attribute as this is the |
321 | // default. |
322 | if (vis == Visibility::Public) { |
323 | symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName())); |
324 | return; |
325 | } |
326 | |
327 | // Otherwise, update the attribute. |
328 | assert((vis == Visibility::Private || vis == Visibility::Nested) && |
329 | "unknown symbol visibility kind" ); |
330 | |
331 | StringRef visName = vis == Visibility::Private ? "private" : "nested" ; |
332 | symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName)); |
333 | } |
334 | |
335 | /// Returns the nearest symbol table from a given operation `from`. Returns |
336 | /// nullptr if no valid parent symbol table could be found. |
337 | Operation *SymbolTable::getNearestSymbolTable(Operation *from) { |
338 | assert(from && "expected valid operation" ); |
339 | if (isPotentiallyUnknownSymbolTable(op: from)) |
340 | return nullptr; |
341 | |
342 | while (!from->hasTrait<OpTrait::SymbolTable>()) { |
343 | from = from->getParentOp(); |
344 | |
345 | // Check that this is a valid op and isn't an unknown symbol table. |
346 | if (!from || isPotentiallyUnknownSymbolTable(op: from)) |
347 | return nullptr; |
348 | } |
349 | return from; |
350 | } |
351 | |
352 | /// Walks all symbol table operations nested within, and including, `op`. For |
353 | /// each symbol table operation, the provided callback is invoked with the op |
354 | /// and a boolean signifying if the symbols within that symbol table can be |
355 | /// treated as if all uses are visible. `allSymUsesVisible` identifies whether |
356 | /// all of the symbol uses of symbols within `op` are visible. |
357 | void SymbolTable::walkSymbolTables( |
358 | Operation *op, bool allSymUsesVisible, |
359 | function_ref<void(Operation *, bool)> callback) { |
360 | bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>(); |
361 | if (isSymbolTable) { |
362 | SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op); |
363 | allSymUsesVisible |= !symbol || symbol.isPrivate(); |
364 | } else { |
365 | // Otherwise if 'op' is not a symbol table, any nested symbols are |
366 | // guaranteed to be hidden. |
367 | allSymUsesVisible = true; |
368 | } |
369 | |
370 | for (Region ®ion : op->getRegions()) |
371 | for (Block &block : region) |
372 | for (Operation &nestedOp : block) |
373 | walkSymbolTables(op: &nestedOp, allSymUsesVisible, callback); |
374 | |
375 | // If 'op' had the symbol table trait, visit it after any nested symbol |
376 | // tables. |
377 | if (isSymbolTable) |
378 | callback(op, allSymUsesVisible); |
379 | } |
380 | |
381 | /// Returns the operation registered with the given symbol name with the |
382 | /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation |
383 | /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol |
384 | /// was found. |
385 | Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, |
386 | StringAttr symbol) { |
387 | assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); |
388 | Region ®ion = symbolTableOp->getRegion(index: 0); |
389 | if (region.empty()) |
390 | return nullptr; |
391 | |
392 | // Look for a symbol with the given name. |
393 | StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(), |
394 | SymbolTable::getSymbolAttrName()); |
395 | for (auto &op : region.front()) |
396 | if (getNameIfSymbol(&op, symbolNameId) == symbol) |
397 | return &op; |
398 | return nullptr; |
399 | } |
400 | Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, |
401 | SymbolRefAttr symbol) { |
402 | SmallVector<Operation *, 4> resolvedSymbols; |
403 | if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols))) |
404 | return nullptr; |
405 | return resolvedSymbols.back(); |
406 | } |
407 | |
408 | /// Internal implementation of `lookupSymbolIn` that allows for specialized |
409 | /// implementations of the lookup function. |
410 | static LogicalResult lookupSymbolInImpl( |
411 | Operation *symbolTableOp, SymbolRefAttr symbol, |
412 | SmallVectorImpl<Operation *> &symbols, |
413 | function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) { |
414 | assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); |
415 | |
416 | // Lookup the root reference for this symbol. |
417 | symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference()); |
418 | if (!symbolTableOp) |
419 | return failure(); |
420 | symbols.push_back(Elt: symbolTableOp); |
421 | |
422 | // If there are no nested references, just return the root symbol directly. |
423 | ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences(); |
424 | if (nestedRefs.empty()) |
425 | return success(); |
426 | |
427 | // Verify that the root is also a symbol table. |
428 | if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) |
429 | return failure(); |
430 | |
431 | // Otherwise, lookup each of the nested non-leaf references and ensure that |
432 | // each corresponds to a valid symbol table. |
433 | for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) { |
434 | symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr()); |
435 | if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>()) |
436 | return failure(); |
437 | symbols.push_back(symbolTableOp); |
438 | } |
439 | symbols.push_back(Elt: lookupSymbolFn(symbolTableOp, symbol.getLeafReference())); |
440 | return success(isSuccess: symbols.back()); |
441 | } |
442 | |
443 | LogicalResult |
444 | SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol, |
445 | SmallVectorImpl<Operation *> &symbols) { |
446 | auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) { |
447 | return lookupSymbolIn(symbolTableOp, symbol); |
448 | }; |
449 | return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn); |
450 | } |
451 | |
452 | /// Returns the operation registered with the given symbol name within the |
453 | /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns |
454 | /// nullptr if no valid symbol was found. |
455 | Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, |
456 | StringAttr symbol) { |
457 | Operation *symbolTableOp = getNearestSymbolTable(from); |
458 | return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; |
459 | } |
460 | Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, |
461 | SymbolRefAttr symbol) { |
462 | Operation *symbolTableOp = getNearestSymbolTable(from); |
463 | return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; |
464 | } |
465 | |
466 | raw_ostream &mlir::operator<<(raw_ostream &os, |
467 | SymbolTable::Visibility visibility) { |
468 | switch (visibility) { |
469 | case SymbolTable::Visibility::Public: |
470 | return os << "public" ; |
471 | case SymbolTable::Visibility::Private: |
472 | return os << "private" ; |
473 | case SymbolTable::Visibility::Nested: |
474 | return os << "nested" ; |
475 | } |
476 | llvm_unreachable("Unexpected visibility" ); |
477 | } |
478 | |
479 | //===----------------------------------------------------------------------===// |
480 | // SymbolTable Trait Types |
481 | //===----------------------------------------------------------------------===// |
482 | |
483 | LogicalResult detail::verifySymbolTable(Operation *op) { |
484 | if (op->getNumRegions() != 1) |
485 | return op->emitOpError() |
486 | << "Operations with a 'SymbolTable' must have exactly one region" ; |
487 | if (!llvm::hasSingleElement(C&: op->getRegion(index: 0))) |
488 | return op->emitOpError() |
489 | << "Operations with a 'SymbolTable' must have exactly one block" ; |
490 | |
491 | // Check that all symbols are uniquely named within child regions. |
492 | DenseMap<Attribute, Location> nameToOrigLoc; |
493 | for (auto &block : op->getRegion(index: 0)) { |
494 | for (auto &op : block) { |
495 | // Check for a symbol name attribute. |
496 | auto nameAttr = |
497 | op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()); |
498 | if (!nameAttr) |
499 | continue; |
500 | |
501 | // Try to insert this symbol into the table. |
502 | auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc()); |
503 | if (!it.second) |
504 | return op.emitError() |
505 | .append("redefinition of symbol named '" , nameAttr.getValue(), "'" ) |
506 | .attachNote(it.first->second) |
507 | .append("see existing symbol definition here" ); |
508 | } |
509 | } |
510 | |
511 | // Verify any nested symbol user operations. |
512 | SymbolTableCollection symbolTable; |
513 | auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> { |
514 | if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op)) |
515 | return WalkResult(user.verifySymbolUses(symbolTable)); |
516 | return WalkResult::advance(); |
517 | }; |
518 | |
519 | std::optional<WalkResult> result = |
520 | walkSymbolTable(regions: op->getRegions(), callback: verifySymbolUserFn); |
521 | return success(isSuccess: result && !result->wasInterrupted()); |
522 | } |
523 | |
524 | LogicalResult detail::verifySymbol(Operation *op) { |
525 | // Verify the name attribute. |
526 | if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName())) |
527 | return op->emitOpError() << "requires string attribute '" |
528 | << mlir::SymbolTable::getSymbolAttrName() << "'" ; |
529 | |
530 | // Verify the visibility attribute. |
531 | if (Attribute vis = op->getAttr(name: mlir::SymbolTable::getVisibilityAttrName())) { |
532 | StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis); |
533 | if (!visStrAttr) |
534 | return op->emitOpError() << "requires visibility attribute '" |
535 | << mlir::SymbolTable::getVisibilityAttrName() |
536 | << "' to be a string attribute, but got " << vis; |
537 | |
538 | if (!llvm::is_contained(ArrayRef<StringRef>{"public" , "private" , "nested" }, |
539 | visStrAttr.getValue())) |
540 | return op->emitOpError() |
541 | << "visibility expected to be one of [\"public\", \"private\", " |
542 | "\"nested\"], but got " |
543 | << visStrAttr; |
544 | } |
545 | return success(); |
546 | } |
547 | |
548 | //===----------------------------------------------------------------------===// |
549 | // Symbol Use Lists |
550 | //===----------------------------------------------------------------------===// |
551 | |
552 | /// Walk all of the symbol references within the given operation, invoking the |
553 | /// provided callback for each found use. The callbacks takes the use of the |
554 | /// symbol. |
555 | static WalkResult |
556 | walkSymbolRefs(Operation *op, |
557 | function_ref<WalkResult(SymbolTable::SymbolUse)> callback) { |
558 | return op->getAttrDictionary().walk<WalkOrder::PreOrder>( |
559 | [&](SymbolRefAttr symbolRef) { |
560 | if (callback({op, symbolRef}).wasInterrupted()) |
561 | return WalkResult::interrupt(); |
562 | |
563 | // Don't walk nested references. |
564 | return WalkResult::skip(); |
565 | }); |
566 | } |
567 | |
568 | /// Walk all of the uses, for any symbol, that are nested within the given |
569 | /// regions, invoking the provided callback for each. This does not traverse |
570 | /// into any nested symbol tables. |
571 | static std::optional<WalkResult> |
572 | walkSymbolUses(MutableArrayRef<Region> regions, |
573 | function_ref<WalkResult(SymbolTable::SymbolUse)> callback) { |
574 | return walkSymbolTable(regions, |
575 | callback: [&](Operation *op) -> std::optional<WalkResult> { |
576 | // Check that this isn't a potentially unknown symbol |
577 | // table. |
578 | if (isPotentiallyUnknownSymbolTable(op)) |
579 | return std::nullopt; |
580 | |
581 | return walkSymbolRefs(op, callback); |
582 | }); |
583 | } |
584 | /// Walk all of the uses, for any symbol, that are nested within the given |
585 | /// operation 'from', invoking the provided callback for each. This does not |
586 | /// traverse into any nested symbol tables. |
587 | static std::optional<WalkResult> |
588 | walkSymbolUses(Operation *from, |
589 | function_ref<WalkResult(SymbolTable::SymbolUse)> callback) { |
590 | // If this operation has regions, and it, as well as its dialect, isn't |
591 | // registered then conservatively fail. The operation may define a |
592 | // symbol table, so we can't opaquely know if we should traverse to find |
593 | // nested uses. |
594 | if (isPotentiallyUnknownSymbolTable(op: from)) |
595 | return std::nullopt; |
596 | |
597 | // Walk the uses on this operation. |
598 | if (walkSymbolRefs(op: from, callback).wasInterrupted()) |
599 | return WalkResult::interrupt(); |
600 | |
601 | // Only recurse if this operation is not a symbol table. A symbol table |
602 | // defines a new scope, so we can't walk the attributes from within the symbol |
603 | // table op. |
604 | if (!from->hasTrait<OpTrait::SymbolTable>()) |
605 | return walkSymbolUses(regions: from->getRegions(), callback); |
606 | return WalkResult::advance(); |
607 | } |
608 | |
609 | namespace { |
610 | /// This class represents a single symbol scope. A symbol scope represents the |
611 | /// set of operations nested within a symbol table that may reference symbols |
612 | /// within that table. A symbol scope does not contain the symbol table |
613 | /// operation itself, just its contained operations. A scope ends at leaf |
614 | /// operations or another symbol table operation. |
615 | struct SymbolScope { |
616 | /// Walk the symbol uses within this scope, invoking the given callback. |
617 | /// This variant is used when the callback type matches that expected by |
618 | /// 'walkSymbolUses'. |
619 | template <typename CallbackT, |
620 | std::enable_if_t<!std::is_same< |
621 | typename llvm::function_traits<CallbackT>::result_t, |
622 | void>::value> * = nullptr> |
623 | std::optional<WalkResult> walk(CallbackT cback) { |
624 | if (Region *region = llvm::dyn_cast_if_present<Region *>(Val&: limit)) |
625 | return walkSymbolUses(*region, cback); |
626 | return walkSymbolUses(limit.get<Operation *>(), cback); |
627 | } |
628 | /// This variant is used when the callback type matches a stripped down type: |
629 | /// void(SymbolTable::SymbolUse use) |
630 | template <typename CallbackT, |
631 | std::enable_if_t<std::is_same< |
632 | typename llvm::function_traits<CallbackT>::result_t, |
633 | void>::value> * = nullptr> |
634 | std::optional<WalkResult> walk(CallbackT cback) { |
635 | return walk([=](SymbolTable::SymbolUse use) { |
636 | return cback(use), WalkResult::advance(); |
637 | }); |
638 | } |
639 | |
640 | /// Walk all of the operations nested under the current scope without |
641 | /// traversing into any nested symbol tables. |
642 | template <typename CallbackT> |
643 | std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) { |
644 | if (Region *region = llvm::dyn_cast_if_present<Region *>(Val&: limit)) |
645 | return ::walkSymbolTable(*region, cback); |
646 | return ::walkSymbolTable(limit.get<Operation *>(), cback); |
647 | } |
648 | |
649 | /// The representation of the symbol within this scope. |
650 | SymbolRefAttr symbol; |
651 | |
652 | /// The IR unit representing this scope. |
653 | llvm::PointerUnion<Operation *, Region *> limit; |
654 | }; |
655 | } // namespace |
656 | |
657 | /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'. |
658 | static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, |
659 | Operation *limit) { |
660 | StringAttr symName = SymbolTable::getSymbolName(symbol); |
661 | assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit); |
662 | |
663 | // Compute the ancestors of 'limit'. |
664 | SetVector<Operation *, SmallVector<Operation *, 4>, |
665 | SmallPtrSet<Operation *, 4>> |
666 | limitAncestors; |
667 | Operation *limitAncestor = limit; |
668 | do { |
669 | // Check to see if 'symbol' is an ancestor of 'limit'. |
670 | if (limitAncestor == symbol) { |
671 | // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr |
672 | // doesn't support parent references. |
673 | if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == |
674 | symbol->getParentOp()) |
675 | return {{SymbolRefAttr::get(symName), limit}}; |
676 | return {}; |
677 | } |
678 | |
679 | limitAncestors.insert(X: limitAncestor); |
680 | } while ((limitAncestor = limitAncestor->getParentOp())); |
681 | |
682 | // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'. |
683 | Operation *commonAncestor = symbol->getParentOp(); |
684 | do { |
685 | if (limitAncestors.count(key: commonAncestor)) |
686 | break; |
687 | } while ((commonAncestor = commonAncestor->getParentOp())); |
688 | assert(commonAncestor && "'limit' and 'symbol' have no common ancestor" ); |
689 | |
690 | // Compute the set of valid nested references for 'symbol' as far up to the |
691 | // common ancestor as possible. |
692 | SmallVector<SymbolRefAttr, 2> references; |
693 | bool collectedAllReferences = succeeded( |
694 | collectValidReferencesFor(symbol, symName, commonAncestor, references)); |
695 | |
696 | // Handle the case where the common ancestor is 'limit'. |
697 | if (commonAncestor == limit) { |
698 | SmallVector<SymbolScope, 2> scopes; |
699 | |
700 | // Walk each of the ancestors of 'symbol', calling the compute function for |
701 | // each one. |
702 | Operation *limitIt = symbol->getParentOp(); |
703 | for (size_t i = 0, e = references.size(); i != e; |
704 | ++i, limitIt = limitIt->getParentOp()) { |
705 | assert(limitIt->hasTrait<OpTrait::SymbolTable>()); |
706 | scopes.push_back(Elt: {references[i], &limitIt->getRegion(index: 0)}); |
707 | } |
708 | return scopes; |
709 | } |
710 | |
711 | // Otherwise, we just need the symbol reference for 'symbol' that will be |
712 | // used within 'limit'. This is the last reference in the list we computed |
713 | // above if we were able to collect all references. |
714 | if (!collectedAllReferences) |
715 | return {}; |
716 | return {{references.back(), limit}}; |
717 | } |
718 | static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, |
719 | Region *limit) { |
720 | auto scopes = collectSymbolScopes(symbol, limit: limit->getParentOp()); |
721 | |
722 | // If we collected some scopes to walk, make sure to constrain the one for |
723 | // limit to the specific region requested. |
724 | if (!scopes.empty()) |
725 | scopes.back().limit = limit; |
726 | return scopes; |
727 | } |
728 | static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, |
729 | Region *limit) { |
730 | return {{SymbolRefAttr::get(symbol), limit}}; |
731 | } |
732 | |
733 | static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, |
734 | Operation *limit) { |
735 | SmallVector<SymbolScope, 1> scopes; |
736 | auto symbolRef = SymbolRefAttr::get(symbol); |
737 | for (auto ®ion : limit->getRegions()) |
738 | scopes.push_back(Elt: {symbolRef, ®ion}); |
739 | return scopes; |
740 | } |
741 | |
742 | /// Returns true if the given reference 'SubRef' is a sub reference of the |
743 | /// reference 'ref', i.e. 'ref' is a further qualified reference. |
744 | static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) { |
745 | if (ref == subRef) |
746 | return true; |
747 | |
748 | // If the references are not pointer equal, check to see if `subRef` is a |
749 | // prefix of `ref`. |
750 | if (llvm::isa<FlatSymbolRefAttr>(ref) || |
751 | ref.getRootReference() != subRef.getRootReference()) |
752 | return false; |
753 | |
754 | auto refLeafs = ref.getNestedReferences(); |
755 | auto subRefLeafs = subRef.getNestedReferences(); |
756 | return subRefLeafs.size() < refLeafs.size() && |
757 | subRefLeafs == refLeafs.take_front(subRefLeafs.size()); |
758 | } |
759 | |
760 | //===----------------------------------------------------------------------===// |
761 | // SymbolTable::getSymbolUses |
762 | |
763 | /// The implementation of SymbolTable::getSymbolUses below. |
764 | template <typename FromT> |
765 | static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) { |
766 | std::vector<SymbolTable::SymbolUse> uses; |
767 | auto walkFn = [&](SymbolTable::SymbolUse symbolUse) { |
768 | uses.push_back(x: symbolUse); |
769 | return WalkResult::advance(); |
770 | }; |
771 | auto result = walkSymbolUses(from, walkFn); |
772 | return result ? std::optional<SymbolTable::UseRange>(std::move(uses)) |
773 | : std::nullopt; |
774 | } |
775 | |
776 | /// Get an iterator range for all of the uses, for any symbol, that are nested |
777 | /// within the given operation 'from'. This does not traverse into any nested |
778 | /// symbol tables, and will also only return uses on 'from' if it does not |
779 | /// also define a symbol table. This is because we treat the region as the |
780 | /// boundary of the symbol table, and not the op itself. This function returns |
781 | /// std::nullopt if there are any unknown operations that may potentially be |
782 | /// symbol tables. |
783 | auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> { |
784 | return getSymbolUsesImpl(from); |
785 | } |
786 | auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> { |
787 | return getSymbolUsesImpl(from: MutableArrayRef<Region>(*from)); |
788 | } |
789 | |
790 | //===----------------------------------------------------------------------===// |
791 | // SymbolTable::getSymbolUses |
792 | |
793 | /// The implementation of SymbolTable::getSymbolUses below. |
794 | template <typename SymbolT, typename IRUnitT> |
795 | static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol, |
796 | IRUnitT *limit) { |
797 | std::vector<SymbolTable::SymbolUse> uses; |
798 | for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { |
799 | if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) { |
800 | if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())) |
801 | uses.push_back(x: symbolUse); |
802 | })) |
803 | return std::nullopt; |
804 | } |
805 | return SymbolTable::UseRange(std::move(uses)); |
806 | } |
807 | |
808 | /// Get all of the uses of the given symbol that are nested within the given |
809 | /// operation 'from', invoking the provided callback for each. This does not |
810 | /// traverse into any nested symbol tables. This function returns std::nullopt |
811 | /// if there are any unknown operations that may potentially be symbol tables. |
812 | auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from) |
813 | -> std::optional<UseRange> { |
814 | return getSymbolUsesImpl(symbol, from); |
815 | } |
816 | auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from) |
817 | -> std::optional<UseRange> { |
818 | return getSymbolUsesImpl(symbol, limit: from); |
819 | } |
820 | auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from) |
821 | -> std::optional<UseRange> { |
822 | return getSymbolUsesImpl(symbol, from); |
823 | } |
824 | auto SymbolTable::getSymbolUses(Operation *symbol, Region *from) |
825 | -> std::optional<UseRange> { |
826 | return getSymbolUsesImpl(symbol, limit: from); |
827 | } |
828 | |
829 | //===----------------------------------------------------------------------===// |
830 | // SymbolTable::symbolKnownUseEmpty |
831 | |
832 | /// The implementation of SymbolTable::symbolKnownUseEmpty below. |
833 | template <typename SymbolT, typename IRUnitT> |
834 | static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) { |
835 | for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { |
836 | // Walk all of the symbol uses looking for a reference to 'symbol'. |
837 | if (scope.walk([&](SymbolTable::SymbolUse symbolUse) { |
838 | return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()) |
839 | ? WalkResult::interrupt() |
840 | : WalkResult::advance(); |
841 | }) != WalkResult::advance()) |
842 | return false; |
843 | } |
844 | return true; |
845 | } |
846 | |
847 | /// Return if the given symbol is known to have no uses that are nested within |
848 | /// the given operation 'from'. This does not traverse into any nested symbol |
849 | /// tables. This function will also return false if there are any unknown |
850 | /// operations that may potentially be symbol tables. |
851 | bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) { |
852 | return symbolKnownUseEmptyImpl(symbol, from); |
853 | } |
854 | bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) { |
855 | return symbolKnownUseEmptyImpl(symbol, limit: from); |
856 | } |
857 | bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) { |
858 | return symbolKnownUseEmptyImpl(symbol, from); |
859 | } |
860 | bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) { |
861 | return symbolKnownUseEmptyImpl(symbol, limit: from); |
862 | } |
863 | |
864 | //===----------------------------------------------------------------------===// |
865 | // SymbolTable::replaceAllSymbolUses |
866 | |
867 | /// Generates a new symbol reference attribute with a new leaf reference. |
868 | static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, |
869 | FlatSymbolRefAttr newLeafAttr) { |
870 | if (llvm::isa<FlatSymbolRefAttr>(oldAttr)) |
871 | return newLeafAttr; |
872 | auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); |
873 | nestedRefs.back() = newLeafAttr; |
874 | return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs); |
875 | } |
876 | |
877 | /// The implementation of SymbolTable::replaceAllSymbolUses below. |
878 | template <typename SymbolT, typename IRUnitT> |
879 | static LogicalResult |
880 | replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) { |
881 | // Generate a new attribute to replace the given attribute. |
882 | FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol); |
883 | for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { |
884 | SymbolRefAttr oldAttr = scope.symbol; |
885 | SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); |
886 | AttrTypeReplacer replacer; |
887 | replacer.addReplacement( |
888 | [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> { |
889 | // Regardless of the match, don't walk nested SymbolRefAttrs, we don't |
890 | // want to accidentally replace an inner reference. |
891 | if (attr == oldAttr) |
892 | return {newAttr, WalkResult::skip()}; |
893 | // Handle prefix matches. |
894 | if (isReferencePrefixOf(oldAttr, attr)) { |
895 | auto oldNestedRefs = oldAttr.getNestedReferences(); |
896 | auto nestedRefs = attr.getNestedReferences(); |
897 | if (oldNestedRefs.empty()) |
898 | return {SymbolRefAttr::get(newSymbol, nestedRefs), |
899 | WalkResult::skip()}; |
900 | |
901 | auto newNestedRefs = llvm::to_vector<4>(nestedRefs); |
902 | newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr; |
903 | return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs), |
904 | WalkResult::skip()}; |
905 | } |
906 | return {attr, WalkResult::skip()}; |
907 | }); |
908 | |
909 | auto walkFn = [&](Operation *op) -> std::optional<WalkResult> { |
910 | replacer.replaceElementsIn(op); |
911 | return WalkResult::advance(); |
912 | }; |
913 | if (!scope.walkSymbolTable(walkFn)) |
914 | return failure(); |
915 | } |
916 | return success(); |
917 | } |
918 | |
919 | /// Attempt to replace all uses of the given symbol 'oldSymbol' with the |
920 | /// provided symbol 'newSymbol' that are nested within the given operation |
921 | /// 'from'. This does not traverse into any nested symbol tables. If there are |
922 | /// any unknown operations that may potentially be symbol tables, no uses are |
923 | /// replaced and failure is returned. |
924 | LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol, |
925 | StringAttr newSymbol, |
926 | Operation *from) { |
927 | return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); |
928 | } |
929 | LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, |
930 | StringAttr newSymbol, |
931 | Operation *from) { |
932 | return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); |
933 | } |
934 | LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol, |
935 | StringAttr newSymbol, |
936 | Region *from) { |
937 | return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); |
938 | } |
939 | LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, |
940 | StringAttr newSymbol, |
941 | Region *from) { |
942 | return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); |
943 | } |
944 | |
945 | //===----------------------------------------------------------------------===// |
946 | // SymbolTableCollection |
947 | //===----------------------------------------------------------------------===// |
948 | |
949 | Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
950 | StringAttr symbol) { |
951 | return getSymbolTable(op: symbolTableOp).lookup(symbol); |
952 | } |
953 | Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
954 | SymbolRefAttr name) { |
955 | SmallVector<Operation *, 4> symbols; |
956 | if (failed(lookupSymbolIn(symbolTableOp, name, symbols))) |
957 | return nullptr; |
958 | return symbols.back(); |
959 | } |
960 | /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by |
961 | /// a given SymbolRefAttr. Returns failure if any of the nested references could |
962 | /// not be resolved. |
963 | LogicalResult |
964 | SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
965 | SymbolRefAttr name, |
966 | SmallVectorImpl<Operation *> &symbols) { |
967 | auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) { |
968 | return lookupSymbolIn(symbolTableOp, symbol); |
969 | }; |
970 | return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn); |
971 | } |
972 | |
973 | /// Returns the operation registered with the given symbol name within the |
974 | /// closest parent operation of, or including, 'from' with the |
975 | /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was |
976 | /// found. |
977 | Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from, |
978 | StringAttr symbol) { |
979 | Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from); |
980 | return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; |
981 | } |
982 | Operation * |
983 | SymbolTableCollection::lookupNearestSymbolFrom(Operation *from, |
984 | SymbolRefAttr symbol) { |
985 | Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from); |
986 | return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; |
987 | } |
988 | |
989 | /// Lookup, or create, a symbol table for an operation. |
990 | SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) { |
991 | auto it = symbolTables.try_emplace(Key: op, Args: nullptr); |
992 | if (it.second) |
993 | it.first->second = std::make_unique<SymbolTable>(args&: op); |
994 | return *it.first->second; |
995 | } |
996 | |
997 | //===----------------------------------------------------------------------===// |
998 | // LockedSymbolTableCollection |
999 | //===----------------------------------------------------------------------===// |
1000 | |
1001 | Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
1002 | StringAttr symbol) { |
1003 | return getSymbolTable(symbolTableOp).lookup(symbol); |
1004 | } |
1005 | |
1006 | Operation * |
1007 | LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
1008 | FlatSymbolRefAttr symbol) { |
1009 | return lookupSymbolIn(symbolTableOp, symbol.getAttr()); |
1010 | } |
1011 | |
1012 | Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
1013 | SymbolRefAttr name) { |
1014 | SmallVector<Operation *> symbols; |
1015 | if (failed(lookupSymbolIn(symbolTableOp, name, symbols))) |
1016 | return nullptr; |
1017 | return symbols.back(); |
1018 | } |
1019 | |
1020 | LogicalResult LockedSymbolTableCollection::lookupSymbolIn( |
1021 | Operation *symbolTableOp, SymbolRefAttr name, |
1022 | SmallVectorImpl<Operation *> &symbols) { |
1023 | auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) { |
1024 | return lookupSymbolIn(symbolTableOp, symbol); |
1025 | }; |
1026 | return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn); |
1027 | } |
1028 | |
1029 | SymbolTable & |
1030 | LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) { |
1031 | assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); |
1032 | // Try to find an existing symbol table. |
1033 | { |
1034 | llvm::sys::SmartScopedReader<true> lock(mutex); |
1035 | auto it = collection.symbolTables.find(Val: symbolTableOp); |
1036 | if (it != collection.symbolTables.end()) |
1037 | return *it->second; |
1038 | } |
1039 | // Create a symbol table for the operation. Perform construction outside of |
1040 | // the critical section. |
1041 | auto symbolTable = std::make_unique<SymbolTable>(args&: symbolTableOp); |
1042 | // Insert the constructed symbol table. |
1043 | llvm::sys::SmartScopedWriter<true> lock(mutex); |
1044 | return *collection.symbolTables |
1045 | .insert(KV: {symbolTableOp, std::move(symbolTable)}) |
1046 | .first->second; |
1047 | } |
1048 | |
1049 | //===----------------------------------------------------------------------===// |
1050 | // SymbolUserMap |
1051 | //===----------------------------------------------------------------------===// |
1052 | |
1053 | SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable, |
1054 | Operation *symbolTableOp) |
1055 | : symbolTable(symbolTable) { |
1056 | // Walk each of the symbol tables looking for discardable callgraph nodes. |
1057 | SmallVector<Operation *> symbols; |
1058 | auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { |
1059 | for (Operation &nestedOp : symbolTableOp->getRegion(index: 0).getOps()) { |
1060 | auto symbolUses = SymbolTable::getSymbolUses(from: &nestedOp); |
1061 | assert(symbolUses && "expected uses to be valid" ); |
1062 | |
1063 | for (const SymbolTable::SymbolUse &use : *symbolUses) { |
1064 | symbols.clear(); |
1065 | (void)symbolTable.lookupSymbolIn(symbolTableOp, name: use.getSymbolRef(), |
1066 | symbols); |
1067 | for (Operation *symbolOp : symbols) |
1068 | symbolToUsers[symbolOp].insert(X: use.getUser()); |
1069 | } |
1070 | } |
1071 | }; |
1072 | // We just set `allSymUsesVisible` to false here because it isn't necessary |
1073 | // for building the user map. |
1074 | SymbolTable::walkSymbolTables(op: symbolTableOp, /*allSymUsesVisible=*/false, |
1075 | callback: walkFn); |
1076 | } |
1077 | |
1078 | void SymbolUserMap::replaceAllUsesWith(Operation *symbol, |
1079 | StringAttr newSymbolName) { |
1080 | auto it = symbolToUsers.find(Val: symbol); |
1081 | if (it == symbolToUsers.end()) |
1082 | return; |
1083 | |
1084 | // Replace the uses within the users of `symbol`. |
1085 | for (Operation *user : it->second) |
1086 | (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user); |
1087 | |
1088 | // Move the current users of `symbol` to the new symbol if it is in the |
1089 | // symbol table. |
1090 | Operation *newSymbol = |
1091 | symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName); |
1092 | if (newSymbol != symbol) { |
1093 | // Transfer over the users to the new symbol. The reference to the old one |
1094 | // is fetched again as the iterator is invalidated during the insertion. |
1095 | auto newIt = symbolToUsers.try_emplace(Key: newSymbol, Args: SetVector<Operation *>{}); |
1096 | auto oldIt = symbolToUsers.find(Val: symbol); |
1097 | assert(oldIt != symbolToUsers.end() && "missing old users list" ); |
1098 | if (newIt.second) |
1099 | newIt.first->second = std::move(oldIt->second); |
1100 | else |
1101 | newIt.first->second.set_union(oldIt->second); |
1102 | symbolToUsers.erase(I: oldIt); |
1103 | } |
1104 | } |
1105 | |
1106 | //===----------------------------------------------------------------------===// |
1107 | // Visibility parsing implementation. |
1108 | //===----------------------------------------------------------------------===// |
1109 | |
1110 | ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser, |
1111 | NamedAttrList &attrs) { |
1112 | StringRef visibility; |
1113 | if (parser.parseOptionalKeyword(keyword: &visibility, allowedValues: {"public" , "private" , "nested" })) |
1114 | return failure(); |
1115 | |
1116 | StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility); |
1117 | attrs.push_back(newAttribute: parser.getBuilder().getNamedAttr( |
1118 | name: SymbolTable::getVisibilityAttrName(), val: visibilityAttr)); |
1119 | return success(); |
1120 | } |
1121 | |
1122 | //===----------------------------------------------------------------------===// |
1123 | // Symbol Interfaces |
1124 | //===----------------------------------------------------------------------===// |
1125 | |
1126 | /// Include the generated symbol interfaces. |
1127 | #include "mlir/IR/SymbolInterfaces.cpp.inc" |
1128 | |