1//===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains various utilities that number IR structures in preparation
10// for bytecode emission.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
15#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
16
17#include "mlir/IR/OpImplementation.h"
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/ADT/StringMap.h"
21#include <cstdint>
22
23namespace mlir {
24class BytecodeDialectInterface;
25class BytecodeWriterConfig;
26
27namespace bytecode {
28namespace detail {
29struct DialectNumbering;
30
31//===----------------------------------------------------------------------===//
32// Attribute and Type Numbering
33//===----------------------------------------------------------------------===//
34
35/// This class represents a numbering entry for an Attribute or Type.
36struct AttrTypeNumbering {
37 AttrTypeNumbering(PointerUnion<Attribute, Type> value) : value(value) {}
38
39 /// The concrete value.
40 PointerUnion<Attribute, Type> value;
41
42 /// The number assigned to this value.
43 unsigned number = 0;
44
45 /// The number of references to this value.
46 unsigned refCount = 1;
47
48 /// The dialect of this value.
49 DialectNumbering *dialect = nullptr;
50};
51struct AttributeNumbering : public AttrTypeNumbering {
52 AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {}
53 Attribute getValue() const { return value.get<Attribute>(); }
54};
55struct TypeNumbering : public AttrTypeNumbering {
56 TypeNumbering(Type value) : AttrTypeNumbering(value) {}
57 Type getValue() const { return value.get<Type>(); }
58};
59
60//===----------------------------------------------------------------------===//
61// OpName Numbering
62//===----------------------------------------------------------------------===//
63
64/// This class represents the numbering entry of an operation name.
65struct OpNameNumbering {
66 OpNameNumbering(DialectNumbering *dialect, OperationName name)
67 : dialect(dialect), name(name) {}
68
69 /// The dialect of this value.
70 DialectNumbering *dialect;
71
72 /// The concrete name.
73 OperationName name;
74
75 /// The number assigned to this name.
76 unsigned number = 0;
77
78 /// The number of references to this name.
79 unsigned refCount = 1;
80};
81
82//===----------------------------------------------------------------------===//
83// Dialect Resource Numbering
84//===----------------------------------------------------------------------===//
85
86/// This class represents a numbering entry for a dialect resource.
87struct DialectResourceNumbering {
88 DialectResourceNumbering(std::string key) : key(std::move(key)) {}
89
90 /// The key used to reference this resource.
91 std::string key;
92
93 /// The number assigned to this resource.
94 unsigned number = 0;
95
96 /// A flag indicating if this resource is only a declaration, not a full
97 /// definition.
98 bool isDeclaration = true;
99};
100
101//===----------------------------------------------------------------------===//
102// Dialect Numbering
103//===----------------------------------------------------------------------===//
104
105/// This class represents a numbering entry for an Dialect.
106struct DialectNumbering {
107 DialectNumbering(StringRef name, unsigned number)
108 : name(name), number(number) {}
109
110 /// The namespace of the dialect.
111 StringRef name;
112
113 /// The number assigned to the dialect.
114 unsigned number;
115
116 /// The bytecode dialect interface of the dialect if defined.
117 const BytecodeDialectInterface *interface = nullptr;
118
119 /// The asm dialect interface of the dialect if defined.
120 const OpAsmDialectInterface *asmInterface = nullptr;
121
122 /// The referenced resources of this dialect.
123 SetVector<AsmDialectResourceHandle> resources;
124
125 /// A mapping from resource key to the corresponding resource numbering entry.
126 llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap;
127};
128
129//===----------------------------------------------------------------------===//
130// Operation Numbering
131//===----------------------------------------------------------------------===//
132
133/// This class represents the numbering entry of an operation.
134struct OperationNumbering {
135 OperationNumbering(unsigned number) : number(number) {}
136
137 /// The number assigned to this operation.
138 unsigned number;
139
140 /// A flag indicating if this operation's regions are isolated. If unset, the
141 /// operation isn't yet known to be isolated.
142 std::optional<bool> isIsolatedFromAbove;
143};
144
145//===----------------------------------------------------------------------===//
146// IRNumberingState
147//===----------------------------------------------------------------------===//
148
149/// This class manages numbering IR entities in preparation of bytecode
150/// emission.
151class IRNumberingState {
152public:
153 IRNumberingState(Operation *op, const BytecodeWriterConfig &config);
154
155 /// Return the numbered dialects.
156 auto getDialects() {
157 return llvm::make_pointee_range(Range: llvm::make_second_range(c&: dialects));
158 }
159 auto getAttributes() { return llvm::make_pointee_range(Range&: orderedAttrs); }
160 auto getOpNames() { return llvm::make_pointee_range(Range&: orderedOpNames); }
161 auto getTypes() { return llvm::make_pointee_range(Range&: orderedTypes); }
162
163 /// Return the number for the given IR unit.
164 unsigned getNumber(Attribute attr) {
165 assert(attrs.count(attr) && "attribute not numbered");
166 return attrs[attr]->number;
167 }
168 unsigned getNumber(Block *block) {
169 assert(blockIDs.count(block) && "block not numbered");
170 return blockIDs[block];
171 }
172 unsigned getNumber(Operation *op) {
173 assert(operations.count(op) && "operation not numbered");
174 return operations[op]->number;
175 }
176 unsigned getNumber(OperationName opName) {
177 assert(opNames.count(opName) && "opName not numbered");
178 return opNames[opName]->number;
179 }
180 unsigned getNumber(Type type) {
181 assert(types.count(type) && "type not numbered");
182 return types[type]->number;
183 }
184 unsigned getNumber(Value value) {
185 assert(valueIDs.count(value) && "value not numbered");
186 return valueIDs[value];
187 }
188 unsigned getNumber(const AsmDialectResourceHandle &resource) {
189 assert(dialectResources.count(resource) && "resource not numbered");
190 return dialectResources[resource]->number;
191 }
192
193 /// Return the block and value counts of the given region.
194 std::pair<unsigned, unsigned> getBlockValueCount(Region *region) {
195 assert(regionBlockValueCounts.count(region) && "value not numbered");
196 return regionBlockValueCounts[region];
197 }
198
199 /// Return the number of operations in the given block.
200 unsigned getOperationCount(Block *block) {
201 assert(blockOperationCounts.count(block) && "block not numbered");
202 return blockOperationCounts[block];
203 }
204
205 /// Return if the given operation is isolated from above.
206 bool isIsolatedFromAbove(Operation *op) {
207 assert(operations.count(op) && "operation not numbered");
208 return operations[op]->isIsolatedFromAbove.value_or(u: false);
209 }
210
211 /// Get the set desired bytecode version to emit.
212 int64_t getDesiredBytecodeVersion() const;
213
214private:
215 /// This class is used to provide a fake dialect writer for numbering nested
216 /// attributes and types.
217 struct NumberingDialectWriter;
218
219 /// Compute the global numbering state for the given root operation.
220 void computeGlobalNumberingState(Operation *rootOp);
221
222 /// Number the given IR unit for bytecode emission.
223 void number(Attribute attr);
224 void number(Block &block);
225 DialectNumbering &numberDialect(Dialect *dialect);
226 DialectNumbering &numberDialect(StringRef dialect);
227 void number(Operation &op);
228 void number(OperationName opName);
229 void number(Region &region);
230 void number(Type type);
231
232 /// Number the given dialect resources.
233 void number(Dialect *dialect, ArrayRef<AsmDialectResourceHandle> resources);
234
235 /// Finalize the numberings of any dialect resources.
236 void finalizeDialectResourceNumberings(Operation *rootOp);
237
238 /// Mapping from IR to the respective numbering entries.
239 DenseMap<Attribute, AttributeNumbering *> attrs;
240 DenseMap<Operation *, OperationNumbering *> operations;
241 DenseMap<OperationName, OpNameNumbering *> opNames;
242 DenseMap<Type, TypeNumbering *> types;
243 DenseMap<Dialect *, DialectNumbering *> registeredDialects;
244 llvm::MapVector<StringRef, DialectNumbering *> dialects;
245 std::vector<AttributeNumbering *> orderedAttrs;
246 std::vector<OpNameNumbering *> orderedOpNames;
247 std::vector<TypeNumbering *> orderedTypes;
248
249 /// A mapping from dialect resource handle to the numbering for the referenced
250 /// resource.
251 llvm::DenseMap<AsmDialectResourceHandle, DialectResourceNumbering *>
252 dialectResources;
253
254 /// Allocators used for the various numbering entries.
255 llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator;
256 llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator;
257 llvm::SpecificBumpPtrAllocator<OperationNumbering> opAllocator;
258 llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator;
259 llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator;
260 llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator;
261
262 /// The value ID for each Block and Value.
263 DenseMap<Block *, unsigned> blockIDs;
264 DenseMap<Value, unsigned> valueIDs;
265
266 /// The number of operations in each block.
267 DenseMap<Block *, unsigned> blockOperationCounts;
268
269 /// A map from region to the number of blocks and values within that region.
270 DenseMap<Region *, std::pair<unsigned, unsigned>> regionBlockValueCounts;
271
272 /// The next value ID to assign when numbering.
273 unsigned nextValueID = 0;
274
275 // Configuration: useful to query the required version to emit.
276 const BytecodeWriterConfig &config;
277};
278} // namespace detail
279} // namespace bytecode
280} // namespace mlir
281
282#endif
283

source code of mlir/lib/Bytecode/Writer/IRNumbering.h