1//===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header defines various interfaces and utilities necessary for dialects
10// to hook into bytecode serialization.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
15#define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
16
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/Dialect.h"
20#include "mlir/IR/DialectInterface.h"
21#include "mlir/IR/OpImplementation.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/Twine.h"
24
25namespace mlir {
26//===--------------------------------------------------------------------===//
27// Dialect Version Interface.
28//===--------------------------------------------------------------------===//
29
30/// This class is used to represent the version of a dialect, for the purpose
31/// of polymorphic destruction.
32class DialectVersion {
33public:
34 virtual ~DialectVersion() = default;
35};
36
37//===----------------------------------------------------------------------===//
38// DialectBytecodeReader
39//===----------------------------------------------------------------------===//
40
41/// This class defines a virtual interface for reading a bytecode stream,
42/// providing hooks into the bytecode reader. As such, this class should only be
43/// derived and defined by the main bytecode reader, users (i.e. dialects)
44/// should generally only interact with this class via the
45/// BytecodeDialectInterface below.
46class DialectBytecodeReader {
47public:
48 virtual ~DialectBytecodeReader() = default;
49
50 /// Emit an error to the reader.
51 virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0;
52
53 /// Retrieve the dialect version by name if available.
54 virtual FailureOr<const DialectVersion *>
55 getDialectVersion(StringRef dialectName) const = 0;
56 template <class T>
57 FailureOr<const DialectVersion *> getDialectVersion() const {
58 return getDialectVersion(T::getDialectNamespace());
59 }
60
61 /// Retrieve the context associated to the reader.
62 virtual MLIRContext *getContext() const = 0;
63
64 /// Return the bytecode version being read.
65 virtual uint64_t getBytecodeVersion() const = 0;
66
67 /// Read out a list of elements, invoking the provided callback for each
68 /// element. The callback function may be in any of the following forms:
69 /// * LogicalResult(T &)
70 /// * FailureOr<T>()
71 template <typename T, typename CallbackFn>
72 LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) {
73 uint64_t size;
74 if (failed(Result: readVarInt(result&: size)))
75 return failure();
76 result.reserve(size);
77
78 for (uint64_t i = 0; i < size; ++i) {
79 // Check if the callback uses FailureOr, or populates the result by
80 // reference.
81 if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) {
82 T element = {};
83 if (failed(callback(element)))
84 return failure();
85 result.emplace_back(std::move(element));
86 } else {
87 FailureOr<T> element = callback();
88 if (failed(element))
89 return failure();
90 result.emplace_back(std::move(*element));
91 }
92 }
93 return success();
94 }
95
96 //===--------------------------------------------------------------------===//
97 // IR
98 //===--------------------------------------------------------------------===//
99
100 /// Read a reference to the given attribute.
101 virtual LogicalResult readAttribute(Attribute &result) = 0;
102 /// Read an optional reference to the given attribute. Returns success even if
103 /// the Attribute isn't present.
104 virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0;
105
106 template <typename T>
107 LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
108 return readList(attrs, [this](T &attr) { return readAttribute(attr); });
109 }
110 template <typename T>
111 LogicalResult readAttribute(T &result) {
112 Attribute baseResult;
113 if (failed(Result: readAttribute(result&: baseResult)))
114 return failure();
115 if ((result = dyn_cast<T>(baseResult)))
116 return success();
117 return emitError() << "expected " << llvm::getTypeName<T>()
118 << ", but got: " << baseResult;
119 }
120 template <typename T>
121 LogicalResult readOptionalAttribute(T &result) {
122 Attribute baseResult;
123 if (failed(Result: readOptionalAttribute(attr&: baseResult)))
124 return failure();
125 if (!baseResult)
126 return success();
127 if ((result = dyn_cast<T>(baseResult)))
128 return success();
129 return emitError() << "expected " << llvm::getTypeName<T>()
130 << ", but got: " << baseResult;
131 }
132
133 /// Read a reference to the given type.
134 virtual LogicalResult readType(Type &result) = 0;
135 template <typename T>
136 LogicalResult readTypes(SmallVectorImpl<T> &types) {
137 return readList(types, [this](T &type) { return readType(type); });
138 }
139 template <typename T>
140 LogicalResult readType(T &result) {
141 Type baseResult;
142 if (failed(Result: readType(result&: baseResult)))
143 return failure();
144 if ((result = dyn_cast<T>(baseResult)))
145 return success();
146 return emitError() << "expected " << llvm::getTypeName<T>()
147 << ", but got: " << baseResult;
148 }
149
150 /// Read a handle to a dialect resource.
151 template <typename ResourceT>
152 FailureOr<ResourceT> readResourceHandle() {
153 FailureOr<AsmDialectResourceHandle> handle = readResourceHandle();
154 if (failed(Result: handle))
155 return failure();
156 if (auto *result = dyn_cast<ResourceT>(&*handle))
157 return std::move(*result);
158 return emitError() << "provided resource handle differs from the "
159 "expected resource type";
160 }
161
162 //===--------------------------------------------------------------------===//
163 // Primitives
164 //===--------------------------------------------------------------------===//
165
166 /// Read a variable width integer.
167 virtual LogicalResult readVarInt(uint64_t &result) = 0;
168
169 /// Read a signed variable width integer.
170 virtual LogicalResult readSignedVarInt(int64_t &result) = 0;
171 LogicalResult readSignedVarInts(SmallVectorImpl<int64_t> &result) {
172 return readList(result,
173 callback: [this](int64_t &value) { return readSignedVarInt(result&: value); });
174 }
175
176 /// Parse a variable length encoded integer whose low bit is used to encode an
177 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
178 LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) {
179 if (failed(Result: readVarInt(result)))
180 return failure();
181 flag = result & 1;
182 result >>= 1;
183 return success();
184 }
185
186 /// Read a "small" sparse array of integer <= 32 bits elements, where
187 /// index/value pairs can be compressed when the array is small.
188 /// Note that only some position of the array will be read and the ones
189 /// not stored in the bytecode are gonne be left untouched.
190 /// If the provided array is too small for the stored indices, an error
191 /// will be returned.
192 template <typename T>
193 LogicalResult readSparseArray(MutableArrayRef<T> array) {
194 static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
195 static_assert(std::is_integral<T>::value, "expects integer");
196 uint64_t nonZeroesCount;
197 bool useSparseEncoding;
198 if (failed(Result: readVarIntWithFlag(result&: nonZeroesCount, flag&: useSparseEncoding)))
199 return failure();
200 if (nonZeroesCount == 0)
201 return success();
202 if (!useSparseEncoding) {
203 // This is a simple dense array.
204 if (nonZeroesCount > array.size()) {
205 emitError(msg: "trying to read an array of ")
206 << nonZeroesCount << " but only " << array.size()
207 << " storage available.";
208 return failure();
209 }
210 for (int64_t index : llvm::seq<int64_t>(Begin: 0, End: nonZeroesCount)) {
211 uint64_t value;
212 if (failed(Result: readVarInt(result&: value)))
213 return failure();
214 array[index] = value;
215 }
216 return success();
217 }
218 // Read sparse encoding
219 // This is the number of bits used for packing the index with the value.
220 uint64_t indexBitSize;
221 if (failed(Result: readVarInt(result&: indexBitSize)))
222 return failure();
223 constexpr uint64_t maxIndexBitSize = 8;
224 if (indexBitSize > maxIndexBitSize) {
225 emitError(msg: "reading sparse array with indexing above 8 bits: ")
226 << indexBitSize;
227 return failure();
228 }
229 for (uint32_t count : llvm::seq<uint32_t>(Begin: 0, End: nonZeroesCount)) {
230 (void)count;
231 uint64_t indexValuePair;
232 if (failed(Result: readVarInt(result&: indexValuePair)))
233 return failure();
234 uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize));
235 uint64_t value = indexValuePair >> indexBitSize;
236 if (index >= array.size()) {
237 emitError(msg: "reading a sparse array found index ")
238 << index << " but only " << array.size() << " storage available.";
239 return failure();
240 }
241 array[index] = value;
242 }
243 return success();
244 }
245
246 /// Read an APInt that is known to have been encoded with the given width.
247 virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0;
248
249 /// Read an APFloat that is known to have been encoded with the given
250 /// semantics.
251 virtual FailureOr<APFloat>
252 readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) = 0;
253
254 /// Read a string from the bytecode.
255 virtual LogicalResult readString(StringRef &result) = 0;
256
257 /// Read a blob from the bytecode.
258 virtual LogicalResult readBlob(ArrayRef<char> &result) = 0;
259
260 /// Read a bool from the bytecode.
261 virtual LogicalResult readBool(bool &result) = 0;
262
263private:
264 /// Read a handle to a dialect resource.
265 virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0;
266};
267
268//===----------------------------------------------------------------------===//
269// DialectBytecodeWriter
270//===----------------------------------------------------------------------===//
271
272/// This class defines a virtual interface for writing to a bytecode stream,
273/// providing hooks into the bytecode writer. As such, this class should only be
274/// derived and defined by the main bytecode writer, users (i.e. dialects)
275/// should generally only interact with this class via the
276/// BytecodeDialectInterface below.
277class DialectBytecodeWriter {
278public:
279 virtual ~DialectBytecodeWriter() = default;
280
281 //===--------------------------------------------------------------------===//
282 // IR
283 //===--------------------------------------------------------------------===//
284
285 /// Write out a list of elements, invoking the provided callback for each
286 /// element.
287 template <typename RangeT, typename CallbackFn>
288 void writeList(RangeT &&range, CallbackFn &&callback) {
289 writeVarInt(value: llvm::size(range));
290 for (auto &element : range)
291 callback(element);
292 }
293
294 /// Write a reference to the given attribute.
295 virtual void writeAttribute(Attribute attr) = 0;
296 virtual void writeOptionalAttribute(Attribute attr) = 0;
297 template <typename T>
298 void writeAttributes(ArrayRef<T> attrs) {
299 writeList(attrs, [this](T attr) { writeAttribute(attr); });
300 }
301
302 /// Write a reference to the given type.
303 virtual void writeType(Type type) = 0;
304 template <typename T>
305 void writeTypes(ArrayRef<T> types) {
306 writeList(types, [this](T type) { writeType(type); });
307 }
308
309 /// Write the given handle to a dialect resource.
310 virtual void
311 writeResourceHandle(const AsmDialectResourceHandle &resource) = 0;
312
313 //===--------------------------------------------------------------------===//
314 // Primitives
315 //===--------------------------------------------------------------------===//
316
317 /// Write a variable width integer to the output stream. This should be the
318 /// preferred method for emitting integers whenever possible.
319 virtual void writeVarInt(uint64_t value) = 0;
320
321 /// Write a signed variable width integer to the output stream. This should be
322 /// the preferred method for emitting signed integers whenever possible.
323 virtual void writeSignedVarInt(int64_t value) = 0;
324 void writeSignedVarInts(ArrayRef<int64_t> value) {
325 writeList(range&: value, callback: [this](int64_t value) { writeSignedVarInt(value); });
326 }
327
328 /// Write a VarInt and a flag packed together.
329 void writeVarIntWithFlag(uint64_t value, bool flag) {
330 writeVarInt(value: (value << 1) | (flag ? 1 : 0));
331 }
332
333 /// Write out a "small" sparse array of integer <= 32 bits elements, where
334 /// index/value pairs can be compressed when the array is small. This method
335 /// will scan the array multiple times and should not be used for large
336 /// arrays. The optional provided "zero" can be used to adjust for the
337 /// expected repeated value. We assume here that the array size fits in a 32
338 /// bits integer.
339 template <typename T>
340 void writeSparseArray(ArrayRef<T> array) {
341 static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
342 static_assert(std::is_integral<T>::value, "expects integer");
343 uint32_t size = array.size();
344 uint32_t nonZeroesCount = 0, lastIndex = 0;
345 for (uint32_t index : llvm::seq<uint32_t>(Begin: 0, End: size)) {
346 if (!array[index])
347 continue;
348 nonZeroesCount++;
349 lastIndex = index;
350 }
351 // If the last position is too large, or the array isn't at least 50%
352 // sparse, emit it with a dense encoding.
353 if (lastIndex > 256 || nonZeroesCount > size / 2) {
354 // Emit the array size and a flag which indicates whether it is sparse.
355 writeVarIntWithFlag(value: size, flag: false);
356 for (const T &elt : array)
357 writeVarInt(value: elt);
358 return;
359 }
360 // Emit sparse: first the number of elements we'll write and a flag
361 // indicating it is a sparse encoding.
362 writeVarIntWithFlag(value: nonZeroesCount, flag: true);
363 if (nonZeroesCount == 0)
364 return;
365 // This is the number of bits used for packing the index with the value.
366 int indexBitSize = llvm::Log2_32_Ceil(Value: lastIndex + 1);
367 writeVarInt(value: indexBitSize);
368 for (uint32_t index : llvm::seq<uint32_t>(Begin: 0, End: lastIndex + 1)) {
369 T value = array[index];
370 if (!value)
371 continue;
372 uint64_t indexValuePair = (value << indexBitSize) | (index);
373 writeVarInt(value: indexValuePair);
374 }
375 }
376
377 /// Write an APInt to the bytecode stream whose bitwidth will be known
378 /// externally at read time. This method is useful for encoding APInt values
379 /// when the width is known via external means, such as via a type. This
380 /// method should generally only be invoked if you need an APInt, otherwise
381 /// use the varint methods above. APInt values are generally encoded using
382 /// zigzag encoding, to enable more efficient encodings for negative values.
383 virtual void writeAPIntWithKnownWidth(const APInt &value) = 0;
384
385 /// Write an APFloat to the bytecode stream whose semantics will be known
386 /// externally at read time. This method is useful for encoding APFloat values
387 /// when the semantics are known via external means, such as via a type.
388 virtual void writeAPFloatWithKnownSemantics(const APFloat &value) = 0;
389
390 /// Write a string to the bytecode, which is owned by the caller and is
391 /// guaranteed to not die before the end of the bytecode process. This should
392 /// only be called if such a guarantee can be made, such as when the string is
393 /// owned by an attribute or type.
394 virtual void writeOwnedString(StringRef str) = 0;
395
396 /// Write a blob to the bytecode, which is owned by the caller and is
397 /// guaranteed to not die before the end of the bytecode process. The blob is
398 /// written as-is, with no additional compression or compaction.
399 virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
400
401 /// Write a bool to the output stream.
402 virtual void writeOwnedBool(bool value) = 0;
403
404 /// Return the bytecode version being emitted for.
405 virtual int64_t getBytecodeVersion() const = 0;
406
407 /// Retrieve the dialect version by name if available.
408 virtual FailureOr<const DialectVersion *>
409 getDialectVersion(StringRef dialectName) const = 0;
410
411 template <class T>
412 FailureOr<const DialectVersion *> getDialectVersion() const {
413 return getDialectVersion(T::getDialectNamespace());
414 }
415};
416
417//===----------------------------------------------------------------------===//
418// BytecodeDialectInterface
419//===----------------------------------------------------------------------===//
420
421class BytecodeDialectInterface
422 : public DialectInterface::Base<BytecodeDialectInterface> {
423public:
424 using Base::Base;
425
426 //===--------------------------------------------------------------------===//
427 // Reading
428 //===--------------------------------------------------------------------===//
429
430 /// Read an attribute belonging to this dialect from the given reader. This
431 /// method should return null in the case of failure. Optionally, the dialect
432 /// version can be accessed through the reader.
433 virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
434 reader.emitError() << "dialect " << getDialect()->getNamespace()
435 << " does not support reading attributes from bytecode";
436 return Attribute();
437 }
438
439 /// Read a type belonging to this dialect from the given reader. This method
440 /// should return null in the case of failure. Optionally, the dialect version
441 /// can be accessed thorugh the reader.
442 virtual Type readType(DialectBytecodeReader &reader) const {
443 reader.emitError() << "dialect " << getDialect()->getNamespace()
444 << " does not support reading types from bytecode";
445 return Type();
446 }
447
448 //===--------------------------------------------------------------------===//
449 // Writing
450 //===--------------------------------------------------------------------===//
451
452 /// Write the given attribute, which belongs to this dialect, to the given
453 /// writer. This method may return failure to indicate that the given
454 /// attribute could not be encoded, in which case the textual format will be
455 /// used to encode this attribute instead.
456 virtual LogicalResult writeAttribute(Attribute attr,
457 DialectBytecodeWriter &writer) const {
458 return failure();
459 }
460
461 /// Write the given type, which belongs to this dialect, to the given writer.
462 /// This method may return failure to indicate that the given type could not
463 /// be encoded, in which case the textual format will be used to encode this
464 /// type instead.
465 virtual LogicalResult writeType(Type type,
466 DialectBytecodeWriter &writer) const {
467 return failure();
468 }
469
470 /// Write the version of this dialect to the given writer.
471 virtual void writeVersion(DialectBytecodeWriter &writer) const {}
472
473 // Read the version of this dialect from the provided reader and return it as
474 // a `unique_ptr` to a dialect version object.
475 virtual std::unique_ptr<DialectVersion>
476 readVersion(DialectBytecodeReader &reader) const {
477 reader.emitError(msg: "Dialect does not support versioning");
478 return nullptr;
479 }
480
481 /// Hook invoked after parsing completed, if a version directive was present
482 /// and included an entry for the current dialect. This hook offers the
483 /// opportunity to the dialect to visit the IR and upgrades constructs emitted
484 /// by the version of the dialect corresponding to the provided version.
485 virtual LogicalResult
486 upgradeFromVersion(Operation *topLevelOp,
487 const DialectVersion &version) const {
488 return success();
489 }
490};
491
492/// Helper for resource handle reading that returns LogicalResult.
493template <typename T, typename... Ts>
494static LogicalResult readResourceHandle(DialectBytecodeReader &reader,
495 FailureOr<T> &value, Ts &&...params) {
496 FailureOr<T> handle = reader.readResourceHandle<T>();
497 if (failed(handle))
498 return failure();
499 if (auto *result = dyn_cast<T>(&*handle)) {
500 value = std::move(*result);
501 return success();
502 }
503 return failure();
504}
505
506/// Helper method that injects context only if needed, this helps unify some of
507/// the attribute construction methods.
508template <typename T, typename... Ts>
509auto get(MLIRContext *context, Ts &&...params) {
510 // Prefer a direct `get` method if one exists.
511 if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) {
512 (void)context;
513 return T::get(std::forward<Ts>(params)...);
514 } else if constexpr (llvm::is_detected<detail::has_get_method, T,
515 MLIRContext *, Ts...>::value) {
516 return T::get(context, std::forward<Ts>(params)...);
517 } else {
518 // Otherwise, pass to the base get.
519 return T::Base::get(context, std::forward<Ts>(params)...);
520 }
521}
522
523} // namespace mlir
524
525#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
526

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/include/mlir/Bytecode/BytecodeImplementation.h