1//===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/ExecutionEngine/CRunnerUtils.h"
10#include "llvm/ADT/SmallVector.h"
11
12#include "gmock/gmock.h"
13
14using namespace ::mlir;
15using namespace ::testing;
16
17TEST(DynamicMemRef, rankZero) {
18 int data = 57;
19
20 StridedMemRefType<int, 0> memRef;
21 memRef.basePtr = &data;
22 memRef.data = &data;
23 memRef.offset = 0;
24
25 DynamicMemRefType<int> dynamicMemRef(memRef);
26
27 llvm::SmallVector<int, 1> values(dynamicMemRef.begin(), dynamicMemRef.end());
28 EXPECT_THAT(values, ElementsAre(57));
29}
30
31TEST(DynamicMemRef, rankOne) {
32 std::array<int, 3> data;
33
34 for (size_t i = 0; i < data.size(); ++i) {
35 data[i] = i;
36 }
37
38 StridedMemRefType<int, 1> memRef;
39 memRef.basePtr = data.data();
40 memRef.data = data.data();
41 memRef.offset = 0;
42 memRef.sizes[0] = 3;
43 memRef.strides[0] = 1;
44
45 DynamicMemRefType<int> dynamicMemRef(memRef);
46
47 llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
48 EXPECT_THAT(values, ElementsAreArray(data));
49
50 for (int64_t i = 0; i < 3; ++i) {
51 EXPECT_EQ(*dynamicMemRef[i], data[i]);
52 }
53}
54
55TEST(DynamicMemRef, rankTwo) {
56 std::array<int, 6> data;
57
58 for (size_t i = 0; i < data.size(); ++i) {
59 data[i] = i;
60 }
61
62 StridedMemRefType<int, 2> memRef;
63 memRef.basePtr = data.data();
64 memRef.data = data.data();
65 memRef.offset = 0;
66 memRef.sizes[0] = 2;
67 memRef.sizes[1] = 3;
68 memRef.strides[0] = 3;
69 memRef.strides[1] = 1;
70
71 DynamicMemRefType<int> dynamicMemRef(memRef);
72
73 llvm::SmallVector<int, 6> values(dynamicMemRef.begin(), dynamicMemRef.end());
74 EXPECT_THAT(values, ElementsAreArray(data));
75}
76
77TEST(DynamicMemRef, rankThree) {
78 std::array<int, 24> data;
79
80 for (size_t i = 0; i < data.size(); ++i) {
81 data[i] = i;
82 }
83
84 StridedMemRefType<int, 3> memRef;
85 memRef.basePtr = data.data();
86 memRef.data = data.data();
87 memRef.offset = 0;
88 memRef.sizes[0] = 2;
89 memRef.sizes[1] = 3;
90 memRef.sizes[2] = 4;
91 memRef.strides[0] = 12;
92 memRef.strides[1] = 4;
93 memRef.strides[2] = 1;
94
95 DynamicMemRefType<int> dynamicMemRef(memRef);
96
97 llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end());
98 EXPECT_THAT(values, ElementsAreArray(data));
99}
100
101TEST(DynamicMemRef, rankOneWithOffset) {
102 constexpr int offset = 4;
103 std::array<int, 3 + offset> buffer;
104
105 for (size_t i = 0; i < buffer.size(); ++i) {
106 buffer[i] = i;
107 }
108
109 StridedMemRefType<int, 1> memRef;
110 memRef.basePtr = buffer.data();
111 memRef.data = buffer.data();
112 memRef.offset = offset;
113 memRef.sizes[0] = 3;
114 memRef.strides[0] = 1;
115
116 DynamicMemRefType<int> dynamicMemRef(memRef);
117
118 llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end());
119
120 for (int64_t i = 0; i < 3; ++i) {
121 EXPECT_EQ(values[i], buffer[offset + i]);
122 EXPECT_EQ(*dynamicMemRef[i], buffer[offset + i]);
123 }
124}
125

source code of mlir/unittests/ExecutionEngine/DynamicMemRef.cpp