1 | use super::encode_percents; |
2 | use crate::{Error, Result}; |
3 | #[cfg (not(feature = "tokio" ))] |
4 | use async_io::Async; |
5 | #[cfg (not(feature = "tokio" ))] |
6 | use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; |
7 | use std::{ |
8 | collections::HashMap, |
9 | fmt::{Display, Formatter}, |
10 | str::FromStr, |
11 | }; |
12 | #[cfg (feature = "tokio" )] |
13 | use tokio::net::TcpStream; |
14 | |
15 | /// A TCP transport in a D-Bus address. |
16 | #[derive (Clone, Debug, PartialEq, Eq)] |
17 | pub struct Tcp { |
18 | pub(super) host: String, |
19 | pub(super) bind: Option<String>, |
20 | pub(super) port: u16, |
21 | pub(super) family: Option<TcpTransportFamily>, |
22 | pub(super) nonce_file: Option<Vec<u8>>, |
23 | } |
24 | |
25 | impl Tcp { |
26 | /// Create a new TCP transport with the given host and port. |
27 | pub fn new(host: &str, port: u16) -> Self { |
28 | Self { |
29 | host: host.to_owned(), |
30 | port, |
31 | bind: None, |
32 | family: None, |
33 | nonce_file: None, |
34 | } |
35 | } |
36 | |
37 | /// Set the `tcp:` address `bind` value. |
38 | pub fn set_bind(mut self, bind: Option<String>) -> Self { |
39 | self.bind = bind; |
40 | |
41 | self |
42 | } |
43 | |
44 | /// Set the `tcp:` address `family` value. |
45 | pub fn set_family(mut self, family: Option<TcpTransportFamily>) -> Self { |
46 | self.family = family; |
47 | |
48 | self |
49 | } |
50 | |
51 | /// Set the `tcp:` address `noncefile` value. |
52 | pub fn set_nonce_file(mut self, nonce_file: Option<Vec<u8>>) -> Self { |
53 | self.nonce_file = nonce_file; |
54 | |
55 | self |
56 | } |
57 | |
58 | /// Returns the `tcp:` address `host` value. |
59 | pub fn host(&self) -> &str { |
60 | &self.host |
61 | } |
62 | |
63 | /// Returns the `tcp:` address `bind` value. |
64 | pub fn bind(&self) -> Option<&str> { |
65 | self.bind.as_deref() |
66 | } |
67 | |
68 | /// Returns the `tcp:` address `port` value. |
69 | pub fn port(&self) -> u16 { |
70 | self.port |
71 | } |
72 | |
73 | /// Returns the `tcp:` address `family` value. |
74 | pub fn family(&self) -> Option<TcpTransportFamily> { |
75 | self.family |
76 | } |
77 | |
78 | /// The nonce file path, if any. |
79 | pub fn nonce_file(&self) -> Option<&[u8]> { |
80 | self.nonce_file.as_deref() |
81 | } |
82 | |
83 | /// Take ownership of the nonce file path, if any. |
84 | pub fn take_nonce_file(&mut self) -> Option<Vec<u8>> { |
85 | self.nonce_file.take() |
86 | } |
87 | |
88 | pub(super) fn from_options( |
89 | opts: HashMap<&str, &str>, |
90 | nonce_tcp_required: bool, |
91 | ) -> Result<Self> { |
92 | let bind = None; |
93 | if opts.contains_key("bind" ) { |
94 | return Err(Error::Address("`bind` isn't yet supported" .into())); |
95 | } |
96 | |
97 | let host = opts |
98 | .get("host" ) |
99 | .ok_or_else(|| Error::Address("tcp address is missing `host`" .into()))? |
100 | .to_string(); |
101 | let port = opts |
102 | .get("port" ) |
103 | .ok_or_else(|| Error::Address("tcp address is missing `port`" .into()))?; |
104 | let port = port |
105 | .parse::<u16>() |
106 | .map_err(|_| Error::Address("invalid tcp `port`" .into()))?; |
107 | let family = opts |
108 | .get("family" ) |
109 | .map(|f| TcpTransportFamily::from_str(f)) |
110 | .transpose()?; |
111 | let nonce_file = opts |
112 | .get("noncefile" ) |
113 | .map(|f| super::decode_percents(f)) |
114 | .transpose()?; |
115 | if nonce_tcp_required && nonce_file.is_none() { |
116 | return Err(Error::Address( |
117 | "nonce-tcp address is missing `noncefile`" .into(), |
118 | )); |
119 | } |
120 | |
121 | Ok(Self { |
122 | host, |
123 | bind, |
124 | port, |
125 | family, |
126 | nonce_file, |
127 | }) |
128 | } |
129 | |
130 | #[cfg (not(feature = "tokio" ))] |
131 | pub(super) async fn connect(self) -> Result<Async<TcpStream>> { |
132 | let addrs = crate::Task::spawn_blocking( |
133 | move || -> Result<Vec<SocketAddr>> { |
134 | let addrs = (self.host(), self.port()).to_socket_addrs()?.filter(|a| { |
135 | if let Some(family) = self.family() { |
136 | if family == TcpTransportFamily::Ipv4 { |
137 | a.is_ipv4() |
138 | } else { |
139 | a.is_ipv6() |
140 | } |
141 | } else { |
142 | true |
143 | } |
144 | }); |
145 | Ok(addrs.collect()) |
146 | }, |
147 | "connect tcp" , |
148 | ) |
149 | .await |
150 | .map_err(|e| Error::Address(format!("Failed to receive TCP addresses: {e}" )))?; |
151 | |
152 | // we could attempt connections in parallel? |
153 | let mut last_err = Error::Address("Failed to connect" .into()); |
154 | for addr in addrs { |
155 | match Async::<TcpStream>::connect(addr).await { |
156 | Ok(stream) => return Ok(stream), |
157 | Err(e) => last_err = e.into(), |
158 | } |
159 | } |
160 | |
161 | Err(last_err) |
162 | } |
163 | |
164 | #[cfg (feature = "tokio" )] |
165 | pub(super) async fn connect(self) -> Result<TcpStream> { |
166 | TcpStream::connect((self.host(), self.port())) |
167 | .await |
168 | .map_err(|e| Error::InputOutput(e.into())) |
169 | } |
170 | } |
171 | |
172 | impl Display for Tcp { |
173 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
174 | match self.nonce_file() { |
175 | Some(nonce_file) => { |
176 | f.write_str("nonce-tcp:noncefile=" )?; |
177 | encode_percents(f, nonce_file)?; |
178 | f.write_str("," )?; |
179 | } |
180 | None => f.write_str("tcp:" )?, |
181 | } |
182 | f.write_str("host=" )?; |
183 | |
184 | encode_percents(f, self.host().as_bytes())?; |
185 | |
186 | write!(f, ",port= {}" , self.port())?; |
187 | |
188 | if let Some(bind) = self.bind() { |
189 | f.write_str(",bind=" )?; |
190 | encode_percents(f, bind.as_bytes())?; |
191 | } |
192 | |
193 | if let Some(family) = self.family() { |
194 | write!(f, ",family= {family}" )?; |
195 | } |
196 | |
197 | Ok(()) |
198 | } |
199 | } |
200 | |
201 | /// A `tcp:` address family. |
202 | #[derive (Copy, Clone, Debug, PartialEq, Eq)] |
203 | pub enum TcpTransportFamily { |
204 | Ipv4, |
205 | Ipv6, |
206 | } |
207 | |
208 | impl FromStr for TcpTransportFamily { |
209 | type Err = Error; |
210 | |
211 | fn from_str(family: &str) -> Result<Self> { |
212 | match family { |
213 | "ipv4" => Ok(Self::Ipv4), |
214 | "ipv6" => Ok(Self::Ipv6), |
215 | _ => Err(Error::Address(format!( |
216 | "invalid tcp address `family`: {family}" |
217 | ))), |
218 | } |
219 | } |
220 | } |
221 | |
222 | impl Display for TcpTransportFamily { |
223 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
224 | match self { |
225 | Self::Ipv4 => write!(f, "ipv4" ), |
226 | Self::Ipv6 => write!(f, "ipv6" ), |
227 | } |
228 | } |
229 | } |
230 | |