1//===- unittest/Tooling/RefactoringTestActionRulesTest.cpp ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "ReplacementTest.h"
10#include "RewriterTestContext.h"
11#include "clang/Tooling/Refactoring.h"
12#include "clang/Tooling/Refactoring/Extract/Extract.h"
13#include "clang/Tooling/Refactoring/RefactoringAction.h"
14#include "clang/Tooling/Refactoring/RefactoringDiagnostic.h"
15#include "clang/Tooling/Refactoring/Rename/SymbolName.h"
16#include "clang/Tooling/Tooling.h"
17#include "llvm/Support/Errc.h"
18#include "gtest/gtest.h"
19#include <optional>
20
21using namespace clang;
22using namespace tooling;
23
24namespace {
25
26class RefactoringActionRulesTest : public ::testing::Test {
27protected:
28 void SetUp() override {
29 Context.Sources.setMainFileID(
30 Context.createInMemoryFile(Name: "input.cpp", Content: DefaultCode));
31 }
32
33 RewriterTestContext Context;
34 std::string DefaultCode = std::string(100, 'a');
35};
36
37Expected<AtomicChanges>
38createReplacements(const std::unique_ptr<RefactoringActionRule> &Rule,
39 RefactoringRuleContext &Context) {
40 class Consumer final : public RefactoringResultConsumer {
41 void handleError(llvm::Error Err) override { Result = std::move(Err); }
42
43 void handle(AtomicChanges SourceReplacements) override {
44 Result = std::move(SourceReplacements);
45 }
46 void handle(SymbolOccurrences Occurrences) override {
47 RefactoringResultConsumer::handle(Occurrences: std::move(Occurrences));
48 }
49
50 public:
51 std::optional<Expected<AtomicChanges>> Result;
52 };
53
54 Consumer C;
55 Rule->invoke(Consumer&: C, Context);
56 return std::move(*C.Result);
57}
58
59TEST_F(RefactoringActionRulesTest, MyFirstRefactoringRule) {
60 class ReplaceAWithB : public SourceChangeRefactoringRule {
61 std::pair<SourceRange, int> Selection;
62
63 public:
64 ReplaceAWithB(std::pair<SourceRange, int> Selection)
65 : Selection(Selection) {}
66
67 static Expected<ReplaceAWithB>
68 initiate(RefactoringRuleContext &Cotnext,
69 std::pair<SourceRange, int> Selection) {
70 return ReplaceAWithB(Selection);
71 }
72
73 Expected<AtomicChanges>
74 createSourceReplacements(RefactoringRuleContext &Context) {
75 const SourceManager &SM = Context.getSources();
76 SourceLocation Loc =
77 Selection.first.getBegin().getLocWithOffset(Offset: Selection.second);
78 AtomicChange Change(SM, Loc);
79 llvm::Error E = Change.replace(SM, Loc, Length: 1, Text: "b");
80 if (E)
81 return std::move(E);
82 return AtomicChanges{Change};
83 }
84 };
85
86 class SelectionRequirement : public SourceRangeSelectionRequirement {
87 public:
88 Expected<std::pair<SourceRange, int>>
89 evaluate(RefactoringRuleContext &Context) const {
90 Expected<SourceRange> R =
91 SourceRangeSelectionRequirement::evaluate(Context);
92 if (!R)
93 return R.takeError();
94 return std::make_pair(x&: *R, y: 20);
95 }
96 };
97 auto Rule =
98 createRefactoringActionRule<ReplaceAWithB>(Requirements: SelectionRequirement());
99
100 // When the requirements are satisfied, the rule's function must be invoked.
101 {
102 RefactoringRuleContext RefContext(Context.Sources);
103 SourceLocation Cursor =
104 Context.Sources.getLocForStartOfFile(FID: Context.Sources.getMainFileID())
105 .getLocWithOffset(Offset: 10);
106 RefContext.setSelectionRange({Cursor, Cursor});
107
108 Expected<AtomicChanges> ErrorOrResult =
109 createReplacements(Rule, Context&: RefContext);
110 ASSERT_FALSE(!ErrorOrResult);
111 AtomicChanges Result = std::move(*ErrorOrResult);
112 ASSERT_EQ(Result.size(), 1u);
113 std::string YAMLString =
114 const_cast<AtomicChange &>(Result[0]).toYAMLString();
115
116 ASSERT_STREQ("---\n"
117 "Key: 'input.cpp:30'\n"
118 "FilePath: input.cpp\n"
119 "Error: ''\n"
120 "InsertedHeaders: []\n"
121 "RemovedHeaders: []\n"
122 "Replacements:\n"
123 " - FilePath: input.cpp\n"
124 " Offset: 30\n"
125 " Length: 1\n"
126 " ReplacementText: b\n"
127 "...\n",
128 YAMLString.c_str());
129 }
130
131 // When one of the requirements is not satisfied, invoke should return a
132 // valid error.
133 {
134 RefactoringRuleContext RefContext(Context.Sources);
135 Expected<AtomicChanges> ErrorOrResult =
136 createReplacements(Rule, Context&: RefContext);
137
138 ASSERT_TRUE(!ErrorOrResult);
139 unsigned DiagID;
140 llvm::handleAllErrors(E: ErrorOrResult.takeError(),
141 Handlers: [&](DiagnosticError &Error) {
142 DiagID = Error.getDiagnostic().second.getDiagID();
143 });
144 EXPECT_EQ(DiagID, diag::err_refactor_no_selection);
145 }
146}
147
148TEST_F(RefactoringActionRulesTest, ReturnError) {
149 class ErrorRule : public SourceChangeRefactoringRule {
150 public:
151 static Expected<ErrorRule> initiate(RefactoringRuleContext &,
152 SourceRange R) {
153 return ErrorRule(R);
154 }
155
156 ErrorRule(SourceRange R) {}
157 Expected<AtomicChanges> createSourceReplacements(RefactoringRuleContext &) {
158 return llvm::make_error<llvm::StringError>(
159 Args: "Error", Args: llvm::make_error_code(E: llvm::errc::invalid_argument));
160 }
161 };
162
163 auto Rule =
164 createRefactoringActionRule<ErrorRule>(Requirements: SourceRangeSelectionRequirement());
165 RefactoringRuleContext RefContext(Context.Sources);
166 SourceLocation Cursor =
167 Context.Sources.getLocForStartOfFile(FID: Context.Sources.getMainFileID());
168 RefContext.setSelectionRange({Cursor, Cursor});
169 Expected<AtomicChanges> Result = createReplacements(Rule, Context&: RefContext);
170
171 ASSERT_TRUE(!Result);
172 std::string Message;
173 llvm::handleAllErrors(E: Result.takeError(), Handlers: [&](llvm::StringError &Error) {
174 Message = Error.getMessage();
175 });
176 EXPECT_EQ(Message, "Error");
177}
178
179std::optional<SymbolOccurrences>
180findOccurrences(RefactoringActionRule &Rule, RefactoringRuleContext &Context) {
181 class Consumer final : public RefactoringResultConsumer {
182 void handleError(llvm::Error) override {}
183 void handle(SymbolOccurrences Occurrences) override {
184 Result = std::move(Occurrences);
185 }
186 void handle(AtomicChanges Changes) override {
187 RefactoringResultConsumer::handle(SourceReplacements: std::move(Changes));
188 }
189
190 public:
191 std::optional<SymbolOccurrences> Result;
192 };
193
194 Consumer C;
195 Rule.invoke(Consumer&: C, Context);
196 return std::move(C.Result);
197}
198
199TEST_F(RefactoringActionRulesTest, ReturnSymbolOccurrences) {
200 class FindOccurrences : public FindSymbolOccurrencesRefactoringRule {
201 SourceRange Selection;
202
203 public:
204 FindOccurrences(SourceRange Selection) : Selection(Selection) {}
205
206 static Expected<FindOccurrences> initiate(RefactoringRuleContext &,
207 SourceRange Selection) {
208 return FindOccurrences(Selection);
209 }
210
211 Expected<SymbolOccurrences>
212 findSymbolOccurrences(RefactoringRuleContext &) override {
213 SymbolOccurrences Occurrences;
214 Occurrences.push_back(x: SymbolOccurrence(SymbolName("test"),
215 SymbolOccurrence::MatchingSymbol,
216 Selection.getBegin()));
217 return std::move(Occurrences);
218 }
219 };
220
221 auto Rule = createRefactoringActionRule<FindOccurrences>(
222 Requirements: SourceRangeSelectionRequirement());
223
224 RefactoringRuleContext RefContext(Context.Sources);
225 SourceLocation Cursor =
226 Context.Sources.getLocForStartOfFile(FID: Context.Sources.getMainFileID());
227 RefContext.setSelectionRange({Cursor, Cursor});
228 std::optional<SymbolOccurrences> Result = findOccurrences(Rule&: *Rule, Context&: RefContext);
229
230 ASSERT_FALSE(!Result);
231 SymbolOccurrences Occurrences = std::move(*Result);
232 EXPECT_EQ(Occurrences.size(), 1u);
233 EXPECT_EQ(Occurrences[0].getKind(), SymbolOccurrence::MatchingSymbol);
234 EXPECT_EQ(Occurrences[0].getNameRanges().size(), 1u);
235 EXPECT_EQ(Occurrences[0].getNameRanges()[0],
236 SourceRange(Cursor, Cursor.getLocWithOffset(strlen("test"))));
237}
238
239TEST_F(RefactoringActionRulesTest, EditorCommandBinding) {
240 const RefactoringDescriptor &Descriptor = ExtractFunction::describe();
241 EXPECT_EQ(Descriptor.Name, "extract-function");
242 EXPECT_EQ(
243 Descriptor.Description,
244 "(WIP action; use with caution!) Extracts code into a new function");
245 EXPECT_EQ(Descriptor.Title, "Extract Function");
246}
247
248} // end anonymous namespace
249

source code of clang/unittests/Tooling/RefactoringActionRulesTest.cpp