1 | use std::io::{self, BufRead as _}; |
2 | #[cfg (unix)] |
3 | use std::os::unix::io::{AsRawFd, RawFd}; |
4 | #[cfg (windows)] |
5 | use std::os::windows::io::{AsRawSocket, RawSocket}; |
6 | use std::pin::Pin; |
7 | use std::task::{Context, Poll}; |
8 | |
9 | use rustls::ServerConnection; |
10 | use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; |
11 | |
12 | use crate::common::{IoSession, Stream, TlsState}; |
13 | |
14 | /// A wrapper around an underlying raw stream which implements the TLS or SSL |
15 | /// protocol. |
16 | #[derive (Debug)] |
17 | pub struct TlsStream<IO> { |
18 | pub(crate) io: IO, |
19 | pub(crate) session: ServerConnection, |
20 | pub(crate) state: TlsState, |
21 | } |
22 | |
23 | impl<IO> TlsStream<IO> { |
24 | #[inline ] |
25 | pub fn get_ref(&self) -> (&IO, &ServerConnection) { |
26 | (&self.io, &self.session) |
27 | } |
28 | |
29 | #[inline ] |
30 | pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) { |
31 | (&mut self.io, &mut self.session) |
32 | } |
33 | |
34 | #[inline ] |
35 | pub fn into_inner(self) -> (IO, ServerConnection) { |
36 | (self.io, self.session) |
37 | } |
38 | } |
39 | |
40 | impl<IO> IoSession for TlsStream<IO> { |
41 | type Io = IO; |
42 | type Session = ServerConnection; |
43 | |
44 | #[inline ] |
45 | fn skip_handshake(&self) -> bool { |
46 | false |
47 | } |
48 | |
49 | #[inline ] |
50 | fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { |
51 | (&mut self.state, &mut self.io, &mut self.session) |
52 | } |
53 | |
54 | #[inline ] |
55 | fn into_io(self) -> Self::Io { |
56 | self.io |
57 | } |
58 | } |
59 | |
60 | impl<IO> AsyncRead for TlsStream<IO> |
61 | where |
62 | IO: AsyncRead + AsyncWrite + Unpin, |
63 | { |
64 | fn poll_read( |
65 | mut self: Pin<&mut Self>, |
66 | cx: &mut Context<'_>, |
67 | buf: &mut ReadBuf<'_>, |
68 | ) -> Poll<io::Result<()>> { |
69 | let data: &[u8] = ready!(self.as_mut().poll_fill_buf(cx))?; |
70 | let len: usize = data.len().min(buf.remaining()); |
71 | buf.put_slice(&data[..len]); |
72 | self.consume(amt:len); |
73 | Poll::Ready(Ok(())) |
74 | } |
75 | } |
76 | |
77 | impl<IO> AsyncBufRead for TlsStream<IO> |
78 | where |
79 | IO: AsyncRead + AsyncWrite + Unpin, |
80 | { |
81 | fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { |
82 | match self.state { |
83 | TlsState::Stream | TlsState::WriteShutdown => { |
84 | let this = self.get_mut(); |
85 | let stream = |
86 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
87 | |
88 | match stream.poll_fill_buf(cx) { |
89 | Poll::Ready(Ok(buf)) => { |
90 | if buf.is_empty() { |
91 | this.state.shutdown_read(); |
92 | } |
93 | |
94 | Poll::Ready(Ok(buf)) |
95 | } |
96 | Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { |
97 | this.state.shutdown_read(); |
98 | Poll::Ready(Err(err)) |
99 | } |
100 | output => output, |
101 | } |
102 | } |
103 | TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])), |
104 | #[cfg (feature = "early-data" )] |
105 | ref s => unreachable!("server TLS can not hit this state: {:?}" , s), |
106 | } |
107 | } |
108 | |
109 | fn consume(mut self: Pin<&mut Self>, amt: usize) { |
110 | self.session.reader().consume(amt); |
111 | } |
112 | } |
113 | |
114 | impl<IO> AsyncWrite for TlsStream<IO> |
115 | where |
116 | IO: AsyncRead + AsyncWrite + Unpin, |
117 | { |
118 | /// Note: that it does not guarantee the final data to be sent. |
119 | /// To be cautious, you must manually call `flush`. |
120 | fn poll_write( |
121 | self: Pin<&mut Self>, |
122 | cx: &mut Context<'_>, |
123 | buf: &[u8], |
124 | ) -> Poll<io::Result<usize>> { |
125 | let this = self.get_mut(); |
126 | let mut stream = |
127 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
128 | stream.as_mut_pin().poll_write(cx, buf) |
129 | } |
130 | |
131 | /// Note: that it does not guarantee the final data to be sent. |
132 | /// To be cautious, you must manually call `flush`. |
133 | fn poll_write_vectored( |
134 | self: Pin<&mut Self>, |
135 | cx: &mut Context<'_>, |
136 | bufs: &[io::IoSlice<'_>], |
137 | ) -> Poll<io::Result<usize>> { |
138 | let this = self.get_mut(); |
139 | let mut stream = |
140 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
141 | stream.as_mut_pin().poll_write_vectored(cx, bufs) |
142 | } |
143 | |
144 | #[inline ] |
145 | fn is_write_vectored(&self) -> bool { |
146 | true |
147 | } |
148 | |
149 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
150 | let this = self.get_mut(); |
151 | let mut stream = |
152 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
153 | stream.as_mut_pin().poll_flush(cx) |
154 | } |
155 | |
156 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
157 | if self.state.writeable() { |
158 | self.session.send_close_notify(); |
159 | self.state.shutdown_write(); |
160 | } |
161 | |
162 | let this = self.get_mut(); |
163 | let mut stream = |
164 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
165 | stream.as_mut_pin().poll_shutdown(cx) |
166 | } |
167 | } |
168 | |
169 | #[cfg (unix)] |
170 | impl<IO> AsRawFd for TlsStream<IO> |
171 | where |
172 | IO: AsRawFd, |
173 | { |
174 | fn as_raw_fd(&self) -> RawFd { |
175 | self.get_ref().0.as_raw_fd() |
176 | } |
177 | } |
178 | |
179 | #[cfg (windows)] |
180 | impl<IO> AsRawSocket for TlsStream<IO> |
181 | where |
182 | IO: AsRawSocket, |
183 | { |
184 | fn as_raw_socket(&self) -> RawSocket { |
185 | self.get_ref().0.as_raw_socket() |
186 | } |
187 | } |
188 | |