1 | //===- BytecodeWriter.cpp - MLIR Bytecode Writer --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Bytecode/BytecodeWriter.h" |
10 | #include "IRNumbering.h" |
11 | #include "mlir/Bytecode/BytecodeImplementation.h" |
12 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
13 | #include "mlir/Bytecode/Encoding.h" |
14 | #include "mlir/IR/Attributes.h" |
15 | #include "mlir/IR/Diagnostics.h" |
16 | #include "mlir/IR/OpImplementation.h" |
17 | #include "mlir/Support/LogicalResult.h" |
18 | #include "llvm/ADT/ArrayRef.h" |
19 | #include "llvm/ADT/CachedHashString.h" |
20 | #include "llvm/ADT/MapVector.h" |
21 | #include "llvm/ADT/SmallVector.h" |
22 | #include "llvm/Support/Endian.h" |
23 | #include "llvm/Support/raw_ostream.h" |
24 | #include <optional> |
25 | |
26 | #define DEBUG_TYPE "mlir-bytecode-writer" |
27 | |
28 | using namespace mlir; |
29 | using namespace mlir::bytecode::detail; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // BytecodeWriterConfig |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | struct BytecodeWriterConfig::Impl { |
36 | Impl(StringRef producer) : producer(producer) {} |
37 | |
38 | /// Version to use when writing. |
39 | /// Note: This only differs from kVersion if a specific version is set. |
40 | int64_t bytecodeVersion = bytecode::kVersion; |
41 | |
42 | /// A flag specifying whether to elide emission of resources into the bytecode |
43 | /// file. |
44 | bool shouldElideResourceData = false; |
45 | |
46 | /// A map containing dialect version information for each dialect to emit. |
47 | llvm::StringMap<std::unique_ptr<DialectVersion>> dialectVersionMap; |
48 | |
49 | /// The producer of the bytecode. |
50 | StringRef producer; |
51 | |
52 | /// Printer callbacks used to emit custom type and attribute encodings. |
53 | llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> |
54 | attributeWriterCallbacks; |
55 | llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> |
56 | typeWriterCallbacks; |
57 | |
58 | /// A collection of non-dialect resource printers. |
59 | SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; |
60 | }; |
61 | |
62 | BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer) |
63 | : impl(std::make_unique<Impl>(args&: producer)) {} |
64 | BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map, |
65 | StringRef producer) |
66 | : BytecodeWriterConfig(producer) { |
67 | attachFallbackResourcePrinter(map); |
68 | } |
69 | BytecodeWriterConfig::~BytecodeWriterConfig() = default; |
70 | |
71 | ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>> |
72 | BytecodeWriterConfig::getAttributeWriterCallbacks() const { |
73 | return impl->attributeWriterCallbacks; |
74 | } |
75 | |
76 | ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>> |
77 | BytecodeWriterConfig::getTypeWriterCallbacks() const { |
78 | return impl->typeWriterCallbacks; |
79 | } |
80 | |
81 | void BytecodeWriterConfig::attachAttributeCallback( |
82 | std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback) { |
83 | impl->attributeWriterCallbacks.emplace_back(Args: std::move(callback)); |
84 | } |
85 | |
86 | void BytecodeWriterConfig::attachTypeCallback( |
87 | std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback) { |
88 | impl->typeWriterCallbacks.emplace_back(Args: std::move(callback)); |
89 | } |
90 | |
91 | void BytecodeWriterConfig::attachResourcePrinter( |
92 | std::unique_ptr<AsmResourcePrinter> printer) { |
93 | impl->externalResourcePrinters.emplace_back(Args: std::move(printer)); |
94 | } |
95 | |
96 | void BytecodeWriterConfig::setElideResourceDataFlag( |
97 | bool shouldElideResourceData) { |
98 | impl->shouldElideResourceData = shouldElideResourceData; |
99 | } |
100 | |
101 | void BytecodeWriterConfig::setDesiredBytecodeVersion(int64_t bytecodeVersion) { |
102 | impl->bytecodeVersion = bytecodeVersion; |
103 | } |
104 | |
105 | int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const { |
106 | return impl->bytecodeVersion; |
107 | } |
108 | |
109 | llvm::StringMap<std::unique_ptr<DialectVersion>> & |
110 | BytecodeWriterConfig::getDialectVersionMap() const { |
111 | return impl->dialectVersionMap; |
112 | } |
113 | |
114 | void BytecodeWriterConfig::setDialectVersion( |
115 | llvm::StringRef dialectName, |
116 | std::unique_ptr<DialectVersion> dialectVersion) const { |
117 | assert(!impl->dialectVersionMap.contains(dialectName) && |
118 | "cannot override a previously set dialect version" ); |
119 | impl->dialectVersionMap.insert(KV: {dialectName, std::move(dialectVersion)}); |
120 | } |
121 | |
122 | //===----------------------------------------------------------------------===// |
123 | // EncodingEmitter |
124 | //===----------------------------------------------------------------------===// |
125 | |
126 | namespace { |
127 | /// This class functions as the underlying encoding emitter for the bytecode |
128 | /// writer. This class is a bit different compared to other types of encoders; |
129 | /// it does not use a single buffer, but instead may contain several buffers |
130 | /// (some owned by the writer, and some not) that get concatted during the final |
131 | /// emission. |
132 | class EncodingEmitter { |
133 | public: |
134 | EncodingEmitter() = default; |
135 | EncodingEmitter(const EncodingEmitter &) = delete; |
136 | EncodingEmitter &operator=(const EncodingEmitter &) = delete; |
137 | |
138 | /// Write the current contents to the provided stream. |
139 | void writeTo(raw_ostream &os) const; |
140 | |
141 | /// Return the current size of the encoded buffer. |
142 | size_t size() const { return prevResultSize + currentResult.size(); } |
143 | |
144 | //===--------------------------------------------------------------------===// |
145 | // Emission |
146 | //===--------------------------------------------------------------------===// |
147 | |
148 | /// Backpatch a byte in the result buffer at the given offset. |
149 | void patchByte(uint64_t offset, uint8_t value) { |
150 | assert(offset < size() && offset >= prevResultSize && |
151 | "cannot patch previously emitted data" ); |
152 | currentResult[offset - prevResultSize] = value; |
153 | } |
154 | |
155 | /// Emit the provided blob of data, which is owned by the caller and is |
156 | /// guaranteed to not die before the end of the bytecode process. |
157 | void emitOwnedBlob(ArrayRef<uint8_t> data) { |
158 | // Push the current buffer before adding the provided data. |
159 | appendResult(result: std::move(currentResult)); |
160 | appendOwnedResult(result: data); |
161 | } |
162 | |
163 | /// Emit the provided blob of data that has the given alignment, which is |
164 | /// owned by the caller and is guaranteed to not die before the end of the |
165 | /// bytecode process. The alignment value is also encoded, making it available |
166 | /// on load. |
167 | void emitOwnedBlobAndAlignment(ArrayRef<uint8_t> data, uint32_t alignment) { |
168 | emitVarInt(value: alignment); |
169 | emitVarInt(value: data.size()); |
170 | |
171 | alignTo(alignment); |
172 | emitOwnedBlob(data); |
173 | } |
174 | void emitOwnedBlobAndAlignment(ArrayRef<char> data, uint32_t alignment) { |
175 | ArrayRef<uint8_t> castedData(reinterpret_cast<const uint8_t *>(data.data()), |
176 | data.size()); |
177 | emitOwnedBlobAndAlignment(data: castedData, alignment); |
178 | } |
179 | |
180 | /// Align the emitter to the given alignment. |
181 | void alignTo(unsigned alignment) { |
182 | if (alignment < 2) |
183 | return; |
184 | assert(llvm::isPowerOf2_32(alignment) && "expected valid alignment" ); |
185 | |
186 | // Check to see if we need to emit any padding bytes to meet the desired |
187 | // alignment. |
188 | size_t curOffset = size(); |
189 | size_t paddingSize = llvm::alignTo(Value: curOffset, Align: alignment) - curOffset; |
190 | while (paddingSize--) |
191 | emitByte(byte: bytecode::kAlignmentByte); |
192 | |
193 | // Keep track of the maximum required alignment. |
194 | requiredAlignment = std::max(a: requiredAlignment, b: alignment); |
195 | } |
196 | |
197 | //===--------------------------------------------------------------------===// |
198 | // Integer Emission |
199 | |
200 | /// Emit a single byte. |
201 | template <typename T> |
202 | void emitByte(T byte) { |
203 | currentResult.push_back(x: static_cast<uint8_t>(byte)); |
204 | } |
205 | |
206 | /// Emit a range of bytes. |
207 | void emitBytes(ArrayRef<uint8_t> bytes) { |
208 | llvm::append_range(C&: currentResult, R&: bytes); |
209 | } |
210 | |
211 | /// Emit a variable length integer. The first encoded byte contains a prefix |
212 | /// in the low bits indicating the encoded length of the value. This length |
213 | /// prefix is a bit sequence of '0's followed by a '1'. The number of '0' bits |
214 | /// indicate the number of _additional_ bytes (not including the prefix byte). |
215 | /// All remaining bits in the first byte, along with all of the bits in |
216 | /// additional bytes, provide the value of the integer encoded in |
217 | /// little-endian order. |
218 | void emitVarInt(uint64_t value) { |
219 | // In the most common case, the value can be represented in a single byte. |
220 | // Given how hot this case is, explicitly handle that here. |
221 | if ((value >> 7) == 0) |
222 | return emitByte(byte: (value << 1) | 0x1); |
223 | emitMultiByteVarInt(value); |
224 | } |
225 | |
226 | /// Emit a signed variable length integer. Signed varints are encoded using |
227 | /// a varint with zigzag encoding, meaning that we use the low bit of the |
228 | /// value to indicate the sign of the value. This allows for more efficient |
229 | /// encoding of negative values by limiting the number of active bits |
230 | void emitSignedVarInt(uint64_t value) { |
231 | emitVarInt(value: (value << 1) ^ (uint64_t)((int64_t)value >> 63)); |
232 | } |
233 | |
234 | /// Emit a variable length integer whose low bit is used to encode the |
235 | /// provided flag, i.e. encoded as: (value << 1) | (flag ? 1 : 0). |
236 | void emitVarIntWithFlag(uint64_t value, bool flag) { |
237 | emitVarInt(value: (value << 1) | (flag ? 1 : 0)); |
238 | } |
239 | |
240 | //===--------------------------------------------------------------------===// |
241 | // String Emission |
242 | |
243 | /// Emit the given string as a nul terminated string. |
244 | void emitNulTerminatedString(StringRef str) { |
245 | emitString(str); |
246 | emitByte(byte: 0); |
247 | } |
248 | |
249 | /// Emit the given string without a nul terminator. |
250 | void emitString(StringRef str) { |
251 | emitBytes(bytes: {reinterpret_cast<const uint8_t *>(str.data()), str.size()}); |
252 | } |
253 | |
254 | //===--------------------------------------------------------------------===// |
255 | // Section Emission |
256 | |
257 | /// Emit a nested section of the given code, whose contents are encoded in the |
258 | /// provided emitter. |
259 | void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) { |
260 | // Emit the section code and length. The high bit of the code is used to |
261 | // indicate whether the section alignment is present, so save an offset to |
262 | // it. |
263 | uint64_t codeOffset = currentResult.size(); |
264 | emitByte(byte: code); |
265 | emitVarInt(value: emitter.size()); |
266 | |
267 | // Integrate the alignment of the section into this emitter if necessary. |
268 | unsigned emitterAlign = emitter.requiredAlignment; |
269 | if (emitterAlign > 1) { |
270 | if (size() & (emitterAlign - 1)) { |
271 | emitVarInt(value: emitterAlign); |
272 | alignTo(alignment: emitterAlign); |
273 | |
274 | // Indicate that we needed to align the section, the high bit of the |
275 | // code field is used for this. |
276 | currentResult[codeOffset] |= 0b10000000; |
277 | } else { |
278 | // Otherwise, if we happen to be at a compatible offset, we just |
279 | // remember that we need this alignment. |
280 | requiredAlignment = std::max(a: requiredAlignment, b: emitterAlign); |
281 | } |
282 | } |
283 | |
284 | // Push our current buffer and then merge the provided section body into |
285 | // ours. |
286 | appendResult(result: std::move(currentResult)); |
287 | for (std::vector<uint8_t> &result : emitter.prevResultStorage) |
288 | prevResultStorage.push_back(x: std::move(result)); |
289 | llvm::append_range(C&: prevResultList, R&: emitter.prevResultList); |
290 | prevResultSize += emitter.prevResultSize; |
291 | appendResult(result: std::move(emitter.currentResult)); |
292 | } |
293 | |
294 | private: |
295 | /// Emit the given value using a variable width encoding. This method is a |
296 | /// fallback when the number of bytes needed to encode the value is greater |
297 | /// than 1. We mark it noinline here so that the single byte hot path isn't |
298 | /// pessimized. |
299 | LLVM_ATTRIBUTE_NOINLINE void emitMultiByteVarInt(uint64_t value); |
300 | |
301 | /// Append a new result buffer to the current contents. |
302 | void appendResult(std::vector<uint8_t> &&result) { |
303 | if (result.empty()) |
304 | return; |
305 | prevResultStorage.emplace_back(args: std::move(result)); |
306 | appendOwnedResult(result: prevResultStorage.back()); |
307 | } |
308 | void appendOwnedResult(ArrayRef<uint8_t> result) { |
309 | if (result.empty()) |
310 | return; |
311 | prevResultSize += result.size(); |
312 | prevResultList.emplace_back(args&: result); |
313 | } |
314 | |
315 | /// The result of the emitter currently being built. We refrain from building |
316 | /// a single buffer to simplify emitting sections, large data, and more. The |
317 | /// result is thus represented using multiple distinct buffers, some of which |
318 | /// we own (via prevResultStorage), and some of which are just pointers into |
319 | /// externally owned buffers. |
320 | std::vector<uint8_t> currentResult; |
321 | std::vector<ArrayRef<uint8_t>> prevResultList; |
322 | std::vector<std::vector<uint8_t>> prevResultStorage; |
323 | |
324 | /// An up-to-date total size of all of the buffers within `prevResultList`. |
325 | /// This enables O(1) size checks of the current encoding. |
326 | size_t prevResultSize = 0; |
327 | |
328 | /// The highest required alignment for the start of this section. |
329 | unsigned requiredAlignment = 1; |
330 | }; |
331 | |
332 | //===----------------------------------------------------------------------===// |
333 | // StringSectionBuilder |
334 | //===----------------------------------------------------------------------===// |
335 | |
336 | namespace { |
337 | /// This class is used to simplify the process of emitting the string section. |
338 | class StringSectionBuilder { |
339 | public: |
340 | /// Add the given string to the string section, and return the index of the |
341 | /// string within the section. |
342 | size_t insert(StringRef str) { |
343 | auto it = strings.insert(KV: {llvm::CachedHashStringRef(str), strings.size()}); |
344 | return it.first->second; |
345 | } |
346 | |
347 | /// Write the current set of strings to the given emitter. |
348 | void write(EncodingEmitter &emitter) { |
349 | emitter.emitVarInt(value: strings.size()); |
350 | |
351 | // Emit the sizes in reverse order, so that we don't need to backpatch an |
352 | // offset to the string data or have a separate section. |
353 | for (const auto &it : llvm::reverse(C&: strings)) |
354 | emitter.emitVarInt(value: it.first.size() + 1); |
355 | // Emit the string data itself. |
356 | for (const auto &it : strings) |
357 | emitter.emitNulTerminatedString(str: it.first.val()); |
358 | } |
359 | |
360 | private: |
361 | /// A set of strings referenced within the bytecode. The value of the map is |
362 | /// unused. |
363 | llvm::MapVector<llvm::CachedHashStringRef, size_t> strings; |
364 | }; |
365 | } // namespace |
366 | |
367 | class DialectWriter : public DialectBytecodeWriter { |
368 | using DialectVersionMapT = llvm::StringMap<std::unique_ptr<DialectVersion>>; |
369 | |
370 | public: |
371 | DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter, |
372 | IRNumberingState &numberingState, |
373 | StringSectionBuilder &stringSection, |
374 | const DialectVersionMapT &dialectVersionMap) |
375 | : bytecodeVersion(bytecodeVersion), emitter(emitter), |
376 | numberingState(numberingState), stringSection(stringSection), |
377 | dialectVersionMap(dialectVersionMap) {} |
378 | |
379 | //===--------------------------------------------------------------------===// |
380 | // IR |
381 | //===--------------------------------------------------------------------===// |
382 | |
383 | void writeAttribute(Attribute attr) override { |
384 | emitter.emitVarInt(value: numberingState.getNumber(attr)); |
385 | } |
386 | void writeOptionalAttribute(Attribute attr) override { |
387 | if (!attr) { |
388 | emitter.emitVarInt(value: 0); |
389 | return; |
390 | } |
391 | emitter.emitVarIntWithFlag(value: numberingState.getNumber(attr), flag: true); |
392 | } |
393 | |
394 | void writeType(Type type) override { |
395 | emitter.emitVarInt(value: numberingState.getNumber(type)); |
396 | } |
397 | |
398 | void writeResourceHandle(const AsmDialectResourceHandle &resource) override { |
399 | emitter.emitVarInt(value: numberingState.getNumber(resource)); |
400 | } |
401 | |
402 | //===--------------------------------------------------------------------===// |
403 | // Primitives |
404 | //===--------------------------------------------------------------------===// |
405 | |
406 | void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } |
407 | |
408 | void writeSignedVarInt(int64_t value) override { |
409 | emitter.emitSignedVarInt(value); |
410 | } |
411 | |
412 | void writeAPIntWithKnownWidth(const APInt &value) override { |
413 | size_t bitWidth = value.getBitWidth(); |
414 | |
415 | // If the value is a single byte, just emit it directly without going |
416 | // through a varint. |
417 | if (bitWidth <= 8) |
418 | return emitter.emitByte(byte: value.getLimitedValue()); |
419 | |
420 | // If the value fits within a single varint, emit it directly. |
421 | if (bitWidth <= 64) |
422 | return emitter.emitSignedVarInt(value: value.getLimitedValue()); |
423 | |
424 | // Otherwise, we need to encode a variable number of active words. We use |
425 | // active words instead of the number of total words under the observation |
426 | // that smaller values will be more common. |
427 | unsigned numActiveWords = value.getActiveWords(); |
428 | emitter.emitVarInt(value: numActiveWords); |
429 | |
430 | const uint64_t *rawValueData = value.getRawData(); |
431 | for (unsigned i = 0; i < numActiveWords; ++i) |
432 | emitter.emitSignedVarInt(value: rawValueData[i]); |
433 | } |
434 | |
435 | void writeAPFloatWithKnownSemantics(const APFloat &value) override { |
436 | writeAPIntWithKnownWidth(value: value.bitcastToAPInt()); |
437 | } |
438 | |
439 | void writeOwnedString(StringRef str) override { |
440 | emitter.emitVarInt(value: stringSection.insert(str)); |
441 | } |
442 | |
443 | void writeOwnedBlob(ArrayRef<char> blob) override { |
444 | emitter.emitVarInt(value: blob.size()); |
445 | emitter.emitOwnedBlob(data: ArrayRef<uint8_t>( |
446 | reinterpret_cast<const uint8_t *>(blob.data()), blob.size())); |
447 | } |
448 | |
449 | void writeOwnedBool(bool value) override { emitter.emitByte(byte: value); } |
450 | |
451 | int64_t getBytecodeVersion() const override { return bytecodeVersion; } |
452 | |
453 | FailureOr<const DialectVersion *> |
454 | getDialectVersion(StringRef dialectName) const override { |
455 | auto dialectEntry = dialectVersionMap.find(Key: dialectName); |
456 | if (dialectEntry == dialectVersionMap.end()) |
457 | return failure(); |
458 | return dialectEntry->getValue().get(); |
459 | } |
460 | |
461 | private: |
462 | int64_t bytecodeVersion; |
463 | EncodingEmitter &emitter; |
464 | IRNumberingState &numberingState; |
465 | StringSectionBuilder &stringSection; |
466 | const DialectVersionMapT &dialectVersionMap; |
467 | }; |
468 | |
469 | namespace { |
470 | class PropertiesSectionBuilder { |
471 | public: |
472 | PropertiesSectionBuilder(IRNumberingState &numberingState, |
473 | StringSectionBuilder &stringSection, |
474 | const BytecodeWriterConfig::Impl &config) |
475 | : numberingState(numberingState), stringSection(stringSection), |
476 | config(config) {} |
477 | |
478 | /// Emit the op properties in the properties section and return the index of |
479 | /// the properties within the section. Return -1 if no properties was emitted. |
480 | std::optional<ssize_t> emit(Operation *op) { |
481 | EncodingEmitter propertiesEmitter; |
482 | if (!op->getPropertiesStorageSize()) |
483 | return std::nullopt; |
484 | if (!op->isRegistered()) { |
485 | // Unregistered op are storing properties as an optional attribute. |
486 | Attribute prop = *op->getPropertiesStorage().as<Attribute *>(); |
487 | if (!prop) |
488 | return std::nullopt; |
489 | EncodingEmitter sizeEmitter; |
490 | sizeEmitter.emitVarInt(value: numberingState.getNumber(attr: prop)); |
491 | scratch.clear(); |
492 | llvm::raw_svector_ostream os(scratch); |
493 | sizeEmitter.writeTo(os); |
494 | return emit(rawProperties: scratch); |
495 | } |
496 | |
497 | EncodingEmitter emitter; |
498 | DialectWriter propertiesWriter(config.bytecodeVersion, emitter, |
499 | numberingState, stringSection, |
500 | config.dialectVersionMap); |
501 | auto iface = cast<BytecodeOpInterface>(op); |
502 | iface.writeProperties(propertiesWriter); |
503 | scratch.clear(); |
504 | llvm::raw_svector_ostream os(scratch); |
505 | emitter.writeTo(os); |
506 | return emit(rawProperties: scratch); |
507 | } |
508 | |
509 | /// Write the current set of properties to the given emitter. |
510 | void write(EncodingEmitter &emitter) { |
511 | emitter.emitVarInt(value: propertiesStorage.size()); |
512 | if (propertiesStorage.empty()) |
513 | return; |
514 | for (const auto &storage : propertiesStorage) { |
515 | if (storage.empty()) { |
516 | emitter.emitBytes(bytes: ArrayRef<uint8_t>()); |
517 | continue; |
518 | } |
519 | emitter.emitBytes(bytes: ArrayRef(reinterpret_cast<const uint8_t *>(&storage[0]), |
520 | storage.size())); |
521 | } |
522 | } |
523 | |
524 | /// Returns true if the section is empty. |
525 | bool empty() { return propertiesStorage.empty(); } |
526 | |
527 | private: |
528 | /// Emit raw data and returns the offset in the internal buffer. |
529 | /// Data are deduplicated and will be copied in the internal buffer only if |
530 | /// they don't exist there already. |
531 | ssize_t emit(ArrayRef<char> rawProperties) { |
532 | // Populate a scratch buffer with the properties size. |
533 | SmallVector<char> sizeScratch; |
534 | { |
535 | EncodingEmitter sizeEmitter; |
536 | sizeEmitter.emitVarInt(value: rawProperties.size()); |
537 | llvm::raw_svector_ostream os(sizeScratch); |
538 | sizeEmitter.writeTo(os); |
539 | } |
540 | // Append a new storage to the table now. |
541 | size_t index = propertiesStorage.size(); |
542 | propertiesStorage.emplace_back(); |
543 | std::vector<char> &newStorage = propertiesStorage.back(); |
544 | size_t propertiesSize = sizeScratch.size() + rawProperties.size(); |
545 | newStorage.reserve(n: propertiesSize); |
546 | newStorage.insert(position: newStorage.end(), first: sizeScratch.begin(), last: sizeScratch.end()); |
547 | newStorage.insert(position: newStorage.end(), first: rawProperties.begin(), |
548 | last: rawProperties.end()); |
549 | |
550 | // Try to de-duplicate the new serialized properties. |
551 | // If the properties is a duplicate, pop it back from the storage. |
552 | auto inserted = propertiesUniquing.insert( |
553 | KV: std::make_pair(x: ArrayRef<char>(newStorage), y&: index)); |
554 | if (!inserted.second) |
555 | propertiesStorage.pop_back(); |
556 | return inserted.first->getSecond(); |
557 | } |
558 | |
559 | /// Storage for properties. |
560 | std::vector<std::vector<char>> propertiesStorage; |
561 | SmallVector<char> scratch; |
562 | DenseMap<ArrayRef<char>, int64_t> propertiesUniquing; |
563 | IRNumberingState &numberingState; |
564 | StringSectionBuilder &stringSection; |
565 | const BytecodeWriterConfig::Impl &config; |
566 | }; |
567 | } // namespace |
568 | |
569 | /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need |
570 | /// to go through an intermediate buffer when interacting with code that wants a |
571 | /// raw_ostream. |
572 | class RawEmitterOstream : public raw_ostream { |
573 | public: |
574 | explicit RawEmitterOstream(EncodingEmitter &emitter) : emitter(emitter) { |
575 | SetUnbuffered(); |
576 | } |
577 | |
578 | private: |
579 | void write_impl(const char *ptr, size_t size) override { |
580 | emitter.emitBytes(bytes: {reinterpret_cast<const uint8_t *>(ptr), size}); |
581 | } |
582 | uint64_t current_pos() const override { return emitter.size(); } |
583 | |
584 | /// The section being emitted to. |
585 | EncodingEmitter &emitter; |
586 | }; |
587 | } // namespace |
588 | |
589 | void EncodingEmitter::writeTo(raw_ostream &os) const { |
590 | for (auto &prevResult : prevResultList) |
591 | os.write(Ptr: (const char *)prevResult.data(), Size: prevResult.size()); |
592 | os.write(Ptr: (const char *)currentResult.data(), Size: currentResult.size()); |
593 | } |
594 | |
595 | void EncodingEmitter::emitMultiByteVarInt(uint64_t value) { |
596 | // Compute the number of bytes needed to encode the value. Each byte can hold |
597 | // up to 7-bits of data. We only check up to the number of bits we can encode |
598 | // in the first byte (8). |
599 | uint64_t it = value >> 7; |
600 | for (size_t numBytes = 2; numBytes < 9; ++numBytes) { |
601 | if (LLVM_LIKELY(it >>= 7) == 0) { |
602 | uint64_t encodedValue = (value << 1) | 0x1; |
603 | encodedValue <<= (numBytes - 1); |
604 | llvm::support::ulittle64_t encodedValueLE(encodedValue); |
605 | emitBytes(bytes: {reinterpret_cast<uint8_t *>(&encodedValueLE), numBytes}); |
606 | return; |
607 | } |
608 | } |
609 | |
610 | // If the value is too large to encode in a single byte, emit a special all |
611 | // zero marker byte and splat the value directly. |
612 | emitByte(byte: 0); |
613 | llvm::support::ulittle64_t valueLE(value); |
614 | emitBytes(bytes: {reinterpret_cast<uint8_t *>(&valueLE), sizeof(valueLE)}); |
615 | } |
616 | |
617 | //===----------------------------------------------------------------------===// |
618 | // Bytecode Writer |
619 | //===----------------------------------------------------------------------===// |
620 | |
621 | namespace { |
622 | class BytecodeWriter { |
623 | public: |
624 | BytecodeWriter(Operation *op, const BytecodeWriterConfig &config) |
625 | : numberingState(op, config), config(config.getImpl()), |
626 | propertiesSection(numberingState, stringSection, config.getImpl()) {} |
627 | |
628 | /// Write the bytecode for the given root operation. |
629 | LogicalResult write(Operation *rootOp, raw_ostream &os); |
630 | |
631 | private: |
632 | //===--------------------------------------------------------------------===// |
633 | // Dialects |
634 | |
635 | void writeDialectSection(EncodingEmitter &emitter); |
636 | |
637 | //===--------------------------------------------------------------------===// |
638 | // Attributes and Types |
639 | |
640 | void writeAttrTypeSection(EncodingEmitter &emitter); |
641 | |
642 | //===--------------------------------------------------------------------===// |
643 | // Operations |
644 | |
645 | LogicalResult writeBlock(EncodingEmitter &emitter, Block *block); |
646 | LogicalResult writeOp(EncodingEmitter &emitter, Operation *op); |
647 | LogicalResult writeRegion(EncodingEmitter &emitter, Region *region); |
648 | LogicalResult writeIRSection(EncodingEmitter &emitter, Operation *op); |
649 | |
650 | LogicalResult writeRegions(EncodingEmitter &emitter, |
651 | MutableArrayRef<Region> regions) { |
652 | return success(isSuccess: llvm::all_of(Range&: regions, P: [&](Region ®ion) { |
653 | return succeeded(result: writeRegion(emitter, region: ®ion)); |
654 | })); |
655 | } |
656 | |
657 | //===--------------------------------------------------------------------===// |
658 | // Resources |
659 | |
660 | void writeResourceSection(Operation *op, EncodingEmitter &emitter); |
661 | |
662 | //===--------------------------------------------------------------------===// |
663 | // Strings |
664 | |
665 | void writeStringSection(EncodingEmitter &emitter); |
666 | |
667 | //===--------------------------------------------------------------------===// |
668 | // Properties |
669 | |
670 | void writePropertiesSection(EncodingEmitter &emitter); |
671 | |
672 | //===--------------------------------------------------------------------===// |
673 | // Helpers |
674 | |
675 | void writeUseListOrders(EncodingEmitter &emitter, uint8_t &opEncodingMask, |
676 | ValueRange range); |
677 | |
678 | //===--------------------------------------------------------------------===// |
679 | // Fields |
680 | |
681 | /// The builder used for the string section. |
682 | StringSectionBuilder stringSection; |
683 | |
684 | /// The IR numbering state generated for the root operation. |
685 | IRNumberingState numberingState; |
686 | |
687 | /// Configuration dictating bytecode emission. |
688 | const BytecodeWriterConfig::Impl &config; |
689 | |
690 | /// Storage for the properties section |
691 | PropertiesSectionBuilder propertiesSection; |
692 | }; |
693 | } // namespace |
694 | |
695 | LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { |
696 | EncodingEmitter emitter; |
697 | |
698 | // Emit the bytecode file header. This is how we identify the output as a |
699 | // bytecode file. |
700 | emitter.emitString(str: "ML\xefR" ); |
701 | |
702 | // Emit the bytecode version. |
703 | if (config.bytecodeVersion < bytecode::kMinSupportedVersion || |
704 | config.bytecodeVersion > bytecode::kVersion) |
705 | return rootOp->emitError() |
706 | << "unsupported version requested " << config.bytecodeVersion |
707 | << ", must be in range [" |
708 | << static_cast<int64_t>(bytecode::kMinSupportedVersion) << ", " |
709 | << static_cast<int64_t>(bytecode::kVersion) << ']'; |
710 | emitter.emitVarInt(value: config.bytecodeVersion); |
711 | |
712 | // Emit the producer. |
713 | emitter.emitNulTerminatedString(str: config.producer); |
714 | |
715 | // Emit the dialect section. |
716 | writeDialectSection(emitter); |
717 | |
718 | // Emit the attributes and types section. |
719 | writeAttrTypeSection(emitter); |
720 | |
721 | // Emit the IR section. |
722 | if (failed(result: writeIRSection(emitter, op: rootOp))) |
723 | return failure(); |
724 | |
725 | // Emit the resources section. |
726 | writeResourceSection(op: rootOp, emitter); |
727 | |
728 | // Emit the string section. |
729 | writeStringSection(emitter); |
730 | |
731 | // Emit the properties section. |
732 | if (config.bytecodeVersion >= bytecode::kNativePropertiesEncoding) |
733 | writePropertiesSection(emitter); |
734 | else if (!propertiesSection.empty()) |
735 | return rootOp->emitError( |
736 | message: "unexpected properties emitted incompatible with bytecode <5" ); |
737 | |
738 | // Write the generated bytecode to the provided output stream. |
739 | emitter.writeTo(os); |
740 | |
741 | return success(); |
742 | } |
743 | |
744 | //===----------------------------------------------------------------------===// |
745 | // Dialects |
746 | |
747 | /// Write the given entries in contiguous groups with the same parent dialect. |
748 | /// Each dialect sub-group is encoded with the parent dialect and number of |
749 | /// elements, followed by the encoding for the entries. The given callback is |
750 | /// invoked to encode each individual entry. |
751 | template <typename EntriesT, typename EntryCallbackT> |
752 | static void writeDialectGrouping(EncodingEmitter &emitter, EntriesT &&entries, |
753 | EntryCallbackT &&callback) { |
754 | for (auto it = entries.begin(), e = entries.end(); it != e;) { |
755 | auto groupStart = it++; |
756 | |
757 | // Find the end of the group that shares the same parent dialect. |
758 | DialectNumbering *currentDialect = groupStart->dialect; |
759 | it = std::find_if(it, e, [&](const auto &entry) { |
760 | return entry.dialect != currentDialect; |
761 | }); |
762 | |
763 | // Emit the dialect and number of elements. |
764 | emitter.emitVarInt(value: currentDialect->number); |
765 | emitter.emitVarInt(value: std::distance(groupStart, it)); |
766 | |
767 | // Emit the entries within the group. |
768 | for (auto &entry : llvm::make_range(groupStart, it)) |
769 | callback(entry); |
770 | } |
771 | } |
772 | |
773 | void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) { |
774 | EncodingEmitter dialectEmitter; |
775 | |
776 | // Emit the referenced dialects. |
777 | auto dialects = numberingState.getDialects(); |
778 | dialectEmitter.emitVarInt(value: llvm::size(Range&: dialects)); |
779 | for (DialectNumbering &dialect : dialects) { |
780 | // Write the string section and get the ID. |
781 | size_t nameID = stringSection.insert(str: dialect.name); |
782 | |
783 | if (config.bytecodeVersion < bytecode::kDialectVersioning) { |
784 | dialectEmitter.emitVarInt(value: nameID); |
785 | continue; |
786 | } |
787 | |
788 | // Try writing the version to the versionEmitter. |
789 | EncodingEmitter versionEmitter; |
790 | if (dialect.interface) { |
791 | // The writer used when emitting using a custom bytecode encoding. |
792 | DialectWriter versionWriter(config.bytecodeVersion, versionEmitter, |
793 | numberingState, stringSection, |
794 | config.dialectVersionMap); |
795 | dialect.interface->writeVersion(writer&: versionWriter); |
796 | } |
797 | |
798 | // If the version emitter is empty, version is not available. We can encode |
799 | // this in the dialect ID, so if there is no version, we don't write the |
800 | // section. |
801 | size_t versionAvailable = versionEmitter.size() > 0; |
802 | dialectEmitter.emitVarIntWithFlag(value: nameID, flag: versionAvailable); |
803 | if (versionAvailable) |
804 | dialectEmitter.emitSection(code: bytecode::Section::kDialectVersions, |
805 | emitter: std::move(versionEmitter)); |
806 | } |
807 | |
808 | if (config.bytecodeVersion >= bytecode::kElideUnknownBlockArgLocation) |
809 | dialectEmitter.emitVarInt(value: size(Range: numberingState.getOpNames())); |
810 | |
811 | // Emit the referenced operation names grouped by dialect. |
812 | auto emitOpName = [&](OpNameNumbering &name) { |
813 | size_t stringId = stringSection.insert(str: name.name.stripDialect()); |
814 | if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding) |
815 | dialectEmitter.emitVarInt(value: stringId); |
816 | else |
817 | dialectEmitter.emitVarIntWithFlag(value: stringId, flag: name.name.isRegistered()); |
818 | }; |
819 | writeDialectGrouping(emitter&: dialectEmitter, entries: numberingState.getOpNames(), callback&: emitOpName); |
820 | |
821 | emitter.emitSection(code: bytecode::Section::kDialect, emitter: std::move(dialectEmitter)); |
822 | } |
823 | |
824 | //===----------------------------------------------------------------------===// |
825 | // Attributes and Types |
826 | |
827 | void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { |
828 | EncodingEmitter attrTypeEmitter; |
829 | EncodingEmitter offsetEmitter; |
830 | offsetEmitter.emitVarInt(value: llvm::size(Range: numberingState.getAttributes())); |
831 | offsetEmitter.emitVarInt(value: llvm::size(Range: numberingState.getTypes())); |
832 | |
833 | // A functor used to emit an attribute or type entry. |
834 | uint64_t prevOffset = 0; |
835 | auto emitAttrOrType = [&](auto &entry) { |
836 | auto entryValue = entry.getValue(); |
837 | |
838 | auto emitAttrOrTypeRawImpl = [&]() -> void { |
839 | RawEmitterOstream(attrTypeEmitter) << entryValue; |
840 | attrTypeEmitter.emitByte(byte: 0); |
841 | }; |
842 | auto emitAttrOrTypeImpl = [&]() -> bool { |
843 | // TODO: We don't currently support custom encoded mutable types and |
844 | // attributes. |
845 | if (entryValue.template hasTrait<TypeTrait::IsMutable>() || |
846 | entryValue.template hasTrait<AttributeTrait::IsMutable>()) { |
847 | emitAttrOrTypeRawImpl(); |
848 | return false; |
849 | } |
850 | |
851 | DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, |
852 | numberingState, stringSection, |
853 | config.dialectVersionMap); |
854 | if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) { |
855 | for (const auto &callback : config.typeWriterCallbacks) { |
856 | if (succeeded(callback->write(entryValue, dialectWriter))) |
857 | return true; |
858 | } |
859 | if (const BytecodeDialectInterface *interface = |
860 | entry.dialect->interface) { |
861 | if (succeeded(interface->writeType(type: entryValue, writer&: dialectWriter))) |
862 | return true; |
863 | } |
864 | } else { |
865 | for (const auto &callback : config.attributeWriterCallbacks) { |
866 | if (succeeded(callback->write(entryValue, dialectWriter))) |
867 | return true; |
868 | } |
869 | if (const BytecodeDialectInterface *interface = |
870 | entry.dialect->interface) { |
871 | if (succeeded(interface->writeAttribute(attr: entryValue, writer&: dialectWriter))) |
872 | return true; |
873 | } |
874 | } |
875 | |
876 | // If the entry was not emitted using a callback or a dialect interface, |
877 | // emit it using the textual format. |
878 | emitAttrOrTypeRawImpl(); |
879 | return false; |
880 | }; |
881 | |
882 | bool hasCustomEncoding = emitAttrOrTypeImpl(); |
883 | |
884 | // Record the offset of this entry. |
885 | uint64_t curOffset = attrTypeEmitter.size(); |
886 | offsetEmitter.emitVarIntWithFlag(value: curOffset - prevOffset, flag: hasCustomEncoding); |
887 | prevOffset = curOffset; |
888 | }; |
889 | |
890 | // Emit the attribute and type entries for each dialect. |
891 | writeDialectGrouping(emitter&: offsetEmitter, entries: numberingState.getAttributes(), |
892 | callback&: emitAttrOrType); |
893 | writeDialectGrouping(emitter&: offsetEmitter, entries: numberingState.getTypes(), |
894 | callback&: emitAttrOrType); |
895 | |
896 | // Emit the sections to the stream. |
897 | emitter.emitSection(code: bytecode::Section::kAttrTypeOffset, |
898 | emitter: std::move(offsetEmitter)); |
899 | emitter.emitSection(code: bytecode::Section::kAttrType, emitter: std::move(attrTypeEmitter)); |
900 | } |
901 | |
902 | //===----------------------------------------------------------------------===// |
903 | // Operations |
904 | |
905 | LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter, |
906 | Block *block) { |
907 | ArrayRef<BlockArgument> args = block->getArguments(); |
908 | bool hasArgs = !args.empty(); |
909 | |
910 | // Emit the number of operations in this block, and if it has arguments. We |
911 | // use the low bit of the operation count to indicate if the block has |
912 | // arguments. |
913 | unsigned numOps = numberingState.getOperationCount(block); |
914 | emitter.emitVarIntWithFlag(value: numOps, flag: hasArgs); |
915 | |
916 | // Emit the arguments of the block. |
917 | if (hasArgs) { |
918 | emitter.emitVarInt(value: args.size()); |
919 | for (BlockArgument arg : args) { |
920 | Location argLoc = arg.getLoc(); |
921 | if (config.bytecodeVersion >= bytecode::kElideUnknownBlockArgLocation) { |
922 | emitter.emitVarIntWithFlag(numberingState.getNumber(arg.getType()), |
923 | !isa<UnknownLoc>(argLoc)); |
924 | if (!isa<UnknownLoc>(argLoc)) |
925 | emitter.emitVarInt(value: numberingState.getNumber(attr: argLoc)); |
926 | } else { |
927 | emitter.emitVarInt(value: numberingState.getNumber(type: arg.getType())); |
928 | emitter.emitVarInt(value: numberingState.getNumber(attr: argLoc)); |
929 | } |
930 | } |
931 | if (config.bytecodeVersion >= bytecode::kUseListOrdering) { |
932 | uint64_t maskOffset = emitter.size(); |
933 | uint8_t encodingMask = 0; |
934 | emitter.emitByte(byte: 0); |
935 | writeUseListOrders(emitter, opEncodingMask&: encodingMask, range: args); |
936 | if (encodingMask) |
937 | emitter.patchByte(offset: maskOffset, value: encodingMask); |
938 | } |
939 | } |
940 | |
941 | // Emit the operations within the block. |
942 | for (Operation &op : *block) |
943 | if (failed(result: writeOp(emitter, op: &op))) |
944 | return failure(); |
945 | return success(); |
946 | } |
947 | |
948 | LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { |
949 | emitter.emitVarInt(value: numberingState.getNumber(opName: op->getName())); |
950 | |
951 | // Emit a mask for the operation components. We need to fill this in later |
952 | // (when we actually know what needs to be emitted), so emit a placeholder for |
953 | // now. |
954 | uint64_t maskOffset = emitter.size(); |
955 | uint8_t opEncodingMask = 0; |
956 | emitter.emitByte(byte: 0); |
957 | |
958 | // Emit the location for this operation. |
959 | emitter.emitVarInt(value: numberingState.getNumber(attr: op->getLoc())); |
960 | |
961 | // Emit the attributes of this operation. |
962 | DictionaryAttr attrs = op->getDiscardableAttrDictionary(); |
963 | // Allow deployment to version <kNativePropertiesEncoding by merging inherent |
964 | // attribute with the discardable ones. We should fail if there are any |
965 | // conflicts. When properties are not used by the op, also store everything as |
966 | // attributes. |
967 | if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding || |
968 | !op->getPropertiesStorage()) { |
969 | attrs = op->getAttrDictionary(); |
970 | } |
971 | if (!attrs.empty()) { |
972 | opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs; |
973 | emitter.emitVarInt(value: numberingState.getNumber(attrs)); |
974 | } |
975 | |
976 | // Emit the properties of this operation, for now we still support deployment |
977 | // to version <kNativePropertiesEncoding. |
978 | if (config.bytecodeVersion >= bytecode::kNativePropertiesEncoding) { |
979 | std::optional<ssize_t> propertiesId = propertiesSection.emit(op); |
980 | if (propertiesId.has_value()) { |
981 | opEncodingMask |= bytecode::OpEncodingMask::kHasProperties; |
982 | emitter.emitVarInt(value: *propertiesId); |
983 | } |
984 | } |
985 | |
986 | // Emit the result types of the operation. |
987 | if (unsigned numResults = op->getNumResults()) { |
988 | opEncodingMask |= bytecode::OpEncodingMask::kHasResults; |
989 | emitter.emitVarInt(value: numResults); |
990 | for (Type type : op->getResultTypes()) |
991 | emitter.emitVarInt(value: numberingState.getNumber(type)); |
992 | } |
993 | |
994 | // Emit the operands of the operation. |
995 | if (unsigned numOperands = op->getNumOperands()) { |
996 | opEncodingMask |= bytecode::OpEncodingMask::kHasOperands; |
997 | emitter.emitVarInt(value: numOperands); |
998 | for (Value operand : op->getOperands()) |
999 | emitter.emitVarInt(value: numberingState.getNumber(value: operand)); |
1000 | } |
1001 | |
1002 | // Emit the successors of the operation. |
1003 | if (unsigned numSuccessors = op->getNumSuccessors()) { |
1004 | opEncodingMask |= bytecode::OpEncodingMask::kHasSuccessors; |
1005 | emitter.emitVarInt(value: numSuccessors); |
1006 | for (Block *successor : op->getSuccessors()) |
1007 | emitter.emitVarInt(value: numberingState.getNumber(block: successor)); |
1008 | } |
1009 | |
1010 | // Emit the use-list orders to bytecode, so we can reconstruct the same order |
1011 | // at parsing. |
1012 | if (config.bytecodeVersion >= bytecode::kUseListOrdering) |
1013 | writeUseListOrders(emitter, opEncodingMask, range: ValueRange(op->getResults())); |
1014 | |
1015 | // Check for regions. |
1016 | unsigned numRegions = op->getNumRegions(); |
1017 | if (numRegions) |
1018 | opEncodingMask |= bytecode::OpEncodingMask::kHasInlineRegions; |
1019 | |
1020 | // Update the mask for the operation. |
1021 | emitter.patchByte(offset: maskOffset, value: opEncodingMask); |
1022 | |
1023 | // With the mask emitted, we can now emit the regions of the operation. We do |
1024 | // this after mask emission to avoid offset complications that may arise by |
1025 | // emitting the regions first (e.g. if the regions are huge, backpatching the |
1026 | // op encoding mask is more annoying). |
1027 | if (numRegions) { |
1028 | bool isIsolatedFromAbove = numberingState.isIsolatedFromAbove(op); |
1029 | emitter.emitVarIntWithFlag(value: numRegions, flag: isIsolatedFromAbove); |
1030 | |
1031 | // If the region is not isolated from above, or we are emitting bytecode |
1032 | // targeting version <kLazyLoading, we don't use a section. |
1033 | if (isIsolatedFromAbove && |
1034 | config.bytecodeVersion >= bytecode::kLazyLoading) { |
1035 | EncodingEmitter regionEmitter; |
1036 | if (failed(result: writeRegions(emitter&: regionEmitter, regions: op->getRegions()))) |
1037 | return failure(); |
1038 | emitter.emitSection(code: bytecode::Section::kIR, emitter: std::move(regionEmitter)); |
1039 | |
1040 | } else if (failed(result: writeRegions(emitter, regions: op->getRegions()))) { |
1041 | return failure(); |
1042 | } |
1043 | } |
1044 | return success(); |
1045 | } |
1046 | |
1047 | void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter, |
1048 | uint8_t &opEncodingMask, |
1049 | ValueRange range) { |
1050 | // Loop over the results and store the use-list order per result index. |
1051 | DenseMap<unsigned, llvm::SmallVector<unsigned>> map; |
1052 | for (auto item : llvm::enumerate(First&: range)) { |
1053 | auto value = item.value(); |
1054 | // No need to store a custom use-list order if the result does not have |
1055 | // multiple uses. |
1056 | if (value.use_empty() || value.hasOneUse()) |
1057 | continue; |
1058 | |
1059 | // For each result, assemble the list of pairs (use-list-index, |
1060 | // global-value-index). While doing so, detect if the global-value-index is |
1061 | // already ordered with respect to the use-list-index. |
1062 | bool alreadyOrdered = true; |
1063 | auto &firstUse = *value.use_begin(); |
1064 | uint64_t prevID = bytecode::getUseID( |
1065 | val&: firstUse, ownerID: numberingState.getNumber(op: firstUse.getOwner())); |
1066 | llvm::SmallVector<std::pair<unsigned, uint64_t>> useListPairs( |
1067 | {{0, prevID}}); |
1068 | |
1069 | for (auto use : llvm::drop_begin(RangeOrContainer: llvm::enumerate(First: value.getUses()))) { |
1070 | uint64_t currentID = bytecode::getUseID( |
1071 | val&: use.value(), ownerID: numberingState.getNumber(op: use.value().getOwner())); |
1072 | // The use-list order achieved when building the IR at parsing always |
1073 | // pushes new uses on front. Hence, if the order by unique ID is |
1074 | // monotonically decreasing, a roundtrip to bytecode preserves such order. |
1075 | alreadyOrdered &= (prevID > currentID); |
1076 | useListPairs.push_back(Elt: {use.index(), currentID}); |
1077 | prevID = currentID; |
1078 | } |
1079 | |
1080 | // Do not emit if the order is already sorted. |
1081 | if (alreadyOrdered) |
1082 | continue; |
1083 | |
1084 | // Sort the use indices by the unique ID indices in descending order. |
1085 | std::sort( |
1086 | first: useListPairs.begin(), last: useListPairs.end(), |
1087 | comp: [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); |
1088 | |
1089 | map.try_emplace(Key: item.index(), Args: llvm::map_range(C&: useListPairs, F: [](auto elem) { |
1090 | return elem.first; |
1091 | })); |
1092 | } |
1093 | |
1094 | if (map.empty()) |
1095 | return; |
1096 | |
1097 | opEncodingMask |= bytecode::OpEncodingMask::kHasUseListOrders; |
1098 | // Emit the number of results that have a custom use-list order if the number |
1099 | // of results is greater than one. |
1100 | if (range.size() != 1) |
1101 | emitter.emitVarInt(value: map.size()); |
1102 | |
1103 | for (const auto &item : map) { |
1104 | auto resultIdx = item.getFirst(); |
1105 | auto useListOrder = item.getSecond(); |
1106 | |
1107 | // Compute the number of uses that are actually shuffled. If those are less |
1108 | // than half of the total uses, encoding the index pair `(src, dst)` is more |
1109 | // space efficient. |
1110 | size_t shuffledElements = |
1111 | llvm::count_if(Range: llvm::enumerate(First&: useListOrder), |
1112 | P: [](auto item) { return item.index() != item.value(); }); |
1113 | bool indexPairEncoding = shuffledElements < (useListOrder.size() / 2); |
1114 | |
1115 | // For single result, we don't need to store the result index. |
1116 | if (range.size() != 1) |
1117 | emitter.emitVarInt(value: resultIdx); |
1118 | |
1119 | if (indexPairEncoding) { |
1120 | emitter.emitVarIntWithFlag(value: shuffledElements * 2, flag: indexPairEncoding); |
1121 | for (auto pair : llvm::enumerate(First&: useListOrder)) { |
1122 | if (pair.index() != pair.value()) { |
1123 | emitter.emitVarInt(value: pair.value()); |
1124 | emitter.emitVarInt(value: pair.index()); |
1125 | } |
1126 | } |
1127 | } else { |
1128 | emitter.emitVarIntWithFlag(value: useListOrder.size(), flag: indexPairEncoding); |
1129 | for (const auto &index : useListOrder) |
1130 | emitter.emitVarInt(value: index); |
1131 | } |
1132 | } |
1133 | } |
1134 | |
1135 | LogicalResult BytecodeWriter::writeRegion(EncodingEmitter &emitter, |
1136 | Region *region) { |
1137 | // If the region is empty, we only need to emit the number of blocks (which is |
1138 | // zero). |
1139 | if (region->empty()) { |
1140 | emitter.emitVarInt(/*numBlocks*/ value: 0); |
1141 | return success(); |
1142 | } |
1143 | |
1144 | // Emit the number of blocks and values within the region. |
1145 | unsigned numBlocks, numValues; |
1146 | std::tie(args&: numBlocks, args&: numValues) = numberingState.getBlockValueCount(region); |
1147 | emitter.emitVarInt(value: numBlocks); |
1148 | emitter.emitVarInt(value: numValues); |
1149 | |
1150 | // Emit the blocks within the region. |
1151 | for (Block &block : *region) |
1152 | if (failed(result: writeBlock(emitter, block: &block))) |
1153 | return failure(); |
1154 | return success(); |
1155 | } |
1156 | |
1157 | LogicalResult BytecodeWriter::writeIRSection(EncodingEmitter &emitter, |
1158 | Operation *op) { |
1159 | EncodingEmitter irEmitter; |
1160 | |
1161 | // Write the IR section the same way as a block with no arguments. Note that |
1162 | // the low-bit of the operation count for a block is used to indicate if the |
1163 | // block has arguments, which in this case is always false. |
1164 | irEmitter.emitVarIntWithFlag(/*numOps*/ value: 1, /*hasArgs*/ flag: false); |
1165 | |
1166 | // Emit the operations. |
1167 | if (failed(result: writeOp(emitter&: irEmitter, op))) |
1168 | return failure(); |
1169 | |
1170 | emitter.emitSection(code: bytecode::Section::kIR, emitter: std::move(irEmitter)); |
1171 | return success(); |
1172 | } |
1173 | |
1174 | //===----------------------------------------------------------------------===// |
1175 | // Resources |
1176 | |
1177 | namespace { |
1178 | /// This class represents a resource builder implementation for the MLIR |
1179 | /// bytecode format. |
1180 | class ResourceBuilder : public AsmResourceBuilder { |
1181 | public: |
1182 | using PostProcessFn = function_ref<void(StringRef, AsmResourceEntryKind)>; |
1183 | |
1184 | ResourceBuilder(EncodingEmitter &emitter, StringSectionBuilder &stringSection, |
1185 | PostProcessFn postProcessFn, bool shouldElideData) |
1186 | : emitter(emitter), stringSection(stringSection), |
1187 | postProcessFn(postProcessFn), shouldElideData(shouldElideData) {} |
1188 | ~ResourceBuilder() override = default; |
1189 | |
1190 | void buildBlob(StringRef key, ArrayRef<char> data, |
1191 | uint32_t dataAlignment) final { |
1192 | if (!shouldElideData) |
1193 | emitter.emitOwnedBlobAndAlignment(data, alignment: dataAlignment); |
1194 | postProcessFn(key, AsmResourceEntryKind::Blob); |
1195 | } |
1196 | void buildBool(StringRef key, bool data) final { |
1197 | if (!shouldElideData) |
1198 | emitter.emitByte(byte: data); |
1199 | postProcessFn(key, AsmResourceEntryKind::Bool); |
1200 | } |
1201 | void buildString(StringRef key, StringRef data) final { |
1202 | if (!shouldElideData) |
1203 | emitter.emitVarInt(value: stringSection.insert(str: data)); |
1204 | postProcessFn(key, AsmResourceEntryKind::String); |
1205 | } |
1206 | |
1207 | private: |
1208 | EncodingEmitter &emitter; |
1209 | StringSectionBuilder &stringSection; |
1210 | PostProcessFn postProcessFn; |
1211 | bool shouldElideData = false; |
1212 | }; |
1213 | } // namespace |
1214 | |
1215 | void BytecodeWriter::writeResourceSection(Operation *op, |
1216 | EncodingEmitter &emitter) { |
1217 | EncodingEmitter resourceEmitter; |
1218 | EncodingEmitter resourceOffsetEmitter; |
1219 | uint64_t prevOffset = 0; |
1220 | SmallVector<std::tuple<StringRef, AsmResourceEntryKind, uint64_t>> |
1221 | curResourceEntries; |
1222 | |
1223 | // Functor used to process the offset for a resource of `kind` defined by |
1224 | // 'key'. |
1225 | auto appendResourceOffset = [&](StringRef key, AsmResourceEntryKind kind) { |
1226 | uint64_t curOffset = resourceEmitter.size(); |
1227 | curResourceEntries.emplace_back(Args&: key, Args&: kind, Args: curOffset - prevOffset); |
1228 | prevOffset = curOffset; |
1229 | }; |
1230 | |
1231 | // Functor used to emit a resource group defined by 'key'. |
1232 | auto emitResourceGroup = [&](uint64_t key) { |
1233 | resourceOffsetEmitter.emitVarInt(value: key); |
1234 | resourceOffsetEmitter.emitVarInt(value: curResourceEntries.size()); |
1235 | for (auto [key, kind, size] : curResourceEntries) { |
1236 | resourceOffsetEmitter.emitVarInt(value: stringSection.insert(str: key)); |
1237 | resourceOffsetEmitter.emitVarInt(value: size); |
1238 | resourceOffsetEmitter.emitByte(byte: kind); |
1239 | } |
1240 | }; |
1241 | |
1242 | // Builder used to emit resources. |
1243 | ResourceBuilder entryBuilder(resourceEmitter, stringSection, |
1244 | appendResourceOffset, |
1245 | config.shouldElideResourceData); |
1246 | |
1247 | // Emit the external resource entries. |
1248 | resourceOffsetEmitter.emitVarInt(value: config.externalResourcePrinters.size()); |
1249 | for (const auto &printer : config.externalResourcePrinters) { |
1250 | curResourceEntries.clear(); |
1251 | printer->buildResources(op, builder&: entryBuilder); |
1252 | emitResourceGroup(stringSection.insert(str: printer->getName())); |
1253 | } |
1254 | |
1255 | // Emit the dialect resource entries. |
1256 | for (DialectNumbering &dialect : numberingState.getDialects()) { |
1257 | if (!dialect.asmInterface) |
1258 | continue; |
1259 | curResourceEntries.clear(); |
1260 | dialect.asmInterface->buildResources(op, referencedResources: dialect.resources, builder&: entryBuilder); |
1261 | |
1262 | // Emit the declaration resources for this dialect, these didn't get emitted |
1263 | // by the interface. These resources don't have data attached, so just use a |
1264 | // "blob" kind as a placeholder. |
1265 | for (const auto &resource : dialect.resourceMap) |
1266 | if (resource.second->isDeclaration) |
1267 | appendResourceOffset(resource.first, AsmResourceEntryKind::Blob); |
1268 | |
1269 | // Emit the resource group for this dialect. |
1270 | if (!curResourceEntries.empty()) |
1271 | emitResourceGroup(dialect.number); |
1272 | } |
1273 | |
1274 | // If we didn't emit any resource groups, elide the resource sections. |
1275 | if (resourceOffsetEmitter.size() == 0) |
1276 | return; |
1277 | |
1278 | emitter.emitSection(code: bytecode::Section::kResourceOffset, |
1279 | emitter: std::move(resourceOffsetEmitter)); |
1280 | emitter.emitSection(code: bytecode::Section::kResource, emitter: std::move(resourceEmitter)); |
1281 | } |
1282 | |
1283 | //===----------------------------------------------------------------------===// |
1284 | // Strings |
1285 | |
1286 | void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) { |
1287 | EncodingEmitter stringEmitter; |
1288 | stringSection.write(emitter&: stringEmitter); |
1289 | emitter.emitSection(code: bytecode::Section::kString, emitter: std::move(stringEmitter)); |
1290 | } |
1291 | |
1292 | //===----------------------------------------------------------------------===// |
1293 | // Properties |
1294 | |
1295 | void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) { |
1296 | EncodingEmitter propertiesEmitter; |
1297 | propertiesSection.write(emitter&: propertiesEmitter); |
1298 | emitter.emitSection(code: bytecode::Section::kProperties, |
1299 | emitter: std::move(propertiesEmitter)); |
1300 | } |
1301 | |
1302 | //===----------------------------------------------------------------------===// |
1303 | // Entry Points |
1304 | //===----------------------------------------------------------------------===// |
1305 | |
1306 | LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, |
1307 | const BytecodeWriterConfig &config) { |
1308 | BytecodeWriter writer(op, config); |
1309 | return writer.write(rootOp: op, os); |
1310 | } |
1311 | |