1 | use std::future::Future; |
2 | use std::ops::{Deref, DerefMut}; |
3 | use std::pin::Pin; |
4 | use std::task::{Context, Poll}; |
5 | use std::{io, mem}; |
6 | |
7 | use rustls::server::AcceptedAlert; |
8 | use rustls::{ConnectionCommon, SideData}; |
9 | use tokio::io::{AsyncRead, AsyncWrite}; |
10 | |
11 | use crate::common::{Stream, SyncWriteAdapter, TlsState}; |
12 | |
13 | pub(crate) trait IoSession { |
14 | type Io; |
15 | type Session; |
16 | |
17 | fn skip_handshake(&self) -> bool; |
18 | fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session); |
19 | fn into_io(self) -> Self::Io; |
20 | } |
21 | |
22 | pub(crate) enum MidHandshake<IS: IoSession> { |
23 | Handshaking(IS), |
24 | End, |
25 | SendAlert { |
26 | io: IS::Io, |
27 | alert: AcceptedAlert, |
28 | error: io::Error, |
29 | }, |
30 | Error { |
31 | io: IS::Io, |
32 | error: io::Error, |
33 | }, |
34 | } |
35 | |
36 | impl<IS, SD> Future for MidHandshake<IS> |
37 | where |
38 | IS: IoSession + Unpin, |
39 | IS::Io: AsyncRead + AsyncWrite + Unpin, |
40 | IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin, |
41 | SD: SideData, |
42 | { |
43 | type Output = Result<IS, (io::Error, IS::Io)>; |
44 | |
45 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
46 | let this = self.get_mut(); |
47 | |
48 | let mut stream = match mem::replace(this, MidHandshake::End) { |
49 | MidHandshake::Handshaking(stream) => stream, |
50 | MidHandshake::SendAlert { |
51 | mut io, |
52 | mut alert, |
53 | error, |
54 | } => loop { |
55 | match alert.write(&mut SyncWriteAdapter { io: &mut io, cx }) { |
56 | Err(e) if e.kind() == io::ErrorKind::WouldBlock => { |
57 | *this = MidHandshake::SendAlert { io, error, alert }; |
58 | return Poll::Pending; |
59 | } |
60 | Err(_) | Ok(0) => return Poll::Ready(Err((error, io))), |
61 | Ok(_) => {} |
62 | }; |
63 | }, |
64 | // Starting the handshake returned an error; fail the future immediately. |
65 | MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))), |
66 | _ => panic!("unexpected polling after handshake" ), |
67 | }; |
68 | |
69 | if !stream.skip_handshake() { |
70 | let (state, io, session) = stream.get_mut(); |
71 | let mut tls_stream = Stream::new(io, session).set_eof(!state.readable()); |
72 | |
73 | macro_rules! try_poll { |
74 | ( $e:expr ) => { |
75 | match $e { |
76 | Poll::Ready(Ok(_)) => (), |
77 | Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), |
78 | Poll::Pending => { |
79 | *this = MidHandshake::Handshaking(stream); |
80 | return Poll::Pending; |
81 | } |
82 | } |
83 | }; |
84 | } |
85 | |
86 | while tls_stream.session.is_handshaking() { |
87 | try_poll!(tls_stream.handshake(cx)); |
88 | } |
89 | |
90 | try_poll!(Pin::new(&mut tls_stream).poll_flush(cx)); |
91 | } |
92 | |
93 | Poll::Ready(Ok(stream)) |
94 | } |
95 | } |
96 | |