1//===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Bytecode/BytecodeReader.h"
10#include "mlir/AsmParser/AsmParser.h"
11#include "mlir/Bytecode/BytecodeImplementation.h"
12#include "mlir/Bytecode/BytecodeOpInterface.h"
13#include "mlir/Bytecode/Encoding.h"
14#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/Verifier.h"
18#include "mlir/IR/Visitors.h"
19#include "mlir/Support/LLVM.h"
20#include "llvm/ADT/ArrayRef.h"
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/Support/Endian.h"
25#include "llvm/Support/MemoryBufferRef.h"
26#include "llvm/Support/SourceMgr.h"
27
28#include <cstddef>
29#include <list>
30#include <memory>
31#include <numeric>
32#include <optional>
33
34#define DEBUG_TYPE "mlir-bytecode-reader"
35
36using namespace mlir;
37
38/// Stringify the given section ID.
39static std::string toString(bytecode::Section::ID sectionID) {
40 switch (sectionID) {
41 case bytecode::Section::kString:
42 return "String (0)";
43 case bytecode::Section::kDialect:
44 return "Dialect (1)";
45 case bytecode::Section::kAttrType:
46 return "AttrType (2)";
47 case bytecode::Section::kAttrTypeOffset:
48 return "AttrTypeOffset (3)";
49 case bytecode::Section::kIR:
50 return "IR (4)";
51 case bytecode::Section::kResource:
52 return "Resource (5)";
53 case bytecode::Section::kResourceOffset:
54 return "ResourceOffset (6)";
55 case bytecode::Section::kDialectVersions:
56 return "DialectVersions (7)";
57 case bytecode::Section::kProperties:
58 return "Properties (8)";
59 default:
60 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
61 }
62}
63
64/// Returns true if the given top-level section ID is optional.
65static bool isSectionOptional(bytecode::Section::ID sectionID, int version) {
66 switch (sectionID) {
67 case bytecode::Section::kString:
68 case bytecode::Section::kDialect:
69 case bytecode::Section::kAttrType:
70 case bytecode::Section::kAttrTypeOffset:
71 case bytecode::Section::kIR:
72 return false;
73 case bytecode::Section::kResource:
74 case bytecode::Section::kResourceOffset:
75 case bytecode::Section::kDialectVersions:
76 return true;
77 case bytecode::Section::kProperties:
78 return version < bytecode::kNativePropertiesEncoding;
79 default:
80 llvm_unreachable("unknown section ID");
81 }
82}
83
84//===----------------------------------------------------------------------===//
85// EncodingReader
86//===----------------------------------------------------------------------===//
87
88namespace {
89class EncodingReader {
90public:
91 explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
92 : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
93 explicit EncodingReader(StringRef contents, Location fileLoc)
94 : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
95 contents.size()},
96 fileLoc) {}
97
98 /// Returns true if the entire section has been read.
99 bool empty() const { return dataIt == buffer.end(); }
100
101 /// Returns the remaining size of the bytecode.
102 size_t size() const { return buffer.end() - dataIt; }
103
104 /// Align the current reader position to the specified alignment.
105 LogicalResult alignTo(unsigned alignment) {
106 if (!llvm::isPowerOf2_32(Value: alignment))
107 return emitError(args: "expected alignment to be a power-of-two");
108
109 auto isUnaligned = [&](const uint8_t *ptr) {
110 return ((uintptr_t)ptr & (alignment - 1)) != 0;
111 };
112
113 // Shift the reader position to the next alignment boundary.
114 while (isUnaligned(dataIt)) {
115 uint8_t padding;
116 if (failed(Result: parseByte(value&: padding)))
117 return failure();
118 if (padding != bytecode::kAlignmentByte) {
119 return emitError(args: "expected alignment byte (0xCB), but got: '0x" +
120 llvm::utohexstr(X: padding) + "'");
121 }
122 }
123
124 // Ensure the data iterator is now aligned. This case is unlikely because we
125 // *just* went through the effort to align the data iterator.
126 if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
127 return emitError(args: "expected data iterator aligned to ", args&: alignment,
128 args: ", but got pointer: '0x" +
129 llvm::utohexstr(X: (uintptr_t)dataIt) + "'");
130 }
131
132 return success();
133 }
134
135 /// Emit an error using the given arguments.
136 template <typename... Args>
137 InFlightDiagnostic emitError(Args &&...args) const {
138 return ::emitError(loc: fileLoc).append(std::forward<Args>(args)...);
139 }
140 InFlightDiagnostic emitError() const { return ::emitError(loc: fileLoc); }
141
142 /// Parse a single byte from the stream.
143 template <typename T>
144 LogicalResult parseByte(T &value) {
145 if (empty())
146 return emitError(args: "attempting to parse a byte at the end of the bytecode");
147 value = static_cast<T>(*dataIt++);
148 return success();
149 }
150 /// Parse a range of bytes of 'length' into the given result.
151 LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) {
152 if (length > size()) {
153 return emitError(args: "attempting to parse ", args&: length, args: " bytes when only ",
154 args: size(), args: " remain");
155 }
156 result = {dataIt, length};
157 dataIt += length;
158 return success();
159 }
160 /// Parse a range of bytes of 'length' into the given result, which can be
161 /// assumed to be large enough to hold `length`.
162 LogicalResult parseBytes(size_t length, uint8_t *result) {
163 if (length > size()) {
164 return emitError(args: "attempting to parse ", args&: length, args: " bytes when only ",
165 args: size(), args: " remain");
166 }
167 memcpy(dest: result, src: dataIt, n: length);
168 dataIt += length;
169 return success();
170 }
171
172 /// Parse an aligned blob of data, where the alignment was encoded alongside
173 /// the data.
174 LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
175 uint64_t &alignment) {
176 uint64_t dataSize;
177 if (failed(Result: parseVarInt(result&: alignment)) || failed(Result: parseVarInt(result&: dataSize)) ||
178 failed(Result: alignTo(alignment)))
179 return failure();
180 return parseBytes(length: dataSize, result&: data);
181 }
182
183 /// Parse a variable length encoded integer from the byte stream. The first
184 /// encoded byte contains a prefix in the low bits indicating the encoded
185 /// length of the value. This length prefix is a bit sequence of '0's followed
186 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
187 /// (not including the prefix byte). All remaining bits in the first byte,
188 /// along with all of the bits in additional bytes, provide the value of the
189 /// integer encoded in little-endian order.
190 LogicalResult parseVarInt(uint64_t &result) {
191 // Parse the first byte of the encoding, which contains the length prefix.
192 if (failed(Result: parseByte(value&: result)))
193 return failure();
194
195 // Handle the overwhelmingly common case where the value is stored in a
196 // single byte. In this case, the first bit is the `1` marker bit.
197 if (LLVM_LIKELY(result & 1)) {
198 result >>= 1;
199 return success();
200 }
201
202 // Handle the overwhelming uncommon case where the value required all 8
203 // bytes (i.e. a really really big number). In this case, the marker byte is
204 // all zeros: `00000000`.
205 if (LLVM_UNLIKELY(result == 0)) {
206 llvm::support::ulittle64_t resultLE;
207 if (failed(Result: parseBytes(length: sizeof(resultLE),
208 result: reinterpret_cast<uint8_t *>(&resultLE))))
209 return failure();
210 result = resultLE;
211 return success();
212 }
213 return parseMultiByteVarInt(result);
214 }
215
216 /// Parse a signed variable length encoded integer from the byte stream. A
217 /// signed varint is encoded as a normal varint with zigzag encoding applied,
218 /// i.e. the low bit of the value is used to indicate the sign.
219 LogicalResult parseSignedVarInt(uint64_t &result) {
220 if (failed(Result: parseVarInt(result)))
221 return failure();
222 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1)
223 result = (result >> 1) ^ (~(result & 1) + 1);
224 return success();
225 }
226
227 /// Parse a variable length encoded integer whose low bit is used to encode an
228 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
229 LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
230 if (failed(Result: parseVarInt(result)))
231 return failure();
232 flag = result & 1;
233 result >>= 1;
234 return success();
235 }
236
237 /// Skip the first `length` bytes within the reader.
238 LogicalResult skipBytes(size_t length) {
239 if (length > size()) {
240 return emitError(args: "attempting to skip ", args&: length, args: " bytes when only ",
241 args: size(), args: " remain");
242 }
243 dataIt += length;
244 return success();
245 }
246
247 /// Parse a null-terminated string into `result` (without including the NUL
248 /// terminator).
249 LogicalResult parseNullTerminatedString(StringRef &result) {
250 const char *startIt = (const char *)dataIt;
251 const char *nulIt = (const char *)memchr(s: startIt, c: 0, n: size());
252 if (!nulIt)
253 return emitError(
254 args: "malformed null-terminated string, no null character found");
255
256 result = StringRef(startIt, nulIt - startIt);
257 dataIt = (const uint8_t *)nulIt + 1;
258 return success();
259 }
260
261 /// Parse a section header, placing the kind of section in `sectionID` and the
262 /// contents of the section in `sectionData`.
263 LogicalResult parseSection(bytecode::Section::ID &sectionID,
264 ArrayRef<uint8_t> &sectionData) {
265 uint8_t sectionIDAndHasAlignment;
266 uint64_t length;
267 if (failed(Result: parseByte(value&: sectionIDAndHasAlignment)) ||
268 failed(Result: parseVarInt(result&: length)))
269 return failure();
270
271 // Extract the section ID and whether the section is aligned. The high bit
272 // of the ID is the alignment flag.
273 sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
274 0b01111111);
275 bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
276
277 // Check that the section is actually valid before trying to process its
278 // data.
279 if (sectionID >= bytecode::Section::kNumSections)
280 return emitError(args: "invalid section ID: ", args: unsigned(sectionID));
281
282 // Process the section alignment if present.
283 if (hasAlignment) {
284 uint64_t alignment;
285 if (failed(Result: parseVarInt(result&: alignment)) || failed(Result: alignTo(alignment)))
286 return failure();
287 }
288
289 // Parse the actual section data.
290 return parseBytes(length: static_cast<size_t>(length), result&: sectionData);
291 }
292
293 Location getLoc() const { return fileLoc; }
294
295private:
296 /// Parse a variable length encoded integer from the byte stream. This method
297 /// is a fallback when the number of bytes used to encode the value is greater
298 /// than 1, but less than the max (9). The provided `result` value can be
299 /// assumed to already contain the first byte of the value.
300 /// NOTE: This method is marked noinline to avoid pessimizing the common case
301 /// of single byte encoding.
302 LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) {
303 // Count the number of trailing zeros in the marker byte, this indicates the
304 // number of trailing bytes that are part of the value. We use `uint32_t`
305 // here because we only care about the first byte, and so that be actually
306 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
307 // implementation).
308 uint32_t numBytes = llvm::countr_zero<uint32_t>(Val: result);
309 assert(numBytes > 0 && numBytes <= 7 &&
310 "unexpected number of trailing zeros in varint encoding");
311
312 // Parse in the remaining bytes of the value.
313 llvm::support::ulittle64_t resultLE(result);
314 if (failed(
315 Result: parseBytes(length: numBytes, result: reinterpret_cast<uint8_t *>(&resultLE) + 1)))
316 return failure();
317
318 // Shift out the low-order bits that were used to mark how the value was
319 // encoded.
320 result = resultLE >> (numBytes + 1);
321 return success();
322 }
323
324 /// The bytecode buffer.
325 ArrayRef<uint8_t> buffer;
326
327 /// The current iterator within the 'buffer'.
328 const uint8_t *dataIt;
329
330 /// A location for the bytecode used to report errors.
331 Location fileLoc;
332};
333} // namespace
334
335/// Resolve an index into the given entry list. `entry` may either be a
336/// reference, in which case it is assigned to the corresponding value in
337/// `entries`, or a pointer, in which case it is assigned to the address of the
338/// element in `entries`.
339template <typename RangeT, typename T>
340static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries,
341 uint64_t index, T &entry,
342 StringRef entryStr) {
343 if (index >= entries.size())
344 return reader.emitError(args: "invalid ", args&: entryStr, args: " index: ", args&: index);
345
346 // If the provided entry is a pointer, resolve to the address of the entry.
347 if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
348 entry = entries[index];
349 else
350 entry = &entries[index];
351 return success();
352}
353
354/// Parse and resolve an index into the given entry list.
355template <typename RangeT, typename T>
356static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
357 T &entry, StringRef entryStr) {
358 uint64_t entryIdx;
359 if (failed(Result: reader.parseVarInt(result&: entryIdx)))
360 return failure();
361 return resolveEntry(reader, entries, entryIdx, entry, entryStr);
362}
363
364//===----------------------------------------------------------------------===//
365// StringSectionReader
366//===----------------------------------------------------------------------===//
367
368namespace {
369/// This class is used to read references to the string section from the
370/// bytecode.
371class StringSectionReader {
372public:
373 /// Initialize the string section reader with the given section data.
374 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
375
376 /// Parse a shared string from the string section. The shared string is
377 /// encoded using an index to a corresponding string in the string section.
378 LogicalResult parseString(EncodingReader &reader, StringRef &result) const {
379 return parseEntry(reader, entries: strings, entry&: result, entryStr: "string");
380 }
381
382 /// Parse a shared string from the string section. The shared string is
383 /// encoded using an index to a corresponding string in the string section.
384 /// This variant parses a flag compressed with the index.
385 LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result,
386 bool &flag) const {
387 uint64_t entryIdx;
388 if (failed(Result: reader.parseVarIntWithFlag(result&: entryIdx, flag)))
389 return failure();
390 return parseStringAtIndex(reader, index: entryIdx, result);
391 }
392
393 /// Parse a shared string from the string section. The shared string is
394 /// encoded using an index to a corresponding string in the string section.
395 LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
396 StringRef &result) const {
397 return resolveEntry(reader, entries: strings, index, entry&: result, entryStr: "string");
398 }
399
400private:
401 /// The table of strings referenced within the bytecode file.
402 SmallVector<StringRef> strings;
403};
404} // namespace
405
406LogicalResult StringSectionReader::initialize(Location fileLoc,
407 ArrayRef<uint8_t> sectionData) {
408 EncodingReader stringReader(sectionData, fileLoc);
409
410 // Parse the number of strings in the section.
411 uint64_t numStrings;
412 if (failed(Result: stringReader.parseVarInt(result&: numStrings)))
413 return failure();
414 strings.resize(N: numStrings);
415
416 // Parse each of the strings. The sizes of the strings are encoded in reverse
417 // order, so that's the order we populate the table.
418 size_t stringDataEndOffset = sectionData.size();
419 for (StringRef &string : llvm::reverse(C&: strings)) {
420 uint64_t stringSize;
421 if (failed(Result: stringReader.parseVarInt(result&: stringSize)))
422 return failure();
423 if (stringDataEndOffset < stringSize) {
424 return stringReader.emitError(
425 args: "string size exceeds the available data size");
426 }
427
428 // Extract the string from the data, dropping the null character.
429 size_t stringOffset = stringDataEndOffset - stringSize;
430 string = StringRef(
431 reinterpret_cast<const char *>(sectionData.data() + stringOffset),
432 stringSize - 1);
433 stringDataEndOffset = stringOffset;
434 }
435
436 // Check that the only remaining data was for the strings, i.e. the reader
437 // should be at the same offset as the first string.
438 if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
439 return stringReader.emitError(args: "unexpected trailing data between the "
440 "offsets for strings and their data");
441 }
442 return success();
443}
444
445//===----------------------------------------------------------------------===//
446// BytecodeDialect
447//===----------------------------------------------------------------------===//
448
449namespace {
450class DialectReader;
451
452/// This struct represents a dialect entry within the bytecode.
453struct BytecodeDialect {
454 /// Load the dialect into the provided context if it hasn't been loaded yet.
455 /// Returns failure if the dialect couldn't be loaded *and* the provided
456 /// context does not allow unregistered dialects. The provided reader is used
457 /// for error emission if necessary.
458 LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
459
460 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can
461 /// only be called after `load`.
462 Dialect *getLoadedDialect() const {
463 assert(dialect &&
464 "expected `load` to be invoked before `getLoadedDialect`");
465 return *dialect;
466 }
467
468 /// The loaded dialect entry. This field is std::nullopt if we haven't
469 /// attempted to load, nullptr if we failed to load, otherwise the loaded
470 /// dialect.
471 std::optional<Dialect *> dialect;
472
473 /// The bytecode interface of the dialect, or nullptr if the dialect does not
474 /// implement the bytecode interface. This field should only be checked if the
475 /// `dialect` field is not std::nullopt.
476 const BytecodeDialectInterface *interface = nullptr;
477
478 /// The name of the dialect.
479 StringRef name;
480
481 /// A buffer containing the encoding of the dialect version parsed.
482 ArrayRef<uint8_t> versionBuffer;
483
484 /// Lazy loaded dialect version from the handle above.
485 std::unique_ptr<DialectVersion> loadedVersion;
486};
487
488/// This struct represents an operation name entry within the bytecode.
489struct BytecodeOperationName {
490 BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
491 std::optional<bool> wasRegistered)
492 : dialect(dialect), name(name), wasRegistered(wasRegistered) {}
493
494 /// The loaded operation name, or std::nullopt if it hasn't been processed
495 /// yet.
496 std::optional<OperationName> opName;
497
498 /// The dialect that owns this operation name.
499 BytecodeDialect *dialect;
500
501 /// The name of the operation, without the dialect prefix.
502 StringRef name;
503
504 /// Whether this operation was registered when the bytecode was produced.
505 /// This flag is populated when bytecode version >=kNativePropertiesEncoding.
506 std::optional<bool> wasRegistered;
507};
508} // namespace
509
510/// Parse a single dialect group encoded in the byte stream.
511static LogicalResult parseDialectGrouping(
512 EncodingReader &reader,
513 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
514 function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
515 // Parse the dialect and the number of entries in the group.
516 std::unique_ptr<BytecodeDialect> *dialect;
517 if (failed(Result: parseEntry(reader, entries&: dialects, entry&: dialect, entryStr: "dialect")))
518 return failure();
519 uint64_t numEntries;
520 if (failed(Result: reader.parseVarInt(result&: numEntries)))
521 return failure();
522
523 for (uint64_t i = 0; i < numEntries; ++i)
524 if (failed(Result: entryCallback(dialect->get())))
525 return failure();
526 return success();
527}
528
529//===----------------------------------------------------------------------===//
530// ResourceSectionReader
531//===----------------------------------------------------------------------===//
532
533namespace {
534/// This class is used to read the resource section from the bytecode.
535class ResourceSectionReader {
536public:
537 /// Initialize the resource section reader with the given section data.
538 LogicalResult
539 initialize(Location fileLoc, const ParserConfig &config,
540 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
541 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
542 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
543 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
544
545 /// Parse a dialect resource handle from the resource section.
546 LogicalResult parseResourceHandle(EncodingReader &reader,
547 AsmDialectResourceHandle &result) const {
548 return parseEntry(reader, entries: dialectResources, entry&: result, entryStr: "resource handle");
549 }
550
551private:
552 /// The table of dialect resources within the bytecode file.
553 SmallVector<AsmDialectResourceHandle> dialectResources;
554 llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
555};
556
557class ParsedResourceEntry : public AsmParsedResourceEntry {
558public:
559 ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
560 EncodingReader &reader, StringSectionReader &stringReader,
561 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
562 : key(key), kind(kind), reader(reader), stringReader(stringReader),
563 bufferOwnerRef(bufferOwnerRef) {}
564 ~ParsedResourceEntry() override = default;
565
566 StringRef getKey() const final { return key; }
567
568 InFlightDiagnostic emitError() const final { return reader.emitError(); }
569
570 AsmResourceEntryKind getKind() const final { return kind; }
571
572 FailureOr<bool> parseAsBool() const final {
573 if (kind != AsmResourceEntryKind::Bool)
574 return emitError() << "expected a bool resource entry, but found a "
575 << toString(kind) << " entry instead";
576
577 bool value;
578 if (failed(Result: reader.parseByte(value)))
579 return failure();
580 return value;
581 }
582 FailureOr<std::string> parseAsString() const final {
583 if (kind != AsmResourceEntryKind::String)
584 return emitError() << "expected a string resource entry, but found a "
585 << toString(kind) << " entry instead";
586
587 StringRef string;
588 if (failed(Result: stringReader.parseString(reader, result&: string)))
589 return failure();
590 return string.str();
591 }
592
593 FailureOr<AsmResourceBlob>
594 parseAsBlob(BlobAllocatorFn allocator) const final {
595 if (kind != AsmResourceEntryKind::Blob)
596 return emitError() << "expected a blob resource entry, but found a "
597 << toString(kind) << " entry instead";
598
599 ArrayRef<uint8_t> data;
600 uint64_t alignment;
601 if (failed(Result: reader.parseBlobAndAlignment(data, alignment)))
602 return failure();
603
604 // If we have an extendable reference to the buffer owner, we don't need to
605 // allocate a new buffer for the data, and can use the data directly.
606 if (bufferOwnerRef) {
607 ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()),
608 data.size());
609
610 // Allocate an unmanager buffer which captures a reference to the owner.
611 // For now we just mark this as immutable, but in the future we should
612 // explore marking this as mutable when desired.
613 return UnmanagedAsmResourceBlob::allocateWithAlign(
614 data: charData, align: alignment,
615 deleter: [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {});
616 }
617
618 // Allocate memory for the blob using the provided allocator and copy the
619 // data into it.
620 AsmResourceBlob blob = allocator(data.size(), alignment);
621 assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
622 blob.isMutable() &&
623 "blob allocator did not return a properly aligned address");
624 memcpy(dest: blob.getMutableData().data(), src: data.data(), n: data.size());
625 return blob;
626 }
627
628private:
629 StringRef key;
630 AsmResourceEntryKind kind;
631 EncodingReader &reader;
632 StringSectionReader &stringReader;
633 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
634};
635} // namespace
636
637template <typename T>
638static LogicalResult
639parseResourceGroup(Location fileLoc, bool allowEmpty,
640 EncodingReader &offsetReader, EncodingReader &resourceReader,
641 StringSectionReader &stringReader, T *handler,
642 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
643 function_ref<StringRef(StringRef)> remapKey = {},
644 function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
645 uint64_t numResources;
646 if (failed(Result: offsetReader.parseVarInt(result&: numResources)))
647 return failure();
648
649 for (uint64_t i = 0; i < numResources; ++i) {
650 StringRef key;
651 AsmResourceEntryKind kind;
652 uint64_t resourceOffset;
653 ArrayRef<uint8_t> data;
654 if (failed(Result: stringReader.parseString(reader&: offsetReader, result&: key)) ||
655 failed(Result: offsetReader.parseVarInt(result&: resourceOffset)) ||
656 failed(Result: offsetReader.parseByte(value&: kind)) ||
657 failed(Result: resourceReader.parseBytes(length: resourceOffset, result&: data)))
658 return failure();
659
660 // Process the resource key.
661 if ((processKeyFn && failed(Result: processKeyFn(key))))
662 return failure();
663
664 // If the resource data is empty and we allow it, don't error out when
665 // parsing below, just skip it.
666 if (allowEmpty && data.empty())
667 continue;
668
669 // Ignore the entry if we don't have a valid handler.
670 if (!handler)
671 continue;
672
673 // Otherwise, parse the resource value.
674 EncodingReader entryReader(data, fileLoc);
675 key = remapKey(key);
676 ParsedResourceEntry entry(key, kind, entryReader, stringReader,
677 bufferOwnerRef);
678 if (failed(handler->parseResource(entry)))
679 return failure();
680 if (!entryReader.empty()) {
681 return entryReader.emitError(
682 args: "unexpected trailing bytes in resource entry '", args&: key, args: "'");
683 }
684 }
685 return success();
686}
687
688LogicalResult ResourceSectionReader::initialize(
689 Location fileLoc, const ParserConfig &config,
690 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
691 StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
692 ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
693 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
694 EncodingReader resourceReader(sectionData, fileLoc);
695 EncodingReader offsetReader(offsetSectionData, fileLoc);
696
697 // Read the number of external resource providers.
698 uint64_t numExternalResourceGroups;
699 if (failed(Result: offsetReader.parseVarInt(result&: numExternalResourceGroups)))
700 return failure();
701
702 // Utility functor that dispatches to `parseResourceGroup`, but implicitly
703 // provides most of the arguments.
704 auto parseGroup = [&](auto *handler, bool allowEmpty = false,
705 function_ref<LogicalResult(StringRef)> keyFn = {}) {
706 auto resolveKey = [&](StringRef key) -> StringRef {
707 auto it = dialectResourceHandleRenamingMap.find(Key: key);
708 if (it == dialectResourceHandleRenamingMap.end())
709 return key;
710 return it->second;
711 };
712
713 return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
714 stringReader, handler, bufferOwnerRef, resolveKey,
715 keyFn);
716 };
717
718 // Read the external resources from the bytecode.
719 for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
720 StringRef key;
721 if (failed(Result: stringReader.parseString(reader&: offsetReader, result&: key)))
722 return failure();
723
724 // Get the handler for these resources.
725 // TODO: Should we require handling external resources in some scenarios?
726 AsmResourceParser *handler = config.getResourceParser(name: key);
727 if (!handler) {
728 emitWarning(loc: fileLoc) << "ignoring unknown external resources for '" << key
729 << "'";
730 }
731
732 if (failed(Result: parseGroup(handler)))
733 return failure();
734 }
735
736 // Read the dialect resources from the bytecode.
737 MLIRContext *ctx = fileLoc->getContext();
738 while (!offsetReader.empty()) {
739 std::unique_ptr<BytecodeDialect> *dialect;
740 if (failed(Result: parseEntry(reader&: offsetReader, entries&: dialects, entry&: dialect, entryStr: "dialect")) ||
741 failed(Result: (*dialect)->load(reader: dialectReader, ctx)))
742 return failure();
743 Dialect *loadedDialect = (*dialect)->getLoadedDialect();
744 if (!loadedDialect) {
745 return resourceReader.emitError()
746 << "dialect '" << (*dialect)->name << "' is unknown";
747 }
748 const auto *handler = dyn_cast<OpAsmDialectInterface>(Val: loadedDialect);
749 if (!handler) {
750 return resourceReader.emitError()
751 << "unexpected resources for dialect '" << (*dialect)->name << "'";
752 }
753
754 // Ensure that each resource is declared before being processed.
755 auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
756 FailureOr<AsmDialectResourceHandle> handle =
757 handler->declareResource(key);
758 if (failed(Result: handle)) {
759 return resourceReader.emitError()
760 << "unknown 'resource' key '" << key << "' for dialect '"
761 << (*dialect)->name << "'";
762 }
763 dialectResourceHandleRenamingMap[key] = handler->getResourceKey(handle: *handle);
764 dialectResources.push_back(Elt: *handle);
765 return success();
766 };
767
768 // Parse the resources for this dialect. We allow empty resources because we
769 // just treat these as declarations.
770 if (failed(Result: parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn)))
771 return failure();
772 }
773
774 return success();
775}
776
777//===----------------------------------------------------------------------===//
778// Attribute/Type Reader
779//===----------------------------------------------------------------------===//
780
781namespace {
782/// This class provides support for reading attribute and type entries from the
783/// bytecode. Attribute and Type entries are read lazily on demand, so we use
784/// this reader to manage when to actually parse them from the bytecode.
785class AttrTypeReader {
786 /// This class represents a single attribute or type entry.
787 template <typename T>
788 struct Entry {
789 /// The entry, or null if it hasn't been resolved yet.
790 T entry = {};
791 /// The parent dialect of this entry.
792 BytecodeDialect *dialect = nullptr;
793 /// A flag indicating if the entry was encoded using a custom encoding,
794 /// instead of using the textual assembly format.
795 bool hasCustomEncoding = false;
796 /// The raw data of this entry in the bytecode.
797 ArrayRef<uint8_t> data;
798 };
799 using AttrEntry = Entry<Attribute>;
800 using TypeEntry = Entry<Type>;
801
802public:
803 AttrTypeReader(const StringSectionReader &stringReader,
804 const ResourceSectionReader &resourceReader,
805 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
806 uint64_t &bytecodeVersion, Location fileLoc,
807 const ParserConfig &config)
808 : stringReader(stringReader), resourceReader(resourceReader),
809 dialectsMap(dialectsMap), fileLoc(fileLoc),
810 bytecodeVersion(bytecodeVersion), parserConfig(config) {}
811
812 /// Initialize the attribute and type information within the reader.
813 LogicalResult
814 initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
815 ArrayRef<uint8_t> sectionData,
816 ArrayRef<uint8_t> offsetSectionData);
817
818 /// Resolve the attribute or type at the given index. Returns nullptr on
819 /// failure.
820 Attribute resolveAttribute(size_t index) {
821 return resolveEntry(entries&: attributes, index, entryType: "Attribute");
822 }
823 Type resolveType(size_t index) { return resolveEntry(entries&: types, index, entryType: "Type"); }
824
825 /// Parse a reference to an attribute or type using the given reader.
826 LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
827 uint64_t attrIdx;
828 if (failed(Result: reader.parseVarInt(result&: attrIdx)))
829 return failure();
830 result = resolveAttribute(index: attrIdx);
831 return success(IsSuccess: !!result);
832 }
833 LogicalResult parseOptionalAttribute(EncodingReader &reader,
834 Attribute &result) {
835 uint64_t attrIdx;
836 bool flag;
837 if (failed(Result: reader.parseVarIntWithFlag(result&: attrIdx, flag)))
838 return failure();
839 if (!flag)
840 return success();
841 result = resolveAttribute(index: attrIdx);
842 return success(IsSuccess: !!result);
843 }
844
845 LogicalResult parseType(EncodingReader &reader, Type &result) {
846 uint64_t typeIdx;
847 if (failed(Result: reader.parseVarInt(result&: typeIdx)))
848 return failure();
849 result = resolveType(index: typeIdx);
850 return success(IsSuccess: !!result);
851 }
852
853 template <typename T>
854 LogicalResult parseAttribute(EncodingReader &reader, T &result) {
855 Attribute baseResult;
856 if (failed(Result: parseAttribute(reader, result&: baseResult)))
857 return failure();
858 if ((result = dyn_cast<T>(baseResult)))
859 return success();
860 return reader.emitError("expected attribute of type: ",
861 llvm::getTypeName<T>(), ", but got: ", baseResult);
862 }
863
864private:
865 /// Resolve the given entry at `index`.
866 template <typename T>
867 T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
868 StringRef entryType);
869
870 /// Parse an entry using the given reader that was encoded using the textual
871 /// assembly format.
872 template <typename T>
873 LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
874 StringRef entryType);
875
876 /// Parse an entry using the given reader that was encoded using a custom
877 /// bytecode format.
878 template <typename T>
879 LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
880 StringRef entryType);
881
882 /// The string section reader used to resolve string references when parsing
883 /// custom encoded attribute/type entries.
884 const StringSectionReader &stringReader;
885
886 /// The resource section reader used to resolve resource references when
887 /// parsing custom encoded attribute/type entries.
888 const ResourceSectionReader &resourceReader;
889
890 /// The map of the loaded dialects used to retrieve dialect information, such
891 /// as the dialect version.
892 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
893
894 /// The set of attribute and type entries.
895 SmallVector<AttrEntry> attributes;
896 SmallVector<TypeEntry> types;
897
898 /// The map of cached attributes, used to avoid re-parsing the same
899 /// attribute multiple times.
900 llvm::StringMap<Attribute> attributesCache;
901
902 /// A location used for error emission.
903 Location fileLoc;
904
905 /// Current bytecode version being used.
906 uint64_t &bytecodeVersion;
907
908 /// Reference to the parser configuration.
909 const ParserConfig &parserConfig;
910};
911
912class DialectReader : public DialectBytecodeReader {
913public:
914 DialectReader(AttrTypeReader &attrTypeReader,
915 const StringSectionReader &stringReader,
916 const ResourceSectionReader &resourceReader,
917 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
918 EncodingReader &reader, uint64_t &bytecodeVersion)
919 : attrTypeReader(attrTypeReader), stringReader(stringReader),
920 resourceReader(resourceReader), dialectsMap(dialectsMap),
921 reader(reader), bytecodeVersion(bytecodeVersion) {}
922
923 InFlightDiagnostic emitError(const Twine &msg) const override {
924 return reader.emitError(args: msg);
925 }
926
927 FailureOr<const DialectVersion *>
928 getDialectVersion(StringRef dialectName) const override {
929 // First check if the dialect is available in the map.
930 auto dialectEntry = dialectsMap.find(Key: dialectName);
931 if (dialectEntry == dialectsMap.end())
932 return failure();
933 // If the dialect was found, try to load it. This will trigger reading the
934 // bytecode version from the version buffer if it wasn't already processed.
935 // Return failure if either of those two actions could not be completed.
936 if (failed(Result: dialectEntry->getValue()->load(reader: *this, ctx: getLoc().getContext())) ||
937 dialectEntry->getValue()->loadedVersion == nullptr)
938 return failure();
939 return dialectEntry->getValue()->loadedVersion.get();
940 }
941
942 MLIRContext *getContext() const override { return getLoc().getContext(); }
943
944 uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
945
946 DialectReader withEncodingReader(EncodingReader &encReader) const {
947 return DialectReader(attrTypeReader, stringReader, resourceReader,
948 dialectsMap, encReader, bytecodeVersion);
949 }
950
951 Location getLoc() const { return reader.getLoc(); }
952
953 //===--------------------------------------------------------------------===//
954 // IR
955 //===--------------------------------------------------------------------===//
956
957 LogicalResult readAttribute(Attribute &result) override {
958 return attrTypeReader.parseAttribute(reader, result);
959 }
960 LogicalResult readOptionalAttribute(Attribute &result) override {
961 return attrTypeReader.parseOptionalAttribute(reader, result);
962 }
963 LogicalResult readType(Type &result) override {
964 return attrTypeReader.parseType(reader, result);
965 }
966
967 FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
968 AsmDialectResourceHandle handle;
969 if (failed(Result: resourceReader.parseResourceHandle(reader, result&: handle)))
970 return failure();
971 return handle;
972 }
973
974 //===--------------------------------------------------------------------===//
975 // Primitives
976 //===--------------------------------------------------------------------===//
977
978 LogicalResult readVarInt(uint64_t &result) override {
979 return reader.parseVarInt(result);
980 }
981
982 LogicalResult readSignedVarInt(int64_t &result) override {
983 uint64_t unsignedResult;
984 if (failed(Result: reader.parseSignedVarInt(result&: unsignedResult)))
985 return failure();
986 result = static_cast<int64_t>(unsignedResult);
987 return success();
988 }
989
990 FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override {
991 // Small values are encoded using a single byte.
992 if (bitWidth <= 8) {
993 uint8_t value;
994 if (failed(Result: reader.parseByte(value)))
995 return failure();
996 return APInt(bitWidth, value);
997 }
998
999 // Large values up to 64 bits are encoded using a single varint.
1000 if (bitWidth <= 64) {
1001 uint64_t value;
1002 if (failed(Result: reader.parseSignedVarInt(result&: value)))
1003 return failure();
1004 return APInt(bitWidth, value);
1005 }
1006
1007 // Otherwise, for really big values we encode the array of active words in
1008 // the value.
1009 uint64_t numActiveWords;
1010 if (failed(Result: reader.parseVarInt(result&: numActiveWords)))
1011 return failure();
1012 SmallVector<uint64_t, 4> words(numActiveWords);
1013 for (uint64_t i = 0; i < numActiveWords; ++i)
1014 if (failed(Result: reader.parseSignedVarInt(result&: words[i])))
1015 return failure();
1016 return APInt(bitWidth, words);
1017 }
1018
1019 FailureOr<APFloat>
1020 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override {
1021 FailureOr<APInt> intVal =
1022 readAPIntWithKnownWidth(bitWidth: APFloat::getSizeInBits(Sem: semantics));
1023 if (failed(Result: intVal))
1024 return failure();
1025 return APFloat(semantics, *intVal);
1026 }
1027
1028 LogicalResult readString(StringRef &result) override {
1029 return stringReader.parseString(reader, result);
1030 }
1031
1032 LogicalResult readBlob(ArrayRef<char> &result) override {
1033 uint64_t dataSize;
1034 ArrayRef<uint8_t> data;
1035 if (failed(Result: reader.parseVarInt(result&: dataSize)) ||
1036 failed(Result: reader.parseBytes(length: dataSize, result&: data)))
1037 return failure();
1038 result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()),
1039 data.size());
1040 return success();
1041 }
1042
1043 LogicalResult readBool(bool &result) override {
1044 return reader.parseByte(value&: result);
1045 }
1046
1047private:
1048 AttrTypeReader &attrTypeReader;
1049 const StringSectionReader &stringReader;
1050 const ResourceSectionReader &resourceReader;
1051 const llvm::StringMap<BytecodeDialect *> &dialectsMap;
1052 EncodingReader &reader;
1053 uint64_t &bytecodeVersion;
1054};
1055
1056/// Wraps the properties section and handles reading properties out of it.
1057class PropertiesSectionReader {
1058public:
1059 /// Initialize the properties section reader with the given section data.
1060 LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
1061 if (sectionData.empty())
1062 return success();
1063 EncodingReader propReader(sectionData, fileLoc);
1064 uint64_t count;
1065 if (failed(Result: propReader.parseVarInt(result&: count)))
1066 return failure();
1067 // Parse the raw properties buffer.
1068 if (failed(Result: propReader.parseBytes(length: propReader.size(), result&: propertiesBuffers)))
1069 return failure();
1070
1071 EncodingReader offsetsReader(propertiesBuffers, fileLoc);
1072 offsetTable.reserve(N: count);
1073 for (auto idx : llvm::seq<int64_t>(Begin: 0, End: count)) {
1074 (void)idx;
1075 offsetTable.push_back(Elt: propertiesBuffers.size() - offsetsReader.size());
1076 ArrayRef<uint8_t> rawProperties;
1077 uint64_t dataSize;
1078 if (failed(Result: offsetsReader.parseVarInt(result&: dataSize)) ||
1079 failed(Result: offsetsReader.parseBytes(length: dataSize, result&: rawProperties)))
1080 return failure();
1081 }
1082 if (!offsetsReader.empty())
1083 return offsetsReader.emitError()
1084 << "Broken properties section: didn't exhaust the offsets table";
1085 return success();
1086 }
1087
1088 LogicalResult read(Location fileLoc, DialectReader &dialectReader,
1089 OperationName *opName, OperationState &opState) const {
1090 uint64_t propertiesIdx;
1091 if (failed(Result: dialectReader.readVarInt(result&: propertiesIdx)))
1092 return failure();
1093 if (propertiesIdx >= offsetTable.size())
1094 return dialectReader.emitError(msg: "Properties idx out-of-bound for ")
1095 << opName->getStringRef();
1096 size_t propertiesOffset = offsetTable[propertiesIdx];
1097 if (propertiesIdx >= propertiesBuffers.size())
1098 return dialectReader.emitError(msg: "Properties offset out-of-bound for ")
1099 << opName->getStringRef();
1100
1101 // Acquire the sub-buffer that represent the requested properties.
1102 ArrayRef<char> rawProperties;
1103 {
1104 // "Seek" to the requested offset by getting a new reader with the right
1105 // sub-buffer.
1106 EncodingReader reader(propertiesBuffers.drop_front(N: propertiesOffset),
1107 fileLoc);
1108 // Properties are stored as a sequence of {size + raw_data}.
1109 if (failed(
1110 Result: dialectReader.withEncodingReader(encReader&: reader).readBlob(result&: rawProperties)))
1111 return failure();
1112 }
1113 // Setup a new reader to read from the `rawProperties` sub-buffer.
1114 EncodingReader reader(
1115 StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
1116 DialectReader propReader = dialectReader.withEncodingReader(encReader&: reader);
1117
1118 auto *iface = opName->getInterface<BytecodeOpInterface>();
1119 if (iface)
1120 return iface->readProperties(propReader, opState);
1121 if (opName->isRegistered())
1122 return propReader.emitError(
1123 msg: "has properties but missing BytecodeOpInterface for ")
1124 << opName->getStringRef();
1125 // Unregistered op are storing properties as an attribute.
1126 return propReader.readAttribute(result&: opState.propertiesAttr);
1127 }
1128
1129private:
1130 /// The properties buffer referenced within the bytecode file.
1131 ArrayRef<uint8_t> propertiesBuffers;
1132
1133 /// Table of offset in the buffer above.
1134 SmallVector<int64_t> offsetTable;
1135};
1136} // namespace
1137
1138LogicalResult AttrTypeReader::initialize(
1139 MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
1140 ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
1141 EncodingReader offsetReader(offsetSectionData, fileLoc);
1142
1143 // Parse the number of attribute and type entries.
1144 uint64_t numAttributes, numTypes;
1145 if (failed(Result: offsetReader.parseVarInt(result&: numAttributes)) ||
1146 failed(Result: offsetReader.parseVarInt(result&: numTypes)))
1147 return failure();
1148 attributes.resize(N: numAttributes);
1149 types.resize(N: numTypes);
1150
1151 // A functor used to accumulate the offsets for the entries in the given
1152 // range.
1153 uint64_t currentOffset = 0;
1154 auto parseEntries = [&](auto &&range) {
1155 size_t currentIndex = 0, endIndex = range.size();
1156
1157 // Parse an individual entry.
1158 auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
1159 auto &entry = range[currentIndex++];
1160
1161 uint64_t entrySize;
1162 if (failed(offsetReader.parseVarIntWithFlag(result&: entrySize,
1163 flag&: entry.hasCustomEncoding)))
1164 return failure();
1165
1166 // Verify that the offset is actually valid.
1167 if (currentOffset + entrySize > sectionData.size()) {
1168 return offsetReader.emitError(
1169 args: "Attribute or Type entry offset points past the end of section");
1170 }
1171
1172 entry.data = sectionData.slice(N: currentOffset, M: entrySize);
1173 entry.dialect = dialect;
1174 currentOffset += entrySize;
1175 return success();
1176 };
1177 while (currentIndex != endIndex)
1178 if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn)))
1179 return failure();
1180 return success();
1181 };
1182
1183 // Process each of the attributes, and then the types.
1184 if (failed(Result: parseEntries(attributes)) || failed(Result: parseEntries(types)))
1185 return failure();
1186
1187 // Ensure that we read everything from the section.
1188 if (!offsetReader.empty()) {
1189 return offsetReader.emitError(
1190 args: "unexpected trailing data in the Attribute/Type offset section");
1191 }
1192
1193 return success();
1194}
1195
1196template <typename T>
1197T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
1198 StringRef entryType) {
1199 if (index >= entries.size()) {
1200 emitError(loc: fileLoc) << "invalid " << entryType << " index: " << index;
1201 return {};
1202 }
1203
1204 // If the entry has already been resolved, there is nothing left to do.
1205 Entry<T> &entry = entries[index];
1206 if (entry.entry)
1207 return entry.entry;
1208
1209 // Parse the entry.
1210 EncodingReader reader(entry.data, fileLoc);
1211
1212 // Parse based on how the entry was encoded.
1213 if (entry.hasCustomEncoding) {
1214 if (failed(parseCustomEntry(entry, reader, entryType)))
1215 return T();
1216 } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
1217 return T();
1218 }
1219
1220 if (!reader.empty()) {
1221 reader.emitError(args: "unexpected trailing bytes after " + entryType + " entry");
1222 return T();
1223 }
1224 return entry.entry;
1225}
1226
1227template <typename T>
1228LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1229 StringRef entryType) {
1230 StringRef asmStr;
1231 if (failed(Result: reader.parseNullTerminatedString(result&: asmStr)))
1232 return failure();
1233
1234 // Invoke the MLIR assembly parser to parse the entry text.
1235 size_t numRead = 0;
1236 MLIRContext *context = fileLoc->getContext();
1237 if constexpr (std::is_same_v<T, Type>)
1238 result =
1239 ::parseType(typeStr: asmStr, context, numRead: &numRead, /*isKnownNullTerminated=*/true);
1240 else
1241 result = ::parseAttribute(attrStr: asmStr, context, type: Type(), numRead: &numRead,
1242 /*isKnownNullTerminated=*/true, attributesCache: &attributesCache);
1243 if (!result)
1244 return failure();
1245
1246 // Ensure there weren't dangling characters after the entry.
1247 if (numRead != asmStr.size()) {
1248 return reader.emitError(args: "trailing characters found after ", args&: entryType,
1249 args: " assembly format: ", args: asmStr.drop_front(N: numRead));
1250 }
1251 return success();
1252}
1253
1254template <typename T>
1255LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
1256 EncodingReader &reader,
1257 StringRef entryType) {
1258 DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
1259 reader, bytecodeVersion);
1260 if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
1261 return failure();
1262
1263 if constexpr (std::is_same_v<T, Type>) {
1264 // Try parsing with callbacks first if available.
1265 for (const auto &callback :
1266 parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
1267 if (failed(
1268 callback->read(reader&: dialectReader, dialectName: entry.dialect->name, entry&: entry.entry)))
1269 return failure();
1270 // Early return if parsing was successful.
1271 if (!!entry.entry)
1272 return success();
1273
1274 // Reset the reader if we failed to parse, so we can fall through the
1275 // other parsing functions.
1276 reader = EncodingReader(entry.data, reader.getLoc());
1277 }
1278 } else {
1279 // Try parsing with callbacks first if available.
1280 for (const auto &callback :
1281 parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
1282 if (failed(
1283 callback->read(reader&: dialectReader, dialectName: entry.dialect->name, entry&: entry.entry)))
1284 return failure();
1285 // Early return if parsing was successful.
1286 if (!!entry.entry)
1287 return success();
1288
1289 // Reset the reader if we failed to parse, so we can fall through the
1290 // other parsing functions.
1291 reader = EncodingReader(entry.data, reader.getLoc());
1292 }
1293 }
1294
1295 // Ensure that the dialect implements the bytecode interface.
1296 if (!entry.dialect->interface) {
1297 return reader.emitError("dialect '", entry.dialect->name,
1298 "' does not implement the bytecode interface");
1299 }
1300
1301 if constexpr (std::is_same_v<T, Type>)
1302 entry.entry = entry.dialect->interface->readType(dialectReader);
1303 else
1304 entry.entry = entry.dialect->interface->readAttribute(dialectReader);
1305
1306 return success(!!entry.entry);
1307}
1308
1309//===----------------------------------------------------------------------===//
1310// Bytecode Reader
1311//===----------------------------------------------------------------------===//
1312
1313/// This class is used to read a bytecode buffer and translate it into MLIR.
1314class mlir::BytecodeReader::Impl {
1315 struct RegionReadState;
1316 using LazyLoadableOpsInfo =
1317 std::list<std::pair<Operation *, RegionReadState>>;
1318 using LazyLoadableOpsMap =
1319 DenseMap<Operation *, LazyLoadableOpsInfo::iterator>;
1320
1321public:
1322 Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
1323 llvm::MemoryBufferRef buffer,
1324 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
1325 : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
1326 attrTypeReader(stringReader, resourceReader, dialectsMap, version,
1327 fileLoc, config),
1328 // Use the builtin unrealized conversion cast operation to represent
1329 // forward references to values that aren't yet defined.
1330 forwardRefOpState(UnknownLoc::get(context: config.getContext()),
1331 "builtin.unrealized_conversion_cast", ValueRange(),
1332 NoneType::get(context: config.getContext())),
1333 buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
1334
1335 /// Read the bytecode defined within `buffer` into the given block.
1336 LogicalResult read(Block *block,
1337 llvm::function_ref<bool(Operation *)> lazyOps);
1338
1339 /// Return the number of ops that haven't been materialized yet.
1340 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
1341
1342 bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(Val: op); }
1343
1344 /// Materialize the provided operation, invoke the lazyOpsCallback on every
1345 /// newly found lazy operation.
1346 LogicalResult
1347 materialize(Operation *op,
1348 llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1349 this->lazyOpsCallback = lazyOpsCallback;
1350 auto resetlazyOpsCallback =
1351 llvm::make_scope_exit(F: [&] { this->lazyOpsCallback = nullptr; });
1352 auto it = lazyLoadableOpsMap.find(Val: op);
1353 assert(it != lazyLoadableOpsMap.end() &&
1354 "materialize called on non-materializable op");
1355 return materialize(it);
1356 }
1357
1358 /// Materialize all operations.
1359 LogicalResult materializeAll() {
1360 while (!lazyLoadableOpsMap.empty()) {
1361 if (failed(Result: materialize(it: lazyLoadableOpsMap.begin())))
1362 return failure();
1363 }
1364 return success();
1365 }
1366
1367 /// Finalize the lazy-loading by calling back with every op that hasn't been
1368 /// materialized to let the client decide if the op should be deleted or
1369 /// materialized. The op is materialized if the callback returns true, deleted
1370 /// otherwise.
1371 LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
1372 while (!lazyLoadableOps.empty()) {
1373 Operation *op = lazyLoadableOps.begin()->first;
1374 if (shouldMaterialize(op)) {
1375 if (failed(Result: materialize(it: lazyLoadableOpsMap.find(Val: op))))
1376 return failure();
1377 continue;
1378 }
1379 op->dropAllReferences();
1380 op->erase();
1381 lazyLoadableOps.pop_front();
1382 lazyLoadableOpsMap.erase(Val: op);
1383 }
1384 return success();
1385 }
1386
1387private:
1388 LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
1389 assert(it != lazyLoadableOpsMap.end() &&
1390 "materialize called on non-materializable op");
1391 valueScopes.emplace_back();
1392 std::vector<RegionReadState> regionStack;
1393 regionStack.push_back(x: std::move(it->getSecond()->second));
1394 lazyLoadableOps.erase(position: it->getSecond());
1395 lazyLoadableOpsMap.erase(I: it);
1396
1397 while (!regionStack.empty())
1398 if (failed(Result: parseRegions(regionStack, readState&: regionStack.back())))
1399 return failure();
1400 return success();
1401 }
1402
1403 /// Return the context for this config.
1404 MLIRContext *getContext() const { return config.getContext(); }
1405
1406 /// Parse the bytecode version.
1407 LogicalResult parseVersion(EncodingReader &reader);
1408
1409 //===--------------------------------------------------------------------===//
1410 // Dialect Section
1411
1412 LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
1413
1414 /// Parse an operation name reference using the given reader, and set the
1415 /// `wasRegistered` flag that indicates if the bytecode was produced by a
1416 /// context where opName was registered.
1417 FailureOr<OperationName> parseOpName(EncodingReader &reader,
1418 std::optional<bool> &wasRegistered);
1419
1420 //===--------------------------------------------------------------------===//
1421 // Attribute/Type Section
1422
1423 /// Parse an attribute or type using the given reader.
1424 template <typename T>
1425 LogicalResult parseAttribute(EncodingReader &reader, T &result) {
1426 return attrTypeReader.parseAttribute(reader, result);
1427 }
1428 LogicalResult parseType(EncodingReader &reader, Type &result) {
1429 return attrTypeReader.parseType(reader, result);
1430 }
1431
1432 //===--------------------------------------------------------------------===//
1433 // Resource Section
1434
1435 LogicalResult
1436 parseResourceSection(EncodingReader &reader,
1437 std::optional<ArrayRef<uint8_t>> resourceData,
1438 std::optional<ArrayRef<uint8_t>> resourceOffsetData);
1439
1440 //===--------------------------------------------------------------------===//
1441 // IR Section
1442
1443 /// This struct represents the current read state of a range of regions. This
1444 /// struct is used to enable iterative parsing of regions.
1445 struct RegionReadState {
1446 RegionReadState(Operation *op, EncodingReader *reader,
1447 bool isIsolatedFromAbove)
1448 : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
1449 RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
1450 bool isIsolatedFromAbove)
1451 : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
1452 isIsolatedFromAbove(isIsolatedFromAbove) {}
1453
1454 /// The current regions being read.
1455 MutableArrayRef<Region>::iterator curRegion, endRegion;
1456 /// This is the reader to use for this region, this pointer is pointing to
1457 /// the parent region reader unless the current region is IsolatedFromAbove,
1458 /// in which case the pointer is pointing to the `owningReader` which is a
1459 /// section dedicated to the current region.
1460 EncodingReader *reader;
1461 std::unique_ptr<EncodingReader> owningReader;
1462
1463 /// The number of values defined immediately within this region.
1464 unsigned numValues = 0;
1465
1466 /// The current blocks of the region being read.
1467 SmallVector<Block *> curBlocks;
1468 Region::iterator curBlock = {};
1469
1470 /// The number of operations remaining to be read from the current block
1471 /// being read.
1472 uint64_t numOpsRemaining = 0;
1473
1474 /// A flag indicating if the regions being read are isolated from above.
1475 bool isIsolatedFromAbove = false;
1476 };
1477
1478 LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
1479 LogicalResult parseRegions(std::vector<RegionReadState> &regionStack,
1480 RegionReadState &readState);
1481 FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
1482 RegionReadState &readState,
1483 bool &isIsolatedFromAbove);
1484
1485 LogicalResult parseRegion(RegionReadState &readState);
1486 LogicalResult parseBlockHeader(EncodingReader &reader,
1487 RegionReadState &readState);
1488 LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
1489
1490 //===--------------------------------------------------------------------===//
1491 // Value Processing
1492
1493 /// Parse an operand reference using the given reader. Returns nullptr in the
1494 /// case of failure.
1495 Value parseOperand(EncodingReader &reader);
1496
1497 /// Sequentially define the given value range.
1498 LogicalResult defineValues(EncodingReader &reader, ValueRange values);
1499
1500 /// Create a value to use for a forward reference.
1501 Value createForwardRef();
1502
1503 //===--------------------------------------------------------------------===//
1504 // Use-list order helpers
1505
1506 /// This struct is a simple storage that contains information required to
1507 /// reorder the use-list of a value with respect to the pre-order traversal
1508 /// ordering.
1509 struct UseListOrderStorage {
1510 UseListOrderStorage(bool isIndexPairEncoding,
1511 SmallVector<unsigned, 4> &&indices)
1512 : indices(std::move(indices)),
1513 isIndexPairEncoding(isIndexPairEncoding){};
1514 /// The vector containing the information required to reorder the
1515 /// use-list of a value.
1516 SmallVector<unsigned, 4> indices;
1517
1518 /// Whether indices represent a pair of type `(src, dst)` or it is a direct
1519 /// indexing, such as `dst = order[src]`.
1520 bool isIndexPairEncoding;
1521 };
1522
1523 /// Parse use-list order from bytecode for a range of values if available. The
1524 /// range is expected to be either a block argument or an op result range. On
1525 /// success, return a map of the position in the range and the use-list order
1526 /// encoding. The function assumes to know the size of the range it is
1527 /// processing.
1528 using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
1529 FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
1530 uint64_t rangeSize);
1531
1532 /// Shuffle the use-chain according to the order parsed.
1533 LogicalResult sortUseListOrder(Value value);
1534
1535 /// Recursively visit all the values defined within topLevelOp and sort the
1536 /// use-list orders according to the indices parsed.
1537 LogicalResult processUseLists(Operation *topLevelOp);
1538
1539 //===--------------------------------------------------------------------===//
1540 // Fields
1541
1542 /// This class represents a single value scope, in which a value scope is
1543 /// delimited by isolated from above regions.
1544 struct ValueScope {
1545 /// Push a new region state onto this scope, reserving enough values for
1546 /// those defined within the current region of the provided state.
1547 void push(RegionReadState &readState) {
1548 nextValueIDs.push_back(Elt: values.size());
1549 values.resize(new_size: values.size() + readState.numValues);
1550 }
1551
1552 /// Pop the values defined for the current region within the provided region
1553 /// state.
1554 void pop(RegionReadState &readState) {
1555 values.resize(new_size: values.size() - readState.numValues);
1556 nextValueIDs.pop_back();
1557 }
1558
1559 /// The set of values defined in this scope.
1560 std::vector<Value> values;
1561
1562 /// The ID for the next defined value for each region current being
1563 /// processed in this scope.
1564 SmallVector<unsigned, 4> nextValueIDs;
1565 };
1566
1567 /// The configuration of the parser.
1568 const ParserConfig &config;
1569
1570 /// A location to use when emitting errors.
1571 Location fileLoc;
1572
1573 /// Flag that indicates if lazyloading is enabled.
1574 bool lazyLoading;
1575
1576 /// Keep track of operations that have been lazy loaded (their regions haven't
1577 /// been materialized), along with the `RegionReadState` that allows to
1578 /// lazy-load the regions nested under the operation.
1579 LazyLoadableOpsInfo lazyLoadableOps;
1580 LazyLoadableOpsMap lazyLoadableOpsMap;
1581 llvm::function_ref<bool(Operation *)> lazyOpsCallback;
1582
1583 /// The reader used to process attribute and types within the bytecode.
1584 AttrTypeReader attrTypeReader;
1585
1586 /// The version of the bytecode being read.
1587 uint64_t version = 0;
1588
1589 /// The producer of the bytecode being read.
1590 StringRef producer;
1591
1592 /// The table of IR units referenced within the bytecode file.
1593 SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
1594 llvm::StringMap<BytecodeDialect *> dialectsMap;
1595 SmallVector<BytecodeOperationName> opNames;
1596
1597 /// The reader used to process resources within the bytecode.
1598 ResourceSectionReader resourceReader;
1599
1600 /// Worklist of values with custom use-list orders to process before the end
1601 /// of the parsing.
1602 DenseMap<void *, UseListOrderStorage> valueToUseListMap;
1603
1604 /// The table of strings referenced within the bytecode file.
1605 StringSectionReader stringReader;
1606
1607 /// The table of properties referenced by the operation in the bytecode file.
1608 PropertiesSectionReader propertiesReader;
1609
1610 /// The current set of available IR value scopes.
1611 std::vector<ValueScope> valueScopes;
1612
1613 /// The global pre-order operation ordering.
1614 DenseMap<Operation *, unsigned> operationIDs;
1615
1616 /// A block containing the set of operations defined to create forward
1617 /// references.
1618 Block forwardRefOps;
1619
1620 /// A block containing previously created, and no longer used, forward
1621 /// reference operations.
1622 Block openForwardRefOps;
1623
1624 /// An operation state used when instantiating forward references.
1625 OperationState forwardRefOpState;
1626
1627 /// Reference to the input buffer.
1628 llvm::MemoryBufferRef buffer;
1629
1630 /// The optional owning source manager, which when present may be used to
1631 /// extend the lifetime of the input buffer.
1632 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
1633};
1634
1635LogicalResult BytecodeReader::Impl::read(
1636 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
1637 EncodingReader reader(buffer.getBuffer(), fileLoc);
1638 this->lazyOpsCallback = lazyOpsCallback;
1639 auto resetlazyOpsCallback =
1640 llvm::make_scope_exit(F: [&] { this->lazyOpsCallback = nullptr; });
1641
1642 // Skip over the bytecode header, this should have already been checked.
1643 if (failed(Result: reader.skipBytes(length: StringRef("ML\xefR").size())))
1644 return failure();
1645 // Parse the bytecode version and producer.
1646 if (failed(Result: parseVersion(reader)) ||
1647 failed(Result: reader.parseNullTerminatedString(result&: producer)))
1648 return failure();
1649
1650 // Add a diagnostic handler that attaches a note that includes the original
1651 // producer of the bytecode.
1652 ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) {
1653 diag.attachNote() << "in bytecode version " << version
1654 << " produced by: " << producer;
1655 return failure();
1656 });
1657
1658 // Parse the raw data for each of the top-level sections of the bytecode.
1659 std::optional<ArrayRef<uint8_t>>
1660 sectionDatas[bytecode::Section::kNumSections];
1661 while (!reader.empty()) {
1662 // Read the next section from the bytecode.
1663 bytecode::Section::ID sectionID;
1664 ArrayRef<uint8_t> sectionData;
1665 if (failed(Result: reader.parseSection(sectionID, sectionData)))
1666 return failure();
1667
1668 // Check for duplicate sections, we only expect one instance of each.
1669 if (sectionDatas[sectionID]) {
1670 return reader.emitError(args: "duplicate top-level section: ",
1671 args: ::toString(sectionID));
1672 }
1673 sectionDatas[sectionID] = sectionData;
1674 }
1675 // Check that all of the required sections were found.
1676 for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
1677 bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
1678 if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) {
1679 return reader.emitError(args: "missing data for top-level section: ",
1680 args: ::toString(sectionID));
1681 }
1682 }
1683
1684 // Process the string section first.
1685 if (failed(Result: stringReader.initialize(
1686 fileLoc, sectionData: *sectionDatas[bytecode::Section::kString])))
1687 return failure();
1688
1689 // Process the properties section.
1690 if (sectionDatas[bytecode::Section::kProperties] &&
1691 failed(Result: propertiesReader.initialize(
1692 fileLoc, sectionData: *sectionDatas[bytecode::Section::kProperties])))
1693 return failure();
1694
1695 // Process the dialect section.
1696 if (failed(Result: parseDialectSection(sectionData: *sectionDatas[bytecode::Section::kDialect])))
1697 return failure();
1698
1699 // Process the resource section if present.
1700 if (failed(Result: parseResourceSection(
1701 reader, resourceData: sectionDatas[bytecode::Section::kResource],
1702 resourceOffsetData: sectionDatas[bytecode::Section::kResourceOffset])))
1703 return failure();
1704
1705 // Process the attribute and type section.
1706 if (failed(Result: attrTypeReader.initialize(
1707 dialects, sectionData: *sectionDatas[bytecode::Section::kAttrType],
1708 offsetSectionData: *sectionDatas[bytecode::Section::kAttrTypeOffset])))
1709 return failure();
1710
1711 // Finally, process the IR section.
1712 return parseIRSection(sectionData: *sectionDatas[bytecode::Section::kIR], block);
1713}
1714
1715LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
1716 if (failed(Result: reader.parseVarInt(result&: version)))
1717 return failure();
1718
1719 // Validate the bytecode version.
1720 uint64_t currentVersion = bytecode::kVersion;
1721 uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
1722 if (version < minSupportedVersion) {
1723 return reader.emitError(args: "bytecode version ", args&: version,
1724 args: " is older than the current version of ",
1725 args&: currentVersion, args: ", and upgrade is not supported");
1726 }
1727 if (version > currentVersion) {
1728 return reader.emitError(args: "bytecode version ", args&: version,
1729 args: " is newer than the current version ",
1730 args&: currentVersion);
1731 }
1732 // Override any request to lazy-load if the bytecode version is too old.
1733 if (version < bytecode::kLazyLoading)
1734 lazyLoading = false;
1735 return success();
1736}
1737
1738//===----------------------------------------------------------------------===//
1739// Dialect Section
1740//===----------------------------------------------------------------------===//
1741
1742LogicalResult BytecodeDialect::load(const DialectReader &reader,
1743 MLIRContext *ctx) {
1744 if (dialect)
1745 return success();
1746 Dialect *loadedDialect = ctx->getOrLoadDialect(name);
1747 if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
1748 return reader.emitError(msg: "dialect '")
1749 << name
1750 << "' is unknown. If this is intended, please call "
1751 "allowUnregisteredDialects() on the MLIRContext, or use "
1752 "-allow-unregistered-dialect with the MLIR tool used.";
1753 }
1754 dialect = loadedDialect;
1755
1756 // If the dialect was actually loaded, check to see if it has a bytecode
1757 // interface.
1758 if (loadedDialect)
1759 interface = dyn_cast<BytecodeDialectInterface>(Val: loadedDialect);
1760 if (!versionBuffer.empty()) {
1761 if (!interface)
1762 return reader.emitError(msg: "dialect '")
1763 << name
1764 << "' does not implement the bytecode interface, "
1765 "but found a version entry";
1766 EncodingReader encReader(versionBuffer, reader.getLoc());
1767 DialectReader versionReader = reader.withEncodingReader(encReader);
1768 loadedVersion = interface->readVersion(reader&: versionReader);
1769 if (!loadedVersion)
1770 return failure();
1771 }
1772 return success();
1773}
1774
1775LogicalResult
1776BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
1777 EncodingReader sectionReader(sectionData, fileLoc);
1778
1779 // Parse the number of dialects in the section.
1780 uint64_t numDialects;
1781 if (failed(Result: sectionReader.parseVarInt(result&: numDialects)))
1782 return failure();
1783 dialects.resize(N: numDialects);
1784
1785 // Parse each of the dialects.
1786 for (uint64_t i = 0; i < numDialects; ++i) {
1787 dialects[i] = std::make_unique<BytecodeDialect>();
1788 /// Before version kDialectVersioning, there wasn't any versioning available
1789 /// for dialects, and the entryIdx represent the string itself.
1790 if (version < bytecode::kDialectVersioning) {
1791 if (failed(Result: stringReader.parseString(reader&: sectionReader, result&: dialects[i]->name)))
1792 return failure();
1793 continue;
1794 }
1795
1796 // Parse ID representing dialect and version.
1797 uint64_t dialectNameIdx;
1798 bool versionAvailable;
1799 if (failed(Result: sectionReader.parseVarIntWithFlag(result&: dialectNameIdx,
1800 flag&: versionAvailable)))
1801 return failure();
1802 if (failed(Result: stringReader.parseStringAtIndex(reader&: sectionReader, index: dialectNameIdx,
1803 result&: dialects[i]->name)))
1804 return failure();
1805 if (versionAvailable) {
1806 bytecode::Section::ID sectionID;
1807 if (failed(Result: sectionReader.parseSection(sectionID,
1808 sectionData&: dialects[i]->versionBuffer)))
1809 return failure();
1810 if (sectionID != bytecode::Section::kDialectVersions) {
1811 emitError(loc: fileLoc, message: "expected dialect version section");
1812 return failure();
1813 }
1814 }
1815 dialectsMap[dialects[i]->name] = dialects[i].get();
1816 }
1817
1818 // Parse the operation names, which are grouped by dialect.
1819 auto parseOpName = [&](BytecodeDialect *dialect) {
1820 StringRef opName;
1821 std::optional<bool> wasRegistered;
1822 // Prior to version kNativePropertiesEncoding, the information about wheter
1823 // an op was registered or not wasn't encoded.
1824 if (version < bytecode::kNativePropertiesEncoding) {
1825 if (failed(Result: stringReader.parseString(reader&: sectionReader, result&: opName)))
1826 return failure();
1827 } else {
1828 bool wasRegisteredFlag;
1829 if (failed(Result: stringReader.parseStringWithFlag(reader&: sectionReader, result&: opName,
1830 flag&: wasRegisteredFlag)))
1831 return failure();
1832 wasRegistered = wasRegisteredFlag;
1833 }
1834 opNames.emplace_back(Args&: dialect, Args&: opName, Args&: wasRegistered);
1835 return success();
1836 };
1837 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation
1838 // where the number of ops are known.
1839 if (version >= bytecode::kElideUnknownBlockArgLocation) {
1840 uint64_t numOps;
1841 if (failed(Result: sectionReader.parseVarInt(result&: numOps)))
1842 return failure();
1843 opNames.reserve(N: numOps);
1844 }
1845 while (!sectionReader.empty())
1846 if (failed(Result: parseDialectGrouping(reader&: sectionReader, dialects, entryCallback: parseOpName)))
1847 return failure();
1848 return success();
1849}
1850
1851FailureOr<OperationName>
1852BytecodeReader::Impl::parseOpName(EncodingReader &reader,
1853 std::optional<bool> &wasRegistered) {
1854 BytecodeOperationName *opName = nullptr;
1855 if (failed(Result: parseEntry(reader, entries&: opNames, entry&: opName, entryStr: "operation name")))
1856 return failure();
1857 wasRegistered = opName->wasRegistered;
1858 // Check to see if this operation name has already been resolved. If we
1859 // haven't, load the dialect and build the operation name.
1860 if (!opName->opName) {
1861 // If the opName is empty, this is because we use to accept names such as
1862 // `foo` without any `.` separator. We shouldn't tolerate this in textual
1863 // format anymore but for now we'll be backward compatible. This can only
1864 // happen with unregistered dialects.
1865 if (opName->name.empty()) {
1866 opName->opName.emplace(args&: opName->dialect->name, args: getContext());
1867 } else {
1868 // Load the dialect and its version.
1869 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1870 dialectsMap, reader, version);
1871 if (failed(Result: opName->dialect->load(reader: dialectReader, ctx: getContext())))
1872 return failure();
1873 opName->opName.emplace(args: (opName->dialect->name + "." + opName->name).str(),
1874 args: getContext());
1875 }
1876 }
1877 return *opName->opName;
1878}
1879
1880//===----------------------------------------------------------------------===//
1881// Resource Section
1882//===----------------------------------------------------------------------===//
1883
1884LogicalResult BytecodeReader::Impl::parseResourceSection(
1885 EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
1886 std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
1887 // Ensure both sections are either present or not.
1888 if (resourceData.has_value() != resourceOffsetData.has_value()) {
1889 if (resourceOffsetData)
1890 return emitError(loc: fileLoc, message: "unexpected resource offset section when "
1891 "resource section is not present");
1892 return emitError(
1893 loc: fileLoc,
1894 message: "expected resource offset section when resource section is present");
1895 }
1896
1897 // If the resource sections are absent, there is nothing to do.
1898 if (!resourceData)
1899 return success();
1900
1901 // Initialize the resource reader with the resource sections.
1902 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
1903 dialectsMap, reader, version);
1904 return resourceReader.initialize(fileLoc, config, dialects, stringReader,
1905 sectionData: *resourceData, offsetSectionData: *resourceOffsetData,
1906 dialectReader, bufferOwnerRef);
1907}
1908
1909//===----------------------------------------------------------------------===//
1910// UseListOrder Helpers
1911//===----------------------------------------------------------------------===//
1912
1913FailureOr<BytecodeReader::Impl::UseListMapT>
1914BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
1915 uint64_t numResults) {
1916 BytecodeReader::Impl::UseListMapT map;
1917 uint64_t numValuesToRead = 1;
1918 if (numResults > 1 && failed(Result: reader.parseVarInt(result&: numValuesToRead)))
1919 return failure();
1920
1921 for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
1922 uint64_t resultIdx = 0;
1923 if (numResults > 1 && failed(Result: reader.parseVarInt(result&: resultIdx)))
1924 return failure();
1925
1926 uint64_t numValues;
1927 bool indexPairEncoding;
1928 if (failed(Result: reader.parseVarIntWithFlag(result&: numValues, flag&: indexPairEncoding)))
1929 return failure();
1930
1931 SmallVector<unsigned, 4> useListOrders;
1932 for (size_t idx = 0; idx < numValues; idx++) {
1933 uint64_t index;
1934 if (failed(Result: reader.parseVarInt(result&: index)))
1935 return failure();
1936 useListOrders.push_back(Elt: index);
1937 }
1938
1939 // Store in a map the result index
1940 map.try_emplace(Key: resultIdx, Args: UseListOrderStorage(indexPairEncoding,
1941 std::move(useListOrders)));
1942 }
1943
1944 return map;
1945}
1946
1947/// Sorts each use according to the order specified in the use-list parsed. If
1948/// the custom use-list is not found, this means that the order needs to be
1949/// consistent with the reverse pre-order walk of the IR. If multiple uses lie
1950/// on the same operation, the order will follow the reverse operand number
1951/// ordering.
1952LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
1953 // Early return for trivial use-lists.
1954 if (value.use_empty() || value.hasOneUse())
1955 return success();
1956
1957 bool hasIncomingOrder =
1958 valueToUseListMap.contains(Val: value.getAsOpaquePointer());
1959
1960 // Compute the current order of the use-list with respect to the global
1961 // ordering. Detect if the order is already sorted while doing so.
1962 bool alreadySorted = true;
1963 auto &firstUse = *value.use_begin();
1964 uint64_t prevID =
1965 bytecode::getUseID(val&: firstUse, ownerID: operationIDs.at(Val: firstUse.getOwner()));
1966 llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
1967 for (auto item : llvm::drop_begin(RangeOrContainer: llvm::enumerate(First: value.getUses()))) {
1968 uint64_t currentID = bytecode::getUseID(
1969 val&: item.value(), ownerID: operationIDs.at(Val: item.value().getOwner()));
1970 alreadySorted &= prevID > currentID;
1971 currentOrder.push_back(Elt: {item.index(), currentID});
1972 prevID = currentID;
1973 }
1974
1975 // If the order is already sorted, and there wasn't a custom order to apply
1976 // from the bytecode file, we are done.
1977 if (alreadySorted && !hasIncomingOrder)
1978 return success();
1979
1980 // If not already sorted, sort the indices of the current order by descending
1981 // useIDs.
1982 if (!alreadySorted)
1983 std::sort(
1984 first: currentOrder.begin(), last: currentOrder.end(),
1985 comp: [](auto elem1, auto elem2) { return elem1.second > elem2.second; });
1986
1987 if (!hasIncomingOrder) {
1988 // If the bytecode file did not contain any custom use-list order, it means
1989 // that the order was descending useID. Hence, shuffle by the first index
1990 // of the `currentOrder` pair.
1991 SmallVector<unsigned> shuffle(llvm::make_first_range(c&: currentOrder));
1992 value.shuffleUseList(indices: shuffle);
1993 return success();
1994 }
1995
1996 // Pull the custom order info from the map.
1997 UseListOrderStorage customOrder =
1998 valueToUseListMap.at(Val: value.getAsOpaquePointer());
1999 SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
2000 uint64_t numUses = value.getNumUses();
2001
2002 // If the encoding was a pair of indices `(src, dst)` for every permutation,
2003 // reconstruct the shuffle vector for every use. Initialize the shuffle vector
2004 // as identity, and then apply the mapping encoded in the indices.
2005 if (customOrder.isIndexPairEncoding) {
2006 // Return failure if the number of indices was not representing pairs.
2007 if (shuffle.size() & 1)
2008 return failure();
2009
2010 SmallVector<unsigned, 4> newShuffle(numUses);
2011 size_t idx = 0;
2012 std::iota(first: newShuffle.begin(), last: newShuffle.end(), value: idx);
2013 for (idx = 0; idx < shuffle.size(); idx += 2)
2014 newShuffle[shuffle[idx]] = shuffle[idx + 1];
2015
2016 shuffle = std::move(newShuffle);
2017 }
2018
2019 // Make sure that the indices represent a valid mapping. That is, the sum of
2020 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no
2021 // duplicates are allowed in the list.
2022 DenseSet<unsigned> set;
2023 uint64_t accumulator = 0;
2024 for (const auto &elem : shuffle) {
2025 if (!set.insert(V: elem).second)
2026 return failure();
2027 accumulator += elem;
2028 }
2029 if (numUses != shuffle.size() ||
2030 accumulator != (((numUses - 1) * numUses) >> 1))
2031 return failure();
2032
2033 // Apply the current ordering map onto the shuffle vector to get the final
2034 // use-list sorting indices before shuffling.
2035 shuffle = SmallVector<unsigned, 4>(llvm::map_range(
2036 C&: currentOrder, F: [&](auto item) { return shuffle[item.first]; }));
2037 value.shuffleUseList(indices: shuffle);
2038 return success();
2039}
2040
2041LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
2042 // Precompute operation IDs according to the pre-order walk of the IR. We
2043 // can't do this while parsing since parseRegions ordering is not strictly
2044 // equal to the pre-order walk.
2045 unsigned operationID = 0;
2046 topLevelOp->walk<mlir::WalkOrder::PreOrder>(
2047 callback: [&](Operation *op) { operationIDs.try_emplace(Key: op, Args: operationID++); });
2048
2049 auto blockWalk = topLevelOp->walk(callback: [this](Block *block) {
2050 for (auto arg : block->getArguments())
2051 if (failed(Result: sortUseListOrder(value: arg)))
2052 return WalkResult::interrupt();
2053 return WalkResult::advance();
2054 });
2055
2056 auto resultWalk = topLevelOp->walk(callback: [this](Operation *op) {
2057 for (auto result : op->getResults())
2058 if (failed(Result: sortUseListOrder(value: result)))
2059 return WalkResult::interrupt();
2060 return WalkResult::advance();
2061 });
2062
2063 return failure(IsFailure: blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
2064}
2065
2066//===----------------------------------------------------------------------===//
2067// IR Section
2068//===----------------------------------------------------------------------===//
2069
2070LogicalResult
2071BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
2072 Block *block) {
2073 EncodingReader reader(sectionData, fileLoc);
2074
2075 // A stack of operation regions currently being read from the bytecode.
2076 std::vector<RegionReadState> regionStack;
2077
2078 // Parse the top-level block using a temporary module operation.
2079 OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(loc: fileLoc);
2080 regionStack.emplace_back(args: *moduleOp, args: &reader, /*isIsolatedFromAbove=*/args: true);
2081 regionStack.back().curBlocks.push_back(Elt: moduleOp->getBody());
2082 regionStack.back().curBlock = regionStack.back().curRegion->begin();
2083 if (failed(Result: parseBlockHeader(reader, readState&: regionStack.back())))
2084 return failure();
2085 valueScopes.emplace_back();
2086 valueScopes.back().push(readState&: regionStack.back());
2087
2088 // Iteratively parse regions until everything has been resolved.
2089 while (!regionStack.empty())
2090 if (failed(Result: parseRegions(regionStack, readState&: regionStack.back())))
2091 return failure();
2092 if (!forwardRefOps.empty()) {
2093 return reader.emitError(
2094 args: "not all forward unresolved forward operand references");
2095 }
2096
2097 // Sort use-lists according to what specified in bytecode.
2098 if (failed(Result: processUseLists(topLevelOp: *moduleOp)))
2099 return reader.emitError(
2100 args: "parsed use-list orders were invalid and could not be applied");
2101
2102 // Resolve dialect version.
2103 for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
2104 // Parsing is complete, give an opportunity to each dialect to visit the
2105 // IR and perform upgrades.
2106 if (!byteCodeDialect->loadedVersion)
2107 continue;
2108 if (byteCodeDialect->interface &&
2109 failed(Result: byteCodeDialect->interface->upgradeFromVersion(
2110 topLevelOp: *moduleOp, version: *byteCodeDialect->loadedVersion)))
2111 return failure();
2112 }
2113
2114 // Verify that the parsed operations are valid.
2115 if (config.shouldVerifyAfterParse() && failed(Result: verify(op: *moduleOp)))
2116 return failure();
2117
2118 // Splice the parsed operations over to the provided top-level block.
2119 auto &parsedOps = moduleOp->getBody()->getOperations();
2120 auto &destOps = block->getOperations();
2121 destOps.splice(where: destOps.end(), L2&: parsedOps, first: parsedOps.begin(), last: parsedOps.end());
2122 return success();
2123}
2124
2125LogicalResult
2126BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
2127 RegionReadState &readState) {
2128 // Process regions, blocks, and operations until the end or if a nested
2129 // region is encountered. In this case we push a new state in regionStack and
2130 // return, the processing of the current region will resume afterward.
2131 for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
2132 // If the current block hasn't been setup yet, parse the header for this
2133 // region. The current block is already setup when this function was
2134 // interrupted to recurse down in a nested region and we resume the current
2135 // block after processing the nested region.
2136 if (readState.curBlock == Region::iterator()) {
2137 if (failed(Result: parseRegion(readState)))
2138 return failure();
2139
2140 // If the region is empty, there is nothing to more to do.
2141 if (readState.curRegion->empty())
2142 continue;
2143 }
2144
2145 // Parse the blocks within the region.
2146 EncodingReader &reader = *readState.reader;
2147 do {
2148 while (readState.numOpsRemaining--) {
2149 // Read in the next operation. We don't read its regions directly, we
2150 // handle those afterwards as necessary.
2151 bool isIsolatedFromAbove = false;
2152 FailureOr<Operation *> op =
2153 parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
2154 if (failed(Result: op))
2155 return failure();
2156
2157 // If the op has regions, add it to the stack for processing and return:
2158 // we stop the processing of the current region and resume it after the
2159 // inner one is completed. Unless LazyLoading is activated in which case
2160 // nested region parsing is delayed.
2161 if ((*op)->getNumRegions()) {
2162 RegionReadState childState(*op, &reader, isIsolatedFromAbove);
2163
2164 // Isolated regions are encoded as a section in version 2 and above.
2165 if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
2166 bytecode::Section::ID sectionID;
2167 ArrayRef<uint8_t> sectionData;
2168 if (failed(Result: reader.parseSection(sectionID, sectionData)))
2169 return failure();
2170 if (sectionID != bytecode::Section::kIR)
2171 return emitError(loc: fileLoc, message: "expected IR section for region");
2172 childState.owningReader =
2173 std::make_unique<EncodingReader>(args&: sectionData, args&: fileLoc);
2174 childState.reader = childState.owningReader.get();
2175
2176 // If the user has a callback set, they have the opportunity to
2177 // control lazyloading as we go.
2178 if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) {
2179 lazyLoadableOps.emplace_back(args&: *op, args: std::move(childState));
2180 lazyLoadableOpsMap.try_emplace(Key: *op,
2181 Args: std::prev(x: lazyLoadableOps.end()));
2182 continue;
2183 }
2184 }
2185 regionStack.push_back(x: std::move(childState));
2186
2187 // If the op is isolated from above, push a new value scope.
2188 if (isIsolatedFromAbove)
2189 valueScopes.emplace_back();
2190 return success();
2191 }
2192 }
2193
2194 // Move to the next block of the region.
2195 if (++readState.curBlock == readState.curRegion->end())
2196 break;
2197 if (failed(Result: parseBlockHeader(reader, readState)))
2198 return failure();
2199 } while (true);
2200
2201 // Reset the current block and any values reserved for this region.
2202 readState.curBlock = {};
2203 valueScopes.back().pop(readState);
2204 }
2205
2206 // When the regions have been fully parsed, pop them off of the read stack. If
2207 // the regions were isolated from above, we also pop the last value scope.
2208 if (readState.isIsolatedFromAbove) {
2209 assert(!valueScopes.empty() && "Expect a valueScope after reading region");
2210 valueScopes.pop_back();
2211 }
2212 assert(!regionStack.empty() && "Expect a regionStack after reading region");
2213 regionStack.pop_back();
2214 return success();
2215}
2216
2217FailureOr<Operation *>
2218BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
2219 RegionReadState &readState,
2220 bool &isIsolatedFromAbove) {
2221 // Parse the name of the operation.
2222 std::optional<bool> wasRegistered;
2223 FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
2224 if (failed(Result: opName))
2225 return failure();
2226
2227 // Parse the operation mask, which indicates which components of the operation
2228 // are present.
2229 uint8_t opMask;
2230 if (failed(Result: reader.parseByte(value&: opMask)))
2231 return failure();
2232
2233 /// Parse the location.
2234 LocationAttr opLoc;
2235 if (failed(Result: parseAttribute(reader, result&: opLoc)))
2236 return failure();
2237
2238 // With the location and name resolved, we can start building the operation
2239 // state.
2240 OperationState opState(opLoc, *opName);
2241
2242 // Parse the attributes of the operation.
2243 if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
2244 DictionaryAttr dictAttr;
2245 if (failed(Result: parseAttribute(reader, result&: dictAttr)))
2246 return failure();
2247 opState.attributes = dictAttr;
2248 }
2249
2250 if (opMask & bytecode::OpEncodingMask::kHasProperties) {
2251 // kHasProperties wasn't emitted in older bytecode, we should never get
2252 // there without also having the `wasRegistered` flag available.
2253 if (!wasRegistered)
2254 return emitError(loc: fileLoc,
2255 message: "Unexpected missing `wasRegistered` opname flag at "
2256 "bytecode version ")
2257 << version << " with properties.";
2258 // When an operation is emitted without being registered, the properties are
2259 // stored as an attribute. Otherwise the op must implement the bytecode
2260 // interface and control the serialization.
2261 if (wasRegistered) {
2262 DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
2263 dialectsMap, reader, version);
2264 if (failed(
2265 Result: propertiesReader.read(fileLoc, dialectReader, opName: &*opName, opState)))
2266 return failure();
2267 } else {
2268 // If the operation wasn't registered when it was emitted, the properties
2269 // was serialized as an attribute.
2270 if (failed(Result: parseAttribute(reader, result&: opState.propertiesAttr)))
2271 return failure();
2272 }
2273 }
2274
2275 /// Parse the results of the operation.
2276 if (opMask & bytecode::OpEncodingMask::kHasResults) {
2277 uint64_t numResults;
2278 if (failed(Result: reader.parseVarInt(result&: numResults)))
2279 return failure();
2280 opState.types.resize(N: numResults);
2281 for (int i = 0, e = numResults; i < e; ++i)
2282 if (failed(Result: parseType(reader, result&: opState.types[i])))
2283 return failure();
2284 }
2285
2286 /// Parse the operands of the operation.
2287 if (opMask & bytecode::OpEncodingMask::kHasOperands) {
2288 uint64_t numOperands;
2289 if (failed(Result: reader.parseVarInt(result&: numOperands)))
2290 return failure();
2291 opState.operands.resize(N: numOperands);
2292 for (int i = 0, e = numOperands; i < e; ++i)
2293 if (!(opState.operands[i] = parseOperand(reader)))
2294 return failure();
2295 }
2296
2297 /// Parse the successors of the operation.
2298 if (opMask & bytecode::OpEncodingMask::kHasSuccessors) {
2299 uint64_t numSuccs;
2300 if (failed(Result: reader.parseVarInt(result&: numSuccs)))
2301 return failure();
2302 opState.successors.resize(N: numSuccs);
2303 for (int i = 0, e = numSuccs; i < e; ++i) {
2304 if (failed(Result: parseEntry(reader, entries&: readState.curBlocks, entry&: opState.successors[i],
2305 entryStr: "successor")))
2306 return failure();
2307 }
2308 }
2309
2310 /// Parse the use-list orders for the results of the operation. Use-list
2311 /// orders are available since version 3 of the bytecode.
2312 std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
2313 if (version >= bytecode::kUseListOrdering &&
2314 (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) {
2315 size_t numResults = opState.types.size();
2316 auto parseResult = parseUseListOrderForRange(reader, numResults);
2317 if (failed(Result: parseResult))
2318 return failure();
2319 resultIdxToUseListMap = std::move(*parseResult);
2320 }
2321
2322 /// Parse the regions of the operation.
2323 if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
2324 uint64_t numRegions;
2325 if (failed(Result: reader.parseVarIntWithFlag(result&: numRegions, flag&: isIsolatedFromAbove)))
2326 return failure();
2327
2328 opState.regions.reserve(N: numRegions);
2329 for (int i = 0, e = numRegions; i < e; ++i)
2330 opState.regions.push_back(Elt: std::make_unique<Region>());
2331 }
2332
2333 // Create the operation at the back of the current block.
2334 Operation *op = Operation::create(state: opState);
2335 readState.curBlock->push_back(op);
2336
2337 // If the operation had results, update the value references. We don't need to
2338 // do this if the current value scope is empty. That is, the op was not
2339 // encoded within a parent region.
2340 if (readState.numValues && op->getNumResults() &&
2341 failed(Result: defineValues(reader, values: op->getResults())))
2342 return failure();
2343
2344 /// Store a map for every value that received a custom use-list order from the
2345 /// bytecode file.
2346 if (resultIdxToUseListMap.has_value()) {
2347 for (size_t idx = 0; idx < op->getNumResults(); idx++) {
2348 if (resultIdxToUseListMap->contains(Val: idx)) {
2349 valueToUseListMap.try_emplace(Key: op->getResult(idx).getAsOpaquePointer(),
2350 Args: resultIdxToUseListMap->at(Val: idx));
2351 }
2352 }
2353 }
2354 return op;
2355}
2356
2357LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
2358 EncodingReader &reader = *readState.reader;
2359
2360 // Parse the number of blocks in the region.
2361 uint64_t numBlocks;
2362 if (failed(Result: reader.parseVarInt(result&: numBlocks)))
2363 return failure();
2364
2365 // If the region is empty, there is nothing else to do.
2366 if (numBlocks == 0)
2367 return success();
2368
2369 // Parse the number of values defined in this region.
2370 uint64_t numValues;
2371 if (failed(Result: reader.parseVarInt(result&: numValues)))
2372 return failure();
2373 readState.numValues = numValues;
2374
2375 // Create the blocks within this region. We do this before processing so that
2376 // we can rely on the blocks existing when creating operations.
2377 readState.curBlocks.clear();
2378 readState.curBlocks.reserve(N: numBlocks);
2379 for (uint64_t i = 0; i < numBlocks; ++i) {
2380 readState.curBlocks.push_back(Elt: new Block());
2381 readState.curRegion->push_back(block: readState.curBlocks.back());
2382 }
2383
2384 // Prepare the current value scope for this region.
2385 valueScopes.back().push(readState);
2386
2387 // Parse the entry block of the region.
2388 readState.curBlock = readState.curRegion->begin();
2389 return parseBlockHeader(reader, readState);
2390}
2391
2392LogicalResult
2393BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
2394 RegionReadState &readState) {
2395 bool hasArgs;
2396 if (failed(Result: reader.parseVarIntWithFlag(result&: readState.numOpsRemaining, flag&: hasArgs)))
2397 return failure();
2398
2399 // Parse the arguments of the block.
2400 if (hasArgs && failed(Result: parseBlockArguments(reader, block: &*readState.curBlock)))
2401 return failure();
2402
2403 // Uselist orders are available since version 3 of the bytecode.
2404 if (version < bytecode::kUseListOrdering)
2405 return success();
2406
2407 uint8_t hasUseListOrders = 0;
2408 if (hasArgs && failed(Result: reader.parseByte(value&: hasUseListOrders)))
2409 return failure();
2410
2411 if (!hasUseListOrders)
2412 return success();
2413
2414 Block &blk = *readState.curBlock;
2415 auto argIdxToUseListMap =
2416 parseUseListOrderForRange(reader, numResults: blk.getNumArguments());
2417 if (failed(Result: argIdxToUseListMap) || argIdxToUseListMap->empty())
2418 return failure();
2419
2420 for (size_t idx = 0; idx < blk.getNumArguments(); idx++)
2421 if (argIdxToUseListMap->contains(Val: idx))
2422 valueToUseListMap.try_emplace(Key: blk.getArgument(i: idx).getAsOpaquePointer(),
2423 Args: argIdxToUseListMap->at(Val: idx));
2424
2425 // We don't parse the operations of the block here, that's done elsewhere.
2426 return success();
2427}
2428
2429LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
2430 Block *block) {
2431 // Parse the value ID for the first argument, and the number of arguments.
2432 uint64_t numArgs;
2433 if (failed(Result: reader.parseVarInt(result&: numArgs)))
2434 return failure();
2435
2436 SmallVector<Type> argTypes;
2437 SmallVector<Location> argLocs;
2438 argTypes.reserve(N: numArgs);
2439 argLocs.reserve(N: numArgs);
2440
2441 Location unknownLoc = UnknownLoc::get(context: config.getContext());
2442 while (numArgs--) {
2443 Type argType;
2444 LocationAttr argLoc = unknownLoc;
2445 if (version >= bytecode::kElideUnknownBlockArgLocation) {
2446 // Parse the type with hasLoc flag to determine if it has type.
2447 uint64_t typeIdx;
2448 bool hasLoc;
2449 if (failed(Result: reader.parseVarIntWithFlag(result&: typeIdx, flag&: hasLoc)) ||
2450 !(argType = attrTypeReader.resolveType(index: typeIdx)))
2451 return failure();
2452 if (hasLoc && failed(Result: parseAttribute(reader, result&: argLoc)))
2453 return failure();
2454 } else {
2455 // All args has type and location.
2456 if (failed(Result: parseType(reader, result&: argType)) ||
2457 failed(Result: parseAttribute(reader, result&: argLoc)))
2458 return failure();
2459 }
2460 argTypes.push_back(Elt: argType);
2461 argLocs.push_back(Elt: argLoc);
2462 }
2463 block->addArguments(types: argTypes, locs: argLocs);
2464 return defineValues(reader, values: block->getArguments());
2465}
2466
2467//===----------------------------------------------------------------------===//
2468// Value Processing
2469//===----------------------------------------------------------------------===//
2470
2471Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
2472 std::vector<Value> &values = valueScopes.back().values;
2473 Value *value = nullptr;
2474 if (failed(Result: parseEntry(reader, entries&: values, entry&: value, entryStr: "value")))
2475 return Value();
2476
2477 // Create a new forward reference if necessary.
2478 if (!*value)
2479 *value = createForwardRef();
2480 return *value;
2481}
2482
2483LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
2484 ValueRange newValues) {
2485 ValueScope &valueScope = valueScopes.back();
2486 std::vector<Value> &values = valueScope.values;
2487
2488 unsigned &valueID = valueScope.nextValueIDs.back();
2489 unsigned valueIDEnd = valueID + newValues.size();
2490 if (valueIDEnd > values.size()) {
2491 return reader.emitError(
2492 args: "value index range was outside of the expected range for "
2493 "the parent region, got [",
2494 args&: valueID, args: ", ", args&: valueIDEnd, args: "), but the maximum index was ",
2495 args: values.size() - 1);
2496 }
2497
2498 // Assign the values and update any forward references.
2499 for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
2500 Value newValue = newValues[i];
2501
2502 // Check to see if a definition for this value already exists.
2503 if (Value oldValue = std::exchange(obj&: values[valueID], new_val&: newValue)) {
2504 Operation *forwardRefOp = oldValue.getDefiningOp();
2505
2506 // Assert that this is a forward reference operation. Given how we compute
2507 // definition ids (incrementally as we parse), it shouldn't be possible
2508 // for the value to be defined any other way.
2509 assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&
2510 "value index was already defined?");
2511
2512 oldValue.replaceAllUsesWith(newValue);
2513 forwardRefOp->moveBefore(block: &openForwardRefOps, iterator: openForwardRefOps.end());
2514 }
2515 }
2516 return success();
2517}
2518
2519Value BytecodeReader::Impl::createForwardRef() {
2520 // Check for an available existing operation to use. Otherwise, create a new
2521 // fake operation to use for the reference.
2522 if (!openForwardRefOps.empty()) {
2523 Operation *op = &openForwardRefOps.back();
2524 op->moveBefore(block: &forwardRefOps, iterator: forwardRefOps.end());
2525 } else {
2526 forwardRefOps.push_back(op: Operation::create(state: forwardRefOpState));
2527 }
2528 return forwardRefOps.back().getResult(idx: 0);
2529}
2530
2531//===----------------------------------------------------------------------===//
2532// Entry Points
2533//===----------------------------------------------------------------------===//
2534
2535BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); }
2536
2537BytecodeReader::BytecodeReader(
2538 llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
2539 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2540 Location sourceFileLoc =
2541 FileLineColLoc::get(context: config.getContext(), fileName: buffer.getBufferIdentifier(),
2542 /*line=*/0, /*column=*/0);
2543 impl = std::make_unique<Impl>(args&: sourceFileLoc, args: config, args&: lazyLoading, args&: buffer,
2544 args: bufferOwnerRef);
2545}
2546
2547LogicalResult BytecodeReader::readTopLevel(
2548 Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2549 return impl->read(block, lazyOpsCallback);
2550}
2551
2552int64_t BytecodeReader::getNumOpsToMaterialize() const {
2553 return impl->getNumOpsToMaterialize();
2554}
2555
2556bool BytecodeReader::isMaterializable(Operation *op) {
2557 return impl->isMaterializable(op);
2558}
2559
2560LogicalResult BytecodeReader::materialize(
2561 Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
2562 return impl->materialize(op, lazyOpsCallback);
2563}
2564
2565LogicalResult
2566BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) {
2567 return impl->finalize(shouldMaterialize);
2568}
2569
2570bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
2571 return buffer.getBuffer().starts_with(Prefix: "ML\xefR");
2572}
2573
2574/// Read the bytecode from the provided memory buffer reference.
2575/// `bufferOwnerRef` if provided is the owning source manager for the buffer,
2576/// and may be used to extend the lifetime of the buffer.
2577static LogicalResult
2578readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
2579 const ParserConfig &config,
2580 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
2581 Location sourceFileLoc =
2582 FileLineColLoc::get(context: config.getContext(), fileName: buffer.getBufferIdentifier(),
2583 /*line=*/0, /*column=*/0);
2584 if (!isBytecode(buffer)) {
2585 return emitError(loc: sourceFileLoc,
2586 message: "input buffer is not an MLIR bytecode file");
2587 }
2588
2589 BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false,
2590 buffer, bufferOwnerRef);
2591 return reader.read(block, /*lazyOpsCallback=*/nullptr);
2592}
2593
2594LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
2595 const ParserConfig &config) {
2596 return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{});
2597}
2598LogicalResult
2599mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
2600 Block *block, const ParserConfig &config) {
2601 return readBytecodeFileImpl(
2602 buffer: *sourceMgr->getMemoryBuffer(i: sourceMgr->getMainFileID()), block, config,
2603 bufferOwnerRef: sourceMgr);
2604}
2605

source code of mlir/lib/Bytecode/Reader/BytecodeReader.cpp