1 | use crate::io::util::DEFAULT_BUF_SIZE; |
2 | use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; |
3 | |
4 | use pin_project_lite::pin_project; |
5 | use std::fmt; |
6 | use std::io::{self, IoSlice, SeekFrom, Write}; |
7 | use std::pin::Pin; |
8 | use std::task::{Context, Poll}; |
9 | |
10 | pin_project! { |
11 | /// Wraps a writer and buffers its output. |
12 | /// |
13 | /// It can be excessively inefficient to work directly with something that |
14 | /// implements [`AsyncWrite`]. A `BufWriter` keeps an in-memory buffer of data and |
15 | /// writes it to an underlying writer in large, infrequent batches. |
16 | /// |
17 | /// `BufWriter` can improve the speed of programs that make *small* and |
18 | /// *repeated* write calls to the same file or network socket. It does not |
19 | /// help when writing very large amounts at once, or writing just one or a few |
20 | /// times. It also provides no advantage when writing to a destination that is |
21 | /// in memory, like a `Vec<u8>`. |
22 | /// |
23 | /// When the `BufWriter` is dropped, the contents of its buffer will be |
24 | /// discarded. Creating multiple instances of a `BufWriter` on the same |
25 | /// stream can cause data loss. If you need to write out the contents of its |
26 | /// buffer, you must manually call flush before the writer is dropped. |
27 | /// |
28 | /// [`AsyncWrite`]: AsyncWrite |
29 | /// [`flush`]: super::AsyncWriteExt::flush |
30 | /// |
31 | #[cfg_attr(docsrs, doc(cfg(feature = "io-util" )))] |
32 | pub struct BufWriter<W> { |
33 | #[pin] |
34 | pub(super) inner: W, |
35 | pub(super) buf: Vec<u8>, |
36 | pub(super) written: usize, |
37 | pub(super) seek_state: SeekState, |
38 | } |
39 | } |
40 | |
41 | impl<W: AsyncWrite> BufWriter<W> { |
42 | /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB, |
43 | /// but may change in the future. |
44 | pub fn new(inner: W) -> Self { |
45 | Self::with_capacity(DEFAULT_BUF_SIZE, inner) |
46 | } |
47 | |
48 | /// Creates a new `BufWriter` with the specified buffer capacity. |
49 | pub fn with_capacity(cap: usize, inner: W) -> Self { |
50 | Self { |
51 | inner, |
52 | buf: Vec::with_capacity(cap), |
53 | written: 0, |
54 | seek_state: SeekState::Init, |
55 | } |
56 | } |
57 | |
58 | fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
59 | let mut me = self.project(); |
60 | |
61 | let len = me.buf.len(); |
62 | let mut ret = Ok(()); |
63 | while *me.written < len { |
64 | match ready!(me.inner.as_mut().poll_write(cx, &me.buf[*me.written..])) { |
65 | Ok(0) => { |
66 | ret = Err(io::Error::new( |
67 | io::ErrorKind::WriteZero, |
68 | "failed to write the buffered data" , |
69 | )); |
70 | break; |
71 | } |
72 | Ok(n) => *me.written += n, |
73 | Err(e) => { |
74 | ret = Err(e); |
75 | break; |
76 | } |
77 | } |
78 | } |
79 | if *me.written > 0 { |
80 | me.buf.drain(..*me.written); |
81 | } |
82 | *me.written = 0; |
83 | Poll::Ready(ret) |
84 | } |
85 | |
86 | /// Gets a reference to the underlying writer. |
87 | pub fn get_ref(&self) -> &W { |
88 | &self.inner |
89 | } |
90 | |
91 | /// Gets a mutable reference to the underlying writer. |
92 | /// |
93 | /// It is inadvisable to directly write to the underlying writer. |
94 | pub fn get_mut(&mut self) -> &mut W { |
95 | &mut self.inner |
96 | } |
97 | |
98 | /// Gets a pinned mutable reference to the underlying writer. |
99 | /// |
100 | /// It is inadvisable to directly write to the underlying writer. |
101 | pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { |
102 | self.project().inner |
103 | } |
104 | |
105 | /// Consumes this `BufWriter`, returning the underlying writer. |
106 | /// |
107 | /// Note that any leftover data in the internal buffer is lost. |
108 | pub fn into_inner(self) -> W { |
109 | self.inner |
110 | } |
111 | |
112 | /// Returns a reference to the internally buffered data. |
113 | pub fn buffer(&self) -> &[u8] { |
114 | &self.buf |
115 | } |
116 | } |
117 | |
118 | impl<W: AsyncWrite> AsyncWrite for BufWriter<W> { |
119 | fn poll_write( |
120 | mut self: Pin<&mut Self>, |
121 | cx: &mut Context<'_>, |
122 | buf: &[u8], |
123 | ) -> Poll<io::Result<usize>> { |
124 | if self.buf.len() + buf.len() > self.buf.capacity() { |
125 | ready!(self.as_mut().flush_buf(cx))?; |
126 | } |
127 | |
128 | let me = self.project(); |
129 | if buf.len() >= me.buf.capacity() { |
130 | me.inner.poll_write(cx, buf) |
131 | } else { |
132 | Poll::Ready(me.buf.write(buf)) |
133 | } |
134 | } |
135 | |
136 | fn poll_write_vectored( |
137 | mut self: Pin<&mut Self>, |
138 | cx: &mut Context<'_>, |
139 | mut bufs: &[IoSlice<'_>], |
140 | ) -> Poll<io::Result<usize>> { |
141 | if self.inner.is_write_vectored() { |
142 | let total_len = bufs |
143 | .iter() |
144 | .fold(0usize, |acc, b| acc.saturating_add(b.len())); |
145 | if total_len > self.buf.capacity() - self.buf.len() { |
146 | ready!(self.as_mut().flush_buf(cx))?; |
147 | } |
148 | let me = self.as_mut().project(); |
149 | if total_len >= me.buf.capacity() { |
150 | // It's more efficient to pass the slices directly to the |
151 | // underlying writer than to buffer them. |
152 | // The case when the total_len calculation saturates at |
153 | // usize::MAX is also handled here. |
154 | me.inner.poll_write_vectored(cx, bufs) |
155 | } else { |
156 | bufs.iter().for_each(|b| me.buf.extend_from_slice(b)); |
157 | Poll::Ready(Ok(total_len)) |
158 | } |
159 | } else { |
160 | // Remove empty buffers at the beginning of bufs. |
161 | while bufs.first().map(|buf| buf.len()) == Some(0) { |
162 | bufs = &bufs[1..]; |
163 | } |
164 | if bufs.is_empty() { |
165 | return Poll::Ready(Ok(0)); |
166 | } |
167 | // Flush if the first buffer doesn't fit. |
168 | let first_len = bufs[0].len(); |
169 | if first_len > self.buf.capacity() - self.buf.len() { |
170 | ready!(self.as_mut().flush_buf(cx))?; |
171 | debug_assert!(self.buf.is_empty()); |
172 | } |
173 | let me = self.as_mut().project(); |
174 | if first_len >= me.buf.capacity() { |
175 | // The slice is at least as large as the buffering capacity, |
176 | // so it's better to write it directly, bypassing the buffer. |
177 | debug_assert!(me.buf.is_empty()); |
178 | return me.inner.poll_write(cx, &bufs[0]); |
179 | } else { |
180 | me.buf.extend_from_slice(&bufs[0]); |
181 | bufs = &bufs[1..]; |
182 | } |
183 | let mut total_written = first_len; |
184 | debug_assert!(total_written != 0); |
185 | // Append the buffers that fit in the internal buffer. |
186 | for buf in bufs { |
187 | if buf.len() > me.buf.capacity() - me.buf.len() { |
188 | break; |
189 | } else { |
190 | me.buf.extend_from_slice(buf); |
191 | total_written += buf.len(); |
192 | } |
193 | } |
194 | Poll::Ready(Ok(total_written)) |
195 | } |
196 | } |
197 | |
198 | fn is_write_vectored(&self) -> bool { |
199 | true |
200 | } |
201 | |
202 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
203 | ready!(self.as_mut().flush_buf(cx))?; |
204 | self.get_pin_mut().poll_flush(cx) |
205 | } |
206 | |
207 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
208 | ready!(self.as_mut().flush_buf(cx))?; |
209 | self.get_pin_mut().poll_shutdown(cx) |
210 | } |
211 | } |
212 | |
213 | #[derive(Debug, Clone, Copy)] |
214 | pub(super) enum SeekState { |
215 | /// start_seek has not been called. |
216 | Init, |
217 | /// start_seek has been called, but poll_complete has not yet been called. |
218 | Start(SeekFrom), |
219 | /// Waiting for completion of poll_complete. |
220 | Pending, |
221 | } |
222 | |
223 | /// Seek to the offset, in bytes, in the underlying writer. |
224 | /// |
225 | /// Seeking always writes out the internal buffer before seeking. |
226 | impl<W: AsyncWrite + AsyncSeek> AsyncSeek for BufWriter<W> { |
227 | fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { |
228 | // We need to flush the internal buffer before seeking. |
229 | // It receives a `Context` and returns a `Poll`, so it cannot be called |
230 | // inside `start_seek`. |
231 | *self.project().seek_state = SeekState::Start(pos); |
232 | Ok(()) |
233 | } |
234 | |
235 | fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { |
236 | let pos = match self.seek_state { |
237 | SeekState::Init => { |
238 | return self.project().inner.poll_complete(cx); |
239 | } |
240 | SeekState::Start(pos) => Some(pos), |
241 | SeekState::Pending => None, |
242 | }; |
243 | |
244 | // Flush the internal buffer before seeking. |
245 | ready!(self.as_mut().flush_buf(cx))?; |
246 | |
247 | let mut me = self.project(); |
248 | if let Some(pos) = pos { |
249 | // Ensure previous seeks have finished before starting a new one |
250 | ready!(me.inner.as_mut().poll_complete(cx))?; |
251 | if let Err(e) = me.inner.as_mut().start_seek(pos) { |
252 | *me.seek_state = SeekState::Init; |
253 | return Poll::Ready(Err(e)); |
254 | } |
255 | } |
256 | match me.inner.poll_complete(cx) { |
257 | Poll::Ready(res) => { |
258 | *me.seek_state = SeekState::Init; |
259 | Poll::Ready(res) |
260 | } |
261 | Poll::Pending => { |
262 | *me.seek_state = SeekState::Pending; |
263 | Poll::Pending |
264 | } |
265 | } |
266 | } |
267 | } |
268 | |
269 | impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> { |
270 | fn poll_read( |
271 | self: Pin<&mut Self>, |
272 | cx: &mut Context<'_>, |
273 | buf: &mut ReadBuf<'_>, |
274 | ) -> Poll<io::Result<()>> { |
275 | self.get_pin_mut().poll_read(cx, buf) |
276 | } |
277 | } |
278 | |
279 | impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> { |
280 | fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { |
281 | self.get_pin_mut().poll_fill_buf(cx) |
282 | } |
283 | |
284 | fn consume(self: Pin<&mut Self>, amt: usize) { |
285 | self.get_pin_mut().consume(amt); |
286 | } |
287 | } |
288 | |
289 | impl<W: fmt::Debug> fmt::Debug for BufWriter<W> { |
290 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
291 | f.debug_struct("BufWriter" ) |
292 | .field("writer" , &self.inner) |
293 | .field( |
294 | "buffer" , |
295 | &format_args!("{}/{}" , self.buf.len(), self.buf.capacity()), |
296 | ) |
297 | .field("written" , &self.written) |
298 | .finish() |
299 | } |
300 | } |
301 | |
302 | #[cfg (test)] |
303 | mod tests { |
304 | use super::*; |
305 | |
306 | #[test] |
307 | fn assert_unpin() { |
308 | crate::is_unpin::<BufWriter<()>>(); |
309 | } |
310 | } |
311 | |