1//===-- Shared memory RPC server instantiation ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// Workaround for missing __has_builtin in < GCC 10.
10#ifndef __has_builtin
11#define __has_builtin(x) 0
12#endif
13
14#include "llvmlibc_rpc_server.h"
15
16#include "src/__support/RPC/rpc.h"
17#include "src/__support/arg_list.h"
18#include "src/stdio/printf_core/converter.h"
19#include "src/stdio/printf_core/parser.h"
20#include "src/stdio/printf_core/writer.h"
21
22#include "src/stdio/gpu/file.h"
23#include <algorithm>
24#include <atomic>
25#include <cstdio>
26#include <cstring>
27#include <memory>
28#include <mutex>
29#include <unordered_map>
30#include <variant>
31#include <vector>
32
33using namespace LIBC_NAMESPACE;
34using namespace LIBC_NAMESPACE::printf_core;
35
36static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
37 "Buffer size mismatch");
38
39static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
40 "Incorrect maximum port count");
41
42template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) {
43 FILE *files[lane_size] = {nullptr};
44 // Get the appropriate output stream to use.
45 if (port.get_opcode() == RPC_PRINTF_TO_STREAM)
46 port.recv([&](rpc::Buffer *buffer, uint32_t id) {
47 files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
48 });
49 else if (port.get_opcode() == RPC_PRINTF_TO_STDOUT)
50 std::fill(files, files + lane_size, stdout);
51 else
52 std::fill(files, files + lane_size, stderr);
53
54 uint64_t format_sizes[lane_size] = {0};
55 void *format[lane_size] = {nullptr};
56
57 uint64_t args_sizes[lane_size] = {0};
58 void *args[lane_size] = {nullptr};
59
60 // Recieve the format string and arguments from the client.
61 port.recv_n(format, format_sizes,
62 [&](uint64_t size) { return new char[size]; });
63 port.recv_n(args, args_sizes, [&](uint64_t size) { return new char[size]; });
64
65 // Identify any arguments that are actually pointers to strings on the client.
66 // Additionally we want to determine how much buffer space we need to print.
67 std::vector<void *> strs_to_copy[lane_size];
68 int buffer_size[lane_size] = {0};
69 for (uint32_t lane = 0; lane < lane_size; ++lane) {
70 if (!format[lane])
71 continue;
72
73 WriteBuffer wb(nullptr, 0);
74 Writer writer(&wb);
75
76 internal::StructArgList printf_args(args[lane], args_sizes[lane]);
77 Parser<internal::StructArgList> parser(
78 reinterpret_cast<const char *>(format[lane]), printf_args);
79
80 for (FormatSection cur_section = parser.get_next_section();
81 !cur_section.raw_string.empty();
82 cur_section = parser.get_next_section()) {
83 if (cur_section.has_conv && cur_section.conv_name == 's' &&
84 cur_section.conv_val_ptr) {
85 strs_to_copy[lane].emplace_back(cur_section.conv_val_ptr);
86 } else if (cur_section.has_conv) {
87 // Ignore conversion errors for the first pass.
88 convert(writer: &writer, to_conv: cur_section);
89 } else {
90 writer.write(new_string: cur_section.raw_string);
91 }
92 }
93 buffer_size[lane] = writer.get_chars_written();
94 }
95
96 // Recieve any strings from the client and push them into a buffer.
97 std::vector<void *> copied_strs[lane_size];
98 while (std::any_of(std::begin(strs_to_copy), std::end(strs_to_copy),
99 [](const auto &v) { return !v.empty() && v.back(); })) {
100 port.send([&](rpc::Buffer *buffer, uint32_t id) {
101 void *ptr = !strs_to_copy[id].empty() ? strs_to_copy[id].back() : nullptr;
102 buffer->data[1] = reinterpret_cast<uintptr_t>(ptr);
103 if (!strs_to_copy[id].empty())
104 strs_to_copy[id].pop_back();
105 });
106 uint64_t str_sizes[lane_size] = {0};
107 void *strs[lane_size] = {nullptr};
108 port.recv_n(strs, str_sizes, [](uint64_t size) { return new char[size]; });
109 for (uint32_t lane = 0; lane < lane_size; ++lane) {
110 if (!strs[lane])
111 continue;
112
113 copied_strs[lane].emplace_back(strs[lane]);
114 buffer_size[lane] += str_sizes[lane];
115 }
116 }
117
118 // Perform the final formatting and printing using the LLVM C library printf.
119 int results[lane_size] = {0};
120 std::vector<void *> to_be_deleted;
121 for (uint32_t lane = 0; lane < lane_size; ++lane) {
122 if (!format[lane])
123 continue;
124
125 std::unique_ptr<char[]> buffer(new char[buffer_size[lane]]);
126 WriteBuffer wb(buffer.get(), buffer_size[lane]);
127 Writer writer(&wb);
128
129 internal::StructArgList printf_args(args[lane], args_sizes[lane]);
130 Parser<internal::StructArgList> parser(
131 reinterpret_cast<const char *>(format[lane]), printf_args);
132
133 // Parse and print the format string using the arguments we copied from
134 // the client.
135 int ret = 0;
136 for (FormatSection cur_section = parser.get_next_section();
137 !cur_section.raw_string.empty();
138 cur_section = parser.get_next_section()) {
139 // If this argument was a string we use the memory buffer we copied from
140 // the client by replacing the raw pointer with the copied one.
141 if (cur_section.has_conv && cur_section.conv_name == 's') {
142 if (!copied_strs[lane].empty()) {
143 cur_section.conv_val_ptr = copied_strs[lane].back();
144 to_be_deleted.push_back(copied_strs[lane].back());
145 copied_strs[lane].pop_back();
146 } else {
147 cur_section.conv_val_ptr = nullptr;
148 }
149 }
150 if (cur_section.has_conv) {
151 ret = convert(writer: &writer, to_conv: cur_section);
152 if (ret == -1)
153 break;
154 } else {
155 writer.write(new_string: cur_section.raw_string);
156 }
157 }
158
159 results[lane] =
160 fwrite(buffer.get(), 1, writer.get_chars_written(), files[lane]);
161 if (results[lane] != writer.get_chars_written() || ret == -1)
162 results[lane] = -1;
163 }
164
165 // Send the final return value and signal completion by setting the string
166 // argument to null.
167 port.send([&](rpc::Buffer *buffer, uint32_t id) {
168 buffer->data[0] = static_cast<uint64_t>(results[id]);
169 buffer->data[1] = reinterpret_cast<uintptr_t>(nullptr);
170 delete[] reinterpret_cast<char *>(format[id]);
171 delete[] reinterpret_cast<char *>(args[id]);
172 });
173 for (void *ptr : to_be_deleted)
174 delete[] reinterpret_cast<char *>(ptr);
175}
176
177template <uint32_t lane_size>
178rpc_status_t handle_server_impl(
179 rpc::Server &server,
180 const std::unordered_map<uint16_t, rpc_opcode_callback_ty> &callbacks,
181 const std::unordered_map<uint16_t, void *> &callback_data,
182 uint32_t &index) {
183 auto port = server.try_open(lane_size, start: index);
184 if (!port)
185 return RPC_STATUS_SUCCESS;
186
187 switch (port->get_opcode()) {
188 case RPC_WRITE_TO_STREAM:
189 case RPC_WRITE_TO_STDERR:
190 case RPC_WRITE_TO_STDOUT:
191 case RPC_WRITE_TO_STDOUT_NEWLINE: {
192 uint64_t sizes[lane_size] = {0};
193 void *strs[lane_size] = {nullptr};
194 FILE *files[lane_size] = {nullptr};
195 if (port->get_opcode() == RPC_WRITE_TO_STREAM) {
196 port->recv([&](rpc::Buffer *buffer, uint32_t id) {
197 files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
198 });
199 } else if (port->get_opcode() == RPC_WRITE_TO_STDERR) {
200 std::fill(files, files + lane_size, stderr);
201 } else {
202 std::fill(files, files + lane_size, stdout);
203 }
204
205 port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; });
206 port->send([&](rpc::Buffer *buffer, uint32_t id) {
207 flockfile(files[id]);
208 buffer->data[0] = fwrite_unlocked(strs[id], 1, sizes[id], files[id]);
209 if (port->get_opcode() == RPC_WRITE_TO_STDOUT_NEWLINE &&
210 buffer->data[0] == sizes[id])
211 buffer->data[0] += fwrite_unlocked("\n", 1, 1, files[id]);
212 funlockfile(files[id]);
213 delete[] reinterpret_cast<uint8_t *>(strs[id]);
214 });
215 break;
216 }
217 case RPC_READ_FROM_STREAM: {
218 uint64_t sizes[lane_size] = {0};
219 void *data[lane_size] = {nullptr};
220 port->recv([&](rpc::Buffer *buffer, uint32_t id) {
221 data[id] = new char[buffer->data[0]];
222 sizes[id] =
223 fread(data[id], 1, buffer->data[0], file::to_stream(f: buffer->data[1]));
224 });
225 port->send_n(data, sizes);
226 port->send([&](rpc::Buffer *buffer, uint32_t id) {
227 delete[] reinterpret_cast<uint8_t *>(data[id]);
228 std::memcpy(dest: buffer->data, src: &sizes[id], n: sizeof(uint64_t));
229 });
230 break;
231 }
232 case RPC_READ_FGETS: {
233 uint64_t sizes[lane_size] = {0};
234 void *data[lane_size] = {nullptr};
235 port->recv([&](rpc::Buffer *buffer, uint32_t id) {
236 data[id] = new char[buffer->data[0]];
237 const char *str =
238 fgets(s: reinterpret_cast<char *>(data[id]), n: buffer->data[0],
239 stream: file::to_stream(f: buffer->data[1]));
240 sizes[id] = !str ? 0 : std::strlen(s: str) + 1;
241 });
242 port->send_n(data, sizes);
243 for (uint32_t id = 0; id < lane_size; ++id)
244 if (data[id])
245 delete[] reinterpret_cast<uint8_t *>(data[id]);
246 break;
247 }
248 case RPC_OPEN_FILE: {
249 uint64_t sizes[lane_size] = {0};
250 void *paths[lane_size] = {nullptr};
251 port->recv_n(paths, sizes, [&](uint64_t size) { return new char[size]; });
252 port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
253 FILE *file = fopen(filename: reinterpret_cast<char *>(paths[id]),
254 modes: reinterpret_cast<char *>(buffer->data));
255 buffer->data[0] = reinterpret_cast<uintptr_t>(file);
256 });
257 break;
258 }
259 case RPC_CLOSE_FILE: {
260 port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
261 FILE *file = reinterpret_cast<FILE *>(buffer->data[0]);
262 buffer->data[0] = fclose(stream: file);
263 });
264 break;
265 }
266 case RPC_EXIT: {
267 // Send a response to the client to signal that we are ready to exit.
268 port->recv_and_send([](rpc::Buffer *) {});
269 port->recv([](rpc::Buffer *buffer) {
270 int status = 0;
271 std::memcpy(dest: &status, src: buffer->data, n: sizeof(int));
272 exit(status: status);
273 });
274 break;
275 }
276 case RPC_ABORT: {
277 // Send a response to the client to signal that we are ready to abort.
278 port->recv_and_send([](rpc::Buffer *) {});
279 port->recv([](rpc::Buffer *) {});
280 abort();
281 break;
282 }
283 case RPC_HOST_CALL: {
284 uint64_t sizes[lane_size] = {0};
285 void *args[lane_size] = {nullptr};
286 port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; });
287 port->recv([&](rpc::Buffer *buffer, uint32_t id) {
288 reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
289 });
290 port->send([&](rpc::Buffer *, uint32_t id) {
291 delete[] reinterpret_cast<uint8_t *>(args[id]);
292 });
293 break;
294 }
295 case RPC_FEOF: {
296 port->recv_and_send([](rpc::Buffer *buffer) {
297 buffer->data[0] = feof(stream: file::to_stream(f: buffer->data[0]));
298 });
299 break;
300 }
301 case RPC_FERROR: {
302 port->recv_and_send([](rpc::Buffer *buffer) {
303 buffer->data[0] = ferror(stream: file::to_stream(f: buffer->data[0]));
304 });
305 break;
306 }
307 case RPC_CLEARERR: {
308 port->recv_and_send([](rpc::Buffer *buffer) {
309 clearerr(stream: file::to_stream(f: buffer->data[0]));
310 });
311 break;
312 }
313 case RPC_FSEEK: {
314 port->recv_and_send([](rpc::Buffer *buffer) {
315 buffer->data[0] = fseek(stream: file::to_stream(f: buffer->data[0]),
316 off: static_cast<long>(buffer->data[1]),
317 whence: static_cast<int>(buffer->data[2]));
318 });
319 break;
320 }
321 case RPC_FTELL: {
322 port->recv_and_send([](rpc::Buffer *buffer) {
323 buffer->data[0] = ftell(stream: file::to_stream(f: buffer->data[0]));
324 });
325 break;
326 }
327 case RPC_FFLUSH: {
328 port->recv_and_send([](rpc::Buffer *buffer) {
329 buffer->data[0] = fflush(stream: file::to_stream(f: buffer->data[0]));
330 });
331 break;
332 }
333 case RPC_UNGETC: {
334 port->recv_and_send([](rpc::Buffer *buffer) {
335 buffer->data[0] = ungetc(c: static_cast<int>(buffer->data[0]),
336 stream: file::to_stream(f: buffer->data[1]));
337 });
338 break;
339 }
340 case RPC_PRINTF_TO_STREAM:
341 case RPC_PRINTF_TO_STDOUT:
342 case RPC_PRINTF_TO_STDERR: {
343 handle_printf<lane_size>(*port);
344 break;
345 }
346 case RPC_NOOP: {
347 port->recv([](rpc::Buffer *) {});
348 break;
349 }
350 default: {
351 auto handler =
352 callbacks.find(x: static_cast<rpc_opcode_t>(port->get_opcode()));
353
354 // We error out on an unhandled opcode.
355 if (handler == callbacks.end())
356 return RPC_STATUS_UNHANDLED_OPCODE;
357
358 // Invoke the registered callback with a reference to the port.
359 void *data =
360 callback_data.at(k: static_cast<rpc_opcode_t>(port->get_opcode()));
361 rpc_port_t port_ref{.handle: reinterpret_cast<uint64_t>(&*port), .lane_size: lane_size};
362 (handler->second)(port_ref, data);
363 }
364 }
365
366 // Increment the index so we start the scan after this port.
367 index = port->get_index() + 1;
368 port->close();
369
370 return RPC_STATUS_CONTINUE;
371}
372
373struct Device {
374 Device(uint32_t lane_size, uint32_t num_ports, void *buffer)
375 : lane_size(lane_size), buffer(buffer), server(num_ports, buffer),
376 client(num_ports, buffer) {}
377
378 rpc_status_t handle_server(uint32_t &index) {
379 switch (lane_size) {
380 case 1:
381 return handle_server_impl<1>(server, callbacks, callback_data, index);
382 case 32:
383 return handle_server_impl<32>(server, callbacks, callback_data, index);
384 case 64:
385 return handle_server_impl<64>(server, callbacks, callback_data, index);
386 default:
387 return RPC_STATUS_INVALID_LANE_SIZE;
388 }
389 }
390
391 uint32_t lane_size;
392 void *buffer;
393 rpc::Server server;
394 rpc::Client client;
395 std::unordered_map<uint16_t, rpc_opcode_callback_ty> callbacks;
396 std::unordered_map<uint16_t, void *> callback_data;
397};
398
399rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports,
400 uint32_t lane_size, rpc_alloc_ty alloc,
401 void *data) {
402 if (!rpc_device)
403 return RPC_STATUS_ERROR;
404 if (lane_size != 1 && lane_size != 32 && lane_size != 64)
405 return RPC_STATUS_INVALID_LANE_SIZE;
406
407 uint64_t size = rpc::Server::allocation_size(lane_size, port_count: num_ports);
408 void *buffer = alloc(size, data);
409
410 if (!buffer)
411 return RPC_STATUS_ERROR;
412
413 Device *device = new Device(lane_size, num_ports, buffer);
414 if (!device)
415 return RPC_STATUS_ERROR;
416
417 rpc_device->handle = reinterpret_cast<uintptr_t>(device);
418 return RPC_STATUS_SUCCESS;
419}
420
421rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc,
422 void *data) {
423 if (!rpc_device.handle)
424 return RPC_STATUS_ERROR;
425
426 Device *device = reinterpret_cast<Device *>(rpc_device.handle);
427 dealloc(device->buffer, data);
428 delete device;
429
430 return RPC_STATUS_SUCCESS;
431}
432
433rpc_status_t rpc_handle_server(rpc_device_t rpc_device) {
434 if (!rpc_device.handle)
435 return RPC_STATUS_ERROR;
436
437 Device *device = reinterpret_cast<Device *>(rpc_device.handle);
438 uint32_t index = 0;
439 for (;;) {
440 rpc_status_t status = device->handle_server(index);
441 if (status != RPC_STATUS_CONTINUE)
442 return status;
443 }
444}
445
446rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint16_t opcode,
447 rpc_opcode_callback_ty callback,
448 void *data) {
449 if (!rpc_device.handle)
450 return RPC_STATUS_ERROR;
451
452 Device *device = reinterpret_cast<Device *>(rpc_device.handle);
453
454 device->callbacks[opcode] = callback;
455 device->callback_data[opcode] = data;
456 return RPC_STATUS_SUCCESS;
457}
458
459const void *rpc_get_client_buffer(rpc_device_t rpc_device) {
460 if (!rpc_device.handle)
461 return nullptr;
462 Device *device = reinterpret_cast<Device *>(rpc_device.handle);
463 return &device->client;
464}
465
466uint64_t rpc_get_client_size() { return sizeof(rpc::Client); }
467
468void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
469 auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
470 port->send(fill: [=](rpc::Buffer *buffer) {
471 callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
472 });
473}
474
475void rpc_send_n(rpc_port_t ref, const void *const *src, uint64_t *size) {
476 auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
477 port->send_n(src, size);
478}
479
480void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
481 auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
482 port->recv(use: [=](rpc::Buffer *buffer) {
483 callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
484 });
485}
486
487void rpc_recv_n(rpc_port_t ref, void **dst, uint64_t *size, rpc_alloc_ty alloc,
488 void *data) {
489 auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
490 auto alloc_fn = [=](uint64_t size) { return alloc(size, data); };
491 port->recv_n(dst, size, alloc&: alloc_fn);
492}
493
494void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
495 void *data) {
496 auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
497 port->recv_and_send(work: [=](rpc::Buffer *buffer) {
498 callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
499 });
500}
501

source code of libc/utils/gpu/server/rpc_server.cpp