1use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
2
3use std::future::Future;
4use std::io;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8#[derive(Debug)]
9pub(super) struct CopyBuffer {
10 read_done: bool,
11 need_flush: bool,
12 pos: usize,
13 cap: usize,
14 amt: u64,
15 buf: Box<[u8]>,
16}
17
18impl CopyBuffer {
19 pub(super) fn new() -> Self {
20 Self {
21 read_done: false,
22 need_flush: false,
23 pos: 0,
24 cap: 0,
25 amt: 0,
26 buf: vec![0; super::DEFAULT_BUF_SIZE].into_boxed_slice(),
27 }
28 }
29
30 fn poll_fill_buf<R>(
31 &mut self,
32 cx: &mut Context<'_>,
33 reader: Pin<&mut R>,
34 ) -> Poll<io::Result<()>>
35 where
36 R: AsyncRead + ?Sized,
37 {
38 let me = &mut *self;
39 let mut buf = ReadBuf::new(&mut me.buf);
40 buf.set_filled(me.cap);
41
42 let res = reader.poll_read(cx, &mut buf);
43 if let Poll::Ready(Ok(())) = res {
44 let filled_len = buf.filled().len();
45 me.read_done = me.cap == filled_len;
46 me.cap = filled_len;
47 }
48 res
49 }
50
51 fn poll_write_buf<R, W>(
52 &mut self,
53 cx: &mut Context<'_>,
54 mut reader: Pin<&mut R>,
55 mut writer: Pin<&mut W>,
56 ) -> Poll<io::Result<usize>>
57 where
58 R: AsyncRead + ?Sized,
59 W: AsyncWrite + ?Sized,
60 {
61 let me = &mut *self;
62 match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
63 Poll::Pending => {
64 // Top up the buffer towards full if we can read a bit more
65 // data - this should improve the chances of a large write
66 if !me.read_done && me.cap < me.buf.len() {
67 ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
68 }
69 Poll::Pending
70 }
71 res => res,
72 }
73 }
74
75 pub(super) fn poll_copy<R, W>(
76 &mut self,
77 cx: &mut Context<'_>,
78 mut reader: Pin<&mut R>,
79 mut writer: Pin<&mut W>,
80 ) -> Poll<io::Result<u64>>
81 where
82 R: AsyncRead + ?Sized,
83 W: AsyncWrite + ?Sized,
84 {
85 ready!(crate::trace::trace_leaf(cx));
86 #[cfg(any(
87 feature = "fs",
88 feature = "io-std",
89 feature = "net",
90 feature = "process",
91 feature = "rt",
92 feature = "signal",
93 feature = "sync",
94 feature = "time",
95 ))]
96 // Keep track of task budget
97 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
98 loop {
99 // If our buffer is empty, then we need to read some data to
100 // continue.
101 if self.pos == self.cap && !self.read_done {
102 self.pos = 0;
103 self.cap = 0;
104
105 match self.poll_fill_buf(cx, reader.as_mut()) {
106 Poll::Ready(Ok(())) => {
107 #[cfg(any(
108 feature = "fs",
109 feature = "io-std",
110 feature = "net",
111 feature = "process",
112 feature = "rt",
113 feature = "signal",
114 feature = "sync",
115 feature = "time",
116 ))]
117 coop.made_progress();
118 }
119 Poll::Ready(Err(err)) => {
120 #[cfg(any(
121 feature = "fs",
122 feature = "io-std",
123 feature = "net",
124 feature = "process",
125 feature = "rt",
126 feature = "signal",
127 feature = "sync",
128 feature = "time",
129 ))]
130 coop.made_progress();
131 return Poll::Ready(Err(err));
132 }
133 Poll::Pending => {
134 // Try flushing when the reader has no progress to avoid deadlock
135 // when the reader depends on buffered writer.
136 if self.need_flush {
137 ready!(writer.as_mut().poll_flush(cx))?;
138 #[cfg(any(
139 feature = "fs",
140 feature = "io-std",
141 feature = "net",
142 feature = "process",
143 feature = "rt",
144 feature = "signal",
145 feature = "sync",
146 feature = "time",
147 ))]
148 coop.made_progress();
149 self.need_flush = false;
150 }
151
152 return Poll::Pending;
153 }
154 }
155 }
156
157 // If our buffer has some data, let's write it out!
158 while self.pos < self.cap {
159 let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
160 #[cfg(any(
161 feature = "fs",
162 feature = "io-std",
163 feature = "net",
164 feature = "process",
165 feature = "rt",
166 feature = "signal",
167 feature = "sync",
168 feature = "time",
169 ))]
170 coop.made_progress();
171 if i == 0 {
172 return Poll::Ready(Err(io::Error::new(
173 io::ErrorKind::WriteZero,
174 "write zero byte into writer",
175 )));
176 } else {
177 self.pos += i;
178 self.amt += i as u64;
179 self.need_flush = true;
180 }
181 }
182
183 // If pos larger than cap, this loop will never stop.
184 // In particular, user's wrong poll_write implementation returning
185 // incorrect written length may lead to thread blocking.
186 debug_assert!(
187 self.pos <= self.cap,
188 "writer returned length larger than input slice"
189 );
190
191 // If we've written all the data and we've seen EOF, flush out the
192 // data and finish the transfer.
193 if self.pos == self.cap && self.read_done {
194 ready!(writer.as_mut().poll_flush(cx))?;
195 #[cfg(any(
196 feature = "fs",
197 feature = "io-std",
198 feature = "net",
199 feature = "process",
200 feature = "rt",
201 feature = "signal",
202 feature = "sync",
203 feature = "time",
204 ))]
205 coop.made_progress();
206 return Poll::Ready(Ok(self.amt));
207 }
208 }
209 }
210}
211
212/// A future that asynchronously copies the entire contents of a reader into a
213/// writer.
214#[derive(Debug)]
215#[must_use = "futures do nothing unless you `.await` or poll them"]
216struct Copy<'a, R: ?Sized, W: ?Sized> {
217 reader: &'a mut R,
218 writer: &'a mut W,
219 buf: CopyBuffer,
220}
221
222cfg_io_util! {
223 /// Asynchronously copies the entire contents of a reader into a writer.
224 ///
225 /// This function returns a future that will continuously read data from
226 /// `reader` and then write it into `writer` in a streaming fashion until
227 /// `reader` returns EOF or fails.
228 ///
229 /// On success, the total number of bytes that were copied from `reader` to
230 /// `writer` is returned.
231 ///
232 /// This is an asynchronous version of [`std::io::copy`][std].
233 ///
234 /// A heap-allocated copy buffer with 8 KB is created to take data from the
235 /// reader to the writer, check [`copy_buf`] if you want an alternative for
236 /// [`AsyncBufRead`]. You can use `copy_buf` with [`BufReader`] to change the
237 /// buffer capacity.
238 ///
239 /// [std]: std::io::copy
240 /// [`copy_buf`]: crate::io::copy_buf
241 /// [`AsyncBufRead`]: crate::io::AsyncBufRead
242 /// [`BufReader`]: crate::io::BufReader
243 ///
244 /// # Errors
245 ///
246 /// The returned future will return an error immediately if any call to
247 /// `poll_read` or `poll_write` returns an error.
248 ///
249 /// # Examples
250 ///
251 /// ```
252 /// use tokio::io;
253 ///
254 /// # async fn dox() -> std::io::Result<()> {
255 /// let mut reader: &[u8] = b"hello";
256 /// let mut writer: Vec<u8> = vec![];
257 ///
258 /// io::copy(&mut reader, &mut writer).await?;
259 ///
260 /// assert_eq!(&b"hello"[..], &writer[..]);
261 /// # Ok(())
262 /// # }
263 /// ```
264 pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
265 where
266 R: AsyncRead + Unpin + ?Sized,
267 W: AsyncWrite + Unpin + ?Sized,
268 {
269 Copy {
270 reader,
271 writer,
272 buf: CopyBuffer::new()
273 }.await
274 }
275}
276
277impl<R, W> Future for Copy<'_, R, W>
278where
279 R: AsyncRead + Unpin + ?Sized,
280 W: AsyncWrite + Unpin + ?Sized,
281{
282 type Output = io::Result<u64>;
283
284 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
285 let me: &mut Copy<'_, R, W> = &mut *self;
286
287 me.buf
288 .poll_copy(cx, reader:Pin::new(&mut *me.reader), writer:Pin::new(&mut *me.writer))
289 }
290}
291