1#[cfg(feature = "__tls")]
2use http::header::HeaderValue;
3use http::uri::{Authority, Scheme};
4use http::Uri;
5use hyper::client::connect::{Connected, Connection};
6use hyper::service::Service;
7#[cfg(feature = "native-tls-crate")]
8use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11use pin_project_lite::pin_project;
12use std::future::Future;
13use std::io::{self, IoSlice};
14use std::net::IpAddr;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use std::time::Duration;
19
20#[cfg(feature = "default-tls")]
21use self::native_tls_conn::NativeTlsConn;
22#[cfg(feature = "__rustls")]
23use self::rustls_tls_conn::RustlsTlsConn;
24use crate::dns::DynResolver;
25use crate::error::BoxError;
26use crate::proxy::{Proxy, ProxyScheme};
27
28pub(crate) type HttpConnector = hyper::client::HttpConnector<DynResolver>;
29
30#[derive(Clone)]
31pub(crate) struct Connector {
32 inner: Inner,
33 proxies: Arc<Vec<Proxy>>,
34 verbose: verbose::Wrapper,
35 timeout: Option<Duration>,
36 #[cfg(feature = "__tls")]
37 nodelay: bool,
38 #[cfg(feature = "__tls")]
39 user_agent: Option<HeaderValue>,
40}
41
42#[derive(Clone)]
43enum Inner {
44 #[cfg(not(feature = "__tls"))]
45 Http(HttpConnector),
46 #[cfg(feature = "default-tls")]
47 DefaultTls(HttpConnector, TlsConnector),
48 #[cfg(feature = "__rustls")]
49 RustlsTls {
50 http: HttpConnector,
51 tls: Arc<rustls::ClientConfig>,
52 tls_proxy: Arc<rustls::ClientConfig>,
53 },
54}
55
56impl Connector {
57 #[cfg(not(feature = "__tls"))]
58 pub(crate) fn new<T>(
59 mut http: HttpConnector,
60 proxies: Arc<Vec<Proxy>>,
61 local_addr: T,
62 nodelay: bool,
63 ) -> Connector
64 where
65 T: Into<Option<IpAddr>>,
66 {
67 http.set_local_address(local_addr.into());
68 http.set_nodelay(nodelay);
69 Connector {
70 inner: Inner::Http(http),
71 verbose: verbose::OFF,
72 proxies,
73 timeout: None,
74 }
75 }
76
77 #[cfg(feature = "default-tls")]
78 pub(crate) fn new_default_tls<T>(
79 http: HttpConnector,
80 tls: TlsConnectorBuilder,
81 proxies: Arc<Vec<Proxy>>,
82 user_agent: Option<HeaderValue>,
83 local_addr: T,
84 nodelay: bool,
85 ) -> crate::Result<Connector>
86 where
87 T: Into<Option<IpAddr>>,
88 {
89 let tls = tls.build().map_err(crate::error::builder)?;
90 Ok(Self::from_built_default_tls(
91 http, tls, proxies, user_agent, local_addr, nodelay,
92 ))
93 }
94
95 #[cfg(feature = "default-tls")]
96 pub(crate) fn from_built_default_tls<T>(
97 mut http: HttpConnector,
98 tls: TlsConnector,
99 proxies: Arc<Vec<Proxy>>,
100 user_agent: Option<HeaderValue>,
101 local_addr: T,
102 nodelay: bool,
103 ) -> Connector
104 where
105 T: Into<Option<IpAddr>>,
106 {
107 http.set_local_address(local_addr.into());
108 http.enforce_http(false);
109
110 Connector {
111 inner: Inner::DefaultTls(http, tls),
112 proxies,
113 verbose: verbose::OFF,
114 timeout: None,
115 nodelay,
116 user_agent,
117 }
118 }
119
120 #[cfg(feature = "__rustls")]
121 pub(crate) fn new_rustls_tls<T>(
122 mut http: HttpConnector,
123 tls: rustls::ClientConfig,
124 proxies: Arc<Vec<Proxy>>,
125 user_agent: Option<HeaderValue>,
126 local_addr: T,
127 nodelay: bool,
128 ) -> Connector
129 where
130 T: Into<Option<IpAddr>>,
131 {
132 http.set_local_address(local_addr.into());
133 http.enforce_http(false);
134
135 let (tls, tls_proxy) = if proxies.is_empty() {
136 let tls = Arc::new(tls);
137 (tls.clone(), tls)
138 } else {
139 let mut tls_proxy = tls.clone();
140 tls_proxy.alpn_protocols.clear();
141 (Arc::new(tls), Arc::new(tls_proxy))
142 };
143
144 Connector {
145 inner: Inner::RustlsTls {
146 http,
147 tls,
148 tls_proxy,
149 },
150 proxies,
151 verbose: verbose::OFF,
152 timeout: None,
153 nodelay,
154 user_agent,
155 }
156 }
157
158 pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
159 self.timeout = timeout;
160 }
161
162 pub(crate) fn set_verbose(&mut self, enabled: bool) {
163 self.verbose.0 = enabled;
164 }
165
166 #[cfg(feature = "socks")]
167 async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
168 let dns = match proxy {
169 ProxyScheme::Socks5 {
170 remote_dns: false, ..
171 } => socks::DnsResolve::Local,
172 ProxyScheme::Socks5 {
173 remote_dns: true, ..
174 } => socks::DnsResolve::Proxy,
175 ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
176 unreachable!("connect_socks is only called for socks proxies");
177 }
178 };
179
180 match &self.inner {
181 #[cfg(feature = "default-tls")]
182 Inner::DefaultTls(_http, tls) => {
183 if dst.scheme() == Some(&Scheme::HTTPS) {
184 let host = dst.host().ok_or("no host in url")?.to_string();
185 let conn = socks::connect(proxy, dst, dns).await?;
186 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
187 let io = tls_connector.connect(&host, conn).await?;
188 return Ok(Conn {
189 inner: self.verbose.wrap(NativeTlsConn { inner: io }),
190 is_proxy: false,
191 });
192 }
193 }
194 #[cfg(feature = "__rustls")]
195 Inner::RustlsTls { tls_proxy, .. } => {
196 if dst.scheme() == Some(&Scheme::HTTPS) {
197 use std::convert::TryFrom;
198 use tokio_rustls::TlsConnector as RustlsConnector;
199
200 let tls = tls_proxy.clone();
201 let host = dst.host().ok_or("no host in url")?.to_string();
202 let conn = socks::connect(proxy, dst, dns).await?;
203 let server_name = rustls::ServerName::try_from(host.as_str())
204 .map_err(|_| "Invalid Server Name")?;
205 let io = RustlsConnector::from(tls)
206 .connect(server_name, conn)
207 .await?;
208 return Ok(Conn {
209 inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
210 is_proxy: false,
211 });
212 }
213 }
214 #[cfg(not(feature = "__tls"))]
215 Inner::Http(_) => (),
216 }
217
218 socks::connect(proxy, dst, dns).await.map(|tcp| Conn {
219 inner: self.verbose.wrap(tcp),
220 is_proxy: false,
221 })
222 }
223
224 async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> {
225 match self.inner {
226 #[cfg(not(feature = "__tls"))]
227 Inner::Http(mut http) => {
228 let io = http.call(dst).await?;
229 Ok(Conn {
230 inner: self.verbose.wrap(io),
231 is_proxy,
232 })
233 }
234 #[cfg(feature = "default-tls")]
235 Inner::DefaultTls(http, tls) => {
236 let mut http = http.clone();
237
238 // Disable Nagle's algorithm for TLS handshake
239 //
240 // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
241 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
242 http.set_nodelay(true);
243 }
244
245 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
246 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
247 let io = http.call(dst).await?;
248
249 if let hyper_tls::MaybeHttpsStream::Https(stream) = io {
250 if !self.nodelay {
251 stream.get_ref().get_ref().get_ref().set_nodelay(false)?;
252 }
253 Ok(Conn {
254 inner: self.verbose.wrap(NativeTlsConn { inner: stream }),
255 is_proxy,
256 })
257 } else {
258 Ok(Conn {
259 inner: self.verbose.wrap(io),
260 is_proxy,
261 })
262 }
263 }
264 #[cfg(feature = "__rustls")]
265 Inner::RustlsTls { http, tls, .. } => {
266 let mut http = http.clone();
267
268 // Disable Nagle's algorithm for TLS handshake
269 //
270 // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
271 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
272 http.set_nodelay(true);
273 }
274
275 let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
276 let io = http.call(dst).await?;
277
278 if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
279 if !self.nodelay {
280 let (io, _) = stream.get_ref();
281 io.set_nodelay(false)?;
282 }
283 Ok(Conn {
284 inner: self.verbose.wrap(RustlsTlsConn { inner: stream }),
285 is_proxy,
286 })
287 } else {
288 Ok(Conn {
289 inner: self.verbose.wrap(io),
290 is_proxy,
291 })
292 }
293 }
294 }
295 }
296
297 async fn connect_via_proxy(
298 self,
299 dst: Uri,
300 proxy_scheme: ProxyScheme,
301 ) -> Result<Conn, BoxError> {
302 log::debug!("proxy({:?}) intercepts '{:?}'", proxy_scheme, dst);
303
304 let (proxy_dst, _auth) = match proxy_scheme {
305 ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
306 ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
307 #[cfg(feature = "socks")]
308 ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
309 };
310
311 #[cfg(feature = "__tls")]
312 let auth = _auth;
313
314 match &self.inner {
315 #[cfg(feature = "default-tls")]
316 Inner::DefaultTls(http, tls) => {
317 if dst.scheme() == Some(&Scheme::HTTPS) {
318 let host = dst.host().to_owned();
319 let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
320 let http = http.clone();
321 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
322 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
323 let conn = http.call(proxy_dst).await?;
324 log::trace!("tunneling HTTPS over proxy");
325 let tunneled = tunnel(
326 conn,
327 host.ok_or("no host in url")?.to_string(),
328 port,
329 self.user_agent.clone(),
330 auth,
331 )
332 .await?;
333 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
334 let io = tls_connector
335 .connect(host.ok_or("no host in url")?, tunneled)
336 .await?;
337 return Ok(Conn {
338 inner: self.verbose.wrap(NativeTlsConn { inner: io }),
339 is_proxy: false,
340 });
341 }
342 }
343 #[cfg(feature = "__rustls")]
344 Inner::RustlsTls {
345 http,
346 tls,
347 tls_proxy,
348 } => {
349 if dst.scheme() == Some(&Scheme::HTTPS) {
350 use rustls::ServerName;
351 use std::convert::TryFrom;
352 use tokio_rustls::TlsConnector as RustlsConnector;
353
354 let host = dst.host().ok_or("no host in url")?.to_string();
355 let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
356 let http = http.clone();
357 let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
358 let tls = tls.clone();
359 let conn = http.call(proxy_dst).await?;
360 log::trace!("tunneling HTTPS over proxy");
361 let maybe_server_name =
362 ServerName::try_from(host.as_str()).map_err(|_| "Invalid Server Name");
363 let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
364 let server_name = maybe_server_name?;
365 let io = RustlsConnector::from(tls)
366 .connect(server_name, tunneled)
367 .await?;
368
369 return Ok(Conn {
370 inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
371 is_proxy: false,
372 });
373 }
374 }
375 #[cfg(not(feature = "__tls"))]
376 Inner::Http(_) => (),
377 }
378
379 self.connect_with_maybe_proxy(proxy_dst, true).await
380 }
381
382 pub fn set_keepalive(&mut self, dur: Option<Duration>) {
383 match &mut self.inner {
384 #[cfg(feature = "default-tls")]
385 Inner::DefaultTls(http, _tls) => http.set_keepalive(dur),
386 #[cfg(feature = "__rustls")]
387 Inner::RustlsTls { http, .. } => http.set_keepalive(dur),
388 #[cfg(not(feature = "__tls"))]
389 Inner::Http(http) => http.set_keepalive(dur),
390 }
391 }
392}
393
394fn into_uri(scheme: Scheme, host: Authority) -> Uri {
395 // TODO: Should the `http` crate get `From<(Scheme, Authority)> for Uri`?
396 http::Uri::builder()
397 .scheme(scheme)
398 .authority(host)
399 .path_and_query(http::uri::PathAndQuery::from_static("/"))
400 .build()
401 .expect(msg:"scheme and authority is valid Uri")
402}
403
404async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError>
405where
406 F: Future<Output = Result<T, BoxError>>,
407{
408 if let Some(to: Duration) = timeout {
409 match tokio::time::timeout(duration:to, future:f).await {
410 Err(_elapsed: Elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
411 Ok(Ok(try_res: T)) => Ok(try_res),
412 Ok(Err(e: Box)) => Err(e),
413 }
414 } else {
415 f.await
416 }
417}
418
419impl Service<Uri> for Connector {
420 type Response = Conn;
421 type Error = BoxError;
422 type Future = Connecting;
423
424 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
425 Poll::Ready(Ok(()))
426 }
427
428 fn call(&mut self, dst: Uri) -> Self::Future {
429 log::debug!("starting new connection: {:?}", dst);
430 let timeout = self.timeout;
431 for prox in self.proxies.iter() {
432 if let Some(proxy_scheme) = prox.intercept(&dst) {
433 return Box::pin(with_timeout(
434 self.clone().connect_via_proxy(dst, proxy_scheme),
435 timeout,
436 ));
437 }
438 }
439
440 Box::pin(with_timeout(
441 self.clone().connect_with_maybe_proxy(dst, false),
442 timeout,
443 ))
444 }
445}
446
447pub(crate) trait AsyncConn:
448 AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static
449{
450}
451
452impl<T: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {}
453
454type BoxConn = Box<dyn AsyncConn>;
455
456pin_project! {
457 /// Note: the `is_proxy` member means *is plain text HTTP proxy*.
458 /// This tells hyper whether the URI should be written in
459 /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or
460 /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise.
461 pub(crate) struct Conn {
462 #[pin]
463 inner: BoxConn,
464 is_proxy: bool,
465 }
466}
467
468impl Connection for Conn {
469 fn connected(&self) -> Connected {
470 self.inner.connected().proxy(self.is_proxy)
471 }
472}
473
474impl AsyncRead for Conn {
475 fn poll_read(
476 self: Pin<&mut Self>,
477 cx: &mut Context,
478 buf: &mut ReadBuf<'_>,
479 ) -> Poll<io::Result<()>> {
480 let this: Projection<'_> = self.project();
481 AsyncRead::poll_read(self:this.inner, cx, buf)
482 }
483}
484
485impl AsyncWrite for Conn {
486 fn poll_write(
487 self: Pin<&mut Self>,
488 cx: &mut Context,
489 buf: &[u8],
490 ) -> Poll<Result<usize, io::Error>> {
491 let this = self.project();
492 AsyncWrite::poll_write(this.inner, cx, buf)
493 }
494
495 fn poll_write_vectored(
496 self: Pin<&mut Self>,
497 cx: &mut Context<'_>,
498 bufs: &[IoSlice<'_>],
499 ) -> Poll<Result<usize, io::Error>> {
500 let this = self.project();
501 AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
502 }
503
504 fn is_write_vectored(&self) -> bool {
505 self.inner.is_write_vectored()
506 }
507
508 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
509 let this = self.project();
510 AsyncWrite::poll_flush(this.inner, cx)
511 }
512
513 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
514 let this = self.project();
515 AsyncWrite::poll_shutdown(this.inner, cx)
516 }
517}
518
519pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>;
520
521#[cfg(feature = "__tls")]
522async fn tunnel<T>(
523 mut conn: T,
524 host: String,
525 port: u16,
526 user_agent: Option<HeaderValue>,
527 auth: Option<HeaderValue>,
528) -> Result<T, BoxError>
529where
530 T: AsyncRead + AsyncWrite + Unpin,
531{
532 use tokio::io::{AsyncReadExt, AsyncWriteExt};
533
534 let mut buf = format!(
535 "\
536 CONNECT {0}:{1} HTTP/1.1\r\n\
537 Host: {0}:{1}\r\n\
538 ",
539 host, port
540 )
541 .into_bytes();
542
543 // user-agent
544 if let Some(user_agent) = user_agent {
545 buf.extend_from_slice(b"User-Agent: ");
546 buf.extend_from_slice(user_agent.as_bytes());
547 buf.extend_from_slice(b"\r\n");
548 }
549
550 // proxy-authorization
551 if let Some(value) = auth {
552 log::debug!("tunnel to {}:{} using basic auth", host, port);
553 buf.extend_from_slice(b"Proxy-Authorization: ");
554 buf.extend_from_slice(value.as_bytes());
555 buf.extend_from_slice(b"\r\n");
556 }
557
558 // headers end
559 buf.extend_from_slice(b"\r\n");
560
561 conn.write_all(&buf).await?;
562
563 let mut buf = [0; 8192];
564 let mut pos = 0;
565
566 loop {
567 let n = conn.read(&mut buf[pos..]).await?;
568
569 if n == 0 {
570 return Err(tunnel_eof());
571 }
572 pos += n;
573
574 let recvd = &buf[..pos];
575 if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
576 if recvd.ends_with(b"\r\n\r\n") {
577 return Ok(conn);
578 }
579 if pos == buf.len() {
580 return Err("proxy headers too long for tunnel".into());
581 }
582 // else read more
583 } else if recvd.starts_with(b"HTTP/1.1 407") {
584 return Err("proxy authentication required".into());
585 } else {
586 return Err("unsuccessful tunnel".into());
587 }
588 }
589}
590
591#[cfg(feature = "__tls")]
592fn tunnel_eof() -> BoxError {
593 "unexpected eof while tunneling".into()
594}
595
596#[cfg(feature = "default-tls")]
597mod native_tls_conn {
598 use hyper::client::connect::{Connected, Connection};
599 use pin_project_lite::pin_project;
600 use std::{
601 io::{self, IoSlice},
602 pin::Pin,
603 task::{Context, Poll},
604 };
605 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
606 use tokio_native_tls::TlsStream;
607
608 pin_project! {
609 pub(super) struct NativeTlsConn<T> {
610 #[pin] pub(super) inner: TlsStream<T>,
611 }
612 }
613
614 impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for NativeTlsConn<T> {
615 #[cfg(feature = "native-tls-alpn")]
616 fn connected(&self) -> Connected {
617 match self.inner.get_ref().negotiated_alpn().ok() {
618 Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => self
619 .inner
620 .get_ref()
621 .get_ref()
622 .get_ref()
623 .connected()
624 .negotiated_h2(),
625 _ => self.inner.get_ref().get_ref().get_ref().connected(),
626 }
627 }
628
629 #[cfg(not(feature = "native-tls-alpn"))]
630 fn connected(&self) -> Connected {
631 self.inner.get_ref().get_ref().get_ref().connected()
632 }
633 }
634
635 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsConn<T> {
636 fn poll_read(
637 self: Pin<&mut Self>,
638 cx: &mut Context,
639 buf: &mut ReadBuf<'_>,
640 ) -> Poll<tokio::io::Result<()>> {
641 let this = self.project();
642 AsyncRead::poll_read(this.inner, cx, buf)
643 }
644 }
645
646 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsConn<T> {
647 fn poll_write(
648 self: Pin<&mut Self>,
649 cx: &mut Context,
650 buf: &[u8],
651 ) -> Poll<Result<usize, tokio::io::Error>> {
652 let this = self.project();
653 AsyncWrite::poll_write(this.inner, cx, buf)
654 }
655
656 fn poll_write_vectored(
657 self: Pin<&mut Self>,
658 cx: &mut Context<'_>,
659 bufs: &[IoSlice<'_>],
660 ) -> Poll<Result<usize, io::Error>> {
661 let this = self.project();
662 AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
663 }
664
665 fn is_write_vectored(&self) -> bool {
666 self.inner.is_write_vectored()
667 }
668
669 fn poll_flush(
670 self: Pin<&mut Self>,
671 cx: &mut Context,
672 ) -> Poll<Result<(), tokio::io::Error>> {
673 let this = self.project();
674 AsyncWrite::poll_flush(this.inner, cx)
675 }
676
677 fn poll_shutdown(
678 self: Pin<&mut Self>,
679 cx: &mut Context,
680 ) -> Poll<Result<(), tokio::io::Error>> {
681 let this = self.project();
682 AsyncWrite::poll_shutdown(this.inner, cx)
683 }
684 }
685}
686
687#[cfg(feature = "__rustls")]
688mod rustls_tls_conn {
689 use hyper::client::connect::{Connected, Connection};
690 use pin_project_lite::pin_project;
691 use std::{
692 io::{self, IoSlice},
693 pin::Pin,
694 task::{Context, Poll},
695 };
696 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
697 use tokio_rustls::client::TlsStream;
698
699 pin_project! {
700 pub(super) struct RustlsTlsConn<T> {
701 #[pin] pub(super) inner: TlsStream<T>,
702 }
703 }
704
705 impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> {
706 fn connected(&self) -> Connected {
707 if self.inner.get_ref().1.alpn_protocol() == Some(b"h2") {
708 self.inner.get_ref().0.connected().negotiated_h2()
709 } else {
710 self.inner.get_ref().0.connected()
711 }
712 }
713 }
714
715 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsTlsConn<T> {
716 fn poll_read(
717 self: Pin<&mut Self>,
718 cx: &mut Context,
719 buf: &mut ReadBuf<'_>,
720 ) -> Poll<tokio::io::Result<()>> {
721 let this = self.project();
722 AsyncRead::poll_read(this.inner, cx, buf)
723 }
724 }
725
726 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> {
727 fn poll_write(
728 self: Pin<&mut Self>,
729 cx: &mut Context,
730 buf: &[u8],
731 ) -> Poll<Result<usize, tokio::io::Error>> {
732 let this = self.project();
733 AsyncWrite::poll_write(this.inner, cx, buf)
734 }
735
736 fn poll_write_vectored(
737 self: Pin<&mut Self>,
738 cx: &mut Context<'_>,
739 bufs: &[IoSlice<'_>],
740 ) -> Poll<Result<usize, io::Error>> {
741 let this = self.project();
742 AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
743 }
744
745 fn is_write_vectored(&self) -> bool {
746 self.inner.is_write_vectored()
747 }
748
749 fn poll_flush(
750 self: Pin<&mut Self>,
751 cx: &mut Context,
752 ) -> Poll<Result<(), tokio::io::Error>> {
753 let this = self.project();
754 AsyncWrite::poll_flush(this.inner, cx)
755 }
756
757 fn poll_shutdown(
758 self: Pin<&mut Self>,
759 cx: &mut Context,
760 ) -> Poll<Result<(), tokio::io::Error>> {
761 let this = self.project();
762 AsyncWrite::poll_shutdown(this.inner, cx)
763 }
764 }
765}
766
767#[cfg(feature = "socks")]
768mod socks {
769 use std::io;
770 use std::net::ToSocketAddrs;
771
772 use http::Uri;
773 use tokio::net::TcpStream;
774 use tokio_socks::tcp::Socks5Stream;
775
776 use super::{BoxError, Scheme};
777 use crate::proxy::ProxyScheme;
778
779 pub(super) enum DnsResolve {
780 Local,
781 Proxy,
782 }
783
784 pub(super) async fn connect(
785 proxy: ProxyScheme,
786 dst: Uri,
787 dns: DnsResolve,
788 ) -> Result<TcpStream, BoxError> {
789 let https = dst.scheme() == Some(&Scheme::HTTPS);
790 let original_host = dst
791 .host()
792 .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?;
793 let mut host = original_host.to_owned();
794 let port = match dst.port() {
795 Some(p) => p.as_u16(),
796 None if https => 443u16,
797 _ => 80u16,
798 };
799
800 if let DnsResolve::Local = dns {
801 let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next();
802 if let Some(new_target) = maybe_new_target {
803 host = new_target.ip().to_string();
804 }
805 }
806
807 let (socket_addr, auth) = match proxy {
808 ProxyScheme::Socks5 { addr, auth, .. } => (addr, auth),
809 _ => unreachable!(),
810 };
811
812 // Get a Tokio TcpStream
813 let stream = if let Some((username, password)) = auth {
814 Socks5Stream::connect_with_password(
815 socket_addr,
816 (host.as_str(), port),
817 &username,
818 &password,
819 )
820 .await
821 .map_err(|e| format!("socks connect error: {}", e))?
822 } else {
823 Socks5Stream::connect(socket_addr, (host.as_str(), port))
824 .await
825 .map_err(|e| format!("socks connect error: {}", e))?
826 };
827
828 Ok(stream.into_inner())
829 }
830}
831
832mod verbose {
833 use hyper::client::connect::{Connected, Connection};
834 use std::cmp::min;
835 use std::fmt;
836 use std::io::{self, IoSlice};
837 use std::pin::Pin;
838 use std::task::{Context, Poll};
839 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
840
841 pub(super) const OFF: Wrapper = Wrapper(false);
842
843 #[derive(Clone, Copy)]
844 pub(super) struct Wrapper(pub(super) bool);
845
846 impl Wrapper {
847 pub(super) fn wrap<T: super::AsyncConn>(&self, conn: T) -> super::BoxConn {
848 if self.0 && log::log_enabled!(log::Level::Trace) {
849 Box::new(Verbose {
850 // truncate is fine
851 id: crate::util::fast_random() as u32,
852 inner: conn,
853 })
854 } else {
855 Box::new(conn)
856 }
857 }
858 }
859
860 struct Verbose<T> {
861 id: u32,
862 inner: T,
863 }
864
865 impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for Verbose<T> {
866 fn connected(&self) -> Connected {
867 self.inner.connected()
868 }
869 }
870
871 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for Verbose<T> {
872 fn poll_read(
873 mut self: Pin<&mut Self>,
874 cx: &mut Context,
875 buf: &mut ReadBuf<'_>,
876 ) -> Poll<std::io::Result<()>> {
877 match Pin::new(&mut self.inner).poll_read(cx, buf) {
878 Poll::Ready(Ok(())) => {
879 log::trace!("{:08x} read: {:?}", self.id, Escape(buf.filled()));
880 Poll::Ready(Ok(()))
881 }
882 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
883 Poll::Pending => Poll::Pending,
884 }
885 }
886 }
887
888 impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Verbose<T> {
889 fn poll_write(
890 mut self: Pin<&mut Self>,
891 cx: &mut Context,
892 buf: &[u8],
893 ) -> Poll<Result<usize, std::io::Error>> {
894 match Pin::new(&mut self.inner).poll_write(cx, buf) {
895 Poll::Ready(Ok(n)) => {
896 log::trace!("{:08x} write: {:?}", self.id, Escape(&buf[..n]));
897 Poll::Ready(Ok(n))
898 }
899 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
900 Poll::Pending => Poll::Pending,
901 }
902 }
903
904 fn poll_write_vectored(
905 mut self: Pin<&mut Self>,
906 cx: &mut Context<'_>,
907 bufs: &[IoSlice<'_>],
908 ) -> Poll<Result<usize, io::Error>> {
909 match Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) {
910 Poll::Ready(Ok(nwritten)) => {
911 log::trace!(
912 "{:08x} write (vectored): {:?}",
913 self.id,
914 Vectored { bufs, nwritten }
915 );
916 Poll::Ready(Ok(nwritten))
917 }
918 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
919 Poll::Pending => Poll::Pending,
920 }
921 }
922
923 fn is_write_vectored(&self) -> bool {
924 self.inner.is_write_vectored()
925 }
926
927 fn poll_flush(
928 mut self: Pin<&mut Self>,
929 cx: &mut Context,
930 ) -> Poll<Result<(), std::io::Error>> {
931 Pin::new(&mut self.inner).poll_flush(cx)
932 }
933
934 fn poll_shutdown(
935 mut self: Pin<&mut Self>,
936 cx: &mut Context,
937 ) -> Poll<Result<(), std::io::Error>> {
938 Pin::new(&mut self.inner).poll_shutdown(cx)
939 }
940 }
941
942 struct Escape<'a>(&'a [u8]);
943
944 impl fmt::Debug for Escape<'_> {
945 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
946 write!(f, "b\"")?;
947 for &c in self.0 {
948 // https://doc.rust-lang.org/reference.html#byte-escapes
949 if c == b'\n' {
950 write!(f, "\\n")?;
951 } else if c == b'\r' {
952 write!(f, "\\r")?;
953 } else if c == b'\t' {
954 write!(f, "\\t")?;
955 } else if c == b'\\' || c == b'"' {
956 write!(f, "\\{}", c as char)?;
957 } else if c == b'\0' {
958 write!(f, "\\0")?;
959 // ASCII printable
960 } else if c >= 0x20 && c < 0x7f {
961 write!(f, "{}", c as char)?;
962 } else {
963 write!(f, "\\x{:02x}", c)?;
964 }
965 }
966 write!(f, "\"")?;
967 Ok(())
968 }
969 }
970
971 struct Vectored<'a, 'b> {
972 bufs: &'a [IoSlice<'b>],
973 nwritten: usize,
974 }
975
976 impl fmt::Debug for Vectored<'_, '_> {
977 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
978 let mut left = self.nwritten;
979 for buf in self.bufs.iter() {
980 if left == 0 {
981 break;
982 }
983 let n = min(left, buf.len());
984 Escape(&buf[..n]).fmt(f)?;
985 left -= n;
986 }
987 Ok(())
988 }
989 }
990}
991
992#[cfg(feature = "__tls")]
993#[cfg(test)]
994mod tests {
995 use super::tunnel;
996 use crate::proxy;
997 use std::io::{Read, Write};
998 use std::net::TcpListener;
999 use std::thread;
1000 use tokio::net::TcpStream;
1001 use tokio::runtime;
1002
1003 static TUNNEL_UA: &str = "tunnel-test/x.y";
1004 static TUNNEL_OK: &[u8] = b"\
1005 HTTP/1.1 200 OK\r\n\
1006 \r\n\
1007 ";
1008
1009 macro_rules! mock_tunnel {
1010 () => {{
1011 mock_tunnel!(TUNNEL_OK)
1012 }};
1013 ($write:expr) => {{
1014 mock_tunnel!($write, "")
1015 }};
1016 ($write:expr, $auth:expr) => {{
1017 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1018 let addr = listener.local_addr().unwrap();
1019 let connect_expected = format!(
1020 "\
1021 CONNECT {0}:{1} HTTP/1.1\r\n\
1022 Host: {0}:{1}\r\n\
1023 User-Agent: {2}\r\n\
1024 {3}\
1025 \r\n\
1026 ",
1027 addr.ip(),
1028 addr.port(),
1029 TUNNEL_UA,
1030 $auth
1031 )
1032 .into_bytes();
1033
1034 thread::spawn(move || {
1035 let (mut sock, _) = listener.accept().unwrap();
1036 let mut buf = [0u8; 4096];
1037 let n = sock.read(&mut buf).unwrap();
1038 assert_eq!(&buf[..n], &connect_expected[..]);
1039
1040 sock.write_all($write).unwrap();
1041 });
1042 addr
1043 }};
1044 }
1045
1046 fn ua() -> Option<http::header::HeaderValue> {
1047 Some(http::header::HeaderValue::from_static(TUNNEL_UA))
1048 }
1049
1050 #[test]
1051 fn test_tunnel() {
1052 let addr = mock_tunnel!();
1053
1054 let rt = runtime::Builder::new_current_thread()
1055 .enable_all()
1056 .build()
1057 .expect("new rt");
1058 let f = async move {
1059 let tcp = TcpStream::connect(&addr).await?;
1060 let host = addr.ip().to_string();
1061 let port = addr.port();
1062 tunnel(tcp, host, port, ua(), None).await
1063 };
1064
1065 rt.block_on(f).unwrap();
1066 }
1067
1068 #[test]
1069 fn test_tunnel_eof() {
1070 let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
1071
1072 let rt = runtime::Builder::new_current_thread()
1073 .enable_all()
1074 .build()
1075 .expect("new rt");
1076 let f = async move {
1077 let tcp = TcpStream::connect(&addr).await?;
1078 let host = addr.ip().to_string();
1079 let port = addr.port();
1080 tunnel(tcp, host, port, ua(), None).await
1081 };
1082
1083 rt.block_on(f).unwrap_err();
1084 }
1085
1086 #[test]
1087 fn test_tunnel_non_http_response() {
1088 let addr = mock_tunnel!(b"foo bar baz hallo");
1089
1090 let rt = runtime::Builder::new_current_thread()
1091 .enable_all()
1092 .build()
1093 .expect("new rt");
1094 let f = async move {
1095 let tcp = TcpStream::connect(&addr).await?;
1096 let host = addr.ip().to_string();
1097 let port = addr.port();
1098 tunnel(tcp, host, port, ua(), None).await
1099 };
1100
1101 rt.block_on(f).unwrap_err();
1102 }
1103
1104 #[test]
1105 fn test_tunnel_proxy_unauthorized() {
1106 let addr = mock_tunnel!(
1107 b"\
1108 HTTP/1.1 407 Proxy Authentication Required\r\n\
1109 Proxy-Authenticate: Basic realm=\"nope\"\r\n\
1110 \r\n\
1111 "
1112 );
1113
1114 let rt = runtime::Builder::new_current_thread()
1115 .enable_all()
1116 .build()
1117 .expect("new rt");
1118 let f = async move {
1119 let tcp = TcpStream::connect(&addr).await?;
1120 let host = addr.ip().to_string();
1121 let port = addr.port();
1122 tunnel(tcp, host, port, ua(), None).await
1123 };
1124
1125 let error = rt.block_on(f).unwrap_err();
1126 assert_eq!(error.to_string(), "proxy authentication required");
1127 }
1128
1129 #[test]
1130 fn test_tunnel_basic_auth() {
1131 let addr = mock_tunnel!(
1132 TUNNEL_OK,
1133 "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
1134 );
1135
1136 let rt = runtime::Builder::new_current_thread()
1137 .enable_all()
1138 .build()
1139 .expect("new rt");
1140 let f = async move {
1141 let tcp = TcpStream::connect(&addr).await?;
1142 let host = addr.ip().to_string();
1143 let port = addr.port();
1144 tunnel(
1145 tcp,
1146 host,
1147 port,
1148 ua(),
1149 Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
1150 )
1151 .await
1152 };
1153
1154 rt.block_on(f).unwrap();
1155 }
1156}
1157