1//===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Bytecode/BytecodeReader.h"
10#include "mlir/AsmParser/AsmParser.h"
11#include "mlir/Bytecode/BytecodeImplementation.h"
12#include "mlir/Bytecode/BytecodeOpInterface.h"
13#include "mlir/Bytecode/Encoding.h"
14#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/Verifier.h"
18#include "mlir/IR/Visitors.h"
19#include "mlir/Support/LLVM.h"
20#include "mlir/Support/LogicalResult.h"
21#include "llvm/ADT/ArrayRef.h"
22#include "llvm/ADT/ScopeExit.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/StringRef.h"
25#include "llvm/Support/Endian.h"
26#include "llvm/Support/MemoryBufferRef.h"
27#include "llvm/Support/SourceMgr.h"
28
29#include <cstddef>
30#include <list>
31#include <memory>
32#include <numeric>
33#include <optional>
34
35#define DEBUG_TYPE "mlir-bytecode-reader"
36
37using namespace mlir;
38
39/// Stringify the given section ID.
40static std::string toString(bytecode::Section::ID sectionID) {
41 switch (sectionID) {
42 case bytecode::Section::kString:
43 return "String (0)";
44 case bytecode::Section::kDialect:
45 return "Dialect (1)";
46 case bytecode::Section::kAttrType:
47 return "AttrType (2)";
48 case bytecode::Section::kAttrTypeOffset:
49 return "AttrTypeOffset (3)";
50 case bytecode::Section::kIR:
51 return "IR (4)";
52 case bytecode::Section::kResource:
53 return "Resource (5)";
54 case bytecode::Section::kResourceOffset:
55 return "ResourceOffset (6)";
56 case bytecode::Section::kDialectVersions:
57 return "DialectVersions (7)";
58 case bytecode::Section::kProperties:
59 return "Properties (8)";
60 default:
61 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
62 }
63}
64
65/// Returns true if the given top-level section ID is optional.
66static bool isSectionOptional(bytecode::Section::ID sectionID, int version) {
67 switch (sectionID) {
68 case bytecode::Section::kString:
69 case bytecode::Section::kDialect:
70 case bytecode::Section::kAttrType:
71 case bytecode::Section::kAttrTypeOffset:
72 case bytecode::Section::kIR:
73 return false;
74 case bytecode::Section::kResource:
75 case bytecode::Section::kResourceOffset:
76 case bytecode::Section::kDialectVersions:
77 return true;
78 case bytecode::Section::kProperties:
79 return version < bytecode::kNativePropertiesEncoding;
80 default:
81 llvm_unreachable("unknown section ID");
82 }
83}
84
85//===----------------------------------------------------------------------===//
86// EncodingReader
87//===----------------------------------------------------------------------===//
88
89namespace {
90class EncodingReader {
91public:
92 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
93 : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
94 explicit EncodingReader(StringRef contents, Location fileLoc)
95 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
96 contents.size()},
97 fileLoc) {}
98
99 /// Returns true if the entire section has been read.
100 bool empty() const { return dataIt == buffer.end(); }
101
102 /// Returns the remaining size of the bytecode.
103 size_t size() const { return buffer.end() - dataIt; }
104
105 /// Align the current reader position to the specified alignment.
106 LogicalResult alignTo(unsigned alignment) {
107 if (!llvm::isPowerOf2_32(Value: alignment))
108 return emitError(args: "expected alignment to be a power-of-two");
109
110 auto isUnaligned = [&](const uint8_t *ptr) {
111 return ((uintptr_t)ptr & (alignment - 1)) != 0;
112 };
113
114 // Shift the reader position to the next alignment boundary.
115 while (isUnaligned(dataIt)) {
116 uint8_t padding;
117 if (failed(result: parseByte(value&: padding)))
118 return failure();
119 if (padding != bytecode::kAlignmentByte) {
120 return emitError(args: "expected alignment byte (0xCB), but got: '0x" +
121 llvm::utohexstr(X: padding) + "'");
122 }
123 }
124
125 // Ensure the data iterator is now aligned. This case is unlikely because we
126 // *just* went through the effort to align the data iterator.
127 if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
128 return emitError(args: "expected data iterator aligned to ", args&: alignment,
129 args: ", but got pointer: '0x" +
130 llvm::utohexstr(X: (uintptr_t)dataIt) + "'");
131 }
132
133 return success();
134 }
135
136 /// Emit an error using the given arguments.
137 template <typename... Args>
138 InFlightDiagnostic emitError(Args &&...args) const {
139 return ::emitError(loc: fileLoc).append(std::forward<Args>(args)...);
140 }
141 InFlightDiagnostic emitError() const { return ::emitError(loc: fileLoc); }
142
143 /// Parse a single byte from the stream.
144 template <typename T>
145 LogicalResult parseByte(T &value) {
146 if (empty())
147 return emitError(args: "attempting to parse a byte at the end of the bytecode");
148 value = static_cast<T>(*dataIt++);
149 return success();
150 }
151 /// Parse a range of bytes of 'length' into the given result.
152 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) {
153 if (length > size()) {
154 return emitError(args: "attempting to parse ", args&: length, args: " bytes when only ",
155 args: size(), args: " remain");
156 }
157 result = {dataIt, length};
158 dataIt += length;
159 return success();
160 }
161 /// Parse a range of bytes of 'length' into the given result, which can be
162 /// assumed to be large enough to hold `length`.
163 LogicalResult parseBytes(size_t length, uint8_t *result) {
164 if (length > size()) {
165 return emitError(args: "attempting to parse ", args&: length, args: " bytes when only ",
166 args: size(), args: " remain");
167 }
168 memcpy(dest: result, src: dataIt, n: length);
169 dataIt += length;
170 return success();
171 }
172
173 /// Parse an aligned blob of data, where the alignment was encoded alongside
174 /// the data.
175 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
176 uint64_t &alignment) {
177 uint64_t dataSize;
178 if (failed(result: parseVarInt(result&: alignment)) || failed(result: parseVarInt(result&: dataSize)) ||
179 failed(result: alignTo(alignment)))
180 return failure();
181 return parseBytes(length: dataSize, result&: data);
182 }
183
184 /// Parse a variable length encoded integer from the byte stream. The first
185 /// encoded byte contains a prefix in the low bits indicating the encoded
186 /// length of the value. This length prefix is a bit sequence of '0's followed
187 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
188 /// (not including the prefix byte). All remaining bits in the first byte,
189 /// along with all of the bits in additional bytes, provide the value of the
190 /// integer encoded in little-endian order.
191 LogicalResult parseVarInt(uint64_t &result) {
192 // Parse the first byte of the encoding, which contains the length prefix.
193 if (failed(result: parseByte(value&: result)))
194 return failure();
195
196 // Handle the overwhelmingly common case where the value is stored in a
197 // single byte. In this case, the first bit is the `1` marker bit.
198 if (LLVM_LIKELY(result & 1)) {
199 result >>= 1;
200 return success();
201 }
202
203 // Handle the overwhelming uncommon case where the value required all 8
204 // bytes (i.e. a really really big number). In this case, the marker byte is
205 // all zeros: `00000000`.
206 if (LLVM_UNLIKELY(result == 0)) {
207 llvm::support::ulittle64_t resultLE;
208 if (failed(result: parseBytes(length: sizeof(resultLE),
209 result: reinterpret_cast<uint8_t *>(&resultLE))))
210 return failure();
211 result = resultLE;
212 return success();
213 }
214 return parseMultiByteVarInt(result);
215 }
216
217 /// Parse a signed variable length encoded integer from the byte stream. A
218 /// signed varint is encoded as a normal varint with zigzag encoding applied,
219 /// i.e. the low bit of the value is used to indicate the sign.
220 LogicalResult parseSignedVarInt(uint64_t &result) {
221 if (failed(result: parseVarInt(result)))
222 return failure();
223 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1)
224 result = (result >> 1) ^ (~(result & 1) + 1);
225 return success();
226 }
227
228 /// Parse a variable length encoded integer whose low bit is used to encode an
229 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
230 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
231 if (failed(result: parseVarInt(result)))
232 return failure();
233 flag = result & 1;
234 result >>= 1;
235 return success();
236 }
237
238 /// Skip the first `length` bytes within the reader.
239 LogicalResult skipBytes(size_t length) {
240 if (length > size()) {
241 return emitError(args: "attempting to skip ", args&: length, args: " bytes when only ",
242 args: size(), args: " remain");
243 }
244 dataIt += length;
245 return success();
246 }
247
248 /// Parse a null-terminated string into `result` (without including the NUL
249 /// terminator).
250 LogicalResult parseNullTerminatedString(StringRef &result) {
251 const char *startIt = (const char *)dataIt;
252 const char *nulIt = (const char *)memchr(s: startIt, c: 0, n: size());
253 if (!nulIt)
254 return emitError(
255 args: "malformed null-terminated string, no null character found");
256
257 result = StringRef(startIt, nulIt - startIt);
258 dataIt = (const uint8_t *)nulIt + 1;
259 return success();
260 }
261
262 /// Parse a section header, placing the kind of section in `sectionID` and the
263 /// contents of the section in `sectionData`.
264 LogicalResult parseSection(bytecode::Section::ID &sectionID,
265 ArrayRef<uint8_t> &sectionData) {
266 uint8_t sectionIDAndHasAlignment;
267 uint64_t length;
268 if (failed(result: parseByte(value&: sectionIDAndHasAlignment)) ||
269 failed(result: parseVarInt(result&: length)))
270 return failure();
271
272 // Extract the section ID and whether the section is aligned. The high bit
273 // of the ID is the alignment flag.
274 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
275 0b01111111);
276 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
277
278 // Check that the section is actually valid before trying to process its
279 // data.
280 if (sectionID >= bytecode::Section::kNumSections)
281 return emitError(args: "invalid section ID: ", args: unsigned(sectionID));
282
283 // Process the section alignment if present.
284 if (hasAlignment) {
285 uint64_t alignment;
286 if (failed(result: parseVarInt(result&: alignment)) || failed(result: alignTo(alignment)))
287 return failure();
288 }
289
290 // Parse the actual section data.
291 return parseBytes(length: static_cast<size_t>(length), result&: sectionData);
292 }
293
294 Location getLoc() const { return fileLoc; }
295
296private:
297 /// Parse a variable length encoded integer from the byte stream. This method
298 /// is a fallback when the number of bytes used to encode the value is greater
299 /// than 1, but less than the max (9). The provided `result` value can be
300 /// assumed to already contain the first byte of the value.
301 /// NOTE: This method is marked noinline to avoid pessimizing the common case
302 /// of single byte encoding.
303 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) {
304 // Count the number of trailing zeros in the marker byte, this indicates the
305 // number of trailing bytes that are part of the value. We use `uint32_t`
306 // here because we only care about the first byte, and so that be actually
307 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
308 // implementation).
309 uint32_t numBytes = llvm::countr_zero<uint32_t>(Val: result);
310 assert(numBytes > 0 && numBytes <= 7 &&
311 "unexpected number of trailing zeros in varint encoding");
312
313 // Parse in the remaining bytes of the value.
314 llvm::support::ulittle64_t resultLE(result);
315 if (failed(
316 result: parseBytes(length: numBytes, result: reinterpret_cast<uint8_t *>(&resultLE) + 1)))
317 return failure();
318
319 // Shift out the low-order bits that were used to mark how the value was
320 // encoded.
321 result = resultLE >> (numBytes + 1);
322 return success();
323 }
324
325 /// The bytecode buffer.
326 ArrayRef<uint8_t> buffer;
327
328 /// The current iterator within the 'buffer'.
329 const uint8_t *dataIt;
330
331 /// A location for the bytecode used to report errors.
332 Location fileLoc;
333};
334} // namespace
335
336/// Resolve an index into the given entry list. `entry` may either be a
337/// reference, in which case it is assigned to the corresponding value in
338/// `entries`, or a pointer, in which case it is assigned to the address of the
339/// element in `entries`.
340template <typename RangeT, typename T>
341static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries,
342 uint64_t index, T &entry,
343 StringRef entryStr) {
344 if (index >= entries.size())
345 return reader.emitError(args: "invalid ", args&: entryStr, args: " index: ", args&: index);
346
347 // If the provided entry is a pointer, resolve to the address of the entry.
348 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
349 entry = entries[index];
350 else
351 entry = &entries[index];
352 return success();
353}
354
355/// Parse and resolve an index into the given entry list.
356template <typename RangeT, typename T>
357static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
358 T &entry, StringRef entryStr) {
359 uint64_t entryIdx;
360 if (failed(result: reader.parseVarInt(result&: entryIdx)))
361 return failure();
362 return resolveEntry(reader, entries, entryIdx, entry, entryStr);
363}
364
365//===----------------------------------------------------------------------===//
366// StringSectionReader
367//===----------------------------------------------------------------------===//
368
369namespace {
370/// This class is used to read references to the string section from the
371/// bytecode.
372class StringSectionReader {
373public:
374 /// Initialize the string section reader with the given section data.
375 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
376
377 /// Parse a shared string from the string section. The shared string is
378 /// encoded using an index to a corresponding string in the string section.
379 LogicalResult parseString(EncodingReader &reader, StringRef &result) {
380 return parseEntry(reader, entries&: strings, entry&: result, entryStr: "string");
381 }
382
383 /// Parse a shared string from the string section. The shared string is
384 /// encoded using an index to a corresponding string in the string section.
385 /// This variant parses a flag compressed with the index.
386 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result,
387 bool &flag) {
388 uint64_t entryIdx;
389 if (failed(result: reader.parseVarIntWithFlag(result&: entryIdx, flag)))
390 return failure();
391 return parseStringAtIndex(reader, index: entryIdx, result);
392 }
393
394 /// Parse a shared string from the string section. The shared string is
395 /// encoded using an index to a corresponding string in the string section.
396 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
397 StringRef &result) {
398 return resolveEntry(reader, entries&: strings, index, entry&: result, entryStr: "string");
399 }
400
401private:
402 /// The table of strings referenced within the bytecode file.
403 SmallVector<StringRef> strings;
404};
405} // namespace
406
407LogicalResult StringSectionReader::initialize(Location fileLoc,
408 ArrayRef<uint8_t> sectionData) {
409 EncodingReader stringReader(sectionData, fileLoc);
410
411 // Parse the number of strings in the section.
412 uint64_t numStrings;
413 if (failed(result: stringReader.parseVarInt(result&: numStrings)))
414 return failure();
415 strings.resize(N: numStrings);
416
417 // Parse each of the strings. The sizes of the strings are encoded in reverse
418 // order, so that's the order we populate the table.
419 size_t stringDataEndOffset = sectionData.size();
420 for (StringRef &string : llvm::reverse(C&: strings)) {
421 uint64_t stringSize;
422 if (failed(result: stringReader.parseVarInt(result&: stringSize)))
423 return failure();
424 if (stringDataEndOffset < stringSize) {
425 return stringReader.emitError(
426 args: "string size exceeds the available data size");
427 }
428
429 // Extract the string from the data, dropping the null character.
430 size_t stringOffset = stringDataEndOffset - stringSize;
431 string = StringRef(
432 reinterpret_cast<const char *>(sectionData.data() + stringOffset),
433 stringSize - 1);
434 stringDataEndOffset = stringOffset;
435 }
436
437 // Check that the only remaining data was for the strings, i.e. the reader
438 // should be at the same offset as the first string.
439 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
440 return stringReader.emitError(args: "unexpected trailing data between the "
441 "offsets for strings and their data");
442 }
443 return success();
444}
445
446//===----------------------------------------------------------------------===//
447// BytecodeDialect
448//===----------------------------------------------------------------------===//
449
450namespace {
451class DialectReader;
452
453/// This struct represents a dialect entry within the bytecode.
454struct BytecodeDialect {
455 /// Load the dialect into the provided context if it hasn't been loaded yet.
456 /// Returns failure if the dialect couldn't be loaded *and* the provided
457 /// context does not allow unregistered dialects. The provided reader is used
458 /// for error emission if necessary.
459 LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
460
461 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can
462 /// only be called after `load`.
463 Dialect *getLoadedDialect() const {
464 assert(dialect &&
465 "expected `load` to be invoked before `getLoadedDialect`");
466 return *dialect;
467 }
468
469 /// The loaded dialect entry. This field is std::nullopt if we haven't
470 /// attempted to load, nullptr if we failed to load, otherwise the loaded
471 /// dialect.
472 std::optional<Dialect *> dialect;
473
474 /// The bytecode interface of the dialect, or nullptr if the dialect does not
475 /// implement the bytecode interface. This field should only be checked if the
476 /// `dialect` field is not std::nullopt.
477 const BytecodeDialectInterface *interface = nullptr;
478
479 /// The name of the dialect.
480 StringRef name;
481
482 /// A buffer containing the encoding of the dialect version parsed.
483 ArrayRef<uint8_t> versionBuffer;
484
485 /// Lazy loaded dialect version from the handle above.
486 std::unique_ptr<DialectVersion> loadedVersion;
487};
488
489/// This struct represents an operation name entry within the bytecode.
490struct BytecodeOperationName {
491 BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
492 std::optional<bool> wasRegistered)
493 : dialect(dialect), name(name), wasRegistered(wasRegistered) {}
494
495 /// The loaded operation name, or std::nullopt if it hasn't been processed
496 /// yet.
497 std::optional<OperationName> opName;
498
499 /// The dialect that owns this operation name.
500 BytecodeDialect *dialect;
501
502 /// The name of the operation, without the dialect prefix.
503 StringRef name;
504
505 /// Whether this operation was registered when the bytecode was produced.
506 /// This flag is populated when bytecode version >=kNativePropertiesEncoding.
507 std::optional<bool> wasRegistered;
508};
509} // namespace
510
511/// Parse a single dialect group encoded in the byte stream.
512static LogicalResult parseDialectGrouping(
513 EncodingReader &reader,
514 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
515 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
516 // Parse the dialect and the number of entries in the group.
517 std::unique_ptr<BytecodeDialect> *dialect;
518 if (failed(result: parseEntry(reader, entries&: dialects, entry&: dialect, entryStr: "dialect")))
519 return failure();
520 uint64_t numEntries;
521 if (failed(result: reader.parseVarInt(result&: numEntries)))
522 return failure();
523
524 for (uint64_t i = 0; i < numEntries; ++i)
525 if (failed(result: entryCallback(dialect->get())))
526 return failure();
527 return success();
528}
529
530//===----------------------------------------------------------------------===//
531// ResourceSectionReader
532//===----------------------------------------------------------------------===//
533
534namespace {
535/// This class is used to read the resource section from the bytecode.
536class ResourceSectionReader {
537public:
538 /// Initialize the resource section reader with the given section data.
539 LogicalResult
540 initialize(Location fileLoc, const ParserConfig &config,
541 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
542 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
543 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
544 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
545
546 /// Parse a dialect resource handle from the resource section.
547 LogicalResult parseResourceHandle(EncodingReader &reader,
548 AsmDialectResourceHandle &result) {
549 return parseEntry(reader, entries&: dialectResources, entry&: result, entryStr: "resource handle");
550 }
551
552private:
553 /// The table of dialect resources within the bytecode file.
554 SmallVector<AsmDialectResourceHandle> dialectResources;
555 llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
556};
557
558class ParsedResourceEntry : public AsmParsedResourceEntry {
559public:
560 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
561 EncodingReader &reader, StringSectionReader &stringReader,
562 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
563 : key(key), kind(kind), reader(reader), stringReader(stringReader),
564 bufferOwnerRef(bufferOwnerRef) {}
565 ~ParsedResourceEntry() override = default;
566
567 StringRef getKey() const final { return key; }
568
569 InFlightDiagnostic emitError() const final { return reader.emitError(); }
570
571 AsmResourceEntryKind getKind() const final { return kind; }
572
573 FailureOr<bool> parseAsBool() const final {
574 if (kind != AsmResourceEntryKind::Bool)
575 return emitError() << "expected a bool resource entry, but found a "
576 << toString(kind) << " entry instead";
577
578 bool value;
579 if (failed(result: reader.parseByte(value)))
580 return failure();
581 return value;
582 }
583 FailureOr<std::string> parseAsString() const final {
584 if (kind != AsmResourceEntryKind::String)
585 return emitError() << "expected a string resource entry, but found a "
586 << toString(kind) << " entry instead";
587
588 StringRef string;
589 if (failed(result: stringReader.parseString(reader, result&: string)))
590 return failure();
591 return string.str();
592 }
593
594 FailureOr<AsmResourceBlob>
595 parseAsBlob(BlobAllocatorFn allocator) const final {
596 if (kind != AsmResourceEntryKind::Blob)
597 return emitError() << "expected a blob resource entry, but found a "
598 << toString(kind) << " entry instead";
599
600 ArrayRef<uint8_t> data;
601 uint64_t alignment;
602 if (failed(result: reader.parseBlobAndAlignment(data, alignment)))
603 return failure();
604
605 // If we have an extendable reference to the buffer owner, we don't need to
606 // allocate a new buffer for the data, and can use the data directly.
607 if (bufferOwnerRef) {
608 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()),
609 data.size());
610
611 // Allocate an unmanager buffer which captures a reference to the owner.
612 // For now we just mark this as immutable, but in the future we should
613 // explore marking this as mutable when desired.
614 return UnmanagedAsmResourceBlob::allocateWithAlign(
615 data: charData, align: alignment,
616 deleter: [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {});
617 }
618
619 // Allocate memory for the blob using the provided allocator and copy the
620 // data into it.
621 AsmResourceBlob blob = allocator(data.size(), alignment);
622 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
623 blob.isMutable() &&
624 "blob allocator did not return a properly aligned address");
625 memcpy(dest: blob.getMutableData().data(), src: data.data(), n: data.size());
626 return blob;
627 }
628
629private:
630 StringRef key;
631 AsmResourceEntryKind kind;
632 EncodingReader &reader;
633 StringSectionReader &stringReader;
634 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
635};
636} // namespace
637
638template <typename T>
639static LogicalResult
640parseResourceGroup(Location fileLoc, bool allowEmpty,
641 EncodingReader &offsetReader, EncodingReader &resourceReader,
642 StringSectionReader &stringReader, T *handler,
643 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
644 function_ref<StringRef(StringRef)> remapKey = {},
645 function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
646 uint64_t numResources;
647 if (failed(result: offsetReader.parseVarInt(result&: numResources)))
648 return failure();
649
650 for (uint64_t i = 0; i < numResources; ++i) {
651 StringRef key;
652 AsmResourceEntryKind kind;
653 uint64_t resourceOffset;
654 ArrayRef<uint8_t> data;
655 if (failed(result: stringReader.parseString(reader&: offsetReader, result&: key)) ||
656 failed(result: offsetReader.parseVarInt(result&: resourceOffset)) ||
657 failed(result: offsetReader.parseByte(value&: kind)) ||
658 failed(result: resourceReader.parseBytes(length: resourceOffset, result&: data)))
659 return failure();
660
661 // Process the resource key.
662 if ((processKeyFn && failed(result: processKeyFn(key))))
663 return failure();
664
665 // If the resource data is empty and we allow it, don't error out when
666 // parsing below, just skip it.
667 if (allowEmpty && data.empty())
668 continue;
669
670 // Ignore the entry if we don't have a valid handler.
671 if (!handler)
672 continue;
673
674 // Otherwise, parse the resource value.
675 EncodingReader entryReader(data, fileLoc);
676 key = remapKey(key);
677 ParsedResourceEntry entry(key, kind, entryReader, stringReader,
678 bufferOwnerRef);
679 if (failed(handler->parseResource(entry)))
680 return failure();
681 if (!entryReader.empty()) {
682 return entryReader.emitError(
683 args: "unexpected trailing bytes in resource entry '", args&: key, args: "'");
684 }
685 }
686 return success();
687}
688
689LogicalResult ResourceSectionReader::initialize(
690 Location fileLoc, const ParserConfig &config,
691 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
692 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
693 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
694 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
695 EncodingReader resourceReader(sectionData, fileLoc);
696 EncodingReader offsetReader(offsetSectionData, fileLoc);
697
698 // Read the number of external resource providers.
699 uint64_t numExternalResourceGroups;
700 if (failed(result: offsetReader.parseVarInt(result&: numExternalResourceGroups)))
701 return failure();
702
703 // Utility functor that dispatches to `parseResourceGroup`, but implicitly
704 // provides most of the arguments.
705 auto parseGroup = [&](auto *handler, bool allowEmpty = false,
706 function_ref<LogicalResult(StringRef)> keyFn = {}) {
707 auto resolveKey = [&](StringRef key) -> StringRef {
708 auto it = dialectResourceHandleRenamingMap.find(Key: key);
709 if (it == dialectResourceHandleRenamingMap.end())
710 return "";
711 return it->second;
712 };
713
714 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
715 stringReader, handler, bufferOwnerRef, resolveKey,
716 keyFn);
717 };
718
719 // Read the external resources from the bytecode.
720 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
721 StringRef key;
722 if (failed(result: stringReader.parseString(reader&: offsetReader, result&: key)))
723 return failure();
724
725 // Get the handler for these resources.
726 // TODO: Should we require handling external resources in some scenarios?
727 AsmResourceParser *handler = config.getResourceParser(name: key);
728 if (!handler) {
729 emitWarning(loc: fileLoc) << "ignoring unknown external resources for '" << key
730 << "'";
731 }
732
733 if (failed(result: parseGroup(handler)))
734 return failure();
735 }
736
737 // Read the dialect resources from the bytecode.
738 MLIRContext *ctx = fileLoc->getContext();
739 while (!offsetReader.empty()) {
740 std::unique_ptr<BytecodeDialect> *dialect;
741 if (failed(result: parseEntry(reader&: offsetReader, entries&: dialects, entry&: dialect, entryStr: "dialect")) ||
742 failed(result: (*dialect)->load(reader: dialectReader, ctx)))
743 return failure();
744 Dialect *loadedDialect = (*dialect)->getLoadedDialect();
745 if (!loadedDialect) {
746 return resourceReader.emitError()
747 << "dialect '" << (*dialect)->name << "' is unknown";
748 }
749 const auto *handler = dyn_cast<OpAsmDialectInterface>(Val: loadedDialect);
750 if (!handler) {
751 return resourceReader.emitError()
752 << "unexpected resources for dialect '" << (*dialect)->name << "'";
753 }
754
755 // Ensure that each resource is declared before being processed.
756 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
757 FailureOr<AsmDialectResourceHandle> handle =
758 handler->declareResource(key);
759 if (failed(result: handle)) {
760 return resourceReader.emitError()
761 << "unknown 'resource' key '" << key << "' for dialect '"
762 << (*dialect)->name << "'";
763 }
764 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(handle: *handle);
765 dialectResources.push_back(Elt: *handle);
766 return success();
767 };
768
769 // Parse the resources for this dialect. We allow empty resources because we
770 // just treat these as declarations.
771 if (failed(result: parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn)))
772 return failure();
773 }
774
775 return success();
776}
777
778//===----------------------------------------------------------------------===//
779// Attribute/Type Reader
780//===----------------------------------------------------------------------===//
781
782namespace {
783/// This class provides support for reading attribute and type entries from the
784/// bytecode. Attribute and Type entries are read lazily on demand, so we use
785/// this reader to manage when to actually parse them from the bytecode.
786class AttrTypeReader {
787 /// This class represents a single attribute or type entry.
788 template <typename T>
789 struct Entry {
790 /// The entry, or null if it hasn't been resolved yet.
791 T entry = {};
792 /// The parent dialect of this entry.
793 BytecodeDialect *dialect = nullptr;
794 /// A flag indicating if the entry was encoded using a custom encoding,
795 /// instead of using the textual assembly format.
796 bool hasCustomEncoding = false;
797 /// The raw data of this entry in the bytecode.
798 ArrayRef<uint8_t> data;
799 };
800 using AttrEntry = Entry<Attribute>;
801 using TypeEntry = Entry<Type>;
802
803public:
804 AttrTypeReader(StringSectionReader &stringReader,
805 ResourceSectionReader &resourceReader,
806 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
807 uint64_t &bytecodeVersion, Location fileLoc,
808 const ParserConfig &config)
809 : stringReader(stringReader), resourceReader(resourceReader),
810 dialectsMap(dialectsMap), fileLoc(fileLoc),
811 bytecodeVersion(bytecodeVersion), parserConfig(config) {}
812
813 /// Initialize the attribute and type information within the reader.
814 LogicalResult
815 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
816 ArrayRef<uint8_t> sectionData,
817 ArrayRef<uint8_t> offsetSectionData);
818
819 /// Resolve the attribute or type at the given index. Returns nullptr on
820 /// failure.
821 Attribute resolveAttribute(size_t index) {
822 return resolveEntry(entries&: attributes, index, entryType: "Attribute");
823 }
824 Type resolveType(size_t index) { return resolveEntry(entries&: types, index, entryType: "Type"); }
825
826 /// Parse a reference to an attribute or type using the given reader.
827 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
828 uint64_t attrIdx;
829 if (failed(result: reader.parseVarInt(result&: attrIdx)))
830 return failure();
831 result = resolveAttribute(index: attrIdx);
832 return success(isSuccess: !!result);
833 }
834 LogicalResult parseOptionalAttribute(EncodingReader &reader,
835 Attribute &result) {
836 uint64_t attrIdx;
837 bool flag;
838 if (failed(result: reader.parseVarIntWithFlag(result&: attrIdx, flag)))
839 return failure();
840 if (!flag)
841 return success();
842 result = resolveAttribute(index: attrIdx);
843 return success(isSuccess: !!result);
844 }
845
846 LogicalResult parseType(EncodingReader &reader, Type &result) {
847 uint64_t typeIdx;
848 if (failed(result: reader.parseVarInt(result&: typeIdx)))
849 return failure();
850 result = resolveType(index: typeIdx);
851 return success(isSuccess: !!result);
852 }
853
854 template <typename T>
855 LogicalResult parseAttribute(EncodingReader &reader, T &result) {
856 Attribute baseResult;
857 if (failed(result: parseAttribute(reader, result&: baseResult)))
858 return failure();
859 if ((result = dyn_cast<T>(baseResult)))
860 return success();
861 return reader.emitError("expected attribute of type: ",
862 llvm::getTypeName<T>(), ", but got: ", baseResult);
863 }
864
865private:
866 /// Resolve the given entry at `index`.
867 template <typename T>
868 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
869 StringRef entryType);
870
871 /// Parse an entry using the given reader that was encoded using the textual
872 /// assembly format.
873 template <typename T>
874 LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
875 StringRef entryType);
876
877 /// Parse an entry using the given reader that was encoded using a custom
878 /// bytecode format.
879 template <typename T>
880 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
881 StringRef entryType);
882
883 /// The string section reader used to resolve string references when parsing
884 /// custom encoded attribute/type entries.
885 StringSectionReader &stringReader;
886
887 /// The resource section reader used to resolve resource references when
888 /// parsing custom encoded attribute/type entries.
889 ResourceSectionReader &resourceReader;
890
891 /// The map of the loaded dialects used to retrieve dialect information, such
892 /// as the dialect version.
893 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
894
895 /// The set of attribute and type entries.
896 SmallVector<AttrEntry> attributes;
897 SmallVector<TypeEntry> types;
898
899 /// A location used for error emission.
900 Location fileLoc;
901
902 /// Current bytecode version being used.
903 uint64_t &bytecodeVersion;
904
905 /// Reference to the parser configuration.
906 const ParserConfig &parserConfig;
907};
908
909class DialectReader : public DialectBytecodeReader {
910public:
911 DialectReader(AttrTypeReader &attrTypeReader,
912 StringSectionReader &stringReader,
913 ResourceSectionReader &resourceReader,
914 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
915 EncodingReader &reader, uint64_t &bytecodeVersion)
916 : attrTypeReader(attrTypeReader), stringReader(stringReader),
917 resourceReader(resourceReader), dialectsMap(dialectsMap),
918 reader(reader), bytecodeVersion(bytecodeVersion) {}
919
920 InFlightDiagnostic emitError(const Twine &msg) const override {
921 return reader.emitError(args: msg);
922 }
923
924 FailureOr<const DialectVersion *>
925 getDialectVersion(StringRef dialectName) const override {
926 // First check if the dialect is available in the map.
927 auto dialectEntry = dialectsMap.find(Key: dialectName);
928 if (dialectEntry == dialectsMap.end())
929 return failure();
930 // If the dialect was found, try to load it. This will trigger reading the
931 // bytecode version from the version buffer if it wasn't already processed.
932 // Return failure if either of those two actions could not be completed.
933 if (failed(result: dialectEntry->getValue()->load(reader: *this, ctx: getLoc().getContext())) ||
934 dialectEntry->getValue()->loadedVersion == nullptr)
935 return failure();
936 return dialectEntry->getValue()->loadedVersion.get();
937 }
938
939 MLIRContext *getContext() const override { return getLoc().getContext(); }
940
941 uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
942
943 DialectReader withEncodingReader(EncodingReader &encReader) const {
944 return DialectReader(attrTypeReader, stringReader, resourceReader,
945 dialectsMap, encReader, bytecodeVersion);
946 }
947
948 Location getLoc() const { return reader.getLoc(); }
949
950 //===--------------------------------------------------------------------===//
951 // IR
952 //===--------------------------------------------------------------------===//
953
954 LogicalResult readAttribute(Attribute &result) override {
955 return attrTypeReader.parseAttribute(reader, result);
956 }
957 LogicalResult readOptionalAttribute(Attribute &result) override {
958 return attrTypeReader.parseOptionalAttribute(reader, result);
959 }
960 LogicalResult readType(Type &result) override {
961 return attrTypeReader.parseType(reader, result);
962 }
963
964 FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
965 AsmDialectResourceHandle handle;
966 if (failed(result: resourceReader.parseResourceHandle(reader, result&: handle)))
967 return failure();
968 return handle;
969 }
970
971 //===--------------------------------------------------------------------===//
972 // Primitives
973 //===--------------------------------------------------------------------===//
974
975 LogicalResult readVarInt(uint64_t &result) override {
976 return reader.parseVarInt(result);
977 }
978
979 LogicalResult readSignedVarInt(int64_t &result) override {
980 uint64_t unsignedResult;
981 if (failed(result: reader.parseSignedVarInt(result&: unsignedResult)))
982 return failure();
983 result = static_cast<int64_t>(unsignedResult);
984 return success();
985 }
986
987 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override {
988 // Small values are encoded using a single byte.
989 if (bitWidth <= 8) {
990 uint8_t value;
991 if (failed(result: reader.parseByte(value)))
992 return failure();
993 return APInt(bitWidth, value);
994 }
995
996 // Large values up to 64 bits are encoded using a single varint.
997 if (bitWidth <= 64) {
998 uint64_t value;
999 if (failed(result: reader.parseSignedVarInt(result&: value)))
1000 return failure();
1001 return APInt(bitWidth, value);
1002 }
1003
1004 // Otherwise, for really big values we encode the array of active words in
1005 // the value.
1006 uint64_t numActiveWords;
1007 if (failed(result: reader.parseVarInt(result&: numActiveWords)))
1008 return failure();
1009 SmallVector<uint64_t, 4> words(numActiveWords);
1010 for (uint64_t i = 0; i < numActiveWords; ++i)
1011 if (failed(result: reader.parseSignedVarInt(result&: words[i])))
1012 return failure();
1013 return APInt(bitWidth, words);
1014 }
1015
1016 FailureOr<APFloat>
1017 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override {
1018 FailureOr<APInt> intVal =
1019 readAPIntWithKnownWidth(bitWidth: APFloat::getSizeInBits(Sem: semantics));
1020 if (failed(result: intVal))
1021 return failure();
1022 return APFloat(semantics, *intVal);
1023 }
1024
1025 LogicalResult readString(StringRef &result) override {
1026 return stringReader.parseString(reader, result);
1027 }
1028
1029 LogicalResult readBlob(ArrayRef<char> &result) override {
1030 uint64_t dataSize;
1031 ArrayRef<uint8_t> data;
1032 if (failed(result: reader.parseVarInt(result&: dataSize)) ||
1033 failed(result: reader.parseBytes(length: dataSize, result&: data)))
1034 return failure();
1035 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()),
1036 data.size());
1037 return success();
1038 }
1039
1040 LogicalResult readBool(bool &result) override {
1041 return reader.parseByte(value&: result);
1042 }
1043
1044private:
1045 AttrTypeReader &attrTypeReader;
1046 StringSectionReader &stringReader;
1047 ResourceSectionReader &resourceReader;
1048 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1049 EncodingReader &reader;
1050 uint64_t &bytecodeVersion;
1051};
1052
1053/// Wraps the properties section and handles reading properties out of it.
1054class PropertiesSectionReader {
1055public:
1056 /// Initialize the properties section reader with the given section data.
1057 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
1058 if (sectionData.empty())
1059 return success();
1060 EncodingReader propReader(sectionData, fileLoc);
1061 uint64_t count;
1062 if (failed(result: propReader.parseVarInt(result&: count)))
1063 return failure();
1064 // Parse the raw properties buffer.
1065 if (failed(result: propReader.parseBytes(length: propReader.size(), result&: propertiesBuffers)))
1066 return failure();
1067
1068 EncodingReader offsetsReader(propertiesBuffers, fileLoc);
1069 offsetTable.reserve(N: count);
1070 for (auto idx : llvm::seq<int64_t>(Begin: 0, End: count)) {
1071 (void)idx;
1072 offsetTable.push_back(Elt: propertiesBuffers.size() - offsetsReader.size());
1073 ArrayRef<uint8_t> rawProperties;
1074 uint64_t dataSize;
1075 if (failed(result: offsetsReader.parseVarInt(result&: dataSize)) ||
1076 failed(result: offsetsReader.parseBytes(length: dataSize, result&: rawProperties)))
1077 return failure();
1078 }
1079 if (!offsetsReader.empty())
1080 return offsetsReader.emitError()
1081 << "Broken properties section: didn't exhaust the offsets table";
1082 return success();
1083 }
1084
1085 LogicalResult read(Location fileLoc, DialectReader &dialectReader,
1086 OperationName *opName, OperationState &opState) {
1087 uint64_t propertiesIdx;
1088 if (failed(result: dialectReader.readVarInt(result&: propertiesIdx)))
1089 return failure();
1090 if (propertiesIdx >= offsetTable.size())
1091 return dialectReader.emitError(msg: "Properties idx out-of-bound for ")
1092 << opName->getStringRef();
1093 size_t propertiesOffset = offsetTable[propertiesIdx];
1094 if (propertiesIdx >= propertiesBuffers.size())
1095 return dialectReader.emitError(msg: "Properties offset out-of-bound for ")
1096 << opName->getStringRef();
1097
1098 // Acquire the sub-buffer that represent the requested properties.
1099 ArrayRef<char> rawProperties;
1100 {
1101 // "Seek" to the requested offset by getting a new reader with the right
1102 // sub-buffer.
1103 EncodingReader reader(propertiesBuffers.drop_front(N: propertiesOffset),
1104 fileLoc);
1105 // Properties are stored as a sequence of {size + raw_data}.
1106 if (failed(
1107 result: dialectReader.withEncodingReader(encReader&: reader).readBlob(result&: rawProperties)))
1108 return failure();
1109 }
1110 // Setup a new reader to read from the `rawProperties` sub-buffer.
1111 EncodingReader reader(
1112 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
1113 DialectReader propReader = dialectReader.withEncodingReader(encReader&: reader);
1114
1115 auto *iface = opName->getInterface<BytecodeOpInterface>();
1116 if (iface)
1117 return iface->readProperties(propReader, opState);
1118 if (opName->isRegistered())
1119 return propReader.emitError(
1120 msg: "has properties but missing BytecodeOpInterface for ")
1121 << opName->getStringRef();
1122 // Unregistered op are storing properties as an attribute.
1123 return propReader.readAttribute(result&: opState.propertiesAttr);
1124 }
1125
1126private:
1127 /// The properties buffer referenced within the bytecode file.
1128 ArrayRef<uint8_t> propertiesBuffers;
1129
1130 /// Table of offset in the buffer above.
1131 SmallVector<int64_t> offsetTable;
1132};
1133} // namespace
1134
1135LogicalResult AttrTypeReader::initialize(
1136 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
1137 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
1138 EncodingReader offsetReader(offsetSectionData, fileLoc);
1139
1140 // Parse the number of attribute and type entries.
1141 uint64_t numAttributes, numTypes;
1142 if (failed(result: offsetReader.parseVarInt(result&: numAttributes)) ||
1143 failed(result: offsetReader.parseVarInt(result&: numTypes)))
1144 return failure();
1145 attributes.resize(N: numAttributes);
1146 types.resize(N: numTypes);
1147
1148 // A functor used to accumulate the offsets for the entries in the given
1149 // range.
1150 uint64_t currentOffset = 0;
1151 auto parseEntries = [&](auto &&range) {
1152 size_t currentIndex = 0, endIndex = range.size();
1153
1154 // Parse an individual entry.
1155 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
1156 auto &entry = range[currentIndex++];
1157
1158 uint64_t entrySize;
1159 if (failed(offsetReader.parseVarIntWithFlag(result&: entrySize,
1160 flag&: entry.hasCustomEncoding)))
1161 return failure();
1162
1163 // Verify that the offset is actually valid.
1164 if (currentOffset + entrySize > sectionData.size()) {
1165 return offsetReader.emitError(
1166 args: "Attribute or Type entry offset points past the end of section");
1167 }
1168
1169 entry.data = sectionData.slice(N: currentOffset, M: entrySize);
1170 entry.dialect = dialect;
1171 currentOffset += entrySize;
1172 return success();
1173 };
1174 while (currentIndex != endIndex)
1175 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn)))
1176 return failure();
1177 return success();
1178 };
1179
1180 // Process each of the attributes, and then the types.
1181 if (failed(result: parseEntries(attributes)) || failed(result: parseEntries(types)))
1182 return failure();
1183
1184 // Ensure that we read everything from the section.
1185 if (!offsetReader.empty()) {
1186 return offsetReader.emitError(
1187 args: "unexpected trailing data in the Attribute/Type offset section");
1188 }
1189
1190 return success();
1191}
1192
1193template <typename T>
1194T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
1195 StringRef entryType) {
1196 if (index >= entries.size()) {
1197 emitError(loc: fileLoc) << "invalid " << entryType << " index: " << index;
1198 return {};
1199 }
1200
1201 // If the entry has already been resolved, there is nothing left to do.
1202 Entry<T> &entry = entries[index];
1203 if (entry.entry)
1204 return entry.entry;
1205
1206 // Parse the entry.
1207 EncodingReader reader(entry.data, fileLoc);
1208
1209 // Parse based on how the entry was encoded.
1210 if (entry.hasCustomEncoding) {
1211 if (failed(parseCustomEntry(entry, reader, entryType)))
1212 return T();
1213 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
1214 return T();
1215 }
1216
1217 if (!reader.empty()) {
1218 reader.emitError(args: "unexpected trailing bytes after " + entryType + " entry");
1219 return T();
1220 }
1221 return entry.entry;
1222}
1223
1224template <typename T>
1225LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1226 StringRef entryType) {
1227 StringRef asmStr;
1228 if (failed(result: reader.parseNullTerminatedString(result&: asmStr)))
1229 return failure();
1230
1231 // Invoke the MLIR assembly parser to parse the entry text.
1232 size_t numRead = 0;
1233 MLIRContext *context = fileLoc->getContext();
1234 if constexpr (std::is_same_v<T, Type>)
1235 result =
1236 ::parseType(typeStr: asmStr, context, numRead: &numRead, /*isKnownNullTerminated=*/true);
1237 else
1238 result = ::parseAttribute(attrStr: asmStr, context, type: Type(), numRead: &numRead,
1239 /*isKnownNullTerminated=*/true);
1240 if (!result)
1241 return failure();
1242
1243 // Ensure there weren't dangling characters after the entry.
1244 if (numRead != asmStr.size()) {
1245 return reader.emitError(args: "trailing characters found after ", args&: entryType,
1246 args: " assembly format: ", args: asmStr.drop_front(N: numRead));
1247 }
1248 return success();
1249}
1250
1251template <typename T>
1252LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
1253 EncodingReader &reader,
1254 StringRef entryType) {
1255 DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
1256 reader, bytecodeVersion);
1257 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
1258 return failure();
1259
1260 if constexpr (std::is_same_v<T, Type>) {
1261 // Try parsing with callbacks first if available.
1262 for (const auto &callback :
1263 parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
1264 if (failed(
1265 callback->read(reader&: dialectReader, dialectName: entry.dialect->name, entry&: entry.entry)))
1266 return failure();
1267 // Early return if parsing was successful.
1268 if (!!entry.entry)
1269 return success();
1270
1271 // Reset the reader if we failed to parse, so we can fall through the
1272 // other parsing functions.
1273 reader = EncodingReader(entry.data, reader.getLoc());
1274 }
1275 } else {
1276 // Try parsing with callbacks first if available.
1277 for (const auto &callback :
1278 parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
1279 if (failed(
1280 callback->read(reader&: dialectReader, dialectName: entry.dialect->name, entry&: entry.entry)))
1281 return failure();
1282 // Early return if parsing was successful.
1283 if (!!entry.entry)
1284 return success();
1285
1286 // Reset the reader if we failed to parse, so we can fall through the
1287 // other parsing functions.
1288 reader = EncodingReader(entry.data, reader.getLoc());
1289 }
1290 }
1291
1292 // Ensure that the dialect implements the bytecode interface.
1293 if (!entry.dialect->interface) {
1294 return reader.emitError("dialect '", entry.dialect->name,
1295 "' does not implement the bytecode interface");
1296 }
1297
1298 if constexpr (std::is_same_v<T, Type>)
1299 entry.entry = entry.dialect->interface->readType(dialectReader);
1300 else
1301 entry.entry = entry.dialect->interface->readAttribute(dialectReader);
1302
1303 return success(!!entry.entry);
1304}
1305
1306//===----------------------------------------------------------------------===//
1307// Bytecode Reader
1308//===----------------------------------------------------------------------===//
1309
1310/// This class is used to read a bytecode buffer and translate it into MLIR.
1311class mlir::BytecodeReader::Impl {
1312 struct RegionReadState;
1313 using LazyLoadableOpsInfo =
1314 std::list<std::pair<Operation *, RegionReadState>>;
1315 using LazyLoadableOpsMap =
1316 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>;
1317
1318public:
1319 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
1320 llvm::MemoryBufferRef buffer,
1321 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
1322 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
1323 attrTypeReader(stringReader, resourceReader, dialectsMap, version,
1324 fileLoc, config),
1325 // Use the builtin unrealized conversion cast operation to represent
1326 // forward references to values that aren't yet defined.
1327 forwardRefOpState(UnknownLoc::get(config.getContext()),
1328 "builtin.unrealized_conversion_cast", ValueRange(),
1329 NoneType::get(config.getContext())),
1330 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
1331
1332 /// Read the bytecode defined within `buffer` into the given block.
1333 LogicalResult read(Block *block,
1334 llvm::function_ref<bool(Operation *)> lazyOps);
1335
1336 /// Return the number of ops that haven't been materialized yet.
1337 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
1338
1339 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(Val: op); }
1340
1341 /// Materialize the provided operation, invoke the lazyOpsCallback on every
1342 /// newly found lazy operation.
1343 LogicalResult
1344 materialize(Operation *op,
1345 llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1346 this->lazyOpsCallback = lazyOpsCallback;
1347 auto resetlazyOpsCallback =
1348 llvm::make_scope_exit(F: [&] { this->lazyOpsCallback = nullptr; });
1349 auto it = lazyLoadableOpsMap.find(Val: op);
1350 assert(it != lazyLoadableOpsMap.end() &&
1351 "materialize called on non-materializable op");
1352 return materialize(it);
1353 }
1354
1355 /// Materialize all operations.
1356 LogicalResult materializeAll() {
1357 while (!lazyLoadableOpsMap.empty()) {
1358 if (failed(result: materialize(it: lazyLoadableOpsMap.begin())))
1359 return failure();
1360 }
1361 return success();
1362 }
1363
1364 /// Finalize the lazy-loading by calling back with every op that hasn't been
1365 /// materialized to let the client decide if the op should be deleted or
1366 /// materialized. The op is materialized if the callback returns true, deleted
1367 /// otherwise.
1368 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
1369 while (!lazyLoadableOps.empty()) {
1370 Operation *op = lazyLoadableOps.begin()->first;
1371 if (shouldMaterialize(op)) {
1372 if (failed(result: materialize(it: lazyLoadableOpsMap.find(Val: op))))
1373 return failure();
1374 continue;
1375 }
1376 op->dropAllReferences();
1377 op->erase();
1378 lazyLoadableOps.pop_front();
1379 lazyLoadableOpsMap.erase(Val: op);
1380 }
1381 return success();
1382 }
1383
1384private:
1385 LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
1386 assert(it != lazyLoadableOpsMap.end() &&
1387 "materialize called on non-materializable op");
1388 valueScopes.emplace_back();
1389 std::vector<RegionReadState> regionStack;
1390 regionStack.push_back(x: std::move(it->getSecond()->second));
1391 lazyLoadableOps.erase(position: it->getSecond());
1392 lazyLoadableOpsMap.erase(I: it);
1393
1394 while (!regionStack.empty())
1395 if (failed(result: parseRegions(regionStack, readState&: regionStack.back())))
1396 return failure();
1397 return success();
1398 }
1399
1400 /// Return the context for this config.
1401 MLIRContext *getContext() const { return config.getContext(); }
1402
1403 /// Parse the bytecode version.
1404 LogicalResult parseVersion(EncodingReader &reader);
1405
1406 //===--------------------------------------------------------------------===//
1407 // Dialect Section
1408
1409 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
1410
1411 /// Parse an operation name reference using the given reader, and set the
1412 /// `wasRegistered` flag that indicates if the bytecode was produced by a
1413 /// context where opName was registered.
1414 FailureOr<OperationName> parseOpName(EncodingReader &reader,
1415 std::optional<bool> &wasRegistered);
1416
1417 //===--------------------------------------------------------------------===//
1418 // Attribute/Type Section
1419
1420 /// Parse an attribute or type using the given reader.
1421 template <typename T>
1422 LogicalResult parseAttribute(EncodingReader &reader, T &result) {
1423 return attrTypeReader.parseAttribute(reader, result);
1424 }
1425 LogicalResult parseType(EncodingReader &reader, Type &result) {
1426 return attrTypeReader.parseType(reader, result);
1427 }
1428
1429 //===--------------------------------------------------------------------===//
1430 // Resource Section
1431
1432 LogicalResult
1433 parseResourceSection(EncodingReader &reader,
1434 std::optional<ArrayRef<uint8_t>> resourceData,
1435 std::optional<ArrayRef<uint8_t>> resourceOffsetData);
1436
1437 //===--------------------------------------------------------------------===//
1438 // IR Section
1439
1440 /// This struct represents the current read state of a range of regions. This
1441 /// struct is used to enable iterative parsing of regions.
1442 struct RegionReadState {
1443 RegionReadState(Operation *op, EncodingReader *reader,
1444 bool isIsolatedFromAbove)
1445 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
1446 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
1447 bool isIsolatedFromAbove)
1448 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
1449 isIsolatedFromAbove(isIsolatedFromAbove) {}
1450
1451 /// The current regions being read.
1452 MutableArrayRef<Region>::iterator curRegion, endRegion;
1453 /// This is the reader to use for this region, this pointer is pointing to
1454 /// the parent region reader unless the current region is IsolatedFromAbove,
1455 /// in which case the pointer is pointing to the `owningReader` which is a
1456 /// section dedicated to the current region.
1457 EncodingReader *reader;
1458 std::unique_ptr<EncodingReader> owningReader;
1459
1460 /// The number of values defined immediately within this region.
1461 unsigned numValues = 0;
1462
1463 /// The current blocks of the region being read.
1464 SmallVector<Block *> curBlocks;
1465 Region::iterator curBlock = {};
1466
1467 /// The number of operations remaining to be read from the current block
1468 /// being read.
1469 uint64_t numOpsRemaining = 0;
1470
1471 /// A flag indicating if the regions being read are isolated from above.
1472 bool isIsolatedFromAbove = false;
1473 };
1474
1475 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
1476 LogicalResult parseRegions(std::vector<RegionReadState> &regionStack,
1477 RegionReadState &readState);
1478 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
1479 RegionReadState &readState,
1480 bool &isIsolatedFromAbove);
1481
1482 LogicalResult parseRegion(RegionReadState &readState);
1483 LogicalResult parseBlockHeader(EncodingReader &reader,
1484 RegionReadState &readState);
1485 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
1486
1487 //===--------------------------------------------------------------------===//
1488 // Value Processing
1489
1490 /// Parse an operand reference using the given reader. Returns nullptr in the
1491 /// case of failure.
1492 Value parseOperand(EncodingReader &reader);
1493
1494 /// Sequentially define the given value range.
1495 LogicalResult defineValues(EncodingReader &reader, ValueRange values);
1496
1497 /// Create a value to use for a forward reference.
1498 Value createForwardRef();
1499
1500 //===--------------------------------------------------------------------===//
1501 // Use-list order helpers
1502
1503 /// This struct is a simple storage that contains information required to
1504 /// reorder the use-list of a value with respect to the pre-order traversal
1505 /// ordering.
1506 struct UseListOrderStorage {
1507 UseListOrderStorage(bool isIndexPairEncoding,
1508 SmallVector<unsigned, 4> &&indices)
1509 : indices(std::move(indices)),
1510 isIndexPairEncoding(isIndexPairEncoding){};
1511 /// The vector containing the information required to reorder the
1512 /// use-list of a value.
1513 SmallVector<unsigned, 4> indices;
1514
1515 /// Whether indices represent a pair of type `(src, dst)` or it is a direct
1516 /// indexing, such as `dst = order[src]`.
1517 bool isIndexPairEncoding;
1518 };
1519
1520 /// Parse use-list order from bytecode for a range of values if available. The
1521 /// range is expected to be either a block argument or an op result range. On
1522 /// success, return a map of the position in the range and the use-list order
1523 /// encoding. The function assumes to know the size of the range it is
1524 /// processing.
1525 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
1526 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
1527 uint64_t rangeSize);
1528
1529 /// Shuffle the use-chain according to the order parsed.
1530 LogicalResult sortUseListOrder(Value value);
1531
1532 /// Recursively visit all the values defined within topLevelOp and sort the
1533 /// use-list orders according to the indices parsed.
1534 LogicalResult processUseLists(Operation *topLevelOp);
1535
1536 //===--------------------------------------------------------------------===//
1537 // Fields
1538
1539 /// This class represents a single value scope, in which a value scope is
1540 /// delimited by isolated from above regions.
1541 struct ValueScope {
1542 /// Push a new region state onto this scope, reserving enough values for
1543 /// those defined within the current region of the provided state.
1544 void push(RegionReadState &readState) {
1545 nextValueIDs.push_back(Elt: values.size());
1546 values.resize(new_size: values.size() + readState.numValues);
1547 }
1548
1549 /// Pop the values defined for the current region within the provided region
1550 /// state.
1551 void pop(RegionReadState &readState) {
1552 values.resize(new_size: values.size() - readState.numValues);
1553 nextValueIDs.pop_back();
1554 }
1555
1556 /// The set of values defined in this scope.
1557 std::vector<Value> values;
1558
1559 /// The ID for the next defined value for each region current being
1560 /// processed in this scope.
1561 SmallVector<unsigned, 4> nextValueIDs;
1562 };
1563
1564 /// The configuration of the parser.
1565 const ParserConfig &config;
1566
1567 /// A location to use when emitting errors.
1568 Location fileLoc;
1569
1570 /// Flag that indicates if lazyloading is enabled.
1571 bool lazyLoading;
1572
1573 /// Keep track of operations that have been lazy loaded (their regions haven't
1574 /// been materialized), along with the `RegionReadState` that allows to
1575 /// lazy-load the regions nested under the operation.
1576 LazyLoadableOpsInfo lazyLoadableOps;
1577 LazyLoadableOpsMap lazyLoadableOpsMap;
1578 llvm::function_ref<bool(Operation *)> lazyOpsCallback;
1579
1580 /// The reader used to process attribute and types within the bytecode.
1581 AttrTypeReader attrTypeReader;
1582
1583 /// The version of the bytecode being read.
1584 uint64_t version = 0;
1585
1586 /// The producer of the bytecode being read.
1587 StringRef producer;
1588
1589 /// The table of IR units referenced within the bytecode file.
1590 SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
1591 llvm::StringMap<BytecodeDialect *> dialectsMap;
1592 SmallVector<BytecodeOperationName> opNames;
1593
1594 /// The reader used to process resources within the bytecode.
1595 ResourceSectionReader resourceReader;
1596
1597 /// Worklist of values with custom use-list orders to process before the end
1598 /// of the parsing.
1599 DenseMap<void *, UseListOrderStorage> valueToUseListMap;
1600
1601 /// The table of strings referenced within the bytecode file.
1602 StringSectionReader stringReader;
1603
1604 /// The table of properties referenced by the operation in the bytecode file.
1605 PropertiesSectionReader propertiesReader;
1606
1607 /// The current set of available IR value scopes.
1608 std::vector<ValueScope> valueScopes;
1609
1610 /// The global pre-order operation ordering.
1611 DenseMap<Operation *, unsigned> operationIDs;
1612
1613 /// A block containing the set of operations defined to create forward
1614 /// references.
1615 Block forwardRefOps;
1616
1617 /// A block containing previously created, and no longer used, forward
1618 /// reference operations.
1619 Block openForwardRefOps;
1620
1621 /// An operation state used when instantiating forward references.
1622 OperationState forwardRefOpState;
1623
1624 /// Reference to the input buffer.
1625 llvm::MemoryBufferRef buffer;
1626
1627 /// The optional owning source manager, which when present may be used to
1628 /// extend the lifetime of the input buffer.
1629 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
1630};
1631
1632LogicalResult BytecodeReader::Impl::read(
1633 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1634 EncodingReader reader(buffer.getBuffer(), fileLoc);
1635 this->lazyOpsCallback = lazyOpsCallback;
1636 auto resetlazyOpsCallback =
1637 llvm::make_scope_exit(F: [&] { this->lazyOpsCallback = nullptr; });
1638
1639 // Skip over the bytecode header, this should have already been checked.
1640 if (failed(result: reader.skipBytes(length: StringRef("ML\xefR").size())))
1641 return failure();
1642 // Parse the bytecode version and producer.
1643 if (failed(result: parseVersion(reader)) ||
1644 failed(result: reader.parseNullTerminatedString(result&: producer)))
1645 return failure();
1646
1647 // Add a diagnostic handler that attaches a note that includes the original
1648 // producer of the bytecode.
1649 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) {
1650 diag.attachNote() << "in bytecode version " << version
1651 << " produced by: " << producer;
1652 return failure();
1653 });
1654
1655 // Parse the raw data for each of the top-level sections of the bytecode.
1656 std::optional<ArrayRef<uint8_t>>
1657 sectionDatas[bytecode::Section::kNumSections];
1658 while (!reader.empty()) {
1659 // Read the next section from the bytecode.
1660 bytecode::Section::ID sectionID;
1661 ArrayRef<uint8_t> sectionData;
1662 if (failed(result: reader.parseSection(sectionID, sectionData)))
1663 return failure();
1664
1665 // Check for duplicate sections, we only expect one instance of each.
1666 if (sectionDatas[sectionID]) {
1667 return reader.emitError(args: "duplicate top-level section: ",
1668 args: ::toString(sectionID));
1669 }
1670 sectionDatas[sectionID] = sectionData;
1671 }
1672 // Check that all of the required sections were found.
1673 for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
1674 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
1675 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) {
1676 return reader.emitError(args: "missing data for top-level section: ",
1677 args: ::toString(sectionID));
1678 }
1679 }
1680
1681 // Process the string section first.
1682 if (failed(result: stringReader.initialize(
1683 fileLoc, sectionData: *sectionDatas[bytecode::Section::kString])))
1684 return failure();
1685
1686 // Process the properties section.
1687 if (sectionDatas[bytecode::Section::kProperties] &&
1688 failed(result: propertiesReader.initialize(
1689 fileLoc, sectionData: *sectionDatas[bytecode::Section::kProperties])))
1690 return failure();
1691
1692 // Process the dialect section.
1693 if (failed(result: parseDialectSection(sectionData: *sectionDatas[bytecode::Section::kDialect])))
1694 return failure();
1695
1696 // Process the resource section if present.
1697 if (failed(result: parseResourceSection(
1698 reader, resourceData: sectionDatas[bytecode::Section::kResource],
1699 resourceOffsetData: sectionDatas[bytecode::Section::kResourceOffset])))
1700 return failure();
1701
1702 // Process the attribute and type section.
1703 if (failed(result: attrTypeReader.initialize(
1704 dialects, sectionData: *sectionDatas[bytecode::Section::kAttrType],
1705 offsetSectionData: *sectionDatas[bytecode::Section::kAttrTypeOffset])))
1706 return failure();
1707
1708 // Finally, process the IR section.
1709 return parseIRSection(sectionData: *sectionDatas[bytecode::Section::kIR], block);
1710}
1711
1712LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
1713 if (failed(result: reader.parseVarInt(result&: version)))
1714 return failure();
1715
1716 // Validate the bytecode version.
1717 uint64_t currentVersion = bytecode::kVersion;
1718 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
1719 if (version < minSupportedVersion) {
1720 return reader.emitError(args: "bytecode version ", args&: version,
1721 args: " is older than the current version of ",
1722 args&: currentVersion, args: ", and upgrade is not supported");
1723 }
1724 if (version > currentVersion) {
1725 return reader.emitError(args: "bytecode version ", args&: version,
1726 args: " is newer than the current version ",
1727 args&: currentVersion);
1728 }
1729 // Override any request to lazy-load if the bytecode version is too old.
1730 if (version < bytecode::kLazyLoading)
1731 lazyLoading = false;
1732 return success();
1733}
1734
1735//===----------------------------------------------------------------------===//
1736// Dialect Section
1737
1738LogicalResult BytecodeDialect::load(const DialectReader &reader,
1739 MLIRContext *ctx) {
1740 if (dialect)
1741 return success();
1742 Dialect *loadedDialect = ctx->getOrLoadDialect(name);
1743 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
1744 return reader.emitError(msg: "dialect '")
1745 << name
1746 << "' is unknown. If this is intended, please call "
1747 "allowUnregisteredDialects() on the MLIRContext, or use "
1748 "-allow-unregistered-dialect with the MLIR tool used.";
1749 }
1750 dialect = loadedDialect;
1751
1752 // If the dialect was actually loaded, check to see if it has a bytecode
1753 // interface.
1754 if (loadedDialect)
1755 interface = dyn_cast<BytecodeDialectInterface>(Val: loadedDialect);
1756 if (!versionBuffer.empty()) {
1757 if (!interface)
1758 return reader.emitError(msg: "dialect '")
1759 << name
1760 << "' does not implement the bytecode interface, "
1761 "but found a version entry";
1762 EncodingReader encReader(versionBuffer, reader.getLoc());
1763 DialectReader versionReader = reader.withEncodingReader(encReader);
1764 loadedVersion = interface->readVersion(reader&: versionReader);
1765 if (!loadedVersion)
1766 return failure();
1767 }
1768 return success();
1769}
1770
1771LogicalResult
1772BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
1773 EncodingReader sectionReader(sectionData, fileLoc);
1774
1775 // Parse the number of dialects in the section.
1776 uint64_t numDialects;
1777 if (failed(result: sectionReader.parseVarInt(result&: numDialects)))
1778 return failure();
1779 dialects.resize(N: numDialects);
1780
1781 // Parse each of the dialects.
1782 for (uint64_t i = 0; i < numDialects; ++i) {
1783 dialects[i] = std::make_unique<BytecodeDialect>();
1784 /// Before version kDialectVersioning, there wasn't any versioning available
1785 /// for dialects, and the entryIdx represent the string itself.
1786 if (version < bytecode::kDialectVersioning) {
1787 if (failed(result: stringReader.parseString(reader&: sectionReader, result&: dialects[i]->name)))
1788 return failure();
1789 continue;
1790 }
1791
1792 // Parse ID representing dialect and version.
1793 uint64_t dialectNameIdx;
1794 bool versionAvailable;
1795 if (failed(result: sectionReader.parseVarIntWithFlag(result&: dialectNameIdx,
1796 flag&: versionAvailable)))
1797 return failure();
1798 if (failed(result: stringReader.parseStringAtIndex(reader&: sectionReader, index: dialectNameIdx,
1799 result&: dialects[i]->name)))
1800 return failure();
1801 if (versionAvailable) {
1802 bytecode::Section::ID sectionID;
1803 if (failed(result: sectionReader.parseSection(sectionID,
1804 sectionData&: dialects[i]->versionBuffer)))
1805 return failure();
1806 if (sectionID != bytecode::Section::kDialectVersions) {
1807 emitError(loc: fileLoc, message: "expected dialect version section");
1808 return failure();
1809 }
1810 }
1811 dialectsMap[dialects[i]->name] = dialects[i].get();
1812 }
1813
1814 // Parse the operation names, which are grouped by dialect.
1815 auto parseOpName = [&](BytecodeDialect *dialect) {
1816 StringRef opName;
1817 std::optional<bool> wasRegistered;
1818 // Prior to version kNativePropertiesEncoding, the information about wheter
1819 // an op was registered or not wasn't encoded.
1820 if (version < bytecode::kNativePropertiesEncoding) {
1821 if (failed(result: stringReader.parseString(reader&: sectionReader, result&: opName)))
1822 return failure();
1823 } else {
1824 bool wasRegisteredFlag;
1825 if (failed(result: stringReader.parseStringWithFlag(reader&: sectionReader, result&: opName,
1826 flag&: wasRegisteredFlag)))
1827 return failure();
1828 wasRegistered = wasRegisteredFlag;
1829 }
1830 opNames.emplace_back(Args&: dialect, Args&: opName, Args&: wasRegistered);
1831 return success();
1832 };
1833 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation
1834 // where the number of ops are known.
1835 if (version >= bytecode::kElideUnknownBlockArgLocation) {
1836 uint64_t numOps;
1837 if (failed(result: sectionReader.parseVarInt(result&: numOps)))
1838 return failure();
1839 opNames.reserve(N: numOps);
1840 }
1841 while (!sectionReader.empty())
1842 if (failed(result: parseDialectGrouping(reader&: sectionReader, dialects, entryCallback: parseOpName)))
1843 return failure();
1844 return success();
1845}
1846
1847FailureOr<OperationName>
1848BytecodeReader::Impl::parseOpName(EncodingReader &reader,
1849 std::optional<bool> &wasRegistered) {
1850 BytecodeOperationName *opName = nullptr;
1851 if (failed(result: parseEntry(reader, entries&: opNames, entry&: opName, entryStr: "operation name")))
1852 return failure();
1853 wasRegistered = opName->wasRegistered;
1854 // Check to see if this operation name has already been resolved. If we
1855 // haven't, load the dialect and build the operation name.
1856 if (!opName->opName) {
1857 // If the opName is empty, this is because we use to accept names such as
1858 // `foo` without any `.` separator. We shouldn't tolerate this in textual
1859 // format anymore but for now we'll be backward compatible. This can only
1860 // happen with unregistered dialects.
1861 if (opName->name.empty()) {
1862 opName->opName.emplace(args&: opName->dialect->name, args: getContext());
1863 } else {
1864 // Load the dialect and its version.
1865 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1866 dialectsMap, reader, version);
1867 if (failed(result: opName->dialect->load(reader: dialectReader, ctx: getContext())))
1868 return failure();
1869 opName->opName.emplace(args: (opName->dialect->name + "." + opName->name).str(),
1870 args: getContext());
1871 }
1872 }
1873 return *opName->opName;
1874}
1875
1876//===----------------------------------------------------------------------===//
1877// Resource Section
1878
1879LogicalResult BytecodeReader::Impl::parseResourceSection(
1880 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
1881 std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
1882 // Ensure both sections are either present or not.
1883 if (resourceData.has_value() != resourceOffsetData.has_value()) {
1884 if (resourceOffsetData)
1885 return emitError(loc: fileLoc, message: "unexpected resource offset section when "
1886 "resource section is not present");
1887 return emitError(
1888 loc: fileLoc,
1889 message: "expected resource offset section when resource section is present");
1890 }
1891
1892 // If the resource sections are absent, there is nothing to do.
1893 if (!resourceData)
1894 return success();
1895
1896 // Initialize the resource reader with the resource sections.
1897 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1898 dialectsMap, reader, version);
1899 return resourceReader.initialize(fileLoc, config, dialects, stringReader,
1900 sectionData: *resourceData, offsetSectionData: *resourceOffsetData,
1901 dialectReader, bufferOwnerRef);
1902}
1903
1904//===----------------------------------------------------------------------===//
1905// UseListOrder Helpers
1906
1907FailureOr<BytecodeReader::Impl::UseListMapT>
1908BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
1909 uint64_t numResults) {
1910 BytecodeReader::Impl::UseListMapT map;
1911 uint64_t numValuesToRead = 1;
1912 if (numResults > 1 && failed(result: reader.parseVarInt(result&: numValuesToRead)))
1913 return failure();
1914
1915 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
1916 uint64_t resultIdx = 0;
1917 if (numResults > 1 && failed(result: reader.parseVarInt(result&: resultIdx)))
1918 return failure();
1919
1920 uint64_t numValues;
1921 bool indexPairEncoding;
1922 if (failed(result: reader.parseVarIntWithFlag(result&: numValues, flag&: indexPairEncoding)))
1923 return failure();
1924
1925 SmallVector<unsigned, 4> useListOrders;
1926 for (size_t idx = 0; idx < numValues; idx++) {
1927 uint64_t index;
1928 if (failed(result: reader.parseVarInt(result&: index)))
1929 return failure();
1930 useListOrders.push_back(Elt: index);
1931 }
1932
1933 // Store in a map the result index
1934 map.try_emplace(Key: resultIdx, Args: UseListOrderStorage(indexPairEncoding,
1935 std::move(useListOrders)));
1936 }
1937
1938 return map;
1939}
1940
1941/// Sorts each use according to the order specified in the use-list parsed. If
1942/// the custom use-list is not found, this means that the order needs to be
1943/// consistent with the reverse pre-order walk of the IR. If multiple uses lie
1944/// on the same operation, the order will follow the reverse operand number
1945/// ordering.
1946LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
1947 // Early return for trivial use-lists.
1948 if (value.use_empty() || value.hasOneUse())
1949 return success();
1950
1951 bool hasIncomingOrder =
1952 valueToUseListMap.contains(Val: value.getAsOpaquePointer());
1953
1954 // Compute the current order of the use-list with respect to the global
1955 // ordering. Detect if the order is already sorted while doing so.
1956 bool alreadySorted = true;
1957 auto &firstUse = *value.use_begin();
1958 uint64_t prevID =
1959 bytecode::getUseID(val&: firstUse, ownerID: operationIDs.at(Val: firstUse.getOwner()));
1960 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
1961 for (auto item : llvm::drop_begin(RangeOrContainer: llvm::enumerate(First: value.getUses()))) {
1962 uint64_t currentID = bytecode::getUseID(
1963 val&: item.value(), ownerID: operationIDs.at(Val: item.value().getOwner()));
1964 alreadySorted &= prevID > currentID;
1965 currentOrder.push_back(Elt: {item.index(), currentID});
1966 prevID = currentID;
1967 }
1968
1969 // If the order is already sorted, and there wasn't a custom order to apply
1970 // from the bytecode file, we are done.
1971 if (alreadySorted && !hasIncomingOrder)
1972 return success();
1973
1974 // If not already sorted, sort the indices of the current order by descending
1975 // useIDs.
1976 if (!alreadySorted)
1977 std::sort(
1978 first: currentOrder.begin(), last: currentOrder.end(),
1979 comp: [](auto elem1, auto elem2) { return elem1.second > elem2.second; });
1980
1981 if (!hasIncomingOrder) {
1982 // If the bytecode file did not contain any custom use-list order, it means
1983 // that the order was descending useID. Hence, shuffle by the first index
1984 // of the `currentOrder` pair.
1985 SmallVector<unsigned> shuffle = SmallVector<unsigned>(
1986 llvm::map_range(C&: currentOrder, F: [&](auto item) { return item.first; }));
1987 value.shuffleUseList(indices: shuffle);
1988 return success();
1989 }
1990
1991 // Pull the custom order info from the map.
1992 UseListOrderStorage customOrder =
1993 valueToUseListMap.at(Val: value.getAsOpaquePointer());
1994 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
1995 uint64_t numUses =
1996 std::distance(first: value.getUses().begin(), last: value.getUses().end());
1997
1998 // If the encoding was a pair of indices `(src, dst)` for every permutation,
1999 // reconstruct the shuffle vector for every use. Initialize the shuffle vector
2000 // as identity, and then apply the mapping encoded in the indices.
2001 if (customOrder.isIndexPairEncoding) {
2002 // Return failure if the number of indices was not representing pairs.
2003 if (shuffle.size() & 1)
2004 return failure();
2005
2006 SmallVector<unsigned, 4> newShuffle(numUses);
2007 size_t idx = 0;
2008 std::iota(first: newShuffle.begin(), last: newShuffle.end(), value: idx);
2009 for (idx = 0; idx < shuffle.size(); idx += 2)
2010 newShuffle[shuffle[idx]] = shuffle[idx + 1];
2011
2012 shuffle = std::move(newShuffle);
2013 }
2014
2015 // Make sure that the indices represent a valid mapping. That is, the sum of
2016 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no
2017 // duplicates are allowed in the list.
2018 DenseSet<unsigned> set;
2019 uint64_t accumulator = 0;
2020 for (const auto &elem : shuffle) {
2021 if (set.contains(V: elem))
2022 return failure();
2023 accumulator += elem;
2024 set.insert(V: elem);
2025 }
2026 if (numUses != shuffle.size() ||
2027 accumulator != (((numUses - 1) * numUses) >> 1))
2028 return failure();
2029
2030 // Apply the current ordering map onto the shuffle vector to get the final
2031 // use-list sorting indices before shuffling.
2032 shuffle = SmallVector<unsigned, 4>(llvm::map_range(
2033 C&: currentOrder, F: [&](auto item) { return shuffle[item.first]; }));
2034 value.shuffleUseList(indices: shuffle);
2035 return success();
2036}
2037
2038LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
2039 // Precompute operation IDs according to the pre-order walk of the IR. We
2040 // can't do this while parsing since parseRegions ordering is not strictly
2041 // equal to the pre-order walk.
2042 unsigned operationID = 0;
2043 topLevelOp->walk<mlir::WalkOrder::PreOrder>(
2044 callback: [&](Operation *op) { operationIDs.try_emplace(Key: op, Args: operationID++); });
2045
2046 auto blockWalk = topLevelOp->walk(callback: [this](Block *block) {
2047 for (auto arg : block->getArguments())
2048 if (failed(result: sortUseListOrder(value: arg)))
2049 return WalkResult::interrupt();
2050 return WalkResult::advance();
2051 });
2052
2053 auto resultWalk = topLevelOp->walk(callback: [this](Operation *op) {
2054 for (auto result : op->getResults())
2055 if (failed(result: sortUseListOrder(value: result)))
2056 return WalkResult::interrupt();
2057 return WalkResult::advance();
2058 });
2059
2060 return failure(isFailure: blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
2061}
2062
2063//===----------------------------------------------------------------------===//
2064// IR Section
2065
2066LogicalResult
2067BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
2068 Block *block) {
2069 EncodingReader reader(sectionData, fileLoc);
2070
2071 // A stack of operation regions currently being read from the bytecode.
2072 std::vector<RegionReadState> regionStack;
2073
2074 // Parse the top-level block using a temporary module operation.
2075 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
2076 regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true);
2077 regionStack.back().curBlocks.push_back(Elt: moduleOp->getBody());
2078 regionStack.back().curBlock = regionStack.back().curRegion->begin();
2079 if (failed(result: parseBlockHeader(reader, readState&: regionStack.back())))
2080 return failure();
2081 valueScopes.emplace_back();
2082 valueScopes.back().push(readState&: regionStack.back());
2083
2084 // Iteratively parse regions until everything has been resolved.
2085 while (!regionStack.empty())
2086 if (failed(result: parseRegions(regionStack, readState&: regionStack.back())))
2087 return failure();
2088 if (!forwardRefOps.empty()) {
2089 return reader.emitError(
2090 args: "not all forward unresolved forward operand references");
2091 }
2092
2093 // Sort use-lists according to what specified in bytecode.
2094 if (failed(processUseLists(topLevelOp: *moduleOp)))
2095 return reader.emitError(
2096 args: "parsed use-list orders were invalid and could not be applied");
2097
2098 // Resolve dialect version.
2099 for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
2100 // Parsing is complete, give an opportunity to each dialect to visit the
2101 // IR and perform upgrades.
2102 if (!byteCodeDialect->loadedVersion)
2103 continue;
2104 if (byteCodeDialect->interface &&
2105 failed(byteCodeDialect->interface->upgradeFromVersion(
2106 topLevelOp: *moduleOp, version: *byteCodeDialect->loadedVersion)))
2107 return failure();
2108 }
2109
2110 // Verify that the parsed operations are valid.
2111 if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
2112 return failure();
2113
2114 // Splice the parsed operations over to the provided top-level block.
2115 auto &parsedOps = moduleOp->getBody()->getOperations();
2116 auto &destOps = block->getOperations();
2117 destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
2118 return success();
2119}
2120
2121LogicalResult
2122BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
2123 RegionReadState &readState) {
2124 // Process regions, blocks, and operations until the end or if a nested
2125 // region is encountered. In this case we push a new state in regionStack and
2126 // return, the processing of the current region will resume afterward.
2127 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
2128 // If the current block hasn't been setup yet, parse the header for this
2129 // region. The current block is already setup when this function was
2130 // interrupted to recurse down in a nested region and we resume the current
2131 // block after processing the nested region.
2132 if (readState.curBlock == Region::iterator()) {
2133 if (failed(result: parseRegion(readState)))
2134 return failure();
2135
2136 // If the region is empty, there is nothing to more to do.
2137 if (readState.curRegion->empty())
2138 continue;
2139 }
2140
2141 // Parse the blocks within the region.
2142 EncodingReader &reader = *readState.reader;
2143 do {
2144 while (readState.numOpsRemaining--) {
2145 // Read in the next operation. We don't read its regions directly, we
2146 // handle those afterwards as necessary.
2147 bool isIsolatedFromAbove = false;
2148 FailureOr<Operation *> op =
2149 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
2150 if (failed(result: op))
2151 return failure();
2152
2153 // If the op has regions, add it to the stack for processing and return:
2154 // we stop the processing of the current region and resume it after the
2155 // inner one is completed. Unless LazyLoading is activated in which case
2156 // nested region parsing is delayed.
2157 if ((*op)->getNumRegions()) {
2158 RegionReadState childState(*op, &reader, isIsolatedFromAbove);
2159
2160 // Isolated regions are encoded as a section in version 2 and above.
2161 if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
2162 bytecode::Section::ID sectionID;
2163 ArrayRef<uint8_t> sectionData;
2164 if (failed(result: reader.parseSection(sectionID, sectionData)))
2165 return failure();
2166 if (sectionID != bytecode::Section::kIR)
2167 return emitError(loc: fileLoc, message: "expected IR section for region");
2168 childState.owningReader =
2169 std::make_unique<EncodingReader>(args&: sectionData, args&: fileLoc);
2170 childState.reader = childState.owningReader.get();
2171
2172 // If the user has a callback set, they have the opportunity to
2173 // control lazyloading as we go.
2174 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) {
2175 lazyLoadableOps.emplace_back(args&: *op, args: std::move(childState));
2176 lazyLoadableOpsMap.try_emplace(Key: *op,
2177 Args: std::prev(x: lazyLoadableOps.end()));
2178 continue;
2179 }
2180 }
2181 regionStack.push_back(x: std::move(childState));
2182
2183 // If the op is isolated from above, push a new value scope.
2184 if (isIsolatedFromAbove)
2185 valueScopes.emplace_back();
2186 return success();
2187 }
2188 }
2189
2190 // Move to the next block of the region.
2191 if (++readState.curBlock == readState.curRegion->end())
2192 break;
2193 if (failed(result: parseBlockHeader(reader, readState)))
2194 return failure();
2195 } while (true);
2196
2197 // Reset the current block and any values reserved for this region.
2198 readState.curBlock = {};
2199 valueScopes.back().pop(readState);
2200 }
2201
2202 // When the regions have been fully parsed, pop them off of the read stack. If
2203 // the regions were isolated from above, we also pop the last value scope.
2204 if (readState.isIsolatedFromAbove) {
2205 assert(!valueScopes.empty() && "Expect a valueScope after reading region");
2206 valueScopes.pop_back();
2207 }
2208 assert(!regionStack.empty() && "Expect a regionStack after reading region");
2209 regionStack.pop_back();
2210 return success();
2211}
2212
2213FailureOr<Operation *>
2214BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
2215 RegionReadState &readState,
2216 bool &isIsolatedFromAbove) {
2217 // Parse the name of the operation.
2218 std::optional<bool> wasRegistered;
2219 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
2220 if (failed(result: opName))
2221 return failure();
2222
2223 // Parse the operation mask, which indicates which components of the operation
2224 // are present.
2225 uint8_t opMask;
2226 if (failed(result: reader.parseByte(value&: opMask)))
2227 return failure();
2228
2229 /// Parse the location.
2230 LocationAttr opLoc;
2231 if (failed(result: parseAttribute(reader, result&: opLoc)))
2232 return failure();
2233
2234 // With the location and name resolved, we can start building the operation
2235 // state.
2236 OperationState opState(opLoc, *opName);
2237
2238 // Parse the attributes of the operation.
2239 if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
2240 DictionaryAttr dictAttr;
2241 if (failed(parseAttribute(reader, dictAttr)))
2242 return failure();
2243 opState.attributes = dictAttr;
2244 }
2245
2246 if (opMask & bytecode::OpEncodingMask::kHasProperties) {
2247 // kHasProperties wasn't emitted in older bytecode, we should never get
2248 // there without also having the `wasRegistered` flag available.
2249 if (!wasRegistered)
2250 return emitError(loc: fileLoc,
2251 message: "Unexpected missing `wasRegistered` opname flag at "
2252 "bytecode version ")
2253 << version << " with properties.";
2254 // When an operation is emitted without being registered, the properties are
2255 // stored as an attribute. Otherwise the op must implement the bytecode
2256 // interface and control the serialization.
2257 if (wasRegistered) {
2258 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2259 dialectsMap, reader, version);
2260 if (failed(
2261 result: propertiesReader.read(fileLoc, dialectReader, opName: &*opName, opState)))
2262 return failure();
2263 } else {
2264 // If the operation wasn't registered when it was emitted, the properties
2265 // was serialized as an attribute.
2266 if (failed(result: parseAttribute(reader, result&: opState.propertiesAttr)))
2267 return failure();
2268 }
2269 }
2270
2271 /// Parse the results of the operation.
2272 if (opMask & bytecode::OpEncodingMask::kHasResults) {
2273 uint64_t numResults;
2274 if (failed(result: reader.parseVarInt(result&: numResults)))
2275 return failure();
2276 opState.types.resize(N: numResults);
2277 for (int i = 0, e = numResults; i < e; ++i)
2278 if (failed(result: parseType(reader, result&: opState.types[i])))
2279 return failure();
2280 }
2281
2282 /// Parse the operands of the operation.
2283 if (opMask & bytecode::OpEncodingMask::kHasOperands) {
2284 uint64_t numOperands;
2285 if (failed(result: reader.parseVarInt(result&: numOperands)))
2286 return failure();
2287 opState.operands.resize(N: numOperands);
2288 for (int i = 0, e = numOperands; i < e; ++i)
2289 if (!(opState.operands[i] = parseOperand(reader)))
2290 return failure();
2291 }
2292
2293 /// Parse the successors of the operation.
2294 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) {
2295 uint64_t numSuccs;
2296 if (failed(result: reader.parseVarInt(result&: numSuccs)))
2297 return failure();
2298 opState.successors.resize(N: numSuccs);
2299 for (int i = 0, e = numSuccs; i < e; ++i) {
2300 if (failed(result: parseEntry(reader, entries&: readState.curBlocks, entry&: opState.successors[i],
2301 entryStr: "successor")))
2302 return failure();
2303 }
2304 }
2305
2306 /// Parse the use-list orders for the results of the operation. Use-list
2307 /// orders are available since version 3 of the bytecode.
2308 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
2309 if (version >= bytecode::kUseListOrdering &&
2310 (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) {
2311 size_t numResults = opState.types.size();
2312 auto parseResult = parseUseListOrderForRange(reader, numResults);
2313 if (failed(result: parseResult))
2314 return failure();
2315 resultIdxToUseListMap = std::move(*parseResult);
2316 }
2317
2318 /// Parse the regions of the operation.
2319 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
2320 uint64_t numRegions;
2321 if (failed(result: reader.parseVarIntWithFlag(result&: numRegions, flag&: isIsolatedFromAbove)))
2322 return failure();
2323
2324 opState.regions.reserve(N: numRegions);
2325 for (int i = 0, e = numRegions; i < e; ++i)
2326 opState.regions.push_back(Elt: std::make_unique<Region>());
2327 }
2328
2329 // Create the operation at the back of the current block.
2330 Operation *op = Operation::create(state: opState);
2331 readState.curBlock->push_back(op);
2332
2333 // If the operation had results, update the value references. We don't need to
2334 // do this if the current value scope is empty. That is, the op was not
2335 // encoded within a parent region.
2336 if (readState.numValues && op->getNumResults() &&
2337 failed(result: defineValues(reader, values: op->getResults())))
2338 return failure();
2339
2340 /// Store a map for every value that received a custom use-list order from the
2341 /// bytecode file.
2342 if (resultIdxToUseListMap.has_value()) {
2343 for (size_t idx = 0; idx < op->getNumResults(); idx++) {
2344 if (resultIdxToUseListMap->contains(Val: idx)) {
2345 valueToUseListMap.try_emplace(Key: op->getResult(idx).getAsOpaquePointer(),
2346 Args: resultIdxToUseListMap->at(Val: idx));
2347 }
2348 }
2349 }
2350 return op;
2351}
2352
2353LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
2354 EncodingReader &reader = *readState.reader;
2355
2356 // Parse the number of blocks in the region.
2357 uint64_t numBlocks;
2358 if (failed(result: reader.parseVarInt(result&: numBlocks)))
2359 return failure();
2360
2361 // If the region is empty, there is nothing else to do.
2362 if (numBlocks == 0)
2363 return success();
2364
2365 // Parse the number of values defined in this region.
2366 uint64_t numValues;
2367 if (failed(result: reader.parseVarInt(result&: numValues)))
2368 return failure();
2369 readState.numValues = numValues;
2370
2371 // Create the blocks within this region. We do this before processing so that
2372 // we can rely on the blocks existing when creating operations.
2373 readState.curBlocks.clear();
2374 readState.curBlocks.reserve(N: numBlocks);
2375 for (uint64_t i = 0; i < numBlocks; ++i) {
2376 readState.curBlocks.push_back(Elt: new Block());
2377 readState.curRegion->push_back(block: readState.curBlocks.back());
2378 }
2379
2380 // Prepare the current value scope for this region.
2381 valueScopes.back().push(readState);
2382
2383 // Parse the entry block of the region.
2384 readState.curBlock = readState.curRegion->begin();
2385 return parseBlockHeader(reader, readState);
2386}
2387
2388LogicalResult
2389BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
2390 RegionReadState &readState) {
2391 bool hasArgs;
2392 if (failed(result: reader.parseVarIntWithFlag(result&: readState.numOpsRemaining, flag&: hasArgs)))
2393 return failure();
2394
2395 // Parse the arguments of the block.
2396 if (hasArgs && failed(result: parseBlockArguments(reader, block: &*readState.curBlock)))
2397 return failure();
2398
2399 // Uselist orders are available since version 3 of the bytecode.
2400 if (version < bytecode::kUseListOrdering)
2401 return success();
2402
2403 uint8_t hasUseListOrders = 0;
2404 if (hasArgs && failed(result: reader.parseByte(value&: hasUseListOrders)))
2405 return failure();
2406
2407 if (!hasUseListOrders)
2408 return success();
2409
2410 Block &blk = *readState.curBlock;
2411 auto argIdxToUseListMap =
2412 parseUseListOrderForRange(reader, numResults: blk.getNumArguments());
2413 if (failed(result: argIdxToUseListMap) || argIdxToUseListMap->empty())
2414 return failure();
2415
2416 for (size_t idx = 0; idx < blk.getNumArguments(); idx++)
2417 if (argIdxToUseListMap->contains(Val: idx))
2418 valueToUseListMap.try_emplace(Key: blk.getArgument(i: idx).getAsOpaquePointer(),
2419 Args: argIdxToUseListMap->at(Val: idx));
2420
2421 // We don't parse the operations of the block here, that's done elsewhere.
2422 return success();
2423}
2424
2425LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
2426 Block *block) {
2427 // Parse the value ID for the first argument, and the number of arguments.
2428 uint64_t numArgs;
2429 if (failed(result: reader.parseVarInt(result&: numArgs)))
2430 return failure();
2431
2432 SmallVector<Type> argTypes;
2433 SmallVector<Location> argLocs;
2434 argTypes.reserve(N: numArgs);
2435 argLocs.reserve(N: numArgs);
2436
2437 Location unknownLoc = UnknownLoc::get(config.getContext());
2438 while (numArgs--) {
2439 Type argType;
2440 LocationAttr argLoc = unknownLoc;
2441 if (version >= bytecode::kElideUnknownBlockArgLocation) {
2442 // Parse the type with hasLoc flag to determine if it has type.
2443 uint64_t typeIdx;
2444 bool hasLoc;
2445 if (failed(result: reader.parseVarIntWithFlag(result&: typeIdx, flag&: hasLoc)) ||
2446 !(argType = attrTypeReader.resolveType(index: typeIdx)))
2447 return failure();
2448 if (hasLoc && failed(result: parseAttribute(reader, result&: argLoc)))
2449 return failure();
2450 } else {
2451 // All args has type and location.
2452 if (failed(result: parseType(reader, result&: argType)) ||
2453 failed(result: parseAttribute(reader, result&: argLoc)))
2454 return failure();
2455 }
2456 argTypes.push_back(Elt: argType);
2457 argLocs.push_back(Elt: argLoc);
2458 }
2459 block->addArguments(types: argTypes, locs: argLocs);
2460 return defineValues(reader, values: block->getArguments());
2461}
2462
2463//===----------------------------------------------------------------------===//
2464// Value Processing
2465
2466Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
2467 std::vector<Value> &values = valueScopes.back().values;
2468 Value *value = nullptr;
2469 if (failed(result: parseEntry(reader, entries&: values, entry&: value, entryStr: "value")))
2470 return Value();
2471
2472 // Create a new forward reference if necessary.
2473 if (!*value)
2474 *value = createForwardRef();
2475 return *value;
2476}
2477
2478LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
2479 ValueRange newValues) {
2480 ValueScope &valueScope = valueScopes.back();
2481 std::vector<Value> &values = valueScope.values;
2482
2483 unsigned &valueID = valueScope.nextValueIDs.back();
2484 unsigned valueIDEnd = valueID + newValues.size();
2485 if (valueIDEnd > values.size()) {
2486 return reader.emitError(
2487 args: "value index range was outside of the expected range for "
2488 "the parent region, got [",
2489 args&: valueID, args: ", ", args&: valueIDEnd, args: "), but the maximum index was ",
2490 args: values.size() - 1);
2491 }
2492
2493 // Assign the values and update any forward references.
2494 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
2495 Value newValue = newValues[i];
2496
2497 // Check to see if a definition for this value already exists.
2498 if (Value oldValue = std::exchange(obj&: values[valueID], new_val&: newValue)) {
2499 Operation *forwardRefOp = oldValue.getDefiningOp();
2500
2501 // Assert that this is a forward reference operation. Given how we compute
2502 // definition ids (incrementally as we parse), it shouldn't be possible
2503 // for the value to be defined any other way.
2504 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&
2505 "value index was already defined?");
2506
2507 oldValue.replaceAllUsesWith(newValue);
2508 forwardRefOp->moveBefore(block: &openForwardRefOps, iterator: openForwardRefOps.end());
2509 }
2510 }
2511 return success();
2512}
2513
2514Value BytecodeReader::Impl::createForwardRef() {
2515 // Check for an avaliable existing operation to use. Otherwise, create a new
2516 // fake operation to use for the reference.
2517 if (!openForwardRefOps.empty()) {
2518 Operation *op = &openForwardRefOps.back();
2519 op->moveBefore(block: &forwardRefOps, iterator: forwardRefOps.end());
2520 } else {
2521 forwardRefOps.push_back(op: Operation::create(state: forwardRefOpState));
2522 }
2523 return forwardRefOps.back().getResult(idx: 0);
2524}
2525
2526//===----------------------------------------------------------------------===//
2527// Entry Points
2528//===----------------------------------------------------------------------===//
2529
2530BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); }
2531
2532BytecodeReader::BytecodeReader(
2533 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
2534 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2535 Location sourceFileLoc =
2536 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2537 /*line=*/0, /*column=*/0);
2538 impl = std::make_unique<Impl>(args&: sourceFileLoc, args: config, args&: lazyLoading, args&: buffer,
2539 args: bufferOwnerRef);
2540}
2541
2542LogicalResult BytecodeReader::readTopLevel(
2543 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2544 return impl->read(block, lazyOpsCallback);
2545}
2546
2547int64_t BytecodeReader::getNumOpsToMaterialize() const {
2548 return impl->getNumOpsToMaterialize();
2549}
2550
2551bool BytecodeReader::isMaterializable(Operation *op) {
2552 return impl->isMaterializable(op);
2553}
2554
2555LogicalResult BytecodeReader::materialize(
2556 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2557 return impl->materialize(op, lazyOpsCallback);
2558}
2559
2560LogicalResult
2561BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) {
2562 return impl->finalize(shouldMaterialize);
2563}
2564
2565bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
2566 return buffer.getBuffer().starts_with(Prefix: "ML\xefR");
2567}
2568
2569/// Read the bytecode from the provided memory buffer reference.
2570/// `bufferOwnerRef` if provided is the owning source manager for the buffer,
2571/// and may be used to extend the lifetime of the buffer.
2572static LogicalResult
2573readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
2574 const ParserConfig &config,
2575 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2576 Location sourceFileLoc =
2577 FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
2578 /*line=*/0, /*column=*/0);
2579 if (!isBytecode(buffer)) {
2580 return emitError(loc: sourceFileLoc,
2581 message: "input buffer is not an MLIR bytecode file");
2582 }
2583
2584 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false,
2585 buffer, bufferOwnerRef);
2586 return reader.read(block, /*lazyOpsCallback=*/nullptr);
2587}
2588
2589LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
2590 const ParserConfig &config) {
2591 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{});
2592}
2593LogicalResult
2594mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
2595 Block *block, const ParserConfig &config) {
2596 return readBytecodeFileImpl(
2597 buffer: *sourceMgr->getMemoryBuffer(i: sourceMgr->getMainFileID()), block, config,
2598 bufferOwnerRef: sourceMgr);
2599}
2600

source code of mlir/lib/Bytecode/Reader/BytecodeReader.cpp