1//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "mlir/IR/SymbolTable.h"
9#include "mlir/IR/BuiltinOps.h"
10#include "mlir/IR/Verifier.h"
11#include "mlir/Interfaces/CallInterfaces.h"
12#include "mlir/Interfaces/FunctionInterfaces.h"
13#include "mlir/Parser/Parser.h"
14
15#include "gtest/gtest.h"
16
17using namespace mlir;
18
19namespace test {
20void registerTestDialect(DialectRegistry &);
21} // namespace test
22
23class ReplaceAllSymbolUsesTest : public ::testing::Test {
24protected:
25 using ReplaceFnType = llvm::function_ref<LogicalResult(
26 SymbolTable, ModuleOp, Operation *, Operation *)>;
27
28 void SetUp() override {
29 ::test::registerTestDialect(registry);
30 context = std::make_unique<MLIRContext>(args&: registry);
31 }
32
33 void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
34 // Set up IR and find func ops.
35 OwningOpRef<ModuleOp> module =
36 parseSourceString<ModuleOp>(kInput, context.get());
37 SymbolTable symbolTable(module.get());
38 auto opIterator = module->getBody(0)->getOperations().begin();
39 auto fooOp = cast<FunctionOpInterface>(opIterator++);
40 auto barOp = cast<FunctionOpInterface>(opIterator++);
41 ASSERT_EQ(fooOp.getNameAttr(), "foo");
42 ASSERT_EQ(barOp.getNameAttr(), "bar");
43
44 // Call test function that does symbol replacement.
45 LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
46 ASSERT_TRUE(succeeded(res));
47 ASSERT_TRUE(succeeded(verify(module.get())));
48
49 // Check that it got renamed.
50 bool calleeFound = false;
51 fooOp->walk([&](CallOpInterface callOp) {
52 StringAttr callee = dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee())
53 .getLeafReference();
54 EXPECT_EQ(callee, "baz");
55 calleeFound = true;
56 });
57 EXPECT_TRUE(calleeFound);
58 }
59
60 std::unique_ptr<MLIRContext> context;
61
62private:
63 constexpr static llvm::StringLiteral kInput = R"MLIR(
64 module {
65 test.conversion_func_op private @foo() {
66 "test.conversion_call_op"() { callee=@bar } : () -> ()
67 "test.return"() : () -> ()
68 }
69 test.conversion_func_op private @bar()
70 }
71 )MLIR";
72
73 DialectRegistry registry;
74};
75
76namespace {
77
78TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
79 // Symbol as `Operation *`, rename within module.
80 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
81 auto barOp) -> LogicalResult {
82 return symbolTable.replaceAllSymbolUses(
83 barOp, StringAttr::get(context.get(), "baz"), module);
84 });
85}
86
87TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
88 // Symbol as `StringAttr`, rename within module.
89 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
90 auto barOp) -> LogicalResult {
91 return symbolTable.replaceAllSymbolUses(
92 StringAttr::get(context.get(), "bar"),
93 StringAttr::get(context.get(), "baz"), module);
94 });
95}
96
97TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
98 // Symbol as `Operation *`, rename within module body.
99 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
100 auto barOp) -> LogicalResult {
101 return symbolTable.replaceAllSymbolUses(
102 barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
103 });
104}
105
106TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
107 // Symbol as `StringAttr`, rename within module body.
108 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
109 auto barOp) -> LogicalResult {
110 return symbolTable.replaceAllSymbolUses(
111 StringAttr::get(context.get(), "bar"),
112 StringAttr::get(context.get(), "baz"), &module->getRegion(0));
113 });
114}
115
116TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
117 // Symbol as `Operation *`, rename within function.
118 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
119 auto barOp) -> LogicalResult {
120 return symbolTable.replaceAllSymbolUses(
121 barOp, StringAttr::get(context.get(), "baz"), fooOp);
122 });
123}
124
125TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
126 // Symbol as `StringAttr`, rename within function.
127 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
128 auto barOp) -> LogicalResult {
129 return symbolTable.replaceAllSymbolUses(
130 StringAttr::get(context.get(), "bar"),
131 StringAttr::get(context.get(), "baz"), fooOp);
132 });
133}
134
135} // namespace
136

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/unittests/IR/SymbolTableTest.cpp