| 1 | //===- Traits.h - Common op traits shared by dialects -----------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file declares common op traits that are not core to MLIR but can be |
| 10 | // shared by multiple dialects. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #ifndef MLIR_DIALECT_TRAITS_H |
| 15 | #define MLIR_DIALECT_TRAITS_H |
| 16 | |
| 17 | #include "mlir/IR/OpDefinition.h" |
| 18 | |
| 19 | namespace mlir { |
| 20 | namespace OpTrait { |
| 21 | |
| 22 | // These functions are out-of-line implementations of the methods in the |
| 23 | // corresponding trait classes. This avoids them being template |
| 24 | // instantiated/duplicated. |
| 25 | namespace impl { |
| 26 | LogicalResult verifyCompatibleOperandBroadcast(Operation *op); |
| 27 | } // namespace impl |
| 28 | |
| 29 | namespace util { |
| 30 | /// Returns true and sets `resultShape` to the broadcasted shape from the two |
| 31 | /// given shapes if they are broadcast compatible. Returns false and clears |
| 32 | /// `resultShape` otherwise. |
| 33 | /// |
| 34 | /// The rules for determining the result shape are: |
| 35 | /// |
| 36 | /// Zip together the dimensions in the two given shapes by prepending the shape |
| 37 | /// with less dimensions with 1s. For each dimension pair, deduces the result |
| 38 | /// dimension according to the following order: |
| 39 | /// - If there are unknown dimensions, follows the TensorFlow behavior: |
| 40 | /// - If either dimension is greater than 1, we assume that the program is |
| 41 | /// correct, and the other dimension will be broadcast to match it. |
| 42 | /// - If either dimension is 1, the other dimension is the result. |
| 43 | /// - Otherwise, the result dimension is unknown dimension. |
| 44 | /// - If one of the dimension is 1, the other dimension is the result. |
| 45 | /// - If two dimensions are the same, that's the result. |
| 46 | /// - Otherwise, incompatible shape. |
| 47 | bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2, |
| 48 | SmallVectorImpl<int64_t> &resultShape); |
| 49 | |
| 50 | /// Returns true if a broadcast between n shapes is guaranteed to be |
| 51 | /// successful and not result in an error. False does not guarantee that the |
| 52 | /// shapes are not broadcastable; it might guarantee that they are not |
| 53 | /// broadcastable or it might mean that this function does not have enough |
| 54 | /// information to know. |
| 55 | /// |
| 56 | /// Conceptually, this returns true if getBroadcastedShape would have returned |
| 57 | /// true and vice versa, with one exception. If a dimension is unknown in both |
| 58 | /// shapes, getBroadcastedShape would return true and have a result with unknown |
| 59 | /// dimension, while this function will return false because it's possible for |
| 60 | /// both shapes to have a dimension greater than 1 and different which would |
| 61 | /// fail to broadcast. |
| 62 | bool staticallyKnownBroadcastable(ArrayRef<SmallVector<int64_t, 6>> shapes); |
| 63 | bool staticallyKnownBroadcastable(ArrayRef<int64_t> shape1, |
| 64 | ArrayRef<int64_t> shape2); |
| 65 | |
| 66 | /// Returns the result broadcast composition type from the two given types by |
| 67 | /// following NumPy broadcast semantics. Returned type may have dynamic shape if |
| 68 | /// either of the input types has dynamic shape. Returns null type if the two |
| 69 | /// given types are not broadcast-compatible. |
| 70 | /// |
| 71 | /// elementType, if specified, will be used as the element type of the |
| 72 | /// broadcasted result type. Otherwise it is required that the element type of |
| 73 | /// type1 and type2 is the same and this element type will be used as the |
| 74 | /// resultant element type. |
| 75 | Type getBroadcastedType(Type type1, Type type2, Type elementType = nullptr); |
| 76 | |
| 77 | } // namespace util |
| 78 | |
| 79 | /// Trait for ops that are known to have broadcast compatible operands and |
| 80 | /// result types. Specifically, starting from the most varying dimension, each |
| 81 | /// dimension pair of the operands' shapes should either be the same or one |
| 82 | /// of them is one. Also, the results's shapes should have the corresponding |
| 83 | /// dimension equal to the larger one, if known. Shapes are checked partially if |
| 84 | /// ranks or dimensions are not known. For example, an op with tensor<?x2xf32> |
| 85 | /// and tensor<2xf32> as operand types and tensor<5x3x2xi16> as the result |
| 86 | /// type has broadcast compatible operands ns result types. |
| 87 | template <typename ConcreteType> |
| 88 | class ResultsBroadcastableShape |
| 89 | : public TraitBase<ConcreteType, ResultsBroadcastableShape> { |
| 90 | public: |
| 91 | static LogicalResult verifyTrait(Operation *op) { |
| 92 | return impl::verifyCompatibleOperandBroadcast(op); |
| 93 | } |
| 94 | }; |
| 95 | |
| 96 | } // namespace OpTrait |
| 97 | } // namespace mlir |
| 98 | |
| 99 | #endif // MLIR_DIALECT_TRAITS_H |
| 100 | |