1 | //===- ValueRange.h - Indexed Value-Iterators Range Classes -----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the ValueRange related classes. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_VALUERANGE_H |
14 | #define MLIR_IR_VALUERANGE_H |
15 | |
16 | #include "mlir/IR/BuiltinAttributes.h" |
17 | #include "mlir/IR/Types.h" |
18 | #include "mlir/IR/Value.h" |
19 | #include "llvm/ADT/PointerUnion.h" |
20 | #include "llvm/ADT/Sequence.h" |
21 | #include <optional> |
22 | |
23 | namespace mlir { |
24 | class ValueRange; |
25 | template <typename ValueRangeT> |
26 | class ValueTypeRange; |
27 | class TypeRangeRange; |
28 | template <typename ValueIteratorT> |
29 | class ValueTypeIterator; |
30 | class OperandRangeRange; |
31 | class MutableOperandRangeRange; |
32 | |
33 | //===----------------------------------------------------------------------===// |
34 | // Operation Value-Iterators |
35 | //===----------------------------------------------------------------------===// |
36 | |
37 | //===----------------------------------------------------------------------===// |
38 | // OperandRange |
39 | |
40 | /// This class implements the operand iterators for the Operation class. |
41 | class OperandRange final : public llvm::detail::indexed_accessor_range_base< |
42 | OperandRange, OpOperand *, Value, Value, Value> { |
43 | public: |
44 | using RangeBaseT::RangeBaseT; |
45 | |
46 | /// Returns the types of the values within this range. |
47 | using type_iterator = ValueTypeIterator<iterator>; |
48 | using type_range = ValueTypeRange<OperandRange>; |
49 | type_range getTypes() const; |
50 | type_range getType() const; |
51 | |
52 | /// Return the operand index of the first element of this range. The range |
53 | /// must not be empty. |
54 | unsigned getBeginOperandIndex() const; |
55 | |
56 | /// Split this range into a set of contiguous subranges using the given |
57 | /// elements attribute, which contains the sizes of the sub ranges. |
58 | OperandRangeRange split(DenseI32ArrayAttr segmentSizes) const; |
59 | |
60 | private: |
61 | /// See `llvm::detail::indexed_accessor_range_base` for details. |
62 | static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) { |
63 | return object + index; |
64 | } |
65 | /// See `llvm::detail::indexed_accessor_range_base` for details. |
66 | static Value dereference_iterator(OpOperand *object, ptrdiff_t index) { |
67 | return object[index].get(); |
68 | } |
69 | |
70 | /// Allow access to `offset_base` and `dereference_iterator`. |
71 | friend RangeBaseT; |
72 | }; |
73 | |
74 | //===----------------------------------------------------------------------===// |
75 | // OperandRangeRange |
76 | |
77 | /// This class represents a contiguous range of operand ranges, e.g. from a |
78 | /// VariadicOfVariadic operand group. |
79 | class OperandRangeRange final |
80 | : public llvm::indexed_accessor_range< |
81 | OperandRangeRange, std::pair<OpOperand *, Attribute>, OperandRange, |
82 | OperandRange, OperandRange> { |
83 | using OwnerT = std::pair<OpOperand *, Attribute>; |
84 | using RangeBaseT = |
85 | llvm::indexed_accessor_range<OperandRangeRange, OwnerT, OperandRange, |
86 | OperandRange, OperandRange>; |
87 | |
88 | public: |
89 | using RangeBaseT::RangeBaseT; |
90 | |
91 | /// Returns the range of types of the values within this range. |
92 | TypeRangeRange getTypes() const; |
93 | TypeRangeRange getType() const; |
94 | |
95 | /// Construct a range given a parent set of operands, and an I32 elements |
96 | /// attribute containing the sizes of the sub ranges. |
97 | OperandRangeRange(OperandRange operands, Attribute operandSegments); |
98 | |
99 | /// Flatten all of the sub ranges into a single contiguous operand range. |
100 | OperandRange join() const; |
101 | |
102 | private: |
103 | /// See `llvm::indexed_accessor_range` for details. |
104 | static OperandRange dereference(const OwnerT &object, ptrdiff_t index); |
105 | |
106 | /// Allow access to `dereference_iterator`. |
107 | friend RangeBaseT; |
108 | }; |
109 | |
110 | //===----------------------------------------------------------------------===// |
111 | // MutableOperandRange |
112 | |
113 | /// This class provides a mutable adaptor for a range of operands. It allows for |
114 | /// setting, inserting, and erasing operands from the given range. |
115 | class MutableOperandRange { |
116 | public: |
117 | /// A pair of a named attribute corresponding to an operand segment attribute, |
118 | /// and the index within that attribute. The attribute should correspond to a |
119 | /// dense i32 array attr. |
120 | using OperandSegment = std::pair<unsigned, NamedAttribute>; |
121 | |
122 | /// Construct a new mutable range from the given operand, operand start index, |
123 | /// and range length. `operandSegments` is an optional set of operand segments |
124 | /// to be updated when mutating the operand list. |
125 | MutableOperandRange(Operation *owner, unsigned start, unsigned length, |
126 | ArrayRef<OperandSegment> operandSegments = std::nullopt); |
127 | MutableOperandRange(Operation *owner); |
128 | |
129 | /// Construct a new mutable range for the given OpOperand. |
130 | MutableOperandRange(OpOperand &opOperand); |
131 | |
132 | /// Slice this range into a sub range, with the additional operand segment. |
133 | MutableOperandRange |
134 | slice(unsigned subStart, unsigned subLen, |
135 | std::optional<OperandSegment> segment = std::nullopt) const; |
136 | |
137 | /// Append the given values to the range. |
138 | void append(ValueRange values); |
139 | |
140 | /// Assign this range to the given values. |
141 | void assign(ValueRange values); |
142 | |
143 | /// Assign the range to the given value. |
144 | void assign(Value value); |
145 | |
146 | /// Erase the operands within the given sub-range. |
147 | void erase(unsigned subStart, unsigned subLen = 1); |
148 | |
149 | /// Clear this range and erase all of the operands. |
150 | void clear(); |
151 | |
152 | /// Returns the current size of the range. |
153 | unsigned size() const { return length; } |
154 | |
155 | /// Returns if the current range is empty. |
156 | bool empty() const { return size() == 0; } |
157 | |
158 | /// Explicit conversion to an OperandRange. |
159 | OperandRange getAsOperandRange() const; |
160 | |
161 | /// Allow implicit conversion to an OperandRange. |
162 | operator OperandRange() const; |
163 | |
164 | /// Allow implicit conversion to a MutableArrayRef. |
165 | operator MutableArrayRef<OpOperand>() const; |
166 | |
167 | /// Returns the owning operation. |
168 | Operation *getOwner() const { return owner; } |
169 | |
170 | /// Split this range into a set of contiguous subranges using the given |
171 | /// elements attribute, which contains the sizes of the sub ranges. |
172 | MutableOperandRangeRange split(NamedAttribute segmentSizes) const; |
173 | |
174 | /// Returns the OpOperand at the given index. |
175 | OpOperand &operator[](unsigned index) const; |
176 | |
177 | /// Iterators enumerate OpOperands. |
178 | MutableArrayRef<OpOperand>::iterator begin() const; |
179 | MutableArrayRef<OpOperand>::iterator end() const; |
180 | |
181 | private: |
182 | /// Update the length of this range to the one provided. |
183 | void updateLength(unsigned newLength); |
184 | |
185 | /// The owning operation of this range. |
186 | Operation *owner; |
187 | |
188 | /// The start index of the operand range within the owner operand list, and |
189 | /// the length starting from `start`. |
190 | unsigned start, length; |
191 | |
192 | /// Optional set of operand segments that should be updated when mutating the |
193 | /// length of this range. |
194 | SmallVector<OperandSegment, 1> operandSegments; |
195 | }; |
196 | |
197 | //===----------------------------------------------------------------------===// |
198 | // MutableOperandRangeRange |
199 | |
200 | /// This class represents a contiguous range of mutable operand ranges, e.g. |
201 | /// from a VariadicOfVariadic operand group. |
202 | class MutableOperandRangeRange final |
203 | : public llvm::indexed_accessor_range< |
204 | MutableOperandRangeRange, |
205 | std::pair<MutableOperandRange, NamedAttribute>, MutableOperandRange, |
206 | MutableOperandRange, MutableOperandRange> { |
207 | using OwnerT = std::pair<MutableOperandRange, NamedAttribute>; |
208 | using RangeBaseT = |
209 | llvm::indexed_accessor_range<MutableOperandRangeRange, OwnerT, |
210 | MutableOperandRange, MutableOperandRange, |
211 | MutableOperandRange>; |
212 | |
213 | public: |
214 | using RangeBaseT::RangeBaseT; |
215 | |
216 | /// Construct a range given a parent set of operands, and an I32 tensor |
217 | /// elements attribute containing the sizes of the sub ranges. |
218 | MutableOperandRangeRange(const MutableOperandRange &operands, |
219 | NamedAttribute operandSegmentAttr); |
220 | |
221 | /// Flatten all of the sub ranges into a single contiguous mutable operand |
222 | /// range. |
223 | MutableOperandRange join() const; |
224 | |
225 | /// Allow implicit conversion to an OperandRangeRange. |
226 | operator OperandRangeRange() const; |
227 | |
228 | private: |
229 | /// See `llvm::indexed_accessor_range` for details. |
230 | static MutableOperandRange dereference(const OwnerT &object, ptrdiff_t index); |
231 | |
232 | /// Allow access to `dereference_iterator`. |
233 | friend RangeBaseT; |
234 | }; |
235 | |
236 | //===----------------------------------------------------------------------===// |
237 | // ResultRange |
238 | |
239 | /// This class implements the result iterators for the Operation class. |
240 | class ResultRange final |
241 | : public llvm::detail::indexed_accessor_range_base< |
242 | ResultRange, detail::OpResultImpl *, OpResult, OpResult, OpResult> { |
243 | public: |
244 | using RangeBaseT::RangeBaseT; |
245 | ResultRange(OpResult result); |
246 | |
247 | //===--------------------------------------------------------------------===// |
248 | // Types |
249 | //===--------------------------------------------------------------------===// |
250 | |
251 | /// Returns the types of the values within this range. |
252 | using type_iterator = ValueTypeIterator<iterator>; |
253 | using type_range = ValueTypeRange<ResultRange>; |
254 | type_range getTypes() const; |
255 | type_range getType() const; |
256 | |
257 | //===--------------------------------------------------------------------===// |
258 | // Uses |
259 | //===--------------------------------------------------------------------===// |
260 | |
261 | class UseIterator; |
262 | using use_iterator = UseIterator; |
263 | using use_range = iterator_range<use_iterator>; |
264 | |
265 | /// Returns a range of all uses of results within this range, which is useful |
266 | /// for iterating over all uses. |
267 | use_range getUses() const; |
268 | use_iterator use_begin() const; |
269 | use_iterator use_end() const; |
270 | |
271 | /// Returns true if no results in this range have uses. |
272 | bool use_empty() const { |
273 | return llvm::all_of(Range: *this, |
274 | P: [](OpResult result) { return result.use_empty(); }); |
275 | } |
276 | |
277 | /// Replace all uses of results of this range with the provided 'values'. The |
278 | /// size of `values` must match the size of this range. |
279 | template <typename ValuesT> |
280 | std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value> |
281 | replaceAllUsesWith(ValuesT &&values) { |
282 | assert(static_cast<size_t>(std::distance(values.begin(), values.end())) == |
283 | size() && |
284 | "expected 'values' to correspond 1-1 with the number of results" ); |
285 | |
286 | for (auto it : llvm::zip(*this, values)) |
287 | std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); |
288 | } |
289 | |
290 | /// Replace all uses of results of this range with results of 'op'. |
291 | void replaceAllUsesWith(Operation *op); |
292 | |
293 | /// Replace uses of results of this range with the provided 'values' if the |
294 | /// given callback returns true. The size of `values` must match the size of |
295 | /// this range. |
296 | template <typename ValuesT> |
297 | std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value> |
298 | replaceUsesWithIf(ValuesT &&values, |
299 | function_ref<bool(OpOperand &)> shouldReplace) { |
300 | assert(static_cast<size_t>(std::distance(values.begin(), values.end())) == |
301 | size() && |
302 | "expected 'values' to correspond 1-1 with the number of results" ); |
303 | |
304 | for (auto it : llvm::zip(*this, values)) |
305 | std::get<0>(it).replaceUsesWithIf(std::get<1>(it), shouldReplace); |
306 | } |
307 | |
308 | /// Replace uses of results of this range with results of `op` if the given |
309 | /// callback returns true. |
310 | void replaceUsesWithIf(Operation *op, |
311 | function_ref<bool(OpOperand &)> shouldReplace); |
312 | |
313 | //===--------------------------------------------------------------------===// |
314 | // Users |
315 | //===--------------------------------------------------------------------===// |
316 | |
317 | using user_iterator = ValueUserIterator<use_iterator, OpOperand>; |
318 | using user_range = iterator_range<user_iterator>; |
319 | |
320 | /// Returns a range of all users. |
321 | user_range getUsers(); |
322 | user_iterator user_begin(); |
323 | user_iterator user_end(); |
324 | |
325 | private: |
326 | /// See `llvm::detail::indexed_accessor_range_base` for details. |
327 | static detail::OpResultImpl *offset_base(detail::OpResultImpl *object, |
328 | ptrdiff_t index) { |
329 | return object->getNextResultAtOffset(offset: index); |
330 | } |
331 | /// See `llvm::detail::indexed_accessor_range_base` for details. |
332 | static OpResult dereference_iterator(detail::OpResultImpl *object, |
333 | ptrdiff_t index) { |
334 | return offset_base(object, index); |
335 | } |
336 | |
337 | /// Allow access to `offset_base` and `dereference_iterator`. |
338 | friend RangeBaseT; |
339 | }; |
340 | |
341 | /// This class implements a use iterator for a range of operation results. |
342 | /// This iterates over all uses of all results within the given result range. |
343 | class ResultRange::UseIterator final |
344 | : public llvm::iterator_facade_base<UseIterator, std::forward_iterator_tag, |
345 | OpOperand> { |
346 | public: |
347 | /// Initialize the UseIterator. Specify `end` to return iterator to last |
348 | /// use, otherwise this is an iterator to the first use. |
349 | explicit UseIterator(ResultRange results, bool end = false); |
350 | |
351 | using llvm::iterator_facade_base<UseIterator, std::forward_iterator_tag, |
352 | OpOperand>::operator++; |
353 | UseIterator &operator++(); |
354 | OpOperand *operator->() const { return use.getOperand(); } |
355 | OpOperand &operator*() const { return *use.getOperand(); } |
356 | |
357 | bool operator==(const UseIterator &rhs) const { return use == rhs.use; } |
358 | bool operator!=(const UseIterator &rhs) const { return !(*this == rhs); } |
359 | |
360 | private: |
361 | void skipOverResultsWithNoUsers(); |
362 | |
363 | /// The range of results being iterated over. |
364 | ResultRange::iterator it, endIt; |
365 | /// The use of the result. |
366 | Value::use_iterator use; |
367 | }; |
368 | |
369 | //===----------------------------------------------------------------------===// |
370 | // ValueRange |
371 | |
372 | /// This class provides an abstraction over the different types of ranges over |
373 | /// Values. In many cases, this prevents the need to explicitly materialize a |
374 | /// SmallVector/std::vector. This class should be used in places that are not |
375 | /// suitable for a more derived type (e.g. ArrayRef) or a template range |
376 | /// parameter. |
377 | class ValueRange final |
378 | : public llvm::detail::indexed_accessor_range_base< |
379 | ValueRange, |
380 | PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>, |
381 | Value, Value, Value> { |
382 | public: |
383 | /// The type representing the owner of a ValueRange. This is either a list of |
384 | /// values, operands, or results. |
385 | using OwnerT = |
386 | PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>; |
387 | |
388 | using RangeBaseT::RangeBaseT; |
389 | |
390 | template <typename Arg, |
391 | typename = std::enable_if_t< |
392 | std::is_constructible<ArrayRef<Value>, Arg>::value && |
393 | !std::is_convertible<Arg, Value>::value>> |
394 | ValueRange(Arg &&arg) : ValueRange(ArrayRef<Value>(std::forward<Arg>(arg))) {} |
395 | ValueRange(const Value &value) : ValueRange(&value, /*count=*/1) {} |
396 | ValueRange(const std::initializer_list<Value> &values) |
397 | : ValueRange(ArrayRef<Value>(values)) {} |
398 | ValueRange(iterator_range<OperandRange::iterator> values) |
399 | : ValueRange(OperandRange(values)) {} |
400 | ValueRange(iterator_range<ResultRange::iterator> values) |
401 | : ValueRange(ResultRange(values)) {} |
402 | ValueRange(ArrayRef<BlockArgument> values) |
403 | : ValueRange(ArrayRef<Value>(values.data(), values.size())) {} |
404 | ValueRange(ArrayRef<Value> values = std::nullopt); |
405 | ValueRange(OperandRange values); |
406 | ValueRange(ResultRange values); |
407 | |
408 | /// Returns the types of the values within this range. |
409 | using type_iterator = ValueTypeIterator<iterator>; |
410 | using type_range = ValueTypeRange<ValueRange>; |
411 | type_range getTypes() const; |
412 | type_range getType() const; |
413 | |
414 | private: |
415 | /// See `llvm::detail::indexed_accessor_range_base` for details. |
416 | static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index); |
417 | /// See `llvm::detail::indexed_accessor_range_base` for details. |
418 | static Value dereference_iterator(const OwnerT &owner, ptrdiff_t index); |
419 | |
420 | /// Allow access to `offset_base` and `dereference_iterator`. |
421 | friend RangeBaseT; |
422 | }; |
423 | |
424 | } // namespace mlir |
425 | |
426 | #endif // MLIR_IR_VALUERANGE_H |
427 | |