1 | //===-- Shared memory RPC server instantiation ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | // Workaround for missing __has_builtin in < GCC 10. |
10 | #ifndef __has_builtin |
11 | #define __has_builtin(x) 0 |
12 | #endif |
13 | |
14 | #include "llvmlibc_rpc_server.h" |
15 | |
16 | #include "src/__support/RPC/rpc.h" |
17 | #include "src/__support/arg_list.h" |
18 | #include "src/stdio/printf_core/converter.h" |
19 | #include "src/stdio/printf_core/parser.h" |
20 | #include "src/stdio/printf_core/writer.h" |
21 | |
22 | #include "src/stdio/gpu/file.h" |
23 | #include <algorithm> |
24 | #include <atomic> |
25 | #include <cstdio> |
26 | #include <cstring> |
27 | #include <memory> |
28 | #include <mutex> |
29 | #include <unordered_map> |
30 | #include <variant> |
31 | #include <vector> |
32 | |
33 | using namespace LIBC_NAMESPACE; |
34 | using namespace LIBC_NAMESPACE::printf_core; |
35 | |
36 | static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer), |
37 | "Buffer size mismatch" ); |
38 | |
39 | static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT, |
40 | "Incorrect maximum port count" ); |
41 | |
42 | template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) { |
43 | FILE *files[lane_size] = {nullptr}; |
44 | // Get the appropriate output stream to use. |
45 | if (port.get_opcode() == RPC_PRINTF_TO_STREAM) |
46 | port.recv([&](rpc::Buffer *buffer, uint32_t id) { |
47 | files[id] = reinterpret_cast<FILE *>(buffer->data[0]); |
48 | }); |
49 | else if (port.get_opcode() == RPC_PRINTF_TO_STDOUT) |
50 | std::fill(files, files + lane_size, stdout); |
51 | else |
52 | std::fill(files, files + lane_size, stderr); |
53 | |
54 | uint64_t format_sizes[lane_size] = {0}; |
55 | void *format[lane_size] = {nullptr}; |
56 | |
57 | uint64_t args_sizes[lane_size] = {0}; |
58 | void *args[lane_size] = {nullptr}; |
59 | |
60 | // Recieve the format string and arguments from the client. |
61 | port.recv_n(format, format_sizes, |
62 | [&](uint64_t size) { return new char[size]; }); |
63 | port.recv_n(args, args_sizes, [&](uint64_t size) { return new char[size]; }); |
64 | |
65 | // Identify any arguments that are actually pointers to strings on the client. |
66 | // Additionally we want to determine how much buffer space we need to print. |
67 | std::vector<void *> strs_to_copy[lane_size]; |
68 | int buffer_size[lane_size] = {0}; |
69 | for (uint32_t lane = 0; lane < lane_size; ++lane) { |
70 | if (!format[lane]) |
71 | continue; |
72 | |
73 | WriteBuffer wb(nullptr, 0); |
74 | Writer writer(&wb); |
75 | |
76 | internal::StructArgList printf_args(args[lane], args_sizes[lane]); |
77 | Parser<internal::StructArgList> parser( |
78 | reinterpret_cast<const char *>(format[lane]), printf_args); |
79 | |
80 | for (FormatSection cur_section = parser.get_next_section(); |
81 | !cur_section.raw_string.empty(); |
82 | cur_section = parser.get_next_section()) { |
83 | if (cur_section.has_conv && cur_section.conv_name == 's' && |
84 | cur_section.conv_val_ptr) { |
85 | strs_to_copy[lane].emplace_back(cur_section.conv_val_ptr); |
86 | } else if (cur_section.has_conv) { |
87 | // Ignore conversion errors for the first pass. |
88 | convert(writer: &writer, to_conv: cur_section); |
89 | } else { |
90 | writer.write(new_string: cur_section.raw_string); |
91 | } |
92 | } |
93 | buffer_size[lane] = writer.get_chars_written(); |
94 | } |
95 | |
96 | // Recieve any strings from the client and push them into a buffer. |
97 | std::vector<void *> copied_strs[lane_size]; |
98 | while (std::any_of(std::begin(strs_to_copy), std::end(strs_to_copy), |
99 | [](const auto &v) { return !v.empty() && v.back(); })) { |
100 | port.send([&](rpc::Buffer *buffer, uint32_t id) { |
101 | void *ptr = !strs_to_copy[id].empty() ? strs_to_copy[id].back() : nullptr; |
102 | buffer->data[1] = reinterpret_cast<uintptr_t>(ptr); |
103 | if (!strs_to_copy[id].empty()) |
104 | strs_to_copy[id].pop_back(); |
105 | }); |
106 | uint64_t str_sizes[lane_size] = {0}; |
107 | void *strs[lane_size] = {nullptr}; |
108 | port.recv_n(strs, str_sizes, [](uint64_t size) { return new char[size]; }); |
109 | for (uint32_t lane = 0; lane < lane_size; ++lane) { |
110 | if (!strs[lane]) |
111 | continue; |
112 | |
113 | copied_strs[lane].emplace_back(strs[lane]); |
114 | buffer_size[lane] += str_sizes[lane]; |
115 | } |
116 | } |
117 | |
118 | // Perform the final formatting and printing using the LLVM C library printf. |
119 | int results[lane_size] = {0}; |
120 | std::vector<void *> to_be_deleted; |
121 | for (uint32_t lane = 0; lane < lane_size; ++lane) { |
122 | if (!format[lane]) |
123 | continue; |
124 | |
125 | std::unique_ptr<char[]> buffer(new char[buffer_size[lane]]); |
126 | WriteBuffer wb(buffer.get(), buffer_size[lane]); |
127 | Writer writer(&wb); |
128 | |
129 | internal::StructArgList printf_args(args[lane], args_sizes[lane]); |
130 | Parser<internal::StructArgList> parser( |
131 | reinterpret_cast<const char *>(format[lane]), printf_args); |
132 | |
133 | // Parse and print the format string using the arguments we copied from |
134 | // the client. |
135 | int ret = 0; |
136 | for (FormatSection cur_section = parser.get_next_section(); |
137 | !cur_section.raw_string.empty(); |
138 | cur_section = parser.get_next_section()) { |
139 | // If this argument was a string we use the memory buffer we copied from |
140 | // the client by replacing the raw pointer with the copied one. |
141 | if (cur_section.has_conv && cur_section.conv_name == 's') { |
142 | if (!copied_strs[lane].empty()) { |
143 | cur_section.conv_val_ptr = copied_strs[lane].back(); |
144 | to_be_deleted.push_back(copied_strs[lane].back()); |
145 | copied_strs[lane].pop_back(); |
146 | } else { |
147 | cur_section.conv_val_ptr = nullptr; |
148 | } |
149 | } |
150 | if (cur_section.has_conv) { |
151 | ret = convert(writer: &writer, to_conv: cur_section); |
152 | if (ret == -1) |
153 | break; |
154 | } else { |
155 | writer.write(new_string: cur_section.raw_string); |
156 | } |
157 | } |
158 | |
159 | results[lane] = |
160 | fwrite(buffer.get(), 1, writer.get_chars_written(), files[lane]); |
161 | if (results[lane] != writer.get_chars_written() || ret == -1) |
162 | results[lane] = -1; |
163 | } |
164 | |
165 | // Send the final return value and signal completion by setting the string |
166 | // argument to null. |
167 | port.send([&](rpc::Buffer *buffer, uint32_t id) { |
168 | buffer->data[0] = static_cast<uint64_t>(results[id]); |
169 | buffer->data[1] = reinterpret_cast<uintptr_t>(nullptr); |
170 | delete[] reinterpret_cast<char *>(format[id]); |
171 | delete[] reinterpret_cast<char *>(args[id]); |
172 | }); |
173 | for (void *ptr : to_be_deleted) |
174 | delete[] reinterpret_cast<char *>(ptr); |
175 | } |
176 | |
177 | template <uint32_t lane_size> |
178 | rpc_status_t handle_server_impl( |
179 | rpc::Server &server, |
180 | const std::unordered_map<uint16_t, rpc_opcode_callback_ty> &callbacks, |
181 | const std::unordered_map<uint16_t, void *> &callback_data, |
182 | uint32_t &index) { |
183 | auto port = server.try_open(lane_size, start: index); |
184 | if (!port) |
185 | return RPC_STATUS_SUCCESS; |
186 | |
187 | switch (port->get_opcode()) { |
188 | case RPC_WRITE_TO_STREAM: |
189 | case RPC_WRITE_TO_STDERR: |
190 | case RPC_WRITE_TO_STDOUT: |
191 | case RPC_WRITE_TO_STDOUT_NEWLINE: { |
192 | uint64_t sizes[lane_size] = {0}; |
193 | void *strs[lane_size] = {nullptr}; |
194 | FILE *files[lane_size] = {nullptr}; |
195 | if (port->get_opcode() == RPC_WRITE_TO_STREAM) { |
196 | port->recv([&](rpc::Buffer *buffer, uint32_t id) { |
197 | files[id] = reinterpret_cast<FILE *>(buffer->data[0]); |
198 | }); |
199 | } else if (port->get_opcode() == RPC_WRITE_TO_STDERR) { |
200 | std::fill(files, files + lane_size, stderr); |
201 | } else { |
202 | std::fill(files, files + lane_size, stdout); |
203 | } |
204 | |
205 | port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; }); |
206 | port->send([&](rpc::Buffer *buffer, uint32_t id) { |
207 | flockfile(files[id]); |
208 | buffer->data[0] = fwrite_unlocked(strs[id], 1, sizes[id], files[id]); |
209 | if (port->get_opcode() == RPC_WRITE_TO_STDOUT_NEWLINE && |
210 | buffer->data[0] == sizes[id]) |
211 | buffer->data[0] += fwrite_unlocked("\n" , 1, 1, files[id]); |
212 | funlockfile(files[id]); |
213 | delete[] reinterpret_cast<uint8_t *>(strs[id]); |
214 | }); |
215 | break; |
216 | } |
217 | case RPC_READ_FROM_STREAM: { |
218 | uint64_t sizes[lane_size] = {0}; |
219 | void *data[lane_size] = {nullptr}; |
220 | port->recv([&](rpc::Buffer *buffer, uint32_t id) { |
221 | data[id] = new char[buffer->data[0]]; |
222 | sizes[id] = |
223 | fread(data[id], 1, buffer->data[0], file::to_stream(f: buffer->data[1])); |
224 | }); |
225 | port->send_n(data, sizes); |
226 | port->send([&](rpc::Buffer *buffer, uint32_t id) { |
227 | delete[] reinterpret_cast<uint8_t *>(data[id]); |
228 | std::memcpy(dest: buffer->data, src: &sizes[id], n: sizeof(uint64_t)); |
229 | }); |
230 | break; |
231 | } |
232 | case RPC_READ_FGETS: { |
233 | uint64_t sizes[lane_size] = {0}; |
234 | void *data[lane_size] = {nullptr}; |
235 | port->recv([&](rpc::Buffer *buffer, uint32_t id) { |
236 | data[id] = new char[buffer->data[0]]; |
237 | const char *str = |
238 | fgets(s: reinterpret_cast<char *>(data[id]), n: buffer->data[0], |
239 | stream: file::to_stream(f: buffer->data[1])); |
240 | sizes[id] = !str ? 0 : std::strlen(s: str) + 1; |
241 | }); |
242 | port->send_n(data, sizes); |
243 | for (uint32_t id = 0; id < lane_size; ++id) |
244 | if (data[id]) |
245 | delete[] reinterpret_cast<uint8_t *>(data[id]); |
246 | break; |
247 | } |
248 | case RPC_OPEN_FILE: { |
249 | uint64_t sizes[lane_size] = {0}; |
250 | void *paths[lane_size] = {nullptr}; |
251 | port->recv_n(paths, sizes, [&](uint64_t size) { return new char[size]; }); |
252 | port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) { |
253 | FILE *file = fopen(filename: reinterpret_cast<char *>(paths[id]), |
254 | modes: reinterpret_cast<char *>(buffer->data)); |
255 | buffer->data[0] = reinterpret_cast<uintptr_t>(file); |
256 | }); |
257 | break; |
258 | } |
259 | case RPC_CLOSE_FILE: { |
260 | port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) { |
261 | FILE *file = reinterpret_cast<FILE *>(buffer->data[0]); |
262 | buffer->data[0] = fclose(stream: file); |
263 | }); |
264 | break; |
265 | } |
266 | case RPC_EXIT: { |
267 | // Send a response to the client to signal that we are ready to exit. |
268 | port->recv_and_send([](rpc::Buffer *) {}); |
269 | port->recv([](rpc::Buffer *buffer) { |
270 | int status = 0; |
271 | std::memcpy(dest: &status, src: buffer->data, n: sizeof(int)); |
272 | exit(status: status); |
273 | }); |
274 | break; |
275 | } |
276 | case RPC_ABORT: { |
277 | // Send a response to the client to signal that we are ready to abort. |
278 | port->recv_and_send([](rpc::Buffer *) {}); |
279 | port->recv([](rpc::Buffer *) {}); |
280 | abort(); |
281 | break; |
282 | } |
283 | case RPC_HOST_CALL: { |
284 | uint64_t sizes[lane_size] = {0}; |
285 | void *args[lane_size] = {nullptr}; |
286 | port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; }); |
287 | port->recv([&](rpc::Buffer *buffer, uint32_t id) { |
288 | reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]); |
289 | }); |
290 | port->send([&](rpc::Buffer *, uint32_t id) { |
291 | delete[] reinterpret_cast<uint8_t *>(args[id]); |
292 | }); |
293 | break; |
294 | } |
295 | case RPC_FEOF: { |
296 | port->recv_and_send([](rpc::Buffer *buffer) { |
297 | buffer->data[0] = feof(stream: file::to_stream(f: buffer->data[0])); |
298 | }); |
299 | break; |
300 | } |
301 | case RPC_FERROR: { |
302 | port->recv_and_send([](rpc::Buffer *buffer) { |
303 | buffer->data[0] = ferror(stream: file::to_stream(f: buffer->data[0])); |
304 | }); |
305 | break; |
306 | } |
307 | case RPC_CLEARERR: { |
308 | port->recv_and_send([](rpc::Buffer *buffer) { |
309 | clearerr(stream: file::to_stream(f: buffer->data[0])); |
310 | }); |
311 | break; |
312 | } |
313 | case RPC_FSEEK: { |
314 | port->recv_and_send([](rpc::Buffer *buffer) { |
315 | buffer->data[0] = fseek(stream: file::to_stream(f: buffer->data[0]), |
316 | off: static_cast<long>(buffer->data[1]), |
317 | whence: static_cast<int>(buffer->data[2])); |
318 | }); |
319 | break; |
320 | } |
321 | case RPC_FTELL: { |
322 | port->recv_and_send([](rpc::Buffer *buffer) { |
323 | buffer->data[0] = ftell(stream: file::to_stream(f: buffer->data[0])); |
324 | }); |
325 | break; |
326 | } |
327 | case RPC_FFLUSH: { |
328 | port->recv_and_send([](rpc::Buffer *buffer) { |
329 | buffer->data[0] = fflush(stream: file::to_stream(f: buffer->data[0])); |
330 | }); |
331 | break; |
332 | } |
333 | case RPC_UNGETC: { |
334 | port->recv_and_send([](rpc::Buffer *buffer) { |
335 | buffer->data[0] = ungetc(c: static_cast<int>(buffer->data[0]), |
336 | stream: file::to_stream(f: buffer->data[1])); |
337 | }); |
338 | break; |
339 | } |
340 | case RPC_PRINTF_TO_STREAM: |
341 | case RPC_PRINTF_TO_STDOUT: |
342 | case RPC_PRINTF_TO_STDERR: { |
343 | handle_printf<lane_size>(*port); |
344 | break; |
345 | } |
346 | case RPC_NOOP: { |
347 | port->recv([](rpc::Buffer *) {}); |
348 | break; |
349 | } |
350 | default: { |
351 | auto handler = |
352 | callbacks.find(x: static_cast<rpc_opcode_t>(port->get_opcode())); |
353 | |
354 | // We error out on an unhandled opcode. |
355 | if (handler == callbacks.end()) |
356 | return RPC_STATUS_UNHANDLED_OPCODE; |
357 | |
358 | // Invoke the registered callback with a reference to the port. |
359 | void *data = |
360 | callback_data.at(k: static_cast<rpc_opcode_t>(port->get_opcode())); |
361 | rpc_port_t port_ref{.handle: reinterpret_cast<uint64_t>(&*port), .lane_size: lane_size}; |
362 | (handler->second)(port_ref, data); |
363 | } |
364 | } |
365 | |
366 | // Increment the index so we start the scan after this port. |
367 | index = port->get_index() + 1; |
368 | port->close(); |
369 | |
370 | return RPC_STATUS_CONTINUE; |
371 | } |
372 | |
373 | struct Device { |
374 | Device(uint32_t lane_size, uint32_t num_ports, void *buffer) |
375 | : lane_size(lane_size), buffer(buffer), server(num_ports, buffer), |
376 | client(num_ports, buffer) {} |
377 | |
378 | rpc_status_t handle_server(uint32_t &index) { |
379 | switch (lane_size) { |
380 | case 1: |
381 | return handle_server_impl<1>(server, callbacks, callback_data, index); |
382 | case 32: |
383 | return handle_server_impl<32>(server, callbacks, callback_data, index); |
384 | case 64: |
385 | return handle_server_impl<64>(server, callbacks, callback_data, index); |
386 | default: |
387 | return RPC_STATUS_INVALID_LANE_SIZE; |
388 | } |
389 | } |
390 | |
391 | uint32_t lane_size; |
392 | void *buffer; |
393 | rpc::Server server; |
394 | rpc::Client client; |
395 | std::unordered_map<uint16_t, rpc_opcode_callback_ty> callbacks; |
396 | std::unordered_map<uint16_t, void *> callback_data; |
397 | }; |
398 | |
399 | rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports, |
400 | uint32_t lane_size, rpc_alloc_ty alloc, |
401 | void *data) { |
402 | if (!rpc_device) |
403 | return RPC_STATUS_ERROR; |
404 | if (lane_size != 1 && lane_size != 32 && lane_size != 64) |
405 | return RPC_STATUS_INVALID_LANE_SIZE; |
406 | |
407 | uint64_t size = rpc::Server::allocation_size(lane_size, port_count: num_ports); |
408 | void *buffer = alloc(size, data); |
409 | |
410 | if (!buffer) |
411 | return RPC_STATUS_ERROR; |
412 | |
413 | Device *device = new Device(lane_size, num_ports, buffer); |
414 | if (!device) |
415 | return RPC_STATUS_ERROR; |
416 | |
417 | rpc_device->handle = reinterpret_cast<uintptr_t>(device); |
418 | return RPC_STATUS_SUCCESS; |
419 | } |
420 | |
421 | rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc, |
422 | void *data) { |
423 | if (!rpc_device.handle) |
424 | return RPC_STATUS_ERROR; |
425 | |
426 | Device *device = reinterpret_cast<Device *>(rpc_device.handle); |
427 | dealloc(device->buffer, data); |
428 | delete device; |
429 | |
430 | return RPC_STATUS_SUCCESS; |
431 | } |
432 | |
433 | rpc_status_t rpc_handle_server(rpc_device_t rpc_device) { |
434 | if (!rpc_device.handle) |
435 | return RPC_STATUS_ERROR; |
436 | |
437 | Device *device = reinterpret_cast<Device *>(rpc_device.handle); |
438 | uint32_t index = 0; |
439 | for (;;) { |
440 | rpc_status_t status = device->handle_server(index); |
441 | if (status != RPC_STATUS_CONTINUE) |
442 | return status; |
443 | } |
444 | } |
445 | |
446 | rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint16_t opcode, |
447 | rpc_opcode_callback_ty callback, |
448 | void *data) { |
449 | if (!rpc_device.handle) |
450 | return RPC_STATUS_ERROR; |
451 | |
452 | Device *device = reinterpret_cast<Device *>(rpc_device.handle); |
453 | |
454 | device->callbacks[opcode] = callback; |
455 | device->callback_data[opcode] = data; |
456 | return RPC_STATUS_SUCCESS; |
457 | } |
458 | |
459 | const void *rpc_get_client_buffer(rpc_device_t rpc_device) { |
460 | if (!rpc_device.handle) |
461 | return nullptr; |
462 | Device *device = reinterpret_cast<Device *>(rpc_device.handle); |
463 | return &device->client; |
464 | } |
465 | |
466 | uint64_t rpc_get_client_size() { return sizeof(rpc::Client); } |
467 | |
468 | void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) { |
469 | auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle); |
470 | port->send(fill: [=](rpc::Buffer *buffer) { |
471 | callback(reinterpret_cast<rpc_buffer_t *>(buffer), data); |
472 | }); |
473 | } |
474 | |
475 | void rpc_send_n(rpc_port_t ref, const void *const *src, uint64_t *size) { |
476 | auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle); |
477 | port->send_n(src, size); |
478 | } |
479 | |
480 | void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) { |
481 | auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle); |
482 | port->recv(use: [=](rpc::Buffer *buffer) { |
483 | callback(reinterpret_cast<rpc_buffer_t *>(buffer), data); |
484 | }); |
485 | } |
486 | |
487 | void rpc_recv_n(rpc_port_t ref, void **dst, uint64_t *size, rpc_alloc_ty alloc, |
488 | void *data) { |
489 | auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle); |
490 | auto alloc_fn = [=](uint64_t size) { return alloc(size, data); }; |
491 | port->recv_n(dst, size, alloc&: alloc_fn); |
492 | } |
493 | |
494 | void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback, |
495 | void *data) { |
496 | auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle); |
497 | port->recv_and_send(work: [=](rpc::Buffer *buffer) { |
498 | callback(reinterpret_cast<rpc_buffer_t *>(buffer), data); |
499 | }); |
500 | } |
501 | |