1 | //===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "IRNumbering.h" |
10 | #include "mlir/Bytecode/BytecodeImplementation.h" |
11 | #include "mlir/Bytecode/BytecodeOpInterface.h" |
12 | #include "mlir/Bytecode/Encoding.h" |
13 | #include "mlir/IR/AsmState.h" |
14 | #include "mlir/IR/BuiltinTypes.h" |
15 | #include "mlir/IR/OpDefinition.h" |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::bytecode::detail; |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // NumberingDialectWriter |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter { |
25 | NumberingDialectWriter( |
26 | IRNumberingState &state, |
27 | llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap) |
28 | : state(state), dialectVersionMap(dialectVersionMap) {} |
29 | |
30 | void writeAttribute(Attribute attr) override { state.number(attr); } |
31 | void writeOptionalAttribute(Attribute attr) override { |
32 | if (attr) |
33 | state.number(attr); |
34 | } |
35 | void writeType(Type type) override { state.number(type); } |
36 | void writeResourceHandle(const AsmDialectResourceHandle &resource) override { |
37 | state.number(dialect: resource.getDialect(), resources: resource); |
38 | } |
39 | |
40 | /// Stubbed out methods that are not used for numbering. |
41 | void writeVarInt(uint64_t) override {} |
42 | void writeSignedVarInt(int64_t value) override {} |
43 | void writeAPIntWithKnownWidth(const APInt &value) override {} |
44 | void writeAPFloatWithKnownSemantics(const APFloat &value) override {} |
45 | void writeOwnedString(StringRef) override { |
46 | // TODO: It might be nice to prenumber strings and sort by the number of |
47 | // references. This could potentially be useful for optimizing things like |
48 | // file locations. |
49 | } |
50 | void writeOwnedBlob(ArrayRef<char> blob) override {} |
51 | void writeOwnedBool(bool value) override {} |
52 | |
53 | int64_t getBytecodeVersion() const override { |
54 | return state.getDesiredBytecodeVersion(); |
55 | } |
56 | |
57 | FailureOr<const DialectVersion *> |
58 | getDialectVersion(StringRef dialectName) const override { |
59 | auto dialectEntry = dialectVersionMap.find(Key: dialectName); |
60 | if (dialectEntry == dialectVersionMap.end()) |
61 | return failure(); |
62 | return dialectEntry->getValue().get(); |
63 | } |
64 | |
65 | /// The parent numbering state that is populated by this writer. |
66 | IRNumberingState &state; |
67 | |
68 | /// A map containing dialect version information for each dialect to emit. |
69 | llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap; |
70 | }; |
71 | |
72 | //===----------------------------------------------------------------------===// |
73 | // IR Numbering |
74 | //===----------------------------------------------------------------------===// |
75 | |
76 | /// Group and sort the elements of the given range by their parent dialect. This |
77 | /// grouping is applied to sub-sections of the ranged defined by how many bytes |
78 | /// it takes to encode a varint index to that sub-section. |
79 | template <typename T> |
80 | static void groupByDialectPerByte(T range) { |
81 | if (range.empty()) |
82 | return; |
83 | |
84 | // A functor used to sort by a given dialect, with a desired dialect to be |
85 | // ordered first (to better enable sharing of dialects across byte groups). |
86 | auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs, |
87 | const auto &rhs) { |
88 | if (lhs->dialect->number == dialectToOrderFirst) |
89 | return rhs->dialect->number != dialectToOrderFirst; |
90 | if (rhs->dialect->number == dialectToOrderFirst) |
91 | return false; |
92 | return lhs->dialect->number < rhs->dialect->number; |
93 | }; |
94 | |
95 | unsigned dialectToOrderFirst = 0; |
96 | size_t elementsInByteGroup = 0; |
97 | auto iterRange = range; |
98 | for (unsigned i = 1; i < 9; ++i) { |
99 | // Update the number of elements in the current byte grouping. Reminder |
100 | // that varint encodes 7-bits per byte, so that's how we compute the |
101 | // number of elements in each byte grouping. |
102 | elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup; |
103 | |
104 | // Slice out the sub-set of elements that are in the current byte grouping |
105 | // to be sorted. |
106 | auto byteSubRange = iterRange.take_front(elementsInByteGroup); |
107 | iterRange = iterRange.drop_front(byteSubRange.size()); |
108 | |
109 | // Sort the sub range for this byte. |
110 | llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) { |
111 | return sortByDialect(dialectToOrderFirst, lhs, rhs); |
112 | }); |
113 | |
114 | // Update the dialect to order first to be the dialect at the end of the |
115 | // current grouping. This seeks to allow larger dialect groupings across |
116 | // byte boundaries. |
117 | dialectToOrderFirst = byteSubRange.back()->dialect->number; |
118 | |
119 | // If the data range is now empty, we are done. |
120 | if (iterRange.empty()) |
121 | break; |
122 | } |
123 | |
124 | // Assign the entry numbers based on the sort order. |
125 | for (auto [idx, value] : llvm::enumerate(range)) |
126 | value->number = idx; |
127 | } |
128 | |
129 | IRNumberingState::IRNumberingState(Operation *op, |
130 | const BytecodeWriterConfig &config) |
131 | : config(config) { |
132 | computeGlobalNumberingState(rootOp: op); |
133 | |
134 | // Number the root operation. |
135 | number(op&: *op); |
136 | |
137 | // A worklist of region contexts to number and the next value id before that |
138 | // region. |
139 | SmallVector<std::pair<Region *, unsigned>, 8> numberContext; |
140 | |
141 | // Functor to push the regions of the given operation onto the numbering |
142 | // context. |
143 | auto addOpRegionsToNumber = [&](Operation *op) { |
144 | MutableArrayRef<Region> regions = op->getRegions(); |
145 | if (regions.empty()) |
146 | return; |
147 | |
148 | // Isolated regions don't share value numbers with their parent, so we can |
149 | // start numbering these regions at zero. |
150 | unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID; |
151 | for (Region ®ion : regions) |
152 | numberContext.emplace_back(Args: ®ion, Args&: opFirstValueID); |
153 | }; |
154 | addOpRegionsToNumber(op); |
155 | |
156 | // Iteratively process each of the nested regions. |
157 | while (!numberContext.empty()) { |
158 | Region *region; |
159 | std::tie(args&: region, args&: nextValueID) = numberContext.pop_back_val(); |
160 | number(region&: *region); |
161 | |
162 | // Traverse into nested regions. |
163 | for (Operation &op : region->getOps()) |
164 | addOpRegionsToNumber(&op); |
165 | } |
166 | |
167 | // Number each of the dialects. For now this is just in the order they were |
168 | // found, given that the number of dialects on average is small enough to fit |
169 | // within a singly byte (128). If we ever have real world use cases that have |
170 | // a huge number of dialects, this could be made more intelligent. |
171 | for (auto [idx, dialect] : llvm::enumerate(First&: dialects)) |
172 | dialect.second->number = idx; |
173 | |
174 | // Number each of the recorded components within each dialect. |
175 | |
176 | // First sort by ref count so that the most referenced elements are first. We |
177 | // try to bias more heavily used elements to the front. This allows for more |
178 | // frequently referenced things to be encoded using smaller varints. |
179 | auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) { |
180 | return lhs->refCount > rhs->refCount; |
181 | }; |
182 | llvm::stable_sort(Range&: orderedAttrs, C: sortByRefCountFn); |
183 | llvm::stable_sort(Range&: orderedOpNames, C: sortByRefCountFn); |
184 | llvm::stable_sort(Range&: orderedTypes, C: sortByRefCountFn); |
185 | |
186 | // After that, we apply a secondary ordering based on the parent dialect. This |
187 | // ordering is applied to sub-sections of the element list defined by how many |
188 | // bytes it takes to encode a varint index to that sub-section. This allows |
189 | // for more efficiently encoding components of the same dialect (e.g. we only |
190 | // have to encode the dialect reference once). |
191 | groupByDialectPerByte(range: llvm::MutableArrayRef(orderedAttrs)); |
192 | groupByDialectPerByte(range: llvm::MutableArrayRef(orderedOpNames)); |
193 | groupByDialectPerByte(range: llvm::MutableArrayRef(orderedTypes)); |
194 | |
195 | // Finalize the numbering of the dialect resources. |
196 | finalizeDialectResourceNumberings(rootOp: op); |
197 | } |
198 | |
199 | void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) { |
200 | // A simple state struct tracking data used when walking operations. |
201 | struct StackState { |
202 | /// The operation currently being walked. |
203 | Operation *op; |
204 | |
205 | /// The numbering of the operation. |
206 | OperationNumbering *numbering; |
207 | |
208 | /// A flag indicating if the current state or one of its parents has |
209 | /// unresolved isolation status. This is tracked separately from the |
210 | /// isIsolatedFromAbove bit on `numbering` because we need to be able to |
211 | /// handle the given case: |
212 | /// top.op { |
213 | /// %value = ... |
214 | /// middle.op { |
215 | /// %value2 = ... |
216 | /// inner.op { |
217 | /// // Here we mark `inner.op` as not isolated. Note `middle.op` |
218 | /// // isn't known not isolated yet. |
219 | /// use.op %value2 |
220 | /// |
221 | /// // Here inner.op is already known to be non-isolated, but |
222 | /// // `middle.op` is now also discovered to be non-isolated. |
223 | /// use.op %value |
224 | /// } |
225 | /// } |
226 | /// } |
227 | bool hasUnresolvedIsolation; |
228 | }; |
229 | |
230 | // Compute a global operation ID numbering according to the pre-order walk of |
231 | // the IR. This is used as reference to construct use-list orders. |
232 | unsigned operationID = 0; |
233 | |
234 | // Walk each of the operations within the IR, tracking a stack of operations |
235 | // as we recurse into nested regions. This walk method hooks in at two stages |
236 | // during the walk: |
237 | // |
238 | // BeforeAllRegions: |
239 | // Here we generate a numbering for the operation and push it onto the |
240 | // stack if it has regions. We also compute the isolation status of parent |
241 | // regions at this stage. This is done by checking the parent regions of |
242 | // operands used by the operation, and marking each region between the |
243 | // the operand region and the current as not isolated. See |
244 | // StackState::hasUnresolvedIsolation above for an example. |
245 | // |
246 | // AfterAllRegions: |
247 | // Here we pop the operation from the stack, and if it hasn't been marked |
248 | // as non-isolated, we mark it as so. A non-isolated use would have been |
249 | // found while walking the regions, so it is safe to mark the operation at |
250 | // this point. |
251 | // |
252 | SmallVector<StackState> opStack; |
253 | rootOp->walk(callback: [&](Operation *op, const WalkStage &stage) { |
254 | // After visiting all nested regions, we pop the operation from the stack. |
255 | if (op->getNumRegions() && stage.isAfterAllRegions()) { |
256 | // If no non-isolated uses were found, we can safely mark this operation |
257 | // as isolated from above. |
258 | OperationNumbering *numbering = opStack.pop_back_val().numbering; |
259 | if (!numbering->isIsolatedFromAbove.has_value()) |
260 | numbering->isIsolatedFromAbove = true; |
261 | return; |
262 | } |
263 | |
264 | // When visiting before nested regions, we process "IsolatedFromAbove" |
265 | // checks and compute the number for this operation. |
266 | if (!stage.isBeforeAllRegions()) |
267 | return; |
268 | // Update the isolation status of parent regions if any have yet to be |
269 | // resolved. |
270 | if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) { |
271 | Region *parentRegion = op->getParentRegion(); |
272 | for (Value operand : op->getOperands()) { |
273 | Region *operandRegion = operand.getParentRegion(); |
274 | if (operandRegion == parentRegion) |
275 | continue; |
276 | // We've found a use of an operand outside of the current region, |
277 | // walk the operation stack searching for the parent operation, |
278 | // marking every region on the way as not isolated. |
279 | Operation *operandContainerOp = operandRegion->getParentOp(); |
280 | auto it = std::find_if( |
281 | first: opStack.rbegin(), last: opStack.rend(), pred: [=](const StackState &it) { |
282 | // We only need to mark up to the container region, or the first |
283 | // that has an unresolved status. |
284 | return !it.hasUnresolvedIsolation || it.op == operandContainerOp; |
285 | }); |
286 | assert(it != opStack.rend() && "expected to find the container" ); |
287 | for (auto &state : llvm::make_range(x: opStack.rbegin(), y: it)) { |
288 | // If we stopped at a region that knows its isolation status, we can |
289 | // stop updating the isolation status for the parent regions. |
290 | state.hasUnresolvedIsolation = it->hasUnresolvedIsolation; |
291 | state.numbering->isIsolatedFromAbove = false; |
292 | } |
293 | } |
294 | } |
295 | |
296 | // Compute the number for this op and push it onto the stack. |
297 | auto *numbering = |
298 | new (opAllocator.Allocate()) OperationNumbering(operationID++); |
299 | if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
300 | numbering->isIsolatedFromAbove = true; |
301 | operations.try_emplace(Key: op, Args&: numbering); |
302 | if (op->getNumRegions()) { |
303 | opStack.emplace_back(Args: StackState{ |
304 | .op: op, .numbering: numbering, .hasUnresolvedIsolation: !numbering->isIsolatedFromAbove.has_value()}); |
305 | } |
306 | }); |
307 | } |
308 | |
309 | void IRNumberingState::number(Attribute attr) { |
310 | auto it = attrs.insert(KV: {attr, nullptr}); |
311 | if (!it.second) { |
312 | ++it.first->second->refCount; |
313 | return; |
314 | } |
315 | auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr); |
316 | it.first->second = numbering; |
317 | orderedAttrs.push_back(x: numbering); |
318 | |
319 | // Check for OpaqueAttr, which is a dialect-specific attribute that didn't |
320 | // have a registered dialect when it got created. We don't want to encode this |
321 | // as the builtin OpaqueAttr, we want to encode it as if the dialect was |
322 | // actually loaded. |
323 | if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) { |
324 | numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace()); |
325 | return; |
326 | } |
327 | numbering->dialect = &numberDialect(dialect: &attr.getDialect()); |
328 | |
329 | // If this attribute will be emitted using the bytecode format, perform a |
330 | // dummy writing to number any nested components. |
331 | // TODO: We don't allow custom encodings for mutable attributes right now. |
332 | if (!attr.hasTrait<AttributeTrait::IsMutable>()) { |
333 | // Try overriding emission with callbacks. |
334 | for (const auto &callback : config.getAttributeWriterCallbacks()) { |
335 | NumberingDialectWriter writer(*this, config.getDialectVersionMap()); |
336 | // The client has the ability to override the group name through the |
337 | // callback. |
338 | std::optional<StringRef> groupNameOverride; |
339 | if (succeeded(result: callback->write(entry: attr, name&: groupNameOverride, writer))) { |
340 | if (groupNameOverride.has_value()) |
341 | numbering->dialect = &numberDialect(dialect: *groupNameOverride); |
342 | return; |
343 | } |
344 | } |
345 | |
346 | if (const auto *interface = numbering->dialect->interface) { |
347 | NumberingDialectWriter writer(*this, config.getDialectVersionMap()); |
348 | if (succeeded(result: interface->writeAttribute(attr, writer))) |
349 | return; |
350 | } |
351 | } |
352 | // If this attribute will be emitted using the fallback, number the nested |
353 | // dialect resources. We don't number everything (e.g. no nested |
354 | // attributes/types), because we don't want to encode things we won't decode |
355 | // (the textual format can't really share much). |
356 | AsmState tempState(attr.getContext()); |
357 | llvm::raw_null_ostream dummyOS; |
358 | attr.print(os&: dummyOS, state&: tempState); |
359 | |
360 | // Number the used dialect resources. |
361 | for (const auto &it : tempState.getDialectResources()) |
362 | number(dialect: it.getFirst(), resources: it.getSecond().getArrayRef()); |
363 | } |
364 | |
365 | void IRNumberingState::number(Block &block) { |
366 | // Number the arguments of the block. |
367 | for (BlockArgument arg : block.getArguments()) { |
368 | valueIDs.try_emplace(Key: arg, Args: nextValueID++); |
369 | number(attr: arg.getLoc()); |
370 | number(type: arg.getType()); |
371 | } |
372 | |
373 | // Number the operations in this block. |
374 | unsigned &numOps = blockOperationCounts[&block]; |
375 | for (Operation &op : block) { |
376 | number(op); |
377 | ++numOps; |
378 | } |
379 | } |
380 | |
381 | auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & { |
382 | DialectNumbering *&numbering = registeredDialects[dialect]; |
383 | if (!numbering) { |
384 | numbering = &numberDialect(dialect: dialect->getNamespace()); |
385 | numbering->interface = dyn_cast<BytecodeDialectInterface>(Val: dialect); |
386 | numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(Val: dialect); |
387 | } |
388 | return *numbering; |
389 | } |
390 | |
391 | auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & { |
392 | DialectNumbering *&numbering = dialects[dialect]; |
393 | if (!numbering) { |
394 | numbering = new (dialectAllocator.Allocate()) |
395 | DialectNumbering(dialect, dialects.size() - 1); |
396 | } |
397 | return *numbering; |
398 | } |
399 | |
400 | void IRNumberingState::number(Region ®ion) { |
401 | if (region.empty()) |
402 | return; |
403 | size_t firstValueID = nextValueID; |
404 | |
405 | // Number the blocks within this region. |
406 | size_t blockCount = 0; |
407 | for (auto it : llvm::enumerate(First&: region)) { |
408 | blockIDs.try_emplace(Key: &it.value(), Args: it.index()); |
409 | number(block&: it.value()); |
410 | ++blockCount; |
411 | } |
412 | |
413 | // Remember the number of blocks and values in this region. |
414 | regionBlockValueCounts.try_emplace(Key: ®ion, Args&: blockCount, |
415 | Args: nextValueID - firstValueID); |
416 | } |
417 | |
418 | void IRNumberingState::number(Operation &op) { |
419 | // Number the components of an operation that won't be numbered elsewhere |
420 | // (e.g. we don't number operands, regions, or successors here). |
421 | number(opName: op.getName()); |
422 | for (OpResult result : op.getResults()) { |
423 | valueIDs.try_emplace(Key: result, Args: nextValueID++); |
424 | number(type: result.getType()); |
425 | } |
426 | |
427 | // Prior to a version with native property encoding, or when properties are |
428 | // not used, we need to number also the merged dictionary containing both the |
429 | // inherent and discardable attribute. |
430 | DictionaryAttr dictAttr; |
431 | if (config.getDesiredBytecodeVersion() >= bytecode::kNativePropertiesEncoding) |
432 | dictAttr = op.getRawDictionaryAttrs(); |
433 | else |
434 | dictAttr = op.getAttrDictionary(); |
435 | // Only number the operation's dictionary if it isn't empty. |
436 | if (!dictAttr.empty()) |
437 | number(dictAttr); |
438 | |
439 | // Visit the operation properties (if any) to make sure referenced attributes |
440 | // are numbered. |
441 | if (config.getDesiredBytecodeVersion() >= |
442 | bytecode::kNativePropertiesEncoding && |
443 | op.getPropertiesStorageSize()) { |
444 | if (op.isRegistered()) { |
445 | // Operation that have properties *must* implement this interface. |
446 | auto iface = cast<BytecodeOpInterface>(op); |
447 | NumberingDialectWriter writer(*this, config.getDialectVersionMap()); |
448 | iface.writeProperties(writer); |
449 | } else { |
450 | // Unregistered op are storing properties as an optional attribute. |
451 | if (Attribute prop = *op.getPropertiesStorage().as<Attribute *>()) |
452 | number(attr: prop); |
453 | } |
454 | } |
455 | |
456 | number(attr: op.getLoc()); |
457 | } |
458 | |
459 | void IRNumberingState::number(OperationName opName) { |
460 | OpNameNumbering *&numbering = opNames[opName]; |
461 | if (numbering) { |
462 | ++numbering->refCount; |
463 | return; |
464 | } |
465 | DialectNumbering *dialectNumber = nullptr; |
466 | if (Dialect *dialect = opName.getDialect()) |
467 | dialectNumber = &numberDialect(dialect); |
468 | else |
469 | dialectNumber = &numberDialect(dialect: opName.getDialectNamespace()); |
470 | |
471 | numbering = |
472 | new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName); |
473 | orderedOpNames.push_back(x: numbering); |
474 | } |
475 | |
476 | void IRNumberingState::number(Type type) { |
477 | auto it = types.insert(KV: {type, nullptr}); |
478 | if (!it.second) { |
479 | ++it.first->second->refCount; |
480 | return; |
481 | } |
482 | auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type); |
483 | it.first->second = numbering; |
484 | orderedTypes.push_back(x: numbering); |
485 | |
486 | // Check for OpaqueType, which is a dialect-specific type that didn't have a |
487 | // registered dialect when it got created. We don't want to encode this as the |
488 | // builtin OpaqueType, we want to encode it as if the dialect was actually |
489 | // loaded. |
490 | if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) { |
491 | numbering->dialect = &numberDialect(opaqueType.getDialectNamespace()); |
492 | return; |
493 | } |
494 | numbering->dialect = &numberDialect(dialect: &type.getDialect()); |
495 | |
496 | // If this type will be emitted using the bytecode format, perform a dummy |
497 | // writing to number any nested components. |
498 | // TODO: We don't allow custom encodings for mutable types right now. |
499 | if (!type.hasTrait<TypeTrait::IsMutable>()) { |
500 | // Try overriding emission with callbacks. |
501 | for (const auto &callback : config.getTypeWriterCallbacks()) { |
502 | NumberingDialectWriter writer(*this, config.getDialectVersionMap()); |
503 | // The client has the ability to override the group name through the |
504 | // callback. |
505 | std::optional<StringRef> groupNameOverride; |
506 | if (succeeded(result: callback->write(entry: type, name&: groupNameOverride, writer))) { |
507 | if (groupNameOverride.has_value()) |
508 | numbering->dialect = &numberDialect(dialect: *groupNameOverride); |
509 | return; |
510 | } |
511 | } |
512 | |
513 | // If this attribute will be emitted using the bytecode format, perform a |
514 | // dummy writing to number any nested components. |
515 | if (const auto *interface = numbering->dialect->interface) { |
516 | NumberingDialectWriter writer(*this, config.getDialectVersionMap()); |
517 | if (succeeded(result: interface->writeType(type, writer))) |
518 | return; |
519 | } |
520 | } |
521 | // If this type will be emitted using the fallback, number the nested dialect |
522 | // resources. We don't number everything (e.g. no nested attributes/types), |
523 | // because we don't want to encode things we won't decode (the textual format |
524 | // can't really share much). |
525 | AsmState tempState(type.getContext()); |
526 | llvm::raw_null_ostream dummyOS; |
527 | type.print(os&: dummyOS, state&: tempState); |
528 | |
529 | // Number the used dialect resources. |
530 | for (const auto &it : tempState.getDialectResources()) |
531 | number(dialect: it.getFirst(), resources: it.getSecond().getArrayRef()); |
532 | } |
533 | |
534 | void IRNumberingState::number(Dialect *dialect, |
535 | ArrayRef<AsmDialectResourceHandle> resources) { |
536 | DialectNumbering &dialectNumber = numberDialect(dialect); |
537 | assert( |
538 | dialectNumber.asmInterface && |
539 | "expected dialect owning a resource to implement OpAsmDialectInterface" ); |
540 | |
541 | for (const auto &resource : resources) { |
542 | // Check if this is a newly seen resource. |
543 | if (!dialectNumber.resources.insert(X: resource)) |
544 | return; |
545 | |
546 | auto *numbering = |
547 | new (resourceAllocator.Allocate()) DialectResourceNumbering( |
548 | dialectNumber.asmInterface->getResourceKey(handle: resource)); |
549 | dialectNumber.resourceMap.insert(KV: {numbering->key, numbering}); |
550 | dialectResources.try_emplace(Key: resource, Args&: numbering); |
551 | } |
552 | } |
553 | |
554 | int64_t IRNumberingState::getDesiredBytecodeVersion() const { |
555 | return config.getDesiredBytecodeVersion(); |
556 | } |
557 | |
558 | namespace { |
559 | /// A dummy resource builder used to number dialect resources. |
560 | struct NumberingResourceBuilder : public AsmResourceBuilder { |
561 | NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID) |
562 | : dialect(dialect), nextResourceID(nextResourceID) {} |
563 | ~NumberingResourceBuilder() override = default; |
564 | |
565 | void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final { |
566 | numberEntry(key); |
567 | } |
568 | void buildBool(StringRef key, bool) final { numberEntry(key); } |
569 | void buildString(StringRef key, StringRef) final { |
570 | // TODO: We could pre-number the value string here as well. |
571 | numberEntry(key); |
572 | } |
573 | |
574 | /// Number the dialect entry for the given key. |
575 | void numberEntry(StringRef key) { |
576 | // TODO: We could pre-number resource key strings here as well. |
577 | |
578 | auto *it = dialect->resourceMap.find(Key: key); |
579 | if (it != dialect->resourceMap.end()) { |
580 | it->second->number = nextResourceID++; |
581 | it->second->isDeclaration = false; |
582 | } |
583 | } |
584 | |
585 | DialectNumbering *dialect; |
586 | unsigned &nextResourceID; |
587 | }; |
588 | } // namespace |
589 | |
590 | void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) { |
591 | unsigned nextResourceID = 0; |
592 | for (DialectNumbering &dialect : getDialects()) { |
593 | if (!dialect.asmInterface) |
594 | continue; |
595 | NumberingResourceBuilder entryBuilder(&dialect, nextResourceID); |
596 | dialect.asmInterface->buildResources(op: rootOp, referencedResources: dialect.resources, |
597 | builder&: entryBuilder); |
598 | |
599 | // Number any resources that weren't added by the dialect. This can happen |
600 | // if there was no backing data to the resource, but we still want these |
601 | // resource references to roundtrip, so we number them and indicate that the |
602 | // data is missing. |
603 | for (const auto &it : dialect.resourceMap) |
604 | if (it.second->isDeclaration) |
605 | it.second->number = nextResourceID++; |
606 | } |
607 | } |
608 | |