1//===- TypeRange.h ----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the TypeRange and ValueTypeRange classes.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_TYPERANGE_H
14#define MLIR_IR_TYPERANGE_H
15
16#include "mlir/IR/Types.h"
17#include "mlir/IR/Value.h"
18#include "mlir/IR/ValueRange.h"
19#include "llvm/ADT/PointerUnion.h"
20#include "llvm/ADT/Sequence.h"
21
22namespace mlir {
23
24//===----------------------------------------------------------------------===//
25// TypeRange
26
27/// This class provides an abstraction over the various different ranges of
28/// value types. In many cases, this prevents the need to explicitly materialize
29/// a SmallVector/std::vector. This class should be used in places that are not
30/// suitable for a more derived type (e.g. ArrayRef) or a template range
31/// parameter.
32class TypeRange : public llvm::detail::indexed_accessor_range_base<
33 TypeRange,
34 llvm::PointerUnion<const Value *, const Type *,
35 OpOperand *, detail::OpResultImpl *>,
36 Type, Type, Type> {
37public:
38 using RangeBaseT::RangeBaseT;
39 TypeRange(ArrayRef<Type> types = std::nullopt);
40 explicit TypeRange(OperandRange values);
41 explicit TypeRange(ResultRange values);
42 explicit TypeRange(ValueRange values);
43 template <typename ValueRangeT>
44 TypeRange(ValueTypeRange<ValueRangeT> values)
45 : TypeRange(ValueRange(ValueRangeT(values.begin().getCurrent(),
46 values.end().getCurrent()))) {}
47 template <typename Arg, typename = std::enable_if_t<std::is_constructible<
48 ArrayRef<Type>, Arg>::value>>
49 TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
50 TypeRange(std::initializer_list<Type> types)
51 : TypeRange(ArrayRef<Type>(types)) {}
52
53private:
54 /// The owner of the range is either:
55 /// * A pointer to the first element of an array of values.
56 /// * A pointer to the first element of an array of types.
57 /// * A pointer to the first element of an array of operands.
58 /// * A pointer to the first element of an array of results.
59 using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
60 detail::OpResultImpl *>;
61
62 /// See `llvm::detail::indexed_accessor_range_base` for details.
63 static OwnerT offset_base(OwnerT object, ptrdiff_t index);
64 /// See `llvm::detail::indexed_accessor_range_base` for details.
65 static Type dereference_iterator(OwnerT object, ptrdiff_t index);
66
67 /// Allow access to `offset_base` and `dereference_iterator`.
68 friend RangeBaseT;
69};
70
71/// Make TypeRange hashable.
72inline ::llvm::hash_code hash_value(TypeRange arg) {
73 return ::llvm::hash_combine_range(first: arg.begin(), last: arg.end());
74}
75
76/// Emit a type range to the given output stream.
77inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
78 llvm::interleaveComma(c: types, os);
79 return os;
80}
81
82//===----------------------------------------------------------------------===//
83// TypeRangeRange
84
85using TypeRangeRangeIterator =
86 llvm::mapped_iterator<llvm::iota_range<unsigned>::iterator,
87 std::function<TypeRange(unsigned)>>;
88
89/// This class provides an abstraction for a range of TypeRange. This is useful
90/// when accessing the types of a range of ranges, such as when using
91/// OperandRangeRange.
92class TypeRangeRange : public llvm::iterator_range<TypeRangeRangeIterator> {
93public:
94 template <typename RangeT>
95 TypeRangeRange(const RangeT &range)
96 : TypeRangeRange(llvm::seq<unsigned>(0, range.size()), range) {}
97
98private:
99 template <typename RangeT>
100 TypeRangeRange(llvm::iota_range<unsigned> sizeRange, const RangeT &range)
101 : llvm::iterator_range<TypeRangeRangeIterator>(
102 {sizeRange.begin(), getRangeFn(range)},
103 {sizeRange.end(), nullptr}) {}
104
105 template <typename RangeT>
106 static std::function<TypeRange(unsigned)> getRangeFn(const RangeT &range) {
107 return [=](unsigned index) -> TypeRange { return TypeRange(range[index]); };
108 }
109};
110
111//===----------------------------------------------------------------------===//
112// ValueTypeRange
113
114/// This class implements iteration on the types of a given range of values.
115template <typename ValueIteratorT>
116class ValueTypeIterator final
117 : public llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>,
118 ValueIteratorT, Type> {
119public:
120 using llvm::mapped_iterator_base<ValueTypeIterator<ValueIteratorT>,
121 ValueIteratorT, Type>::mapped_iterator_base;
122
123 /// Map the element to the iterator result type.
124 Type mapElement(Value value) const { return value.getType(); }
125};
126
127/// This class implements iteration on the types of a given range of values.
128template <typename ValueRangeT>
129class ValueTypeRange final
130 : public llvm::iterator_range<
131 ValueTypeIterator<typename ValueRangeT::iterator>> {
132public:
133 using llvm::iterator_range<
134 ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
135 template <typename Container>
136 ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
137
138 /// Return the type at the given index.
139 Type operator[](size_t index) const {
140 assert(index < size() && "invalid index into type range");
141 return *(this->begin() + index);
142 }
143
144 /// Return the size of this range.
145 size_t size() const { return llvm::size(*this); }
146
147 /// Return first type in the range.
148 Type front() { return (*this)[0]; }
149
150 /// Compare this range with another.
151 template <typename OtherT>
152 bool operator==(const OtherT &other) const {
153 return llvm::size(*this) == llvm::size(other) &&
154 std::equal(this->begin(), this->end(), other.begin());
155 }
156 template <typename OtherT>
157 bool operator!=(const OtherT &other) const {
158 return !(*this == other);
159 }
160};
161
162template <typename RangeT>
163inline bool operator==(ArrayRef<Type> lhs, const ValueTypeRange<RangeT> &rhs) {
164 return lhs.size() == static_cast<size_t>(llvm::size(rhs)) &&
165 std::equal(lhs.begin(), lhs.end(), rhs.begin());
166}
167
168//===----------------------------------------------------------------------===//
169// SubElements
170//===----------------------------------------------------------------------===//
171
172/// Enable TypeRange to be introspected for sub-elements.
173template <>
174struct AttrTypeSubElementHandler<TypeRange> {
175 static void walk(TypeRange param, AttrTypeImmediateSubElementWalker &walker) {
176 walker.walkRange(elements&: param);
177 }
178 static TypeRange replace(TypeRange param,
179 AttrSubElementReplacements &attrRepls,
180 TypeSubElementReplacements &typeRepls) {
181 return typeRepls.take_front(n: param.size());
182 }
183};
184
185} // namespace mlir
186
187namespace llvm {
188
189// Provide DenseMapInfo for TypeRange.
190template <>
191struct DenseMapInfo<mlir::TypeRange> {
192 static mlir::TypeRange getEmptyKey() {
193 return mlir::TypeRange(getEmptyKeyPointer(), 0);
194 }
195
196 static mlir::TypeRange getTombstoneKey() {
197 return mlir::TypeRange(getTombstoneKeyPointer(), 0);
198 }
199
200 static unsigned getHashValue(mlir::TypeRange val) { return hash_value(arg: val); }
201
202 static bool isEqual(mlir::TypeRange lhs, mlir::TypeRange rhs) {
203 if (isEmptyKey(range: rhs))
204 return isEmptyKey(range: lhs);
205 if (isTombstoneKey(range: rhs))
206 return isTombstoneKey(range: lhs);
207 return lhs == rhs;
208 }
209
210private:
211 static const mlir::Type *getEmptyKeyPointer() {
212 return DenseMapInfo<mlir::Type *>::getEmptyKey();
213 }
214
215 static const mlir::Type *getTombstoneKeyPointer() {
216 return DenseMapInfo<mlir::Type *>::getTombstoneKey();
217 }
218
219 static bool isEmptyKey(mlir::TypeRange range) {
220 if (const auto *type =
221 llvm::dyn_cast_if_present<const mlir::Type *>(Val: range.getBase()))
222 return type == getEmptyKeyPointer();
223 return false;
224 }
225
226 static bool isTombstoneKey(mlir::TypeRange range) {
227 if (const auto *type =
228 llvm::dyn_cast_if_present<const mlir::Type *>(Val: range.getBase()))
229 return type == getTombstoneKeyPointer();
230 return false;
231 }
232};
233
234} // namespace llvm
235
236#endif // MLIR_IR_TYPERANGE_H
237

source code of mlir/include/mlir/IR/TypeRange.h