| 1 | //===- RPC.h - Interface for remote procedure calls from the GPU ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "RPC.h" |
| 10 | |
| 11 | #include "Shared/Debug.h" |
| 12 | #include "Shared/RPCOpcodes.h" |
| 13 | |
| 14 | #include "PluginInterface.h" |
| 15 | |
| 16 | #include "shared/rpc.h" |
| 17 | #include "shared/rpc_opcodes.h" |
| 18 | #include "shared/rpc_server.h" |
| 19 | |
| 20 | using namespace llvm; |
| 21 | using namespace omp; |
| 22 | using namespace target; |
| 23 | |
| 24 | template <uint32_t NumLanes> |
| 25 | rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, |
| 26 | rpc::Server::Port &Port) { |
| 27 | |
| 28 | switch (Port.get_opcode()) { |
| 29 | case LIBC_MALLOC: { |
| 30 | Port.recv_and_send([&](rpc::Buffer *Buffer, uint32_t) { |
| 31 | Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate( |
| 32 | Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING)); |
| 33 | }); |
| 34 | break; |
| 35 | } |
| 36 | case LIBC_FREE: { |
| 37 | Port.recv([&](rpc::Buffer *Buffer, uint32_t) { |
| 38 | Device.free(reinterpret_cast<void *>(Buffer->data[0]), |
| 39 | TARGET_ALLOC_DEVICE_NON_BLOCKING); |
| 40 | }); |
| 41 | break; |
| 42 | } |
| 43 | case OFFLOAD_HOST_CALL: { |
| 44 | uint64_t Sizes[NumLanes] = {0}; |
| 45 | unsigned long long Results[NumLanes] = {0}; |
| 46 | void *Args[NumLanes] = {nullptr}; |
| 47 | Port.recv_n(Args, Sizes, [&](uint64_t Size) { return new char[Size]; }); |
| 48 | Port.recv([&](rpc::Buffer *buffer, uint32_t ID) { |
| 49 | using FuncPtrTy = unsigned long long (*)(void *); |
| 50 | auto Func = reinterpret_cast<FuncPtrTy>(buffer->data[0]); |
| 51 | Results[ID] = Func(Args[ID]); |
| 52 | }); |
| 53 | Port.send([&](rpc::Buffer *Buffer, uint32_t ID) { |
| 54 | Buffer->data[0] = static_cast<uint64_t>(Results[ID]); |
| 55 | delete[] reinterpret_cast<char *>(Args[ID]); |
| 56 | }); |
| 57 | break; |
| 58 | } |
| 59 | default: |
| 60 | return rpc::RPC_UNHANDLED_OPCODE; |
| 61 | break; |
| 62 | } |
| 63 | return rpc::RPC_SUCCESS; |
| 64 | } |
| 65 | |
| 66 | static rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device, |
| 67 | rpc::Server::Port &Port, |
| 68 | uint32_t NumLanes) { |
| 69 | if (NumLanes == 1) |
| 70 | return handleOffloadOpcodes<1>(Device, Port); |
| 71 | else if (NumLanes == 32) |
| 72 | return handleOffloadOpcodes<32>(Device, Port); |
| 73 | else if (NumLanes == 64) |
| 74 | return handleOffloadOpcodes<64>(Device, Port); |
| 75 | else |
| 76 | return rpc::RPC_ERROR; |
| 77 | } |
| 78 | |
| 79 | static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) { |
| 80 | uint64_t NumPorts = |
| 81 | std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); |
| 82 | rpc::Server Server(NumPorts, Buffer); |
| 83 | |
| 84 | auto Port = Server.try_open(Device.getWarpSize()); |
| 85 | if (!Port) |
| 86 | return rpc::RPC_SUCCESS; |
| 87 | |
| 88 | rpc::Status Status = |
| 89 | handleOffloadOpcodes(Device, *Port, Device.getWarpSize()); |
| 90 | |
| 91 | // Let the `libc` library handle any other unhandled opcodes. |
| 92 | if (Status == rpc::RPC_UNHANDLED_OPCODE) |
| 93 | Status = LIBC_NAMESPACE::shared::handle_libc_opcodes(*Port, |
| 94 | Device.getWarpSize()); |
| 95 | |
| 96 | Port->close(); |
| 97 | |
| 98 | return Status; |
| 99 | } |
| 100 | |
| 101 | void RPCServerTy::ServerThread::startThread() { |
| 102 | if (!Running.fetch_or(true, std::memory_order_acquire)) |
| 103 | Worker = std::thread([this]() { run(); }); |
| 104 | } |
| 105 | |
| 106 | void RPCServerTy::ServerThread::shutDown() { |
| 107 | if (!Running.fetch_and(false, std::memory_order_release)) |
| 108 | return; |
| 109 | { |
| 110 | std::lock_guard<decltype(Mutex)> Lock(Mutex); |
| 111 | CV.notify_all(); |
| 112 | } |
| 113 | if (Worker.joinable()) |
| 114 | Worker.join(); |
| 115 | } |
| 116 | |
| 117 | void RPCServerTy::ServerThread::run() { |
| 118 | std::unique_lock<decltype(Mutex)> Lock(Mutex); |
| 119 | for (;;) { |
| 120 | CV.wait(Lock, [&]() { |
| 121 | return NumUsers.load(std::memory_order_acquire) > 0 || |
| 122 | !Running.load(std::memory_order_acquire); |
| 123 | }); |
| 124 | |
| 125 | if (!Running.load(std::memory_order_acquire)) |
| 126 | return; |
| 127 | |
| 128 | Lock.unlock(); |
| 129 | while (NumUsers.load(std::memory_order_relaxed) > 0 && |
| 130 | Running.load(std::memory_order_relaxed)) { |
| 131 | std::lock_guard<decltype(Mutex)> Lock(BufferMutex); |
| 132 | for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) { |
| 133 | if (!Buffer || !Device) |
| 134 | continue; |
| 135 | |
| 136 | // If running the server failed, print a message but keep running. |
| 137 | if (runServer(*Device, Buffer) != rpc::RPC_SUCCESS) |
| 138 | FAILURE_MESSAGE("Unhandled or invalid RPC opcode!" ); |
| 139 | } |
| 140 | } |
| 141 | Lock.lock(); |
| 142 | } |
| 143 | } |
| 144 | |
| 145 | RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin) |
| 146 | : Buffers(std::make_unique<void *[]>(Plugin.getNumDevices())), |
| 147 | Devices(std::make_unique<plugin::GenericDeviceTy *[]>( |
| 148 | Plugin.getNumDevices())), |
| 149 | Thread(new ServerThread(Buffers.get(), Devices.get(), |
| 150 | Plugin.getNumDevices(), BufferMutex)) {} |
| 151 | |
| 152 | llvm::Error RPCServerTy::startThread() { |
| 153 | Thread->startThread(); |
| 154 | return Error::success(); |
| 155 | } |
| 156 | |
| 157 | llvm::Error RPCServerTy::shutDown() { |
| 158 | Thread->shutDown(); |
| 159 | return Error::success(); |
| 160 | } |
| 161 | |
| 162 | llvm::Expected<bool> |
| 163 | RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device, |
| 164 | plugin::GenericGlobalHandlerTy &Handler, |
| 165 | plugin::DeviceImageTy &Image) { |
| 166 | return Handler.isSymbolInImage(Device, Image, "__llvm_rpc_client" ); |
| 167 | } |
| 168 | |
| 169 | Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, |
| 170 | plugin::GenericGlobalHandlerTy &Handler, |
| 171 | plugin::DeviceImageTy &Image) { |
| 172 | uint64_t NumPorts = |
| 173 | std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT); |
| 174 | void *RPCBuffer = Device.allocate( |
| 175 | rpc::Server::allocation_size(Device.getWarpSize(), NumPorts), nullptr, |
| 176 | TARGET_ALLOC_HOST); |
| 177 | if (!RPCBuffer) |
| 178 | return plugin::Plugin::error( |
| 179 | error::ErrorCode::UNKNOWN, |
| 180 | "failed to initialize RPC server for device %d" , Device.getDeviceId()); |
| 181 | |
| 182 | // Get the address of the RPC client from the device. |
| 183 | plugin::GlobalTy ClientGlobal("__llvm_rpc_client" , sizeof(rpc::Client)); |
| 184 | if (auto Err = |
| 185 | Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal)) |
| 186 | return Err; |
| 187 | |
| 188 | rpc::Client client(NumPorts, RPCBuffer); |
| 189 | if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client, |
| 190 | sizeof(rpc::Client), nullptr)) |
| 191 | return Err; |
| 192 | std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex); |
| 193 | Buffers[Device.getDeviceId()] = RPCBuffer; |
| 194 | Devices[Device.getDeviceId()] = &Device; |
| 195 | |
| 196 | return Error::success(); |
| 197 | } |
| 198 | |
| 199 | Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) { |
| 200 | std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex); |
| 201 | Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST); |
| 202 | Buffers[Device.getDeviceId()] = nullptr; |
| 203 | Devices[Device.getDeviceId()] = nullptr; |
| 204 | return Error::success(); |
| 205 | } |
| 206 | |