1 | //===-- Acceptor.cpp --------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "Acceptor.h" |
10 | |
11 | #include "llvm/ADT/StringRef.h" |
12 | #include "llvm/Support/ScopedPrinter.h" |
13 | |
14 | #include "lldb/Host/ConnectionFileDescriptor.h" |
15 | #include "lldb/Host/common/TCPSocket.h" |
16 | #include "lldb/Utility/StreamString.h" |
17 | #include "lldb/Utility/UriParser.h" |
18 | #include <optional> |
19 | |
20 | using namespace lldb; |
21 | using namespace lldb_private; |
22 | using namespace lldb_private::lldb_server; |
23 | using namespace llvm; |
24 | |
25 | namespace { |
26 | |
27 | struct SocketScheme { |
28 | const char *m_scheme; |
29 | const Socket::SocketProtocol m_protocol; |
30 | }; |
31 | |
32 | SocketScheme socket_schemes[] = { |
33 | {.m_scheme: "tcp" , .m_protocol: Socket::ProtocolTcp}, |
34 | {.m_scheme: "udp" , .m_protocol: Socket::ProtocolUdp}, |
35 | {.m_scheme: "unix" , .m_protocol: Socket::ProtocolUnixDomain}, |
36 | {.m_scheme: "unix-abstract" , .m_protocol: Socket::ProtocolUnixAbstract}, |
37 | }; |
38 | |
39 | bool FindProtocolByScheme(const char *scheme, |
40 | Socket::SocketProtocol &protocol) { |
41 | for (auto s : socket_schemes) { |
42 | if (!strcmp(s1: s.m_scheme, s2: scheme)) { |
43 | protocol = s.m_protocol; |
44 | return true; |
45 | } |
46 | } |
47 | return false; |
48 | } |
49 | |
50 | const char *FindSchemeByProtocol(const Socket::SocketProtocol protocol) { |
51 | for (auto s : socket_schemes) { |
52 | if (s.m_protocol == protocol) |
53 | return s.m_scheme; |
54 | } |
55 | return nullptr; |
56 | } |
57 | } |
58 | |
59 | Status Acceptor::Listen(int backlog) { |
60 | return m_listener_socket_up->Listen(name: StringRef(m_name), backlog); |
61 | } |
62 | |
63 | Status Acceptor::Accept(const bool child_processes_inherit, Connection *&conn) { |
64 | Socket *conn_socket = nullptr; |
65 | auto error = m_listener_socket_up->Accept(socket&: conn_socket); |
66 | if (error.Success()) |
67 | conn = new ConnectionFileDescriptor(conn_socket); |
68 | |
69 | return error; |
70 | } |
71 | |
72 | Socket::SocketProtocol Acceptor::GetSocketProtocol() const { |
73 | return m_listener_socket_up->GetSocketProtocol(); |
74 | } |
75 | |
76 | const char *Acceptor::GetSocketScheme() const { |
77 | return FindSchemeByProtocol(protocol: GetSocketProtocol()); |
78 | } |
79 | |
80 | std::string Acceptor::GetLocalSocketId() const { return m_local_socket_id(); } |
81 | |
82 | std::unique_ptr<Acceptor> Acceptor::Create(StringRef name, |
83 | const bool child_processes_inherit, |
84 | Status &error) { |
85 | error.Clear(); |
86 | |
87 | Socket::SocketProtocol socket_protocol = Socket::ProtocolUnixDomain; |
88 | // Try to match socket name as URL - e.g., tcp://localhost:5555 |
89 | if (std::optional<URI> res = URI::Parse(uri: name)) { |
90 | if (!FindProtocolByScheme(scheme: res->scheme.str().c_str(), protocol&: socket_protocol)) |
91 | error.SetErrorStringWithFormat("Unknown protocol scheme \"%s\"" , |
92 | res->scheme.str().c_str()); |
93 | else |
94 | name = name.drop_front(N: res->scheme.size() + strlen(s: "://" )); |
95 | } else { |
96 | // Try to match socket name as $host:port - e.g., localhost:5555 |
97 | if (!llvm::errorToBool(Err: Socket::DecodeHostAndPort(host_and_port: name).takeError())) |
98 | socket_protocol = Socket::ProtocolTcp; |
99 | } |
100 | |
101 | if (error.Fail()) |
102 | return std::unique_ptr<Acceptor>(); |
103 | |
104 | std::unique_ptr<Socket> listener_socket_up = |
105 | Socket::Create(protocol: socket_protocol, child_processes_inherit, error); |
106 | |
107 | LocalSocketIdFunc local_socket_id; |
108 | if (error.Success()) { |
109 | if (listener_socket_up->GetSocketProtocol() == Socket::ProtocolTcp) { |
110 | TCPSocket *tcp_socket = |
111 | static_cast<TCPSocket *>(listener_socket_up.get()); |
112 | local_socket_id = [tcp_socket]() { |
113 | auto local_port = tcp_socket->GetLocalPortNumber(); |
114 | return (local_port != 0) ? llvm::to_string(Value: local_port) : "" ; |
115 | }; |
116 | } else { |
117 | const std::string socket_name = std::string(name); |
118 | local_socket_id = [socket_name]() { return socket_name; }; |
119 | } |
120 | |
121 | return std::unique_ptr<Acceptor>( |
122 | new Acceptor(std::move(listener_socket_up), name, local_socket_id)); |
123 | } |
124 | |
125 | return std::unique_ptr<Acceptor>(); |
126 | } |
127 | |
128 | Acceptor::Acceptor(std::unique_ptr<Socket> &&listener_socket, StringRef name, |
129 | const LocalSocketIdFunc &local_socket_id) |
130 | : m_listener_socket_up(std::move(listener_socket)), m_name(name.str()), |
131 | m_local_socket_id(local_socket_id) {} |
132 | |