| 1 | //===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/IR/SymbolTable.h" |
| 10 | #include "mlir/IR/Builders.h" |
| 11 | #include "mlir/IR/OpImplementation.h" |
| 12 | #include "llvm/ADT/SetVector.h" |
| 13 | #include "llvm/ADT/SmallPtrSet.h" |
| 14 | #include "llvm/ADT/SmallString.h" |
| 15 | #include "llvm/ADT/StringSwitch.h" |
| 16 | #include <optional> |
| 17 | |
| 18 | using namespace mlir; |
| 19 | |
| 20 | /// Return true if the given operation is unknown and may potentially define a |
| 21 | /// symbol table. |
| 22 | static bool isPotentiallyUnknownSymbolTable(Operation *op) { |
| 23 | return op->getNumRegions() == 1 && !op->getDialect(); |
| 24 | } |
| 25 | |
| 26 | /// Returns the string name of the given symbol, or null if this is not a |
| 27 | /// symbol. |
| 28 | static StringAttr getNameIfSymbol(Operation *op) { |
| 29 | return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); |
| 30 | } |
| 31 | static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) { |
| 32 | return op->getAttrOfType<StringAttr>(symbolAttrNameId); |
| 33 | } |
| 34 | |
| 35 | /// Computes the nested symbol reference attribute for the symbol 'symbolName' |
| 36 | /// that are usable within the symbol table operations from 'symbol' as far up |
| 37 | /// to the given operation 'within', where 'within' is an ancestor of 'symbol'. |
| 38 | /// Returns success if all references up to 'within' could be computed. |
| 39 | static LogicalResult |
| 40 | collectValidReferencesFor(Operation *symbol, StringAttr symbolName, |
| 41 | Operation *within, |
| 42 | SmallVectorImpl<SymbolRefAttr> &results) { |
| 43 | assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor" ); |
| 44 | MLIRContext *ctx = symbol->getContext(); |
| 45 | |
| 46 | auto leafRef = FlatSymbolRefAttr::get(symbolName); |
| 47 | results.push_back(leafRef); |
| 48 | |
| 49 | // Early exit for when 'within' is the parent of 'symbol'. |
| 50 | Operation *symbolTableOp = symbol->getParentOp(); |
| 51 | if (within == symbolTableOp) |
| 52 | return success(); |
| 53 | |
| 54 | // Collect references until 'symbolTableOp' reaches 'within'. |
| 55 | SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef); |
| 56 | StringAttr symbolNameId = |
| 57 | StringAttr::get(ctx, SymbolTable::getSymbolAttrName()); |
| 58 | do { |
| 59 | // Each parent of 'symbol' should define a symbol table. |
| 60 | if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) |
| 61 | return failure(); |
| 62 | // Each parent of 'symbol' should also be a symbol. |
| 63 | StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId); |
| 64 | if (!symbolTableName) |
| 65 | return failure(); |
| 66 | results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs)); |
| 67 | |
| 68 | symbolTableOp = symbolTableOp->getParentOp(); |
| 69 | if (symbolTableOp == within) |
| 70 | break; |
| 71 | nestedRefs.insert(nestedRefs.begin(), |
| 72 | FlatSymbolRefAttr::get(symbolTableName)); |
| 73 | } while (true); |
| 74 | return success(); |
| 75 | } |
| 76 | |
| 77 | /// Walk all of the operations within the given set of regions, without |
| 78 | /// traversing into any nested symbol tables. Stops walking if the result of the |
| 79 | /// callback is anything other than `WalkResult::advance`. |
| 80 | static std::optional<WalkResult> |
| 81 | walkSymbolTable(MutableArrayRef<Region> regions, |
| 82 | function_ref<std::optional<WalkResult>(Operation *)> callback) { |
| 83 | SmallVector<Region *, 1> worklist(llvm::make_pointer_range(Range&: regions)); |
| 84 | while (!worklist.empty()) { |
| 85 | for (Operation &op : worklist.pop_back_val()->getOps()) { |
| 86 | std::optional<WalkResult> result = callback(&op); |
| 87 | if (result != WalkResult::advance()) |
| 88 | return result; |
| 89 | |
| 90 | // If this op defines a new symbol table scope, we can't traverse. Any |
| 91 | // symbol references nested within 'op' are different semantically. |
| 92 | if (!op.hasTrait<OpTrait::SymbolTable>()) { |
| 93 | for (Region ®ion : op.getRegions()) |
| 94 | worklist.push_back(Elt: ®ion); |
| 95 | } |
| 96 | } |
| 97 | } |
| 98 | return WalkResult::advance(); |
| 99 | } |
| 100 | |
| 101 | /// Walk all of the operations nested under, and including, the given operation, |
| 102 | /// without traversing into any nested symbol tables. Stops walking if the |
| 103 | /// result of the callback is anything other than `WalkResult::advance`. |
| 104 | static std::optional<WalkResult> |
| 105 | walkSymbolTable(Operation *op, |
| 106 | function_ref<std::optional<WalkResult>(Operation *)> callback) { |
| 107 | std::optional<WalkResult> result = callback(op); |
| 108 | if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>()) |
| 109 | return result; |
| 110 | return walkSymbolTable(regions: op->getRegions(), callback); |
| 111 | } |
| 112 | |
| 113 | //===----------------------------------------------------------------------===// |
| 114 | // SymbolTable |
| 115 | //===----------------------------------------------------------------------===// |
| 116 | |
| 117 | /// Build a symbol table with the symbols within the given operation. |
| 118 | SymbolTable::SymbolTable(Operation *symbolTableOp) |
| 119 | : symbolTableOp(symbolTableOp) { |
| 120 | assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() && |
| 121 | "expected operation to have SymbolTable trait" ); |
| 122 | assert(symbolTableOp->getNumRegions() == 1 && |
| 123 | "expected operation to have a single region" ); |
| 124 | assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) && |
| 125 | "expected operation to have a single block" ); |
| 126 | |
| 127 | StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(), |
| 128 | SymbolTable::getSymbolAttrName()); |
| 129 | for (auto &op : symbolTableOp->getRegion(index: 0).front()) { |
| 130 | StringAttr name = getNameIfSymbol(&op, symbolNameId); |
| 131 | if (!name) |
| 132 | continue; |
| 133 | |
| 134 | auto inserted = symbolTable.insert({name, &op}); |
| 135 | (void)inserted; |
| 136 | assert(inserted.second && |
| 137 | "expected region to contain uniquely named symbol operations" ); |
| 138 | } |
| 139 | } |
| 140 | |
| 141 | /// Look up a symbol with the specified name, returning null if no such name |
| 142 | /// exists. Names never include the @ on them. |
| 143 | Operation *SymbolTable::lookup(StringRef name) const { |
| 144 | return lookup(StringAttr::get(symbolTableOp->getContext(), name)); |
| 145 | } |
| 146 | Operation *SymbolTable::lookup(StringAttr name) const { |
| 147 | return symbolTable.lookup(Val: name); |
| 148 | } |
| 149 | |
| 150 | void SymbolTable::remove(Operation *op) { |
| 151 | StringAttr name = getNameIfSymbol(op); |
| 152 | assert(name && "expected valid 'name' attribute" ); |
| 153 | assert(op->getParentOp() == symbolTableOp && |
| 154 | "expected this operation to be inside of the operation with this " |
| 155 | "SymbolTable" ); |
| 156 | |
| 157 | auto it = symbolTable.find(name); |
| 158 | if (it != symbolTable.end() && it->second == op) |
| 159 | symbolTable.erase(it); |
| 160 | } |
| 161 | |
| 162 | void SymbolTable::erase(Operation *symbol) { |
| 163 | remove(op: symbol); |
| 164 | symbol->erase(); |
| 165 | } |
| 166 | |
| 167 | // TODO: Consider if this should be renamed to something like insertOrUpdate |
| 168 | /// Insert a new symbol into the table and associated operation if not already |
| 169 | /// there and rename it as necessary to avoid collisions. Return the name of |
| 170 | /// the symbol after insertion as attribute. |
| 171 | StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { |
| 172 | // The symbol cannot be the child of another op and must be the child of the |
| 173 | // symbolTableOp after this. |
| 174 | // |
| 175 | // TODO: consider if SymbolTable's constructor should behave the same. |
| 176 | if (!symbol->getParentOp()) { |
| 177 | auto &body = symbolTableOp->getRegion(index: 0).front(); |
| 178 | if (insertPt == Block::iterator()) { |
| 179 | insertPt = Block::iterator(body.end()); |
| 180 | } else { |
| 181 | assert((insertPt == body.end() || |
| 182 | insertPt->getParentOp() == symbolTableOp) && |
| 183 | "expected insertPt to be in the associated module operation" ); |
| 184 | } |
| 185 | // Insert before the terminator, if any. |
| 186 | if (insertPt == Block::iterator(body.end()) && !body.empty() && |
| 187 | std::prev(x: body.end())->hasTrait<OpTrait::IsTerminator>()) |
| 188 | insertPt = std::prev(x: body.end()); |
| 189 | |
| 190 | body.getOperations().insert(where: insertPt, New: symbol); |
| 191 | } |
| 192 | assert(symbol->getParentOp() == symbolTableOp && |
| 193 | "symbol is already inserted in another op" ); |
| 194 | |
| 195 | // Add this symbol to the symbol table, uniquing the name if a conflict is |
| 196 | // detected. |
| 197 | StringAttr name = getSymbolName(symbol); |
| 198 | if (symbolTable.insert({name, symbol}).second) |
| 199 | return name; |
| 200 | // If the symbol was already in the table, also return. |
| 201 | if (symbolTable.lookup(Val: name) == symbol) |
| 202 | return name; |
| 203 | |
| 204 | MLIRContext *context = symbol->getContext(); |
| 205 | SmallString<128> nameBuffer = generateSymbolName<128>( |
| 206 | name.getValue(), |
| 207 | [&](StringRef candidate) { |
| 208 | return !symbolTable |
| 209 | .insert({StringAttr::get(context, candidate), symbol}) |
| 210 | .second; |
| 211 | }, |
| 212 | uniquingCounter); |
| 213 | setSymbolName(symbol, name: nameBuffer); |
| 214 | return getSymbolName(symbol); |
| 215 | } |
| 216 | |
| 217 | LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) { |
| 218 | Operation *op = lookup(from); |
| 219 | return rename(op, to); |
| 220 | } |
| 221 | |
| 222 | LogicalResult SymbolTable::rename(Operation *op, StringAttr to) { |
| 223 | StringAttr from = getNameIfSymbol(op); |
| 224 | (void)from; |
| 225 | |
| 226 | assert(from && "expected valid 'name' attribute" ); |
| 227 | assert(op->getParentOp() == symbolTableOp && |
| 228 | "expected this operation to be inside of the operation with this " |
| 229 | "SymbolTable" ); |
| 230 | assert(lookup(from) == op && "current name does not resolve to op" ); |
| 231 | assert(lookup(to) == nullptr && "new name already exists" ); |
| 232 | |
| 233 | if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp()))) |
| 234 | return failure(); |
| 235 | |
| 236 | // Remove op with old name, change name, add with new name. The order is |
| 237 | // important here due to how `remove` and `insert` rely on the op name. |
| 238 | remove(op); |
| 239 | setSymbolName(op, to); |
| 240 | insert(op); |
| 241 | |
| 242 | assert(lookup(to) == op && "new name does not resolve to renamed op" ); |
| 243 | assert(lookup(from) == nullptr && "old name still exists" ); |
| 244 | |
| 245 | return success(); |
| 246 | } |
| 247 | |
| 248 | LogicalResult SymbolTable::rename(StringAttr from, StringRef to) { |
| 249 | auto toAttr = StringAttr::get(getOp()->getContext(), to); |
| 250 | return rename(from, toAttr); |
| 251 | } |
| 252 | |
| 253 | LogicalResult SymbolTable::rename(Operation *op, StringRef to) { |
| 254 | auto toAttr = StringAttr::get(getOp()->getContext(), to); |
| 255 | return rename(op, toAttr); |
| 256 | } |
| 257 | |
| 258 | FailureOr<StringAttr> |
| 259 | SymbolTable::renameToUnique(StringAttr oldName, |
| 260 | ArrayRef<SymbolTable *> others) { |
| 261 | |
| 262 | // Determine new name that is unique in all symbol tables. |
| 263 | StringAttr newName; |
| 264 | { |
| 265 | MLIRContext *context = oldName.getContext(); |
| 266 | SmallString<64> prefix = oldName.getValue(); |
| 267 | int uniqueId = 0; |
| 268 | prefix.push_back(Elt: '_'); |
| 269 | while (true) { |
| 270 | newName = StringAttr::get(context, prefix + Twine(uniqueId++)); |
| 271 | auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); }; |
| 272 | if (!lookupNewName(this) && llvm::none_of(Range&: others, P: lookupNewName)) { |
| 273 | break; |
| 274 | } |
| 275 | } |
| 276 | } |
| 277 | |
| 278 | // Apply renaming. |
| 279 | if (failed(rename(oldName, newName))) |
| 280 | return failure(); |
| 281 | return newName; |
| 282 | } |
| 283 | |
| 284 | FailureOr<StringAttr> |
| 285 | SymbolTable::renameToUnique(Operation *op, ArrayRef<SymbolTable *> others) { |
| 286 | StringAttr from = getNameIfSymbol(op); |
| 287 | assert(from && "expected valid 'name' attribute" ); |
| 288 | return renameToUnique(from, others); |
| 289 | } |
| 290 | |
| 291 | /// Returns the name of the given symbol operation. |
| 292 | StringAttr SymbolTable::getSymbolName(Operation *symbol) { |
| 293 | StringAttr name = getNameIfSymbol(symbol); |
| 294 | assert(name && "expected valid symbol name" ); |
| 295 | return name; |
| 296 | } |
| 297 | |
| 298 | /// Sets the name of the given symbol operation. |
| 299 | void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) { |
| 300 | symbol->setAttr(getSymbolAttrName(), name); |
| 301 | } |
| 302 | |
| 303 | /// Returns the visibility of the given symbol operation. |
| 304 | SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) { |
| 305 | // If the attribute doesn't exist, assume public. |
| 306 | StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName()); |
| 307 | if (!vis) |
| 308 | return Visibility::Public; |
| 309 | |
| 310 | // Otherwise, switch on the string value. |
| 311 | return StringSwitch<Visibility>(vis.getValue()) |
| 312 | .Case(S: "private" , Value: Visibility::Private) |
| 313 | .Case(S: "nested" , Value: Visibility::Nested) |
| 314 | .Case(S: "public" , Value: Visibility::Public); |
| 315 | } |
| 316 | /// Sets the visibility of the given symbol operation. |
| 317 | void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) { |
| 318 | MLIRContext *ctx = symbol->getContext(); |
| 319 | |
| 320 | // If the visibility is public, just drop the attribute as this is the |
| 321 | // default. |
| 322 | if (vis == Visibility::Public) { |
| 323 | symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName())); |
| 324 | return; |
| 325 | } |
| 326 | |
| 327 | // Otherwise, update the attribute. |
| 328 | assert((vis == Visibility::Private || vis == Visibility::Nested) && |
| 329 | "unknown symbol visibility kind" ); |
| 330 | |
| 331 | StringRef visName = vis == Visibility::Private ? "private" : "nested" ; |
| 332 | symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName)); |
| 333 | } |
| 334 | |
| 335 | /// Returns the nearest symbol table from a given operation `from`. Returns |
| 336 | /// nullptr if no valid parent symbol table could be found. |
| 337 | Operation *SymbolTable::getNearestSymbolTable(Operation *from) { |
| 338 | assert(from && "expected valid operation" ); |
| 339 | if (isPotentiallyUnknownSymbolTable(op: from)) |
| 340 | return nullptr; |
| 341 | |
| 342 | while (!from->hasTrait<OpTrait::SymbolTable>()) { |
| 343 | from = from->getParentOp(); |
| 344 | |
| 345 | // Check that this is a valid op and isn't an unknown symbol table. |
| 346 | if (!from || isPotentiallyUnknownSymbolTable(op: from)) |
| 347 | return nullptr; |
| 348 | } |
| 349 | return from; |
| 350 | } |
| 351 | |
| 352 | /// Walks all symbol table operations nested within, and including, `op`. For |
| 353 | /// each symbol table operation, the provided callback is invoked with the op |
| 354 | /// and a boolean signifying if the symbols within that symbol table can be |
| 355 | /// treated as if all uses are visible. `allSymUsesVisible` identifies whether |
| 356 | /// all of the symbol uses of symbols within `op` are visible. |
| 357 | void SymbolTable::walkSymbolTables( |
| 358 | Operation *op, bool allSymUsesVisible, |
| 359 | function_ref<void(Operation *, bool)> callback) { |
| 360 | bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>(); |
| 361 | if (isSymbolTable) { |
| 362 | SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op); |
| 363 | allSymUsesVisible |= !symbol || symbol.isPrivate(); |
| 364 | } else { |
| 365 | // Otherwise if 'op' is not a symbol table, any nested symbols are |
| 366 | // guaranteed to be hidden. |
| 367 | allSymUsesVisible = true; |
| 368 | } |
| 369 | |
| 370 | for (Region ®ion : op->getRegions()) |
| 371 | for (Block &block : region) |
| 372 | for (Operation &nestedOp : block) |
| 373 | walkSymbolTables(op: &nestedOp, allSymUsesVisible, callback); |
| 374 | |
| 375 | // If 'op' had the symbol table trait, visit it after any nested symbol |
| 376 | // tables. |
| 377 | if (isSymbolTable) |
| 378 | callback(op, allSymUsesVisible); |
| 379 | } |
| 380 | |
| 381 | /// Returns the operation registered with the given symbol name with the |
| 382 | /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation |
| 383 | /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol |
| 384 | /// was found. |
| 385 | Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, |
| 386 | StringAttr symbol) { |
| 387 | assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); |
| 388 | Region ®ion = symbolTableOp->getRegion(index: 0); |
| 389 | if (region.empty()) |
| 390 | return nullptr; |
| 391 | |
| 392 | // Look for a symbol with the given name. |
| 393 | StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(), |
| 394 | SymbolTable::getSymbolAttrName()); |
| 395 | for (auto &op : region.front()) |
| 396 | if (getNameIfSymbol(&op, symbolNameId) == symbol) |
| 397 | return &op; |
| 398 | return nullptr; |
| 399 | } |
| 400 | Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, |
| 401 | SymbolRefAttr symbol) { |
| 402 | SmallVector<Operation *, 4> resolvedSymbols; |
| 403 | if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols))) |
| 404 | return nullptr; |
| 405 | return resolvedSymbols.back(); |
| 406 | } |
| 407 | |
| 408 | /// Internal implementation of `lookupSymbolIn` that allows for specialized |
| 409 | /// implementations of the lookup function. |
| 410 | static LogicalResult lookupSymbolInImpl( |
| 411 | Operation *symbolTableOp, SymbolRefAttr symbol, |
| 412 | SmallVectorImpl<Operation *> &symbols, |
| 413 | function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) { |
| 414 | assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); |
| 415 | |
| 416 | // Lookup the root reference for this symbol. |
| 417 | symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference()); |
| 418 | if (!symbolTableOp) |
| 419 | return failure(); |
| 420 | symbols.push_back(Elt: symbolTableOp); |
| 421 | |
| 422 | // If there are no nested references, just return the root symbol directly. |
| 423 | ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences(); |
| 424 | if (nestedRefs.empty()) |
| 425 | return success(); |
| 426 | |
| 427 | // Verify that the root is also a symbol table. |
| 428 | if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) |
| 429 | return failure(); |
| 430 | |
| 431 | // Otherwise, lookup each of the nested non-leaf references and ensure that |
| 432 | // each corresponds to a valid symbol table. |
| 433 | for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) { |
| 434 | symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr()); |
| 435 | if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>()) |
| 436 | return failure(); |
| 437 | symbols.push_back(symbolTableOp); |
| 438 | } |
| 439 | symbols.push_back(Elt: lookupSymbolFn(symbolTableOp, symbol.getLeafReference())); |
| 440 | return success(IsSuccess: symbols.back()); |
| 441 | } |
| 442 | |
| 443 | LogicalResult |
| 444 | SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol, |
| 445 | SmallVectorImpl<Operation *> &symbols) { |
| 446 | auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) { |
| 447 | return lookupSymbolIn(symbolTableOp, symbol); |
| 448 | }; |
| 449 | return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn); |
| 450 | } |
| 451 | |
| 452 | /// Returns the operation registered with the given symbol name within the |
| 453 | /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns |
| 454 | /// nullptr if no valid symbol was found. |
| 455 | Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, |
| 456 | StringAttr symbol) { |
| 457 | Operation *symbolTableOp = getNearestSymbolTable(from); |
| 458 | return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; |
| 459 | } |
| 460 | Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, |
| 461 | SymbolRefAttr symbol) { |
| 462 | Operation *symbolTableOp = getNearestSymbolTable(from); |
| 463 | return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; |
| 464 | } |
| 465 | |
| 466 | raw_ostream &mlir::operator<<(raw_ostream &os, |
| 467 | SymbolTable::Visibility visibility) { |
| 468 | switch (visibility) { |
| 469 | case SymbolTable::Visibility::Public: |
| 470 | return os << "public" ; |
| 471 | case SymbolTable::Visibility::Private: |
| 472 | return os << "private" ; |
| 473 | case SymbolTable::Visibility::Nested: |
| 474 | return os << "nested" ; |
| 475 | } |
| 476 | llvm_unreachable("Unexpected visibility" ); |
| 477 | } |
| 478 | |
| 479 | //===----------------------------------------------------------------------===// |
| 480 | // SymbolTable Trait Types |
| 481 | //===----------------------------------------------------------------------===// |
| 482 | |
| 483 | LogicalResult detail::verifySymbolTable(Operation *op) { |
| 484 | if (op->getNumRegions() != 1) |
| 485 | return op->emitOpError() |
| 486 | << "Operations with a 'SymbolTable' must have exactly one region" ; |
| 487 | if (!llvm::hasSingleElement(C&: op->getRegion(index: 0))) |
| 488 | return op->emitOpError() |
| 489 | << "Operations with a 'SymbolTable' must have exactly one block" ; |
| 490 | |
| 491 | // Check that all symbols are uniquely named within child regions. |
| 492 | DenseMap<Attribute, Location> nameToOrigLoc; |
| 493 | for (auto &block : op->getRegion(index: 0)) { |
| 494 | for (auto &op : block) { |
| 495 | // Check for a symbol name attribute. |
| 496 | auto nameAttr = |
| 497 | op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()); |
| 498 | if (!nameAttr) |
| 499 | continue; |
| 500 | |
| 501 | // Try to insert this symbol into the table. |
| 502 | auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc()); |
| 503 | if (!it.second) |
| 504 | return op.emitError() |
| 505 | .append("redefinition of symbol named '" , nameAttr.getValue(), "'" ) |
| 506 | .attachNote(it.first->second) |
| 507 | .append("see existing symbol definition here" ); |
| 508 | } |
| 509 | } |
| 510 | |
| 511 | // Verify any nested symbol user operations. |
| 512 | SymbolTableCollection symbolTable; |
| 513 | auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> { |
| 514 | if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op)) |
| 515 | return WalkResult(user.verifySymbolUses(symbolTable)); |
| 516 | return WalkResult::advance(); |
| 517 | }; |
| 518 | |
| 519 | std::optional<WalkResult> result = |
| 520 | walkSymbolTable(regions: op->getRegions(), callback: verifySymbolUserFn); |
| 521 | return success(IsSuccess: result && !result->wasInterrupted()); |
| 522 | } |
| 523 | |
| 524 | LogicalResult detail::verifySymbol(Operation *op) { |
| 525 | // Verify the name attribute. |
| 526 | if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName())) |
| 527 | return op->emitOpError() << "requires string attribute '" |
| 528 | << mlir::SymbolTable::getSymbolAttrName() << "'" ; |
| 529 | |
| 530 | // Verify the visibility attribute. |
| 531 | if (Attribute vis = op->getAttr(name: mlir::SymbolTable::getVisibilityAttrName())) { |
| 532 | StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis); |
| 533 | if (!visStrAttr) |
| 534 | return op->emitOpError() << "requires visibility attribute '" |
| 535 | << mlir::SymbolTable::getVisibilityAttrName() |
| 536 | << "' to be a string attribute, but got " << vis; |
| 537 | |
| 538 | if (!llvm::is_contained(ArrayRef<StringRef>{"public" , "private" , "nested" }, |
| 539 | visStrAttr.getValue())) |
| 540 | return op->emitOpError() |
| 541 | << "visibility expected to be one of [\"public\", \"private\", " |
| 542 | "\"nested\"], but got " |
| 543 | << visStrAttr; |
| 544 | } |
| 545 | return success(); |
| 546 | } |
| 547 | |
| 548 | //===----------------------------------------------------------------------===// |
| 549 | // Symbol Use Lists |
| 550 | //===----------------------------------------------------------------------===// |
| 551 | |
| 552 | /// Walk all of the symbol references within the given operation, invoking the |
| 553 | /// provided callback for each found use. The callbacks takes the use of the |
| 554 | /// symbol. |
| 555 | static WalkResult |
| 556 | walkSymbolRefs(Operation *op, |
| 557 | function_ref<WalkResult(SymbolTable::SymbolUse)> callback) { |
| 558 | return op->getAttrDictionary().walk<WalkOrder::PreOrder>( |
| 559 | [&](SymbolRefAttr symbolRef) { |
| 560 | if (callback({op, symbolRef}).wasInterrupted()) |
| 561 | return WalkResult::interrupt(); |
| 562 | |
| 563 | // Don't walk nested references. |
| 564 | return WalkResult::skip(); |
| 565 | }); |
| 566 | } |
| 567 | |
| 568 | /// Walk all of the uses, for any symbol, that are nested within the given |
| 569 | /// regions, invoking the provided callback for each. This does not traverse |
| 570 | /// into any nested symbol tables. |
| 571 | static std::optional<WalkResult> |
| 572 | walkSymbolUses(MutableArrayRef<Region> regions, |
| 573 | function_ref<WalkResult(SymbolTable::SymbolUse)> callback) { |
| 574 | return walkSymbolTable(regions, |
| 575 | callback: [&](Operation *op) -> std::optional<WalkResult> { |
| 576 | // Check that this isn't a potentially unknown symbol |
| 577 | // table. |
| 578 | if (isPotentiallyUnknownSymbolTable(op)) |
| 579 | return std::nullopt; |
| 580 | |
| 581 | return walkSymbolRefs(op, callback); |
| 582 | }); |
| 583 | } |
| 584 | /// Walk all of the uses, for any symbol, that are nested within the given |
| 585 | /// operation 'from', invoking the provided callback for each. This does not |
| 586 | /// traverse into any nested symbol tables. |
| 587 | static std::optional<WalkResult> |
| 588 | walkSymbolUses(Operation *from, |
| 589 | function_ref<WalkResult(SymbolTable::SymbolUse)> callback) { |
| 590 | // If this operation has regions, and it, as well as its dialect, isn't |
| 591 | // registered then conservatively fail. The operation may define a |
| 592 | // symbol table, so we can't opaquely know if we should traverse to find |
| 593 | // nested uses. |
| 594 | if (isPotentiallyUnknownSymbolTable(op: from)) |
| 595 | return std::nullopt; |
| 596 | |
| 597 | // Walk the uses on this operation. |
| 598 | if (walkSymbolRefs(op: from, callback).wasInterrupted()) |
| 599 | return WalkResult::interrupt(); |
| 600 | |
| 601 | // Only recurse if this operation is not a symbol table. A symbol table |
| 602 | // defines a new scope, so we can't walk the attributes from within the symbol |
| 603 | // table op. |
| 604 | if (!from->hasTrait<OpTrait::SymbolTable>()) |
| 605 | return walkSymbolUses(regions: from->getRegions(), callback); |
| 606 | return WalkResult::advance(); |
| 607 | } |
| 608 | |
| 609 | namespace { |
| 610 | /// This class represents a single symbol scope. A symbol scope represents the |
| 611 | /// set of operations nested within a symbol table that may reference symbols |
| 612 | /// within that table. A symbol scope does not contain the symbol table |
| 613 | /// operation itself, just its contained operations. A scope ends at leaf |
| 614 | /// operations or another symbol table operation. |
| 615 | struct SymbolScope { |
| 616 | /// Walk the symbol uses within this scope, invoking the given callback. |
| 617 | /// This variant is used when the callback type matches that expected by |
| 618 | /// 'walkSymbolUses'. |
| 619 | template <typename CallbackT, |
| 620 | std::enable_if_t<!std::is_same< |
| 621 | typename llvm::function_traits<CallbackT>::result_t, |
| 622 | void>::value> * = nullptr> |
| 623 | std::optional<WalkResult> walk(CallbackT cback) { |
| 624 | if (Region *region = llvm::dyn_cast_if_present<Region *>(Val&: limit)) |
| 625 | return walkSymbolUses(*region, cback); |
| 626 | return walkSymbolUses(cast<Operation *>(Val&: limit), cback); |
| 627 | } |
| 628 | /// This variant is used when the callback type matches a stripped down type: |
| 629 | /// void(SymbolTable::SymbolUse use) |
| 630 | template <typename CallbackT, |
| 631 | std::enable_if_t<std::is_same< |
| 632 | typename llvm::function_traits<CallbackT>::result_t, |
| 633 | void>::value> * = nullptr> |
| 634 | std::optional<WalkResult> walk(CallbackT cback) { |
| 635 | return walk([=](SymbolTable::SymbolUse use) { |
| 636 | return cback(use), WalkResult::advance(); |
| 637 | }); |
| 638 | } |
| 639 | |
| 640 | /// Walk all of the operations nested under the current scope without |
| 641 | /// traversing into any nested symbol tables. |
| 642 | template <typename CallbackT> |
| 643 | std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) { |
| 644 | if (Region *region = llvm::dyn_cast_if_present<Region *>(Val&: limit)) |
| 645 | return ::walkSymbolTable(*region, cback); |
| 646 | return ::walkSymbolTable(cast<Operation *>(Val&: limit), cback); |
| 647 | } |
| 648 | |
| 649 | /// The representation of the symbol within this scope. |
| 650 | SymbolRefAttr symbol; |
| 651 | |
| 652 | /// The IR unit representing this scope. |
| 653 | llvm::PointerUnion<Operation *, Region *> limit; |
| 654 | }; |
| 655 | } // namespace |
| 656 | |
| 657 | /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'. |
| 658 | static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, |
| 659 | Operation *limit) { |
| 660 | StringAttr symName = SymbolTable::getSymbolName(symbol); |
| 661 | assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit); |
| 662 | |
| 663 | // Compute the ancestors of 'limit'. |
| 664 | SetVector<Operation *, SmallVector<Operation *, 4>, |
| 665 | SmallPtrSet<Operation *, 4>> |
| 666 | limitAncestors; |
| 667 | Operation *limitAncestor = limit; |
| 668 | do { |
| 669 | // Check to see if 'symbol' is an ancestor of 'limit'. |
| 670 | if (limitAncestor == symbol) { |
| 671 | // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr |
| 672 | // doesn't support parent references. |
| 673 | if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == |
| 674 | symbol->getParentOp()) |
| 675 | return {{SymbolRefAttr::get(symName), limit}}; |
| 676 | return {}; |
| 677 | } |
| 678 | |
| 679 | limitAncestors.insert(X: limitAncestor); |
| 680 | } while ((limitAncestor = limitAncestor->getParentOp())); |
| 681 | |
| 682 | // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'. |
| 683 | Operation *commonAncestor = symbol->getParentOp(); |
| 684 | do { |
| 685 | if (limitAncestors.count(key: commonAncestor)) |
| 686 | break; |
| 687 | } while ((commonAncestor = commonAncestor->getParentOp())); |
| 688 | assert(commonAncestor && "'limit' and 'symbol' have no common ancestor" ); |
| 689 | |
| 690 | // Compute the set of valid nested references for 'symbol' as far up to the |
| 691 | // common ancestor as possible. |
| 692 | SmallVector<SymbolRefAttr, 2> references; |
| 693 | bool collectedAllReferences = succeeded( |
| 694 | collectValidReferencesFor(symbol, symName, commonAncestor, references)); |
| 695 | |
| 696 | // Handle the case where the common ancestor is 'limit'. |
| 697 | if (commonAncestor == limit) { |
| 698 | SmallVector<SymbolScope, 2> scopes; |
| 699 | |
| 700 | // Walk each of the ancestors of 'symbol', calling the compute function for |
| 701 | // each one. |
| 702 | Operation *limitIt = symbol->getParentOp(); |
| 703 | for (size_t i = 0, e = references.size(); i != e; |
| 704 | ++i, limitIt = limitIt->getParentOp()) { |
| 705 | assert(limitIt->hasTrait<OpTrait::SymbolTable>()); |
| 706 | scopes.push_back(Elt: {references[i], &limitIt->getRegion(index: 0)}); |
| 707 | } |
| 708 | return scopes; |
| 709 | } |
| 710 | |
| 711 | // Otherwise, we just need the symbol reference for 'symbol' that will be |
| 712 | // used within 'limit'. This is the last reference in the list we computed |
| 713 | // above if we were able to collect all references. |
| 714 | if (!collectedAllReferences) |
| 715 | return {}; |
| 716 | return {{references.back(), limit}}; |
| 717 | } |
| 718 | static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, |
| 719 | Region *limit) { |
| 720 | auto scopes = collectSymbolScopes(symbol, limit: limit->getParentOp()); |
| 721 | |
| 722 | // If we collected some scopes to walk, make sure to constrain the one for |
| 723 | // limit to the specific region requested. |
| 724 | if (!scopes.empty()) |
| 725 | scopes.back().limit = limit; |
| 726 | return scopes; |
| 727 | } |
| 728 | static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, |
| 729 | Region *limit) { |
| 730 | return {{SymbolRefAttr::get(symbol), limit}}; |
| 731 | } |
| 732 | |
| 733 | static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, |
| 734 | Operation *limit) { |
| 735 | SmallVector<SymbolScope, 1> scopes; |
| 736 | auto symbolRef = SymbolRefAttr::get(symbol); |
| 737 | for (auto ®ion : limit->getRegions()) |
| 738 | scopes.push_back(Elt: {symbolRef, ®ion}); |
| 739 | return scopes; |
| 740 | } |
| 741 | |
| 742 | /// Returns true if the given reference 'SubRef' is a sub reference of the |
| 743 | /// reference 'ref', i.e. 'ref' is a further qualified reference. |
| 744 | static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) { |
| 745 | if (ref == subRef) |
| 746 | return true; |
| 747 | |
| 748 | // If the references are not pointer equal, check to see if `subRef` is a |
| 749 | // prefix of `ref`. |
| 750 | if (llvm::isa<FlatSymbolRefAttr>(ref) || |
| 751 | ref.getRootReference() != subRef.getRootReference()) |
| 752 | return false; |
| 753 | |
| 754 | auto refLeafs = ref.getNestedReferences(); |
| 755 | auto subRefLeafs = subRef.getNestedReferences(); |
| 756 | return subRefLeafs.size() < refLeafs.size() && |
| 757 | subRefLeafs == refLeafs.take_front(subRefLeafs.size()); |
| 758 | } |
| 759 | |
| 760 | //===----------------------------------------------------------------------===// |
| 761 | // SymbolTable::getSymbolUses |
| 762 | //===----------------------------------------------------------------------===// |
| 763 | |
| 764 | /// The implementation of SymbolTable::getSymbolUses below. |
| 765 | template <typename FromT> |
| 766 | static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) { |
| 767 | std::vector<SymbolTable::SymbolUse> uses; |
| 768 | auto walkFn = [&](SymbolTable::SymbolUse symbolUse) { |
| 769 | uses.push_back(x: symbolUse); |
| 770 | return WalkResult::advance(); |
| 771 | }; |
| 772 | auto result = walkSymbolUses(from, walkFn); |
| 773 | return result ? std::optional<SymbolTable::UseRange>(std::move(uses)) |
| 774 | : std::nullopt; |
| 775 | } |
| 776 | |
| 777 | /// Get an iterator range for all of the uses, for any symbol, that are nested |
| 778 | /// within the given operation 'from'. This does not traverse into any nested |
| 779 | /// symbol tables, and will also only return uses on 'from' if it does not |
| 780 | /// also define a symbol table. This is because we treat the region as the |
| 781 | /// boundary of the symbol table, and not the op itself. This function returns |
| 782 | /// std::nullopt if there are any unknown operations that may potentially be |
| 783 | /// symbol tables. |
| 784 | auto SymbolTable::getSymbolUses(Operation *from) -> std::optional<UseRange> { |
| 785 | return getSymbolUsesImpl(from); |
| 786 | } |
| 787 | auto SymbolTable::getSymbolUses(Region *from) -> std::optional<UseRange> { |
| 788 | return getSymbolUsesImpl(from: MutableArrayRef<Region>(*from)); |
| 789 | } |
| 790 | |
| 791 | //===----------------------------------------------------------------------===// |
| 792 | // SymbolTable::getSymbolUses |
| 793 | //===----------------------------------------------------------------------===// |
| 794 | |
| 795 | /// The implementation of SymbolTable::getSymbolUses below. |
| 796 | template <typename SymbolT, typename IRUnitT> |
| 797 | static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol, |
| 798 | IRUnitT *limit) { |
| 799 | std::vector<SymbolTable::SymbolUse> uses; |
| 800 | for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { |
| 801 | if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) { |
| 802 | if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())) |
| 803 | uses.push_back(x: symbolUse); |
| 804 | })) |
| 805 | return std::nullopt; |
| 806 | } |
| 807 | return SymbolTable::UseRange(std::move(uses)); |
| 808 | } |
| 809 | |
| 810 | /// Get all of the uses of the given symbol that are nested within the given |
| 811 | /// operation 'from'. This does not traverse into any nested symbol tables. |
| 812 | /// This function returns std::nullopt if there are any unknown operations that |
| 813 | /// may potentially be symbol tables. |
| 814 | auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from) |
| 815 | -> std::optional<UseRange> { |
| 816 | return getSymbolUsesImpl(symbol, from); |
| 817 | } |
| 818 | auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from) |
| 819 | -> std::optional<UseRange> { |
| 820 | return getSymbolUsesImpl(symbol, limit: from); |
| 821 | } |
| 822 | auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from) |
| 823 | -> std::optional<UseRange> { |
| 824 | return getSymbolUsesImpl(symbol, from); |
| 825 | } |
| 826 | auto SymbolTable::getSymbolUses(Operation *symbol, Region *from) |
| 827 | -> std::optional<UseRange> { |
| 828 | return getSymbolUsesImpl(symbol, limit: from); |
| 829 | } |
| 830 | |
| 831 | //===----------------------------------------------------------------------===// |
| 832 | // SymbolTable::symbolKnownUseEmpty |
| 833 | //===----------------------------------------------------------------------===// |
| 834 | |
| 835 | /// The implementation of SymbolTable::symbolKnownUseEmpty below. |
| 836 | template <typename SymbolT, typename IRUnitT> |
| 837 | static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) { |
| 838 | for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { |
| 839 | // Walk all of the symbol uses looking for a reference to 'symbol'. |
| 840 | if (scope.walk([&](SymbolTable::SymbolUse symbolUse) { |
| 841 | return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()) |
| 842 | ? WalkResult::interrupt() |
| 843 | : WalkResult::advance(); |
| 844 | }) != WalkResult::advance()) |
| 845 | return false; |
| 846 | } |
| 847 | return true; |
| 848 | } |
| 849 | |
| 850 | /// Return if the given symbol is known to have no uses that are nested within |
| 851 | /// the given operation 'from'. This does not traverse into any nested symbol |
| 852 | /// tables. This function will also return false if there are any unknown |
| 853 | /// operations that may potentially be symbol tables. |
| 854 | bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) { |
| 855 | return symbolKnownUseEmptyImpl(symbol, from); |
| 856 | } |
| 857 | bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) { |
| 858 | return symbolKnownUseEmptyImpl(symbol, limit: from); |
| 859 | } |
| 860 | bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) { |
| 861 | return symbolKnownUseEmptyImpl(symbol, from); |
| 862 | } |
| 863 | bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) { |
| 864 | return symbolKnownUseEmptyImpl(symbol, limit: from); |
| 865 | } |
| 866 | |
| 867 | //===----------------------------------------------------------------------===// |
| 868 | // SymbolTable::replaceAllSymbolUses |
| 869 | //===----------------------------------------------------------------------===// |
| 870 | |
| 871 | /// Generates a new symbol reference attribute with a new leaf reference. |
| 872 | static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, |
| 873 | FlatSymbolRefAttr newLeafAttr) { |
| 874 | if (llvm::isa<FlatSymbolRefAttr>(oldAttr)) |
| 875 | return newLeafAttr; |
| 876 | auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); |
| 877 | nestedRefs.back() = newLeafAttr; |
| 878 | return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs); |
| 879 | } |
| 880 | |
| 881 | /// The implementation of SymbolTable::replaceAllSymbolUses below. |
| 882 | template <typename SymbolT, typename IRUnitT> |
| 883 | static LogicalResult |
| 884 | replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) { |
| 885 | // Generate a new attribute to replace the given attribute. |
| 886 | FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol); |
| 887 | for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { |
| 888 | SymbolRefAttr oldAttr = scope.symbol; |
| 889 | SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); |
| 890 | AttrTypeReplacer replacer; |
| 891 | replacer.addReplacement( |
| 892 | [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> { |
| 893 | // Regardless of the match, don't walk nested SymbolRefAttrs, we don't |
| 894 | // want to accidentally replace an inner reference. |
| 895 | if (attr == oldAttr) |
| 896 | return {newAttr, WalkResult::skip()}; |
| 897 | // Handle prefix matches. |
| 898 | if (isReferencePrefixOf(oldAttr, attr)) { |
| 899 | auto oldNestedRefs = oldAttr.getNestedReferences(); |
| 900 | auto nestedRefs = attr.getNestedReferences(); |
| 901 | if (oldNestedRefs.empty()) |
| 902 | return {SymbolRefAttr::get(newSymbol, nestedRefs), |
| 903 | WalkResult::skip()}; |
| 904 | |
| 905 | auto newNestedRefs = llvm::to_vector<4>(nestedRefs); |
| 906 | newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr; |
| 907 | return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs), |
| 908 | WalkResult::skip()}; |
| 909 | } |
| 910 | return {attr, WalkResult::skip()}; |
| 911 | }); |
| 912 | |
| 913 | auto walkFn = [&](Operation *op) -> std::optional<WalkResult> { |
| 914 | replacer.replaceElementsIn(op); |
| 915 | return WalkResult::advance(); |
| 916 | }; |
| 917 | if (!scope.walkSymbolTable(walkFn)) |
| 918 | return failure(); |
| 919 | } |
| 920 | return success(); |
| 921 | } |
| 922 | |
| 923 | /// Attempt to replace all uses of the given symbol 'oldSymbol' with the |
| 924 | /// provided symbol 'newSymbol' that are nested within the given operation |
| 925 | /// 'from'. This does not traverse into any nested symbol tables. If there are |
| 926 | /// any unknown operations that may potentially be symbol tables, no uses are |
| 927 | /// replaced and failure is returned. |
| 928 | LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol, |
| 929 | StringAttr newSymbol, |
| 930 | Operation *from) { |
| 931 | return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); |
| 932 | } |
| 933 | LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, |
| 934 | StringAttr newSymbol, |
| 935 | Operation *from) { |
| 936 | return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); |
| 937 | } |
| 938 | LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol, |
| 939 | StringAttr newSymbol, |
| 940 | Region *from) { |
| 941 | return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); |
| 942 | } |
| 943 | LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, |
| 944 | StringAttr newSymbol, |
| 945 | Region *from) { |
| 946 | return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); |
| 947 | } |
| 948 | |
| 949 | //===----------------------------------------------------------------------===// |
| 950 | // SymbolTableCollection |
| 951 | //===----------------------------------------------------------------------===// |
| 952 | |
| 953 | Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
| 954 | StringAttr symbol) { |
| 955 | return getSymbolTable(op: symbolTableOp).lookup(symbol); |
| 956 | } |
| 957 | Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
| 958 | SymbolRefAttr name) { |
| 959 | SmallVector<Operation *, 4> symbols; |
| 960 | if (failed(lookupSymbolIn(symbolTableOp, name, symbols))) |
| 961 | return nullptr; |
| 962 | return symbols.back(); |
| 963 | } |
| 964 | /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by |
| 965 | /// a given SymbolRefAttr. Returns failure if any of the nested references could |
| 966 | /// not be resolved. |
| 967 | LogicalResult |
| 968 | SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
| 969 | SymbolRefAttr name, |
| 970 | SmallVectorImpl<Operation *> &symbols) { |
| 971 | auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) { |
| 972 | return lookupSymbolIn(symbolTableOp, symbol); |
| 973 | }; |
| 974 | return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn); |
| 975 | } |
| 976 | |
| 977 | /// Returns the operation registered with the given symbol name within the |
| 978 | /// closest parent operation of, or including, 'from' with the |
| 979 | /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was |
| 980 | /// found. |
| 981 | Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from, |
| 982 | StringAttr symbol) { |
| 983 | Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from); |
| 984 | return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; |
| 985 | } |
| 986 | Operation * |
| 987 | SymbolTableCollection::lookupNearestSymbolFrom(Operation *from, |
| 988 | SymbolRefAttr symbol) { |
| 989 | Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from); |
| 990 | return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; |
| 991 | } |
| 992 | |
| 993 | /// Lookup, or create, a symbol table for an operation. |
| 994 | SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) { |
| 995 | auto it = symbolTables.try_emplace(Key: op, Args: nullptr); |
| 996 | if (it.second) |
| 997 | it.first->second = std::make_unique<SymbolTable>(args&: op); |
| 998 | return *it.first->second; |
| 999 | } |
| 1000 | |
| 1001 | void SymbolTableCollection::invalidateSymbolTable(Operation *op) { |
| 1002 | symbolTables.erase(Val: op); |
| 1003 | } |
| 1004 | |
| 1005 | //===----------------------------------------------------------------------===// |
| 1006 | // LockedSymbolTableCollection |
| 1007 | //===----------------------------------------------------------------------===// |
| 1008 | |
| 1009 | Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
| 1010 | StringAttr symbol) { |
| 1011 | return getSymbolTable(symbolTableOp).lookup(symbol); |
| 1012 | } |
| 1013 | |
| 1014 | Operation * |
| 1015 | LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
| 1016 | FlatSymbolRefAttr symbol) { |
| 1017 | return lookupSymbolIn(symbolTableOp, symbol.getAttr()); |
| 1018 | } |
| 1019 | |
| 1020 | Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, |
| 1021 | SymbolRefAttr name) { |
| 1022 | SmallVector<Operation *> symbols; |
| 1023 | if (failed(lookupSymbolIn(symbolTableOp, name, symbols))) |
| 1024 | return nullptr; |
| 1025 | return symbols.back(); |
| 1026 | } |
| 1027 | |
| 1028 | LogicalResult LockedSymbolTableCollection::lookupSymbolIn( |
| 1029 | Operation *symbolTableOp, SymbolRefAttr name, |
| 1030 | SmallVectorImpl<Operation *> &symbols) { |
| 1031 | auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) { |
| 1032 | return lookupSymbolIn(symbolTableOp, symbol); |
| 1033 | }; |
| 1034 | return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn); |
| 1035 | } |
| 1036 | |
| 1037 | SymbolTable & |
| 1038 | LockedSymbolTableCollection::getSymbolTable(Operation *symbolTableOp) { |
| 1039 | assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); |
| 1040 | // Try to find an existing symbol table. |
| 1041 | { |
| 1042 | llvm::sys::SmartScopedReader<true> lock(mutex); |
| 1043 | auto it = collection.symbolTables.find(Val: symbolTableOp); |
| 1044 | if (it != collection.symbolTables.end()) |
| 1045 | return *it->second; |
| 1046 | } |
| 1047 | // Create a symbol table for the operation. Perform construction outside of |
| 1048 | // the critical section. |
| 1049 | auto symbolTable = std::make_unique<SymbolTable>(args&: symbolTableOp); |
| 1050 | // Insert the constructed symbol table. |
| 1051 | llvm::sys::SmartScopedWriter<true> lock(mutex); |
| 1052 | return *collection.symbolTables |
| 1053 | .insert(KV: {symbolTableOp, std::move(symbolTable)}) |
| 1054 | .first->second; |
| 1055 | } |
| 1056 | |
| 1057 | //===----------------------------------------------------------------------===// |
| 1058 | // SymbolUserMap |
| 1059 | //===----------------------------------------------------------------------===// |
| 1060 | |
| 1061 | SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable, |
| 1062 | Operation *symbolTableOp) |
| 1063 | : symbolTable(symbolTable) { |
| 1064 | // Walk each of the symbol tables looking for discardable callgraph nodes. |
| 1065 | SmallVector<Operation *> symbols; |
| 1066 | auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { |
| 1067 | for (Operation &nestedOp : symbolTableOp->getRegion(index: 0).getOps()) { |
| 1068 | auto symbolUses = SymbolTable::getSymbolUses(from: &nestedOp); |
| 1069 | assert(symbolUses && "expected uses to be valid" ); |
| 1070 | |
| 1071 | for (const SymbolTable::SymbolUse &use : *symbolUses) { |
| 1072 | symbols.clear(); |
| 1073 | (void)symbolTable.lookupSymbolIn(symbolTableOp, name: use.getSymbolRef(), |
| 1074 | symbols); |
| 1075 | for (Operation *symbolOp : symbols) |
| 1076 | symbolToUsers[symbolOp].insert(X: use.getUser()); |
| 1077 | } |
| 1078 | } |
| 1079 | }; |
| 1080 | // We just set `allSymUsesVisible` to false here because it isn't necessary |
| 1081 | // for building the user map. |
| 1082 | SymbolTable::walkSymbolTables(op: symbolTableOp, /*allSymUsesVisible=*/false, |
| 1083 | callback: walkFn); |
| 1084 | } |
| 1085 | |
| 1086 | void SymbolUserMap::replaceAllUsesWith(Operation *symbol, |
| 1087 | StringAttr newSymbolName) { |
| 1088 | auto it = symbolToUsers.find(Val: symbol); |
| 1089 | if (it == symbolToUsers.end()) |
| 1090 | return; |
| 1091 | |
| 1092 | // Replace the uses within the users of `symbol`. |
| 1093 | for (Operation *user : it->second) |
| 1094 | (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user); |
| 1095 | |
| 1096 | // Move the current users of `symbol` to the new symbol if it is in the |
| 1097 | // symbol table. |
| 1098 | Operation *newSymbol = |
| 1099 | symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName); |
| 1100 | if (newSymbol != symbol) { |
| 1101 | // Transfer over the users to the new symbol. The reference to the old one |
| 1102 | // is fetched again as the iterator is invalidated during the insertion. |
| 1103 | auto newIt = symbolToUsers.try_emplace(Key: newSymbol); |
| 1104 | auto oldIt = symbolToUsers.find(Val: symbol); |
| 1105 | assert(oldIt != symbolToUsers.end() && "missing old users list" ); |
| 1106 | if (newIt.second) |
| 1107 | newIt.first->second = std::move(oldIt->second); |
| 1108 | else |
| 1109 | newIt.first->second.set_union(oldIt->second); |
| 1110 | symbolToUsers.erase(I: oldIt); |
| 1111 | } |
| 1112 | } |
| 1113 | |
| 1114 | //===----------------------------------------------------------------------===// |
| 1115 | // Visibility parsing implementation. |
| 1116 | //===----------------------------------------------------------------------===// |
| 1117 | |
| 1118 | ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser, |
| 1119 | NamedAttrList &attrs) { |
| 1120 | StringRef visibility; |
| 1121 | if (parser.parseOptionalKeyword(keyword: &visibility, allowedValues: {"public" , "private" , "nested" })) |
| 1122 | return failure(); |
| 1123 | |
| 1124 | StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility); |
| 1125 | attrs.push_back(newAttribute: parser.getBuilder().getNamedAttr( |
| 1126 | name: SymbolTable::getVisibilityAttrName(), val: visibilityAttr)); |
| 1127 | return success(); |
| 1128 | } |
| 1129 | |
| 1130 | //===----------------------------------------------------------------------===// |
| 1131 | // Symbol Interfaces |
| 1132 | //===----------------------------------------------------------------------===// |
| 1133 | |
| 1134 | /// Include the generated symbol interfaces. |
| 1135 | #include "mlir/IR/SymbolInterfaces.cpp.inc" |
| 1136 | |