1 | #[cfg (feature = "__tls" )] |
2 | use http::header::HeaderValue; |
3 | use http::uri::{Authority, Scheme}; |
4 | use http::Uri; |
5 | use hyper::client::connect::{Connected, Connection}; |
6 | use hyper::service::Service; |
7 | #[cfg (feature = "native-tls-crate" )] |
8 | use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; |
9 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
10 | |
11 | use pin_project_lite::pin_project; |
12 | use std::future::Future; |
13 | use std::io::{self, IoSlice}; |
14 | use std::net::IpAddr; |
15 | use std::pin::Pin; |
16 | use std::sync::Arc; |
17 | use std::task::{Context, Poll}; |
18 | use std::time::Duration; |
19 | |
20 | #[cfg (feature = "default-tls" )] |
21 | use self::native_tls_conn::NativeTlsConn; |
22 | #[cfg (feature = "__rustls" )] |
23 | use self::rustls_tls_conn::RustlsTlsConn; |
24 | use crate::dns::DynResolver; |
25 | use crate::error::BoxError; |
26 | use crate::proxy::{Proxy, ProxyScheme}; |
27 | |
28 | pub(crate) type HttpConnector = hyper::client::HttpConnector<DynResolver>; |
29 | |
30 | #[derive (Clone)] |
31 | pub(crate) struct Connector { |
32 | inner: Inner, |
33 | proxies: Arc<Vec<Proxy>>, |
34 | verbose: verbose::Wrapper, |
35 | timeout: Option<Duration>, |
36 | #[cfg (feature = "__tls" )] |
37 | nodelay: bool, |
38 | #[cfg (feature = "__tls" )] |
39 | tls_info: bool, |
40 | #[cfg (feature = "__tls" )] |
41 | user_agent: Option<HeaderValue>, |
42 | } |
43 | |
44 | #[derive (Clone)] |
45 | enum Inner { |
46 | #[cfg (not(feature = "__tls" ))] |
47 | Http(HttpConnector), |
48 | #[cfg (feature = "default-tls" )] |
49 | DefaultTls(HttpConnector, TlsConnector), |
50 | #[cfg (feature = "__rustls" )] |
51 | RustlsTls { |
52 | http: HttpConnector, |
53 | tls: Arc<rustls::ClientConfig>, |
54 | tls_proxy: Arc<rustls::ClientConfig>, |
55 | }, |
56 | } |
57 | |
58 | impl Connector { |
59 | #[cfg (not(feature = "__tls" ))] |
60 | pub(crate) fn new<T>( |
61 | mut http: HttpConnector, |
62 | proxies: Arc<Vec<Proxy>>, |
63 | local_addr: T, |
64 | nodelay: bool, |
65 | ) -> Connector |
66 | where |
67 | T: Into<Option<IpAddr>>, |
68 | { |
69 | http.set_local_address(local_addr.into()); |
70 | http.set_nodelay(nodelay); |
71 | |
72 | Connector { |
73 | inner: Inner::Http(http), |
74 | verbose: verbose::OFF, |
75 | proxies, |
76 | timeout: None, |
77 | } |
78 | } |
79 | |
80 | #[cfg (feature = "default-tls" )] |
81 | pub(crate) fn new_default_tls<T>( |
82 | http: HttpConnector, |
83 | tls: TlsConnectorBuilder, |
84 | proxies: Arc<Vec<Proxy>>, |
85 | user_agent: Option<HeaderValue>, |
86 | local_addr: T, |
87 | nodelay: bool, |
88 | tls_info: bool, |
89 | ) -> crate::Result<Connector> |
90 | where |
91 | T: Into<Option<IpAddr>>, |
92 | { |
93 | let tls = tls.build().map_err(crate::error::builder)?; |
94 | Ok(Self::from_built_default_tls( |
95 | http, tls, proxies, user_agent, local_addr, nodelay, tls_info, |
96 | )) |
97 | } |
98 | |
99 | #[cfg (feature = "default-tls" )] |
100 | pub(crate) fn from_built_default_tls<T>( |
101 | mut http: HttpConnector, |
102 | tls: TlsConnector, |
103 | proxies: Arc<Vec<Proxy>>, |
104 | user_agent: Option<HeaderValue>, |
105 | local_addr: T, |
106 | nodelay: bool, |
107 | tls_info: bool, |
108 | ) -> Connector |
109 | where |
110 | T: Into<Option<IpAddr>>, |
111 | { |
112 | http.set_local_address(local_addr.into()); |
113 | http.set_nodelay(nodelay); |
114 | http.enforce_http(false); |
115 | |
116 | Connector { |
117 | inner: Inner::DefaultTls(http, tls), |
118 | proxies, |
119 | verbose: verbose::OFF, |
120 | timeout: None, |
121 | nodelay, |
122 | tls_info, |
123 | user_agent, |
124 | } |
125 | } |
126 | |
127 | #[cfg (feature = "__rustls" )] |
128 | pub(crate) fn new_rustls_tls<T>( |
129 | mut http: HttpConnector, |
130 | tls: rustls::ClientConfig, |
131 | proxies: Arc<Vec<Proxy>>, |
132 | user_agent: Option<HeaderValue>, |
133 | local_addr: T, |
134 | nodelay: bool, |
135 | tls_info: bool, |
136 | ) -> Connector |
137 | where |
138 | T: Into<Option<IpAddr>>, |
139 | { |
140 | http.set_local_address(local_addr.into()); |
141 | http.set_nodelay(nodelay); |
142 | http.enforce_http(false); |
143 | |
144 | let (tls, tls_proxy) = if proxies.is_empty() { |
145 | let tls = Arc::new(tls); |
146 | (tls.clone(), tls) |
147 | } else { |
148 | let mut tls_proxy = tls.clone(); |
149 | tls_proxy.alpn_protocols.clear(); |
150 | (Arc::new(tls), Arc::new(tls_proxy)) |
151 | }; |
152 | |
153 | Connector { |
154 | inner: Inner::RustlsTls { |
155 | http, |
156 | tls, |
157 | tls_proxy, |
158 | }, |
159 | proxies, |
160 | verbose: verbose::OFF, |
161 | timeout: None, |
162 | nodelay, |
163 | tls_info, |
164 | user_agent, |
165 | } |
166 | } |
167 | |
168 | pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) { |
169 | self.timeout = timeout; |
170 | } |
171 | |
172 | pub(crate) fn set_verbose(&mut self, enabled: bool) { |
173 | self.verbose.0 = enabled; |
174 | } |
175 | |
176 | #[cfg (feature = "socks" )] |
177 | async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> { |
178 | let dns = match proxy { |
179 | ProxyScheme::Socks5 { |
180 | remote_dns: false, .. |
181 | } => socks::DnsResolve::Local, |
182 | ProxyScheme::Socks5 { |
183 | remote_dns: true, .. |
184 | } => socks::DnsResolve::Proxy, |
185 | ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => { |
186 | unreachable!("connect_socks is only called for socks proxies" ); |
187 | } |
188 | }; |
189 | |
190 | match &self.inner { |
191 | #[cfg (feature = "default-tls" )] |
192 | Inner::DefaultTls(_http, tls) => { |
193 | if dst.scheme() == Some(&Scheme::HTTPS) { |
194 | let host = dst.host().ok_or("no host in url" )?.to_string(); |
195 | let conn = socks::connect(proxy, dst, dns).await?; |
196 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
197 | let io = tls_connector.connect(&host, conn).await?; |
198 | return Ok(Conn { |
199 | inner: self.verbose.wrap(NativeTlsConn { inner: io }), |
200 | is_proxy: false, |
201 | tls_info: self.tls_info, |
202 | }); |
203 | } |
204 | } |
205 | #[cfg (feature = "__rustls" )] |
206 | Inner::RustlsTls { tls_proxy, .. } => { |
207 | if dst.scheme() == Some(&Scheme::HTTPS) { |
208 | use std::convert::TryFrom; |
209 | use tokio_rustls::TlsConnector as RustlsConnector; |
210 | |
211 | let tls = tls_proxy.clone(); |
212 | let host = dst.host().ok_or("no host in url" )?.to_string(); |
213 | let conn = socks::connect(proxy, dst, dns).await?; |
214 | let server_name = rustls::ServerName::try_from(host.as_str()) |
215 | .map_err(|_| "Invalid Server Name" )?; |
216 | let io = RustlsConnector::from(tls) |
217 | .connect(server_name, conn) |
218 | .await?; |
219 | return Ok(Conn { |
220 | inner: self.verbose.wrap(RustlsTlsConn { inner: io }), |
221 | is_proxy: false, |
222 | tls_info: false, |
223 | }); |
224 | } |
225 | } |
226 | #[cfg (not(feature = "__tls" ))] |
227 | Inner::Http(_) => (), |
228 | } |
229 | |
230 | socks::connect(proxy, dst, dns).await.map(|tcp| Conn { |
231 | inner: self.verbose.wrap(tcp), |
232 | is_proxy: false, |
233 | tls_info: false, |
234 | }) |
235 | } |
236 | |
237 | async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> { |
238 | match self.inner { |
239 | #[cfg (not(feature = "__tls" ))] |
240 | Inner::Http(mut http) => { |
241 | let io = http.call(dst).await?; |
242 | Ok(Conn { |
243 | inner: self.verbose.wrap(io), |
244 | is_proxy, |
245 | tls_info: false, |
246 | }) |
247 | } |
248 | #[cfg (feature = "default-tls" )] |
249 | Inner::DefaultTls(http, tls) => { |
250 | let mut http = http.clone(); |
251 | |
252 | // Disable Nagle's algorithm for TLS handshake |
253 | // |
254 | // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES |
255 | if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) { |
256 | http.set_nodelay(true); |
257 | } |
258 | |
259 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
260 | let mut http = hyper_tls::HttpsConnector::from((http, tls_connector)); |
261 | let io = http.call(dst).await?; |
262 | |
263 | if let hyper_tls::MaybeHttpsStream::Https(stream) = io { |
264 | if !self.nodelay { |
265 | stream.get_ref().get_ref().get_ref().set_nodelay(false)?; |
266 | } |
267 | Ok(Conn { |
268 | inner: self.verbose.wrap(NativeTlsConn { inner: stream }), |
269 | is_proxy, |
270 | tls_info: self.tls_info, |
271 | }) |
272 | } else { |
273 | Ok(Conn { |
274 | inner: self.verbose.wrap(io), |
275 | is_proxy, |
276 | tls_info: false, |
277 | }) |
278 | } |
279 | } |
280 | #[cfg (feature = "__rustls" )] |
281 | Inner::RustlsTls { http, tls, .. } => { |
282 | let mut http = http.clone(); |
283 | |
284 | // Disable Nagle's algorithm for TLS handshake |
285 | // |
286 | // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES |
287 | if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) { |
288 | http.set_nodelay(true); |
289 | } |
290 | |
291 | let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone())); |
292 | let io = http.call(dst).await?; |
293 | |
294 | if let hyper_rustls::MaybeHttpsStream::Https(stream) = io { |
295 | if !self.nodelay { |
296 | let (io, _) = stream.get_ref(); |
297 | io.set_nodelay(false)?; |
298 | } |
299 | Ok(Conn { |
300 | inner: self.verbose.wrap(RustlsTlsConn { inner: stream }), |
301 | is_proxy, |
302 | tls_info: self.tls_info, |
303 | }) |
304 | } else { |
305 | Ok(Conn { |
306 | inner: self.verbose.wrap(io), |
307 | is_proxy, |
308 | tls_info: false, |
309 | }) |
310 | } |
311 | } |
312 | } |
313 | } |
314 | |
315 | async fn connect_via_proxy( |
316 | self, |
317 | dst: Uri, |
318 | proxy_scheme: ProxyScheme, |
319 | ) -> Result<Conn, BoxError> { |
320 | log::debug!("proxy( {proxy_scheme:?}) intercepts ' {dst:?}'" ); |
321 | |
322 | let (proxy_dst, _auth) = match proxy_scheme { |
323 | ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth), |
324 | ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth), |
325 | #[cfg (feature = "socks" )] |
326 | ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await, |
327 | }; |
328 | |
329 | #[cfg (feature = "__tls" )] |
330 | let auth = _auth; |
331 | |
332 | match &self.inner { |
333 | #[cfg (feature = "default-tls" )] |
334 | Inner::DefaultTls(http, tls) => { |
335 | if dst.scheme() == Some(&Scheme::HTTPS) { |
336 | let host = dst.host().to_owned(); |
337 | let port = dst.port().map(|p| p.as_u16()).unwrap_or(443); |
338 | let http = http.clone(); |
339 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
340 | let mut http = hyper_tls::HttpsConnector::from((http, tls_connector)); |
341 | let conn = http.call(proxy_dst).await?; |
342 | log::trace!("tunneling HTTPS over proxy" ); |
343 | let tunneled = tunnel( |
344 | conn, |
345 | host.ok_or("no host in url" )?.to_string(), |
346 | port, |
347 | self.user_agent.clone(), |
348 | auth, |
349 | ) |
350 | .await?; |
351 | let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone()); |
352 | let io = tls_connector |
353 | .connect(host.ok_or("no host in url" )?, tunneled) |
354 | .await?; |
355 | return Ok(Conn { |
356 | inner: self.verbose.wrap(NativeTlsConn { inner: io }), |
357 | is_proxy: false, |
358 | tls_info: false, |
359 | }); |
360 | } |
361 | } |
362 | #[cfg (feature = "__rustls" )] |
363 | Inner::RustlsTls { |
364 | http, |
365 | tls, |
366 | tls_proxy, |
367 | } => { |
368 | if dst.scheme() == Some(&Scheme::HTTPS) { |
369 | use rustls::ServerName; |
370 | use std::convert::TryFrom; |
371 | use tokio_rustls::TlsConnector as RustlsConnector; |
372 | |
373 | let host = dst.host().ok_or("no host in url" )?.to_string(); |
374 | let port = dst.port().map(|r| r.as_u16()).unwrap_or(443); |
375 | let http = http.clone(); |
376 | let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); |
377 | let tls = tls.clone(); |
378 | let conn = http.call(proxy_dst).await?; |
379 | log::trace!("tunneling HTTPS over proxy" ); |
380 | let maybe_server_name = |
381 | ServerName::try_from(host.as_str()).map_err(|_| "Invalid Server Name" ); |
382 | let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; |
383 | let server_name = maybe_server_name?; |
384 | let io = RustlsConnector::from(tls) |
385 | .connect(server_name, tunneled) |
386 | .await?; |
387 | |
388 | return Ok(Conn { |
389 | inner: self.verbose.wrap(RustlsTlsConn { inner: io }), |
390 | is_proxy: false, |
391 | tls_info: false, |
392 | }); |
393 | } |
394 | } |
395 | #[cfg (not(feature = "__tls" ))] |
396 | Inner::Http(_) => (), |
397 | } |
398 | |
399 | self.connect_with_maybe_proxy(proxy_dst, true).await |
400 | } |
401 | |
402 | pub fn set_keepalive(&mut self, dur: Option<Duration>) { |
403 | match &mut self.inner { |
404 | #[cfg (feature = "default-tls" )] |
405 | Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), |
406 | #[cfg (feature = "__rustls" )] |
407 | Inner::RustlsTls { http, .. } => http.set_keepalive(dur), |
408 | #[cfg (not(feature = "__tls" ))] |
409 | Inner::Http(http) => http.set_keepalive(dur), |
410 | } |
411 | } |
412 | } |
413 | |
414 | fn into_uri(scheme: Scheme, host: Authority) -> Uri { |
415 | // TODO: Should the `http` crate get `From<(Scheme, Authority)> for Uri`? |
416 | http::Uri::builder() |
417 | .scheme(scheme) |
418 | .authority(host) |
419 | .path_and_query(http::uri::PathAndQuery::from_static("/" )) |
420 | .build() |
421 | .expect(msg:"scheme and authority is valid Uri" ) |
422 | } |
423 | |
424 | async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError> |
425 | where |
426 | F: Future<Output = Result<T, BoxError>>, |
427 | { |
428 | if let Some(to: Duration) = timeout { |
429 | match tokio::time::timeout(duration:to, future:f).await { |
430 | Err(_elapsed: Elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError), |
431 | Ok(Ok(try_res: T)) => Ok(try_res), |
432 | Ok(Err(e: Box)) => Err(e), |
433 | } |
434 | } else { |
435 | f.await |
436 | } |
437 | } |
438 | |
439 | impl Service<Uri> for Connector { |
440 | type Response = Conn; |
441 | type Error = BoxError; |
442 | type Future = Connecting; |
443 | |
444 | fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
445 | Poll::Ready(Ok(())) |
446 | } |
447 | |
448 | fn call(&mut self, dst: Uri) -> Self::Future { |
449 | log::debug!("starting new connection: {dst:?}" ); |
450 | let timeout = self.timeout; |
451 | for prox in self.proxies.iter() { |
452 | if let Some(proxy_scheme) = prox.intercept(&dst) { |
453 | return Box::pin(with_timeout( |
454 | self.clone().connect_via_proxy(dst, proxy_scheme), |
455 | timeout, |
456 | )); |
457 | } |
458 | } |
459 | |
460 | Box::pin(with_timeout( |
461 | self.clone().connect_with_maybe_proxy(dst, false), |
462 | timeout, |
463 | )) |
464 | } |
465 | } |
466 | |
467 | #[cfg (feature = "__tls" )] |
468 | trait TlsInfoFactory { |
469 | fn tls_info(&self) -> Option<crate::tls::TlsInfo>; |
470 | } |
471 | |
472 | #[cfg (feature = "__tls" )] |
473 | impl TlsInfoFactory for tokio::net::TcpStream { |
474 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
475 | None |
476 | } |
477 | } |
478 | |
479 | #[cfg (feature = "default-tls" )] |
480 | impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<tokio::net::TcpStream> { |
481 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
482 | match self { |
483 | hyper_tls::MaybeHttpsStream::Https(tls: &TlsStream) => tls.tls_info(), |
484 | hyper_tls::MaybeHttpsStream::Http(_) => None, |
485 | } |
486 | } |
487 | } |
488 | |
489 | #[cfg (feature = "default-tls" )] |
490 | impl TlsInfoFactory for hyper_tls::TlsStream<hyper_tls::MaybeHttpsStream<tokio::net::TcpStream>> { |
491 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
492 | let peer_certificate: Option> = self |
493 | .get_ref() |
494 | .peer_certificate() |
495 | .ok() |
496 | .flatten() |
497 | .and_then(|c: Certificate| c.to_der().ok()); |
498 | Some(crate::tls::TlsInfo { peer_certificate }) |
499 | } |
500 | } |
501 | |
502 | #[cfg (feature = "default-tls" )] |
503 | impl TlsInfoFactory for tokio_native_tls::TlsStream<tokio::net::TcpStream> { |
504 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
505 | let peer_certificate: Option> = self |
506 | .get_ref() |
507 | .peer_certificate() |
508 | .ok() |
509 | .flatten() |
510 | .and_then(|c: Certificate| c.to_der().ok()); |
511 | Some(crate::tls::TlsInfo { peer_certificate }) |
512 | } |
513 | } |
514 | |
515 | #[cfg (feature = "__rustls" )] |
516 | impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream> { |
517 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
518 | match self { |
519 | hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(), |
520 | hyper_rustls::MaybeHttpsStream::Http(_) => None, |
521 | } |
522 | } |
523 | } |
524 | |
525 | #[cfg (feature = "__rustls" )] |
526 | impl TlsInfoFactory for tokio_rustls::TlsStream<tokio::net::TcpStream> { |
527 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
528 | let peer_certificate = self |
529 | .get_ref() |
530 | .1 |
531 | .peer_certificates() |
532 | .and_then(|certs| certs.first()) |
533 | .map(|c| c.0.clone()); |
534 | Some(crate::tls::TlsInfo { peer_certificate }) |
535 | } |
536 | } |
537 | |
538 | #[cfg (feature = "__rustls" )] |
539 | impl TlsInfoFactory |
540 | for tokio_rustls::client::TlsStream<hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>> |
541 | { |
542 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
543 | let peer_certificate = self |
544 | .get_ref() |
545 | .1 |
546 | .peer_certificates() |
547 | .and_then(|certs| certs.first()) |
548 | .map(|c| c.0.clone()); |
549 | Some(crate::tls::TlsInfo { peer_certificate }) |
550 | } |
551 | } |
552 | |
553 | #[cfg (feature = "__rustls" )] |
554 | impl TlsInfoFactory for tokio_rustls::client::TlsStream<tokio::net::TcpStream> { |
555 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
556 | let peer_certificate = self |
557 | .get_ref() |
558 | .1 |
559 | .peer_certificates() |
560 | .and_then(|certs| certs.first()) |
561 | .map(|c| c.0.clone()); |
562 | Some(crate::tls::TlsInfo { peer_certificate }) |
563 | } |
564 | } |
565 | |
566 | pub(crate) trait AsyncConn: |
567 | AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static |
568 | { |
569 | } |
570 | |
571 | impl<T: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {} |
572 | |
573 | #[cfg (feature = "__tls" )] |
574 | trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {} |
575 | #[cfg (not(feature = "__tls" ))] |
576 | trait AsyncConnWithInfo: AsyncConn {} |
577 | |
578 | #[cfg (feature = "__tls" )] |
579 | impl<T: AsyncConn + TlsInfoFactory> AsyncConnWithInfo for T {} |
580 | #[cfg (not(feature = "__tls" ))] |
581 | impl<T: AsyncConn> AsyncConnWithInfo for T {} |
582 | |
583 | type BoxConn = Box<dyn AsyncConnWithInfo>; |
584 | |
585 | pin_project! { |
586 | /// Note: the `is_proxy` member means *is plain text HTTP proxy*. |
587 | /// This tells hyper whether the URI should be written in |
588 | /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or |
589 | /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. |
590 | pub(crate) struct Conn { |
591 | #[pin] |
592 | inner: BoxConn, |
593 | is_proxy: bool, |
594 | // Only needed for __tls, but #[cfg()] on fields breaks pin_project! |
595 | tls_info: bool, |
596 | } |
597 | } |
598 | |
599 | impl Connection for Conn { |
600 | fn connected(&self) -> Connected { |
601 | let connected: Connected = self.inner.connected().proxy(self.is_proxy); |
602 | #[cfg (feature = "__tls" )] |
603 | if self.tls_info { |
604 | if let Some(tls_info: TlsInfo) = self.inner.tls_info() { |
605 | connected.extra(tls_info) |
606 | } else { |
607 | connected |
608 | } |
609 | } else { |
610 | connected |
611 | } |
612 | #[cfg (not(feature = "__tls" ))] |
613 | connected |
614 | } |
615 | } |
616 | |
617 | impl AsyncRead for Conn { |
618 | fn poll_read( |
619 | self: Pin<&mut Self>, |
620 | cx: &mut Context, |
621 | buf: &mut ReadBuf<'_>, |
622 | ) -> Poll<io::Result<()>> { |
623 | let this: Projection<'_> = self.project(); |
624 | AsyncRead::poll_read(self:this.inner, cx, buf) |
625 | } |
626 | } |
627 | |
628 | impl AsyncWrite for Conn { |
629 | fn poll_write( |
630 | self: Pin<&mut Self>, |
631 | cx: &mut Context, |
632 | buf: &[u8], |
633 | ) -> Poll<Result<usize, io::Error>> { |
634 | let this = self.project(); |
635 | AsyncWrite::poll_write(this.inner, cx, buf) |
636 | } |
637 | |
638 | fn poll_write_vectored( |
639 | self: Pin<&mut Self>, |
640 | cx: &mut Context<'_>, |
641 | bufs: &[IoSlice<'_>], |
642 | ) -> Poll<Result<usize, io::Error>> { |
643 | let this = self.project(); |
644 | AsyncWrite::poll_write_vectored(this.inner, cx, bufs) |
645 | } |
646 | |
647 | fn is_write_vectored(&self) -> bool { |
648 | self.inner.is_write_vectored() |
649 | } |
650 | |
651 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> { |
652 | let this = self.project(); |
653 | AsyncWrite::poll_flush(this.inner, cx) |
654 | } |
655 | |
656 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> { |
657 | let this = self.project(); |
658 | AsyncWrite::poll_shutdown(this.inner, cx) |
659 | } |
660 | } |
661 | |
662 | pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>; |
663 | |
664 | #[cfg (feature = "__tls" )] |
665 | async fn tunnel<T>( |
666 | mut conn: T, |
667 | host: String, |
668 | port: u16, |
669 | user_agent: Option<HeaderValue>, |
670 | auth: Option<HeaderValue>, |
671 | ) -> Result<T, BoxError> |
672 | where |
673 | T: AsyncRead + AsyncWrite + Unpin, |
674 | { |
675 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; |
676 | |
677 | let mut buf = format!( |
678 | "\ |
679 | CONNECT {host}: {port} HTTP/1.1 \r\n\ |
680 | Host: {host}: {port}\r\n\ |
681 | " |
682 | ) |
683 | .into_bytes(); |
684 | |
685 | // user-agent |
686 | if let Some(user_agent) = user_agent { |
687 | buf.extend_from_slice(b"User-Agent: " ); |
688 | buf.extend_from_slice(user_agent.as_bytes()); |
689 | buf.extend_from_slice(b" \r\n" ); |
690 | } |
691 | |
692 | // proxy-authorization |
693 | if let Some(value) = auth { |
694 | log::debug!("tunnel to {host}: {port} using basic auth" ); |
695 | buf.extend_from_slice(b"Proxy-Authorization: " ); |
696 | buf.extend_from_slice(value.as_bytes()); |
697 | buf.extend_from_slice(b" \r\n" ); |
698 | } |
699 | |
700 | // headers end |
701 | buf.extend_from_slice(b" \r\n" ); |
702 | |
703 | conn.write_all(&buf).await?; |
704 | |
705 | let mut buf = [0; 8192]; |
706 | let mut pos = 0; |
707 | |
708 | loop { |
709 | let n = conn.read(&mut buf[pos..]).await?; |
710 | |
711 | if n == 0 { |
712 | return Err(tunnel_eof()); |
713 | } |
714 | pos += n; |
715 | |
716 | let recvd = &buf[..pos]; |
717 | if recvd.starts_with(b"HTTP/1.1 200" ) || recvd.starts_with(b"HTTP/1.0 200" ) { |
718 | if recvd.ends_with(b" \r\n\r\n" ) { |
719 | return Ok(conn); |
720 | } |
721 | if pos == buf.len() { |
722 | return Err("proxy headers too long for tunnel" .into()); |
723 | } |
724 | // else read more |
725 | } else if recvd.starts_with(b"HTTP/1.1 407" ) { |
726 | return Err("proxy authentication required" .into()); |
727 | } else { |
728 | return Err("unsuccessful tunnel" .into()); |
729 | } |
730 | } |
731 | } |
732 | |
733 | #[cfg (feature = "__tls" )] |
734 | fn tunnel_eof() -> BoxError { |
735 | "unexpected eof while tunneling" .into() |
736 | } |
737 | |
738 | #[cfg (feature = "default-tls" )] |
739 | mod native_tls_conn { |
740 | use super::TlsInfoFactory; |
741 | use hyper::client::connect::{Connected, Connection}; |
742 | use pin_project_lite::pin_project; |
743 | use std::{ |
744 | io::{self, IoSlice}, |
745 | pin::Pin, |
746 | task::{Context, Poll}, |
747 | }; |
748 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
749 | use tokio_native_tls::TlsStream; |
750 | |
751 | pin_project! { |
752 | pub(super) struct NativeTlsConn<T> { |
753 | #[pin] pub(super) inner: TlsStream<T>, |
754 | } |
755 | } |
756 | |
757 | impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for NativeTlsConn<T> { |
758 | #[cfg (feature = "native-tls-alpn" )] |
759 | fn connected(&self) -> Connected { |
760 | match self.inner.get_ref().negotiated_alpn().ok() { |
761 | Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => self |
762 | .inner |
763 | .get_ref() |
764 | .get_ref() |
765 | .get_ref() |
766 | .connected() |
767 | .negotiated_h2(), |
768 | _ => self.inner.get_ref().get_ref().get_ref().connected(), |
769 | } |
770 | } |
771 | |
772 | #[cfg (not(feature = "native-tls-alpn" ))] |
773 | fn connected(&self) -> Connected { |
774 | self.inner.get_ref().get_ref().get_ref().connected() |
775 | } |
776 | } |
777 | |
778 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsConn<T> { |
779 | fn poll_read( |
780 | self: Pin<&mut Self>, |
781 | cx: &mut Context, |
782 | buf: &mut ReadBuf<'_>, |
783 | ) -> Poll<tokio::io::Result<()>> { |
784 | let this = self.project(); |
785 | AsyncRead::poll_read(this.inner, cx, buf) |
786 | } |
787 | } |
788 | |
789 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsConn<T> { |
790 | fn poll_write( |
791 | self: Pin<&mut Self>, |
792 | cx: &mut Context, |
793 | buf: &[u8], |
794 | ) -> Poll<Result<usize, tokio::io::Error>> { |
795 | let this = self.project(); |
796 | AsyncWrite::poll_write(this.inner, cx, buf) |
797 | } |
798 | |
799 | fn poll_write_vectored( |
800 | self: Pin<&mut Self>, |
801 | cx: &mut Context<'_>, |
802 | bufs: &[IoSlice<'_>], |
803 | ) -> Poll<Result<usize, io::Error>> { |
804 | let this = self.project(); |
805 | AsyncWrite::poll_write_vectored(this.inner, cx, bufs) |
806 | } |
807 | |
808 | fn is_write_vectored(&self) -> bool { |
809 | self.inner.is_write_vectored() |
810 | } |
811 | |
812 | fn poll_flush( |
813 | self: Pin<&mut Self>, |
814 | cx: &mut Context, |
815 | ) -> Poll<Result<(), tokio::io::Error>> { |
816 | let this = self.project(); |
817 | AsyncWrite::poll_flush(this.inner, cx) |
818 | } |
819 | |
820 | fn poll_shutdown( |
821 | self: Pin<&mut Self>, |
822 | cx: &mut Context, |
823 | ) -> Poll<Result<(), tokio::io::Error>> { |
824 | let this = self.project(); |
825 | AsyncWrite::poll_shutdown(this.inner, cx) |
826 | } |
827 | } |
828 | |
829 | impl TlsInfoFactory for NativeTlsConn<tokio::net::TcpStream> { |
830 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
831 | self.inner.tls_info() |
832 | } |
833 | } |
834 | |
835 | impl TlsInfoFactory for NativeTlsConn<hyper_tls::MaybeHttpsStream<tokio::net::TcpStream>> { |
836 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
837 | self.inner.tls_info() |
838 | } |
839 | } |
840 | } |
841 | |
842 | #[cfg (feature = "__rustls" )] |
843 | mod rustls_tls_conn { |
844 | use super::TlsInfoFactory; |
845 | use hyper::client::connect::{Connected, Connection}; |
846 | use pin_project_lite::pin_project; |
847 | use std::{ |
848 | io::{self, IoSlice}, |
849 | pin::Pin, |
850 | task::{Context, Poll}, |
851 | }; |
852 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
853 | use tokio_rustls::client::TlsStream; |
854 | |
855 | pin_project! { |
856 | pub(super) struct RustlsTlsConn<T> { |
857 | #[pin] pub(super) inner: TlsStream<T>, |
858 | } |
859 | } |
860 | |
861 | impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> { |
862 | fn connected(&self) -> Connected { |
863 | if self.inner.get_ref().1.alpn_protocol() == Some(b"h2" ) { |
864 | self.inner.get_ref().0.connected().negotiated_h2() |
865 | } else { |
866 | self.inner.get_ref().0.connected() |
867 | } |
868 | } |
869 | } |
870 | |
871 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsTlsConn<T> { |
872 | fn poll_read( |
873 | self: Pin<&mut Self>, |
874 | cx: &mut Context, |
875 | buf: &mut ReadBuf<'_>, |
876 | ) -> Poll<tokio::io::Result<()>> { |
877 | let this = self.project(); |
878 | AsyncRead::poll_read(this.inner, cx, buf) |
879 | } |
880 | } |
881 | |
882 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> { |
883 | fn poll_write( |
884 | self: Pin<&mut Self>, |
885 | cx: &mut Context, |
886 | buf: &[u8], |
887 | ) -> Poll<Result<usize, tokio::io::Error>> { |
888 | let this = self.project(); |
889 | AsyncWrite::poll_write(this.inner, cx, buf) |
890 | } |
891 | |
892 | fn poll_write_vectored( |
893 | self: Pin<&mut Self>, |
894 | cx: &mut Context<'_>, |
895 | bufs: &[IoSlice<'_>], |
896 | ) -> Poll<Result<usize, io::Error>> { |
897 | let this = self.project(); |
898 | AsyncWrite::poll_write_vectored(this.inner, cx, bufs) |
899 | } |
900 | |
901 | fn is_write_vectored(&self) -> bool { |
902 | self.inner.is_write_vectored() |
903 | } |
904 | |
905 | fn poll_flush( |
906 | self: Pin<&mut Self>, |
907 | cx: &mut Context, |
908 | ) -> Poll<Result<(), tokio::io::Error>> { |
909 | let this = self.project(); |
910 | AsyncWrite::poll_flush(this.inner, cx) |
911 | } |
912 | |
913 | fn poll_shutdown( |
914 | self: Pin<&mut Self>, |
915 | cx: &mut Context, |
916 | ) -> Poll<Result<(), tokio::io::Error>> { |
917 | let this = self.project(); |
918 | AsyncWrite::poll_shutdown(this.inner, cx) |
919 | } |
920 | } |
921 | |
922 | impl TlsInfoFactory for RustlsTlsConn<tokio::net::TcpStream> { |
923 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
924 | self.inner.tls_info() |
925 | } |
926 | } |
927 | |
928 | impl TlsInfoFactory for RustlsTlsConn<hyper_rustls::MaybeHttpsStream<tokio::net::TcpStream>> { |
929 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
930 | self.inner.tls_info() |
931 | } |
932 | } |
933 | } |
934 | |
935 | #[cfg (feature = "socks" )] |
936 | mod socks { |
937 | use std::io; |
938 | use std::net::ToSocketAddrs; |
939 | |
940 | use http::Uri; |
941 | use tokio::net::TcpStream; |
942 | use tokio_socks::tcp::Socks5Stream; |
943 | |
944 | use super::{BoxError, Scheme}; |
945 | use crate::proxy::ProxyScheme; |
946 | |
947 | pub(super) enum DnsResolve { |
948 | Local, |
949 | Proxy, |
950 | } |
951 | |
952 | pub(super) async fn connect( |
953 | proxy: ProxyScheme, |
954 | dst: Uri, |
955 | dns: DnsResolve, |
956 | ) -> Result<TcpStream, BoxError> { |
957 | let https = dst.scheme() == Some(&Scheme::HTTPS); |
958 | let original_host = dst |
959 | .host() |
960 | .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url" ))?; |
961 | let mut host = original_host.to_owned(); |
962 | let port = match dst.port() { |
963 | Some(p) => p.as_u16(), |
964 | None if https => 443u16, |
965 | _ => 80u16, |
966 | }; |
967 | |
968 | if let DnsResolve::Local = dns { |
969 | let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next(); |
970 | if let Some(new_target) = maybe_new_target { |
971 | host = new_target.ip().to_string(); |
972 | } |
973 | } |
974 | |
975 | let (socket_addr, auth) = match proxy { |
976 | ProxyScheme::Socks5 { addr, auth, .. } => (addr, auth), |
977 | _ => unreachable!(), |
978 | }; |
979 | |
980 | // Get a Tokio TcpStream |
981 | let stream = if let Some((username, password)) = auth { |
982 | Socks5Stream::connect_with_password( |
983 | socket_addr, |
984 | (host.as_str(), port), |
985 | &username, |
986 | &password, |
987 | ) |
988 | .await |
989 | .map_err(|e| format!("socks connect error: {e}" ))? |
990 | } else { |
991 | Socks5Stream::connect(socket_addr, (host.as_str(), port)) |
992 | .await |
993 | .map_err(|e| format!("socks connect error: {e}" ))? |
994 | }; |
995 | |
996 | Ok(stream.into_inner()) |
997 | } |
998 | } |
999 | |
1000 | mod verbose { |
1001 | use hyper::client::connect::{Connected, Connection}; |
1002 | use std::cmp::min; |
1003 | use std::fmt; |
1004 | use std::io::{self, IoSlice}; |
1005 | use std::pin::Pin; |
1006 | use std::task::{Context, Poll}; |
1007 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
1008 | |
1009 | pub(super) const OFF: Wrapper = Wrapper(false); |
1010 | |
1011 | #[derive (Clone, Copy)] |
1012 | pub(super) struct Wrapper(pub(super) bool); |
1013 | |
1014 | impl Wrapper { |
1015 | pub(super) fn wrap<T: super::AsyncConnWithInfo>(&self, conn: T) -> super::BoxConn { |
1016 | if self.0 && log::log_enabled!(log::Level::Trace) { |
1017 | Box::new(Verbose { |
1018 | // truncate is fine |
1019 | id: crate::util::fast_random() as u32, |
1020 | inner: conn, |
1021 | }) |
1022 | } else { |
1023 | Box::new(conn) |
1024 | } |
1025 | } |
1026 | } |
1027 | |
1028 | struct Verbose<T> { |
1029 | id: u32, |
1030 | inner: T, |
1031 | } |
1032 | |
1033 | impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for Verbose<T> { |
1034 | fn connected(&self) -> Connected { |
1035 | self.inner.connected() |
1036 | } |
1037 | } |
1038 | |
1039 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for Verbose<T> { |
1040 | fn poll_read( |
1041 | mut self: Pin<&mut Self>, |
1042 | cx: &mut Context, |
1043 | buf: &mut ReadBuf<'_>, |
1044 | ) -> Poll<std::io::Result<()>> { |
1045 | match Pin::new(&mut self.inner).poll_read(cx, buf) { |
1046 | Poll::Ready(Ok(())) => { |
1047 | log::trace!(" {:08x} read: {:?}" , self.id, Escape(buf.filled())); |
1048 | Poll::Ready(Ok(())) |
1049 | } |
1050 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
1051 | Poll::Pending => Poll::Pending, |
1052 | } |
1053 | } |
1054 | } |
1055 | |
1056 | impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Verbose<T> { |
1057 | fn poll_write( |
1058 | mut self: Pin<&mut Self>, |
1059 | cx: &mut Context, |
1060 | buf: &[u8], |
1061 | ) -> Poll<Result<usize, std::io::Error>> { |
1062 | match Pin::new(&mut self.inner).poll_write(cx, buf) { |
1063 | Poll::Ready(Ok(n)) => { |
1064 | log::trace!(" {:08x} write: {:?}" , self.id, Escape(&buf[..n])); |
1065 | Poll::Ready(Ok(n)) |
1066 | } |
1067 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
1068 | Poll::Pending => Poll::Pending, |
1069 | } |
1070 | } |
1071 | |
1072 | fn poll_write_vectored( |
1073 | mut self: Pin<&mut Self>, |
1074 | cx: &mut Context<'_>, |
1075 | bufs: &[IoSlice<'_>], |
1076 | ) -> Poll<Result<usize, io::Error>> { |
1077 | match Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) { |
1078 | Poll::Ready(Ok(nwritten)) => { |
1079 | log::trace!( |
1080 | " {:08x} write (vectored): {:?}" , |
1081 | self.id, |
1082 | Vectored { bufs, nwritten } |
1083 | ); |
1084 | Poll::Ready(Ok(nwritten)) |
1085 | } |
1086 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
1087 | Poll::Pending => Poll::Pending, |
1088 | } |
1089 | } |
1090 | |
1091 | fn is_write_vectored(&self) -> bool { |
1092 | self.inner.is_write_vectored() |
1093 | } |
1094 | |
1095 | fn poll_flush( |
1096 | mut self: Pin<&mut Self>, |
1097 | cx: &mut Context, |
1098 | ) -> Poll<Result<(), std::io::Error>> { |
1099 | Pin::new(&mut self.inner).poll_flush(cx) |
1100 | } |
1101 | |
1102 | fn poll_shutdown( |
1103 | mut self: Pin<&mut Self>, |
1104 | cx: &mut Context, |
1105 | ) -> Poll<Result<(), std::io::Error>> { |
1106 | Pin::new(&mut self.inner).poll_shutdown(cx) |
1107 | } |
1108 | } |
1109 | |
1110 | #[cfg (feature = "__tls" )] |
1111 | impl<T: super::TlsInfoFactory> super::TlsInfoFactory for Verbose<T> { |
1112 | fn tls_info(&self) -> Option<crate::tls::TlsInfo> { |
1113 | self.inner.tls_info() |
1114 | } |
1115 | } |
1116 | |
1117 | struct Escape<'a>(&'a [u8]); |
1118 | |
1119 | impl fmt::Debug for Escape<'_> { |
1120 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
1121 | write!(f, "b \"" )?; |
1122 | for &c in self.0 { |
1123 | // https://doc.rust-lang.org/reference.html#byte-escapes |
1124 | if c == b' \n' { |
1125 | write!(f, " \\n" )?; |
1126 | } else if c == b' \r' { |
1127 | write!(f, " \\r" )?; |
1128 | } else if c == b' \t' { |
1129 | write!(f, " \\t" )?; |
1130 | } else if c == b' \\' || c == b'"' { |
1131 | write!(f, " \\{}" , c as char)?; |
1132 | } else if c == b' \0' { |
1133 | write!(f, " \\0" )?; |
1134 | // ASCII printable |
1135 | } else if c >= 0x20 && c < 0x7f { |
1136 | write!(f, " {}" , c as char)?; |
1137 | } else { |
1138 | write!(f, " \\x {c:02x}" )?; |
1139 | } |
1140 | } |
1141 | write!(f, " \"" )?; |
1142 | Ok(()) |
1143 | } |
1144 | } |
1145 | |
1146 | struct Vectored<'a, 'b> { |
1147 | bufs: &'a [IoSlice<'b>], |
1148 | nwritten: usize, |
1149 | } |
1150 | |
1151 | impl fmt::Debug for Vectored<'_, '_> { |
1152 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
1153 | let mut left = self.nwritten; |
1154 | for buf in self.bufs.iter() { |
1155 | if left == 0 { |
1156 | break; |
1157 | } |
1158 | let n = min(left, buf.len()); |
1159 | Escape(&buf[..n]).fmt(f)?; |
1160 | left -= n; |
1161 | } |
1162 | Ok(()) |
1163 | } |
1164 | } |
1165 | } |
1166 | |
1167 | #[cfg (feature = "__tls" )] |
1168 | #[cfg (test)] |
1169 | mod tests { |
1170 | use super::tunnel; |
1171 | use crate::proxy; |
1172 | use std::io::{Read, Write}; |
1173 | use std::net::TcpListener; |
1174 | use std::thread; |
1175 | use tokio::net::TcpStream; |
1176 | use tokio::runtime; |
1177 | |
1178 | static TUNNEL_UA: &str = "tunnel-test/x.y" ; |
1179 | static TUNNEL_OK: &[u8] = b"\ |
1180 | HTTP/1.1 200 OK \r\n\ |
1181 | \r\n\ |
1182 | " ; |
1183 | |
1184 | macro_rules! mock_tunnel { |
1185 | () => {{ |
1186 | mock_tunnel!(TUNNEL_OK) |
1187 | }}; |
1188 | ($write:expr) => {{ |
1189 | mock_tunnel!($write, "" ) |
1190 | }}; |
1191 | ($write:expr, $auth:expr) => {{ |
1192 | let listener = TcpListener::bind("127.0.0.1:0" ).unwrap(); |
1193 | let addr = listener.local_addr().unwrap(); |
1194 | let connect_expected = format!( |
1195 | "\ |
1196 | CONNECT {0}:{1} HTTP/1.1 \r\n\ |
1197 | Host: {0}:{1} \r\n\ |
1198 | User-Agent: {2} \r\n\ |
1199 | {3}\ |
1200 | \r\n\ |
1201 | " , |
1202 | addr.ip(), |
1203 | addr.port(), |
1204 | TUNNEL_UA, |
1205 | $auth |
1206 | ) |
1207 | .into_bytes(); |
1208 | |
1209 | thread::spawn(move || { |
1210 | let (mut sock, _) = listener.accept().unwrap(); |
1211 | let mut buf = [0u8; 4096]; |
1212 | let n = sock.read(&mut buf).unwrap(); |
1213 | assert_eq!(&buf[..n], &connect_expected[..]); |
1214 | |
1215 | sock.write_all($write).unwrap(); |
1216 | }); |
1217 | addr |
1218 | }}; |
1219 | } |
1220 | |
1221 | fn ua() -> Option<http::header::HeaderValue> { |
1222 | Some(http::header::HeaderValue::from_static(TUNNEL_UA)) |
1223 | } |
1224 | |
1225 | #[test ] |
1226 | fn test_tunnel() { |
1227 | let addr = mock_tunnel!(); |
1228 | |
1229 | let rt = runtime::Builder::new_current_thread() |
1230 | .enable_all() |
1231 | .build() |
1232 | .expect("new rt" ); |
1233 | let f = async move { |
1234 | let tcp = TcpStream::connect(&addr).await?; |
1235 | let host = addr.ip().to_string(); |
1236 | let port = addr.port(); |
1237 | tunnel(tcp, host, port, ua(), None).await |
1238 | }; |
1239 | |
1240 | rt.block_on(f).unwrap(); |
1241 | } |
1242 | |
1243 | #[test ] |
1244 | fn test_tunnel_eof() { |
1245 | let addr = mock_tunnel!(b"HTTP/1.1 200 OK" ); |
1246 | |
1247 | let rt = runtime::Builder::new_current_thread() |
1248 | .enable_all() |
1249 | .build() |
1250 | .expect("new rt" ); |
1251 | let f = async move { |
1252 | let tcp = TcpStream::connect(&addr).await?; |
1253 | let host = addr.ip().to_string(); |
1254 | let port = addr.port(); |
1255 | tunnel(tcp, host, port, ua(), None).await |
1256 | }; |
1257 | |
1258 | rt.block_on(f).unwrap_err(); |
1259 | } |
1260 | |
1261 | #[test ] |
1262 | fn test_tunnel_non_http_response() { |
1263 | let addr = mock_tunnel!(b"foo bar baz hallo" ); |
1264 | |
1265 | let rt = runtime::Builder::new_current_thread() |
1266 | .enable_all() |
1267 | .build() |
1268 | .expect("new rt" ); |
1269 | let f = async move { |
1270 | let tcp = TcpStream::connect(&addr).await?; |
1271 | let host = addr.ip().to_string(); |
1272 | let port = addr.port(); |
1273 | tunnel(tcp, host, port, ua(), None).await |
1274 | }; |
1275 | |
1276 | rt.block_on(f).unwrap_err(); |
1277 | } |
1278 | |
1279 | #[test ] |
1280 | fn test_tunnel_proxy_unauthorized() { |
1281 | let addr = mock_tunnel!( |
1282 | b"\ |
1283 | HTTP/1.1 407 Proxy Authentication Required \r\n\ |
1284 | Proxy-Authenticate: Basic realm= \"nope \"\r\n\ |
1285 | \r\n\ |
1286 | " |
1287 | ); |
1288 | |
1289 | let rt = runtime::Builder::new_current_thread() |
1290 | .enable_all() |
1291 | .build() |
1292 | .expect("new rt" ); |
1293 | let f = async move { |
1294 | let tcp = TcpStream::connect(&addr).await?; |
1295 | let host = addr.ip().to_string(); |
1296 | let port = addr.port(); |
1297 | tunnel(tcp, host, port, ua(), None).await |
1298 | }; |
1299 | |
1300 | let error = rt.block_on(f).unwrap_err(); |
1301 | assert_eq!(error.to_string(), "proxy authentication required" ); |
1302 | } |
1303 | |
1304 | #[test ] |
1305 | fn test_tunnel_basic_auth() { |
1306 | let addr = mock_tunnel!( |
1307 | TUNNEL_OK, |
1308 | "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ== \r\n" |
1309 | ); |
1310 | |
1311 | let rt = runtime::Builder::new_current_thread() |
1312 | .enable_all() |
1313 | .build() |
1314 | .expect("new rt" ); |
1315 | let f = async move { |
1316 | let tcp = TcpStream::connect(&addr).await?; |
1317 | let host = addr.ip().to_string(); |
1318 | let port = addr.port(); |
1319 | tunnel( |
1320 | tcp, |
1321 | host, |
1322 | port, |
1323 | ua(), |
1324 | Some(proxy::encode_basic_auth("Aladdin" , "open sesame" )), |
1325 | ) |
1326 | .await |
1327 | }; |
1328 | |
1329 | rt.block_on(f).unwrap(); |
1330 | } |
1331 | } |
1332 | |