1//===- ValueRange.h - Indexed Value-Iterators Range Classes -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the ValueRange related classes.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_VALUERANGE_H
14#define MLIR_IR_VALUERANGE_H
15
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/Types.h"
18#include "mlir/IR/Value.h"
19#include "llvm/ADT/PointerUnion.h"
20#include "llvm/ADT/Sequence.h"
21#include <optional>
22
23namespace mlir {
24class ValueRange;
25template <typename ValueRangeT>
26class ValueTypeRange;
27class TypeRangeRange;
28template <typename ValueIteratorT>
29class ValueTypeIterator;
30class OperandRangeRange;
31class MutableOperandRangeRange;
32
33//===----------------------------------------------------------------------===//
34// Operation Value-Iterators
35//===----------------------------------------------------------------------===//
36
37//===----------------------------------------------------------------------===//
38// OperandRange
39
40/// This class implements the operand iterators for the Operation class.
41class OperandRange final : public llvm::detail::indexed_accessor_range_base<
42 OperandRange, OpOperand *, Value, Value, Value> {
43public:
44 using RangeBaseT::RangeBaseT;
45
46 /// Returns the types of the values within this range.
47 using type_iterator = ValueTypeIterator<iterator>;
48 using type_range = ValueTypeRange<OperandRange>;
49 type_range getTypes() const;
50 type_range getType() const;
51
52 /// Return the operand index of the first element of this range. The range
53 /// must not be empty.
54 unsigned getBeginOperandIndex() const;
55
56 /// Split this range into a set of contiguous subranges using the given
57 /// elements attribute, which contains the sizes of the sub ranges.
58 OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const;
59
60private:
61 /// See `llvm::detail::indexed_accessor_range_base` for details.
62 static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) {
63 return object + index;
64 }
65 /// See `llvm::detail::indexed_accessor_range_base` for details.
66 static Value dereference_iterator(OpOperand *object, ptrdiff_t index) {
67 return object[index].get();
68 }
69
70 /// Allow access to `offset_base` and `dereference_iterator`.
71 friend RangeBaseT;
72};
73
74//===----------------------------------------------------------------------===//
75// OperandRangeRange
76
77/// This class represents a contiguous range of operand ranges, e.g. from a
78/// VariadicOfVariadic operand group.
79class OperandRangeRange final
80 : public llvm::indexed_accessor_range<
81 OperandRangeRange, std::pair<OpOperand *, Attribute>, OperandRange,
82 OperandRange, OperandRange> {
83 using OwnerT = std::pair<OpOperand *, Attribute>;
84 using RangeBaseT =
85 llvm::indexed_accessor_range<OperandRangeRange, OwnerT, OperandRange,
86 OperandRange, OperandRange>;
87
88public:
89 using RangeBaseT::RangeBaseT;
90
91 /// Returns the range of types of the values within this range.
92 TypeRangeRange getTypes() const;
93 TypeRangeRange getType() const;
94
95 /// Construct a range given a parent set of operands, and an I32 elements
96 /// attribute containing the sizes of the sub ranges.
97 OperandRangeRange(OperandRange operands, Attribute operandSegments);
98
99 /// Flatten all of the sub ranges into a single contiguous operand range.
100 OperandRange join() const;
101
102private:
103 /// See `llvm::indexed_accessor_range` for details.
104 static OperandRange dereference(const OwnerT &object, ptrdiff_t index);
105
106 /// Allow access to `dereference_iterator`.
107 friend RangeBaseT;
108};
109
110//===----------------------------------------------------------------------===//
111// MutableOperandRange
112
113/// This class provides a mutable adaptor for a range of operands. It allows for
114/// setting, inserting, and erasing operands from the given range.
115class MutableOperandRange {
116public:
117 /// A pair of a named attribute corresponding to an operand segment attribute,
118 /// and the index within that attribute. The attribute should correspond to a
119 /// dense i32 array attr.
120 using OperandSegment = std::pair<unsigned, NamedAttribute>;
121
122 /// Construct a new mutable range from the given operand, operand start index,
123 /// and range length. `operandSegments` is an optional set of operand segments
124 /// to be updated when mutating the operand list.
125 MutableOperandRange(Operation *owner, unsigned start, unsigned length,
126 ArrayRef<OperandSegment> operandSegments = std::nullopt);
127 MutableOperandRange(Operation *owner);
128
129 /// Construct a new mutable range for the given OpOperand.
130 MutableOperandRange(OpOperand &opOperand);
131
132 /// Slice this range into a sub range, with the additional operand segment.
133 MutableOperandRange
134 slice(unsigned subStart, unsigned subLen,
135 std::optional<OperandSegment> segment = std::nullopt) const;
136
137 /// Append the given values to the range.
138 void append(ValueRange values);
139
140 /// Assign this range to the given values.
141 void assign(ValueRange values);
142
143 /// Assign the range to the given value.
144 void assign(Value value);
145
146 /// Erase the operands within the given sub-range.
147 void erase(unsigned subStart, unsigned subLen = 1);
148
149 /// Clear this range and erase all of the operands.
150 void clear();
151
152 /// Returns the current size of the range.
153 unsigned size() const { return length; }
154
155 /// Returns if the current range is empty.
156 bool empty() const { return size() == 0; }
157
158 /// Explicit conversion to an OperandRange.
159 OperandRange getAsOperandRange() const;
160
161 /// Allow implicit conversion to an OperandRange.
162 operator OperandRange() const;
163
164 /// Allow implicit conversion to a MutableArrayRef.
165 operator MutableArrayRef<OpOperand>() const;
166
167 /// Returns the owning operation.
168 Operation *getOwner() const { return owner; }
169
170 /// Split this range into a set of contiguous subranges using the given
171 /// elements attribute, which contains the sizes of the sub ranges.
172 MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
173
174 /// Returns the OpOperand at the given index.
175 OpOperand &operator[](unsigned index) const;
176
177 /// Iterators enumerate OpOperands.
178 MutableArrayRef<OpOperand>::iterator begin() const;
179 MutableArrayRef<OpOperand>::iterator end() const;
180
181private:
182 /// Update the length of this range to the one provided.
183 void updateLength(unsigned newLength);
184
185 /// The owning operation of this range.
186 Operation *owner;
187
188 /// The start index of the operand range within the owner operand list, and
189 /// the length starting from `start`.
190 unsigned start, length;
191
192 /// Optional set of operand segments that should be updated when mutating the
193 /// length of this range.
194 SmallVector<OperandSegment, 1> operandSegments;
195};
196
197//===----------------------------------------------------------------------===//
198// MutableOperandRangeRange
199
200/// This class represents a contiguous range of mutable operand ranges, e.g.
201/// from a VariadicOfVariadic operand group.
202class MutableOperandRangeRange final
203 : public llvm::indexed_accessor_range<
204 MutableOperandRangeRange,
205 std::pair<MutableOperandRange, NamedAttribute>, MutableOperandRange,
206 MutableOperandRange, MutableOperandRange> {
207 using OwnerT = std::pair<MutableOperandRange, NamedAttribute>;
208 using RangeBaseT =
209 llvm::indexed_accessor_range<MutableOperandRangeRange, OwnerT,
210 MutableOperandRange, MutableOperandRange,
211 MutableOperandRange>;
212
213public:
214 using RangeBaseT::RangeBaseT;
215
216 /// Construct a range given a parent set of operands, and an I32 tensor
217 /// elements attribute containing the sizes of the sub ranges.
218 MutableOperandRangeRange(const MutableOperandRange &operands,
219 NamedAttribute operandSegmentAttr);
220
221 /// Flatten all of the sub ranges into a single contiguous mutable operand
222 /// range.
223 MutableOperandRange join() const;
224
225 /// Allow implicit conversion to an OperandRangeRange.
226 operator OperandRangeRange() const;
227
228private:
229 /// See `llvm::indexed_accessor_range` for details.
230 static MutableOperandRange dereference(const OwnerT &object, ptrdiff_t index);
231
232 /// Allow access to `dereference_iterator`.
233 friend RangeBaseT;
234};
235
236//===----------------------------------------------------------------------===//
237// ResultRange
238
239/// This class implements the result iterators for the Operation class.
240class ResultRange final
241 : public llvm::detail::indexed_accessor_range_base<
242 ResultRange, detail::OpResultImpl *, OpResult, OpResult, OpResult> {
243public:
244 using RangeBaseT::RangeBaseT;
245 ResultRange(OpResult result);
246
247 //===--------------------------------------------------------------------===//
248 // Types
249 //===--------------------------------------------------------------------===//
250
251 /// Returns the types of the values within this range.
252 using type_iterator = ValueTypeIterator<iterator>;
253 using type_range = ValueTypeRange<ResultRange>;
254 type_range getTypes() const;
255 type_range getType() const;
256
257 //===--------------------------------------------------------------------===//
258 // Uses
259 //===--------------------------------------------------------------------===//
260
261 class UseIterator;
262 using use_iterator = UseIterator;
263 using use_range = iterator_range<use_iterator>;
264
265 /// Returns a range of all uses of results within this range, which is useful
266 /// for iterating over all uses.
267 use_range getUses() const;
268 use_iterator use_begin() const;
269 use_iterator use_end() const;
270
271 /// Returns true if no results in this range have uses.
272 bool use_empty() const {
273 return llvm::all_of(Range: *this,
274 P: [](OpResult result) { return result.use_empty(); });
275 }
276
277 /// Replace all uses of results of this range with the provided 'values'. The
278 /// size of `values` must match the size of this range.
279 template <typename ValuesT>
280 std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
281 replaceAllUsesWith(ValuesT &&values) {
282 assert(static_cast<size_t>(std::distance(values.begin(), values.end())) ==
283 size() &&
284 "expected 'values' to correspond 1-1 with the number of results");
285
286 for (auto it : llvm::zip(*this, values))
287 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
288 }
289
290 /// Replace all uses of results of this range with results of 'op'.
291 void replaceAllUsesWith(Operation *op);
292
293 /// Replace uses of results of this range with the provided 'values' if the
294 /// given callback returns true. The size of `values` must match the size of
295 /// this range.
296 template <typename ValuesT>
297 std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
298 replaceUsesWithIf(ValuesT &&values,
299 function_ref<bool(OpOperand &)> shouldReplace) {
300 assert(static_cast<size_t>(std::distance(values.begin(), values.end())) ==
301 size() &&
302 "expected 'values' to correspond 1-1 with the number of results");
303
304 for (auto it : llvm::zip(*this, values))
305 std::get<0>(it).replaceUsesWithIf(std::get<1>(it), shouldReplace);
306 }
307
308 /// Replace uses of results of this range with results of `op` if the given
309 /// callback returns true.
310 void replaceUsesWithIf(Operation *op,
311 function_ref<bool(OpOperand &)> shouldReplace);
312
313 //===--------------------------------------------------------------------===//
314 // Users
315 //===--------------------------------------------------------------------===//
316
317 using user_iterator = ValueUserIterator<use_iterator, OpOperand>;
318 using user_range = iterator_range<user_iterator>;
319
320 /// Returns a range of all users.
321 user_range getUsers();
322 user_iterator user_begin();
323 user_iterator user_end();
324
325private:
326 /// See `llvm::detail::indexed_accessor_range_base` for details.
327 static detail::OpResultImpl *offset_base(detail::OpResultImpl *object,
328 ptrdiff_t index) {
329 return object->getNextResultAtOffset(offset: index);
330 }
331 /// See `llvm::detail::indexed_accessor_range_base` for details.
332 static OpResult dereference_iterator(detail::OpResultImpl *object,
333 ptrdiff_t index) {
334 return offset_base(object, index);
335 }
336
337 /// Allow access to `offset_base` and `dereference_iterator`.
338 friend RangeBaseT;
339};
340
341/// This class implements a use iterator for a range of operation results.
342/// This iterates over all uses of all results within the given result range.
343class ResultRange::UseIterator final
344 : public llvm::iterator_facade_base<UseIterator, std::forward_iterator_tag,
345 OpOperand> {
346public:
347 /// Initialize the UseIterator. Specify `end` to return iterator to last
348 /// use, otherwise this is an iterator to the first use.
349 explicit UseIterator(ResultRange results, bool end = false);
350
351 using llvm::iterator_facade_base<UseIterator, std::forward_iterator_tag,
352 OpOperand>::operator++;
353 UseIterator &operator++();
354 OpOperand *operator->() const { return use.getOperand(); }
355 OpOperand &operator*() const { return *use.getOperand(); }
356
357 bool operator==(const UseIterator &rhs) const { return use == rhs.use; }
358 bool operator!=(const UseIterator &rhs) const { return !(*this == rhs); }
359
360private:
361 void skipOverResultsWithNoUsers();
362
363 /// The range of results being iterated over.
364 ResultRange::iterator it, endIt;
365 /// The use of the result.
366 Value::use_iterator use;
367};
368
369//===----------------------------------------------------------------------===//
370// ValueRange
371
372/// This class provides an abstraction over the different types of ranges over
373/// Values. In many cases, this prevents the need to explicitly materialize a
374/// SmallVector/std::vector. This class should be used in places that are not
375/// suitable for a more derived type (e.g. ArrayRef) or a template range
376/// parameter.
377class ValueRange final
378 : public llvm::detail::indexed_accessor_range_base<
379 ValueRange,
380 PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>,
381 Value, Value, Value> {
382public:
383 /// The type representing the owner of a ValueRange. This is either a list of
384 /// values, operands, or results.
385 using OwnerT =
386 PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>;
387
388 using RangeBaseT::RangeBaseT;
389
390 template <typename Arg,
391 typename = std::enable_if_t<
392 std::is_constructible<ArrayRef<Value>, Arg>::value &&
393 !std::is_convertible<Arg, Value>::value>>
394 ValueRange(Arg &&arg) : ValueRange(ArrayRef<Value>(std::forward<Arg>(arg))) {}
395 ValueRange(const Value &value) : ValueRange(&value, /*count=*/1) {}
396 ValueRange(const std::initializer_list<Value> &values)
397 : ValueRange(ArrayRef<Value>(values)) {}
398 ValueRange(iterator_range<OperandRange::iterator> values)
399 : ValueRange(OperandRange(values)) {}
400 ValueRange(iterator_range<ResultRange::iterator> values)
401 : ValueRange(ResultRange(values)) {}
402 ValueRange(ArrayRef<BlockArgument> values)
403 : ValueRange(ArrayRef<Value>(values.data(), values.size())) {}
404 ValueRange(ArrayRef<Value> values = std::nullopt);
405 ValueRange(OperandRange values);
406 ValueRange(ResultRange values);
407
408 /// Returns the types of the values within this range.
409 using type_iterator = ValueTypeIterator<iterator>;
410 using type_range = ValueTypeRange<ValueRange>;
411 type_range getTypes() const;
412 type_range getType() const;
413
414private:
415 /// See `llvm::detail::indexed_accessor_range_base` for details.
416 static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
417 /// See `llvm::detail::indexed_accessor_range_base` for details.
418 static Value dereference_iterator(const OwnerT &owner, ptrdiff_t index);
419
420 /// Allow access to `offset_base` and `dereference_iterator`.
421 friend RangeBaseT;
422};
423
424} // namespace mlir
425
426#endif // MLIR_IR_VALUERANGE_H
427

source code of mlir/include/mlir/IR/ValueRange.h