1 | //===- RPC.h - Interface for remote procedure calls from the GPU ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "RPC.h" |
10 | |
11 | #include "Shared/Debug.h" |
12 | #include "Shared/RPCOpcodes.h" |
13 | |
14 | #include "PluginInterface.h" |
15 | |
16 | #include "shared/rpc.h" |
17 | #include "shared/rpc_opcodes.h" |
18 | #include "shared/rpc_server.h" |
19 | |
20 | using namespace llvm; |
21 | using namespace omp; |
22 | using namespace target; |
23 | |
24 | template <uint32_t NumLanes> |
25 | rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, |
26 | rpc::Server::Port &Port) { |
27 | |
28 | switch (Port.get_opcode()) { |
29 | case LIBC_MALLOC: { |
30 | Port.recv_and_send([&](rpc::Buffer *Buffer, uint32_t) { |
31 | Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate( |
32 | Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING)); |
33 | }); |
34 | break; |
35 | } |
36 | case LIBC_FREE: { |
37 | Port.recv([&](rpc::Buffer *Buffer, uint32_t) { |
38 | Device.free(reinterpret_cast<void *>(Buffer->data[0]), |
39 | TARGET_ALLOC_DEVICE_NON_BLOCKING); |
40 | }); |
41 | break; |
42 | } |
43 | case OFFLOAD_HOST_CALL: { |
44 | uint64_t Sizes[NumLanes] = {0}; |
45 | unsigned long long Results[NumLanes] = {0}; |
46 | void *Args[NumLanes] = {nullptr}; |
47 | Port.recv_n(Args, Sizes, [&](uint64_t Size) { return new char[Size]; }); |
48 | Port.recv([&](rpc::Buffer *buffer, uint32_t ID) { |
49 | using FuncPtrTy = unsigned long long (*)(void *); |
50 | auto Func = reinterpret_cast<FuncPtrTy>(buffer->data[0]); |
51 | Results[ID] = Func(Args[ID]); |
52 | }); |
53 | Port.send([&](rpc::Buffer *Buffer, uint32_t ID) { |
54 | Buffer->data[0] = static_cast<uint64_t>(Results[ID]); |
55 | delete[] reinterpret_cast<char *>(Args[ID]); |
56 | }); |
57 | break; |
58 | } |
59 | default: |
60 | return rpc::RPC_UNHANDLED_OPCODE; |
61 | break; |
62 | } |
63 | return rpc::RPC_SUCCESS; |
64 | } |
65 | |
66 | static rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, |
67 | rpc::Server::Port &Port, |
68 | uint32_t NumLanes) { |
69 | if (NumLanes == 1) |
70 | return handleOffloadOpcodes<1>(Device, Port); |
71 | else if (NumLanes == 32) |
72 | return handleOffloadOpcodes<32>(Device, Port); |
73 | else if (NumLanes == 64) |
74 | return handleOffloadOpcodes<64>(Device, Port); |
75 | else |
76 | return rpc::RPC_ERROR; |
77 | } |
78 | |
79 | static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) { |
80 | uint64_t NumPorts = |
81 | std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); |
82 | rpc::Server Server(NumPorts, Buffer); |
83 | |
84 | auto Port = Server.try_open(Device.getWarpSize()); |
85 | if (!Port) |
86 | return rpc::RPC_SUCCESS; |
87 | |
88 | rpc::Status Status = |
89 | handleOffloadOpcodes(Device, *Port, Device.getWarpSize()); |
90 | |
91 | // Let the `libc` library handle any other unhandled opcodes. |
92 | if (Status == rpc::RPC_UNHANDLED_OPCODE) |
93 | Status = LIBC_NAMESPACE::shared::handle_libc_opcodes(*Port, |
94 | Device.getWarpSize()); |
95 | |
96 | Port->close(); |
97 | |
98 | return Status; |
99 | } |
100 | |
101 | void RPCServerTy::ServerThread::startThread() { |
102 | if (!Running.fetch_or(true, std::memory_order_acquire)) |
103 | Worker = std::thread([this]() { run(); }); |
104 | } |
105 | |
106 | void RPCServerTy::ServerThread::shutDown() { |
107 | if (!Running.fetch_and(false, std::memory_order_release)) |
108 | return; |
109 | { |
110 | std::lock_guard<decltype(Mutex)> Lock(Mutex); |
111 | CV.notify_all(); |
112 | } |
113 | if (Worker.joinable()) |
114 | Worker.join(); |
115 | } |
116 | |
117 | void RPCServerTy::ServerThread::run() { |
118 | std::unique_lock<decltype(Mutex)> Lock(Mutex); |
119 | for (;;) { |
120 | CV.wait(Lock, [&]() { |
121 | return NumUsers.load(std::memory_order_acquire) > 0 || |
122 | !Running.load(std::memory_order_acquire); |
123 | }); |
124 | |
125 | if (!Running.load(std::memory_order_acquire)) |
126 | return; |
127 | |
128 | Lock.unlock(); |
129 | while (NumUsers.load(std::memory_order_relaxed) > 0 && |
130 | Running.load(std::memory_order_relaxed)) { |
131 | std::lock_guard<decltype(Mutex)> Lock(BufferMutex); |
132 | for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) { |
133 | if (!Buffer || !Device) |
134 | continue; |
135 | |
136 | // If running the server failed, print a message but keep running. |
137 | if (runServer(*Device, Buffer) != rpc::RPC_SUCCESS) |
138 | FAILURE_MESSAGE("Unhandled or invalid RPC opcode!" ); |
139 | } |
140 | } |
141 | Lock.lock(); |
142 | } |
143 | } |
144 | |
145 | RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin) |
146 | : Buffers(std::make_unique<void *[]>(Plugin.getNumDevices())), |
147 | Devices(std::make_unique<plugin::GenericDeviceTy *[]>( |
148 | Plugin.getNumDevices())), |
149 | Thread(new ServerThread(Buffers.get(), Devices.get(), |
150 | Plugin.getNumDevices(), BufferMutex)) {} |
151 | |
152 | llvm::Error RPCServerTy::startThread() { |
153 | Thread->startThread(); |
154 | return Error::success(); |
155 | } |
156 | |
157 | llvm::Error RPCServerTy::shutDown() { |
158 | Thread->shutDown(); |
159 | return Error::success(); |
160 | } |
161 | |
162 | llvm::Expected<bool> |
163 | RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device, |
164 | plugin::GenericGlobalHandlerTy &Handler, |
165 | plugin::DeviceImageTy &Image) { |
166 | return Handler.isSymbolInImage(Device, Image, "__llvm_rpc_client" ); |
167 | } |
168 | |
169 | Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, |
170 | plugin::GenericGlobalHandlerTy &Handler, |
171 | plugin::DeviceImageTy &Image) { |
172 | uint64_t NumPorts = |
173 | std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); |
174 | void *RPCBuffer = Device.allocate( |
175 | rpc::Server::allocation_size(Device.getWarpSize(), NumPorts), nullptr, |
176 | TARGET_ALLOC_HOST); |
177 | if (!RPCBuffer) |
178 | return plugin::Plugin::error( |
179 | error::ErrorCode::UNKNOWN, |
180 | "failed to initialize RPC server for device %d" , Device.getDeviceId()); |
181 | |
182 | // Get the address of the RPC client from the device. |
183 | plugin::GlobalTy ClientGlobal("__llvm_rpc_client" , sizeof(rpc::Client)); |
184 | if (auto Err = |
185 | Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal)) |
186 | return Err; |
187 | |
188 | rpc::Client client(NumPorts, RPCBuffer); |
189 | if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client, |
190 | sizeof(rpc::Client), nullptr)) |
191 | return Err; |
192 | std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex); |
193 | Buffers[Device.getDeviceId()] = RPCBuffer; |
194 | Devices[Device.getDeviceId()] = &Device; |
195 | |
196 | return Error::success(); |
197 | } |
198 | |
199 | Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) { |
200 | std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex); |
201 | Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST); |
202 | Buffers[Device.getDeviceId()] = nullptr; |
203 | Devices[Device.getDeviceId()] = nullptr; |
204 | return Error::success(); |
205 | } |
206 | |