| 1 | use std::future::Future; |
| 2 | use std::ops::{Deref, DerefMut}; |
| 3 | use std::pin::Pin; |
| 4 | use std::task::{Context, Poll}; |
| 5 | use std::{io, mem}; |
| 6 | |
| 7 | use rustls::server::AcceptedAlert; |
| 8 | use rustls::{ConnectionCommon, SideData}; |
| 9 | use tokio::io::{AsyncRead, AsyncWrite}; |
| 10 | |
| 11 | use crate::common::{Stream, SyncWriteAdapter, TlsState}; |
| 12 | |
| 13 | pub(crate) trait IoSession { |
| 14 | type Io; |
| 15 | type Session; |
| 16 | |
| 17 | fn skip_handshake(&self) -> bool; |
| 18 | fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session); |
| 19 | fn into_io(self) -> Self::Io; |
| 20 | } |
| 21 | |
| 22 | pub(crate) enum MidHandshake<IS: IoSession> { |
| 23 | Handshaking(IS), |
| 24 | End, |
| 25 | SendAlert { |
| 26 | io: IS::Io, |
| 27 | alert: AcceptedAlert, |
| 28 | error: io::Error, |
| 29 | }, |
| 30 | Error { |
| 31 | io: IS::Io, |
| 32 | error: io::Error, |
| 33 | }, |
| 34 | } |
| 35 | |
| 36 | impl<IS, SD> Future for MidHandshake<IS> |
| 37 | where |
| 38 | IS: IoSession + Unpin, |
| 39 | IS::Io: AsyncRead + AsyncWrite + Unpin, |
| 40 | IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin, |
| 41 | SD: SideData, |
| 42 | { |
| 43 | type Output = Result<IS, (io::Error, IS::Io)>; |
| 44 | |
| 45 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 46 | let this = self.get_mut(); |
| 47 | |
| 48 | let mut stream = match mem::replace(this, MidHandshake::End) { |
| 49 | MidHandshake::Handshaking(stream) => stream, |
| 50 | MidHandshake::SendAlert { |
| 51 | mut io, |
| 52 | mut alert, |
| 53 | error, |
| 54 | } => loop { |
| 55 | match alert.write(&mut SyncWriteAdapter { io: &mut io, cx }) { |
| 56 | Err(e) if e.kind() == io::ErrorKind::WouldBlock => { |
| 57 | *this = MidHandshake::SendAlert { io, error, alert }; |
| 58 | return Poll::Pending; |
| 59 | } |
| 60 | Err(_) | Ok(0) => return Poll::Ready(Err((error, io))), |
| 61 | Ok(_) => {} |
| 62 | }; |
| 63 | }, |
| 64 | // Starting the handshake returned an error; fail the future immediately. |
| 65 | MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))), |
| 66 | _ => panic!("unexpected polling after handshake" ), |
| 67 | }; |
| 68 | |
| 69 | if !stream.skip_handshake() { |
| 70 | let (state, io, session) = stream.get_mut(); |
| 71 | let mut tls_stream = Stream::new(io, session).set_eof(!state.readable()); |
| 72 | |
| 73 | macro_rules! try_poll { |
| 74 | ( $e:expr ) => { |
| 75 | match $e { |
| 76 | Poll::Ready(Ok(_)) => (), |
| 77 | Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), |
| 78 | Poll::Pending => { |
| 79 | *this = MidHandshake::Handshaking(stream); |
| 80 | return Poll::Pending; |
| 81 | } |
| 82 | } |
| 83 | }; |
| 84 | } |
| 85 | |
| 86 | while tls_stream.session.is_handshaking() { |
| 87 | try_poll!(tls_stream.handshake(cx)); |
| 88 | } |
| 89 | |
| 90 | try_poll!(Pin::new(&mut tls_stream).poll_flush(cx)); |
| 91 | } |
| 92 | |
| 93 | Poll::Ready(Ok(stream)) |
| 94 | } |
| 95 | } |
| 96 | |