1 | use crate::codec::UserError; |
2 | use crate::codec::UserError::*; |
3 | use crate::frame::{self, Frame, FrameSize}; |
4 | use crate::hpack; |
5 | |
6 | use bytes::{Buf, BufMut, BytesMut}; |
7 | use std::pin::Pin; |
8 | use std::task::{Context, Poll}; |
9 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
10 | use tokio_util::io::poll_write_buf; |
11 | |
12 | use std::io::{self, Cursor}; |
13 | |
14 | // A macro to get around a method needing to borrow &mut self |
15 | macro_rules! limited_write_buf { |
16 | ($self:expr) => {{ |
17 | let limit = $self.max_frame_size() + frame::HEADER_LEN; |
18 | $self.buf.get_mut().limit(limit) |
19 | }}; |
20 | } |
21 | |
22 | #[derive (Debug)] |
23 | pub struct FramedWrite<T, B> { |
24 | /// Upstream `AsyncWrite` |
25 | inner: T, |
26 | |
27 | encoder: Encoder<B>, |
28 | } |
29 | |
30 | #[derive (Debug)] |
31 | struct Encoder<B> { |
32 | /// HPACK encoder |
33 | hpack: hpack::Encoder, |
34 | |
35 | /// Write buffer |
36 | /// |
37 | /// TODO: Should this be a ring buffer? |
38 | buf: Cursor<BytesMut>, |
39 | |
40 | /// Next frame to encode |
41 | next: Option<Next<B>>, |
42 | |
43 | /// Last data frame |
44 | last_data_frame: Option<frame::Data<B>>, |
45 | |
46 | /// Max frame size, this is specified by the peer |
47 | max_frame_size: FrameSize, |
48 | |
49 | /// Chain payloads bigger than this. |
50 | chain_threshold: usize, |
51 | |
52 | /// Min buffer required to attempt to write a frame |
53 | min_buffer_capacity: usize, |
54 | } |
55 | |
56 | #[derive (Debug)] |
57 | enum Next<B> { |
58 | Data(frame::Data<B>), |
59 | Continuation(frame::Continuation), |
60 | } |
61 | |
62 | /// Initialize the connection with this amount of write buffer. |
63 | /// |
64 | /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS |
65 | /// frame that big. |
66 | const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024; |
67 | |
68 | /// Chain payloads bigger than this when vectored I/O is enabled. The remote |
69 | /// will never advertise a max frame size less than this (well, the spec says |
70 | /// the max frame size can't be less than 16kb, so not even close). |
71 | const CHAIN_THRESHOLD: usize = 256; |
72 | |
73 | /// Chain payloads bigger than this when vectored I/O is **not** enabled. |
74 | /// A larger value in this scenario will reduce the number of small and |
75 | /// fragmented data being sent, and hereby improve the throughput. |
76 | const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024; |
77 | |
78 | // TODO: Make generic |
79 | impl<T, B> FramedWrite<T, B> |
80 | where |
81 | T: AsyncWrite + Unpin, |
82 | B: Buf, |
83 | { |
84 | pub fn new(inner: T) -> FramedWrite<T, B> { |
85 | let chain_threshold = if inner.is_write_vectored() { |
86 | CHAIN_THRESHOLD |
87 | } else { |
88 | CHAIN_THRESHOLD_WITHOUT_VECTORED_IO |
89 | }; |
90 | FramedWrite { |
91 | inner, |
92 | encoder: Encoder { |
93 | hpack: hpack::Encoder::default(), |
94 | buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), |
95 | next: None, |
96 | last_data_frame: None, |
97 | max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, |
98 | chain_threshold, |
99 | min_buffer_capacity: chain_threshold + frame::HEADER_LEN, |
100 | }, |
101 | } |
102 | } |
103 | |
104 | /// Returns `Ready` when `send` is able to accept a frame |
105 | /// |
106 | /// Calling this function may result in the current contents of the buffer |
107 | /// to be flushed to `T`. |
108 | pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { |
109 | if !self.encoder.has_capacity() { |
110 | // Try flushing |
111 | ready!(self.flush(cx))?; |
112 | |
113 | if !self.encoder.has_capacity() { |
114 | return Poll::Pending; |
115 | } |
116 | } |
117 | |
118 | Poll::Ready(Ok(())) |
119 | } |
120 | |
121 | /// Buffer a frame. |
122 | /// |
123 | /// `poll_ready` must be called first to ensure that a frame may be |
124 | /// accepted. |
125 | pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { |
126 | self.encoder.buffer(item) |
127 | } |
128 | |
129 | /// Flush buffered data to the wire |
130 | pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { |
131 | let span = tracing::trace_span!("FramedWrite::flush" ); |
132 | let _e = span.enter(); |
133 | |
134 | loop { |
135 | while !self.encoder.is_empty() { |
136 | match self.encoder.next { |
137 | Some(Next::Data(ref mut frame)) => { |
138 | tracing::trace!(queued_data_frame = true); |
139 | let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); |
140 | ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))? |
141 | } |
142 | _ => { |
143 | tracing::trace!(queued_data_frame = false); |
144 | ready!(poll_write_buf( |
145 | Pin::new(&mut self.inner), |
146 | cx, |
147 | &mut self.encoder.buf |
148 | ))? |
149 | } |
150 | }; |
151 | } |
152 | |
153 | match self.encoder.unset_frame() { |
154 | ControlFlow::Continue => (), |
155 | ControlFlow::Break => break, |
156 | } |
157 | } |
158 | |
159 | tracing::trace!("flushing buffer" ); |
160 | // Flush the upstream |
161 | ready!(Pin::new(&mut self.inner).poll_flush(cx))?; |
162 | |
163 | Poll::Ready(Ok(())) |
164 | } |
165 | |
166 | /// Close the codec |
167 | pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { |
168 | ready!(self.flush(cx))?; |
169 | Pin::new(&mut self.inner).poll_shutdown(cx) |
170 | } |
171 | } |
172 | |
173 | #[must_use ] |
174 | enum ControlFlow { |
175 | Continue, |
176 | Break, |
177 | } |
178 | |
179 | impl<B> Encoder<B> |
180 | where |
181 | B: Buf, |
182 | { |
183 | fn unset_frame(&mut self) -> ControlFlow { |
184 | // Clear internal buffer |
185 | self.buf.set_position(0); |
186 | self.buf.get_mut().clear(); |
187 | |
188 | // The data frame has been written, so unset it |
189 | match self.next.take() { |
190 | Some(Next::Data(frame)) => { |
191 | self.last_data_frame = Some(frame); |
192 | debug_assert!(self.is_empty()); |
193 | ControlFlow::Break |
194 | } |
195 | Some(Next::Continuation(frame)) => { |
196 | // Buffer the continuation frame, then try to write again |
197 | let mut buf = limited_write_buf!(self); |
198 | if let Some(continuation) = frame.encode(&mut buf) { |
199 | self.next = Some(Next::Continuation(continuation)); |
200 | } |
201 | ControlFlow::Continue |
202 | } |
203 | None => ControlFlow::Break, |
204 | } |
205 | } |
206 | |
207 | fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { |
208 | // Ensure that we have enough capacity to accept the write. |
209 | assert!(self.has_capacity()); |
210 | let span = tracing::trace_span!("FramedWrite::buffer" , frame = ?item); |
211 | let _e = span.enter(); |
212 | |
213 | tracing::debug!(frame = ?item, "send" ); |
214 | |
215 | match item { |
216 | Frame::Data(mut v) => { |
217 | // Ensure that the payload is not greater than the max frame. |
218 | let len = v.payload().remaining(); |
219 | |
220 | if len > self.max_frame_size() { |
221 | return Err(PayloadTooBig); |
222 | } |
223 | |
224 | if len >= self.chain_threshold { |
225 | let head = v.head(); |
226 | |
227 | // Encode the frame head to the buffer |
228 | head.encode(len, self.buf.get_mut()); |
229 | |
230 | if self.buf.get_ref().remaining() < self.chain_threshold { |
231 | let extra_bytes = self.chain_threshold - self.buf.remaining(); |
232 | self.buf.get_mut().put(v.payload_mut().take(extra_bytes)); |
233 | } |
234 | |
235 | // Save the data frame |
236 | self.next = Some(Next::Data(v)); |
237 | } else { |
238 | v.encode_chunk(self.buf.get_mut()); |
239 | |
240 | // The chunk has been fully encoded, so there is no need to |
241 | // keep it around |
242 | assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded" ); |
243 | |
244 | // Save off the last frame... |
245 | self.last_data_frame = Some(v); |
246 | } |
247 | } |
248 | Frame::Headers(v) => { |
249 | let mut buf = limited_write_buf!(self); |
250 | if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { |
251 | self.next = Some(Next::Continuation(continuation)); |
252 | } |
253 | } |
254 | Frame::PushPromise(v) => { |
255 | let mut buf = limited_write_buf!(self); |
256 | if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { |
257 | self.next = Some(Next::Continuation(continuation)); |
258 | } |
259 | } |
260 | Frame::Settings(v) => { |
261 | v.encode(self.buf.get_mut()); |
262 | tracing::trace!(rem = self.buf.remaining(), "encoded settings" ); |
263 | } |
264 | Frame::GoAway(v) => { |
265 | v.encode(self.buf.get_mut()); |
266 | tracing::trace!(rem = self.buf.remaining(), "encoded go_away" ); |
267 | } |
268 | Frame::Ping(v) => { |
269 | v.encode(self.buf.get_mut()); |
270 | tracing::trace!(rem = self.buf.remaining(), "encoded ping" ); |
271 | } |
272 | Frame::WindowUpdate(v) => { |
273 | v.encode(self.buf.get_mut()); |
274 | tracing::trace!(rem = self.buf.remaining(), "encoded window_update" ); |
275 | } |
276 | |
277 | Frame::Priority(_) => { |
278 | /* |
279 | v.encode(self.buf.get_mut()); |
280 | tracing::trace!("encoded priority; rem={:?}", self.buf.remaining()); |
281 | */ |
282 | unimplemented!(); |
283 | } |
284 | Frame::Reset(v) => { |
285 | v.encode(self.buf.get_mut()); |
286 | tracing::trace!(rem = self.buf.remaining(), "encoded reset" ); |
287 | } |
288 | } |
289 | |
290 | Ok(()) |
291 | } |
292 | |
293 | fn has_capacity(&self) -> bool { |
294 | self.next.is_none() |
295 | && (self.buf.get_ref().capacity() - self.buf.get_ref().len() |
296 | >= self.min_buffer_capacity) |
297 | } |
298 | |
299 | fn is_empty(&self) -> bool { |
300 | match self.next { |
301 | Some(Next::Data(ref frame)) => !frame.payload().has_remaining(), |
302 | _ => !self.buf.has_remaining(), |
303 | } |
304 | } |
305 | } |
306 | |
307 | impl<B> Encoder<B> { |
308 | fn max_frame_size(&self) -> usize { |
309 | self.max_frame_size as usize |
310 | } |
311 | } |
312 | |
313 | impl<T, B> FramedWrite<T, B> { |
314 | /// Returns the max frame size that can be sent |
315 | pub fn max_frame_size(&self) -> usize { |
316 | self.encoder.max_frame_size() |
317 | } |
318 | |
319 | /// Set the peer's max frame size. |
320 | pub fn set_max_frame_size(&mut self, val: usize) { |
321 | assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); |
322 | self.encoder.max_frame_size = val as FrameSize; |
323 | } |
324 | |
325 | /// Set the peer's header table size. |
326 | pub fn set_header_table_size(&mut self, val: usize) { |
327 | self.encoder.hpack.update_max_size(val); |
328 | } |
329 | |
330 | /// Retrieve the last data frame that has been sent |
331 | pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> { |
332 | self.encoder.last_data_frame.take() |
333 | } |
334 | |
335 | pub fn get_mut(&mut self) -> &mut T { |
336 | &mut self.inner |
337 | } |
338 | } |
339 | |
340 | impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> { |
341 | fn poll_read( |
342 | mut self: Pin<&mut Self>, |
343 | cx: &mut Context<'_>, |
344 | buf: &mut ReadBuf, |
345 | ) -> Poll<io::Result<()>> { |
346 | Pin::new(&mut self.inner).poll_read(cx, buf) |
347 | } |
348 | } |
349 | |
350 | // We never project the Pin to `B`. |
351 | impl<T: Unpin, B> Unpin for FramedWrite<T, B> {} |
352 | |
353 | #[cfg (feature = "unstable" )] |
354 | mod unstable { |
355 | use super::*; |
356 | |
357 | impl<T, B> FramedWrite<T, B> { |
358 | pub fn get_ref(&self) -> &T { |
359 | &self.inner |
360 | } |
361 | } |
362 | } |
363 | |