1use crate::codec::decoder::Decoder;
2use crate::codec::encoder::Encoder;
3
4use futures_core::Stream;
5use tokio::io::{AsyncRead, AsyncWrite};
6
7use bytes::BytesMut;
8use futures_core::ready;
9use futures_sink::Sink;
10use pin_project_lite::pin_project;
11use std::borrow::{Borrow, BorrowMut};
12use std::io;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use tracing::trace;
16
17pin_project! {
18 #[derive(Debug)]
19 pub(crate) struct FramedImpl<T, U, State> {
20 #[pin]
21 pub(crate) inner: T,
22 pub(crate) state: State,
23 pub(crate) codec: U,
24 }
25}
26
27const INITIAL_CAPACITY: usize = 8 * 1024;
28
29#[derive(Debug)]
30pub(crate) struct ReadFrame {
31 pub(crate) eof: bool,
32 pub(crate) is_readable: bool,
33 pub(crate) buffer: BytesMut,
34 pub(crate) has_errored: bool,
35}
36
37pub(crate) struct WriteFrame {
38 pub(crate) buffer: BytesMut,
39 pub(crate) backpressure_boundary: usize,
40}
41
42#[derive(Default)]
43pub(crate) struct RWFrames {
44 pub(crate) read: ReadFrame,
45 pub(crate) write: WriteFrame,
46}
47
48impl Default for ReadFrame {
49 fn default() -> Self {
50 Self {
51 eof: false,
52 is_readable: false,
53 buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
54 has_errored: false,
55 }
56 }
57}
58
59impl Default for WriteFrame {
60 fn default() -> Self {
61 Self {
62 buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
63 backpressure_boundary: INITIAL_CAPACITY,
64 }
65 }
66}
67
68impl From<BytesMut> for ReadFrame {
69 fn from(mut buffer: BytesMut) -> Self {
70 let size = buffer.capacity();
71 if size < INITIAL_CAPACITY {
72 buffer.reserve(INITIAL_CAPACITY - size);
73 }
74
75 Self {
76 buffer,
77 is_readable: size > 0,
78 eof: false,
79 has_errored: false,
80 }
81 }
82}
83
84impl From<BytesMut> for WriteFrame {
85 fn from(mut buffer: BytesMut) -> Self {
86 let size = buffer.capacity();
87 if size < INITIAL_CAPACITY {
88 buffer.reserve(INITIAL_CAPACITY - size);
89 }
90
91 Self {
92 buffer,
93 backpressure_boundary: INITIAL_CAPACITY,
94 }
95 }
96}
97
98impl Borrow<ReadFrame> for RWFrames {
99 fn borrow(&self) -> &ReadFrame {
100 &self.read
101 }
102}
103impl BorrowMut<ReadFrame> for RWFrames {
104 fn borrow_mut(&mut self) -> &mut ReadFrame {
105 &mut self.read
106 }
107}
108impl Borrow<WriteFrame> for RWFrames {
109 fn borrow(&self) -> &WriteFrame {
110 &self.write
111 }
112}
113impl BorrowMut<WriteFrame> for RWFrames {
114 fn borrow_mut(&mut self) -> &mut WriteFrame {
115 &mut self.write
116 }
117}
118impl<T, U, R> Stream for FramedImpl<T, U, R>
119where
120 T: AsyncRead,
121 U: Decoder,
122 R: BorrowMut<ReadFrame>,
123{
124 type Item = Result<U::Item, U::Error>;
125
126 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127 use crate::util::poll_read_buf;
128
129 let mut pinned = self.project();
130 let state: &mut ReadFrame = pinned.state.borrow_mut();
131 // The following loops implements a state machine with each state corresponding
132 // to a combination of the `is_readable` and `eof` flags. States persist across
133 // loop entries and most state transitions occur with a return.
134 //
135 // The initial state is `reading`.
136 //
137 // | state | eof | is_readable | has_errored |
138 // |---------|-------|-------------|-------------|
139 // | reading | false | false | false |
140 // | framing | false | true | false |
141 // | pausing | true | true | false |
142 // | paused | true | false | false |
143 // | errored | <any> | <any> | true |
144 // `decode_eof` returns Err
145 // ┌────────────────────────────────────────────────────────┐
146 // `decode_eof` returns │ │
147 // `Ok(Some)` │ │
148 // ┌─────┐ │ `decode_eof` returns After returning │
149 // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐
150 // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │
151 // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘
152 // Pending read │ │ │ │ │ │
153 // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │
154 // │ │ │ ┌──────┐ │ Pending │ │
155 // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │
156 // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │
157 // └──┬─▲────┘ └─────┬──┬┘ │ │
158 // │ │ │ │ `decode` returns Err │ │
159 // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │
160 // │ read returns Err │
161 // └────────────────────────────────────────────────────────────────────────────────────────────┘
162 loop {
163 // Return `None` if we have encountered an error from the underlying decoder
164 // See: https://github.com/tokio-rs/tokio/issues/3976
165 if state.has_errored {
166 // preparing has_errored -> paused
167 trace!("Returning None and setting paused");
168 state.is_readable = false;
169 state.has_errored = false;
170 return Poll::Ready(None);
171 }
172
173 // Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
174 // i.e. it _might_ contain data consumable as a frame or closing frame.
175 // Both signal that there is no such data by returning `None`.
176 //
177 // If `decode` couldn't read a frame and the upstream source has returned eof,
178 // `decode_eof` will attempt to decode the remaining bytes as closing frames.
179 //
180 // If the underlying AsyncRead is resumable, we may continue after an EOF,
181 // but must finish emitting all of it's associated `decode_eof` frames.
182 // Furthermore, we don't want to emit any `decode_eof` frames on retried
183 // reads after an EOF unless we've actually read more data.
184 if state.is_readable {
185 // pausing or framing
186 if state.eof {
187 // pausing
188 let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
189 trace!("Got an error, going to errored state");
190 state.has_errored = true;
191 err
192 })?;
193 if frame.is_none() {
194 state.is_readable = false; // prepare pausing -> paused
195 }
196 // implicit pausing -> pausing or pausing -> paused
197 return Poll::Ready(frame.map(Ok));
198 }
199
200 // framing
201 trace!("attempting to decode a frame");
202
203 if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
204 trace!("Got an error, going to errored state");
205 state.has_errored = true;
206 op
207 })? {
208 trace!("frame decoded from buffer");
209 // implicit framing -> framing
210 return Poll::Ready(Some(Ok(frame)));
211 }
212
213 // framing -> reading
214 state.is_readable = false;
215 }
216 // reading or paused
217 // If we can't build a frame yet, try to read more data and try again.
218 // Make sure we've got room for at least one byte to read to ensure
219 // that we don't get a spurious 0 that looks like EOF.
220 state.buffer.reserve(1);
221 let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
222 |err| {
223 trace!("Got an error, going to errored state");
224 state.has_errored = true;
225 err
226 },
227 )? {
228 Poll::Ready(ct) => ct,
229 // implicit reading -> reading or implicit paused -> paused
230 Poll::Pending => return Poll::Pending,
231 };
232 if bytect == 0 {
233 if state.eof {
234 // We're already at an EOF, and since we've reached this path
235 // we're also not readable. This implies that we've already finished
236 // our `decode_eof` handling, so we can simply return `None`.
237 // implicit paused -> paused
238 return Poll::Ready(None);
239 }
240 // prepare reading -> paused
241 state.eof = true;
242 } else {
243 // prepare paused -> framing or noop reading -> framing
244 state.eof = false;
245 }
246
247 // paused -> framing or reading -> framing or reading -> pausing
248 state.is_readable = true;
249 }
250 }
251}
252
253impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W>
254where
255 T: AsyncWrite,
256 U: Encoder<I>,
257 U::Error: From<io::Error>,
258 W: BorrowMut<WriteFrame>,
259{
260 type Error = U::Error;
261
262 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
263 if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary {
264 self.as_mut().poll_flush(cx)
265 } else {
266 Poll::Ready(Ok(()))
267 }
268 }
269
270 fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
271 let pinned = self.project();
272 pinned
273 .codec
274 .encode(item, &mut pinned.state.borrow_mut().buffer)?;
275 Ok(())
276 }
277
278 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
279 use crate::util::poll_write_buf;
280 trace!("flushing framed transport");
281 let mut pinned = self.project();
282
283 while !pinned.state.borrow_mut().buffer.is_empty() {
284 let WriteFrame { buffer, .. } = pinned.state.borrow_mut();
285 trace!(remaining = buffer.len(), "writing;");
286
287 let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?;
288
289 if n == 0 {
290 return Poll::Ready(Err(io::Error::new(
291 io::ErrorKind::WriteZero,
292 "failed to \
293 write frame to transport",
294 )
295 .into()));
296 }
297 }
298
299 // Try flushing the underlying IO
300 ready!(pinned.inner.poll_flush(cx))?;
301
302 trace!("framed transport flushed");
303 Poll::Ready(Ok(()))
304 }
305
306 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
307 ready!(self.as_mut().poll_flush(cx))?;
308 ready!(self.project().inner.poll_shutdown(cx))?;
309
310 Poll::Ready(Ok(()))
311 }
312}
313