| 1 | //! DNS Resolution used by the `HttpConnector`. |
| 2 | //! |
| 3 | //! This module contains: |
| 4 | //! |
| 5 | //! - A [`GaiResolver`] that is the default resolver for the `HttpConnector`. |
| 6 | //! - The `Name` type used as an argument to custom resolvers. |
| 7 | //! |
| 8 | //! # Resolvers are `Service`s |
| 9 | //! |
| 10 | //! A resolver is just a |
| 11 | //! `Service<Name, Response = impl Iterator<Item = SocketAddr>>`. |
| 12 | //! |
| 13 | //! A simple resolver that ignores the name and always returns a specific |
| 14 | //! address: |
| 15 | //! |
| 16 | //! ```rust,ignore |
| 17 | //! use std::{convert::Infallible, iter, net::SocketAddr}; |
| 18 | //! |
| 19 | //! let resolver = tower::service_fn(|_name| async { |
| 20 | //! Ok::<_, Infallible>(iter::once(SocketAddr::from(([127, 0, 0, 1], 8080)))) |
| 21 | //! }); |
| 22 | //! ``` |
| 23 | use std::error::Error; |
| 24 | use std::future::Future; |
| 25 | use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; |
| 26 | use std::pin::Pin; |
| 27 | use std::str::FromStr; |
| 28 | use std::task::{self, Poll}; |
| 29 | use std::{fmt, io, vec}; |
| 30 | |
| 31 | use tokio::task::JoinHandle; |
| 32 | use tower_service::Service; |
| 33 | use tracing::debug_span; |
| 34 | |
| 35 | pub(super) use self::sealed::Resolve; |
| 36 | |
| 37 | /// A domain name to resolve into IP addresses. |
| 38 | #[derive (Clone, Hash, Eq, PartialEq)] |
| 39 | pub struct Name { |
| 40 | host: Box<str>, |
| 41 | } |
| 42 | |
| 43 | /// A resolver using blocking `getaddrinfo` calls in a threadpool. |
| 44 | #[derive (Clone)] |
| 45 | pub struct GaiResolver { |
| 46 | _priv: (), |
| 47 | } |
| 48 | |
| 49 | /// An iterator of IP addresses returned from `getaddrinfo`. |
| 50 | pub struct GaiAddrs { |
| 51 | inner: SocketAddrs, |
| 52 | } |
| 53 | |
| 54 | /// A future to resolve a name returned by `GaiResolver`. |
| 55 | pub struct GaiFuture { |
| 56 | inner: JoinHandle<Result<SocketAddrs, io::Error>>, |
| 57 | } |
| 58 | |
| 59 | impl Name { |
| 60 | pub(super) fn new(host: Box<str>) -> Name { |
| 61 | Name { host } |
| 62 | } |
| 63 | |
| 64 | /// View the hostname as a string slice. |
| 65 | pub fn as_str(&self) -> &str { |
| 66 | &self.host |
| 67 | } |
| 68 | } |
| 69 | |
| 70 | impl fmt::Debug for Name { |
| 71 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 72 | fmt::Debug::fmt(&self.host, f) |
| 73 | } |
| 74 | } |
| 75 | |
| 76 | impl fmt::Display for Name { |
| 77 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 78 | fmt::Display::fmt(&self.host, f) |
| 79 | } |
| 80 | } |
| 81 | |
| 82 | impl FromStr for Name { |
| 83 | type Err = InvalidNameError; |
| 84 | |
| 85 | fn from_str(host: &str) -> Result<Self, Self::Err> { |
| 86 | // Possibly add validation later |
| 87 | Ok(Name::new(host.into())) |
| 88 | } |
| 89 | } |
| 90 | |
| 91 | /// Error indicating a given string was not a valid domain name. |
| 92 | #[derive (Debug)] |
| 93 | pub struct InvalidNameError(()); |
| 94 | |
| 95 | impl fmt::Display for InvalidNameError { |
| 96 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 97 | f.write_str(data:"Not a valid domain name" ) |
| 98 | } |
| 99 | } |
| 100 | |
| 101 | impl Error for InvalidNameError {} |
| 102 | |
| 103 | impl GaiResolver { |
| 104 | /// Construct a new `GaiResolver`. |
| 105 | pub fn new() -> Self { |
| 106 | GaiResolver { _priv: () } |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | impl Service<Name> for GaiResolver { |
| 111 | type Response = GaiAddrs; |
| 112 | type Error = io::Error; |
| 113 | type Future = GaiFuture; |
| 114 | |
| 115 | fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> { |
| 116 | Poll::Ready(Ok(())) |
| 117 | } |
| 118 | |
| 119 | fn call(&mut self, name: Name) -> Self::Future { |
| 120 | let span: Span = debug_span!("resolve" , host = %name.host); |
| 121 | let blocking: JoinHandle> = tokio::task::spawn_blocking(move || { |
| 122 | let _enter: Entered<'_> = span.enter(); |
| 123 | (&*name.host, 0) |
| 124 | .to_socket_addrs() |
| 125 | .map(|i: IntoIter| SocketAddrs { iter: i }) |
| 126 | }); |
| 127 | |
| 128 | GaiFuture { inner: blocking } |
| 129 | } |
| 130 | } |
| 131 | |
| 132 | impl fmt::Debug for GaiResolver { |
| 133 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 134 | f.pad("GaiResolver" ) |
| 135 | } |
| 136 | } |
| 137 | |
| 138 | impl Future for GaiFuture { |
| 139 | type Output = Result<GaiAddrs, io::Error>; |
| 140 | |
| 141 | fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { |
| 142 | Pin::new(&mut self.inner).poll(cx).map(|res: Result, …>| match res { |
| 143 | Ok(Ok(addrs: SocketAddrs)) => Ok(GaiAddrs { inner: addrs }), |
| 144 | Ok(Err(err: Error)) => Err(err), |
| 145 | Err(join_err: JoinError) => { |
| 146 | if join_err.is_cancelled() { |
| 147 | Err(io::Error::new(kind:io::ErrorKind::Interrupted, error:join_err)) |
| 148 | } else { |
| 149 | panic!("gai background task failed: {:?}" , join_err) |
| 150 | } |
| 151 | } |
| 152 | }) |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | impl fmt::Debug for GaiFuture { |
| 157 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 158 | f.pad("GaiFuture" ) |
| 159 | } |
| 160 | } |
| 161 | |
| 162 | impl Drop for GaiFuture { |
| 163 | fn drop(&mut self) { |
| 164 | self.inner.abort(); |
| 165 | } |
| 166 | } |
| 167 | |
| 168 | impl Iterator for GaiAddrs { |
| 169 | type Item = SocketAddr; |
| 170 | |
| 171 | fn next(&mut self) -> Option<Self::Item> { |
| 172 | self.inner.next() |
| 173 | } |
| 174 | } |
| 175 | |
| 176 | impl fmt::Debug for GaiAddrs { |
| 177 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 178 | f.pad("GaiAddrs" ) |
| 179 | } |
| 180 | } |
| 181 | |
| 182 | pub(super) struct SocketAddrs { |
| 183 | iter: vec::IntoIter<SocketAddr>, |
| 184 | } |
| 185 | |
| 186 | impl SocketAddrs { |
| 187 | pub(super) fn new(addrs: Vec<SocketAddr>) -> Self { |
| 188 | SocketAddrs { |
| 189 | iter: addrs.into_iter(), |
| 190 | } |
| 191 | } |
| 192 | |
| 193 | pub(super) fn try_parse(host: &str, port: u16) -> Option<SocketAddrs> { |
| 194 | if let Ok(addr) = host.parse::<Ipv4Addr>() { |
| 195 | let addr = SocketAddrV4::new(addr, port); |
| 196 | return Some(SocketAddrs { |
| 197 | iter: vec![SocketAddr::V4(addr)].into_iter(), |
| 198 | }); |
| 199 | } |
| 200 | if let Ok(addr) = host.parse::<Ipv6Addr>() { |
| 201 | let addr = SocketAddrV6::new(addr, port, 0, 0); |
| 202 | return Some(SocketAddrs { |
| 203 | iter: vec![SocketAddr::V6(addr)].into_iter(), |
| 204 | }); |
| 205 | } |
| 206 | None |
| 207 | } |
| 208 | |
| 209 | #[inline ] |
| 210 | fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> SocketAddrs { |
| 211 | SocketAddrs::new(self.iter.filter(predicate).collect()) |
| 212 | } |
| 213 | |
| 214 | pub(super) fn split_by_preference( |
| 215 | self, |
| 216 | local_addr_ipv4: Option<Ipv4Addr>, |
| 217 | local_addr_ipv6: Option<Ipv6Addr>, |
| 218 | ) -> (SocketAddrs, SocketAddrs) { |
| 219 | match (local_addr_ipv4, local_addr_ipv6) { |
| 220 | (Some(_), None) => (self.filter(SocketAddr::is_ipv4), SocketAddrs::new(vec![])), |
| 221 | (None, Some(_)) => (self.filter(SocketAddr::is_ipv6), SocketAddrs::new(vec![])), |
| 222 | _ => { |
| 223 | let preferring_v6 = self |
| 224 | .iter |
| 225 | .as_slice() |
| 226 | .first() |
| 227 | .map(SocketAddr::is_ipv6) |
| 228 | .unwrap_or(false); |
| 229 | |
| 230 | let (preferred, fallback) = self |
| 231 | .iter |
| 232 | .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6); |
| 233 | |
| 234 | (SocketAddrs::new(preferred), SocketAddrs::new(fallback)) |
| 235 | } |
| 236 | } |
| 237 | } |
| 238 | |
| 239 | pub(super) fn is_empty(&self) -> bool { |
| 240 | self.iter.as_slice().is_empty() |
| 241 | } |
| 242 | |
| 243 | pub(super) fn len(&self) -> usize { |
| 244 | self.iter.as_slice().len() |
| 245 | } |
| 246 | } |
| 247 | |
| 248 | impl Iterator for SocketAddrs { |
| 249 | type Item = SocketAddr; |
| 250 | #[inline ] |
| 251 | fn next(&mut self) -> Option<SocketAddr> { |
| 252 | self.iter.next() |
| 253 | } |
| 254 | } |
| 255 | |
| 256 | mod sealed { |
| 257 | use std::future::Future; |
| 258 | use std::task::{self, Poll}; |
| 259 | |
| 260 | use super::{Name, SocketAddr}; |
| 261 | use tower_service::Service; |
| 262 | |
| 263 | // "Trait alias" for `Service<Name, Response = Addrs>` |
| 264 | pub trait Resolve { |
| 265 | type Addrs: Iterator<Item = SocketAddr>; |
| 266 | type Error: Into<Box<dyn std::error::Error + Send + Sync>>; |
| 267 | type Future: Future<Output = Result<Self::Addrs, Self::Error>>; |
| 268 | |
| 269 | fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>; |
| 270 | fn resolve(&mut self, name: Name) -> Self::Future; |
| 271 | } |
| 272 | |
| 273 | impl<S> Resolve for S |
| 274 | where |
| 275 | S: Service<Name>, |
| 276 | S::Response: Iterator<Item = SocketAddr>, |
| 277 | S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, |
| 278 | { |
| 279 | type Addrs = S::Response; |
| 280 | type Error = S::Error; |
| 281 | type Future = S::Future; |
| 282 | |
| 283 | fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 284 | Service::poll_ready(self, cx) |
| 285 | } |
| 286 | |
| 287 | fn resolve(&mut self, name: Name) -> Self::Future { |
| 288 | Service::call(self, name) |
| 289 | } |
| 290 | } |
| 291 | } |
| 292 | |
| 293 | pub(super) async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error> |
| 294 | where |
| 295 | R: Resolve, |
| 296 | { |
| 297 | futures_util::future::poll_fn(|cx: &mut Context<'_>| resolver.poll_ready(cx)).await?; |
| 298 | resolver.resolve(name).await |
| 299 | } |
| 300 | |
| 301 | #[cfg (test)] |
| 302 | mod tests { |
| 303 | use super::*; |
| 304 | use std::net::{Ipv4Addr, Ipv6Addr}; |
| 305 | |
| 306 | #[test ] |
| 307 | fn test_ip_addrs_split_by_preference() { |
| 308 | let ip_v4 = Ipv4Addr::new(127, 0, 0, 1); |
| 309 | let ip_v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); |
| 310 | let v4_addr = (ip_v4, 80).into(); |
| 311 | let v6_addr = (ip_v6, 80).into(); |
| 312 | |
| 313 | let (mut preferred, mut fallback) = SocketAddrs { |
| 314 | iter: vec![v4_addr, v6_addr].into_iter(), |
| 315 | } |
| 316 | .split_by_preference(None, None); |
| 317 | assert!(preferred.next().unwrap().is_ipv4()); |
| 318 | assert!(fallback.next().unwrap().is_ipv6()); |
| 319 | |
| 320 | let (mut preferred, mut fallback) = SocketAddrs { |
| 321 | iter: vec![v6_addr, v4_addr].into_iter(), |
| 322 | } |
| 323 | .split_by_preference(None, None); |
| 324 | assert!(preferred.next().unwrap().is_ipv6()); |
| 325 | assert!(fallback.next().unwrap().is_ipv4()); |
| 326 | |
| 327 | let (mut preferred, mut fallback) = SocketAddrs { |
| 328 | iter: vec![v4_addr, v6_addr].into_iter(), |
| 329 | } |
| 330 | .split_by_preference(Some(ip_v4), Some(ip_v6)); |
| 331 | assert!(preferred.next().unwrap().is_ipv4()); |
| 332 | assert!(fallback.next().unwrap().is_ipv6()); |
| 333 | |
| 334 | let (mut preferred, mut fallback) = SocketAddrs { |
| 335 | iter: vec![v6_addr, v4_addr].into_iter(), |
| 336 | } |
| 337 | .split_by_preference(Some(ip_v4), Some(ip_v6)); |
| 338 | assert!(preferred.next().unwrap().is_ipv6()); |
| 339 | assert!(fallback.next().unwrap().is_ipv4()); |
| 340 | |
| 341 | let (mut preferred, fallback) = SocketAddrs { |
| 342 | iter: vec![v4_addr, v6_addr].into_iter(), |
| 343 | } |
| 344 | .split_by_preference(Some(ip_v4), None); |
| 345 | assert!(preferred.next().unwrap().is_ipv4()); |
| 346 | assert!(fallback.is_empty()); |
| 347 | |
| 348 | let (mut preferred, fallback) = SocketAddrs { |
| 349 | iter: vec![v4_addr, v6_addr].into_iter(), |
| 350 | } |
| 351 | .split_by_preference(None, Some(ip_v6)); |
| 352 | assert!(preferred.next().unwrap().is_ipv6()); |
| 353 | assert!(fallback.is_empty()); |
| 354 | } |
| 355 | |
| 356 | #[test ] |
| 357 | fn test_name_from_str() { |
| 358 | const DOMAIN: &str = "test.example.com" ; |
| 359 | let name = Name::from_str(DOMAIN).expect("Should be a valid domain" ); |
| 360 | assert_eq!(name.as_str(), DOMAIN); |
| 361 | assert_eq!(name.to_string(), DOMAIN); |
| 362 | } |
| 363 | } |
| 364 | |