1 | #[cfg (feature = "__tls" )] |
2 | use http::header::HeaderValue; |
3 | use http::uri::{Authority, Scheme}; |
4 | use http::Uri; |
5 | use hyper::client::connect::{Connected, Connection}; |
6 | use hyper::service::Service; |
7 | #[cfg (feature = "native-tls-crate" )] |
8 | use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; |
9 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
10 | |
11 | use pin_project_lite::pin_project; |
12 | use std::future::Future; |
13 | use std::io::{self, IoSlice}; |
14 | use std::net::IpAddr; |
15 | use std::pin::Pin; |
16 | use std::sync::Arc; |
17 | use std::task::{Context, Poll}; |
18 | use std::time::Duration; |
19 | |
20 | #[cfg (feature = "default-tls" )] |
21 | use self::native_tls_conn::NativeTlsConn; |
22 | #[cfg (feature = "__rustls" )] |
23 | use self::rustls_tls_conn::RustlsTlsConn; |
24 | use crate::dns::DynResolver; |
25 | use crate::error::BoxError; |
26 | use crate::proxy::{Proxy, ProxyScheme}; |
27 | |
28 | pub(crate) type HttpConnector = hyper::client::HttpConnector<DynResolver>; |
29 | |
30 | #[derive (Clone)] |
31 | pub(crate) struct Connector { |
32 | inner: Inner, |
33 | proxies: Arc<Vec<Proxy>>, |
34 | verbose: verbose::Wrapper, |
35 | timeout: Option<Duration>, |
36 | #[cfg (feature = "__tls" )] |
37 | nodelay: bool, |
38 | #[cfg (feature = "__tls" )] |
39 | user_agent: Option<HeaderValue>, |
40 | } |
41 | |
42 | #[derive (Clone)] |
43 | enum Inner { |
44 | #[cfg (not(feature = "__tls" ))] |
45 | Http(HttpConnector), |
46 | #[cfg (feature = "default-tls" )] |
47 | DefaultTls(HttpConnector, TlsConnector), |
48 | #[cfg (feature = "__rustls" )] |
49 | RustlsTls { |
50 | http: HttpConnector, |
51 | tls: Arc<rustls::ClientConfig>, |
52 | tls_proxy: Arc<rustls::ClientConfig>, |
53 | }, |
54 | } |
55 | |
56 | impl Connector { |
57 | #[cfg (not(feature = "__tls" ))] |
58 | pub(crate) fn new<T>( |
59 | mut http: HttpConnector, |
60 | proxies: Arc<Vec<Proxy>>, |
61 | local_addr: T, |
62 | nodelay: bool, |
63 | ) -> Connector |
64 | where |
65 | T: Into<Option<IpAddr>>, |
66 | { |
67 | http.set_local_address(local_addr.into()); |
68 | http.set_nodelay(nodelay); |
69 | Connector { |
70 | inner: Inner::Http(http), |
71 | verbose: verbose::OFF, |
72 | proxies, |
73 | timeout: None, |
74 | } |
75 | } |
76 | |
77 | #[cfg (feature = "default-tls" )] |
78 | pub(crate) fn new_default_tls<T>( |
79 | http: HttpConnector, |
80 | tls: TlsConnectorBuilder, |
81 | proxies: Arc<Vec<Proxy>>, |
82 | user_agent: Option<HeaderValue>, |
83 | local_addr: T, |
84 | nodelay: bool, |
85 | ) -> crate::Result<Connector> |
86 | where |
87 | T: Into<Option<IpAddr>>, |
88 | { |
89 | let tls = tls.build().map_err(crate::error::builder)?; |
90 | Ok(Self::from_built_default_tls( |
91 | http, tls, proxies, user_agent, local_addr, nodelay, |
92 | )) |
93 | } |
94 | |
95 | #[cfg (feature = "default-tls" )] |
96 | pub(crate) fn from_built_default_tls<T>( |
97 | mut http: HttpConnector, |
98 | tls: TlsConnector, |
99 | proxies: Arc<Vec<Proxy>>, |
100 | user_agent: Option<HeaderValue>, |
101 | local_addr: T, |
102 | nodelay: bool, |
103 | ) -> Connector |
104 | where |
105 | T: Into<Option<IpAddr>>, |
106 | { |
107 | http.set_local_address(local_addr.into()); |
108 | http.enforce_http(false); |
109 | |
110 | Connector { |
111 | inner: Inner::DefaultTls(http, tls), |
112 | proxies, |
113 | verbose: verbose::OFF, |
114 | timeout: None, |
115 | nodelay, |
116 | user_agent, |
117 | } |
118 | } |
119 | |
120 | #[cfg (feature = "__rustls" )] |
121 | pub(crate) fn new_rustls_tls<T>( |
122 | mut http: HttpConnector, |
123 | tls: rustls::ClientConfig, |
124 | proxies: Arc<Vec<Proxy>>, |
125 | user_agent: Option<HeaderValue>, |
126 | local_addr: T, |
127 | nodelay: bool, |
128 | ) -> Connector |
129 | where |
130 | T: Into<Option<IpAddr>>, |
131 | { |
132 | http.set_local_address(local_addr.into()); |
133 | http.enforce_http(false); |
134 | |
135 | let (tls, tls_proxy) = if proxies.is_empty() { |
136 | let tls = Arc::new(tls); |
137 | (tls.clone(), tls) |
138 | } else { |
139 | let mut tls_proxy = tls.clone(); |
140 | tls_proxy.alpn_protocols.clear(); |
141 | (Arc::new(tls), Arc::new(tls_proxy)) |
142 | }; |
143 | |
144 | Connector { |
145 | inner: Inner::RustlsTls { |
146 | http, |
147 | tls, |
148 | tls_proxy, |
149 | }, |
150 | proxies, |
151 | verbose: verbose::OFF, |
152 | timeout: None, |
153 | nodelay, |
154 | user_agent, |
155 | } |
156 | } |
157 | |
158 | pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) { |
159 | self.timeout = timeout; |
160 | } |
161 | |
162 | pub(crate) fn set_verbose(&mut self, enabled: bool) { |
163 | self.verbose.0 = enabled; |
164 | } |
165 | |
166 | #[cfg (feature = "socks" )] |
167 | async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> { |
168 | let dns = match proxy { |
169 | ProxyScheme::Socks5 { |
170 | remote_dns: false, .. |
171 | } => socks::DnsResolve::Local, |
172 | ProxyScheme::Socks5 { |
173 | remote_dns: true, .. |
174 | } => socks::DnsResolve::Proxy, |
175 | ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => { |
176 | unreachable!("connect_socks is only called for socks proxies" ); |
177 | } |
178 | }; |
179 | |
180 | match &self.inner { |
181 | #[cfg (feature = "default-tls" )] |
182 | Inner::DefaultTls(_http, tls) => { |
183 | if dst.scheme() == Some(&Scheme::HTTPS) { |
184 | let host = dst.host().ok_or("no host in url" )?.to_string(); |
185 | let conn = socks::connect(proxy, dst, dns).await?; |
186 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
187 | let io = tls_connector.connect(&host, conn).await?; |
188 | return Ok(Conn { |
189 | inner: self.verbose.wrap(NativeTlsConn { inner: io }), |
190 | is_proxy: false, |
191 | }); |
192 | } |
193 | } |
194 | #[cfg (feature = "__rustls" )] |
195 | Inner::RustlsTls { tls_proxy, .. } => { |
196 | if dst.scheme() == Some(&Scheme::HTTPS) { |
197 | use std::convert::TryFrom; |
198 | use tokio_rustls::TlsConnector as RustlsConnector; |
199 | |
200 | let tls = tls_proxy.clone(); |
201 | let host = dst.host().ok_or("no host in url" )?.to_string(); |
202 | let conn = socks::connect(proxy, dst, dns).await?; |
203 | let server_name = rustls::ServerName::try_from(host.as_str()) |
204 | .map_err(|_| "Invalid Server Name" )?; |
205 | let io = RustlsConnector::from(tls) |
206 | .connect(server_name, conn) |
207 | .await?; |
208 | return Ok(Conn { |
209 | inner: self.verbose.wrap(RustlsTlsConn { inner: io }), |
210 | is_proxy: false, |
211 | }); |
212 | } |
213 | } |
214 | #[cfg (not(feature = "__tls" ))] |
215 | Inner::Http(_) => (), |
216 | } |
217 | |
218 | socks::connect(proxy, dst, dns).await.map(|tcp| Conn { |
219 | inner: self.verbose.wrap(tcp), |
220 | is_proxy: false, |
221 | }) |
222 | } |
223 | |
224 | async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> { |
225 | match self.inner { |
226 | #[cfg (not(feature = "__tls" ))] |
227 | Inner::Http(mut http) => { |
228 | let io = http.call(dst).await?; |
229 | Ok(Conn { |
230 | inner: self.verbose.wrap(io), |
231 | is_proxy, |
232 | }) |
233 | } |
234 | #[cfg (feature = "default-tls" )] |
235 | Inner::DefaultTls(http, tls) => { |
236 | let mut http = http.clone(); |
237 | |
238 | // Disable Nagle's algorithm for TLS handshake |
239 | // |
240 | // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES |
241 | if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) { |
242 | http.set_nodelay(true); |
243 | } |
244 | |
245 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
246 | let mut http = hyper_tls::HttpsConnector::from((http, tls_connector)); |
247 | let io = http.call(dst).await?; |
248 | |
249 | if let hyper_tls::MaybeHttpsStream::Https(stream) = io { |
250 | if !self.nodelay { |
251 | stream.get_ref().get_ref().get_ref().set_nodelay(false)?; |
252 | } |
253 | Ok(Conn { |
254 | inner: self.verbose.wrap(NativeTlsConn { inner: stream }), |
255 | is_proxy, |
256 | }) |
257 | } else { |
258 | Ok(Conn { |
259 | inner: self.verbose.wrap(io), |
260 | is_proxy, |
261 | }) |
262 | } |
263 | } |
264 | #[cfg (feature = "__rustls" )] |
265 | Inner::RustlsTls { http, tls, .. } => { |
266 | let mut http = http.clone(); |
267 | |
268 | // Disable Nagle's algorithm for TLS handshake |
269 | // |
270 | // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES |
271 | if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) { |
272 | http.set_nodelay(true); |
273 | } |
274 | |
275 | let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone())); |
276 | let io = http.call(dst).await?; |
277 | |
278 | if let hyper_rustls::MaybeHttpsStream::Https(stream) = io { |
279 | if !self.nodelay { |
280 | let (io, _) = stream.get_ref(); |
281 | io.set_nodelay(false)?; |
282 | } |
283 | Ok(Conn { |
284 | inner: self.verbose.wrap(RustlsTlsConn { inner: stream }), |
285 | is_proxy, |
286 | }) |
287 | } else { |
288 | Ok(Conn { |
289 | inner: self.verbose.wrap(io), |
290 | is_proxy, |
291 | }) |
292 | } |
293 | } |
294 | } |
295 | } |
296 | |
297 | async fn connect_via_proxy( |
298 | self, |
299 | dst: Uri, |
300 | proxy_scheme: ProxyScheme, |
301 | ) -> Result<Conn, BoxError> { |
302 | log::debug!("proxy( {:?}) intercepts ' {:?}'" , proxy_scheme, dst); |
303 | |
304 | let (proxy_dst, _auth) = match proxy_scheme { |
305 | ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth), |
306 | ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth), |
307 | #[cfg (feature = "socks" )] |
308 | ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await, |
309 | }; |
310 | |
311 | #[cfg (feature = "__tls" )] |
312 | let auth = _auth; |
313 | |
314 | match &self.inner { |
315 | #[cfg (feature = "default-tls" )] |
316 | Inner::DefaultTls(http, tls) => { |
317 | if dst.scheme() == Some(&Scheme::HTTPS) { |
318 | let host = dst.host().to_owned(); |
319 | let port = dst.port().map(|p| p.as_u16()).unwrap_or(443); |
320 | let http = http.clone(); |
321 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
322 | let mut http = hyper_tls::HttpsConnector::from((http, tls_connector)); |
323 | let conn = http.call(proxy_dst).await?; |
324 | log::trace!("tunneling HTTPS over proxy" ); |
325 | let tunneled = tunnel( |
326 | conn, |
327 | host.ok_or("no host in url" )?.to_string(), |
328 | port, |
329 | self.user_agent.clone(), |
330 | auth, |
331 | ) |
332 | .await?; |
333 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
334 | let io = tls_connector |
335 | .connect(host.ok_or("no host in url" )?, tunneled) |
336 | .await?; |
337 | return Ok(Conn { |
338 | inner: self.verbose.wrap(NativeTlsConn { inner: io }), |
339 | is_proxy: false, |
340 | }); |
341 | } |
342 | } |
343 | #[cfg (feature = "__rustls" )] |
344 | Inner::RustlsTls { |
345 | http, |
346 | tls, |
347 | tls_proxy, |
348 | } => { |
349 | if dst.scheme() == Some(&Scheme::HTTPS) { |
350 | use rustls::ServerName; |
351 | use std::convert::TryFrom; |
352 | use tokio_rustls::TlsConnector as RustlsConnector; |
353 | |
354 | let host = dst.host().ok_or("no host in url" )?.to_string(); |
355 | let port = dst.port().map(|r| r.as_u16()).unwrap_or(443); |
356 | let http = http.clone(); |
357 | let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); |
358 | let tls = tls.clone(); |
359 | let conn = http.call(proxy_dst).await?; |
360 | log::trace!("tunneling HTTPS over proxy" ); |
361 | let maybe_server_name = |
362 | ServerName::try_from(host.as_str()).map_err(|_| "Invalid Server Name" ); |
363 | let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; |
364 | let server_name = maybe_server_name?; |
365 | let io = RustlsConnector::from(tls) |
366 | .connect(server_name, tunneled) |
367 | .await?; |
368 | |
369 | return Ok(Conn { |
370 | inner: self.verbose.wrap(RustlsTlsConn { inner: io }), |
371 | is_proxy: false, |
372 | }); |
373 | } |
374 | } |
375 | #[cfg (not(feature = "__tls" ))] |
376 | Inner::Http(_) => (), |
377 | } |
378 | |
379 | self.connect_with_maybe_proxy(proxy_dst, true).await |
380 | } |
381 | |
382 | pub fn set_keepalive(&mut self, dur: Option<Duration>) { |
383 | match &mut self.inner { |
384 | #[cfg (feature = "default-tls" )] |
385 | Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), |
386 | #[cfg (feature = "__rustls" )] |
387 | Inner::RustlsTls { http, .. } => http.set_keepalive(dur), |
388 | #[cfg (not(feature = "__tls" ))] |
389 | Inner::Http(http) => http.set_keepalive(dur), |
390 | } |
391 | } |
392 | } |
393 | |
394 | fn into_uri(scheme: Scheme, host: Authority) -> Uri { |
395 | // TODO: Should the `http` crate get `From<(Scheme, Authority)> for Uri`? |
396 | http::Uri::builder() |
397 | .scheme(scheme) |
398 | .authority(host) |
399 | .path_and_query(http::uri::PathAndQuery::from_static("/" )) |
400 | .build() |
401 | .expect(msg:"scheme and authority is valid Uri" ) |
402 | } |
403 | |
404 | async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError> |
405 | where |
406 | F: Future<Output = Result<T, BoxError>>, |
407 | { |
408 | if let Some(to: Duration) = timeout { |
409 | match tokio::time::timeout(duration:to, future:f).await { |
410 | Err(_elapsed: Elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError), |
411 | Ok(Ok(try_res: T)) => Ok(try_res), |
412 | Ok(Err(e: Box)) => Err(e), |
413 | } |
414 | } else { |
415 | f.await |
416 | } |
417 | } |
418 | |
419 | impl Service<Uri> for Connector { |
420 | type Response = Conn; |
421 | type Error = BoxError; |
422 | type Future = Connecting; |
423 | |
424 | fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
425 | Poll::Ready(Ok(())) |
426 | } |
427 | |
428 | fn call(&mut self, dst: Uri) -> Self::Future { |
429 | log::debug!("starting new connection: {:?}" , dst); |
430 | let timeout = self.timeout; |
431 | for prox in self.proxies.iter() { |
432 | if let Some(proxy_scheme) = prox.intercept(&dst) { |
433 | return Box::pin(with_timeout( |
434 | self.clone().connect_via_proxy(dst, proxy_scheme), |
435 | timeout, |
436 | )); |
437 | } |
438 | } |
439 | |
440 | Box::pin(with_timeout( |
441 | self.clone().connect_with_maybe_proxy(dst, false), |
442 | timeout, |
443 | )) |
444 | } |
445 | } |
446 | |
447 | pub(crate) trait AsyncConn: |
448 | AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static |
449 | { |
450 | } |
451 | |
452 | impl<T: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {} |
453 | |
454 | type BoxConn = Box<dyn AsyncConn>; |
455 | |
456 | pin_project! { |
457 | /// Note: the `is_proxy` member means *is plain text HTTP proxy*. |
458 | /// This tells hyper whether the URI should be written in |
459 | /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or |
460 | /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. |
461 | pub(crate) struct Conn { |
462 | #[pin] |
463 | inner: BoxConn, |
464 | is_proxy: bool, |
465 | } |
466 | } |
467 | |
468 | impl Connection for Conn { |
469 | fn connected(&self) -> Connected { |
470 | self.inner.connected().proxy(self.is_proxy) |
471 | } |
472 | } |
473 | |
474 | impl AsyncRead for Conn { |
475 | fn poll_read( |
476 | self: Pin<&mut Self>, |
477 | cx: &mut Context, |
478 | buf: &mut ReadBuf<'_>, |
479 | ) -> Poll<io::Result<()>> { |
480 | let this: Projection<'_> = self.project(); |
481 | AsyncRead::poll_read(self:this.inner, cx, buf) |
482 | } |
483 | } |
484 | |
485 | impl AsyncWrite for Conn { |
486 | fn poll_write( |
487 | self: Pin<&mut Self>, |
488 | cx: &mut Context, |
489 | buf: &[u8], |
490 | ) -> Poll<Result<usize, io::Error>> { |
491 | let this = self.project(); |
492 | AsyncWrite::poll_write(this.inner, cx, buf) |
493 | } |
494 | |
495 | fn poll_write_vectored( |
496 | self: Pin<&mut Self>, |
497 | cx: &mut Context<'_>, |
498 | bufs: &[IoSlice<'_>], |
499 | ) -> Poll<Result<usize, io::Error>> { |
500 | let this = self.project(); |
501 | AsyncWrite::poll_write_vectored(this.inner, cx, bufs) |
502 | } |
503 | |
504 | fn is_write_vectored(&self) -> bool { |
505 | self.inner.is_write_vectored() |
506 | } |
507 | |
508 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> { |
509 | let this = self.project(); |
510 | AsyncWrite::poll_flush(this.inner, cx) |
511 | } |
512 | |
513 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> { |
514 | let this = self.project(); |
515 | AsyncWrite::poll_shutdown(this.inner, cx) |
516 | } |
517 | } |
518 | |
519 | pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>; |
520 | |
521 | #[cfg (feature = "__tls" )] |
522 | async fn tunnel<T>( |
523 | mut conn: T, |
524 | host: String, |
525 | port: u16, |
526 | user_agent: Option<HeaderValue>, |
527 | auth: Option<HeaderValue>, |
528 | ) -> Result<T, BoxError> |
529 | where |
530 | T: AsyncRead + AsyncWrite + Unpin, |
531 | { |
532 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; |
533 | |
534 | let mut buf = format!( |
535 | "\ |
536 | CONNECT {0}: {1} HTTP/1.1 \r\n\ |
537 | Host: {0}: {1}\r\n\ |
538 | " , |
539 | host, port |
540 | ) |
541 | .into_bytes(); |
542 | |
543 | // user-agent |
544 | if let Some(user_agent) = user_agent { |
545 | buf.extend_from_slice(b"User-Agent: " ); |
546 | buf.extend_from_slice(user_agent.as_bytes()); |
547 | buf.extend_from_slice(b" \r\n" ); |
548 | } |
549 | |
550 | // proxy-authorization |
551 | if let Some(value) = auth { |
552 | log::debug!("tunnel to {}: {} using basic auth" , host, port); |
553 | buf.extend_from_slice(b"Proxy-Authorization: " ); |
554 | buf.extend_from_slice(value.as_bytes()); |
555 | buf.extend_from_slice(b" \r\n" ); |
556 | } |
557 | |
558 | // headers end |
559 | buf.extend_from_slice(b" \r\n" ); |
560 | |
561 | conn.write_all(&buf).await?; |
562 | |
563 | let mut buf = [0; 8192]; |
564 | let mut pos = 0; |
565 | |
566 | loop { |
567 | let n = conn.read(&mut buf[pos..]).await?; |
568 | |
569 | if n == 0 { |
570 | return Err(tunnel_eof()); |
571 | } |
572 | pos += n; |
573 | |
574 | let recvd = &buf[..pos]; |
575 | if recvd.starts_with(b"HTTP/1.1 200" ) || recvd.starts_with(b"HTTP/1.0 200" ) { |
576 | if recvd.ends_with(b" \r\n\r\n" ) { |
577 | return Ok(conn); |
578 | } |
579 | if pos == buf.len() { |
580 | return Err("proxy headers too long for tunnel" .into()); |
581 | } |
582 | // else read more |
583 | } else if recvd.starts_with(b"HTTP/1.1 407" ) { |
584 | return Err("proxy authentication required" .into()); |
585 | } else { |
586 | return Err("unsuccessful tunnel" .into()); |
587 | } |
588 | } |
589 | } |
590 | |
591 | #[cfg (feature = "__tls" )] |
592 | fn tunnel_eof() -> BoxError { |
593 | "unexpected eof while tunneling" .into() |
594 | } |
595 | |
596 | #[cfg (feature = "default-tls" )] |
597 | mod native_tls_conn { |
598 | use hyper::client::connect::{Connected, Connection}; |
599 | use pin_project_lite::pin_project; |
600 | use std::{ |
601 | io::{self, IoSlice}, |
602 | pin::Pin, |
603 | task::{Context, Poll}, |
604 | }; |
605 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
606 | use tokio_native_tls::TlsStream; |
607 | |
608 | pin_project! { |
609 | pub(super) struct NativeTlsConn<T> { |
610 | #[pin] pub(super) inner: TlsStream<T>, |
611 | } |
612 | } |
613 | |
614 | impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for NativeTlsConn<T> { |
615 | #[cfg (feature = "native-tls-alpn" )] |
616 | fn connected(&self) -> Connected { |
617 | match self.inner.get_ref().negotiated_alpn().ok() { |
618 | Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => self |
619 | .inner |
620 | .get_ref() |
621 | .get_ref() |
622 | .get_ref() |
623 | .connected() |
624 | .negotiated_h2(), |
625 | _ => self.inner.get_ref().get_ref().get_ref().connected(), |
626 | } |
627 | } |
628 | |
629 | #[cfg (not(feature = "native-tls-alpn" ))] |
630 | fn connected(&self) -> Connected { |
631 | self.inner.get_ref().get_ref().get_ref().connected() |
632 | } |
633 | } |
634 | |
635 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsConn<T> { |
636 | fn poll_read( |
637 | self: Pin<&mut Self>, |
638 | cx: &mut Context, |
639 | buf: &mut ReadBuf<'_>, |
640 | ) -> Poll<tokio::io::Result<()>> { |
641 | let this = self.project(); |
642 | AsyncRead::poll_read(this.inner, cx, buf) |
643 | } |
644 | } |
645 | |
646 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsConn<T> { |
647 | fn poll_write( |
648 | self: Pin<&mut Self>, |
649 | cx: &mut Context, |
650 | buf: &[u8], |
651 | ) -> Poll<Result<usize, tokio::io::Error>> { |
652 | let this = self.project(); |
653 | AsyncWrite::poll_write(this.inner, cx, buf) |
654 | } |
655 | |
656 | fn poll_write_vectored( |
657 | self: Pin<&mut Self>, |
658 | cx: &mut Context<'_>, |
659 | bufs: &[IoSlice<'_>], |
660 | ) -> Poll<Result<usize, io::Error>> { |
661 | let this = self.project(); |
662 | AsyncWrite::poll_write_vectored(this.inner, cx, bufs) |
663 | } |
664 | |
665 | fn is_write_vectored(&self) -> bool { |
666 | self.inner.is_write_vectored() |
667 | } |
668 | |
669 | fn poll_flush( |
670 | self: Pin<&mut Self>, |
671 | cx: &mut Context, |
672 | ) -> Poll<Result<(), tokio::io::Error>> { |
673 | let this = self.project(); |
674 | AsyncWrite::poll_flush(this.inner, cx) |
675 | } |
676 | |
677 | fn poll_shutdown( |
678 | self: Pin<&mut Self>, |
679 | cx: &mut Context, |
680 | ) -> Poll<Result<(), tokio::io::Error>> { |
681 | let this = self.project(); |
682 | AsyncWrite::poll_shutdown(this.inner, cx) |
683 | } |
684 | } |
685 | } |
686 | |
687 | #[cfg (feature = "__rustls" )] |
688 | mod rustls_tls_conn { |
689 | use hyper::client::connect::{Connected, Connection}; |
690 | use pin_project_lite::pin_project; |
691 | use std::{ |
692 | io::{self, IoSlice}, |
693 | pin::Pin, |
694 | task::{Context, Poll}, |
695 | }; |
696 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
697 | use tokio_rustls::client::TlsStream; |
698 | |
699 | pin_project! { |
700 | pub(super) struct RustlsTlsConn<T> { |
701 | #[pin] pub(super) inner: TlsStream<T>, |
702 | } |
703 | } |
704 | |
705 | impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> { |
706 | fn connected(&self) -> Connected { |
707 | if self.inner.get_ref().1.alpn_protocol() == Some(b"h2" ) { |
708 | self.inner.get_ref().0.connected().negotiated_h2() |
709 | } else { |
710 | self.inner.get_ref().0.connected() |
711 | } |
712 | } |
713 | } |
714 | |
715 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsTlsConn<T> { |
716 | fn poll_read( |
717 | self: Pin<&mut Self>, |
718 | cx: &mut Context, |
719 | buf: &mut ReadBuf<'_>, |
720 | ) -> Poll<tokio::io::Result<()>> { |
721 | let this = self.project(); |
722 | AsyncRead::poll_read(this.inner, cx, buf) |
723 | } |
724 | } |
725 | |
726 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> { |
727 | fn poll_write( |
728 | self: Pin<&mut Self>, |
729 | cx: &mut Context, |
730 | buf: &[u8], |
731 | ) -> Poll<Result<usize, tokio::io::Error>> { |
732 | let this = self.project(); |
733 | AsyncWrite::poll_write(this.inner, cx, buf) |
734 | } |
735 | |
736 | fn poll_write_vectored( |
737 | self: Pin<&mut Self>, |
738 | cx: &mut Context<'_>, |
739 | bufs: &[IoSlice<'_>], |
740 | ) -> Poll<Result<usize, io::Error>> { |
741 | let this = self.project(); |
742 | AsyncWrite::poll_write_vectored(this.inner, cx, bufs) |
743 | } |
744 | |
745 | fn is_write_vectored(&self) -> bool { |
746 | self.inner.is_write_vectored() |
747 | } |
748 | |
749 | fn poll_flush( |
750 | self: Pin<&mut Self>, |
751 | cx: &mut Context, |
752 | ) -> Poll<Result<(), tokio::io::Error>> { |
753 | let this = self.project(); |
754 | AsyncWrite::poll_flush(this.inner, cx) |
755 | } |
756 | |
757 | fn poll_shutdown( |
758 | self: Pin<&mut Self>, |
759 | cx: &mut Context, |
760 | ) -> Poll<Result<(), tokio::io::Error>> { |
761 | let this = self.project(); |
762 | AsyncWrite::poll_shutdown(this.inner, cx) |
763 | } |
764 | } |
765 | } |
766 | |
767 | #[cfg (feature = "socks" )] |
768 | mod socks { |
769 | use std::io; |
770 | use std::net::ToSocketAddrs; |
771 | |
772 | use http::Uri; |
773 | use tokio::net::TcpStream; |
774 | use tokio_socks::tcp::Socks5Stream; |
775 | |
776 | use super::{BoxError, Scheme}; |
777 | use crate::proxy::ProxyScheme; |
778 | |
779 | pub(super) enum DnsResolve { |
780 | Local, |
781 | Proxy, |
782 | } |
783 | |
784 | pub(super) async fn connect( |
785 | proxy: ProxyScheme, |
786 | dst: Uri, |
787 | dns: DnsResolve, |
788 | ) -> Result<TcpStream, BoxError> { |
789 | let https = dst.scheme() == Some(&Scheme::HTTPS); |
790 | let original_host = dst |
791 | .host() |
792 | .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url" ))?; |
793 | let mut host = original_host.to_owned(); |
794 | let port = match dst.port() { |
795 | Some(p) => p.as_u16(), |
796 | None if https => 443u16, |
797 | _ => 80u16, |
798 | }; |
799 | |
800 | if let DnsResolve::Local = dns { |
801 | let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next(); |
802 | if let Some(new_target) = maybe_new_target { |
803 | host = new_target.ip().to_string(); |
804 | } |
805 | } |
806 | |
807 | let (socket_addr, auth) = match proxy { |
808 | ProxyScheme::Socks5 { addr, auth, .. } => (addr, auth), |
809 | _ => unreachable!(), |
810 | }; |
811 | |
812 | // Get a Tokio TcpStream |
813 | let stream = if let Some((username, password)) = auth { |
814 | Socks5Stream::connect_with_password( |
815 | socket_addr, |
816 | (host.as_str(), port), |
817 | &username, |
818 | &password, |
819 | ) |
820 | .await |
821 | .map_err(|e| format!("socks connect error: {}" , e))? |
822 | } else { |
823 | Socks5Stream::connect(socket_addr, (host.as_str(), port)) |
824 | .await |
825 | .map_err(|e| format!("socks connect error: {}" , e))? |
826 | }; |
827 | |
828 | Ok(stream.into_inner()) |
829 | } |
830 | } |
831 | |
832 | mod verbose { |
833 | use hyper::client::connect::{Connected, Connection}; |
834 | use std::cmp::min; |
835 | use std::fmt; |
836 | use std::io::{self, IoSlice}; |
837 | use std::pin::Pin; |
838 | use std::task::{Context, Poll}; |
839 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
840 | |
841 | pub(super) const OFF: Wrapper = Wrapper(false); |
842 | |
843 | #[derive (Clone, Copy)] |
844 | pub(super) struct Wrapper(pub(super) bool); |
845 | |
846 | impl Wrapper { |
847 | pub(super) fn wrap<T: super::AsyncConn>(&self, conn: T) -> super::BoxConn { |
848 | if self.0 && log::log_enabled!(log::Level::Trace) { |
849 | Box::new(Verbose { |
850 | // truncate is fine |
851 | id: crate::util::fast_random() as u32, |
852 | inner: conn, |
853 | }) |
854 | } else { |
855 | Box::new(conn) |
856 | } |
857 | } |
858 | } |
859 | |
860 | struct Verbose<T> { |
861 | id: u32, |
862 | inner: T, |
863 | } |
864 | |
865 | impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for Verbose<T> { |
866 | fn connected(&self) -> Connected { |
867 | self.inner.connected() |
868 | } |
869 | } |
870 | |
871 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for Verbose<T> { |
872 | fn poll_read( |
873 | mut self: Pin<&mut Self>, |
874 | cx: &mut Context, |
875 | buf: &mut ReadBuf<'_>, |
876 | ) -> Poll<std::io::Result<()>> { |
877 | match Pin::new(&mut self.inner).poll_read(cx, buf) { |
878 | Poll::Ready(Ok(())) => { |
879 | log::trace!(" {:08x} read: {:?}" , self.id, Escape(buf.filled())); |
880 | Poll::Ready(Ok(())) |
881 | } |
882 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
883 | Poll::Pending => Poll::Pending, |
884 | } |
885 | } |
886 | } |
887 | |
888 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Verbose<T> { |
889 | fn poll_write( |
890 | mut self: Pin<&mut Self>, |
891 | cx: &mut Context, |
892 | buf: &[u8], |
893 | ) -> Poll<Result<usize, std::io::Error>> { |
894 | match Pin::new(&mut self.inner).poll_write(cx, buf) { |
895 | Poll::Ready(Ok(n)) => { |
896 | log::trace!(" {:08x} write: {:?}" , self.id, Escape(&buf[..n])); |
897 | Poll::Ready(Ok(n)) |
898 | } |
899 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
900 | Poll::Pending => Poll::Pending, |
901 | } |
902 | } |
903 | |
904 | fn poll_write_vectored( |
905 | mut self: Pin<&mut Self>, |
906 | cx: &mut Context<'_>, |
907 | bufs: &[IoSlice<'_>], |
908 | ) -> Poll<Result<usize, io::Error>> { |
909 | match Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) { |
910 | Poll::Ready(Ok(nwritten)) => { |
911 | log::trace!( |
912 | " {:08x} write (vectored): {:?}" , |
913 | self.id, |
914 | Vectored { bufs, nwritten } |
915 | ); |
916 | Poll::Ready(Ok(nwritten)) |
917 | } |
918 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
919 | Poll::Pending => Poll::Pending, |
920 | } |
921 | } |
922 | |
923 | fn is_write_vectored(&self) -> bool { |
924 | self.inner.is_write_vectored() |
925 | } |
926 | |
927 | fn poll_flush( |
928 | mut self: Pin<&mut Self>, |
929 | cx: &mut Context, |
930 | ) -> Poll<Result<(), std::io::Error>> { |
931 | Pin::new(&mut self.inner).poll_flush(cx) |
932 | } |
933 | |
934 | fn poll_shutdown( |
935 | mut self: Pin<&mut Self>, |
936 | cx: &mut Context, |
937 | ) -> Poll<Result<(), std::io::Error>> { |
938 | Pin::new(&mut self.inner).poll_shutdown(cx) |
939 | } |
940 | } |
941 | |
942 | struct Escape<'a>(&'a [u8]); |
943 | |
944 | impl fmt::Debug for Escape<'_> { |
945 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
946 | write!(f, "b \"" )?; |
947 | for &c in self.0 { |
948 | // https://doc.rust-lang.org/reference.html#byte-escapes |
949 | if c == b' \n' { |
950 | write!(f, " \\n" )?; |
951 | } else if c == b' \r' { |
952 | write!(f, " \\r" )?; |
953 | } else if c == b' \t' { |
954 | write!(f, " \\t" )?; |
955 | } else if c == b' \\' || c == b'"' { |
956 | write!(f, " \\{}" , c as char)?; |
957 | } else if c == b' \0' { |
958 | write!(f, " \\0" )?; |
959 | // ASCII printable |
960 | } else if c >= 0x20 && c < 0x7f { |
961 | write!(f, " {}" , c as char)?; |
962 | } else { |
963 | write!(f, " \\x {:02x}" , c)?; |
964 | } |
965 | } |
966 | write!(f, " \"" )?; |
967 | Ok(()) |
968 | } |
969 | } |
970 | |
971 | struct Vectored<'a, 'b> { |
972 | bufs: &'a [IoSlice<'b>], |
973 | nwritten: usize, |
974 | } |
975 | |
976 | impl fmt::Debug for Vectored<'_, '_> { |
977 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
978 | let mut left = self.nwritten; |
979 | for buf in self.bufs.iter() { |
980 | if left == 0 { |
981 | break; |
982 | } |
983 | let n = min(left, buf.len()); |
984 | Escape(&buf[..n]).fmt(f)?; |
985 | left -= n; |
986 | } |
987 | Ok(()) |
988 | } |
989 | } |
990 | } |
991 | |
992 | #[cfg (feature = "__tls" )] |
993 | #[cfg (test)] |
994 | mod tests { |
995 | use super::tunnel; |
996 | use crate::proxy; |
997 | use std::io::{Read, Write}; |
998 | use std::net::TcpListener; |
999 | use std::thread; |
1000 | use tokio::net::TcpStream; |
1001 | use tokio::runtime; |
1002 | |
1003 | static TUNNEL_UA: &str = "tunnel-test/x.y" ; |
1004 | static TUNNEL_OK: &[u8] = b"\ |
1005 | HTTP/1.1 200 OK \r\n\ |
1006 | \r\n\ |
1007 | " ; |
1008 | |
1009 | macro_rules! mock_tunnel { |
1010 | () => {{ |
1011 | mock_tunnel!(TUNNEL_OK) |
1012 | }}; |
1013 | ($write:expr) => {{ |
1014 | mock_tunnel!($write, "" ) |
1015 | }}; |
1016 | ($write:expr, $auth:expr) => {{ |
1017 | let listener = TcpListener::bind("127.0.0.1:0" ).unwrap(); |
1018 | let addr = listener.local_addr().unwrap(); |
1019 | let connect_expected = format!( |
1020 | "\ |
1021 | CONNECT {0}:{1} HTTP/1.1 \r\n\ |
1022 | Host: {0}:{1} \r\n\ |
1023 | User-Agent: {2} \r\n\ |
1024 | {3}\ |
1025 | \r\n\ |
1026 | " , |
1027 | addr.ip(), |
1028 | addr.port(), |
1029 | TUNNEL_UA, |
1030 | $auth |
1031 | ) |
1032 | .into_bytes(); |
1033 | |
1034 | thread::spawn(move || { |
1035 | let (mut sock, _) = listener.accept().unwrap(); |
1036 | let mut buf = [0u8; 4096]; |
1037 | let n = sock.read(&mut buf).unwrap(); |
1038 | assert_eq!(&buf[..n], &connect_expected[..]); |
1039 | |
1040 | sock.write_all($write).unwrap(); |
1041 | }); |
1042 | addr |
1043 | }}; |
1044 | } |
1045 | |
1046 | fn ua() -> Option<http::header::HeaderValue> { |
1047 | Some(http::header::HeaderValue::from_static(TUNNEL_UA)) |
1048 | } |
1049 | |
1050 | #[test ] |
1051 | fn test_tunnel() { |
1052 | let addr = mock_tunnel!(); |
1053 | |
1054 | let rt = runtime::Builder::new_current_thread() |
1055 | .enable_all() |
1056 | .build() |
1057 | .expect("new rt" ); |
1058 | let f = async move { |
1059 | let tcp = TcpStream::connect(&addr).await?; |
1060 | let host = addr.ip().to_string(); |
1061 | let port = addr.port(); |
1062 | tunnel(tcp, host, port, ua(), None).await |
1063 | }; |
1064 | |
1065 | rt.block_on(f).unwrap(); |
1066 | } |
1067 | |
1068 | #[test ] |
1069 | fn test_tunnel_eof() { |
1070 | let addr = mock_tunnel!(b"HTTP/1.1 200 OK" ); |
1071 | |
1072 | let rt = runtime::Builder::new_current_thread() |
1073 | .enable_all() |
1074 | .build() |
1075 | .expect("new rt" ); |
1076 | let f = async move { |
1077 | let tcp = TcpStream::connect(&addr).await?; |
1078 | let host = addr.ip().to_string(); |
1079 | let port = addr.port(); |
1080 | tunnel(tcp, host, port, ua(), None).await |
1081 | }; |
1082 | |
1083 | rt.block_on(f).unwrap_err(); |
1084 | } |
1085 | |
1086 | #[test ] |
1087 | fn test_tunnel_non_http_response() { |
1088 | let addr = mock_tunnel!(b"foo bar baz hallo" ); |
1089 | |
1090 | let rt = runtime::Builder::new_current_thread() |
1091 | .enable_all() |
1092 | .build() |
1093 | .expect("new rt" ); |
1094 | let f = async move { |
1095 | let tcp = TcpStream::connect(&addr).await?; |
1096 | let host = addr.ip().to_string(); |
1097 | let port = addr.port(); |
1098 | tunnel(tcp, host, port, ua(), None).await |
1099 | }; |
1100 | |
1101 | rt.block_on(f).unwrap_err(); |
1102 | } |
1103 | |
1104 | #[test ] |
1105 | fn test_tunnel_proxy_unauthorized() { |
1106 | let addr = mock_tunnel!( |
1107 | b"\ |
1108 | HTTP/1.1 407 Proxy Authentication Required \r\n\ |
1109 | Proxy-Authenticate: Basic realm= \"nope \"\r\n\ |
1110 | \r\n\ |
1111 | " |
1112 | ); |
1113 | |
1114 | let rt = runtime::Builder::new_current_thread() |
1115 | .enable_all() |
1116 | .build() |
1117 | .expect("new rt" ); |
1118 | let f = async move { |
1119 | let tcp = TcpStream::connect(&addr).await?; |
1120 | let host = addr.ip().to_string(); |
1121 | let port = addr.port(); |
1122 | tunnel(tcp, host, port, ua(), None).await |
1123 | }; |
1124 | |
1125 | let error = rt.block_on(f).unwrap_err(); |
1126 | assert_eq!(error.to_string(), "proxy authentication required" ); |
1127 | } |
1128 | |
1129 | #[test ] |
1130 | fn test_tunnel_basic_auth() { |
1131 | let addr = mock_tunnel!( |
1132 | TUNNEL_OK, |
1133 | "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ== \r\n" |
1134 | ); |
1135 | |
1136 | let rt = runtime::Builder::new_current_thread() |
1137 | .enable_all() |
1138 | .build() |
1139 | .expect("new rt" ); |
1140 | let f = async move { |
1141 | let tcp = TcpStream::connect(&addr).await?; |
1142 | let host = addr.ip().to_string(); |
1143 | let port = addr.port(); |
1144 | tunnel( |
1145 | tcp, |
1146 | host, |
1147 | port, |
1148 | ua(), |
1149 | Some(proxy::encode_basic_auth("Aladdin" , "open sesame" )), |
1150 | ) |
1151 | .await |
1152 | }; |
1153 | |
1154 | rt.block_on(f).unwrap(); |
1155 | } |
1156 | } |
1157 | |