1//! In-process memory IO types.
2
3use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
4use crate::loom::sync::Mutex;
5
6use bytes::{Buf, BytesMut};
7use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{self, Poll, Waker},
11};
12
13/// A bidirectional pipe to read and write bytes in memory.
14///
15/// A pair of `DuplexStream`s are created together, and they act as a "channel"
16/// that can be used as in-memory IO types. Writing to one of the pairs will
17/// allow that data to be read from the other, and vice versa.
18///
19/// # Closing a `DuplexStream`
20///
21/// If one end of the `DuplexStream` channel is dropped, any pending reads on
22/// the other side will continue to read data until the buffer is drained, then
23/// they will signal EOF by returning 0 bytes. Any writes to the other side,
24/// including pending ones (that are waiting for free space in the buffer) will
25/// return `Err(BrokenPipe)` immediately.
26///
27/// # Example
28///
29/// ```
30/// # async fn ex() -> std::io::Result<()> {
31/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
32/// let (mut client, mut server) = tokio::io::duplex(64);
33///
34/// client.write_all(b"ping").await?;
35///
36/// let mut buf = [0u8; 4];
37/// server.read_exact(&mut buf).await?;
38/// assert_eq!(&buf, b"ping");
39///
40/// server.write_all(b"pong").await?;
41///
42/// client.read_exact(&mut buf).await?;
43/// assert_eq!(&buf, b"pong");
44/// # Ok(())
45/// # }
46/// ```
47#[derive(Debug)]
48#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49pub struct DuplexStream {
50 read: Arc<Mutex<Pipe>>,
51 write: Arc<Mutex<Pipe>>,
52}
53
54/// A unidirectional IO over a piece of memory.
55///
56/// Data can be written to the pipe, and reading will return that data.
57#[derive(Debug)]
58struct Pipe {
59 /// The buffer storing the bytes written, also read from.
60 ///
61 /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
62 /// functionality already. Additionally, it can try to copy data in the
63 /// same buffer if there read index has advanced far enough.
64 buffer: BytesMut,
65 /// Determines if the write side has been closed.
66 is_closed: bool,
67 /// The maximum amount of bytes that can be written before returning
68 /// `Poll::Pending`.
69 max_buf_size: usize,
70 /// If the `read` side has been polled and is pending, this is the waker
71 /// for that parked task.
72 read_waker: Option<Waker>,
73 /// If the `write` side has filled the `max_buf_size` and returned
74 /// `Poll::Pending`, this is the waker for that parked task.
75 write_waker: Option<Waker>,
76}
77
78// ===== impl DuplexStream =====
79
80/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
81///
82/// The `max_buf_size` argument is the maximum amount of bytes that can be
83/// written to a side before the write returns `Poll::Pending`.
84#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
85pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
86 let one: Arc> = Arc::new(data:Mutex::new(Pipe::new(max_buf_size)));
87 let two: Arc> = Arc::new(data:Mutex::new(Pipe::new(max_buf_size)));
88
89 (
90 DuplexStream {
91 read: one.clone(),
92 write: two.clone(),
93 },
94 DuplexStream {
95 read: two,
96 write: one,
97 },
98 )
99}
100
101impl AsyncRead for DuplexStream {
102 // Previous rustc required this `self` to be `mut`, even though newer
103 // versions recognize it isn't needed to call `lock()`. So for
104 // compatibility, we include the `mut` and `allow` the lint.
105 //
106 // See https://github.com/rust-lang/rust/issues/73592
107 #[allow(unused_mut)]
108 fn poll_read(
109 mut self: Pin<&mut Self>,
110 cx: &mut task::Context<'_>,
111 buf: &mut ReadBuf<'_>,
112 ) -> Poll<std::io::Result<()>> {
113 Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
114 }
115}
116
117impl AsyncWrite for DuplexStream {
118 #[allow(unused_mut)]
119 fn poll_write(
120 mut self: Pin<&mut Self>,
121 cx: &mut task::Context<'_>,
122 buf: &[u8],
123 ) -> Poll<std::io::Result<usize>> {
124 Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
125 }
126
127 #[allow(unused_mut)]
128 fn poll_flush(
129 mut self: Pin<&mut Self>,
130 cx: &mut task::Context<'_>,
131 ) -> Poll<std::io::Result<()>> {
132 Pin::new(&mut *self.write.lock()).poll_flush(cx)
133 }
134
135 #[allow(unused_mut)]
136 fn poll_shutdown(
137 mut self: Pin<&mut Self>,
138 cx: &mut task::Context<'_>,
139 ) -> Poll<std::io::Result<()>> {
140 Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
141 }
142}
143
144impl Drop for DuplexStream {
145 fn drop(&mut self) {
146 // notify the other side of the closure
147 self.write.lock().close_write();
148 self.read.lock().close_read();
149 }
150}
151
152// ===== impl Pipe =====
153
154impl Pipe {
155 fn new(max_buf_size: usize) -> Self {
156 Pipe {
157 buffer: BytesMut::new(),
158 is_closed: false,
159 max_buf_size,
160 read_waker: None,
161 write_waker: None,
162 }
163 }
164
165 fn close_write(&mut self) {
166 self.is_closed = true;
167 // needs to notify any readers that no more data will come
168 if let Some(waker) = self.read_waker.take() {
169 waker.wake();
170 }
171 }
172
173 fn close_read(&mut self) {
174 self.is_closed = true;
175 // needs to notify any writers that they have to abort
176 if let Some(waker) = self.write_waker.take() {
177 waker.wake();
178 }
179 }
180
181 fn poll_read_internal(
182 mut self: Pin<&mut Self>,
183 cx: &mut task::Context<'_>,
184 buf: &mut ReadBuf<'_>,
185 ) -> Poll<std::io::Result<()>> {
186 if self.buffer.has_remaining() {
187 let max = self.buffer.remaining().min(buf.remaining());
188 buf.put_slice(&self.buffer[..max]);
189 self.buffer.advance(max);
190 if max > 0 {
191 // The passed `buf` might have been empty, don't wake up if
192 // no bytes have been moved.
193 if let Some(waker) = self.write_waker.take() {
194 waker.wake();
195 }
196 }
197 Poll::Ready(Ok(()))
198 } else if self.is_closed {
199 Poll::Ready(Ok(()))
200 } else {
201 self.read_waker = Some(cx.waker().clone());
202 Poll::Pending
203 }
204 }
205
206 fn poll_write_internal(
207 mut self: Pin<&mut Self>,
208 cx: &mut task::Context<'_>,
209 buf: &[u8],
210 ) -> Poll<std::io::Result<usize>> {
211 if self.is_closed {
212 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
213 }
214 let avail = self.max_buf_size - self.buffer.len();
215 if avail == 0 {
216 self.write_waker = Some(cx.waker().clone());
217 return Poll::Pending;
218 }
219
220 let len = buf.len().min(avail);
221 self.buffer.extend_from_slice(&buf[..len]);
222 if let Some(waker) = self.read_waker.take() {
223 waker.wake();
224 }
225 Poll::Ready(Ok(len))
226 }
227}
228
229impl AsyncRead for Pipe {
230 cfg_coop! {
231 fn poll_read(
232 self: Pin<&mut Self>,
233 cx: &mut task::Context<'_>,
234 buf: &mut ReadBuf<'_>,
235 ) -> Poll<std::io::Result<()>> {
236 ready!(crate::trace::trace_leaf(cx));
237 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
238
239 let ret = self.poll_read_internal(cx, buf);
240 if ret.is_ready() {
241 coop.made_progress();
242 }
243 ret
244 }
245 }
246
247 cfg_not_coop! {
248 fn poll_read(
249 self: Pin<&mut Self>,
250 cx: &mut task::Context<'_>,
251 buf: &mut ReadBuf<'_>,
252 ) -> Poll<std::io::Result<()>> {
253 ready!(crate::trace::trace_leaf(cx));
254 self.poll_read_internal(cx, buf)
255 }
256 }
257}
258
259impl AsyncWrite for Pipe {
260 cfg_coop! {
261 fn poll_write(
262 self: Pin<&mut Self>,
263 cx: &mut task::Context<'_>,
264 buf: &[u8],
265 ) -> Poll<std::io::Result<usize>> {
266 ready!(crate::trace::trace_leaf(cx));
267 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
268
269 let ret = self.poll_write_internal(cx, buf);
270 if ret.is_ready() {
271 coop.made_progress();
272 }
273 ret
274 }
275 }
276
277 cfg_not_coop! {
278 fn poll_write(
279 self: Pin<&mut Self>,
280 cx: &mut task::Context<'_>,
281 buf: &[u8],
282 ) -> Poll<std::io::Result<usize>> {
283 ready!(crate::trace::trace_leaf(cx));
284 self.poll_write_internal(cx, buf)
285 }
286 }
287
288 fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
289 Poll::Ready(Ok(()))
290 }
291
292 fn poll_shutdown(
293 mut self: Pin<&mut Self>,
294 _: &mut task::Context<'_>,
295 ) -> Poll<std::io::Result<()>> {
296 self.close_write();
297 Poll::Ready(Ok(()))
298 }
299}
300