1 | //===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | #include "mlir/IR/SymbolTable.h" |
9 | #include "mlir/IR/BuiltinOps.h" |
10 | #include "mlir/IR/Verifier.h" |
11 | #include "mlir/Interfaces/CallInterfaces.h" |
12 | #include "mlir/Interfaces/FunctionInterfaces.h" |
13 | #include "mlir/Parser/Parser.h" |
14 | |
15 | #include "gtest/gtest.h" |
16 | |
17 | using namespace mlir; |
18 | |
19 | namespace test { |
20 | void registerTestDialect(DialectRegistry &); |
21 | } // namespace test |
22 | |
23 | class ReplaceAllSymbolUsesTest : public ::testing::Test { |
24 | protected: |
25 | using ReplaceFnType = llvm::function_ref<LogicalResult( |
26 | SymbolTable, ModuleOp, Operation *, Operation *)>; |
27 | |
28 | void SetUp() override { |
29 | ::test::registerTestDialect(registry); |
30 | context = std::make_unique<MLIRContext>(args&: registry); |
31 | } |
32 | |
33 | void testReplaceAllSymbolUses(ReplaceFnType replaceFn) { |
34 | // Set up IR and find func ops. |
35 | OwningOpRef<ModuleOp> module = |
36 | parseSourceString<ModuleOp>(kInput, context.get()); |
37 | SymbolTable symbolTable(module.get()); |
38 | auto opIterator = module->getBody(0)->getOperations().begin(); |
39 | auto fooOp = cast<FunctionOpInterface>(opIterator++); |
40 | auto barOp = cast<FunctionOpInterface>(opIterator++); |
41 | ASSERT_EQ(fooOp.getNameAttr(), "foo" ); |
42 | ASSERT_EQ(barOp.getNameAttr(), "bar" ); |
43 | |
44 | // Call test function that does symbol replacement. |
45 | LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp); |
46 | ASSERT_TRUE(succeeded(res)); |
47 | ASSERT_TRUE(succeeded(verify(module.get()))); |
48 | |
49 | // Check that it got renamed. |
50 | bool calleeFound = false; |
51 | fooOp->walk([&](CallOpInterface callOp) { |
52 | StringAttr callee = dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee()) |
53 | .getLeafReference(); |
54 | EXPECT_EQ(callee, "baz" ); |
55 | calleeFound = true; |
56 | }); |
57 | EXPECT_TRUE(calleeFound); |
58 | } |
59 | |
60 | std::unique_ptr<MLIRContext> context; |
61 | |
62 | private: |
63 | constexpr static llvm::StringLiteral kInput = R"MLIR( |
64 | module { |
65 | test.conversion_func_op private @foo() { |
66 | "test.conversion_call_op"() { callee=@bar } : () -> () |
67 | "test.return"() : () -> () |
68 | } |
69 | test.conversion_func_op private @bar() |
70 | } |
71 | )MLIR" ; |
72 | |
73 | DialectRegistry registry; |
74 | }; |
75 | |
76 | namespace { |
77 | |
78 | TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) { |
79 | // Symbol as `Operation *`, rename within module. |
80 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
81 | auto barOp) -> LogicalResult { |
82 | return symbolTable.replaceAllSymbolUses( |
83 | barOp, StringAttr::get(context.get(), "baz" ), module); |
84 | }); |
85 | } |
86 | |
87 | TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) { |
88 | // Symbol as `StringAttr`, rename within module. |
89 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
90 | auto barOp) -> LogicalResult { |
91 | return symbolTable.replaceAllSymbolUses( |
92 | StringAttr::get(context.get(), "bar" ), |
93 | StringAttr::get(context.get(), "baz" ), module); |
94 | }); |
95 | } |
96 | |
97 | TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) { |
98 | // Symbol as `Operation *`, rename within module body. |
99 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
100 | auto barOp) -> LogicalResult { |
101 | return symbolTable.replaceAllSymbolUses( |
102 | barOp, StringAttr::get(context.get(), "baz" ), &module->getRegion(0)); |
103 | }); |
104 | } |
105 | |
106 | TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) { |
107 | // Symbol as `StringAttr`, rename within module body. |
108 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
109 | auto barOp) -> LogicalResult { |
110 | return symbolTable.replaceAllSymbolUses( |
111 | StringAttr::get(context.get(), "bar" ), |
112 | StringAttr::get(context.get(), "baz" ), &module->getRegion(0)); |
113 | }); |
114 | } |
115 | |
116 | TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) { |
117 | // Symbol as `Operation *`, rename within function. |
118 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
119 | auto barOp) -> LogicalResult { |
120 | return symbolTable.replaceAllSymbolUses( |
121 | barOp, StringAttr::get(context.get(), "baz" ), fooOp); |
122 | }); |
123 | } |
124 | |
125 | TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { |
126 | // Symbol as `StringAttr`, rename within function. |
127 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
128 | auto barOp) -> LogicalResult { |
129 | return symbolTable.replaceAllSymbolUses( |
130 | StringAttr::get(context.get(), "bar" ), |
131 | StringAttr::get(context.get(), "baz" ), fooOp); |
132 | }); |
133 | } |
134 | |
135 | } // namespace |
136 | |