| 1 | use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; |
| 2 | |
| 3 | use std::future::Future; |
| 4 | use std::io; |
| 5 | use std::pin::Pin; |
| 6 | use std::task::{ready, Context, Poll}; |
| 7 | |
| 8 | #[derive (Debug)] |
| 9 | pub(super) struct CopyBuffer { |
| 10 | read_done: bool, |
| 11 | need_flush: bool, |
| 12 | pos: usize, |
| 13 | cap: usize, |
| 14 | amt: u64, |
| 15 | buf: Box<[u8]>, |
| 16 | } |
| 17 | |
| 18 | impl CopyBuffer { |
| 19 | pub(super) fn new(buf_size: usize) -> Self { |
| 20 | Self { |
| 21 | read_done: false, |
| 22 | need_flush: false, |
| 23 | pos: 0, |
| 24 | cap: 0, |
| 25 | amt: 0, |
| 26 | buf: vec![0; buf_size].into_boxed_slice(), |
| 27 | } |
| 28 | } |
| 29 | |
| 30 | fn poll_fill_buf<R>( |
| 31 | &mut self, |
| 32 | cx: &mut Context<'_>, |
| 33 | reader: Pin<&mut R>, |
| 34 | ) -> Poll<io::Result<()>> |
| 35 | where |
| 36 | R: AsyncRead + ?Sized, |
| 37 | { |
| 38 | let me = &mut *self; |
| 39 | let mut buf = ReadBuf::new(&mut me.buf); |
| 40 | buf.set_filled(me.cap); |
| 41 | |
| 42 | let res = reader.poll_read(cx, &mut buf); |
| 43 | if let Poll::Ready(Ok(())) = res { |
| 44 | let filled_len = buf.filled().len(); |
| 45 | me.read_done = me.cap == filled_len; |
| 46 | me.cap = filled_len; |
| 47 | } |
| 48 | res |
| 49 | } |
| 50 | |
| 51 | fn poll_write_buf<R, W>( |
| 52 | &mut self, |
| 53 | cx: &mut Context<'_>, |
| 54 | mut reader: Pin<&mut R>, |
| 55 | mut writer: Pin<&mut W>, |
| 56 | ) -> Poll<io::Result<usize>> |
| 57 | where |
| 58 | R: AsyncRead + ?Sized, |
| 59 | W: AsyncWrite + ?Sized, |
| 60 | { |
| 61 | let me = &mut *self; |
| 62 | match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { |
| 63 | Poll::Pending => { |
| 64 | // Top up the buffer towards full if we can read a bit more |
| 65 | // data - this should improve the chances of a large write |
| 66 | if !me.read_done && me.cap < me.buf.len() { |
| 67 | ready!(me.poll_fill_buf(cx, reader.as_mut()))?; |
| 68 | } |
| 69 | Poll::Pending |
| 70 | } |
| 71 | res => res, |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | pub(super) fn poll_copy<R, W>( |
| 76 | &mut self, |
| 77 | cx: &mut Context<'_>, |
| 78 | mut reader: Pin<&mut R>, |
| 79 | mut writer: Pin<&mut W>, |
| 80 | ) -> Poll<io::Result<u64>> |
| 81 | where |
| 82 | R: AsyncRead + ?Sized, |
| 83 | W: AsyncWrite + ?Sized, |
| 84 | { |
| 85 | ready!(crate::trace::trace_leaf(cx)); |
| 86 | #[cfg (any( |
| 87 | feature = "fs" , |
| 88 | feature = "io-std" , |
| 89 | feature = "net" , |
| 90 | feature = "process" , |
| 91 | feature = "rt" , |
| 92 | feature = "signal" , |
| 93 | feature = "sync" , |
| 94 | feature = "time" , |
| 95 | ))] |
| 96 | // Keep track of task budget |
| 97 | let coop = ready!(crate::runtime::coop::poll_proceed(cx)); |
| 98 | loop { |
| 99 | // If there is some space left in our buffer, then we try to read some |
| 100 | // data to continue, thus maximizing the chances of a large write. |
| 101 | if self.cap < self.buf.len() && !self.read_done { |
| 102 | match self.poll_fill_buf(cx, reader.as_mut()) { |
| 103 | Poll::Ready(Ok(())) => { |
| 104 | #[cfg (any( |
| 105 | feature = "fs" , |
| 106 | feature = "io-std" , |
| 107 | feature = "net" , |
| 108 | feature = "process" , |
| 109 | feature = "rt" , |
| 110 | feature = "signal" , |
| 111 | feature = "sync" , |
| 112 | feature = "time" , |
| 113 | ))] |
| 114 | coop.made_progress(); |
| 115 | } |
| 116 | Poll::Ready(Err(err)) => { |
| 117 | #[cfg (any( |
| 118 | feature = "fs" , |
| 119 | feature = "io-std" , |
| 120 | feature = "net" , |
| 121 | feature = "process" , |
| 122 | feature = "rt" , |
| 123 | feature = "signal" , |
| 124 | feature = "sync" , |
| 125 | feature = "time" , |
| 126 | ))] |
| 127 | coop.made_progress(); |
| 128 | return Poll::Ready(Err(err)); |
| 129 | } |
| 130 | Poll::Pending => { |
| 131 | // Ignore pending reads when our buffer is not empty, because |
| 132 | // we can try to write data immediately. |
| 133 | if self.pos == self.cap { |
| 134 | // Try flushing when the reader has no progress to avoid deadlock |
| 135 | // when the reader depends on buffered writer. |
| 136 | if self.need_flush { |
| 137 | ready!(writer.as_mut().poll_flush(cx))?; |
| 138 | #[cfg (any( |
| 139 | feature = "fs" , |
| 140 | feature = "io-std" , |
| 141 | feature = "net" , |
| 142 | feature = "process" , |
| 143 | feature = "rt" , |
| 144 | feature = "signal" , |
| 145 | feature = "sync" , |
| 146 | feature = "time" , |
| 147 | ))] |
| 148 | coop.made_progress(); |
| 149 | self.need_flush = false; |
| 150 | } |
| 151 | |
| 152 | return Poll::Pending; |
| 153 | } |
| 154 | } |
| 155 | } |
| 156 | } |
| 157 | |
| 158 | // If our buffer has some data, let's write it out! |
| 159 | while self.pos < self.cap { |
| 160 | let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; |
| 161 | #[cfg (any( |
| 162 | feature = "fs" , |
| 163 | feature = "io-std" , |
| 164 | feature = "net" , |
| 165 | feature = "process" , |
| 166 | feature = "rt" , |
| 167 | feature = "signal" , |
| 168 | feature = "sync" , |
| 169 | feature = "time" , |
| 170 | ))] |
| 171 | coop.made_progress(); |
| 172 | if i == 0 { |
| 173 | return Poll::Ready(Err(io::Error::new( |
| 174 | io::ErrorKind::WriteZero, |
| 175 | "write zero byte into writer" , |
| 176 | ))); |
| 177 | } else { |
| 178 | self.pos += i; |
| 179 | self.amt += i as u64; |
| 180 | self.need_flush = true; |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | // If pos larger than cap, this loop will never stop. |
| 185 | // In particular, user's wrong poll_write implementation returning |
| 186 | // incorrect written length may lead to thread blocking. |
| 187 | debug_assert!( |
| 188 | self.pos <= self.cap, |
| 189 | "writer returned length larger than input slice" |
| 190 | ); |
| 191 | |
| 192 | // All data has been written, the buffer can be considered empty again |
| 193 | self.pos = 0; |
| 194 | self.cap = 0; |
| 195 | |
| 196 | // If we've written all the data and we've seen EOF, flush out the |
| 197 | // data and finish the transfer. |
| 198 | if self.read_done { |
| 199 | ready!(writer.as_mut().poll_flush(cx))?; |
| 200 | #[cfg (any( |
| 201 | feature = "fs" , |
| 202 | feature = "io-std" , |
| 203 | feature = "net" , |
| 204 | feature = "process" , |
| 205 | feature = "rt" , |
| 206 | feature = "signal" , |
| 207 | feature = "sync" , |
| 208 | feature = "time" , |
| 209 | ))] |
| 210 | coop.made_progress(); |
| 211 | return Poll::Ready(Ok(self.amt)); |
| 212 | } |
| 213 | } |
| 214 | } |
| 215 | } |
| 216 | |
| 217 | /// A future that asynchronously copies the entire contents of a reader into a |
| 218 | /// writer. |
| 219 | #[derive (Debug)] |
| 220 | #[must_use = "futures do nothing unless you `.await` or poll them" ] |
| 221 | struct Copy<'a, R: ?Sized, W: ?Sized> { |
| 222 | reader: &'a mut R, |
| 223 | writer: &'a mut W, |
| 224 | buf: CopyBuffer, |
| 225 | } |
| 226 | |
| 227 | cfg_io_util! { |
| 228 | /// Asynchronously copies the entire contents of a reader into a writer. |
| 229 | /// |
| 230 | /// This function returns a future that will continuously read data from |
| 231 | /// `reader` and then write it into `writer` in a streaming fashion until |
| 232 | /// `reader` returns EOF or fails. |
| 233 | /// |
| 234 | /// On success, the total number of bytes that were copied from `reader` to |
| 235 | /// `writer` is returned. |
| 236 | /// |
| 237 | /// This is an asynchronous version of [`std::io::copy`][std]. |
| 238 | /// |
| 239 | /// A heap-allocated copy buffer with 8 KB is created to take data from the |
| 240 | /// reader to the writer, check [`copy_buf`] if you want an alternative for |
| 241 | /// [`AsyncBufRead`]. You can use `copy_buf` with [`BufReader`] to change the |
| 242 | /// buffer capacity. |
| 243 | /// |
| 244 | /// [std]: std::io::copy |
| 245 | /// [`copy_buf`]: crate::io::copy_buf |
| 246 | /// [`AsyncBufRead`]: crate::io::AsyncBufRead |
| 247 | /// [`BufReader`]: crate::io::BufReader |
| 248 | /// |
| 249 | /// # Errors |
| 250 | /// |
| 251 | /// The returned future will return an error immediately if any call to |
| 252 | /// `poll_read` or `poll_write` returns an error. |
| 253 | /// |
| 254 | /// # Examples |
| 255 | /// |
| 256 | /// ``` |
| 257 | /// use tokio::io; |
| 258 | /// |
| 259 | /// # async fn dox() -> std::io::Result<()> { |
| 260 | /// let mut reader: &[u8] = b"hello"; |
| 261 | /// let mut writer: Vec<u8> = vec![]; |
| 262 | /// |
| 263 | /// io::copy(&mut reader, &mut writer).await?; |
| 264 | /// |
| 265 | /// assert_eq!(&b"hello"[..], &writer[..]); |
| 266 | /// # Ok(()) |
| 267 | /// # } |
| 268 | /// ``` |
| 269 | pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> |
| 270 | where |
| 271 | R: AsyncRead + Unpin + ?Sized, |
| 272 | W: AsyncWrite + Unpin + ?Sized, |
| 273 | { |
| 274 | Copy { |
| 275 | reader, |
| 276 | writer, |
| 277 | buf: CopyBuffer::new(super::DEFAULT_BUF_SIZE) |
| 278 | }.await |
| 279 | } |
| 280 | } |
| 281 | |
| 282 | impl<R, W> Future for Copy<'_, R, W> |
| 283 | where |
| 284 | R: AsyncRead + Unpin + ?Sized, |
| 285 | W: AsyncWrite + Unpin + ?Sized, |
| 286 | { |
| 287 | type Output = io::Result<u64>; |
| 288 | |
| 289 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { |
| 290 | let me: &mut Copy<'_, R, W> = &mut *self; |
| 291 | |
| 292 | me.buf |
| 293 | .poll_copy(cx, reader:Pin::new(&mut *me.reader), writer:Pin::new(&mut *me.writer)) |
| 294 | } |
| 295 | } |
| 296 | |