1//===- Traits.cpp - Common op traits shared by dialects -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Traits.h"
10#include "mlir/IR/BuiltinTypes.h"
11#include "mlir/IR/TypeUtilities.h"
12#include "llvm/Support/FormatVariadic.h"
13#include <optional>
14
15using namespace mlir;
16
17bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
18 ArrayRef<int64_t> shape2) {
19 SmallVector<SmallVector<int64_t, 6>, 2> extents;
20 extents.emplace_back(Args: shape1.begin(), Args: shape1.end());
21 extents.emplace_back(Args: shape2.begin(), Args: shape2.end());
22 return staticallyKnownBroadcastable(shapes: extents);
23}
24
25bool OpTrait::util::staticallyKnownBroadcastable(
26 ArrayRef<SmallVector<int64_t, 6>> shapes) {
27 assert(!shapes.empty() && "Expected at least one shape");
28 size_t maxRank = shapes[0].size();
29 for (size_t i = 1; i != shapes.size(); ++i)
30 maxRank = std::max(a: maxRank, b: shapes[i].size());
31
32 // We look backwards through every column of `shapes`.
33 for (size_t i = 0; i != maxRank; ++i) {
34 bool seenDynamic = false;
35 std::optional<int64_t> nonOneDim;
36 for (ArrayRef<int64_t> extent : shapes) {
37 int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
38
39 if (dim == 1)
40 continue;
41
42 // Dimensions are compatible when
43 //. 1. One is dynamic, the rest are 1
44 if (ShapedType::isDynamic(dim)) {
45 if (seenDynamic || nonOneDim)
46 return false;
47 seenDynamic = true;
48 }
49
50 // 2. All are 1 or a specific constant.
51 if (nonOneDim && dim != *nonOneDim)
52 return false;
53
54 nonOneDim = dim;
55 }
56 }
57 return true;
58}
59
60bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
61 ArrayRef<int64_t> shape2,
62 SmallVectorImpl<int64_t> &resultShape) {
63 // To compute the result broadcasted shape, we compare operand shapes
64 // element-wise: starting with the trailing dimensions, and working the
65 // way backward. Two dimensions are compatible when
66 // 1. they are equal, or
67 // 2. one of them is 1
68 // The result shape has the maximum among the two inputs at every
69 // dimension index.
70
71 resultShape.clear();
72 if (shape1.size() > shape2.size()) {
73 std::copy(first: shape1.begin(), last: shape1.end(), result: std::back_inserter(x&: resultShape));
74 } else {
75 std::copy(first: shape2.begin(), last: shape2.end(), result: std::back_inserter(x&: resultShape));
76 }
77
78 auto i1 = shape1.rbegin(), e1 = shape1.rend();
79 auto i2 = shape2.rbegin(), e2 = shape2.rend();
80 auto iR = resultShape.rbegin();
81
82 // Check each dimension is consistent.
83 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
84 if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
85 // One or both dimensions is unknown. Follow TensorFlow behavior:
86 // - If either dimension is greater than 1, we assume that the program is
87 // correct, and the other dimension will be broadcast to match it.
88 // - If either dimension is 1, the other dimension is the output.
89 if (*i1 > 1) {
90 *iR = *i1;
91 } else if (*i2 > 1) {
92 *iR = *i2;
93 } else if (*i1 == 1) {
94 *iR = *i2;
95 } else if (*i2 == 1) {
96 *iR = *i1;
97 } else {
98 *iR = ShapedType::kDynamic;
99 }
100 } else {
101 if (*i1 == *i2 || *i2 == 1) {
102 *iR = *i1;
103 } else if (*i1 == 1) {
104 *iR = *i2;
105 } else {
106 // This dimension of the two operand types is incompatible.
107 resultShape.clear();
108 return false;
109 }
110 }
111 }
112
113 return true;
114}
115
116/// Returns the shape of the given type. Scalars will be considered as having a
117/// shape with zero dimensions.
118static ArrayRef<int64_t> getShape(Type type) {
119 if (auto sType = dyn_cast<ShapedType>(type))
120 return sType.getShape();
121 return {};
122}
123
124/// Returns the result broadcast composition type from the two given types by
125/// following NumPy broadcast semantics. Returned type may have dynamic shape if
126/// either of the input types has dynamic shape. Returns null type if the two
127/// given types are not broadcast-compatible.
128///
129/// elementType, if specified, will be used as the element type of the
130/// broadcasted result type. Otherwise it is required that the element type of
131/// type1 and type2 is the same and this element type will be used as the
132/// resultant element type.
133Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
134 Type elementType) {
135 // If the elementType is not specified, then the use the common element type
136 // of the inputs or fail if there is no common element type.
137 if (!elementType) {
138 elementType = getElementTypeOrSelf(type: type1);
139 if (elementType != getElementTypeOrSelf(type: type2))
140 return {};
141 }
142
143 // If one of the types is unranked tensor, then the other type shouldn't be
144 // vector and the result should have unranked tensor type.
145 if (isa<UnrankedTensorType>(Val: type1) || isa<UnrankedTensorType>(Val: type2)) {
146 if (isa<VectorType>(Val: type1) || isa<VectorType>(Val: type2))
147 return {};
148 return UnrankedTensorType::get(elementType);
149 }
150
151 // Returns the type kind if the given type is a vector or ranked tensor type.
152 // Returns std::nullopt otherwise.
153 auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> {
154 if (isa<VectorType, RankedTensorType>(type))
155 return type.getTypeID();
156 return std::nullopt;
157 };
158
159 // Make sure the composite type, if has, is consistent.
160 std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
161 std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
162 std::optional<TypeID> resultCompositeKind;
163
164 if (compositeKind1 && compositeKind2) {
165 // Disallow mixing vector and tensor.
166 if (compositeKind1 != compositeKind2)
167 return {};
168 resultCompositeKind = compositeKind1;
169 } else if (compositeKind1) {
170 resultCompositeKind = compositeKind1;
171 } else if (compositeKind2) {
172 resultCompositeKind = compositeKind2;
173 }
174
175 // Get the shape of each type.
176 SmallVector<int64_t, 4> resultShape;
177 if (!getBroadcastedShape(shape1: getShape(type: type1), shape2: getShape(type: type2), resultShape))
178 return {};
179
180 // Compose the final broadcasted type
181 if (resultCompositeKind == VectorType::getTypeID())
182 return VectorType::get(resultShape, elementType);
183 if (resultCompositeKind == RankedTensorType::getTypeID())
184 return RankedTensorType::get(resultShape, elementType);
185 return elementType;
186}
187
188/// Returns a tuple corresponding to whether range has tensor or vector type.
189template <typename iterator_range>
190static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
191 return {llvm::any_of(types, llvm::IsaPred<TensorType>),
192 llvm::any_of(types, llvm::IsaPred<VectorType>)};
193}
194
195static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
196 ArrayRef<int64_t> existing) {
197 // If both interred and existing dimensions are static, they must be equal.
198 auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
199 return ShapedType::isDynamic(existingDim) ||
200 ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
201 };
202 if (inferred.size() != existing.size())
203 return false;
204 for (auto [inferredDim, existingDim] : llvm::zip_equal(t&: inferred, u&: existing))
205 if (!isCompatible(inferredDim, existingDim))
206 return false;
207 return true;
208}
209
210static std::string getShapeString(ArrayRef<int64_t> shape) {
211 // TODO: should replace with printing shape more uniformly across here and
212 // when in type.
213 std::string ret;
214 llvm::raw_string_ostream ss(ret);
215 ss << '\'';
216 llvm::interleave(
217 c: shape, os&: ss,
218 each_fn: [&](int64_t dim) {
219 if (ShapedType::isDynamic(dim))
220 ss << '?';
221 else
222 ss << dim;
223 },
224 separator: "x");
225 ss << '\'';
226 return ss.str();
227}
228
229LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
230 // Ensure broadcasting only tensor or only vector types.
231 auto operandsHasTensorVectorType =
232 hasTensorOrVectorType(types: op->getOperandTypes());
233 auto resultsHasTensorVectorType = hasTensorOrVectorType(types: op->getResultTypes());
234 if ((std::get<0>(t&: operandsHasTensorVectorType) ||
235 std::get<0>(t&: resultsHasTensorVectorType)) &&
236 (std::get<1>(t&: operandsHasTensorVectorType) ||
237 std::get<1>(t&: resultsHasTensorVectorType)))
238 return op->emitError(message: "cannot broadcast vector with tensor");
239
240 auto rankedOperands =
241 make_filter_range(Range: op->getOperandTypes(), Pred: llvm::IsaPred<RankedTensorType>);
242
243 // If all operands are unranked, then all result shapes are possible.
244 if (rankedOperands.empty())
245 return success();
246
247 // Compute broadcasted shape of operands (which requires that operands are
248 // broadcast compatible). The results need to be broadcast compatible with
249 // this result shape.
250 SmallVector<int64_t, 4> resultShape;
251 (void)util::getBroadcastedShape(shape1: getShape(type: *rankedOperands.begin()), shape2: {},
252 resultShape);
253 for (auto other : make_early_inc_range(Range&: rankedOperands)) {
254 SmallVector<int64_t, 4> temp = resultShape;
255 if (!util::getBroadcastedShape(shape1: temp, shape2: getShape(type: other), resultShape))
256 return op->emitOpError(message: "operands don't have broadcast-compatible shapes");
257 }
258
259 auto rankedResults =
260 make_filter_range(Range: op->getResultTypes(), Pred: llvm::IsaPred<RankedTensorType>);
261
262 // If all of the results are unranked then no further verification.
263 if (rankedResults.empty())
264 return success();
265
266 for (auto type : rankedResults) {
267 ArrayRef<int64_t> actualSuffix =
268 getShape(type).take_back(N: resultShape.size());
269 if (!isCompatibleInferredReturnShape(inferred: resultShape, existing: actualSuffix))
270 return op->emitOpError()
271 << "result type " << getShapeString(shape: getShape(type))
272 << " not broadcast compatible with broadcasted operands's shapes "
273 << getShapeString(shape: resultShape);
274 }
275 return success();
276}
277

source code of mlir/lib/Dialect/Traits.cpp