1//===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/AffineMap.h"
10#include "mlir/IR/BuiltinAttributes.h"
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/IR/Dialect.h"
13#include "mlir/IR/DialectInterface.h"
14#include "mlir/Support/LLVM.h"
15#include "llvm/ADT/SmallVector.h"
16#include "gtest/gtest.h"
17#include <cstdint>
18
19using namespace mlir;
20using namespace mlir::detail;
21
22namespace {
23TEST(ShapedTypeTest, CloneMemref) {
24 MLIRContext context;
25
26 Type i32 = IntegerType::get(&context, 32);
27 Type f32 = Float32Type::get(&context);
28 Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7);
29 Type memrefOriginalType = i32;
30 llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
31 AffineMap map = makeStridedLinearLayoutMap(strides: {2, 3}, offset: 5, context: &context);
32
33 ShapedType memrefType =
34 (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
35 .setMemorySpace(memSpace)
36 .setLayout(AffineMapAttr::get(map));
37 // Update shape.
38 llvm::SmallVector<int64_t> memrefNewShape({30, 40});
39 ASSERT_NE(memrefOriginalShape, memrefNewShape);
40 ASSERT_EQ(memrefType.clone(memrefNewShape),
41 (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
42 .setMemorySpace(memSpace)
43 .setLayout(AffineMapAttr::get(map)));
44 // Update type.
45 Type memrefNewType = f32;
46 ASSERT_NE(memrefOriginalType, memrefNewType);
47 ASSERT_EQ(memrefType.clone(memrefNewType),
48 (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
49 .setMemorySpace(memSpace)
50 .setLayout(AffineMapAttr::get(map)));
51 // Update both.
52 ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
53 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
54 .setMemorySpace(memSpace)
55 .setLayout(AffineMapAttr::get(map)));
56
57 // Test unranked memref cloning.
58 ShapedType unrankedTensorType =
59 UnrankedMemRefType::get(memrefOriginalType, memSpace);
60 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),
61 (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
62 .setMemorySpace(memSpace));
63 ASSERT_EQ(unrankedTensorType.clone(memrefNewType),
64 UnrankedMemRefType::get(memrefNewType, memSpace));
65 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),
66 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
67 .setMemorySpace(memSpace));
68}
69
70TEST(ShapedTypeTest, CloneTensor) {
71 MLIRContext context;
72
73 Type i32 = IntegerType::get(&context, 32);
74 Type f32 = Float32Type::get(&context);
75
76 Type tensorOriginalType = i32;
77 llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});
78
79 // Test ranked tensor cloning.
80 ShapedType tensorType =
81 RankedTensorType::get(tensorOriginalShape, tensorOriginalType);
82 // Update shape.
83 llvm::SmallVector<int64_t> tensorNewShape({30, 40});
84 ASSERT_NE(tensorOriginalShape, tensorNewShape);
85 ASSERT_EQ(
86 tensorType.clone(tensorNewShape),
87 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
88 // Update type.
89 Type tensorNewType = f32;
90 ASSERT_NE(tensorOriginalType, tensorNewType);
91 ASSERT_EQ(
92 tensorType.clone(tensorNewType),
93 (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType));
94 // Update both.
95 ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
96 (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType));
97
98 // Test unranked tensor cloning.
99 ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
100 ASSERT_EQ(
101 unrankedTensorType.clone(tensorNewShape),
102 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
103 ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
104 (ShapedType)UnrankedTensorType::get(tensorNewType));
105 ASSERT_EQ(
106 unrankedTensorType.clone(tensorNewShape),
107 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
108}
109
110TEST(ShapedTypeTest, CloneVector) {
111 MLIRContext context;
112
113 Type i32 = IntegerType::get(&context, 32);
114 Type f32 = Float32Type::get(&context);
115
116 Type vectorOriginalType = i32;
117 llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
118 ShapedType vectorType =
119 VectorType::get(vectorOriginalShape, vectorOriginalType);
120 // Update shape.
121 llvm::SmallVector<int64_t> vectorNewShape({30, 40});
122 ASSERT_NE(vectorOriginalShape, vectorNewShape);
123 ASSERT_EQ(vectorType.clone(vectorNewShape),
124 VectorType::get(vectorNewShape, vectorOriginalType));
125 // Update type.
126 Type vectorNewType = f32;
127 ASSERT_NE(vectorOriginalType, vectorNewType);
128 ASSERT_EQ(vectorType.clone(vectorNewType),
129 VectorType::get(vectorOriginalShape, vectorNewType));
130 // Update both.
131 ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),
132 VectorType::get(vectorNewShape, vectorNewType));
133}
134
135TEST(ShapedTypeTest, VectorTypeBuilder) {
136 MLIRContext context;
137 Type f32 = Float32Type::get(&context);
138
139 SmallVector<int64_t> shape{2, 4, 8, 9, 1};
140 SmallVector<bool> scalableDims{true, false, true, false, false};
141 VectorType vectorType = VectorType::get(shape, f32, scalableDims);
142
143 {
144 // Drop some dims.
145 VectorType dropFrontTwoDims =
146 VectorType::Builder(vectorType).dropDim(0).dropDim(0);
147 ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());
148 ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
149 ASSERT_EQ(vectorType.getScalableDims().drop_front(2),
150 dropFrontTwoDims.getScalableDims());
151 }
152
153 {
154 // Set some dims.
155 VectorType setTwoDims =
156 VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);
157 ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));
158 ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());
159 ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());
160 }
161
162 {
163 // Test for bug from:
164 // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
165 // Constructs a temporary builder, modifies it, copies it to `builder`.
166 // This used to lead to a use-after-free. Running under sanitizers will
167 // catch any issues.
168 VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
169 VectorType newVectorType = VectorType(builder);
170 ASSERT_EQ(newVectorType.getDimSize(0), 16);
171 }
172
173 {
174 // Make builder from scratch (without scalable dims) -- this use to lead to
175 // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
176 // Running under sanitizers will catch any issues.
177 SmallVector<int64_t> shape{1, 2, 3, 4};
178 VectorType::Builder builder(shape, f32);
179 ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));
180 }
181
182 {
183 // Set vector shape (without scalable dims) -- this use to lead to
184 // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
185 // Running under sanitizers will catch any issues.
186 VectorType::Builder builder(vectorType);
187 SmallVector<int64_t> newShape{2, 2};
188 builder.setShape(newShape);
189 ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));
190 }
191}
192
193TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
194 MLIRContext context;
195 Type f32 = Float32Type::get(&context);
196
197 SmallVector<int64_t> shape{2, 4, 8, 16, 32};
198 RankedTensorType tensorType = RankedTensorType::get(shape, f32);
199
200 {
201 // Drop some dims.
202 RankedTensorType dropFrontTwoDims =
203 RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0);
204 ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
205 ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32}));
206 }
207
208 {
209 // Insert some dims.
210 RankedTensorType insertTwoDims =
211 RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);
212 ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());
213 ASSERT_EQ(insertTwoDims.getShape(),
214 ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));
215 }
216
217 {
218 // Test for bug from:
219 // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
220 // Constructs a temporary builder, modifies it, copies it to `builder`.
221 // This used to lead to a use-after-free. Running under sanitizers will
222 // catch any issues.
223 RankedTensorType::Builder builder =
224 RankedTensorType::Builder(tensorType).dropDim(0);
225 RankedTensorType newTensorType = RankedTensorType(builder);
226 ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());
227 }
228}
229
230/// Simple wrapper class to enable "isa querying" and simple accessing of
231/// encoding.
232class TensorWithString : public RankedTensorType {
233public:
234 using RankedTensorType::RankedTensorType;
235
236 static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
237 StringRef name) {
238 return mlir::cast<TensorWithString>(RankedTensorType::get(
239 shape, elementType, StringAttr::get(elementType.getContext(), name)));
240 }
241
242 StringRef getName() const {
243 if (Attribute enc = getEncoding())
244 return mlir::cast<StringAttr>(enc).getValue();
245 return {};
246 }
247
248 static bool classof(Type type) {
249 if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
250 return mlir::isa_and_present<StringAttr>(rt.getEncoding());
251 return false;
252 }
253};
254
255TEST(ShapedTypeTest, RankedTensorTypeView) {
256 MLIRContext context;
257 Type f32 = Float32Type::get(&context);
258
259 Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
260
261 UnitAttr unitAttr = UnitAttr::get(&context);
262 Type unitEncodingRankedTensorType =
263 RankedTensorType::get({10, 20}, f32, unitAttr);
264
265 StringAttr stringAttr = StringAttr::get(&context, "app");
266 Type stringEncodingRankedTensorType =
267 RankedTensorType::get({10, 20}, f32, stringAttr);
268
269 EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
270 EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
271 ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
272
273 // Cast to TensorWithString view.
274 auto view = mlir::cast<TensorWithString>(Val&: stringEncodingRankedTensorType);
275 ASSERT_TRUE(mlir::isa<TensorWithString>(view));
276 EXPECT_EQ(view.getName(), "app");
277 // Verify one could cast view type back to base type.
278 ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
279
280 Type viewCreated = TensorWithString::get(shape: {10, 20}, elementType: f32, name: "bob");
281 ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
282 ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
283 view = mlir::cast<TensorWithString>(Val&: viewCreated);
284 EXPECT_EQ(view.getName(), "bob");
285
286 // Verify encoding clone methods.
287 EXPECT_EQ(unitEncodingRankedTensorType,
288 cast<RankedTensorType>(noEncodingRankedTensorType)
289 .cloneWithEncoding(unitAttr));
290 EXPECT_EQ(stringEncodingRankedTensorType,
291 cast<RankedTensorType>(noEncodingRankedTensorType)
292 .cloneWithEncoding(stringAttr));
293 EXPECT_EQ(
294 noEncodingRankedTensorType,
295 cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
296 EXPECT_EQ(
297 noEncodingRankedTensorType,
298 cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
299}
300
301} // namespace
302

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/unittests/IR/ShapedTypeTest.cpp