1 | use crate::codec::UserError; |
2 | use crate::codec::UserError::*; |
3 | use crate::frame::{self, Frame, FrameSize}; |
4 | use crate::hpack; |
5 | |
6 | use bytes::{Buf, BufMut, BytesMut}; |
7 | use std::pin::Pin; |
8 | use std::task::{Context, Poll}; |
9 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
10 | use tokio_util::io::poll_write_buf; |
11 | |
12 | use std::io::{self, Cursor}; |
13 | |
14 | // A macro to get around a method needing to borrow &mut self |
15 | macro_rules! limited_write_buf { |
16 | ($self:expr) => {{ |
17 | let limit = $self.max_frame_size() + frame::HEADER_LEN; |
18 | $self.buf.get_mut().limit(limit) |
19 | }}; |
20 | } |
21 | |
22 | #[derive (Debug)] |
23 | pub struct FramedWrite<T, B> { |
24 | /// Upstream `AsyncWrite` |
25 | inner: T, |
26 | final_flush_done: bool, |
27 | |
28 | encoder: Encoder<B>, |
29 | } |
30 | |
31 | #[derive (Debug)] |
32 | struct Encoder<B> { |
33 | /// HPACK encoder |
34 | hpack: hpack::Encoder, |
35 | |
36 | /// Write buffer |
37 | /// |
38 | /// TODO: Should this be a ring buffer? |
39 | buf: Cursor<BytesMut>, |
40 | |
41 | /// Next frame to encode |
42 | next: Option<Next<B>>, |
43 | |
44 | /// Last data frame |
45 | last_data_frame: Option<frame::Data<B>>, |
46 | |
47 | /// Max frame size, this is specified by the peer |
48 | max_frame_size: FrameSize, |
49 | |
50 | /// Chain payloads bigger than this. |
51 | chain_threshold: usize, |
52 | |
53 | /// Min buffer required to attempt to write a frame |
54 | min_buffer_capacity: usize, |
55 | } |
56 | |
57 | #[derive (Debug)] |
58 | enum Next<B> { |
59 | Data(frame::Data<B>), |
60 | Continuation(frame::Continuation), |
61 | } |
62 | |
63 | /// Initialize the connection with this amount of write buffer. |
64 | /// |
65 | /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS |
66 | /// frame that big. |
67 | const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024; |
68 | |
69 | /// Chain payloads bigger than this when vectored I/O is enabled. The remote |
70 | /// will never advertise a max frame size less than this (well, the spec says |
71 | /// the max frame size can't be less than 16kb, so not even close). |
72 | const CHAIN_THRESHOLD: usize = 256; |
73 | |
74 | /// Chain payloads bigger than this when vectored I/O is **not** enabled. |
75 | /// A larger value in this scenario will reduce the number of small and |
76 | /// fragmented data being sent, and hereby improve the throughput. |
77 | const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024; |
78 | |
79 | // TODO: Make generic |
80 | impl<T, B> FramedWrite<T, B> |
81 | where |
82 | T: AsyncWrite + Unpin, |
83 | B: Buf, |
84 | { |
85 | pub fn new(inner: T) -> FramedWrite<T, B> { |
86 | let chain_threshold = if inner.is_write_vectored() { |
87 | CHAIN_THRESHOLD |
88 | } else { |
89 | CHAIN_THRESHOLD_WITHOUT_VECTORED_IO |
90 | }; |
91 | FramedWrite { |
92 | inner, |
93 | final_flush_done: false, |
94 | encoder: Encoder { |
95 | hpack: hpack::Encoder::default(), |
96 | buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), |
97 | next: None, |
98 | last_data_frame: None, |
99 | max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, |
100 | chain_threshold, |
101 | min_buffer_capacity: chain_threshold + frame::HEADER_LEN, |
102 | }, |
103 | } |
104 | } |
105 | |
106 | /// Returns `Ready` when `send` is able to accept a frame |
107 | /// |
108 | /// Calling this function may result in the current contents of the buffer |
109 | /// to be flushed to `T`. |
110 | pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { |
111 | if !self.encoder.has_capacity() { |
112 | // Try flushing |
113 | ready!(self.flush(cx))?; |
114 | |
115 | if !self.encoder.has_capacity() { |
116 | return Poll::Pending; |
117 | } |
118 | } |
119 | |
120 | Poll::Ready(Ok(())) |
121 | } |
122 | |
123 | /// Buffer a frame. |
124 | /// |
125 | /// `poll_ready` must be called first to ensure that a frame may be |
126 | /// accepted. |
127 | pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { |
128 | self.encoder.buffer(item) |
129 | } |
130 | |
131 | /// Flush buffered data to the wire |
132 | pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { |
133 | let span = tracing::trace_span!("FramedWrite::flush" ); |
134 | let _e = span.enter(); |
135 | |
136 | loop { |
137 | while !self.encoder.is_empty() { |
138 | match self.encoder.next { |
139 | Some(Next::Data(ref mut frame)) => { |
140 | tracing::trace!(queued_data_frame = true); |
141 | let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); |
142 | ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))? |
143 | } |
144 | _ => { |
145 | tracing::trace!(queued_data_frame = false); |
146 | ready!(poll_write_buf( |
147 | Pin::new(&mut self.inner), |
148 | cx, |
149 | &mut self.encoder.buf |
150 | ))? |
151 | } |
152 | }; |
153 | } |
154 | |
155 | match self.encoder.unset_frame() { |
156 | ControlFlow::Continue => (), |
157 | ControlFlow::Break => break, |
158 | } |
159 | } |
160 | |
161 | tracing::trace!("flushing buffer" ); |
162 | // Flush the upstream |
163 | ready!(Pin::new(&mut self.inner).poll_flush(cx))?; |
164 | |
165 | Poll::Ready(Ok(())) |
166 | } |
167 | |
168 | /// Close the codec |
169 | pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { |
170 | if !self.final_flush_done { |
171 | ready!(self.flush(cx))?; |
172 | self.final_flush_done = true; |
173 | } |
174 | Pin::new(&mut self.inner).poll_shutdown(cx) |
175 | } |
176 | } |
177 | |
178 | #[must_use ] |
179 | enum ControlFlow { |
180 | Continue, |
181 | Break, |
182 | } |
183 | |
184 | impl<B> Encoder<B> |
185 | where |
186 | B: Buf, |
187 | { |
188 | fn unset_frame(&mut self) -> ControlFlow { |
189 | // Clear internal buffer |
190 | self.buf.set_position(0); |
191 | self.buf.get_mut().clear(); |
192 | |
193 | // The data frame has been written, so unset it |
194 | match self.next.take() { |
195 | Some(Next::Data(frame)) => { |
196 | self.last_data_frame = Some(frame); |
197 | debug_assert!(self.is_empty()); |
198 | ControlFlow::Break |
199 | } |
200 | Some(Next::Continuation(frame)) => { |
201 | // Buffer the continuation frame, then try to write again |
202 | let mut buf = limited_write_buf!(self); |
203 | if let Some(continuation) = frame.encode(&mut buf) { |
204 | self.next = Some(Next::Continuation(continuation)); |
205 | } |
206 | ControlFlow::Continue |
207 | } |
208 | None => ControlFlow::Break, |
209 | } |
210 | } |
211 | |
212 | fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { |
213 | // Ensure that we have enough capacity to accept the write. |
214 | assert!(self.has_capacity()); |
215 | let span = tracing::trace_span!("FramedWrite::buffer" , frame = ?item); |
216 | let _e = span.enter(); |
217 | |
218 | tracing::debug!(frame = ?item, "send" ); |
219 | |
220 | match item { |
221 | Frame::Data(mut v) => { |
222 | // Ensure that the payload is not greater than the max frame. |
223 | let len = v.payload().remaining(); |
224 | |
225 | if len > self.max_frame_size() { |
226 | return Err(PayloadTooBig); |
227 | } |
228 | |
229 | if len >= self.chain_threshold { |
230 | let head = v.head(); |
231 | |
232 | // Encode the frame head to the buffer |
233 | head.encode(len, self.buf.get_mut()); |
234 | |
235 | if self.buf.get_ref().remaining() < self.chain_threshold { |
236 | let extra_bytes = self.chain_threshold - self.buf.remaining(); |
237 | self.buf.get_mut().put(v.payload_mut().take(extra_bytes)); |
238 | } |
239 | |
240 | // Save the data frame |
241 | self.next = Some(Next::Data(v)); |
242 | } else { |
243 | v.encode_chunk(self.buf.get_mut()); |
244 | |
245 | // The chunk has been fully encoded, so there is no need to |
246 | // keep it around |
247 | assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded" ); |
248 | |
249 | // Save off the last frame... |
250 | self.last_data_frame = Some(v); |
251 | } |
252 | } |
253 | Frame::Headers(v) => { |
254 | let mut buf = limited_write_buf!(self); |
255 | if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { |
256 | self.next = Some(Next::Continuation(continuation)); |
257 | } |
258 | } |
259 | Frame::PushPromise(v) => { |
260 | let mut buf = limited_write_buf!(self); |
261 | if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { |
262 | self.next = Some(Next::Continuation(continuation)); |
263 | } |
264 | } |
265 | Frame::Settings(v) => { |
266 | v.encode(self.buf.get_mut()); |
267 | tracing::trace!(rem = self.buf.remaining(), "encoded settings" ); |
268 | } |
269 | Frame::GoAway(v) => { |
270 | v.encode(self.buf.get_mut()); |
271 | tracing::trace!(rem = self.buf.remaining(), "encoded go_away" ); |
272 | } |
273 | Frame::Ping(v) => { |
274 | v.encode(self.buf.get_mut()); |
275 | tracing::trace!(rem = self.buf.remaining(), "encoded ping" ); |
276 | } |
277 | Frame::WindowUpdate(v) => { |
278 | v.encode(self.buf.get_mut()); |
279 | tracing::trace!(rem = self.buf.remaining(), "encoded window_update" ); |
280 | } |
281 | |
282 | Frame::Priority(_) => { |
283 | /* |
284 | v.encode(self.buf.get_mut()); |
285 | tracing::trace!("encoded priority; rem={:?}", self.buf.remaining()); |
286 | */ |
287 | unimplemented!(); |
288 | } |
289 | Frame::Reset(v) => { |
290 | v.encode(self.buf.get_mut()); |
291 | tracing::trace!(rem = self.buf.remaining(), "encoded reset" ); |
292 | } |
293 | } |
294 | |
295 | Ok(()) |
296 | } |
297 | |
298 | fn has_capacity(&self) -> bool { |
299 | self.next.is_none() |
300 | && (self.buf.get_ref().capacity() - self.buf.get_ref().len() |
301 | >= self.min_buffer_capacity) |
302 | } |
303 | |
304 | fn is_empty(&self) -> bool { |
305 | match self.next { |
306 | Some(Next::Data(ref frame)) => !frame.payload().has_remaining(), |
307 | _ => !self.buf.has_remaining(), |
308 | } |
309 | } |
310 | } |
311 | |
312 | impl<B> Encoder<B> { |
313 | fn max_frame_size(&self) -> usize { |
314 | self.max_frame_size as usize |
315 | } |
316 | } |
317 | |
318 | impl<T, B> FramedWrite<T, B> { |
319 | /// Returns the max frame size that can be sent |
320 | pub fn max_frame_size(&self) -> usize { |
321 | self.encoder.max_frame_size() |
322 | } |
323 | |
324 | /// Set the peer's max frame size. |
325 | pub fn set_max_frame_size(&mut self, val: usize) { |
326 | assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); |
327 | self.encoder.max_frame_size = val as FrameSize; |
328 | } |
329 | |
330 | /// Set the peer's header table size. |
331 | pub fn set_header_table_size(&mut self, val: usize) { |
332 | self.encoder.hpack.update_max_size(val); |
333 | } |
334 | |
335 | /// Retrieve the last data frame that has been sent |
336 | pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> { |
337 | self.encoder.last_data_frame.take() |
338 | } |
339 | |
340 | pub fn get_mut(&mut self) -> &mut T { |
341 | &mut self.inner |
342 | } |
343 | } |
344 | |
345 | impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> { |
346 | fn poll_read( |
347 | mut self: Pin<&mut Self>, |
348 | cx: &mut Context<'_>, |
349 | buf: &mut ReadBuf, |
350 | ) -> Poll<io::Result<()>> { |
351 | Pin::new(&mut self.inner).poll_read(cx, buf) |
352 | } |
353 | } |
354 | |
355 | // We never project the Pin to `B`. |
356 | impl<T: Unpin, B> Unpin for FramedWrite<T, B> {} |
357 | |
358 | #[cfg (feature = "unstable" )] |
359 | mod unstable { |
360 | use super::*; |
361 | |
362 | impl<T, B> FramedWrite<T, B> { |
363 | pub fn get_ref(&self) -> &T { |
364 | &self.inner |
365 | } |
366 | } |
367 | } |
368 | |