1 | //===- TypeRange.h ----------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the TypeRange and ValueTypeRange classes. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_TYPERANGE_H |
14 | #define MLIR_IR_TYPERANGE_H |
15 | |
16 | #include "mlir/IR/Types.h" |
17 | #include "mlir/IR/Value.h" |
18 | #include "mlir/IR/ValueRange.h" |
19 | #include "llvm/ADT/PointerUnion.h" |
20 | #include "llvm/ADT/Sequence.h" |
21 | |
22 | namespace mlir { |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // TypeRange |
26 | |
27 | /// This class provides an abstraction over the various different ranges of |
28 | /// value types. In many cases, this prevents the need to explicitly materialize |
29 | /// a SmallVector/std::vector. This class should be used in places that are not |
30 | /// suitable for a more derived type (e.g. ArrayRef) or a template range |
31 | /// parameter. |
32 | class TypeRange : public llvm::detail::indexed_accessor_range_base< |
33 | TypeRange, |
34 | llvm::PointerUnion<const Value *, const Type *, |
35 | OpOperand *, detail::OpResultImpl *>, |
36 | Type, Type, Type> { |
37 | public: |
38 | using RangeBaseT::RangeBaseT; |
39 | TypeRange(ArrayRef<Type> types = std::nullopt); |
40 | explicit TypeRange(OperandRange values); |
41 | explicit TypeRange(ResultRange values); |
42 | explicit TypeRange(ValueRange values); |
43 | template <typename ValueRangeT> |
44 | TypeRange(ValueTypeRange<ValueRangeT> values) |
45 | : TypeRange(ValueRange(ValueRangeT(values.begin().getCurrent(), |
46 | values.end().getCurrent()))) {} |
47 | template <typename Arg, typename = std::enable_if_t<std::is_constructible< |
48 | ArrayRef<Type>, Arg>::value>> |
49 | TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {} |
50 | TypeRange(std::initializer_list<Type> types) |
51 | : TypeRange(ArrayRef<Type>(types)) {} |
52 | |
53 | private: |
54 | /// The owner of the range is either: |
55 | /// * A pointer to the first element of an array of values. |
56 | /// * A pointer to the first element of an array of types. |
57 | /// * A pointer to the first element of an array of operands. |
58 | /// * A pointer to the first element of an array of results. |
59 | using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *, |
60 | detail::OpResultImpl *>; |
61 | |
62 | /// See `llvm::detail::indexed_accessor_range_base` for details. |
63 | static OwnerT offset_base(OwnerT object, ptrdiff_t index); |
64 | /// See `llvm::detail::indexed_accessor_range_base` for details. |
65 | static Type dereference_iterator(OwnerT object, ptrdiff_t index); |
66 | |
67 | /// Allow access to `offset_base` and `dereference_iterator`. |
68 | friend RangeBaseT; |
69 | }; |
70 | |
71 | /// Make TypeRange hashable. |
72 | inline ::llvm::hash_code hash_value(TypeRange arg) { |
73 | return ::llvm::hash_combine_range(first: arg.begin(), last: arg.end()); |
74 | } |
75 | |
76 | /// Emit a type range to the given output stream. |
77 | inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) { |
78 | llvm::interleaveComma(c: types, os); |
79 | return os; |
80 | } |
81 | |
82 | //===----------------------------------------------------------------------===// |
83 | // TypeRangeRange |
84 | |
85 | using TypeRangeRangeIterator = |
86 | llvm::mapped_iterator<llvm::iota_range<unsigned>::iterator, |
87 | std::function<TypeRange(unsigned)>>; |
88 | |
89 | /// This class provides an abstraction for a range of TypeRange. This is useful |
90 | /// when accessing the types of a range of ranges, such as when using |
91 | /// OperandRangeRange. |
92 | class TypeRangeRange : public llvm::iterator_range<TypeRangeRangeIterator> { |
93 | public: |
94 | template <typename RangeT> |
95 | TypeRangeRange(const RangeT &range) |
96 | : TypeRangeRange(llvm::seq<unsigned>(0, range.size()), range) {} |
97 | |
98 | private: |
99 | template <typename RangeT> |
100 | TypeRangeRange(llvm::iota_range<unsigned> sizeRange, const RangeT &range) |
101 | : llvm::iterator_range<TypeRangeRangeIterator>( |
102 | {sizeRange.begin(), getRangeFn(range)}, |
103 | {sizeRange.end(), nullptr}) {} |
104 | |
105 | template <typename RangeT> |
106 | static std::function<TypeRange(unsigned)> getRangeFn(const RangeT &range) { |
107 | return [=](unsigned index) -> TypeRange { return TypeRange(range[index]); }; |
108 | } |
109 | }; |
110 | |
111 | //===----------------------------------------------------------------------===// |
112 | // ValueTypeRange |
113 | |
114 | /// This class implements iteration on the types of a given range of values. |
115 | template <typename ValueIteratorT> |
116 | class ValueTypeIterator final |
117 | : public llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>, |
118 | ValueIteratorT, Type> { |
119 | public: |
120 | using llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>, |
121 | ValueIteratorT, Type>::mapped_iterator_base; |
122 | |
123 | /// Map the element to the iterator result type. |
124 | Type mapElement(Value value) const { return value.getType(); } |
125 | }; |
126 | |
127 | /// This class implements iteration on the types of a given range of values. |
128 | template <typename ValueRangeT> |
129 | class ValueTypeRange final |
130 | : public llvm::iterator_range< |
131 | ValueTypeIterator<typename ValueRangeT::iterator>> { |
132 | public: |
133 | using llvm::iterator_range< |
134 | ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range; |
135 | template <typename Container> |
136 | ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {} |
137 | |
138 | /// Return the type at the given index. |
139 | Type operator[](size_t index) const { |
140 | assert(index < size() && "invalid index into type range" ); |
141 | return *(this->begin() + index); |
142 | } |
143 | |
144 | /// Return the size of this range. |
145 | size_t size() const { return llvm::size(*this); } |
146 | |
147 | /// Return first type in the range. |
148 | Type front() { return (*this)[0]; } |
149 | |
150 | /// Compare this range with another. |
151 | template <typename OtherT> |
152 | bool operator==(const OtherT &other) const { |
153 | return llvm::size(*this) == llvm::size(other) && |
154 | std::equal(this->begin(), this->end(), other.begin()); |
155 | } |
156 | template <typename OtherT> |
157 | bool operator!=(const OtherT &other) const { |
158 | return !(*this == other); |
159 | } |
160 | }; |
161 | |
162 | template <typename RangeT> |
163 | inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) { |
164 | return lhs.size() == static_cast<size_t>(llvm::size(rhs)) && |
165 | std::equal(lhs.begin(), lhs.end(), rhs.begin()); |
166 | } |
167 | |
168 | //===----------------------------------------------------------------------===// |
169 | // SubElements |
170 | //===----------------------------------------------------------------------===// |
171 | |
172 | /// Enable TypeRange to be introspected for sub-elements. |
173 | template <> |
174 | struct AttrTypeSubElementHandler<TypeRange> { |
175 | static void walk(TypeRange param, AttrTypeImmediateSubElementWalker &walker) { |
176 | walker.walkRange(elements&: param); |
177 | } |
178 | static TypeRange replace(TypeRange param, |
179 | AttrSubElementReplacements &attrRepls, |
180 | TypeSubElementReplacements &typeRepls) { |
181 | return typeRepls.take_front(n: param.size()); |
182 | } |
183 | }; |
184 | |
185 | } // namespace mlir |
186 | |
187 | namespace llvm { |
188 | |
189 | // Provide DenseMapInfo for TypeRange. |
190 | template <> |
191 | struct DenseMapInfo<mlir::TypeRange> { |
192 | static mlir::TypeRange getEmptyKey() { |
193 | return mlir::TypeRange(getEmptyKeyPointer(), 0); |
194 | } |
195 | |
196 | static mlir::TypeRange getTombstoneKey() { |
197 | return mlir::TypeRange(getTombstoneKeyPointer(), 0); |
198 | } |
199 | |
200 | static unsigned getHashValue(mlir::TypeRange val) { return hash_value(arg: val); } |
201 | |
202 | static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) { |
203 | if (isEmptyKey(range: rhs)) |
204 | return isEmptyKey(range: lhs); |
205 | if (isTombstoneKey(range: rhs)) |
206 | return isTombstoneKey(range: lhs); |
207 | return lhs == rhs; |
208 | } |
209 | |
210 | private: |
211 | static const mlir::Type *getEmptyKeyPointer() { |
212 | return DenseMapInfo<mlir::Type *>::getEmptyKey(); |
213 | } |
214 | |
215 | static const mlir::Type *getTombstoneKeyPointer() { |
216 | return DenseMapInfo<mlir::Type *>::getTombstoneKey(); |
217 | } |
218 | |
219 | static bool isEmptyKey(mlir::TypeRange range) { |
220 | if (const auto *type = |
221 | llvm::dyn_cast_if_present<const mlir::Type *>(Val: range.getBase())) |
222 | return type == getEmptyKeyPointer(); |
223 | return false; |
224 | } |
225 | |
226 | static bool isTombstoneKey(mlir::TypeRange range) { |
227 | if (const auto *type = |
228 | llvm::dyn_cast_if_present<const mlir::Type *>(Val: range.getBase())) |
229 | return type == getTombstoneKeyPointer(); |
230 | return false; |
231 | } |
232 | }; |
233 | |
234 | } // namespace llvm |
235 | |
236 | #endif // MLIR_IR_TYPERANGE_H |
237 | |