| 1 | use hyper_util::client::legacy::connect::dns::Name as HyperName; |
| 2 | use tower_service::Service; |
| 3 | |
| 4 | use std::collections::HashMap; |
| 5 | use std::future::Future; |
| 6 | use std::net::SocketAddr; |
| 7 | use std::pin::Pin; |
| 8 | use std::str::FromStr; |
| 9 | use std::sync::Arc; |
| 10 | use std::task::{Context, Poll}; |
| 11 | |
| 12 | use crate::error::BoxError; |
| 13 | |
| 14 | /// Alias for an `Iterator` trait object over `SocketAddr`. |
| 15 | pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Send>; |
| 16 | |
| 17 | /// Alias for the `Future` type returned by a DNS resolver. |
| 18 | pub type Resolving = Pin<Box<dyn Future<Output = Result<Addrs, BoxError>> + Send>>; |
| 19 | |
| 20 | /// Trait for customizing DNS resolution in reqwest. |
| 21 | pub trait Resolve: Send + Sync { |
| 22 | /// Performs DNS resolution on a `Name`. |
| 23 | /// The return type is a future containing an iterator of `SocketAddr`. |
| 24 | /// |
| 25 | /// It differs from `tower_service::Service<Name>` in several ways: |
| 26 | /// * It is assumed that `resolve` will always be ready to poll. |
| 27 | /// * It does not need a mutable reference to `self`. |
| 28 | /// * Since trait objects cannot make use of associated types, it requires |
| 29 | /// wrapping the returned `Future` and its contained `Iterator` with `Box`. |
| 30 | /// |
| 31 | /// Explicitly specified port in the URL will override any port in the resolved `SocketAddr`s. |
| 32 | /// Otherwise, port `0` will be replaced by the conventional port for the given scheme (e.g. 80 for http). |
| 33 | fn resolve(&self, name: Name) -> Resolving; |
| 34 | } |
| 35 | |
| 36 | /// A name that must be resolved to addresses. |
| 37 | #[derive (Debug)] |
| 38 | pub struct Name(pub(super) HyperName); |
| 39 | |
| 40 | impl Name { |
| 41 | /// View the name as a string. |
| 42 | pub fn as_str(&self) -> &str { |
| 43 | self.0.as_str() |
| 44 | } |
| 45 | } |
| 46 | |
| 47 | impl FromStr for Name { |
| 48 | type Err = sealed::InvalidNameError; |
| 49 | |
| 50 | fn from_str(host: &str) -> Result<Self, Self::Err> { |
| 51 | HyperName::from_str(host) |
| 52 | .map(Name) |
| 53 | .map_err(|_| sealed::InvalidNameError { _ext: () }) |
| 54 | } |
| 55 | } |
| 56 | |
| 57 | #[derive (Clone)] |
| 58 | pub(crate) struct DynResolver { |
| 59 | resolver: Arc<dyn Resolve>, |
| 60 | } |
| 61 | |
| 62 | impl DynResolver { |
| 63 | pub(crate) fn new(resolver: Arc<dyn Resolve>) -> Self { |
| 64 | Self { resolver } |
| 65 | } |
| 66 | } |
| 67 | |
| 68 | impl Service<HyperName> for DynResolver { |
| 69 | type Response = Addrs; |
| 70 | type Error = BoxError; |
| 71 | type Future = Resolving; |
| 72 | |
| 73 | fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 74 | Poll::Ready(Ok(())) |
| 75 | } |
| 76 | |
| 77 | fn call(&mut self, name: HyperName) -> Self::Future { |
| 78 | self.resolver.resolve(Name(name)) |
| 79 | } |
| 80 | } |
| 81 | |
| 82 | pub(crate) struct DnsResolverWithOverrides { |
| 83 | dns_resolver: Arc<dyn Resolve>, |
| 84 | overrides: Arc<HashMap<String, Vec<SocketAddr>>>, |
| 85 | } |
| 86 | |
| 87 | impl DnsResolverWithOverrides { |
| 88 | pub(crate) fn new( |
| 89 | dns_resolver: Arc<dyn Resolve>, |
| 90 | overrides: HashMap<String, Vec<SocketAddr>>, |
| 91 | ) -> Self { |
| 92 | DnsResolverWithOverrides { |
| 93 | dns_resolver, |
| 94 | overrides: Arc::new(data:overrides), |
| 95 | } |
| 96 | } |
| 97 | } |
| 98 | |
| 99 | impl Resolve for DnsResolverWithOverrides { |
| 100 | fn resolve(&self, name: Name) -> Resolving { |
| 101 | match self.overrides.get(name.as_str()) { |
| 102 | Some(dest: &Vec) => { |
| 103 | let addrs: Addrs = Box::new(dest.clone().into_iter()); |
| 104 | Box::pin(std::future::ready(Ok(addrs))) |
| 105 | } |
| 106 | None => self.dns_resolver.resolve(name), |
| 107 | } |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | mod sealed { |
| 112 | use std::fmt; |
| 113 | |
| 114 | #[derive (Debug)] |
| 115 | pub struct InvalidNameError { |
| 116 | pub(super) _ext: (), |
| 117 | } |
| 118 | |
| 119 | impl fmt::Display for InvalidNameError { |
| 120 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 121 | f.write_str(data:"invalid DNS name" ) |
| 122 | } |
| 123 | } |
| 124 | |
| 125 | impl std::error::Error for InvalidNameError {} |
| 126 | } |
| 127 | |