1//===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/AffineMap.h"
10#include "mlir/IR/BuiltinAttributes.h"
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/IR/Dialect.h"
13#include "mlir/IR/DialectInterface.h"
14#include "llvm/ADT/SmallVector.h"
15#include "gtest/gtest.h"
16#include <cstdint>
17
18using namespace mlir;
19using namespace mlir::detail;
20
21namespace {
22TEST(ShapedTypeTest, CloneMemref) {
23 MLIRContext context;
24
25 Type i32 = IntegerType::get(&context, 32);
26 Type f32 = FloatType::getF32(ctx: &context);
27 Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7);
28 Type memrefOriginalType = i32;
29 llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
30 AffineMap map = makeStridedLinearLayoutMap(strides: {2, 3}, offset: 5, context: &context);
31
32 ShapedType memrefType =
33 (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
34 .setMemorySpace(memSpace)
35 .setLayout(AffineMapAttr::get(map));
36 // Update shape.
37 llvm::SmallVector<int64_t> memrefNewShape({30, 40});
38 ASSERT_NE(memrefOriginalShape, memrefNewShape);
39 ASSERT_EQ(memrefType.clone(memrefNewShape),
40 (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
41 .setMemorySpace(memSpace)
42 .setLayout(AffineMapAttr::get(map)));
43 // Update type.
44 Type memrefNewType = f32;
45 ASSERT_NE(memrefOriginalType, memrefNewType);
46 ASSERT_EQ(memrefType.clone(memrefNewType),
47 (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
48 .setMemorySpace(memSpace)
49 .setLayout(AffineMapAttr::get(map)));
50 // Update both.
51 ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
52 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
53 .setMemorySpace(memSpace)
54 .setLayout(AffineMapAttr::get(map)));
55
56 // Test unranked memref cloning.
57 ShapedType unrankedTensorType =
58 UnrankedMemRefType::get(memrefOriginalType, memSpace);
59 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),
60 (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
61 .setMemorySpace(memSpace));
62 ASSERT_EQ(unrankedTensorType.clone(memrefNewType),
63 UnrankedMemRefType::get(memrefNewType, memSpace));
64 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),
65 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
66 .setMemorySpace(memSpace));
67}
68
69TEST(ShapedTypeTest, CloneTensor) {
70 MLIRContext context;
71
72 Type i32 = IntegerType::get(&context, 32);
73 Type f32 = FloatType::getF32(ctx: &context);
74
75 Type tensorOriginalType = i32;
76 llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});
77
78 // Test ranked tensor cloning.
79 ShapedType tensorType =
80 RankedTensorType::get(tensorOriginalShape, tensorOriginalType);
81 // Update shape.
82 llvm::SmallVector<int64_t> tensorNewShape({30, 40});
83 ASSERT_NE(tensorOriginalShape, tensorNewShape);
84 ASSERT_EQ(
85 tensorType.clone(tensorNewShape),
86 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
87 // Update type.
88 Type tensorNewType = f32;
89 ASSERT_NE(tensorOriginalType, tensorNewType);
90 ASSERT_EQ(
91 tensorType.clone(tensorNewType),
92 (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType));
93 // Update both.
94 ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
95 (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType));
96
97 // Test unranked tensor cloning.
98 ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
99 ASSERT_EQ(
100 unrankedTensorType.clone(tensorNewShape),
101 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
102 ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
103 (ShapedType)UnrankedTensorType::get(tensorNewType));
104 ASSERT_EQ(
105 unrankedTensorType.clone(tensorNewShape),
106 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
107}
108
109TEST(ShapedTypeTest, CloneVector) {
110 MLIRContext context;
111
112 Type i32 = IntegerType::get(&context, 32);
113 Type f32 = FloatType::getF32(ctx: &context);
114
115 Type vectorOriginalType = i32;
116 llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
117 ShapedType vectorType =
118 VectorType::get(vectorOriginalShape, vectorOriginalType);
119 // Update shape.
120 llvm::SmallVector<int64_t> vectorNewShape({30, 40});
121 ASSERT_NE(vectorOriginalShape, vectorNewShape);
122 ASSERT_EQ(vectorType.clone(vectorNewShape),
123 VectorType::get(vectorNewShape, vectorOriginalType));
124 // Update type.
125 Type vectorNewType = f32;
126 ASSERT_NE(vectorOriginalType, vectorNewType);
127 ASSERT_EQ(vectorType.clone(vectorNewType),
128 VectorType::get(vectorOriginalShape, vectorNewType));
129 // Update both.
130 ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),
131 VectorType::get(vectorNewShape, vectorNewType));
132}
133
134TEST(ShapedTypeTest, VectorTypeBuilder) {
135 MLIRContext context;
136 Type f32 = FloatType::getF32(ctx: &context);
137
138 SmallVector<int64_t> shape{2, 4, 8, 9, 1};
139 SmallVector<bool> scalableDims{true, false, true, false, false};
140 VectorType vectorType = VectorType::get(shape, f32, scalableDims);
141
142 {
143 // Drop some dims.
144 VectorType dropFrontTwoDims =
145 VectorType::Builder(vectorType).dropDim(0).dropDim(0);
146 ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());
147 ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
148 ASSERT_EQ(vectorType.getScalableDims().drop_front(2),
149 dropFrontTwoDims.getScalableDims());
150 }
151
152 {
153 // Set some dims.
154 VectorType setTwoDims =
155 VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);
156 ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));
157 ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());
158 ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());
159 }
160
161 {
162 // Test for bug from:
163 // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
164 // Constructs a temporary builder, modifies it, copies it to `builder`.
165 // This used to lead to a use-after-free. Running under sanitizers will
166 // catch any issues.
167 VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
168 VectorType newVectorType = VectorType(builder);
169 ASSERT_EQ(newVectorType.getDimSize(0), 16);
170 }
171
172 {
173 // Make builder from scratch (without scalable dims) -- this use to lead to
174 // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
175 // Running under sanitizers will catch any issues.
176 SmallVector<int64_t> shape{1, 2, 3, 4};
177 VectorType::Builder builder(shape, f32);
178 ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));
179 }
180
181 {
182 // Set vector shape (without scalable dims) -- this use to lead to
183 // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
184 // Running under sanitizers will catch any issues.
185 VectorType::Builder builder(vectorType);
186 SmallVector<int64_t> newShape{2, 2};
187 builder.setShape(newShape);
188 ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));
189 }
190}
191
192TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
193 MLIRContext context;
194 Type f32 = FloatType::getF32(ctx: &context);
195
196 SmallVector<int64_t> shape{2, 4, 8, 16, 32};
197 RankedTensorType tensorType = RankedTensorType::get(shape, f32);
198
199 {
200 // Drop some dims.
201 RankedTensorType dropFrontTwoDims =
202 RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0);
203 ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
204 ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32}));
205 }
206
207 {
208 // Insert some dims.
209 RankedTensorType insertTwoDims =
210 RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);
211 ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());
212 ASSERT_EQ(insertTwoDims.getShape(),
213 ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));
214 }
215
216 {
217 // Test for bug from:
218 // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
219 // Constructs a temporary builder, modifies it, copies it to `builder`.
220 // This used to lead to a use-after-free. Running under sanitizers will
221 // catch any issues.
222 RankedTensorType::Builder builder =
223 RankedTensorType::Builder(tensorType).dropDim(0);
224 RankedTensorType newTensorType = RankedTensorType(builder);
225 ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());
226 }
227}
228
229} // namespace
230

source code of mlir/unittests/IR/ShapedTypeTest.cpp