1 | // Copied from hyperium/hyper-tls#62e3376/src/stream.rs |
2 | use std::fmt; |
3 | use std::io; |
4 | use std::pin::Pin; |
5 | use std::task::{Context, Poll}; |
6 | |
7 | use hyper::rt; |
8 | use hyper_util::client::legacy::connect::{Connected, Connection}; |
9 | |
10 | use hyper_util::rt::TokioIo; |
11 | use tokio_rustls::client::TlsStream; |
12 | |
13 | /// A stream that might be protected with TLS. |
14 | #[allow (clippy::large_enum_variant)] |
15 | pub enum MaybeHttpsStream<T> { |
16 | /// A stream over plain text. |
17 | Http(T), |
18 | /// A stream protected with TLS. |
19 | Https(TokioIo<TlsStream<TokioIo<T>>>), |
20 | } |
21 | |
22 | impl<T: rt::Read + rt::Write + Connection + Unpin> Connection for MaybeHttpsStream<T> { |
23 | fn connected(&self) -> Connected { |
24 | match self { |
25 | Self::Http(s: &T) => s.connected(), |
26 | Self::Https(s: &TokioIo>>) => { |
27 | let (tcp: &TokioIo, tls: &ClientConnection) = s.inner().get_ref(); |
28 | if tls.alpn_protocol() == Some(b"h2" ) { |
29 | tcp.inner().connected().negotiated_h2() |
30 | } else { |
31 | tcp.inner().connected() |
32 | } |
33 | } |
34 | } |
35 | } |
36 | } |
37 | |
38 | impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> { |
39 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
40 | match *self { |
41 | Self::Http(..) => f.pad("Http(..)" ), |
42 | Self::Https(..) => f.pad("Https(..)" ), |
43 | } |
44 | } |
45 | } |
46 | |
47 | impl<T> From<T> for MaybeHttpsStream<T> { |
48 | fn from(inner: T) -> Self { |
49 | Self::Http(inner) |
50 | } |
51 | } |
52 | |
53 | impl<T> From<TlsStream<TokioIo<T>>> for MaybeHttpsStream<T> { |
54 | fn from(inner: TlsStream<TokioIo<T>>) -> Self { |
55 | Self::Https(TokioIo::new(inner)) |
56 | } |
57 | } |
58 | |
59 | impl<T: rt::Read + rt::Write + Unpin> rt::Read for MaybeHttpsStream<T> { |
60 | #[inline ] |
61 | fn poll_read( |
62 | self: Pin<&mut Self>, |
63 | cx: &mut Context, |
64 | buf: rt::ReadBufCursor<'_>, |
65 | ) -> Poll<Result<(), io::Error>> { |
66 | match Pin::get_mut(self) { |
67 | Self::Http(s: &mut T) => Pin::new(pointer:s).poll_read(cx, buf), |
68 | Self::Https(s: &mut TokioIo>>) => Pin::new(pointer:s).poll_read(cx, buf), |
69 | } |
70 | } |
71 | } |
72 | |
73 | impl<T: rt::Write + rt::Read + Unpin> rt::Write for MaybeHttpsStream<T> { |
74 | #[inline ] |
75 | fn poll_write( |
76 | self: Pin<&mut Self>, |
77 | cx: &mut Context<'_>, |
78 | buf: &[u8], |
79 | ) -> Poll<Result<usize, io::Error>> { |
80 | match Pin::get_mut(self) { |
81 | Self::Http(s) => Pin::new(s).poll_write(cx, buf), |
82 | Self::Https(s) => Pin::new(s).poll_write(cx, buf), |
83 | } |
84 | } |
85 | |
86 | #[inline ] |
87 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
88 | match Pin::get_mut(self) { |
89 | Self::Http(s) => Pin::new(s).poll_flush(cx), |
90 | Self::Https(s) => Pin::new(s).poll_flush(cx), |
91 | } |
92 | } |
93 | |
94 | #[inline ] |
95 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
96 | match Pin::get_mut(self) { |
97 | Self::Http(s) => Pin::new(s).poll_shutdown(cx), |
98 | Self::Https(s) => Pin::new(s).poll_shutdown(cx), |
99 | } |
100 | } |
101 | |
102 | #[inline ] |
103 | fn is_write_vectored(&self) -> bool { |
104 | match self { |
105 | Self::Http(s) => s.is_write_vectored(), |
106 | Self::Https(s) => s.is_write_vectored(), |
107 | } |
108 | } |
109 | |
110 | #[inline ] |
111 | fn poll_write_vectored( |
112 | self: Pin<&mut Self>, |
113 | cx: &mut Context<'_>, |
114 | bufs: &[io::IoSlice<'_>], |
115 | ) -> Poll<Result<usize, io::Error>> { |
116 | match Pin::get_mut(self) { |
117 | Self::Http(s) => Pin::new(s).poll_write_vectored(cx, bufs), |
118 | Self::Https(s) => Pin::new(s).poll_write_vectored(cx, bufs), |
119 | } |
120 | } |
121 | } |
122 | |