1 | use std::fmt; |
2 | use std::future::Future; |
3 | use std::pin::Pin; |
4 | use std::task::{Context, Poll}; |
5 | |
6 | use hyper::{client::connect::HttpConnector, service::Service, Uri}; |
7 | use tokio::io::{AsyncRead, AsyncWrite}; |
8 | use tokio_native_tls::TlsConnector; |
9 | |
10 | use crate::stream::MaybeHttpsStream; |
11 | |
12 | type BoxError = Box<dyn std::error::Error + Send + Sync>; |
13 | |
14 | /// A Connector for the `https` scheme. |
15 | #[derive (Clone)] |
16 | pub struct HttpsConnector<T> { |
17 | force_https: bool, |
18 | http: T, |
19 | tls: TlsConnector, |
20 | } |
21 | |
22 | impl HttpsConnector<HttpConnector> { |
23 | /// Construct a new HttpsConnector. |
24 | /// |
25 | /// This uses hyper's default `HttpConnector`, and default `TlsConnector`. |
26 | /// If you wish to use something besides the defaults, use `From::from`. |
27 | /// |
28 | /// # Note |
29 | /// |
30 | /// By default this connector will use plain HTTP if the URL provded uses |
31 | /// the HTTP scheme (eg: http://example.com/). |
32 | /// |
33 | /// If you would like to force the use of HTTPS then call https_only(true) |
34 | /// on the returned connector. |
35 | /// |
36 | /// # Panics |
37 | /// |
38 | /// This will panic if the underlying TLS context could not be created. |
39 | /// |
40 | /// To handle that error yourself, you can use the `HttpsConnector::from` |
41 | /// constructor after trying to make a `TlsConnector`. |
42 | pub fn new() -> Self { |
43 | native_tls::TlsConnector::new() |
44 | .map(|tls| HttpsConnector::new_(tls.into())) |
45 | .unwrap_or_else(|e| panic!("HttpsConnector::new() failure: {}" , e)) |
46 | } |
47 | |
48 | fn new_(tls: TlsConnector) -> Self { |
49 | let mut http = HttpConnector::new(); |
50 | http.enforce_http(false); |
51 | HttpsConnector::from((http, tls)) |
52 | } |
53 | } |
54 | |
55 | impl<T: Default> Default for HttpsConnector<T> { |
56 | fn default() -> Self { |
57 | Self::new_with_connector(http:Default::default()) |
58 | } |
59 | } |
60 | |
61 | impl<T> HttpsConnector<T> { |
62 | /// Force the use of HTTPS when connecting. |
63 | /// |
64 | /// If a URL is not `https` when connecting, an error is returned. |
65 | pub fn https_only(&mut self, enable: bool) { |
66 | self.force_https = enable; |
67 | } |
68 | |
69 | /// With connector constructor |
70 | /// |
71 | pub fn new_with_connector(http: T) -> Self { |
72 | native_tls::TlsConnector::new() |
73 | .map(|tls| HttpsConnector::from((http, tls.into()))) |
74 | .unwrap_or_else(|e: Error| { |
75 | panic!( |
76 | "HttpsConnector::new_with_connector(<connector>) failure: {}" , |
77 | e |
78 | ) |
79 | }) |
80 | } |
81 | } |
82 | |
83 | impl<T> From<(T, TlsConnector)> for HttpsConnector<T> { |
84 | fn from(args: (T, TlsConnector)) -> HttpsConnector<T> { |
85 | HttpsConnector { |
86 | force_https: false, |
87 | http: args.0, |
88 | tls: args.1, |
89 | } |
90 | } |
91 | } |
92 | |
93 | impl<T: fmt::Debug> fmt::Debug for HttpsConnector<T> { |
94 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
95 | f&mut DebugStruct<'_, '_>.debug_struct("HttpsConnector" ) |
96 | .field("force_https" , &self.force_https) |
97 | .field(name:"http" , &self.http) |
98 | .finish() |
99 | } |
100 | } |
101 | |
102 | impl<T> Service<Uri> for HttpsConnector<T> |
103 | where |
104 | T: Service<Uri>, |
105 | T::Response: AsyncRead + AsyncWrite + Send + Unpin, |
106 | T::Future: Send + 'static, |
107 | T::Error: Into<BoxError>, |
108 | { |
109 | type Response = MaybeHttpsStream<T::Response>; |
110 | type Error = BoxError; |
111 | type Future = HttpsConnecting<T::Response>; |
112 | |
113 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
114 | match self.http.poll_ready(cx) { |
115 | Poll::Ready(Ok(())) => Poll::Ready(Ok(())), |
116 | Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), |
117 | Poll::Pending => Poll::Pending, |
118 | } |
119 | } |
120 | |
121 | fn call(&mut self, dst: Uri) -> Self::Future { |
122 | let is_https = dst.scheme_str() == Some("https" ); |
123 | // Early abort if HTTPS is forced but can't be used |
124 | if !is_https && self.force_https { |
125 | return err(ForceHttpsButUriNotHttps.into()); |
126 | } |
127 | |
128 | let host = dst |
129 | .host() |
130 | .unwrap_or("" ) |
131 | .trim_matches(|c| c == '[' || c == ']' ) |
132 | .to_owned(); |
133 | let connecting = self.http.call(dst); |
134 | let tls = self.tls.clone(); |
135 | let fut = async move { |
136 | let tcp = connecting.await.map_err(Into::into)?; |
137 | let maybe = if is_https { |
138 | let tls = tls.connect(&host, tcp).await?; |
139 | MaybeHttpsStream::Https(tls) |
140 | } else { |
141 | MaybeHttpsStream::Http(tcp) |
142 | }; |
143 | Ok(maybe) |
144 | }; |
145 | HttpsConnecting(Box::pin(fut)) |
146 | } |
147 | } |
148 | |
149 | fn err<T>(e: BoxError) -> HttpsConnecting<T> { |
150 | HttpsConnecting(Box::pin(async { Err(e) })) |
151 | } |
152 | |
153 | type BoxedFut<T> = Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T>, BoxError>> + Send>>; |
154 | |
155 | /// A Future representing work to connect to a URL, and a TLS handshake. |
156 | pub struct HttpsConnecting<T>(BoxedFut<T>); |
157 | |
158 | impl<T: AsyncRead + AsyncWrite + Unpin> Future for HttpsConnecting<T> { |
159 | type Output = Result<MaybeHttpsStream<T>, BoxError>; |
160 | |
161 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
162 | Pin::new(&mut self.0).poll(cx) |
163 | } |
164 | } |
165 | |
166 | impl<T> fmt::Debug for HttpsConnecting<T> { |
167 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
168 | f.pad("HttpsConnecting" ) |
169 | } |
170 | } |
171 | |
172 | // ===== Custom Errors ===== |
173 | |
174 | #[derive (Debug)] |
175 | struct ForceHttpsButUriNotHttps; |
176 | |
177 | impl fmt::Display for ForceHttpsButUriNotHttps { |
178 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
179 | f.write_str(data:"https required but URI was not https" ) |
180 | } |
181 | } |
182 | |
183 | impl std::error::Error for ForceHttpsButUriNotHttps {} |
184 | |