| 1 | //===- DynamicMemRef.cpp ----------------------------------------*- C++ -*-===// |
| 2 | // |
| 3 | // This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/ExecutionEngine/CRunnerUtils.h" |
| 10 | #include "llvm/ADT/SmallVector.h" |
| 11 | |
| 12 | #include "gmock/gmock.h" |
| 13 | |
| 14 | using namespace ::mlir; |
| 15 | using namespace ::testing; |
| 16 | |
| 17 | TEST(DynamicMemRef, rankZero) { |
| 18 | int data = 57; |
| 19 | |
| 20 | StridedMemRefType<int, 0> memRef; |
| 21 | memRef.basePtr = &data; |
| 22 | memRef.data = &data; |
| 23 | memRef.offset = 0; |
| 24 | |
| 25 | DynamicMemRefType<int> dynamicMemRef(memRef); |
| 26 | |
| 27 | llvm::SmallVector<int, 1> values(dynamicMemRef.begin(), dynamicMemRef.end()); |
| 28 | EXPECT_THAT(values, ElementsAre(57)); |
| 29 | } |
| 30 | |
| 31 | TEST(DynamicMemRef, rankOne) { |
| 32 | std::array<int, 3> data; |
| 33 | |
| 34 | for (size_t i = 0; i < data.size(); ++i) { |
| 35 | data[i] = i; |
| 36 | } |
| 37 | |
| 38 | StridedMemRefType<int, 1> memRef; |
| 39 | memRef.basePtr = data.data(); |
| 40 | memRef.data = data.data(); |
| 41 | memRef.offset = 0; |
| 42 | memRef.sizes[0] = 3; |
| 43 | memRef.strides[0] = 1; |
| 44 | |
| 45 | DynamicMemRefType<int> dynamicMemRef(memRef); |
| 46 | |
| 47 | llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end()); |
| 48 | EXPECT_THAT(values, ElementsAreArray(data)); |
| 49 | |
| 50 | for (int64_t i = 0; i < 3; ++i) { |
| 51 | EXPECT_EQ(*dynamicMemRef[i], data[i]); |
| 52 | } |
| 53 | } |
| 54 | |
| 55 | TEST(DynamicMemRef, rankTwo) { |
| 56 | std::array<int, 6> data; |
| 57 | |
| 58 | for (size_t i = 0; i < data.size(); ++i) { |
| 59 | data[i] = i; |
| 60 | } |
| 61 | |
| 62 | StridedMemRefType<int, 2> memRef; |
| 63 | memRef.basePtr = data.data(); |
| 64 | memRef.data = data.data(); |
| 65 | memRef.offset = 0; |
| 66 | memRef.sizes[0] = 2; |
| 67 | memRef.sizes[1] = 3; |
| 68 | memRef.strides[0] = 3; |
| 69 | memRef.strides[1] = 1; |
| 70 | |
| 71 | DynamicMemRefType<int> dynamicMemRef(memRef); |
| 72 | |
| 73 | llvm::SmallVector<int, 6> values(dynamicMemRef.begin(), dynamicMemRef.end()); |
| 74 | EXPECT_THAT(values, ElementsAreArray(data)); |
| 75 | } |
| 76 | |
| 77 | TEST(DynamicMemRef, rankThree) { |
| 78 | std::array<int, 24> data; |
| 79 | |
| 80 | for (size_t i = 0; i < data.size(); ++i) { |
| 81 | data[i] = i; |
| 82 | } |
| 83 | |
| 84 | StridedMemRefType<int, 3> memRef; |
| 85 | memRef.basePtr = data.data(); |
| 86 | memRef.data = data.data(); |
| 87 | memRef.offset = 0; |
| 88 | memRef.sizes[0] = 2; |
| 89 | memRef.sizes[1] = 3; |
| 90 | memRef.sizes[2] = 4; |
| 91 | memRef.strides[0] = 12; |
| 92 | memRef.strides[1] = 4; |
| 93 | memRef.strides[2] = 1; |
| 94 | |
| 95 | DynamicMemRefType<int> dynamicMemRef(memRef); |
| 96 | |
| 97 | llvm::SmallVector<int, 24> values(dynamicMemRef.begin(), dynamicMemRef.end()); |
| 98 | EXPECT_THAT(values, ElementsAreArray(data)); |
| 99 | } |
| 100 | |
| 101 | TEST(DynamicMemRef, rankOneWithOffset) { |
| 102 | constexpr int offset = 4; |
| 103 | std::array<int, 3 + offset> buffer; |
| 104 | |
| 105 | for (size_t i = 0; i < buffer.size(); ++i) { |
| 106 | buffer[i] = i; |
| 107 | } |
| 108 | |
| 109 | StridedMemRefType<int, 1> memRef; |
| 110 | memRef.basePtr = buffer.data(); |
| 111 | memRef.data = buffer.data(); |
| 112 | memRef.offset = offset; |
| 113 | memRef.sizes[0] = 3; |
| 114 | memRef.strides[0] = 1; |
| 115 | |
| 116 | DynamicMemRefType<int> dynamicMemRef(memRef); |
| 117 | |
| 118 | llvm::SmallVector<int, 3> values(dynamicMemRef.begin(), dynamicMemRef.end()); |
| 119 | |
| 120 | for (int64_t i = 0; i < 3; ++i) { |
| 121 | EXPECT_EQ(values[i], buffer[offset + i]); |
| 122 | EXPECT_EQ(*dynamicMemRef[i], buffer[offset + i]); |
| 123 | } |
| 124 | } |
| 125 | |