1//===- OperationSupportTest.cpp - Operation support unit tests ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/OperationSupport.h"
10#include "../../test/lib/Dialect/Test/TestDialect.h"
11#include "../../test/lib/Dialect/Test/TestOps.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "llvm/ADT/BitVector.h"
15#include "llvm/Support/FormatVariadic.h"
16#include "gtest/gtest.h"
17
18using namespace mlir;
19using namespace mlir::detail;
20
21static Operation *createOp(MLIRContext *context, ArrayRef<Value> operands = {},
22 ArrayRef<Type> resultTypes = {},
23 unsigned int numRegions = 0) {
24 context->allowUnregisteredDialects();
25 return Operation::create(location: UnknownLoc::get(context),
26 name: OperationName("foo.bar", context), resultTypes,
27 operands, attributes: std::nullopt, properties: nullptr, successors: {}, numRegions);
28}
29
30namespace {
31TEST(OperandStorageTest, NonResizable) {
32 MLIRContext context;
33 Builder builder(&context);
34
35 Operation *useOp =
36 createOp(context: &context, /*operands=*/{}, resultTypes: builder.getIntegerType(width: 16));
37 Value operand = useOp->getResult(idx: 0);
38
39 // Create a non-resizable operation with one operand.
40 Operation *user = createOp(context: &context, operands: operand);
41
42 // The same number of operands is okay.
43 user->setOperands(operand);
44 EXPECT_EQ(user->getNumOperands(), 1u);
45
46 // Removing is okay.
47 user->setOperands({});
48 EXPECT_EQ(user->getNumOperands(), 0u);
49
50 // Destroy the operations.
51 user->destroy();
52 useOp->destroy();
53}
54
55TEST(OperandStorageTest, Resizable) {
56 MLIRContext context;
57 Builder builder(&context);
58
59 Operation *useOp =
60 createOp(context: &context, /*operands=*/{}, resultTypes: builder.getIntegerType(width: 16));
61 Value operand = useOp->getResult(idx: 0);
62
63 // Create a resizable operation with one operand.
64 Operation *user = createOp(context: &context, operands: operand);
65
66 // The same number of operands is okay.
67 user->setOperands(operand);
68 EXPECT_EQ(user->getNumOperands(), 1u);
69
70 // Removing is okay.
71 user->setOperands({});
72 EXPECT_EQ(user->getNumOperands(), 0u);
73
74 // Adding more operands is okay.
75 user->setOperands({operand, operand, operand});
76 EXPECT_EQ(user->getNumOperands(), 3u);
77
78 // Destroy the operations.
79 user->destroy();
80 useOp->destroy();
81}
82
83TEST(OperandStorageTest, RangeReplace) {
84 MLIRContext context;
85 Builder builder(&context);
86
87 Operation *useOp =
88 createOp(context: &context, /*operands=*/{}, resultTypes: builder.getIntegerType(width: 16));
89 Value operand = useOp->getResult(idx: 0);
90
91 // Create a resizable operation with one operand.
92 Operation *user = createOp(context: &context, operands: operand);
93
94 // Check setting with the same number of operands.
95 user->setOperands(/*start=*/0, /*length=*/1, operands: operand);
96 EXPECT_EQ(user->getNumOperands(), 1u);
97
98 // Check setting with more operands.
99 user->setOperands(/*start=*/0, /*length=*/1, operands: {operand, operand, operand});
100 EXPECT_EQ(user->getNumOperands(), 3u);
101
102 // Check setting with less operands.
103 user->setOperands(/*start=*/1, /*length=*/2, operands: {operand});
104 EXPECT_EQ(user->getNumOperands(), 2u);
105
106 // Check inserting without replacing operands.
107 user->setOperands(/*start=*/2, /*length=*/0, operands: {operand});
108 EXPECT_EQ(user->getNumOperands(), 3u);
109
110 // Check erasing operands.
111 user->setOperands(/*start=*/0, /*length=*/3, operands: {});
112 EXPECT_EQ(user->getNumOperands(), 0u);
113
114 // Destroy the operations.
115 user->destroy();
116 useOp->destroy();
117}
118
119TEST(OperandStorageTest, MutableRange) {
120 MLIRContext context;
121 Builder builder(&context);
122
123 Operation *useOp =
124 createOp(context: &context, /*operands=*/{}, resultTypes: builder.getIntegerType(width: 16));
125 Value operand = useOp->getResult(idx: 0);
126
127 // Create a resizable operation with one operand.
128 Operation *user = createOp(context: &context, operands: operand);
129
130 // Check setting with the same number of operands.
131 MutableOperandRange mutableOperands(user);
132 mutableOperands.assign(value: operand);
133 EXPECT_EQ(mutableOperands.size(), 1u);
134 EXPECT_EQ(user->getNumOperands(), 1u);
135
136 // Check setting with more operands.
137 mutableOperands.assign(values: {operand, operand, operand});
138 EXPECT_EQ(mutableOperands.size(), 3u);
139 EXPECT_EQ(user->getNumOperands(), 3u);
140
141 // Check with inserting a new operand.
142 mutableOperands.append(values: {operand, operand});
143 EXPECT_EQ(mutableOperands.size(), 5u);
144 EXPECT_EQ(user->getNumOperands(), 5u);
145
146 // Check erasing operands.
147 mutableOperands.clear();
148 EXPECT_EQ(mutableOperands.size(), 0u);
149 EXPECT_EQ(user->getNumOperands(), 0u);
150
151 // Destroy the operations.
152 user->destroy();
153 useOp->destroy();
154}
155
156TEST(OperandStorageTest, RangeErase) {
157 MLIRContext context;
158 Builder builder(&context);
159
160 Type type = builder.getNoneType();
161 Operation *useOp = createOp(context: &context, /*operands=*/{}, resultTypes: {type, type});
162 Value operand1 = useOp->getResult(idx: 0);
163 Value operand2 = useOp->getResult(idx: 1);
164
165 // Create an operation with operands to erase.
166 Operation *user =
167 createOp(context: &context, operands: {operand2, operand1, operand2, operand1});
168 BitVector eraseIndices(user->getNumOperands());
169
170 // Check erasing no operands.
171 user->eraseOperands(eraseIndices);
172 EXPECT_EQ(user->getNumOperands(), 4u);
173
174 // Check erasing disjoint operands.
175 eraseIndices.set(0);
176 eraseIndices.set(3);
177 user->eraseOperands(eraseIndices);
178 EXPECT_EQ(user->getNumOperands(), 2u);
179 EXPECT_EQ(user->getOperand(0), operand1);
180 EXPECT_EQ(user->getOperand(1), operand2);
181
182 // Destroy the operations.
183 user->destroy();
184 useOp->destroy();
185}
186
187TEST(OperationOrderTest, OrderIsAlwaysValid) {
188 MLIRContext context;
189 Builder builder(&context);
190
191 Operation *containerOp = createOp(context: &context, /*operands=*/{},
192 /*resultTypes=*/{},
193 /*numRegions=*/1);
194 Region &region = containerOp->getRegion(index: 0);
195 Block *block = new Block();
196 region.push_back(block);
197
198 // Insert two operations, then iteratively add more operations in the middle
199 // of them. Eventually we will insert more than kOrderStride operations and
200 // the block order will need to be recomputed.
201 Operation *frontOp = createOp(context: &context);
202 Operation *backOp = createOp(context: &context);
203 block->push_back(op: frontOp);
204 block->push_back(op: backOp);
205
206 // Chosen to be larger than Operation::kOrderStride.
207 int kNumOpsToInsert = 10;
208 for (int i = 0; i < kNumOpsToInsert; ++i) {
209 Operation *op = createOp(context: &context);
210 block->getOperations().insert(where: backOp->getIterator(), New: op);
211 ASSERT_TRUE(op->isBeforeInBlock(backOp));
212 // Note verifyOpOrder() returns false if the order is valid.
213 ASSERT_FALSE(block->verifyOpOrder());
214 }
215
216 containerOp->destroy();
217}
218
219TEST(OperationFormatPrintTest, CanUseVariadicFormat) {
220 MLIRContext context;
221 Builder builder(&context);
222
223 Operation *op = createOp(context: &context);
224
225 std::string str = formatv(Fmt: "{0}", Vals&: *op).str();
226 ASSERT_STREQ(str.c_str(), "\"foo.bar\"() : () -> ()");
227
228 op->destroy();
229}
230
231TEST(OperationFormatPrintTest, CanPrintNameAsPrefix) {
232 MLIRContext context;
233 Builder builder(&context);
234
235 context.allowUnregisteredDialects();
236 Operation *op = Operation::create(
237 location: NameLoc::get(name: StringAttr::get(context: &context, bytes: "my_named_loc")),
238 name: OperationName("t.op", &context), resultTypes: builder.getIntegerType(width: 16), operands: {},
239 attributes: std::nullopt, properties: nullptr, successors: {}, numRegions: 0);
240
241 std::string str;
242 OpPrintingFlags flags;
243 flags.printNameLocAsPrefix(enable: true);
244 llvm::raw_string_ostream os(str);
245 op->print(os, flags);
246 ASSERT_STREQ(str.c_str(), "%my_named_loc = \"t.op\"() : () -> i16\n");
247
248 op->destroy();
249}
250
251TEST(NamedAttrListTest, TestAppendAssign) {
252 MLIRContext ctx;
253 NamedAttrList attrs;
254 Builder b(&ctx);
255
256 attrs.append(name: b.getStringAttr(bytes: "foo"), attr: b.getStringAttr(bytes: "bar"));
257 attrs.append(name: "baz", attr: b.getStringAttr(bytes: "boo"));
258
259 {
260 auto *it = attrs.begin();
261 EXPECT_EQ(it->getName(), b.getStringAttr("foo"));
262 EXPECT_EQ(it->getValue(), b.getStringAttr("bar"));
263 ++it;
264 EXPECT_EQ(it->getName(), b.getStringAttr("baz"));
265 EXPECT_EQ(it->getValue(), b.getStringAttr("boo"));
266 }
267
268 attrs.append(name: "foo", attr: b.getStringAttr(bytes: "zoo"));
269 {
270 auto dup = attrs.findDuplicate();
271 ASSERT_TRUE(dup.has_value());
272 }
273
274 SmallVector<NamedAttribute> newAttrs = {
275 b.getNamedAttr(name: "foo", val: b.getStringAttr(bytes: "f")),
276 b.getNamedAttr(name: "zoo", val: b.getStringAttr(bytes: "z")),
277 };
278 attrs.assign(range: newAttrs);
279
280 auto dup = attrs.findDuplicate();
281 ASSERT_FALSE(dup.has_value());
282
283 {
284 auto *it = attrs.begin();
285 EXPECT_EQ(it->getName(), b.getStringAttr("foo"));
286 EXPECT_EQ(it->getValue(), b.getStringAttr("f"));
287 ++it;
288 EXPECT_EQ(it->getName(), b.getStringAttr("zoo"));
289 EXPECT_EQ(it->getValue(), b.getStringAttr("z"));
290 }
291
292 attrs.assign(range: {});
293 ASSERT_TRUE(attrs.empty());
294}
295
296TEST(OperandStorageTest, PopulateDefaultAttrs) {
297 MLIRContext context;
298 context.getOrLoadDialect<test::TestDialect>();
299 Builder builder(&context);
300
301 OpBuilder b(&context);
302 auto req1 = b.getI32IntegerAttr(value: 10);
303 auto req2 = b.getI32IntegerAttr(value: 60);
304 // Verify default attributes populated post op creation.
305 Operation *op = b.create<test::OpAttrMatch1>(location: b.getUnknownLoc(), args&: req1, args: nullptr,
306 args: nullptr, args&: req2);
307 auto opt = op->getInherentAttr(name: "default_valued_attr");
308 EXPECT_NE(opt, nullptr) << *op;
309
310 op->destroy();
311}
312
313TEST(OperationEquivalenceTest, HashWorksWithFlags) {
314 MLIRContext context;
315 context.getOrLoadDialect<test::TestDialect>();
316 OpBuilder b(&context);
317
318 auto *op1 = createOp(context: &context);
319 // `op1` has an unknown loc.
320 auto *op2 = createOp(context: &context);
321 op2->setLoc(NameLoc::get(name: StringAttr::get(context: &context, bytes: "foo")));
322 auto getHash = [](Operation *op, OperationEquivalence::Flags flags) {
323 return OperationEquivalence::computeHash(
324 op, hashOperands: OperationEquivalence::ignoreHashValue,
325 hashResults: OperationEquivalence::ignoreHashValue, flags);
326 };
327 // Check ignore location.
328 EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreLocations),
329 getHash(op2, OperationEquivalence::IgnoreLocations));
330 EXPECT_NE(getHash(op1, OperationEquivalence::None),
331 getHash(op2, OperationEquivalence::None));
332 op1->setLoc(NameLoc::get(name: StringAttr::get(context: &context, bytes: "foo")));
333 // Check ignore discardable dictionary attributes.
334 SmallVector<NamedAttribute> newAttrs = {
335 b.getNamedAttr(name: "foo", val: b.getStringAttr(bytes: "f"))};
336 op1->setAttrs(newAttrs);
337 EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreDiscardableAttrs),
338 getHash(op2, OperationEquivalence::IgnoreDiscardableAttrs));
339 EXPECT_NE(getHash(op1, OperationEquivalence::None),
340 getHash(op2, OperationEquivalence::None));
341 op1->destroy();
342 op2->destroy();
343
344 // Check ignore properties.
345 auto req1 = b.getI32IntegerAttr(value: 10);
346 Operation *opWithProperty1 = b.create<test::OpAttrMatch1>(
347 location: b.getUnknownLoc(), args&: req1, args: nullptr, args: nullptr, args&: req1);
348 auto req2 = b.getI32IntegerAttr(value: 60);
349 Operation *opWithProperty2 = b.create<test::OpAttrMatch1>(
350 location: b.getUnknownLoc(), args&: req2, args: nullptr, args: nullptr, args&: req2);
351 EXPECT_EQ(getHash(opWithProperty1, OperationEquivalence::IgnoreProperties),
352 getHash(opWithProperty2, OperationEquivalence::IgnoreProperties));
353 EXPECT_NE(getHash(opWithProperty1, OperationEquivalence::None),
354 getHash(opWithProperty2, OperationEquivalence::None));
355 opWithProperty1->destroy();
356 opWithProperty2->destroy();
357}
358
359} // namespace
360

source code of mlir/unittests/IR/OperationSupportTest.cpp