1 | //===- TensorSpecTest.cpp - test for TensorSpec ---------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/Analysis/TensorSpec.h" |
10 | #include "llvm/Support/Path.h" |
11 | #include "llvm/Support/SourceMgr.h" |
12 | #include "llvm/Testing/Support/SupportHelpers.h" |
13 | #include "gtest/gtest.h" |
14 | |
15 | using namespace llvm; |
16 | |
17 | extern const char *TestMainArgv0; |
18 | |
19 | TEST(TensorSpecTest, JSONParsing) { |
20 | auto Value = json::parse( |
21 | JSON: R"({"name": "tensor_name", |
22 | "port": 2, |
23 | "type": "int32_t", |
24 | "shape":[1,4] |
25 | })" ); |
26 | EXPECT_TRUE(!!Value); |
27 | LLVMContext Ctx; |
28 | std::optional<TensorSpec> Spec = getTensorSpecFromJSON(Ctx, Value: *Value); |
29 | EXPECT_TRUE(Spec); |
30 | EXPECT_EQ(*Spec, TensorSpec::createSpec<int32_t>("tensor_name" , {1, 4}, 2)); |
31 | } |
32 | |
33 | TEST(TensorSpecTest, JSONParsingInvalidTensorType) { |
34 | auto Value = json::parse( |
35 | JSON: R"( |
36 | {"name": "tensor_name", |
37 | "port": 2, |
38 | "type": "no such type", |
39 | "shape":[1,4] |
40 | } |
41 | )" ); |
42 | EXPECT_TRUE(!!Value); |
43 | LLVMContext Ctx; |
44 | auto Spec = getTensorSpecFromJSON(Ctx, Value: *Value); |
45 | EXPECT_FALSE(Spec); |
46 | } |
47 | |
48 | TEST(TensorSpecTest, TensorSpecSizesAndTypes) { |
49 | auto Spec1D = TensorSpec::createSpec<int16_t>(Name: "Hi1" , Shape: {1}); |
50 | auto Spec2D = TensorSpec::createSpec<int16_t>(Name: "Hi2" , Shape: {1, 1}); |
51 | auto Spec1DLarge = TensorSpec::createSpec<float>(Name: "Hi3" , Shape: {10}); |
52 | auto Spec3DLarge = TensorSpec::createSpec<float>(Name: "Hi3" , Shape: {2, 4, 10}); |
53 | EXPECT_TRUE(Spec1D.isElementType<int16_t>()); |
54 | EXPECT_FALSE(Spec3DLarge.isElementType<double>()); |
55 | EXPECT_EQ(Spec1D.getElementCount(), 1U); |
56 | EXPECT_EQ(Spec2D.getElementCount(), 1U); |
57 | EXPECT_EQ(Spec1DLarge.getElementCount(), 10U); |
58 | EXPECT_EQ(Spec3DLarge.getElementCount(), 80U); |
59 | EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float)); |
60 | EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t)); |
61 | } |
62 | |
63 | TEST(TensorSpecTest, PrintValueForDebug) { |
64 | std::vector<int32_t> Values{1, 3}; |
65 | EXPECT_EQ(tensorValueToString(reinterpret_cast<const char *>(Values.data()), |
66 | TensorSpec::createSpec<int32_t>("name" , {2})), |
67 | "1,3" ); |
68 | } |