| 1 | //===- TypeUtilities.cpp - Helper function for type queries ---------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines generic type utilities. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/IR/TypeUtilities.h" |
| 14 | #include "mlir/IR/Attributes.h" |
| 15 | #include "mlir/IR/BuiltinTypes.h" |
| 16 | #include "mlir/IR/Types.h" |
| 17 | #include "mlir/IR/Value.h" |
| 18 | #include "llvm/ADT/SmallVectorExtras.h" |
| 19 | #include <numeric> |
| 20 | |
| 21 | using namespace mlir; |
| 22 | |
| 23 | Type mlir::getElementTypeOrSelf(Type type) { |
| 24 | if (auto st = llvm::dyn_cast<ShapedType>(type)) |
| 25 | return st.getElementType(); |
| 26 | return type; |
| 27 | } |
| 28 | |
| 29 | Type mlir::getElementTypeOrSelf(Value val) { |
| 30 | return getElementTypeOrSelf(type: val.getType()); |
| 31 | } |
| 32 | |
| 33 | Type mlir::getElementTypeOrSelf(Attribute attr) { |
| 34 | if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) |
| 35 | return getElementTypeOrSelf(typedAttr.getType()); |
| 36 | return {}; |
| 37 | } |
| 38 | |
| 39 | SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) { |
| 40 | SmallVector<Type, 10> fTypes; |
| 41 | t.getFlattenedTypes(fTypes); |
| 42 | return fTypes; |
| 43 | } |
| 44 | |
| 45 | /// Return true if the specified type is an opaque type with the specified |
| 46 | /// dialect and typeData. |
| 47 | bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect, |
| 48 | StringRef typeData) { |
| 49 | if (auto opaque = llvm::dyn_cast<mlir::OpaqueType>(type)) |
| 50 | return opaque.getDialectNamespace() == dialect && |
| 51 | opaque.getTypeData() == typeData; |
| 52 | return false; |
| 53 | } |
| 54 | |
| 55 | /// Returns success if the given two shapes are compatible. That is, they have |
| 56 | /// the same size and each pair of the elements are equal or one of them is |
| 57 | /// dynamic. |
| 58 | LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1, |
| 59 | ArrayRef<int64_t> shape2) { |
| 60 | if (shape1.size() != shape2.size()) |
| 61 | return failure(); |
| 62 | for (auto dims : llvm::zip(t&: shape1, u&: shape2)) { |
| 63 | int64_t dim1 = std::get<0>(t&: dims); |
| 64 | int64_t dim2 = std::get<1>(t&: dims); |
| 65 | if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && |
| 66 | dim1 != dim2) |
| 67 | return failure(); |
| 68 | } |
| 69 | return success(); |
| 70 | } |
| 71 | |
| 72 | /// Returns success if the given two types have compatible shape. That is, |
| 73 | /// they are both scalars (not shaped), or they are both shaped types and at |
| 74 | /// least one is unranked or they have compatible dimensions. Dimensions are |
| 75 | /// compatible if at least one is dynamic or both are equal. The element type |
| 76 | /// does not matter. |
| 77 | LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { |
| 78 | auto sType1 = llvm::dyn_cast<ShapedType>(type1); |
| 79 | auto sType2 = llvm::dyn_cast<ShapedType>(type2); |
| 80 | |
| 81 | // Either both or neither type should be shaped. |
| 82 | if (!sType1) |
| 83 | return success(!sType2); |
| 84 | if (!sType2) |
| 85 | return failure(); |
| 86 | |
| 87 | if (!sType1.hasRank() || !sType2.hasRank()) |
| 88 | return success(); |
| 89 | |
| 90 | return verifyCompatibleShape(sType1.getShape(), sType2.getShape()); |
| 91 | } |
| 92 | |
| 93 | /// Returns success if the given two arrays have the same number of elements and |
| 94 | /// each pair wise entries have compatible shape. |
| 95 | LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) { |
| 96 | if (types1.size() != types2.size()) |
| 97 | return failure(); |
| 98 | for (auto it : llvm::zip_first(t&: types1, u&: types2)) |
| 99 | if (failed(Result: verifyCompatibleShape(type1: std::get<0>(t&: it), type2: std::get<1>(t&: it)))) |
| 100 | return failure(); |
| 101 | return success(); |
| 102 | } |
| 103 | |
| 104 | LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) { |
| 105 | if (dims.empty()) |
| 106 | return success(); |
| 107 | auto staticDim = std::accumulate( |
| 108 | first: dims.begin(), last: dims.end(), init: dims.front(), binary_op: [](auto fold, auto dim) { |
| 109 | return ShapedType::isDynamic(dim) ? fold : dim; |
| 110 | }); |
| 111 | return success(IsSuccess: llvm::all_of(Range&: dims, P: [&](auto dim) { |
| 112 | return ShapedType::isDynamic(dim) || dim == staticDim; |
| 113 | })); |
| 114 | } |
| 115 | |
| 116 | /// Returns success if all given types have compatible shapes. That is, they are |
| 117 | /// all scalars (not shaped), or they are all shaped types and any ranked shapes |
| 118 | /// have compatible dimensions. Dimensions are compatible if all non-dynamic |
| 119 | /// dims are equal. The element type does not matter. |
| 120 | LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { |
| 121 | auto shapedTypes = llvm::map_to_vector<8>( |
| 122 | types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); }); |
| 123 | // Return failure if some, but not all are not shaped. Return early if none |
| 124 | // are shaped also. |
| 125 | if (llvm::none_of(shapedTypes, [](auto t) { return t; })) |
| 126 | return success(); |
| 127 | if (!llvm::all_of(shapedTypes, [](auto t) { return t; })) |
| 128 | return failure(); |
| 129 | |
| 130 | // Return failure if some, but not all, are scalable vectors. |
| 131 | bool hasScalableVecTypes = false; |
| 132 | bool hasNonScalableVecTypes = false; |
| 133 | for (Type t : types) { |
| 134 | auto vType = llvm::dyn_cast<VectorType>(t); |
| 135 | if (vType && vType.isScalable()) |
| 136 | hasScalableVecTypes = true; |
| 137 | else |
| 138 | hasNonScalableVecTypes = true; |
| 139 | if (hasScalableVecTypes && hasNonScalableVecTypes) |
| 140 | return failure(); |
| 141 | } |
| 142 | |
| 143 | // Remove all unranked shapes |
| 144 | auto shapes = llvm::filter_to_vector<8>( |
| 145 | shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }); |
| 146 | if (shapes.empty()) |
| 147 | return success(); |
| 148 | |
| 149 | // All ranks should be equal |
| 150 | auto firstRank = shapes.front().getRank(); |
| 151 | if (llvm::any_of(shapes, |
| 152 | [&](auto shape) { return firstRank != shape.getRank(); })) |
| 153 | return failure(); |
| 154 | |
| 155 | for (unsigned i = 0; i < firstRank; ++i) { |
| 156 | // Retrieve all ranked dimensions |
| 157 | auto dims = llvm::map_to_vector<8>( |
| 158 | llvm::make_filter_range( |
| 159 | shapes, [&](auto shape) { return shape.getRank() >= i; }), |
| 160 | [&](auto shape) { return shape.getDimSize(i); }); |
| 161 | if (verifyCompatibleDims(dims).failed()) |
| 162 | return failure(); |
| 163 | } |
| 164 | |
| 165 | return success(); |
| 166 | } |
| 167 | |
| 168 | Type OperandElementTypeIterator::mapElement(Value value) const { |
| 169 | return llvm::cast<ShapedType>(value.getType()).getElementType(); |
| 170 | } |
| 171 | |
| 172 | Type ResultElementTypeIterator::mapElement(Value value) const { |
| 173 | return llvm::cast<ShapedType>(value.getType()).getElementType(); |
| 174 | } |
| 175 | |
| 176 | TypeRange mlir::insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices, |
| 177 | TypeRange newTypes, |
| 178 | SmallVectorImpl<Type> &storage) { |
| 179 | assert(indices.size() == newTypes.size() && |
| 180 | "mismatch between indice and type count" ); |
| 181 | if (indices.empty()) |
| 182 | return oldTypes; |
| 183 | |
| 184 | auto fromIt = oldTypes.begin(); |
| 185 | for (auto it : llvm::zip(t&: indices, u&: newTypes)) { |
| 186 | const auto toIt = oldTypes.begin() + std::get<0>(t&: it); |
| 187 | storage.append(in_start: fromIt, in_end: toIt); |
| 188 | storage.push_back(Elt: std::get<1>(t&: it)); |
| 189 | fromIt = toIt; |
| 190 | } |
| 191 | storage.append(in_start: fromIt, in_end: oldTypes.end()); |
| 192 | return storage; |
| 193 | } |
| 194 | |
| 195 | TypeRange mlir::filterTypesOut(TypeRange types, const BitVector &indices, |
| 196 | SmallVectorImpl<Type> &storage) { |
| 197 | if (indices.none()) |
| 198 | return types; |
| 199 | |
| 200 | for (unsigned i = 0, e = types.size(); i < e; ++i) |
| 201 | if (!indices[i]) |
| 202 | storage.emplace_back(Args: types[i]); |
| 203 | return storage; |
| 204 | } |
| 205 | |