1 | //===- Traits.cpp - Common op traits shared by dialects -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Traits.h" |
10 | #include "mlir/IR/BuiltinTypes.h" |
11 | #include "mlir/IR/TypeUtilities.h" |
12 | #include "llvm/Support/FormatVariadic.h" |
13 | #include <optional> |
14 | |
15 | using namespace mlir; |
16 | |
17 | bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1, |
18 | ArrayRef<int64_t> shape2) { |
19 | SmallVector<SmallVector<int64_t, 6>, 2> extents; |
20 | extents.emplace_back(Args: shape1.begin(), Args: shape1.end()); |
21 | extents.emplace_back(Args: shape2.begin(), Args: shape2.end()); |
22 | return staticallyKnownBroadcastable(shapes: extents); |
23 | } |
24 | |
25 | bool OpTrait::util::staticallyKnownBroadcastable( |
26 | ArrayRef<SmallVector<int64_t, 6>> shapes) { |
27 | assert(!shapes.empty() && "Expected at least one shape" ); |
28 | size_t maxRank = shapes[0].size(); |
29 | for (size_t i = 1; i != shapes.size(); ++i) |
30 | maxRank = std::max(a: maxRank, b: shapes[i].size()); |
31 | |
32 | // We look backwards through every column of `shapes`. |
33 | for (size_t i = 0; i != maxRank; ++i) { |
34 | bool seenDynamic = false; |
35 | std::optional<int64_t> nonOneDim; |
36 | for (ArrayRef<int64_t> extent : shapes) { |
37 | int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1]; |
38 | |
39 | if (dim == 1) |
40 | continue; |
41 | |
42 | // Dimensions are compatible when |
43 | //. 1. One is dynamic, the rest are 1 |
44 | if (ShapedType::isDynamic(dim)) { |
45 | if (seenDynamic || nonOneDim) |
46 | return false; |
47 | seenDynamic = true; |
48 | } |
49 | |
50 | // 2. All are 1 or a specific constant. |
51 | if (nonOneDim && dim != *nonOneDim) |
52 | return false; |
53 | |
54 | nonOneDim = dim; |
55 | } |
56 | } |
57 | return true; |
58 | } |
59 | |
60 | bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, |
61 | ArrayRef<int64_t> shape2, |
62 | SmallVectorImpl<int64_t> &resultShape) { |
63 | // To compute the result broadcasted shape, we compare operand shapes |
64 | // element-wise: starting with the trailing dimensions, and working the |
65 | // way backward. Two dimensions are compatible when |
66 | // 1. they are equal, or |
67 | // 2. one of them is 1 |
68 | // The result shape has the maximum among the two inputs at every |
69 | // dimension index. |
70 | |
71 | resultShape.clear(); |
72 | if (shape1.size() > shape2.size()) { |
73 | std::copy(first: shape1.begin(), last: shape1.end(), result: std::back_inserter(x&: resultShape)); |
74 | } else { |
75 | std::copy(first: shape2.begin(), last: shape2.end(), result: std::back_inserter(x&: resultShape)); |
76 | } |
77 | |
78 | auto i1 = shape1.rbegin(), e1 = shape1.rend(); |
79 | auto i2 = shape2.rbegin(), e2 = shape2.rend(); |
80 | auto iR = resultShape.rbegin(); |
81 | |
82 | // Check each dimension is consistent. |
83 | for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { |
84 | if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) { |
85 | // One or both dimensions is unknown. Follow TensorFlow behavior: |
86 | // - If either dimension is greater than 1, we assume that the program is |
87 | // correct, and the other dimension will be broadcast to match it. |
88 | // - If either dimension is 1, the other dimension is the output. |
89 | if (*i1 > 1) { |
90 | *iR = *i1; |
91 | } else if (*i2 > 1) { |
92 | *iR = *i2; |
93 | } else if (*i1 == 1) { |
94 | *iR = *i2; |
95 | } else if (*i2 == 1) { |
96 | *iR = *i1; |
97 | } else { |
98 | *iR = ShapedType::kDynamic; |
99 | } |
100 | } else { |
101 | if (*i1 == *i2 || *i2 == 1) { |
102 | *iR = *i1; |
103 | } else if (*i1 == 1) { |
104 | *iR = *i2; |
105 | } else { |
106 | // This dimension of the two operand types is incompatible. |
107 | resultShape.clear(); |
108 | return false; |
109 | } |
110 | } |
111 | } |
112 | |
113 | return true; |
114 | } |
115 | |
116 | /// Returns the shape of the given type. Scalars will be considered as having a |
117 | /// shape with zero dimensions. |
118 | static ArrayRef<int64_t> getShape(Type type) { |
119 | if (auto sType = dyn_cast<ShapedType>(type)) |
120 | return sType.getShape(); |
121 | return {}; |
122 | } |
123 | |
124 | /// Returns the result broadcast composition type from the two given types by |
125 | /// following NumPy broadcast semantics. Returned type may have dynamic shape if |
126 | /// either of the input types has dynamic shape. Returns null type if the two |
127 | /// given types are not broadcast-compatible. |
128 | /// |
129 | /// elementType, if specified, will be used as the element type of the |
130 | /// broadcasted result type. Otherwise it is required that the element type of |
131 | /// type1 and type2 is the same and this element type will be used as the |
132 | /// resultant element type. |
133 | Type OpTrait::util::getBroadcastedType(Type type1, Type type2, |
134 | Type elementType) { |
135 | // If the elementType is not specified, then the use the common element type |
136 | // of the inputs or fail if there is no common element type. |
137 | if (!elementType) { |
138 | elementType = getElementTypeOrSelf(type: type1); |
139 | if (elementType != getElementTypeOrSelf(type: type2)) |
140 | return {}; |
141 | } |
142 | |
143 | // If one of the types is unranked tensor, then the other type shouldn't be |
144 | // vector and the result should have unranked tensor type. |
145 | if (isa<UnrankedTensorType>(Val: type1) || isa<UnrankedTensorType>(Val: type2)) { |
146 | if (isa<VectorType>(Val: type1) || isa<VectorType>(Val: type2)) |
147 | return {}; |
148 | return UnrankedTensorType::get(elementType); |
149 | } |
150 | |
151 | // Returns the type kind if the given type is a vector or ranked tensor type. |
152 | // Returns std::nullopt otherwise. |
153 | auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> { |
154 | if (isa<VectorType, RankedTensorType>(type)) |
155 | return type.getTypeID(); |
156 | return std::nullopt; |
157 | }; |
158 | |
159 | // Make sure the composite type, if has, is consistent. |
160 | std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1); |
161 | std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2); |
162 | std::optional<TypeID> resultCompositeKind; |
163 | |
164 | if (compositeKind1 && compositeKind2) { |
165 | // Disallow mixing vector and tensor. |
166 | if (compositeKind1 != compositeKind2) |
167 | return {}; |
168 | resultCompositeKind = compositeKind1; |
169 | } else if (compositeKind1) { |
170 | resultCompositeKind = compositeKind1; |
171 | } else if (compositeKind2) { |
172 | resultCompositeKind = compositeKind2; |
173 | } |
174 | |
175 | // Get the shape of each type. |
176 | SmallVector<int64_t, 4> resultShape; |
177 | if (!getBroadcastedShape(shape1: getShape(type: type1), shape2: getShape(type: type2), resultShape)) |
178 | return {}; |
179 | |
180 | // Compose the final broadcasted type |
181 | if (resultCompositeKind == VectorType::getTypeID()) |
182 | return VectorType::get(resultShape, elementType); |
183 | if (resultCompositeKind == RankedTensorType::getTypeID()) |
184 | return RankedTensorType::get(resultShape, elementType); |
185 | return elementType; |
186 | } |
187 | |
188 | /// Returns a tuple corresponding to whether range has tensor or vector type. |
189 | template <typename iterator_range> |
190 | static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) { |
191 | return {llvm::any_of(types, llvm::IsaPred<TensorType>), |
192 | llvm::any_of(types, llvm::IsaPred<VectorType>)}; |
193 | } |
194 | |
195 | static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred, |
196 | ArrayRef<int64_t> existing) { |
197 | // If both interred and existing dimensions are static, they must be equal. |
198 | auto isCompatible = [](int64_t inferredDim, int64_t existingDim) { |
199 | return ShapedType::isDynamic(existingDim) || |
200 | ShapedType::isDynamic(inferredDim) || inferredDim == existingDim; |
201 | }; |
202 | if (inferred.size() != existing.size()) |
203 | return false; |
204 | for (auto [inferredDim, existingDim] : llvm::zip_equal(t&: inferred, u&: existing)) |
205 | if (!isCompatible(inferredDim, existingDim)) |
206 | return false; |
207 | return true; |
208 | } |
209 | |
210 | static std::string getShapeString(ArrayRef<int64_t> shape) { |
211 | // TODO: should replace with printing shape more uniformly across here and |
212 | // when in type. |
213 | std::string ret; |
214 | llvm::raw_string_ostream ss(ret); |
215 | ss << '\''; |
216 | llvm::interleave( |
217 | c: shape, os&: ss, |
218 | each_fn: [&](int64_t dim) { |
219 | if (ShapedType::isDynamic(dim)) |
220 | ss << '?'; |
221 | else |
222 | ss << dim; |
223 | }, |
224 | separator: "x" ); |
225 | ss << '\''; |
226 | return ss.str(); |
227 | } |
228 | |
229 | LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { |
230 | // Ensure broadcasting only tensor or only vector types. |
231 | auto operandsHasTensorVectorType = |
232 | hasTensorOrVectorType(types: op->getOperandTypes()); |
233 | auto resultsHasTensorVectorType = hasTensorOrVectorType(types: op->getResultTypes()); |
234 | if ((std::get<0>(t&: operandsHasTensorVectorType) || |
235 | std::get<0>(t&: resultsHasTensorVectorType)) && |
236 | (std::get<1>(t&: operandsHasTensorVectorType) || |
237 | std::get<1>(t&: resultsHasTensorVectorType))) |
238 | return op->emitError(message: "cannot broadcast vector with tensor" ); |
239 | |
240 | auto rankedOperands = |
241 | make_filter_range(Range: op->getOperandTypes(), Pred: llvm::IsaPred<RankedTensorType>); |
242 | |
243 | // If all operands are unranked, then all result shapes are possible. |
244 | if (rankedOperands.empty()) |
245 | return success(); |
246 | |
247 | // Compute broadcasted shape of operands (which requires that operands are |
248 | // broadcast compatible). The results need to be broadcast compatible with |
249 | // this result shape. |
250 | SmallVector<int64_t, 4> resultShape; |
251 | (void)util::getBroadcastedShape(shape1: getShape(type: *rankedOperands.begin()), shape2: {}, |
252 | resultShape); |
253 | for (auto other : make_early_inc_range(Range&: rankedOperands)) { |
254 | SmallVector<int64_t, 4> temp = resultShape; |
255 | if (!util::getBroadcastedShape(shape1: temp, shape2: getShape(type: other), resultShape)) |
256 | return op->emitOpError(message: "operands don't have broadcast-compatible shapes" ); |
257 | } |
258 | |
259 | auto rankedResults = |
260 | make_filter_range(Range: op->getResultTypes(), Pred: llvm::IsaPred<RankedTensorType>); |
261 | |
262 | // If all of the results are unranked then no further verification. |
263 | if (rankedResults.empty()) |
264 | return success(); |
265 | |
266 | for (auto type : rankedResults) { |
267 | ArrayRef<int64_t> actualSuffix = |
268 | getShape(type).take_back(N: resultShape.size()); |
269 | if (!isCompatibleInferredReturnShape(inferred: resultShape, existing: actualSuffix)) |
270 | return op->emitOpError() |
271 | << "result type " << getShapeString(shape: getShape(type)) |
272 | << " not broadcast compatible with broadcasted operands's shapes " |
273 | << getShapeString(shape: resultShape); |
274 | } |
275 | return success(); |
276 | } |
277 | |