1//! In-process memory IO types.
2
3use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
4use crate::loom::sync::Mutex;
5
6use bytes::{Buf, BytesMut};
7use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{self, Poll, Waker},
11};
12
13/// A bidirectional pipe to read and write bytes in memory.
14///
15/// A pair of `DuplexStream`s are created together, and they act as a "channel"
16/// that can be used as in-memory IO types. Writing to one of the pairs will
17/// allow that data to be read from the other, and vice versa.
18///
19/// # Closing a `DuplexStream`
20///
21/// If one end of the `DuplexStream` channel is dropped, any pending reads on
22/// the other side will continue to read data until the buffer is drained, then
23/// they will signal EOF by returning 0 bytes. Any writes to the other side,
24/// including pending ones (that are waiting for free space in the buffer) will
25/// return `Err(BrokenPipe)` immediately.
26///
27/// # Example
28///
29/// ```
30/// # async fn ex() -> std::io::Result<()> {
31/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
32/// let (mut client, mut server) = tokio::io::duplex(64);
33///
34/// client.write_all(b"ping").await?;
35///
36/// let mut buf = [0u8; 4];
37/// server.read_exact(&mut buf).await?;
38/// assert_eq!(&buf, b"ping");
39///
40/// server.write_all(b"pong").await?;
41///
42/// client.read_exact(&mut buf).await?;
43/// assert_eq!(&buf, b"pong");
44/// # Ok(())
45/// # }
46/// ```
47#[derive(Debug)]
48#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49pub struct DuplexStream {
50 read: Arc<Mutex<Pipe>>,
51 write: Arc<Mutex<Pipe>>,
52}
53
54/// A unidirectional IO over a piece of memory.
55///
56/// Data can be written to the pipe, and reading will return that data.
57#[derive(Debug)]
58struct Pipe {
59 /// The buffer storing the bytes written, also read from.
60 ///
61 /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
62 /// functionality already. Additionally, it can try to copy data in the
63 /// same buffer if there read index has advanced far enough.
64 buffer: BytesMut,
65 /// Determines if the write side has been closed.
66 is_closed: bool,
67 /// The maximum amount of bytes that can be written before returning
68 /// `Poll::Pending`.
69 max_buf_size: usize,
70 /// If the `read` side has been polled and is pending, this is the waker
71 /// for that parked task.
72 read_waker: Option<Waker>,
73 /// If the `write` side has filled the `max_buf_size` and returned
74 /// `Poll::Pending`, this is the waker for that parked task.
75 write_waker: Option<Waker>,
76}
77
78// ===== impl DuplexStream =====
79
80/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
81///
82/// The `max_buf_size` argument is the maximum amount of bytes that can be
83/// written to a side before the write returns `Poll::Pending`.
84#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
85pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
86 let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
87 let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
88
89 (
90 DuplexStream {
91 read: one.clone(),
92 write: two.clone(),
93 },
94 DuplexStream {
95 read: two,
96 write: one,
97 },
98 )
99}
100
101impl AsyncRead for DuplexStream {
102 // Previous rustc required this `self` to be `mut`, even though newer
103 // versions recognize it isn't needed to call `lock()`. So for
104 // compatibility, we include the `mut` and `allow` the lint.
105 //
106 // See https://github.com/rust-lang/rust/issues/73592
107 #[allow(unused_mut)]
108 fn poll_read(
109 mut self: Pin<&mut Self>,
110 cx: &mut task::Context<'_>,
111 buf: &mut ReadBuf<'_>,
112 ) -> Poll<std::io::Result<()>> {
113 Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
114 }
115}
116
117impl AsyncWrite for DuplexStream {
118 #[allow(unused_mut)]
119 fn poll_write(
120 mut self: Pin<&mut Self>,
121 cx: &mut task::Context<'_>,
122 buf: &[u8],
123 ) -> Poll<std::io::Result<usize>> {
124 Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
125 }
126
127 fn poll_write_vectored(
128 self: Pin<&mut Self>,
129 cx: &mut task::Context<'_>,
130 bufs: &[std::io::IoSlice<'_>],
131 ) -> Poll<Result<usize, std::io::Error>> {
132 Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs)
133 }
134
135 fn is_write_vectored(&self) -> bool {
136 true
137 }
138
139 #[allow(unused_mut)]
140 fn poll_flush(
141 mut self: Pin<&mut Self>,
142 cx: &mut task::Context<'_>,
143 ) -> Poll<std::io::Result<()>> {
144 Pin::new(&mut *self.write.lock()).poll_flush(cx)
145 }
146
147 #[allow(unused_mut)]
148 fn poll_shutdown(
149 mut self: Pin<&mut Self>,
150 cx: &mut task::Context<'_>,
151 ) -> Poll<std::io::Result<()>> {
152 Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
153 }
154}
155
156impl Drop for DuplexStream {
157 fn drop(&mut self) {
158 // notify the other side of the closure
159 self.write.lock().close_write();
160 self.read.lock().close_read();
161 }
162}
163
164// ===== impl Pipe =====
165
166impl Pipe {
167 fn new(max_buf_size: usize) -> Self {
168 Pipe {
169 buffer: BytesMut::new(),
170 is_closed: false,
171 max_buf_size,
172 read_waker: None,
173 write_waker: None,
174 }
175 }
176
177 fn close_write(&mut self) {
178 self.is_closed = true;
179 // needs to notify any readers that no more data will come
180 if let Some(waker) = self.read_waker.take() {
181 waker.wake();
182 }
183 }
184
185 fn close_read(&mut self) {
186 self.is_closed = true;
187 // needs to notify any writers that they have to abort
188 if let Some(waker) = self.write_waker.take() {
189 waker.wake();
190 }
191 }
192
193 fn poll_read_internal(
194 mut self: Pin<&mut Self>,
195 cx: &mut task::Context<'_>,
196 buf: &mut ReadBuf<'_>,
197 ) -> Poll<std::io::Result<()>> {
198 if self.buffer.has_remaining() {
199 let max = self.buffer.remaining().min(buf.remaining());
200 buf.put_slice(&self.buffer[..max]);
201 self.buffer.advance(max);
202 if max > 0 {
203 // The passed `buf` might have been empty, don't wake up if
204 // no bytes have been moved.
205 if let Some(waker) = self.write_waker.take() {
206 waker.wake();
207 }
208 }
209 Poll::Ready(Ok(()))
210 } else if self.is_closed {
211 Poll::Ready(Ok(()))
212 } else {
213 self.read_waker = Some(cx.waker().clone());
214 Poll::Pending
215 }
216 }
217
218 fn poll_write_internal(
219 mut self: Pin<&mut Self>,
220 cx: &mut task::Context<'_>,
221 buf: &[u8],
222 ) -> Poll<std::io::Result<usize>> {
223 if self.is_closed {
224 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
225 }
226 let avail = self.max_buf_size - self.buffer.len();
227 if avail == 0 {
228 self.write_waker = Some(cx.waker().clone());
229 return Poll::Pending;
230 }
231
232 let len = buf.len().min(avail);
233 self.buffer.extend_from_slice(&buf[..len]);
234 if let Some(waker) = self.read_waker.take() {
235 waker.wake();
236 }
237 Poll::Ready(Ok(len))
238 }
239
240 fn poll_write_vectored_internal(
241 mut self: Pin<&mut Self>,
242 cx: &mut task::Context<'_>,
243 bufs: &[std::io::IoSlice<'_>],
244 ) -> Poll<Result<usize, std::io::Error>> {
245 if self.is_closed {
246 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
247 }
248 let avail = self.max_buf_size - self.buffer.len();
249 if avail == 0 {
250 self.write_waker = Some(cx.waker().clone());
251 return Poll::Pending;
252 }
253
254 let mut rem = avail;
255 for buf in bufs {
256 if rem == 0 {
257 break;
258 }
259
260 let len = buf.len().min(rem);
261 self.buffer.extend_from_slice(&buf[..len]);
262 rem -= len;
263 }
264
265 if let Some(waker) = self.read_waker.take() {
266 waker.wake();
267 }
268 Poll::Ready(Ok(avail - rem))
269 }
270}
271
272impl AsyncRead for Pipe {
273 cfg_coop! {
274 fn poll_read(
275 self: Pin<&mut Self>,
276 cx: &mut task::Context<'_>,
277 buf: &mut ReadBuf<'_>,
278 ) -> Poll<std::io::Result<()>> {
279 ready!(crate::trace::trace_leaf(cx));
280 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
281
282 let ret = self.poll_read_internal(cx, buf);
283 if ret.is_ready() {
284 coop.made_progress();
285 }
286 ret
287 }
288 }
289
290 cfg_not_coop! {
291 fn poll_read(
292 self: Pin<&mut Self>,
293 cx: &mut task::Context<'_>,
294 buf: &mut ReadBuf<'_>,
295 ) -> Poll<std::io::Result<()>> {
296 ready!(crate::trace::trace_leaf(cx));
297 self.poll_read_internal(cx, buf)
298 }
299 }
300}
301
302impl AsyncWrite for Pipe {
303 cfg_coop! {
304 fn poll_write(
305 self: Pin<&mut Self>,
306 cx: &mut task::Context<'_>,
307 buf: &[u8],
308 ) -> Poll<std::io::Result<usize>> {
309 ready!(crate::trace::trace_leaf(cx));
310 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
311
312 let ret = self.poll_write_internal(cx, buf);
313 if ret.is_ready() {
314 coop.made_progress();
315 }
316 ret
317 }
318 }
319
320 cfg_not_coop! {
321 fn poll_write(
322 self: Pin<&mut Self>,
323 cx: &mut task::Context<'_>,
324 buf: &[u8],
325 ) -> Poll<std::io::Result<usize>> {
326 ready!(crate::trace::trace_leaf(cx));
327 self.poll_write_internal(cx, buf)
328 }
329 }
330
331 cfg_coop! {
332 fn poll_write_vectored(
333 self: Pin<&mut Self>,
334 cx: &mut task::Context<'_>,
335 bufs: &[std::io::IoSlice<'_>],
336 ) -> Poll<Result<usize, std::io::Error>> {
337 ready!(crate::trace::trace_leaf(cx));
338 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
339
340 let ret = self.poll_write_vectored_internal(cx, bufs);
341 if ret.is_ready() {
342 coop.made_progress();
343 }
344 ret
345 }
346 }
347
348 cfg_not_coop! {
349 fn poll_write_vectored(
350 self: Pin<&mut Self>,
351 cx: &mut task::Context<'_>,
352 bufs: &[std::io::IoSlice<'_>],
353 ) -> Poll<Result<usize, std::io::Error>> {
354 ready!(crate::trace::trace_leaf(cx));
355 self.poll_write_vectored_internal(cx, bufs)
356 }
357 }
358
359 fn is_write_vectored(&self) -> bool {
360 true
361 }
362
363 fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
364 Poll::Ready(Ok(()))
365 }
366
367 fn poll_shutdown(
368 mut self: Pin<&mut Self>,
369 _: &mut task::Context<'_>,
370 ) -> Poll<std::io::Result<()>> {
371 self.close_write();
372 Poll::Ready(Ok(()))
373 }
374}
375