1#[cfg(feature = "__tls")]
2use http::header::HeaderValue;
3use http::uri::{Authority, Scheme};
4use http::Uri;
5use hyper::client::connect::{Connected, Connection};
6use hyper::service::Service;
7#[cfg(feature = "native-tls-crate")]
8use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11use pin_project_lite::pin_project;
12use std::future::Future;
13use std::io::{self, IoSlice};
14use std::net::IpAddr;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use std::time::Duration;
19
20#[cfg(feature = "default-tls")]
21use self::native_tls_conn::NativeTlsConn;
22#[cfg(feature = "__rustls")]
23use self::rustls_tls_conn::RustlsTlsConn;
24use crate::dns::DynResolver;
25use crate::error::BoxError;
26use crate::proxy::{Proxy, ProxyScheme};
27
28pub(crate) type HttpConnector = hyper::client::HttpConnector<DynResolver>;
29
30#[derive(Clone)]
31pub(crate) struct Connector {
32 inner: Inner,
33 proxies: Arc<Vec<Proxy>>,
34 verbose: verbose::Wrapper,
35 timeout: Option<Duration>,
36 #[cfg(feature = "__tls")]
37 nodelay: bool,
38 #[cfg(feature = "__tls")]
39 tls_info: bool,
40 #[cfg(feature = "__tls")]
41 user_agent: Option<HeaderValue>,
42}
43
44#[derive(Clone)]
45enum Inner {
46 #[cfg(not(feature = "__tls"))]
47 Http(HttpConnector),
48 #[cfg(feature = "default-tls")]
49 DefaultTls(HttpConnector, TlsConnector),
50 #[cfg(feature = "__rustls")]
51 RustlsTls {
52 http: HttpConnector,
53 tls: Arc<rustls::ClientConfig>,
54 tls_proxy: Arc<rustls::ClientConfig>,
55 },
56}
57
58impl Connector {
59 #[cfg(not(feature = "__tls"))]
60 pub(crate) fn new<T>(
61 mut http: HttpConnector,
62 proxies: Arc<Vec<Proxy>>,
63 local_addr: T,
64 nodelay: bool,
65 ) -> Connector
66 where
67 T: Into<Option<IpAddr>>,
68 {
69 http.set_local_address(local_addr.into());
70 http.set_nodelay(nodelay);
71
72 Connector {
73 inner: Inner::Http(http),
74 verbose: verbose::OFF,
75 proxies,
76 timeout: None,
77 }
78 }
79
80 #[cfg(feature = "default-tls")]
81 pub(crate) fn new_default_tls<T>(
82 http: HttpConnector,
83 tls: TlsConnectorBuilder,
84 proxies: Arc<Vec<Proxy>>,
85 user_agent: Option<HeaderValue>,
86 local_addr: T,
87 nodelay: bool,
88 tls_info: bool,
89 ) -> crate::Result<Connector>
90 where
91 T: Into<Option<IpAddr>>,
92 {
93 let tls = tls.build().map_err(crate::error::builder)?;
94 Ok(Self::from_built_default_tls(
95 http, tls, proxies, user_agent, local_addr, nodelay, tls_info,
96 ))
97 }
98
99 #[cfg(feature = "default-tls")]
100 pub(crate) fn from_built_default_tls<T>(
101 mut http: HttpConnector,
102 tls: TlsConnector,
103 proxies: Arc<Vec<Proxy>>,
104 user_agent: Option<HeaderValue>,
105 local_addr: T,
106 nodelay: bool,
107 tls_info: bool,
108 ) -> Connector
109 where
110 T: Into<Option<IpAddr>>,
111 {
112 http.set_local_address(local_addr.into());
113 http.set_nodelay(nodelay);
114 http.enforce_http(false);
115
116 Connector {
117 inner: Inner::DefaultTls(http, tls),
118 proxies,
119 verbose: verbose::OFF,
120 timeout: None,
121 nodelay,
122 tls_info,
123 user_agent,
124 }
125 }
126
127 #[cfg(feature = "__rustls")]
128 pub(crate) fn new_rustls_tls<T>(
129 mut http: HttpConnector,
130 tls: rustls::ClientConfig,
131 proxies: Arc<Vec<Proxy>>,
132 user_agent: Option<HeaderValue>,
133 local_addr: T,
134 nodelay: bool,
135 tls_info: bool,
136 ) -> Connector
137 where
138 T: Into<Option<IpAddr>>,
139 {
140 http.set_local_address(local_addr.into());
141 http.set_nodelay(nodelay);
142 http.enforce_http(false);
143
144 let (tls, tls_proxy) = if proxies.is_empty() {
145 let tls = Arc::new(tls);
146 (tls.clone(), tls)
147 } else {
148 let mut tls_proxy = tls.clone();
149 tls_proxy.alpn_protocols.clear();
150 (Arc::new(tls), Arc::new(tls_proxy))
151 };
152
153 Connector {
154 inner: Inner::RustlsTls {
155 http,
156 tls,
157 tls_proxy,
158 },
159 proxies,
160 verbose: verbose::OFF,
161 timeout: None,
162 nodelay,
163 tls_info,
164 user_agent,
165 }
166 }
167
168 pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
169 self.timeout = timeout;
170 }
171
172 pub(crate) fn set_verbose(&mut self, enabled: bool) {
173 self.verbose.0 = enabled;
174 }
175
176 #[cfg(feature = "socks")]
177 async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
178 let dns = match proxy {
179 ProxyScheme::Socks5 {
180 remote_dns: false, ..
181 } => socks::DnsResolve::Local,
182 ProxyScheme::Socks5 {
183 remote_dns: true, ..
184 } => socks::DnsResolve::Proxy,
185 ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
186 unreachable!("connect_socks is only called for socks proxies");
187 }
188 };
189
190 match &self.inner {
191 #[cfg(feature = "default-tls")]
192 Inner::DefaultTls(_http, tls) => {
193 if dst.scheme() == Some(&Scheme::HTTPS) {
194 let host = dst.host().ok_or("no host in url")?.to_string();
195 let conn = socks::connect(proxy, dst, dns).await?;
196 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
197 let io = tls_connector.connect(&host, conn).await?;
198 return Ok(Conn {
199 inner: self.verbose.wrap(NativeTlsConn { inner: io }),
200 is_proxy: false,
201 tls_info: self.tls_info,
202 });
203 }
204 }
205 #[cfg(feature = "__rustls")]
206 Inner::RustlsTls { tls_proxy, .. } => {
207 if dst.scheme() == Some(&Scheme::HTTPS) {
208 use std::convert::TryFrom;
209 use tokio_rustls::TlsConnector as RustlsConnector;
210
211 let tls = tls_proxy.clone();
212 let host = dst.host().ok_or("no host in url")?.to_string();
213 let conn = socks::connect(proxy, dst, dns).await?;
214 let server_name = rustls::ServerName::try_from(host.as_str())
215 .map_err(|_| "Invalid Server Name")?;
216 let io = RustlsConnector::from(tls)
217 .connect(server_name, conn)
218 .await?;
219 return Ok(Conn {
220 inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
221 is_proxy: false,
222 tls_info: false,
223 });
224 }
225 }
226 #[cfg(not(feature = "__tls"))]
227 Inner::Http(_) => (),
228 }
229
230 socks::connect(proxy, dst, dns).await.map(|tcp| Conn {
231 inner: self.verbose.wrap(tcp),
232 is_proxy: false,
233 tls_info: false,
234 })
235 }
236
237 async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> {
238 match self.inner {
239 #[cfg(not(feature = "__tls"))]
240 Inner::Http(mut http) => {
241 let io = http.call(dst).await?;
242 Ok(Conn {
243 inner: self.verbose.wrap(io),
244 is_proxy,
245 tls_info: false,
246 })
247 }
248 #[cfg(feature = "default-tls")]
249 Inner::DefaultTls(http, tls) => {
250 let mut http = http.clone();
251
252 // Disable Nagle's algorithm for TLS handshake
253 //
254 // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
255 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
256 http.set_nodelay(true);
257 }
258
259 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
260 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
261 let io = http.call(dst).await?;
262
263 if let hyper_tls::MaybeHttpsStream::Https(stream) = io {
264 if !self.nodelay {
265 stream.get_ref().get_ref().get_ref().set_nodelay(false)?;
266 }
267 Ok(Conn {
268 inner: self.verbose.wrap(NativeTlsConn { inner: stream }),
269 is_proxy,
270 tls_info: self.tls_info,
271 })
272 } else {
273 Ok(Conn {
274 inner: self.verbose.wrap(io),
275 is_proxy,
276 tls_info: false,
277 })
278 }
279 }
280 #[cfg(feature = "__rustls")]
281 Inner::RustlsTls { http, tls, .. } => {
282 let mut http = http.clone();
283
284 // Disable Nagle's algorithm for TLS handshake
285 //
286 // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
287 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
288 http.set_nodelay(true);
289 }
290
291 let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
292 let io = http.call(dst).await?;
293
294 if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
295 if !self.nodelay {
296 let (io, _) = stream.get_ref();
297 io.set_nodelay(false)?;
298 }
299 Ok(Conn {
300 inner: self.verbose.wrap(RustlsTlsConn { inner: stream }),
301 is_proxy,
302 tls_info: self.tls_info,
303 })
304 } else {
305 Ok(Conn {
306 inner: self.verbose.wrap(io),
307 is_proxy,
308 tls_info: false,
309 })
310 }
311 }
312 }
313 }
314
315 async fn connect_via_proxy(
316 self,
317 dst: Uri,
318 proxy_scheme: ProxyScheme,
319 ) -> Result<Conn, BoxError> {
320 log::debug!("proxy({proxy_scheme:?}) intercepts '{dst:?}'");
321
322 let (proxy_dst, _auth) = match proxy_scheme {
323 ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
324 ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
325 #[cfg(feature = "socks")]
326 ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
327 };
328
329 #[cfg(feature = "__tls")]
330 let auth = _auth;
331
332 match &self.inner {
333 #[cfg(feature = "default-tls")]
334 Inner::DefaultTls(http, tls) => {
335 if dst.scheme() == Some(&Scheme::HTTPS) {
336 let host = dst.host().to_owned();
337 let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
338 let http = http.clone();
339 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
340 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
341 let conn = http.call(proxy_dst).await?;
342 log::trace!("tunneling HTTPS over proxy");
343 let tunneled = tunnel(
344 conn,
345 host.ok_or("no host in url")?.to_string(),
346 port,
347 self.user_agent.clone(),
348 auth,
349 )
350 .await?;
351 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
352 let io = tls_connector
353 .connect(host.ok_or("no host in url")?, tunneled)
354 .await?;
355 return Ok(Conn {
356 inner: self.verbose.wrap(NativeTlsConn { inner: io }),
357 is_proxy: false,
358 tls_info: false,
359 });
360 }
361 }
362 #[cfg(feature = "__rustls")]
363 Inner::RustlsTls {
364 http,
365 tls,
366 tls_proxy,
367 } => {
368 if dst.scheme() == Some(&Scheme::HTTPS) {
369 use rustls::ServerName;
370 use std::convert::TryFrom;
371 use tokio_rustls::TlsConnector as RustlsConnector;
372
373 let host = dst.host().ok_or("no host in url")?.to_string();
374 let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
375 let http = http.clone();
376 let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
377 let tls = tls.clone();
378 let conn = http.call(proxy_dst).await?;
379 log::trace!("tunneling HTTPS over proxy");
380 let maybe_server_name =
381 ServerName::try_from(host.as_str()).map_err(|_| "Invalid Server Name");
382 let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
383 let server_name = maybe_server_name?;
384 let io = RustlsConnector::from(tls)
385 .connect(server_name, tunneled)
386 .await?;
387
388 return Ok(Conn {
389 inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
390 is_proxy: false,
391 tls_info: false,
392 });
393 }
394 }
395 #[cfg(not(feature = "__tls"))]
396 Inner::Http(_) => (),
397 }
398
399 self.connect_with_maybe_proxy(proxy_dst, true).await
400 }
401
402 pub fn set_keepalive(&mut self, dur: Option<Duration>) {
403 match &mut self.inner {
404 #[cfg(feature = "default-tls")]
405 Inner::DefaultTls(http, _tls) => http.set_keepalive(dur),
406 #[cfg(feature = "__rustls")]
407 Inner::RustlsTls { http, .. } => http.set_keepalive(dur),
408 #[cfg(not(feature = "__tls"))]
409 Inner::Http(http) => http.set_keepalive(dur),
410 }
411 }
412}
413
414fn into_uri(scheme: Scheme, host: Authority) -> Uri {
415 // TODO: Should the `http` crate get `From<(Scheme, Authority)> for Uri`?
416 http::Uri::builder()
417 .scheme(scheme)
418 .authority(host)
419 .path_and_query(http::uri::PathAndQuery::from_static("/"))
420 .build()
421 .expect(msg:"scheme and authority is valid Uri")
422}
423
424async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError>
425where
426 F: Future<Output = Result<T, BoxError>>,
427{
428 if let Some(to: Duration) = timeout {
429 match tokio::time::timeout(duration:to, future:f).await {
430 Err(_elapsed: Elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
431 Ok(Ok(try_res: T)) => Ok(try_res),
432 Ok(Err(e: Box)) => Err(e),
433 }
434 } else {
435 f.await
436 }
437}
438
439impl Service<Uri> for Connector {
440 type Response = Conn;
441 type Error = BoxError;
442 type Future = Connecting;
443
444 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
445 Poll::Ready(Ok(()))
446 }
447
448 fn call(&mut self, dst: Uri) -> Self::Future {
449 log::debug!("starting new connection: {dst:?}");
450 let timeout = self.timeout;
451 for prox in self.proxies.iter() {
452 if let Some(proxy_scheme) = prox.intercept(&dst) {
453 return Box::pin(with_timeout(
454 self.clone().connect_via_proxy(dst, proxy_scheme),
455 timeout,
456 ));
457 }
458 }
459
460 Box::pin(with_timeout(
461 self.clone().connect_with_maybe_proxy(dst, false),
462 timeout,
463 ))
464 }
465}
466
467#[cfg(feature = "__tls")]
468trait TlsInfoFactory {
469 fn tls_info(&self) -> Option<crate::tls::TlsInfo>;
470}
471
472#[cfg(feature = "__tls")]
473impl TlsInfoFactory for tokio::net::TcpStream {
474 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
475 None
476 }
477}
478
479#[cfg(feature = "default-tls")]
480impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<tokio::net::TcpStream> {
481 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
482 match self {
483 hyper_tls::MaybeHttpsStream::Https(tls: &TlsStream) => tls.tls_info(),
484 hyper_tls::MaybeHttpsStream::Http(_) => None,
485 }
486 }
487}
488
489#[cfg(feature = "default-tls")]
490impl TlsInfoFactory for hyper_tls::TlsStream<hyper_tls::MaybeHttpsStream<tokio::net::TcpStream>> {
491 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
492 let peer_certificate: Option> = self
493 .get_ref()
494 .peer_certificate()
495 .ok()
496 .flatten()
497 .and_then(|c: Certificate| c.to_der().ok());
498 Some(crate::tls::TlsInfo { peer_certificate })
499 }
500}
501
502#[cfg(feature = "default-tls")]
503impl TlsInfoFactory for tokio_native_tls::TlsStream<tokio::net::TcpStream> {
504 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
505 let peer_certificate: Option> = self
506 .get_ref()
507 .peer_certificate()
508 .ok()
509 .flatten()
510 .and_then(|c: Certificate| c.to_der().ok());
511 Some(crate::tls::TlsInfo { peer_certificate })
512 }
513}
514
515#[cfg(feature = "__rustls")]
516impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream> {
517 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
518 match self {
519 hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
520 hyper_rustls::MaybeHttpsStream::Http(_) => None,
521 }
522 }
523}
524
525#[cfg(feature = "__rustls")]
526impl TlsInfoFactory for tokio_rustls::TlsStream<tokio::net::TcpStream> {
527 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
528 let peer_certificate = self
529 .get_ref()
530 .1
531 .peer_certificates()
532 .and_then(|certs| certs.first())
533 .map(|c| c.0.clone());
534 Some(crate::tls::TlsInfo { peer_certificate })
535 }
536}
537
538#[cfg(feature = "__rustls")]
539impl TlsInfoFactory
540 for tokio_rustls::client::TlsStream<hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>>
541{
542 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
543 let peer_certificate = self
544 .get_ref()
545 .1
546 .peer_certificates()
547 .and_then(|certs| certs.first())
548 .map(|c| c.0.clone());
549 Some(crate::tls::TlsInfo { peer_certificate })
550 }
551}
552
553#[cfg(feature = "__rustls")]
554impl TlsInfoFactory for tokio_rustls::client::TlsStream<tokio::net::TcpStream> {
555 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
556 let peer_certificate = self
557 .get_ref()
558 .1
559 .peer_certificates()
560 .and_then(|certs| certs.first())
561 .map(|c| c.0.clone());
562 Some(crate::tls::TlsInfo { peer_certificate })
563 }
564}
565
566pub(crate) trait AsyncConn:
567 AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static
568{
569}
570
571impl<T: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {}
572
573#[cfg(feature = "__tls")]
574trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {}
575#[cfg(not(feature = "__tls"))]
576trait AsyncConnWithInfo: AsyncConn {}
577
578#[cfg(feature = "__tls")]
579impl<T: AsyncConn + TlsInfoFactory> AsyncConnWithInfo for T {}
580#[cfg(not(feature = "__tls"))]
581impl<T: AsyncConn> AsyncConnWithInfo for T {}
582
583type BoxConn = Box<dyn AsyncConnWithInfo>;
584
585pin_project! {
586 /// Note: the `is_proxy` member means *is plain text HTTP proxy*.
587 /// This tells hyper whether the URI should be written in
588 /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or
589 /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise.
590 pub(crate) struct Conn {
591 #[pin]
592 inner: BoxConn,
593 is_proxy: bool,
594 // Only needed for __tls, but #[cfg()] on fields breaks pin_project!
595 tls_info: bool,
596 }
597}
598
599impl Connection for Conn {
600 fn connected(&self) -> Connected {
601 let connected: Connected = self.inner.connected().proxy(self.is_proxy);
602 #[cfg(feature = "__tls")]
603 if self.tls_info {
604 if let Some(tls_info: TlsInfo) = self.inner.tls_info() {
605 connected.extra(tls_info)
606 } else {
607 connected
608 }
609 } else {
610 connected
611 }
612 #[cfg(not(feature = "__tls"))]
613 connected
614 }
615}
616
617impl AsyncRead for Conn {
618 fn poll_read(
619 self: Pin<&mut Self>,
620 cx: &mut Context,
621 buf: &mut ReadBuf<'_>,
622 ) -> Poll<io::Result<()>> {
623 let this: Projection<'_> = self.project();
624 AsyncRead::poll_read(self:this.inner, cx, buf)
625 }
626}
627
628impl AsyncWrite for Conn {
629 fn poll_write(
630 self: Pin<&mut Self>,
631 cx: &mut Context,
632 buf: &[u8],
633 ) -> Poll<Result<usize, io::Error>> {
634 let this = self.project();
635 AsyncWrite::poll_write(this.inner, cx, buf)
636 }
637
638 fn poll_write_vectored(
639 self: Pin<&mut Self>,
640 cx: &mut Context<'_>,
641 bufs: &[IoSlice<'_>],
642 ) -> Poll<Result<usize, io::Error>> {
643 let this = self.project();
644 AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
645 }
646
647 fn is_write_vectored(&self) -> bool {
648 self.inner.is_write_vectored()
649 }
650
651 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
652 let this = self.project();
653 AsyncWrite::poll_flush(this.inner, cx)
654 }
655
656 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
657 let this = self.project();
658 AsyncWrite::poll_shutdown(this.inner, cx)
659 }
660}
661
662pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>;
663
664#[cfg(feature = "__tls")]
665async fn tunnel<T>(
666 mut conn: T,
667 host: String,
668 port: u16,
669 user_agent: Option<HeaderValue>,
670 auth: Option<HeaderValue>,
671) -> Result<T, BoxError>
672where
673 T: AsyncRead + AsyncWrite + Unpin,
674{
675 use tokio::io::{AsyncReadExt, AsyncWriteExt};
676
677 let mut buf = format!(
678 "\
679 CONNECT {host}:{port} HTTP/1.1\r\n\
680 Host: {host}:{port}\r\n\
681 "
682 )
683 .into_bytes();
684
685 // user-agent
686 if let Some(user_agent) = user_agent {
687 buf.extend_from_slice(b"User-Agent: ");
688 buf.extend_from_slice(user_agent.as_bytes());
689 buf.extend_from_slice(b"\r\n");
690 }
691
692 // proxy-authorization
693 if let Some(value) = auth {
694 log::debug!("tunnel to {host}:{port} using basic auth");
695 buf.extend_from_slice(b"Proxy-Authorization: ");
696 buf.extend_from_slice(value.as_bytes());
697 buf.extend_from_slice(b"\r\n");
698 }
699
700 // headers end
701 buf.extend_from_slice(b"\r\n");
702
703 conn.write_all(&buf).await?;
704
705 let mut buf = [0; 8192];
706 let mut pos = 0;
707
708 loop {
709 let n = conn.read(&mut buf[pos..]).await?;
710
711 if n == 0 {
712 return Err(tunnel_eof());
713 }
714 pos += n;
715
716 let recvd = &buf[..pos];
717 if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
718 if recvd.ends_with(b"\r\n\r\n") {
719 return Ok(conn);
720 }
721 if pos == buf.len() {
722 return Err("proxy headers too long for tunnel".into());
723 }
724 // else read more
725 } else if recvd.starts_with(b"HTTP/1.1 407") {
726 return Err("proxy authentication required".into());
727 } else {
728 return Err("unsuccessful tunnel".into());
729 }
730 }
731}
732
733#[cfg(feature = "__tls")]
734fn tunnel_eof() -> BoxError {
735 "unexpected eof while tunneling".into()
736}
737
738#[cfg(feature = "default-tls")]
739mod native_tls_conn {
740 use super::TlsInfoFactory;
741 use hyper::client::connect::{Connected, Connection};
742 use pin_project_lite::pin_project;
743 use std::{
744 io::{self, IoSlice},
745 pin::Pin,
746 task::{Context, Poll},
747 };
748 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
749 use tokio_native_tls::TlsStream;
750
751 pin_project! {
752 pub(super) struct NativeTlsConn<T> {
753 #[pin] pub(super) inner: TlsStream<T>,
754 }
755 }
756
757 impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for NativeTlsConn<T> {
758 #[cfg(feature = "native-tls-alpn")]
759 fn connected(&self) -> Connected {
760 match self.inner.get_ref().negotiated_alpn().ok() {
761 Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => self
762 .inner
763 .get_ref()
764 .get_ref()
765 .get_ref()
766 .connected()
767 .negotiated_h2(),
768 _ => self.inner.get_ref().get_ref().get_ref().connected(),
769 }
770 }
771
772 #[cfg(not(feature = "native-tls-alpn"))]
773 fn connected(&self) -> Connected {
774 self.inner.get_ref().get_ref().get_ref().connected()
775 }
776 }
777
778 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsConn<T> {
779 fn poll_read(
780 self: Pin<&mut Self>,
781 cx: &mut Context,
782 buf: &mut ReadBuf<'_>,
783 ) -> Poll<tokio::io::Result<()>> {
784 let this = self.project();
785 AsyncRead::poll_read(this.inner, cx, buf)
786 }
787 }
788
789 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsConn<T> {
790 fn poll_write(
791 self: Pin<&mut Self>,
792 cx: &mut Context,
793 buf: &[u8],
794 ) -> Poll<Result<usize, tokio::io::Error>> {
795 let this = self.project();
796 AsyncWrite::poll_write(this.inner, cx, buf)
797 }
798
799 fn poll_write_vectored(
800 self: Pin<&mut Self>,
801 cx: &mut Context<'_>,
802 bufs: &[IoSlice<'_>],
803 ) -> Poll<Result<usize, io::Error>> {
804 let this = self.project();
805 AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
806 }
807
808 fn is_write_vectored(&self) -> bool {
809 self.inner.is_write_vectored()
810 }
811
812 fn poll_flush(
813 self: Pin<&mut Self>,
814 cx: &mut Context,
815 ) -> Poll<Result<(), tokio::io::Error>> {
816 let this = self.project();
817 AsyncWrite::poll_flush(this.inner, cx)
818 }
819
820 fn poll_shutdown(
821 self: Pin<&mut Self>,
822 cx: &mut Context,
823 ) -> Poll<Result<(), tokio::io::Error>> {
824 let this = self.project();
825 AsyncWrite::poll_shutdown(this.inner, cx)
826 }
827 }
828
829 impl TlsInfoFactory for NativeTlsConn<tokio::net::TcpStream> {
830 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
831 self.inner.tls_info()
832 }
833 }
834
835 impl TlsInfoFactory for NativeTlsConn<hyper_tls::MaybeHttpsStream<tokio::net::TcpStream>> {
836 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
837 self.inner.tls_info()
838 }
839 }
840}
841
842#[cfg(feature = "__rustls")]
843mod rustls_tls_conn {
844 use super::TlsInfoFactory;
845 use hyper::client::connect::{Connected, Connection};
846 use pin_project_lite::pin_project;
847 use std::{
848 io::{self, IoSlice},
849 pin::Pin,
850 task::{Context, Poll},
851 };
852 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
853 use tokio_rustls::client::TlsStream;
854
855 pin_project! {
856 pub(super) struct RustlsTlsConn<T> {
857 #[pin] pub(super) inner: TlsStream<T>,
858 }
859 }
860
861 impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> {
862 fn connected(&self) -> Connected {
863 if self.inner.get_ref().1.alpn_protocol() == Some(b"h2") {
864 self.inner.get_ref().0.connected().negotiated_h2()
865 } else {
866 self.inner.get_ref().0.connected()
867 }
868 }
869 }
870
871 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsTlsConn<T> {
872 fn poll_read(
873 self: Pin<&mut Self>,
874 cx: &mut Context,
875 buf: &mut ReadBuf<'_>,
876 ) -> Poll<tokio::io::Result<()>> {
877 let this = self.project();
878 AsyncRead::poll_read(this.inner, cx, buf)
879 }
880 }
881
882 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> {
883 fn poll_write(
884 self: Pin<&mut Self>,
885 cx: &mut Context,
886 buf: &[u8],
887 ) -> Poll<Result<usize, tokio::io::Error>> {
888 let this = self.project();
889 AsyncWrite::poll_write(this.inner, cx, buf)
890 }
891
892 fn poll_write_vectored(
893 self: Pin<&mut Self>,
894 cx: &mut Context<'_>,
895 bufs: &[IoSlice<'_>],
896 ) -> Poll<Result<usize, io::Error>> {
897 let this = self.project();
898 AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
899 }
900
901 fn is_write_vectored(&self) -> bool {
902 self.inner.is_write_vectored()
903 }
904
905 fn poll_flush(
906 self: Pin<&mut Self>,
907 cx: &mut Context,
908 ) -> Poll<Result<(), tokio::io::Error>> {
909 let this = self.project();
910 AsyncWrite::poll_flush(this.inner, cx)
911 }
912
913 fn poll_shutdown(
914 self: Pin<&mut Self>,
915 cx: &mut Context,
916 ) -> Poll<Result<(), tokio::io::Error>> {
917 let this = self.project();
918 AsyncWrite::poll_shutdown(this.inner, cx)
919 }
920 }
921
922 impl TlsInfoFactory for RustlsTlsConn<tokio::net::TcpStream> {
923 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
924 self.inner.tls_info()
925 }
926 }
927
928 impl TlsInfoFactory for RustlsTlsConn<hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>> {
929 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
930 self.inner.tls_info()
931 }
932 }
933}
934
935#[cfg(feature = "socks")]
936mod socks {
937 use std::io;
938 use std::net::ToSocketAddrs;
939
940 use http::Uri;
941 use tokio::net::TcpStream;
942 use tokio_socks::tcp::Socks5Stream;
943
944 use super::{BoxError, Scheme};
945 use crate::proxy::ProxyScheme;
946
947 pub(super) enum DnsResolve {
948 Local,
949 Proxy,
950 }
951
952 pub(super) async fn connect(
953 proxy: ProxyScheme,
954 dst: Uri,
955 dns: DnsResolve,
956 ) -> Result<TcpStream, BoxError> {
957 let https = dst.scheme() == Some(&Scheme::HTTPS);
958 let original_host = dst
959 .host()
960 .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?;
961 let mut host = original_host.to_owned();
962 let port = match dst.port() {
963 Some(p) => p.as_u16(),
964 None if https => 443u16,
965 _ => 80u16,
966 };
967
968 if let DnsResolve::Local = dns {
969 let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next();
970 if let Some(new_target) = maybe_new_target {
971 host = new_target.ip().to_string();
972 }
973 }
974
975 let (socket_addr, auth) = match proxy {
976 ProxyScheme::Socks5 { addr, auth, .. } => (addr, auth),
977 _ => unreachable!(),
978 };
979
980 // Get a Tokio TcpStream
981 let stream = if let Some((username, password)) = auth {
982 Socks5Stream::connect_with_password(
983 socket_addr,
984 (host.as_str(), port),
985 &username,
986 &password,
987 )
988 .await
989 .map_err(|e| format!("socks connect error: {e}"))?
990 } else {
991 Socks5Stream::connect(socket_addr, (host.as_str(), port))
992 .await
993 .map_err(|e| format!("socks connect error: {e}"))?
994 };
995
996 Ok(stream.into_inner())
997 }
998}
999
1000mod verbose {
1001 use hyper::client::connect::{Connected, Connection};
1002 use std::cmp::min;
1003 use std::fmt;
1004 use std::io::{self, IoSlice};
1005 use std::pin::Pin;
1006 use std::task::{Context, Poll};
1007 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
1008
1009 pub(super) const OFF: Wrapper = Wrapper(false);
1010
1011 #[derive(Clone, Copy)]
1012 pub(super) struct Wrapper(pub(super) bool);
1013
1014 impl Wrapper {
1015 pub(super) fn wrap<T: super::AsyncConnWithInfo>(&self, conn: T) -> super::BoxConn {
1016 if self.0 && log::log_enabled!(log::Level::Trace) {
1017 Box::new(Verbose {
1018 // truncate is fine
1019 id: crate::util::fast_random() as u32,
1020 inner: conn,
1021 })
1022 } else {
1023 Box::new(conn)
1024 }
1025 }
1026 }
1027
1028 struct Verbose<T> {
1029 id: u32,
1030 inner: T,
1031 }
1032
1033 impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for Verbose<T> {
1034 fn connected(&self) -> Connected {
1035 self.inner.connected()
1036 }
1037 }
1038
1039 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for Verbose<T> {
1040 fn poll_read(
1041 mut self: Pin<&mut Self>,
1042 cx: &mut Context,
1043 buf: &mut ReadBuf<'_>,
1044 ) -> Poll<std::io::Result<()>> {
1045 match Pin::new(&mut self.inner).poll_read(cx, buf) {
1046 Poll::Ready(Ok(())) => {
1047 log::trace!("{:08x} read: {:?}", self.id, Escape(buf.filled()));
1048 Poll::Ready(Ok(()))
1049 }
1050 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1051 Poll::Pending => Poll::Pending,
1052 }
1053 }
1054 }
1055
1056 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Verbose<T> {
1057 fn poll_write(
1058 mut self: Pin<&mut Self>,
1059 cx: &mut Context,
1060 buf: &[u8],
1061 ) -> Poll<Result<usize, std::io::Error>> {
1062 match Pin::new(&mut self.inner).poll_write(cx, buf) {
1063 Poll::Ready(Ok(n)) => {
1064 log::trace!("{:08x} write: {:?}", self.id, Escape(&buf[..n]));
1065 Poll::Ready(Ok(n))
1066 }
1067 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1068 Poll::Pending => Poll::Pending,
1069 }
1070 }
1071
1072 fn poll_write_vectored(
1073 mut self: Pin<&mut Self>,
1074 cx: &mut Context<'_>,
1075 bufs: &[IoSlice<'_>],
1076 ) -> Poll<Result<usize, io::Error>> {
1077 match Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) {
1078 Poll::Ready(Ok(nwritten)) => {
1079 log::trace!(
1080 "{:08x} write (vectored): {:?}",
1081 self.id,
1082 Vectored { bufs, nwritten }
1083 );
1084 Poll::Ready(Ok(nwritten))
1085 }
1086 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1087 Poll::Pending => Poll::Pending,
1088 }
1089 }
1090
1091 fn is_write_vectored(&self) -> bool {
1092 self.inner.is_write_vectored()
1093 }
1094
1095 fn poll_flush(
1096 mut self: Pin<&mut Self>,
1097 cx: &mut Context,
1098 ) -> Poll<Result<(), std::io::Error>> {
1099 Pin::new(&mut self.inner).poll_flush(cx)
1100 }
1101
1102 fn poll_shutdown(
1103 mut self: Pin<&mut Self>,
1104 cx: &mut Context,
1105 ) -> Poll<Result<(), std::io::Error>> {
1106 Pin::new(&mut self.inner).poll_shutdown(cx)
1107 }
1108 }
1109
1110 #[cfg(feature = "__tls")]
1111 impl<T: super::TlsInfoFactory> super::TlsInfoFactory for Verbose<T> {
1112 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1113 self.inner.tls_info()
1114 }
1115 }
1116
1117 struct Escape<'a>(&'a [u8]);
1118
1119 impl fmt::Debug for Escape<'_> {
1120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1121 write!(f, "b\"")?;
1122 for &c in self.0 {
1123 // https://doc.rust-lang.org/reference.html#byte-escapes
1124 if c == b'\n' {
1125 write!(f, "\\n")?;
1126 } else if c == b'\r' {
1127 write!(f, "\\r")?;
1128 } else if c == b'\t' {
1129 write!(f, "\\t")?;
1130 } else if c == b'\\' || c == b'"' {
1131 write!(f, "\\{}", c as char)?;
1132 } else if c == b'\0' {
1133 write!(f, "\\0")?;
1134 // ASCII printable
1135 } else if c >= 0x20 && c < 0x7f {
1136 write!(f, "{}", c as char)?;
1137 } else {
1138 write!(f, "\\x{c:02x}")?;
1139 }
1140 }
1141 write!(f, "\"")?;
1142 Ok(())
1143 }
1144 }
1145
1146 struct Vectored<'a, 'b> {
1147 bufs: &'a [IoSlice<'b>],
1148 nwritten: usize,
1149 }
1150
1151 impl fmt::Debug for Vectored<'_, '_> {
1152 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1153 let mut left = self.nwritten;
1154 for buf in self.bufs.iter() {
1155 if left == 0 {
1156 break;
1157 }
1158 let n = min(left, buf.len());
1159 Escape(&buf[..n]).fmt(f)?;
1160 left -= n;
1161 }
1162 Ok(())
1163 }
1164 }
1165}
1166
1167#[cfg(feature = "__tls")]
1168#[cfg(test)]
1169mod tests {
1170 use super::tunnel;
1171 use crate::proxy;
1172 use std::io::{Read, Write};
1173 use std::net::TcpListener;
1174 use std::thread;
1175 use tokio::net::TcpStream;
1176 use tokio::runtime;
1177
1178 static TUNNEL_UA: &str = "tunnel-test/x.y";
1179 static TUNNEL_OK: &[u8] = b"\
1180 HTTP/1.1 200 OK\r\n\
1181 \r\n\
1182 ";
1183
1184 macro_rules! mock_tunnel {
1185 () => {{
1186 mock_tunnel!(TUNNEL_OK)
1187 }};
1188 ($write:expr) => {{
1189 mock_tunnel!($write, "")
1190 }};
1191 ($write:expr, $auth:expr) => {{
1192 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1193 let addr = listener.local_addr().unwrap();
1194 let connect_expected = format!(
1195 "\
1196 CONNECT {0}:{1} HTTP/1.1\r\n\
1197 Host: {0}:{1}\r\n\
1198 User-Agent: {2}\r\n\
1199 {3}\
1200 \r\n\
1201 ",
1202 addr.ip(),
1203 addr.port(),
1204 TUNNEL_UA,
1205 $auth
1206 )
1207 .into_bytes();
1208
1209 thread::spawn(move || {
1210 let (mut sock, _) = listener.accept().unwrap();
1211 let mut buf = [0u8; 4096];
1212 let n = sock.read(&mut buf).unwrap();
1213 assert_eq!(&buf[..n], &connect_expected[..]);
1214
1215 sock.write_all($write).unwrap();
1216 });
1217 addr
1218 }};
1219 }
1220
1221 fn ua() -> Option<http::header::HeaderValue> {
1222 Some(http::header::HeaderValue::from_static(TUNNEL_UA))
1223 }
1224
1225 #[test]
1226 fn test_tunnel() {
1227 let addr = mock_tunnel!();
1228
1229 let rt = runtime::Builder::new_current_thread()
1230 .enable_all()
1231 .build()
1232 .expect("new rt");
1233 let f = async move {
1234 let tcp = TcpStream::connect(&addr).await?;
1235 let host = addr.ip().to_string();
1236 let port = addr.port();
1237 tunnel(tcp, host, port, ua(), None).await
1238 };
1239
1240 rt.block_on(f).unwrap();
1241 }
1242
1243 #[test]
1244 fn test_tunnel_eof() {
1245 let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
1246
1247 let rt = runtime::Builder::new_current_thread()
1248 .enable_all()
1249 .build()
1250 .expect("new rt");
1251 let f = async move {
1252 let tcp = TcpStream::connect(&addr).await?;
1253 let host = addr.ip().to_string();
1254 let port = addr.port();
1255 tunnel(tcp, host, port, ua(), None).await
1256 };
1257
1258 rt.block_on(f).unwrap_err();
1259 }
1260
1261 #[test]
1262 fn test_tunnel_non_http_response() {
1263 let addr = mock_tunnel!(b"foo bar baz hallo");
1264
1265 let rt = runtime::Builder::new_current_thread()
1266 .enable_all()
1267 .build()
1268 .expect("new rt");
1269 let f = async move {
1270 let tcp = TcpStream::connect(&addr).await?;
1271 let host = addr.ip().to_string();
1272 let port = addr.port();
1273 tunnel(tcp, host, port, ua(), None).await
1274 };
1275
1276 rt.block_on(f).unwrap_err();
1277 }
1278
1279 #[test]
1280 fn test_tunnel_proxy_unauthorized() {
1281 let addr = mock_tunnel!(
1282 b"\
1283 HTTP/1.1 407 Proxy Authentication Required\r\n\
1284 Proxy-Authenticate: Basic realm=\"nope\"\r\n\
1285 \r\n\
1286 "
1287 );
1288
1289 let rt = runtime::Builder::new_current_thread()
1290 .enable_all()
1291 .build()
1292 .expect("new rt");
1293 let f = async move {
1294 let tcp = TcpStream::connect(&addr).await?;
1295 let host = addr.ip().to_string();
1296 let port = addr.port();
1297 tunnel(tcp, host, port, ua(), None).await
1298 };
1299
1300 let error = rt.block_on(f).unwrap_err();
1301 assert_eq!(error.to_string(), "proxy authentication required");
1302 }
1303
1304 #[test]
1305 fn test_tunnel_basic_auth() {
1306 let addr = mock_tunnel!(
1307 TUNNEL_OK,
1308 "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
1309 );
1310
1311 let rt = runtime::Builder::new_current_thread()
1312 .enable_all()
1313 .build()
1314 .expect("new rt");
1315 let f = async move {
1316 let tcp = TcpStream::connect(&addr).await?;
1317 let host = addr.ip().to_string();
1318 let port = addr.port();
1319 tunnel(
1320 tcp,
1321 host,
1322 port,
1323 ua(),
1324 Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
1325 )
1326 .await
1327 };
1328
1329 rt.block_on(f).unwrap();
1330 }
1331}
1332