1 | use std::cmp; |
2 | use std::fmt; |
3 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
4 | use std::future::Future; |
5 | use std::io::{self, IoSlice}; |
6 | use std::marker::Unpin; |
7 | use std::mem::MaybeUninit; |
8 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
9 | use std::time::Duration; |
10 | |
11 | use bytes::{Buf, BufMut, Bytes, BytesMut}; |
12 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
13 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
14 | use tokio::time::Instant; |
15 | use tracing::{debug, trace}; |
16 | |
17 | use super::{Http1Transaction, ParseContext, ParsedMessage}; |
18 | use crate::common::buf::BufList; |
19 | use crate::common::{task, Pin, Poll}; |
20 | |
21 | /// The initial buffer size allocated before trying to read from IO. |
22 | pub(crate) const INIT_BUFFER_SIZE: usize = 8192; |
23 | |
24 | /// The minimum value that can be set to max buffer size. |
25 | pub(crate) const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE; |
26 | |
27 | /// The default maximum read buffer size. If the buffer gets this big and |
28 | /// a message is still not complete, a `TooLarge` error is triggered. |
29 | // Note: if this changes, update server::conn::Http::max_buf_size docs. |
30 | pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; |
31 | |
32 | /// The maximum number of distinct `Buf`s to hold in a list before requiring |
33 | /// a flush. Only affects when the buffer strategy is to queue buffers. |
34 | /// |
35 | /// Note that a flush can happen before reaching the maximum. This simply |
36 | /// forces a flush if the queue gets this big. |
37 | const MAX_BUF_LIST_BUFFERS: usize = 16; |
38 | |
39 | pub(crate) struct Buffered<T, B> { |
40 | flush_pipeline: bool, |
41 | io: T, |
42 | read_blocked: bool, |
43 | read_buf: BytesMut, |
44 | read_buf_strategy: ReadStrategy, |
45 | write_buf: WriteBuf<B>, |
46 | } |
47 | |
48 | impl<T, B> fmt::Debug for Buffered<T, B> |
49 | where |
50 | B: Buf, |
51 | { |
52 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
53 | f&mut DebugStruct<'_, '_>.debug_struct("Buffered" ) |
54 | .field("read_buf" , &self.read_buf) |
55 | .field(name:"write_buf" , &self.write_buf) |
56 | .finish() |
57 | } |
58 | } |
59 | |
60 | impl<T, B> Buffered<T, B> |
61 | where |
62 | T: AsyncRead + AsyncWrite + Unpin, |
63 | B: Buf, |
64 | { |
65 | pub(crate) fn new(io: T) -> Buffered<T, B> { |
66 | let strategy = if io.is_write_vectored() { |
67 | WriteStrategy::Queue |
68 | } else { |
69 | WriteStrategy::Flatten |
70 | }; |
71 | let write_buf = WriteBuf::new(strategy); |
72 | Buffered { |
73 | flush_pipeline: false, |
74 | io, |
75 | read_blocked: false, |
76 | read_buf: BytesMut::with_capacity(0), |
77 | read_buf_strategy: ReadStrategy::default(), |
78 | write_buf, |
79 | } |
80 | } |
81 | |
82 | #[cfg (feature = "server" )] |
83 | pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) { |
84 | debug_assert!(!self.write_buf.has_remaining()); |
85 | self.flush_pipeline = enabled; |
86 | if enabled { |
87 | self.set_write_strategy_flatten(); |
88 | } |
89 | } |
90 | |
91 | pub(crate) fn set_max_buf_size(&mut self, max: usize) { |
92 | assert!( |
93 | max >= MINIMUM_MAX_BUFFER_SIZE, |
94 | "The max_buf_size cannot be smaller than {}." , |
95 | MINIMUM_MAX_BUFFER_SIZE, |
96 | ); |
97 | self.read_buf_strategy = ReadStrategy::with_max(max); |
98 | self.write_buf.max_buf_size = max; |
99 | } |
100 | |
101 | #[cfg (feature = "client" )] |
102 | pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) { |
103 | self.read_buf_strategy = ReadStrategy::Exact(sz); |
104 | } |
105 | |
106 | pub(crate) fn set_write_strategy_flatten(&mut self) { |
107 | // this should always be called only at construction time, |
108 | // so this assert is here to catch myself |
109 | debug_assert!(self.write_buf.queue.bufs_cnt() == 0); |
110 | self.write_buf.set_strategy(WriteStrategy::Flatten); |
111 | } |
112 | |
113 | pub(crate) fn set_write_strategy_queue(&mut self) { |
114 | // this should always be called only at construction time, |
115 | // so this assert is here to catch myself |
116 | debug_assert!(self.write_buf.queue.bufs_cnt() == 0); |
117 | self.write_buf.set_strategy(WriteStrategy::Queue); |
118 | } |
119 | |
120 | pub(crate) fn read_buf(&self) -> &[u8] { |
121 | self.read_buf.as_ref() |
122 | } |
123 | |
124 | #[cfg (test)] |
125 | #[cfg (feature = "nightly" )] |
126 | pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut { |
127 | &mut self.read_buf |
128 | } |
129 | |
130 | /// Return the "allocated" available space, not the potential space |
131 | /// that could be allocated in the future. |
132 | fn read_buf_remaining_mut(&self) -> usize { |
133 | self.read_buf.capacity() - self.read_buf.len() |
134 | } |
135 | |
136 | /// Return whether we can append to the headers buffer. |
137 | /// |
138 | /// Reasons we can't: |
139 | /// - The write buf is in queue mode, and some of the past body is still |
140 | /// needing to be flushed. |
141 | pub(crate) fn can_headers_buf(&self) -> bool { |
142 | !self.write_buf.queue.has_remaining() |
143 | } |
144 | |
145 | pub(crate) fn headers_buf(&mut self) -> &mut Vec<u8> { |
146 | let buf = self.write_buf.headers_mut(); |
147 | &mut buf.bytes |
148 | } |
149 | |
150 | pub(super) fn write_buf(&mut self) -> &mut WriteBuf<B> { |
151 | &mut self.write_buf |
152 | } |
153 | |
154 | pub(crate) fn buffer<BB: Buf + Into<B>>(&mut self, buf: BB) { |
155 | self.write_buf.buffer(buf) |
156 | } |
157 | |
158 | pub(crate) fn can_buffer(&self) -> bool { |
159 | self.flush_pipeline || self.write_buf.can_buffer() |
160 | } |
161 | |
162 | pub(crate) fn consume_leading_lines(&mut self) { |
163 | if !self.read_buf.is_empty() { |
164 | let mut i = 0; |
165 | while i < self.read_buf.len() { |
166 | match self.read_buf[i] { |
167 | b' \r' | b' \n' => i += 1, |
168 | _ => break, |
169 | } |
170 | } |
171 | self.read_buf.advance(i); |
172 | } |
173 | } |
174 | |
175 | pub(super) fn parse<S>( |
176 | &mut self, |
177 | cx: &mut task::Context<'_>, |
178 | parse_ctx: ParseContext<'_>, |
179 | ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>> |
180 | where |
181 | S: Http1Transaction, |
182 | { |
183 | loop { |
184 | match super::role::parse_headers::<S>( |
185 | &mut self.read_buf, |
186 | ParseContext { |
187 | cached_headers: parse_ctx.cached_headers, |
188 | req_method: parse_ctx.req_method, |
189 | h1_parser_config: parse_ctx.h1_parser_config.clone(), |
190 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
191 | h1_header_read_timeout: parse_ctx.h1_header_read_timeout, |
192 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
193 | h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut, |
194 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
195 | h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running, |
196 | preserve_header_case: parse_ctx.preserve_header_case, |
197 | #[cfg (feature = "ffi" )] |
198 | preserve_header_order: parse_ctx.preserve_header_order, |
199 | h09_responses: parse_ctx.h09_responses, |
200 | #[cfg (feature = "ffi" )] |
201 | on_informational: parse_ctx.on_informational, |
202 | #[cfg (feature = "ffi" )] |
203 | raw_headers: parse_ctx.raw_headers, |
204 | }, |
205 | )? { |
206 | Some(msg) => { |
207 | debug!("parsed {} headers" , msg.head.headers.len()); |
208 | |
209 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
210 | { |
211 | *parse_ctx.h1_header_read_timeout_running = false; |
212 | |
213 | if let Some(h1_header_read_timeout_fut) = |
214 | parse_ctx.h1_header_read_timeout_fut |
215 | { |
216 | // Reset the timer in order to avoid woken up when the timeout finishes |
217 | h1_header_read_timeout_fut |
218 | .as_mut() |
219 | .reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60)); |
220 | } |
221 | } |
222 | return Poll::Ready(Ok(msg)); |
223 | } |
224 | None => { |
225 | let max = self.read_buf_strategy.max(); |
226 | if self.read_buf.len() >= max { |
227 | debug!("max_buf_size ( {}) reached, closing" , max); |
228 | return Poll::Ready(Err(crate::Error::new_too_large())); |
229 | } |
230 | |
231 | #[cfg (all(feature = "server" , feature = "runtime" ))] |
232 | if *parse_ctx.h1_header_read_timeout_running { |
233 | if let Some(h1_header_read_timeout_fut) = |
234 | parse_ctx.h1_header_read_timeout_fut |
235 | { |
236 | if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() { |
237 | *parse_ctx.h1_header_read_timeout_running = false; |
238 | |
239 | tracing::warn!("read header from client timeout" ); |
240 | return Poll::Ready(Err(crate::Error::new_header_timeout())); |
241 | } |
242 | } |
243 | } |
244 | } |
245 | } |
246 | if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 { |
247 | trace!("parse eof" ); |
248 | return Poll::Ready(Err(crate::Error::new_incomplete())); |
249 | } |
250 | } |
251 | } |
252 | |
253 | pub(crate) fn poll_read_from_io( |
254 | &mut self, |
255 | cx: &mut task::Context<'_>, |
256 | ) -> Poll<io::Result<usize>> { |
257 | self.read_blocked = false; |
258 | let next = self.read_buf_strategy.next(); |
259 | if self.read_buf_remaining_mut() < next { |
260 | self.read_buf.reserve(next); |
261 | } |
262 | |
263 | let dst = self.read_buf.chunk_mut(); |
264 | let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; |
265 | let mut buf = ReadBuf::uninit(dst); |
266 | match Pin::new(&mut self.io).poll_read(cx, &mut buf) { |
267 | Poll::Ready(Ok(_)) => { |
268 | let n = buf.filled().len(); |
269 | trace!("received {} bytes" , n); |
270 | unsafe { |
271 | // Safety: we just read that many bytes into the |
272 | // uninitialized part of the buffer, so this is okay. |
273 | // @tokio pls give me back `poll_read_buf` thanks |
274 | self.read_buf.advance_mut(n); |
275 | } |
276 | self.read_buf_strategy.record(n); |
277 | Poll::Ready(Ok(n)) |
278 | } |
279 | Poll::Pending => { |
280 | self.read_blocked = true; |
281 | Poll::Pending |
282 | } |
283 | Poll::Ready(Err(e)) => Poll::Ready(Err(e)), |
284 | } |
285 | } |
286 | |
287 | pub(crate) fn into_inner(self) -> (T, Bytes) { |
288 | (self.io, self.read_buf.freeze()) |
289 | } |
290 | |
291 | pub(crate) fn io_mut(&mut self) -> &mut T { |
292 | &mut self.io |
293 | } |
294 | |
295 | pub(crate) fn is_read_blocked(&self) -> bool { |
296 | self.read_blocked |
297 | } |
298 | |
299 | pub(crate) fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { |
300 | if self.flush_pipeline && !self.read_buf.is_empty() { |
301 | Poll::Ready(Ok(())) |
302 | } else if self.write_buf.remaining() == 0 { |
303 | Pin::new(&mut self.io).poll_flush(cx) |
304 | } else { |
305 | if let WriteStrategy::Flatten = self.write_buf.strategy { |
306 | return self.poll_flush_flattened(cx); |
307 | } |
308 | |
309 | const MAX_WRITEV_BUFS: usize = 64; |
310 | loop { |
311 | let n = { |
312 | let mut iovs = [IoSlice::new(&[]); MAX_WRITEV_BUFS]; |
313 | let len = self.write_buf.chunks_vectored(&mut iovs); |
314 | ready!(Pin::new(&mut self.io).poll_write_vectored(cx, &iovs[..len]))? |
315 | }; |
316 | // TODO(eliza): we have to do this manually because |
317 | // `poll_write_buf` doesn't exist in Tokio 0.3 yet...when |
318 | // `poll_write_buf` comes back, the manual advance will need to leave! |
319 | self.write_buf.advance(n); |
320 | debug!("flushed {} bytes" , n); |
321 | if self.write_buf.remaining() == 0 { |
322 | break; |
323 | } else if n == 0 { |
324 | trace!( |
325 | "write returned zero, but {} bytes remaining" , |
326 | self.write_buf.remaining() |
327 | ); |
328 | return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); |
329 | } |
330 | } |
331 | Pin::new(&mut self.io).poll_flush(cx) |
332 | } |
333 | } |
334 | |
335 | /// Specialized version of `flush` when strategy is Flatten. |
336 | /// |
337 | /// Since all buffered bytes are flattened into the single headers buffer, |
338 | /// that skips some bookkeeping around using multiple buffers. |
339 | fn poll_flush_flattened(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { |
340 | loop { |
341 | let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.chunk()))?; |
342 | debug!("flushed {} bytes" , n); |
343 | self.write_buf.headers.advance(n); |
344 | if self.write_buf.headers.remaining() == 0 { |
345 | self.write_buf.headers.reset(); |
346 | break; |
347 | } else if n == 0 { |
348 | trace!( |
349 | "write returned zero, but {} bytes remaining" , |
350 | self.write_buf.remaining() |
351 | ); |
352 | return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); |
353 | } |
354 | } |
355 | Pin::new(&mut self.io).poll_flush(cx) |
356 | } |
357 | |
358 | #[cfg (test)] |
359 | fn flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a { |
360 | futures_util::future::poll_fn(move |cx| self.poll_flush(cx)) |
361 | } |
362 | } |
363 | |
364 | // The `B` is a `Buf`, we never project a pin to it |
365 | impl<T: Unpin, B> Unpin for Buffered<T, B> {} |
366 | |
367 | // TODO: This trait is old... at least rename to PollBytes or something... |
368 | pub(crate) trait MemRead { |
369 | fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>>; |
370 | } |
371 | |
372 | impl<T, B> MemRead for Buffered<T, B> |
373 | where |
374 | T: AsyncRead + AsyncWrite + Unpin, |
375 | B: Buf, |
376 | { |
377 | fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> { |
378 | if !self.read_buf.is_empty() { |
379 | let n: usize = std::cmp::min(v1:len, self.read_buf.len()); |
380 | Poll::Ready(Ok(self.read_buf.split_to(at:n).freeze())) |
381 | } else { |
382 | let n: usize = ready!(self.poll_read_from_io(cx))?; |
383 | Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(v1:len, v2:n)).freeze())) |
384 | } |
385 | } |
386 | } |
387 | |
388 | #[derive (Clone, Copy, Debug)] |
389 | enum ReadStrategy { |
390 | Adaptive { |
391 | decrease_now: bool, |
392 | next: usize, |
393 | max: usize, |
394 | }, |
395 | #[cfg (feature = "client" )] |
396 | Exact(usize), |
397 | } |
398 | |
399 | impl ReadStrategy { |
400 | fn with_max(max: usize) -> ReadStrategy { |
401 | ReadStrategy::Adaptive { |
402 | decrease_now: false, |
403 | next: INIT_BUFFER_SIZE, |
404 | max, |
405 | } |
406 | } |
407 | |
408 | fn next(&self) -> usize { |
409 | match *self { |
410 | ReadStrategy::Adaptive { next, .. } => next, |
411 | #[cfg (feature = "client" )] |
412 | ReadStrategy::Exact(exact) => exact, |
413 | } |
414 | } |
415 | |
416 | fn max(&self) -> usize { |
417 | match *self { |
418 | ReadStrategy::Adaptive { max, .. } => max, |
419 | #[cfg (feature = "client" )] |
420 | ReadStrategy::Exact(exact) => exact, |
421 | } |
422 | } |
423 | |
424 | fn record(&mut self, bytes_read: usize) { |
425 | match *self { |
426 | ReadStrategy::Adaptive { |
427 | ref mut decrease_now, |
428 | ref mut next, |
429 | max, |
430 | .. |
431 | } => { |
432 | if bytes_read >= *next { |
433 | *next = cmp::min(incr_power_of_two(*next), max); |
434 | *decrease_now = false; |
435 | } else { |
436 | let decr_to = prev_power_of_two(*next); |
437 | if bytes_read < decr_to { |
438 | if *decrease_now { |
439 | *next = cmp::max(decr_to, INIT_BUFFER_SIZE); |
440 | *decrease_now = false; |
441 | } else { |
442 | // Decreasing is a two "record" process. |
443 | *decrease_now = true; |
444 | } |
445 | } else { |
446 | // A read within the current range should cancel |
447 | // a potential decrease, since we just saw proof |
448 | // that we still need this size. |
449 | *decrease_now = false; |
450 | } |
451 | } |
452 | } |
453 | #[cfg (feature = "client" )] |
454 | ReadStrategy::Exact(_) => (), |
455 | } |
456 | } |
457 | } |
458 | |
459 | fn incr_power_of_two(n: usize) -> usize { |
460 | n.saturating_mul(2) |
461 | } |
462 | |
463 | fn prev_power_of_two(n: usize) -> usize { |
464 | // Only way this shift can underflow is if n is less than 4. |
465 | // (Which would means `usize::MAX >> 64` and underflowed!) |
466 | debug_assert!(n >= 4); |
467 | (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1 |
468 | } |
469 | |
470 | impl Default for ReadStrategy { |
471 | fn default() -> ReadStrategy { |
472 | ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE) |
473 | } |
474 | } |
475 | |
476 | #[derive (Clone)] |
477 | pub(crate) struct Cursor<T> { |
478 | bytes: T, |
479 | pos: usize, |
480 | } |
481 | |
482 | impl<T: AsRef<[u8]>> Cursor<T> { |
483 | #[inline ] |
484 | pub(crate) fn new(bytes: T) -> Cursor<T> { |
485 | Cursor { bytes, pos: 0 } |
486 | } |
487 | } |
488 | |
489 | impl Cursor<Vec<u8>> { |
490 | /// If we've advanced the position a bit in this cursor, and wish to |
491 | /// extend the underlying vector, we may wish to unshift the "read" bytes |
492 | /// off, and move everything else over. |
493 | fn maybe_unshift(&mut self, additional: usize) { |
494 | if self.pos == 0 { |
495 | // nothing to do |
496 | return; |
497 | } |
498 | |
499 | if self.bytes.capacity() - self.bytes.len() >= additional { |
500 | // there's room! |
501 | return; |
502 | } |
503 | |
504 | self.bytes.drain(range:0..self.pos); |
505 | self.pos = 0; |
506 | } |
507 | |
508 | fn reset(&mut self) { |
509 | self.pos = 0; |
510 | self.bytes.clear(); |
511 | } |
512 | } |
513 | |
514 | impl<T: AsRef<[u8]>> fmt::Debug for Cursor<T> { |
515 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
516 | f&mut DebugStruct<'_, '_>.debug_struct("Cursor" ) |
517 | .field("pos" , &self.pos) |
518 | .field(name:"len" , &self.bytes.as_ref().len()) |
519 | .finish() |
520 | } |
521 | } |
522 | |
523 | impl<T: AsRef<[u8]>> Buf for Cursor<T> { |
524 | #[inline ] |
525 | fn remaining(&self) -> usize { |
526 | self.bytes.as_ref().len() - self.pos |
527 | } |
528 | |
529 | #[inline ] |
530 | fn chunk(&self) -> &[u8] { |
531 | &self.bytes.as_ref()[self.pos..] |
532 | } |
533 | |
534 | #[inline ] |
535 | fn advance(&mut self, cnt: usize) { |
536 | debug_assert!(self.pos + cnt <= self.bytes.as_ref().len()); |
537 | self.pos += cnt; |
538 | } |
539 | } |
540 | |
541 | // an internal buffer to collect writes before flushes |
542 | pub(super) struct WriteBuf<B> { |
543 | /// Re-usable buffer that holds message headers |
544 | headers: Cursor<Vec<u8>>, |
545 | max_buf_size: usize, |
546 | /// Deque of user buffers if strategy is Queue |
547 | queue: BufList<B>, |
548 | strategy: WriteStrategy, |
549 | } |
550 | |
551 | impl<B: Buf> WriteBuf<B> { |
552 | fn new(strategy: WriteStrategy) -> WriteBuf<B> { |
553 | WriteBuf { |
554 | headers: Cursor::new(bytes:Vec::with_capacity(INIT_BUFFER_SIZE)), |
555 | max_buf_size: DEFAULT_MAX_BUFFER_SIZE, |
556 | queue: BufList::new(), |
557 | strategy, |
558 | } |
559 | } |
560 | } |
561 | |
562 | impl<B> WriteBuf<B> |
563 | where |
564 | B: Buf, |
565 | { |
566 | fn set_strategy(&mut self, strategy: WriteStrategy) { |
567 | self.strategy = strategy; |
568 | } |
569 | |
570 | pub(super) fn buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB) { |
571 | debug_assert!(buf.has_remaining()); |
572 | match self.strategy { |
573 | WriteStrategy::Flatten => { |
574 | let head = self.headers_mut(); |
575 | |
576 | head.maybe_unshift(buf.remaining()); |
577 | trace!( |
578 | self.len = head.remaining(), |
579 | buf.len = buf.remaining(), |
580 | "buffer.flatten" |
581 | ); |
582 | //perf: This is a little faster than <Vec as BufMut>>::put, |
583 | //but accomplishes the same result. |
584 | loop { |
585 | let adv = { |
586 | let slice = buf.chunk(); |
587 | if slice.is_empty() { |
588 | return; |
589 | } |
590 | head.bytes.extend_from_slice(slice); |
591 | slice.len() |
592 | }; |
593 | buf.advance(adv); |
594 | } |
595 | } |
596 | WriteStrategy::Queue => { |
597 | trace!( |
598 | self.len = self.remaining(), |
599 | buf.len = buf.remaining(), |
600 | "buffer.queue" |
601 | ); |
602 | self.queue.push(buf.into()); |
603 | } |
604 | } |
605 | } |
606 | |
607 | fn can_buffer(&self) -> bool { |
608 | match self.strategy { |
609 | WriteStrategy::Flatten => self.remaining() < self.max_buf_size, |
610 | WriteStrategy::Queue => { |
611 | self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size |
612 | } |
613 | } |
614 | } |
615 | |
616 | fn headers_mut(&mut self) -> &mut Cursor<Vec<u8>> { |
617 | debug_assert!(!self.queue.has_remaining()); |
618 | &mut self.headers |
619 | } |
620 | } |
621 | |
622 | impl<B: Buf> fmt::Debug for WriteBuf<B> { |
623 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
624 | f&mut DebugStruct<'_, '_>.debug_struct("WriteBuf" ) |
625 | .field("remaining" , &self.remaining()) |
626 | .field(name:"strategy" , &self.strategy) |
627 | .finish() |
628 | } |
629 | } |
630 | |
631 | impl<B: Buf> Buf for WriteBuf<B> { |
632 | #[inline ] |
633 | fn remaining(&self) -> usize { |
634 | self.headers.remaining() + self.queue.remaining() |
635 | } |
636 | |
637 | #[inline ] |
638 | fn chunk(&self) -> &[u8] { |
639 | let headers = self.headers.chunk(); |
640 | if !headers.is_empty() { |
641 | headers |
642 | } else { |
643 | self.queue.chunk() |
644 | } |
645 | } |
646 | |
647 | #[inline ] |
648 | fn advance(&mut self, cnt: usize) { |
649 | let hrem = self.headers.remaining(); |
650 | |
651 | match hrem.cmp(&cnt) { |
652 | cmp::Ordering::Equal => self.headers.reset(), |
653 | cmp::Ordering::Greater => self.headers.advance(cnt), |
654 | cmp::Ordering::Less => { |
655 | let qcnt = cnt - hrem; |
656 | self.headers.reset(); |
657 | self.queue.advance(qcnt); |
658 | } |
659 | } |
660 | } |
661 | |
662 | #[inline ] |
663 | fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { |
664 | let n = self.headers.chunks_vectored(dst); |
665 | self.queue.chunks_vectored(&mut dst[n..]) + n |
666 | } |
667 | } |
668 | |
669 | #[derive (Debug)] |
670 | enum WriteStrategy { |
671 | Flatten, |
672 | Queue, |
673 | } |
674 | |
675 | #[cfg (test)] |
676 | mod tests { |
677 | use super::*; |
678 | use std::time::Duration; |
679 | |
680 | use tokio_test::io::Builder as Mock; |
681 | |
682 | // #[cfg(feature = "nightly")] |
683 | // use test::Bencher; |
684 | |
685 | /* |
686 | impl<T: Read> MemRead for AsyncIo<T> { |
687 | fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> { |
688 | let mut v = vec![0; len]; |
689 | let n = try_nb!(self.read(v.as_mut_slice())); |
690 | Ok(Async::Ready(BytesMut::from(&v[..n]).freeze())) |
691 | } |
692 | } |
693 | */ |
694 | |
695 | #[tokio::test] |
696 | #[ignore ] |
697 | async fn iobuf_write_empty_slice() { |
698 | // TODO(eliza): can i have writev back pls T_T |
699 | // // First, let's just check that the Mock would normally return an |
700 | // // error on an unexpected write, even if the buffer is empty... |
701 | // let mut mock = Mock::new().build(); |
702 | // futures_util::future::poll_fn(|cx| { |
703 | // Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[])) |
704 | // }) |
705 | // .await |
706 | // .expect_err("should be a broken pipe"); |
707 | |
708 | // // underlying io will return the logic error upon write, |
709 | // // so we are testing that the io_buf does not trigger a write |
710 | // // when there is nothing to flush |
711 | // let mock = Mock::new().build(); |
712 | // let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
713 | // io_buf.flush().await.expect("should short-circuit flush"); |
714 | } |
715 | |
716 | #[tokio::test] |
717 | async fn parse_reads_until_blocked() { |
718 | use crate::proto::h1::ClientTransaction; |
719 | |
720 | let _ = pretty_env_logger::try_init(); |
721 | let mock = Mock::new() |
722 | // Split over multiple reads will read all of it |
723 | .read(b"HTTP/1.1 200 OK \r\n" ) |
724 | .read(b"Server: hyper \r\n" ) |
725 | // missing last line ending |
726 | .wait(Duration::from_secs(1)) |
727 | .build(); |
728 | |
729 | let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
730 | |
731 | // We expect a `parse` to be not ready, and so can't await it directly. |
732 | // Rather, this `poll_fn` will wrap the `Poll` result. |
733 | futures_util::future::poll_fn(|cx| { |
734 | let parse_ctx = ParseContext { |
735 | cached_headers: &mut None, |
736 | req_method: &mut None, |
737 | h1_parser_config: Default::default(), |
738 | #[cfg (feature = "runtime" )] |
739 | h1_header_read_timeout: None, |
740 | #[cfg (feature = "runtime" )] |
741 | h1_header_read_timeout_fut: &mut None, |
742 | #[cfg (feature = "runtime" )] |
743 | h1_header_read_timeout_running: &mut false, |
744 | preserve_header_case: false, |
745 | #[cfg (feature = "ffi" )] |
746 | preserve_header_order: false, |
747 | h09_responses: false, |
748 | #[cfg (feature = "ffi" )] |
749 | on_informational: &mut None, |
750 | #[cfg (feature = "ffi" )] |
751 | raw_headers: false, |
752 | }; |
753 | assert!(buffered |
754 | .parse::<ClientTransaction>(cx, parse_ctx) |
755 | .is_pending()); |
756 | Poll::Ready(()) |
757 | }) |
758 | .await; |
759 | |
760 | assert_eq!( |
761 | buffered.read_buf, |
762 | b"HTTP/1.1 200 OK \r\nServer: hyper \r\n" [..] |
763 | ); |
764 | } |
765 | |
766 | #[test ] |
767 | fn read_strategy_adaptive_increments() { |
768 | let mut strategy = ReadStrategy::default(); |
769 | assert_eq!(strategy.next(), 8192); |
770 | |
771 | // Grows if record == next |
772 | strategy.record(8192); |
773 | assert_eq!(strategy.next(), 16384); |
774 | |
775 | strategy.record(16384); |
776 | assert_eq!(strategy.next(), 32768); |
777 | |
778 | // Enormous records still increment at same rate |
779 | strategy.record(::std::usize::MAX); |
780 | assert_eq!(strategy.next(), 65536); |
781 | |
782 | let max = strategy.max(); |
783 | while strategy.next() < max { |
784 | strategy.record(max); |
785 | } |
786 | |
787 | assert_eq!(strategy.next(), max, "never goes over max" ); |
788 | strategy.record(max + 1); |
789 | assert_eq!(strategy.next(), max, "never goes over max" ); |
790 | } |
791 | |
792 | #[test ] |
793 | fn read_strategy_adaptive_decrements() { |
794 | let mut strategy = ReadStrategy::default(); |
795 | strategy.record(8192); |
796 | assert_eq!(strategy.next(), 16384); |
797 | |
798 | strategy.record(1); |
799 | assert_eq!( |
800 | strategy.next(), |
801 | 16384, |
802 | "first smaller record doesn't decrement yet" |
803 | ); |
804 | strategy.record(8192); |
805 | assert_eq!(strategy.next(), 16384, "record was with range" ); |
806 | |
807 | strategy.record(1); |
808 | assert_eq!( |
809 | strategy.next(), |
810 | 16384, |
811 | "in-range record should make this the 'first' again" |
812 | ); |
813 | |
814 | strategy.record(1); |
815 | assert_eq!(strategy.next(), 8192, "second smaller record decrements" ); |
816 | |
817 | strategy.record(1); |
818 | assert_eq!(strategy.next(), 8192, "first doesn't decrement" ); |
819 | strategy.record(1); |
820 | assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum" ); |
821 | } |
822 | |
823 | #[test ] |
824 | fn read_strategy_adaptive_stays_the_same() { |
825 | let mut strategy = ReadStrategy::default(); |
826 | strategy.record(8192); |
827 | assert_eq!(strategy.next(), 16384); |
828 | |
829 | strategy.record(8193); |
830 | assert_eq!( |
831 | strategy.next(), |
832 | 16384, |
833 | "first smaller record doesn't decrement yet" |
834 | ); |
835 | |
836 | strategy.record(8193); |
837 | assert_eq!( |
838 | strategy.next(), |
839 | 16384, |
840 | "with current step does not decrement" |
841 | ); |
842 | } |
843 | |
844 | #[test ] |
845 | fn read_strategy_adaptive_max_fuzz() { |
846 | fn fuzz(max: usize) { |
847 | let mut strategy = ReadStrategy::with_max(max); |
848 | while strategy.next() < max { |
849 | strategy.record(::std::usize::MAX); |
850 | } |
851 | let mut next = strategy.next(); |
852 | while next > 8192 { |
853 | strategy.record(1); |
854 | strategy.record(1); |
855 | next = strategy.next(); |
856 | assert!( |
857 | next.is_power_of_two(), |
858 | "decrement should be powers of two: {} (max = {})" , |
859 | next, |
860 | max, |
861 | ); |
862 | } |
863 | } |
864 | |
865 | let mut max = 8192; |
866 | while max < std::usize::MAX { |
867 | fuzz(max); |
868 | max = (max / 2).saturating_mul(3); |
869 | } |
870 | fuzz(::std::usize::MAX); |
871 | } |
872 | |
873 | #[test ] |
874 | #[should_panic ] |
875 | #[cfg (debug_assertions)] // needs to trigger a debug_assert |
876 | fn write_buf_requires_non_empty_bufs() { |
877 | let mock = Mock::new().build(); |
878 | let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
879 | |
880 | buffered.buffer(Cursor::new(Vec::new())); |
881 | } |
882 | |
883 | /* |
884 | TODO: needs tokio_test::io to allow configure write_buf calls |
885 | #[test] |
886 | fn write_buf_queue() { |
887 | let _ = pretty_env_logger::try_init(); |
888 | |
889 | let mock = AsyncIo::new_buf(vec![], 1024); |
890 | let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
891 | |
892 | |
893 | buffered.headers_buf().extend(b"hello "); |
894 | buffered.buffer(Cursor::new(b"world, ".to_vec())); |
895 | buffered.buffer(Cursor::new(b"it's ".to_vec())); |
896 | buffered.buffer(Cursor::new(b"hyper!".to_vec())); |
897 | assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); |
898 | buffered.flush().unwrap(); |
899 | |
900 | assert_eq!(buffered.io, b"hello world, it's hyper!"); |
901 | assert_eq!(buffered.io.num_writes(), 1); |
902 | assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); |
903 | } |
904 | */ |
905 | |
906 | #[tokio::test] |
907 | async fn write_buf_flatten() { |
908 | let _ = pretty_env_logger::try_init(); |
909 | |
910 | let mock = Mock::new().write(b"hello world, it's hyper!" ).build(); |
911 | |
912 | let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
913 | buffered.write_buf.set_strategy(WriteStrategy::Flatten); |
914 | |
915 | buffered.headers_buf().extend(b"hello " ); |
916 | buffered.buffer(Cursor::new(b"world, " .to_vec())); |
917 | buffered.buffer(Cursor::new(b"it's " .to_vec())); |
918 | buffered.buffer(Cursor::new(b"hyper!" .to_vec())); |
919 | assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); |
920 | |
921 | buffered.flush().await.expect("flush" ); |
922 | } |
923 | |
924 | #[test ] |
925 | fn write_buf_flatten_partially_flushed() { |
926 | let _ = pretty_env_logger::try_init(); |
927 | |
928 | let b = |s: &str| Cursor::new(s.as_bytes().to_vec()); |
929 | |
930 | let mut write_buf = WriteBuf::<Cursor<Vec<u8>>>::new(WriteStrategy::Flatten); |
931 | |
932 | write_buf.buffer(b("hello " )); |
933 | write_buf.buffer(b("world, " )); |
934 | |
935 | assert_eq!(write_buf.chunk(), b"hello world, " ); |
936 | |
937 | // advance most of the way, but not all |
938 | write_buf.advance(11); |
939 | |
940 | assert_eq!(write_buf.chunk(), b", " ); |
941 | assert_eq!(write_buf.headers.pos, 11); |
942 | assert_eq!(write_buf.headers.bytes.capacity(), INIT_BUFFER_SIZE); |
943 | |
944 | // there's still room in the headers buffer, so just push on the end |
945 | write_buf.buffer(b("it's hyper!" )); |
946 | |
947 | assert_eq!(write_buf.chunk(), b", it's hyper!" ); |
948 | assert_eq!(write_buf.headers.pos, 11); |
949 | |
950 | let rem1 = write_buf.remaining(); |
951 | let cap = write_buf.headers.bytes.capacity(); |
952 | |
953 | // but when this would go over capacity, don't copy the old bytes |
954 | write_buf.buffer(Cursor::new(vec![b'X' ; cap])); |
955 | assert_eq!(write_buf.remaining(), cap + rem1); |
956 | assert_eq!(write_buf.headers.pos, 0); |
957 | } |
958 | |
959 | #[tokio::test] |
960 | async fn write_buf_queue_disable_auto() { |
961 | let _ = pretty_env_logger::try_init(); |
962 | |
963 | let mock = Mock::new() |
964 | .write(b"hello " ) |
965 | .write(b"world, " ) |
966 | .write(b"it's " ) |
967 | .write(b"hyper!" ) |
968 | .build(); |
969 | |
970 | let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock); |
971 | buffered.write_buf.set_strategy(WriteStrategy::Queue); |
972 | |
973 | // we have 4 buffers, and vec IO disabled, but explicitly said |
974 | // don't try to auto detect (via setting strategy above) |
975 | |
976 | buffered.headers_buf().extend(b"hello " ); |
977 | buffered.buffer(Cursor::new(b"world, " .to_vec())); |
978 | buffered.buffer(Cursor::new(b"it's " .to_vec())); |
979 | buffered.buffer(Cursor::new(b"hyper!" .to_vec())); |
980 | assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); |
981 | |
982 | buffered.flush().await.expect("flush" ); |
983 | |
984 | assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); |
985 | } |
986 | |
987 | // #[cfg(feature = "nightly")] |
988 | // #[bench] |
989 | // fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) { |
990 | // let s = "Hello, World!"; |
991 | // b.bytes = s.len() as u64; |
992 | |
993 | // let mut write_buf = WriteBuf::<bytes::Bytes>::new(); |
994 | // write_buf.set_strategy(WriteStrategy::Flatten); |
995 | // b.iter(|| { |
996 | // let chunk = bytes::Bytes::from(s); |
997 | // write_buf.buffer(chunk); |
998 | // ::test::black_box(&write_buf); |
999 | // write_buf.headers.bytes.clear(); |
1000 | // }) |
1001 | // } |
1002 | } |
1003 | |