1use super::copy::CopyBuffer;
2
3use crate::future::poll_fn;
4use crate::io::{AsyncRead, AsyncWrite};
5
6use std::io;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10enum TransferState {
11 Running(CopyBuffer),
12 ShuttingDown(u64),
13 Done(u64),
14}
15
16fn transfer_one_direction<A, B>(
17 cx: &mut Context<'_>,
18 state: &mut TransferState,
19 r: &mut A,
20 w: &mut B,
21) -> Poll<io::Result<u64>>
22where
23 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
24 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
25{
26 let mut r = Pin::new(r);
27 let mut w = Pin::new(w);
28
29 loop {
30 match state {
31 TransferState::Running(buf) => {
32 let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
33 *state = TransferState::ShuttingDown(count);
34 }
35 TransferState::ShuttingDown(count) => {
36 ready!(w.as_mut().poll_shutdown(cx))?;
37
38 *state = TransferState::Done(*count);
39 }
40 TransferState::Done(count) => return Poll::Ready(Ok(*count)),
41 }
42 }
43}
44/// Copies data in both directions between `a` and `b`.
45///
46/// This function returns a future that will read from both streams,
47/// writing any data read to the opposing stream.
48/// This happens in both directions concurrently.
49///
50/// If an EOF is observed on one stream, [`shutdown()`] will be invoked on
51/// the other, and reading from that stream will stop. Copying of data in
52/// the other direction will continue.
53///
54/// The future will complete successfully once both directions of communication has been shut down.
55/// A direction is shut down when the reader reports EOF,
56/// at which point [`shutdown()`] is called on the corresponding writer. When finished,
57/// it will return a tuple of the number of bytes copied from a to b
58/// and the number of bytes copied from b to a, in that order.
59///
60/// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown
61///
62/// # Errors
63///
64/// The future will immediately return an error if any IO operation on `a`
65/// or `b` returns an error. Some data read from either stream may be lost (not
66/// written to the other stream) in this case.
67///
68/// # Return value
69///
70/// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`.
71#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
72pub async fn copy_bidirectional<A, B>(a: &mut A, b: &mut B) -> Result<(u64, u64), std::io::Error>
73where
74 A: AsyncRead + AsyncWrite + Unpin + ?Sized,
75 B: AsyncRead + AsyncWrite + Unpin + ?Sized,
76{
77 let mut a_to_b = TransferState::Running(CopyBuffer::new());
78 let mut b_to_a = TransferState::Running(CopyBuffer::new());
79 poll_fn(|cx| {
80 let a_to_b = transfer_one_direction(cx, &mut a_to_b, a, b)?;
81 let b_to_a = transfer_one_direction(cx, &mut b_to_a, b, a)?;
82
83 // It is not a problem if ready! returns early because transfer_one_direction for the
84 // other direction will keep returning TransferState::Done(count) in future calls to poll
85 let a_to_b = ready!(a_to_b);
86 let b_to_a = ready!(b_to_a);
87
88 Poll::Ready(Ok((a_to_b, b_to_a)))
89 })
90 .await
91}
92