1 | //===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header defines various interfaces and utilities necessary for dialects |
10 | // to hook into bytecode serialization. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H |
15 | #define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H |
16 | |
17 | #include "mlir/IR/Attributes.h" |
18 | #include "mlir/IR/Diagnostics.h" |
19 | #include "mlir/IR/Dialect.h" |
20 | #include "mlir/IR/DialectInterface.h" |
21 | #include "mlir/IR/OpImplementation.h" |
22 | #include "mlir/Support/LogicalResult.h" |
23 | #include "llvm/ADT/STLExtras.h" |
24 | #include "llvm/ADT/Twine.h" |
25 | |
26 | namespace mlir { |
27 | //===--------------------------------------------------------------------===// |
28 | // Dialect Version Interface. |
29 | //===--------------------------------------------------------------------===// |
30 | |
31 | /// This class is used to represent the version of a dialect, for the purpose |
32 | /// of polymorphic destruction. |
33 | class DialectVersion { |
34 | public: |
35 | virtual ~DialectVersion() = default; |
36 | }; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // DialectBytecodeReader |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | /// This class defines a virtual interface for reading a bytecode stream, |
43 | /// providing hooks into the bytecode reader. As such, this class should only be |
44 | /// derived and defined by the main bytecode reader, users (i.e. dialects) |
45 | /// should generally only interact with this class via the |
46 | /// BytecodeDialectInterface below. |
47 | class DialectBytecodeReader { |
48 | public: |
49 | virtual ~DialectBytecodeReader() = default; |
50 | |
51 | /// Emit an error to the reader. |
52 | virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0; |
53 | |
54 | /// Retrieve the dialect version by name if available. |
55 | virtual FailureOr<const DialectVersion *> |
56 | getDialectVersion(StringRef dialectName) const = 0; |
57 | template <class T> |
58 | FailureOr<const DialectVersion *> getDialectVersion() const { |
59 | return getDialectVersion(T::getDialectNamespace()); |
60 | } |
61 | |
62 | /// Retrieve the context associated to the reader. |
63 | virtual MLIRContext *getContext() const = 0; |
64 | |
65 | /// Return the bytecode version being read. |
66 | virtual uint64_t getBytecodeVersion() const = 0; |
67 | |
68 | /// Read out a list of elements, invoking the provided callback for each |
69 | /// element. The callback function may be in any of the following forms: |
70 | /// * LogicalResult(T &) |
71 | /// * FailureOr<T>() |
72 | template <typename T, typename CallbackFn> |
73 | LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) { |
74 | uint64_t size; |
75 | if (failed(result: readVarInt(result&: size))) |
76 | return failure(); |
77 | result.reserve(size); |
78 | |
79 | for (uint64_t i = 0; i < size; ++i) { |
80 | // Check if the callback uses FailureOr, or populates the result by |
81 | // reference. |
82 | if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) { |
83 | T element = {}; |
84 | if (failed(callback(element))) |
85 | return failure(); |
86 | result.emplace_back(std::move(element)); |
87 | } else { |
88 | FailureOr<T> element = callback(); |
89 | if (failed(element)) |
90 | return failure(); |
91 | result.emplace_back(std::move(*element)); |
92 | } |
93 | } |
94 | return success(); |
95 | } |
96 | |
97 | //===--------------------------------------------------------------------===// |
98 | // IR |
99 | //===--------------------------------------------------------------------===// |
100 | |
101 | /// Read a reference to the given attribute. |
102 | virtual LogicalResult readAttribute(Attribute &result) = 0; |
103 | /// Read an optional reference to the given attribute. Returns success even if |
104 | /// the Attribute isn't present. |
105 | virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0; |
106 | |
107 | template <typename T> |
108 | LogicalResult readAttributes(SmallVectorImpl<T> &attrs) { |
109 | return readList(attrs, [this](T &attr) { return readAttribute(attr); }); |
110 | } |
111 | template <typename T> |
112 | LogicalResult readAttribute(T &result) { |
113 | Attribute baseResult; |
114 | if (failed(result: readAttribute(result&: baseResult))) |
115 | return failure(); |
116 | if ((result = dyn_cast<T>(baseResult))) |
117 | return success(); |
118 | return emitError() << "expected " << llvm::getTypeName<T>() |
119 | << ", but got: " << baseResult; |
120 | } |
121 | template <typename T> |
122 | LogicalResult readOptionalAttribute(T &result) { |
123 | Attribute baseResult; |
124 | if (failed(result: readOptionalAttribute(attr&: baseResult))) |
125 | return failure(); |
126 | if (!baseResult) |
127 | return success(); |
128 | if ((result = dyn_cast<T>(baseResult))) |
129 | return success(); |
130 | return emitError() << "expected " << llvm::getTypeName<T>() |
131 | << ", but got: " << baseResult; |
132 | } |
133 | |
134 | /// Read a reference to the given type. |
135 | virtual LogicalResult readType(Type &result) = 0; |
136 | template <typename T> |
137 | LogicalResult readTypes(SmallVectorImpl<T> &types) { |
138 | return readList(types, [this](T &type) { return readType(type); }); |
139 | } |
140 | template <typename T> |
141 | LogicalResult readType(T &result) { |
142 | Type baseResult; |
143 | if (failed(result: readType(result&: baseResult))) |
144 | return failure(); |
145 | if ((result = dyn_cast<T>(baseResult))) |
146 | return success(); |
147 | return emitError() << "expected " << llvm::getTypeName<T>() |
148 | << ", but got: " << baseResult; |
149 | } |
150 | |
151 | /// Read a handle to a dialect resource. |
152 | template <typename ResourceT> |
153 | FailureOr<ResourceT> readResourceHandle() { |
154 | FailureOr<AsmDialectResourceHandle> handle = readResourceHandle(); |
155 | if (failed(result: handle)) |
156 | return failure(); |
157 | if (auto *result = dyn_cast<ResourceT>(&*handle)) |
158 | return std::move(*result); |
159 | return emitError() << "provided resource handle differs from the " |
160 | "expected resource type" ; |
161 | } |
162 | |
163 | //===--------------------------------------------------------------------===// |
164 | // Primitives |
165 | //===--------------------------------------------------------------------===// |
166 | |
167 | /// Read a variable width integer. |
168 | virtual LogicalResult readVarInt(uint64_t &result) = 0; |
169 | |
170 | /// Read a signed variable width integer. |
171 | virtual LogicalResult readSignedVarInt(int64_t &result) = 0; |
172 | LogicalResult readSignedVarInts(SmallVectorImpl<int64_t> &result) { |
173 | return readList(result, |
174 | callback: [this](int64_t &value) { return readSignedVarInt(result&: value); }); |
175 | } |
176 | |
177 | /// Parse a variable length encoded integer whose low bit is used to encode an |
178 | /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. |
179 | LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) { |
180 | if (failed(result: readVarInt(result))) |
181 | return failure(); |
182 | flag = result & 1; |
183 | result >>= 1; |
184 | return success(); |
185 | } |
186 | |
187 | /// Read a "small" sparse array of integer <= 32 bits elements, where |
188 | /// index/value pairs can be compressed when the array is small. |
189 | /// Note that only some position of the array will be read and the ones |
190 | /// not stored in the bytecode are gonne be left untouched. |
191 | /// If the provided array is too small for the stored indices, an error |
192 | /// will be returned. |
193 | template <typename T> |
194 | LogicalResult readSparseArray(MutableArrayRef<T> array) { |
195 | static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits" ); |
196 | static_assert(std::is_integral<T>::value, "expects integer" ); |
197 | uint64_t nonZeroesCount; |
198 | bool useSparseEncoding; |
199 | if (failed(result: readVarIntWithFlag(result&: nonZeroesCount, flag&: useSparseEncoding))) |
200 | return failure(); |
201 | if (nonZeroesCount == 0) |
202 | return success(); |
203 | if (!useSparseEncoding) { |
204 | // This is a simple dense array. |
205 | if (nonZeroesCount > array.size()) { |
206 | emitError(msg: "trying to read an array of " ) |
207 | << nonZeroesCount << " but only " << array.size() |
208 | << " storage available." ; |
209 | return failure(); |
210 | } |
211 | for (int64_t index : llvm::seq<int64_t>(Begin: 0, End: nonZeroesCount)) { |
212 | uint64_t value; |
213 | if (failed(result: readVarInt(result&: value))) |
214 | return failure(); |
215 | array[index] = value; |
216 | } |
217 | return success(); |
218 | } |
219 | // Read sparse encoding |
220 | // This is the number of bits used for packing the index with the value. |
221 | uint64_t indexBitSize; |
222 | if (failed(result: readVarInt(result&: indexBitSize))) |
223 | return failure(); |
224 | constexpr uint64_t maxIndexBitSize = 8; |
225 | if (indexBitSize > maxIndexBitSize) { |
226 | emitError(msg: "reading sparse array with indexing above 8 bits: " ) |
227 | << indexBitSize; |
228 | return failure(); |
229 | } |
230 | for (uint32_t count : llvm::seq<uint32_t>(Begin: 0, End: nonZeroesCount)) { |
231 | (void)count; |
232 | uint64_t indexValuePair; |
233 | if (failed(result: readVarInt(result&: indexValuePair))) |
234 | return failure(); |
235 | uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize)); |
236 | uint64_t value = indexValuePair >> indexBitSize; |
237 | if (index >= array.size()) { |
238 | emitError(msg: "reading a sparse array found index " ) |
239 | << index << " but only " << array.size() << " storage available." ; |
240 | return failure(); |
241 | } |
242 | array[index] = value; |
243 | } |
244 | return success(); |
245 | } |
246 | |
247 | /// Read an APInt that is known to have been encoded with the given width. |
248 | virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0; |
249 | |
250 | /// Read an APFloat that is known to have been encoded with the given |
251 | /// semantics. |
252 | virtual FailureOr<APFloat> |
253 | readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) = 0; |
254 | |
255 | /// Read a string from the bytecode. |
256 | virtual LogicalResult readString(StringRef &result) = 0; |
257 | |
258 | /// Read a blob from the bytecode. |
259 | virtual LogicalResult readBlob(ArrayRef<char> &result) = 0; |
260 | |
261 | /// Read a bool from the bytecode. |
262 | virtual LogicalResult readBool(bool &result) = 0; |
263 | |
264 | private: |
265 | /// Read a handle to a dialect resource. |
266 | virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0; |
267 | }; |
268 | |
269 | //===----------------------------------------------------------------------===// |
270 | // DialectBytecodeWriter |
271 | //===----------------------------------------------------------------------===// |
272 | |
273 | /// This class defines a virtual interface for writing to a bytecode stream, |
274 | /// providing hooks into the bytecode writer. As such, this class should only be |
275 | /// derived and defined by the main bytecode writer, users (i.e. dialects) |
276 | /// should generally only interact with this class via the |
277 | /// BytecodeDialectInterface below. |
278 | class DialectBytecodeWriter { |
279 | public: |
280 | virtual ~DialectBytecodeWriter() = default; |
281 | |
282 | //===--------------------------------------------------------------------===// |
283 | // IR |
284 | //===--------------------------------------------------------------------===// |
285 | |
286 | /// Write out a list of elements, invoking the provided callback for each |
287 | /// element. |
288 | template <typename RangeT, typename CallbackFn> |
289 | void writeList(RangeT &&range, CallbackFn &&callback) { |
290 | writeVarInt(value: llvm::size(range)); |
291 | for (auto &element : range) |
292 | callback(element); |
293 | } |
294 | |
295 | /// Write a reference to the given attribute. |
296 | virtual void writeAttribute(Attribute attr) = 0; |
297 | virtual void writeOptionalAttribute(Attribute attr) = 0; |
298 | template <typename T> |
299 | void writeAttributes(ArrayRef<T> attrs) { |
300 | writeList(attrs, [this](T attr) { writeAttribute(attr); }); |
301 | } |
302 | |
303 | /// Write a reference to the given type. |
304 | virtual void writeType(Type type) = 0; |
305 | template <typename T> |
306 | void writeTypes(ArrayRef<T> types) { |
307 | writeList(types, [this](T type) { writeType(type); }); |
308 | } |
309 | |
310 | /// Write the given handle to a dialect resource. |
311 | virtual void |
312 | writeResourceHandle(const AsmDialectResourceHandle &resource) = 0; |
313 | |
314 | //===--------------------------------------------------------------------===// |
315 | // Primitives |
316 | //===--------------------------------------------------------------------===// |
317 | |
318 | /// Write a variable width integer to the output stream. This should be the |
319 | /// preferred method for emitting integers whenever possible. |
320 | virtual void writeVarInt(uint64_t value) = 0; |
321 | |
322 | /// Write a signed variable width integer to the output stream. This should be |
323 | /// the preferred method for emitting signed integers whenever possible. |
324 | virtual void writeSignedVarInt(int64_t value) = 0; |
325 | void writeSignedVarInts(ArrayRef<int64_t> value) { |
326 | writeList(range&: value, callback: [this](int64_t value) { writeSignedVarInt(value); }); |
327 | } |
328 | |
329 | /// Write a VarInt and a flag packed together. |
330 | void writeVarIntWithFlag(uint64_t value, bool flag) { |
331 | writeVarInt(value: (value << 1) | (flag ? 1 : 0)); |
332 | } |
333 | |
334 | /// Write out a "small" sparse array of integer <= 32 bits elements, where |
335 | /// index/value pairs can be compressed when the array is small. This method |
336 | /// will scan the array multiple times and should not be used for large |
337 | /// arrays. The optional provided "zero" can be used to adjust for the |
338 | /// expected repeated value. We assume here that the array size fits in a 32 |
339 | /// bits integer. |
340 | template <typename T> |
341 | void writeSparseArray(ArrayRef<T> array) { |
342 | static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits" ); |
343 | static_assert(std::is_integral<T>::value, "expects integer" ); |
344 | uint32_t size = array.size(); |
345 | uint32_t nonZeroesCount = 0, lastIndex = 0; |
346 | for (uint32_t index : llvm::seq<uint32_t>(Begin: 0, End: size)) { |
347 | if (!array[index]) |
348 | continue; |
349 | nonZeroesCount++; |
350 | lastIndex = index; |
351 | } |
352 | // If the last position is too large, or the array isn't at least 50% |
353 | // sparse, emit it with a dense encoding. |
354 | if (lastIndex > 256 || nonZeroesCount > size / 2) { |
355 | // Emit the array size and a flag which indicates whether it is sparse. |
356 | writeVarIntWithFlag(value: size, flag: false); |
357 | for (const T &elt : array) |
358 | writeVarInt(value: elt); |
359 | return; |
360 | } |
361 | // Emit sparse: first the number of elements we'll write and a flag |
362 | // indicating it is a sparse encoding. |
363 | writeVarIntWithFlag(value: nonZeroesCount, flag: true); |
364 | if (nonZeroesCount == 0) |
365 | return; |
366 | // This is the number of bits used for packing the index with the value. |
367 | int indexBitSize = llvm::Log2_32_Ceil(Value: lastIndex + 1); |
368 | writeVarInt(value: indexBitSize); |
369 | for (uint32_t index : llvm::seq<uint32_t>(Begin: 0, End: lastIndex + 1)) { |
370 | T value = array[index]; |
371 | if (!value) |
372 | continue; |
373 | uint64_t indexValuePair = (value << indexBitSize) | (index); |
374 | writeVarInt(value: indexValuePair); |
375 | } |
376 | } |
377 | |
378 | /// Write an APInt to the bytecode stream whose bitwidth will be known |
379 | /// externally at read time. This method is useful for encoding APInt values |
380 | /// when the width is known via external means, such as via a type. This |
381 | /// method should generally only be invoked if you need an APInt, otherwise |
382 | /// use the varint methods above. APInt values are generally encoded using |
383 | /// zigzag encoding, to enable more efficient encodings for negative values. |
384 | virtual void writeAPIntWithKnownWidth(const APInt &value) = 0; |
385 | |
386 | /// Write an APFloat to the bytecode stream whose semantics will be known |
387 | /// externally at read time. This method is useful for encoding APFloat values |
388 | /// when the semantics are known via external means, such as via a type. |
389 | virtual void writeAPFloatWithKnownSemantics(const APFloat &value) = 0; |
390 | |
391 | /// Write a string to the bytecode, which is owned by the caller and is |
392 | /// guaranteed to not die before the end of the bytecode process. This should |
393 | /// only be called if such a guarantee can be made, such as when the string is |
394 | /// owned by an attribute or type. |
395 | virtual void writeOwnedString(StringRef str) = 0; |
396 | |
397 | /// Write a blob to the bytecode, which is owned by the caller and is |
398 | /// guaranteed to not die before the end of the bytecode process. The blob is |
399 | /// written as-is, with no additional compression or compaction. |
400 | virtual void writeOwnedBlob(ArrayRef<char> blob) = 0; |
401 | |
402 | /// Write a bool to the output stream. |
403 | virtual void writeOwnedBool(bool value) = 0; |
404 | |
405 | /// Return the bytecode version being emitted for. |
406 | virtual int64_t getBytecodeVersion() const = 0; |
407 | |
408 | /// Retrieve the dialect version by name if available. |
409 | virtual FailureOr<const DialectVersion *> |
410 | getDialectVersion(StringRef dialectName) const = 0; |
411 | |
412 | template <class T> |
413 | FailureOr<const DialectVersion *> getDialectVersion() const { |
414 | return getDialectVersion(T::getDialectNamespace()); |
415 | } |
416 | }; |
417 | |
418 | //===----------------------------------------------------------------------===// |
419 | // BytecodeDialectInterface |
420 | //===----------------------------------------------------------------------===// |
421 | |
422 | class BytecodeDialectInterface |
423 | : public DialectInterface::Base<BytecodeDialectInterface> { |
424 | public: |
425 | using Base::Base; |
426 | |
427 | //===--------------------------------------------------------------------===// |
428 | // Reading |
429 | //===--------------------------------------------------------------------===// |
430 | |
431 | /// Read an attribute belonging to this dialect from the given reader. This |
432 | /// method should return null in the case of failure. Optionally, the dialect |
433 | /// version can be accessed through the reader. |
434 | virtual Attribute readAttribute(DialectBytecodeReader &reader) const { |
435 | reader.emitError() << "dialect " << getDialect()->getNamespace() |
436 | << " does not support reading attributes from bytecode" ; |
437 | return Attribute(); |
438 | } |
439 | |
440 | /// Read a type belonging to this dialect from the given reader. This method |
441 | /// should return null in the case of failure. Optionally, the dialect version |
442 | /// can be accessed thorugh the reader. |
443 | virtual Type readType(DialectBytecodeReader &reader) const { |
444 | reader.emitError() << "dialect " << getDialect()->getNamespace() |
445 | << " does not support reading types from bytecode" ; |
446 | return Type(); |
447 | } |
448 | |
449 | //===--------------------------------------------------------------------===// |
450 | // Writing |
451 | //===--------------------------------------------------------------------===// |
452 | |
453 | /// Write the given attribute, which belongs to this dialect, to the given |
454 | /// writer. This method may return failure to indicate that the given |
455 | /// attribute could not be encoded, in which case the textual format will be |
456 | /// used to encode this attribute instead. |
457 | virtual LogicalResult writeAttribute(Attribute attr, |
458 | DialectBytecodeWriter &writer) const { |
459 | return failure(); |
460 | } |
461 | |
462 | /// Write the given type, which belongs to this dialect, to the given writer. |
463 | /// This method may return failure to indicate that the given type could not |
464 | /// be encoded, in which case the textual format will be used to encode this |
465 | /// type instead. |
466 | virtual LogicalResult writeType(Type type, |
467 | DialectBytecodeWriter &writer) const { |
468 | return failure(); |
469 | } |
470 | |
471 | /// Write the version of this dialect to the given writer. |
472 | virtual void writeVersion(DialectBytecodeWriter &writer) const {} |
473 | |
474 | // Read the version of this dialect from the provided reader and return it as |
475 | // a `unique_ptr` to a dialect version object. |
476 | virtual std::unique_ptr<DialectVersion> |
477 | readVersion(DialectBytecodeReader &reader) const { |
478 | reader.emitError(msg: "Dialect does not support versioning" ); |
479 | return nullptr; |
480 | } |
481 | |
482 | /// Hook invoked after parsing completed, if a version directive was present |
483 | /// and included an entry for the current dialect. This hook offers the |
484 | /// opportunity to the dialect to visit the IR and upgrades constructs emitted |
485 | /// by the version of the dialect corresponding to the provided version. |
486 | virtual LogicalResult |
487 | upgradeFromVersion(Operation *topLevelOp, |
488 | const DialectVersion &version) const { |
489 | return success(); |
490 | } |
491 | }; |
492 | |
493 | /// Helper for resource handle reading that returns LogicalResult. |
494 | template <typename T, typename... Ts> |
495 | static LogicalResult readResourceHandle(DialectBytecodeReader &reader, |
496 | FailureOr<T> &value, Ts &&...params) { |
497 | FailureOr<T> handle = reader.readResourceHandle<T>(); |
498 | if (failed(handle)) |
499 | return failure(); |
500 | if (auto *result = dyn_cast<T>(&*handle)) { |
501 | value = std::move(*result); |
502 | return success(); |
503 | } |
504 | return failure(); |
505 | } |
506 | |
507 | /// Helper method that injects context only if needed, this helps unify some of |
508 | /// the attribute construction methods. |
509 | template <typename T, typename... Ts> |
510 | auto get(MLIRContext *context, Ts &&...params) { |
511 | // Prefer a direct `get` method if one exists. |
512 | if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) { |
513 | (void)context; |
514 | return T::get(std::forward<Ts>(params)...); |
515 | } else if constexpr (llvm::is_detected<detail::has_get_method, T, |
516 | MLIRContext *, Ts...>::value) { |
517 | return T::get(context, std::forward<Ts>(params)...); |
518 | } else { |
519 | // Otherwise, pass to the base get. |
520 | return T::Base::get(context, std::forward<Ts>(params)...); |
521 | } |
522 | } |
523 | |
524 | } // namespace mlir |
525 | |
526 | #endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H |
527 | |