1 | //===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header defines various interfaces and utilities necessary for dialects |
10 | // to hook into bytecode serialization. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H |
15 | #define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H |
16 | |
17 | #include "mlir/IR/Attributes.h" |
18 | #include "mlir/IR/Diagnostics.h" |
19 | #include "mlir/IR/Dialect.h" |
20 | #include "mlir/IR/DialectInterface.h" |
21 | #include "mlir/IR/OpImplementation.h" |
22 | #include "llvm/ADT/STLExtras.h" |
23 | #include "llvm/ADT/Twine.h" |
24 | |
25 | namespace mlir { |
26 | //===--------------------------------------------------------------------===// |
27 | // Dialect Version Interface. |
28 | //===--------------------------------------------------------------------===// |
29 | |
30 | /// This class is used to represent the version of a dialect, for the purpose |
31 | /// of polymorphic destruction. |
32 | class DialectVersion { |
33 | public: |
34 | virtual ~DialectVersion() = default; |
35 | }; |
36 | |
37 | //===----------------------------------------------------------------------===// |
38 | // DialectBytecodeReader |
39 | //===----------------------------------------------------------------------===// |
40 | |
41 | /// This class defines a virtual interface for reading a bytecode stream, |
42 | /// providing hooks into the bytecode reader. As such, this class should only be |
43 | /// derived and defined by the main bytecode reader, users (i.e. dialects) |
44 | /// should generally only interact with this class via the |
45 | /// BytecodeDialectInterface below. |
46 | class DialectBytecodeReader { |
47 | public: |
48 | virtual ~DialectBytecodeReader() = default; |
49 | |
50 | /// Emit an error to the reader. |
51 | virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0; |
52 | |
53 | /// Retrieve the dialect version by name if available. |
54 | virtual FailureOr<const DialectVersion *> |
55 | getDialectVersion(StringRef dialectName) const = 0; |
56 | template <class T> |
57 | FailureOr<const DialectVersion *> getDialectVersion() const { |
58 | return getDialectVersion(T::getDialectNamespace()); |
59 | } |
60 | |
61 | /// Retrieve the context associated to the reader. |
62 | virtual MLIRContext *getContext() const = 0; |
63 | |
64 | /// Return the bytecode version being read. |
65 | virtual uint64_t getBytecodeVersion() const = 0; |
66 | |
67 | /// Read out a list of elements, invoking the provided callback for each |
68 | /// element. The callback function may be in any of the following forms: |
69 | /// * LogicalResult(T &) |
70 | /// * FailureOr<T>() |
71 | template <typename T, typename CallbackFn> |
72 | LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) { |
73 | uint64_t size; |
74 | if (failed(Result: readVarInt(result&: size))) |
75 | return failure(); |
76 | result.reserve(size); |
77 | |
78 | for (uint64_t i = 0; i < size; ++i) { |
79 | // Check if the callback uses FailureOr, or populates the result by |
80 | // reference. |
81 | if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) { |
82 | T element = {}; |
83 | if (failed(callback(element))) |
84 | return failure(); |
85 | result.emplace_back(std::move(element)); |
86 | } else { |
87 | FailureOr<T> element = callback(); |
88 | if (failed(element)) |
89 | return failure(); |
90 | result.emplace_back(std::move(*element)); |
91 | } |
92 | } |
93 | return success(); |
94 | } |
95 | |
96 | //===--------------------------------------------------------------------===// |
97 | // IR |
98 | //===--------------------------------------------------------------------===// |
99 | |
100 | /// Read a reference to the given attribute. |
101 | virtual LogicalResult readAttribute(Attribute &result) = 0; |
102 | /// Read an optional reference to the given attribute. Returns success even if |
103 | /// the Attribute isn't present. |
104 | virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0; |
105 | |
106 | template <typename T> |
107 | LogicalResult readAttributes(SmallVectorImpl<T> &attrs) { |
108 | return readList(attrs, [this](T &attr) { return readAttribute(attr); }); |
109 | } |
110 | template <typename T> |
111 | LogicalResult readAttribute(T &result) { |
112 | Attribute baseResult; |
113 | if (failed(Result: readAttribute(result&: baseResult))) |
114 | return failure(); |
115 | if ((result = dyn_cast<T>(baseResult))) |
116 | return success(); |
117 | return emitError() << "expected "<< llvm::getTypeName<T>() |
118 | << ", but got: "<< baseResult; |
119 | } |
120 | template <typename T> |
121 | LogicalResult readOptionalAttribute(T &result) { |
122 | Attribute baseResult; |
123 | if (failed(Result: readOptionalAttribute(attr&: baseResult))) |
124 | return failure(); |
125 | if (!baseResult) |
126 | return success(); |
127 | if ((result = dyn_cast<T>(baseResult))) |
128 | return success(); |
129 | return emitError() << "expected "<< llvm::getTypeName<T>() |
130 | << ", but got: "<< baseResult; |
131 | } |
132 | |
133 | /// Read a reference to the given type. |
134 | virtual LogicalResult readType(Type &result) = 0; |
135 | template <typename T> |
136 | LogicalResult readTypes(SmallVectorImpl<T> &types) { |
137 | return readList(types, [this](T &type) { return readType(type); }); |
138 | } |
139 | template <typename T> |
140 | LogicalResult readType(T &result) { |
141 | Type baseResult; |
142 | if (failed(Result: readType(result&: baseResult))) |
143 | return failure(); |
144 | if ((result = dyn_cast<T>(baseResult))) |
145 | return success(); |
146 | return emitError() << "expected "<< llvm::getTypeName<T>() |
147 | << ", but got: "<< baseResult; |
148 | } |
149 | |
150 | /// Read a handle to a dialect resource. |
151 | template <typename ResourceT> |
152 | FailureOr<ResourceT> readResourceHandle() { |
153 | FailureOr<AsmDialectResourceHandle> handle = readResourceHandle(); |
154 | if (failed(Result: handle)) |
155 | return failure(); |
156 | if (auto *result = dyn_cast<ResourceT>(&*handle)) |
157 | return std::move(*result); |
158 | return emitError() << "provided resource handle differs from the " |
159 | "expected resource type"; |
160 | } |
161 | |
162 | //===--------------------------------------------------------------------===// |
163 | // Primitives |
164 | //===--------------------------------------------------------------------===// |
165 | |
166 | /// Read a variable width integer. |
167 | virtual LogicalResult readVarInt(uint64_t &result) = 0; |
168 | |
169 | /// Read a signed variable width integer. |
170 | virtual LogicalResult readSignedVarInt(int64_t &result) = 0; |
171 | LogicalResult readSignedVarInts(SmallVectorImpl<int64_t> &result) { |
172 | return readList(result, |
173 | callback: [this](int64_t &value) { return readSignedVarInt(result&: value); }); |
174 | } |
175 | |
176 | /// Parse a variable length encoded integer whose low bit is used to encode an |
177 | /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. |
178 | LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) { |
179 | if (failed(Result: readVarInt(result))) |
180 | return failure(); |
181 | flag = result & 1; |
182 | result >>= 1; |
183 | return success(); |
184 | } |
185 | |
186 | /// Read a "small" sparse array of integer <= 32 bits elements, where |
187 | /// index/value pairs can be compressed when the array is small. |
188 | /// Note that only some position of the array will be read and the ones |
189 | /// not stored in the bytecode are gonne be left untouched. |
190 | /// If the provided array is too small for the stored indices, an error |
191 | /// will be returned. |
192 | template <typename T> |
193 | LogicalResult readSparseArray(MutableArrayRef<T> array) { |
194 | static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits"); |
195 | static_assert(std::is_integral<T>::value, "expects integer"); |
196 | uint64_t nonZeroesCount; |
197 | bool useSparseEncoding; |
198 | if (failed(Result: readVarIntWithFlag(result&: nonZeroesCount, flag&: useSparseEncoding))) |
199 | return failure(); |
200 | if (nonZeroesCount == 0) |
201 | return success(); |
202 | if (!useSparseEncoding) { |
203 | // This is a simple dense array. |
204 | if (nonZeroesCount > array.size()) { |
205 | emitError(msg: "trying to read an array of ") |
206 | << nonZeroesCount << " but only "<< array.size() |
207 | << " storage available."; |
208 | return failure(); |
209 | } |
210 | for (int64_t index : llvm::seq<int64_t>(Begin: 0, End: nonZeroesCount)) { |
211 | uint64_t value; |
212 | if (failed(Result: readVarInt(result&: value))) |
213 | return failure(); |
214 | array[index] = value; |
215 | } |
216 | return success(); |
217 | } |
218 | // Read sparse encoding |
219 | // This is the number of bits used for packing the index with the value. |
220 | uint64_t indexBitSize; |
221 | if (failed(Result: readVarInt(result&: indexBitSize))) |
222 | return failure(); |
223 | constexpr uint64_t maxIndexBitSize = 8; |
224 | if (indexBitSize > maxIndexBitSize) { |
225 | emitError(msg: "reading sparse array with indexing above 8 bits: ") |
226 | << indexBitSize; |
227 | return failure(); |
228 | } |
229 | for (uint32_t count : llvm::seq<uint32_t>(Begin: 0, End: nonZeroesCount)) { |
230 | (void)count; |
231 | uint64_t indexValuePair; |
232 | if (failed(Result: readVarInt(result&: indexValuePair))) |
233 | return failure(); |
234 | uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize)); |
235 | uint64_t value = indexValuePair >> indexBitSize; |
236 | if (index >= array.size()) { |
237 | emitError(msg: "reading a sparse array found index ") |
238 | << index << " but only "<< array.size() << " storage available."; |
239 | return failure(); |
240 | } |
241 | array[index] = value; |
242 | } |
243 | return success(); |
244 | } |
245 | |
246 | /// Read an APInt that is known to have been encoded with the given width. |
247 | virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0; |
248 | |
249 | /// Read an APFloat that is known to have been encoded with the given |
250 | /// semantics. |
251 | virtual FailureOr<APFloat> |
252 | readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) = 0; |
253 | |
254 | /// Read a string from the bytecode. |
255 | virtual LogicalResult readString(StringRef &result) = 0; |
256 | |
257 | /// Read a blob from the bytecode. |
258 | virtual LogicalResult readBlob(ArrayRef<char> &result) = 0; |
259 | |
260 | /// Read a bool from the bytecode. |
261 | virtual LogicalResult readBool(bool &result) = 0; |
262 | |
263 | private: |
264 | /// Read a handle to a dialect resource. |
265 | virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0; |
266 | }; |
267 | |
268 | //===----------------------------------------------------------------------===// |
269 | // DialectBytecodeWriter |
270 | //===----------------------------------------------------------------------===// |
271 | |
272 | /// This class defines a virtual interface for writing to a bytecode stream, |
273 | /// providing hooks into the bytecode writer. As such, this class should only be |
274 | /// derived and defined by the main bytecode writer, users (i.e. dialects) |
275 | /// should generally only interact with this class via the |
276 | /// BytecodeDialectInterface below. |
277 | class DialectBytecodeWriter { |
278 | public: |
279 | virtual ~DialectBytecodeWriter() = default; |
280 | |
281 | //===--------------------------------------------------------------------===// |
282 | // IR |
283 | //===--------------------------------------------------------------------===// |
284 | |
285 | /// Write out a list of elements, invoking the provided callback for each |
286 | /// element. |
287 | template <typename RangeT, typename CallbackFn> |
288 | void writeList(RangeT &&range, CallbackFn &&callback) { |
289 | writeVarInt(value: llvm::size(range)); |
290 | for (auto &element : range) |
291 | callback(element); |
292 | } |
293 | |
294 | /// Write a reference to the given attribute. |
295 | virtual void writeAttribute(Attribute attr) = 0; |
296 | virtual void writeOptionalAttribute(Attribute attr) = 0; |
297 | template <typename T> |
298 | void writeAttributes(ArrayRef<T> attrs) { |
299 | writeList(attrs, [this](T attr) { writeAttribute(attr); }); |
300 | } |
301 | |
302 | /// Write a reference to the given type. |
303 | virtual void writeType(Type type) = 0; |
304 | template <typename T> |
305 | void writeTypes(ArrayRef<T> types) { |
306 | writeList(types, [this](T type) { writeType(type); }); |
307 | } |
308 | |
309 | /// Write the given handle to a dialect resource. |
310 | virtual void |
311 | writeResourceHandle(const AsmDialectResourceHandle &resource) = 0; |
312 | |
313 | //===--------------------------------------------------------------------===// |
314 | // Primitives |
315 | //===--------------------------------------------------------------------===// |
316 | |
317 | /// Write a variable width integer to the output stream. This should be the |
318 | /// preferred method for emitting integers whenever possible. |
319 | virtual void writeVarInt(uint64_t value) = 0; |
320 | |
321 | /// Write a signed variable width integer to the output stream. This should be |
322 | /// the preferred method for emitting signed integers whenever possible. |
323 | virtual void writeSignedVarInt(int64_t value) = 0; |
324 | void writeSignedVarInts(ArrayRef<int64_t> value) { |
325 | writeList(range&: value, callback: [this](int64_t value) { writeSignedVarInt(value); }); |
326 | } |
327 | |
328 | /// Write a VarInt and a flag packed together. |
329 | void writeVarIntWithFlag(uint64_t value, bool flag) { |
330 | writeVarInt(value: (value << 1) | (flag ? 1 : 0)); |
331 | } |
332 | |
333 | /// Write out a "small" sparse array of integer <= 32 bits elements, where |
334 | /// index/value pairs can be compressed when the array is small. This method |
335 | /// will scan the array multiple times and should not be used for large |
336 | /// arrays. The optional provided "zero" can be used to adjust for the |
337 | /// expected repeated value. We assume here that the array size fits in a 32 |
338 | /// bits integer. |
339 | template <typename T> |
340 | void writeSparseArray(ArrayRef<T> array) { |
341 | static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits"); |
342 | static_assert(std::is_integral<T>::value, "expects integer"); |
343 | uint32_t size = array.size(); |
344 | uint32_t nonZeroesCount = 0, lastIndex = 0; |
345 | for (uint32_t index : llvm::seq<uint32_t>(Begin: 0, End: size)) { |
346 | if (!array[index]) |
347 | continue; |
348 | nonZeroesCount++; |
349 | lastIndex = index; |
350 | } |
351 | // If the last position is too large, or the array isn't at least 50% |
352 | // sparse, emit it with a dense encoding. |
353 | if (lastIndex > 256 || nonZeroesCount > size / 2) { |
354 | // Emit the array size and a flag which indicates whether it is sparse. |
355 | writeVarIntWithFlag(value: size, flag: false); |
356 | for (const T &elt : array) |
357 | writeVarInt(value: elt); |
358 | return; |
359 | } |
360 | // Emit sparse: first the number of elements we'll write and a flag |
361 | // indicating it is a sparse encoding. |
362 | writeVarIntWithFlag(value: nonZeroesCount, flag: true); |
363 | if (nonZeroesCount == 0) |
364 | return; |
365 | // This is the number of bits used for packing the index with the value. |
366 | int indexBitSize = llvm::Log2_32_Ceil(Value: lastIndex + 1); |
367 | writeVarInt(value: indexBitSize); |
368 | for (uint32_t index : llvm::seq<uint32_t>(Begin: 0, End: lastIndex + 1)) { |
369 | T value = array[index]; |
370 | if (!value) |
371 | continue; |
372 | uint64_t indexValuePair = (value << indexBitSize) | (index); |
373 | writeVarInt(value: indexValuePair); |
374 | } |
375 | } |
376 | |
377 | /// Write an APInt to the bytecode stream whose bitwidth will be known |
378 | /// externally at read time. This method is useful for encoding APInt values |
379 | /// when the width is known via external means, such as via a type. This |
380 | /// method should generally only be invoked if you need an APInt, otherwise |
381 | /// use the varint methods above. APInt values are generally encoded using |
382 | /// zigzag encoding, to enable more efficient encodings for negative values. |
383 | virtual void writeAPIntWithKnownWidth(const APInt &value) = 0; |
384 | |
385 | /// Write an APFloat to the bytecode stream whose semantics will be known |
386 | /// externally at read time. This method is useful for encoding APFloat values |
387 | /// when the semantics are known via external means, such as via a type. |
388 | virtual void writeAPFloatWithKnownSemantics(const APFloat &value) = 0; |
389 | |
390 | /// Write a string to the bytecode, which is owned by the caller and is |
391 | /// guaranteed to not die before the end of the bytecode process. This should |
392 | /// only be called if such a guarantee can be made, such as when the string is |
393 | /// owned by an attribute or type. |
394 | virtual void writeOwnedString(StringRef str) = 0; |
395 | |
396 | /// Write a blob to the bytecode, which is owned by the caller and is |
397 | /// guaranteed to not die before the end of the bytecode process. The blob is |
398 | /// written as-is, with no additional compression or compaction. |
399 | virtual void writeOwnedBlob(ArrayRef<char> blob) = 0; |
400 | |
401 | /// Write a bool to the output stream. |
402 | virtual void writeOwnedBool(bool value) = 0; |
403 | |
404 | /// Return the bytecode version being emitted for. |
405 | virtual int64_t getBytecodeVersion() const = 0; |
406 | |
407 | /// Retrieve the dialect version by name if available. |
408 | virtual FailureOr<const DialectVersion *> |
409 | getDialectVersion(StringRef dialectName) const = 0; |
410 | |
411 | template <class T> |
412 | FailureOr<const DialectVersion *> getDialectVersion() const { |
413 | return getDialectVersion(T::getDialectNamespace()); |
414 | } |
415 | }; |
416 | |
417 | //===----------------------------------------------------------------------===// |
418 | // BytecodeDialectInterface |
419 | //===----------------------------------------------------------------------===// |
420 | |
421 | class BytecodeDialectInterface |
422 | : public DialectInterface::Base<BytecodeDialectInterface> { |
423 | public: |
424 | using Base::Base; |
425 | |
426 | //===--------------------------------------------------------------------===// |
427 | // Reading |
428 | //===--------------------------------------------------------------------===// |
429 | |
430 | /// Read an attribute belonging to this dialect from the given reader. This |
431 | /// method should return null in the case of failure. Optionally, the dialect |
432 | /// version can be accessed through the reader. |
433 | virtual Attribute readAttribute(DialectBytecodeReader &reader) const { |
434 | reader.emitError() << "dialect "<< getDialect()->getNamespace() |
435 | << " does not support reading attributes from bytecode"; |
436 | return Attribute(); |
437 | } |
438 | |
439 | /// Read a type belonging to this dialect from the given reader. This method |
440 | /// should return null in the case of failure. Optionally, the dialect version |
441 | /// can be accessed thorugh the reader. |
442 | virtual Type readType(DialectBytecodeReader &reader) const { |
443 | reader.emitError() << "dialect "<< getDialect()->getNamespace() |
444 | << " does not support reading types from bytecode"; |
445 | return Type(); |
446 | } |
447 | |
448 | //===--------------------------------------------------------------------===// |
449 | // Writing |
450 | //===--------------------------------------------------------------------===// |
451 | |
452 | /// Write the given attribute, which belongs to this dialect, to the given |
453 | /// writer. This method may return failure to indicate that the given |
454 | /// attribute could not be encoded, in which case the textual format will be |
455 | /// used to encode this attribute instead. |
456 | virtual LogicalResult writeAttribute(Attribute attr, |
457 | DialectBytecodeWriter &writer) const { |
458 | return failure(); |
459 | } |
460 | |
461 | /// Write the given type, which belongs to this dialect, to the given writer. |
462 | /// This method may return failure to indicate that the given type could not |
463 | /// be encoded, in which case the textual format will be used to encode this |
464 | /// type instead. |
465 | virtual LogicalResult writeType(Type type, |
466 | DialectBytecodeWriter &writer) const { |
467 | return failure(); |
468 | } |
469 | |
470 | /// Write the version of this dialect to the given writer. |
471 | virtual void writeVersion(DialectBytecodeWriter &writer) const {} |
472 | |
473 | // Read the version of this dialect from the provided reader and return it as |
474 | // a `unique_ptr` to a dialect version object. |
475 | virtual std::unique_ptr<DialectVersion> |
476 | readVersion(DialectBytecodeReader &reader) const { |
477 | reader.emitError(msg: "Dialect does not support versioning"); |
478 | return nullptr; |
479 | } |
480 | |
481 | /// Hook invoked after parsing completed, if a version directive was present |
482 | /// and included an entry for the current dialect. This hook offers the |
483 | /// opportunity to the dialect to visit the IR and upgrades constructs emitted |
484 | /// by the version of the dialect corresponding to the provided version. |
485 | virtual LogicalResult |
486 | upgradeFromVersion(Operation *topLevelOp, |
487 | const DialectVersion &version) const { |
488 | return success(); |
489 | } |
490 | }; |
491 | |
492 | /// Helper for resource handle reading that returns LogicalResult. |
493 | template <typename T, typename... Ts> |
494 | static LogicalResult readResourceHandle(DialectBytecodeReader &reader, |
495 | FailureOr<T> &value, Ts &&...params) { |
496 | FailureOr<T> handle = reader.readResourceHandle<T>(); |
497 | if (failed(handle)) |
498 | return failure(); |
499 | if (auto *result = dyn_cast<T>(&*handle)) { |
500 | value = std::move(*result); |
501 | return success(); |
502 | } |
503 | return failure(); |
504 | } |
505 | |
506 | /// Helper method that injects context only if needed, this helps unify some of |
507 | /// the attribute construction methods. |
508 | template <typename T, typename... Ts> |
509 | auto get(MLIRContext *context, Ts &&...params) { |
510 | // Prefer a direct `get` method if one exists. |
511 | if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) { |
512 | (void)context; |
513 | return T::get(std::forward<Ts>(params)...); |
514 | } else if constexpr (llvm::is_detected<detail::has_get_method, T, |
515 | MLIRContext *, Ts...>::value) { |
516 | return T::get(context, std::forward<Ts>(params)...); |
517 | } else { |
518 | // Otherwise, pass to the base get. |
519 | return T::Base::get(context, std::forward<Ts>(params)...); |
520 | } |
521 | } |
522 | |
523 | } // namespace mlir |
524 | |
525 | #endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H |
526 |
Definitions
- DialectVersion
- ~DialectVersion
- DialectBytecodeReader
- ~DialectBytecodeReader
- getDialectVersion
- readList
- readAttributes
- readAttribute
- readOptionalAttribute
- readTypes
- readType
- readResourceHandle
- readSignedVarInts
- readVarIntWithFlag
- readSparseArray
- DialectBytecodeWriter
- ~DialectBytecodeWriter
- writeList
- writeAttributes
- writeTypes
- writeSignedVarInts
- writeVarIntWithFlag
- writeSparseArray
- getDialectVersion
- BytecodeDialectInterface
- readAttribute
- readType
- writeAttribute
- writeType
- writeVersion
- readVersion
- upgradeFromVersion
- readResourceHandle
Learn to use CMake with our Intro Training
Find out more