1//! `TcpStream` owned split support.
2//!
3//! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf`
4//! with the `TcpStream::into_split` method. `OwnedReadHalf` implements
5//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`.
6//!
7//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
8//! split has no associated overhead and enforces all invariants at the type
9//! level.
10
11use crate::future::poll_fn;
12use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready};
13use crate::net::TcpStream;
14
15use std::error::Error;
16use std::net::{Shutdown, SocketAddr};
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::{fmt, io};
21
22cfg_io_util! {
23 use bytes::BufMut;
24}
25
26/// Owned read half of a [`TcpStream`], created by [`into_split`].
27///
28/// Reading from an `OwnedReadHalf` is usually done using the convenience methods found
29/// on the [`AsyncReadExt`] trait.
30///
31/// [`TcpStream`]: TcpStream
32/// [`into_split`]: TcpStream::into_split()
33/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
34#[derive(Debug)]
35pub struct OwnedReadHalf {
36 inner: Arc<TcpStream>,
37}
38
39/// Owned write half of a [`TcpStream`], created by [`into_split`].
40///
41/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will
42/// shut down the TCP stream in the write direction. Dropping the write half
43/// will also shut down the write half of the TCP stream.
44///
45/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found
46/// on the [`AsyncWriteExt`] trait.
47///
48/// [`TcpStream`]: TcpStream
49/// [`into_split`]: TcpStream::into_split()
50/// [`AsyncWrite`]: trait@crate::io::AsyncWrite
51/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown
52/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
53#[derive(Debug)]
54pub struct OwnedWriteHalf {
55 inner: Arc<TcpStream>,
56 shutdown_on_drop: bool,
57}
58
59pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
60 let arc = Arc::new(stream);
61 let read = OwnedReadHalf {
62 inner: Arc::clone(&arc),
63 };
64 let write = OwnedWriteHalf {
65 inner: arc,
66 shutdown_on_drop: true,
67 };
68 (read, write)
69}
70
71pub(crate) fn reunite(
72 read: OwnedReadHalf,
73 write: OwnedWriteHalf,
74) -> Result<TcpStream, ReuniteError> {
75 if Arc::ptr_eq(&read.inner, &write.inner) {
76 write.forget();
77 // This unwrap cannot fail as the api does not allow creating more than two Arcs,
78 // and we just dropped the other half.
79 Ok(Arc::try_unwrap(read.inner).expect("TcpStream: try_unwrap failed in reunite"))
80 } else {
81 Err(ReuniteError(read, write))
82 }
83}
84
85/// Error indicating that two halves were not from the same socket, and thus could
86/// not be reunited.
87#[derive(Debug)]
88pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
89
90impl fmt::Display for ReuniteError {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 write!(
93 f,
94 "tried to reunite halves that are not from the same socket"
95 )
96 }
97}
98
99impl Error for ReuniteError {}
100
101impl OwnedReadHalf {
102 /// Attempts to put the two halves of a `TcpStream` back together and
103 /// recover the original socket. Succeeds only if the two halves
104 /// originated from the same call to [`into_split`].
105 ///
106 /// [`into_split`]: TcpStream::into_split()
107 pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
108 reunite(self, other)
109 }
110
111 /// Attempt to receive data on the socket, without removing that data from
112 /// the queue, registering the current task for wakeup if data is not yet
113 /// available.
114 ///
115 /// Note that on multiple calls to `poll_peek` or `poll_read`, only the
116 /// `Waker` from the `Context` passed to the most recent call is scheduled
117 /// to receive a wakeup.
118 ///
119 /// See the [`TcpStream::poll_peek`] level documentation for more details.
120 ///
121 /// # Examples
122 ///
123 /// ```no_run
124 /// use tokio::io::{self, ReadBuf};
125 /// use tokio::net::TcpStream;
126 ///
127 /// use futures::future::poll_fn;
128 ///
129 /// #[tokio::main]
130 /// async fn main() -> io::Result<()> {
131 /// let stream = TcpStream::connect("127.0.0.1:8000").await?;
132 /// let (mut read_half, _) = stream.into_split();
133 /// let mut buf = [0; 10];
134 /// let mut buf = ReadBuf::new(&mut buf);
135 ///
136 /// poll_fn(|cx| {
137 /// read_half.poll_peek(cx, &mut buf)
138 /// }).await?;
139 ///
140 /// Ok(())
141 /// }
142 /// ```
143 ///
144 /// [`TcpStream::poll_peek`]: TcpStream::poll_peek
145 pub fn poll_peek(
146 &mut self,
147 cx: &mut Context<'_>,
148 buf: &mut ReadBuf<'_>,
149 ) -> Poll<io::Result<usize>> {
150 self.inner.poll_peek(cx, buf)
151 }
152
153 /// Receives data on the socket from the remote address to which it is
154 /// connected, without removing that data from the queue. On success,
155 /// returns the number of bytes peeked.
156 ///
157 /// See the [`TcpStream::peek`] level documentation for more details.
158 ///
159 /// [`TcpStream::peek`]: TcpStream::peek
160 ///
161 /// # Examples
162 ///
163 /// ```no_run
164 /// use tokio::net::TcpStream;
165 /// use tokio::io::AsyncReadExt;
166 /// use std::error::Error;
167 ///
168 /// #[tokio::main]
169 /// async fn main() -> Result<(), Box<dyn Error>> {
170 /// // Connect to a peer
171 /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
172 /// let (mut read_half, _) = stream.into_split();
173 ///
174 /// let mut b1 = [0; 10];
175 /// let mut b2 = [0; 10];
176 ///
177 /// // Peek at the data
178 /// let n = read_half.peek(&mut b1).await?;
179 ///
180 /// // Read the data
181 /// assert_eq!(n, read_half.read(&mut b2[..n]).await?);
182 /// assert_eq!(&b1[..n], &b2[..n]);
183 ///
184 /// Ok(())
185 /// }
186 /// ```
187 ///
188 /// The [`read`] method is defined on the [`AsyncReadExt`] trait.
189 ///
190 /// [`read`]: fn@crate::io::AsyncReadExt::read
191 /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
192 pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
193 let mut buf = ReadBuf::new(buf);
194 poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
195 }
196
197 /// Waits for any of the requested ready states.
198 ///
199 /// This function is usually paired with [`try_read()`]. It can be used instead
200 /// of [`readable()`] to check the returned ready set for [`Ready::READABLE`]
201 /// and [`Ready::READ_CLOSED`] events.
202 ///
203 /// The function may complete without the socket being ready. This is a
204 /// false-positive and attempting an operation will return with
205 /// `io::ErrorKind::WouldBlock`. The function can also return with an empty
206 /// [`Ready`] set, so you should always check the returned value and possibly
207 /// wait again if the requested states are not set.
208 ///
209 /// This function is equivalent to [`TcpStream::ready`].
210 ///
211 /// [`try_read()`]: Self::try_read
212 /// [`readable()`]: Self::readable
213 ///
214 /// # Cancel safety
215 ///
216 /// This method is cancel safe. Once a readiness event occurs, the method
217 /// will continue to return immediately until the readiness event is
218 /// consumed by an attempt to read or write that fails with `WouldBlock` or
219 /// `Poll::Pending`.
220 pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
221 self.inner.ready(interest).await
222 }
223
224 /// Waits for the socket to become readable.
225 ///
226 /// This function is equivalent to `ready(Interest::READABLE)` and is usually
227 /// paired with `try_read()`.
228 ///
229 /// This function is also equivalent to [`TcpStream::ready`].
230 ///
231 /// # Cancel safety
232 ///
233 /// This method is cancel safe. Once a readiness event occurs, the method
234 /// will continue to return immediately until the readiness event is
235 /// consumed by an attempt to read that fails with `WouldBlock` or
236 /// `Poll::Pending`.
237 pub async fn readable(&self) -> io::Result<()> {
238 self.inner.readable().await
239 }
240
241 /// Tries to read data from the stream into the provided buffer, returning how
242 /// many bytes were read.
243 ///
244 /// Receives any pending data from the socket but does not wait for new data
245 /// to arrive. On success, returns the number of bytes read. Because
246 /// `try_read()` is non-blocking, the buffer does not have to be stored by
247 /// the async task and can exist entirely on the stack.
248 ///
249 /// Usually, [`readable()`] or [`ready()`] is used with this function.
250 ///
251 /// [`readable()`]: Self::readable()
252 /// [`ready()`]: Self::ready()
253 ///
254 /// # Return
255 ///
256 /// If data is successfully read, `Ok(n)` is returned, where `n` is the
257 /// number of bytes read. If `n` is `0`, then it can indicate one of two scenarios:
258 ///
259 /// 1. The stream's read half is closed and will no longer yield data.
260 /// 2. The specified buffer was 0 bytes in length.
261 ///
262 /// If the stream is not ready to read data,
263 /// `Err(io::ErrorKind::WouldBlock)` is returned.
264 pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
265 self.inner.try_read(buf)
266 }
267
268 /// Tries to read data from the stream into the provided buffers, returning
269 /// how many bytes were read.
270 ///
271 /// Data is copied to fill each buffer in order, with the final buffer
272 /// written to possibly being only partially filled. This method behaves
273 /// equivalently to a single call to [`try_read()`] with concatenated
274 /// buffers.
275 ///
276 /// Receives any pending data from the socket but does not wait for new data
277 /// to arrive. On success, returns the number of bytes read. Because
278 /// `try_read_vectored()` is non-blocking, the buffer does not have to be
279 /// stored by the async task and can exist entirely on the stack.
280 ///
281 /// Usually, [`readable()`] or [`ready()`] is used with this function.
282 ///
283 /// [`try_read()`]: Self::try_read()
284 /// [`readable()`]: Self::readable()
285 /// [`ready()`]: Self::ready()
286 ///
287 /// # Return
288 ///
289 /// If data is successfully read, `Ok(n)` is returned, where `n` is the
290 /// number of bytes read. `Ok(0)` indicates the stream's read half is closed
291 /// and will no longer yield data. If the stream is not ready to read data
292 /// `Err(io::ErrorKind::WouldBlock)` is returned.
293 pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
294 self.inner.try_read_vectored(bufs)
295 }
296
297 cfg_io_util! {
298 /// Tries to read data from the stream into the provided buffer, advancing the
299 /// buffer's internal cursor, returning how many bytes were read.
300 ///
301 /// Receives any pending data from the socket but does not wait for new data
302 /// to arrive. On success, returns the number of bytes read. Because
303 /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by
304 /// the async task and can exist entirely on the stack.
305 ///
306 /// Usually, [`readable()`] or [`ready()`] is used with this function.
307 ///
308 /// [`readable()`]: Self::readable()
309 /// [`ready()`]: Self::ready()
310 ///
311 /// # Return
312 ///
313 /// If data is successfully read, `Ok(n)` is returned, where `n` is the
314 /// number of bytes read. `Ok(0)` indicates the stream's read half is closed
315 /// and will no longer yield data. If the stream is not ready to read data
316 /// `Err(io::ErrorKind::WouldBlock)` is returned.
317 pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
318 self.inner.try_read_buf(buf)
319 }
320 }
321
322 /// Returns the remote address that this stream is connected to.
323 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
324 self.inner.peer_addr()
325 }
326
327 /// Returns the local address that this stream is bound to.
328 pub fn local_addr(&self) -> io::Result<SocketAddr> {
329 self.inner.local_addr()
330 }
331}
332
333impl AsyncRead for OwnedReadHalf {
334 fn poll_read(
335 self: Pin<&mut Self>,
336 cx: &mut Context<'_>,
337 buf: &mut ReadBuf<'_>,
338 ) -> Poll<io::Result<()>> {
339 self.inner.poll_read_priv(cx, buf)
340 }
341}
342
343impl OwnedWriteHalf {
344 /// Attempts to put the two halves of a `TcpStream` back together and
345 /// recover the original socket. Succeeds only if the two halves
346 /// originated from the same call to [`into_split`].
347 ///
348 /// [`into_split`]: TcpStream::into_split()
349 pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
350 reunite(other, self)
351 }
352
353 /// Destroys the write half, but don't close the write half of the stream
354 /// until the read half is dropped. If the read half has already been
355 /// dropped, this closes the stream.
356 pub fn forget(mut self) {
357 self.shutdown_on_drop = false;
358 drop(self);
359 }
360
361 /// Waits for any of the requested ready states.
362 ///
363 /// This function is usually paired with [`try_write()`]. It can be used instead
364 /// of [`writable()`] to check the returned ready set for [`Ready::WRITABLE`]
365 /// and [`Ready::WRITE_CLOSED`] events.
366 ///
367 /// The function may complete without the socket being ready. This is a
368 /// false-positive and attempting an operation will return with
369 /// `io::ErrorKind::WouldBlock`. The function can also return with an empty
370 /// [`Ready`] set, so you should always check the returned value and possibly
371 /// wait again if the requested states are not set.
372 ///
373 /// This function is equivalent to [`TcpStream::ready`].
374 ///
375 /// [`try_write()`]: Self::try_write
376 /// [`writable()`]: Self::writable
377 ///
378 /// # Cancel safety
379 ///
380 /// This method is cancel safe. Once a readiness event occurs, the method
381 /// will continue to return immediately until the readiness event is
382 /// consumed by an attempt to read or write that fails with `WouldBlock` or
383 /// `Poll::Pending`.
384 pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
385 self.inner.ready(interest).await
386 }
387
388 /// Waits for the socket to become writable.
389 ///
390 /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually
391 /// paired with `try_write()`.
392 ///
393 /// # Cancel safety
394 ///
395 /// This method is cancel safe. Once a readiness event occurs, the method
396 /// will continue to return immediately until the readiness event is
397 /// consumed by an attempt to write that fails with `WouldBlock` or
398 /// `Poll::Pending`.
399 pub async fn writable(&self) -> io::Result<()> {
400 self.inner.writable().await
401 }
402
403 /// Tries to write a buffer to the stream, returning how many bytes were
404 /// written.
405 ///
406 /// The function will attempt to write the entire contents of `buf`, but
407 /// only part of the buffer may be written.
408 ///
409 /// This function is usually paired with `writable()`.
410 ///
411 /// # Return
412 ///
413 /// If data is successfully written, `Ok(n)` is returned, where `n` is the
414 /// number of bytes written. If the stream is not ready to write data,
415 /// `Err(io::ErrorKind::WouldBlock)` is returned.
416 pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
417 self.inner.try_write(buf)
418 }
419
420 /// Tries to write several buffers to the stream, returning how many bytes
421 /// were written.
422 ///
423 /// Data is written from each buffer in order, with the final buffer read
424 /// from possible being only partially consumed. This method behaves
425 /// equivalently to a single call to [`try_write()`] with concatenated
426 /// buffers.
427 ///
428 /// This function is usually paired with `writable()`.
429 ///
430 /// [`try_write()`]: Self::try_write()
431 ///
432 /// # Return
433 ///
434 /// If data is successfully written, `Ok(n)` is returned, where `n` is the
435 /// number of bytes written. If the stream is not ready to write data,
436 /// `Err(io::ErrorKind::WouldBlock)` is returned.
437 pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
438 self.inner.try_write_vectored(bufs)
439 }
440
441 /// Returns the remote address that this stream is connected to.
442 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
443 self.inner.peer_addr()
444 }
445
446 /// Returns the local address that this stream is bound to.
447 pub fn local_addr(&self) -> io::Result<SocketAddr> {
448 self.inner.local_addr()
449 }
450}
451
452impl Drop for OwnedWriteHalf {
453 fn drop(&mut self) {
454 if self.shutdown_on_drop {
455 let _ = self.inner.shutdown_std(Shutdown::Write);
456 }
457 }
458}
459
460impl AsyncWrite for OwnedWriteHalf {
461 fn poll_write(
462 self: Pin<&mut Self>,
463 cx: &mut Context<'_>,
464 buf: &[u8],
465 ) -> Poll<io::Result<usize>> {
466 self.inner.poll_write_priv(cx, buf)
467 }
468
469 fn poll_write_vectored(
470 self: Pin<&mut Self>,
471 cx: &mut Context<'_>,
472 bufs: &[io::IoSlice<'_>],
473 ) -> Poll<io::Result<usize>> {
474 self.inner.poll_write_vectored_priv(cx, bufs)
475 }
476
477 fn is_write_vectored(&self) -> bool {
478 self.inner.is_write_vectored()
479 }
480
481 #[inline]
482 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
483 // tcp flush is a no-op
484 Poll::Ready(Ok(()))
485 }
486
487 // `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
488 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
489 let res = self.inner.shutdown_std(Shutdown::Write);
490 if res.is_ok() {
491 Pin::into_inner(self).shutdown_on_drop = false;
492 }
493 res.into()
494 }
495}
496
497impl AsRef<TcpStream> for OwnedReadHalf {
498 fn as_ref(&self) -> &TcpStream {
499 &self.inner
500 }
501}
502
503impl AsRef<TcpStream> for OwnedWriteHalf {
504 fn as_ref(&self) -> &TcpStream {
505 &self.inner
506 }
507}
508