1 | //===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains various utilities that number IR structures in preparation |
10 | // for bytecode emission. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H |
15 | #define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H |
16 | |
17 | #include "mlir/IR/OpImplementation.h" |
18 | #include "llvm/ADT/MapVector.h" |
19 | #include "llvm/ADT/SetVector.h" |
20 | #include "llvm/ADT/StringMap.h" |
21 | #include <cstdint> |
22 | |
23 | namespace mlir { |
24 | class BytecodeDialectInterface; |
25 | class BytecodeWriterConfig; |
26 | |
27 | namespace bytecode { |
28 | namespace detail { |
29 | struct DialectNumbering; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // Attribute and Type Numbering |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | /// This class represents a numbering entry for an Attribute or Type. |
36 | struct AttrTypeNumbering { |
37 | AttrTypeNumbering(PointerUnion<Attribute, Type> value) : value(value) {} |
38 | |
39 | /// The concrete value. |
40 | PointerUnion<Attribute, Type> value; |
41 | |
42 | /// The number assigned to this value. |
43 | unsigned number = 0; |
44 | |
45 | /// The number of references to this value. |
46 | unsigned refCount = 1; |
47 | |
48 | /// The dialect of this value. |
49 | DialectNumbering *dialect = nullptr; |
50 | }; |
51 | struct AttributeNumbering : public AttrTypeNumbering { |
52 | AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {} |
53 | Attribute getValue() const { return value.get<Attribute>(); } |
54 | }; |
55 | struct TypeNumbering : public AttrTypeNumbering { |
56 | TypeNumbering(Type value) : AttrTypeNumbering(value) {} |
57 | Type getValue() const { return value.get<Type>(); } |
58 | }; |
59 | |
60 | //===----------------------------------------------------------------------===// |
61 | // OpName Numbering |
62 | //===----------------------------------------------------------------------===// |
63 | |
64 | /// This class represents the numbering entry of an operation name. |
65 | struct { |
66 | (DialectNumbering *dialect, OperationName name) |
67 | : dialect(dialect), name(name) {} |
68 | |
69 | /// The dialect of this value. |
70 | DialectNumbering *; |
71 | |
72 | /// The concrete name. |
73 | OperationName ; |
74 | |
75 | /// The number assigned to this name. |
76 | unsigned = 0; |
77 | |
78 | /// The number of references to this name. |
79 | unsigned = 1; |
80 | }; |
81 | |
82 | //===----------------------------------------------------------------------===// |
83 | // Dialect Resource Numbering |
84 | //===----------------------------------------------------------------------===// |
85 | |
86 | /// This class represents a numbering entry for a dialect resource. |
87 | struct DialectResourceNumbering { |
88 | DialectResourceNumbering(std::string key) : key(std::move(key)) {} |
89 | |
90 | /// The key used to reference this resource. |
91 | std::string key; |
92 | |
93 | /// The number assigned to this resource. |
94 | unsigned number = 0; |
95 | |
96 | /// A flag indicating if this resource is only a declaration, not a full |
97 | /// definition. |
98 | bool isDeclaration = true; |
99 | }; |
100 | |
101 | //===----------------------------------------------------------------------===// |
102 | // Dialect Numbering |
103 | //===----------------------------------------------------------------------===// |
104 | |
105 | /// This class represents a numbering entry for an Dialect. |
106 | struct DialectNumbering { |
107 | DialectNumbering(StringRef name, unsigned number) |
108 | : name(name), number(number) {} |
109 | |
110 | /// The namespace of the dialect. |
111 | StringRef name; |
112 | |
113 | /// The number assigned to the dialect. |
114 | unsigned number; |
115 | |
116 | /// The bytecode dialect interface of the dialect if defined. |
117 | const BytecodeDialectInterface *interface = nullptr; |
118 | |
119 | /// The asm dialect interface of the dialect if defined. |
120 | const OpAsmDialectInterface *asmInterface = nullptr; |
121 | |
122 | /// The referenced resources of this dialect. |
123 | SetVector<AsmDialectResourceHandle> resources; |
124 | |
125 | /// A mapping from resource key to the corresponding resource numbering entry. |
126 | llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap; |
127 | }; |
128 | |
129 | //===----------------------------------------------------------------------===// |
130 | // Operation Numbering |
131 | //===----------------------------------------------------------------------===// |
132 | |
133 | /// This class represents the numbering entry of an operation. |
134 | struct OperationNumbering { |
135 | OperationNumbering(unsigned number) : number(number) {} |
136 | |
137 | /// The number assigned to this operation. |
138 | unsigned number; |
139 | |
140 | /// A flag indicating if this operation's regions are isolated. If unset, the |
141 | /// operation isn't yet known to be isolated. |
142 | std::optional<bool> isIsolatedFromAbove; |
143 | }; |
144 | |
145 | //===----------------------------------------------------------------------===// |
146 | // IRNumberingState |
147 | //===----------------------------------------------------------------------===// |
148 | |
149 | /// This class manages numbering IR entities in preparation of bytecode |
150 | /// emission. |
151 | class IRNumberingState { |
152 | public: |
153 | IRNumberingState(Operation *op, const BytecodeWriterConfig &config); |
154 | |
155 | /// Return the numbered dialects. |
156 | auto getDialects() { |
157 | return llvm::make_pointee_range(Range: llvm::make_second_range(c&: dialects)); |
158 | } |
159 | auto getAttributes() { return llvm::make_pointee_range(Range&: orderedAttrs); } |
160 | auto getOpNames() { return llvm::make_pointee_range(Range&: orderedOpNames); } |
161 | auto getTypes() { return llvm::make_pointee_range(Range&: orderedTypes); } |
162 | |
163 | /// Return the number for the given IR unit. |
164 | unsigned getNumber(Attribute attr) { |
165 | assert(attrs.count(attr) && "attribute not numbered" ); |
166 | return attrs[attr]->number; |
167 | } |
168 | unsigned getNumber(Block *block) { |
169 | assert(blockIDs.count(block) && "block not numbered" ); |
170 | return blockIDs[block]; |
171 | } |
172 | unsigned getNumber(Operation *op) { |
173 | assert(operations.count(op) && "operation not numbered" ); |
174 | return operations[op]->number; |
175 | } |
176 | unsigned getNumber(OperationName opName) { |
177 | assert(opNames.count(opName) && "opName not numbered" ); |
178 | return opNames[opName]->number; |
179 | } |
180 | unsigned getNumber(Type type) { |
181 | assert(types.count(type) && "type not numbered" ); |
182 | return types[type]->number; |
183 | } |
184 | unsigned getNumber(Value value) { |
185 | assert(valueIDs.count(value) && "value not numbered" ); |
186 | return valueIDs[value]; |
187 | } |
188 | unsigned getNumber(const AsmDialectResourceHandle &resource) { |
189 | assert(dialectResources.count(resource) && "resource not numbered" ); |
190 | return dialectResources[resource]->number; |
191 | } |
192 | |
193 | /// Return the block and value counts of the given region. |
194 | std::pair<unsigned, unsigned> getBlockValueCount(Region *region) { |
195 | assert(regionBlockValueCounts.count(region) && "value not numbered" ); |
196 | return regionBlockValueCounts[region]; |
197 | } |
198 | |
199 | /// Return the number of operations in the given block. |
200 | unsigned getOperationCount(Block *block) { |
201 | assert(blockOperationCounts.count(block) && "block not numbered" ); |
202 | return blockOperationCounts[block]; |
203 | } |
204 | |
205 | /// Return if the given operation is isolated from above. |
206 | bool isIsolatedFromAbove(Operation *op) { |
207 | assert(operations.count(op) && "operation not numbered" ); |
208 | return operations[op]->isIsolatedFromAbove.value_or(u: false); |
209 | } |
210 | |
211 | /// Get the set desired bytecode version to emit. |
212 | int64_t getDesiredBytecodeVersion() const; |
213 | |
214 | private: |
215 | /// This class is used to provide a fake dialect writer for numbering nested |
216 | /// attributes and types. |
217 | struct NumberingDialectWriter; |
218 | |
219 | /// Compute the global numbering state for the given root operation. |
220 | void computeGlobalNumberingState(Operation *rootOp); |
221 | |
222 | /// Number the given IR unit for bytecode emission. |
223 | void number(Attribute attr); |
224 | void number(Block &block); |
225 | DialectNumbering &numberDialect(Dialect *dialect); |
226 | DialectNumbering &numberDialect(StringRef dialect); |
227 | void number(Operation &op); |
228 | void number(OperationName opName); |
229 | void number(Region ®ion); |
230 | void number(Type type); |
231 | |
232 | /// Number the given dialect resources. |
233 | void number(Dialect *dialect, ArrayRef<AsmDialectResourceHandle> resources); |
234 | |
235 | /// Finalize the numberings of any dialect resources. |
236 | void finalizeDialectResourceNumberings(Operation *rootOp); |
237 | |
238 | /// Mapping from IR to the respective numbering entries. |
239 | DenseMap<Attribute, AttributeNumbering *> attrs; |
240 | DenseMap<Operation *, OperationNumbering *> operations; |
241 | DenseMap<OperationName, OpNameNumbering *> opNames; |
242 | DenseMap<Type, TypeNumbering *> types; |
243 | DenseMap<Dialect *, DialectNumbering *> registeredDialects; |
244 | llvm::MapVector<StringRef, DialectNumbering *> dialects; |
245 | std::vector<AttributeNumbering *> orderedAttrs; |
246 | std::vector<OpNameNumbering *> orderedOpNames; |
247 | std::vector<TypeNumbering *> orderedTypes; |
248 | |
249 | /// A mapping from dialect resource handle to the numbering for the referenced |
250 | /// resource. |
251 | llvm::DenseMap<AsmDialectResourceHandle, DialectResourceNumbering *> |
252 | dialectResources; |
253 | |
254 | /// Allocators used for the various numbering entries. |
255 | llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator; |
256 | llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator; |
257 | llvm::SpecificBumpPtrAllocator<OperationNumbering> opAllocator; |
258 | llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator; |
259 | llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator; |
260 | llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator; |
261 | |
262 | /// The value ID for each Block and Value. |
263 | DenseMap<Block *, unsigned> blockIDs; |
264 | DenseMap<Value, unsigned> valueIDs; |
265 | |
266 | /// The number of operations in each block. |
267 | DenseMap<Block *, unsigned> blockOperationCounts; |
268 | |
269 | /// A map from region to the number of blocks and values within that region. |
270 | DenseMap<Region *, std::pair<unsigned, unsigned>> regionBlockValueCounts; |
271 | |
272 | /// The next value ID to assign when numbering. |
273 | unsigned nextValueID = 0; |
274 | |
275 | // Configuration: useful to query the required version to emit. |
276 | const BytecodeWriterConfig &config; |
277 | }; |
278 | } // namespace detail |
279 | } // namespace bytecode |
280 | } // namespace mlir |
281 | |
282 | #endif |
283 | |