| 1 | use crate::codec::decoder::Decoder; |
| 2 | use crate::codec::encoder::Encoder; |
| 3 | |
| 4 | use futures_core::Stream; |
| 5 | use tokio::io::{AsyncRead, AsyncWrite}; |
| 6 | |
| 7 | use bytes::BytesMut; |
| 8 | use futures_sink::Sink; |
| 9 | use pin_project_lite::pin_project; |
| 10 | use std::borrow::{Borrow, BorrowMut}; |
| 11 | use std::io; |
| 12 | use std::pin::Pin; |
| 13 | use std::task::{ready, Context, Poll}; |
| 14 | |
| 15 | pin_project! { |
| 16 | #[derive (Debug)] |
| 17 | pub(crate) struct FramedImpl<T, U, State> { |
| 18 | #[pin] |
| 19 | pub(crate) inner: T, |
| 20 | pub(crate) state: State, |
| 21 | pub(crate) codec: U, |
| 22 | } |
| 23 | } |
| 24 | |
| 25 | const INITIAL_CAPACITY: usize = 8 * 1024; |
| 26 | |
| 27 | #[derive (Debug)] |
| 28 | pub(crate) struct ReadFrame { |
| 29 | pub(crate) eof: bool, |
| 30 | pub(crate) is_readable: bool, |
| 31 | pub(crate) buffer: BytesMut, |
| 32 | pub(crate) has_errored: bool, |
| 33 | } |
| 34 | |
| 35 | pub(crate) struct WriteFrame { |
| 36 | pub(crate) buffer: BytesMut, |
| 37 | pub(crate) backpressure_boundary: usize, |
| 38 | } |
| 39 | |
| 40 | #[derive (Default)] |
| 41 | pub(crate) struct RWFrames { |
| 42 | pub(crate) read: ReadFrame, |
| 43 | pub(crate) write: WriteFrame, |
| 44 | } |
| 45 | |
| 46 | impl Default for ReadFrame { |
| 47 | fn default() -> Self { |
| 48 | Self { |
| 49 | eof: false, |
| 50 | is_readable: false, |
| 51 | buffer: BytesMut::with_capacity(INITIAL_CAPACITY), |
| 52 | has_errored: false, |
| 53 | } |
| 54 | } |
| 55 | } |
| 56 | |
| 57 | impl Default for WriteFrame { |
| 58 | fn default() -> Self { |
| 59 | Self { |
| 60 | buffer: BytesMut::with_capacity(INITIAL_CAPACITY), |
| 61 | backpressure_boundary: INITIAL_CAPACITY, |
| 62 | } |
| 63 | } |
| 64 | } |
| 65 | |
| 66 | impl From<BytesMut> for ReadFrame { |
| 67 | fn from(mut buffer: BytesMut) -> Self { |
| 68 | let size: usize = buffer.capacity(); |
| 69 | if size < INITIAL_CAPACITY { |
| 70 | buffer.reserve(INITIAL_CAPACITY - size); |
| 71 | } |
| 72 | |
| 73 | Self { |
| 74 | buffer, |
| 75 | is_readable: size > 0, |
| 76 | eof: false, |
| 77 | has_errored: false, |
| 78 | } |
| 79 | } |
| 80 | } |
| 81 | |
| 82 | impl From<BytesMut> for WriteFrame { |
| 83 | fn from(mut buffer: BytesMut) -> Self { |
| 84 | let size: usize = buffer.capacity(); |
| 85 | if size < INITIAL_CAPACITY { |
| 86 | buffer.reserve(INITIAL_CAPACITY - size); |
| 87 | } |
| 88 | |
| 89 | Self { |
| 90 | buffer, |
| 91 | backpressure_boundary: INITIAL_CAPACITY, |
| 92 | } |
| 93 | } |
| 94 | } |
| 95 | |
| 96 | impl Borrow<ReadFrame> for RWFrames { |
| 97 | fn borrow(&self) -> &ReadFrame { |
| 98 | &self.read |
| 99 | } |
| 100 | } |
| 101 | impl BorrowMut<ReadFrame> for RWFrames { |
| 102 | fn borrow_mut(&mut self) -> &mut ReadFrame { |
| 103 | &mut self.read |
| 104 | } |
| 105 | } |
| 106 | impl Borrow<WriteFrame> for RWFrames { |
| 107 | fn borrow(&self) -> &WriteFrame { |
| 108 | &self.write |
| 109 | } |
| 110 | } |
| 111 | impl BorrowMut<WriteFrame> for RWFrames { |
| 112 | fn borrow_mut(&mut self) -> &mut WriteFrame { |
| 113 | &mut self.write |
| 114 | } |
| 115 | } |
| 116 | impl<T, U, R> Stream for FramedImpl<T, U, R> |
| 117 | where |
| 118 | T: AsyncRead, |
| 119 | U: Decoder, |
| 120 | R: BorrowMut<ReadFrame>, |
| 121 | { |
| 122 | type Item = Result<U::Item, U::Error>; |
| 123 | |
| 124 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 125 | use crate::util::poll_read_buf; |
| 126 | |
| 127 | let mut pinned = self.project(); |
| 128 | let state: &mut ReadFrame = pinned.state.borrow_mut(); |
| 129 | // The following loops implements a state machine with each state corresponding |
| 130 | // to a combination of the `is_readable` and `eof` flags. States persist across |
| 131 | // loop entries and most state transitions occur with a return. |
| 132 | // |
| 133 | // The initial state is `reading`. |
| 134 | // |
| 135 | // | state | eof | is_readable | has_errored | |
| 136 | // |---------|-------|-------------|-------------| |
| 137 | // | reading | false | false | false | |
| 138 | // | framing | false | true | false | |
| 139 | // | pausing | true | true | false | |
| 140 | // | paused | true | false | false | |
| 141 | // | errored | <any> | <any> | true | |
| 142 | // `decode_eof` returns Err |
| 143 | // ┌────────────────────────────────────────────────────────┐ |
| 144 | // `decode_eof` returns │ │ |
| 145 | // `Ok(Some)` │ │ |
| 146 | // ┌─────┐ │ `decode_eof` returns After returning │ |
| 147 | // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ |
| 148 | // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ |
| 149 | // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ |
| 150 | // Pending read │ │ │ │ │ │ |
| 151 | // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ |
| 152 | // │ │ │ ┌──────┐ │ Pending │ │ |
| 153 | // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ |
| 154 | // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ |
| 155 | // └──┬─▲────┘ └─────┬──┬┘ │ │ |
| 156 | // │ │ │ │ `decode` returns Err │ │ |
| 157 | // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ |
| 158 | // │ read returns Err │ |
| 159 | // └────────────────────────────────────────────────────────────────────────────────────────────┘ |
| 160 | loop { |
| 161 | // Return `None` if we have encountered an error from the underlying decoder |
| 162 | // See: https://github.com/tokio-rs/tokio/issues/3976 |
| 163 | if state.has_errored { |
| 164 | // preparing has_errored -> paused |
| 165 | trace!("Returning None and setting paused" ); |
| 166 | state.is_readable = false; |
| 167 | state.has_errored = false; |
| 168 | return Poll::Ready(None); |
| 169 | } |
| 170 | |
| 171 | // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", |
| 172 | // i.e. it _might_ contain data consumable as a frame or closing frame. |
| 173 | // Both signal that there is no such data by returning `None`. |
| 174 | // |
| 175 | // If `decode` couldn't read a frame and the upstream source has returned eof, |
| 176 | // `decode_eof` will attempt to decode the remaining bytes as closing frames. |
| 177 | // |
| 178 | // If the underlying AsyncRead is resumable, we may continue after an EOF, |
| 179 | // but must finish emitting all of it's associated `decode_eof` frames. |
| 180 | // Furthermore, we don't want to emit any `decode_eof` frames on retried |
| 181 | // reads after an EOF unless we've actually read more data. |
| 182 | if state.is_readable { |
| 183 | // pausing or framing |
| 184 | if state.eof { |
| 185 | // pausing |
| 186 | let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { |
| 187 | trace!("Got an error, going to errored state" ); |
| 188 | state.has_errored = true; |
| 189 | err |
| 190 | })?; |
| 191 | if frame.is_none() { |
| 192 | state.is_readable = false; // prepare pausing -> paused |
| 193 | } |
| 194 | // implicit pausing -> pausing or pausing -> paused |
| 195 | return Poll::Ready(frame.map(Ok)); |
| 196 | } |
| 197 | |
| 198 | // framing |
| 199 | trace!("attempting to decode a frame" ); |
| 200 | |
| 201 | if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { |
| 202 | trace!("Got an error, going to errored state" ); |
| 203 | state.has_errored = true; |
| 204 | op |
| 205 | })? { |
| 206 | trace!("frame decoded from buffer" ); |
| 207 | // implicit framing -> framing |
| 208 | return Poll::Ready(Some(Ok(frame))); |
| 209 | } |
| 210 | |
| 211 | // framing -> reading |
| 212 | state.is_readable = false; |
| 213 | } |
| 214 | // reading or paused |
| 215 | // If we can't build a frame yet, try to read more data and try again. |
| 216 | // Make sure we've got room for at least one byte to read to ensure |
| 217 | // that we don't get a spurious 0 that looks like EOF. |
| 218 | state.buffer.reserve(1); |
| 219 | #[allow (clippy::blocks_in_conditions)] |
| 220 | let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( |
| 221 | |err| { |
| 222 | trace!("Got an error, going to errored state" ); |
| 223 | state.has_errored = true; |
| 224 | err |
| 225 | }, |
| 226 | )? { |
| 227 | Poll::Ready(ct) => ct, |
| 228 | // implicit reading -> reading or implicit paused -> paused |
| 229 | Poll::Pending => return Poll::Pending, |
| 230 | }; |
| 231 | if bytect == 0 { |
| 232 | if state.eof { |
| 233 | // We're already at an EOF, and since we've reached this path |
| 234 | // we're also not readable. This implies that we've already finished |
| 235 | // our `decode_eof` handling, so we can simply return `None`. |
| 236 | // implicit paused -> paused |
| 237 | return Poll::Ready(None); |
| 238 | } |
| 239 | // prepare reading -> paused |
| 240 | state.eof = true; |
| 241 | } else { |
| 242 | // prepare paused -> framing or noop reading -> framing |
| 243 | state.eof = false; |
| 244 | } |
| 245 | |
| 246 | // paused -> framing or reading -> framing or reading -> pausing |
| 247 | state.is_readable = true; |
| 248 | } |
| 249 | } |
| 250 | } |
| 251 | |
| 252 | impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W> |
| 253 | where |
| 254 | T: AsyncWrite, |
| 255 | U: Encoder<I>, |
| 256 | U::Error: From<io::Error>, |
| 257 | W: BorrowMut<WriteFrame>, |
| 258 | { |
| 259 | type Error = U::Error; |
| 260 | |
| 261 | fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 262 | if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary { |
| 263 | self.as_mut().poll_flush(cx) |
| 264 | } else { |
| 265 | Poll::Ready(Ok(())) |
| 266 | } |
| 267 | } |
| 268 | |
| 269 | fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { |
| 270 | let pinned = self.project(); |
| 271 | pinned |
| 272 | .codec |
| 273 | .encode(item, &mut pinned.state.borrow_mut().buffer)?; |
| 274 | Ok(()) |
| 275 | } |
| 276 | |
| 277 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 278 | use crate::util::poll_write_buf; |
| 279 | trace!("flushing framed transport" ); |
| 280 | let mut pinned = self.project(); |
| 281 | |
| 282 | while !pinned.state.borrow_mut().buffer.is_empty() { |
| 283 | let WriteFrame { buffer, .. } = pinned.state.borrow_mut(); |
| 284 | trace!(remaining = buffer.len(), "writing;" ); |
| 285 | |
| 286 | let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; |
| 287 | |
| 288 | if n == 0 { |
| 289 | return Poll::Ready(Err(io::Error::new( |
| 290 | io::ErrorKind::WriteZero, |
| 291 | "failed to \ |
| 292 | write frame to transport" , |
| 293 | ) |
| 294 | .into())); |
| 295 | } |
| 296 | } |
| 297 | |
| 298 | // Try flushing the underlying IO |
| 299 | ready!(pinned.inner.poll_flush(cx))?; |
| 300 | |
| 301 | trace!("framed transport flushed" ); |
| 302 | Poll::Ready(Ok(())) |
| 303 | } |
| 304 | |
| 305 | fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 306 | ready!(self.as_mut().poll_flush(cx))?; |
| 307 | ready!(self.project().inner.poll_shutdown(cx))?; |
| 308 | |
| 309 | Poll::Ready(Ok(())) |
| 310 | } |
| 311 | } |
| 312 | |