1 | use hyper_util::client::legacy::connect::dns::Name as HyperName; |
2 | use tower_service::Service; |
3 | |
4 | use std::collections::HashMap; |
5 | use std::future::Future; |
6 | use std::net::SocketAddr; |
7 | use std::pin::Pin; |
8 | use std::str::FromStr; |
9 | use std::sync::Arc; |
10 | use std::task::{Context, Poll}; |
11 | |
12 | use crate::error::BoxError; |
13 | |
14 | /// Alias for an `Iterator` trait object over `SocketAddr`. |
15 | pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Send>; |
16 | |
17 | /// Alias for the `Future` type returned by a DNS resolver. |
18 | pub type Resolving = Pin<Box<dyn Future<Output = Result<Addrs, BoxError>> + Send>>; |
19 | |
20 | /// Trait for customizing DNS resolution in reqwest. |
21 | pub trait Resolve: Send + Sync { |
22 | /// Performs DNS resolution on a `Name`. |
23 | /// The return type is a future containing an iterator of `SocketAddr`. |
24 | /// |
25 | /// It differs from `tower_service::Service<Name>` in several ways: |
26 | /// * It is assumed that `resolve` will always be ready to poll. |
27 | /// * It does not need a mutable reference to `self`. |
28 | /// * Since trait objects cannot make use of associated types, it requires |
29 | /// wrapping the returned `Future` and its contained `Iterator` with `Box`. |
30 | /// |
31 | /// Explicitly specified port in the URL will override any port in the resolved `SocketAddr`s. |
32 | /// Otherwise, port `0` will be replaced by the conventional port for the given scheme (e.g. 80 for http). |
33 | fn resolve(&self, name: Name) -> Resolving; |
34 | } |
35 | |
36 | /// A name that must be resolved to addresses. |
37 | #[derive (Debug)] |
38 | pub struct Name(pub(super) HyperName); |
39 | |
40 | impl Name { |
41 | /// View the name as a string. |
42 | pub fn as_str(&self) -> &str { |
43 | self.0.as_str() |
44 | } |
45 | } |
46 | |
47 | impl FromStr for Name { |
48 | type Err = sealed::InvalidNameError; |
49 | |
50 | fn from_str(host: &str) -> Result<Self, Self::Err> { |
51 | HyperName::from_str(host) |
52 | .map(Name) |
53 | .map_err(|_| sealed::InvalidNameError { _ext: () }) |
54 | } |
55 | } |
56 | |
57 | #[derive (Clone)] |
58 | pub(crate) struct DynResolver { |
59 | resolver: Arc<dyn Resolve>, |
60 | } |
61 | |
62 | impl DynResolver { |
63 | pub(crate) fn new(resolver: Arc<dyn Resolve>) -> Self { |
64 | Self { resolver } |
65 | } |
66 | } |
67 | |
68 | impl Service<HyperName> for DynResolver { |
69 | type Response = Addrs; |
70 | type Error = BoxError; |
71 | type Future = Resolving; |
72 | |
73 | fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
74 | Poll::Ready(Ok(())) |
75 | } |
76 | |
77 | fn call(&mut self, name: HyperName) -> Self::Future { |
78 | self.resolver.resolve(Name(name)) |
79 | } |
80 | } |
81 | |
82 | pub(crate) struct DnsResolverWithOverrides { |
83 | dns_resolver: Arc<dyn Resolve>, |
84 | overrides: Arc<HashMap<String, Vec<SocketAddr>>>, |
85 | } |
86 | |
87 | impl DnsResolverWithOverrides { |
88 | pub(crate) fn new( |
89 | dns_resolver: Arc<dyn Resolve>, |
90 | overrides: HashMap<String, Vec<SocketAddr>>, |
91 | ) -> Self { |
92 | DnsResolverWithOverrides { |
93 | dns_resolver, |
94 | overrides: Arc::new(data:overrides), |
95 | } |
96 | } |
97 | } |
98 | |
99 | impl Resolve for DnsResolverWithOverrides { |
100 | fn resolve(&self, name: Name) -> Resolving { |
101 | match self.overrides.get(name.as_str()) { |
102 | Some(dest: &Vec) => { |
103 | let addrs: Addrs = Box::new(dest.clone().into_iter()); |
104 | Box::pin(std::future::ready(Ok(addrs))) |
105 | } |
106 | None => self.dns_resolver.resolve(name), |
107 | } |
108 | } |
109 | } |
110 | |
111 | mod sealed { |
112 | use std::fmt; |
113 | |
114 | #[derive (Debug)] |
115 | pub struct InvalidNameError { |
116 | pub(super) _ext: (), |
117 | } |
118 | |
119 | impl fmt::Display for InvalidNameError { |
120 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
121 | f.write_str(data:"invalid DNS name" ) |
122 | } |
123 | } |
124 | |
125 | impl std::error::Error for InvalidNameError {} |
126 | } |
127 | |