1use crate::io::util::DEFAULT_BUF_SIZE;
2use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
3
4use pin_project_lite::pin_project;
5use std::fmt;
6use std::io::{self, IoSlice, SeekFrom, Write};
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10pin_project! {
11 /// Wraps a writer and buffers its output.
12 ///
13 /// It can be excessively inefficient to work directly with something that
14 /// implements [`AsyncWrite`]. A `BufWriter` keeps an in-memory buffer of data and
15 /// writes it to an underlying writer in large, infrequent batches.
16 ///
17 /// `BufWriter` can improve the speed of programs that make *small* and
18 /// *repeated* write calls to the same file or network socket. It does not
19 /// help when writing very large amounts at once, or writing just one or a few
20 /// times. It also provides no advantage when writing to a destination that is
21 /// in memory, like a `Vec<u8>`.
22 ///
23 /// When the `BufWriter` is dropped, the contents of its buffer will be
24 /// discarded. Creating multiple instances of a `BufWriter` on the same
25 /// stream can cause data loss. If you need to write out the contents of its
26 /// buffer, you must manually call flush before the writer is dropped.
27 ///
28 /// [`AsyncWrite`]: AsyncWrite
29 /// [`flush`]: super::AsyncWriteExt::flush
30 ///
31 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
32 pub struct BufWriter<W> {
33 #[pin]
34 pub(super) inner: W,
35 pub(super) buf: Vec<u8>,
36 pub(super) written: usize,
37 pub(super) seek_state: SeekState,
38 }
39}
40
41impl<W: AsyncWrite> BufWriter<W> {
42 /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB,
43 /// but may change in the future.
44 pub fn new(inner: W) -> Self {
45 Self::with_capacity(DEFAULT_BUF_SIZE, inner)
46 }
47
48 /// Creates a new `BufWriter` with the specified buffer capacity.
49 pub fn with_capacity(cap: usize, inner: W) -> Self {
50 Self {
51 inner,
52 buf: Vec::with_capacity(cap),
53 written: 0,
54 seek_state: SeekState::Init,
55 }
56 }
57
58 fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
59 let mut me = self.project();
60
61 let len = me.buf.len();
62 let mut ret = Ok(());
63 while *me.written < len {
64 match ready!(me.inner.as_mut().poll_write(cx, &me.buf[*me.written..])) {
65 Ok(0) => {
66 ret = Err(io::Error::new(
67 io::ErrorKind::WriteZero,
68 "failed to write the buffered data",
69 ));
70 break;
71 }
72 Ok(n) => *me.written += n,
73 Err(e) => {
74 ret = Err(e);
75 break;
76 }
77 }
78 }
79 if *me.written > 0 {
80 me.buf.drain(..*me.written);
81 }
82 *me.written = 0;
83 Poll::Ready(ret)
84 }
85
86 /// Gets a reference to the underlying writer.
87 pub fn get_ref(&self) -> &W {
88 &self.inner
89 }
90
91 /// Gets a mutable reference to the underlying writer.
92 ///
93 /// It is inadvisable to directly write to the underlying writer.
94 pub fn get_mut(&mut self) -> &mut W {
95 &mut self.inner
96 }
97
98 /// Gets a pinned mutable reference to the underlying writer.
99 ///
100 /// It is inadvisable to directly write to the underlying writer.
101 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
102 self.project().inner
103 }
104
105 /// Consumes this `BufWriter`, returning the underlying writer.
106 ///
107 /// Note that any leftover data in the internal buffer is lost.
108 pub fn into_inner(self) -> W {
109 self.inner
110 }
111
112 /// Returns a reference to the internally buffered data.
113 pub fn buffer(&self) -> &[u8] {
114 &self.buf
115 }
116}
117
118impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
119 fn poll_write(
120 mut self: Pin<&mut Self>,
121 cx: &mut Context<'_>,
122 buf: &[u8],
123 ) -> Poll<io::Result<usize>> {
124 if self.buf.len() + buf.len() > self.buf.capacity() {
125 ready!(self.as_mut().flush_buf(cx))?;
126 }
127
128 let me = self.project();
129 if buf.len() >= me.buf.capacity() {
130 me.inner.poll_write(cx, buf)
131 } else {
132 Poll::Ready(me.buf.write(buf))
133 }
134 }
135
136 fn poll_write_vectored(
137 mut self: Pin<&mut Self>,
138 cx: &mut Context<'_>,
139 mut bufs: &[IoSlice<'_>],
140 ) -> Poll<io::Result<usize>> {
141 if self.inner.is_write_vectored() {
142 let total_len = bufs
143 .iter()
144 .fold(0usize, |acc, b| acc.saturating_add(b.len()));
145 if total_len > self.buf.capacity() - self.buf.len() {
146 ready!(self.as_mut().flush_buf(cx))?;
147 }
148 let me = self.as_mut().project();
149 if total_len >= me.buf.capacity() {
150 // It's more efficient to pass the slices directly to the
151 // underlying writer than to buffer them.
152 // The case when the total_len calculation saturates at
153 // usize::MAX is also handled here.
154 me.inner.poll_write_vectored(cx, bufs)
155 } else {
156 bufs.iter().for_each(|b| me.buf.extend_from_slice(b));
157 Poll::Ready(Ok(total_len))
158 }
159 } else {
160 // Remove empty buffers at the beginning of bufs.
161 while bufs.first().map(|buf| buf.len()) == Some(0) {
162 bufs = &bufs[1..];
163 }
164 if bufs.is_empty() {
165 return Poll::Ready(Ok(0));
166 }
167 // Flush if the first buffer doesn't fit.
168 let first_len = bufs[0].len();
169 if first_len > self.buf.capacity() - self.buf.len() {
170 ready!(self.as_mut().flush_buf(cx))?;
171 debug_assert!(self.buf.is_empty());
172 }
173 let me = self.as_mut().project();
174 if first_len >= me.buf.capacity() {
175 // The slice is at least as large as the buffering capacity,
176 // so it's better to write it directly, bypassing the buffer.
177 debug_assert!(me.buf.is_empty());
178 return me.inner.poll_write(cx, &bufs[0]);
179 } else {
180 me.buf.extend_from_slice(&bufs[0]);
181 bufs = &bufs[1..];
182 }
183 let mut total_written = first_len;
184 debug_assert!(total_written != 0);
185 // Append the buffers that fit in the internal buffer.
186 for buf in bufs {
187 if buf.len() > me.buf.capacity() - me.buf.len() {
188 break;
189 } else {
190 me.buf.extend_from_slice(buf);
191 total_written += buf.len();
192 }
193 }
194 Poll::Ready(Ok(total_written))
195 }
196 }
197
198 fn is_write_vectored(&self) -> bool {
199 true
200 }
201
202 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
203 ready!(self.as_mut().flush_buf(cx))?;
204 self.get_pin_mut().poll_flush(cx)
205 }
206
207 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208 ready!(self.as_mut().flush_buf(cx))?;
209 self.get_pin_mut().poll_shutdown(cx)
210 }
211}
212
213#[derive(Debug, Clone, Copy)]
214pub(super) enum SeekState {
215 /// start_seek has not been called.
216 Init,
217 /// start_seek has been called, but poll_complete has not yet been called.
218 Start(SeekFrom),
219 /// Waiting for completion of poll_complete.
220 Pending,
221}
222
223/// Seek to the offset, in bytes, in the underlying writer.
224///
225/// Seeking always writes out the internal buffer before seeking.
226impl<W: AsyncWrite + AsyncSeek> AsyncSeek for BufWriter<W> {
227 fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
228 // We need to flush the internal buffer before seeking.
229 // It receives a `Context` and returns a `Poll`, so it cannot be called
230 // inside `start_seek`.
231 *self.project().seek_state = SeekState::Start(pos);
232 Ok(())
233 }
234
235 fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
236 let pos = match self.seek_state {
237 SeekState::Init => {
238 return self.project().inner.poll_complete(cx);
239 }
240 SeekState::Start(pos) => Some(pos),
241 SeekState::Pending => None,
242 };
243
244 // Flush the internal buffer before seeking.
245 ready!(self.as_mut().flush_buf(cx))?;
246
247 let mut me = self.project();
248 if let Some(pos) = pos {
249 // Ensure previous seeks have finished before starting a new one
250 ready!(me.inner.as_mut().poll_complete(cx))?;
251 if let Err(e) = me.inner.as_mut().start_seek(pos) {
252 *me.seek_state = SeekState::Init;
253 return Poll::Ready(Err(e));
254 }
255 }
256 match me.inner.poll_complete(cx) {
257 Poll::Ready(res) => {
258 *me.seek_state = SeekState::Init;
259 Poll::Ready(res)
260 }
261 Poll::Pending => {
262 *me.seek_state = SeekState::Pending;
263 Poll::Pending
264 }
265 }
266 }
267}
268
269impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> {
270 fn poll_read(
271 self: Pin<&mut Self>,
272 cx: &mut Context<'_>,
273 buf: &mut ReadBuf<'_>,
274 ) -> Poll<io::Result<()>> {
275 self.get_pin_mut().poll_read(cx, buf)
276 }
277}
278
279impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> {
280 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
281 self.get_pin_mut().poll_fill_buf(cx)
282 }
283
284 fn consume(self: Pin<&mut Self>, amt: usize) {
285 self.get_pin_mut().consume(amt);
286 }
287}
288
289impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
290 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291 f.debug_struct("BufWriter")
292 .field("writer", &self.inner)
293 .field(
294 "buffer",
295 &format_args!("{}/{}", self.buf.len(), self.buf.capacity()),
296 )
297 .field("written", &self.written)
298 .finish()
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn assert_unpin() {
308 crate::is_unpin::<BufWriter<()>>();
309 }
310}
311