1 | //===-- LSPClient.cpp - Helper for ClangdLSPServer tests ------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "LSPClient.h" |
10 | #include "Protocol.h" |
11 | #include "TestFS.h" |
12 | #include "Transport.h" |
13 | #include "support/Logger.h" |
14 | #include "support/Threading.h" |
15 | #include "llvm/ADT/STLExtras.h" |
16 | #include "llvm/ADT/StringMap.h" |
17 | #include "llvm/ADT/StringRef.h" |
18 | #include "llvm/Support/Error.h" |
19 | #include "llvm/Support/JSON.h" |
20 | #include "llvm/Support/Path.h" |
21 | #include "llvm/Support/raw_ostream.h" |
22 | #include "gtest/gtest.h" |
23 | #include <condition_variable> |
24 | #include <cstddef> |
25 | #include <cstdint> |
26 | #include <deque> |
27 | #include <functional> |
28 | #include <memory> |
29 | #include <mutex> |
30 | #include <optional> |
31 | #include <queue> |
32 | #include <string> |
33 | #include <utility> |
34 | #include <vector> |
35 | |
36 | namespace clang { |
37 | namespace clangd { |
38 | |
39 | llvm::Expected<llvm::json::Value> clang::clangd::LSPClient::CallResult::take() { |
40 | std::unique_lock<std::mutex> Lock(Mu); |
41 | static constexpr size_t TimeoutSecs = 60; |
42 | if (!clangd::wait(Lock, CV, D: timeoutSeconds(Seconds: TimeoutSecs), |
43 | F: [this] { return Value.has_value(); })) { |
44 | ADD_FAILURE() << "No result from call after " << TimeoutSecs << " seconds!" ; |
45 | return llvm::json::Value(nullptr); |
46 | } |
47 | auto Res = std::move(*Value); |
48 | Value.reset(); |
49 | return Res; |
50 | } |
51 | |
52 | llvm::json::Value LSPClient::CallResult::takeValue() { |
53 | auto ExpValue = take(); |
54 | if (!ExpValue) { |
55 | ADD_FAILURE() << "takeValue(): " << llvm::toString(E: ExpValue.takeError()); |
56 | return llvm::json::Value(nullptr); |
57 | } |
58 | return std::move(*ExpValue); |
59 | } |
60 | |
61 | void LSPClient::CallResult::set(llvm::Expected<llvm::json::Value> V) { |
62 | std::lock_guard<std::mutex> Lock(Mu); |
63 | if (Value) { |
64 | ADD_FAILURE() << "Multiple replies" ; |
65 | llvm::consumeError(Err: V.takeError()); |
66 | return; |
67 | } |
68 | Value = std::move(V); |
69 | CV.notify_all(); |
70 | } |
71 | |
72 | LSPClient::CallResult::~CallResult() { |
73 | if (Value && !*Value) { |
74 | ADD_FAILURE() << llvm::toString(E: Value->takeError()); |
75 | } |
76 | } |
77 | |
78 | static void logBody(llvm::StringRef Method, llvm::json::Value V, bool Send) { |
79 | // We invert <<< and >>> as the combined log is from the server's viewpoint. |
80 | vlog(Fmt: "{0} {1}: {2:2}" , Vals: Send ? "<<<" : ">>>" , Vals&: Method, Vals&: V); |
81 | } |
82 | |
83 | class LSPClient::TransportImpl : public Transport { |
84 | public: |
85 | std::pair<llvm::json::Value, CallResult *> addCallSlot() { |
86 | std::lock_guard<std::mutex> Lock(Mu); |
87 | unsigned ID = CallResults.size(); |
88 | CallResults.emplace_back(); |
89 | return {ID, &CallResults.back()}; |
90 | } |
91 | |
92 | // A null action causes the transport to shut down. |
93 | void enqueue(std::function<void(MessageHandler &)> Action) { |
94 | std::lock_guard<std::mutex> Lock(Mu); |
95 | Actions.push(x: std::move(Action)); |
96 | CV.notify_all(); |
97 | } |
98 | |
99 | std::vector<llvm::json::Value> takeNotifications(llvm::StringRef Method) { |
100 | std::vector<llvm::json::Value> Result; |
101 | { |
102 | std::lock_guard<std::mutex> Lock(Mu); |
103 | std::swap(x&: Result, y&: Notifications[Method]); |
104 | } |
105 | return Result; |
106 | } |
107 | |
108 | void expectCall(llvm::StringRef Method) { |
109 | std::lock_guard<std::mutex> Lock(Mu); |
110 | Calls[Method] = {}; |
111 | } |
112 | |
113 | std::vector<llvm::json::Value> takeCallParams(llvm::StringRef Method) { |
114 | std::vector<llvm::json::Value> Result; |
115 | { |
116 | std::lock_guard<std::mutex> Lock(Mu); |
117 | std::swap(x&: Result, y&: Calls[Method]); |
118 | } |
119 | return Result; |
120 | } |
121 | |
122 | private: |
123 | void reply(llvm::json::Value ID, |
124 | llvm::Expected<llvm::json::Value> V) override { |
125 | if (V) // Nothing additional to log for error. |
126 | logBody(Method: "reply" , V: *V, /*Send=*/false); |
127 | std::lock_guard<std::mutex> Lock(Mu); |
128 | if (auto I = ID.getAsInteger()) { |
129 | if (*I >= 0 && *I < static_cast<int64_t>(CallResults.size())) { |
130 | CallResults[*I].set(std::move(V)); |
131 | return; |
132 | } |
133 | } |
134 | ADD_FAILURE() << "Invalid reply to ID " << ID; |
135 | llvm::consumeError(Err: std::move(V).takeError()); |
136 | } |
137 | |
138 | void notify(llvm::StringRef Method, llvm::json::Value V) override { |
139 | logBody(Method, V, /*Send=*/false); |
140 | std::lock_guard<std::mutex> Lock(Mu); |
141 | Notifications[Method].push_back(x: std::move(V)); |
142 | } |
143 | |
144 | void call(llvm::StringRef Method, llvm::json::Value Params, |
145 | llvm::json::Value ID) override { |
146 | logBody(Method, V: Params, /*Send=*/false); |
147 | std::lock_guard<std::mutex> Lock(Mu); |
148 | if (Calls.contains(Key: Method)) { |
149 | Calls[Method].push_back(x: std::move(Params)); |
150 | } else { |
151 | ADD_FAILURE() << "Unexpected server->client call " << Method; |
152 | } |
153 | } |
154 | |
155 | llvm::Error loop(MessageHandler &H) override { |
156 | std::unique_lock<std::mutex> Lock(Mu); |
157 | while (true) { |
158 | CV.wait(lock&: Lock, p: [&] { return !Actions.empty(); }); |
159 | if (!Actions.front()) // Stop! |
160 | return llvm::Error::success(); |
161 | auto Action = std::move(Actions.front()); |
162 | Actions.pop(); |
163 | Lock.unlock(); |
164 | Action(H); |
165 | Lock.lock(); |
166 | } |
167 | } |
168 | |
169 | std::mutex Mu; |
170 | std::deque<CallResult> CallResults; |
171 | std::queue<std::function<void(Transport::MessageHandler &)>> Actions; |
172 | std::condition_variable CV; |
173 | llvm::StringMap<std::vector<llvm::json::Value>> Notifications; |
174 | llvm::StringMap<std::vector<llvm::json::Value>> Calls; |
175 | }; |
176 | |
177 | LSPClient::LSPClient() : T(std::make_unique<TransportImpl>()) {} |
178 | LSPClient::~LSPClient() = default; |
179 | |
180 | LSPClient::CallResult &LSPClient::call(llvm::StringRef Method, |
181 | llvm::json::Value Params) { |
182 | auto Slot = T->addCallSlot(); |
183 | T->enqueue(Action: [ID(Slot.first), Method(Method.str()), |
184 | Params(std::move(Params))](Transport::MessageHandler &H) { |
185 | logBody(Method, V: Params, /*Send=*/true); |
186 | H.onCall(Method, Params: std::move(Params), ID); |
187 | }); |
188 | return *Slot.second; |
189 | } |
190 | |
191 | void LSPClient::expectServerCall(llvm::StringRef Method) { |
192 | T->expectCall(Method); |
193 | } |
194 | |
195 | void LSPClient::notify(llvm::StringRef Method, llvm::json::Value Params) { |
196 | T->enqueue(Action: [Method(Method.str()), |
197 | Params(std::move(Params))](Transport::MessageHandler &H) { |
198 | logBody(Method, V: Params, /*Send=*/true); |
199 | H.onNotify(Method, std::move(Params)); |
200 | }); |
201 | } |
202 | |
203 | std::vector<llvm::json::Value> |
204 | LSPClient::takeNotifications(llvm::StringRef Method) { |
205 | return T->takeNotifications(Method); |
206 | } |
207 | |
208 | std::vector<llvm::json::Value> |
209 | LSPClient::takeCallParams(llvm::StringRef Method) { |
210 | return T->takeCallParams(Method); |
211 | } |
212 | |
213 | void LSPClient::stop() { T->enqueue(Action: nullptr); } |
214 | |
215 | Transport &LSPClient::transport() { return *T; } |
216 | |
217 | using Obj = llvm::json::Object; |
218 | |
219 | llvm::json::Value LSPClient::uri(llvm::StringRef Path) { |
220 | std::string Storage; |
221 | if (!llvm::sys::path::is_absolute(path: Path)) |
222 | Path = Storage = testPath(File: Path); |
223 | return toJSON(U: URIForFile::canonicalize(AbsPath: Path, TUPath: Path)); |
224 | } |
225 | llvm::json::Value LSPClient::documentID(llvm::StringRef Path) { |
226 | return Obj{{.K: "uri" , .V: uri(Path)}}; |
227 | } |
228 | |
229 | void LSPClient::didOpen(llvm::StringRef Path, llvm::StringRef Content) { |
230 | notify( |
231 | Method: "textDocument/didOpen" , |
232 | Params: Obj{{.K: "textDocument" , |
233 | .V: Obj{{.K: "uri" , .V: uri(Path)}, {.K: "text" , .V: Content}, {.K: "languageId" , .V: "cpp" }}}}); |
234 | } |
235 | void LSPClient::didChange(llvm::StringRef Path, llvm::StringRef Content) { |
236 | notify(Method: "textDocument/didChange" , |
237 | Params: Obj{{.K: "textDocument" , .V: documentID(Path)}, |
238 | {.K: "contentChanges" , .V: llvm::json::Array{Obj{{.K: "text" , .V: Content}}}}}); |
239 | } |
240 | void LSPClient::didClose(llvm::StringRef Path) { |
241 | notify(Method: "textDocument/didClose" , Params: Obj{{.K: "textDocument" , .V: documentID(Path)}}); |
242 | } |
243 | |
244 | void LSPClient::sync() { call(Method: "sync" , Params: nullptr).takeValue(); } |
245 | |
246 | std::optional<std::vector<llvm::json::Value>> |
247 | LSPClient::diagnostics(llvm::StringRef Path) { |
248 | sync(); |
249 | auto Notifications = takeNotifications(Method: "textDocument/publishDiagnostics" ); |
250 | for (const auto &Notification : llvm::reverse(C&: Notifications)) { |
251 | if (const auto *PubDiagsParams = Notification.getAsObject()) { |
252 | auto U = PubDiagsParams->getString(K: "uri" ); |
253 | auto *D = PubDiagsParams->getArray(K: "diagnostics" ); |
254 | if (!U || !D) { |
255 | ADD_FAILURE() << "Bad PublishDiagnosticsParams: " << PubDiagsParams; |
256 | continue; |
257 | } |
258 | if (*U == uri(Path)) |
259 | return std::vector<llvm::json::Value>(D->begin(), D->end()); |
260 | } |
261 | } |
262 | return {}; |
263 | } |
264 | |
265 | } // namespace clangd |
266 | } // namespace clang |
267 | |