1use futures_core::ready;
2use futures_core::task::{Context, Poll};
3use futures_io::{AsyncBufRead, AsyncRead};
4use pin_project_lite::pin_project;
5use std::pin::Pin;
6use std::{cmp, io};
7
8pin_project! {
9 /// Reader for the [`take`](super::AsyncReadExt::take) method.
10 #[derive(Debug)]
11 #[must_use = "readers do nothing unless you `.await` or poll them"]
12 pub struct Take<R> {
13 #[pin]
14 inner: R,
15 limit: u64,
16 }
17}
18
19impl<R: AsyncRead> Take<R> {
20 pub(super) fn new(inner: R, limit: u64) -> Self {
21 Self { inner, limit }
22 }
23
24 /// Returns the remaining number of bytes that can be
25 /// read before this instance will return EOF.
26 ///
27 /// # Note
28 ///
29 /// This instance may reach `EOF` after reading fewer bytes than indicated by
30 /// this method if the underlying [`AsyncRead`] instance reaches EOF.
31 ///
32 /// # Examples
33 ///
34 /// ```
35 /// # futures::executor::block_on(async {
36 /// use futures::io::{AsyncReadExt, Cursor};
37 ///
38 /// let reader = Cursor::new(&b"12345678"[..]);
39 /// let mut buffer = [0; 2];
40 ///
41 /// let mut take = reader.take(4);
42 /// let n = take.read(&mut buffer).await?;
43 ///
44 /// assert_eq!(take.limit(), 2);
45 /// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
46 /// ```
47 pub fn limit(&self) -> u64 {
48 self.limit
49 }
50
51 /// Sets the number of bytes that can be read before this instance will
52 /// return EOF. This is the same as constructing a new `Take` instance, so
53 /// the amount of bytes read and the previous limit value don't matter when
54 /// calling this method.
55 ///
56 /// # Examples
57 ///
58 /// ```
59 /// # futures::executor::block_on(async {
60 /// use futures::io::{AsyncReadExt, Cursor};
61 ///
62 /// let reader = Cursor::new(&b"12345678"[..]);
63 /// let mut buffer = [0; 4];
64 ///
65 /// let mut take = reader.take(4);
66 /// let n = take.read(&mut buffer).await?;
67 ///
68 /// assert_eq!(n, 4);
69 /// assert_eq!(take.limit(), 0);
70 ///
71 /// take.set_limit(10);
72 /// let n = take.read(&mut buffer).await?;
73 /// assert_eq!(n, 4);
74 ///
75 /// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
76 /// ```
77 pub fn set_limit(&mut self, limit: u64) {
78 self.limit = limit
79 }
80
81 delegate_access_inner!(inner, R, ());
82}
83
84impl<R: AsyncRead> AsyncRead for Take<R> {
85 fn poll_read(
86 self: Pin<&mut Self>,
87 cx: &mut Context<'_>,
88 buf: &mut [u8],
89 ) -> Poll<Result<usize, io::Error>> {
90 let this: Projection<'_, R> = self.project();
91
92 if *this.limit == 0 {
93 return Poll::Ready(Ok(0));
94 }
95
96 let max: usize = cmp::min(v1:buf.len() as u64, *this.limit) as usize;
97 let n: usize = ready!(this.inner.poll_read(cx, &mut buf[..max]))?;
98 *this.limit -= n as u64;
99 Poll::Ready(Ok(n))
100 }
101}
102
103impl<R: AsyncBufRead> AsyncBufRead for Take<R> {
104 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
105 let this: Projection<'_, R> = self.project();
106
107 // Don't call into inner reader at all at EOF because it may still block
108 if *this.limit == 0 {
109 return Poll::Ready(Ok(&[]));
110 }
111
112 let buf: &[u8] = ready!(this.inner.poll_fill_buf(cx)?);
113 let cap: usize = cmp::min(v1:buf.len() as u64, *this.limit) as usize;
114 Poll::Ready(Ok(&buf[..cap]))
115 }
116
117 fn consume(self: Pin<&mut Self>, amt: usize) {
118 let this: Projection<'_, R> = self.project();
119
120 // Don't let callers reset the limit by passing an overlarge value
121 let amt: usize = cmp::min(v1:amt as u64, *this.limit) as usize;
122 *this.limit -= amt as u64;
123 this.inner.consume(amt);
124 }
125}
126