1 | //===- TypeUtilities.cpp - Helper function for type queries ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines generic type utilities. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/IR/TypeUtilities.h" |
14 | #include "mlir/IR/Attributes.h" |
15 | #include "mlir/IR/BuiltinTypes.h" |
16 | #include "mlir/IR/Types.h" |
17 | #include "mlir/IR/Value.h" |
18 | #include "llvm/ADT/SmallVectorExtras.h" |
19 | #include <numeric> |
20 | |
21 | using namespace mlir; |
22 | |
23 | Type mlir::getElementTypeOrSelf(Type type) { |
24 | if (auto st = llvm::dyn_cast<ShapedType>(type)) |
25 | return st.getElementType(); |
26 | return type; |
27 | } |
28 | |
29 | Type mlir::getElementTypeOrSelf(Value val) { |
30 | return getElementTypeOrSelf(type: val.getType()); |
31 | } |
32 | |
33 | Type mlir::getElementTypeOrSelf(Attribute attr) { |
34 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) |
35 | return getElementTypeOrSelf(typedAttr.getType()); |
36 | return {}; |
37 | } |
38 | |
39 | SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) { |
40 | SmallVector<Type, 10> fTypes; |
41 | t.getFlattenedTypes(fTypes); |
42 | return fTypes; |
43 | } |
44 | |
45 | /// Return true if the specified type is an opaque type with the specified |
46 | /// dialect and typeData. |
47 | bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect, |
48 | StringRef typeData) { |
49 | if (auto opaque = llvm::dyn_cast<mlir::OpaqueType>(type)) |
50 | return opaque.getDialectNamespace() == dialect && |
51 | opaque.getTypeData() == typeData; |
52 | return false; |
53 | } |
54 | |
55 | /// Returns success if the given two shapes are compatible. That is, they have |
56 | /// the same size and each pair of the elements are equal or one of them is |
57 | /// dynamic. |
58 | LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1, |
59 | ArrayRef<int64_t> shape2) { |
60 | if (shape1.size() != shape2.size()) |
61 | return failure(); |
62 | for (auto dims : llvm::zip(t&: shape1, u&: shape2)) { |
63 | int64_t dim1 = std::get<0>(t&: dims); |
64 | int64_t dim2 = std::get<1>(t&: dims); |
65 | if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && |
66 | dim1 != dim2) |
67 | return failure(); |
68 | } |
69 | return success(); |
70 | } |
71 | |
72 | /// Returns success if the given two types have compatible shape. That is, |
73 | /// they are both scalars (not shaped), or they are both shaped types and at |
74 | /// least one is unranked or they have compatible dimensions. Dimensions are |
75 | /// compatible if at least one is dynamic or both are equal. The element type |
76 | /// does not matter. |
77 | LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { |
78 | auto sType1 = llvm::dyn_cast<ShapedType>(type1); |
79 | auto sType2 = llvm::dyn_cast<ShapedType>(type2); |
80 | |
81 | // Either both or neither type should be shaped. |
82 | if (!sType1) |
83 | return success(!sType2); |
84 | if (!sType2) |
85 | return failure(); |
86 | |
87 | if (!sType1.hasRank() || !sType2.hasRank()) |
88 | return success(); |
89 | |
90 | return verifyCompatibleShape(sType1.getShape(), sType2.getShape()); |
91 | } |
92 | |
93 | /// Returns success if the given two arrays have the same number of elements and |
94 | /// each pair wise entries have compatible shape. |
95 | LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) { |
96 | if (types1.size() != types2.size()) |
97 | return failure(); |
98 | for (auto it : llvm::zip_first(t&: types1, u&: types2)) |
99 | if (failed(result: verifyCompatibleShape(type1: std::get<0>(t&: it), type2: std::get<1>(t&: it)))) |
100 | return failure(); |
101 | return success(); |
102 | } |
103 | |
104 | LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) { |
105 | if (dims.empty()) |
106 | return success(); |
107 | auto staticDim = std::accumulate( |
108 | first: dims.begin(), last: dims.end(), init: dims.front(), binary_op: [](auto fold, auto dim) { |
109 | return ShapedType::isDynamic(dim) ? fold : dim; |
110 | }); |
111 | return success(isSuccess: llvm::all_of(Range&: dims, P: [&](auto dim) { |
112 | return ShapedType::isDynamic(dim) || dim == staticDim; |
113 | })); |
114 | } |
115 | |
116 | /// Returns success if all given types have compatible shapes. That is, they are |
117 | /// all scalars (not shaped), or they are all shaped types and any ranked shapes |
118 | /// have compatible dimensions. Dimensions are compatible if all non-dynamic |
119 | /// dims are equal. The element type does not matter. |
120 | LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { |
121 | auto shapedTypes = llvm::map_to_vector<8>( |
122 | types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); }); |
123 | // Return failure if some, but not all are not shaped. Return early if none |
124 | // are shaped also. |
125 | if (llvm::none_of(shapedTypes, [](auto t) { return t; })) |
126 | return success(); |
127 | if (!llvm::all_of(shapedTypes, [](auto t) { return t; })) |
128 | return failure(); |
129 | |
130 | // Return failure if some, but not all, are scalable vectors. |
131 | bool hasScalableVecTypes = false; |
132 | bool hasNonScalableVecTypes = false; |
133 | for (Type t : types) { |
134 | auto vType = llvm::dyn_cast<VectorType>(t); |
135 | if (vType && vType.isScalable()) |
136 | hasScalableVecTypes = true; |
137 | else |
138 | hasNonScalableVecTypes = true; |
139 | if (hasScalableVecTypes && hasNonScalableVecTypes) |
140 | return failure(); |
141 | } |
142 | |
143 | // Remove all unranked shapes |
144 | auto shapes = llvm::to_vector<8>(llvm::make_filter_range( |
145 | shapedTypes, [](auto shapedType) { return shapedType.hasRank(); })); |
146 | if (shapes.empty()) |
147 | return success(); |
148 | |
149 | // All ranks should be equal |
150 | auto firstRank = shapes.front().getRank(); |
151 | if (llvm::any_of(shapes, |
152 | [&](auto shape) { return firstRank != shape.getRank(); })) |
153 | return failure(); |
154 | |
155 | for (unsigned i = 0; i < firstRank; ++i) { |
156 | // Retrieve all ranked dimensions |
157 | auto dims = llvm::map_to_vector<8>( |
158 | llvm::make_filter_range( |
159 | shapes, [&](auto shape) { return shape.getRank() >= i; }), |
160 | [&](auto shape) { return shape.getDimSize(i); }); |
161 | if (verifyCompatibleDims(dims).failed()) |
162 | return failure(); |
163 | } |
164 | |
165 | return success(); |
166 | } |
167 | |
168 | Type OperandElementTypeIterator::mapElement(Value value) const { |
169 | return llvm::cast<ShapedType>(value.getType()).getElementType(); |
170 | } |
171 | |
172 | Type ResultElementTypeIterator::mapElement(Value value) const { |
173 | return llvm::cast<ShapedType>(value.getType()).getElementType(); |
174 | } |
175 | |
176 | TypeRange mlir::insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices, |
177 | TypeRange newTypes, |
178 | SmallVectorImpl<Type> &storage) { |
179 | assert(indices.size() == newTypes.size() && |
180 | "mismatch between indice and type count" ); |
181 | if (indices.empty()) |
182 | return oldTypes; |
183 | |
184 | auto fromIt = oldTypes.begin(); |
185 | for (auto it : llvm::zip(t&: indices, u&: newTypes)) { |
186 | const auto toIt = oldTypes.begin() + std::get<0>(t&: it); |
187 | storage.append(in_start: fromIt, in_end: toIt); |
188 | storage.push_back(Elt: std::get<1>(t&: it)); |
189 | fromIt = toIt; |
190 | } |
191 | storage.append(in_start: fromIt, in_end: oldTypes.end()); |
192 | return storage; |
193 | } |
194 | |
195 | TypeRange mlir::filterTypesOut(TypeRange types, const BitVector &indices, |
196 | SmallVectorImpl<Type> &storage) { |
197 | if (indices.none()) |
198 | return types; |
199 | |
200 | for (unsigned i = 0, e = types.size(); i < e; ++i) |
201 | if (!indices[i]) |
202 | storage.emplace_back(Args: types[i]); |
203 | return storage; |
204 | } |
205 | |