1 | use std::io::{self, BufRead as _}; |
2 | #[cfg (unix)] |
3 | use std::os::unix::io::{AsRawFd, RawFd}; |
4 | #[cfg (windows)] |
5 | use std::os::windows::io::{AsRawSocket, RawSocket}; |
6 | use std::pin::Pin; |
7 | #[cfg (feature = "early-data" )] |
8 | use std::task::Waker; |
9 | use std::task::{Context, Poll}; |
10 | |
11 | use rustls::ClientConnection; |
12 | use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; |
13 | |
14 | use crate::common::{IoSession, Stream, TlsState}; |
15 | |
16 | /// A wrapper around an underlying raw stream which implements the TLS or SSL |
17 | /// protocol. |
18 | #[derive (Debug)] |
19 | pub struct TlsStream<IO> { |
20 | pub(crate) io: IO, |
21 | pub(crate) session: ClientConnection, |
22 | pub(crate) state: TlsState, |
23 | |
24 | #[cfg (feature = "early-data" )] |
25 | pub(crate) early_waker: Option<Waker>, |
26 | } |
27 | |
28 | impl<IO> TlsStream<IO> { |
29 | #[inline ] |
30 | pub fn get_ref(&self) -> (&IO, &ClientConnection) { |
31 | (&self.io, &self.session) |
32 | } |
33 | |
34 | #[inline ] |
35 | pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) { |
36 | (&mut self.io, &mut self.session) |
37 | } |
38 | |
39 | #[inline ] |
40 | pub fn into_inner(self) -> (IO, ClientConnection) { |
41 | (self.io, self.session) |
42 | } |
43 | } |
44 | |
45 | #[cfg (unix)] |
46 | impl<S> AsRawFd for TlsStream<S> |
47 | where |
48 | S: AsRawFd, |
49 | { |
50 | fn as_raw_fd(&self) -> RawFd { |
51 | self.get_ref().0.as_raw_fd() |
52 | } |
53 | } |
54 | |
55 | #[cfg (windows)] |
56 | impl<S> AsRawSocket for TlsStream<S> |
57 | where |
58 | S: AsRawSocket, |
59 | { |
60 | fn as_raw_socket(&self) -> RawSocket { |
61 | self.get_ref().0.as_raw_socket() |
62 | } |
63 | } |
64 | |
65 | impl<IO> IoSession for TlsStream<IO> { |
66 | type Io = IO; |
67 | type Session = ClientConnection; |
68 | |
69 | #[inline ] |
70 | fn skip_handshake(&self) -> bool { |
71 | self.state.is_early_data() |
72 | } |
73 | |
74 | #[inline ] |
75 | fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { |
76 | (&mut self.state, &mut self.io, &mut self.session) |
77 | } |
78 | |
79 | #[inline ] |
80 | fn into_io(self) -> Self::Io { |
81 | self.io |
82 | } |
83 | } |
84 | |
85 | #[cfg (feature = "early-data" )] |
86 | impl<IO> TlsStream<IO> |
87 | where |
88 | IO: AsyncRead + AsyncWrite + Unpin, |
89 | { |
90 | fn poll_early_data(&mut self, cx: &mut Context<'_>) { |
91 | // In the EarlyData state, we have not really established a Tls connection. |
92 | // Before writing data through `AsyncWrite` and completing the tls handshake, |
93 | // we ignore read readiness and return to pending. |
94 | // |
95 | // In order to avoid event loss, |
96 | // we need to register a waker and wake it up after tls is connected. |
97 | if self |
98 | .early_waker |
99 | .as_ref() |
100 | .filter(|waker| cx.waker().will_wake(waker)) |
101 | .is_none() |
102 | { |
103 | self.early_waker = Some(cx.waker().clone()); |
104 | } |
105 | } |
106 | } |
107 | |
108 | impl<IO> AsyncRead for TlsStream<IO> |
109 | where |
110 | IO: AsyncRead + AsyncWrite + Unpin, |
111 | { |
112 | fn poll_read( |
113 | mut self: Pin<&mut Self>, |
114 | cx: &mut Context<'_>, |
115 | buf: &mut ReadBuf<'_>, |
116 | ) -> Poll<io::Result<()>> { |
117 | let data: &[u8] = ready!(self.as_mut().poll_fill_buf(cx))?; |
118 | let len: usize = data.len().min(buf.remaining()); |
119 | buf.put_slice(&data[..len]); |
120 | self.consume(amt:len); |
121 | Poll::Ready(Ok(())) |
122 | } |
123 | } |
124 | |
125 | impl<IO> AsyncBufRead for TlsStream<IO> |
126 | where |
127 | IO: AsyncRead + AsyncWrite + Unpin, |
128 | { |
129 | fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { |
130 | match self.state { |
131 | #[cfg (feature = "early-data" )] |
132 | TlsState::EarlyData(..) => { |
133 | self.get_mut().poll_early_data(cx); |
134 | Poll::Pending |
135 | } |
136 | TlsState::Stream | TlsState::WriteShutdown => { |
137 | let this = self.get_mut(); |
138 | let stream = |
139 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
140 | |
141 | match stream.poll_fill_buf(cx) { |
142 | Poll::Ready(Ok(buf)) => { |
143 | if buf.is_empty() { |
144 | this.state.shutdown_read(); |
145 | } |
146 | |
147 | Poll::Ready(Ok(buf)) |
148 | } |
149 | Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { |
150 | this.state.shutdown_read(); |
151 | Poll::Ready(Err(err)) |
152 | } |
153 | output => output, |
154 | } |
155 | } |
156 | TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(&[])), |
157 | } |
158 | } |
159 | |
160 | fn consume(mut self: Pin<&mut Self>, amt: usize) { |
161 | self.session.reader().consume(amt); |
162 | } |
163 | } |
164 | |
165 | impl<IO> AsyncWrite for TlsStream<IO> |
166 | where |
167 | IO: AsyncRead + AsyncWrite + Unpin, |
168 | { |
169 | /// Note: that it does not guarantee the final data to be sent. |
170 | /// To be cautious, you must manually call `flush`. |
171 | fn poll_write( |
172 | self: Pin<&mut Self>, |
173 | cx: &mut Context<'_>, |
174 | buf: &[u8], |
175 | ) -> Poll<io::Result<usize>> { |
176 | let this = self.get_mut(); |
177 | let mut stream = |
178 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
179 | |
180 | #[cfg (feature = "early-data" )] |
181 | { |
182 | let bufs = [io::IoSlice::new(buf)]; |
183 | let written = ready!(poll_handle_early_data( |
184 | &mut this.state, |
185 | &mut stream, |
186 | &mut this.early_waker, |
187 | cx, |
188 | &bufs |
189 | ))?; |
190 | if written != 0 { |
191 | return Poll::Ready(Ok(written)); |
192 | } |
193 | } |
194 | |
195 | stream.as_mut_pin().poll_write(cx, buf) |
196 | } |
197 | |
198 | /// Note: that it does not guarantee the final data to be sent. |
199 | /// To be cautious, you must manually call `flush`. |
200 | fn poll_write_vectored( |
201 | self: Pin<&mut Self>, |
202 | cx: &mut Context<'_>, |
203 | bufs: &[io::IoSlice<'_>], |
204 | ) -> Poll<io::Result<usize>> { |
205 | let this = self.get_mut(); |
206 | let mut stream = |
207 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
208 | |
209 | #[cfg (feature = "early-data" )] |
210 | { |
211 | let written = ready!(poll_handle_early_data( |
212 | &mut this.state, |
213 | &mut stream, |
214 | &mut this.early_waker, |
215 | cx, |
216 | bufs |
217 | ))?; |
218 | if written != 0 { |
219 | return Poll::Ready(Ok(written)); |
220 | } |
221 | } |
222 | |
223 | stream.as_mut_pin().poll_write_vectored(cx, bufs) |
224 | } |
225 | |
226 | #[inline ] |
227 | fn is_write_vectored(&self) -> bool { |
228 | true |
229 | } |
230 | |
231 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
232 | let this = self.get_mut(); |
233 | let mut stream = |
234 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
235 | |
236 | #[cfg (feature = "early-data" )] |
237 | ready!(poll_handle_early_data( |
238 | &mut this.state, |
239 | &mut stream, |
240 | &mut this.early_waker, |
241 | cx, |
242 | &[] |
243 | ))?; |
244 | |
245 | stream.as_mut_pin().poll_flush(cx) |
246 | } |
247 | |
248 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
249 | #[cfg (feature = "early-data" )] |
250 | { |
251 | // complete handshake |
252 | if matches!(self.state, TlsState::EarlyData(..)) { |
253 | ready!(self.as_mut().poll_flush(cx))?; |
254 | } |
255 | } |
256 | |
257 | if self.state.writeable() { |
258 | self.session.send_close_notify(); |
259 | self.state.shutdown_write(); |
260 | } |
261 | |
262 | let this = self.get_mut(); |
263 | let mut stream = |
264 | Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); |
265 | stream.as_mut_pin().poll_shutdown(cx) |
266 | } |
267 | } |
268 | |
269 | #[cfg (feature = "early-data" )] |
270 | fn poll_handle_early_data<IO>( |
271 | state: &mut TlsState, |
272 | stream: &mut Stream<IO, ClientConnection>, |
273 | early_waker: &mut Option<Waker>, |
274 | cx: &mut Context<'_>, |
275 | bufs: &[io::IoSlice<'_>], |
276 | ) -> Poll<io::Result<usize>> |
277 | where |
278 | IO: AsyncRead + AsyncWrite + Unpin, |
279 | { |
280 | if let TlsState::EarlyData(pos, data) = state { |
281 | use std::io::Write; |
282 | |
283 | // write early data |
284 | if let Some(mut early_data) = stream.session.early_data() { |
285 | let mut written = 0; |
286 | |
287 | for buf in bufs { |
288 | if buf.is_empty() { |
289 | continue; |
290 | } |
291 | |
292 | let len = match early_data.write(buf) { |
293 | Ok(0) => break, |
294 | Ok(n) => n, |
295 | Err(err) => return Poll::Ready(Err(err)), |
296 | }; |
297 | |
298 | written += len; |
299 | data.extend_from_slice(&buf[..len]); |
300 | |
301 | if len < buf.len() { |
302 | break; |
303 | } |
304 | } |
305 | |
306 | if written != 0 { |
307 | return Poll::Ready(Ok(written)); |
308 | } |
309 | } |
310 | |
311 | // complete handshake |
312 | while stream.session.is_handshaking() { |
313 | ready!(stream.handshake(cx))?; |
314 | } |
315 | |
316 | // write early data (fallback) |
317 | if !stream.session.is_early_data_accepted() { |
318 | while *pos < data.len() { |
319 | let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; |
320 | *pos += len; |
321 | } |
322 | } |
323 | |
324 | // end |
325 | *state = TlsState::Stream; |
326 | |
327 | if let Some(waker) = early_waker.take() { |
328 | waker.wake(); |
329 | } |
330 | } |
331 | |
332 | Poll::Ready(Ok(0)) |
333 | } |
334 | |