| 1 | //===- BytecodeWriter.cpp - MLIR Bytecode Writer --------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Bytecode/BytecodeWriter.h" |
| 10 | #include "IRNumbering.h" |
| 11 | #include "mlir/Bytecode/BytecodeImplementation.h" |
| 12 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
| 13 | #include "mlir/Bytecode/Encoding.h" |
| 14 | #include "mlir/IR/Attributes.h" |
| 15 | #include "mlir/IR/Diagnostics.h" |
| 16 | #include "mlir/IR/OpImplementation.h" |
| 17 | #include "llvm/ADT/ArrayRef.h" |
| 18 | #include "llvm/ADT/CachedHashString.h" |
| 19 | #include "llvm/ADT/MapVector.h" |
| 20 | #include "llvm/ADT/SmallVector.h" |
| 21 | #include "llvm/Support/Debug.h" |
| 22 | #include "llvm/Support/Endian.h" |
| 23 | #include "llvm/Support/raw_ostream.h" |
| 24 | #include <optional> |
| 25 | |
| 26 | #define DEBUG_TYPE "mlir-bytecode-writer" |
| 27 | |
| 28 | using namespace mlir; |
| 29 | using namespace mlir::bytecode::detail; |
| 30 | |
| 31 | //===----------------------------------------------------------------------===// |
| 32 | // BytecodeWriterConfig |
| 33 | //===----------------------------------------------------------------------===// |
| 34 | |
| 35 | struct BytecodeWriterConfig::Impl { |
| 36 | Impl(StringRef producer) : producer(producer) {} |
| 37 | |
| 38 | /// Version to use when writing. |
| 39 | /// Note: This only differs from kVersion if a specific version is set. |
| 40 | int64_t bytecodeVersion = bytecode::kVersion; |
| 41 | |
| 42 | /// A flag specifying whether to elide emission of resources into the bytecode |
| 43 | /// file. |
| 44 | bool shouldElideResourceData = false; |
| 45 | |
| 46 | /// A map containing dialect version information for each dialect to emit. |
| 47 | llvm::StringMap<std::unique_ptr<DialectVersion>> dialectVersionMap; |
| 48 | |
| 49 | /// The producer of the bytecode. |
| 50 | StringRef producer; |
| 51 | |
| 52 | /// Printer callbacks used to emit custom type and attribute encodings. |
| 53 | llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> |
| 54 | attributeWriterCallbacks; |
| 55 | llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> |
| 56 | typeWriterCallbacks; |
| 57 | |
| 58 | /// A collection of non-dialect resource printers. |
| 59 | SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; |
| 60 | }; |
| 61 | |
| 62 | BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer) |
| 63 | : impl(std::make_unique<Impl>(args&: producer)) {} |
| 64 | BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map, |
| 65 | StringRef producer) |
| 66 | : BytecodeWriterConfig(producer) { |
| 67 | attachFallbackResourcePrinter(map); |
| 68 | } |
| 69 | BytecodeWriterConfig::BytecodeWriterConfig(BytecodeWriterConfig &&config) |
| 70 | : impl(std::move(config.impl)) {} |
| 71 | |
| 72 | BytecodeWriterConfig::~BytecodeWriterConfig() = default; |
| 73 | |
| 74 | ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> |
| 75 | BytecodeWriterConfig::getAttributeWriterCallbacks() const { |
| 76 | return impl->attributeWriterCallbacks; |
| 77 | } |
| 78 | |
| 79 | ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> |
| 80 | BytecodeWriterConfig::getTypeWriterCallbacks() const { |
| 81 | return impl->typeWriterCallbacks; |
| 82 | } |
| 83 | |
| 84 | void BytecodeWriterConfig::attachAttributeCallback( |
| 85 | std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback) { |
| 86 | impl->attributeWriterCallbacks.emplace_back(Args: std::move(callback)); |
| 87 | } |
| 88 | |
| 89 | void BytecodeWriterConfig::attachTypeCallback( |
| 90 | std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback) { |
| 91 | impl->typeWriterCallbacks.emplace_back(Args: std::move(callback)); |
| 92 | } |
| 93 | |
| 94 | void BytecodeWriterConfig::attachResourcePrinter( |
| 95 | std::unique_ptr<AsmResourcePrinter> printer) { |
| 96 | impl->externalResourcePrinters.emplace_back(Args: std::move(printer)); |
| 97 | } |
| 98 | |
| 99 | void BytecodeWriterConfig::setElideResourceDataFlag( |
| 100 | bool shouldElideResourceData) { |
| 101 | impl->shouldElideResourceData = shouldElideResourceData; |
| 102 | } |
| 103 | |
| 104 | void BytecodeWriterConfig::setDesiredBytecodeVersion(int64_t bytecodeVersion) { |
| 105 | impl->bytecodeVersion = bytecodeVersion; |
| 106 | } |
| 107 | |
| 108 | int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const { |
| 109 | return impl->bytecodeVersion; |
| 110 | } |
| 111 | |
| 112 | llvm::StringMap<std::unique_ptr<DialectVersion>> & |
| 113 | BytecodeWriterConfig::getDialectVersionMap() const { |
| 114 | return impl->dialectVersionMap; |
| 115 | } |
| 116 | |
| 117 | void BytecodeWriterConfig::setDialectVersion( |
| 118 | llvm::StringRef dialectName, |
| 119 | std::unique_ptr<DialectVersion> dialectVersion) const { |
| 120 | assert(!impl->dialectVersionMap.contains(dialectName) && |
| 121 | "cannot override a previously set dialect version" ); |
| 122 | impl->dialectVersionMap.insert(KV: {dialectName, std::move(dialectVersion)}); |
| 123 | } |
| 124 | |
| 125 | //===----------------------------------------------------------------------===// |
| 126 | // EncodingEmitter |
| 127 | //===----------------------------------------------------------------------===// |
| 128 | |
| 129 | namespace { |
| 130 | /// This class functions as the underlying encoding emitter for the bytecode |
| 131 | /// writer. This class is a bit different compared to other types of encoders; |
| 132 | /// it does not use a single buffer, but instead may contain several buffers |
| 133 | /// (some owned by the writer, and some not) that get concatted during the final |
| 134 | /// emission. |
| 135 | class EncodingEmitter { |
| 136 | public: |
| 137 | EncodingEmitter() = default; |
| 138 | EncodingEmitter(const EncodingEmitter &) = delete; |
| 139 | EncodingEmitter &operator=(const EncodingEmitter &) = delete; |
| 140 | |
| 141 | /// Write the current contents to the provided stream. |
| 142 | void writeTo(raw_ostream &os) const; |
| 143 | |
| 144 | /// Return the current size of the encoded buffer. |
| 145 | size_t size() const { return prevResultSize + currentResult.size(); } |
| 146 | |
| 147 | //===--------------------------------------------------------------------===// |
| 148 | // Emission |
| 149 | //===--------------------------------------------------------------------===// |
| 150 | |
| 151 | /// Backpatch a byte in the result buffer at the given offset. |
| 152 | void patchByte(uint64_t offset, uint8_t value, StringLiteral desc) { |
| 153 | LLVM_DEBUG(llvm::dbgs() << "patchByte(" << offset << ',' << uint64_t(value) |
| 154 | << ")\t" << desc << '\n'); |
| 155 | assert(offset < size() && offset >= prevResultSize && |
| 156 | "cannot patch previously emitted data" ); |
| 157 | currentResult[offset - prevResultSize] = value; |
| 158 | } |
| 159 | |
| 160 | /// Emit the provided blob of data, which is owned by the caller and is |
| 161 | /// guaranteed to not die before the end of the bytecode process. |
| 162 | void emitOwnedBlob(ArrayRef<uint8_t> data, StringLiteral desc) { |
| 163 | LLVM_DEBUG(llvm::dbgs() |
| 164 | << "emitOwnedBlob(" << data.size() << "b)\t" << desc << '\n'); |
| 165 | // Push the current buffer before adding the provided data. |
| 166 | appendResult(result: std::move(currentResult)); |
| 167 | appendOwnedResult(result: data); |
| 168 | } |
| 169 | |
| 170 | /// Emit the provided blob of data that has the given alignment, which is |
| 171 | /// owned by the caller and is guaranteed to not die before the end of the |
| 172 | /// bytecode process. The alignment value is also encoded, making it available |
| 173 | /// on load. |
| 174 | void emitOwnedBlobAndAlignment(ArrayRef<uint8_t> data, uint32_t alignment, |
| 175 | StringLiteral desc) { |
| 176 | emitVarInt(value: alignment, desc); |
| 177 | emitVarInt(value: data.size(), desc); |
| 178 | |
| 179 | alignTo(alignment); |
| 180 | emitOwnedBlob(data, desc); |
| 181 | } |
| 182 | void emitOwnedBlobAndAlignment(ArrayRef<char> data, uint32_t alignment, |
| 183 | StringLiteral desc) { |
| 184 | ArrayRef<uint8_t> castedData(reinterpret_cast<const uint8_t *>(data.data()), |
| 185 | data.size()); |
| 186 | emitOwnedBlobAndAlignment(data: castedData, alignment, desc); |
| 187 | } |
| 188 | |
| 189 | /// Align the emitter to the given alignment. |
| 190 | void alignTo(unsigned alignment) { |
| 191 | if (alignment < 2) |
| 192 | return; |
| 193 | assert(llvm::isPowerOf2_32(alignment) && "expected valid alignment" ); |
| 194 | |
| 195 | // Check to see if we need to emit any padding bytes to meet the desired |
| 196 | // alignment. |
| 197 | size_t curOffset = size(); |
| 198 | size_t paddingSize = llvm::alignTo(Value: curOffset, Align: alignment) - curOffset; |
| 199 | while (paddingSize--) |
| 200 | emitByte(byte: bytecode::kAlignmentByte, desc: "alignment byte" ); |
| 201 | |
| 202 | // Keep track of the maximum required alignment. |
| 203 | requiredAlignment = std::max(a: requiredAlignment, b: alignment); |
| 204 | } |
| 205 | |
| 206 | //===--------------------------------------------------------------------===// |
| 207 | // Integer Emission |
| 208 | |
| 209 | /// Emit a single byte. |
| 210 | template <typename T> |
| 211 | void emitByte(T byte, StringLiteral desc) { |
| 212 | LLVM_DEBUG(llvm::dbgs() |
| 213 | << "emitByte(" << uint64_t(byte) << ")\t" << desc << '\n'); |
| 214 | currentResult.push_back(x: static_cast<uint8_t>(byte)); |
| 215 | } |
| 216 | |
| 217 | /// Emit a range of bytes. |
| 218 | void emitBytes(ArrayRef<uint8_t> bytes, StringLiteral desc) { |
| 219 | LLVM_DEBUG(llvm::dbgs() |
| 220 | << "emitBytes(" << bytes.size() << "b)\t" << desc << '\n'); |
| 221 | llvm::append_range(C&: currentResult, R&: bytes); |
| 222 | } |
| 223 | |
| 224 | /// Emit a variable length integer. The first encoded byte contains a prefix |
| 225 | /// in the low bits indicating the encoded length of the value. This length |
| 226 | /// prefix is a bit sequence of '0's followed by a '1'. The number of '0' bits |
| 227 | /// indicate the number of _additional_ bytes (not including the prefix byte). |
| 228 | /// All remaining bits in the first byte, along with all of the bits in |
| 229 | /// additional bytes, provide the value of the integer encoded in |
| 230 | /// little-endian order. |
| 231 | void emitVarInt(uint64_t value, StringLiteral desc) { |
| 232 | LLVM_DEBUG(llvm::dbgs() << "emitVarInt(" << value << ")\t" << desc << '\n'); |
| 233 | |
| 234 | // In the most common case, the value can be represented in a single byte. |
| 235 | // Given how hot this case is, explicitly handle that here. |
| 236 | if ((value >> 7) == 0) |
| 237 | return emitByte(byte: (value << 1) | 0x1, desc); |
| 238 | emitMultiByteVarInt(value, desc); |
| 239 | } |
| 240 | |
| 241 | /// Emit a signed variable length integer. Signed varints are encoded using |
| 242 | /// a varint with zigzag encoding, meaning that we use the low bit of the |
| 243 | /// value to indicate the sign of the value. This allows for more efficient |
| 244 | /// encoding of negative values by limiting the number of active bits |
| 245 | void emitSignedVarInt(uint64_t value, StringLiteral desc) { |
| 246 | emitVarInt(value: (value << 1) ^ (uint64_t)((int64_t)value >> 63), desc); |
| 247 | } |
| 248 | |
| 249 | /// Emit a variable length integer whose low bit is used to encode the |
| 250 | /// provided flag, i.e. encoded as: (value << 1) | (flag ? 1 : 0). |
| 251 | void emitVarIntWithFlag(uint64_t value, bool flag, StringLiteral desc) { |
| 252 | emitVarInt(value: (value << 1) | (flag ? 1 : 0), desc); |
| 253 | } |
| 254 | |
| 255 | //===--------------------------------------------------------------------===// |
| 256 | // String Emission |
| 257 | |
| 258 | /// Emit the given string as a nul terminated string. |
| 259 | void emitNulTerminatedString(StringRef str, StringLiteral desc) { |
| 260 | emitString(str, desc); |
| 261 | emitByte(byte: 0, desc: "null terminator" ); |
| 262 | } |
| 263 | |
| 264 | /// Emit the given string without a nul terminator. |
| 265 | void emitString(StringRef str, StringLiteral desc) { |
| 266 | emitBytes(bytes: {reinterpret_cast<const uint8_t *>(str.data()), str.size()}, |
| 267 | desc); |
| 268 | } |
| 269 | |
| 270 | //===--------------------------------------------------------------------===// |
| 271 | // Section Emission |
| 272 | |
| 273 | /// Emit a nested section of the given code, whose contents are encoded in the |
| 274 | /// provided emitter. |
| 275 | void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) { |
| 276 | // Emit the section code and length. The high bit of the code is used to |
| 277 | // indicate whether the section alignment is present, so save an offset to |
| 278 | // it. |
| 279 | uint64_t codeOffset = currentResult.size(); |
| 280 | emitByte(byte: code, desc: "section code" ); |
| 281 | emitVarInt(value: emitter.size(), desc: "section size" ); |
| 282 | |
| 283 | // Integrate the alignment of the section into this emitter if necessary. |
| 284 | unsigned emitterAlign = emitter.requiredAlignment; |
| 285 | if (emitterAlign > 1) { |
| 286 | if (size() & (emitterAlign - 1)) { |
| 287 | emitVarInt(value: emitterAlign, desc: "section alignment" ); |
| 288 | alignTo(alignment: emitterAlign); |
| 289 | |
| 290 | // Indicate that we needed to align the section, the high bit of the |
| 291 | // code field is used for this. |
| 292 | currentResult[codeOffset] |= 0b10000000; |
| 293 | } else { |
| 294 | // Otherwise, if we happen to be at a compatible offset, we just |
| 295 | // remember that we need this alignment. |
| 296 | requiredAlignment = std::max(a: requiredAlignment, b: emitterAlign); |
| 297 | } |
| 298 | } |
| 299 | |
| 300 | // Push our current buffer and then merge the provided section body into |
| 301 | // ours. |
| 302 | appendResult(result: std::move(currentResult)); |
| 303 | for (std::vector<uint8_t> &result : emitter.prevResultStorage) |
| 304 | prevResultStorage.push_back(x: std::move(result)); |
| 305 | llvm::append_range(C&: prevResultList, R&: emitter.prevResultList); |
| 306 | prevResultSize += emitter.prevResultSize; |
| 307 | appendResult(result: std::move(emitter.currentResult)); |
| 308 | } |
| 309 | |
| 310 | private: |
| 311 | /// Emit the given value using a variable width encoding. This method is a |
| 312 | /// fallback when the number of bytes needed to encode the value is greater |
| 313 | /// than 1. We mark it noinline here so that the single byte hot path isn't |
| 314 | /// pessimized. |
| 315 | LLVM_ATTRIBUTE_NOINLINE void emitMultiByteVarInt(uint64_t value, |
| 316 | StringLiteral desc); |
| 317 | |
| 318 | /// Append a new result buffer to the current contents. |
| 319 | void appendResult(std::vector<uint8_t> &&result) { |
| 320 | if (result.empty()) |
| 321 | return; |
| 322 | prevResultStorage.emplace_back(args: std::move(result)); |
| 323 | appendOwnedResult(result: prevResultStorage.back()); |
| 324 | } |
| 325 | void appendOwnedResult(ArrayRef<uint8_t> result) { |
| 326 | if (result.empty()) |
| 327 | return; |
| 328 | prevResultSize += result.size(); |
| 329 | prevResultList.emplace_back(args&: result); |
| 330 | } |
| 331 | |
| 332 | /// The result of the emitter currently being built. We refrain from building |
| 333 | /// a single buffer to simplify emitting sections, large data, and more. The |
| 334 | /// result is thus represented using multiple distinct buffers, some of which |
| 335 | /// we own (via prevResultStorage), and some of which are just pointers into |
| 336 | /// externally owned buffers. |
| 337 | std::vector<uint8_t> currentResult; |
| 338 | std::vector<ArrayRef<uint8_t>> prevResultList; |
| 339 | std::vector<std::vector<uint8_t>> prevResultStorage; |
| 340 | |
| 341 | /// An up-to-date total size of all of the buffers within `prevResultList`. |
| 342 | /// This enables O(1) size checks of the current encoding. |
| 343 | size_t prevResultSize = 0; |
| 344 | |
| 345 | /// The highest required alignment for the start of this section. |
| 346 | unsigned requiredAlignment = 1; |
| 347 | }; |
| 348 | |
| 349 | //===----------------------------------------------------------------------===// |
| 350 | // StringSectionBuilder |
| 351 | //===----------------------------------------------------------------------===// |
| 352 | |
| 353 | namespace { |
| 354 | /// This class is used to simplify the process of emitting the string section. |
| 355 | class StringSectionBuilder { |
| 356 | public: |
| 357 | /// Add the given string to the string section, and return the index of the |
| 358 | /// string within the section. |
| 359 | size_t insert(StringRef str) { |
| 360 | auto it = strings.insert(KV: {llvm::CachedHashStringRef(str), strings.size()}); |
| 361 | return it.first->second; |
| 362 | } |
| 363 | |
| 364 | /// Write the current set of strings to the given emitter. |
| 365 | void write(EncodingEmitter &emitter) { |
| 366 | emitter.emitVarInt(value: strings.size(), desc: "string section size" ); |
| 367 | |
| 368 | // Emit the sizes in reverse order, so that we don't need to backpatch an |
| 369 | // offset to the string data or have a separate section. |
| 370 | for (const auto &it : llvm::reverse(C&: strings)) |
| 371 | emitter.emitVarInt(value: it.first.size() + 1, desc: "string size" ); |
| 372 | // Emit the string data itself. |
| 373 | for (const auto &it : strings) |
| 374 | emitter.emitNulTerminatedString(str: it.first.val(), desc: "string" ); |
| 375 | } |
| 376 | |
| 377 | private: |
| 378 | /// A set of strings referenced within the bytecode. The value of the map is |
| 379 | /// unused. |
| 380 | llvm::MapVector<llvm::CachedHashStringRef, size_t> strings; |
| 381 | }; |
| 382 | } // namespace |
| 383 | |
| 384 | class DialectWriter : public DialectBytecodeWriter { |
| 385 | using DialectVersionMapT = llvm::StringMap<std::unique_ptr<DialectVersion>>; |
| 386 | |
| 387 | public: |
| 388 | DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter, |
| 389 | IRNumberingState &numberingState, |
| 390 | StringSectionBuilder &stringSection, |
| 391 | const DialectVersionMapT &dialectVersionMap) |
| 392 | : bytecodeVersion(bytecodeVersion), emitter(emitter), |
| 393 | numberingState(numberingState), stringSection(stringSection), |
| 394 | dialectVersionMap(dialectVersionMap) {} |
| 395 | |
| 396 | //===--------------------------------------------------------------------===// |
| 397 | // IR |
| 398 | //===--------------------------------------------------------------------===// |
| 399 | |
| 400 | void writeAttribute(Attribute attr) override { |
| 401 | emitter.emitVarInt(value: numberingState.getNumber(attr), desc: "dialect attr" ); |
| 402 | } |
| 403 | void writeOptionalAttribute(Attribute attr) override { |
| 404 | if (!attr) { |
| 405 | emitter.emitVarInt(value: 0, desc: "dialect optional attr none" ); |
| 406 | return; |
| 407 | } |
| 408 | emitter.emitVarIntWithFlag(value: numberingState.getNumber(attr), flag: true, |
| 409 | desc: "dialect optional attr" ); |
| 410 | } |
| 411 | |
| 412 | void writeType(Type type) override { |
| 413 | emitter.emitVarInt(value: numberingState.getNumber(type), desc: "dialect type" ); |
| 414 | } |
| 415 | |
| 416 | void writeResourceHandle(const AsmDialectResourceHandle &resource) override { |
| 417 | emitter.emitVarInt(value: numberingState.getNumber(resource), desc: "dialect resource" ); |
| 418 | } |
| 419 | |
| 420 | //===--------------------------------------------------------------------===// |
| 421 | // Primitives |
| 422 | //===--------------------------------------------------------------------===// |
| 423 | |
| 424 | void writeVarInt(uint64_t value) override { |
| 425 | emitter.emitVarInt(value, desc: "dialect writer" ); |
| 426 | } |
| 427 | |
| 428 | void writeSignedVarInt(int64_t value) override { |
| 429 | emitter.emitSignedVarInt(value, desc: "dialect writer" ); |
| 430 | } |
| 431 | |
| 432 | void writeAPIntWithKnownWidth(const APInt &value) override { |
| 433 | size_t bitWidth = value.getBitWidth(); |
| 434 | |
| 435 | // If the value is a single byte, just emit it directly without going |
| 436 | // through a varint. |
| 437 | if (bitWidth <= 8) |
| 438 | return emitter.emitByte(byte: value.getLimitedValue(), desc: "dialect APInt" ); |
| 439 | |
| 440 | // If the value fits within a single varint, emit it directly. |
| 441 | if (bitWidth <= 64) |
| 442 | return emitter.emitSignedVarInt(value: value.getLimitedValue(), desc: "dialect APInt" ); |
| 443 | |
| 444 | // Otherwise, we need to encode a variable number of active words. We use |
| 445 | // active words instead of the number of total words under the observation |
| 446 | // that smaller values will be more common. |
| 447 | unsigned numActiveWords = value.getActiveWords(); |
| 448 | emitter.emitVarInt(value: numActiveWords, desc: "dialect APInt word count" ); |
| 449 | |
| 450 | const uint64_t *rawValueData = value.getRawData(); |
| 451 | for (unsigned i = 0; i < numActiveWords; ++i) |
| 452 | emitter.emitSignedVarInt(value: rawValueData[i], desc: "dialect APInt word" ); |
| 453 | } |
| 454 | |
| 455 | void writeAPFloatWithKnownSemantics(const APFloat &value) override { |
| 456 | writeAPIntWithKnownWidth(value: value.bitcastToAPInt()); |
| 457 | } |
| 458 | |
| 459 | void writeOwnedString(StringRef str) override { |
| 460 | emitter.emitVarInt(value: stringSection.insert(str), desc: "dialect string" ); |
| 461 | } |
| 462 | |
| 463 | void writeOwnedBlob(ArrayRef<char> blob) override { |
| 464 | emitter.emitVarInt(value: blob.size(), desc: "dialect blob" ); |
| 465 | emitter.emitOwnedBlob( |
| 466 | data: ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(blob.data()), |
| 467 | blob.size()), |
| 468 | desc: "dialect blob" ); |
| 469 | } |
| 470 | |
| 471 | void writeOwnedBool(bool value) override { |
| 472 | emitter.emitByte(byte: value, desc: "dialect bool" ); |
| 473 | } |
| 474 | |
| 475 | int64_t getBytecodeVersion() const override { return bytecodeVersion; } |
| 476 | |
| 477 | FailureOr<const DialectVersion *> |
| 478 | getDialectVersion(StringRef dialectName) const override { |
| 479 | auto dialectEntry = dialectVersionMap.find(Key: dialectName); |
| 480 | if (dialectEntry == dialectVersionMap.end()) |
| 481 | return failure(); |
| 482 | return dialectEntry->getValue().get(); |
| 483 | } |
| 484 | |
| 485 | private: |
| 486 | int64_t bytecodeVersion; |
| 487 | EncodingEmitter &emitter; |
| 488 | IRNumberingState &numberingState; |
| 489 | StringSectionBuilder &stringSection; |
| 490 | const DialectVersionMapT &dialectVersionMap; |
| 491 | }; |
| 492 | |
| 493 | namespace { |
| 494 | class PropertiesSectionBuilder { |
| 495 | public: |
| 496 | PropertiesSectionBuilder(IRNumberingState &numberingState, |
| 497 | StringSectionBuilder &stringSection, |
| 498 | const BytecodeWriterConfig::Impl &config) |
| 499 | : numberingState(numberingState), stringSection(stringSection), |
| 500 | config(config) {} |
| 501 | |
| 502 | /// Emit the op properties in the properties section and return the index of |
| 503 | /// the properties within the section. Return -1 if no properties was emitted. |
| 504 | std::optional<ssize_t> emit(Operation *op) { |
| 505 | EncodingEmitter propertiesEmitter; |
| 506 | if (!op->getPropertiesStorageSize()) |
| 507 | return std::nullopt; |
| 508 | if (!op->isRegistered()) { |
| 509 | // Unregistered op are storing properties as an optional attribute. |
| 510 | Attribute prop = *op->getPropertiesStorage().as<Attribute *>(); |
| 511 | if (!prop) |
| 512 | return std::nullopt; |
| 513 | EncodingEmitter sizeEmitter; |
| 514 | sizeEmitter.emitVarInt(value: numberingState.getNumber(attr: prop), desc: "properties size" ); |
| 515 | scratch.clear(); |
| 516 | llvm::raw_svector_ostream os(scratch); |
| 517 | sizeEmitter.writeTo(os); |
| 518 | return emit(rawProperties: scratch); |
| 519 | } |
| 520 | |
| 521 | EncodingEmitter emitter; |
| 522 | DialectWriter propertiesWriter(config.bytecodeVersion, emitter, |
| 523 | numberingState, stringSection, |
| 524 | config.dialectVersionMap); |
| 525 | auto iface = cast<BytecodeOpInterface>(op); |
| 526 | iface.writeProperties(propertiesWriter); |
| 527 | scratch.clear(); |
| 528 | llvm::raw_svector_ostream os(scratch); |
| 529 | emitter.writeTo(os); |
| 530 | return emit(rawProperties: scratch); |
| 531 | } |
| 532 | |
| 533 | /// Write the current set of properties to the given emitter. |
| 534 | void write(EncodingEmitter &emitter) { |
| 535 | emitter.emitVarInt(value: propertiesStorage.size(), desc: "properties size" ); |
| 536 | if (propertiesStorage.empty()) |
| 537 | return; |
| 538 | for (const auto &storage : propertiesStorage) { |
| 539 | if (storage.empty()) { |
| 540 | emitter.emitBytes(bytes: ArrayRef<uint8_t>(), desc: "empty properties" ); |
| 541 | continue; |
| 542 | } |
| 543 | emitter.emitBytes(bytes: ArrayRef(reinterpret_cast<const uint8_t *>(&storage[0]), |
| 544 | storage.size()), |
| 545 | desc: "property" ); |
| 546 | } |
| 547 | } |
| 548 | |
| 549 | /// Returns true if the section is empty. |
| 550 | bool empty() { return propertiesStorage.empty(); } |
| 551 | |
| 552 | private: |
| 553 | /// Emit raw data and returns the offset in the internal buffer. |
| 554 | /// Data are deduplicated and will be copied in the internal buffer only if |
| 555 | /// they don't exist there already. |
| 556 | ssize_t emit(ArrayRef<char> rawProperties) { |
| 557 | // Populate a scratch buffer with the properties size. |
| 558 | SmallVector<char> sizeScratch; |
| 559 | { |
| 560 | EncodingEmitter sizeEmitter; |
| 561 | sizeEmitter.emitVarInt(value: rawProperties.size(), desc: "properties" ); |
| 562 | llvm::raw_svector_ostream os(sizeScratch); |
| 563 | sizeEmitter.writeTo(os); |
| 564 | } |
| 565 | // Append a new storage to the table now. |
| 566 | size_t index = propertiesStorage.size(); |
| 567 | propertiesStorage.emplace_back(); |
| 568 | std::vector<char> &newStorage = propertiesStorage.back(); |
| 569 | size_t propertiesSize = sizeScratch.size() + rawProperties.size(); |
| 570 | newStorage.reserve(n: propertiesSize); |
| 571 | llvm::append_range(C&: newStorage, R&: sizeScratch); |
| 572 | llvm::append_range(C&: newStorage, R&: rawProperties); |
| 573 | |
| 574 | // Try to de-duplicate the new serialized properties. |
| 575 | // If the properties is a duplicate, pop it back from the storage. |
| 576 | auto inserted = propertiesUniquing.insert( |
| 577 | KV: std::make_pair(x: ArrayRef<char>(newStorage), y&: index)); |
| 578 | if (!inserted.second) |
| 579 | propertiesStorage.pop_back(); |
| 580 | return inserted.first->getSecond(); |
| 581 | } |
| 582 | |
| 583 | /// Storage for properties. |
| 584 | std::vector<std::vector<char>> propertiesStorage; |
| 585 | SmallVector<char> scratch; |
| 586 | DenseMap<ArrayRef<char>, int64_t> propertiesUniquing; |
| 587 | IRNumberingState &numberingState; |
| 588 | StringSectionBuilder &stringSection; |
| 589 | const BytecodeWriterConfig::Impl &config; |
| 590 | }; |
| 591 | } // namespace |
| 592 | |
| 593 | /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need |
| 594 | /// to go through an intermediate buffer when interacting with code that wants a |
| 595 | /// raw_ostream. |
| 596 | class RawEmitterOstream : public raw_ostream { |
| 597 | public: |
| 598 | explicit RawEmitterOstream(EncodingEmitter &emitter) : emitter(emitter) { |
| 599 | SetUnbuffered(); |
| 600 | } |
| 601 | |
| 602 | private: |
| 603 | void write_impl(const char *ptr, size_t size) override { |
| 604 | emitter.emitBytes(bytes: {reinterpret_cast<const uint8_t *>(ptr), size}, |
| 605 | desc: "raw emitter" ); |
| 606 | } |
| 607 | uint64_t current_pos() const override { return emitter.size(); } |
| 608 | |
| 609 | /// The section being emitted to. |
| 610 | EncodingEmitter &emitter; |
| 611 | }; |
| 612 | } // namespace |
| 613 | |
| 614 | void EncodingEmitter::writeTo(raw_ostream &os) const { |
| 615 | // Reserve space in the ostream for the encoded contents. |
| 616 | os.reserveExtraSpace(ExtraSize: size()); |
| 617 | |
| 618 | for (auto &prevResult : prevResultList) |
| 619 | os.write(Ptr: (const char *)prevResult.data(), Size: prevResult.size()); |
| 620 | os.write(Ptr: (const char *)currentResult.data(), Size: currentResult.size()); |
| 621 | } |
| 622 | |
| 623 | void EncodingEmitter::emitMultiByteVarInt(uint64_t value, StringLiteral desc) { |
| 624 | // Compute the number of bytes needed to encode the value. Each byte can hold |
| 625 | // up to 7-bits of data. We only check up to the number of bits we can encode |
| 626 | // in the first byte (8). |
| 627 | uint64_t it = value >> 7; |
| 628 | for (size_t numBytes = 2; numBytes < 9; ++numBytes) { |
| 629 | if (LLVM_LIKELY(it >>= 7) == 0) { |
| 630 | uint64_t encodedValue = (value << 1) | 0x1; |
| 631 | encodedValue <<= (numBytes - 1); |
| 632 | llvm::support::ulittle64_t encodedValueLE(encodedValue); |
| 633 | emitBytes(bytes: {reinterpret_cast<uint8_t *>(&encodedValueLE), numBytes}, desc); |
| 634 | return; |
| 635 | } |
| 636 | } |
| 637 | |
| 638 | // If the value is too large to encode in a single byte, emit a special all |
| 639 | // zero marker byte and splat the value directly. |
| 640 | emitByte(byte: 0, desc); |
| 641 | llvm::support::ulittle64_t valueLE(value); |
| 642 | emitBytes(bytes: {reinterpret_cast<uint8_t *>(&valueLE), sizeof(valueLE)}, desc); |
| 643 | } |
| 644 | |
| 645 | //===----------------------------------------------------------------------===// |
| 646 | // Bytecode Writer |
| 647 | //===----------------------------------------------------------------------===// |
| 648 | |
| 649 | namespace { |
| 650 | class BytecodeWriter { |
| 651 | public: |
| 652 | BytecodeWriter(Operation *op, const BytecodeWriterConfig &config) |
| 653 | : numberingState(op, config), config(config.getImpl()), |
| 654 | propertiesSection(numberingState, stringSection, config.getImpl()) {} |
| 655 | |
| 656 | /// Write the bytecode for the given root operation. |
| 657 | LogicalResult write(Operation *rootOp, raw_ostream &os); |
| 658 | |
| 659 | private: |
| 660 | //===--------------------------------------------------------------------===// |
| 661 | // Dialects |
| 662 | |
| 663 | void writeDialectSection(EncodingEmitter &emitter); |
| 664 | |
| 665 | //===--------------------------------------------------------------------===// |
| 666 | // Attributes and Types |
| 667 | |
| 668 | void writeAttrTypeSection(EncodingEmitter &emitter); |
| 669 | |
| 670 | //===--------------------------------------------------------------------===// |
| 671 | // Operations |
| 672 | |
| 673 | LogicalResult writeBlock(EncodingEmitter &emitter, Block *block); |
| 674 | LogicalResult writeOp(EncodingEmitter &emitter, Operation *op); |
| 675 | LogicalResult writeRegion(EncodingEmitter &emitter, Region *region); |
| 676 | LogicalResult writeIRSection(EncodingEmitter &emitter, Operation *op); |
| 677 | |
| 678 | LogicalResult writeRegions(EncodingEmitter &emitter, |
| 679 | MutableArrayRef<Region> regions) { |
| 680 | return success(IsSuccess: llvm::all_of(Range&: regions, P: [&](Region ®ion) { |
| 681 | return succeeded(Result: writeRegion(emitter, region: ®ion)); |
| 682 | })); |
| 683 | } |
| 684 | |
| 685 | //===--------------------------------------------------------------------===// |
| 686 | // Resources |
| 687 | |
| 688 | void writeResourceSection(Operation *op, EncodingEmitter &emitter); |
| 689 | |
| 690 | //===--------------------------------------------------------------------===// |
| 691 | // Strings |
| 692 | |
| 693 | void writeStringSection(EncodingEmitter &emitter); |
| 694 | |
| 695 | //===--------------------------------------------------------------------===// |
| 696 | // Properties |
| 697 | |
| 698 | void writePropertiesSection(EncodingEmitter &emitter); |
| 699 | |
| 700 | //===--------------------------------------------------------------------===// |
| 701 | // Helpers |
| 702 | |
| 703 | void writeUseListOrders(EncodingEmitter &emitter, uint8_t &opEncodingMask, |
| 704 | ValueRange range); |
| 705 | |
| 706 | //===--------------------------------------------------------------------===// |
| 707 | // Fields |
| 708 | |
| 709 | /// The builder used for the string section. |
| 710 | StringSectionBuilder stringSection; |
| 711 | |
| 712 | /// The IR numbering state generated for the root operation. |
| 713 | IRNumberingState numberingState; |
| 714 | |
| 715 | /// Configuration dictating bytecode emission. |
| 716 | const BytecodeWriterConfig::Impl &config; |
| 717 | |
| 718 | /// Storage for the properties section |
| 719 | PropertiesSectionBuilder propertiesSection; |
| 720 | }; |
| 721 | } // namespace |
| 722 | |
| 723 | LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { |
| 724 | EncodingEmitter emitter; |
| 725 | |
| 726 | // Emit the bytecode file header. This is how we identify the output as a |
| 727 | // bytecode file. |
| 728 | emitter.emitString(str: "ML\xefR" , desc: "bytecode header" ); |
| 729 | |
| 730 | // Emit the bytecode version. |
| 731 | if (config.bytecodeVersion < bytecode::kMinSupportedVersion || |
| 732 | config.bytecodeVersion > bytecode::kVersion) |
| 733 | return rootOp->emitError() |
| 734 | << "unsupported version requested " << config.bytecodeVersion |
| 735 | << ", must be in range [" |
| 736 | << static_cast<int64_t>(bytecode::kMinSupportedVersion) << ", " |
| 737 | << static_cast<int64_t>(bytecode::kVersion) << ']'; |
| 738 | emitter.emitVarInt(value: config.bytecodeVersion, desc: "bytecode version" ); |
| 739 | |
| 740 | // Emit the producer. |
| 741 | emitter.emitNulTerminatedString(str: config.producer, desc: "bytecode producer" ); |
| 742 | |
| 743 | // Emit the dialect section. |
| 744 | writeDialectSection(emitter); |
| 745 | |
| 746 | // Emit the attributes and types section. |
| 747 | writeAttrTypeSection(emitter); |
| 748 | |
| 749 | // Emit the IR section. |
| 750 | if (failed(Result: writeIRSection(emitter, op: rootOp))) |
| 751 | return failure(); |
| 752 | |
| 753 | // Emit the resources section. |
| 754 | writeResourceSection(op: rootOp, emitter); |
| 755 | |
| 756 | // Emit the string section. |
| 757 | writeStringSection(emitter); |
| 758 | |
| 759 | // Emit the properties section. |
| 760 | if (config.bytecodeVersion >= bytecode::kNativePropertiesEncoding) |
| 761 | writePropertiesSection(emitter); |
| 762 | else if (!propertiesSection.empty()) |
| 763 | return rootOp->emitError( |
| 764 | message: "unexpected properties emitted incompatible with bytecode <5" ); |
| 765 | |
| 766 | // Write the generated bytecode to the provided output stream. |
| 767 | emitter.writeTo(os); |
| 768 | |
| 769 | return success(); |
| 770 | } |
| 771 | |
| 772 | //===----------------------------------------------------------------------===// |
| 773 | // Dialects |
| 774 | //===----------------------------------------------------------------------===// |
| 775 | |
| 776 | /// Write the given entries in contiguous groups with the same parent dialect. |
| 777 | /// Each dialect sub-group is encoded with the parent dialect and number of |
| 778 | /// elements, followed by the encoding for the entries. The given callback is |
| 779 | /// invoked to encode each individual entry. |
| 780 | template <typename EntriesT, typename EntryCallbackT> |
| 781 | static void writeDialectGrouping(EncodingEmitter &emitter, EntriesT &&entries, |
| 782 | EntryCallbackT &&callback) { |
| 783 | for (auto it = entries.begin(), e = entries.end(); it != e;) { |
| 784 | auto groupStart = it++; |
| 785 | |
| 786 | // Find the end of the group that shares the same parent dialect. |
| 787 | DialectNumbering *currentDialect = groupStart->dialect; |
| 788 | it = std::find_if(it, e, [&](const auto &entry) { |
| 789 | return entry.dialect != currentDialect; |
| 790 | }); |
| 791 | |
| 792 | // Emit the dialect and number of elements. |
| 793 | emitter.emitVarInt(value: currentDialect->number, desc: "dialect number" ); |
| 794 | emitter.emitVarInt(value: std::distance(groupStart, it), desc: "dialect offset" ); |
| 795 | |
| 796 | // Emit the entries within the group. |
| 797 | for (auto &entry : llvm::make_range(groupStart, it)) |
| 798 | callback(entry); |
| 799 | } |
| 800 | } |
| 801 | |
| 802 | void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) { |
| 803 | EncodingEmitter dialectEmitter; |
| 804 | |
| 805 | // Emit the referenced dialects. |
| 806 | auto dialects = numberingState.getDialects(); |
| 807 | dialectEmitter.emitVarInt(value: llvm::size(Range&: dialects), desc: "dialects count" ); |
| 808 | for (DialectNumbering &dialect : dialects) { |
| 809 | // Write the string section and get the ID. |
| 810 | size_t nameID = stringSection.insert(str: dialect.name); |
| 811 | |
| 812 | if (config.bytecodeVersion < bytecode::kDialectVersioning) { |
| 813 | dialectEmitter.emitVarInt(value: nameID, desc: "dialect name ID" ); |
| 814 | continue; |
| 815 | } |
| 816 | |
| 817 | // Try writing the version to the versionEmitter. |
| 818 | EncodingEmitter versionEmitter; |
| 819 | if (dialect.interface) { |
| 820 | // The writer used when emitting using a custom bytecode encoding. |
| 821 | DialectWriter versionWriter(config.bytecodeVersion, versionEmitter, |
| 822 | numberingState, stringSection, |
| 823 | config.dialectVersionMap); |
| 824 | dialect.interface->writeVersion(writer&: versionWriter); |
| 825 | } |
| 826 | |
| 827 | // If the version emitter is empty, version is not available. We can encode |
| 828 | // this in the dialect ID, so if there is no version, we don't write the |
| 829 | // section. |
| 830 | size_t versionAvailable = versionEmitter.size() > 0; |
| 831 | dialectEmitter.emitVarIntWithFlag(value: nameID, flag: versionAvailable, |
| 832 | desc: "dialect version" ); |
| 833 | if (versionAvailable) |
| 834 | dialectEmitter.emitSection(code: bytecode::Section::kDialectVersions, |
| 835 | emitter: std::move(versionEmitter)); |
| 836 | } |
| 837 | |
| 838 | if (config.bytecodeVersion >= bytecode::kElideUnknownBlockArgLocation) |
| 839 | dialectEmitter.emitVarInt(value: size(Range: numberingState.getOpNames()), |
| 840 | desc: "op names count" ); |
| 841 | |
| 842 | // Emit the referenced operation names grouped by dialect. |
| 843 | auto emitOpName = [&](OpNameNumbering &name) { |
| 844 | size_t stringId = stringSection.insert(str: name.name.stripDialect()); |
| 845 | if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding) |
| 846 | dialectEmitter.emitVarInt(value: stringId, desc: "dialect op name" ); |
| 847 | else |
| 848 | dialectEmitter.emitVarIntWithFlag(value: stringId, flag: name.name.isRegistered(), |
| 849 | desc: "dialect op name" ); |
| 850 | }; |
| 851 | writeDialectGrouping(emitter&: dialectEmitter, entries: numberingState.getOpNames(), callback&: emitOpName); |
| 852 | |
| 853 | emitter.emitSection(code: bytecode::Section::kDialect, emitter: std::move(dialectEmitter)); |
| 854 | } |
| 855 | |
| 856 | //===----------------------------------------------------------------------===// |
| 857 | // Attributes and Types |
| 858 | //===----------------------------------------------------------------------===// |
| 859 | |
| 860 | void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { |
| 861 | EncodingEmitter attrTypeEmitter; |
| 862 | EncodingEmitter offsetEmitter; |
| 863 | offsetEmitter.emitVarInt(value: llvm::size(Range: numberingState.getAttributes()), |
| 864 | desc: "attributes count" ); |
| 865 | offsetEmitter.emitVarInt(value: llvm::size(Range: numberingState.getTypes()), |
| 866 | desc: "types count" ); |
| 867 | |
| 868 | // A functor used to emit an attribute or type entry. |
| 869 | uint64_t prevOffset = 0; |
| 870 | auto emitAttrOrType = [&](auto &entry) { |
| 871 | auto entryValue = entry.getValue(); |
| 872 | |
| 873 | auto emitAttrOrTypeRawImpl = [&]() -> void { |
| 874 | RawEmitterOstream(attrTypeEmitter) << entryValue; |
| 875 | attrTypeEmitter.emitByte(byte: 0, desc: "attr/type separator" ); |
| 876 | }; |
| 877 | auto emitAttrOrTypeImpl = [&]() -> bool { |
| 878 | // TODO: We don't currently support custom encoded mutable types and |
| 879 | // attributes. |
| 880 | if (entryValue.template hasTrait<TypeTrait::IsMutable>() || |
| 881 | entryValue.template hasTrait<AttributeTrait::IsMutable>()) { |
| 882 | emitAttrOrTypeRawImpl(); |
| 883 | return false; |
| 884 | } |
| 885 | |
| 886 | DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, |
| 887 | numberingState, stringSection, |
| 888 | config.dialectVersionMap); |
| 889 | if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) { |
| 890 | for (const auto &callback : config.typeWriterCallbacks) { |
| 891 | if (succeeded(callback->write(entryValue, dialectWriter))) |
| 892 | return true; |
| 893 | } |
| 894 | if (const BytecodeDialectInterface *interface = |
| 895 | entry.dialect->interface) { |
| 896 | if (succeeded(interface->writeType(type: entryValue, writer&: dialectWriter))) |
| 897 | return true; |
| 898 | } |
| 899 | } else { |
| 900 | for (const auto &callback : config.attributeWriterCallbacks) { |
| 901 | if (succeeded(callback->write(entryValue, dialectWriter))) |
| 902 | return true; |
| 903 | } |
| 904 | if (const BytecodeDialectInterface *interface = |
| 905 | entry.dialect->interface) { |
| 906 | if (succeeded(interface->writeAttribute(attr: entryValue, writer&: dialectWriter))) |
| 907 | return true; |
| 908 | } |
| 909 | } |
| 910 | |
| 911 | // If the entry was not emitted using a callback or a dialect interface, |
| 912 | // emit it using the textual format. |
| 913 | emitAttrOrTypeRawImpl(); |
| 914 | return false; |
| 915 | }; |
| 916 | |
| 917 | bool hasCustomEncoding = emitAttrOrTypeImpl(); |
| 918 | |
| 919 | // Record the offset of this entry. |
| 920 | uint64_t curOffset = attrTypeEmitter.size(); |
| 921 | offsetEmitter.emitVarIntWithFlag(value: curOffset - prevOffset, flag: hasCustomEncoding, |
| 922 | desc: "attr/type offset" ); |
| 923 | prevOffset = curOffset; |
| 924 | }; |
| 925 | |
| 926 | // Emit the attribute and type entries for each dialect. |
| 927 | writeDialectGrouping(emitter&: offsetEmitter, entries: numberingState.getAttributes(), |
| 928 | callback&: emitAttrOrType); |
| 929 | writeDialectGrouping(emitter&: offsetEmitter, entries: numberingState.getTypes(), |
| 930 | callback&: emitAttrOrType); |
| 931 | |
| 932 | // Emit the sections to the stream. |
| 933 | emitter.emitSection(code: bytecode::Section::kAttrTypeOffset, |
| 934 | emitter: std::move(offsetEmitter)); |
| 935 | emitter.emitSection(code: bytecode::Section::kAttrType, emitter: std::move(attrTypeEmitter)); |
| 936 | } |
| 937 | |
| 938 | //===----------------------------------------------------------------------===// |
| 939 | // Operations |
| 940 | //===----------------------------------------------------------------------===// |
| 941 | |
| 942 | LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter, |
| 943 | Block *block) { |
| 944 | ArrayRef<BlockArgument> args = block->getArguments(); |
| 945 | bool hasArgs = !args.empty(); |
| 946 | |
| 947 | // Emit the number of operations in this block, and if it has arguments. We |
| 948 | // use the low bit of the operation count to indicate if the block has |
| 949 | // arguments. |
| 950 | unsigned numOps = numberingState.getOperationCount(block); |
| 951 | emitter.emitVarIntWithFlag(value: numOps, flag: hasArgs, desc: "block num ops" ); |
| 952 | |
| 953 | // Emit the arguments of the block. |
| 954 | if (hasArgs) { |
| 955 | emitter.emitVarInt(value: args.size(), desc: "block args count" ); |
| 956 | for (BlockArgument arg : args) { |
| 957 | Location argLoc = arg.getLoc(); |
| 958 | if (config.bytecodeVersion >= bytecode::kElideUnknownBlockArgLocation) { |
| 959 | emitter.emitVarIntWithFlag(value: numberingState.getNumber(type: arg.getType()), |
| 960 | flag: !isa<UnknownLoc>(Val: argLoc), desc: "block arg type" ); |
| 961 | if (!isa<UnknownLoc>(Val: argLoc)) |
| 962 | emitter.emitVarInt(value: numberingState.getNumber(attr: argLoc), |
| 963 | desc: "block arg location" ); |
| 964 | } else { |
| 965 | emitter.emitVarInt(value: numberingState.getNumber(type: arg.getType()), |
| 966 | desc: "block arg type" ); |
| 967 | emitter.emitVarInt(value: numberingState.getNumber(attr: argLoc), |
| 968 | desc: "block arg location" ); |
| 969 | } |
| 970 | } |
| 971 | if (config.bytecodeVersion >= bytecode::kUseListOrdering) { |
| 972 | uint64_t maskOffset = emitter.size(); |
| 973 | uint8_t encodingMask = 0; |
| 974 | emitter.emitByte(byte: 0, desc: "use-list separator" ); |
| 975 | writeUseListOrders(emitter, opEncodingMask&: encodingMask, range: args); |
| 976 | if (encodingMask) |
| 977 | emitter.patchByte(offset: maskOffset, value: encodingMask, desc: "block patch encoding" ); |
| 978 | } |
| 979 | } |
| 980 | |
| 981 | // Emit the operations within the block. |
| 982 | for (Operation &op : *block) |
| 983 | if (failed(Result: writeOp(emitter, op: &op))) |
| 984 | return failure(); |
| 985 | return success(); |
| 986 | } |
| 987 | |
| 988 | LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { |
| 989 | emitter.emitVarInt(value: numberingState.getNumber(opName: op->getName()), desc: "op name ID" ); |
| 990 | |
| 991 | // Emit a mask for the operation components. We need to fill this in later |
| 992 | // (when we actually know what needs to be emitted), so emit a placeholder for |
| 993 | // now. |
| 994 | uint64_t maskOffset = emitter.size(); |
| 995 | uint8_t opEncodingMask = 0; |
| 996 | emitter.emitByte(byte: 0, desc: "op separator" ); |
| 997 | |
| 998 | // Emit the location for this operation. |
| 999 | emitter.emitVarInt(value: numberingState.getNumber(attr: op->getLoc()), desc: "op location" ); |
| 1000 | |
| 1001 | // Emit the attributes of this operation. |
| 1002 | DictionaryAttr attrs = op->getDiscardableAttrDictionary(); |
| 1003 | // Allow deployment to version <kNativePropertiesEncoding by merging inherent |
| 1004 | // attribute with the discardable ones. We should fail if there are any |
| 1005 | // conflicts. When properties are not used by the op, also store everything as |
| 1006 | // attributes. |
| 1007 | if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding || |
| 1008 | !op->getPropertiesStorage()) { |
| 1009 | attrs = op->getAttrDictionary(); |
| 1010 | } |
| 1011 | if (!attrs.empty()) { |
| 1012 | opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs; |
| 1013 | emitter.emitVarInt(value: numberingState.getNumber(attrs), desc: "op attrs count" ); |
| 1014 | } |
| 1015 | |
| 1016 | // Emit the properties of this operation, for now we still support deployment |
| 1017 | // to version <kNativePropertiesEncoding. |
| 1018 | if (config.bytecodeVersion >= bytecode::kNativePropertiesEncoding) { |
| 1019 | std::optional<ssize_t> propertiesId = propertiesSection.emit(op); |
| 1020 | if (propertiesId.has_value()) { |
| 1021 | opEncodingMask |= bytecode::OpEncodingMask::kHasProperties; |
| 1022 | emitter.emitVarInt(value: *propertiesId, desc: "op properties ID" ); |
| 1023 | } |
| 1024 | } |
| 1025 | |
| 1026 | // Emit the result types of the operation. |
| 1027 | if (unsigned numResults = op->getNumResults()) { |
| 1028 | opEncodingMask |= bytecode::OpEncodingMask::kHasResults; |
| 1029 | emitter.emitVarInt(value: numResults, desc: "op results count" ); |
| 1030 | for (Type type : op->getResultTypes()) |
| 1031 | emitter.emitVarInt(value: numberingState.getNumber(type), desc: "op result type" ); |
| 1032 | } |
| 1033 | |
| 1034 | // Emit the operands of the operation. |
| 1035 | if (unsigned numOperands = op->getNumOperands()) { |
| 1036 | opEncodingMask |= bytecode::OpEncodingMask::kHasOperands; |
| 1037 | emitter.emitVarInt(value: numOperands, desc: "op operands count" ); |
| 1038 | for (Value operand : op->getOperands()) |
| 1039 | emitter.emitVarInt(value: numberingState.getNumber(value: operand), desc: "op operand types" ); |
| 1040 | } |
| 1041 | |
| 1042 | // Emit the successors of the operation. |
| 1043 | if (unsigned numSuccessors = op->getNumSuccessors()) { |
| 1044 | opEncodingMask |= bytecode::OpEncodingMask::kHasSuccessors; |
| 1045 | emitter.emitVarInt(value: numSuccessors, desc: "op successors count" ); |
| 1046 | for (Block *successor : op->getSuccessors()) |
| 1047 | emitter.emitVarInt(value: numberingState.getNumber(block: successor), desc: "op successor" ); |
| 1048 | } |
| 1049 | |
| 1050 | // Emit the use-list orders to bytecode, so we can reconstruct the same order |
| 1051 | // at parsing. |
| 1052 | if (config.bytecodeVersion >= bytecode::kUseListOrdering) |
| 1053 | writeUseListOrders(emitter, opEncodingMask, range: ValueRange(op->getResults())); |
| 1054 | |
| 1055 | // Check for regions. |
| 1056 | unsigned numRegions = op->getNumRegions(); |
| 1057 | if (numRegions) |
| 1058 | opEncodingMask |= bytecode::OpEncodingMask::kHasInlineRegions; |
| 1059 | |
| 1060 | // Update the mask for the operation. |
| 1061 | emitter.patchByte(offset: maskOffset, value: opEncodingMask, desc: "op encoding mask" ); |
| 1062 | |
| 1063 | // With the mask emitted, we can now emit the regions of the operation. We do |
| 1064 | // this after mask emission to avoid offset complications that may arise by |
| 1065 | // emitting the regions first (e.g. if the regions are huge, backpatching the |
| 1066 | // op encoding mask is more annoying). |
| 1067 | if (numRegions) { |
| 1068 | bool isIsolatedFromAbove = numberingState.isIsolatedFromAbove(op); |
| 1069 | emitter.emitVarIntWithFlag(value: numRegions, flag: isIsolatedFromAbove, |
| 1070 | desc: "op regions count" ); |
| 1071 | |
| 1072 | // If the region is not isolated from above, or we are emitting bytecode |
| 1073 | // targeting version <kLazyLoading, we don't use a section. |
| 1074 | if (isIsolatedFromAbove && |
| 1075 | config.bytecodeVersion >= bytecode::kLazyLoading) { |
| 1076 | EncodingEmitter regionEmitter; |
| 1077 | if (failed(Result: writeRegions(emitter&: regionEmitter, regions: op->getRegions()))) |
| 1078 | return failure(); |
| 1079 | emitter.emitSection(code: bytecode::Section::kIR, emitter: std::move(regionEmitter)); |
| 1080 | |
| 1081 | } else if (failed(Result: writeRegions(emitter, regions: op->getRegions()))) { |
| 1082 | return failure(); |
| 1083 | } |
| 1084 | } |
| 1085 | return success(); |
| 1086 | } |
| 1087 | |
| 1088 | void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter, |
| 1089 | uint8_t &opEncodingMask, |
| 1090 | ValueRange range) { |
| 1091 | // Loop over the results and store the use-list order per result index. |
| 1092 | DenseMap<unsigned, llvm::SmallVector<unsigned>> map; |
| 1093 | for (auto item : llvm::enumerate(First&: range)) { |
| 1094 | auto value = item.value(); |
| 1095 | // No need to store a custom use-list order if the result does not have |
| 1096 | // multiple uses. |
| 1097 | if (value.use_empty() || value.hasOneUse()) |
| 1098 | continue; |
| 1099 | |
| 1100 | // For each result, assemble the list of pairs (use-list-index, |
| 1101 | // global-value-index). While doing so, detect if the global-value-index is |
| 1102 | // already ordered with respect to the use-list-index. |
| 1103 | bool alreadyOrdered = true; |
| 1104 | auto &firstUse = *value.use_begin(); |
| 1105 | uint64_t prevID = bytecode::getUseID( |
| 1106 | val&: firstUse, ownerID: numberingState.getNumber(op: firstUse.getOwner())); |
| 1107 | llvm::SmallVector<std::pair<unsigned, uint64_t>> useListPairs( |
| 1108 | {{0, prevID}}); |
| 1109 | |
| 1110 | for (auto use : llvm::drop_begin(RangeOrContainer: llvm::enumerate(First: value.getUses()))) { |
| 1111 | uint64_t currentID = bytecode::getUseID( |
| 1112 | val&: use.value(), ownerID: numberingState.getNumber(op: use.value().getOwner())); |
| 1113 | // The use-list order achieved when building the IR at parsing always |
| 1114 | // pushes new uses on front. Hence, if the order by unique ID is |
| 1115 | // monotonically decreasing, a roundtrip to bytecode preserves such order. |
| 1116 | alreadyOrdered &= (prevID > currentID); |
| 1117 | useListPairs.push_back(Elt: {use.index(), currentID}); |
| 1118 | prevID = currentID; |
| 1119 | } |
| 1120 | |
| 1121 | // Do not emit if the order is already sorted. |
| 1122 | if (alreadyOrdered) |
| 1123 | continue; |
| 1124 | |
| 1125 | // Sort the use indices by the unique ID indices in descending order. |
| 1126 | std::sort( |
| 1127 | first: useListPairs.begin(), last: useListPairs.end(), |
| 1128 | comp: [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); |
| 1129 | |
| 1130 | map.try_emplace(Key: item.index(), Args: llvm::map_range(C&: useListPairs, F: [](auto elem) { |
| 1131 | return elem.first; |
| 1132 | })); |
| 1133 | } |
| 1134 | |
| 1135 | if (map.empty()) |
| 1136 | return; |
| 1137 | |
| 1138 | opEncodingMask |= bytecode::OpEncodingMask::kHasUseListOrders; |
| 1139 | // Emit the number of results that have a custom use-list order if the number |
| 1140 | // of results is greater than one. |
| 1141 | if (range.size() != 1) { |
| 1142 | emitter.emitVarInt(value: map.size(), desc: "custom use-list size" ); |
| 1143 | } |
| 1144 | |
| 1145 | for (const auto &item : map) { |
| 1146 | auto resultIdx = item.getFirst(); |
| 1147 | auto useListOrder = item.getSecond(); |
| 1148 | |
| 1149 | // Compute the number of uses that are actually shuffled. If those are less |
| 1150 | // than half of the total uses, encoding the index pair `(src, dst)` is more |
| 1151 | // space efficient. |
| 1152 | size_t shuffledElements = |
| 1153 | llvm::count_if(Range: llvm::enumerate(First&: useListOrder), |
| 1154 | P: [](auto item) { return item.index() != item.value(); }); |
| 1155 | bool indexPairEncoding = shuffledElements < (useListOrder.size() / 2); |
| 1156 | |
| 1157 | // For single result, we don't need to store the result index. |
| 1158 | if (range.size() != 1) |
| 1159 | emitter.emitVarInt(value: resultIdx, desc: "use-list result index" ); |
| 1160 | |
| 1161 | if (indexPairEncoding) { |
| 1162 | emitter.emitVarIntWithFlag(value: shuffledElements * 2, flag: indexPairEncoding, |
| 1163 | desc: "use-list index pair size" ); |
| 1164 | for (auto pair : llvm::enumerate(First&: useListOrder)) { |
| 1165 | if (pair.index() != pair.value()) { |
| 1166 | emitter.emitVarInt(value: pair.value(), desc: "use-list index pair first" ); |
| 1167 | emitter.emitVarInt(value: pair.index(), desc: "use-list index pair second" ); |
| 1168 | } |
| 1169 | } |
| 1170 | } else { |
| 1171 | emitter.emitVarIntWithFlag(value: useListOrder.size(), flag: indexPairEncoding, |
| 1172 | desc: "use-list size" ); |
| 1173 | for (const auto &index : useListOrder) |
| 1174 | emitter.emitVarInt(value: index, desc: "use-list order" ); |
| 1175 | } |
| 1176 | } |
| 1177 | } |
| 1178 | |
| 1179 | LogicalResult BytecodeWriter::writeRegion(EncodingEmitter &emitter, |
| 1180 | Region *region) { |
| 1181 | // If the region is empty, we only need to emit the number of blocks (which is |
| 1182 | // zero). |
| 1183 | if (region->empty()) { |
| 1184 | emitter.emitVarInt(/*numBlocks*/ value: 0, desc: "region block count empty" ); |
| 1185 | return success(); |
| 1186 | } |
| 1187 | |
| 1188 | // Emit the number of blocks and values within the region. |
| 1189 | unsigned numBlocks, numValues; |
| 1190 | std::tie(args&: numBlocks, args&: numValues) = numberingState.getBlockValueCount(region); |
| 1191 | emitter.emitVarInt(value: numBlocks, desc: "region block count" ); |
| 1192 | emitter.emitVarInt(value: numValues, desc: "region value count" ); |
| 1193 | |
| 1194 | // Emit the blocks within the region. |
| 1195 | for (Block &block : *region) |
| 1196 | if (failed(Result: writeBlock(emitter, block: &block))) |
| 1197 | return failure(); |
| 1198 | return success(); |
| 1199 | } |
| 1200 | |
| 1201 | LogicalResult BytecodeWriter::writeIRSection(EncodingEmitter &emitter, |
| 1202 | Operation *op) { |
| 1203 | EncodingEmitter irEmitter; |
| 1204 | |
| 1205 | // Write the IR section the same way as a block with no arguments. Note that |
| 1206 | // the low-bit of the operation count for a block is used to indicate if the |
| 1207 | // block has arguments, which in this case is always false. |
| 1208 | irEmitter.emitVarIntWithFlag(/*numOps*/ value: 1, /*hasArgs*/ flag: false, desc: "ir section" ); |
| 1209 | |
| 1210 | // Emit the operations. |
| 1211 | if (failed(Result: writeOp(emitter&: irEmitter, op))) |
| 1212 | return failure(); |
| 1213 | |
| 1214 | emitter.emitSection(code: bytecode::Section::kIR, emitter: std::move(irEmitter)); |
| 1215 | return success(); |
| 1216 | } |
| 1217 | |
| 1218 | //===----------------------------------------------------------------------===// |
| 1219 | // Resources |
| 1220 | //===----------------------------------------------------------------------===// |
| 1221 | |
| 1222 | namespace { |
| 1223 | /// This class represents a resource builder implementation for the MLIR |
| 1224 | /// bytecode format. |
| 1225 | class ResourceBuilder : public AsmResourceBuilder { |
| 1226 | public: |
| 1227 | using PostProcessFn = function_ref<void(StringRef, AsmResourceEntryKind)>; |
| 1228 | |
| 1229 | ResourceBuilder(EncodingEmitter &emitter, StringSectionBuilder &stringSection, |
| 1230 | PostProcessFn postProcessFn, bool shouldElideData) |
| 1231 | : emitter(emitter), stringSection(stringSection), |
| 1232 | postProcessFn(postProcessFn), shouldElideData(shouldElideData) {} |
| 1233 | ~ResourceBuilder() override = default; |
| 1234 | |
| 1235 | void buildBlob(StringRef key, ArrayRef<char> data, |
| 1236 | uint32_t dataAlignment) final { |
| 1237 | if (!shouldElideData) |
| 1238 | emitter.emitOwnedBlobAndAlignment(data, alignment: dataAlignment, desc: "resource blob" ); |
| 1239 | postProcessFn(key, AsmResourceEntryKind::Blob); |
| 1240 | } |
| 1241 | void buildBool(StringRef key, bool data) final { |
| 1242 | if (!shouldElideData) |
| 1243 | emitter.emitByte(byte: data, desc: "resource bool" ); |
| 1244 | postProcessFn(key, AsmResourceEntryKind::Bool); |
| 1245 | } |
| 1246 | void buildString(StringRef key, StringRef data) final { |
| 1247 | if (!shouldElideData) |
| 1248 | emitter.emitVarInt(value: stringSection.insert(str: data), desc: "resource string" ); |
| 1249 | postProcessFn(key, AsmResourceEntryKind::String); |
| 1250 | } |
| 1251 | |
| 1252 | private: |
| 1253 | EncodingEmitter &emitter; |
| 1254 | StringSectionBuilder &stringSection; |
| 1255 | PostProcessFn postProcessFn; |
| 1256 | bool shouldElideData = false; |
| 1257 | }; |
| 1258 | } // namespace |
| 1259 | |
| 1260 | void BytecodeWriter::writeResourceSection(Operation *op, |
| 1261 | EncodingEmitter &emitter) { |
| 1262 | EncodingEmitter resourceEmitter; |
| 1263 | EncodingEmitter resourceOffsetEmitter; |
| 1264 | uint64_t prevOffset = 0; |
| 1265 | SmallVector<std::tuple<StringRef, AsmResourceEntryKind, uint64_t>> |
| 1266 | curResourceEntries; |
| 1267 | |
| 1268 | // Functor used to process the offset for a resource of `kind` defined by |
| 1269 | // 'key'. |
| 1270 | auto appendResourceOffset = [&](StringRef key, AsmResourceEntryKind kind) { |
| 1271 | uint64_t curOffset = resourceEmitter.size(); |
| 1272 | curResourceEntries.emplace_back(Args&: key, Args&: kind, Args: curOffset - prevOffset); |
| 1273 | prevOffset = curOffset; |
| 1274 | }; |
| 1275 | |
| 1276 | // Functor used to emit a resource group defined by 'key'. |
| 1277 | auto emitResourceGroup = [&](uint64_t key) { |
| 1278 | resourceOffsetEmitter.emitVarInt(value: key, desc: "resource group key" ); |
| 1279 | resourceOffsetEmitter.emitVarInt(value: curResourceEntries.size(), |
| 1280 | desc: "resource group size" ); |
| 1281 | for (auto [key, kind, size] : curResourceEntries) { |
| 1282 | resourceOffsetEmitter.emitVarInt(value: stringSection.insert(str: key), |
| 1283 | desc: "resource key" ); |
| 1284 | resourceOffsetEmitter.emitVarInt(value: size, desc: "resource size" ); |
| 1285 | resourceOffsetEmitter.emitByte(byte: kind, desc: "resource kind" ); |
| 1286 | } |
| 1287 | }; |
| 1288 | |
| 1289 | // Builder used to emit resources. |
| 1290 | ResourceBuilder entryBuilder(resourceEmitter, stringSection, |
| 1291 | appendResourceOffset, |
| 1292 | config.shouldElideResourceData); |
| 1293 | |
| 1294 | // Emit the external resource entries. |
| 1295 | resourceOffsetEmitter.emitVarInt(value: config.externalResourcePrinters.size(), |
| 1296 | desc: "external resource printer count" ); |
| 1297 | for (const auto &printer : config.externalResourcePrinters) { |
| 1298 | curResourceEntries.clear(); |
| 1299 | printer->buildResources(op, builder&: entryBuilder); |
| 1300 | emitResourceGroup(stringSection.insert(str: printer->getName())); |
| 1301 | } |
| 1302 | |
| 1303 | // Emit the dialect resource entries. |
| 1304 | for (DialectNumbering &dialect : numberingState.getDialects()) { |
| 1305 | if (!dialect.asmInterface) |
| 1306 | continue; |
| 1307 | curResourceEntries.clear(); |
| 1308 | dialect.asmInterface->buildResources(op, referencedResources: dialect.resources, builder&: entryBuilder); |
| 1309 | |
| 1310 | // Emit the declaration resources for this dialect, these didn't get emitted |
| 1311 | // by the interface. These resources don't have data attached, so just use a |
| 1312 | // "blob" kind as a placeholder. |
| 1313 | for (const auto &resource : dialect.resourceMap) |
| 1314 | if (resource.second->isDeclaration) |
| 1315 | appendResourceOffset(resource.first, AsmResourceEntryKind::Blob); |
| 1316 | |
| 1317 | // Emit the resource group for this dialect. |
| 1318 | if (!curResourceEntries.empty()) |
| 1319 | emitResourceGroup(dialect.number); |
| 1320 | } |
| 1321 | |
| 1322 | // If we didn't emit any resource groups, elide the resource sections. |
| 1323 | if (resourceOffsetEmitter.size() == 0) |
| 1324 | return; |
| 1325 | |
| 1326 | emitter.emitSection(code: bytecode::Section::kResourceOffset, |
| 1327 | emitter: std::move(resourceOffsetEmitter)); |
| 1328 | emitter.emitSection(code: bytecode::Section::kResource, emitter: std::move(resourceEmitter)); |
| 1329 | } |
| 1330 | |
| 1331 | //===----------------------------------------------------------------------===// |
| 1332 | // Strings |
| 1333 | //===----------------------------------------------------------------------===// |
| 1334 | |
| 1335 | void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) { |
| 1336 | EncodingEmitter stringEmitter; |
| 1337 | stringSection.write(emitter&: stringEmitter); |
| 1338 | emitter.emitSection(code: bytecode::Section::kString, emitter: std::move(stringEmitter)); |
| 1339 | } |
| 1340 | |
| 1341 | //===----------------------------------------------------------------------===// |
| 1342 | // Properties |
| 1343 | //===----------------------------------------------------------------------===// |
| 1344 | |
| 1345 | void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) { |
| 1346 | EncodingEmitter propertiesEmitter; |
| 1347 | propertiesSection.write(emitter&: propertiesEmitter); |
| 1348 | emitter.emitSection(code: bytecode::Section::kProperties, |
| 1349 | emitter: std::move(propertiesEmitter)); |
| 1350 | } |
| 1351 | |
| 1352 | //===----------------------------------------------------------------------===// |
| 1353 | // Entry Points |
| 1354 | //===----------------------------------------------------------------------===// |
| 1355 | |
| 1356 | LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, |
| 1357 | const BytecodeWriterConfig &config) { |
| 1358 | BytecodeWriter writer(op, config); |
| 1359 | return writer.write(rootOp: op, os); |
| 1360 | } |
| 1361 | |