1//===- RPC.h - Interface for remote procedure calls from the GPU ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "RPC.h"
10
11#include "Shared/Debug.h"
12#include "Shared/RPCOpcodes.h"
13
14#include "PluginInterface.h"
15
16#include "shared/rpc.h"
17#include "shared/rpc_opcodes.h"
18#include "shared/rpc_server.h"
19
20using namespace llvm;
21using namespace omp;
22using namespace target;
23
24template <uint32_t NumLanes>
25rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device,
26 rpc::Server::Port &Port) {
27
28 switch (Port.get_opcode()) {
29 case LIBC_MALLOC: {
30 Port.recv_and_send([&](rpc::Buffer *Buffer, uint32_t) {
31 Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate(
32 Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING));
33 });
34 break;
35 }
36 case LIBC_FREE: {
37 Port.recv([&](rpc::Buffer *Buffer, uint32_t) {
38 Device.free(reinterpret_cast<void *>(Buffer->data[0]),
39 TARGET_ALLOC_DEVICE_NON_BLOCKING);
40 });
41 break;
42 }
43 case OFFLOAD_HOST_CALL: {
44 uint64_t Sizes[NumLanes] = {0};
45 unsigned long long Results[NumLanes] = {0};
46 void *Args[NumLanes] = {nullptr};
47 Port.recv_n(Args, Sizes, [&](uint64_t Size) { return new char[Size]; });
48 Port.recv([&](rpc::Buffer *buffer, uint32_t ID) {
49 using FuncPtrTy = unsigned long long (*)(void *);
50 auto Func = reinterpret_cast<FuncPtrTy>(buffer->data[0]);
51 Results[ID] = Func(Args[ID]);
52 });
53 Port.send([&](rpc::Buffer *Buffer, uint32_t ID) {
54 Buffer->data[0] = static_cast<uint64_t>(Results[ID]);
55 delete[] reinterpret_cast<char *>(Args[ID]);
56 });
57 break;
58 }
59 default:
60 return rpc::RPC_UNHANDLED_OPCODE;
61 break;
62 }
63 return rpc::RPC_SUCCESS;
64}
65
66static rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device,
67 rpc::Server::Port &Port,
68 uint32_t NumLanes) {
69 if (NumLanes == 1)
70 return handleOffloadOpcodes<1>(Device, Port);
71 else if (NumLanes == 32)
72 return handleOffloadOpcodes<32>(Device, Port);
73 else if (NumLanes == 64)
74 return handleOffloadOpcodes<64>(Device, Port);
75 else
76 return rpc::RPC_ERROR;
77}
78
79static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) {
80 uint64_t NumPorts =
81 std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
82 rpc::Server Server(NumPorts, Buffer);
83
84 auto Port = Server.try_open(Device.getWarpSize());
85 if (!Port)
86 return rpc::RPC_SUCCESS;
87
88 rpc::Status Status =
89 handleOffloadOpcodes(Device, *Port, Device.getWarpSize());
90
91 // Let the `libc` library handle any other unhandled opcodes.
92 if (Status == rpc::RPC_UNHANDLED_OPCODE)
93 Status = LIBC_NAMESPACE::shared::handle_libc_opcodes(*Port,
94 Device.getWarpSize());
95
96 Port->close();
97
98 return Status;
99}
100
101void RPCServerTy::ServerThread::startThread() {
102 if (!Running.fetch_or(true, std::memory_order_acquire))
103 Worker = std::thread([this]() { run(); });
104}
105
106void RPCServerTy::ServerThread::shutDown() {
107 if (!Running.fetch_and(false, std::memory_order_release))
108 return;
109 {
110 std::lock_guard<decltype(Mutex)> Lock(Mutex);
111 CV.notify_all();
112 }
113 if (Worker.joinable())
114 Worker.join();
115}
116
117void RPCServerTy::ServerThread::run() {
118 std::unique_lock<decltype(Mutex)> Lock(Mutex);
119 for (;;) {
120 CV.wait(Lock, [&]() {
121 return NumUsers.load(std::memory_order_acquire) > 0 ||
122 !Running.load(std::memory_order_acquire);
123 });
124
125 if (!Running.load(std::memory_order_acquire))
126 return;
127
128 Lock.unlock();
129 while (NumUsers.load(std::memory_order_relaxed) > 0 &&
130 Running.load(std::memory_order_relaxed)) {
131 std::lock_guard<decltype(Mutex)> Lock(BufferMutex);
132 for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) {
133 if (!Buffer || !Device)
134 continue;
135
136 // If running the server failed, print a message but keep running.
137 if (runServer(*Device, Buffer) != rpc::RPC_SUCCESS)
138 FAILURE_MESSAGE("Unhandled or invalid RPC opcode!");
139 }
140 }
141 Lock.lock();
142 }
143}
144
145RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
146 : Buffers(std::make_unique<void *[]>(Plugin.getNumDevices())),
147 Devices(std::make_unique<plugin::GenericDeviceTy *[]>(
148 Plugin.getNumDevices())),
149 Thread(new ServerThread(Buffers.get(), Devices.get(),
150 Plugin.getNumDevices(), BufferMutex)) {}
151
152llvm::Error RPCServerTy::startThread() {
153 Thread->startThread();
154 return Error::success();
155}
156
157llvm::Error RPCServerTy::shutDown() {
158 Thread->shutDown();
159 return Error::success();
160}
161
162llvm::Expected<bool>
163RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
164 plugin::GenericGlobalHandlerTy &Handler,
165 plugin::DeviceImageTy &Image) {
166 return Handler.isSymbolInImage(Device, Image, "__llvm_rpc_client");
167}
168
169Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
170 plugin::GenericGlobalHandlerTy &Handler,
171 plugin::DeviceImageTy &Image) {
172 uint64_t NumPorts =
173 std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
174 void *RPCBuffer = Device.allocate(
175 rpc::Server::allocation_size(Device.getWarpSize(), NumPorts), nullptr,
176 TARGET_ALLOC_HOST);
177 if (!RPCBuffer)
178 return plugin::Plugin::error(
179 error::ErrorCode::UNKNOWN,
180 "failed to initialize RPC server for device %d", Device.getDeviceId());
181
182 // Get the address of the RPC client from the device.
183 plugin::GlobalTy ClientGlobal("__llvm_rpc_client", sizeof(rpc::Client));
184 if (auto Err =
185 Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal))
186 return Err;
187
188 rpc::Client client(NumPorts, RPCBuffer);
189 if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client,
190 sizeof(rpc::Client), nullptr))
191 return Err;
192 std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex);
193 Buffers[Device.getDeviceId()] = RPCBuffer;
194 Devices[Device.getDeviceId()] = &Device;
195
196 return Error::success();
197}
198
199Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
200 std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex);
201 Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST);
202 Buffers[Device.getDeviceId()] = nullptr;
203 Devices[Device.getDeviceId()] = nullptr;
204 return Error::success();
205}
206

source code of offload/plugins-nextgen/common/src/RPC.cpp