1//===- OperationSupport.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains out-of-line implementations of the support types that
10// Operation and related classes build on top of.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/IR/OperationSupport.h"
15#include "mlir/IR/BuiltinAttributes.h"
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/OpDefinition.h"
18#include "llvm/ADT/BitVector.h"
19#include "llvm/Support/SHA1.h"
20#include <numeric>
21#include <optional>
22
23using namespace mlir;
24
25//===----------------------------------------------------------------------===//
26// NamedAttrList
27//===----------------------------------------------------------------------===//
28
29NamedAttrList::NamedAttrList(ArrayRef<NamedAttribute> attributes) {
30 assign(inStart: attributes.begin(), inEnd: attributes.end());
31}
32
33NamedAttrList::NamedAttrList(DictionaryAttr attributes)
34 : NamedAttrList(attributes ? attributes.getValue()
35 : ArrayRef<NamedAttribute>()) {
36 dictionarySorted.setPointerAndInt(PtrVal: attributes, IntVal: true);
37}
38
39NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) {
40 assign(inStart, inEnd);
41}
42
43ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
44
45std::optional<NamedAttribute> NamedAttrList::findDuplicate() const {
46 std::optional<NamedAttribute> duplicate =
47 DictionaryAttr::findDuplicate(attrs, isSorted());
48 // DictionaryAttr::findDuplicate will sort the list, so reset the sorted
49 // state.
50 if (!isSorted())
51 dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: true);
52 return duplicate;
53}
54
55DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
56 if (!isSorted()) {
57 DictionaryAttr::sortInPlace(attrs);
58 dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: true);
59 }
60 if (!dictionarySorted.getPointer())
61 dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
62 return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
63}
64
65/// Add an attribute with the specified name.
66void NamedAttrList::append(StringRef name, Attribute attr) {
67 append(StringAttr::get(attr.getContext(), name), attr);
68}
69
70/// Replaces the attributes with new list of attributes.
71void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) {
72 DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
73 dictionarySorted.setPointerAndInt(PtrVal: nullptr, IntVal: true);
74}
75
76void NamedAttrList::push_back(NamedAttribute newAttribute) {
77 if (isSorted())
78 dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute);
79 dictionarySorted.setPointer(nullptr);
80 attrs.push_back(Elt: newAttribute);
81}
82
83/// Return the specified attribute if present, null otherwise.
84Attribute NamedAttrList::get(StringRef name) const {
85 auto it = findAttr(attrs: *this, name);
86 return it.second ? it.first->getValue() : Attribute();
87}
88Attribute NamedAttrList::get(StringAttr name) const {
89 auto it = findAttr(*this, name);
90 return it.second ? it.first->getValue() : Attribute();
91}
92
93/// Return the specified named attribute if present, std::nullopt otherwise.
94std::optional<NamedAttribute> NamedAttrList::getNamed(StringRef name) const {
95 auto it = findAttr(attrs: *this, name);
96 return it.second ? *it.first : std::optional<NamedAttribute>();
97}
98std::optional<NamedAttribute> NamedAttrList::getNamed(StringAttr name) const {
99 auto it = findAttr(*this, name);
100 return it.second ? *it.first : std::optional<NamedAttribute>();
101}
102
103/// If the an attribute exists with the specified name, change it to the new
104/// value. Otherwise, add a new attribute with the specified name/value.
105Attribute NamedAttrList::set(StringAttr name, Attribute value) {
106 assert(value && "attributes may never be null");
107
108 // Look for an existing attribute with the given name, and set its value
109 // in-place. Return the previous value of the attribute, if there was one.
110 auto it = findAttr(*this, name);
111 if (it.second) {
112 // Update the existing attribute by swapping out the old value for the new
113 // value. Return the old value.
114 Attribute oldValue = it.first->getValue();
115 if (it.first->getValue() != value) {
116 it.first->setValue(value);
117
118 // If the attributes have changed, the dictionary is invalidated.
119 dictionarySorted.setPointer(nullptr);
120 }
121 return oldValue;
122 }
123 // Perform a string lookup to insert the new attribute into its sorted
124 // position.
125 if (isSorted())
126 it = findAttr(*this, name.strref());
127 attrs.insert(it.first, {name, value});
128 // Invalidate the dictionary. Return null as there was no previous value.
129 dictionarySorted.setPointer(nullptr);
130 return Attribute();
131}
132
133Attribute NamedAttrList::set(StringRef name, Attribute value) {
134 assert(value && "attributes may never be null");
135 return set(mlir::StringAttr::get(value.getContext(), name), value);
136}
137
138Attribute
139NamedAttrList::eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it) {
140 // Erasing does not affect the sorted property.
141 Attribute attr = it->getValue();
142 attrs.erase(CI: it);
143 dictionarySorted.setPointer(nullptr);
144 return attr;
145}
146
147Attribute NamedAttrList::erase(StringAttr name) {
148 auto it = findAttr(*this, name);
149 return it.second ? eraseImpl(it: it.first) : Attribute();
150}
151
152Attribute NamedAttrList::erase(StringRef name) {
153 auto it = findAttr(attrs&: *this, name);
154 return it.second ? eraseImpl(it: it.first) : Attribute();
155}
156
157NamedAttrList &
158NamedAttrList::operator=(const SmallVectorImpl<NamedAttribute> &rhs) {
159 assign(inStart: rhs.begin(), inEnd: rhs.end());
160 return *this;
161}
162
163NamedAttrList::operator ArrayRef<NamedAttribute>() const { return attrs; }
164
165//===----------------------------------------------------------------------===//
166// OperationState
167//===----------------------------------------------------------------------===//
168
169OperationState::OperationState(Location location, StringRef name)
170 : location(location), name(name, location->getContext()) {}
171
172OperationState::OperationState(Location location, OperationName name)
173 : location(location), name(name) {}
174
175OperationState::OperationState(Location location, OperationName name,
176 ValueRange operands, TypeRange types,
177 ArrayRef<NamedAttribute> attributes,
178 BlockRange successors,
179 MutableArrayRef<std::unique_ptr<Region>> regions)
180 : location(location), name(name),
181 operands(operands.begin(), operands.end()),
182 types(types.begin(), types.end()),
183 attributes(attributes.begin(), attributes.end()),
184 successors(successors.begin(), successors.end()) {
185 for (std::unique_ptr<Region> &r : regions)
186 this->regions.push_back(Elt: std::move(r));
187}
188OperationState::OperationState(Location location, StringRef name,
189 ValueRange operands, TypeRange types,
190 ArrayRef<NamedAttribute> attributes,
191 BlockRange successors,
192 MutableArrayRef<std::unique_ptr<Region>> regions)
193 : OperationState(location, OperationName(name, location.getContext()),
194 operands, types, attributes, successors, regions) {}
195
196OperationState::~OperationState() {
197 if (properties)
198 propertiesDeleter(properties);
199}
200
201LogicalResult OperationState::setProperties(
202 Operation *op, function_ref<InFlightDiagnostic()> emitError) const {
203 if (LLVM_UNLIKELY(propertiesAttr)) {
204 assert(!properties);
205 return op->setPropertiesFromAttribute(attr: propertiesAttr, emitError);
206 }
207 if (properties)
208 propertiesSetter(op->getPropertiesStorage(), properties);
209 return success();
210}
211
212void OperationState::addOperands(ValueRange newOperands) {
213 operands.append(in_start: newOperands.begin(), in_end: newOperands.end());
214}
215
216void OperationState::addSuccessors(BlockRange newSuccessors) {
217 successors.append(in_start: newSuccessors.begin(), in_end: newSuccessors.end());
218}
219
220Region *OperationState::addRegion() {
221 regions.emplace_back(Args: new Region);
222 return regions.back().get();
223}
224
225void OperationState::addRegion(std::unique_ptr<Region> &&region) {
226 regions.push_back(Elt: std::move(region));
227}
228
229void OperationState::addRegions(
230 MutableArrayRef<std::unique_ptr<Region>> regions) {
231 for (std::unique_ptr<Region> &region : regions)
232 addRegion(region: std::move(region));
233}
234
235//===----------------------------------------------------------------------===//
236// OperandStorage
237//===----------------------------------------------------------------------===//
238
239detail::OperandStorage::OperandStorage(Operation *owner,
240 OpOperand *trailingOperands,
241 ValueRange values)
242 : isStorageDynamic(false), operandStorage(trailingOperands) {
243 numOperands = capacity = values.size();
244 for (unsigned i = 0; i < numOperands; ++i)
245 new (&operandStorage[i]) OpOperand(owner, values[i]);
246}
247
248detail::OperandStorage::~OperandStorage() {
249 for (auto &operand : getOperands())
250 operand.~OpOperand();
251
252 // If the storage is dynamic, deallocate it.
253 if (isStorageDynamic)
254 free(ptr: operandStorage);
255}
256
257/// Replace the operands contained in the storage with the ones provided in
258/// 'values'.
259void detail::OperandStorage::setOperands(Operation *owner, ValueRange values) {
260 MutableArrayRef<OpOperand> storageOperands = resize(owner, newSize: values.size());
261 for (unsigned i = 0, e = values.size(); i != e; ++i)
262 storageOperands[i].set(values[i]);
263}
264
265/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
266/// with the ones provided in 'operands'. 'operands' may be smaller or larger
267/// than the range pointed to by 'start'+'length'.
268void detail::OperandStorage::setOperands(Operation *owner, unsigned start,
269 unsigned length, ValueRange operands) {
270 // If the new size is the same, we can update inplace.
271 unsigned newSize = operands.size();
272 if (newSize == length) {
273 MutableArrayRef<OpOperand> storageOperands = getOperands();
274 for (unsigned i = 0, e = length; i != e; ++i)
275 storageOperands[start + i].set(operands[i]);
276 return;
277 }
278 // If the new size is greater, remove the extra operands and set the rest
279 // inplace.
280 if (newSize < length) {
281 eraseOperands(start: start + operands.size(), length: length - newSize);
282 setOperands(owner, start, length: newSize, operands);
283 return;
284 }
285 // Otherwise, the new size is greater so we need to grow the storage.
286 auto storageOperands = resize(owner, newSize: size() + (newSize - length));
287
288 // Shift operands to the right to make space for the new operands.
289 unsigned rotateSize = storageOperands.size() - (start + length);
290 auto rbegin = storageOperands.rbegin();
291 std::rotate(first: rbegin, middle: std::next(x: rbegin, n: newSize - length), last: rbegin + rotateSize);
292
293 // Update the operands inplace.
294 for (unsigned i = 0, e = operands.size(); i != e; ++i)
295 storageOperands[start + i].set(operands[i]);
296}
297
298/// Erase an operand held by the storage.
299void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
300 MutableArrayRef<OpOperand> operands = getOperands();
301 assert((start + length) <= operands.size());
302 numOperands -= length;
303
304 // Shift all operands down if the operand to remove is not at the end.
305 if (start != numOperands) {
306 auto *indexIt = std::next(x: operands.begin(), n: start);
307 std::rotate(first: indexIt, middle: std::next(x: indexIt, n: length), last: operands.end());
308 }
309 for (unsigned i = 0; i != length; ++i)
310 operands[numOperands + i].~OpOperand();
311}
312
313void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
314 MutableArrayRef<OpOperand> operands = getOperands();
315 assert(eraseIndices.size() == operands.size());
316
317 // Check that at least one operand is erased.
318 int firstErasedIndice = eraseIndices.find_first();
319 if (firstErasedIndice == -1)
320 return;
321
322 // Shift all of the removed operands to the end, and destroy them.
323 numOperands = firstErasedIndice;
324 for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
325 if (!eraseIndices.test(Idx: i))
326 operands[numOperands++] = std::move(operands[i]);
327 for (OpOperand &operand : operands.drop_front(N: numOperands))
328 operand.~OpOperand();
329}
330
331/// Resize the storage to the given size. Returns the array containing the new
332/// operands.
333MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
334 unsigned newSize) {
335 // If the number of operands is less than or equal to the current amount, we
336 // can just update in place.
337 MutableArrayRef<OpOperand> origOperands = getOperands();
338 if (newSize <= numOperands) {
339 // If the number of new size is less than the current, remove any extra
340 // operands.
341 for (unsigned i = newSize; i != numOperands; ++i)
342 origOperands[i].~OpOperand();
343 numOperands = newSize;
344 return origOperands.take_front(N: newSize);
345 }
346
347 // If the new size is within the original inline capacity, grow inplace.
348 if (newSize <= capacity) {
349 OpOperand *opBegin = origOperands.data();
350 for (unsigned e = newSize; numOperands != e; ++numOperands)
351 new (&opBegin[numOperands]) OpOperand(owner);
352 return MutableArrayRef<OpOperand>(opBegin, newSize);
353 }
354
355 // Otherwise, we need to allocate a new storage.
356 unsigned newCapacity =
357 std::max(a: unsigned(llvm::NextPowerOf2(A: capacity + 2)), b: newSize);
358 OpOperand *newOperandStorage =
359 reinterpret_cast<OpOperand *>(malloc(size: sizeof(OpOperand) * newCapacity));
360
361 // Move the current operands to the new storage.
362 MutableArrayRef<OpOperand> newOperands(newOperandStorage, newSize);
363 std::uninitialized_move(first: origOperands.begin(), last: origOperands.end(),
364 result: newOperands.begin());
365
366 // Destroy the original operands.
367 for (auto &operand : origOperands)
368 operand.~OpOperand();
369
370 // Initialize any new operands.
371 for (unsigned e = newSize; numOperands != e; ++numOperands)
372 new (&newOperands[numOperands]) OpOperand(owner);
373
374 // If the current storage is dynamic, free it.
375 if (isStorageDynamic)
376 free(ptr: operandStorage);
377
378 // Update the storage representation to use the new dynamic storage.
379 operandStorage = newOperandStorage;
380 capacity = newCapacity;
381 isStorageDynamic = true;
382 return newOperands;
383}
384
385//===----------------------------------------------------------------------===//
386// Operation Value-Iterators
387//===----------------------------------------------------------------------===//
388
389//===----------------------------------------------------------------------===//
390// OperandRange
391
392unsigned OperandRange::getBeginOperandIndex() const {
393 assert(!empty() && "range must not be empty");
394 return base->getOperandNumber();
395}
396
397OperandRangeRange OperandRange::split(DenseI32ArrayAttr segmentSizes) const {
398 return OperandRangeRange(*this, segmentSizes);
399}
400
401//===----------------------------------------------------------------------===//
402// OperandRangeRange
403
404OperandRangeRange::OperandRangeRange(OperandRange operands,
405 Attribute operandSegments)
406 : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
407 llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
408}
409
410OperandRange OperandRangeRange::join() const {
411 const OwnerT &owner = getBase();
412 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
413 return OperandRange(owner.first,
414 std::accumulate(first: sizeData.begin(), last: sizeData.end(), init: 0));
415}
416
417OperandRange OperandRangeRange::dereference(const OwnerT &object,
418 ptrdiff_t index) {
419 ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
420 uint32_t startIndex =
421 std::accumulate(first: sizeData.begin(), last: sizeData.begin() + index, init: 0);
422 return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
423}
424
425//===----------------------------------------------------------------------===//
426// MutableOperandRange
427
428/// Construct a new mutable range from the given operand, operand start index,
429/// and range length.
430MutableOperandRange::MutableOperandRange(
431 Operation *owner, unsigned start, unsigned length,
432 ArrayRef<OperandSegment> operandSegments)
433 : owner(owner), start(start), length(length),
434 operandSegments(operandSegments.begin(), operandSegments.end()) {
435 assert((start + length) <= owner->getNumOperands() && "invalid range");
436}
437MutableOperandRange::MutableOperandRange(Operation *owner)
438 : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
439
440/// Construct a new mutable range for the given OpOperand.
441MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
442 : MutableOperandRange(opOperand.getOwner(),
443 /*start=*/opOperand.getOperandNumber(),
444 /*length=*/1) {}
445
446/// Slice this range into a sub range, with the additional operand segment.
447MutableOperandRange
448MutableOperandRange::slice(unsigned subStart, unsigned subLen,
449 std::optional<OperandSegment> segment) const {
450 assert((subStart + subLen) <= length && "invalid sub-range");
451 MutableOperandRange subSlice(owner, start + subStart, subLen,
452 operandSegments);
453 if (segment)
454 subSlice.operandSegments.push_back(Elt: *segment);
455 return subSlice;
456}
457
458/// Append the given values to the range.
459void MutableOperandRange::append(ValueRange values) {
460 if (values.empty())
461 return;
462 owner->insertOperands(index: start + length, operands: values);
463 updateLength(newLength: length + values.size());
464}
465
466/// Assign this range to the given values.
467void MutableOperandRange::assign(ValueRange values) {
468 owner->setOperands(start, length, operands: values);
469 if (length != values.size())
470 updateLength(/*newLength=*/values.size());
471}
472
473/// Assign the range to the given value.
474void MutableOperandRange::assign(Value value) {
475 if (length == 1) {
476 owner->setOperand(idx: start, value);
477 } else {
478 owner->setOperands(start, length, operands: value);
479 updateLength(/*newLength=*/1);
480 }
481}
482
483/// Erase the operands within the given sub-range.
484void MutableOperandRange::erase(unsigned subStart, unsigned subLen) {
485 assert((subStart + subLen) <= length && "invalid sub-range");
486 if (length == 0)
487 return;
488 owner->eraseOperands(idx: start + subStart, length: subLen);
489 updateLength(newLength: length - subLen);
490}
491
492/// Clear this range and erase all of the operands.
493void MutableOperandRange::clear() {
494 if (length != 0) {
495 owner->eraseOperands(idx: start, length);
496 updateLength(/*newLength=*/0);
497 }
498}
499
500/// Explicit conversion to an OperandRange.
501OperandRange MutableOperandRange::getAsOperandRange() const {
502 return owner->getOperands().slice(n: start, m: length);
503}
504
505/// Allow implicit conversion to an OperandRange.
506MutableOperandRange::operator OperandRange() const {
507 return getAsOperandRange();
508}
509
510MutableOperandRange::operator MutableArrayRef<OpOperand>() const {
511 return owner->getOpOperands().slice(N: start, M: length);
512}
513
514MutableOperandRangeRange
515MutableOperandRange::split(NamedAttribute segmentSizes) const {
516 return MutableOperandRangeRange(*this, segmentSizes);
517}
518
519/// Update the length of this range to the one provided.
520void MutableOperandRange::updateLength(unsigned newLength) {
521 int32_t diff = int32_t(newLength) - int32_t(length);
522 length = newLength;
523
524 // Update any of the provided segment attributes.
525 for (OperandSegment &segment : operandSegments) {
526 auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
527 SmallVector<int32_t, 8> segments(attr.asArrayRef());
528 segments[segment.first] += diff;
529 segment.second.setValue(
530 DenseI32ArrayAttr::get(attr.getContext(), segments));
531 owner->setAttr(segment.second.getName(), segment.second.getValue());
532 }
533}
534
535OpOperand &MutableOperandRange::operator[](unsigned index) const {
536 assert(index < length && "index is out of bounds");
537 return owner->getOpOperand(idx: start + index);
538}
539
540MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
541 return owner->getOpOperands().slice(N: start, M: length).begin();
542}
543
544MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
545 return owner->getOpOperands().slice(N: start, M: length).end();
546}
547
548//===----------------------------------------------------------------------===//
549// MutableOperandRangeRange
550
551MutableOperandRangeRange::MutableOperandRangeRange(
552 const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
553 : MutableOperandRangeRange(
554 OwnerT(operands, operandSegmentAttr), 0,
555 llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
556}
557
558MutableOperandRange MutableOperandRangeRange::join() const {
559 return getBase().first;
560}
561
562MutableOperandRangeRange::operator OperandRangeRange() const {
563 return OperandRangeRange(getBase().first, getBase().second.getValue());
564}
565
566MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
567 ptrdiff_t index) {
568 ArrayRef<int32_t> sizeData =
569 llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
570 uint32_t startIndex =
571 std::accumulate(first: sizeData.begin(), last: sizeData.begin() + index, init: 0);
572 return object.first.slice(
573 subStart: startIndex, subLen: *(sizeData.begin() + index),
574 segment: MutableOperandRange::OperandSegment(index, object.second));
575}
576
577//===----------------------------------------------------------------------===//
578// ResultRange
579
580ResultRange::ResultRange(OpResult result)
581 : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
582 1) {}
583
584ResultRange::use_range ResultRange::getUses() const {
585 return {use_begin(), use_end()};
586}
587ResultRange::use_iterator ResultRange::use_begin() const {
588 return use_iterator(*this);
589}
590ResultRange::use_iterator ResultRange::use_end() const {
591 return use_iterator(*this, /*end=*/true);
592}
593ResultRange::user_range ResultRange::getUsers() {
594 return {user_begin(), user_end()};
595}
596ResultRange::user_iterator ResultRange::user_begin() {
597 return user_iterator(use_begin());
598}
599ResultRange::user_iterator ResultRange::user_end() {
600 return user_iterator(use_end());
601}
602
603ResultRange::UseIterator::UseIterator(ResultRange results, bool end)
604 : it(end ? results.end() : results.begin()), endIt(results.end()) {
605 // Only initialize current use if there are results/can be uses.
606 if (it != endIt)
607 skipOverResultsWithNoUsers();
608}
609
610ResultRange::UseIterator &ResultRange::UseIterator::operator++() {
611 // We increment over uses, if we reach the last use then move to next
612 // result.
613 if (use != (*it).use_end())
614 ++use;
615 if (use == (*it).use_end()) {
616 ++it;
617 skipOverResultsWithNoUsers();
618 }
619 return *this;
620}
621
622void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
623 while (it != endIt && (*it).use_empty())
624 ++it;
625
626 // If we are at the last result, then set use to first use of
627 // first result (sentinel value used for end).
628 if (it == endIt)
629 use = {};
630 else
631 use = (*it).use_begin();
632}
633
634void ResultRange::replaceAllUsesWith(Operation *op) {
635 replaceAllUsesWith(values: op->getResults());
636}
637
638void ResultRange::replaceUsesWithIf(
639 Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
640 replaceUsesWithIf(values: op->getResults(), shouldReplace);
641}
642
643//===----------------------------------------------------------------------===//
644// ValueRange
645
646ValueRange::ValueRange(ArrayRef<Value> values)
647 : ValueRange(values.data(), values.size()) {}
648ValueRange::ValueRange(OperandRange values)
649 : ValueRange(values.begin().getBase(), values.size()) {}
650ValueRange::ValueRange(ResultRange values)
651 : ValueRange(values.getBase(), values.size()) {}
652
653/// See `llvm::detail::indexed_accessor_range_base` for details.
654ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
655 ptrdiff_t index) {
656 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(Val: owner))
657 return {value + index};
658 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(Val: owner))
659 return {operand + index};
660 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(offset: index);
661}
662/// See `llvm::detail::indexed_accessor_range_base` for details.
663Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
664 if (const auto *value = llvm::dyn_cast_if_present<const Value *>(Val: owner))
665 return value[index];
666 if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(Val: owner))
667 return operand[index].get();
668 return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(offset: index);
669}
670
671//===----------------------------------------------------------------------===//
672// Operation Equivalency
673//===----------------------------------------------------------------------===//
674
675llvm::hash_code OperationEquivalence::computeHash(
676 Operation *op, function_ref<llvm::hash_code(Value)> hashOperands,
677 function_ref<llvm::hash_code(Value)> hashResults, Flags flags) {
678 // Hash operations based upon their:
679 // - Operation Name
680 // - Attributes
681 // - Result Types
682 llvm::hash_code hash =
683 llvm::hash_combine(args: op->getName(), args: op->getRawDictionaryAttrs(),
684 args: op->getResultTypes(), args: op->hashProperties());
685
686 // - Location if required
687 if (!(flags & Flags::IgnoreLocations))
688 hash = llvm::hash_combine(args: hash, args: op->getLoc());
689
690 // - Operands
691 if (op->hasTrait<mlir::OpTrait::IsCommutative>() &&
692 op->getNumOperands() > 0) {
693 size_t operandHash = hashOperands(op->getOperand(idx: 0));
694 for (auto operand : op->getOperands().drop_front())
695 operandHash += hashOperands(operand);
696 hash = llvm::hash_combine(args: hash, args: operandHash);
697 } else {
698 for (Value operand : op->getOperands())
699 hash = llvm::hash_combine(args: hash, args: hashOperands(operand));
700 }
701
702 // - Results
703 for (Value result : op->getResults())
704 hash = llvm::hash_combine(args: hash, args: hashResults(result));
705 return hash;
706}
707
708/*static*/ bool OperationEquivalence::isRegionEquivalentTo(
709 Region *lhs, Region *rhs,
710 function_ref<LogicalResult(Value, Value)> checkEquivalent,
711 function_ref<void(Value, Value)> markEquivalent,
712 OperationEquivalence::Flags flags,
713 function_ref<LogicalResult(ValueRange, ValueRange)>
714 checkCommutativeEquivalent) {
715 DenseMap<Block *, Block *> blocksMap;
716 auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
717 // Check block arguments.
718 if (lBlock.getNumArguments() != rBlock.getNumArguments())
719 return false;
720
721 // Map the two blocks.
722 auto insertion = blocksMap.insert(KV: {&lBlock, &rBlock});
723 if (insertion.first->getSecond() != &rBlock)
724 return false;
725
726 for (auto argPair :
727 llvm::zip(t: lBlock.getArguments(), u: rBlock.getArguments())) {
728 Value curArg = std::get<0>(t&: argPair);
729 Value otherArg = std::get<1>(t&: argPair);
730 if (curArg.getType() != otherArg.getType())
731 return false;
732 if (!(flags & OperationEquivalence::IgnoreLocations) &&
733 curArg.getLoc() != otherArg.getLoc())
734 return false;
735 // Corresponding bbArgs are equivalent.
736 if (markEquivalent)
737 markEquivalent(curArg, otherArg);
738 }
739
740 auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
741 // Check for op equality (recursively).
742 if (!OperationEquivalence::isEquivalentTo(lhs: &lOp, rhs: &rOp, checkEquivalent,
743 markEquivalent, flags,
744 checkCommutativeEquivalent))
745 return false;
746 // Check successor mapping.
747 for (auto successorsPair :
748 llvm::zip(t: lOp.getSuccessors(), u: rOp.getSuccessors())) {
749 Block *curSuccessor = std::get<0>(t&: successorsPair);
750 Block *otherSuccessor = std::get<1>(t&: successorsPair);
751 auto insertion = blocksMap.insert(KV: {curSuccessor, otherSuccessor});
752 if (insertion.first->getSecond() != otherSuccessor)
753 return false;
754 }
755 return true;
756 };
757 return llvm::all_of_zip(argsAndPredicate&: lBlock, argsAndPredicate&: rBlock, argsAndPredicate&: opsEquivalent);
758 };
759 return llvm::all_of_zip(argsAndPredicate&: *lhs, argsAndPredicate&: *rhs, argsAndPredicate&: blocksEquivalent);
760}
761
762// Value equivalence cache to be used with `isRegionEquivalentTo` and
763// `isEquivalentTo`.
764struct ValueEquivalenceCache {
765 DenseMap<Value, Value> equivalentValues;
766 LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) {
767 return success(isSuccess: lhsValue == rhsValue ||
768 equivalentValues.lookup(Val: lhsValue) == rhsValue);
769 }
770 LogicalResult checkCommutativeEquivalent(ValueRange lhsRange,
771 ValueRange rhsRange) {
772 // Handle simple case where sizes mismatch.
773 if (lhsRange.size() != rhsRange.size())
774 return failure();
775
776 // Handle where operands in order are equivalent.
777 auto lhsIt = lhsRange.begin();
778 auto rhsIt = rhsRange.begin();
779 for (; lhsIt != lhsRange.end(); ++lhsIt, ++rhsIt) {
780 if (failed(result: checkEquivalent(lhsValue: *lhsIt, rhsValue: *rhsIt)))
781 break;
782 }
783 if (lhsIt == lhsRange.end())
784 return success();
785
786 // Handle another simple case where operands are just a permutation.
787 // Note: This is not sufficient, this handles simple cases relatively
788 // cheaply.
789 auto sortValues = [](ValueRange values) {
790 SmallVector<Value> sortedValues = llvm::to_vector(Range&: values);
791 llvm::sort(C&: sortedValues, Comp: [](Value a, Value b) {
792 return a.getAsOpaquePointer() < b.getAsOpaquePointer();
793 });
794 return sortedValues;
795 };
796 auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
797 auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
798 return success(isSuccess: lhsSorted == rhsSorted);
799 }
800 void markEquivalent(Value lhsResult, Value rhsResult) {
801 auto insertion = equivalentValues.insert(KV: {lhsResult, rhsResult});
802 // Make sure that the value was not already marked equivalent to some other
803 // value.
804 (void)insertion;
805 assert(insertion.first->second == rhsResult &&
806 "inconsistent OperationEquivalence state");
807 }
808};
809
810/*static*/ bool
811OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
812 OperationEquivalence::Flags flags) {
813 ValueEquivalenceCache cache;
814 return isRegionEquivalentTo(
815 lhs, rhs,
816 checkEquivalent: [&](Value lhsValue, Value rhsValue) -> LogicalResult {
817 return cache.checkEquivalent(lhsValue, rhsValue);
818 },
819 markEquivalent: [&](Value lhsResult, Value rhsResult) {
820 cache.markEquivalent(lhsResult, rhsResult);
821 },
822 flags,
823 checkCommutativeEquivalent: [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
824 return cache.checkCommutativeEquivalent(lhsRange: lhs, rhsRange: rhs);
825 });
826}
827
828/*static*/ bool OperationEquivalence::isEquivalentTo(
829 Operation *lhs, Operation *rhs,
830 function_ref<LogicalResult(Value, Value)> checkEquivalent,
831 function_ref<void(Value, Value)> markEquivalent, Flags flags,
832 function_ref<LogicalResult(ValueRange, ValueRange)>
833 checkCommutativeEquivalent) {
834 if (lhs == rhs)
835 return true;
836
837 // 1. Compare the operation properties.
838 if (lhs->getName() != rhs->getName() ||
839 lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() ||
840 lhs->getNumRegions() != rhs->getNumRegions() ||
841 lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
842 lhs->getNumOperands() != rhs->getNumOperands() ||
843 lhs->getNumResults() != rhs->getNumResults() ||
844 !lhs->getName().compareOpProperties(lhs: lhs->getPropertiesStorage(),
845 rhs: rhs->getPropertiesStorage()))
846 return false;
847 if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
848 return false;
849
850 // 2. Compare operands.
851 if (checkCommutativeEquivalent &&
852 lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
853 auto lhsRange = lhs->getOperands();
854 auto rhsRange = rhs->getOperands();
855 if (failed(result: checkCommutativeEquivalent(lhsRange, rhsRange)))
856 return false;
857 } else {
858 // Check pair wise for equivalence.
859 for (auto operandPair : llvm::zip(t: lhs->getOperands(), u: rhs->getOperands())) {
860 Value curArg = std::get<0>(t&: operandPair);
861 Value otherArg = std::get<1>(t&: operandPair);
862 if (curArg == otherArg)
863 continue;
864 if (curArg.getType() != otherArg.getType())
865 return false;
866 if (failed(result: checkEquivalent(curArg, otherArg)))
867 return false;
868 }
869 }
870
871 // 3. Compare result types and mark results as equivalent.
872 for (auto resultPair : llvm::zip(t: lhs->getResults(), u: rhs->getResults())) {
873 Value curArg = std::get<0>(t&: resultPair);
874 Value otherArg = std::get<1>(t&: resultPair);
875 if (curArg.getType() != otherArg.getType())
876 return false;
877 if (markEquivalent)
878 markEquivalent(curArg, otherArg);
879 }
880
881 // 4. Compare regions.
882 for (auto regionPair : llvm::zip(t: lhs->getRegions(), u: rhs->getRegions()))
883 if (!isRegionEquivalentTo(lhs: &std::get<0>(t&: regionPair),
884 rhs: &std::get<1>(t&: regionPair), checkEquivalent,
885 markEquivalent, flags))
886 return false;
887
888 return true;
889}
890
891/*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs,
892 Operation *rhs,
893 Flags flags) {
894 ValueEquivalenceCache cache;
895 return OperationEquivalence::isEquivalentTo(
896 lhs, rhs,
897 checkEquivalent: [&](Value lhsValue, Value rhsValue) -> LogicalResult {
898 return cache.checkEquivalent(lhsValue, rhsValue);
899 },
900 markEquivalent: [&](Value lhsResult, Value rhsResult) {
901 cache.markEquivalent(lhsResult, rhsResult);
902 },
903 flags,
904 checkCommutativeEquivalent: [&](ValueRange lhs, ValueRange rhs) -> LogicalResult {
905 return cache.checkCommutativeEquivalent(lhsRange: lhs, rhsRange: rhs);
906 });
907}
908
909//===----------------------------------------------------------------------===//
910// OperationFingerPrint
911//===----------------------------------------------------------------------===//
912
913template <typename T>
914static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
915 hasher.update(
916 Data: ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
917}
918
919OperationFingerPrint::OperationFingerPrint(Operation *topOp,
920 bool includeNested) {
921 llvm::SHA1 hasher;
922
923 // Helper function that hashes an operation based on its mutable bits:
924 auto addOperationToHash = [&](Operation *op) {
925 // - Operation pointer
926 addDataToHash(hasher, data: op);
927 // - Parent operation pointer (to take into account the nesting structure)
928 if (op != topOp)
929 addDataToHash(hasher, data: op->getParentOp());
930 // - Attributes
931 addDataToHash(hasher, data: op->getRawDictionaryAttrs());
932 // - Properties
933 addDataToHash(hasher, data: op->hashProperties());
934 // - Blocks in Regions
935 for (Region &region : op->getRegions()) {
936 for (Block &block : region) {
937 addDataToHash(hasher, data: &block);
938 for (BlockArgument arg : block.getArguments())
939 addDataToHash(hasher, data: arg);
940 }
941 }
942 // - Location
943 addDataToHash(hasher, data: op->getLoc().getAsOpaquePointer());
944 // - Operands
945 for (Value operand : op->getOperands())
946 addDataToHash(hasher, data: operand);
947 // - Successors
948 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
949 addDataToHash(hasher, data: op->getSuccessor(index: i));
950 // - Result types
951 for (Type t : op->getResultTypes())
952 addDataToHash(hasher, data: t);
953 };
954
955 if (includeNested)
956 topOp->walk(callback&: addOperationToHash);
957 else
958 addOperationToHash(topOp);
959
960 hash = hasher.result();
961}
962

source code of mlir/lib/IR/OperationSupport.cpp