1 | //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/AffineMap.h" |
10 | #include "mlir/IR/BuiltinAttributes.h" |
11 | #include "mlir/IR/BuiltinTypes.h" |
12 | #include "mlir/IR/Dialect.h" |
13 | #include "mlir/IR/DialectInterface.h" |
14 | #include "llvm/ADT/SmallVector.h" |
15 | #include "gtest/gtest.h" |
16 | #include <cstdint> |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::detail; |
20 | |
21 | namespace { |
22 | TEST(ShapedTypeTest, CloneMemref) { |
23 | MLIRContext context; |
24 | |
25 | Type i32 = IntegerType::get(&context, 32); |
26 | Type f32 = FloatType::getF32(ctx: &context); |
27 | Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7); |
28 | Type memrefOriginalType = i32; |
29 | llvm::SmallVector<int64_t> memrefOriginalShape({10, 20}); |
30 | AffineMap map = makeStridedLinearLayoutMap(strides: {2, 3}, offset: 5, context: &context); |
31 | |
32 | ShapedType memrefType = |
33 | (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType) |
34 | .setMemorySpace(memSpace) |
35 | .setLayout(AffineMapAttr::get(map)); |
36 | // Update shape. |
37 | llvm::SmallVector<int64_t> memrefNewShape({30, 40}); |
38 | ASSERT_NE(memrefOriginalShape, memrefNewShape); |
39 | ASSERT_EQ(memrefType.clone(memrefNewShape), |
40 | (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType) |
41 | .setMemorySpace(memSpace) |
42 | .setLayout(AffineMapAttr::get(map))); |
43 | // Update type. |
44 | Type memrefNewType = f32; |
45 | ASSERT_NE(memrefOriginalType, memrefNewType); |
46 | ASSERT_EQ(memrefType.clone(memrefNewType), |
47 | (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType) |
48 | .setMemorySpace(memSpace) |
49 | .setLayout(AffineMapAttr::get(map))); |
50 | // Update both. |
51 | ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType), |
52 | (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) |
53 | .setMemorySpace(memSpace) |
54 | .setLayout(AffineMapAttr::get(map))); |
55 | |
56 | // Test unranked memref cloning. |
57 | ShapedType unrankedTensorType = |
58 | UnrankedMemRefType::get(memrefOriginalType, memSpace); |
59 | ASSERT_EQ(unrankedTensorType.clone(memrefNewShape), |
60 | (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) |
61 | .setMemorySpace(memSpace)); |
62 | ASSERT_EQ(unrankedTensorType.clone(memrefNewType), |
63 | UnrankedMemRefType::get(memrefNewType, memSpace)); |
64 | ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType), |
65 | (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) |
66 | .setMemorySpace(memSpace)); |
67 | } |
68 | |
69 | TEST(ShapedTypeTest, CloneTensor) { |
70 | MLIRContext context; |
71 | |
72 | Type i32 = IntegerType::get(&context, 32); |
73 | Type f32 = FloatType::getF32(ctx: &context); |
74 | |
75 | Type tensorOriginalType = i32; |
76 | llvm::SmallVector<int64_t> tensorOriginalShape({10, 20}); |
77 | |
78 | // Test ranked tensor cloning. |
79 | ShapedType tensorType = |
80 | RankedTensorType::get(tensorOriginalShape, tensorOriginalType); |
81 | // Update shape. |
82 | llvm::SmallVector<int64_t> tensorNewShape({30, 40}); |
83 | ASSERT_NE(tensorOriginalShape, tensorNewShape); |
84 | ASSERT_EQ( |
85 | tensorType.clone(tensorNewShape), |
86 | (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); |
87 | // Update type. |
88 | Type tensorNewType = f32; |
89 | ASSERT_NE(tensorOriginalType, tensorNewType); |
90 | ASSERT_EQ( |
91 | tensorType.clone(tensorNewType), |
92 | (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType)); |
93 | // Update both. |
94 | ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType), |
95 | (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType)); |
96 | |
97 | // Test unranked tensor cloning. |
98 | ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType); |
99 | ASSERT_EQ( |
100 | unrankedTensorType.clone(tensorNewShape), |
101 | (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); |
102 | ASSERT_EQ(unrankedTensorType.clone(tensorNewType), |
103 | (ShapedType)UnrankedTensorType::get(tensorNewType)); |
104 | ASSERT_EQ( |
105 | unrankedTensorType.clone(tensorNewShape), |
106 | (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); |
107 | } |
108 | |
109 | TEST(ShapedTypeTest, CloneVector) { |
110 | MLIRContext context; |
111 | |
112 | Type i32 = IntegerType::get(&context, 32); |
113 | Type f32 = FloatType::getF32(ctx: &context); |
114 | |
115 | Type vectorOriginalType = i32; |
116 | llvm::SmallVector<int64_t> vectorOriginalShape({10, 20}); |
117 | ShapedType vectorType = |
118 | VectorType::get(vectorOriginalShape, vectorOriginalType); |
119 | // Update shape. |
120 | llvm::SmallVector<int64_t> vectorNewShape({30, 40}); |
121 | ASSERT_NE(vectorOriginalShape, vectorNewShape); |
122 | ASSERT_EQ(vectorType.clone(vectorNewShape), |
123 | VectorType::get(vectorNewShape, vectorOriginalType)); |
124 | // Update type. |
125 | Type vectorNewType = f32; |
126 | ASSERT_NE(vectorOriginalType, vectorNewType); |
127 | ASSERT_EQ(vectorType.clone(vectorNewType), |
128 | VectorType::get(vectorOriginalShape, vectorNewType)); |
129 | // Update both. |
130 | ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType), |
131 | VectorType::get(vectorNewShape, vectorNewType)); |
132 | } |
133 | |
134 | TEST(ShapedTypeTest, VectorTypeBuilder) { |
135 | MLIRContext context; |
136 | Type f32 = FloatType::getF32(ctx: &context); |
137 | |
138 | SmallVector<int64_t> shape{2, 4, 8, 9, 1}; |
139 | SmallVector<bool> scalableDims{true, false, true, false, false}; |
140 | VectorType vectorType = VectorType::get(shape, f32, scalableDims); |
141 | |
142 | { |
143 | // Drop some dims. |
144 | VectorType dropFrontTwoDims = |
145 | VectorType::Builder(vectorType).dropDim(0).dropDim(0); |
146 | ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType()); |
147 | ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape()); |
148 | ASSERT_EQ(vectorType.getScalableDims().drop_front(2), |
149 | dropFrontTwoDims.getScalableDims()); |
150 | } |
151 | |
152 | { |
153 | // Set some dims. |
154 | VectorType setTwoDims = |
155 | VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12); |
156 | ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1})); |
157 | ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType()); |
158 | ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims()); |
159 | } |
160 | |
161 | { |
162 | // Test for bug from: |
163 | // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a |
164 | // Constructs a temporary builder, modifies it, copies it to `builder`. |
165 | // This used to lead to a use-after-free. Running under sanitizers will |
166 | // catch any issues. |
167 | VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16); |
168 | VectorType newVectorType = VectorType(builder); |
169 | ASSERT_EQ(newVectorType.getDimSize(0), 16); |
170 | } |
171 | |
172 | { |
173 | // Make builder from scratch (without scalable dims) -- this use to lead to |
174 | // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969. |
175 | // Running under sanitizers will catch any issues. |
176 | SmallVector<int64_t> shape{1, 2, 3, 4}; |
177 | VectorType::Builder builder(shape, f32); |
178 | ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape)); |
179 | } |
180 | |
181 | { |
182 | // Set vector shape (without scalable dims) -- this use to lead to |
183 | // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969. |
184 | // Running under sanitizers will catch any issues. |
185 | VectorType::Builder builder(vectorType); |
186 | SmallVector<int64_t> newShape{2, 2}; |
187 | builder.setShape(newShape); |
188 | ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape)); |
189 | } |
190 | } |
191 | |
192 | TEST(ShapedTypeTest, RankedTensorTypeBuilder) { |
193 | MLIRContext context; |
194 | Type f32 = FloatType::getF32(ctx: &context); |
195 | |
196 | SmallVector<int64_t> shape{2, 4, 8, 16, 32}; |
197 | RankedTensorType tensorType = RankedTensorType::get(shape, f32); |
198 | |
199 | { |
200 | // Drop some dims. |
201 | RankedTensorType dropFrontTwoDims = |
202 | RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0); |
203 | ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType()); |
204 | ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32})); |
205 | } |
206 | |
207 | { |
208 | // Insert some dims. |
209 | RankedTensorType insertTwoDims = |
210 | RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3); |
211 | ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType()); |
212 | ASSERT_EQ(insertTwoDims.getShape(), |
213 | ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32})); |
214 | } |
215 | |
216 | { |
217 | // Test for bug from: |
218 | // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a |
219 | // Constructs a temporary builder, modifies it, copies it to `builder`. |
220 | // This used to lead to a use-after-free. Running under sanitizers will |
221 | // catch any issues. |
222 | RankedTensorType::Builder builder = |
223 | RankedTensorType::Builder(tensorType).dropDim(0); |
224 | RankedTensorType newTensorType = RankedTensorType(builder); |
225 | ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape()); |
226 | } |
227 | } |
228 | |
229 | } // namespace |
230 | |