1//===- OperationSupport.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains out-of-line implementations of the support types that
10// Operation and related classes build on top of.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/IR/OperationSupport.h"
15#include "mlir/IR/BuiltinAttributes.h"
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/OpDefinition.h"
18#include "llvm/ADT/BitVector.h"
19#include "llvm/Support/SHA1.h"
20#include <numeric>
21#include <optional>
22
23using namespace mlir;
24
25//===----------------------------------------------------------------------===//
26// NamedAttrList
27//===----------------------------------------------------------------------===//
28
29NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) {
30 assign(inStart: attributes.begin(), inEnd: attributes.end());
31}
32
33NamedAttrList::NamedAttrList(DictionaryAttr attributes)
34 : NamedAttrList(attributes ? attributes.getValue()
35 : ArrayRef<NamedAttribute>()) {
36 dictionarySorted.setPointerAndInt(PtrVal: attributes, IntVal: true);
37}
38
39NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) {
40 assign(inStart, inEnd);
41}
42
43ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
44
45std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
46 std::optional<NamedAttribute> duplicate =
47 DictionaryAttr::findDuplicate(attrs, isSorted());
48 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
49 // state.
50 if (!isSorted())
51 dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: true);
52 return duplicate;
53}
54
55DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
56 if (!isSorted()) {
57 DictionaryAttr::sortInPlace(attrs);
58 dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: true);
59 }
60 if (!dictionarySorted.getPointer())
61 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
62 return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
63}
64
65/// Replaces the attributes with new list of attributes.
66void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) {
67 DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
68 dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: true);
69}
70
71void NamedAttrList::push_back(NamedAttribute newAttribute) {
72 if (isSorted())
73 dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
74 dictionarySorted.setPointer(nullptr);
75 attrs.push_back(Elt: newAttribute);
76}
77
78/// Return the specified attribute if present, null otherwise.
79Attribute NamedAttrList::get(StringRef name) const {
80 auto it = findAttr(attrs: *this, name);
81 return it.second ? it.first->getValue() : Attribute();
82}
83Attribute NamedAttrList::get(StringAttr name) const {
84 auto it = findAttr(*this, name);
85 return it.second ? it.first->getValue() : Attribute();
86}
87
88/// Return the specified named attribute if present, std::nullopt otherwise.
89std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
90 auto it = findAttr(attrs: *this, name);
91 return it.second ? *it.first : std::optional<NamedAttribute>();
92}
93std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
94 auto it = findAttr(*this, name);
95 return it.second ? *it.first : std::optional<NamedAttribute>();
96}
97
98/// If the an attribute exists with the specified name, change it to the new
99/// value. Otherwise, add a new attribute with the specified name/value.
100Attribute NamedAttrList::set(StringAttr name, Attribute value) {
101 assert(value && "attributes may never be null");
102
103 // Look for an existing attribute with the given name, and set its value
104 // in-place. Return the previous value of the attribute, if there was one.
105 auto it = findAttr(*this, name);
106 if (it.second) {
107 // Update the existing attribute by swapping out the old value for the new
108 // value. Return the old value.
109 Attribute oldValue = it.first->getValue();
110 if (it.first->getValue() != value) {
111 it.first->setValue(value);
112
113 // If the attributes have changed, the dictionary is invalidated.
114 dictionarySorted.setPointer(nullptr);
115 }
116 return oldValue;
117 }
118 // Perform a string lookup to insert the new attribute into its sorted
119 // position.
120 if (isSorted())
121 it = findAttr(*this, name.strref());
122 attrs.insert(it.first, {name, value});
123 // Invalidate the dictionary. Return null as there was no previous value.
124 dictionarySorted.setPointer(nullptr);
125 return Attribute();
126}
127
128Attribute NamedAttrList::set(StringRef name, Attribute value) {
129 assert(value && "attributes may never be null");
130 return set(mlir::StringAttr::get(value.getContext(), name), value);
131}
132
133Attribute
134NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
135 // Erasing does not affect the sorted property.
136 Attribute attr = it->getValue();
137 attrs.erase(CI: it);
138 dictionarySorted.setPointer(nullptr);
139 return attr;
140}
141
142Attribute NamedAttrList::erase(StringAttr name) {
143 auto it = findAttr(*this, name);
144 return it.second ? eraseImpl(it: it.first) : Attribute();
145}
146
147Attribute NamedAttrList::erase(StringRef name) {
148 auto it = findAttr(attrs&: *this, name);
149 return it.second ? eraseImpl(it: it.first) : Attribute();
150}
151
152NamedAttrList &
153NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) {
154 assign(inStart: rhs.begin(), inEnd: rhs.end());
155 return *this;
156}
157
158NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
159
160//===----------------------------------------------------------------------===//
161// OperationState
162//===----------------------------------------------------------------------===//
163
164OperationState::OperationState(Location location, StringRef name)
165 : location(location), name(name, location->getContext()) {}
166
167OperationState::OperationState(Location location, OperationName name)
168 : location(location), name(name) {}
169
170OperationState::OperationState(Location location, OperationName name,
171 ValueRange operands, TypeRange types,
172 ArrayRef<NamedAttribute> attributes,
173 BlockRange successors,
174 MutableArrayRef<std::unique_ptr<Region>> regions)
175 : location(location), name(name),
176 operands(operands.begin(), operands.end()),
177 types(types.begin(), types.end()),
178 attributes(attributes.begin(), attributes.end()),
179 successors(successors.begin(), successors.end()) {
180 for (std::unique_ptr<Region> &r : regions)
181 this->regions.push_back(Elt: std::move(r));
182}
183OperationState::OperationState(Location location, StringRef name,
184 ValueRange operands, TypeRange types,
185 ArrayRef<NamedAttribute> attributes,
186 BlockRange successors,
187 MutableArrayRef<std::unique_ptr<Region>> regions)
188 : OperationState(location, OperationName(name, location.getContext()),
189 operands, types, attributes, successors, regions) {}
190
191OperationState::~OperationState() {
192 if (properties)
193 propertiesDeleter(properties);
194}
195
196LogicalResult OperationState::setProperties(
197 Operation *op, function_ref<InFlightDiagnostic()> emitError) const {
198 if (LLVM_UNLIKELY(propertiesAttr)) {
199 assert(!properties);
200 return op->setPropertiesFromAttribute(attr: propertiesAttr, emitError);
201 }
202 if (properties)
203 propertiesSetter(op->getPropertiesStorage(), properties);
204 return success();
205}
206
207void OperationState::addOperands(ValueRange newOperands) {
208 operands.append(in_start: newOperands.begin(), in_end: newOperands.end());
209}
210
211void OperationState::addSuccessors(BlockRange newSuccessors) {
212 successors.append(in_start: newSuccessors.begin(), in_end: newSuccessors.end());
213}
214
215Region *OperationState::addRegion() {
216 regions.emplace_back(Args: new Region);
217 return regions.back().get();
218}
219
220void OperationState::addRegion(std::unique_ptr<Region> &&region) {
221 regions.push_back(Elt: std::move(region));
222}
223
224void OperationState::addRegions(
225 MutableArrayRef<std::unique_ptr<Region>> regions) {
226 for (std::unique_ptr<Region> &region : regions)
227 addRegion(region: std::move(region));
228}
229
230//===----------------------------------------------------------------------===//
231// OperandStorage
232//===----------------------------------------------------------------------===//
233
234detail::OperandStorage::OperandStorage(Operation *owner,
235 OpOperand *trailingOperands,
236 ValueRange values)
237 : isStorageDynamic(false), operandStorage(trailingOperands) {
238 numOperands = capacity = values.size();
239 for (unsigned i = 0; i < numOperands; ++i)
240 new (&operandStorage[i]) OpOperand(owner, values[i]);
241}
242
243detail::OperandStorage::~OperandStorage() {
244 for (auto &operand : getOperands())
245 operand.~OpOperand();
246
247 // If the storage is dynamic, deallocate it.
248 if (isStorageDynamic)
249 free(ptr: operandStorage);
250}
251
252/// Replace the operands contained in the storage with the ones provided in
253/// 'values'.
254void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
255 MutableArrayRef<OpOperand> storageOperands = resize(owner, newSize: values.size());
256 for (unsigned i = 0, e = values.size(); i != e; ++i)
257 storageOperands[i].set(values[i]);
258}
259
260/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
261/// with the ones provided in 'operands'. 'operands' may be smaller or larger
262/// than the range pointed to by 'start'+'length'.
263void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
264 unsigned length, ValueRange operands) {
265 // If the new size is the same, we can update inplace.
266 unsigned newSize = operands.size();
267 if (newSize == length) {
268 MutableArrayRef<OpOperand> storageOperands = getOperands();
269 for (unsigned i = 0, e = length; i != e; ++i)
270 storageOperands[start + i].set(operands[i]);
271 return;
272 }
273 // If the new size is greater, remove the extra operands and set the rest
274 // inplace.
275 if (newSize < length) {
276 eraseOperands(start: start + operands.size(), length: length - newSize);
277 setOperands(owner, start, length: newSize, operands);
278 return;
279 }
280 // Otherwise, the new size is greater so we need to grow the storage.
281 auto storageOperands = resize(owner, newSize: size() + (newSize - length));
282
283 // Shift operands to the right to make space for the new operands.
284 unsigned rotateSize = storageOperands.size() - (start + length);
285 auto rbegin = storageOperands.rbegin();
286 std::rotate(first: rbegin, middle: std::next(x: rbegin, n: newSize - length), last: rbegin + rotateSize);
287
288 // Update the operands inplace.
289 for (unsigned i = 0, e = operands.size(); i != e; ++i)
290 storageOperands[start + i].set(operands[i]);
291}
292
293/// Erase an operand held by the storage.
294void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
295 MutableArrayRef<OpOperand> operands = getOperands();
296 assert((start + length) <= operands.size());
297 numOperands -= length;
298
299 // Shift all operands down if the operand to remove is not at the end.
300 if (start != numOperands) {
301 auto *indexIt = std::next(x: operands.begin(), n: start);
302 std::rotate(first: indexIt, middle: std::next(x: indexIt, n: length), last: operands.end());
303 }
304 for (unsigned i = 0; i != length; ++i)
305 operands[numOperands + i].~OpOperand();
306}
307
308void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
309 MutableArrayRef<OpOperand> operands = getOperands();
310 assert(eraseIndices.size() == operands.size());
311
312 // Check that at least one operand is erased.
313 int firstErasedIndice = eraseIndices.find_first();
314 if (firstErasedIndice == -1)
315 return;
316
317 // Shift all of the removed operands to the end, and destroy them.
318 numOperands = firstErasedIndice;
319 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
320 if (!eraseIndices.test(Idx: i))
321 operands[numOperands++] = std::move(operands[i]);
322 for (OpOperand &operand : operands.drop_front(N: numOperands))
323 operand.~OpOperand();
324}
325
326/// Resize the storage to the given size. Returns the array containing the new
327/// operands.
328MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
329 unsigned newSize) {
330 // If the number of operands is less than or equal to the current amount, we
331 // can just update in place.
332 MutableArrayRef<OpOperand> origOperands = getOperands();
333 if (newSize <= numOperands) {
334 // If the number of new size is less than the current, remove any extra
335 // operands.
336 for (unsigned i = newSize; i != numOperands; ++i)
337 origOperands[i].~OpOperand();
338 numOperands = newSize;
339 return origOperands.take_front(N: newSize);
340 }
341
342 // If the new size is within the original inline capacity, grow inplace.
343 if (newSize <= capacity) {
344 OpOperand *opBegin = origOperands.data();
345 for (unsigned e = newSize; numOperands != e; ++numOperands)
346 new (&opBegin[numOperands]) OpOperand(owner);
347 return MutableArrayRef<OpOperand>(opBegin, newSize);
348 }
349
350 // Otherwise, we need to allocate a new storage.
351 unsigned newCapacity =
352 std::max(a: unsigned(llvm::NextPowerOf2(A: capacity + 2)), b: newSize);
353 OpOperand *newOperandStorage =
354 reinterpret_cast<OpOperand *>(malloc(size: sizeof(OpOperand) * newCapacity));
355
356 // Move the current operands to the new storage.
357 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
358 std::uninitialized_move(first: origOperands.begin(), last: origOperands.end(),
359 result: newOperands.begin());
360
361 // Destroy the original operands.
362 for (auto &operand : origOperands)
363 operand.~OpOperand();
364
365 // Initialize any new operands.
366 for (unsigned e = newSize; numOperands != e; ++numOperands)
367 new (&newOperands[numOperands]) OpOperand(owner);
368
369 // If the current storage is dynamic, free it.
370 if (isStorageDynamic)
371 free(ptr: operandStorage);
372
373 // Update the storage representation to use the new dynamic storage.
374 operandStorage = newOperandStorage;
375 capacity = newCapacity;
376 isStorageDynamic = true;
377 return newOperands;
378}
379
380//===----------------------------------------------------------------------===//
381// Operation Value-Iterators
382//===----------------------------------------------------------------------===//
383
384//===----------------------------------------------------------------------===//
385// OperandRange
386//===----------------------------------------------------------------------===//
387
388unsigned OperandRange::getBeginOperandIndex() const {
389 assert(!empty() && "range must not be empty");
390 return base->getOperandNumber();
391}
392
393OperandRangeRange OperandRange::split(DenseI32ArrayAttr segmentSizes) const {
394 return OperandRangeRange(*this, segmentSizes);
395}
396
397//===----------------------------------------------------------------------===//
398// OperandRangeRange
399//===----------------------------------------------------------------------===//
400
401OperandRangeRange::OperandRangeRange(OperandRange operands,
402 Attribute operandSegments)
403 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
404 llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
405}
406
407OperandRange OperandRangeRange::join() const {
408 const OwnerT &owner = getBase();
409 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
410 return OperandRange(owner.first,
411 std::accumulate(first: sizeData.begin(), last: sizeData.end(), init: 0));
412}
413
414OperandRange OperandRangeRange::dereference(const OwnerT &object,
415 ptrdiff_t index) {
416 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
417 uint32_t startIndex =
418 std::accumulate(first: sizeData.begin(), last: sizeData.begin() + index, init: 0);
419 return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
420}
421
422//===----------------------------------------------------------------------===//
423// MutableOperandRange
424//===----------------------------------------------------------------------===//
425
426/// Construct a new mutable range from the given operand, operand start index,
427/// and range length.
428MutableOperandRange::MutableOperandRange(
429 Operation *owner, unsigned start, unsigned length,
430 ArrayRef<OperandSegment> operandSegments)
431 : owner(owner), start(start), length(length),
432 operandSegments(operandSegments) {
433 assert((start + length) <= owner->getNumOperands() && "invalid range");
434}
435MutableOperandRange::MutableOperandRange(Operation *owner)
436 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
437
438/// Construct a new mutable range for the given OpOperand.
439MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
440 : MutableOperandRange(opOperand.getOwner(),
441 /*start=*/opOperand.getOperandNumber(),
442 /*length=*/1) {}
443
444/// Slice this range into a sub range, with the additional operand segment.
445MutableOperandRange
446MutableOperandRange::slice(unsigned subStart, unsigned subLen,
447 std::optional<OperandSegment> segment) const {
448 assert((subStart + subLen) <= length && "invalid sub-range");
449 MutableOperandRange subSlice(owner, start + subStart, subLen,
450 operandSegments);
451 if (segment)
452 subSlice.operandSegments.push_back(Elt: *segment);
453 return subSlice;
454}
455
456/// Append the given values to the range.
457void MutableOperandRange::append(ValueRange values) {
458 if (values.empty())
459 return;
460 owner->insertOperands(index: start + length, operands: values);
461 updateLength(newLength: length + values.size());
462}
463
464/// Assign this range to the given values.
465void MutableOperandRange::assign(ValueRange values) {
466 owner->setOperands(start, length, operands: values);
467 if (length != values.size())
468 updateLength(/*newLength=*/values.size());
469}
470
471/// Assign the range to the given value.
472void MutableOperandRange::assign(Value value) {
473 if (length == 1) {
474 owner->setOperand(idx: start, value);
475 } else {
476 owner->setOperands(start, length, operands: value);
477 updateLength(/*newLength=*/1);
478 }
479}
480
481/// Erase the operands within the given sub-range.
482void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
483 assert((subStart + subLen) <= length && "invalid sub-range");
484 if (length == 0)
485 return;
486 owner->eraseOperands(idx: start + subStart, length: subLen);
487 updateLength(newLength: length - subLen);
488}
489
490/// Clear this range and erase all of the operands.
491void MutableOperandRange::clear() {
492 if (length != 0) {
493 owner->eraseOperands(idx: start, length);
494 updateLength(/*newLength=*/0);
495 }
496}
497
498/// Explicit conversion to an OperandRange.
499OperandRange MutableOperandRange::getAsOperandRange() const {
500 return owner->getOperands().slice(n: start, m: length);
501}
502
503/// Allow implicit conversion to an OperandRange.
504MutableOperandRange::operator OperandRange() const {
505 return getAsOperandRange();
506}
507
508MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
509 return owner->getOpOperands().slice(N: start, M: length);
510}
511
512MutableOperandRangeRange
513MutableOperandRange::split(NamedAttribute segmentSizes) const {
514 return MutableOperandRangeRange(*this, segmentSizes);
515}
516
517/// Update the length of this range to the one provided.
518void MutableOperandRange::updateLength(unsigned newLength) {
519 int32_t diff = int32_t(newLength) - int32_t(length);
520 length = newLength;
521
522 // Update any of the provided segment attributes.
523 for (OperandSegment &segment : operandSegments) {
524 auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
525 SmallVector<int32_t, 8> segments(attr.asArrayRef());
526 segments[segment.first] += diff;
527 segment.second.setValue(
528 DenseI32ArrayAttr::get(attr.getContext(), segments));
529 owner->setAttr(segment.second.getName(), segment.second.getValue());
530 }
531}
532
533OpOperand &MutableOperandRange::operator[](unsigned index) const {
534 assert(index < length && "index is out of bounds");
535 return owner->getOpOperand(idx: start + index);
536}
537
538MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
539 return owner->getOpOperands().slice(N: start, M: length).begin();
540}
541
542MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
543 return owner->getOpOperands().slice(N: start, M: length).end();
544}
545
546//===----------------------------------------------------------------------===//
547// MutableOperandRangeRange
548//===----------------------------------------------------------------------===//
549
550MutableOperandRangeRange::MutableOperandRangeRange(
551 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
552 : MutableOperandRangeRange(
553 OwnerT(operands, operandSegmentAttr), 0,
554 llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
555}
556
557MutableOperandRange MutableOperandRangeRange::join() const {
558 return getBase().first;
559}
560
561MutableOperandRangeRange::operator OperandRangeRange() const {
562 return OperandRangeRange(getBase().first, getBase().second.getValue());
563}
564
565MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
566 ptrdiff_t index) {
567 ArrayRef<int32_t> sizeData =
568 llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
569 uint32_t startIndex =
570 std::accumulate(first: sizeData.begin(), last: sizeData.begin() + index, init: 0);
571 return object.first.slice(
572 subStart: startIndex, subLen: *(sizeData.begin() + index),
573 segment: MutableOperandRange::OperandSegment(index, object.second));
574}
575
576//===----------------------------------------------------------------------===//
577// ResultRange
578//===----------------------------------------------------------------------===//
579
580ResultRange::ResultRange(OpResult result)
581 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
582 1) {}
583
584ResultRange::use_range ResultRange::getUses() const {
585 return {use_begin(), use_end()};
586}
587ResultRange::use_iterator ResultRange::use_begin() const {
588 return use_iterator(*this);
589}
590ResultRange::use_iterator ResultRange::use_end() const {
591 return use_iterator(*this, /*end=*/true);
592}
593ResultRange::user_range ResultRange::getUsers() {
594 return {user_begin(), user_end()};
595}
596ResultRange::user_iterator ResultRange::user_begin() {
597 return user_iterator(use_begin());
598}
599ResultRange::user_iterator ResultRange::user_end() {
600 return user_iterator(use_end());
601}
602
603ResultRange::UseIterator::UseIterator(ResultRange results, bool end)
604 : it(end ? results.end() : results.begin()), endIt(results.end()) {
605 // Only initialize current use if there are results/can be uses.
606 if (it != endIt)
607 skipOverResultsWithNoUsers();
608}
609
610ResultRange::UseIterator &ResultRange::UseIterator::operator++() {
611 // We increment over uses, if we reach the last use then move to next
612 // result.
613 if (use != (*it).use_end())
614 ++use;
615 if (use == (*it).use_end()) {
616 ++it;
617 skipOverResultsWithNoUsers();
618 }
619 return *this;
620}
621
622void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
623 while (it != endIt && (*it).use_empty())
624 ++it;
625
626 // If we are at the last result, then set use to first use of
627 // first result (sentinel value used for end).
628 if (it == endIt)
629 use = {};
630 else
631 use = (*it).use_begin();
632}
633
634void ResultRange::replaceAllUsesWith(Operation *op) {
635 replaceAllUsesWith(values: op->getResults());
636}
637
638void ResultRange::replaceUsesWithIf(
639 Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
640 replaceUsesWithIf(values: op->getResults(), shouldReplace);
641}
642
643//===----------------------------------------------------------------------===//
644// ValueRange
645//===----------------------------------------------------------------------===//
646
647ValueRange::ValueRange(ArrayRef<Value> values)
648 : ValueRange(values.data(), values.size()) {}
649ValueRange::ValueRange(OperandRange values)
650 : ValueRange(values.begin().getBase(), values.size()) {}
651ValueRange::ValueRange(ResultRange values)
652 : ValueRange(values.getBase(), values.size()) {}
653
654/// See `llvm::detail::indexed_accessor_range_base` for details.
655ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
656 ptrdiff_t index) {
657 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(Val: owner))
658 return {value + index};
659 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(Val: owner))
660 return {operand + index};
661 return cast<detail::OpResultImpl *>(Val: owner)->getNextResultAtOffset(offset: index);
662}
663/// See `llvm::detail::indexed_accessor_range_base` for details.
664Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
665 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(Val: owner))
666 return value[index];
667 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(Val: owner))
668 return operand[index].get();
669 return cast<detail::OpResultImpl *>(Val: owner)->getNextResultAtOffset(offset: index);
670}
671
672//===----------------------------------------------------------------------===//
673// Operation Equivalency
674//===----------------------------------------------------------------------===//
675
676llvm::hash_code OperationEquivalence::computeHash(
677 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
678 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
679 // Hash operations based upon their:
680 // - Operation Name
681 // - Attributes
682 // - Result Types
683 DictionaryAttr dictAttrs;
684 if (!(flags & Flags::IgnoreDiscardableAttrs))
685 dictAttrs = op->getRawDictionaryAttrs();
686 llvm::hash_code hash =
687 llvm::hash_combine(op->getName(), dictAttrs, op->getResultTypes());
688 if (!(flags & Flags::IgnoreProperties))
689 hash = llvm::hash_combine(args: hash, args: op->hashProperties());
690
691 // - Location if required
692 if (!(flags & Flags::IgnoreLocations))
693 hash = llvm::hash_combine(args: hash, args: op->getLoc());
694
695 // - Operands
696 if (op->hasTrait<mlir::OpTrait::IsCommutative>() &&
697 op->getNumOperands() > 0) {
698 size_t operandHash = hashOperands(op->getOperand(idx: 0));
699 for (auto operand : op->getOperands().drop_front())
700 operandHash += hashOperands(operand);
701 hash = llvm::hash_combine(args: hash, args: operandHash);
702 } else {
703 for (Value operand : op->getOperands())
704 hash = llvm::hash_combine(args: hash, args: hashOperands(operand));
705 }
706
707 // - Results
708 for (Value result : op->getResults())
709 hash = llvm::hash_combine(args: hash, args: hashResults(result));
710 return hash;
711}
712
713/*static*/ bool OperationEquivalence::isRegionEquivalentTo(
714 Region *lhs, Region *rhs,
715 function_ref<LogicalResult(Value, Value)> checkEquivalent,
716 function_ref<void(Value, Value)> markEquivalent,
717 OperationEquivalence::Flags flags,
718 function_ref<LogicalResult(ValueRange, ValueRange)>
719 checkCommutativeEquivalent) {
720 DenseMap<Block *, Block *> blocksMap;
721 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
722 // Check block arguments.
723 if (lBlock.getNumArguments() != rBlock.getNumArguments())
724 return false;
725
726 // Map the two blocks.
727 auto insertion = blocksMap.insert(KV: {&lBlock, &rBlock});
728 if (insertion.first->getSecond() != &rBlock)
729 return false;
730
731 for (auto argPair :
732 llvm::zip(t: lBlock.getArguments(), u: rBlock.getArguments())) {
733 Value curArg = std::get<0>(t&: argPair);
734 Value otherArg = std::get<1>(t&: argPair);
735 if (curArg.getType() != otherArg.getType())
736 return false;
737 if (!(flags & OperationEquivalence::IgnoreLocations) &&
738 curArg.getLoc() != otherArg.getLoc())
739 return false;
740 // Corresponding bbArgs are equivalent.
741 if (markEquivalent)
742 markEquivalent(curArg, otherArg);
743 }
744
745 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
746 // Check for op equality (recursively).
747 if (!OperationEquivalence::isEquivalentTo(lhs: &lOp, rhs: &rOp, checkEquivalent,
748 markEquivalent, flags,
749 checkCommutativeEquivalent))
750 return false;
751 // Check successor mapping.
752 for (auto successorsPair :
753 llvm::zip(t: lOp.getSuccessors(), u: rOp.getSuccessors())) {
754 Block *curSuccessor = std::get<0>(t&: successorsPair);
755 Block *otherSuccessor = std::get<1>(t&: successorsPair);
756 auto insertion = blocksMap.insert(KV: {curSuccessor, otherSuccessor});
757 if (insertion.first->getSecond() != otherSuccessor)
758 return false;
759 }
760 return true;
761 };
762 return llvm::all_of_zip(argsAndPredicate&: lBlock, argsAndPredicate&: rBlock, argsAndPredicate&: opsEquivalent);
763 };
764 return llvm::all_of_zip(argsAndPredicate&: *lhs, argsAndPredicate&: *rhs, argsAndPredicate&: blocksEquivalent);
765}
766
767// Value equivalence cache to be used with `isRegionEquivalentTo` and
768// `isEquivalentTo`.
769struct ValueEquivalenceCache {
770 DenseMap<Value, Value> equivalentValues;
771 LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
772 return success(IsSuccess: lhsValue == rhsValue ||
773 equivalentValues.lookup(Val: lhsValue) == rhsValue);
774 }
775 LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
776 ValueRange rhsRange) {
777 // Handle simple case where sizes mismatch.
778 if (lhsRange.size() != rhsRange.size())
779 return failure();
780
781 // Handle where operands in order are equivalent.
782 auto lhsIt = lhsRange.begin();
783 auto rhsIt = rhsRange.begin();
784 for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
785 if (failed(Result: checkEquivalent(lhsValue: *lhsIt, rhsValue: *rhsIt)))
786 break;
787 }
788 if (lhsIt == lhsRange.end())
789 return success();
790
791 // Handle another simple case where operands are just a permutation.
792 // Note: This is not sufficient, this handles simple cases relatively
793 // cheaply.
794 auto sortValues = [](ValueRange values) {
795 SmallVector<Value> sortedValues = llvm::to_vector(Range&: values);
796 llvm::sort(C&: sortedValues, Comp: [](Value a, Value b) {
797 return a.getAsOpaquePointer() < b.getAsOpaquePointer();
798 });
799 return sortedValues;
800 };
801 auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
802 auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
803 return success(IsSuccess: lhsSorted == rhsSorted);
804 }
805 void markEquivalent(Value lhsResult, Value rhsResult) {
806 auto insertion = equivalentValues.insert(KV: {lhsResult, rhsResult});
807 // Make sure that the value was not already marked equivalent to some other
808 // value.
809 (void)insertion;
810 assert(insertion.first->second == rhsResult &&
811 "inconsistent OperationEquivalence state");
812 }
813};
814
815/*static*/ bool
816OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
817 OperationEquivalence::Flags flags) {
818 ValueEquivalenceCache cache;
819 return isRegionEquivalentTo(
820 lhs, rhs,
821 checkEquivalent: [&](Value lhsValue, Value rhsValue) -> LogicalResult {
822 return cache.checkEquivalent(lhsValue, rhsValue);
823 },
824 markEquivalent: [&](Value lhsResult, Value rhsResult) {
825 cache.markEquivalent(lhsResult, rhsResult);
826 },
827 flags,
828 checkCommutativeEquivalent: [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
829 return cache.checkCommutativeEquivalent(lhsRange: lhs, rhsRange: rhs);
830 });
831}
832
833/*static*/ bool OperationEquivalence::isEquivalentTo(
834 Operation *lhs, Operation *rhs,
835 function_ref<LogicalResult(Value, Value)> checkEquivalent,
836 function_ref<void(Value, Value)> markEquivalent, Flags flags,
837 function_ref<LogicalResult(ValueRange, ValueRange)>
838 checkCommutativeEquivalent) {
839 if (lhs == rhs)
840 return true;
841
842 // 1. Compare the operation properties.
843 if (!(flags & IgnoreDiscardableAttrs) &&
844 lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs())
845 return false;
846
847 if (lhs->getName() != rhs->getName() ||
848 lhs->getNumRegions() != rhs->getNumRegions() ||
849 lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
850 lhs->getNumOperands() != rhs->getNumOperands() ||
851 lhs->getNumResults() != rhs->getNumResults())
852 return false;
853 if (!(flags & IgnoreProperties) &&
854 !(lhs->getName().compareOpProperties(lhs: lhs->getPropertiesStorage(),
855 rhs: rhs->getPropertiesStorage())))
856 return false;
857 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
858 return false;
859
860 // 2. Compare operands.
861 if (checkCommutativeEquivalent &&
862 lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
863 auto lhsRange = lhs->getOperands();
864 auto rhsRange = rhs->getOperands();
865 if (failed(Result: checkCommutativeEquivalent(lhsRange, rhsRange)))
866 return false;
867 } else {
868 // Check pair wise for equivalence.
869 for (auto operandPair : llvm::zip(t: lhs->getOperands(), u: rhs->getOperands())) {
870 Value curArg = std::get<0>(t&: operandPair);
871 Value otherArg = std::get<1>(t&: operandPair);
872 if (curArg == otherArg)
873 continue;
874 if (curArg.getType() != otherArg.getType())
875 return false;
876 if (failed(Result: checkEquivalent(curArg, otherArg)))
877 return false;
878 }
879 }
880
881 // 3. Compare result types and mark results as equivalent.
882 for (auto resultPair : llvm::zip(t: lhs->getResults(), u: rhs->getResults())) {
883 Value curArg = std::get<0>(t&: resultPair);
884 Value otherArg = std::get<1>(t&: resultPair);
885 if (curArg.getType() != otherArg.getType())
886 return false;
887 if (markEquivalent)
888 markEquivalent(curArg, otherArg);
889 }
890
891 // 4. Compare regions.
892 for (auto regionPair : llvm::zip(t: lhs->getRegions(), u: rhs->getRegions()))
893 if (!isRegionEquivalentTo(lhs: &std::get<0>(t&: regionPair),
894 rhs: &std::get<1>(t&: regionPair), checkEquivalent,
895 markEquivalent, flags))
896 return false;
897
898 return true;
899}
900
901/*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
902 Operation *rhs,
903 Flags flags) {
904 ValueEquivalenceCache cache;
905 return OperationEquivalence::isEquivalentTo(
906 lhs, rhs,
907 checkEquivalent: [&](Value lhsValue, Value rhsValue) -> LogicalResult {
908 return cache.checkEquivalent(lhsValue, rhsValue);
909 },
910 markEquivalent: [&](Value lhsResult, Value rhsResult) {
911 cache.markEquivalent(lhsResult, rhsResult);
912 },
913 flags,
914 checkCommutativeEquivalent: [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
915 return cache.checkCommutativeEquivalent(lhsRange: lhs, rhsRange: rhs);
916 });
917}
918
919//===----------------------------------------------------------------------===//
920// OperationFingerPrint
921//===----------------------------------------------------------------------===//
922
923template <typename T>
924static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
925 hasher.update(
926 Data: ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
927}
928
929OperationFingerPrint::OperationFingerPrint(Operation *topOp,
930 bool includeNested) {
931 llvm::SHA1 hasher;
932
933 // Helper function that hashes an operation based on its mutable bits:
934 auto addOperationToHash = [&](Operation *op) {
935 // - Operation pointer
936 addDataToHash(hasher, data: op);
937 // - Parent operation pointer (to take into account the nesting structure)
938 if (op != topOp)
939 addDataToHash(hasher, data: op->getParentOp());
940 // - Attributes
941 addDataToHash(hasher, data: op->getRawDictionaryAttrs());
942 // - Properties
943 addDataToHash(hasher, data: op->hashProperties());
944 // - Blocks in Regions
945 for (Region &region : op->getRegions()) {
946 for (Block &block : region) {
947 addDataToHash(hasher, data: &block);
948 for (BlockArgument arg : block.getArguments())
949 addDataToHash(hasher, data: arg);
950 }
951 }
952 // - Location
953 addDataToHash(hasher, data: op->getLoc().getAsOpaquePointer());
954 // - Operands
955 for (Value operand : op->getOperands())
956 addDataToHash(hasher, data: operand);
957 // - Successors
958 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
959 addDataToHash(hasher, data: op->getSuccessor(index: i));
960 // - Result types
961 for (Type t : op->getResultTypes())
962 addDataToHash(hasher, data: t);
963 };
964
965 if (includeNested)
966 topOp->walk(callback&: addOperationToHash);
967 else
968 addOperationToHash(topOp);
969
970 hash = hasher.result();
971}
972

source code of mlir/lib/IR/OperationSupport.cpp