1 | //! DNS Resolution used by the `HttpConnector`. |
2 | //! |
3 | //! This module contains: |
4 | //! |
5 | //! - A [`GaiResolver`](GaiResolver) that is the default resolver for the |
6 | //! `HttpConnector`. |
7 | //! - The `Name` type used as an argument to custom resolvers. |
8 | //! |
9 | //! # Resolvers are `Service`s |
10 | //! |
11 | //! A resolver is just a |
12 | //! `Service<Name, Response = impl Iterator<Item = SocketAddr>>`. |
13 | //! |
14 | //! A simple resolver that ignores the name and always returns a specific |
15 | //! address: |
16 | //! |
17 | //! ```rust,ignore |
18 | //! use std::{convert::Infallible, iter, net::SocketAddr}; |
19 | //! |
20 | //! let resolver = tower::service_fn(|_name| async { |
21 | //! Ok::<_, Infallible>(iter::once(SocketAddr::from(([127, 0, 0, 1], 8080)))) |
22 | //! }); |
23 | //! ``` |
24 | use std::error::Error; |
25 | use std::future::Future; |
26 | use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; |
27 | use std::pin::Pin; |
28 | use std::str::FromStr; |
29 | use std::task::{Context, Poll}; |
30 | use std::{fmt, io, vec}; |
31 | |
32 | use tokio::task::JoinHandle; |
33 | use tower_service::Service; |
34 | use tracing::debug; |
35 | |
36 | pub(super) use self::sealed::Resolve; |
37 | |
38 | /// A domain name to resolve into IP addresses. |
39 | #[derive (Clone, Hash, Eq, PartialEq)] |
40 | pub struct Name { |
41 | host: Box<str>, |
42 | } |
43 | |
44 | /// A resolver using blocking `getaddrinfo` calls in a threadpool. |
45 | #[derive (Clone)] |
46 | pub struct GaiResolver { |
47 | _priv: (), |
48 | } |
49 | |
50 | /// An iterator of IP addresses returned from `getaddrinfo`. |
51 | pub struct GaiAddrs { |
52 | inner: SocketAddrs, |
53 | } |
54 | |
55 | /// A future to resolve a name returned by `GaiResolver`. |
56 | pub struct GaiFuture { |
57 | inner: JoinHandle<Result<SocketAddrs, io::Error>>, |
58 | } |
59 | |
60 | impl Name { |
61 | pub(super) fn new(host: Box<str>) -> Name { |
62 | Name { host } |
63 | } |
64 | |
65 | /// View the hostname as a string slice. |
66 | pub fn as_str(&self) -> &str { |
67 | &self.host |
68 | } |
69 | } |
70 | |
71 | impl fmt::Debug for Name { |
72 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
73 | fmt::Debug::fmt(&self.host, f) |
74 | } |
75 | } |
76 | |
77 | impl fmt::Display for Name { |
78 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
79 | fmt::Display::fmt(&self.host, f) |
80 | } |
81 | } |
82 | |
83 | impl FromStr for Name { |
84 | type Err = InvalidNameError; |
85 | |
86 | fn from_str(host: &str) -> Result<Self, Self::Err> { |
87 | // Possibly add validation later |
88 | Ok(Name::new(host:host.into())) |
89 | } |
90 | } |
91 | |
92 | /// Error indicating a given string was not a valid domain name. |
93 | #[derive (Debug)] |
94 | pub struct InvalidNameError(()); |
95 | |
96 | impl fmt::Display for InvalidNameError { |
97 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
98 | f.write_str(data:"Not a valid domain name" ) |
99 | } |
100 | } |
101 | |
102 | impl Error for InvalidNameError {} |
103 | |
104 | impl GaiResolver { |
105 | /// Construct a new `GaiResolver`. |
106 | pub fn new() -> Self { |
107 | GaiResolver { _priv: () } |
108 | } |
109 | } |
110 | |
111 | impl Service<Name> for GaiResolver { |
112 | type Response = GaiAddrs; |
113 | type Error = io::Error; |
114 | type Future = GaiFuture; |
115 | |
116 | fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
117 | Poll::Ready(Ok(())) |
118 | } |
119 | |
120 | fn call(&mut self, name: Name) -> Self::Future { |
121 | let blocking: JoinHandle> = tokio::task::spawn_blocking(move || { |
122 | debug!("resolving host= {:?}" , name.host); |
123 | (&*name.host, 0) |
124 | .to_socket_addrs() |
125 | .map(|i: IntoIter| SocketAddrs { iter: i }) |
126 | }); |
127 | |
128 | GaiFuture { inner: blocking } |
129 | } |
130 | } |
131 | |
132 | impl fmt::Debug for GaiResolver { |
133 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
134 | f.pad("GaiResolver" ) |
135 | } |
136 | } |
137 | |
138 | impl Future for GaiFuture { |
139 | type Output = Result<GaiAddrs, io::Error>; |
140 | |
141 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
142 | Pin::new(&mut self.inner).poll(cx).map(|res: Result, …>| match res { |
143 | Ok(Ok(addrs: SocketAddrs)) => Ok(GaiAddrs { inner: addrs }), |
144 | Ok(Err(err: Error)) => Err(err), |
145 | Err(join_err: JoinError) => { |
146 | if join_err.is_cancelled() { |
147 | Err(io::Error::new(kind:io::ErrorKind::Interrupted, error:join_err)) |
148 | } else { |
149 | panic!("gai background task failed: {:?}" , join_err) |
150 | } |
151 | } |
152 | }) |
153 | } |
154 | } |
155 | |
156 | impl fmt::Debug for GaiFuture { |
157 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
158 | f.pad("GaiFuture" ) |
159 | } |
160 | } |
161 | |
162 | impl Drop for GaiFuture { |
163 | fn drop(&mut self) { |
164 | self.inner.abort(); |
165 | } |
166 | } |
167 | |
168 | impl Iterator for GaiAddrs { |
169 | type Item = SocketAddr; |
170 | |
171 | fn next(&mut self) -> Option<Self::Item> { |
172 | self.inner.next() |
173 | } |
174 | } |
175 | |
176 | impl fmt::Debug for GaiAddrs { |
177 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
178 | f.pad("GaiAddrs" ) |
179 | } |
180 | } |
181 | |
182 | pub(super) struct SocketAddrs { |
183 | iter: vec::IntoIter<SocketAddr>, |
184 | } |
185 | |
186 | impl SocketAddrs { |
187 | pub(super) fn new(addrs: Vec<SocketAddr>) -> Self { |
188 | SocketAddrs { |
189 | iter: addrs.into_iter(), |
190 | } |
191 | } |
192 | |
193 | pub(super) fn try_parse(host: &str, port: u16) -> Option<SocketAddrs> { |
194 | if let Ok(addr) = host.parse::<Ipv4Addr>() { |
195 | let addr = SocketAddrV4::new(addr, port); |
196 | return Some(SocketAddrs { |
197 | iter: vec![SocketAddr::V4(addr)].into_iter(), |
198 | }); |
199 | } |
200 | if let Ok(addr) = host.parse::<Ipv6Addr>() { |
201 | let addr = SocketAddrV6::new(addr, port, 0, 0); |
202 | return Some(SocketAddrs { |
203 | iter: vec![SocketAddr::V6(addr)].into_iter(), |
204 | }); |
205 | } |
206 | None |
207 | } |
208 | |
209 | #[inline ] |
210 | fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> SocketAddrs { |
211 | SocketAddrs::new(self.iter.filter(predicate).collect()) |
212 | } |
213 | |
214 | pub(super) fn split_by_preference( |
215 | self, |
216 | local_addr_ipv4: Option<Ipv4Addr>, |
217 | local_addr_ipv6: Option<Ipv6Addr>, |
218 | ) -> (SocketAddrs, SocketAddrs) { |
219 | match (local_addr_ipv4, local_addr_ipv6) { |
220 | (Some(_), None) => (self.filter(SocketAddr::is_ipv4), SocketAddrs::new(vec![])), |
221 | (None, Some(_)) => (self.filter(SocketAddr::is_ipv6), SocketAddrs::new(vec![])), |
222 | _ => { |
223 | let preferring_v6 = self |
224 | .iter |
225 | .as_slice() |
226 | .first() |
227 | .map(SocketAddr::is_ipv6) |
228 | .unwrap_or(false); |
229 | |
230 | let (preferred, fallback) = self |
231 | .iter |
232 | .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6); |
233 | |
234 | (SocketAddrs::new(preferred), SocketAddrs::new(fallback)) |
235 | } |
236 | } |
237 | } |
238 | |
239 | pub(super) fn is_empty(&self) -> bool { |
240 | self.iter.as_slice().is_empty() |
241 | } |
242 | |
243 | pub(super) fn len(&self) -> usize { |
244 | self.iter.as_slice().len() |
245 | } |
246 | } |
247 | |
248 | impl Iterator for SocketAddrs { |
249 | type Item = SocketAddr; |
250 | #[inline ] |
251 | fn next(&mut self) -> Option<SocketAddr> { |
252 | self.iter.next() |
253 | } |
254 | } |
255 | |
256 | /* |
257 | /// A resolver using `getaddrinfo` calls via the `tokio_executor::threadpool::blocking` API. |
258 | /// |
259 | /// Unlike the `GaiResolver` this will not spawn dedicated threads, but only works when running on the |
260 | /// multi-threaded Tokio runtime. |
261 | #[cfg(feature = "runtime")] |
262 | #[derive(Clone, Debug)] |
263 | pub struct TokioThreadpoolGaiResolver(()); |
264 | |
265 | /// The future returned by `TokioThreadpoolGaiResolver`. |
266 | #[cfg(feature = "runtime")] |
267 | #[derive(Debug)] |
268 | pub struct TokioThreadpoolGaiFuture { |
269 | name: Name, |
270 | } |
271 | |
272 | #[cfg(feature = "runtime")] |
273 | impl TokioThreadpoolGaiResolver { |
274 | /// Creates a new DNS resolver that will use tokio threadpool's blocking |
275 | /// feature. |
276 | /// |
277 | /// **Requires** its futures to be run on the threadpool runtime. |
278 | pub fn new() -> Self { |
279 | TokioThreadpoolGaiResolver(()) |
280 | } |
281 | } |
282 | |
283 | #[cfg(feature = "runtime")] |
284 | impl Service<Name> for TokioThreadpoolGaiResolver { |
285 | type Response = GaiAddrs; |
286 | type Error = io::Error; |
287 | type Future = TokioThreadpoolGaiFuture; |
288 | |
289 | fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
290 | Poll::Ready(Ok(())) |
291 | } |
292 | |
293 | fn call(&mut self, name: Name) -> Self::Future { |
294 | TokioThreadpoolGaiFuture { name } |
295 | } |
296 | } |
297 | |
298 | #[cfg(feature = "runtime")] |
299 | impl Future for TokioThreadpoolGaiFuture { |
300 | type Output = Result<GaiAddrs, io::Error>; |
301 | |
302 | fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> { |
303 | match ready!(tokio_executor::threadpool::blocking(|| ( |
304 | self.name.as_str(), |
305 | 0 |
306 | ) |
307 | .to_socket_addrs())) |
308 | { |
309 | Ok(Ok(iter)) => Poll::Ready(Ok(GaiAddrs { |
310 | inner: IpAddrs { iter }, |
311 | })), |
312 | Ok(Err(e)) => Poll::Ready(Err(e)), |
313 | // a BlockingError, meaning not on a tokio_executor::threadpool :( |
314 | Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), |
315 | } |
316 | } |
317 | } |
318 | */ |
319 | |
320 | mod sealed { |
321 | use std::future::Future; |
322 | use std::task::{Context, Poll}; |
323 | |
324 | use super::{Name, SocketAddr}; |
325 | use tower_service::Service; |
326 | |
327 | // "Trait alias" for `Service<Name, Response = Addrs>` |
328 | pub trait Resolve { |
329 | type Addrs: Iterator<Item = SocketAddr>; |
330 | type Error: Into<Box<dyn std::error::Error + Send + Sync>>; |
331 | type Future: Future<Output = Result<Self::Addrs, Self::Error>>; |
332 | |
333 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>; |
334 | fn resolve(&mut self, name: Name) -> Self::Future; |
335 | } |
336 | |
337 | impl<S> Resolve for S |
338 | where |
339 | S: Service<Name>, |
340 | S::Response: Iterator<Item = SocketAddr>, |
341 | S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, |
342 | { |
343 | type Addrs = S::Response; |
344 | type Error = S::Error; |
345 | type Future = S::Future; |
346 | |
347 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
348 | Service::poll_ready(self, cx) |
349 | } |
350 | |
351 | fn resolve(&mut self, name: Name) -> Self::Future { |
352 | Service::call(self, name) |
353 | } |
354 | } |
355 | } |
356 | |
357 | pub(super) async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error> |
358 | where |
359 | R: Resolve, |
360 | { |
361 | futures_util::future::poll_fn(|cx: &mut Context<'_>| resolver.poll_ready(cx)).await?; |
362 | resolver.resolve(name).await |
363 | } |
364 | |
365 | #[cfg (test)] |
366 | mod tests { |
367 | use super::*; |
368 | use std::net::{Ipv4Addr, Ipv6Addr}; |
369 | |
370 | #[test ] |
371 | fn test_ip_addrs_split_by_preference() { |
372 | let ip_v4 = Ipv4Addr::new(127, 0, 0, 1); |
373 | let ip_v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); |
374 | let v4_addr = (ip_v4, 80).into(); |
375 | let v6_addr = (ip_v6, 80).into(); |
376 | |
377 | let (mut preferred, mut fallback) = SocketAddrs { |
378 | iter: vec![v4_addr, v6_addr].into_iter(), |
379 | } |
380 | .split_by_preference(None, None); |
381 | assert!(preferred.next().unwrap().is_ipv4()); |
382 | assert!(fallback.next().unwrap().is_ipv6()); |
383 | |
384 | let (mut preferred, mut fallback) = SocketAddrs { |
385 | iter: vec![v6_addr, v4_addr].into_iter(), |
386 | } |
387 | .split_by_preference(None, None); |
388 | assert!(preferred.next().unwrap().is_ipv6()); |
389 | assert!(fallback.next().unwrap().is_ipv4()); |
390 | |
391 | let (mut preferred, mut fallback) = SocketAddrs { |
392 | iter: vec![v4_addr, v6_addr].into_iter(), |
393 | } |
394 | .split_by_preference(Some(ip_v4), Some(ip_v6)); |
395 | assert!(preferred.next().unwrap().is_ipv4()); |
396 | assert!(fallback.next().unwrap().is_ipv6()); |
397 | |
398 | let (mut preferred, mut fallback) = SocketAddrs { |
399 | iter: vec![v6_addr, v4_addr].into_iter(), |
400 | } |
401 | .split_by_preference(Some(ip_v4), Some(ip_v6)); |
402 | assert!(preferred.next().unwrap().is_ipv6()); |
403 | assert!(fallback.next().unwrap().is_ipv4()); |
404 | |
405 | let (mut preferred, fallback) = SocketAddrs { |
406 | iter: vec![v4_addr, v6_addr].into_iter(), |
407 | } |
408 | .split_by_preference(Some(ip_v4), None); |
409 | assert!(preferred.next().unwrap().is_ipv4()); |
410 | assert!(fallback.is_empty()); |
411 | |
412 | let (mut preferred, fallback) = SocketAddrs { |
413 | iter: vec![v4_addr, v6_addr].into_iter(), |
414 | } |
415 | .split_by_preference(None, Some(ip_v6)); |
416 | assert!(preferred.next().unwrap().is_ipv6()); |
417 | assert!(fallback.is_empty()); |
418 | } |
419 | |
420 | #[test ] |
421 | fn test_name_from_str() { |
422 | const DOMAIN: &str = "test.example.com" ; |
423 | let name = Name::from_str(DOMAIN).expect("Should be a valid domain" ); |
424 | assert_eq!(name.as_str(), DOMAIN); |
425 | assert_eq!(name.to_string(), DOMAIN); |
426 | } |
427 | } |
428 | |