1 | use std::cmp; |
2 | use std::fmt; |
3 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
4 | use std::future::Future; |
5 | use std::io::{self, IoSlice}; |
6 | use std::marker::Unpin; |
7 | use std::mem::MaybeUninit; |
8 | use std::pin::Pin; |
9 | use std::task::{Context, Poll}; |
10 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
11 | use std::time::Duration; |
12 | |
13 | use bytes::{Buf, BufMut, Bytes, BytesMut}; |
14 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
15 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
16 | use tokio::time::Instant; |
17 | use tracing::{debug, trace}; |
18 | |
19 | use super::{Http1Transaction, ParseContext, ParsedMessage}; |
20 | use crate::common::buf::BufList; |
21 | |
22 | /// The initial buffer size allocated before trying to read from IO. |
23 | pub(crate) const INIT_BUFFER_SIZE: usize = 8192; |
24 | |
25 | /// The minimum value that can be set to max buffer size. |
26 | pub(crate) const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE; |
27 | |
28 | /// The default maximum read buffer size. If the buffer gets this big and |
29 | /// a message is still not complete, a `TooLarge` error is triggered. |
30 | // Note: if this changes, update server::conn::Http::max_buf_size docs. |
31 | pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; |
32 | |
33 | /// The maximum number of distinct `Buf`s to hold in a list before requiring |
34 | /// a flush. Only affects when the buffer strategy is to queue buffers. |
35 | /// |
36 | /// Note that a flush can happen before reaching the maximum. This simply |
37 | /// forces a flush if the queue gets this big. |
38 | const MAX_BUF_LIST_BUFFERS: usize = 16; |
39 | |
40 | pub(crate) struct Buffered<T, B> { |
41 | flush_pipeline: bool, |
42 | io: T, |
43 | read_blocked: bool, |
44 | read_buf: BytesMut, |
45 | read_buf_strategy: ReadStrategy, |
46 | write_buf: WriteBuf<B>, |
47 | } |
48 | |
49 | impl<T, B> fmt::Debug for Buffered<T, B> |
50 | where |
51 | B: Buf, |
52 | { |
53 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
54 | f&mut DebugStruct<'_, '_>.debug_struct("Buffered" ) |
55 | .field("read_buf" , &self.read_buf) |
56 | .field(name:"write_buf" , &self.write_buf) |
57 | .finish() |
58 | } |
59 | } |
60 | |
61 | impl<T, B> Buffered<T, B> |
62 | where |
63 | T: AsyncRead + AsyncWrite + Unpin, |
64 | B: Buf, |
65 | { |
66 | pub(crate) fn new(io: T) -> Buffered<T, B> { |
67 | let strategy = if io.is_write_vectored() { |
68 | WriteStrategy::Queue |
69 | } else { |
70 | WriteStrategy::Flatten |
71 | }; |
72 | let write_buf = WriteBuf::new(strategy); |
73 | Buffered { |
74 | flush_pipeline: false, |
75 | io, |
76 | read_blocked: false, |
77 | read_buf: BytesMut::with_capacity(0), |
78 | read_buf_strategy: ReadStrategy::default(), |
79 | write_buf, |
80 | } |
81 | } |
82 | |
83 | #[cfg (feature = "server" )] |
84 | pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) { |
85 | debug_assert!(!self.write_buf.has_remaining()); |
86 | self.flush_pipeline = enabled; |
87 | if enabled { |
88 | self.set_write_strategy_flatten(); |
89 | } |
90 | } |
91 | |
92 | pub(crate) fn set_max_buf_size(&mut self, max: usize) { |
93 | assert!( |
94 | max >= MINIMUM_MAX_BUFFER_SIZE, |
95 | "The max_buf_size cannot be smaller than {}." , |
96 | MINIMUM_MAX_BUFFER_SIZE, |
97 | ); |
98 | self.read_buf_strategy = ReadStrategy::with_max(max); |
99 | self.write_buf.max_buf_size = max; |
100 | } |
101 | |
102 | #[cfg (feature = "client" )] |
103 | pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) { |
104 | self.read_buf_strategy = ReadStrategy::Exact(sz); |
105 | } |
106 | |
107 | pub(crate) fn set_write_strategy_flatten(&mut self) { |
108 | // this should always be called only at construction time, |
109 | // so this assert is here to catch myself |
110 | debug_assert!(self.write_buf.queue.bufs_cnt() == 0); |
111 | self.write_buf.set_strategy(WriteStrategy::Flatten); |
112 | } |
113 | |
114 | pub(crate) fn set_write_strategy_queue(&mut self) { |
115 | // this should always be called only at construction time, |
116 | // so this assert is here to catch myself |
117 | debug_assert!(self.write_buf.queue.bufs_cnt() == 0); |
118 | self.write_buf.set_strategy(WriteStrategy::Queue); |
119 | } |
120 | |
121 | pub(crate) fn read_buf(&self) -> &[u8] { |
122 | self.read_buf.as_ref() |
123 | } |
124 | |
125 | #[cfg (test)] |
126 | #[cfg (feature = "nightly" )] |
127 | pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut { |
128 | &mut self.read_buf |
129 | } |
130 | |
131 | /// Return the "allocated" available space, not the potential space |
132 | /// that could be allocated in the future. |
133 | fn read_buf_remaining_mut(&self) -> usize { |
134 | self.read_buf.capacity() - self.read_buf.len() |
135 | } |
136 | |
137 | /// Return whether we can append to the headers buffer. |
138 | /// |
139 | /// Reasons we can't: |
140 | /// - The write buf is in queue mode, and some of the past body is still |
141 | /// needing to be flushed. |
142 | pub(crate) fn can_headers_buf(&self) -> bool { |
143 | !self.write_buf.queue.has_remaining() |
144 | } |
145 | |
146 | pub(crate) fn headers_buf(&mut self) -> &mut Vec<u8> { |
147 | let buf = self.write_buf.headers_mut(); |
148 | &mut buf.bytes |
149 | } |
150 | |
151 | pub(super) fn write_buf(&mut self) -> &mut WriteBuf<B> { |
152 | &mut self.write_buf |
153 | } |
154 | |
155 | pub(crate) fn buffer<BB: Buf + Into<B>>(&mut self, buf: BB) { |
156 | self.write_buf.buffer(buf) |
157 | } |
158 | |
159 | pub(crate) fn can_buffer(&self) -> bool { |
160 | self.flush_pipeline || self.write_buf.can_buffer() |
161 | } |
162 | |
163 | pub(crate) fn consume_leading_lines(&mut self) { |
164 | if !self.read_buf.is_empty() { |
165 | let mut i = 0; |
166 | while i < self.read_buf.len() { |
167 | match self.read_buf[i] { |
168 | b' \r' | b' \n' => i += 1, |
169 | _ => break, |
170 | } |
171 | } |
172 | self.read_buf.advance(i); |
173 | } |
174 | } |
175 | |
176 | pub(super) fn parse<S>( |
177 | &mut self, |
178 | cx: &mut Context<'_>, |
179 | parse_ctx: ParseContext<'_>, |
180 | ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>> |
181 | where |
182 | S: Http1Transaction, |
183 | { |
184 | loop { |
185 | match super::role::parse_headers::<S>( |
186 | &mut self.read_buf, |
187 | ParseContext { |
188 | cached_headers: parse_ctx.cached_headers, |
189 | req_method: parse_ctx.req_method, |
190 | h1_parser_config: parse_ctx.h1_parser_config.clone(), |
191 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
192 | h1_header_read_timeout: parse_ctx.h1_header_read_timeout, |
193 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
194 | h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut, |
195 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
196 | h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running, |
197 | preserve_header_case: parse_ctx.preserve_header_case, |
198 | #[cfg (feature = "ffi" )] |
199 | preserve_header_order: parse_ctx.preserve_header_order, |
200 | h09_responses: parse_ctx.h09_responses, |
201 | #[cfg (feature = "ffi" )] |
202 | on_informational: parse_ctx.on_informational, |
203 | #[cfg (feature = "ffi" )] |
204 | raw_headers: parse_ctx.raw_headers, |
205 | }, |
206 | )? { |
207 | Some(msg) => { |
208 | debug!("parsed {} headers" , msg.head.headers.len()); |
209 | |
210 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
211 | { |
212 | *parse_ctx.h1_header_read_timeout_running = false; |
213 | |
214 | if let Some(h1_header_read_timeout_fut) = |
215 | parse_ctx.h1_header_read_timeout_fut |
216 | { |
217 | // Reset the timer in order to avoid woken up when the timeout finishes |
218 | h1_header_read_timeout_fut |
219 | .as_mut() |
220 | .reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60)); |
221 | } |
222 | } |
223 | return Poll::Ready(Ok(msg)); |
224 | } |
225 | None => { |
226 | let max = self.read_buf_strategy.max(); |
227 | if self.read_buf.len() >= max { |
228 | debug!("max_buf_size ( {}) reached, closing" , max); |
229 | return Poll::Ready(Err(crate::Error::new_too_large())); |
230 | } |
231 | |
232 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
233 | if *parse_ctx.h1_header_read_timeout_running { |
234 | if let Some(h1_header_read_timeout_fut) = |
235 | parse_ctx.h1_header_read_timeout_fut |
236 | { |
237 | if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() { |
238 | *parse_ctx.h1_header_read_timeout_running = false; |
239 | |
240 | tracing::warn!("read header from client timeout" ); |
241 | return Poll::Ready(Err(crate::Error::new_header_timeout())); |
242 | } |
243 | } |
244 | } |
245 | } |
246 | } |
247 | if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 { |
248 | trace!("parse eof" ); |
249 | return Poll::Ready(Err(crate::Error::new_incomplete())); |
250 | } |
251 | } |
252 | } |
253 | |
254 | pub(crate) fn poll_read_from_io(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> { |
255 | self.read_blocked = false; |
256 | let next = self.read_buf_strategy.next(); |
257 | if self.read_buf_remaining_mut() < next { |
258 | self.read_buf.reserve(next); |
259 | } |
260 | |
261 | let dst = self.read_buf.chunk_mut(); |
262 | let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; |
263 | let mut buf = ReadBuf::uninit(dst); |
264 | match Pin::new(&mut self.io).poll_read(cx, &mut buf) { |
265 | Poll::Ready(Ok(_)) => { |
266 | let n = buf.filled().len(); |
267 | trace!("received {} bytes" , n); |
268 | unsafe { |
269 | // Safety: we just read that many bytes into the |
270 | // uninitialized part of the buffer, so this is okay. |
271 | // @tokio pls give me back `poll_read_buf` thanks |
272 | self.read_buf.advance_mut(n); |
273 | } |
274 | self.read_buf_strategy.record(n); |
275 | Poll::Ready(Ok(n)) |
276 | } |
277 | Poll::Pending => { |
278 | self.read_blocked = true; |
279 | Poll::Pending |
280 | } |
281 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
282 | } |
283 | } |
284 | |
285 | pub(crate) fn into_inner(self) -> (T, Bytes) { |
286 | (self.io, self.read_buf.freeze()) |
287 | } |
288 | |
289 | pub(crate) fn io_mut(&mut self) -> &mut T { |
290 | &mut self.io |
291 | } |
292 | |
293 | pub(crate) fn is_read_blocked(&self) -> bool { |
294 | self.read_blocked |
295 | } |
296 | |
297 | pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
298 | if self.flush_pipeline && !self.read_buf.is_empty() { |
299 | Poll::Ready(Ok(())) |
300 | } else if self.write_buf.remaining() == 0 { |
301 | Pin::new(&mut self.io).poll_flush(cx) |
302 | } else { |
303 | if let WriteStrategy::Flatten = self.write_buf.strategy { |
304 | return self.poll_flush_flattened(cx); |
305 | } |
306 | |
307 | const MAX_WRITEV_BUFS: usize = 64; |
308 | loop { |
309 | let n = { |
310 | let mut iovs = [IoSlice::new(&[]); MAX_WRITEV_BUFS]; |
311 | let len = self.write_buf.chunks_vectored(&mut iovs); |
312 | ready!(Pin::new(&mut self.io).poll_write_vectored(cx, &iovs[..len]))? |
313 | }; |
314 | // TODO(eliza): we have to do this manually because |
315 | // `poll_write_buf` doesn't exist in Tokio 0.3 yet...when |
316 | // `poll_write_buf` comes back, the manual advance will need to leave! |
317 | self.write_buf.advance(n); |
318 | debug!("flushed {} bytes" , n); |
319 | if self.write_buf.remaining() == 0 { |
320 | break; |
321 | } else if n == 0 { |
322 | trace!( |
323 | "write returned zero, but {} bytes remaining" , |
324 | self.write_buf.remaining() |
325 | ); |
326 | return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); |
327 | } |
328 | } |
329 | Pin::new(&mut self.io).poll_flush(cx) |
330 | } |
331 | } |
332 | |
333 | /// Specialized version of `flush` when strategy is Flatten. |
334 | /// |
335 | /// Since all buffered bytes are flattened into the single headers buffer, |
336 | /// that skips some bookkeeping around using multiple buffers. |
337 | fn poll_flush_flattened(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
338 | loop { |
339 | let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.chunk()))?; |
340 | debug!("flushed {} bytes" , n); |
341 | self.write_buf.headers.advance(n); |
342 | if self.write_buf.headers.remaining() == 0 { |
343 | self.write_buf.headers.reset(); |
344 | break; |
345 | } else if n == 0 { |
346 | trace!( |
347 | "write returned zero, but {} bytes remaining" , |
348 | self.write_buf.remaining() |
349 | ); |
350 | return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); |
351 | } |
352 | } |
353 | Pin::new(&mut self.io).poll_flush(cx) |
354 | } |
355 | |
356 | #[cfg (test)] |
357 | fn flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a { |
358 | futures_util::future::poll_fn(move |cx| self.poll_flush(cx)) |
359 | } |
360 | } |
361 | |
362 | // The `B` is a `Buf`, we never project a pin to it |
363 | impl<T: Unpin, B> Unpin for Buffered<T, B> {} |
364 | |
365 | // TODO: This trait is old... at least rename to PollBytes or something... |
366 | pub(crate) trait MemRead { |
367 | fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>>; |
368 | } |
369 | |
370 | impl<T, B> MemRead for Buffered<T, B> |
371 | where |
372 | T: AsyncRead + AsyncWrite + Unpin, |
373 | B: Buf, |
374 | { |
375 | fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { |
376 | if !self.read_buf.is_empty() { |
377 | let n: usize = std::cmp::min(v1:len, self.read_buf.len()); |
378 | Poll::Ready(Ok(self.read_buf.split_to(at:n).freeze())) |
379 | } else { |
380 | let n: usize = ready!(self.poll_read_from_io(cx))?; |
381 | Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(v1:len, v2:n)).freeze())) |
382 | } |
383 | } |
384 | } |
385 | |
386 | #[derive (Clone, Copy, Debug)] |
387 | enum ReadStrategy { |
388 | Adaptive { |
389 | decrease_now: bool, |
390 | next: usize, |
391 | max: usize, |
392 | }, |
393 | #[cfg (feature = "client" )] |
394 | Exact(usize), |
395 | } |
396 | |
397 | impl ReadStrategy { |
398 | fn with_max(max: usize) -> ReadStrategy { |
399 | ReadStrategy::Adaptive { |
400 | decrease_now: false, |
401 | next: INIT_BUFFER_SIZE, |
402 | max, |
403 | } |
404 | } |
405 | |
406 | fn next(&self) -> usize { |
407 | match *self { |
408 | ReadStrategy::Adaptive { next, .. } => next, |
409 | #[cfg (feature = "client" )] |
410 | ReadStrategy::Exact(exact) => exact, |
411 | } |
412 | } |
413 | |
414 | fn max(&self) -> usize { |
415 | match *self { |
416 | ReadStrategy::Adaptive { max, .. } => max, |
417 | #[cfg (feature = "client" )] |
418 | ReadStrategy::Exact(exact) => exact, |
419 | } |
420 | } |
421 | |
422 | fn record(&mut self, bytes_read: usize) { |
423 | match *self { |
424 | ReadStrategy::Adaptive { |
425 | ref mut decrease_now, |
426 | ref mut next, |
427 | max, |
428 | .. |
429 | } => { |
430 | if bytes_read >= *next { |
431 | *next = cmp::min(incr_power_of_two(*next), max); |
432 | *decrease_now = false; |
433 | } else { |
434 | let decr_to = prev_power_of_two(*next); |
435 | if bytes_read < decr_to { |
436 | if *decrease_now { |
437 | *next = cmp::max(decr_to, INIT_BUFFER_SIZE); |
438 | *decrease_now = false; |
439 | } else { |
440 | // Decreasing is a two "record" process. |
441 | *decrease_now = true; |
442 | } |
443 | } else { |
444 | // A read within the current range should cancel |
445 | // a potential decrease, since we just saw proof |
446 | // that we still need this size. |
447 | *decrease_now = false; |
448 | } |
449 | } |
450 | } |
451 | #[cfg (feature = "client" )] |
452 | ReadStrategy::Exact(_) => (), |
453 | } |
454 | } |
455 | } |
456 | |
457 | fn incr_power_of_two(n: usize) -> usize { |
458 | n.saturating_mul(2) |
459 | } |
460 | |
461 | fn prev_power_of_two(n: usize) -> usize { |
462 | // Only way this shift can underflow is if n is less than 4. |
463 | // (Which would means `usize::MAX >> 64` and underflowed!) |
464 | debug_assert!(n >= 4); |
465 | (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1 |
466 | } |
467 | |
468 | impl Default for ReadStrategy { |
469 | fn default() -> ReadStrategy { |
470 | ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE) |
471 | } |
472 | } |
473 | |
474 | #[derive (Clone)] |
475 | pub(crate) struct Cursor<T> { |
476 | bytes: T, |
477 | pos: usize, |
478 | } |
479 | |
480 | impl<T: AsRef<[u8]>> Cursor<T> { |
481 | #[inline ] |
482 | pub(crate) fn new(bytes: T) -> Cursor<T> { |
483 | Cursor { bytes, pos: 0 } |
484 | } |
485 | } |
486 | |
487 | impl Cursor<Vec<u8>> { |
488 | /// If we've advanced the position a bit in this cursor, and wish to |
489 | /// extend the underlying vector, we may wish to unshift the "read" bytes |
490 | /// off, and move everything else over. |
491 | fn maybe_unshift(&mut self, additional: usize) { |
492 | if self.pos == 0 { |
493 | // nothing to do |
494 | return; |
495 | } |
496 | |
497 | if self.bytes.capacity() - self.bytes.len() >= additional { |
498 | // there's room! |
499 | return; |
500 | } |
501 | |
502 | self.bytes.drain(range:0..self.pos); |
503 | self.pos = 0; |
504 | } |
505 | |
506 | fn reset(&mut self) { |
507 | self.pos = 0; |
508 | self.bytes.clear(); |
509 | } |
510 | } |
511 | |
512 | impl<T: AsRef<[u8]>> fmt::Debug for Cursor<T> { |
513 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
514 | f&mut DebugStruct<'_, '_>.debug_struct("Cursor" ) |
515 | .field("pos" , &self.pos) |
516 | .field(name:"len" , &self.bytes.as_ref().len()) |
517 | .finish() |
518 | } |
519 | } |
520 | |
521 | impl<T: AsRef<[u8]>> Buf for Cursor<T> { |
522 | #[inline ] |
523 | fn remaining(&self) -> usize { |
524 | self.bytes.as_ref().len() - self.pos |
525 | } |
526 | |
527 | #[inline ] |
528 | fn chunk(&self) -> &[u8] { |
529 | &self.bytes.as_ref()[self.pos..] |
530 | } |
531 | |
532 | #[inline ] |
533 | fn advance(&mut self, cnt: usize) { |
534 | debug_assert!(self.pos + cnt <= self.bytes.as_ref().len()); |
535 | self.pos += cnt; |
536 | } |
537 | } |
538 | |
539 | // an internal buffer to collect writes before flushes |
540 | pub(super) struct WriteBuf<B> { |
541 | /// Re-usable buffer that holds message headers |
542 | headers: Cursor<Vec<u8>>, |
543 | max_buf_size: usize, |
544 | /// Deque of user buffers if strategy is Queue |
545 | queue: BufList<B>, |
546 | strategy: WriteStrategy, |
547 | } |
548 | |
549 | impl<B: Buf> WriteBuf<B> { |
550 | fn new(strategy: WriteStrategy) -> WriteBuf<B> { |
551 | WriteBuf { |
552 | headers: Cursor::new(bytes:Vec::with_capacity(INIT_BUFFER_SIZE)), |
553 | max_buf_size: DEFAULT_MAX_BUFFER_SIZE, |
554 | queue: BufList::new(), |
555 | strategy, |
556 | } |
557 | } |
558 | } |
559 | |
560 | impl<B> WriteBuf<B> |
561 | where |
562 | B: Buf, |
563 | { |
564 | fn set_strategy(&mut self, strategy: WriteStrategy) { |
565 | self.strategy = strategy; |
566 | } |
567 | |
568 | pub(super) fn buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB) { |
569 | debug_assert!(buf.has_remaining()); |
570 | match self.strategy { |
571 | WriteStrategy::Flatten => { |
572 | let head = self.headers_mut(); |
573 | |
574 | head.maybe_unshift(buf.remaining()); |
575 | trace!( |
576 | self.len = head.remaining(), |
577 | buf.len = buf.remaining(), |
578 | "buffer.flatten" |
579 | ); |
580 | //perf: This is a little faster than <Vec as BufMut>>::put, |
581 | //but accomplishes the same result. |
582 | loop { |
583 | let adv = { |
584 | let slice = buf.chunk(); |
585 | if slice.is_empty() { |
586 | return; |
587 | } |
588 | head.bytes.extend_from_slice(slice); |
589 | slice.len() |
590 | }; |
591 | buf.advance(adv); |
592 | } |
593 | } |
594 | WriteStrategy::Queue => { |
595 | trace!( |
596 | self.len = self.remaining(), |
597 | buf.len = buf.remaining(), |
598 | "buffer.queue" |
599 | ); |
600 | self.queue.push(buf.into()); |
601 | } |
602 | } |
603 | } |
604 | |
605 | fn can_buffer(&self) -> bool { |
606 | match self.strategy { |
607 | WriteStrategy::Flatten => self.remaining() < self.max_buf_size, |
608 | WriteStrategy::Queue => { |
609 | self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size |
610 | } |
611 | } |
612 | } |
613 | |
614 | fn headers_mut(&mut self) -> &mut Cursor<Vec<u8>> { |
615 | debug_assert!(!self.queue.has_remaining()); |
616 | &mut self.headers |
617 | } |
618 | } |
619 | |
620 | impl<B: Buf> fmt::Debug for WriteBuf<B> { |
621 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
622 | f&mut DebugStruct<'_, '_>.debug_struct("WriteBuf" ) |
623 | .field("remaining" , &self.remaining()) |
624 | .field(name:"strategy" , &self.strategy) |
625 | .finish() |
626 | } |
627 | } |
628 | |
629 | impl<B: Buf> Buf for WriteBuf<B> { |
630 | #[inline ] |
631 | fn remaining(&self) -> usize { |
632 | self.headers.remaining() + self.queue.remaining() |
633 | } |
634 | |
635 | #[inline ] |
636 | fn chunk(&self) -> &[u8] { |
637 | let headers = self.headers.chunk(); |
638 | if !headers.is_empty() { |
639 | headers |
640 | } else { |
641 | self.queue.chunk() |
642 | } |
643 | } |
644 | |
645 | #[inline ] |
646 | fn advance(&mut self, cnt: usize) { |
647 | let hrem = self.headers.remaining(); |
648 | |
649 | match hrem.cmp(&cnt) { |
650 | cmp::Ordering::Equal => self.headers.reset(), |
651 | cmp::Ordering::Greater => self.headers.advance(cnt), |
652 | cmp::Ordering::Less => { |
653 | let qcnt = cnt - hrem; |
654 | self.headers.reset(); |
655 | self.queue.advance(qcnt); |
656 | } |
657 | } |
658 | } |
659 | |
660 | #[inline ] |
661 | fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { |
662 | let n = self.headers.chunks_vectored(dst); |
663 | self.queue.chunks_vectored(&mut dst[n..]) + n |
664 | } |
665 | } |
666 | |
667 | #[derive (Debug)] |
668 | enum WriteStrategy { |
669 | Flatten, |
670 | Queue, |
671 | } |
672 | |
673 | #[cfg (test)] |
674 | mod tests { |
675 | use super::*; |
676 | use std::time::Duration; |
677 | |
678 | use tokio_test::io::Builder as Mock; |
679 | |
680 | // #[cfg(feature = "nightly")] |
681 | // use test::Bencher; |
682 | |
683 | /* |
684 | impl<T: Read> MemRead for AsyncIo<T> { |
685 | fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> { |
686 | let mut v = vec![0; len]; |
687 | let n = try_nb!(self.read(v.as_mut_slice())); |
688 | Ok(Async::Ready(BytesMut::from(&v[..n]).freeze())) |
689 | } |
690 | } |
691 | */ |
692 | |
693 | #[tokio::test ] |
694 | #[ignore ] |
695 | async fn iobuf_write_empty_slice() { |
696 | // TODO(eliza): can i have writev back pls T_T |
697 | // // First, let's just check that the Mock would normally return an |
698 | // // error on an unexpected write, even if the buffer is empty... |
699 | // let mut mock = Mock::new().build(); |
700 | // futures_util::future::poll_fn(|cx| { |
701 | // Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[])) |
702 | // }) |
703 | // .await |
704 | // .expect_err("should be a broken pipe"); |
705 | |
706 | // // underlying io will return the logic error upon write, |
707 | // // so we are testing that the io_buf does not trigger a write |
708 | // // when there is nothing to flush |
709 | // let mock = Mock::new().build(); |
710 | // let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
711 | // io_buf.flush().await.expect("should short-circuit flush"); |
712 | } |
713 | |
714 | #[tokio::test ] |
715 | async fn parse_reads_until_blocked() { |
716 | use crate::proto::h1::ClientTransaction; |
717 | |
718 | let _ = pretty_env_logger::try_init(); |
719 | let mock = Mock::new() |
720 | // Split over multiple reads will read all of it |
721 | .read(b"HTTP/1.1 200 OK \r\n" ) |
722 | .read(b"Server: hyper \r\n" ) |
723 | // missing last line ending |
724 | .wait(Duration::from_secs(1)) |
725 | .build(); |
726 | |
727 | let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
728 | |
729 | // We expect a `parse` to be not ready, and so can't await it directly. |
730 | // Rather, this `poll_fn` will wrap the `Poll` result. |
731 | futures_util::future::poll_fn(|cx| { |
732 | let parse_ctx = ParseContext { |
733 | cached_headers: &mut None, |
734 | req_method: &mut None, |
735 | h1_parser_config: Default::default(), |
736 | #[cfg (feature = "runtime" )] |
737 | h1_header_read_timeout: None, |
738 | #[cfg (feature = "runtime" )] |
739 | h1_header_read_timeout_fut: &mut None, |
740 | #[cfg (feature = "runtime" )] |
741 | h1_header_read_timeout_running: &mut false, |
742 | preserve_header_case: false, |
743 | #[cfg (feature = "ffi" )] |
744 | preserve_header_order: false, |
745 | h09_responses: false, |
746 | #[cfg (feature = "ffi" )] |
747 | on_informational: &mut None, |
748 | #[cfg (feature = "ffi" )] |
749 | raw_headers: false, |
750 | }; |
751 | assert!(buffered |
752 | .parse::<ClientTransaction>(cx, parse_ctx) |
753 | .is_pending()); |
754 | Poll::Ready(()) |
755 | }) |
756 | .await; |
757 | |
758 | assert_eq!( |
759 | buffered.read_buf, |
760 | b"HTTP/1.1 200 OK \r\nServer: hyper \r\n" [..] |
761 | ); |
762 | } |
763 | |
764 | #[test ] |
765 | fn read_strategy_adaptive_increments() { |
766 | let mut strategy = ReadStrategy::default(); |
767 | assert_eq!(strategy.next(), 8192); |
768 | |
769 | // Grows if record == next |
770 | strategy.record(8192); |
771 | assert_eq!(strategy.next(), 16384); |
772 | |
773 | strategy.record(16384); |
774 | assert_eq!(strategy.next(), 32768); |
775 | |
776 | // Enormous records still increment at same rate |
777 | strategy.record(::std::usize::MAX); |
778 | assert_eq!(strategy.next(), 65536); |
779 | |
780 | let max = strategy.max(); |
781 | while strategy.next() < max { |
782 | strategy.record(max); |
783 | } |
784 | |
785 | assert_eq!(strategy.next(), max, "never goes over max" ); |
786 | strategy.record(max + 1); |
787 | assert_eq!(strategy.next(), max, "never goes over max" ); |
788 | } |
789 | |
790 | #[test ] |
791 | fn read_strategy_adaptive_decrements() { |
792 | let mut strategy = ReadStrategy::default(); |
793 | strategy.record(8192); |
794 | assert_eq!(strategy.next(), 16384); |
795 | |
796 | strategy.record(1); |
797 | assert_eq!( |
798 | strategy.next(), |
799 | 16384, |
800 | "first smaller record doesn't decrement yet" |
801 | ); |
802 | strategy.record(8192); |
803 | assert_eq!(strategy.next(), 16384, "record was with range" ); |
804 | |
805 | strategy.record(1); |
806 | assert_eq!( |
807 | strategy.next(), |
808 | 16384, |
809 | "in-range record should make this the 'first' again" |
810 | ); |
811 | |
812 | strategy.record(1); |
813 | assert_eq!(strategy.next(), 8192, "second smaller record decrements" ); |
814 | |
815 | strategy.record(1); |
816 | assert_eq!(strategy.next(), 8192, "first doesn't decrement" ); |
817 | strategy.record(1); |
818 | assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum" ); |
819 | } |
820 | |
821 | #[test ] |
822 | fn read_strategy_adaptive_stays_the_same() { |
823 | let mut strategy = ReadStrategy::default(); |
824 | strategy.record(8192); |
825 | assert_eq!(strategy.next(), 16384); |
826 | |
827 | strategy.record(8193); |
828 | assert_eq!( |
829 | strategy.next(), |
830 | 16384, |
831 | "first smaller record doesn't decrement yet" |
832 | ); |
833 | |
834 | strategy.record(8193); |
835 | assert_eq!( |
836 | strategy.next(), |
837 | 16384, |
838 | "with current step does not decrement" |
839 | ); |
840 | } |
841 | |
842 | #[test ] |
843 | fn read_strategy_adaptive_max_fuzz() { |
844 | fn fuzz(max: usize) { |
845 | let mut strategy = ReadStrategy::with_max(max); |
846 | while strategy.next() < max { |
847 | strategy.record(::std::usize::MAX); |
848 | } |
849 | let mut next = strategy.next(); |
850 | while next > 8192 { |
851 | strategy.record(1); |
852 | strategy.record(1); |
853 | next = strategy.next(); |
854 | assert!( |
855 | next.is_power_of_two(), |
856 | "decrement should be powers of two: {} (max = {})" , |
857 | next, |
858 | max, |
859 | ); |
860 | } |
861 | } |
862 | |
863 | let mut max = 8192; |
864 | while max < std::usize::MAX { |
865 | fuzz(max); |
866 | max = (max / 2).saturating_mul(3); |
867 | } |
868 | fuzz(::std::usize::MAX); |
869 | } |
870 | |
871 | #[test ] |
872 | #[should_panic ] |
873 | #[cfg (debug_assertions)] // needs to trigger a debug_assert |
874 | fn write_buf_requires_non_empty_bufs() { |
875 | let mock = Mock::new().build(); |
876 | let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
877 | |
878 | buffered.buffer(Cursor::new(Vec::new())); |
879 | } |
880 | |
881 | /* |
882 | TODO: needs tokio_test::io to allow configure write_buf calls |
883 | #[test] |
884 | fn write_buf_queue() { |
885 | let _ = pretty_env_logger::try_init(); |
886 | |
887 | let mock = AsyncIo::new_buf(vec![], 1024); |
888 | let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
889 | |
890 | |
891 | buffered.headers_buf().extend(b"hello "); |
892 | buffered.buffer(Cursor::new(b"world, ".to_vec())); |
893 | buffered.buffer(Cursor::new(b"it's ".to_vec())); |
894 | buffered.buffer(Cursor::new(b"hyper!".to_vec())); |
895 | assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); |
896 | buffered.flush().unwrap(); |
897 | |
898 | assert_eq!(buffered.io, b"hello world, it's hyper!"); |
899 | assert_eq!(buffered.io.num_writes(), 1); |
900 | assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); |
901 | } |
902 | */ |
903 | |
904 | #[tokio::test ] |
905 | async fn write_buf_flatten() { |
906 | let _ = pretty_env_logger::try_init(); |
907 | |
908 | let mock = Mock::new().write(b"hello world, it's hyper!" ).build(); |
909 | |
910 | let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
911 | buffered.write_buf.set_strategy(WriteStrategy::Flatten); |
912 | |
913 | buffered.headers_buf().extend(b"hello " ); |
914 | buffered.buffer(Cursor::new(b"world, " .to_vec())); |
915 | buffered.buffer(Cursor::new(b"it's " .to_vec())); |
916 | buffered.buffer(Cursor::new(b"hyper!" .to_vec())); |
917 | assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); |
918 | |
919 | buffered.flush().await.expect("flush" ); |
920 | } |
921 | |
922 | #[test ] |
923 | fn write_buf_flatten_partially_flushed() { |
924 | let _ = pretty_env_logger::try_init(); |
925 | |
926 | let b = |s: &str| Cursor::new(s.as_bytes().to_vec()); |
927 | |
928 | let mut write_buf = WriteBuf::<Cursor<Vec<u8>>>::new(WriteStrategy::Flatten); |
929 | |
930 | write_buf.buffer(b("hello " )); |
931 | write_buf.buffer(b("world, " )); |
932 | |
933 | assert_eq!(write_buf.chunk(), b"hello world, " ); |
934 | |
935 | // advance most of the way, but not all |
936 | write_buf.advance(11); |
937 | |
938 | assert_eq!(write_buf.chunk(), b", " ); |
939 | assert_eq!(write_buf.headers.pos, 11); |
940 | assert_eq!(write_buf.headers.bytes.capacity(), INIT_BUFFER_SIZE); |
941 | |
942 | // there's still room in the headers buffer, so just push on the end |
943 | write_buf.buffer(b("it's hyper!" )); |
944 | |
945 | assert_eq!(write_buf.chunk(), b", it's hyper!" ); |
946 | assert_eq!(write_buf.headers.pos, 11); |
947 | |
948 | let rem1 = write_buf.remaining(); |
949 | let cap = write_buf.headers.bytes.capacity(); |
950 | |
951 | // but when this would go over capacity, don't copy the old bytes |
952 | write_buf.buffer(Cursor::new(vec![b'X' ; cap])); |
953 | assert_eq!(write_buf.remaining(), cap + rem1); |
954 | assert_eq!(write_buf.headers.pos, 0); |
955 | } |
956 | |
957 | #[tokio::test ] |
958 | async fn write_buf_queue_disable_auto() { |
959 | let _ = pretty_env_logger::try_init(); |
960 | |
961 | let mock = Mock::new() |
962 | .write(b"hello " ) |
963 | .write(b"world, " ) |
964 | .write(b"it's " ) |
965 | .write(b"hyper!" ) |
966 | .build(); |
967 | |
968 | let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
969 | buffered.write_buf.set_strategy(WriteStrategy::Queue); |
970 | |
971 | // we have 4 buffers, and vec IO disabled, but explicitly said |
972 | // don't try to auto detect (via setting strategy above) |
973 | |
974 | buffered.headers_buf().extend(b"hello " ); |
975 | buffered.buffer(Cursor::new(b"world, " .to_vec())); |
976 | buffered.buffer(Cursor::new(b"it's " .to_vec())); |
977 | buffered.buffer(Cursor::new(b"hyper!" .to_vec())); |
978 | assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); |
979 | |
980 | buffered.flush().await.expect("flush" ); |
981 | |
982 | assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); |
983 | } |
984 | |
985 | // #[cfg(feature = "nightly")] |
986 | // #[bench] |
987 | // fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) { |
988 | // let s = "Hello, World!"; |
989 | // b.bytes = s.len() as u64; |
990 | |
991 | // let mut write_buf = WriteBuf::<bytes::Bytes>::new(); |
992 | // write_buf.set_strategy(WriteStrategy::Flatten); |
993 | // b.iter(|| { |
994 | // let chunk = bytes::Bytes::from(s); |
995 | // write_buf.buffer(chunk); |
996 | // ::test::black_box(&write_buf); |
997 | // write_buf.headers.bytes.clear(); |
998 | // }) |
999 | // } |
1000 | } |
1001 | |