1use crate::codec::UserError;
2use crate::codec::UserError::*;
3use crate::frame::{self, Frame, FrameSize};
4use crate::hpack;
5
6use bytes::{Buf, BufMut, BytesMut};
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10use tokio_util::io::poll_write_buf;
11
12use std::io::{self, Cursor};
13
14// A macro to get around a method needing to borrow &mut self
15macro_rules! limited_write_buf {
16 ($self:expr) => {{
17 let limit = $self.max_frame_size() + frame::HEADER_LEN;
18 $self.buf.get_mut().limit(limit)
19 }};
20}
21
22#[derive(Debug)]
23pub struct FramedWrite<T, B> {
24 /// Upstream `AsyncWrite`
25 inner: T,
26
27 encoder: Encoder<B>,
28}
29
30#[derive(Debug)]
31struct Encoder<B> {
32 /// HPACK encoder
33 hpack: hpack::Encoder,
34
35 /// Write buffer
36 ///
37 /// TODO: Should this be a ring buffer?
38 buf: Cursor<BytesMut>,
39
40 /// Next frame to encode
41 next: Option<Next<B>>,
42
43 /// Last data frame
44 last_data_frame: Option<frame::Data<B>>,
45
46 /// Max frame size, this is specified by the peer
47 max_frame_size: FrameSize,
48
49 /// Chain payloads bigger than this.
50 chain_threshold: usize,
51
52 /// Min buffer required to attempt to write a frame
53 min_buffer_capacity: usize,
54}
55
56#[derive(Debug)]
57enum Next<B> {
58 Data(frame::Data<B>),
59 Continuation(frame::Continuation),
60}
61
62/// Initialize the connection with this amount of write buffer.
63///
64/// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS
65/// frame that big.
66const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024;
67
68/// Chain payloads bigger than this when vectored I/O is enabled. The remote
69/// will never advertise a max frame size less than this (well, the spec says
70/// the max frame size can't be less than 16kb, so not even close).
71const CHAIN_THRESHOLD: usize = 256;
72
73/// Chain payloads bigger than this when vectored I/O is **not** enabled.
74/// A larger value in this scenario will reduce the number of small and
75/// fragmented data being sent, and hereby improve the throughput.
76const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024;
77
78// TODO: Make generic
79impl<T, B> FramedWrite<T, B>
80where
81 T: AsyncWrite + Unpin,
82 B: Buf,
83{
84 pub fn new(inner: T) -> FramedWrite<T, B> {
85 let chain_threshold = if inner.is_write_vectored() {
86 CHAIN_THRESHOLD
87 } else {
88 CHAIN_THRESHOLD_WITHOUT_VECTORED_IO
89 };
90 FramedWrite {
91 inner,
92 encoder: Encoder {
93 hpack: hpack::Encoder::default(),
94 buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)),
95 next: None,
96 last_data_frame: None,
97 max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
98 chain_threshold,
99 min_buffer_capacity: chain_threshold + frame::HEADER_LEN,
100 },
101 }
102 }
103
104 /// Returns `Ready` when `send` is able to accept a frame
105 ///
106 /// Calling this function may result in the current contents of the buffer
107 /// to be flushed to `T`.
108 pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
109 if !self.encoder.has_capacity() {
110 // Try flushing
111 ready!(self.flush(cx))?;
112
113 if !self.encoder.has_capacity() {
114 return Poll::Pending;
115 }
116 }
117
118 Poll::Ready(Ok(()))
119 }
120
121 /// Buffer a frame.
122 ///
123 /// `poll_ready` must be called first to ensure that a frame may be
124 /// accepted.
125 pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
126 self.encoder.buffer(item)
127 }
128
129 /// Flush buffered data to the wire
130 pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
131 let span = tracing::trace_span!("FramedWrite::flush");
132 let _e = span.enter();
133
134 loop {
135 while !self.encoder.is_empty() {
136 match self.encoder.next {
137 Some(Next::Data(ref mut frame)) => {
138 tracing::trace!(queued_data_frame = true);
139 let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut());
140 ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))?
141 }
142 _ => {
143 tracing::trace!(queued_data_frame = false);
144 ready!(poll_write_buf(
145 Pin::new(&mut self.inner),
146 cx,
147 &mut self.encoder.buf
148 ))?
149 }
150 };
151 }
152
153 match self.encoder.unset_frame() {
154 ControlFlow::Continue => (),
155 ControlFlow::Break => break,
156 }
157 }
158
159 tracing::trace!("flushing buffer");
160 // Flush the upstream
161 ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
162
163 Poll::Ready(Ok(()))
164 }
165
166 /// Close the codec
167 pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
168 ready!(self.flush(cx))?;
169 Pin::new(&mut self.inner).poll_shutdown(cx)
170 }
171}
172
173#[must_use]
174enum ControlFlow {
175 Continue,
176 Break,
177}
178
179impl<B> Encoder<B>
180where
181 B: Buf,
182{
183 fn unset_frame(&mut self) -> ControlFlow {
184 // Clear internal buffer
185 self.buf.set_position(0);
186 self.buf.get_mut().clear();
187
188 // The data frame has been written, so unset it
189 match self.next.take() {
190 Some(Next::Data(frame)) => {
191 self.last_data_frame = Some(frame);
192 debug_assert!(self.is_empty());
193 ControlFlow::Break
194 }
195 Some(Next::Continuation(frame)) => {
196 // Buffer the continuation frame, then try to write again
197 let mut buf = limited_write_buf!(self);
198 if let Some(continuation) = frame.encode(&mut buf) {
199 self.next = Some(Next::Continuation(continuation));
200 }
201 ControlFlow::Continue
202 }
203 None => ControlFlow::Break,
204 }
205 }
206
207 fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
208 // Ensure that we have enough capacity to accept the write.
209 assert!(self.has_capacity());
210 let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item);
211 let _e = span.enter();
212
213 tracing::debug!(frame = ?item, "send");
214
215 match item {
216 Frame::Data(mut v) => {
217 // Ensure that the payload is not greater than the max frame.
218 let len = v.payload().remaining();
219
220 if len > self.max_frame_size() {
221 return Err(PayloadTooBig);
222 }
223
224 if len >= self.chain_threshold {
225 let head = v.head();
226
227 // Encode the frame head to the buffer
228 head.encode(len, self.buf.get_mut());
229
230 if self.buf.get_ref().remaining() < self.chain_threshold {
231 let extra_bytes = self.chain_threshold - self.buf.remaining();
232 self.buf.get_mut().put(v.payload_mut().take(extra_bytes));
233 }
234
235 // Save the data frame
236 self.next = Some(Next::Data(v));
237 } else {
238 v.encode_chunk(self.buf.get_mut());
239
240 // The chunk has been fully encoded, so there is no need to
241 // keep it around
242 assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded");
243
244 // Save off the last frame...
245 self.last_data_frame = Some(v);
246 }
247 }
248 Frame::Headers(v) => {
249 let mut buf = limited_write_buf!(self);
250 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
251 self.next = Some(Next::Continuation(continuation));
252 }
253 }
254 Frame::PushPromise(v) => {
255 let mut buf = limited_write_buf!(self);
256 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
257 self.next = Some(Next::Continuation(continuation));
258 }
259 }
260 Frame::Settings(v) => {
261 v.encode(self.buf.get_mut());
262 tracing::trace!(rem = self.buf.remaining(), "encoded settings");
263 }
264 Frame::GoAway(v) => {
265 v.encode(self.buf.get_mut());
266 tracing::trace!(rem = self.buf.remaining(), "encoded go_away");
267 }
268 Frame::Ping(v) => {
269 v.encode(self.buf.get_mut());
270 tracing::trace!(rem = self.buf.remaining(), "encoded ping");
271 }
272 Frame::WindowUpdate(v) => {
273 v.encode(self.buf.get_mut());
274 tracing::trace!(rem = self.buf.remaining(), "encoded window_update");
275 }
276
277 Frame::Priority(_) => {
278 /*
279 v.encode(self.buf.get_mut());
280 tracing::trace!("encoded priority; rem={:?}", self.buf.remaining());
281 */
282 unimplemented!();
283 }
284 Frame::Reset(v) => {
285 v.encode(self.buf.get_mut());
286 tracing::trace!(rem = self.buf.remaining(), "encoded reset");
287 }
288 }
289
290 Ok(())
291 }
292
293 fn has_capacity(&self) -> bool {
294 self.next.is_none()
295 && (self.buf.get_ref().capacity() - self.buf.get_ref().len()
296 >= self.min_buffer_capacity)
297 }
298
299 fn is_empty(&self) -> bool {
300 match self.next {
301 Some(Next::Data(ref frame)) => !frame.payload().has_remaining(),
302 _ => !self.buf.has_remaining(),
303 }
304 }
305}
306
307impl<B> Encoder<B> {
308 fn max_frame_size(&self) -> usize {
309 self.max_frame_size as usize
310 }
311}
312
313impl<T, B> FramedWrite<T, B> {
314 /// Returns the max frame size that can be sent
315 pub fn max_frame_size(&self) -> usize {
316 self.encoder.max_frame_size()
317 }
318
319 /// Set the peer's max frame size.
320 pub fn set_max_frame_size(&mut self, val: usize) {
321 assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
322 self.encoder.max_frame_size = val as FrameSize;
323 }
324
325 /// Set the peer's header table size.
326 pub fn set_header_table_size(&mut self, val: usize) {
327 self.encoder.hpack.update_max_size(val);
328 }
329
330 /// Retrieve the last data frame that has been sent
331 pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> {
332 self.encoder.last_data_frame.take()
333 }
334
335 pub fn get_mut(&mut self) -> &mut T {
336 &mut self.inner
337 }
338}
339
340impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> {
341 fn poll_read(
342 mut self: Pin<&mut Self>,
343 cx: &mut Context<'_>,
344 buf: &mut ReadBuf,
345 ) -> Poll<io::Result<()>> {
346 Pin::new(&mut self.inner).poll_read(cx, buf)
347 }
348}
349
350// We never project the Pin to `B`.
351impl<T: Unpin, B> Unpin for FramedWrite<T, B> {}
352
353#[cfg(feature = "unstable")]
354mod unstable {
355 use super::*;
356
357 impl<T, B> FramedWrite<T, B> {
358 pub fn get_ref(&self) -> &T {
359 &self.inner
360 }
361 }
362}
363