1//===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header defines various interfaces and utilities necessary for dialects
10// to hook into bytecode serialization.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
15#define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
16
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/Dialect.h"
20#include "mlir/IR/DialectInterface.h"
21#include "mlir/IR/OpImplementation.h"
22#include "mlir/Support/LogicalResult.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Twine.h"
25
26namespace mlir {
27//===--------------------------------------------------------------------===//
28// Dialect Version Interface.
29//===--------------------------------------------------------------------===//
30
31/// This class is used to represent the version of a dialect, for the purpose
32/// of polymorphic destruction.
33class DialectVersion {
34public:
35 virtual ~DialectVersion() = default;
36};
37
38//===----------------------------------------------------------------------===//
39// DialectBytecodeReader
40//===----------------------------------------------------------------------===//
41
42/// This class defines a virtual interface for reading a bytecode stream,
43/// providing hooks into the bytecode reader. As such, this class should only be
44/// derived and defined by the main bytecode reader, users (i.e. dialects)
45/// should generally only interact with this class via the
46/// BytecodeDialectInterface below.
47class DialectBytecodeReader {
48public:
49 virtual ~DialectBytecodeReader() = default;
50
51 /// Emit an error to the reader.
52 virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0;
53
54 /// Retrieve the dialect version by name if available.
55 virtual FailureOr<const DialectVersion *>
56 getDialectVersion(StringRef dialectName) const = 0;
57 template <class T>
58 FailureOr<const DialectVersion *> getDialectVersion() const {
59 return getDialectVersion(T::getDialectNamespace());
60 }
61
62 /// Retrieve the context associated to the reader.
63 virtual MLIRContext *getContext() const = 0;
64
65 /// Return the bytecode version being read.
66 virtual uint64_t getBytecodeVersion() const = 0;
67
68 /// Read out a list of elements, invoking the provided callback for each
69 /// element. The callback function may be in any of the following forms:
70 /// * LogicalResult(T &)
71 /// * FailureOr<T>()
72 template <typename T, typename CallbackFn>
73 LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) {
74 uint64_t size;
75 if (failed(result: readVarInt(result&: size)))
76 return failure();
77 result.reserve(size);
78
79 for (uint64_t i = 0; i < size; ++i) {
80 // Check if the callback uses FailureOr, or populates the result by
81 // reference.
82 if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) {
83 T element = {};
84 if (failed(callback(element)))
85 return failure();
86 result.emplace_back(std::move(element));
87 } else {
88 FailureOr<T> element = callback();
89 if (failed(element))
90 return failure();
91 result.emplace_back(std::move(*element));
92 }
93 }
94 return success();
95 }
96
97 //===--------------------------------------------------------------------===//
98 // IR
99 //===--------------------------------------------------------------------===//
100
101 /// Read a reference to the given attribute.
102 virtual LogicalResult readAttribute(Attribute &result) = 0;
103 /// Read an optional reference to the given attribute. Returns success even if
104 /// the Attribute isn't present.
105 virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0;
106
107 template <typename T>
108 LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
109 return readList(attrs, [this](T &attr) { return readAttribute(attr); });
110 }
111 template <typename T>
112 LogicalResult readAttribute(T &result) {
113 Attribute baseResult;
114 if (failed(result: readAttribute(result&: baseResult)))
115 return failure();
116 if ((result = dyn_cast<T>(baseResult)))
117 return success();
118 return emitError() << "expected " << llvm::getTypeName<T>()
119 << ", but got: " << baseResult;
120 }
121 template <typename T>
122 LogicalResult readOptionalAttribute(T &result) {
123 Attribute baseResult;
124 if (failed(result: readOptionalAttribute(attr&: baseResult)))
125 return failure();
126 if (!baseResult)
127 return success();
128 if ((result = dyn_cast<T>(baseResult)))
129 return success();
130 return emitError() << "expected " << llvm::getTypeName<T>()
131 << ", but got: " << baseResult;
132 }
133
134 /// Read a reference to the given type.
135 virtual LogicalResult readType(Type &result) = 0;
136 template <typename T>
137 LogicalResult readTypes(SmallVectorImpl<T> &types) {
138 return readList(types, [this](T &type) { return readType(type); });
139 }
140 template <typename T>
141 LogicalResult readType(T &result) {
142 Type baseResult;
143 if (failed(result: readType(result&: baseResult)))
144 return failure();
145 if ((result = dyn_cast<T>(baseResult)))
146 return success();
147 return emitError() << "expected " << llvm::getTypeName<T>()
148 << ", but got: " << baseResult;
149 }
150
151 /// Read a handle to a dialect resource.
152 template <typename ResourceT>
153 FailureOr<ResourceT> readResourceHandle() {
154 FailureOr<AsmDialectResourceHandle> handle = readResourceHandle();
155 if (failed(result: handle))
156 return failure();
157 if (auto *result = dyn_cast<ResourceT>(&*handle))
158 return std::move(*result);
159 return emitError() << "provided resource handle differs from the "
160 "expected resource type";
161 }
162
163 //===--------------------------------------------------------------------===//
164 // Primitives
165 //===--------------------------------------------------------------------===//
166
167 /// Read a variable width integer.
168 virtual LogicalResult readVarInt(uint64_t &result) = 0;
169
170 /// Read a signed variable width integer.
171 virtual LogicalResult readSignedVarInt(int64_t &result) = 0;
172 LogicalResult readSignedVarInts(SmallVectorImpl<int64_t> &result) {
173 return readList(result,
174 callback: [this](int64_t &value) { return readSignedVarInt(result&: value); });
175 }
176
177 /// Parse a variable length encoded integer whose low bit is used to encode an
178 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
179 LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) {
180 if (failed(result: readVarInt(result)))
181 return failure();
182 flag = result & 1;
183 result >>= 1;
184 return success();
185 }
186
187 /// Read a "small" sparse array of integer <= 32 bits elements, where
188 /// index/value pairs can be compressed when the array is small.
189 /// Note that only some position of the array will be read and the ones
190 /// not stored in the bytecode are gonne be left untouched.
191 /// If the provided array is too small for the stored indices, an error
192 /// will be returned.
193 template <typename T>
194 LogicalResult readSparseArray(MutableArrayRef<T> array) {
195 static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
196 static_assert(std::is_integral<T>::value, "expects integer");
197 uint64_t nonZeroesCount;
198 bool useSparseEncoding;
199 if (failed(result: readVarIntWithFlag(result&: nonZeroesCount, flag&: useSparseEncoding)))
200 return failure();
201 if (nonZeroesCount == 0)
202 return success();
203 if (!useSparseEncoding) {
204 // This is a simple dense array.
205 if (nonZeroesCount > array.size()) {
206 emitError(msg: "trying to read an array of ")
207 << nonZeroesCount << " but only " << array.size()
208 << " storage available.";
209 return failure();
210 }
211 for (int64_t index : llvm::seq<int64_t>(Begin: 0, End: nonZeroesCount)) {
212 uint64_t value;
213 if (failed(result: readVarInt(result&: value)))
214 return failure();
215 array[index] = value;
216 }
217 return success();
218 }
219 // Read sparse encoding
220 // This is the number of bits used for packing the index with the value.
221 uint64_t indexBitSize;
222 if (failed(result: readVarInt(result&: indexBitSize)))
223 return failure();
224 constexpr uint64_t maxIndexBitSize = 8;
225 if (indexBitSize > maxIndexBitSize) {
226 emitError(msg: "reading sparse array with indexing above 8 bits: ")
227 << indexBitSize;
228 return failure();
229 }
230 for (uint32_t count : llvm::seq<uint32_t>(Begin: 0, End: nonZeroesCount)) {
231 (void)count;
232 uint64_t indexValuePair;
233 if (failed(result: readVarInt(result&: indexValuePair)))
234 return failure();
235 uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize));
236 uint64_t value = indexValuePair >> indexBitSize;
237 if (index >= array.size()) {
238 emitError(msg: "reading a sparse array found index ")
239 << index << " but only " << array.size() << " storage available.";
240 return failure();
241 }
242 array[index] = value;
243 }
244 return success();
245 }
246
247 /// Read an APInt that is known to have been encoded with the given width.
248 virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0;
249
250 /// Read an APFloat that is known to have been encoded with the given
251 /// semantics.
252 virtual FailureOr<APFloat>
253 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) = 0;
254
255 /// Read a string from the bytecode.
256 virtual LogicalResult readString(StringRef &result) = 0;
257
258 /// Read a blob from the bytecode.
259 virtual LogicalResult readBlob(ArrayRef<char> &result) = 0;
260
261 /// Read a bool from the bytecode.
262 virtual LogicalResult readBool(bool &result) = 0;
263
264private:
265 /// Read a handle to a dialect resource.
266 virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0;
267};
268
269//===----------------------------------------------------------------------===//
270// DialectBytecodeWriter
271//===----------------------------------------------------------------------===//
272
273/// This class defines a virtual interface for writing to a bytecode stream,
274/// providing hooks into the bytecode writer. As such, this class should only be
275/// derived and defined by the main bytecode writer, users (i.e. dialects)
276/// should generally only interact with this class via the
277/// BytecodeDialectInterface below.
278class DialectBytecodeWriter {
279public:
280 virtual ~DialectBytecodeWriter() = default;
281
282 //===--------------------------------------------------------------------===//
283 // IR
284 //===--------------------------------------------------------------------===//
285
286 /// Write out a list of elements, invoking the provided callback for each
287 /// element.
288 template <typename RangeT, typename CallbackFn>
289 void writeList(RangeT &&range, CallbackFn &&callback) {
290 writeVarInt(value: llvm::size(range));
291 for (auto &element : range)
292 callback(element);
293 }
294
295 /// Write a reference to the given attribute.
296 virtual void writeAttribute(Attribute attr) = 0;
297 virtual void writeOptionalAttribute(Attribute attr) = 0;
298 template <typename T>
299 void writeAttributes(ArrayRef<T> attrs) {
300 writeList(attrs, [this](T attr) { writeAttribute(attr); });
301 }
302
303 /// Write a reference to the given type.
304 virtual void writeType(Type type) = 0;
305 template <typename T>
306 void writeTypes(ArrayRef<T> types) {
307 writeList(types, [this](T type) { writeType(type); });
308 }
309
310 /// Write the given handle to a dialect resource.
311 virtual void
312 writeResourceHandle(const AsmDialectResourceHandle &resource) = 0;
313
314 //===--------------------------------------------------------------------===//
315 // Primitives
316 //===--------------------------------------------------------------------===//
317
318 /// Write a variable width integer to the output stream. This should be the
319 /// preferred method for emitting integers whenever possible.
320 virtual void writeVarInt(uint64_t value) = 0;
321
322 /// Write a signed variable width integer to the output stream. This should be
323 /// the preferred method for emitting signed integers whenever possible.
324 virtual void writeSignedVarInt(int64_t value) = 0;
325 void writeSignedVarInts(ArrayRef<int64_t> value) {
326 writeList(range&: value, callback: [this](int64_t value) { writeSignedVarInt(value); });
327 }
328
329 /// Write a VarInt and a flag packed together.
330 void writeVarIntWithFlag(uint64_t value, bool flag) {
331 writeVarInt(value: (value << 1) | (flag ? 1 : 0));
332 }
333
334 /// Write out a "small" sparse array of integer <= 32 bits elements, where
335 /// index/value pairs can be compressed when the array is small. This method
336 /// will scan the array multiple times and should not be used for large
337 /// arrays. The optional provided "zero" can be used to adjust for the
338 /// expected repeated value. We assume here that the array size fits in a 32
339 /// bits integer.
340 template <typename T>
341 void writeSparseArray(ArrayRef<T> array) {
342 static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
343 static_assert(std::is_integral<T>::value, "expects integer");
344 uint32_t size = array.size();
345 uint32_t nonZeroesCount = 0, lastIndex = 0;
346 for (uint32_t index : llvm::seq<uint32_t>(Begin: 0, End: size)) {
347 if (!array[index])
348 continue;
349 nonZeroesCount++;
350 lastIndex = index;
351 }
352 // If the last position is too large, or the array isn't at least 50%
353 // sparse, emit it with a dense encoding.
354 if (lastIndex > 256 || nonZeroesCount > size / 2) {
355 // Emit the array size and a flag which indicates whether it is sparse.
356 writeVarIntWithFlag(value: size, flag: false);
357 for (const T &elt : array)
358 writeVarInt(value: elt);
359 return;
360 }
361 // Emit sparse: first the number of elements we'll write and a flag
362 // indicating it is a sparse encoding.
363 writeVarIntWithFlag(value: nonZeroesCount, flag: true);
364 if (nonZeroesCount == 0)
365 return;
366 // This is the number of bits used for packing the index with the value.
367 int indexBitSize = llvm::Log2_32_Ceil(Value: lastIndex + 1);
368 writeVarInt(value: indexBitSize);
369 for (uint32_t index : llvm::seq<uint32_t>(Begin: 0, End: lastIndex + 1)) {
370 T value = array[index];
371 if (!value)
372 continue;
373 uint64_t indexValuePair = (value << indexBitSize) | (index);
374 writeVarInt(value: indexValuePair);
375 }
376 }
377
378 /// Write an APInt to the bytecode stream whose bitwidth will be known
379 /// externally at read time. This method is useful for encoding APInt values
380 /// when the width is known via external means, such as via a type. This
381 /// method should generally only be invoked if you need an APInt, otherwise
382 /// use the varint methods above. APInt values are generally encoded using
383 /// zigzag encoding, to enable more efficient encodings for negative values.
384 virtual void writeAPIntWithKnownWidth(const APInt &value) = 0;
385
386 /// Write an APFloat to the bytecode stream whose semantics will be known
387 /// externally at read time. This method is useful for encoding APFloat values
388 /// when the semantics are known via external means, such as via a type.
389 virtual void writeAPFloatWithKnownSemantics(const APFloat &value) = 0;
390
391 /// Write a string to the bytecode, which is owned by the caller and is
392 /// guaranteed to not die before the end of the bytecode process. This should
393 /// only be called if such a guarantee can be made, such as when the string is
394 /// owned by an attribute or type.
395 virtual void writeOwnedString(StringRef str) = 0;
396
397 /// Write a blob to the bytecode, which is owned by the caller and is
398 /// guaranteed to not die before the end of the bytecode process. The blob is
399 /// written as-is, with no additional compression or compaction.
400 virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
401
402 /// Write a bool to the output stream.
403 virtual void writeOwnedBool(bool value) = 0;
404
405 /// Return the bytecode version being emitted for.
406 virtual int64_t getBytecodeVersion() const = 0;
407
408 /// Retrieve the dialect version by name if available.
409 virtual FailureOr<const DialectVersion *>
410 getDialectVersion(StringRef dialectName) const = 0;
411
412 template <class T>
413 FailureOr<const DialectVersion *> getDialectVersion() const {
414 return getDialectVersion(T::getDialectNamespace());
415 }
416};
417
418//===----------------------------------------------------------------------===//
419// BytecodeDialectInterface
420//===----------------------------------------------------------------------===//
421
422class BytecodeDialectInterface
423 : public DialectInterface::Base<BytecodeDialectInterface> {
424public:
425 using Base::Base;
426
427 //===--------------------------------------------------------------------===//
428 // Reading
429 //===--------------------------------------------------------------------===//
430
431 /// Read an attribute belonging to this dialect from the given reader. This
432 /// method should return null in the case of failure. Optionally, the dialect
433 /// version can be accessed through the reader.
434 virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
435 reader.emitError() << "dialect " << getDialect()->getNamespace()
436 << " does not support reading attributes from bytecode";
437 return Attribute();
438 }
439
440 /// Read a type belonging to this dialect from the given reader. This method
441 /// should return null in the case of failure. Optionally, the dialect version
442 /// can be accessed thorugh the reader.
443 virtual Type readType(DialectBytecodeReader &reader) const {
444 reader.emitError() << "dialect " << getDialect()->getNamespace()
445 << " does not support reading types from bytecode";
446 return Type();
447 }
448
449 //===--------------------------------------------------------------------===//
450 // Writing
451 //===--------------------------------------------------------------------===//
452
453 /// Write the given attribute, which belongs to this dialect, to the given
454 /// writer. This method may return failure to indicate that the given
455 /// attribute could not be encoded, in which case the textual format will be
456 /// used to encode this attribute instead.
457 virtual LogicalResult writeAttribute(Attribute attr,
458 DialectBytecodeWriter &writer) const {
459 return failure();
460 }
461
462 /// Write the given type, which belongs to this dialect, to the given writer.
463 /// This method may return failure to indicate that the given type could not
464 /// be encoded, in which case the textual format will be used to encode this
465 /// type instead.
466 virtual LogicalResult writeType(Type type,
467 DialectBytecodeWriter &writer) const {
468 return failure();
469 }
470
471 /// Write the version of this dialect to the given writer.
472 virtual void writeVersion(DialectBytecodeWriter &writer) const {}
473
474 // Read the version of this dialect from the provided reader and return it as
475 // a `unique_ptr` to a dialect version object.
476 virtual std::unique_ptr<DialectVersion>
477 readVersion(DialectBytecodeReader &reader) const {
478 reader.emitError(msg: "Dialect does not support versioning");
479 return nullptr;
480 }
481
482 /// Hook invoked after parsing completed, if a version directive was present
483 /// and included an entry for the current dialect. This hook offers the
484 /// opportunity to the dialect to visit the IR and upgrades constructs emitted
485 /// by the version of the dialect corresponding to the provided version.
486 virtual LogicalResult
487 upgradeFromVersion(Operation *topLevelOp,
488 const DialectVersion &version) const {
489 return success();
490 }
491};
492
493/// Helper for resource handle reading that returns LogicalResult.
494template <typename T, typename... Ts>
495static LogicalResult readResourceHandle(DialectBytecodeReader &reader,
496 FailureOr<T> &value, Ts &&...params) {
497 FailureOr<T> handle = reader.readResourceHandle<T>();
498 if (failed(handle))
499 return failure();
500 if (auto *result = dyn_cast<T>(&*handle)) {
501 value = std::move(*result);
502 return success();
503 }
504 return failure();
505}
506
507/// Helper method that injects context only if needed, this helps unify some of
508/// the attribute construction methods.
509template <typename T, typename... Ts>
510auto get(MLIRContext *context, Ts &&...params) {
511 // Prefer a direct `get` method if one exists.
512 if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {
513 (void)context;
514 return T::get(std::forward<Ts>(params)...);
515 } else if constexpr (llvm::is_detected<detail::has_get_method, T,
516 MLIRContext *, Ts...>::value) {
517 return T::get(context, std::forward<Ts>(params)...);
518 } else {
519 // Otherwise, pass to the base get.
520 return T::Base::get(context, std::forward<Ts>(params)...);
521 }
522}
523
524} // namespace mlir
525
526#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
527

source code of mlir/include/mlir/Bytecode/BytecodeImplementation.h