1//===- OperationSupport.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains out-of-line implementations of the support types that
10// Operation and related classes build on top of.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/IR/OperationSupport.h"
15#include "mlir/IR/BuiltinAttributes.h"
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/OpDefinition.h"
18#include "llvm/Support/SHA1.h"
19#include <numeric>
20#include <optional>
21
22using namespace mlir;
23
24//===----------------------------------------------------------------------===//
25// NamedAttrList
26//===----------------------------------------------------------------------===//
27
28NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) {
29 assign(inStart: attributes.begin(), inEnd: attributes.end());
30}
31
32NamedAttrList::NamedAttrList(DictionaryAttr attributes)
33 : NamedAttrList(attributes ? attributes.getValue()
34 : ArrayRef<NamedAttribute>()) {
35 dictionarySorted.setPointerAndInt(PtrVal: attributes, IntVal: true);
36}
37
38NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) {
39 assign(inStart, inEnd);
40}
41
42ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
43
44std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
45 std::optional<NamedAttribute> duplicate =
46 DictionaryAttr::findDuplicate(array&: attrs, isSorted: isSorted());
47 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
48 // state.
49 if (!isSorted())
50 dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: true);
51 return duplicate;
52}
53
54DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
55 if (!isSorted()) {
56 DictionaryAttr::sortInPlace(array&: attrs);
57 dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: true);
58 }
59 if (!dictionarySorted.getPointer())
60 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, value: attrs));
61 return llvm::cast<DictionaryAttr>(Val: dictionarySorted.getPointer());
62}
63
64/// Replaces the attributes with new list of attributes.
65void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) {
66 DictionaryAttr::sort(values: ArrayRef<NamedAttribute>{inStart, inEnd}, storage&: attrs);
67 dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: true);
68}
69
70void NamedAttrList::push_back(NamedAttribute newAttribute) {
71 if (isSorted())
72 dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
73 dictionarySorted.setPointer(nullptr);
74 attrs.push_back(Elt: newAttribute);
75}
76
77/// Return the specified attribute if present, null otherwise.
78Attribute NamedAttrList::get(StringRef name) const {
79 auto it = findAttr(attrs: *this, name);
80 return it.second ? it.first->getValue() : Attribute();
81}
82Attribute NamedAttrList::get(StringAttr name) const {
83 auto it = findAttr(attrs: *this, name);
84 return it.second ? it.first->getValue() : Attribute();
85}
86
87/// Return the specified named attribute if present, std::nullopt otherwise.
88std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
89 auto it = findAttr(attrs: *this, name);
90 return it.second ? *it.first : std::optional<NamedAttribute>();
91}
92std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
93 auto it = findAttr(attrs: *this, name);
94 return it.second ? *it.first : std::optional<NamedAttribute>();
95}
96
97/// If the an attribute exists with the specified name, change it to the new
98/// value. Otherwise, add a new attribute with the specified name/value.
99Attribute NamedAttrList::set(StringAttr name, Attribute value) {
100 assert(value && "attributes may never be null");
101
102 // Look for an existing attribute with the given name, and set its value
103 // in-place. Return the previous value of the attribute, if there was one.
104 auto it = findAttr(attrs&: *this, name);
105 if (it.second) {
106 // Update the existing attribute by swapping out the old value for the new
107 // value. Return the old value.
108 Attribute oldValue = it.first->getValue();
109 if (it.first->getValue() != value) {
110 it.first->setValue(value);
111
112 // If the attributes have changed, the dictionary is invalidated.
113 dictionarySorted.setPointer(nullptr);
114 }
115 return oldValue;
116 }
117 // Perform a string lookup to insert the new attribute into its sorted
118 // position.
119 if (isSorted())
120 it = findAttr(attrs&: *this, name: name.strref());
121 attrs.insert(I: it.first, Elt: {name, value});
122 // Invalidate the dictionary. Return null as there was no previous value.
123 dictionarySorted.setPointer(nullptr);
124 return Attribute();
125}
126
127Attribute NamedAttrList::set(StringRef name, Attribute value) {
128 assert(value && "attributes may never be null");
129 return set(name: mlir::StringAttr::get(context: value.getContext(), bytes: name), value);
130}
131
132Attribute
133NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
134 // Erasing does not affect the sorted property.
135 Attribute attr = it->getValue();
136 attrs.erase(CI: it);
137 dictionarySorted.setPointer(nullptr);
138 return attr;
139}
140
141Attribute NamedAttrList::erase(StringAttr name) {
142 auto it = findAttr(attrs&: *this, name);
143 return it.second ? eraseImpl(it: it.first) : Attribute();
144}
145
146Attribute NamedAttrList::erase(StringRef name) {
147 auto it = findAttr(attrs&: *this, name);
148 return it.second ? eraseImpl(it: it.first) : Attribute();
149}
150
151NamedAttrList &
152NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) {
153 assign(inStart: rhs.begin(), inEnd: rhs.end());
154 return *this;
155}
156
157NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
158
159//===----------------------------------------------------------------------===//
160// OperationState
161//===----------------------------------------------------------------------===//
162
163OperationState::OperationState(Location location, StringRef name)
164 : location(location), name(name, location->getContext()) {}
165
166OperationState::OperationState(Location location, OperationName name)
167 : location(location), name(name) {}
168
169OperationState::OperationState(Location location, OperationName name,
170 ValueRange operands, TypeRange types,
171 ArrayRef<NamedAttribute> attributes,
172 BlockRange successors,
173 MutableArrayRef<std::unique_ptr<Region>> regions)
174 : location(location), name(name),
175 operands(operands.begin(), operands.end()),
176 types(types.begin(), types.end()),
177 attributes(attributes.begin(), attributes.end()),
178 successors(successors.begin(), successors.end()) {
179 for (std::unique_ptr<Region> &r : regions)
180 this->regions.push_back(Elt: std::move(r));
181}
182OperationState::OperationState(Location location, StringRef name,
183 ValueRange operands, TypeRange types,
184 ArrayRef<NamedAttribute> attributes,
185 BlockRange successors,
186 MutableArrayRef<std::unique_ptr<Region>> regions)
187 : OperationState(location, OperationName(name, location.getContext()),
188 operands, types, attributes, successors, regions) {}
189
190OperationState::~OperationState() {
191 if (properties)
192 propertiesDeleter(properties);
193}
194
195LogicalResult OperationState::setProperties(
196 Operation *op, function_ref<InFlightDiagnostic()> emitError) const {
197 if (LLVM_UNLIKELY(propertiesAttr)) {
198 assert(!properties);
199 return op->setPropertiesFromAttribute(attr: propertiesAttr, emitError);
200 }
201 if (properties)
202 propertiesSetter(op->getPropertiesStorage(), properties);
203 return success();
204}
205
206void OperationState::addOperands(ValueRange newOperands) {
207 operands.append(in_start: newOperands.begin(), in_end: newOperands.end());
208}
209
210void OperationState::addSuccessors(BlockRange newSuccessors) {
211 successors.append(in_start: newSuccessors.begin(), in_end: newSuccessors.end());
212}
213
214Region *OperationState::addRegion() {
215 regions.emplace_back(Args: new Region);
216 return regions.back().get();
217}
218
219void OperationState::addRegion(std::unique_ptr<Region> &&region) {
220 regions.push_back(Elt: std::move(region));
221}
222
223void OperationState::addRegions(
224 MutableArrayRef<std::unique_ptr<Region>> regions) {
225 for (std::unique_ptr<Region> &region : regions)
226 addRegion(region: std::move(region));
227}
228
229//===----------------------------------------------------------------------===//
230// OperandStorage
231//===----------------------------------------------------------------------===//
232
233detail::OperandStorage::OperandStorage(Operation *owner,
234 OpOperand *trailingOperands,
235 ValueRange values)
236 : isStorageDynamic(false), operandStorage(trailingOperands) {
237 numOperands = capacity = values.size();
238 for (unsigned i = 0; i < numOperands; ++i)
239 new (&operandStorage[i]) OpOperand(owner, values[i]);
240}
241
242detail::OperandStorage::~OperandStorage() {
243 for (auto &operand : getOperands())
244 operand.~OpOperand();
245
246 // If the storage is dynamic, deallocate it.
247 if (isStorageDynamic)
248 free(ptr: operandStorage);
249}
250
251/// Replace the operands contained in the storage with the ones provided in
252/// 'values'.
253void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
254 MutableArrayRef<OpOperand> storageOperands = resize(owner, newSize: values.size());
255 for (unsigned i = 0, e = values.size(); i != e; ++i)
256 storageOperands[i].set(values[i]);
257}
258
259/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
260/// with the ones provided in 'operands'. 'operands' may be smaller or larger
261/// than the range pointed to by 'start'+'length'.
262void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
263 unsigned length, ValueRange operands) {
264 // If the new size is the same, we can update inplace.
265 unsigned newSize = operands.size();
266 if (newSize == length) {
267 MutableArrayRef<OpOperand> storageOperands = getOperands();
268 for (unsigned i = 0, e = length; i != e; ++i)
269 storageOperands[start + i].set(operands[i]);
270 return;
271 }
272 // If the new size is greater, remove the extra operands and set the rest
273 // inplace.
274 if (newSize < length) {
275 eraseOperands(start: start + operands.size(), length: length - newSize);
276 setOperands(owner, start, length: newSize, operands);
277 return;
278 }
279 // Otherwise, the new size is greater so we need to grow the storage.
280 auto storageOperands = resize(owner, newSize: size() + (newSize - length));
281
282 // Shift operands to the right to make space for the new operands.
283 unsigned rotateSize = storageOperands.size() - (start + length);
284 auto rbegin = storageOperands.rbegin();
285 std::rotate(first: rbegin, middle: std::next(x: rbegin, n: newSize - length), last: rbegin + rotateSize);
286
287 // Update the operands inplace.
288 for (unsigned i = 0, e = operands.size(); i != e; ++i)
289 storageOperands[start + i].set(operands[i]);
290}
291
292/// Erase an operand held by the storage.
293void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
294 MutableArrayRef<OpOperand> operands = getOperands();
295 assert((start + length) <= operands.size());
296 numOperands -= length;
297
298 // Shift all operands down if the operand to remove is not at the end.
299 if (start != numOperands) {
300 auto *indexIt = std::next(x: operands.begin(), n: start);
301 std::rotate(first: indexIt, middle: std::next(x: indexIt, n: length), last: operands.end());
302 }
303 for (unsigned i = 0; i != length; ++i)
304 operands[numOperands + i].~OpOperand();
305}
306
307void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
308 MutableArrayRef<OpOperand> operands = getOperands();
309 assert(eraseIndices.size() == operands.size());
310
311 // Check that at least one operand is erased.
312 int firstErasedIndice = eraseIndices.find_first();
313 if (firstErasedIndice == -1)
314 return;
315
316 // Shift all of the removed operands to the end, and destroy them.
317 numOperands = firstErasedIndice;
318 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
319 if (!eraseIndices.test(Idx: i))
320 operands[numOperands++] = std::move(operands[i]);
321 for (OpOperand &operand : operands.drop_front(N: numOperands))
322 operand.~OpOperand();
323}
324
325/// Resize the storage to the given size. Returns the array containing the new
326/// operands.
327MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
328 unsigned newSize) {
329 // If the number of operands is less than or equal to the current amount, we
330 // can just update in place.
331 MutableArrayRef<OpOperand> origOperands = getOperands();
332 if (newSize <= numOperands) {
333 // If the number of new size is less than the current, remove any extra
334 // operands.
335 for (unsigned i = newSize; i != numOperands; ++i)
336 origOperands[i].~OpOperand();
337 numOperands = newSize;
338 return origOperands.take_front(N: newSize);
339 }
340
341 // If the new size is within the original inline capacity, grow inplace.
342 if (newSize <= capacity) {
343 OpOperand *opBegin = origOperands.data();
344 for (unsigned e = newSize; numOperands != e; ++numOperands)
345 new (&opBegin[numOperands]) OpOperand(owner);
346 return MutableArrayRef<OpOperand>(opBegin, newSize);
347 }
348
349 // Otherwise, we need to allocate a new storage.
350 unsigned newCapacity =
351 std::max(a: unsigned(llvm::NextPowerOf2(A: capacity + 2)), b: newSize);
352 OpOperand *newOperandStorage =
353 reinterpret_cast<OpOperand *>(malloc(size: sizeof(OpOperand) * newCapacity));
354
355 // Move the current operands to the new storage.
356 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
357 std::uninitialized_move(first: origOperands.begin(), last: origOperands.end(),
358 result: newOperands.begin());
359
360 // Destroy the original operands.
361 for (auto &operand : origOperands)
362 operand.~OpOperand();
363
364 // Initialize any new operands.
365 for (unsigned e = newSize; numOperands != e; ++numOperands)
366 new (&newOperands[numOperands]) OpOperand(owner);
367
368 // If the current storage is dynamic, free it.
369 if (isStorageDynamic)
370 free(ptr: operandStorage);
371
372 // Update the storage representation to use the new dynamic storage.
373 operandStorage = newOperandStorage;
374 capacity = newCapacity;
375 isStorageDynamic = true;
376 return newOperands;
377}
378
379//===----------------------------------------------------------------------===//
380// Operation Value-Iterators
381//===----------------------------------------------------------------------===//
382
383//===----------------------------------------------------------------------===//
384// OperandRange
385//===----------------------------------------------------------------------===//
386
387unsigned OperandRange::getBeginOperandIndex() const {
388 assert(!empty() && "range must not be empty");
389 return base->getOperandNumber();
390}
391
392OperandRangeRange OperandRange::split(DenseI32ArrayAttr segmentSizes) const {
393 return OperandRangeRange(*this, segmentSizes);
394}
395
396//===----------------------------------------------------------------------===//
397// OperandRangeRange
398//===----------------------------------------------------------------------===//
399
400OperandRangeRange::OperandRangeRange(OperandRange operands,
401 Attribute operandSegments)
402 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
403 llvm::cast<DenseI32ArrayAttr>(Val&: operandSegments).size()) {
404}
405
406OperandRange OperandRangeRange::join() const {
407 const OwnerT &owner = getBase();
408 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(Val: owner.second);
409 return OperandRange(owner.first,
410 std::accumulate(first: sizeData.begin(), last: sizeData.end(), init: 0));
411}
412
413OperandRange OperandRangeRange::dereference(const OwnerT &object,
414 ptrdiff_t index) {
415 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(Val: object.second);
416 uint32_t startIndex =
417 std::accumulate(first: sizeData.begin(), last: sizeData.begin() + index, init: 0);
418 return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
419}
420
421//===----------------------------------------------------------------------===//
422// MutableOperandRange
423//===----------------------------------------------------------------------===//
424
425/// Construct a new mutable range from the given operand, operand start index,
426/// and range length.
427MutableOperandRange::MutableOperandRange(
428 Operation *owner, unsigned start, unsigned length,
429 ArrayRef<OperandSegment> operandSegments)
430 : owner(owner), start(start), length(length),
431 operandSegments(operandSegments) {
432 assert((start + length) <= owner->getNumOperands() && "invalid range");
433}
434MutableOperandRange::MutableOperandRange(Operation *owner)
435 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
436
437/// Construct a new mutable range for the given OpOperand.
438MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
439 : MutableOperandRange(opOperand.getOwner(),
440 /*start=*/opOperand.getOperandNumber(),
441 /*length=*/1) {}
442
443/// Slice this range into a sub range, with the additional operand segment.
444MutableOperandRange
445MutableOperandRange::slice(unsigned subStart, unsigned subLen,
446 std::optional<OperandSegment> segment) const {
447 assert((subStart + subLen) <= length && "invalid sub-range");
448 MutableOperandRange subSlice(owner, start + subStart, subLen,
449 operandSegments);
450 if (segment)
451 subSlice.operandSegments.push_back(Elt: *segment);
452 return subSlice;
453}
454
455/// Append the given values to the range.
456void MutableOperandRange::append(ValueRange values) {
457 if (values.empty())
458 return;
459 owner->insertOperands(index: start + length, operands: values);
460 updateLength(newLength: length + values.size());
461}
462
463/// Assign this range to the given values.
464void MutableOperandRange::assign(ValueRange values) {
465 owner->setOperands(start, length, operands: values);
466 if (length != values.size())
467 updateLength(/*newLength=*/values.size());
468}
469
470/// Assign the range to the given value.
471void MutableOperandRange::assign(Value value) {
472 if (length == 1) {
473 owner->setOperand(idx: start, value);
474 } else {
475 owner->setOperands(start, length, operands: value);
476 updateLength(/*newLength=*/1);
477 }
478}
479
480/// Erase the operands within the given sub-range.
481void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
482 assert((subStart + subLen) <= length && "invalid sub-range");
483 if (length == 0)
484 return;
485 owner->eraseOperands(idx: start + subStart, length: subLen);
486 updateLength(newLength: length - subLen);
487}
488
489/// Clear this range and erase all of the operands.
490void MutableOperandRange::clear() {
491 if (length != 0) {
492 owner->eraseOperands(idx: start, length);
493 updateLength(/*newLength=*/0);
494 }
495}
496
497/// Explicit conversion to an OperandRange.
498OperandRange MutableOperandRange::getAsOperandRange() const {
499 return owner->getOperands().slice(n: start, m: length);
500}
501
502/// Allow implicit conversion to an OperandRange.
503MutableOperandRange::operator OperandRange() const {
504 return getAsOperandRange();
505}
506
507MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
508 return owner->getOpOperands().slice(N: start, M: length);
509}
510
511MutableOperandRangeRange
512MutableOperandRange::split(NamedAttribute segmentSizes) const {
513 return MutableOperandRangeRange(*this, segmentSizes);
514}
515
516/// Update the length of this range to the one provided.
517void MutableOperandRange::updateLength(unsigned newLength) {
518 int32_t diff = int32_t(newLength) - int32_t(length);
519 length = newLength;
520
521 // Update any of the provided segment attributes.
522 for (OperandSegment &segment : operandSegments) {
523 auto attr = llvm::cast<DenseI32ArrayAttr>(Val: segment.second.getValue());
524 SmallVector<int32_t, 8> segments(attr.asArrayRef());
525 segments[segment.first] += diff;
526 segment.second.setValue(
527 DenseI32ArrayAttr::get(context: attr.getContext(), content: segments));
528 owner->setAttr(name: segment.second.getName(), value: segment.second.getValue());
529 }
530}
531
532OpOperand &MutableOperandRange::operator[](unsigned index) const {
533 assert(index < length && "index is out of bounds");
534 return owner->getOpOperand(idx: start + index);
535}
536
537MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
538 return owner->getOpOperands().slice(N: start, M: length).begin();
539}
540
541MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
542 return owner->getOpOperands().slice(N: start, M: length).end();
543}
544
545//===----------------------------------------------------------------------===//
546// MutableOperandRangeRange
547//===----------------------------------------------------------------------===//
548
549MutableOperandRangeRange::MutableOperandRangeRange(
550 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
551 : MutableOperandRangeRange(
552 OwnerT(operands, operandSegmentAttr), 0,
553 llvm::cast<DenseI32ArrayAttr>(Val: operandSegmentAttr.getValue()).size()) {
554}
555
556MutableOperandRange MutableOperandRangeRange::join() const {
557 return getBase().first;
558}
559
560MutableOperandRangeRange::operator OperandRangeRange() const {
561 return OperandRangeRange(getBase().first, getBase().second.getValue());
562}
563
564MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
565 ptrdiff_t index) {
566 ArrayRef<int32_t> sizeData =
567 llvm::cast<DenseI32ArrayAttr>(Val: object.second.getValue());
568 uint32_t startIndex =
569 std::accumulate(first: sizeData.begin(), last: sizeData.begin() + index, init: 0);
570 return object.first.slice(
571 subStart: startIndex, subLen: *(sizeData.begin() + index),
572 segment: MutableOperandRange::OperandSegment(index, object.second));
573}
574
575//===----------------------------------------------------------------------===//
576// ResultRange
577//===----------------------------------------------------------------------===//
578
579ResultRange::ResultRange(OpResult result)
580 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
581 1) {}
582
583ResultRange::use_range ResultRange::getUses() const {
584 return {use_begin(), use_end()};
585}
586ResultRange::use_iterator ResultRange::use_begin() const {
587 return use_iterator(*this);
588}
589ResultRange::use_iterator ResultRange::use_end() const {
590 return use_iterator(*this, /*end=*/true);
591}
592ResultRange::user_range ResultRange::getUsers() {
593 return {user_begin(), user_end()};
594}
595ResultRange::user_iterator ResultRange::user_begin() {
596 return user_iterator(use_begin());
597}
598ResultRange::user_iterator ResultRange::user_end() {
599 return user_iterator(use_end());
600}
601
602ResultRange::UseIterator::UseIterator(ResultRange results, bool end)
603 : it(end ? results.end() : results.begin()), endIt(results.end()) {
604 // Only initialize current use if there are results/can be uses.
605 if (it != endIt)
606 skipOverResultsWithNoUsers();
607}
608
609ResultRange::UseIterator &ResultRange::UseIterator::operator++() {
610 // We increment over uses, if we reach the last use then move to next
611 // result.
612 if (use != (*it).use_end())
613 ++use;
614 if (use == (*it).use_end()) {
615 ++it;
616 skipOverResultsWithNoUsers();
617 }
618 return *this;
619}
620
621void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
622 while (it != endIt && (*it).use_empty())
623 ++it;
624
625 // If we are at the last result, then set use to first use of
626 // first result (sentinel value used for end).
627 if (it == endIt)
628 use = {};
629 else
630 use = (*it).use_begin();
631}
632
633void ResultRange::replaceAllUsesWith(Operation *op) {
634 replaceAllUsesWith(values: op->getResults());
635}
636
637void ResultRange::replaceUsesWithIf(
638 Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
639 replaceUsesWithIf(values: op->getResults(), shouldReplace);
640}
641
642//===----------------------------------------------------------------------===//
643// ValueRange
644//===----------------------------------------------------------------------===//
645
646ValueRange::ValueRange(ArrayRef<Value> values)
647 : ValueRange(values.data(), values.size()) {}
648ValueRange::ValueRange(OperandRange values)
649 : ValueRange(values.begin().getBase(), values.size()) {}
650ValueRange::ValueRange(ResultRange values)
651 : ValueRange(values.getBase(), values.size()) {}
652
653/// See `llvm::detail::indexed_accessor_range_base` for details.
654ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
655 ptrdiff_t index) {
656 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(Val: owner))
657 return {value + index};
658 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(Val: owner))
659 return {operand + index};
660 return cast<detail::OpResultImpl *>(Val: owner)->getNextResultAtOffset(offset: index);
661}
662/// See `llvm::detail::indexed_accessor_range_base` for details.
663Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
664 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(Val: owner))
665 return value[index];
666 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(Val: owner))
667 return operand[index].get();
668 return cast<detail::OpResultImpl *>(Val: owner)->getNextResultAtOffset(offset: index);
669}
670
671//===----------------------------------------------------------------------===//
672// Operation Equivalency
673//===----------------------------------------------------------------------===//
674
675llvm::hash_code OperationEquivalence::computeHash(
676 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
677 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
678 // Hash operations based upon their:
679 // - Operation Name
680 // - Attributes
681 // - Result Types
682 DictionaryAttr dictAttrs;
683 if (!(flags & Flags::IgnoreDiscardableAttrs))
684 dictAttrs = op->getRawDictionaryAttrs();
685 llvm::hash_code hash =
686 llvm::hash_combine(args: op->getName(), args: dictAttrs, args: op->getResultTypes());
687 if (!(flags & Flags::IgnoreProperties))
688 hash = llvm::hash_combine(args: hash, args: op->hashProperties());
689
690 // - Location if required
691 if (!(flags & Flags::IgnoreLocations))
692 hash = llvm::hash_combine(args: hash, args: op->getLoc());
693
694 // - Operands
695 if (op->hasTrait<mlir::OpTrait::IsCommutative>() &&
696 op->getNumOperands() > 0) {
697 size_t operandHash = hashOperands(op->getOperand(idx: 0));
698 for (auto operand : op->getOperands().drop_front())
699 operandHash += hashOperands(operand);
700 hash = llvm::hash_combine(args: hash, args: operandHash);
701 } else {
702 for (Value operand : op->getOperands())
703 hash = llvm::hash_combine(args: hash, args: hashOperands(operand));
704 }
705
706 // - Results
707 for (Value result : op->getResults())
708 hash = llvm::hash_combine(args: hash, args: hashResults(result));
709 return hash;
710}
711
712/*static*/ bool OperationEquivalence::isRegionEquivalentTo(
713 Region *lhs, Region *rhs,
714 function_ref<LogicalResult(Value, Value)> checkEquivalent,
715 function_ref<void(Value, Value)> markEquivalent,
716 OperationEquivalence::Flags flags,
717 function_ref<LogicalResult(ValueRange, ValueRange)>
718 checkCommutativeEquivalent) {
719 DenseMap<Block *, Block *> blocksMap;
720 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
721 // Check block arguments.
722 if (lBlock.getNumArguments() != rBlock.getNumArguments())
723 return false;
724
725 // Map the two blocks.
726 auto insertion = blocksMap.insert(KV: {&lBlock, &rBlock});
727 if (insertion.first->getSecond() != &rBlock)
728 return false;
729
730 for (auto argPair :
731 llvm::zip(t: lBlock.getArguments(), u: rBlock.getArguments())) {
732 Value curArg = std::get<0>(t&: argPair);
733 Value otherArg = std::get<1>(t&: argPair);
734 if (curArg.getType() != otherArg.getType())
735 return false;
736 if (!(flags & OperationEquivalence::IgnoreLocations) &&
737 curArg.getLoc() != otherArg.getLoc())
738 return false;
739 // Corresponding bbArgs are equivalent.
740 if (markEquivalent)
741 markEquivalent(curArg, otherArg);
742 }
743
744 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
745 // Check for op equality (recursively).
746 if (!OperationEquivalence::isEquivalentTo(lhs: &lOp, rhs: &rOp, checkEquivalent,
747 markEquivalent, flags,
748 checkCommutativeEquivalent))
749 return false;
750 // Check successor mapping.
751 for (auto successorsPair :
752 llvm::zip(t: lOp.getSuccessors(), u: rOp.getSuccessors())) {
753 Block *curSuccessor = std::get<0>(t&: successorsPair);
754 Block *otherSuccessor = std::get<1>(t&: successorsPair);
755 auto insertion = blocksMap.insert(KV: {curSuccessor, otherSuccessor});
756 if (insertion.first->getSecond() != otherSuccessor)
757 return false;
758 }
759 return true;
760 };
761 return llvm::all_of_zip(argsAndPredicate&: lBlock, argsAndPredicate&: rBlock, argsAndPredicate&: opsEquivalent);
762 };
763 return llvm::all_of_zip(argsAndPredicate&: *lhs, argsAndPredicate&: *rhs, argsAndPredicate&: blocksEquivalent);
764}
765
766// Value equivalence cache to be used with `isRegionEquivalentTo` and
767// `isEquivalentTo`.
768struct ValueEquivalenceCache {
769 DenseMap<Value, Value> equivalentValues;
770 LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
771 return success(IsSuccess: lhsValue == rhsValue ||
772 equivalentValues.lookup(Val: lhsValue) == rhsValue);
773 }
774 LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
775 ValueRange rhsRange) {
776 // Handle simple case where sizes mismatch.
777 if (lhsRange.size() != rhsRange.size())
778 return failure();
779
780 // Handle where operands in order are equivalent.
781 auto lhsIt = lhsRange.begin();
782 auto rhsIt = rhsRange.begin();
783 for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
784 if (failed(Result: checkEquivalent(lhsValue: *lhsIt, rhsValue: *rhsIt)))
785 break;
786 }
787 if (lhsIt == lhsRange.end())
788 return success();
789
790 // Handle another simple case where operands are just a permutation.
791 // Note: This is not sufficient, this handles simple cases relatively
792 // cheaply.
793 auto sortValues = [](ValueRange values) {
794 SmallVector<Value> sortedValues = llvm::to_vector(Range&: values);
795 llvm::sort(C&: sortedValues, Comp: [](Value a, Value b) {
796 return a.getAsOpaquePointer() < b.getAsOpaquePointer();
797 });
798 return sortedValues;
799 };
800 auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
801 auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
802 return success(IsSuccess: lhsSorted == rhsSorted);
803 }
804 void markEquivalent(Value lhsResult, Value rhsResult) {
805 auto insertion = equivalentValues.insert(KV: {lhsResult, rhsResult});
806 // Make sure that the value was not already marked equivalent to some other
807 // value.
808 (void)insertion;
809 assert(insertion.first->second == rhsResult &&
810 "inconsistent OperationEquivalence state");
811 }
812};
813
814/*static*/ bool
815OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
816 OperationEquivalence::Flags flags) {
817 ValueEquivalenceCache cache;
818 return isRegionEquivalentTo(
819 lhs, rhs,
820 checkEquivalent: [&](Value lhsValue, Value rhsValue) -> LogicalResult {
821 return cache.checkEquivalent(lhsValue, rhsValue);
822 },
823 markEquivalent: [&](Value lhsResult, Value rhsResult) {
824 cache.markEquivalent(lhsResult, rhsResult);
825 },
826 flags,
827 checkCommutativeEquivalent: [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
828 return cache.checkCommutativeEquivalent(lhsRange: lhs, rhsRange: rhs);
829 });
830}
831
832/*static*/ bool OperationEquivalence::isEquivalentTo(
833 Operation *lhs, Operation *rhs,
834 function_ref<LogicalResult(Value, Value)> checkEquivalent,
835 function_ref<void(Value, Value)> markEquivalent, Flags flags,
836 function_ref<LogicalResult(ValueRange, ValueRange)>
837 checkCommutativeEquivalent) {
838 if (lhs == rhs)
839 return true;
840
841 // 1. Compare the operation properties.
842 if (!(flags & IgnoreDiscardableAttrs) &&
843 lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs())
844 return false;
845
846 if (lhs->getName() != rhs->getName() ||
847 lhs->getNumRegions() != rhs->getNumRegions() ||
848 lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
849 lhs->getNumOperands() != rhs->getNumOperands() ||
850 lhs->getNumResults() != rhs->getNumResults())
851 return false;
852 if (!(flags & IgnoreProperties) &&
853 !(lhs->getName().compareOpProperties(lhs: lhs->getPropertiesStorage(),
854 rhs: rhs->getPropertiesStorage())))
855 return false;
856 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
857 return false;
858
859 // 2. Compare operands.
860 if (checkCommutativeEquivalent &&
861 lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
862 auto lhsRange = lhs->getOperands();
863 auto rhsRange = rhs->getOperands();
864 if (failed(Result: checkCommutativeEquivalent(lhsRange, rhsRange)))
865 return false;
866 } else {
867 // Check pair wise for equivalence.
868 for (auto operandPair : llvm::zip(t: lhs->getOperands(), u: rhs->getOperands())) {
869 Value curArg = std::get<0>(t&: operandPair);
870 Value otherArg = std::get<1>(t&: operandPair);
871 if (curArg == otherArg)
872 continue;
873 if (curArg.getType() != otherArg.getType())
874 return false;
875 if (failed(Result: checkEquivalent(curArg, otherArg)))
876 return false;
877 }
878 }
879
880 // 3. Compare result types and mark results as equivalent.
881 for (auto resultPair : llvm::zip(t: lhs->getResults(), u: rhs->getResults())) {
882 Value curArg = std::get<0>(t&: resultPair);
883 Value otherArg = std::get<1>(t&: resultPair);
884 if (curArg.getType() != otherArg.getType())
885 return false;
886 if (markEquivalent)
887 markEquivalent(curArg, otherArg);
888 }
889
890 // 4. Compare regions.
891 for (auto regionPair : llvm::zip(t: lhs->getRegions(), u: rhs->getRegions()))
892 if (!isRegionEquivalentTo(lhs: &std::get<0>(t&: regionPair),
893 rhs: &std::get<1>(t&: regionPair), checkEquivalent,
894 markEquivalent, flags))
895 return false;
896
897 return true;
898}
899
900/*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
901 Operation *rhs,
902 Flags flags) {
903 ValueEquivalenceCache cache;
904 return OperationEquivalence::isEquivalentTo(
905 lhs, rhs,
906 checkEquivalent: [&](Value lhsValue, Value rhsValue) -> LogicalResult {
907 return cache.checkEquivalent(lhsValue, rhsValue);
908 },
909 markEquivalent: [&](Value lhsResult, Value rhsResult) {
910 cache.markEquivalent(lhsResult, rhsResult);
911 },
912 flags,
913 checkCommutativeEquivalent: [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
914 return cache.checkCommutativeEquivalent(lhsRange: lhs, rhsRange: rhs);
915 });
916}
917
918//===----------------------------------------------------------------------===//
919// OperationFingerPrint
920//===----------------------------------------------------------------------===//
921
922template <typename T>
923static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
924 hasher.update(
925 Data: ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
926}
927
928OperationFingerPrint::OperationFingerPrint(Operation *topOp,
929 bool includeNested) {
930 llvm::SHA1 hasher;
931
932 // Helper function that hashes an operation based on its mutable bits:
933 auto addOperationToHash = [&](Operation *op) {
934 // - Operation pointer
935 addDataToHash(hasher, data: op);
936 // - Parent operation pointer (to take into account the nesting structure)
937 if (op != topOp)
938 addDataToHash(hasher, data: op->getParentOp());
939 // - Attributes
940 addDataToHash(hasher, data: op->getRawDictionaryAttrs());
941 // - Properties
942 addDataToHash(hasher, data: op->hashProperties());
943 // - Blocks in Regions
944 for (Region &region : op->getRegions()) {
945 for (Block &block : region) {
946 addDataToHash(hasher, data: &block);
947 for (BlockArgument arg : block.getArguments())
948 addDataToHash(hasher, data: arg);
949 }
950 }
951 // - Location
952 addDataToHash(hasher, data: op->getLoc().getAsOpaquePointer());
953 // - Operands
954 for (Value operand : op->getOperands())
955 addDataToHash(hasher, data: operand);
956 // - Successors
957 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
958 addDataToHash(hasher, data: op->getSuccessor(index: i));
959 // - Result types
960 for (Type t : op->getResultTypes())
961 addDataToHash(hasher, data: t);
962 };
963
964 if (includeNested)
965 topOp->walk(callback&: addOperationToHash);
966 else
967 addOperationToHash(topOp);
968
969 hash = hasher.result();
970}
971

source code of mlir/lib/IR/OperationSupport.cpp