1 | use std::fmt; |
2 | use std::io; |
3 | use std::io::IoSlice; |
4 | use std::pin::Pin; |
5 | use std::task::{Context, Poll}; |
6 | |
7 | use hyper::rt::{Read, ReadBufCursor, Write}; |
8 | |
9 | use hyper_util::{ |
10 | client::legacy::connect::{Connected, Connection}, |
11 | rt::TokioIo, |
12 | }; |
13 | pub use tokio_native_tls::TlsStream; |
14 | |
15 | /// A stream that might be protected with TLS. |
16 | pub enum MaybeHttpsStream<T> { |
17 | /// A stream over plain text. |
18 | Http(T), |
19 | /// A stream protected with TLS. |
20 | Https(TokioIo<TlsStream<TokioIo<T>>>), |
21 | } |
22 | |
23 | // ===== impl MaybeHttpsStream ===== |
24 | |
25 | impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> { |
26 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
27 | match self { |
28 | MaybeHttpsStream::Http(s: &T) => f.debug_tuple(name:"Http" ).field(s).finish(), |
29 | MaybeHttpsStream::Https(s: &TokioIo>>) => f.debug_tuple(name:"Https" ).field(s).finish(), |
30 | } |
31 | } |
32 | } |
33 | |
34 | impl<T> From<T> for MaybeHttpsStream<T> { |
35 | fn from(inner: T) -> Self { |
36 | MaybeHttpsStream::Http(inner) |
37 | } |
38 | } |
39 | |
40 | impl<T> From<TlsStream<TokioIo<T>>> for MaybeHttpsStream<T> { |
41 | fn from(inner: TlsStream<TokioIo<T>>) -> Self { |
42 | MaybeHttpsStream::Https(TokioIo::new(inner)) |
43 | } |
44 | } |
45 | |
46 | impl<T> From<TokioIo<TlsStream<TokioIo<T>>>> for MaybeHttpsStream<T> { |
47 | fn from(inner: TokioIo<TlsStream<TokioIo<T>>>) -> Self { |
48 | MaybeHttpsStream::Https(inner) |
49 | } |
50 | } |
51 | |
52 | impl<T: Read + Write + Unpin> Read for MaybeHttpsStream<T> { |
53 | #[inline ] |
54 | fn poll_read( |
55 | self: Pin<&mut Self>, |
56 | cx: &mut Context, |
57 | buf: ReadBufCursor<'_>, |
58 | ) -> Poll<Result<(), io::Error>> { |
59 | match Pin::get_mut(self) { |
60 | MaybeHttpsStream::Http(s: &mut T) => Pin::new(pointer:s).poll_read(cx, buf), |
61 | MaybeHttpsStream::Https(s: &mut TokioIo>>) => Pin::new(pointer:s).poll_read(cx, buf), |
62 | } |
63 | } |
64 | } |
65 | |
66 | impl<T: Write + Read + Unpin> Write for MaybeHttpsStream<T> { |
67 | #[inline ] |
68 | fn poll_write( |
69 | self: Pin<&mut Self>, |
70 | cx: &mut Context<'_>, |
71 | buf: &[u8], |
72 | ) -> Poll<Result<usize, io::Error>> { |
73 | match Pin::get_mut(self) { |
74 | MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(cx, buf), |
75 | MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(cx, buf), |
76 | } |
77 | } |
78 | |
79 | fn poll_write_vectored( |
80 | self: Pin<&mut Self>, |
81 | cx: &mut Context<'_>, |
82 | bufs: &[IoSlice<'_>], |
83 | ) -> Poll<Result<usize, io::Error>> { |
84 | match Pin::get_mut(self) { |
85 | MaybeHttpsStream::Http(s) => Pin::new(s).poll_write_vectored(cx, bufs), |
86 | MaybeHttpsStream::Https(s) => Pin::new(s).poll_write_vectored(cx, bufs), |
87 | } |
88 | } |
89 | |
90 | fn is_write_vectored(&self) -> bool { |
91 | match self { |
92 | MaybeHttpsStream::Http(s) => s.is_write_vectored(), |
93 | MaybeHttpsStream::Https(s) => s.is_write_vectored(), |
94 | } |
95 | } |
96 | |
97 | #[inline ] |
98 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
99 | match Pin::get_mut(self) { |
100 | MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(cx), |
101 | MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(cx), |
102 | } |
103 | } |
104 | |
105 | #[inline ] |
106 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
107 | match Pin::get_mut(self) { |
108 | MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(cx), |
109 | MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(cx), |
110 | } |
111 | } |
112 | } |
113 | |
114 | impl<T: Write + Read + Connection + Unpin> Connection for MaybeHttpsStream<T> { |
115 | fn connected(&self) -> Connected { |
116 | match self { |
117 | MaybeHttpsStream::Http(s: &T) => s.connected(), |
118 | MaybeHttpsStream::Https(s: &TokioIo>>) => { |
119 | let c: Connected = s.inner().get_ref().get_ref().get_ref().inner().connected(); |
120 | #[cfg (feature = "alpn" )] |
121 | { |
122 | if negotiated_h2(s.inner().get_ref()) { |
123 | return c.negotiated_h2(); |
124 | } |
125 | } |
126 | c |
127 | } |
128 | } |
129 | } |
130 | } |
131 | |
132 | #[cfg (feature = "alpn" )] |
133 | fn negotiated_h2<T: std::io::Read + std::io::Write>(s: &native_tls::TlsStream<T>) -> bool { |
134 | s.negotiated_alpn() |
135 | .unwrap_or(None) |
136 | .map(|list| list == &b"h2" [..]) |
137 | .unwrap_or(false) |
138 | } |
139 | |