1 | use std::future::Future; |
2 | use std::pin::Pin; |
3 | use std::sync::Arc; |
4 | use std::task::{Context, Poll}; |
5 | use std::{fmt, io}; |
6 | |
7 | use http::Uri; |
8 | use hyper::rt; |
9 | use hyper_util::client::legacy::connect::Connection; |
10 | use hyper_util::rt::TokioIo; |
11 | use pki_types::ServerName; |
12 | use tokio_rustls::TlsConnector; |
13 | use tower_service::Service; |
14 | |
15 | use crate::stream::MaybeHttpsStream; |
16 | |
17 | pub(crate) mod builder; |
18 | |
19 | type BoxError = Box<dyn std::error::Error + Send + Sync>; |
20 | |
21 | /// A Connector for the `https` scheme. |
22 | #[derive (Clone)] |
23 | pub struct HttpsConnector<T> { |
24 | force_https: bool, |
25 | http: T, |
26 | tls_config: Arc<rustls::ClientConfig>, |
27 | server_name_resolver: Arc<dyn ResolveServerName + Sync + Send>, |
28 | } |
29 | |
30 | impl<T> HttpsConnector<T> { |
31 | /// Creates a [`crate::HttpsConnectorBuilder`] to configure a `HttpsConnector`. |
32 | /// |
33 | /// This is the same as [`crate::HttpsConnectorBuilder::new()`]. |
34 | pub fn builder() -> builder::ConnectorBuilder<builder::WantsTlsConfig> { |
35 | builder::ConnectorBuilder::new() |
36 | } |
37 | |
38 | /// Force the use of HTTPS when connecting. |
39 | /// |
40 | /// If a URL is not `https` when connecting, an error is returned. |
41 | pub fn enforce_https(&mut self) { |
42 | self.force_https = true; |
43 | } |
44 | } |
45 | |
46 | impl<T> Service<Uri> for HttpsConnector<T> |
47 | where |
48 | T: Service<Uri>, |
49 | T::Response: Connection + rt::Read + rt::Write + Send + Unpin + 'static, |
50 | T::Future: Send + 'static, |
51 | T::Error: Into<BoxError>, |
52 | { |
53 | type Response = MaybeHttpsStream<T::Response>; |
54 | type Error = BoxError; |
55 | |
56 | #[allow (clippy::type_complexity)] |
57 | type Future = |
58 | Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T::Response>, BoxError>> + Send>>; |
59 | |
60 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
61 | match self.http.poll_ready(cx) { |
62 | Poll::Ready(Ok(())) => Poll::Ready(Ok(())), |
63 | Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), |
64 | Poll::Pending => Poll::Pending, |
65 | } |
66 | } |
67 | |
68 | fn call(&mut self, dst: Uri) -> Self::Future { |
69 | // dst.scheme() would need to derive Eq to be matchable; |
70 | // use an if cascade instead |
71 | match dst.scheme() { |
72 | Some(scheme) if scheme == &http::uri::Scheme::HTTP && !self.force_https => { |
73 | let future = self.http.call(dst); |
74 | return Box::pin(async move { |
75 | Ok(MaybeHttpsStream::Http(future.await.map_err(Into::into)?)) |
76 | }); |
77 | } |
78 | Some(scheme) if scheme != &http::uri::Scheme::HTTPS => { |
79 | let message = format!("unsupported scheme {scheme}" ); |
80 | return Box::pin(async move { |
81 | Err(io::Error::new(io::ErrorKind::Other, message).into()) |
82 | }); |
83 | } |
84 | Some(_) => {} |
85 | None => { |
86 | return Box::pin(async move { |
87 | Err(io::Error::new(io::ErrorKind::Other, "missing scheme" ).into()) |
88 | }) |
89 | } |
90 | }; |
91 | |
92 | let cfg = self.tls_config.clone(); |
93 | let hostname = match self.server_name_resolver.resolve(&dst) { |
94 | Ok(hostname) => hostname, |
95 | Err(e) => { |
96 | return Box::pin(async move { Err(e) }); |
97 | } |
98 | }; |
99 | |
100 | let connecting_future = self.http.call(dst); |
101 | Box::pin(async move { |
102 | let tcp = connecting_future |
103 | .await |
104 | .map_err(Into::into)?; |
105 | Ok(MaybeHttpsStream::Https(TokioIo::new( |
106 | TlsConnector::from(cfg) |
107 | .connect(hostname, TokioIo::new(tcp)) |
108 | .await |
109 | .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?, |
110 | ))) |
111 | }) |
112 | } |
113 | } |
114 | |
115 | impl<H, C> From<(H, C)> for HttpsConnector<H> |
116 | where |
117 | C: Into<Arc<rustls::ClientConfig>>, |
118 | { |
119 | fn from((http: H, cfg: C): (H, C)) -> Self { |
120 | Self { |
121 | force_https: false, |
122 | http, |
123 | tls_config: cfg.into(), |
124 | server_name_resolver: Arc::new(data:DefaultServerNameResolver::default()), |
125 | } |
126 | } |
127 | } |
128 | |
129 | impl<T> fmt::Debug for HttpsConnector<T> { |
130 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
131 | f&mut DebugStruct<'_, '_>.debug_struct("HttpsConnector" ) |
132 | .field(name:"force_https" , &self.force_https) |
133 | .finish() |
134 | } |
135 | } |
136 | |
137 | /// The default server name resolver, which uses the hostname in the URI. |
138 | #[derive (Default)] |
139 | pub struct DefaultServerNameResolver(()); |
140 | |
141 | impl ResolveServerName for DefaultServerNameResolver { |
142 | fn resolve( |
143 | &self, |
144 | uri: &Uri, |
145 | ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> { |
146 | let mut hostname: &str = uri.host().unwrap_or_default(); |
147 | |
148 | // Remove square brackets around IPv6 address. |
149 | if let Some(trimmed: &str) = hostnameOption<&str> |
150 | .strip_prefix('[' ) |
151 | .and_then(|h: &str| h.strip_suffix(']' )) |
152 | { |
153 | hostname = trimmed; |
154 | } |
155 | |
156 | ServerName::try_from(hostname.to_string()).map_err(|e: InvalidDnsNameError| Box::new(e) as _) |
157 | } |
158 | } |
159 | |
160 | /// A server name resolver which always returns the same fixed name. |
161 | pub struct FixedServerNameResolver { |
162 | name: ServerName<'static>, |
163 | } |
164 | |
165 | impl FixedServerNameResolver { |
166 | /// Creates a new resolver returning the specified name. |
167 | pub fn new(name: ServerName<'static>) -> Self { |
168 | Self { name } |
169 | } |
170 | } |
171 | |
172 | impl ResolveServerName for FixedServerNameResolver { |
173 | fn resolve( |
174 | &self, |
175 | _: &Uri, |
176 | ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> { |
177 | Ok(self.name.clone()) |
178 | } |
179 | } |
180 | |
181 | impl<F, E> ResolveServerName for F |
182 | where |
183 | F: Fn(&Uri) -> Result<ServerName<'static>, E>, |
184 | E: Into<Box<dyn std::error::Error + Sync + Send>>, |
185 | { |
186 | fn resolve( |
187 | &self, |
188 | uri: &Uri, |
189 | ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> { |
190 | self(uri).map_err(op:Into::into) |
191 | } |
192 | } |
193 | |
194 | /// A trait implemented by types that can resolve a [`ServerName`] for a request. |
195 | pub trait ResolveServerName { |
196 | /// Maps a [`Uri`] into a [`ServerName`]. |
197 | fn resolve( |
198 | &self, |
199 | uri: &Uri, |
200 | ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>>; |
201 | } |
202 | |
203 | #[cfg (all( |
204 | test, |
205 | any(feature = "ring" , feature = "aws-lc-rs" ), |
206 | any( |
207 | feature = "rustls-native-certs" , |
208 | feature = "webpki-roots" , |
209 | feature = "rustls-platform-verifier" , |
210 | ) |
211 | ))] |
212 | mod tests { |
213 | use std::future::poll_fn; |
214 | |
215 | use http::Uri; |
216 | use hyper_util::rt::TokioIo; |
217 | use tokio::net::TcpStream; |
218 | use tower_service::Service; |
219 | |
220 | use super::*; |
221 | use crate::{ConfigBuilderExt, HttpsConnectorBuilder, MaybeHttpsStream}; |
222 | |
223 | #[tokio::test ] |
224 | async fn connects_https() { |
225 | connect(Allow::Any, Scheme::Https) |
226 | .await |
227 | .unwrap(); |
228 | } |
229 | |
230 | #[tokio::test ] |
231 | async fn connects_http() { |
232 | connect(Allow::Any, Scheme::Http) |
233 | .await |
234 | .unwrap(); |
235 | } |
236 | |
237 | #[tokio::test ] |
238 | async fn connects_https_only() { |
239 | connect(Allow::Https, Scheme::Https) |
240 | .await |
241 | .unwrap(); |
242 | } |
243 | |
244 | #[tokio::test ] |
245 | async fn enforces_https_only() { |
246 | let message = connect(Allow::Https, Scheme::Http) |
247 | .await |
248 | .unwrap_err() |
249 | .to_string(); |
250 | |
251 | assert_eq!(message, "unsupported scheme http" ); |
252 | } |
253 | |
254 | async fn connect( |
255 | allow: Allow, |
256 | scheme: Scheme, |
257 | ) -> Result<MaybeHttpsStream<TokioIo<TcpStream>>, BoxError> { |
258 | let config_builder = rustls::ClientConfig::builder(); |
259 | cfg_if::cfg_if! { |
260 | if #[cfg(feature = "rustls-platform-verifier" )] { |
261 | let config_builder = config_builder.with_platform_verifier(); |
262 | } else if #[cfg(feature = "rustls-native-certs" )] { |
263 | let config_builder = config_builder.with_native_roots().unwrap(); |
264 | } else if #[cfg(feature = "webpki-roots" )] { |
265 | let config_builder = config_builder.with_webpki_roots(); |
266 | } |
267 | } |
268 | let config = config_builder.with_no_client_auth(); |
269 | |
270 | let builder = HttpsConnectorBuilder::new().with_tls_config(config); |
271 | let mut service = match allow { |
272 | Allow::Https => builder.https_only(), |
273 | Allow::Any => builder.https_or_http(), |
274 | } |
275 | .enable_http1() |
276 | .build(); |
277 | |
278 | poll_fn(|cx| service.poll_ready(cx)).await?; |
279 | service |
280 | .call(Uri::from_static(match scheme { |
281 | Scheme::Https => "https://google.com" , |
282 | Scheme::Http => "http://google.com" , |
283 | })) |
284 | .await |
285 | } |
286 | |
287 | enum Allow { |
288 | Https, |
289 | Any, |
290 | } |
291 | |
292 | enum Scheme { |
293 | Https, |
294 | Http, |
295 | } |
296 | } |
297 | |