1 | //===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This header defines interfaces to read MLIR bytecode files/streams. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_BYTECODE_BYTECODEREADERCONFIG_H |
14 | #define MLIR_BYTECODE_BYTECODEREADERCONFIG_H |
15 | |
16 | #include "mlir/Support/LLVM.h" |
17 | #include "mlir/Support/LogicalResult.h" |
18 | #include "llvm/ADT/ArrayRef.h" |
19 | #include "llvm/ADT/SmallVector.h" |
20 | #include "llvm/ADT/StringRef.h" |
21 | |
22 | namespace mlir { |
23 | class Attribute; |
24 | class DialectBytecodeReader; |
25 | class Type; |
26 | |
27 | /// A class to interact with the attributes and types parser when parsing MLIR |
28 | /// bytecode. |
29 | template <class T> |
30 | class AttrTypeBytecodeReader { |
31 | public: |
32 | AttrTypeBytecodeReader() = default; |
33 | virtual ~AttrTypeBytecodeReader() = default; |
34 | |
35 | virtual LogicalResult read(DialectBytecodeReader &reader, |
36 | StringRef dialectName, T &entry) = 0; |
37 | |
38 | /// Return an Attribute/Type printer implemented via the given callable, whose |
39 | /// form should match that of the `parse` function above. |
40 | template <typename CallableT, |
41 | std::enable_if_t< |
42 | std::is_convertible_v< |
43 | CallableT, std::function<LogicalResult( |
44 | DialectBytecodeReader &, StringRef, T &)>>, |
45 | bool> = true> |
46 | static std::unique_ptr<AttrTypeBytecodeReader<T>> |
47 | fromCallable(CallableT &&readFn) { |
48 | struct Processor : public AttrTypeBytecodeReader<T> { |
49 | Processor(CallableT &&readFn) |
50 | : AttrTypeBytecodeReader(), readFn(std::move(readFn)) {} |
51 | LogicalResult read(DialectBytecodeReader &reader, StringRef dialectName, |
52 | T &entry) override { |
53 | return readFn(reader, dialectName, entry); |
54 | } |
55 | |
56 | std::decay_t<CallableT> readFn; |
57 | }; |
58 | return std::make_unique<Processor>(std::forward<CallableT>(readFn)); |
59 | } |
60 | }; |
61 | |
62 | //===----------------------------------------------------------------------===// |
63 | // BytecodeReaderConfig |
64 | //===----------------------------------------------------------------------===// |
65 | |
66 | /// A class containing bytecode-specific configurations of the `ParserConfig`. |
67 | class BytecodeReaderConfig { |
68 | public: |
69 | BytecodeReaderConfig() = default; |
70 | |
71 | /// Returns the callbacks available to the parser. |
72 | ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>> |
73 | getAttributeCallbacks() const { |
74 | return attributeBytecodeParsers; |
75 | } |
76 | ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Type>>> |
77 | getTypeCallbacks() const { |
78 | return typeBytecodeParsers; |
79 | } |
80 | |
81 | /// Attach a custom bytecode parser callback to the configuration for parsing |
82 | /// of custom type/attributes encodings. |
83 | void attachAttributeCallback( |
84 | std::unique_ptr<AttrTypeBytecodeReader<Attribute>> parser) { |
85 | attributeBytecodeParsers.emplace_back(Args: std::move(parser)); |
86 | } |
87 | void |
88 | attachTypeCallback(std::unique_ptr<AttrTypeBytecodeReader<Type>> parser) { |
89 | typeBytecodeParsers.emplace_back(Args: std::move(parser)); |
90 | } |
91 | |
92 | /// Attach a custom bytecode parser callback to the configuration for parsing |
93 | /// of custom type/attributes encodings. |
94 | template <typename CallableT> |
95 | std::enable_if_t<std::is_convertible_v< |
96 | CallableT, std::function<LogicalResult(DialectBytecodeReader &, StringRef, |
97 | Attribute &)>>> |
98 | attachAttributeCallback(CallableT &&parserFn) { |
99 | attachAttributeCallback(AttrTypeBytecodeReader<Attribute>::fromCallable( |
100 | std::forward<CallableT>(parserFn))); |
101 | } |
102 | template <typename CallableT> |
103 | std::enable_if_t<std::is_convertible_v< |
104 | CallableT, |
105 | std::function<LogicalResult(DialectBytecodeReader &, StringRef, Type &)>>> |
106 | attachTypeCallback(CallableT &&parserFn) { |
107 | attachTypeCallback(AttrTypeBytecodeReader<Type>::fromCallable( |
108 | std::forward<CallableT>(parserFn))); |
109 | } |
110 | |
111 | private: |
112 | llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>> |
113 | attributeBytecodeParsers; |
114 | llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Type>>> |
115 | typeBytecodeParsers; |
116 | }; |
117 | |
118 | } // namespace mlir |
119 | |
120 | #endif // MLIR_BYTECODE_BYTECODEREADERCONFIG_H |
121 | |