1use crate::io::util::DEFAULT_BUF_SIZE;
2use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
3
4use pin_project_lite::pin_project;
5use std::io::{self, IoSlice, SeekFrom};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::{cmp, fmt, mem};
9
10pin_project! {
11 /// The `BufReader` struct adds buffering to any reader.
12 ///
13 /// It can be excessively inefficient to work directly with a [`AsyncRead`]
14 /// instance. A `BufReader` performs large, infrequent reads on the underlying
15 /// [`AsyncRead`] and maintains an in-memory buffer of the results.
16 ///
17 /// `BufReader` can improve the speed of programs that make *small* and
18 /// *repeated* read calls to the same file or network socket. It does not
19 /// help when reading very large amounts at once, or reading just one or a few
20 /// times. It also provides no advantage when reading from a source that is
21 /// already in memory, like a `Vec<u8>`.
22 ///
23 /// When the `BufReader` is dropped, the contents of its buffer will be
24 /// discarded. Creating multiple instances of a `BufReader` on the same
25 /// stream can cause data loss.
26 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
27 pub struct BufReader<R> {
28 #[pin]
29 pub(super) inner: R,
30 pub(super) buf: Box<[u8]>,
31 pub(super) pos: usize,
32 pub(super) cap: usize,
33 pub(super) seek_state: SeekState,
34 }
35}
36
37impl<R: AsyncRead> BufReader<R> {
38 /// Creates a new `BufReader` with a default buffer capacity. The default is currently 8 KB,
39 /// but may change in the future.
40 pub fn new(inner: R) -> Self {
41 Self::with_capacity(DEFAULT_BUF_SIZE, inner)
42 }
43
44 /// Creates a new `BufReader` with the specified buffer capacity.
45 pub fn with_capacity(capacity: usize, inner: R) -> Self {
46 let buffer = vec![0; capacity];
47 Self {
48 inner,
49 buf: buffer.into_boxed_slice(),
50 pos: 0,
51 cap: 0,
52 seek_state: SeekState::Init,
53 }
54 }
55
56 /// Gets a reference to the underlying reader.
57 ///
58 /// It is inadvisable to directly read from the underlying reader.
59 pub fn get_ref(&self) -> &R {
60 &self.inner
61 }
62
63 /// Gets a mutable reference to the underlying reader.
64 ///
65 /// It is inadvisable to directly read from the underlying reader.
66 pub fn get_mut(&mut self) -> &mut R {
67 &mut self.inner
68 }
69
70 /// Gets a pinned mutable reference to the underlying reader.
71 ///
72 /// It is inadvisable to directly read from the underlying reader.
73 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
74 self.project().inner
75 }
76
77 /// Consumes this `BufReader`, returning the underlying reader.
78 ///
79 /// Note that any leftover data in the internal buffer is lost.
80 pub fn into_inner(self) -> R {
81 self.inner
82 }
83
84 /// Returns a reference to the internally buffered data.
85 ///
86 /// Unlike `fill_buf`, this will not attempt to fill the buffer if it is empty.
87 pub fn buffer(&self) -> &[u8] {
88 &self.buf[self.pos..self.cap]
89 }
90
91 /// Invalidates all data in the internal buffer.
92 #[inline]
93 fn discard_buffer(self: Pin<&mut Self>) {
94 let me = self.project();
95 *me.pos = 0;
96 *me.cap = 0;
97 }
98}
99
100impl<R: AsyncRead> AsyncRead for BufReader<R> {
101 fn poll_read(
102 mut self: Pin<&mut Self>,
103 cx: &mut Context<'_>,
104 buf: &mut ReadBuf<'_>,
105 ) -> Poll<io::Result<()>> {
106 // If we don't have any buffered data and we're doing a massive read
107 // (larger than our internal buffer), bypass our internal buffer
108 // entirely.
109 if self.pos == self.cap && buf.remaining() >= self.buf.len() {
110 let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf));
111 self.discard_buffer();
112 return Poll::Ready(res);
113 }
114 let rem = ready!(self.as_mut().poll_fill_buf(cx))?;
115 let amt = std::cmp::min(rem.len(), buf.remaining());
116 buf.put_slice(&rem[..amt]);
117 self.consume(amt);
118 Poll::Ready(Ok(()))
119 }
120}
121
122impl<R: AsyncRead> AsyncBufRead for BufReader<R> {
123 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
124 let me = self.project();
125
126 // If we've reached the end of our internal buffer then we need to fetch
127 // some more data from the underlying reader.
128 // Branch using `>=` instead of the more correct `==`
129 // to tell the compiler that the pos..cap slice is always valid.
130 if *me.pos >= *me.cap {
131 debug_assert!(*me.pos == *me.cap);
132 let mut buf = ReadBuf::new(me.buf);
133 ready!(me.inner.poll_read(cx, &mut buf))?;
134 *me.cap = buf.filled().len();
135 *me.pos = 0;
136 }
137 Poll::Ready(Ok(&me.buf[*me.pos..*me.cap]))
138 }
139
140 fn consume(self: Pin<&mut Self>, amt: usize) {
141 let me = self.project();
142 *me.pos = cmp::min(*me.pos + amt, *me.cap);
143 }
144}
145
146#[derive(Debug, Clone, Copy)]
147pub(super) enum SeekState {
148 /// start_seek has not been called.
149 Init,
150 /// start_seek has been called, but poll_complete has not yet been called.
151 Start(SeekFrom),
152 /// Waiting for completion of the first poll_complete in the `n.checked_sub(remainder).is_none()` branch.
153 PendingOverflowed(i64),
154 /// Waiting for completion of poll_complete.
155 Pending,
156}
157
158/// Seeks to an offset, in bytes, in the underlying reader.
159///
160/// The position used for seeking with `SeekFrom::Current(_)` is the
161/// position the underlying reader would be at if the `BufReader` had no
162/// internal buffer.
163///
164/// Seeking always discards the internal buffer, even if the seek position
165/// would otherwise fall within it. This guarantees that calling
166/// `.into_inner()` immediately after a seek yields the underlying reader
167/// at the same position.
168///
169/// See [`AsyncSeek`] for more details.
170///
171/// Note: In the edge case where you're seeking with `SeekFrom::Current(n)`
172/// where `n` minus the internal buffer length overflows an `i64`, two
173/// seeks will be performed instead of one. If the second seek returns
174/// `Err`, the underlying reader will be left at the same position it would
175/// have if you called `seek` with `SeekFrom::Current(0)`.
176impl<R: AsyncRead + AsyncSeek> AsyncSeek for BufReader<R> {
177 fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
178 // We needs to call seek operation multiple times.
179 // And we should always call both start_seek and poll_complete,
180 // as start_seek alone cannot guarantee that the operation will be completed.
181 // poll_complete receives a Context and returns a Poll, so it cannot be called
182 // inside start_seek.
183 *self.project().seek_state = SeekState::Start(pos);
184 Ok(())
185 }
186
187 fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
188 let res = match mem::replace(self.as_mut().project().seek_state, SeekState::Init) {
189 SeekState::Init => {
190 // 1.x AsyncSeek recommends calling poll_complete before start_seek.
191 // We don't have to guarantee that the value returned by
192 // poll_complete called without start_seek is correct,
193 // so we'll return 0.
194 return Poll::Ready(Ok(0));
195 }
196 SeekState::Start(SeekFrom::Current(n)) => {
197 let remainder = (self.cap - self.pos) as i64;
198 // it should be safe to assume that remainder fits within an i64 as the alternative
199 // means we managed to allocate 8 exbibytes and that's absurd.
200 // But it's not out of the realm of possibility for some weird underlying reader to
201 // support seeking by i64::MIN so we need to handle underflow when subtracting
202 // remainder.
203 if let Some(offset) = n.checked_sub(remainder) {
204 self.as_mut()
205 .get_pin_mut()
206 .start_seek(SeekFrom::Current(offset))?;
207 } else {
208 // seek backwards by our remainder, and then by the offset
209 self.as_mut()
210 .get_pin_mut()
211 .start_seek(SeekFrom::Current(-remainder))?;
212 if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() {
213 *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n);
214 return Poll::Pending;
215 }
216
217 // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676
218 self.as_mut().discard_buffer();
219
220 self.as_mut()
221 .get_pin_mut()
222 .start_seek(SeekFrom::Current(n))?;
223 }
224 self.as_mut().get_pin_mut().poll_complete(cx)?
225 }
226 SeekState::PendingOverflowed(n) => {
227 if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() {
228 *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n);
229 return Poll::Pending;
230 }
231
232 // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676
233 self.as_mut().discard_buffer();
234
235 self.as_mut()
236 .get_pin_mut()
237 .start_seek(SeekFrom::Current(n))?;
238 self.as_mut().get_pin_mut().poll_complete(cx)?
239 }
240 SeekState::Start(pos) => {
241 // Seeking with Start/End doesn't care about our buffer length.
242 self.as_mut().get_pin_mut().start_seek(pos)?;
243 self.as_mut().get_pin_mut().poll_complete(cx)?
244 }
245 SeekState::Pending => self.as_mut().get_pin_mut().poll_complete(cx)?,
246 };
247
248 match res {
249 Poll::Ready(res) => {
250 self.discard_buffer();
251 Poll::Ready(Ok(res))
252 }
253 Poll::Pending => {
254 *self.as_mut().project().seek_state = SeekState::Pending;
255 Poll::Pending
256 }
257 }
258 }
259}
260
261impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> {
262 fn poll_write(
263 self: Pin<&mut Self>,
264 cx: &mut Context<'_>,
265 buf: &[u8],
266 ) -> Poll<io::Result<usize>> {
267 self.get_pin_mut().poll_write(cx, buf)
268 }
269
270 fn poll_write_vectored(
271 self: Pin<&mut Self>,
272 cx: &mut Context<'_>,
273 bufs: &[IoSlice<'_>],
274 ) -> Poll<io::Result<usize>> {
275 self.get_pin_mut().poll_write_vectored(cx, bufs)
276 }
277
278 fn is_write_vectored(&self) -> bool {
279 self.get_ref().is_write_vectored()
280 }
281
282 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
283 self.get_pin_mut().poll_flush(cx)
284 }
285
286 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
287 self.get_pin_mut().poll_shutdown(cx)
288 }
289}
290
291impl<R: fmt::Debug> fmt::Debug for BufReader<R> {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 f.debug_struct("BufReader")
294 .field("reader", &self.inner)
295 .field(
296 "buffer",
297 &format_args!("{}/{}", self.cap - self.pos, self.buf.len()),
298 )
299 .finish()
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn assert_unpin() {
309 crate::is_unpin::<BufReader<()>>();
310 }
311}
312