1 | #[cfg (feature = "__tls" )] |
2 | use http::header::HeaderValue; |
3 | use http::uri::{Authority, Scheme}; |
4 | use http::Uri; |
5 | use hyper::rt::{Read, ReadBufCursor, Write}; |
6 | use hyper_util::client::legacy::connect::{Connected, Connection}; |
7 | #[cfg (any(feature = "socks" , feature = "__tls" ))] |
8 | use hyper_util::rt::TokioIo; |
9 | #[cfg (feature = "default-tls" )] |
10 | use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; |
11 | use pin_project_lite::pin_project; |
12 | use tower::util::{BoxCloneSyncServiceLayer, MapRequestLayer}; |
13 | use tower::{timeout::TimeoutLayer, util::BoxCloneSyncService, ServiceBuilder}; |
14 | use tower_service::Service; |
15 | |
16 | use std::future::Future; |
17 | use std::io::{self, IoSlice}; |
18 | use std::net::IpAddr; |
19 | use std::pin::Pin; |
20 | use std::sync::Arc; |
21 | use std::task::{Context, Poll}; |
22 | use std::time::Duration; |
23 | |
24 | #[cfg (feature = "default-tls" )] |
25 | use self::native_tls_conn::NativeTlsConn; |
26 | #[cfg (feature = "__rustls" )] |
27 | use self::rustls_tls_conn::RustlsTlsConn; |
28 | use crate::dns::DynResolver; |
29 | use crate::error::{cast_to_internal_error, BoxError}; |
30 | use crate::proxy::{Proxy, ProxyScheme}; |
31 | use sealed::{Conn, Unnameable}; |
32 | |
33 | pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector<DynResolver>; |
34 | |
35 | #[derive (Clone)] |
36 | pub(crate) enum Connector { |
37 | // base service, with or without an embedded timeout |
38 | Simple(ConnectorService), |
39 | // at least one custom layer along with maybe an outer timeout layer |
40 | // from `builder.connect_timeout()` |
41 | WithLayers(BoxCloneSyncService<Unnameable, Conn, BoxError>), |
42 | } |
43 | |
44 | impl Service<Uri> for Connector { |
45 | type Response = Conn; |
46 | type Error = BoxError; |
47 | type Future = Connecting; |
48 | |
49 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
50 | match self { |
51 | Connector::Simple(service: &mut ConnectorService) => service.poll_ready(cx), |
52 | Connector::WithLayers(service: &mut BoxCloneSyncService) => service.poll_ready(cx), |
53 | } |
54 | } |
55 | |
56 | fn call(&mut self, dst: Uri) -> Self::Future { |
57 | match self { |
58 | Connector::Simple(service: &mut ConnectorService) => service.call(req:dst), |
59 | Connector::WithLayers(service: &mut BoxCloneSyncService) => service.call(req:Unnameable(dst)), |
60 | } |
61 | } |
62 | } |
63 | |
64 | pub(crate) type BoxedConnectorService = BoxCloneSyncService<Unnameable, Conn, BoxError>; |
65 | |
66 | pub(crate) type BoxedConnectorLayer = |
67 | BoxCloneSyncServiceLayer<BoxedConnectorService, Unnameable, Conn, BoxError>; |
68 | |
69 | pub(crate) struct ConnectorBuilder { |
70 | inner: Inner, |
71 | proxies: Arc<Vec<Proxy>>, |
72 | verbose: verbose::Wrapper, |
73 | timeout: Option<Duration>, |
74 | #[cfg (feature = "__tls" )] |
75 | nodelay: bool, |
76 | #[cfg (feature = "__tls" )] |
77 | tls_info: bool, |
78 | #[cfg (feature = "__tls" )] |
79 | user_agent: Option<HeaderValue>, |
80 | } |
81 | |
82 | impl ConnectorBuilder { |
83 | pub(crate) fn build(self, layers: Vec<BoxedConnectorLayer>) -> Connector |
84 | where { |
85 | // construct the inner tower service |
86 | let mut base_service = ConnectorService { |
87 | inner: self.inner, |
88 | proxies: self.proxies, |
89 | verbose: self.verbose, |
90 | #[cfg (feature = "__tls" )] |
91 | nodelay: self.nodelay, |
92 | #[cfg (feature = "__tls" )] |
93 | tls_info: self.tls_info, |
94 | #[cfg (feature = "__tls" )] |
95 | user_agent: self.user_agent, |
96 | simple_timeout: None, |
97 | }; |
98 | |
99 | if layers.is_empty() { |
100 | // we have no user-provided layers, only use concrete types |
101 | base_service.simple_timeout = self.timeout; |
102 | return Connector::Simple(base_service); |
103 | } |
104 | |
105 | // otherwise we have user provided layers |
106 | // so we need type erasure all the way through |
107 | // as well as mapping the unnameable type of the layers back to Uri for the inner service |
108 | let unnameable_service = ServiceBuilder::new() |
109 | .layer(MapRequestLayer::new(|request: Unnameable| request.0)) |
110 | .service(base_service); |
111 | let mut service = BoxCloneSyncService::new(unnameable_service); |
112 | |
113 | for layer in layers { |
114 | service = ServiceBuilder::new().layer(layer).service(service); |
115 | } |
116 | |
117 | // now we handle the concrete stuff - any `connect_timeout`, |
118 | // plus a final map_err layer we can use to cast default tower layer |
119 | // errors to internal errors |
120 | match self.timeout { |
121 | Some(timeout) => { |
122 | let service = ServiceBuilder::new() |
123 | .layer(TimeoutLayer::new(timeout)) |
124 | .service(service); |
125 | let service = ServiceBuilder::new() |
126 | .map_err(|error: BoxError| cast_to_internal_error(error)) |
127 | .service(service); |
128 | let service = BoxCloneSyncService::new(service); |
129 | |
130 | Connector::WithLayers(service) |
131 | } |
132 | None => { |
133 | // no timeout, but still map err |
134 | // no named timeout layer but we still map errors since |
135 | // we might have user-provided timeout layer |
136 | let service = ServiceBuilder::new().service(service); |
137 | let service = ServiceBuilder::new() |
138 | .map_err(|error: BoxError| cast_to_internal_error(error)) |
139 | .service(service); |
140 | let service = BoxCloneSyncService::new(service); |
141 | Connector::WithLayers(service) |
142 | } |
143 | } |
144 | } |
145 | |
146 | #[cfg (not(feature = "__tls" ))] |
147 | pub(crate) fn new<T>( |
148 | mut http: HttpConnector, |
149 | proxies: Arc<Vec<Proxy>>, |
150 | local_addr: T, |
151 | #[cfg (any(target_os = "android" , target_os = "fuchsia" , target_os = "linux" ))] |
152 | interface: Option<&str>, |
153 | nodelay: bool, |
154 | ) -> ConnectorBuilder |
155 | where |
156 | T: Into<Option<IpAddr>>, |
157 | { |
158 | http.set_local_address(local_addr.into()); |
159 | #[cfg (any(target_os = "android" , target_os = "fuchsia" , target_os = "linux" ))] |
160 | if let Some(interface) = interface { |
161 | http.set_interface(interface.to_owned()); |
162 | } |
163 | http.set_nodelay(nodelay); |
164 | |
165 | ConnectorBuilder { |
166 | inner: Inner::Http(http), |
167 | proxies, |
168 | verbose: verbose::OFF, |
169 | timeout: None, |
170 | } |
171 | } |
172 | |
173 | #[cfg (feature = "default-tls" )] |
174 | pub(crate) fn new_default_tls<T>( |
175 | http: HttpConnector, |
176 | tls: TlsConnectorBuilder, |
177 | proxies: Arc<Vec<Proxy>>, |
178 | user_agent: Option<HeaderValue>, |
179 | local_addr: T, |
180 | #[cfg (any(target_os = "android" , target_os = "fuchsia" , target_os = "linux" ))] |
181 | interface: Option<&str>, |
182 | nodelay: bool, |
183 | tls_info: bool, |
184 | ) -> crate::Result<ConnectorBuilder> |
185 | where |
186 | T: Into<Option<IpAddr>>, |
187 | { |
188 | let tls = tls.build().map_err(crate::error::builder)?; |
189 | Ok(Self::from_built_default_tls( |
190 | http, |
191 | tls, |
192 | proxies, |
193 | user_agent, |
194 | local_addr, |
195 | #[cfg (any(target_os = "android" , target_os = "fuchsia" , target_os = "linux" ))] |
196 | interface, |
197 | nodelay, |
198 | tls_info, |
199 | )) |
200 | } |
201 | |
202 | #[cfg (feature = "default-tls" )] |
203 | pub(crate) fn from_built_default_tls<T>( |
204 | mut http: HttpConnector, |
205 | tls: TlsConnector, |
206 | proxies: Arc<Vec<Proxy>>, |
207 | user_agent: Option<HeaderValue>, |
208 | local_addr: T, |
209 | #[cfg (any(target_os = "android" , target_os = "fuchsia" , target_os = "linux" ))] |
210 | interface: Option<&str>, |
211 | nodelay: bool, |
212 | tls_info: bool, |
213 | ) -> ConnectorBuilder |
214 | where |
215 | T: Into<Option<IpAddr>>, |
216 | { |
217 | http.set_local_address(local_addr.into()); |
218 | #[cfg (any(target_os = "android" , target_os = "fuchsia" , target_os = "linux" ))] |
219 | if let Some(interface) = interface { |
220 | http.set_interface(interface); |
221 | } |
222 | http.set_nodelay(nodelay); |
223 | http.enforce_http(false); |
224 | |
225 | ConnectorBuilder { |
226 | inner: Inner::DefaultTls(http, tls), |
227 | proxies, |
228 | verbose: verbose::OFF, |
229 | nodelay, |
230 | tls_info, |
231 | user_agent, |
232 | timeout: None, |
233 | } |
234 | } |
235 | |
236 | #[cfg (feature = "__rustls" )] |
237 | pub(crate) fn new_rustls_tls<T>( |
238 | mut http: HttpConnector, |
239 | tls: rustls::ClientConfig, |
240 | proxies: Arc<Vec<Proxy>>, |
241 | user_agent: Option<HeaderValue>, |
242 | local_addr: T, |
243 | #[cfg (any(target_os = "android" , target_os = "fuchsia" , target_os = "linux" ))] |
244 | interface: Option<&str>, |
245 | nodelay: bool, |
246 | tls_info: bool, |
247 | ) -> ConnectorBuilder |
248 | where |
249 | T: Into<Option<IpAddr>>, |
250 | { |
251 | http.set_local_address(local_addr.into()); |
252 | #[cfg (any(target_os = "android" , target_os = "fuchsia" , target_os = "linux" ))] |
253 | if let Some(interface) = interface { |
254 | http.set_interface(interface.to_owned()); |
255 | } |
256 | http.set_nodelay(nodelay); |
257 | http.enforce_http(false); |
258 | |
259 | let (tls, tls_proxy) = if proxies.is_empty() { |
260 | let tls = Arc::new(tls); |
261 | (tls.clone(), tls) |
262 | } else { |
263 | let mut tls_proxy = tls.clone(); |
264 | tls_proxy.alpn_protocols.clear(); |
265 | (Arc::new(tls), Arc::new(tls_proxy)) |
266 | }; |
267 | |
268 | ConnectorBuilder { |
269 | inner: Inner::RustlsTls { |
270 | http, |
271 | tls, |
272 | tls_proxy, |
273 | }, |
274 | proxies, |
275 | verbose: verbose::OFF, |
276 | nodelay, |
277 | tls_info, |
278 | user_agent, |
279 | timeout: None, |
280 | } |
281 | } |
282 | |
283 | pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) { |
284 | self.timeout = timeout; |
285 | } |
286 | |
287 | pub(crate) fn set_verbose(&mut self, enabled: bool) { |
288 | self.verbose.0 = enabled; |
289 | } |
290 | |
291 | pub(crate) fn set_keepalive(&mut self, dur: Option<Duration>) { |
292 | match &mut self.inner { |
293 | #[cfg (feature = "default-tls" )] |
294 | Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), |
295 | #[cfg (feature = "__rustls" )] |
296 | Inner::RustlsTls { http, .. } => http.set_keepalive(dur), |
297 | #[cfg (not(feature = "__tls" ))] |
298 | Inner::Http(http) => http.set_keepalive(dur), |
299 | } |
300 | } |
301 | } |
302 | |
303 | #[allow (missing_debug_implementations)] |
304 | #[derive (Clone)] |
305 | pub(crate) struct ConnectorService { |
306 | inner: Inner, |
307 | proxies: Arc<Vec<Proxy>>, |
308 | verbose: verbose::Wrapper, |
309 | /// When there is a single timeout layer and no other layers, |
310 | /// we embed it directly inside our base Service::call(). |
311 | /// This lets us avoid an extra `Box::pin` indirection layer |
312 | /// since `tokio::time::Timeout` is `Unpin` |
313 | simple_timeout: Option<Duration>, |
314 | #[cfg (feature = "__tls" )] |
315 | nodelay: bool, |
316 | #[cfg (feature = "__tls" )] |
317 | tls_info: bool, |
318 | #[cfg (feature = "__tls" )] |
319 | user_agent: Option<HeaderValue>, |
320 | } |
321 | |
322 | #[derive (Clone)] |
323 | enum Inner { |
324 | #[cfg (not(feature = "__tls" ))] |
325 | Http(HttpConnector), |
326 | #[cfg (feature = "default-tls" )] |
327 | DefaultTls(HttpConnector, TlsConnector), |
328 | #[cfg (feature = "__rustls" )] |
329 | RustlsTls { |
330 | http: HttpConnector, |
331 | tls: Arc<rustls::ClientConfig>, |
332 | tls_proxy: Arc<rustls::ClientConfig>, |
333 | }, |
334 | } |
335 | |
336 | impl ConnectorService { |
337 | #[cfg (feature = "socks" )] |
338 | async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> { |
339 | let dns = match proxy { |
340 | ProxyScheme::Socks4 { |
341 | remote_dns: false, .. |
342 | } => socks::DnsResolve::Local, |
343 | ProxyScheme::Socks4 { |
344 | remote_dns: true, .. |
345 | } => socks::DnsResolve::Proxy, |
346 | ProxyScheme::Socks5 { |
347 | remote_dns: false, .. |
348 | } => socks::DnsResolve::Local, |
349 | ProxyScheme::Socks5 { |
350 | remote_dns: true, .. |
351 | } => socks::DnsResolve::Proxy, |
352 | ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => { |
353 | unreachable!("connect_socks is only called for socks proxies" ); |
354 | } |
355 | }; |
356 | |
357 | match &self.inner { |
358 | #[cfg (feature = "default-tls" )] |
359 | Inner::DefaultTls(_http, tls) => { |
360 | if dst.scheme() == Some(&Scheme::HTTPS) { |
361 | let host = dst.host().ok_or("no host in url" )?.to_string(); |
362 | let conn = socks::connect(proxy, dst, dns).await?; |
363 | let conn = TokioIo::new(conn); |
364 | let conn = TokioIo::new(conn); |
365 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
366 | let io = tls_connector.connect(&host, conn).await?; |
367 | let io = TokioIo::new(io); |
368 | return Ok(Conn { |
369 | inner: self.verbose.wrap(NativeTlsConn { inner: io }), |
370 | is_proxy: false, |
371 | tls_info: self.tls_info, |
372 | }); |
373 | } |
374 | } |
375 | #[cfg (feature = "__rustls" )] |
376 | Inner::RustlsTls { tls, .. } => { |
377 | if dst.scheme() == Some(&Scheme::HTTPS) { |
378 | use std::convert::TryFrom; |
379 | use tokio_rustls::TlsConnector as RustlsConnector; |
380 | |
381 | let tls = tls.clone(); |
382 | let host = dst.host().ok_or("no host in url" )?.to_string(); |
383 | let conn = socks::connect(proxy, dst, dns).await?; |
384 | let conn = TokioIo::new(conn); |
385 | let conn = TokioIo::new(conn); |
386 | let server_name = |
387 | rustls_pki_types::ServerName::try_from(host.as_str().to_owned()) |
388 | .map_err(|_| "Invalid Server Name" )?; |
389 | let io = RustlsConnector::from(tls) |
390 | .connect(server_name, conn) |
391 | .await?; |
392 | let io = TokioIo::new(io); |
393 | return Ok(Conn { |
394 | inner: self.verbose.wrap(RustlsTlsConn { inner: io }), |
395 | is_proxy: false, |
396 | tls_info: false, |
397 | }); |
398 | } |
399 | } |
400 | #[cfg (not(feature = "__tls" ))] |
401 | Inner::Http(_) => (), |
402 | } |
403 | |
404 | socks::connect(proxy, dst, dns).await.map(|tcp| Conn { |
405 | inner: self.verbose.wrap(TokioIo::new(tcp)), |
406 | is_proxy: false, |
407 | tls_info: false, |
408 | }) |
409 | } |
410 | |
411 | async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> { |
412 | match self.inner { |
413 | #[cfg (not(feature = "__tls" ))] |
414 | Inner::Http(mut http) => { |
415 | let io = http.call(dst).await?; |
416 | Ok(Conn { |
417 | inner: self.verbose.wrap(io), |
418 | is_proxy, |
419 | tls_info: false, |
420 | }) |
421 | } |
422 | #[cfg (feature = "default-tls" )] |
423 | Inner::DefaultTls(http, tls) => { |
424 | let mut http = http.clone(); |
425 | |
426 | // Disable Nagle's algorithm for TLS handshake |
427 | // |
428 | // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES |
429 | if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) { |
430 | http.set_nodelay(true); |
431 | } |
432 | |
433 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
434 | let mut http = hyper_tls::HttpsConnector::from((http, tls_connector)); |
435 | let io = http.call(dst).await?; |
436 | |
437 | if let hyper_tls::MaybeHttpsStream::Https(stream) = io { |
438 | if !self.nodelay { |
439 | stream |
440 | .inner() |
441 | .get_ref() |
442 | .get_ref() |
443 | .get_ref() |
444 | .inner() |
445 | .inner() |
446 | .set_nodelay(false)?; |
447 | } |
448 | Ok(Conn { |
449 | inner: self.verbose.wrap(NativeTlsConn { inner: stream }), |
450 | is_proxy, |
451 | tls_info: self.tls_info, |
452 | }) |
453 | } else { |
454 | Ok(Conn { |
455 | inner: self.verbose.wrap(io), |
456 | is_proxy, |
457 | tls_info: false, |
458 | }) |
459 | } |
460 | } |
461 | #[cfg (feature = "__rustls" )] |
462 | Inner::RustlsTls { http, tls, .. } => { |
463 | let mut http = http.clone(); |
464 | |
465 | // Disable Nagle's algorithm for TLS handshake |
466 | // |
467 | // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES |
468 | if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) { |
469 | http.set_nodelay(true); |
470 | } |
471 | |
472 | let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone())); |
473 | let io = http.call(dst).await?; |
474 | |
475 | if let hyper_rustls::MaybeHttpsStream::Https(stream) = io { |
476 | if !self.nodelay { |
477 | let (io, _) = stream.inner().get_ref(); |
478 | io.inner().inner().set_nodelay(false)?; |
479 | } |
480 | Ok(Conn { |
481 | inner: self.verbose.wrap(RustlsTlsConn { inner: stream }), |
482 | is_proxy, |
483 | tls_info: self.tls_info, |
484 | }) |
485 | } else { |
486 | Ok(Conn { |
487 | inner: self.verbose.wrap(io), |
488 | is_proxy, |
489 | tls_info: false, |
490 | }) |
491 | } |
492 | } |
493 | } |
494 | } |
495 | |
496 | async fn connect_via_proxy( |
497 | self, |
498 | dst: Uri, |
499 | proxy_scheme: ProxyScheme, |
500 | ) -> Result<Conn, BoxError> { |
501 | log::debug!("proxy( {proxy_scheme:?}) intercepts ' {dst:?}'" ); |
502 | |
503 | let (proxy_dst, _auth) = match proxy_scheme { |
504 | ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth), |
505 | ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth), |
506 | #[cfg (feature = "socks" )] |
507 | ProxyScheme::Socks4 { .. } => return self.connect_socks(dst, proxy_scheme).await, |
508 | #[cfg (feature = "socks" )] |
509 | ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await, |
510 | }; |
511 | |
512 | #[cfg (feature = "__tls" )] |
513 | let auth = _auth; |
514 | |
515 | match &self.inner { |
516 | #[cfg (feature = "default-tls" )] |
517 | Inner::DefaultTls(http, tls) => { |
518 | if dst.scheme() == Some(&Scheme::HTTPS) { |
519 | let host = dst.host().to_owned(); |
520 | let port = dst.port().map(|p| p.as_u16()).unwrap_or(443); |
521 | let http = http.clone(); |
522 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
523 | let mut http = hyper_tls::HttpsConnector::from((http, tls_connector)); |
524 | let conn = http.call(proxy_dst).await?; |
525 | log::trace!("tunneling HTTPS over proxy" ); |
526 | let tunneled = tunnel( |
527 | conn, |
528 | host.ok_or("no host in url" )?.to_string(), |
529 | port, |
530 | self.user_agent.clone(), |
531 | auth, |
532 | ) |
533 | .await?; |
534 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
535 | let io = tls_connector |
536 | .connect(host.ok_or("no host in url" )?, TokioIo::new(tunneled)) |
537 | .await?; |
538 | return Ok(Conn { |
539 | inner: self.verbose.wrap(NativeTlsConn { |
540 | inner: TokioIo::new(io), |
541 | }), |
542 | is_proxy: false, |
543 | tls_info: false, |
544 | }); |
545 | } |
546 | } |
547 | #[cfg (feature = "__rustls" )] |
548 | Inner::RustlsTls { |
549 | http, |
550 | tls, |
551 | tls_proxy, |
552 | } => { |
553 | if dst.scheme() == Some(&Scheme::HTTPS) { |
554 | use rustls_pki_types::ServerName; |
555 | use std::convert::TryFrom; |
556 | use tokio_rustls::TlsConnector as RustlsConnector; |
557 | |
558 | let host = dst.host().ok_or("no host in url" )?.to_string(); |
559 | let port = dst.port().map(|r| r.as_u16()).unwrap_or(443); |
560 | let http = http.clone(); |
561 | let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); |
562 | let tls = tls.clone(); |
563 | let conn = http.call(proxy_dst).await?; |
564 | log::trace!("tunneling HTTPS over proxy" ); |
565 | let maybe_server_name = ServerName::try_from(host.as_str().to_owned()) |
566 | .map_err(|_| "Invalid Server Name" ); |
567 | let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; |
568 | let server_name = maybe_server_name?; |
569 | let io = RustlsConnector::from(tls) |
570 | .connect(server_name, TokioIo::new(tunneled)) |
571 | .await?; |
572 | |
573 | return Ok(Conn { |
574 | inner: self.verbose.wrap(RustlsTlsConn { |
575 | inner: TokioIo::new(io), |
576 | }), |
577 | is_proxy: false, |
578 | tls_info: false, |
579 | }); |
580 | } |
581 | } |
582 | #[cfg (not(feature = "__tls" ))] |
583 | Inner::Http(_) => (), |
584 | } |
585 | |
586 | self.connect_with_maybe_proxy(proxy_dst, true).await |
587 | } |
588 | } |
589 | |
590 | fn into_uri(scheme: Scheme, host: Authority) -> Uri { |
591 | // TODO: Should the `http` crate get `From<(Scheme, Authority)> for Uri`? |
592 | http::Uri::builder() |
593 | .scheme(scheme) |
594 | .authority(host) |
595 | .path_and_query(http::uri::PathAndQuery::from_static("/" )) |
596 | .build() |
597 | .expect(msg:"scheme and authority is valid Uri" ) |
598 | } |
599 | |
600 | async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError> |
601 | where |
602 | F: Future<Output = Result<T, BoxError>>, |
603 | { |
604 | if let Some(to: Duration) = timeout { |
605 | match tokio::time::timeout(duration:to, future:f).await { |
606 | Err(_elapsed: Elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError), |
607 | Ok(Ok(try_res: T)) => Ok(try_res), |
608 | Ok(Err(e: Box)) => Err(e), |
609 | } |
610 | } else { |
611 | f.await |
612 | } |
613 | } |
614 | |
615 | impl Service<Uri> for ConnectorService { |
616 | type Response = Conn; |
617 | type Error = BoxError; |
618 | type Future = Connecting; |
619 | |
620 | fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
621 | Poll::Ready(Ok(())) |
622 | } |
623 | |
624 | fn call(&mut self, dst: Uri) -> Self::Future { |
625 | log::debug!("starting new connection: {dst:?}" ); |
626 | let timeout = self.simple_timeout; |
627 | for prox in self.proxies.iter() { |
628 | if let Some(proxy_scheme) = prox.intercept(&dst) { |
629 | return Box::pin(with_timeout( |
630 | self.clone().connect_via_proxy(dst, proxy_scheme), |
631 | timeout, |
632 | )); |
633 | } |
634 | } |
635 | |
636 | Box::pin(with_timeout( |
637 | self.clone().connect_with_maybe_proxy(dst, false), |
638 | timeout, |
639 | )) |
640 | } |
641 | } |
642 | |
643 | #[cfg (feature = "__tls" )] |
644 | trait TlsInfoFactory { |
645 | fn tls_info(&self) -> Option<crate::tls::TlsInfo>; |
646 | } |
647 | |
648 | #[cfg (feature = "__tls" )] |
649 | impl TlsInfoFactory for tokio::net::TcpStream { |
650 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
651 | None |
652 | } |
653 | } |
654 | |
655 | #[cfg (feature = "__tls" )] |
656 | impl<T: TlsInfoFactory> TlsInfoFactory for TokioIo<T> { |
657 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
658 | self.inner().tls_info() |
659 | } |
660 | } |
661 | |
662 | #[cfg (feature = "default-tls" )] |
663 | impl TlsInfoFactory for tokio_native_tls::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> { |
664 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
665 | let peer_certificate: Option> = self |
666 | .get_ref() |
667 | .peer_certificate() |
668 | .ok() |
669 | .flatten() |
670 | .and_then(|c: Certificate| c.to_der().ok()); |
671 | Some(crate::tls::TlsInfo { peer_certificate }) |
672 | } |
673 | } |
674 | |
675 | #[cfg (feature = "default-tls" )] |
676 | impl TlsInfoFactory |
677 | for tokio_native_tls::TlsStream< |
678 | TokioIo<hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>, |
679 | > |
680 | { |
681 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
682 | let peer_certificate: Option> = self |
683 | .get_ref() |
684 | .peer_certificate() |
685 | .ok() |
686 | .flatten() |
687 | .and_then(|c: Certificate| c.to_der().ok()); |
688 | Some(crate::tls::TlsInfo { peer_certificate }) |
689 | } |
690 | } |
691 | |
692 | #[cfg (feature = "default-tls" )] |
693 | impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> { |
694 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
695 | match self { |
696 | hyper_tls::MaybeHttpsStream::Https(tls: &TokioIo>>) => tls.tls_info(), |
697 | hyper_tls::MaybeHttpsStream::Http(_) => None, |
698 | } |
699 | } |
700 | } |
701 | |
702 | #[cfg (feature = "__rustls" )] |
703 | impl TlsInfoFactory for tokio_rustls::client::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> { |
704 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
705 | let peer_certificate = self |
706 | .get_ref() |
707 | .1 |
708 | .peer_certificates() |
709 | .and_then(|certs| certs.first()) |
710 | .map(|c| c.to_vec()); |
711 | Some(crate::tls::TlsInfo { peer_certificate }) |
712 | } |
713 | } |
714 | |
715 | #[cfg (feature = "__rustls" )] |
716 | impl TlsInfoFactory |
717 | for tokio_rustls::client::TlsStream< |
718 | TokioIo<hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>, |
719 | > |
720 | { |
721 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
722 | let peer_certificate = self |
723 | .get_ref() |
724 | .1 |
725 | .peer_certificates() |
726 | .and_then(|certs| certs.first()) |
727 | .map(|c| c.to_vec()); |
728 | Some(crate::tls::TlsInfo { peer_certificate }) |
729 | } |
730 | } |
731 | |
732 | #[cfg (feature = "__rustls" )] |
733 | impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> { |
734 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
735 | match self { |
736 | hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(), |
737 | hyper_rustls::MaybeHttpsStream::Http(_) => None, |
738 | } |
739 | } |
740 | } |
741 | |
742 | pub(crate) trait AsyncConn: |
743 | Read + Write + Connection + Send + Sync + Unpin + 'static |
744 | { |
745 | } |
746 | |
747 | impl<T: Read + Write + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {} |
748 | |
749 | #[cfg (feature = "__tls" )] |
750 | trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {} |
751 | #[cfg (not(feature = "__tls" ))] |
752 | trait AsyncConnWithInfo: AsyncConn {} |
753 | |
754 | #[cfg (feature = "__tls" )] |
755 | impl<T: AsyncConn + TlsInfoFactory> AsyncConnWithInfo for T {} |
756 | #[cfg (not(feature = "__tls" ))] |
757 | impl<T: AsyncConn> AsyncConnWithInfo for T {} |
758 | |
759 | type BoxConn = Box<dyn AsyncConnWithInfo>; |
760 | |
761 | pub(crate) mod sealed { |
762 | use super::*; |
763 | #[derive (Debug)] |
764 | pub struct Unnameable(pub(super) Uri); |
765 | |
766 | pin_project! { |
767 | /// Note: the `is_proxy` member means *is plain text HTTP proxy*. |
768 | /// This tells hyper whether the URI should be written in |
769 | /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or |
770 | /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. |
771 | #[allow (missing_debug_implementations)] |
772 | pub struct Conn { |
773 | #[pin] |
774 | pub(super)inner: BoxConn, |
775 | pub(super) is_proxy: bool, |
776 | // Only needed for __tls, but #[cfg()] on fields breaks pin_project! |
777 | pub(super) tls_info: bool, |
778 | } |
779 | } |
780 | |
781 | impl Connection for Conn { |
782 | fn connected(&self) -> Connected { |
783 | let connected = self.inner.connected().proxy(self.is_proxy); |
784 | #[cfg (feature = "__tls" )] |
785 | if self.tls_info { |
786 | if let Some(tls_info) = self.inner.tls_info() { |
787 | connected.extra(tls_info) |
788 | } else { |
789 | connected |
790 | } |
791 | } else { |
792 | connected |
793 | } |
794 | #[cfg (not(feature = "__tls" ))] |
795 | connected |
796 | } |
797 | } |
798 | |
799 | impl Read for Conn { |
800 | fn poll_read( |
801 | self: Pin<&mut Self>, |
802 | cx: &mut Context, |
803 | buf: ReadBufCursor<'_>, |
804 | ) -> Poll<io::Result<()>> { |
805 | let this = self.project(); |
806 | Read::poll_read(this.inner, cx, buf) |
807 | } |
808 | } |
809 | |
810 | impl Write for Conn { |
811 | fn poll_write( |
812 | self: Pin<&mut Self>, |
813 | cx: &mut Context, |
814 | buf: &[u8], |
815 | ) -> Poll<Result<usize, io::Error>> { |
816 | let this = self.project(); |
817 | Write::poll_write(this.inner, cx, buf) |
818 | } |
819 | |
820 | fn poll_write_vectored( |
821 | self: Pin<&mut Self>, |
822 | cx: &mut Context<'_>, |
823 | bufs: &[IoSlice<'_>], |
824 | ) -> Poll<Result<usize, io::Error>> { |
825 | let this = self.project(); |
826 | Write::poll_write_vectored(this.inner, cx, bufs) |
827 | } |
828 | |
829 | fn is_write_vectored(&self) -> bool { |
830 | self.inner.is_write_vectored() |
831 | } |
832 | |
833 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> { |
834 | let this = self.project(); |
835 | Write::poll_flush(this.inner, cx) |
836 | } |
837 | |
838 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> { |
839 | let this = self.project(); |
840 | Write::poll_shutdown(this.inner, cx) |
841 | } |
842 | } |
843 | } |
844 | |
845 | pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>; |
846 | |
847 | #[cfg (feature = "__tls" )] |
848 | async fn tunnel<T>( |
849 | mut conn: T, |
850 | host: String, |
851 | port: u16, |
852 | user_agent: Option<HeaderValue>, |
853 | auth: Option<HeaderValue>, |
854 | ) -> Result<T, BoxError> |
855 | where |
856 | T: Read + Write + Unpin, |
857 | { |
858 | use hyper_util::rt::TokioIo; |
859 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; |
860 | |
861 | let mut buf = format!( |
862 | "\ |
863 | CONNECT {host}: {port} HTTP/1.1 \r\n\ |
864 | Host: {host}: {port}\r\n\ |
865 | " |
866 | ) |
867 | .into_bytes(); |
868 | |
869 | // user-agent |
870 | if let Some(user_agent) = user_agent { |
871 | buf.extend_from_slice(b"User-Agent: " ); |
872 | buf.extend_from_slice(user_agent.as_bytes()); |
873 | buf.extend_from_slice(b" \r\n" ); |
874 | } |
875 | |
876 | // proxy-authorization |
877 | if let Some(value) = auth { |
878 | log::debug!("tunnel to {host}: {port} using basic auth" ); |
879 | buf.extend_from_slice(b"Proxy-Authorization: " ); |
880 | buf.extend_from_slice(value.as_bytes()); |
881 | buf.extend_from_slice(b" \r\n" ); |
882 | } |
883 | |
884 | // headers end |
885 | buf.extend_from_slice(b" \r\n" ); |
886 | |
887 | let mut tokio_conn = TokioIo::new(&mut conn); |
888 | |
889 | tokio_conn.write_all(&buf).await?; |
890 | |
891 | let mut buf = [0; 8192]; |
892 | let mut pos = 0; |
893 | |
894 | loop { |
895 | let n = tokio_conn.read(&mut buf[pos..]).await?; |
896 | |
897 | if n == 0 { |
898 | return Err(tunnel_eof()); |
899 | } |
900 | pos += n; |
901 | |
902 | let recvd = &buf[..pos]; |
903 | if recvd.starts_with(b"HTTP/1.1 200" ) || recvd.starts_with(b"HTTP/1.0 200" ) { |
904 | if recvd.ends_with(b" \r\n\r\n" ) { |
905 | return Ok(conn); |
906 | } |
907 | if pos == buf.len() { |
908 | return Err("proxy headers too long for tunnel" .into()); |
909 | } |
910 | // else read more |
911 | } else if recvd.starts_with(b"HTTP/1.1 407" ) { |
912 | return Err("proxy authentication required" .into()); |
913 | } else { |
914 | return Err("unsuccessful tunnel" .into()); |
915 | } |
916 | } |
917 | } |
918 | |
919 | #[cfg (feature = "__tls" )] |
920 | fn tunnel_eof() -> BoxError { |
921 | "unexpected eof while tunneling" .into() |
922 | } |
923 | |
924 | #[cfg (feature = "default-tls" )] |
925 | mod native_tls_conn { |
926 | use super::TlsInfoFactory; |
927 | use hyper::rt::{Read, ReadBufCursor, Write}; |
928 | use hyper_tls::MaybeHttpsStream; |
929 | use hyper_util::client::legacy::connect::{Connected, Connection}; |
930 | use hyper_util::rt::TokioIo; |
931 | use pin_project_lite::pin_project; |
932 | use std::{ |
933 | io::{self, IoSlice}, |
934 | pin::Pin, |
935 | task::{Context, Poll}, |
936 | }; |
937 | use tokio::io::{AsyncRead, AsyncWrite}; |
938 | use tokio::net::TcpStream; |
939 | use tokio_native_tls::TlsStream; |
940 | |
941 | pin_project! { |
942 | pub(super) struct NativeTlsConn<T> { |
943 | #[pin] pub(super) inner: TokioIo<TlsStream<T>>, |
944 | } |
945 | } |
946 | |
947 | impl Connection for NativeTlsConn<TokioIo<TokioIo<TcpStream>>> { |
948 | fn connected(&self) -> Connected { |
949 | let connected = self |
950 | .inner |
951 | .inner() |
952 | .get_ref() |
953 | .get_ref() |
954 | .get_ref() |
955 | .inner() |
956 | .connected(); |
957 | #[cfg (feature = "native-tls-alpn" )] |
958 | match self.inner.inner().get_ref().negotiated_alpn().ok() { |
959 | Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(), |
960 | _ => connected, |
961 | } |
962 | #[cfg (not(feature = "native-tls-alpn" ))] |
963 | connected |
964 | } |
965 | } |
966 | |
967 | impl Connection for NativeTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> { |
968 | fn connected(&self) -> Connected { |
969 | let connected = self |
970 | .inner |
971 | .inner() |
972 | .get_ref() |
973 | .get_ref() |
974 | .get_ref() |
975 | .inner() |
976 | .connected(); |
977 | #[cfg (feature = "native-tls-alpn" )] |
978 | match self.inner.inner().get_ref().negotiated_alpn().ok() { |
979 | Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(), |
980 | _ => connected, |
981 | } |
982 | #[cfg (not(feature = "native-tls-alpn" ))] |
983 | connected |
984 | } |
985 | } |
986 | |
987 | impl<T: AsyncRead + AsyncWrite + Unpin> Read for NativeTlsConn<T> { |
988 | fn poll_read( |
989 | self: Pin<&mut Self>, |
990 | cx: &mut Context, |
991 | buf: ReadBufCursor<'_>, |
992 | ) -> Poll<tokio::io::Result<()>> { |
993 | let this = self.project(); |
994 | Read::poll_read(this.inner, cx, buf) |
995 | } |
996 | } |
997 | |
998 | impl<T: AsyncRead + AsyncWrite + Unpin> Write for NativeTlsConn<T> { |
999 | fn poll_write( |
1000 | self: Pin<&mut Self>, |
1001 | cx: &mut Context, |
1002 | buf: &[u8], |
1003 | ) -> Poll<Result<usize, tokio::io::Error>> { |
1004 | let this = self.project(); |
1005 | Write::poll_write(this.inner, cx, buf) |
1006 | } |
1007 | |
1008 | fn poll_write_vectored( |
1009 | self: Pin<&mut Self>, |
1010 | cx: &mut Context<'_>, |
1011 | bufs: &[IoSlice<'_>], |
1012 | ) -> Poll<Result<usize, io::Error>> { |
1013 | let this = self.project(); |
1014 | Write::poll_write_vectored(this.inner, cx, bufs) |
1015 | } |
1016 | |
1017 | fn is_write_vectored(&self) -> bool { |
1018 | self.inner.is_write_vectored() |
1019 | } |
1020 | |
1021 | fn poll_flush( |
1022 | self: Pin<&mut Self>, |
1023 | cx: &mut Context, |
1024 | ) -> Poll<Result<(), tokio::io::Error>> { |
1025 | let this = self.project(); |
1026 | Write::poll_flush(this.inner, cx) |
1027 | } |
1028 | |
1029 | fn poll_shutdown( |
1030 | self: Pin<&mut Self>, |
1031 | cx: &mut Context, |
1032 | ) -> Poll<Result<(), tokio::io::Error>> { |
1033 | let this = self.project(); |
1034 | Write::poll_shutdown(this.inner, cx) |
1035 | } |
1036 | } |
1037 | |
1038 | impl<T> TlsInfoFactory for NativeTlsConn<T> |
1039 | where |
1040 | TokioIo<TlsStream<T>>: TlsInfoFactory, |
1041 | { |
1042 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
1043 | self.inner.tls_info() |
1044 | } |
1045 | } |
1046 | } |
1047 | |
1048 | #[cfg (feature = "__rustls" )] |
1049 | mod rustls_tls_conn { |
1050 | use super::TlsInfoFactory; |
1051 | use hyper::rt::{Read, ReadBufCursor, Write}; |
1052 | use hyper_rustls::MaybeHttpsStream; |
1053 | use hyper_util::client::legacy::connect::{Connected, Connection}; |
1054 | use hyper_util::rt::TokioIo; |
1055 | use pin_project_lite::pin_project; |
1056 | use std::{ |
1057 | io::{self, IoSlice}, |
1058 | pin::Pin, |
1059 | task::{Context, Poll}, |
1060 | }; |
1061 | use tokio::io::{AsyncRead, AsyncWrite}; |
1062 | use tokio::net::TcpStream; |
1063 | use tokio_rustls::client::TlsStream; |
1064 | |
1065 | pin_project! { |
1066 | pub(super) struct RustlsTlsConn<T> { |
1067 | #[pin] pub(super) inner: TokioIo<TlsStream<T>>, |
1068 | } |
1069 | } |
1070 | |
1071 | impl Connection for RustlsTlsConn<TokioIo<TokioIo<TcpStream>>> { |
1072 | fn connected(&self) -> Connected { |
1073 | if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2" ) { |
1074 | self.inner |
1075 | .inner() |
1076 | .get_ref() |
1077 | .0 |
1078 | .inner() |
1079 | .connected() |
1080 | .negotiated_h2() |
1081 | } else { |
1082 | self.inner.inner().get_ref().0.inner().connected() |
1083 | } |
1084 | } |
1085 | } |
1086 | impl Connection for RustlsTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> { |
1087 | fn connected(&self) -> Connected { |
1088 | if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2" ) { |
1089 | self.inner |
1090 | .inner() |
1091 | .get_ref() |
1092 | .0 |
1093 | .inner() |
1094 | .connected() |
1095 | .negotiated_h2() |
1096 | } else { |
1097 | self.inner.inner().get_ref().0.inner().connected() |
1098 | } |
1099 | } |
1100 | } |
1101 | |
1102 | impl<T: AsyncRead + AsyncWrite + Unpin> Read for RustlsTlsConn<T> { |
1103 | fn poll_read( |
1104 | self: Pin<&mut Self>, |
1105 | cx: &mut Context, |
1106 | buf: ReadBufCursor<'_>, |
1107 | ) -> Poll<tokio::io::Result<()>> { |
1108 | let this = self.project(); |
1109 | Read::poll_read(this.inner, cx, buf) |
1110 | } |
1111 | } |
1112 | |
1113 | impl<T: AsyncRead + AsyncWrite + Unpin> Write for RustlsTlsConn<T> { |
1114 | fn poll_write( |
1115 | self: Pin<&mut Self>, |
1116 | cx: &mut Context, |
1117 | buf: &[u8], |
1118 | ) -> Poll<Result<usize, tokio::io::Error>> { |
1119 | let this = self.project(); |
1120 | Write::poll_write(this.inner, cx, buf) |
1121 | } |
1122 | |
1123 | fn poll_write_vectored( |
1124 | self: Pin<&mut Self>, |
1125 | cx: &mut Context<'_>, |
1126 | bufs: &[IoSlice<'_>], |
1127 | ) -> Poll<Result<usize, io::Error>> { |
1128 | let this = self.project(); |
1129 | Write::poll_write_vectored(this.inner, cx, bufs) |
1130 | } |
1131 | |
1132 | fn is_write_vectored(&self) -> bool { |
1133 | self.inner.is_write_vectored() |
1134 | } |
1135 | |
1136 | fn poll_flush( |
1137 | self: Pin<&mut Self>, |
1138 | cx: &mut Context, |
1139 | ) -> Poll<Result<(), tokio::io::Error>> { |
1140 | let this = self.project(); |
1141 | Write::poll_flush(this.inner, cx) |
1142 | } |
1143 | |
1144 | fn poll_shutdown( |
1145 | self: Pin<&mut Self>, |
1146 | cx: &mut Context, |
1147 | ) -> Poll<Result<(), tokio::io::Error>> { |
1148 | let this = self.project(); |
1149 | Write::poll_shutdown(this.inner, cx) |
1150 | } |
1151 | } |
1152 | impl<T> TlsInfoFactory for RustlsTlsConn<T> |
1153 | where |
1154 | TokioIo<TlsStream<T>>: TlsInfoFactory, |
1155 | { |
1156 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
1157 | self.inner.tls_info() |
1158 | } |
1159 | } |
1160 | } |
1161 | |
1162 | #[cfg (feature = "socks" )] |
1163 | mod socks { |
1164 | use std::io; |
1165 | |
1166 | use http::Uri; |
1167 | use tokio::net::TcpStream; |
1168 | use tokio_socks::tcp::{Socks4Stream, Socks5Stream}; |
1169 | |
1170 | use super::{BoxError, Scheme}; |
1171 | use crate::proxy::ProxyScheme; |
1172 | |
1173 | pub(super) enum DnsResolve { |
1174 | Local, |
1175 | Proxy, |
1176 | } |
1177 | |
1178 | pub(super) async fn connect( |
1179 | proxy: ProxyScheme, |
1180 | dst: Uri, |
1181 | dns: DnsResolve, |
1182 | ) -> Result<TcpStream, BoxError> { |
1183 | let https = dst.scheme() == Some(&Scheme::HTTPS); |
1184 | let original_host = dst |
1185 | .host() |
1186 | .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url" ))?; |
1187 | let mut host = original_host.to_owned(); |
1188 | let port = match dst.port() { |
1189 | Some(p) => p.as_u16(), |
1190 | None if https => 443u16, |
1191 | _ => 80u16, |
1192 | }; |
1193 | |
1194 | if let DnsResolve::Local = dns { |
1195 | let maybe_new_target = tokio::net::lookup_host((host.as_str(), port)).await?.next(); |
1196 | if let Some(new_target) = maybe_new_target { |
1197 | host = new_target.ip().to_string(); |
1198 | } |
1199 | } |
1200 | |
1201 | match proxy { |
1202 | ProxyScheme::Socks4 { addr, .. } => { |
1203 | let stream = Socks4Stream::connect(addr, (host.as_str(), port)) |
1204 | .await |
1205 | .map_err(|e| format!("socks connect error: {e}" ))?; |
1206 | Ok(stream.into_inner()) |
1207 | } |
1208 | ProxyScheme::Socks5 { addr, ref auth, .. } => { |
1209 | let stream = if let Some((username, password)) = auth { |
1210 | Socks5Stream::connect_with_password( |
1211 | addr, |
1212 | (host.as_str(), port), |
1213 | &username, |
1214 | &password, |
1215 | ) |
1216 | .await |
1217 | .map_err(|e| format!("socks connect error: {e}" ))? |
1218 | } else { |
1219 | Socks5Stream::connect(addr, (host.as_str(), port)) |
1220 | .await |
1221 | .map_err(|e| format!("socks connect error: {e}" ))? |
1222 | }; |
1223 | |
1224 | Ok(stream.into_inner()) |
1225 | } |
1226 | _ => unreachable!(), |
1227 | } |
1228 | } |
1229 | } |
1230 | |
1231 | mod verbose { |
1232 | use hyper::rt::{Read, ReadBufCursor, Write}; |
1233 | use hyper_util::client::legacy::connect::{Connected, Connection}; |
1234 | use std::cmp::min; |
1235 | use std::fmt; |
1236 | use std::io::{self, IoSlice}; |
1237 | use std::pin::Pin; |
1238 | use std::task::{Context, Poll}; |
1239 | |
1240 | pub(super) const OFF: Wrapper = Wrapper(false); |
1241 | |
1242 | #[derive (Clone, Copy)] |
1243 | pub(super) struct Wrapper(pub(super) bool); |
1244 | |
1245 | impl Wrapper { |
1246 | pub(super) fn wrap<T: super::AsyncConnWithInfo>(&self, conn: T) -> super::BoxConn { |
1247 | if self.0 && log::log_enabled!(log::Level::Trace) { |
1248 | Box::new(Verbose { |
1249 | // truncate is fine |
1250 | id: crate::util::fast_random() as u32, |
1251 | inner: conn, |
1252 | }) |
1253 | } else { |
1254 | Box::new(conn) |
1255 | } |
1256 | } |
1257 | } |
1258 | |
1259 | struct Verbose<T> { |
1260 | id: u32, |
1261 | inner: T, |
1262 | } |
1263 | |
1264 | impl<T: Connection + Read + Write + Unpin> Connection for Verbose<T> { |
1265 | fn connected(&self) -> Connected { |
1266 | self.inner.connected() |
1267 | } |
1268 | } |
1269 | |
1270 | impl<T: Read + Write + Unpin> Read for Verbose<T> { |
1271 | fn poll_read( |
1272 | mut self: Pin<&mut Self>, |
1273 | cx: &mut Context, |
1274 | mut buf: ReadBufCursor<'_>, |
1275 | ) -> Poll<std::io::Result<()>> { |
1276 | // TODO: This _does_ forget the `init` len, so it could result in |
1277 | // re-initializing twice. Needs upstream support, perhaps. |
1278 | // SAFETY: Passing to a ReadBuf will never de-initialize any bytes. |
1279 | let mut vbuf = hyper::rt::ReadBuf::uninit(unsafe { buf.as_mut() }); |
1280 | match Pin::new(&mut self.inner).poll_read(cx, vbuf.unfilled()) { |
1281 | Poll::Ready(Ok(())) => { |
1282 | log::trace!(" {:08x} read: {:?}" , self.id, Escape(vbuf.filled())); |
1283 | let len = vbuf.filled().len(); |
1284 | // SAFETY: The two cursors were for the same buffer. What was |
1285 | // filled in one is safe in the other. |
1286 | unsafe { |
1287 | buf.advance(len); |
1288 | } |
1289 | Poll::Ready(Ok(())) |
1290 | } |
1291 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
1292 | Poll::Pending => Poll::Pending, |
1293 | } |
1294 | } |
1295 | } |
1296 | |
1297 | impl<T: Read + Write + Unpin> Write for Verbose<T> { |
1298 | fn poll_write( |
1299 | mut self: Pin<&mut Self>, |
1300 | cx: &mut Context, |
1301 | buf: &[u8], |
1302 | ) -> Poll<Result<usize, std::io::Error>> { |
1303 | match Pin::new(&mut self.inner).poll_write(cx, buf) { |
1304 | Poll::Ready(Ok(n)) => { |
1305 | log::trace!(" {:08x} write: {:?}" , self.id, Escape(&buf[..n])); |
1306 | Poll::Ready(Ok(n)) |
1307 | } |
1308 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
1309 | Poll::Pending => Poll::Pending, |
1310 | } |
1311 | } |
1312 | |
1313 | fn poll_write_vectored( |
1314 | mut self: Pin<&mut Self>, |
1315 | cx: &mut Context<'_>, |
1316 | bufs: &[IoSlice<'_>], |
1317 | ) -> Poll<Result<usize, io::Error>> { |
1318 | match Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) { |
1319 | Poll::Ready(Ok(nwritten)) => { |
1320 | log::trace!( |
1321 | " {:08x} write (vectored): {:?}" , |
1322 | self.id, |
1323 | Vectored { bufs, nwritten } |
1324 | ); |
1325 | Poll::Ready(Ok(nwritten)) |
1326 | } |
1327 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
1328 | Poll::Pending => Poll::Pending, |
1329 | } |
1330 | } |
1331 | |
1332 | fn is_write_vectored(&self) -> bool { |
1333 | self.inner.is_write_vectored() |
1334 | } |
1335 | |
1336 | fn poll_flush( |
1337 | mut self: Pin<&mut Self>, |
1338 | cx: &mut Context, |
1339 | ) -> Poll<Result<(), std::io::Error>> { |
1340 | Pin::new(&mut self.inner).poll_flush(cx) |
1341 | } |
1342 | |
1343 | fn poll_shutdown( |
1344 | mut self: Pin<&mut Self>, |
1345 | cx: &mut Context, |
1346 | ) -> Poll<Result<(), std::io::Error>> { |
1347 | Pin::new(&mut self.inner).poll_shutdown(cx) |
1348 | } |
1349 | } |
1350 | |
1351 | #[cfg (feature = "__tls" )] |
1352 | impl<T: super::TlsInfoFactory> super::TlsInfoFactory for Verbose<T> { |
1353 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
1354 | self.inner.tls_info() |
1355 | } |
1356 | } |
1357 | |
1358 | struct Escape<'a>(&'a [u8]); |
1359 | |
1360 | impl fmt::Debug for Escape<'_> { |
1361 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
1362 | write!(f, "b \"" )?; |
1363 | for &c in self.0 { |
1364 | // https://doc.rust-lang.org/reference.html#byte-escapes |
1365 | if c == b' \n' { |
1366 | write!(f, " \\n" )?; |
1367 | } else if c == b' \r' { |
1368 | write!(f, " \\r" )?; |
1369 | } else if c == b' \t' { |
1370 | write!(f, " \\t" )?; |
1371 | } else if c == b' \\' || c == b'"' { |
1372 | write!(f, " \\{}" , c as char)?; |
1373 | } else if c == b' \0' { |
1374 | write!(f, " \\0" )?; |
1375 | // ASCII printable |
1376 | } else if c >= 0x20 && c < 0x7f { |
1377 | write!(f, " {}" , c as char)?; |
1378 | } else { |
1379 | write!(f, " \\x {c:02x}" )?; |
1380 | } |
1381 | } |
1382 | write!(f, " \"" )?; |
1383 | Ok(()) |
1384 | } |
1385 | } |
1386 | |
1387 | struct Vectored<'a, 'b> { |
1388 | bufs: &'a [IoSlice<'b>], |
1389 | nwritten: usize, |
1390 | } |
1391 | |
1392 | impl fmt::Debug for Vectored<'_, '_> { |
1393 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
1394 | let mut left = self.nwritten; |
1395 | for buf in self.bufs.iter() { |
1396 | if left == 0 { |
1397 | break; |
1398 | } |
1399 | let n = min(left, buf.len()); |
1400 | Escape(&buf[..n]).fmt(f)?; |
1401 | left -= n; |
1402 | } |
1403 | Ok(()) |
1404 | } |
1405 | } |
1406 | } |
1407 | |
1408 | #[cfg (feature = "__tls" )] |
1409 | #[cfg (test)] |
1410 | mod tests { |
1411 | use super::tunnel; |
1412 | use crate::proxy; |
1413 | use hyper_util::rt::TokioIo; |
1414 | use std::io::{Read, Write}; |
1415 | use std::net::TcpListener; |
1416 | use std::thread; |
1417 | use tokio::net::TcpStream; |
1418 | use tokio::runtime; |
1419 | |
1420 | static TUNNEL_UA: &str = "tunnel-test/x.y" ; |
1421 | static TUNNEL_OK: &[u8] = b"\ |
1422 | HTTP/1.1 200 OK \r\n\ |
1423 | \r\n\ |
1424 | " ; |
1425 | |
1426 | macro_rules! mock_tunnel { |
1427 | () => {{ |
1428 | mock_tunnel!(TUNNEL_OK) |
1429 | }}; |
1430 | ($write:expr) => {{ |
1431 | mock_tunnel!($write, "" ) |
1432 | }}; |
1433 | ($write:expr, $auth:expr) => {{ |
1434 | let listener = TcpListener::bind("127.0.0.1:0" ).unwrap(); |
1435 | let addr = listener.local_addr().unwrap(); |
1436 | let connect_expected = format!( |
1437 | "\ |
1438 | CONNECT {0}:{1} HTTP/1.1 \r\n\ |
1439 | Host: {0}:{1} \r\n\ |
1440 | User-Agent: {2} \r\n\ |
1441 | {3}\ |
1442 | \r\n\ |
1443 | " , |
1444 | addr.ip(), |
1445 | addr.port(), |
1446 | TUNNEL_UA, |
1447 | $auth |
1448 | ) |
1449 | .into_bytes(); |
1450 | |
1451 | thread::spawn(move || { |
1452 | let (mut sock, _) = listener.accept().unwrap(); |
1453 | let mut buf = [0u8; 4096]; |
1454 | let n = sock.read(&mut buf).unwrap(); |
1455 | assert_eq!(&buf[..n], &connect_expected[..]); |
1456 | |
1457 | sock.write_all($write).unwrap(); |
1458 | }); |
1459 | addr |
1460 | }}; |
1461 | } |
1462 | |
1463 | fn ua() -> Option<http::header::HeaderValue> { |
1464 | Some(http::header::HeaderValue::from_static(TUNNEL_UA)) |
1465 | } |
1466 | |
1467 | #[test ] |
1468 | fn test_tunnel() { |
1469 | let addr = mock_tunnel!(); |
1470 | |
1471 | let rt = runtime::Builder::new_current_thread() |
1472 | .enable_all() |
1473 | .build() |
1474 | .expect("new rt" ); |
1475 | let f = async move { |
1476 | let tcp = TokioIo::new(TcpStream::connect(&addr).await?); |
1477 | let host = addr.ip().to_string(); |
1478 | let port = addr.port(); |
1479 | tunnel(tcp, host, port, ua(), None).await |
1480 | }; |
1481 | |
1482 | rt.block_on(f).unwrap(); |
1483 | } |
1484 | |
1485 | #[test ] |
1486 | fn test_tunnel_eof() { |
1487 | let addr = mock_tunnel!(b"HTTP/1.1 200 OK" ); |
1488 | |
1489 | let rt = runtime::Builder::new_current_thread() |
1490 | .enable_all() |
1491 | .build() |
1492 | .expect("new rt" ); |
1493 | let f = async move { |
1494 | let tcp = TokioIo::new(TcpStream::connect(&addr).await?); |
1495 | let host = addr.ip().to_string(); |
1496 | let port = addr.port(); |
1497 | tunnel(tcp, host, port, ua(), None).await |
1498 | }; |
1499 | |
1500 | rt.block_on(f).unwrap_err(); |
1501 | } |
1502 | |
1503 | #[test ] |
1504 | fn test_tunnel_non_http_response() { |
1505 | let addr = mock_tunnel!(b"foo bar baz hallo" ); |
1506 | |
1507 | let rt = runtime::Builder::new_current_thread() |
1508 | .enable_all() |
1509 | .build() |
1510 | .expect("new rt" ); |
1511 | let f = async move { |
1512 | let tcp = TokioIo::new(TcpStream::connect(&addr).await?); |
1513 | let host = addr.ip().to_string(); |
1514 | let port = addr.port(); |
1515 | tunnel(tcp, host, port, ua(), None).await |
1516 | }; |
1517 | |
1518 | rt.block_on(f).unwrap_err(); |
1519 | } |
1520 | |
1521 | #[test ] |
1522 | fn test_tunnel_proxy_unauthorized() { |
1523 | let addr = mock_tunnel!( |
1524 | b"\ |
1525 | HTTP/1.1 407 Proxy Authentication Required \r\n\ |
1526 | Proxy-Authenticate: Basic realm= \"nope \"\r\n\ |
1527 | \r\n\ |
1528 | " |
1529 | ); |
1530 | |
1531 | let rt = runtime::Builder::new_current_thread() |
1532 | .enable_all() |
1533 | .build() |
1534 | .expect("new rt" ); |
1535 | let f = async move { |
1536 | let tcp = TokioIo::new(TcpStream::connect(&addr).await?); |
1537 | let host = addr.ip().to_string(); |
1538 | let port = addr.port(); |
1539 | tunnel(tcp, host, port, ua(), None).await |
1540 | }; |
1541 | |
1542 | let error = rt.block_on(f).unwrap_err(); |
1543 | assert_eq!(error.to_string(), "proxy authentication required" ); |
1544 | } |
1545 | |
1546 | #[test ] |
1547 | fn test_tunnel_basic_auth() { |
1548 | let addr = mock_tunnel!( |
1549 | TUNNEL_OK, |
1550 | "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ== \r\n" |
1551 | ); |
1552 | |
1553 | let rt = runtime::Builder::new_current_thread() |
1554 | .enable_all() |
1555 | .build() |
1556 | .expect("new rt" ); |
1557 | let f = async move { |
1558 | let tcp = TokioIo::new(TcpStream::connect(&addr).await?); |
1559 | let host = addr.ip().to_string(); |
1560 | let port = addr.port(); |
1561 | tunnel( |
1562 | tcp, |
1563 | host, |
1564 | port, |
1565 | ua(), |
1566 | Some(proxy::encode_basic_auth("Aladdin" , "open sesame" )), |
1567 | ) |
1568 | .await |
1569 | }; |
1570 | |
1571 | rt.block_on(f).unwrap(); |
1572 | } |
1573 | } |
1574 | |