1 | use hyper::{ |
2 | rt::{Read, Write}, |
3 | Uri, |
4 | }; |
5 | use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioIo}; |
6 | use std::fmt; |
7 | use std::future::Future; |
8 | use std::pin::Pin; |
9 | use std::task::{Context, Poll}; |
10 | use tokio_native_tls::TlsConnector; |
11 | use tower_service::Service; |
12 | |
13 | use crate::stream::MaybeHttpsStream; |
14 | |
15 | type BoxError = Box<dyn std::error::Error + Send + Sync>; |
16 | |
17 | /// A Connector for the `https` scheme. |
18 | #[derive (Clone)] |
19 | pub struct HttpsConnector<T> { |
20 | force_https: bool, |
21 | http: T, |
22 | tls: TlsConnector, |
23 | } |
24 | |
25 | impl HttpsConnector<HttpConnector> { |
26 | /// Construct a new `HttpsConnector`. |
27 | /// |
28 | /// This uses hyper's default `HttpConnector`, and default `TlsConnector`. |
29 | /// If you wish to use something besides the defaults, use `From::from`. |
30 | /// |
31 | /// # Note |
32 | /// |
33 | /// By default this connector will use plain HTTP if the URL provided uses |
34 | /// the HTTP scheme (eg: <http://example.com/>). |
35 | /// |
36 | /// If you would like to force the use of HTTPS then call `https_only(true)` |
37 | /// on the returned connector. |
38 | /// |
39 | /// # Panics |
40 | /// |
41 | /// This will panic if the underlying TLS context could not be created. |
42 | /// |
43 | /// To handle that error yourself, you can use the `HttpsConnector::from` |
44 | /// constructor after trying to make a `TlsConnector`. |
45 | #[must_use ] |
46 | pub fn new() -> Self { |
47 | native_tls::TlsConnector::new().map_or_else( |
48 | |e| panic!("HttpsConnector::new() failure: {}" , e), |
49 | |tls| HttpsConnector::new_(tls.into()), |
50 | ) |
51 | } |
52 | |
53 | fn new_(tls: TlsConnector) -> Self { |
54 | let mut http = HttpConnector::new(); |
55 | http.enforce_http(false); |
56 | HttpsConnector::from((http, tls)) |
57 | } |
58 | } |
59 | |
60 | impl<T: Default> Default for HttpsConnector<T> { |
61 | fn default() -> Self { |
62 | Self::new_with_connector(http:Default::default()) |
63 | } |
64 | } |
65 | |
66 | impl<T> HttpsConnector<T> { |
67 | /// Force the use of HTTPS when connecting. |
68 | /// |
69 | /// If a URL is not `https` when connecting, an error is returned. |
70 | pub fn https_only(&mut self, enable: bool) { |
71 | self.force_https = enable; |
72 | } |
73 | |
74 | /// With connector constructor |
75 | /// |
76 | /// # Panics |
77 | /// |
78 | /// This will panic if the underlying TLS context could not be created. |
79 | /// |
80 | /// To handle that error yourself, you can use the `HttpsConnector::from` |
81 | /// constructor after trying to make a `TlsConnector`. |
82 | pub fn new_with_connector(http: T) -> Self { |
83 | native_tls::TlsConnector::new().map_or_else( |
84 | |e| { |
85 | panic!( |
86 | "HttpsConnector::new_with_connector(<connector>) failure: {}" , |
87 | e |
88 | ) |
89 | }, |
90 | |tls| HttpsConnector::from((http, tls.into())), |
91 | ) |
92 | } |
93 | } |
94 | |
95 | impl<T> From<(T, TlsConnector)> for HttpsConnector<T> { |
96 | fn from(args: (T, TlsConnector)) -> HttpsConnector<T> { |
97 | HttpsConnector { |
98 | force_https: false, |
99 | http: args.0, |
100 | tls: args.1, |
101 | } |
102 | } |
103 | } |
104 | |
105 | impl<T: fmt::Debug> fmt::Debug for HttpsConnector<T> { |
106 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
107 | f&mut DebugStruct<'_, '_>.debug_struct("HttpsConnector" ) |
108 | .field("force_https" , &self.force_https) |
109 | .field(name:"http" , &self.http) |
110 | .finish_non_exhaustive() |
111 | } |
112 | } |
113 | |
114 | impl<T> Service<Uri> for HttpsConnector<T> |
115 | where |
116 | T: Service<Uri>, |
117 | T::Response: Read + Write + Send + Unpin, |
118 | T::Future: Send + 'static, |
119 | T::Error: Into<BoxError>, |
120 | { |
121 | type Response = MaybeHttpsStream<T::Response>; |
122 | type Error = BoxError; |
123 | type Future = HttpsConnecting<T::Response>; |
124 | |
125 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
126 | match self.http.poll_ready(cx) { |
127 | Poll::Ready(Ok(())) => Poll::Ready(Ok(())), |
128 | Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), |
129 | Poll::Pending => Poll::Pending, |
130 | } |
131 | } |
132 | |
133 | fn call(&mut self, dst: Uri) -> Self::Future { |
134 | let is_https = dst.scheme_str() == Some("https" ); |
135 | // Early abort if HTTPS is forced but can't be used |
136 | if !is_https && self.force_https { |
137 | return err(ForceHttpsButUriNotHttps.into()); |
138 | } |
139 | |
140 | let host = dst |
141 | .host() |
142 | .unwrap_or("" ) |
143 | .trim_matches(|c| c == '[' || c == ']' ) |
144 | .to_owned(); |
145 | let connecting = self.http.call(dst); |
146 | |
147 | let tls_connector = self.tls.clone(); |
148 | |
149 | let fut = async move { |
150 | let tcp = connecting.await.map_err(Into::into)?; |
151 | |
152 | let maybe = if is_https { |
153 | let stream = TokioIo::new(tcp); |
154 | |
155 | let tls = TokioIo::new(tls_connector.connect(&host, stream).await?); |
156 | MaybeHttpsStream::Https(tls) |
157 | } else { |
158 | MaybeHttpsStream::Http(tcp) |
159 | }; |
160 | Ok(maybe) |
161 | }; |
162 | HttpsConnecting(Box::pin(fut)) |
163 | } |
164 | } |
165 | |
166 | fn err<T>(e: BoxError) -> HttpsConnecting<T> { |
167 | HttpsConnecting(Box::pin(async { Err(e) })) |
168 | } |
169 | |
170 | type BoxedFut<T> = Pin<Box<dyn Future<Output = Result<MaybeHttpsStream<T>, BoxError>> + Send>>; |
171 | |
172 | /// A Future representing work to connect to a URL, and a TLS handshake. |
173 | pub struct HttpsConnecting<T>(BoxedFut<T>); |
174 | |
175 | impl<T: Read + Write + Unpin> Future for HttpsConnecting<T> { |
176 | type Output = Result<MaybeHttpsStream<T>, BoxError>; |
177 | |
178 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
179 | Pin::new(&mut self.0).poll(cx) |
180 | } |
181 | } |
182 | |
183 | impl<T> fmt::Debug for HttpsConnecting<T> { |
184 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
185 | f.pad("HttpsConnecting" ) |
186 | } |
187 | } |
188 | |
189 | // ===== Custom Errors ===== |
190 | |
191 | #[derive (Debug)] |
192 | struct ForceHttpsButUriNotHttps; |
193 | |
194 | impl fmt::Display for ForceHttpsButUriNotHttps { |
195 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
196 | f.write_str(data:"https required but URI was not https" ) |
197 | } |
198 | } |
199 | |
200 | impl std::error::Error for ForceHttpsButUriNotHttps {} |
201 | |