1use crate::codec::decoder::Decoder;
2use crate::codec::encoder::Encoder;
3
4use futures_core::Stream;
5use tokio::io::{AsyncRead, AsyncWrite};
6
7use bytes::BytesMut;
8use futures_core::ready;
9use futures_sink::Sink;
10use pin_project_lite::pin_project;
11use std::borrow::{Borrow, BorrowMut};
12use std::io;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use tracing::trace;
16
17pin_project! {
18 #[derive(Debug)]
19 pub(crate) struct FramedImpl<T, U, State> {
20 #[pin]
21 pub(crate) inner: T,
22 pub(crate) state: State,
23 pub(crate) codec: U,
24 }
25}
26
27const INITIAL_CAPACITY: usize = 8 * 1024;
28const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY;
29
30#[derive(Debug)]
31pub(crate) struct ReadFrame {
32 pub(crate) eof: bool,
33 pub(crate) is_readable: bool,
34 pub(crate) buffer: BytesMut,
35 pub(crate) has_errored: bool,
36}
37
38pub(crate) struct WriteFrame {
39 pub(crate) buffer: BytesMut,
40}
41
42#[derive(Default)]
43pub(crate) struct RWFrames {
44 pub(crate) read: ReadFrame,
45 pub(crate) write: WriteFrame,
46}
47
48impl Default for ReadFrame {
49 fn default() -> Self {
50 Self {
51 eof: false,
52 is_readable: false,
53 buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
54 has_errored: false,
55 }
56 }
57}
58
59impl Default for WriteFrame {
60 fn default() -> Self {
61 Self {
62 buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
63 }
64 }
65}
66
67impl From<BytesMut> for ReadFrame {
68 fn from(mut buffer: BytesMut) -> Self {
69 let size: usize = buffer.capacity();
70 if size < INITIAL_CAPACITY {
71 buffer.reserve(INITIAL_CAPACITY - size);
72 }
73
74 Self {
75 buffer,
76 is_readable: size > 0,
77 eof: false,
78 has_errored: false,
79 }
80 }
81}
82
83impl From<BytesMut> for WriteFrame {
84 fn from(mut buffer: BytesMut) -> Self {
85 let size: usize = buffer.capacity();
86 if size < INITIAL_CAPACITY {
87 buffer.reserve(INITIAL_CAPACITY - size);
88 }
89
90 Self { buffer }
91 }
92}
93
94impl Borrow<ReadFrame> for RWFrames {
95 fn borrow(&self) -> &ReadFrame {
96 &self.read
97 }
98}
99impl BorrowMut<ReadFrame> for RWFrames {
100 fn borrow_mut(&mut self) -> &mut ReadFrame {
101 &mut self.read
102 }
103}
104impl Borrow<WriteFrame> for RWFrames {
105 fn borrow(&self) -> &WriteFrame {
106 &self.write
107 }
108}
109impl BorrowMut<WriteFrame> for RWFrames {
110 fn borrow_mut(&mut self) -> &mut WriteFrame {
111 &mut self.write
112 }
113}
114impl<T, U, R> Stream for FramedImpl<T, U, R>
115where
116 T: AsyncRead,
117 U: Decoder,
118 R: BorrowMut<ReadFrame>,
119{
120 type Item = Result<U::Item, U::Error>;
121
122 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123 use crate::util::poll_read_buf;
124
125 let mut pinned = self.project();
126 let state: &mut ReadFrame = pinned.state.borrow_mut();
127 // The following loops implements a state machine with each state corresponding
128 // to a combination of the `is_readable` and `eof` flags. States persist across
129 // loop entries and most state transitions occur with a return.
130 //
131 // The initial state is `reading`.
132 //
133 // | state | eof | is_readable | has_errored |
134 // |---------|-------|-------------|-------------|
135 // | reading | false | false | false |
136 // | framing | false | true | false |
137 // | pausing | true | true | false |
138 // | paused | true | false | false |
139 // | errored | <any> | <any> | true |
140 // `decode_eof` returns Err
141 // ┌────────────────────────────────────────────────────────┐
142 // `decode_eof` returns │ │
143 // `Ok(Some)` │ │
144 // ┌─────┐ │ `decode_eof` returns After returning │
145 // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐
146 // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │
147 // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘
148 // Pending read │ │ │ │ │ │
149 // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │
150 // │ │ │ ┌──────┐ │ Pending │ │
151 // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │
152 // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │
153 // └──┬─▲────┘ └─────┬──┬┘ │ │
154 // │ │ │ │ `decode` returns Err │ │
155 // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │
156 // │ read returns Err │
157 // └────────────────────────────────────────────────────────────────────────────────────────────┘
158 loop {
159 // Return `None` if we have encountered an error from the underlying decoder
160 // See: https://github.com/tokio-rs/tokio/issues/3976
161 if state.has_errored {
162 // preparing has_errored -> paused
163 trace!("Returning None and setting paused");
164 state.is_readable = false;
165 state.has_errored = false;
166 return Poll::Ready(None);
167 }
168
169 // Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
170 // i.e. it _might_ contain data consumable as a frame or closing frame.
171 // Both signal that there is no such data by returning `None`.
172 //
173 // If `decode` couldn't read a frame and the upstream source has returned eof,
174 // `decode_eof` will attempt to decode the remaining bytes as closing frames.
175 //
176 // If the underlying AsyncRead is resumable, we may continue after an EOF,
177 // but must finish emitting all of it's associated `decode_eof` frames.
178 // Furthermore, we don't want to emit any `decode_eof` frames on retried
179 // reads after an EOF unless we've actually read more data.
180 if state.is_readable {
181 // pausing or framing
182 if state.eof {
183 // pausing
184 let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
185 trace!("Got an error, going to errored state");
186 state.has_errored = true;
187 err
188 })?;
189 if frame.is_none() {
190 state.is_readable = false; // prepare pausing -> paused
191 }
192 // implicit pausing -> pausing or pausing -> paused
193 return Poll::Ready(frame.map(Ok));
194 }
195
196 // framing
197 trace!("attempting to decode a frame");
198
199 if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
200 trace!("Got an error, going to errored state");
201 state.has_errored = true;
202 op
203 })? {
204 trace!("frame decoded from buffer");
205 // implicit framing -> framing
206 return Poll::Ready(Some(Ok(frame)));
207 }
208
209 // framing -> reading
210 state.is_readable = false;
211 }
212 // reading or paused
213 // If we can't build a frame yet, try to read more data and try again.
214 // Make sure we've got room for at least one byte to read to ensure
215 // that we don't get a spurious 0 that looks like EOF.
216 state.buffer.reserve(1);
217 let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
218 |err| {
219 trace!("Got an error, going to errored state");
220 state.has_errored = true;
221 err
222 },
223 )? {
224 Poll::Ready(ct) => ct,
225 // implicit reading -> reading or implicit paused -> paused
226 Poll::Pending => return Poll::Pending,
227 };
228 if bytect == 0 {
229 if state.eof {
230 // We're already at an EOF, and since we've reached this path
231 // we're also not readable. This implies that we've already finished
232 // our `decode_eof` handling, so we can simply return `None`.
233 // implicit paused -> paused
234 return Poll::Ready(None);
235 }
236 // prepare reading -> paused
237 state.eof = true;
238 } else {
239 // prepare paused -> framing or noop reading -> framing
240 state.eof = false;
241 }
242
243 // paused -> framing or reading -> framing or reading -> pausing
244 state.is_readable = true;
245 }
246 }
247}
248
249impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W>
250where
251 T: AsyncWrite,
252 U: Encoder<I>,
253 U::Error: From<io::Error>,
254 W: BorrowMut<WriteFrame>,
255{
256 type Error = U::Error;
257
258 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
259 if self.state.borrow().buffer.len() >= BACKPRESSURE_BOUNDARY {
260 self.as_mut().poll_flush(cx)
261 } else {
262 Poll::Ready(Ok(()))
263 }
264 }
265
266 fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
267 let pinned = self.project();
268 pinned
269 .codec
270 .encode(item, &mut pinned.state.borrow_mut().buffer)?;
271 Ok(())
272 }
273
274 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
275 use crate::util::poll_write_buf;
276 trace!("flushing framed transport");
277 let mut pinned = self.project();
278
279 while !pinned.state.borrow_mut().buffer.is_empty() {
280 let WriteFrame { buffer } = pinned.state.borrow_mut();
281 trace!(remaining = buffer.len(), "writing;");
282
283 let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?;
284
285 if n == 0 {
286 return Poll::Ready(Err(io::Error::new(
287 io::ErrorKind::WriteZero,
288 "failed to \
289 write frame to transport",
290 )
291 .into()));
292 }
293 }
294
295 // Try flushing the underlying IO
296 ready!(pinned.inner.poll_flush(cx))?;
297
298 trace!("framed transport flushed");
299 Poll::Ready(Ok(()))
300 }
301
302 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303 ready!(self.as_mut().poll_flush(cx))?;
304 ready!(self.project().inner.poll_shutdown(cx))?;
305
306 Poll::Ready(Ok(()))
307 }
308}
309