| 1 | use std::future::Future; |
| 2 | use std::pin::Pin; |
| 3 | use std::sync::Arc; |
| 4 | use std::task::{Context, Poll}; |
| 5 | use std::{fmt, io}; |
| 6 | |
| 7 | use http::Uri; |
| 8 | use hyper::rt; |
| 9 | use hyper_util::client::legacy::connect::Connection; |
| 10 | use hyper_util::rt::TokioIo; |
| 11 | use pki_types::ServerName; |
| 12 | use tokio_rustls::TlsConnector; |
| 13 | use tower_service::Service; |
| 14 | |
| 15 | use crate::stream::MaybeHttpsStream; |
| 16 | |
| 17 | pub(crate) mod builder; |
| 18 | |
| 19 | type BoxError = Box<dyn std::error::Error + Send + Sync>; |
| 20 | |
| 21 | /// A Connector for the `https` scheme. |
| 22 | #[derive (Clone)] |
| 23 | pub struct HttpsConnector<T> { |
| 24 | force_https: bool, |
| 25 | http: T, |
| 26 | tls_config: Arc<rustls::ClientConfig>, |
| 27 | server_name_resolver: Arc<dyn ResolveServerName + Sync + Send>, |
| 28 | } |
| 29 | |
| 30 | impl<T> HttpsConnector<T> { |
| 31 | /// Creates a [`crate::HttpsConnectorBuilder`] to configure a `HttpsConnector`. |
| 32 | /// |
| 33 | /// This is the same as [`crate::HttpsConnectorBuilder::new()`]. |
| 34 | pub fn builder() -> builder::ConnectorBuilder<builder::WantsTlsConfig> { |
| 35 | builder::ConnectorBuilder::new() |
| 36 | } |
| 37 | |
| 38 | /// Force the use of HTTPS when connecting. |
| 39 | /// |
| 40 | /// If a URL is not `https` when connecting, an error is returned. |
| 41 | pub fn enforce_https(&mut self) { |
| 42 | self.force_https = true; |
| 43 | } |
| 44 | } |
| 45 | |
| 46 | impl<T> Service<Uri> for HttpsConnector<T> |
| 47 | where |
| 48 | T: Service<Uri>, |
| 49 | T::Response: Connection + rt::Read + rt::Write + Send + Unpin + 'static, |
| 50 | T::Future: Send + 'static, |
| 51 | T::Error: Into<BoxError>, |
| 52 | { |
| 53 | type Response = MaybeHttpsStream<T::Response>; |
| 54 | type Error = BoxError; |
| 55 | |
| 56 | #[allow (clippy::type_complexity)] |
| 57 | type Future = |
| 58 | Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T::Response>, BoxError>> + Send>>; |
| 59 | |
| 60 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 61 | match self.http.poll_ready(cx) { |
| 62 | Poll::Ready(Ok(())) => Poll::Ready(Ok(())), |
| 63 | Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), |
| 64 | Poll::Pending => Poll::Pending, |
| 65 | } |
| 66 | } |
| 67 | |
| 68 | fn call(&mut self, dst: Uri) -> Self::Future { |
| 69 | // dst.scheme() would need to derive Eq to be matchable; |
| 70 | // use an if cascade instead |
| 71 | match dst.scheme() { |
| 72 | Some(scheme) if scheme == &http::uri::Scheme::HTTP && !self.force_https => { |
| 73 | let future = self.http.call(dst); |
| 74 | return Box::pin(async move { |
| 75 | Ok(MaybeHttpsStream::Http(future.await.map_err(Into::into)?)) |
| 76 | }); |
| 77 | } |
| 78 | Some(scheme) if scheme != &http::uri::Scheme::HTTPS => { |
| 79 | let message = format!("unsupported scheme {scheme}" ); |
| 80 | return Box::pin(async move { |
| 81 | Err(io::Error::new(io::ErrorKind::Other, message).into()) |
| 82 | }); |
| 83 | } |
| 84 | Some(_) => {} |
| 85 | None => { |
| 86 | return Box::pin(async move { |
| 87 | Err(io::Error::new(io::ErrorKind::Other, "missing scheme" ).into()) |
| 88 | }) |
| 89 | } |
| 90 | }; |
| 91 | |
| 92 | let cfg = self.tls_config.clone(); |
| 93 | let hostname = match self.server_name_resolver.resolve(&dst) { |
| 94 | Ok(hostname) => hostname, |
| 95 | Err(e) => { |
| 96 | return Box::pin(async move { Err(e) }); |
| 97 | } |
| 98 | }; |
| 99 | |
| 100 | let connecting_future = self.http.call(dst); |
| 101 | Box::pin(async move { |
| 102 | let tcp = connecting_future |
| 103 | .await |
| 104 | .map_err(Into::into)?; |
| 105 | Ok(MaybeHttpsStream::Https(TokioIo::new( |
| 106 | TlsConnector::from(cfg) |
| 107 | .connect(hostname, TokioIo::new(tcp)) |
| 108 | .await |
| 109 | .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?, |
| 110 | ))) |
| 111 | }) |
| 112 | } |
| 113 | } |
| 114 | |
| 115 | impl<H, C> From<(H, C)> for HttpsConnector<H> |
| 116 | where |
| 117 | C: Into<Arc<rustls::ClientConfig>>, |
| 118 | { |
| 119 | fn from((http: H, cfg: C): (H, C)) -> Self { |
| 120 | Self { |
| 121 | force_https: false, |
| 122 | http, |
| 123 | tls_config: cfg.into(), |
| 124 | server_name_resolver: Arc::new(data:DefaultServerNameResolver::default()), |
| 125 | } |
| 126 | } |
| 127 | } |
| 128 | |
| 129 | impl<T> fmt::Debug for HttpsConnector<T> { |
| 130 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
| 131 | f&mut DebugStruct<'_, '_>.debug_struct("HttpsConnector" ) |
| 132 | .field(name:"force_https" , &self.force_https) |
| 133 | .finish() |
| 134 | } |
| 135 | } |
| 136 | |
| 137 | /// The default server name resolver, which uses the hostname in the URI. |
| 138 | #[derive (Default)] |
| 139 | pub struct DefaultServerNameResolver(()); |
| 140 | |
| 141 | impl ResolveServerName for DefaultServerNameResolver { |
| 142 | fn resolve( |
| 143 | &self, |
| 144 | uri: &Uri, |
| 145 | ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> { |
| 146 | let mut hostname: &str = uri.host().unwrap_or_default(); |
| 147 | |
| 148 | // Remove square brackets around IPv6 address. |
| 149 | if let Some(trimmed: &str) = hostnameOption<&str> |
| 150 | .strip_prefix('[' ) |
| 151 | .and_then(|h: &str| h.strip_suffix(']' )) |
| 152 | { |
| 153 | hostname = trimmed; |
| 154 | } |
| 155 | |
| 156 | ServerName::try_from(hostname.to_string()).map_err(|e: InvalidDnsNameError| Box::new(e) as _) |
| 157 | } |
| 158 | } |
| 159 | |
| 160 | /// A server name resolver which always returns the same fixed name. |
| 161 | pub struct FixedServerNameResolver { |
| 162 | name: ServerName<'static>, |
| 163 | } |
| 164 | |
| 165 | impl FixedServerNameResolver { |
| 166 | /// Creates a new resolver returning the specified name. |
| 167 | pub fn new(name: ServerName<'static>) -> Self { |
| 168 | Self { name } |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | impl ResolveServerName for FixedServerNameResolver { |
| 173 | fn resolve( |
| 174 | &self, |
| 175 | _: &Uri, |
| 176 | ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> { |
| 177 | Ok(self.name.clone()) |
| 178 | } |
| 179 | } |
| 180 | |
| 181 | impl<F, E> ResolveServerName for F |
| 182 | where |
| 183 | F: Fn(&Uri) -> Result<ServerName<'static>, E>, |
| 184 | E: Into<Box<dyn std::error::Error + Sync + Send>>, |
| 185 | { |
| 186 | fn resolve( |
| 187 | &self, |
| 188 | uri: &Uri, |
| 189 | ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>> { |
| 190 | self(uri).map_err(op:Into::into) |
| 191 | } |
| 192 | } |
| 193 | |
| 194 | /// A trait implemented by types that can resolve a [`ServerName`] for a request. |
| 195 | pub trait ResolveServerName { |
| 196 | /// Maps a [`Uri`] into a [`ServerName`]. |
| 197 | fn resolve( |
| 198 | &self, |
| 199 | uri: &Uri, |
| 200 | ) -> Result<ServerName<'static>, Box<dyn std::error::Error + Sync + Send>>; |
| 201 | } |
| 202 | |
| 203 | #[cfg (all( |
| 204 | test, |
| 205 | any(feature = "ring" , feature = "aws-lc-rs" ), |
| 206 | any( |
| 207 | feature = "rustls-native-certs" , |
| 208 | feature = "webpki-roots" , |
| 209 | feature = "rustls-platform-verifier" , |
| 210 | ) |
| 211 | ))] |
| 212 | mod tests { |
| 213 | use std::future::poll_fn; |
| 214 | |
| 215 | use http::Uri; |
| 216 | use hyper_util::rt::TokioIo; |
| 217 | use tokio::net::TcpStream; |
| 218 | use tower_service::Service; |
| 219 | |
| 220 | use super::*; |
| 221 | use crate::{ConfigBuilderExt, HttpsConnectorBuilder, MaybeHttpsStream}; |
| 222 | |
| 223 | #[tokio::test ] |
| 224 | async fn connects_https() { |
| 225 | connect(Allow::Any, Scheme::Https) |
| 226 | .await |
| 227 | .unwrap(); |
| 228 | } |
| 229 | |
| 230 | #[tokio::test ] |
| 231 | async fn connects_http() { |
| 232 | connect(Allow::Any, Scheme::Http) |
| 233 | .await |
| 234 | .unwrap(); |
| 235 | } |
| 236 | |
| 237 | #[tokio::test ] |
| 238 | async fn connects_https_only() { |
| 239 | connect(Allow::Https, Scheme::Https) |
| 240 | .await |
| 241 | .unwrap(); |
| 242 | } |
| 243 | |
| 244 | #[tokio::test ] |
| 245 | async fn enforces_https_only() { |
| 246 | let message = connect(Allow::Https, Scheme::Http) |
| 247 | .await |
| 248 | .unwrap_err() |
| 249 | .to_string(); |
| 250 | |
| 251 | assert_eq!(message, "unsupported scheme http" ); |
| 252 | } |
| 253 | |
| 254 | async fn connect( |
| 255 | allow: Allow, |
| 256 | scheme: Scheme, |
| 257 | ) -> Result<MaybeHttpsStream<TokioIo<TcpStream>>, BoxError> { |
| 258 | let config_builder = rustls::ClientConfig::builder(); |
| 259 | cfg_if::cfg_if! { |
| 260 | if #[cfg(feature = "rustls-platform-verifier" )] { |
| 261 | let config_builder = config_builder.with_platform_verifier(); |
| 262 | } else if #[cfg(feature = "rustls-native-certs" )] { |
| 263 | let config_builder = config_builder.with_native_roots().unwrap(); |
| 264 | } else if #[cfg(feature = "webpki-roots" )] { |
| 265 | let config_builder = config_builder.with_webpki_roots(); |
| 266 | } |
| 267 | } |
| 268 | let config = config_builder.with_no_client_auth(); |
| 269 | |
| 270 | let builder = HttpsConnectorBuilder::new().with_tls_config(config); |
| 271 | let mut service = match allow { |
| 272 | Allow::Https => builder.https_only(), |
| 273 | Allow::Any => builder.https_or_http(), |
| 274 | } |
| 275 | .enable_http1() |
| 276 | .build(); |
| 277 | |
| 278 | poll_fn(|cx| service.poll_ready(cx)).await?; |
| 279 | service |
| 280 | .call(Uri::from_static(match scheme { |
| 281 | Scheme::Https => "https://google.com" , |
| 282 | Scheme::Http => "http://google.com" , |
| 283 | })) |
| 284 | .await |
| 285 | } |
| 286 | |
| 287 | enum Allow { |
| 288 | Https, |
| 289 | Any, |
| 290 | } |
| 291 | |
| 292 | enum Scheme { |
| 293 | Https, |
| 294 | Http, |
| 295 | } |
| 296 | } |
| 297 | |