| 1 | //===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains various utilities that number IR structures in preparation |
| 10 | // for bytecode emission. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H |
| 15 | #define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H |
| 16 | |
| 17 | #include "mlir/IR/OpImplementation.h" |
| 18 | #include "llvm/ADT/MapVector.h" |
| 19 | #include "llvm/ADT/SetVector.h" |
| 20 | #include "llvm/ADT/StringMap.h" |
| 21 | #include <cstdint> |
| 22 | |
| 23 | namespace mlir { |
| 24 | class BytecodeDialectInterface; |
| 25 | class BytecodeWriterConfig; |
| 26 | |
| 27 | namespace bytecode { |
| 28 | namespace detail { |
| 29 | struct DialectNumbering; |
| 30 | |
| 31 | //===----------------------------------------------------------------------===// |
| 32 | // Attribute and Type Numbering |
| 33 | //===----------------------------------------------------------------------===// |
| 34 | |
| 35 | /// This class represents a numbering entry for an Attribute or Type. |
| 36 | struct AttrTypeNumbering { |
| 37 | AttrTypeNumbering(PointerUnion<Attribute, Type> value) : value(value) {} |
| 38 | |
| 39 | /// The concrete value. |
| 40 | PointerUnion<Attribute, Type> value; |
| 41 | |
| 42 | /// The number assigned to this value. |
| 43 | unsigned number = 0; |
| 44 | |
| 45 | /// The number of references to this value. |
| 46 | unsigned refCount = 1; |
| 47 | |
| 48 | /// The dialect of this value. |
| 49 | DialectNumbering *dialect = nullptr; |
| 50 | }; |
| 51 | struct AttributeNumbering : public AttrTypeNumbering { |
| 52 | AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {} |
| 53 | Attribute getValue() const { return cast<Attribute>(Val: value); } |
| 54 | }; |
| 55 | struct TypeNumbering : public AttrTypeNumbering { |
| 56 | TypeNumbering(Type value) : AttrTypeNumbering(value) {} |
| 57 | Type getValue() const { return cast<Type>(Val: value); } |
| 58 | }; |
| 59 | |
| 60 | //===----------------------------------------------------------------------===// |
| 61 | // OpName Numbering |
| 62 | //===----------------------------------------------------------------------===// |
| 63 | |
| 64 | /// This class represents the numbering entry of an operation name. |
| 65 | struct { |
| 66 | (DialectNumbering *dialect, OperationName name) |
| 67 | : dialect(dialect), name(name) {} |
| 68 | |
| 69 | /// The dialect of this value. |
| 70 | DialectNumbering *; |
| 71 | |
| 72 | /// The concrete name. |
| 73 | OperationName ; |
| 74 | |
| 75 | /// The number assigned to this name. |
| 76 | unsigned = 0; |
| 77 | |
| 78 | /// The number of references to this name. |
| 79 | unsigned = 1; |
| 80 | }; |
| 81 | |
| 82 | //===----------------------------------------------------------------------===// |
| 83 | // Dialect Resource Numbering |
| 84 | //===----------------------------------------------------------------------===// |
| 85 | |
| 86 | /// This class represents a numbering entry for a dialect resource. |
| 87 | struct DialectResourceNumbering { |
| 88 | DialectResourceNumbering(std::string key) : key(std::move(key)) {} |
| 89 | |
| 90 | /// The key used to reference this resource. |
| 91 | std::string key; |
| 92 | |
| 93 | /// The number assigned to this resource. |
| 94 | unsigned number = 0; |
| 95 | |
| 96 | /// A flag indicating if this resource is only a declaration, not a full |
| 97 | /// definition. |
| 98 | bool isDeclaration = true; |
| 99 | }; |
| 100 | |
| 101 | //===----------------------------------------------------------------------===// |
| 102 | // Dialect Numbering |
| 103 | //===----------------------------------------------------------------------===// |
| 104 | |
| 105 | /// This class represents a numbering entry for an Dialect. |
| 106 | struct DialectNumbering { |
| 107 | DialectNumbering(StringRef name, unsigned number) |
| 108 | : name(name), number(number) {} |
| 109 | |
| 110 | /// The namespace of the dialect. |
| 111 | StringRef name; |
| 112 | |
| 113 | /// The number assigned to the dialect. |
| 114 | unsigned number; |
| 115 | |
| 116 | /// The bytecode dialect interface of the dialect if defined. |
| 117 | const BytecodeDialectInterface *interface = nullptr; |
| 118 | |
| 119 | /// The asm dialect interface of the dialect if defined. |
| 120 | const OpAsmDialectInterface *asmInterface = nullptr; |
| 121 | |
| 122 | /// The referenced resources of this dialect. |
| 123 | SetVector<AsmDialectResourceHandle> resources; |
| 124 | |
| 125 | /// A mapping from resource key to the corresponding resource numbering entry. |
| 126 | llvm::MapVector<StringRef, DialectResourceNumbering *> resourceMap; |
| 127 | }; |
| 128 | |
| 129 | //===----------------------------------------------------------------------===// |
| 130 | // Operation Numbering |
| 131 | //===----------------------------------------------------------------------===// |
| 132 | |
| 133 | /// This class represents the numbering entry of an operation. |
| 134 | struct OperationNumbering { |
| 135 | OperationNumbering(unsigned number) : number(number) {} |
| 136 | |
| 137 | /// The number assigned to this operation. |
| 138 | unsigned number; |
| 139 | |
| 140 | /// A flag indicating if this operation's regions are isolated. If unset, the |
| 141 | /// operation isn't yet known to be isolated. |
| 142 | std::optional<bool> isIsolatedFromAbove; |
| 143 | }; |
| 144 | |
| 145 | //===----------------------------------------------------------------------===// |
| 146 | // IRNumberingState |
| 147 | //===----------------------------------------------------------------------===// |
| 148 | |
| 149 | /// This class manages numbering IR entities in preparation of bytecode |
| 150 | /// emission. |
| 151 | class IRNumberingState { |
| 152 | public: |
| 153 | IRNumberingState(Operation *op, const BytecodeWriterConfig &config); |
| 154 | |
| 155 | /// Return the numbered dialects. |
| 156 | auto getDialects() { |
| 157 | return llvm::make_pointee_range(Range: llvm::make_second_range(c&: dialects)); |
| 158 | } |
| 159 | auto getAttributes() { return llvm::make_pointee_range(Range&: orderedAttrs); } |
| 160 | auto getOpNames() { return llvm::make_pointee_range(Range&: orderedOpNames); } |
| 161 | auto getTypes() { return llvm::make_pointee_range(Range&: orderedTypes); } |
| 162 | |
| 163 | /// Return the number for the given IR unit. |
| 164 | unsigned getNumber(Attribute attr) { |
| 165 | assert(attrs.count(attr) && "attribute not numbered" ); |
| 166 | return attrs[attr]->number; |
| 167 | } |
| 168 | unsigned getNumber(Block *block) { |
| 169 | assert(blockIDs.count(block) && "block not numbered" ); |
| 170 | return blockIDs[block]; |
| 171 | } |
| 172 | unsigned getNumber(Operation *op) { |
| 173 | assert(operations.count(op) && "operation not numbered" ); |
| 174 | return operations[op]->number; |
| 175 | } |
| 176 | unsigned getNumber(OperationName opName) { |
| 177 | assert(opNames.count(opName) && "opName not numbered" ); |
| 178 | return opNames[opName]->number; |
| 179 | } |
| 180 | unsigned getNumber(Type type) { |
| 181 | assert(types.count(type) && "type not numbered" ); |
| 182 | return types[type]->number; |
| 183 | } |
| 184 | unsigned getNumber(Value value) { |
| 185 | assert(valueIDs.count(value) && "value not numbered" ); |
| 186 | return valueIDs[value]; |
| 187 | } |
| 188 | unsigned getNumber(const AsmDialectResourceHandle &resource) { |
| 189 | assert(dialectResources.count(resource) && "resource not numbered" ); |
| 190 | return dialectResources[resource]->number; |
| 191 | } |
| 192 | |
| 193 | /// Return the block and value counts of the given region. |
| 194 | std::pair<unsigned, unsigned> getBlockValueCount(Region *region) { |
| 195 | assert(regionBlockValueCounts.count(region) && "value not numbered" ); |
| 196 | return regionBlockValueCounts[region]; |
| 197 | } |
| 198 | |
| 199 | /// Return the number of operations in the given block. |
| 200 | unsigned getOperationCount(Block *block) { |
| 201 | assert(blockOperationCounts.count(block) && "block not numbered" ); |
| 202 | return blockOperationCounts[block]; |
| 203 | } |
| 204 | |
| 205 | /// Return if the given operation is isolated from above. |
| 206 | bool isIsolatedFromAbove(Operation *op) { |
| 207 | assert(operations.count(op) && "operation not numbered" ); |
| 208 | return operations[op]->isIsolatedFromAbove.value_or(u: false); |
| 209 | } |
| 210 | |
| 211 | /// Get the set desired bytecode version to emit. |
| 212 | int64_t getDesiredBytecodeVersion() const; |
| 213 | |
| 214 | private: |
| 215 | /// This class is used to provide a fake dialect writer for numbering nested |
| 216 | /// attributes and types. |
| 217 | struct NumberingDialectWriter; |
| 218 | |
| 219 | /// Compute the global numbering state for the given root operation. |
| 220 | void computeGlobalNumberingState(Operation *rootOp); |
| 221 | |
| 222 | /// Number the given IR unit for bytecode emission. |
| 223 | void number(Attribute attr); |
| 224 | void number(Block &block); |
| 225 | DialectNumbering &numberDialect(Dialect *dialect); |
| 226 | DialectNumbering &numberDialect(StringRef dialect); |
| 227 | void number(Operation &op); |
| 228 | void number(OperationName opName); |
| 229 | void number(Region ®ion); |
| 230 | void number(Type type); |
| 231 | |
| 232 | /// Number the given dialect resources. |
| 233 | void number(Dialect *dialect, ArrayRef<AsmDialectResourceHandle> resources); |
| 234 | |
| 235 | /// Finalize the numberings of any dialect resources. |
| 236 | void finalizeDialectResourceNumberings(Operation *rootOp); |
| 237 | |
| 238 | /// Mapping from IR to the respective numbering entries. |
| 239 | DenseMap<Attribute, AttributeNumbering *> attrs; |
| 240 | DenseMap<Operation *, OperationNumbering *> operations; |
| 241 | DenseMap<OperationName, OpNameNumbering *> opNames; |
| 242 | DenseMap<Type, TypeNumbering *> types; |
| 243 | DenseMap<Dialect *, DialectNumbering *> registeredDialects; |
| 244 | llvm::MapVector<StringRef, DialectNumbering *> dialects; |
| 245 | std::vector<AttributeNumbering *> orderedAttrs; |
| 246 | std::vector<OpNameNumbering *> orderedOpNames; |
| 247 | std::vector<TypeNumbering *> orderedTypes; |
| 248 | |
| 249 | /// A mapping from dialect resource handle to the numbering for the referenced |
| 250 | /// resource. |
| 251 | llvm::DenseMap<AsmDialectResourceHandle, DialectResourceNumbering *> |
| 252 | dialectResources; |
| 253 | |
| 254 | /// Allocators used for the various numbering entries. |
| 255 | llvm::SpecificBumpPtrAllocator<AttributeNumbering> attrAllocator; |
| 256 | llvm::SpecificBumpPtrAllocator<DialectNumbering> dialectAllocator; |
| 257 | llvm::SpecificBumpPtrAllocator<OperationNumbering> opAllocator; |
| 258 | llvm::SpecificBumpPtrAllocator<OpNameNumbering> opNameAllocator; |
| 259 | llvm::SpecificBumpPtrAllocator<DialectResourceNumbering> resourceAllocator; |
| 260 | llvm::SpecificBumpPtrAllocator<TypeNumbering> typeAllocator; |
| 261 | |
| 262 | /// The value ID for each Block and Value. |
| 263 | DenseMap<Block *, unsigned> blockIDs; |
| 264 | DenseMap<Value, unsigned> valueIDs; |
| 265 | |
| 266 | /// The number of operations in each block. |
| 267 | DenseMap<Block *, unsigned> blockOperationCounts; |
| 268 | |
| 269 | /// A map from region to the number of blocks and values within that region. |
| 270 | DenseMap<Region *, std::pair<unsigned, unsigned>> regionBlockValueCounts; |
| 271 | |
| 272 | /// The next value ID to assign when numbering. |
| 273 | unsigned nextValueID = 0; |
| 274 | |
| 275 | // Configuration: useful to query the required version to emit. |
| 276 | const BytecodeWriterConfig &config; |
| 277 | }; |
| 278 | } // namespace detail |
| 279 | } // namespace bytecode |
| 280 | } // namespace mlir |
| 281 | |
| 282 | #endif |
| 283 | |