1//===-- LSPClient.cpp - Helper for ClangdLSPServer tests ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "LSPClient.h"
10#include "Protocol.h"
11#include "TestFS.h"
12#include "Transport.h"
13#include "support/Logger.h"
14#include "support/Threading.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/StringMap.h"
17#include "llvm/ADT/StringRef.h"
18#include "llvm/Support/Error.h"
19#include "llvm/Support/JSON.h"
20#include "llvm/Support/Path.h"
21#include "llvm/Support/raw_ostream.h"
22#include "gtest/gtest.h"
23#include <condition_variable>
24#include <cstddef>
25#include <cstdint>
26#include <deque>
27#include <functional>
28#include <memory>
29#include <mutex>
30#include <optional>
31#include <queue>
32#include <string>
33#include <utility>
34#include <vector>
35
36namespace clang {
37namespace clangd {
38
39llvm::Expected<llvm::json::Value> clang::clangd::LSPClient::CallResult::take() {
40 std::unique_lock<std::mutex> Lock(Mu);
41 static constexpr size_t TimeoutSecs = 60;
42 if (!clangd::wait(Lock, CV, D: timeoutSeconds(Seconds: TimeoutSecs),
43 F: [this] { return Value.has_value(); })) {
44 ADD_FAILURE() << "No result from call after " << TimeoutSecs << " seconds!";
45 return llvm::json::Value(nullptr);
46 }
47 auto Res = std::move(*Value);
48 Value.reset();
49 return Res;
50}
51
52llvm::json::Value LSPClient::CallResult::takeValue() {
53 auto ExpValue = take();
54 if (!ExpValue) {
55 ADD_FAILURE() << "takeValue(): " << llvm::toString(E: ExpValue.takeError());
56 return llvm::json::Value(nullptr);
57 }
58 return std::move(*ExpValue);
59}
60
61void LSPClient::CallResult::set(llvm::Expected<llvm::json::Value> V) {
62 std::lock_guard<std::mutex> Lock(Mu);
63 if (Value) {
64 ADD_FAILURE() << "Multiple replies";
65 llvm::consumeError(Err: V.takeError());
66 return;
67 }
68 Value = std::move(V);
69 CV.notify_all();
70}
71
72LSPClient::CallResult::~CallResult() {
73 if (Value && !*Value) {
74 ADD_FAILURE() << llvm::toString(E: Value->takeError());
75 }
76}
77
78static void logBody(llvm::StringRef Method, llvm::json::Value V, bool Send) {
79 // We invert <<< and >>> as the combined log is from the server's viewpoint.
80 vlog(Fmt: "{0} {1}: {2:2}", Vals: Send ? "<<<" : ">>>", Vals&: Method, Vals&: V);
81}
82
83class LSPClient::TransportImpl : public Transport {
84public:
85 std::pair<llvm::json::Value, CallResult *> addCallSlot() {
86 std::lock_guard<std::mutex> Lock(Mu);
87 unsigned ID = CallResults.size();
88 CallResults.emplace_back();
89 return {ID, &CallResults.back()};
90 }
91
92 // A null action causes the transport to shut down.
93 void enqueue(std::function<void(MessageHandler &)> Action) {
94 std::lock_guard<std::mutex> Lock(Mu);
95 Actions.push(x: std::move(Action));
96 CV.notify_all();
97 }
98
99 std::vector<llvm::json::Value> takeNotifications(llvm::StringRef Method) {
100 std::vector<llvm::json::Value> Result;
101 {
102 std::lock_guard<std::mutex> Lock(Mu);
103 std::swap(x&: Result, y&: Notifications[Method]);
104 }
105 return Result;
106 }
107
108 void expectCall(llvm::StringRef Method) {
109 std::lock_guard<std::mutex> Lock(Mu);
110 Calls[Method] = {};
111 }
112
113 std::vector<llvm::json::Value> takeCallParams(llvm::StringRef Method) {
114 std::vector<llvm::json::Value> Result;
115 {
116 std::lock_guard<std::mutex> Lock(Mu);
117 std::swap(x&: Result, y&: Calls[Method]);
118 }
119 return Result;
120 }
121
122private:
123 void reply(llvm::json::Value ID,
124 llvm::Expected<llvm::json::Value> V) override {
125 if (V) // Nothing additional to log for error.
126 logBody(Method: "reply", V: *V, /*Send=*/false);
127 std::lock_guard<std::mutex> Lock(Mu);
128 if (auto I = ID.getAsInteger()) {
129 if (*I >= 0 && *I < static_cast<int64_t>(CallResults.size())) {
130 CallResults[*I].set(std::move(V));
131 return;
132 }
133 }
134 ADD_FAILURE() << "Invalid reply to ID " << ID;
135 llvm::consumeError(Err: std::move(V).takeError());
136 }
137
138 void notify(llvm::StringRef Method, llvm::json::Value V) override {
139 logBody(Method, V, /*Send=*/false);
140 std::lock_guard<std::mutex> Lock(Mu);
141 Notifications[Method].push_back(x: std::move(V));
142 }
143
144 void call(llvm::StringRef Method, llvm::json::Value Params,
145 llvm::json::Value ID) override {
146 logBody(Method, V: Params, /*Send=*/false);
147 std::lock_guard<std::mutex> Lock(Mu);
148 if (Calls.contains(Key: Method)) {
149 Calls[Method].push_back(x: std::move(Params));
150 } else {
151 ADD_FAILURE() << "Unexpected server->client call " << Method;
152 }
153 }
154
155 llvm::Error loop(MessageHandler &H) override {
156 std::unique_lock<std::mutex> Lock(Mu);
157 while (true) {
158 CV.wait(lock&: Lock, p: [&] { return !Actions.empty(); });
159 if (!Actions.front()) // Stop!
160 return llvm::Error::success();
161 auto Action = std::move(Actions.front());
162 Actions.pop();
163 Lock.unlock();
164 Action(H);
165 Lock.lock();
166 }
167 }
168
169 std::mutex Mu;
170 std::deque<CallResult> CallResults;
171 std::queue<std::function<void(Transport::MessageHandler &)>> Actions;
172 std::condition_variable CV;
173 llvm::StringMap<std::vector<llvm::json::Value>> Notifications;
174 llvm::StringMap<std::vector<llvm::json::Value>> Calls;
175};
176
177LSPClient::LSPClient() : T(std::make_unique<TransportImpl>()) {}
178LSPClient::~LSPClient() = default;
179
180LSPClient::CallResult &LSPClient::call(llvm::StringRef Method,
181 llvm::json::Value Params) {
182 auto Slot = T->addCallSlot();
183 T->enqueue(Action: [ID(Slot.first), Method(Method.str()),
184 Params(std::move(Params))](Transport::MessageHandler &H) {
185 logBody(Method, V: Params, /*Send=*/true);
186 H.onCall(Method, Params: std::move(Params), ID);
187 });
188 return *Slot.second;
189}
190
191void LSPClient::expectServerCall(llvm::StringRef Method) {
192 T->expectCall(Method);
193}
194
195void LSPClient::notify(llvm::StringRef Method, llvm::json::Value Params) {
196 T->enqueue(Action: [Method(Method.str()),
197 Params(std::move(Params))](Transport::MessageHandler &H) {
198 logBody(Method, V: Params, /*Send=*/true);
199 H.onNotify(Method, std::move(Params));
200 });
201}
202
203std::vector<llvm::json::Value>
204LSPClient::takeNotifications(llvm::StringRef Method) {
205 return T->takeNotifications(Method);
206}
207
208std::vector<llvm::json::Value>
209LSPClient::takeCallParams(llvm::StringRef Method) {
210 return T->takeCallParams(Method);
211}
212
213void LSPClient::stop() { T->enqueue(Action: nullptr); }
214
215Transport &LSPClient::transport() { return *T; }
216
217using Obj = llvm::json::Object;
218
219llvm::json::Value LSPClient::uri(llvm::StringRef Path) {
220 std::string Storage;
221 if (!llvm::sys::path::is_absolute(path: Path))
222 Path = Storage = testPath(File: Path);
223 return toJSON(U: URIForFile::canonicalize(AbsPath: Path, TUPath: Path));
224}
225llvm::json::Value LSPClient::documentID(llvm::StringRef Path) {
226 return Obj{{.K: "uri", .V: uri(Path)}};
227}
228
229void LSPClient::didOpen(llvm::StringRef Path, llvm::StringRef Content) {
230 notify(
231 Method: "textDocument/didOpen",
232 Params: Obj{{.K: "textDocument",
233 .V: Obj{{.K: "uri", .V: uri(Path)}, {.K: "text", .V: Content}, {.K: "languageId", .V: "cpp"}}}});
234}
235void LSPClient::didChange(llvm::StringRef Path, llvm::StringRef Content) {
236 notify(Method: "textDocument/didChange",
237 Params: Obj{{.K: "textDocument", .V: documentID(Path)},
238 {.K: "contentChanges", .V: llvm::json::Array{Obj{{.K: "text", .V: Content}}}}});
239}
240void LSPClient::didClose(llvm::StringRef Path) {
241 notify(Method: "textDocument/didClose", Params: Obj{{.K: "textDocument", .V: documentID(Path)}});
242}
243
244void LSPClient::sync() { call(Method: "sync", Params: nullptr).takeValue(); }
245
246std::optional<std::vector<llvm::json::Value>>
247LSPClient::diagnostics(llvm::StringRef Path) {
248 sync();
249 auto Notifications = takeNotifications(Method: "textDocument/publishDiagnostics");
250 for (const auto &Notification : llvm::reverse(C&: Notifications)) {
251 if (const auto *PubDiagsParams = Notification.getAsObject()) {
252 auto U = PubDiagsParams->getString(K: "uri");
253 auto *D = PubDiagsParams->getArray(K: "diagnostics");
254 if (!U || !D) {
255 ADD_FAILURE() << "Bad PublishDiagnosticsParams: " << PubDiagsParams;
256 continue;
257 }
258 if (*U == uri(Path))
259 return std::vector<llvm::json::Value>(D->begin(), D->end());
260 }
261 }
262 return {};
263}
264
265} // namespace clangd
266} // namespace clang
267

source code of clang-tools-extra/clangd/unittests/LSPClient.cpp