1 | use crate::codec::decoder::Decoder; |
2 | use crate::codec::encoder::Encoder; |
3 | |
4 | use futures_core::Stream; |
5 | use tokio::io::{AsyncRead, AsyncWrite}; |
6 | |
7 | use bytes::BytesMut; |
8 | use futures_core::ready; |
9 | use futures_sink::Sink; |
10 | use pin_project_lite::pin_project; |
11 | use std::borrow::{Borrow, BorrowMut}; |
12 | use std::io; |
13 | use std::pin::Pin; |
14 | use std::task::{Context, Poll}; |
15 | use tracing::trace; |
16 | |
17 | pin_project! { |
18 | #[derive(Debug)] |
19 | pub(crate) struct FramedImpl<T, U, State> { |
20 | #[pin] |
21 | pub(crate) inner: T, |
22 | pub(crate) state: State, |
23 | pub(crate) codec: U, |
24 | } |
25 | } |
26 | |
27 | const INITIAL_CAPACITY: usize = 8 * 1024; |
28 | const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; |
29 | |
30 | #[derive (Debug)] |
31 | pub(crate) struct ReadFrame { |
32 | pub(crate) eof: bool, |
33 | pub(crate) is_readable: bool, |
34 | pub(crate) buffer: BytesMut, |
35 | pub(crate) has_errored: bool, |
36 | } |
37 | |
38 | pub(crate) struct WriteFrame { |
39 | pub(crate) buffer: BytesMut, |
40 | } |
41 | |
42 | #[derive (Default)] |
43 | pub(crate) struct RWFrames { |
44 | pub(crate) read: ReadFrame, |
45 | pub(crate) write: WriteFrame, |
46 | } |
47 | |
48 | impl Default for ReadFrame { |
49 | fn default() -> Self { |
50 | Self { |
51 | eof: false, |
52 | is_readable: false, |
53 | buffer: BytesMut::with_capacity(INITIAL_CAPACITY), |
54 | has_errored: false, |
55 | } |
56 | } |
57 | } |
58 | |
59 | impl Default for WriteFrame { |
60 | fn default() -> Self { |
61 | Self { |
62 | buffer: BytesMut::with_capacity(INITIAL_CAPACITY), |
63 | } |
64 | } |
65 | } |
66 | |
67 | impl From<BytesMut> for ReadFrame { |
68 | fn from(mut buffer: BytesMut) -> Self { |
69 | let size: usize = buffer.capacity(); |
70 | if size < INITIAL_CAPACITY { |
71 | buffer.reserve(INITIAL_CAPACITY - size); |
72 | } |
73 | |
74 | Self { |
75 | buffer, |
76 | is_readable: size > 0, |
77 | eof: false, |
78 | has_errored: false, |
79 | } |
80 | } |
81 | } |
82 | |
83 | impl From<BytesMut> for WriteFrame { |
84 | fn from(mut buffer: BytesMut) -> Self { |
85 | let size: usize = buffer.capacity(); |
86 | if size < INITIAL_CAPACITY { |
87 | buffer.reserve(INITIAL_CAPACITY - size); |
88 | } |
89 | |
90 | Self { buffer } |
91 | } |
92 | } |
93 | |
94 | impl Borrow<ReadFrame> for RWFrames { |
95 | fn borrow(&self) -> &ReadFrame { |
96 | &self.read |
97 | } |
98 | } |
99 | impl BorrowMut<ReadFrame> for RWFrames { |
100 | fn borrow_mut(&mut self) -> &mut ReadFrame { |
101 | &mut self.read |
102 | } |
103 | } |
104 | impl Borrow<WriteFrame> for RWFrames { |
105 | fn borrow(&self) -> &WriteFrame { |
106 | &self.write |
107 | } |
108 | } |
109 | impl BorrowMut<WriteFrame> for RWFrames { |
110 | fn borrow_mut(&mut self) -> &mut WriteFrame { |
111 | &mut self.write |
112 | } |
113 | } |
114 | impl<T, U, R> Stream for FramedImpl<T, U, R> |
115 | where |
116 | T: AsyncRead, |
117 | U: Decoder, |
118 | R: BorrowMut<ReadFrame>, |
119 | { |
120 | type Item = Result<U::Item, U::Error>; |
121 | |
122 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
123 | use crate::util::poll_read_buf; |
124 | |
125 | let mut pinned = self.project(); |
126 | let state: &mut ReadFrame = pinned.state.borrow_mut(); |
127 | // The following loops implements a state machine with each state corresponding |
128 | // to a combination of the `is_readable` and `eof` flags. States persist across |
129 | // loop entries and most state transitions occur with a return. |
130 | // |
131 | // The initial state is `reading`. |
132 | // |
133 | // | state | eof | is_readable | has_errored | |
134 | // |---------|-------|-------------|-------------| |
135 | // | reading | false | false | false | |
136 | // | framing | false | true | false | |
137 | // | pausing | true | true | false | |
138 | // | paused | true | false | false | |
139 | // | errored | <any> | <any> | true | |
140 | // `decode_eof` returns Err |
141 | // ┌────────────────────────────────────────────────────────┐ |
142 | // `decode_eof` returns │ │ |
143 | // `Ok(Some)` │ │ |
144 | // ┌─────┐ │ `decode_eof` returns After returning │ |
145 | // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ |
146 | // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ |
147 | // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ |
148 | // Pending read │ │ │ │ │ │ |
149 | // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ |
150 | // │ │ │ ┌──────┐ │ Pending │ │ |
151 | // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ |
152 | // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ |
153 | // └──┬─▲────┘ └─────┬──┬┘ │ │ |
154 | // │ │ │ │ `decode` returns Err │ │ |
155 | // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ |
156 | // │ read returns Err │ |
157 | // └────────────────────────────────────────────────────────────────────────────────────────────┘ |
158 | loop { |
159 | // Return `None` if we have encountered an error from the underlying decoder |
160 | // See: https://github.com/tokio-rs/tokio/issues/3976 |
161 | if state.has_errored { |
162 | // preparing has_errored -> paused |
163 | trace!("Returning None and setting paused" ); |
164 | state.is_readable = false; |
165 | state.has_errored = false; |
166 | return Poll::Ready(None); |
167 | } |
168 | |
169 | // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", |
170 | // i.e. it _might_ contain data consumable as a frame or closing frame. |
171 | // Both signal that there is no such data by returning `None`. |
172 | // |
173 | // If `decode` couldn't read a frame and the upstream source has returned eof, |
174 | // `decode_eof` will attempt to decode the remaining bytes as closing frames. |
175 | // |
176 | // If the underlying AsyncRead is resumable, we may continue after an EOF, |
177 | // but must finish emitting all of it's associated `decode_eof` frames. |
178 | // Furthermore, we don't want to emit any `decode_eof` frames on retried |
179 | // reads after an EOF unless we've actually read more data. |
180 | if state.is_readable { |
181 | // pausing or framing |
182 | if state.eof { |
183 | // pausing |
184 | let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { |
185 | trace!("Got an error, going to errored state" ); |
186 | state.has_errored = true; |
187 | err |
188 | })?; |
189 | if frame.is_none() { |
190 | state.is_readable = false; // prepare pausing -> paused |
191 | } |
192 | // implicit pausing -> pausing or pausing -> paused |
193 | return Poll::Ready(frame.map(Ok)); |
194 | } |
195 | |
196 | // framing |
197 | trace!("attempting to decode a frame" ); |
198 | |
199 | if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { |
200 | trace!("Got an error, going to errored state" ); |
201 | state.has_errored = true; |
202 | op |
203 | })? { |
204 | trace!("frame decoded from buffer" ); |
205 | // implicit framing -> framing |
206 | return Poll::Ready(Some(Ok(frame))); |
207 | } |
208 | |
209 | // framing -> reading |
210 | state.is_readable = false; |
211 | } |
212 | // reading or paused |
213 | // If we can't build a frame yet, try to read more data and try again. |
214 | // Make sure we've got room for at least one byte to read to ensure |
215 | // that we don't get a spurious 0 that looks like EOF. |
216 | state.buffer.reserve(1); |
217 | let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( |
218 | |err| { |
219 | trace!("Got an error, going to errored state" ); |
220 | state.has_errored = true; |
221 | err |
222 | }, |
223 | )? { |
224 | Poll::Ready(ct) => ct, |
225 | // implicit reading -> reading or implicit paused -> paused |
226 | Poll::Pending => return Poll::Pending, |
227 | }; |
228 | if bytect == 0 { |
229 | if state.eof { |
230 | // We're already at an EOF, and since we've reached this path |
231 | // we're also not readable. This implies that we've already finished |
232 | // our `decode_eof` handling, so we can simply return `None`. |
233 | // implicit paused -> paused |
234 | return Poll::Ready(None); |
235 | } |
236 | // prepare reading -> paused |
237 | state.eof = true; |
238 | } else { |
239 | // prepare paused -> framing or noop reading -> framing |
240 | state.eof = false; |
241 | } |
242 | |
243 | // paused -> framing or reading -> framing or reading -> pausing |
244 | state.is_readable = true; |
245 | } |
246 | } |
247 | } |
248 | |
249 | impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W> |
250 | where |
251 | T: AsyncWrite, |
252 | U: Encoder<I>, |
253 | U::Error: From<io::Error>, |
254 | W: BorrowMut<WriteFrame>, |
255 | { |
256 | type Error = U::Error; |
257 | |
258 | fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
259 | if self.state.borrow().buffer.len() >= BACKPRESSURE_BOUNDARY { |
260 | self.as_mut().poll_flush(cx) |
261 | } else { |
262 | Poll::Ready(Ok(())) |
263 | } |
264 | } |
265 | |
266 | fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { |
267 | let pinned = self.project(); |
268 | pinned |
269 | .codec |
270 | .encode(item, &mut pinned.state.borrow_mut().buffer)?; |
271 | Ok(()) |
272 | } |
273 | |
274 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
275 | use crate::util::poll_write_buf; |
276 | trace!("flushing framed transport" ); |
277 | let mut pinned = self.project(); |
278 | |
279 | while !pinned.state.borrow_mut().buffer.is_empty() { |
280 | let WriteFrame { buffer } = pinned.state.borrow_mut(); |
281 | trace!(remaining = buffer.len(), "writing;" ); |
282 | |
283 | let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; |
284 | |
285 | if n == 0 { |
286 | return Poll::Ready(Err(io::Error::new( |
287 | io::ErrorKind::WriteZero, |
288 | "failed to \ |
289 | write frame to transport" , |
290 | ) |
291 | .into())); |
292 | } |
293 | } |
294 | |
295 | // Try flushing the underlying IO |
296 | ready!(pinned.inner.poll_flush(cx))?; |
297 | |
298 | trace!("framed transport flushed" ); |
299 | Poll::Ready(Ok(())) |
300 | } |
301 | |
302 | fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
303 | ready!(self.as_mut().poll_flush(cx))?; |
304 | ready!(self.project().inner.poll_shutdown(cx))?; |
305 | |
306 | Poll::Ready(Ok(())) |
307 | } |
308 | } |
309 | |