| 1 | //===- Traits.cpp - Common op traits shared by dialects -------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Traits.h" |
| 10 | #include "mlir/IR/BuiltinTypes.h" |
| 11 | #include "mlir/IR/TypeUtilities.h" |
| 12 | #include "llvm/Support/FormatVariadic.h" |
| 13 | #include <optional> |
| 14 | |
| 15 | using namespace mlir; |
| 16 | |
| 17 | bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1, |
| 18 | ArrayRef<int64_t> shape2) { |
| 19 | SmallVector<SmallVector<int64_t, 6>, 2> extents; |
| 20 | extents.emplace_back(Args: shape1.begin(), Args: shape1.end()); |
| 21 | extents.emplace_back(Args: shape2.begin(), Args: shape2.end()); |
| 22 | return staticallyKnownBroadcastable(shapes: extents); |
| 23 | } |
| 24 | |
| 25 | bool OpTrait::util::staticallyKnownBroadcastable( |
| 26 | ArrayRef<SmallVector<int64_t, 6>> shapes) { |
| 27 | assert(!shapes.empty() && "Expected at least one shape" ); |
| 28 | size_t maxRank = shapes[0].size(); |
| 29 | for (size_t i = 1; i != shapes.size(); ++i) |
| 30 | maxRank = std::max(a: maxRank, b: shapes[i].size()); |
| 31 | |
| 32 | // We look backwards through every column of `shapes`. |
| 33 | for (size_t i = 0; i != maxRank; ++i) { |
| 34 | bool seenDynamic = false; |
| 35 | std::optional<int64_t> nonOneDim; |
| 36 | for (ArrayRef<int64_t> extent : shapes) { |
| 37 | int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1]; |
| 38 | |
| 39 | if (dim == 1) |
| 40 | continue; |
| 41 | |
| 42 | // Dimensions are compatible when |
| 43 | //. 1. One is dynamic, the rest are 1 |
| 44 | if (ShapedType::isDynamic(dim)) { |
| 45 | if (seenDynamic || nonOneDim) |
| 46 | return false; |
| 47 | seenDynamic = true; |
| 48 | } |
| 49 | |
| 50 | // 2. All are 1 or a specific constant. |
| 51 | if (nonOneDim && dim != *nonOneDim) |
| 52 | return false; |
| 53 | |
| 54 | nonOneDim = dim; |
| 55 | } |
| 56 | } |
| 57 | return true; |
| 58 | } |
| 59 | |
| 60 | bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, |
| 61 | ArrayRef<int64_t> shape2, |
| 62 | SmallVectorImpl<int64_t> &resultShape) { |
| 63 | // To compute the result broadcasted shape, we compare operand shapes |
| 64 | // element-wise: starting with the trailing dimensions, and working the |
| 65 | // way backward. Two dimensions are compatible when |
| 66 | // 1. they are equal, or |
| 67 | // 2. one of them is 1 |
| 68 | // The result shape has the maximum among the two inputs at every |
| 69 | // dimension index. |
| 70 | |
| 71 | resultShape.clear(); |
| 72 | if (shape1.size() > shape2.size()) { |
| 73 | llvm::append_range(C&: resultShape, R&: shape1); |
| 74 | } else { |
| 75 | llvm::append_range(C&: resultShape, R&: shape2); |
| 76 | } |
| 77 | |
| 78 | auto i1 = shape1.rbegin(), e1 = shape1.rend(); |
| 79 | auto i2 = shape2.rbegin(), e2 = shape2.rend(); |
| 80 | auto iR = resultShape.rbegin(); |
| 81 | |
| 82 | // Check each dimension is consistent. |
| 83 | for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { |
| 84 | if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) { |
| 85 | // One or both dimensions is unknown. Follow TensorFlow behavior: |
| 86 | // - If either dimension is greater than 1, we assume that the program is |
| 87 | // correct, and the other dimension will be broadcasted to match it. |
| 88 | // - If either dimension is 1, the other dimension is the output. |
| 89 | if (*i1 > 1) { |
| 90 | *iR = *i1; |
| 91 | } else if (*i2 > 1) { |
| 92 | *iR = *i2; |
| 93 | } else if (*i1 == 1) { |
| 94 | *iR = *i2; |
| 95 | } else if (*i2 == 1) { |
| 96 | *iR = *i1; |
| 97 | } else { |
| 98 | *iR = ShapedType::kDynamic; |
| 99 | } |
| 100 | } else { |
| 101 | if (*i1 == *i2 || *i2 == 1) { |
| 102 | *iR = *i1; |
| 103 | } else if (*i1 == 1) { |
| 104 | *iR = *i2; |
| 105 | } else { |
| 106 | // This dimension of the two operand types is incompatible. |
| 107 | resultShape.clear(); |
| 108 | return false; |
| 109 | } |
| 110 | } |
| 111 | } |
| 112 | |
| 113 | return true; |
| 114 | } |
| 115 | |
| 116 | /// Returns the shape of the given type. Scalars will be considered as having a |
| 117 | /// shape with zero dimensions. |
| 118 | static ArrayRef<int64_t> getShape(Type type) { |
| 119 | if (auto sType = dyn_cast<ShapedType>(type)) |
| 120 | return sType.getShape(); |
| 121 | return {}; |
| 122 | } |
| 123 | |
| 124 | /// Returns the result broadcast composition type from the two given types by |
| 125 | /// following NumPy broadcast semantics. Returned type may have dynamic shape if |
| 126 | /// either of the input types has dynamic shape. Returns null type if the two |
| 127 | /// given types are not broadcast-compatible. |
| 128 | /// |
| 129 | /// elementType, if specified, will be used as the element type of the |
| 130 | /// broadcasted result type. Otherwise it is required that the element type of |
| 131 | /// type1 and type2 is the same and this element type will be used as the |
| 132 | /// resultant element type. |
| 133 | Type OpTrait::util::getBroadcastedType(Type type1, Type type2, |
| 134 | Type elementType) { |
| 135 | // If the elementType is not specified, then the use the common element type |
| 136 | // of the inputs or fail if there is no common element type. |
| 137 | if (!elementType) { |
| 138 | elementType = getElementTypeOrSelf(type: type1); |
| 139 | if (elementType != getElementTypeOrSelf(type: type2)) |
| 140 | return {}; |
| 141 | } |
| 142 | |
| 143 | // If one of the types is unranked tensor, then the other type shouldn't be |
| 144 | // vector and the result should have unranked tensor type. |
| 145 | if (isa<UnrankedTensorType>(Val: type1) || isa<UnrankedTensorType>(Val: type2)) { |
| 146 | if (isa<VectorType>(type1) || isa<VectorType>(type2)) |
| 147 | return {}; |
| 148 | return UnrankedTensorType::get(elementType); |
| 149 | } |
| 150 | |
| 151 | // Returns the type kind if the given type is a vector or ranked tensor type. |
| 152 | // Returns std::nullopt otherwise. |
| 153 | auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> { |
| 154 | if (isa<VectorType, RankedTensorType>(type)) |
| 155 | return type.getTypeID(); |
| 156 | return std::nullopt; |
| 157 | }; |
| 158 | |
| 159 | // Make sure the composite type, if has, is consistent. |
| 160 | std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1); |
| 161 | std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2); |
| 162 | std::optional<TypeID> resultCompositeKind; |
| 163 | |
| 164 | if (compositeKind1 && compositeKind2) { |
| 165 | // Disallow mixing vector and tensor. |
| 166 | if (compositeKind1 != compositeKind2) |
| 167 | return {}; |
| 168 | resultCompositeKind = compositeKind1; |
| 169 | } else if (compositeKind1) { |
| 170 | resultCompositeKind = compositeKind1; |
| 171 | } else if (compositeKind2) { |
| 172 | resultCompositeKind = compositeKind2; |
| 173 | } |
| 174 | |
| 175 | // Get the shape of each type. |
| 176 | SmallVector<int64_t, 4> resultShape; |
| 177 | if (!getBroadcastedShape(shape1: getShape(type: type1), shape2: getShape(type: type2), resultShape)) |
| 178 | return {}; |
| 179 | |
| 180 | // Compose the final broadcasted type |
| 181 | if (resultCompositeKind == VectorType::getTypeID()) |
| 182 | return VectorType::get(resultShape, elementType); |
| 183 | if (resultCompositeKind == RankedTensorType::getTypeID()) |
| 184 | return RankedTensorType::get(resultShape, elementType); |
| 185 | return elementType; |
| 186 | } |
| 187 | |
| 188 | /// Returns a tuple corresponding to whether range has tensor or vector type. |
| 189 | template <typename iterator_range> |
| 190 | static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) { |
| 191 | return {llvm::any_of(types, llvm::IsaPred<TensorType>), |
| 192 | llvm::any_of(types, llvm::IsaPred<VectorType>)}; |
| 193 | } |
| 194 | |
| 195 | static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred, |
| 196 | ArrayRef<int64_t> existing) { |
| 197 | // If both interred and existing dimensions are static, they must be equal. |
| 198 | auto isCompatible = [](int64_t inferredDim, int64_t existingDim) { |
| 199 | return ShapedType::isDynamic(existingDim) || |
| 200 | ShapedType::isDynamic(inferredDim) || inferredDim == existingDim; |
| 201 | }; |
| 202 | if (inferred.size() != existing.size()) |
| 203 | return false; |
| 204 | for (auto [inferredDim, existingDim] : llvm::zip_equal(t&: inferred, u&: existing)) |
| 205 | if (!isCompatible(inferredDim, existingDim)) |
| 206 | return false; |
| 207 | return true; |
| 208 | } |
| 209 | |
| 210 | static std::string getShapeString(ArrayRef<int64_t> shape) { |
| 211 | // TODO: should replace with printing shape more uniformly across here and |
| 212 | // when in type. |
| 213 | std::string ret; |
| 214 | llvm::raw_string_ostream ss(ret); |
| 215 | ss << '\''; |
| 216 | llvm::interleave( |
| 217 | c: shape, os&: ss, |
| 218 | each_fn: [&](int64_t dim) { |
| 219 | if (ShapedType::isDynamic(dim)) |
| 220 | ss << '?'; |
| 221 | else |
| 222 | ss << dim; |
| 223 | }, |
| 224 | separator: "x" ); |
| 225 | ss << '\''; |
| 226 | return ret; |
| 227 | } |
| 228 | |
| 229 | LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { |
| 230 | // Ensure broadcasting only tensor or only vector types. |
| 231 | auto operandsHasTensorVectorType = |
| 232 | hasTensorOrVectorType(types: op->getOperandTypes()); |
| 233 | auto resultsHasTensorVectorType = hasTensorOrVectorType(types: op->getResultTypes()); |
| 234 | if ((std::get<0>(t&: operandsHasTensorVectorType) || |
| 235 | std::get<0>(t&: resultsHasTensorVectorType)) && |
| 236 | (std::get<1>(t&: operandsHasTensorVectorType) || |
| 237 | std::get<1>(t&: resultsHasTensorVectorType))) |
| 238 | return op->emitError(message: "cannot broadcast vector with tensor" ); |
| 239 | |
| 240 | auto rankedOperands = |
| 241 | make_filter_range(Range: op->getOperandTypes(), Pred: llvm::IsaPred<RankedTensorType>); |
| 242 | |
| 243 | // If all operands are unranked, then all result shapes are possible. |
| 244 | if (rankedOperands.empty()) |
| 245 | return success(); |
| 246 | |
| 247 | // Compute broadcasted shape of operands (which requires that operands are |
| 248 | // broadcast compatible). The results need to be broadcast compatible with |
| 249 | // this result shape. |
| 250 | SmallVector<int64_t, 4> resultShape; |
| 251 | (void)util::getBroadcastedShape(shape1: getShape(type: *rankedOperands.begin()), shape2: {}, |
| 252 | resultShape); |
| 253 | for (auto other : make_early_inc_range(Range&: rankedOperands)) { |
| 254 | SmallVector<int64_t, 4> temp = resultShape; |
| 255 | if (!util::getBroadcastedShape(shape1: temp, shape2: getShape(type: other), resultShape)) |
| 256 | return op->emitOpError(message: "operands don't have broadcast-compatible shapes" ); |
| 257 | } |
| 258 | |
| 259 | auto rankedResults = |
| 260 | make_filter_range(Range: op->getResultTypes(), Pred: llvm::IsaPred<RankedTensorType>); |
| 261 | |
| 262 | // If all of the results are unranked then no further verification. |
| 263 | if (rankedResults.empty()) |
| 264 | return success(); |
| 265 | |
| 266 | for (auto type : rankedResults) { |
| 267 | ArrayRef<int64_t> actualSuffix = |
| 268 | getShape(type).take_back(N: resultShape.size()); |
| 269 | if (!isCompatibleInferredReturnShape(inferred: resultShape, existing: actualSuffix)) |
| 270 | return op->emitOpError() |
| 271 | << "result type " << getShapeString(shape: getShape(type)) |
| 272 | << " not broadcast compatible with broadcasted operands's shapes " |
| 273 | << getShapeString(shape: resultShape); |
| 274 | } |
| 275 | return success(); |
| 276 | } |
| 277 | |