1 | use std::io::{self, BufRead as _, IoSlice, Read, Write}; |
2 | use std::ops::{Deref, DerefMut}; |
3 | use std::pin::Pin; |
4 | use std::task::{Context, Poll}; |
5 | |
6 | use rustls::{ConnectionCommon, SideData}; |
7 | use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; |
8 | |
9 | mod handshake; |
10 | pub(crate) use handshake::{IoSession, MidHandshake}; |
11 | |
12 | #[derive (Debug)] |
13 | pub enum TlsState { |
14 | #[cfg (feature = "early-data" )] |
15 | EarlyData(usize, Vec<u8>), |
16 | Stream, |
17 | ReadShutdown, |
18 | WriteShutdown, |
19 | FullyShutdown, |
20 | } |
21 | |
22 | impl TlsState { |
23 | #[inline ] |
24 | pub fn shutdown_read(&mut self) { |
25 | match *self { |
26 | TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, |
27 | _ => *self = TlsState::ReadShutdown, |
28 | } |
29 | } |
30 | |
31 | #[inline ] |
32 | pub fn shutdown_write(&mut self) { |
33 | match *self { |
34 | TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, |
35 | _ => *self = TlsState::WriteShutdown, |
36 | } |
37 | } |
38 | |
39 | #[inline ] |
40 | pub fn writeable(&self) -> bool { |
41 | !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown) |
42 | } |
43 | |
44 | #[inline ] |
45 | pub fn readable(&self) -> bool { |
46 | !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown) |
47 | } |
48 | |
49 | #[inline ] |
50 | #[cfg (feature = "early-data" )] |
51 | pub fn is_early_data(&self) -> bool { |
52 | matches!(self, TlsState::EarlyData(..)) |
53 | } |
54 | |
55 | #[inline ] |
56 | #[cfg (not(feature = "early-data" ))] |
57 | pub const fn is_early_data(&self) -> bool { |
58 | false |
59 | } |
60 | } |
61 | |
62 | pub struct Stream<'a, IO, C> { |
63 | pub io: &'a mut IO, |
64 | pub session: &'a mut C, |
65 | pub eof: bool, |
66 | } |
67 | |
68 | impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C> |
69 | where |
70 | C: DerefMut + Deref<Target = ConnectionCommon<SD>>, |
71 | SD: SideData, |
72 | { |
73 | pub fn new(io: &'a mut IO, session: &'a mut C) -> Self { |
74 | Stream { |
75 | io, |
76 | session, |
77 | // The state so far is only used to detect EOF, so either Stream |
78 | // or EarlyData state should both be all right. |
79 | eof: false, |
80 | } |
81 | } |
82 | |
83 | pub fn set_eof(mut self, eof: bool) -> Self { |
84 | self.eof = eof; |
85 | self |
86 | } |
87 | |
88 | pub fn as_mut_pin(&mut self) -> Pin<&mut Self> { |
89 | Pin::new(self) |
90 | } |
91 | |
92 | pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> { |
93 | let mut reader = SyncReadAdapter { io: self.io, cx }; |
94 | |
95 | let n = match self.session.read_tls(&mut reader) { |
96 | Ok(n) => n, |
97 | Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, |
98 | Err(err) => return Poll::Ready(Err(err)), |
99 | }; |
100 | |
101 | self.session.process_new_packets().map_err(|err| { |
102 | // In case we have an alert to send describing this error, |
103 | // try a last-gasp write -- but don't predate the primary |
104 | // error. |
105 | let _ = self.write_io(cx); |
106 | |
107 | io::Error::new(io::ErrorKind::InvalidData, err) |
108 | })?; |
109 | |
110 | Poll::Ready(Ok(n)) |
111 | } |
112 | |
113 | pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> { |
114 | let mut writer = SyncWriteAdapter { io: self.io, cx }; |
115 | |
116 | match self.session.write_tls(&mut writer) { |
117 | Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
118 | result => Poll::Ready(result), |
119 | } |
120 | } |
121 | |
122 | pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> { |
123 | let mut wrlen = 0; |
124 | let mut rdlen = 0; |
125 | |
126 | loop { |
127 | let mut write_would_block = false; |
128 | let mut read_would_block = false; |
129 | let mut need_flush = false; |
130 | |
131 | while self.session.wants_write() { |
132 | match self.write_io(cx) { |
133 | Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), |
134 | Poll::Ready(Ok(n)) => { |
135 | wrlen += n; |
136 | need_flush = true; |
137 | } |
138 | Poll::Pending => { |
139 | write_would_block = true; |
140 | break; |
141 | } |
142 | Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
143 | } |
144 | } |
145 | |
146 | if need_flush { |
147 | match Pin::new(&mut self.io).poll_flush(cx) { |
148 | Poll::Ready(Ok(())) => (), |
149 | Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
150 | Poll::Pending => write_would_block = true, |
151 | } |
152 | } |
153 | |
154 | while !self.eof && self.session.wants_read() { |
155 | match self.read_io(cx) { |
156 | Poll::Ready(Ok(0)) => self.eof = true, |
157 | Poll::Ready(Ok(n)) => rdlen += n, |
158 | Poll::Pending => { |
159 | read_would_block = true; |
160 | break; |
161 | } |
162 | Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
163 | } |
164 | } |
165 | |
166 | return match (self.eof, self.session.is_handshaking()) { |
167 | (true, true) => { |
168 | let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof" ); |
169 | Poll::Ready(Err(err)) |
170 | } |
171 | (_, false) => Poll::Ready(Ok((rdlen, wrlen))), |
172 | (_, true) if write_would_block || read_would_block => { |
173 | if rdlen != 0 || wrlen != 0 { |
174 | Poll::Ready(Ok((rdlen, wrlen))) |
175 | } else { |
176 | Poll::Pending |
177 | } |
178 | } |
179 | (..) => continue, |
180 | }; |
181 | } |
182 | } |
183 | |
184 | pub(crate) fn poll_fill_buf(mut self, cx: &mut Context<'_>) -> Poll<io::Result<&'a [u8]>> |
185 | where |
186 | SD: 'a, |
187 | { |
188 | let mut io_pending = false; |
189 | |
190 | // read a packet |
191 | while !self.eof && self.session.wants_read() { |
192 | match self.read_io(cx) { |
193 | Poll::Ready(Ok(0)) => { |
194 | break; |
195 | } |
196 | Poll::Ready(Ok(_)) => (), |
197 | Poll::Pending => { |
198 | io_pending = true; |
199 | break; |
200 | } |
201 | Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
202 | } |
203 | } |
204 | |
205 | match self.session.reader().into_first_chunk() { |
206 | Ok(buf) => { |
207 | // Note that this could be empty (i.e. EOF) if a `CloseNotify` has been |
208 | // received and there is no more buffered data. |
209 | Poll::Ready(Ok(buf)) |
210 | } |
211 | Err(e) if e.kind() == io::ErrorKind::WouldBlock => { |
212 | if !io_pending { |
213 | // If `wants_read()` is satisfied, rustls will not return `WouldBlock`. |
214 | // but if it does, we can try again. |
215 | // |
216 | // If the rustls state is abnormal, it may cause a cyclic wakeup. |
217 | // but tokio's cooperative budget will prevent infinite wakeup. |
218 | cx.waker().wake_by_ref(); |
219 | } |
220 | |
221 | Poll::Pending |
222 | } |
223 | Err(e) => Poll::Ready(Err(e)), |
224 | } |
225 | } |
226 | } |
227 | |
228 | impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C> |
229 | where |
230 | C: DerefMut + Deref<Target = ConnectionCommon<SD>>, |
231 | SD: SideData + 'a, |
232 | { |
233 | fn poll_read( |
234 | mut self: Pin<&mut Self>, |
235 | cx: &mut Context<'_>, |
236 | buf: &mut ReadBuf<'_>, |
237 | ) -> Poll<io::Result<()>> { |
238 | let data: &[u8] = ready!(self.as_mut().poll_fill_buf(cx))?; |
239 | let amount: usize = buf.remaining().min(data.len()); |
240 | buf.put_slice(&data[..amount]); |
241 | self.session.reader().consume(amount); |
242 | Poll::Ready(Ok(())) |
243 | } |
244 | } |
245 | |
246 | impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncBufRead for Stream<'a, IO, C> |
247 | where |
248 | C: DerefMut + Deref<Target = ConnectionCommon<SD>>, |
249 | SD: SideData + 'a, |
250 | { |
251 | fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { |
252 | let this: &mut Stream<'a, IO, C> = self.get_mut(); |
253 | Stream { |
254 | // reborrow |
255 | io: this.io, |
256 | session: this.session, |
257 | ..*this |
258 | } |
259 | .poll_fill_buf(cx) |
260 | } |
261 | |
262 | fn consume(mut self: Pin<&mut Self>, amt: usize) { |
263 | self.session.reader().consume(amount:amt); |
264 | } |
265 | } |
266 | |
267 | impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'_, IO, C> |
268 | where |
269 | C: DerefMut + Deref<Target = ConnectionCommon<SD>>, |
270 | SD: SideData, |
271 | { |
272 | fn poll_write( |
273 | mut self: Pin<&mut Self>, |
274 | cx: &mut Context, |
275 | buf: &[u8], |
276 | ) -> Poll<io::Result<usize>> { |
277 | let mut pos = 0; |
278 | |
279 | while pos != buf.len() { |
280 | let mut would_block = false; |
281 | |
282 | match self.session.writer().write(&buf[pos..]) { |
283 | Ok(n) => pos += n, |
284 | Err(err) => return Poll::Ready(Err(err)), |
285 | }; |
286 | |
287 | while self.session.wants_write() { |
288 | match self.write_io(cx) { |
289 | Poll::Ready(Ok(0)) | Poll::Pending => { |
290 | would_block = true; |
291 | break; |
292 | } |
293 | Poll::Ready(Ok(_)) => (), |
294 | Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
295 | } |
296 | } |
297 | |
298 | return match (pos, would_block) { |
299 | (0, true) => Poll::Pending, |
300 | (n, true) => Poll::Ready(Ok(n)), |
301 | (_, false) => continue, |
302 | }; |
303 | } |
304 | |
305 | Poll::Ready(Ok(pos)) |
306 | } |
307 | |
308 | fn poll_write_vectored( |
309 | mut self: Pin<&mut Self>, |
310 | cx: &mut Context<'_>, |
311 | bufs: &[IoSlice<'_>], |
312 | ) -> Poll<io::Result<usize>> { |
313 | if bufs.iter().all(|buf| buf.is_empty()) { |
314 | return Poll::Ready(Ok(0)); |
315 | } |
316 | |
317 | loop { |
318 | let mut would_block = false; |
319 | let written = self.session.writer().write_vectored(bufs)?; |
320 | |
321 | while self.session.wants_write() { |
322 | match self.write_io(cx) { |
323 | Poll::Ready(Ok(0)) | Poll::Pending => { |
324 | would_block = true; |
325 | break; |
326 | } |
327 | Poll::Ready(Ok(_)) => (), |
328 | Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
329 | } |
330 | } |
331 | |
332 | return match (written, would_block) { |
333 | (0, true) => Poll::Pending, |
334 | (0, false) => continue, |
335 | (n, _) => Poll::Ready(Ok(n)), |
336 | }; |
337 | } |
338 | } |
339 | |
340 | #[inline ] |
341 | fn is_write_vectored(&self) -> bool { |
342 | true |
343 | } |
344 | |
345 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { |
346 | self.session.writer().flush()?; |
347 | while self.session.wants_write() { |
348 | if ready!(self.write_io(cx))? == 0 { |
349 | return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); |
350 | } |
351 | } |
352 | Pin::new(&mut self.io).poll_flush(cx) |
353 | } |
354 | |
355 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
356 | while self.session.wants_write() { |
357 | if ready!(self.write_io(cx))? == 0 { |
358 | return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); |
359 | } |
360 | } |
361 | |
362 | Poll::Ready(match ready!(Pin::new(&mut self.io).poll_shutdown(cx)) { |
363 | Ok(()) => Ok(()), |
364 | // When trying to shutdown, not being connected seems fine |
365 | Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()), |
366 | Err(err) => Err(err), |
367 | }) |
368 | } |
369 | } |
370 | |
371 | /// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an |
372 | /// associated [`Context`]. |
373 | /// |
374 | /// Turns `Poll::Pending` into `WouldBlock`. |
375 | pub struct SyncReadAdapter<'a, 'b, T> { |
376 | pub io: &'a mut T, |
377 | pub cx: &'a mut Context<'b>, |
378 | } |
379 | |
380 | impl<T: AsyncRead + Unpin> Read for SyncReadAdapter<'_, '_, T> { |
381 | #[inline ] |
382 | fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
383 | let mut buf: ReadBuf<'_> = ReadBuf::new(buf); |
384 | match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) { |
385 | Poll::Ready(Ok(())) => Ok(buf.filled().len()), |
386 | Poll::Ready(Err(err: Error)) => Err(err), |
387 | Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), |
388 | } |
389 | } |
390 | } |
391 | |
392 | /// An adapter that implements a [`Write`] interface for [`AsyncWrite`] types and an |
393 | /// associated [`Context`]. |
394 | /// |
395 | /// Turns `Poll::Pending` into `WouldBlock`. |
396 | pub struct SyncWriteAdapter<'a, 'b, T> { |
397 | pub io: &'a mut T, |
398 | pub cx: &'a mut Context<'b>, |
399 | } |
400 | |
401 | impl<T: Unpin> SyncWriteAdapter<'_, '_, T> { |
402 | #[inline ] |
403 | fn poll_with<U>( |
404 | &mut self, |
405 | f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>, |
406 | ) -> io::Result<U> { |
407 | match f(Pin::new(self.io), self.cx) { |
408 | Poll::Ready(result: Result) => result, |
409 | Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), |
410 | } |
411 | } |
412 | } |
413 | |
414 | impl<T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'_, '_, T> { |
415 | #[inline ] |
416 | fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
417 | self.poll_with(|io: Pin<&mut T>, cx: &mut Context<'_>| io.poll_write(cx, buf)) |
418 | } |
419 | |
420 | #[inline ] |
421 | fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> { |
422 | self.poll_with(|io: Pin<&mut T>, cx: &mut Context<'_>| io.poll_write_vectored(cx, bufs)) |
423 | } |
424 | |
425 | fn flush(&mut self) -> io::Result<()> { |
426 | self.poll_with(|io: Pin<&mut T>, cx: &mut Context<'_>| io.poll_flush(cx)) |
427 | } |
428 | } |
429 | |
430 | #[cfg (test)] |
431 | mod test_stream; |
432 | |