1 | //===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | #include "mlir/IR/SymbolTable.h" |
9 | #include "mlir/IR/BuiltinOps.h" |
10 | #include "mlir/IR/Verifier.h" |
11 | #include "mlir/Interfaces/CallInterfaces.h" |
12 | #include "mlir/Interfaces/FunctionInterfaces.h" |
13 | #include "mlir/Parser/Parser.h" |
14 | |
15 | #include "gtest/gtest.h" |
16 | |
17 | using namespace mlir; |
18 | |
19 | namespace test { |
20 | void registerTestDialect(DialectRegistry &); |
21 | } // namespace test |
22 | |
23 | class ReplaceAllSymbolUsesTest : public ::testing::Test { |
24 | protected: |
25 | using ReplaceFnType = llvm::function_ref<LogicalResult( |
26 | SymbolTable, ModuleOp, Operation *, Operation *)>; |
27 | |
28 | void SetUp() override { |
29 | ::test::registerTestDialect(registry); |
30 | context = std::make_unique<MLIRContext>(args&: registry); |
31 | } |
32 | |
33 | void testReplaceAllSymbolUses(ReplaceFnType replaceFn) { |
34 | // Set up IR and find func ops. |
35 | OwningOpRef<ModuleOp> module = |
36 | parseSourceString<ModuleOp>(kInput, context.get()); |
37 | SymbolTable symbolTable(module.get()); |
38 | auto opIterator = module->getBody(0)->getOperations().begin(); |
39 | auto fooOp = cast<FunctionOpInterface>(opIterator++); |
40 | auto barOp = cast<FunctionOpInterface>(opIterator++); |
41 | ASSERT_EQ(fooOp.getNameAttr(), "foo" ); |
42 | ASSERT_EQ(barOp.getNameAttr(), "bar" ); |
43 | |
44 | // Call test function that does symbol replacement. |
45 | LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp); |
46 | ASSERT_TRUE(succeeded(res)); |
47 | ASSERT_TRUE(succeeded(verify(module.get()))); |
48 | |
49 | // Check that it got renamed. |
50 | bool calleeFound = false; |
51 | fooOp->walk([&](CallOpInterface callOp) { |
52 | StringAttr callee = callOp.getCallableForCallee() |
53 | .dyn_cast<SymbolRefAttr>() |
54 | .getLeafReference(); |
55 | EXPECT_EQ(callee, "baz" ); |
56 | calleeFound = true; |
57 | }); |
58 | EXPECT_TRUE(calleeFound); |
59 | } |
60 | |
61 | std::unique_ptr<MLIRContext> context; |
62 | |
63 | private: |
64 | constexpr static llvm::StringLiteral kInput = R"MLIR( |
65 | module { |
66 | test.conversion_func_op private @foo() { |
67 | "test.conversion_call_op"() { callee=@bar } : () -> () |
68 | "test.return"() : () -> () |
69 | } |
70 | test.conversion_func_op private @bar() |
71 | } |
72 | )MLIR" ; |
73 | |
74 | DialectRegistry registry; |
75 | }; |
76 | |
77 | namespace { |
78 | |
79 | TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) { |
80 | // Symbol as `Operation *`, rename within module. |
81 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
82 | auto barOp) -> LogicalResult { |
83 | return symbolTable.replaceAllSymbolUses( |
84 | barOp, StringAttr::get(context.get(), "baz" ), module); |
85 | }); |
86 | } |
87 | |
88 | TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) { |
89 | // Symbol as `StringAttr`, rename within module. |
90 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
91 | auto barOp) -> LogicalResult { |
92 | return symbolTable.replaceAllSymbolUses( |
93 | StringAttr::get(context.get(), "bar" ), |
94 | StringAttr::get(context.get(), "baz" ), module); |
95 | }); |
96 | } |
97 | |
98 | TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) { |
99 | // Symbol as `Operation *`, rename within module body. |
100 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
101 | auto barOp) -> LogicalResult { |
102 | return symbolTable.replaceAllSymbolUses( |
103 | barOp, StringAttr::get(context.get(), "baz" ), &module->getRegion(0)); |
104 | }); |
105 | } |
106 | |
107 | TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) { |
108 | // Symbol as `StringAttr`, rename within module body. |
109 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
110 | auto barOp) -> LogicalResult { |
111 | return symbolTable.replaceAllSymbolUses( |
112 | StringAttr::get(context.get(), "bar" ), |
113 | StringAttr::get(context.get(), "baz" ), &module->getRegion(0)); |
114 | }); |
115 | } |
116 | |
117 | TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) { |
118 | // Symbol as `Operation *`, rename within function. |
119 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
120 | auto barOp) -> LogicalResult { |
121 | return symbolTable.replaceAllSymbolUses( |
122 | barOp, StringAttr::get(context.get(), "baz" ), fooOp); |
123 | }); |
124 | } |
125 | |
126 | TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { |
127 | // Symbol as `StringAttr`, rename within function. |
128 | testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, |
129 | auto barOp) -> LogicalResult { |
130 | return symbolTable.replaceAllSymbolUses( |
131 | StringAttr::get(context.get(), "bar" ), |
132 | StringAttr::get(context.get(), "baz" ), fooOp); |
133 | }); |
134 | } |
135 | |
136 | } // namespace |
137 | |