| 1 | //===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "IRNumbering.h" |
| 10 | #include "mlir/Bytecode/BytecodeImplementation.h" |
| 11 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
| 12 | #include "mlir/Bytecode/BytecodeWriter.h" |
| 13 | #include "mlir/Bytecode/Encoding.h" |
| 14 | #include "mlir/IR/AsmState.h" |
| 15 | #include "mlir/IR/BuiltinTypes.h" |
| 16 | #include "mlir/IR/OpDefinition.h" |
| 17 | |
| 18 | using namespace mlir; |
| 19 | using namespace mlir::bytecode::detail; |
| 20 | |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | // NumberingDialectWriter |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | |
| 25 | struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter { |
| 26 | NumberingDialectWriter( |
| 27 | IRNumberingState &state, |
| 28 | llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap) |
| 29 | : state(state), dialectVersionMap(dialectVersionMap) {} |
| 30 | |
| 31 | void writeAttribute(Attribute attr) override { state.number(attr); } |
| 32 | void writeOptionalAttribute(Attribute attr) override { |
| 33 | if (attr) |
| 34 | state.number(attr); |
| 35 | } |
| 36 | void writeType(Type type) override { state.number(type); } |
| 37 | void writeResourceHandle(const AsmDialectResourceHandle &resource) override { |
| 38 | state.number(dialect: resource.getDialect(), resources: resource); |
| 39 | } |
| 40 | |
| 41 | /// Stubbed out methods that are not used for numbering. |
| 42 | void writeVarInt(uint64_t) override {} |
| 43 | void writeSignedVarInt(int64_t value) override {} |
| 44 | void writeAPIntWithKnownWidth(const APInt &value) override {} |
| 45 | void writeAPFloatWithKnownSemantics(const APFloat &value) override {} |
| 46 | void writeOwnedString(StringRef) override { |
| 47 | // TODO: It might be nice to prenumber strings and sort by the number of |
| 48 | // references. This could potentially be useful for optimizing things like |
| 49 | // file locations. |
| 50 | } |
| 51 | void writeOwnedBlob(ArrayRef<char> blob) override {} |
| 52 | void writeOwnedBool(bool value) override {} |
| 53 | |
| 54 | int64_t getBytecodeVersion() const override { |
| 55 | return state.getDesiredBytecodeVersion(); |
| 56 | } |
| 57 | |
| 58 | FailureOr<const DialectVersion *> |
| 59 | getDialectVersion(StringRef dialectName) const override { |
| 60 | auto dialectEntry = dialectVersionMap.find(Key: dialectName); |
| 61 | if (dialectEntry == dialectVersionMap.end()) |
| 62 | return failure(); |
| 63 | return dialectEntry->getValue().get(); |
| 64 | } |
| 65 | |
| 66 | /// The parent numbering state that is populated by this writer. |
| 67 | IRNumberingState &state; |
| 68 | |
| 69 | /// A map containing dialect version information for each dialect to emit. |
| 70 | llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap; |
| 71 | }; |
| 72 | |
| 73 | //===----------------------------------------------------------------------===// |
| 74 | // IR Numbering |
| 75 | //===----------------------------------------------------------------------===// |
| 76 | |
| 77 | /// Group and sort the elements of the given range by their parent dialect. This |
| 78 | /// grouping is applied to sub-sections of the ranged defined by how many bytes |
| 79 | /// it takes to encode a varint index to that sub-section. |
| 80 | template <typename T> |
| 81 | static void groupByDialectPerByte(T range) { |
| 82 | if (range.empty()) |
| 83 | return; |
| 84 | |
| 85 | // A functor used to sort by a given dialect, with a desired dialect to be |
| 86 | // ordered first (to better enable sharing of dialects across byte groups). |
| 87 | auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs, |
| 88 | const auto &rhs) { |
| 89 | if (lhs->dialect->number == dialectToOrderFirst) |
| 90 | return rhs->dialect->number != dialectToOrderFirst; |
| 91 | if (rhs->dialect->number == dialectToOrderFirst) |
| 92 | return false; |
| 93 | return lhs->dialect->number < rhs->dialect->number; |
| 94 | }; |
| 95 | |
| 96 | unsigned dialectToOrderFirst = 0; |
| 97 | size_t elementsInByteGroup = 0; |
| 98 | auto iterRange = range; |
| 99 | for (unsigned i = 1; i < 9; ++i) { |
| 100 | // Update the number of elements in the current byte grouping. Reminder |
| 101 | // that varint encodes 7-bits per byte, so that's how we compute the |
| 102 | // number of elements in each byte grouping. |
| 103 | elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup; |
| 104 | |
| 105 | // Slice out the sub-set of elements that are in the current byte grouping |
| 106 | // to be sorted. |
| 107 | auto byteSubRange = iterRange.take_front(elementsInByteGroup); |
| 108 | iterRange = iterRange.drop_front(byteSubRange.size()); |
| 109 | |
| 110 | // Sort the sub range for this byte. |
| 111 | llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) { |
| 112 | return sortByDialect(dialectToOrderFirst, lhs, rhs); |
| 113 | }); |
| 114 | |
| 115 | // Update the dialect to order first to be the dialect at the end of the |
| 116 | // current grouping. This seeks to allow larger dialect groupings across |
| 117 | // byte boundaries. |
| 118 | dialectToOrderFirst = byteSubRange.back()->dialect->number; |
| 119 | |
| 120 | // If the data range is now empty, we are done. |
| 121 | if (iterRange.empty()) |
| 122 | break; |
| 123 | } |
| 124 | |
| 125 | // Assign the entry numbers based on the sort order. |
| 126 | for (auto [idx, value] : llvm::enumerate(range)) |
| 127 | value->number = idx; |
| 128 | } |
| 129 | |
| 130 | IRNumberingState::IRNumberingState(Operation *op, |
| 131 | const BytecodeWriterConfig &config) |
| 132 | : config(config) { |
| 133 | computeGlobalNumberingState(rootOp: op); |
| 134 | |
| 135 | // Number the root operation. |
| 136 | number(op&: *op); |
| 137 | |
| 138 | // A worklist of region contexts to number and the next value id before that |
| 139 | // region. |
| 140 | SmallVector<std::pair<Region *, unsigned>, 8> numberContext; |
| 141 | |
| 142 | // Functor to push the regions of the given operation onto the numbering |
| 143 | // context. |
| 144 | auto addOpRegionsToNumber = [&](Operation *op) { |
| 145 | MutableArrayRef<Region> regions = op->getRegions(); |
| 146 | if (regions.empty()) |
| 147 | return; |
| 148 | |
| 149 | // Isolated regions don't share value numbers with their parent, so we can |
| 150 | // start numbering these regions at zero. |
| 151 | unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID; |
| 152 | for (Region ®ion : regions) |
| 153 | numberContext.emplace_back(Args: ®ion, Args&: opFirstValueID); |
| 154 | }; |
| 155 | addOpRegionsToNumber(op); |
| 156 | |
| 157 | // Iteratively process each of the nested regions. |
| 158 | while (!numberContext.empty()) { |
| 159 | Region *region; |
| 160 | std::tie(args&: region, args&: nextValueID) = numberContext.pop_back_val(); |
| 161 | number(region&: *region); |
| 162 | |
| 163 | // Traverse into nested regions. |
| 164 | for (Operation &op : region->getOps()) |
| 165 | addOpRegionsToNumber(&op); |
| 166 | } |
| 167 | |
| 168 | // Number each of the dialects. For now this is just in the order they were |
| 169 | // found, given that the number of dialects on average is small enough to fit |
| 170 | // within a singly byte (128). If we ever have real world use cases that have |
| 171 | // a huge number of dialects, this could be made more intelligent. |
| 172 | for (auto [idx, dialect] : llvm::enumerate(First&: dialects)) |
| 173 | dialect.second->number = idx; |
| 174 | |
| 175 | // Number each of the recorded components within each dialect. |
| 176 | |
| 177 | // First sort by ref count so that the most referenced elements are first. We |
| 178 | // try to bias more heavily used elements to the front. This allows for more |
| 179 | // frequently referenced things to be encoded using smaller varints. |
| 180 | auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) { |
| 181 | return lhs->refCount > rhs->refCount; |
| 182 | }; |
| 183 | llvm::stable_sort(Range&: orderedAttrs, C: sortByRefCountFn); |
| 184 | llvm::stable_sort(Range&: orderedOpNames, C: sortByRefCountFn); |
| 185 | llvm::stable_sort(Range&: orderedTypes, C: sortByRefCountFn); |
| 186 | |
| 187 | // After that, we apply a secondary ordering based on the parent dialect. This |
| 188 | // ordering is applied to sub-sections of the element list defined by how many |
| 189 | // bytes it takes to encode a varint index to that sub-section. This allows |
| 190 | // for more efficiently encoding components of the same dialect (e.g. we only |
| 191 | // have to encode the dialect reference once). |
| 192 | groupByDialectPerByte(range: llvm::MutableArrayRef(orderedAttrs)); |
| 193 | groupByDialectPerByte(range: llvm::MutableArrayRef(orderedOpNames)); |
| 194 | groupByDialectPerByte(range: llvm::MutableArrayRef(orderedTypes)); |
| 195 | |
| 196 | // Finalize the numbering of the dialect resources. |
| 197 | finalizeDialectResourceNumberings(rootOp: op); |
| 198 | } |
| 199 | |
| 200 | void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) { |
| 201 | // A simple state struct tracking data used when walking operations. |
| 202 | struct StackState { |
| 203 | /// The operation currently being walked. |
| 204 | Operation *op; |
| 205 | |
| 206 | /// The numbering of the operation. |
| 207 | OperationNumbering *numbering; |
| 208 | |
| 209 | /// A flag indicating if the current state or one of its parents has |
| 210 | /// unresolved isolation status. This is tracked separately from the |
| 211 | /// isIsolatedFromAbove bit on `numbering` because we need to be able to |
| 212 | /// handle the given case: |
| 213 | /// top.op { |
| 214 | /// %value = ... |
| 215 | /// middle.op { |
| 216 | /// %value2 = ... |
| 217 | /// inner.op { |
| 218 | /// // Here we mark `inner.op` as not isolated. Note `middle.op` |
| 219 | /// // isn't known not isolated yet. |
| 220 | /// use.op %value2 |
| 221 | /// |
| 222 | /// // Here inner.op is already known to be non-isolated, but |
| 223 | /// // `middle.op` is now also discovered to be non-isolated. |
| 224 | /// use.op %value |
| 225 | /// } |
| 226 | /// } |
| 227 | /// } |
| 228 | bool hasUnresolvedIsolation; |
| 229 | }; |
| 230 | |
| 231 | // Compute a global operation ID numbering according to the pre-order walk of |
| 232 | // the IR. This is used as reference to construct use-list orders. |
| 233 | unsigned operationID = 0; |
| 234 | |
| 235 | // Walk each of the operations within the IR, tracking a stack of operations |
| 236 | // as we recurse into nested regions. This walk method hooks in at two stages |
| 237 | // during the walk: |
| 238 | // |
| 239 | // BeforeAllRegions: |
| 240 | // Here we generate a numbering for the operation and push it onto the |
| 241 | // stack if it has regions. We also compute the isolation status of parent |
| 242 | // regions at this stage. This is done by checking the parent regions of |
| 243 | // operands used by the operation, and marking each region between the |
| 244 | // the operand region and the current as not isolated. See |
| 245 | // StackState::hasUnresolvedIsolation above for an example. |
| 246 | // |
| 247 | // AfterAllRegions: |
| 248 | // Here we pop the operation from the stack, and if it hasn't been marked |
| 249 | // as non-isolated, we mark it as so. A non-isolated use would have been |
| 250 | // found while walking the regions, so it is safe to mark the operation at |
| 251 | // this point. |
| 252 | // |
| 253 | SmallVector<StackState> opStack; |
| 254 | rootOp->walk(callback: [&](Operation *op, const WalkStage &stage) { |
| 255 | // After visiting all nested regions, we pop the operation from the stack. |
| 256 | if (op->getNumRegions() && stage.isAfterAllRegions()) { |
| 257 | // If no non-isolated uses were found, we can safely mark this operation |
| 258 | // as isolated from above. |
| 259 | OperationNumbering *numbering = opStack.pop_back_val().numbering; |
| 260 | if (!numbering->isIsolatedFromAbove.has_value()) |
| 261 | numbering->isIsolatedFromAbove = true; |
| 262 | return; |
| 263 | } |
| 264 | |
| 265 | // When visiting before nested regions, we process "IsolatedFromAbove" |
| 266 | // checks and compute the number for this operation. |
| 267 | if (!stage.isBeforeAllRegions()) |
| 268 | return; |
| 269 | // Update the isolation status of parent regions if any have yet to be |
| 270 | // resolved. |
| 271 | if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) { |
| 272 | Region *parentRegion = op->getParentRegion(); |
| 273 | for (Value operand : op->getOperands()) { |
| 274 | Region *operandRegion = operand.getParentRegion(); |
| 275 | if (operandRegion == parentRegion) |
| 276 | continue; |
| 277 | // We've found a use of an operand outside of the current region, |
| 278 | // walk the operation stack searching for the parent operation, |
| 279 | // marking every region on the way as not isolated. |
| 280 | Operation *operandContainerOp = operandRegion->getParentOp(); |
| 281 | auto it = std::find_if( |
| 282 | first: opStack.rbegin(), last: opStack.rend(), pred: [=](const StackState &it) { |
| 283 | // We only need to mark up to the container region, or the first |
| 284 | // that has an unresolved status. |
| 285 | return !it.hasUnresolvedIsolation || it.op == operandContainerOp; |
| 286 | }); |
| 287 | assert(it != opStack.rend() && "expected to find the container" ); |
| 288 | for (auto &state : llvm::make_range(x: opStack.rbegin(), y: it)) { |
| 289 | // If we stopped at a region that knows its isolation status, we can |
| 290 | // stop updating the isolation status for the parent regions. |
| 291 | state.hasUnresolvedIsolation = it->hasUnresolvedIsolation; |
| 292 | state.numbering->isIsolatedFromAbove = false; |
| 293 | } |
| 294 | } |
| 295 | } |
| 296 | |
| 297 | // Compute the number for this op and push it onto the stack. |
| 298 | auto *numbering = |
| 299 | new (opAllocator.Allocate()) OperationNumbering(operationID++); |
| 300 | if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
| 301 | numbering->isIsolatedFromAbove = true; |
| 302 | operations.try_emplace(Key: op, Args&: numbering); |
| 303 | if (op->getNumRegions()) { |
| 304 | opStack.emplace_back(Args: StackState{ |
| 305 | .op: op, .numbering: numbering, .hasUnresolvedIsolation: !numbering->isIsolatedFromAbove.has_value()}); |
| 306 | } |
| 307 | }); |
| 308 | } |
| 309 | |
| 310 | void IRNumberingState::number(Attribute attr) { |
| 311 | auto it = attrs.try_emplace(Key: attr); |
| 312 | if (!it.second) { |
| 313 | ++it.first->second->refCount; |
| 314 | return; |
| 315 | } |
| 316 | auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr); |
| 317 | it.first->second = numbering; |
| 318 | orderedAttrs.push_back(x: numbering); |
| 319 | |
| 320 | // Check for OpaqueAttr, which is a dialect-specific attribute that didn't |
| 321 | // have a registered dialect when it got created. We don't want to encode this |
| 322 | // as the builtin OpaqueAttr, we want to encode it as if the dialect was |
| 323 | // actually loaded. |
| 324 | if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) { |
| 325 | numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace()); |
| 326 | return; |
| 327 | } |
| 328 | numbering->dialect = &numberDialect(dialect: &attr.getDialect()); |
| 329 | |
| 330 | // If this attribute will be emitted using the bytecode format, perform a |
| 331 | // dummy writing to number any nested components. |
| 332 | // TODO: We don't allow custom encodings for mutable attributes right now. |
| 333 | if (!attr.hasTrait<AttributeTrait::IsMutable>()) { |
| 334 | // Try overriding emission with callbacks. |
| 335 | for (const auto &callback : config.getAttributeWriterCallbacks()) { |
| 336 | NumberingDialectWriter writer(*this, config.getDialectVersionMap()); |
| 337 | // The client has the ability to override the group name through the |
| 338 | // callback. |
| 339 | std::optional<StringRef> groupNameOverride; |
| 340 | if (succeeded(Result: callback->write(entry: attr, name&: groupNameOverride, writer))) { |
| 341 | if (groupNameOverride.has_value()) |
| 342 | numbering->dialect = &numberDialect(dialect: *groupNameOverride); |
| 343 | return; |
| 344 | } |
| 345 | } |
| 346 | |
| 347 | if (const auto *interface = numbering->dialect->interface) { |
| 348 | NumberingDialectWriter writer(*this, config.getDialectVersionMap()); |
| 349 | if (succeeded(Result: interface->writeAttribute(attr, writer))) |
| 350 | return; |
| 351 | } |
| 352 | } |
| 353 | // If this attribute will be emitted using the fallback, number the nested |
| 354 | // dialect resources. We don't number everything (e.g. no nested |
| 355 | // attributes/types), because we don't want to encode things we won't decode |
| 356 | // (the textual format can't really share much). |
| 357 | AsmState tempState(attr.getContext()); |
| 358 | llvm::raw_null_ostream dummyOS; |
| 359 | attr.print(os&: dummyOS, state&: tempState); |
| 360 | |
| 361 | // Number the used dialect resources. |
| 362 | for (const auto &it : tempState.getDialectResources()) |
| 363 | number(dialect: it.getFirst(), resources: it.getSecond().getArrayRef()); |
| 364 | } |
| 365 | |
| 366 | void IRNumberingState::number(Block &block) { |
| 367 | // Number the arguments of the block. |
| 368 | for (BlockArgument arg : block.getArguments()) { |
| 369 | valueIDs.try_emplace(Key: arg, Args: nextValueID++); |
| 370 | number(attr: arg.getLoc()); |
| 371 | number(type: arg.getType()); |
| 372 | } |
| 373 | |
| 374 | // Number the operations in this block. |
| 375 | unsigned &numOps = blockOperationCounts[&block]; |
| 376 | for (Operation &op : block) { |
| 377 | number(op); |
| 378 | ++numOps; |
| 379 | } |
| 380 | } |
| 381 | |
| 382 | auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & { |
| 383 | DialectNumbering *&numbering = registeredDialects[dialect]; |
| 384 | if (!numbering) { |
| 385 | numbering = &numberDialect(dialect: dialect->getNamespace()); |
| 386 | numbering->interface = dyn_cast<BytecodeDialectInterface>(Val: dialect); |
| 387 | numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(Val: dialect); |
| 388 | } |
| 389 | return *numbering; |
| 390 | } |
| 391 | |
| 392 | auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & { |
| 393 | DialectNumbering *&numbering = dialects[dialect]; |
| 394 | if (!numbering) { |
| 395 | numbering = new (dialectAllocator.Allocate()) |
| 396 | DialectNumbering(dialect, dialects.size() - 1); |
| 397 | } |
| 398 | return *numbering; |
| 399 | } |
| 400 | |
| 401 | void IRNumberingState::number(Region ®ion) { |
| 402 | if (region.empty()) |
| 403 | return; |
| 404 | size_t firstValueID = nextValueID; |
| 405 | |
| 406 | // Number the blocks within this region. |
| 407 | size_t blockCount = 0; |
| 408 | for (auto it : llvm::enumerate(First&: region)) { |
| 409 | blockIDs.try_emplace(Key: &it.value(), Args: it.index()); |
| 410 | number(block&: it.value()); |
| 411 | ++blockCount; |
| 412 | } |
| 413 | |
| 414 | // Remember the number of blocks and values in this region. |
| 415 | regionBlockValueCounts.try_emplace(Key: ®ion, Args&: blockCount, |
| 416 | Args: nextValueID - firstValueID); |
| 417 | } |
| 418 | |
| 419 | void IRNumberingState::number(Operation &op) { |
| 420 | // Number the components of an operation that won't be numbered elsewhere |
| 421 | // (e.g. we don't number operands, regions, or successors here). |
| 422 | number(opName: op.getName()); |
| 423 | for (OpResult result : op.getResults()) { |
| 424 | valueIDs.try_emplace(Key: result, Args: nextValueID++); |
| 425 | number(type: result.getType()); |
| 426 | } |
| 427 | |
| 428 | // Prior to a version with native property encoding, or when properties are |
| 429 | // not used, we need to number also the merged dictionary containing both the |
| 430 | // inherent and discardable attribute. |
| 431 | DictionaryAttr dictAttr; |
| 432 | if (config.getDesiredBytecodeVersion() >= bytecode::kNativePropertiesEncoding) |
| 433 | dictAttr = op.getRawDictionaryAttrs(); |
| 434 | else |
| 435 | dictAttr = op.getAttrDictionary(); |
| 436 | // Only number the operation's dictionary if it isn't empty. |
| 437 | if (!dictAttr.empty()) |
| 438 | number(dictAttr); |
| 439 | |
| 440 | // Visit the operation properties (if any) to make sure referenced attributes |
| 441 | // are numbered. |
| 442 | if (config.getDesiredBytecodeVersion() >= |
| 443 | bytecode::kNativePropertiesEncoding && |
| 444 | op.getPropertiesStorageSize()) { |
| 445 | if (op.isRegistered()) { |
| 446 | // Operation that have properties *must* implement this interface. |
| 447 | auto iface = cast<BytecodeOpInterface>(op); |
| 448 | NumberingDialectWriter writer(*this, config.getDialectVersionMap()); |
| 449 | iface.writeProperties(writer); |
| 450 | } else { |
| 451 | // Unregistered op are storing properties as an optional attribute. |
| 452 | if (Attribute prop = *op.getPropertiesStorage().as<Attribute *>()) |
| 453 | number(attr: prop); |
| 454 | } |
| 455 | } |
| 456 | |
| 457 | number(attr: op.getLoc()); |
| 458 | } |
| 459 | |
| 460 | void IRNumberingState::number(OperationName opName) { |
| 461 | OpNameNumbering *&numbering = opNames[opName]; |
| 462 | if (numbering) { |
| 463 | ++numbering->refCount; |
| 464 | return; |
| 465 | } |
| 466 | DialectNumbering *dialectNumber = nullptr; |
| 467 | if (Dialect *dialect = opName.getDialect()) |
| 468 | dialectNumber = &numberDialect(dialect); |
| 469 | else |
| 470 | dialectNumber = &numberDialect(dialect: opName.getDialectNamespace()); |
| 471 | |
| 472 | numbering = |
| 473 | new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName); |
| 474 | orderedOpNames.push_back(x: numbering); |
| 475 | } |
| 476 | |
| 477 | void IRNumberingState::number(Type type) { |
| 478 | auto it = types.try_emplace(Key: type); |
| 479 | if (!it.second) { |
| 480 | ++it.first->second->refCount; |
| 481 | return; |
| 482 | } |
| 483 | auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type); |
| 484 | it.first->second = numbering; |
| 485 | orderedTypes.push_back(x: numbering); |
| 486 | |
| 487 | // Check for OpaqueType, which is a dialect-specific type that didn't have a |
| 488 | // registered dialect when it got created. We don't want to encode this as the |
| 489 | // builtin OpaqueType, we want to encode it as if the dialect was actually |
| 490 | // loaded. |
| 491 | if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) { |
| 492 | numbering->dialect = &numberDialect(opaqueType.getDialectNamespace()); |
| 493 | return; |
| 494 | } |
| 495 | numbering->dialect = &numberDialect(dialect: &type.getDialect()); |
| 496 | |
| 497 | // If this type will be emitted using the bytecode format, perform a dummy |
| 498 | // writing to number any nested components. |
| 499 | // TODO: We don't allow custom encodings for mutable types right now. |
| 500 | if (!type.hasTrait<TypeTrait::IsMutable>()) { |
| 501 | // Try overriding emission with callbacks. |
| 502 | for (const auto &callback : config.getTypeWriterCallbacks()) { |
| 503 | NumberingDialectWriter writer(*this, config.getDialectVersionMap()); |
| 504 | // The client has the ability to override the group name through the |
| 505 | // callback. |
| 506 | std::optional<StringRef> groupNameOverride; |
| 507 | if (succeeded(Result: callback->write(entry: type, name&: groupNameOverride, writer))) { |
| 508 | if (groupNameOverride.has_value()) |
| 509 | numbering->dialect = &numberDialect(dialect: *groupNameOverride); |
| 510 | return; |
| 511 | } |
| 512 | } |
| 513 | |
| 514 | // If this attribute will be emitted using the bytecode format, perform a |
| 515 | // dummy writing to number any nested components. |
| 516 | if (const auto *interface = numbering->dialect->interface) { |
| 517 | NumberingDialectWriter writer(*this, config.getDialectVersionMap()); |
| 518 | if (succeeded(Result: interface->writeType(type, writer))) |
| 519 | return; |
| 520 | } |
| 521 | } |
| 522 | // If this type will be emitted using the fallback, number the nested dialect |
| 523 | // resources. We don't number everything (e.g. no nested attributes/types), |
| 524 | // because we don't want to encode things we won't decode (the textual format |
| 525 | // can't really share much). |
| 526 | AsmState tempState(type.getContext()); |
| 527 | llvm::raw_null_ostream dummyOS; |
| 528 | type.print(os&: dummyOS, state&: tempState); |
| 529 | |
| 530 | // Number the used dialect resources. |
| 531 | for (const auto &it : tempState.getDialectResources()) |
| 532 | number(dialect: it.getFirst(), resources: it.getSecond().getArrayRef()); |
| 533 | } |
| 534 | |
| 535 | void IRNumberingState::number(Dialect *dialect, |
| 536 | ArrayRef<AsmDialectResourceHandle> resources) { |
| 537 | DialectNumbering &dialectNumber = numberDialect(dialect); |
| 538 | assert( |
| 539 | dialectNumber.asmInterface && |
| 540 | "expected dialect owning a resource to implement OpAsmDialectInterface" ); |
| 541 | |
| 542 | for (const auto &resource : resources) { |
| 543 | // Check if this is a newly seen resource. |
| 544 | if (!dialectNumber.resources.insert(X: resource)) |
| 545 | return; |
| 546 | |
| 547 | auto *numbering = |
| 548 | new (resourceAllocator.Allocate()) DialectResourceNumbering( |
| 549 | dialectNumber.asmInterface->getResourceKey(handle: resource)); |
| 550 | dialectNumber.resourceMap.insert(KV: {numbering->key, numbering}); |
| 551 | dialectResources.try_emplace(Key: resource, Args&: numbering); |
| 552 | } |
| 553 | } |
| 554 | |
| 555 | int64_t IRNumberingState::getDesiredBytecodeVersion() const { |
| 556 | return config.getDesiredBytecodeVersion(); |
| 557 | } |
| 558 | |
| 559 | namespace { |
| 560 | /// A dummy resource builder used to number dialect resources. |
| 561 | struct NumberingResourceBuilder : public AsmResourceBuilder { |
| 562 | NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID) |
| 563 | : dialect(dialect), nextResourceID(nextResourceID) {} |
| 564 | ~NumberingResourceBuilder() override = default; |
| 565 | |
| 566 | void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final { |
| 567 | numberEntry(key); |
| 568 | } |
| 569 | void buildBool(StringRef key, bool) final { numberEntry(key); } |
| 570 | void buildString(StringRef key, StringRef) final { |
| 571 | // TODO: We could pre-number the value string here as well. |
| 572 | numberEntry(key); |
| 573 | } |
| 574 | |
| 575 | /// Number the dialect entry for the given key. |
| 576 | void numberEntry(StringRef key) { |
| 577 | // TODO: We could pre-number resource key strings here as well. |
| 578 | |
| 579 | auto *it = dialect->resourceMap.find(Key: key); |
| 580 | if (it != dialect->resourceMap.end()) { |
| 581 | it->second->number = nextResourceID++; |
| 582 | it->second->isDeclaration = false; |
| 583 | } |
| 584 | } |
| 585 | |
| 586 | DialectNumbering *dialect; |
| 587 | unsigned &nextResourceID; |
| 588 | }; |
| 589 | } // namespace |
| 590 | |
| 591 | void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) { |
| 592 | unsigned nextResourceID = 0; |
| 593 | for (DialectNumbering &dialect : getDialects()) { |
| 594 | if (!dialect.asmInterface) |
| 595 | continue; |
| 596 | NumberingResourceBuilder entryBuilder(&dialect, nextResourceID); |
| 597 | dialect.asmInterface->buildResources(op: rootOp, referencedResources: dialect.resources, |
| 598 | builder&: entryBuilder); |
| 599 | |
| 600 | // Number any resources that weren't added by the dialect. This can happen |
| 601 | // if there was no backing data to the resource, but we still want these |
| 602 | // resource references to roundtrip, so we number them and indicate that the |
| 603 | // data is missing. |
| 604 | for (const auto &it : dialect.resourceMap) |
| 605 | if (it.second->isDeclaration) |
| 606 | it.second->number = nextResourceID++; |
| 607 | } |
| 608 | } |
| 609 | |