1//! In-process memory IO types.
2
3use crate::io::{split, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
4use crate::loom::sync::Mutex;
5
6use bytes::{Buf, BytesMut};
7use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{self, ready, Poll, Waker},
11};
12
13/// A bidirectional pipe to read and write bytes in memory.
14///
15/// A pair of `DuplexStream`s are created together, and they act as a "channel"
16/// that can be used as in-memory IO types. Writing to one of the pairs will
17/// allow that data to be read from the other, and vice versa.
18///
19/// # Closing a `DuplexStream`
20///
21/// If one end of the `DuplexStream` channel is dropped, any pending reads on
22/// the other side will continue to read data until the buffer is drained, then
23/// they will signal EOF by returning 0 bytes. Any writes to the other side,
24/// including pending ones (that are waiting for free space in the buffer) will
25/// return `Err(BrokenPipe)` immediately.
26///
27/// # Example
28///
29/// ```
30/// # async fn ex() -> std::io::Result<()> {
31/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
32/// let (mut client, mut server) = tokio::io::duplex(64);
33///
34/// client.write_all(b"ping").await?;
35///
36/// let mut buf = [0u8; 4];
37/// server.read_exact(&mut buf).await?;
38/// assert_eq!(&buf, b"ping");
39///
40/// server.write_all(b"pong").await?;
41///
42/// client.read_exact(&mut buf).await?;
43/// assert_eq!(&buf, b"pong");
44/// # Ok(())
45/// # }
46/// ```
47#[derive(Debug)]
48#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49pub struct DuplexStream {
50 read: Arc<Mutex<SimplexStream>>,
51 write: Arc<Mutex<SimplexStream>>,
52}
53
54/// A unidirectional pipe to read and write bytes in memory.
55///
56/// It can be constructed by [`simplex`] function which will create a pair of
57/// reader and writer or by calling [`SimplexStream::new_unsplit`] that will
58/// create a handle for both reading and writing.
59///
60/// # Example
61///
62/// ```
63/// # async fn ex() -> std::io::Result<()> {
64/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
65/// let (mut receiver, mut sender) = tokio::io::simplex(64);
66///
67/// sender.write_all(b"ping").await?;
68///
69/// let mut buf = [0u8; 4];
70/// receiver.read_exact(&mut buf).await?;
71/// assert_eq!(&buf, b"ping");
72/// # Ok(())
73/// # }
74/// ```
75#[derive(Debug)]
76#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
77pub struct SimplexStream {
78 /// The buffer storing the bytes written, also read from.
79 ///
80 /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
81 /// functionality already. Additionally, it can try to copy data in the
82 /// same buffer if there read index has advanced far enough.
83 buffer: BytesMut,
84 /// Determines if the write side has been closed.
85 is_closed: bool,
86 /// The maximum amount of bytes that can be written before returning
87 /// `Poll::Pending`.
88 max_buf_size: usize,
89 /// If the `read` side has been polled and is pending, this is the waker
90 /// for that parked task.
91 read_waker: Option<Waker>,
92 /// If the `write` side has filled the `max_buf_size` and returned
93 /// `Poll::Pending`, this is the waker for that parked task.
94 write_waker: Option<Waker>,
95}
96
97// ===== impl DuplexStream =====
98
99/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
100///
101/// The `max_buf_size` argument is the maximum amount of bytes that can be
102/// written to a side before the write returns `Poll::Pending`.
103#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
104pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
105 let one: Arc> = Arc::new(data:Mutex::new(SimplexStream::new_unsplit(max_buf_size)));
106 let two: Arc> = Arc::new(data:Mutex::new(SimplexStream::new_unsplit(max_buf_size)));
107
108 (
109 DuplexStream {
110 read: one.clone(),
111 write: two.clone(),
112 },
113 DuplexStream {
114 read: two,
115 write: one,
116 },
117 )
118}
119
120impl AsyncRead for DuplexStream {
121 // Previous rustc required this `self` to be `mut`, even though newer
122 // versions recognize it isn't needed to call `lock()`. So for
123 // compatibility, we include the `mut` and `allow` the lint.
124 //
125 // See https://github.com/rust-lang/rust/issues/73592
126 #[allow(unused_mut)]
127 fn poll_read(
128 mut self: Pin<&mut Self>,
129 cx: &mut task::Context<'_>,
130 buf: &mut ReadBuf<'_>,
131 ) -> Poll<std::io::Result<()>> {
132 Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
133 }
134}
135
136impl AsyncWrite for DuplexStream {
137 #[allow(unused_mut)]
138 fn poll_write(
139 mut self: Pin<&mut Self>,
140 cx: &mut task::Context<'_>,
141 buf: &[u8],
142 ) -> Poll<std::io::Result<usize>> {
143 Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
144 }
145
146 fn poll_write_vectored(
147 self: Pin<&mut Self>,
148 cx: &mut task::Context<'_>,
149 bufs: &[std::io::IoSlice<'_>],
150 ) -> Poll<Result<usize, std::io::Error>> {
151 Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs)
152 }
153
154 fn is_write_vectored(&self) -> bool {
155 true
156 }
157
158 #[allow(unused_mut)]
159 fn poll_flush(
160 mut self: Pin<&mut Self>,
161 cx: &mut task::Context<'_>,
162 ) -> Poll<std::io::Result<()>> {
163 Pin::new(&mut *self.write.lock()).poll_flush(cx)
164 }
165
166 #[allow(unused_mut)]
167 fn poll_shutdown(
168 mut self: Pin<&mut Self>,
169 cx: &mut task::Context<'_>,
170 ) -> Poll<std::io::Result<()>> {
171 Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
172 }
173}
174
175impl Drop for DuplexStream {
176 fn drop(&mut self) {
177 // notify the other side of the closure
178 self.write.lock().close_write();
179 self.read.lock().close_read();
180 }
181}
182
183// ===== impl SimplexStream =====
184
185/// Creates unidirectional buffer that acts like in memory pipe.
186///
187/// The `max_buf_size` argument is the maximum amount of bytes that can be
188/// written to a buffer before the it returns `Poll::Pending`.
189///
190/// # Unify reader and writer
191///
192/// The reader and writer half can be unified into a single structure
193/// of `SimplexStream` that supports both reading and writing or
194/// the `SimplexStream` can be already created as unified structure
195/// using [`SimplexStream::new_unsplit()`].
196///
197/// ```
198/// # async fn ex() -> std::io::Result<()> {
199/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
200/// let (reader, writer) = tokio::io::simplex(64);
201/// let mut simplex_stream = reader.unsplit(writer);
202/// simplex_stream.write_all(b"hello").await?;
203///
204/// let mut buf = [0u8; 5];
205/// simplex_stream.read_exact(&mut buf).await?;
206/// assert_eq!(&buf, b"hello");
207/// # Ok(())
208/// # }
209/// ```
210#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
211pub fn simplex(max_buf_size: usize) -> (ReadHalf<SimplexStream>, WriteHalf<SimplexStream>) {
212 split(stream:SimplexStream::new_unsplit(max_buf_size))
213}
214
215impl SimplexStream {
216 /// Creates unidirectional buffer that acts like in memory pipe. To create split
217 /// version with separate reader and writer you can use [`simplex`] function.
218 ///
219 /// The `max_buf_size` argument is the maximum amount of bytes that can be
220 /// written to a buffer before the it returns `Poll::Pending`.
221 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
222 pub fn new_unsplit(max_buf_size: usize) -> SimplexStream {
223 SimplexStream {
224 buffer: BytesMut::new(),
225 is_closed: false,
226 max_buf_size,
227 read_waker: None,
228 write_waker: None,
229 }
230 }
231
232 fn close_write(&mut self) {
233 self.is_closed = true;
234 // needs to notify any readers that no more data will come
235 if let Some(waker) = self.read_waker.take() {
236 waker.wake();
237 }
238 }
239
240 fn close_read(&mut self) {
241 self.is_closed = true;
242 // needs to notify any writers that they have to abort
243 if let Some(waker) = self.write_waker.take() {
244 waker.wake();
245 }
246 }
247
248 fn poll_read_internal(
249 mut self: Pin<&mut Self>,
250 cx: &mut task::Context<'_>,
251 buf: &mut ReadBuf<'_>,
252 ) -> Poll<std::io::Result<()>> {
253 if self.buffer.has_remaining() {
254 let max = self.buffer.remaining().min(buf.remaining());
255 buf.put_slice(&self.buffer[..max]);
256 self.buffer.advance(max);
257 if max > 0 {
258 // The passed `buf` might have been empty, don't wake up if
259 // no bytes have been moved.
260 if let Some(waker) = self.write_waker.take() {
261 waker.wake();
262 }
263 }
264 Poll::Ready(Ok(()))
265 } else if self.is_closed {
266 Poll::Ready(Ok(()))
267 } else {
268 self.read_waker = Some(cx.waker().clone());
269 Poll::Pending
270 }
271 }
272
273 fn poll_write_internal(
274 mut self: Pin<&mut Self>,
275 cx: &mut task::Context<'_>,
276 buf: &[u8],
277 ) -> Poll<std::io::Result<usize>> {
278 if self.is_closed {
279 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
280 }
281 let avail = self.max_buf_size - self.buffer.len();
282 if avail == 0 {
283 self.write_waker = Some(cx.waker().clone());
284 return Poll::Pending;
285 }
286
287 let len = buf.len().min(avail);
288 self.buffer.extend_from_slice(&buf[..len]);
289 if let Some(waker) = self.read_waker.take() {
290 waker.wake();
291 }
292 Poll::Ready(Ok(len))
293 }
294
295 fn poll_write_vectored_internal(
296 mut self: Pin<&mut Self>,
297 cx: &mut task::Context<'_>,
298 bufs: &[std::io::IoSlice<'_>],
299 ) -> Poll<Result<usize, std::io::Error>> {
300 if self.is_closed {
301 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
302 }
303 let avail = self.max_buf_size - self.buffer.len();
304 if avail == 0 {
305 self.write_waker = Some(cx.waker().clone());
306 return Poll::Pending;
307 }
308
309 let mut rem = avail;
310 for buf in bufs {
311 if rem == 0 {
312 break;
313 }
314
315 let len = buf.len().min(rem);
316 self.buffer.extend_from_slice(&buf[..len]);
317 rem -= len;
318 }
319
320 if let Some(waker) = self.read_waker.take() {
321 waker.wake();
322 }
323 Poll::Ready(Ok(avail - rem))
324 }
325}
326
327impl AsyncRead for SimplexStream {
328 cfg_coop! {
329 fn poll_read(
330 self: Pin<&mut Self>,
331 cx: &mut task::Context<'_>,
332 buf: &mut ReadBuf<'_>,
333 ) -> Poll<std::io::Result<()>> {
334 ready!(crate::trace::trace_leaf(cx));
335 let coop = ready!(crate::task::coop::poll_proceed(cx));
336
337 let ret = self.poll_read_internal(cx, buf);
338 if ret.is_ready() {
339 coop.made_progress();
340 }
341 ret
342 }
343 }
344
345 cfg_not_coop! {
346 fn poll_read(
347 self: Pin<&mut Self>,
348 cx: &mut task::Context<'_>,
349 buf: &mut ReadBuf<'_>,
350 ) -> Poll<std::io::Result<()>> {
351 ready!(crate::trace::trace_leaf(cx));
352 self.poll_read_internal(cx, buf)
353 }
354 }
355}
356
357impl AsyncWrite for SimplexStream {
358 cfg_coop! {
359 fn poll_write(
360 self: Pin<&mut Self>,
361 cx: &mut task::Context<'_>,
362 buf: &[u8],
363 ) -> Poll<std::io::Result<usize>> {
364 ready!(crate::trace::trace_leaf(cx));
365 let coop = ready!(crate::task::coop::poll_proceed(cx));
366
367 let ret = self.poll_write_internal(cx, buf);
368 if ret.is_ready() {
369 coop.made_progress();
370 }
371 ret
372 }
373 }
374
375 cfg_not_coop! {
376 fn poll_write(
377 self: Pin<&mut Self>,
378 cx: &mut task::Context<'_>,
379 buf: &[u8],
380 ) -> Poll<std::io::Result<usize>> {
381 ready!(crate::trace::trace_leaf(cx));
382 self.poll_write_internal(cx, buf)
383 }
384 }
385
386 cfg_coop! {
387 fn poll_write_vectored(
388 self: Pin<&mut Self>,
389 cx: &mut task::Context<'_>,
390 bufs: &[std::io::IoSlice<'_>],
391 ) -> Poll<Result<usize, std::io::Error>> {
392 ready!(crate::trace::trace_leaf(cx));
393 let coop = ready!(crate::task::coop::poll_proceed(cx));
394
395 let ret = self.poll_write_vectored_internal(cx, bufs);
396 if ret.is_ready() {
397 coop.made_progress();
398 }
399 ret
400 }
401 }
402
403 cfg_not_coop! {
404 fn poll_write_vectored(
405 self: Pin<&mut Self>,
406 cx: &mut task::Context<'_>,
407 bufs: &[std::io::IoSlice<'_>],
408 ) -> Poll<Result<usize, std::io::Error>> {
409 ready!(crate::trace::trace_leaf(cx));
410 self.poll_write_vectored_internal(cx, bufs)
411 }
412 }
413
414 fn is_write_vectored(&self) -> bool {
415 true
416 }
417
418 fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
419 Poll::Ready(Ok(()))
420 }
421
422 fn poll_shutdown(
423 mut self: Pin<&mut Self>,
424 _: &mut task::Context<'_>,
425 ) -> Poll<std::io::Result<()>> {
426 self.close_write();
427 Poll::Ready(Ok(()))
428 }
429}
430