1 | //===- BroadcastShapeTest.cpp - broadcasting shape unit tests -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Traits.h" |
10 | #include "llvm/ADT/SmallVector.h" |
11 | #include "gmock/gmock.h" |
12 | |
13 | using namespace mlir::OpTrait::util; |
14 | |
15 | using llvm::SmallVector; |
16 | using ::testing::ElementsAre; |
17 | |
18 | TEST(BroadcastShapeTest, CompatibleScalarAndScalar) { |
19 | SmallVector<int64_t, 4> result; |
20 | ASSERT_TRUE(getBroadcastedShape({}, {}, result)); |
21 | EXPECT_TRUE(result.empty()); |
22 | } |
23 | |
24 | TEST(BroadcastShapeTest, Compatible0DAnd1DTensor) { |
25 | SmallVector<int64_t, 4> result; |
26 | ASSERT_TRUE(getBroadcastedShape({}, {4}, result)); |
27 | EXPECT_THAT(result, ElementsAre(4)); |
28 | } |
29 | |
30 | TEST(BroadcastShapeTest, Compatible0DAnd3DTensor) { |
31 | SmallVector<int64_t, 4> result; |
32 | ASSERT_TRUE(getBroadcastedShape({}, {3, 5, 4}, result)); |
33 | EXPECT_THAT(result, ElementsAre(3, 5, 4)); |
34 | } |
35 | |
36 | TEST(BroadcastShapeTest, CompatibleTensorAndTensor) { |
37 | SmallVector<int64_t, 4> result; |
38 | ASSERT_TRUE(getBroadcastedShape({1, 7, 8, 9}, {8, 9}, result)); |
39 | EXPECT_THAT(result, ElementsAre(1, 7, 8, 9)); |
40 | } |
41 | |
42 | TEST(BroadcastShapeTest, InterleavingOnes) { |
43 | SmallVector<int64_t, 4> result; |
44 | ASSERT_TRUE(getBroadcastedShape({8, 1, 2, 1, 4}, {5, 1, 7, 1}, result)); |
45 | EXPECT_THAT(result, ElementsAre(8, 5, 2, 7, 4)); |
46 | } |
47 | |
48 | TEST(BroadcastShapeTest, InterleavingUnknowns) { |
49 | SmallVector<int64_t, 4> result; |
50 | int64_t dyn = mlir::ShapedType::kDynamic; |
51 | ASSERT_TRUE(getBroadcastedShape({1, 2, dyn, dyn, dyn}, {dyn, dyn, dyn, 4, 1}, |
52 | result)); |
53 | EXPECT_THAT(result, ElementsAre(dyn, 2, dyn, 4, dyn)); |
54 | } |
55 | |
56 | TEST(BroadcastShapeTest, IncompatibleLowDim) { |
57 | SmallVector<int64_t, 4> result; |
58 | ASSERT_FALSE(getBroadcastedShape({4, 3, 5, 5}, {3, 5, 4}, result)); |
59 | EXPECT_TRUE(result.empty()); |
60 | } |
61 | |
62 | TEST(BroadcastShapeTest, IncompatibleMiddleDim) { |
63 | SmallVector<int64_t, 4> result; |
64 | ASSERT_FALSE(getBroadcastedShape({4, 3, 5, 5}, {3, 7, 5}, result)); |
65 | EXPECT_TRUE(result.empty()); |
66 | } |
67 | |
68 | TEST(BroadcastShapeTest, IncompatibleHighDim) { |
69 | SmallVector<int64_t, 4> result; |
70 | ASSERT_FALSE(getBroadcastedShape({3, 5, 5}, {4, 5, 5}, result)); |
71 | EXPECT_TRUE(result.empty()); |
72 | } |
73 | |