1 | //===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===// |
2 | // |
3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/ExecutionEngine/CRunnerUtils.h" |
10 | #include "llvm/ADT/SmallVector.h" |
11 | |
12 | #include "gmock/gmock.h" |
13 | |
14 | using namespace ::mlir; |
15 | using namespace ::testing; |
16 | |
17 | TEST(DynamicMemRef, rankZero) { |
18 | int data = 57; |
19 | |
20 | StridedMemRefType<int, 0> memRef; |
21 | memRef.basePtr = &data; |
22 | memRef.data = &data; |
23 | memRef.offset = 0; |
24 | |
25 | DynamicMemRefType<int> dynamicMemRef(memRef); |
26 | |
27 | llvm::SmallVector<int, 1> values(dynamicMemRef.begin(), dynamicMemRef.end()); |
28 | EXPECT_THAT(values, ElementsAre(57)); |
29 | } |
30 | |
31 | TEST(DynamicMemRef, rankOne) { |
32 | std::array<int, 3> data; |
33 | |
34 | for (size_t i = 0; i < data.size(); ++i) { |
35 | data[i] = i; |
36 | } |
37 | |
38 | StridedMemRefType<int, 1> memRef; |
39 | memRef.basePtr = data.data(); |
40 | memRef.data = data.data(); |
41 | memRef.offset = 0; |
42 | memRef.sizes[0] = 3; |
43 | memRef.strides[0] = 1; |
44 | |
45 | DynamicMemRefType<int> dynamicMemRef(memRef); |
46 | |
47 | llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end()); |
48 | EXPECT_THAT(values, ElementsAreArray(data)); |
49 | |
50 | for (int64_t i = 0; i < 3; ++i) { |
51 | EXPECT_EQ(*dynamicMemRef[i], data[i]); |
52 | } |
53 | } |
54 | |
55 | TEST(DynamicMemRef, rankTwo) { |
56 | std::array<int, 6> data; |
57 | |
58 | for (size_t i = 0; i < data.size(); ++i) { |
59 | data[i] = i; |
60 | } |
61 | |
62 | StridedMemRefType<int, 2> memRef; |
63 | memRef.basePtr = data.data(); |
64 | memRef.data = data.data(); |
65 | memRef.offset = 0; |
66 | memRef.sizes[0] = 2; |
67 | memRef.sizes[1] = 3; |
68 | memRef.strides[0] = 3; |
69 | memRef.strides[1] = 1; |
70 | |
71 | DynamicMemRefType<int> dynamicMemRef(memRef); |
72 | |
73 | llvm::SmallVector<int, 6> values(dynamicMemRef.begin(), dynamicMemRef.end()); |
74 | EXPECT_THAT(values, ElementsAreArray(data)); |
75 | } |
76 | |
77 | TEST(DynamicMemRef, rankThree) { |
78 | std::array<int, 24> data; |
79 | |
80 | for (size_t i = 0; i < data.size(); ++i) { |
81 | data[i] = i; |
82 | } |
83 | |
84 | StridedMemRefType<int, 3> memRef; |
85 | memRef.basePtr = data.data(); |
86 | memRef.data = data.data(); |
87 | memRef.offset = 0; |
88 | memRef.sizes[0] = 2; |
89 | memRef.sizes[1] = 3; |
90 | memRef.sizes[2] = 4; |
91 | memRef.strides[0] = 12; |
92 | memRef.strides[1] = 4; |
93 | memRef.strides[2] = 1; |
94 | |
95 | DynamicMemRefType<int> dynamicMemRef(memRef); |
96 | |
97 | llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end()); |
98 | EXPECT_THAT(values, ElementsAreArray(data)); |
99 | } |
100 | |
101 | TEST(DynamicMemRef, rankOneWithOffset) { |
102 | constexpr int offset = 4; |
103 | std::array<int, 3 + offset> buffer; |
104 | |
105 | for (size_t i = 0; i < buffer.size(); ++i) { |
106 | buffer[i] = i; |
107 | } |
108 | |
109 | StridedMemRefType<int, 1> memRef; |
110 | memRef.basePtr = buffer.data(); |
111 | memRef.data = buffer.data(); |
112 | memRef.offset = offset; |
113 | memRef.sizes[0] = 3; |
114 | memRef.strides[0] = 1; |
115 | |
116 | DynamicMemRefType<int> dynamicMemRef(memRef); |
117 | |
118 | llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end()); |
119 | |
120 | for (int64_t i = 0; i < 3; ++i) { |
121 | EXPECT_EQ(values[i], buffer[offset + i]); |
122 | EXPECT_EQ(*dynamicMemRef[i], buffer[offset + i]); |
123 | } |
124 | } |
125 | |