1 | //! `TcpStream` owned split support. |
2 | //! |
3 | //! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf` |
4 | //! with the `TcpStream::into_split` method. `OwnedReadHalf` implements |
5 | //! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`. |
6 | //! |
7 | //! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized |
8 | //! split has no associated overhead and enforces all invariants at the type |
9 | //! level. |
10 | |
11 | use crate::future::poll_fn; |
12 | use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; |
13 | use crate::net::TcpStream; |
14 | |
15 | use std::error::Error; |
16 | use std::net::{Shutdown, SocketAddr}; |
17 | use std::pin::Pin; |
18 | use std::sync::Arc; |
19 | use std::task::{Context, Poll}; |
20 | use std::{fmt, io}; |
21 | |
22 | cfg_io_util! { |
23 | use bytes::BufMut; |
24 | } |
25 | |
26 | /// Owned read half of a [`TcpStream`], created by [`into_split`]. |
27 | /// |
28 | /// Reading from an `OwnedReadHalf` is usually done using the convenience methods found |
29 | /// on the [`AsyncReadExt`] trait. |
30 | /// |
31 | /// [`TcpStream`]: TcpStream |
32 | /// [`into_split`]: TcpStream::into_split() |
33 | /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt |
34 | #[derive (Debug)] |
35 | pub struct OwnedReadHalf { |
36 | inner: Arc<TcpStream>, |
37 | } |
38 | |
39 | /// Owned write half of a [`TcpStream`], created by [`into_split`]. |
40 | /// |
41 | /// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will |
42 | /// shut down the TCP stream in the write direction. Dropping the write half |
43 | /// will also shut down the write half of the TCP stream. |
44 | /// |
45 | /// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found |
46 | /// on the [`AsyncWriteExt`] trait. |
47 | /// |
48 | /// [`TcpStream`]: TcpStream |
49 | /// [`into_split`]: TcpStream::into_split() |
50 | /// [`AsyncWrite`]: trait@crate::io::AsyncWrite |
51 | /// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown |
52 | /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt |
53 | #[derive (Debug)] |
54 | pub struct OwnedWriteHalf { |
55 | inner: Arc<TcpStream>, |
56 | shutdown_on_drop: bool, |
57 | } |
58 | |
59 | pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) { |
60 | let arc: Arc = Arc::new(data:stream); |
61 | let read: OwnedReadHalf = OwnedReadHalf { |
62 | inner: Arc::clone(&arc), |
63 | }; |
64 | let write: OwnedWriteHalf = OwnedWriteHalf { |
65 | inner: arc, |
66 | shutdown_on_drop: true, |
67 | }; |
68 | (read, write) |
69 | } |
70 | |
71 | pub(crate) fn reunite( |
72 | read: OwnedReadHalf, |
73 | write: OwnedWriteHalf, |
74 | ) -> Result<TcpStream, ReuniteError> { |
75 | if Arc::ptr_eq(&read.inner, &write.inner) { |
76 | write.forget(); |
77 | // This unwrap cannot fail as the api does not allow creating more than two Arcs, |
78 | // and we just dropped the other half. |
79 | Ok(Arc::try_unwrap(read.inner).expect(msg:"TcpStream: try_unwrap failed in reunite" )) |
80 | } else { |
81 | Err(ReuniteError(read, write)) |
82 | } |
83 | } |
84 | |
85 | /// Error indicating that two halves were not from the same socket, and thus could |
86 | /// not be reunited. |
87 | #[derive (Debug)] |
88 | pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf); |
89 | |
90 | impl fmt::Display for ReuniteError { |
91 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
92 | write!( |
93 | f, |
94 | "tried to reunite halves that are not from the same socket" |
95 | ) |
96 | } |
97 | } |
98 | |
99 | impl Error for ReuniteError {} |
100 | |
101 | impl OwnedReadHalf { |
102 | /// Attempts to put the two halves of a `TcpStream` back together and |
103 | /// recover the original socket. Succeeds only if the two halves |
104 | /// originated from the same call to [`into_split`]. |
105 | /// |
106 | /// [`into_split`]: TcpStream::into_split() |
107 | pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> { |
108 | reunite(self, other) |
109 | } |
110 | |
111 | /// Attempt to receive data on the socket, without removing that data from |
112 | /// the queue, registering the current task for wakeup if data is not yet |
113 | /// available. |
114 | /// |
115 | /// Note that on multiple calls to `poll_peek` or `poll_read`, only the |
116 | /// `Waker` from the `Context` passed to the most recent call is scheduled |
117 | /// to receive a wakeup. |
118 | /// |
119 | /// See the [`TcpStream::poll_peek`] level documentation for more details. |
120 | /// |
121 | /// # Examples |
122 | /// |
123 | /// ```no_run |
124 | /// use tokio::io::{self, ReadBuf}; |
125 | /// use tokio::net::TcpStream; |
126 | /// |
127 | /// use futures::future::poll_fn; |
128 | /// |
129 | /// #[tokio::main] |
130 | /// async fn main() -> io::Result<()> { |
131 | /// let stream = TcpStream::connect("127.0.0.1:8000" ).await?; |
132 | /// let (mut read_half, _) = stream.into_split(); |
133 | /// let mut buf = [0; 10]; |
134 | /// let mut buf = ReadBuf::new(&mut buf); |
135 | /// |
136 | /// poll_fn(|cx| { |
137 | /// read_half.poll_peek(cx, &mut buf) |
138 | /// }).await?; |
139 | /// |
140 | /// Ok(()) |
141 | /// } |
142 | /// ``` |
143 | /// |
144 | /// [`TcpStream::poll_peek`]: TcpStream::poll_peek |
145 | pub fn poll_peek( |
146 | &mut self, |
147 | cx: &mut Context<'_>, |
148 | buf: &mut ReadBuf<'_>, |
149 | ) -> Poll<io::Result<usize>> { |
150 | self.inner.poll_peek(cx, buf) |
151 | } |
152 | |
153 | /// Receives data on the socket from the remote address to which it is |
154 | /// connected, without removing that data from the queue. On success, |
155 | /// returns the number of bytes peeked. |
156 | /// |
157 | /// See the [`TcpStream::peek`] level documentation for more details. |
158 | /// |
159 | /// [`TcpStream::peek`]: TcpStream::peek |
160 | /// |
161 | /// # Examples |
162 | /// |
163 | /// ```no_run |
164 | /// use tokio::net::TcpStream; |
165 | /// use tokio::io::AsyncReadExt; |
166 | /// use std::error::Error; |
167 | /// |
168 | /// #[tokio::main] |
169 | /// async fn main() -> Result<(), Box<dyn Error>> { |
170 | /// // Connect to a peer |
171 | /// let stream = TcpStream::connect("127.0.0.1:8080" ).await?; |
172 | /// let (mut read_half, _) = stream.into_split(); |
173 | /// |
174 | /// let mut b1 = [0; 10]; |
175 | /// let mut b2 = [0; 10]; |
176 | /// |
177 | /// // Peek at the data |
178 | /// let n = read_half.peek(&mut b1).await?; |
179 | /// |
180 | /// // Read the data |
181 | /// assert_eq!(n, read_half.read(&mut b2[..n]).await?); |
182 | /// assert_eq!(&b1[..n], &b2[..n]); |
183 | /// |
184 | /// Ok(()) |
185 | /// } |
186 | /// ``` |
187 | /// |
188 | /// The [`read`] method is defined on the [`AsyncReadExt`] trait. |
189 | /// |
190 | /// [`read`]: fn@crate::io::AsyncReadExt::read |
191 | /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt |
192 | pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
193 | let mut buf = ReadBuf::new(buf); |
194 | poll_fn(|cx| self.poll_peek(cx, &mut buf)).await |
195 | } |
196 | |
197 | /// Waits for any of the requested ready states. |
198 | /// |
199 | /// This function is usually paired with [`try_read()`]. It can be used instead |
200 | /// of [`readable()`] to check the returned ready set for [`Ready::READABLE`] |
201 | /// and [`Ready::READ_CLOSED`] events. |
202 | /// |
203 | /// The function may complete without the socket being ready. This is a |
204 | /// false-positive and attempting an operation will return with |
205 | /// `io::ErrorKind::WouldBlock`. The function can also return with an empty |
206 | /// [`Ready`] set, so you should always check the returned value and possibly |
207 | /// wait again if the requested states are not set. |
208 | /// |
209 | /// This function is equivalent to [`TcpStream::ready`]. |
210 | /// |
211 | /// [`try_read()`]: Self::try_read |
212 | /// [`readable()`]: Self::readable |
213 | /// |
214 | /// # Cancel safety |
215 | /// |
216 | /// This method is cancel safe. Once a readiness event occurs, the method |
217 | /// will continue to return immediately until the readiness event is |
218 | /// consumed by an attempt to read or write that fails with `WouldBlock` or |
219 | /// `Poll::Pending`. |
220 | pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { |
221 | self.inner.ready(interest).await |
222 | } |
223 | |
224 | /// Waits for the socket to become readable. |
225 | /// |
226 | /// This function is equivalent to `ready(Interest::READABLE)` and is usually |
227 | /// paired with `try_read()`. |
228 | /// |
229 | /// This function is also equivalent to [`TcpStream::ready`]. |
230 | /// |
231 | /// # Cancel safety |
232 | /// |
233 | /// This method is cancel safe. Once a readiness event occurs, the method |
234 | /// will continue to return immediately until the readiness event is |
235 | /// consumed by an attempt to read that fails with `WouldBlock` or |
236 | /// `Poll::Pending`. |
237 | pub async fn readable(&self) -> io::Result<()> { |
238 | self.inner.readable().await |
239 | } |
240 | |
241 | /// Tries to read data from the stream into the provided buffer, returning how |
242 | /// many bytes were read. |
243 | /// |
244 | /// Receives any pending data from the socket but does not wait for new data |
245 | /// to arrive. On success, returns the number of bytes read. Because |
246 | /// `try_read()` is non-blocking, the buffer does not have to be stored by |
247 | /// the async task and can exist entirely on the stack. |
248 | /// |
249 | /// Usually, [`readable()`] or [`ready()`] is used with this function. |
250 | /// |
251 | /// [`readable()`]: Self::readable() |
252 | /// [`ready()`]: Self::ready() |
253 | /// |
254 | /// # Return |
255 | /// |
256 | /// If data is successfully read, `Ok(n)` is returned, where `n` is the |
257 | /// number of bytes read. If `n` is `0`, then it can indicate one of two scenarios: |
258 | /// |
259 | /// 1. The stream's read half is closed and will no longer yield data. |
260 | /// 2. The specified buffer was 0 bytes in length. |
261 | /// |
262 | /// If the stream is not ready to read data, |
263 | /// `Err(io::ErrorKind::WouldBlock)` is returned. |
264 | pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { |
265 | self.inner.try_read(buf) |
266 | } |
267 | |
268 | /// Tries to read data from the stream into the provided buffers, returning |
269 | /// how many bytes were read. |
270 | /// |
271 | /// Data is copied to fill each buffer in order, with the final buffer |
272 | /// written to possibly being only partially filled. This method behaves |
273 | /// equivalently to a single call to [`try_read()`] with concatenated |
274 | /// buffers. |
275 | /// |
276 | /// Receives any pending data from the socket but does not wait for new data |
277 | /// to arrive. On success, returns the number of bytes read. Because |
278 | /// `try_read_vectored()` is non-blocking, the buffer does not have to be |
279 | /// stored by the async task and can exist entirely on the stack. |
280 | /// |
281 | /// Usually, [`readable()`] or [`ready()`] is used with this function. |
282 | /// |
283 | /// [`try_read()`]: Self::try_read() |
284 | /// [`readable()`]: Self::readable() |
285 | /// [`ready()`]: Self::ready() |
286 | /// |
287 | /// # Return |
288 | /// |
289 | /// If data is successfully read, `Ok(n)` is returned, where `n` is the |
290 | /// number of bytes read. `Ok(0)` indicates the stream's read half is closed |
291 | /// and will no longer yield data. If the stream is not ready to read data |
292 | /// `Err(io::ErrorKind::WouldBlock)` is returned. |
293 | pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { |
294 | self.inner.try_read_vectored(bufs) |
295 | } |
296 | |
297 | cfg_io_util! { |
298 | /// Tries to read data from the stream into the provided buffer, advancing the |
299 | /// buffer's internal cursor, returning how many bytes were read. |
300 | /// |
301 | /// Receives any pending data from the socket but does not wait for new data |
302 | /// to arrive. On success, returns the number of bytes read. Because |
303 | /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by |
304 | /// the async task and can exist entirely on the stack. |
305 | /// |
306 | /// Usually, [`readable()`] or [`ready()`] is used with this function. |
307 | /// |
308 | /// [`readable()`]: Self::readable() |
309 | /// [`ready()`]: Self::ready() |
310 | /// |
311 | /// # Return |
312 | /// |
313 | /// If data is successfully read, `Ok(n)` is returned, where `n` is the |
314 | /// number of bytes read. `Ok(0)` indicates the stream's read half is closed |
315 | /// and will no longer yield data. If the stream is not ready to read data |
316 | /// `Err(io::ErrorKind::WouldBlock)` is returned. |
317 | pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { |
318 | self.inner.try_read_buf(buf) |
319 | } |
320 | } |
321 | |
322 | /// Returns the remote address that this stream is connected to. |
323 | pub fn peer_addr(&self) -> io::Result<SocketAddr> { |
324 | self.inner.peer_addr() |
325 | } |
326 | |
327 | /// Returns the local address that this stream is bound to. |
328 | pub fn local_addr(&self) -> io::Result<SocketAddr> { |
329 | self.inner.local_addr() |
330 | } |
331 | } |
332 | |
333 | impl AsyncRead for OwnedReadHalf { |
334 | fn poll_read( |
335 | self: Pin<&mut Self>, |
336 | cx: &mut Context<'_>, |
337 | buf: &mut ReadBuf<'_>, |
338 | ) -> Poll<io::Result<()>> { |
339 | self.inner.poll_read_priv(cx, buf) |
340 | } |
341 | } |
342 | |
343 | impl OwnedWriteHalf { |
344 | /// Attempts to put the two halves of a `TcpStream` back together and |
345 | /// recover the original socket. Succeeds only if the two halves |
346 | /// originated from the same call to [`into_split`]. |
347 | /// |
348 | /// [`into_split`]: TcpStream::into_split() |
349 | pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> { |
350 | reunite(other, self) |
351 | } |
352 | |
353 | /// Destroys the write half, but don't close the write half of the stream |
354 | /// until the read half is dropped. If the read half has already been |
355 | /// dropped, this closes the stream. |
356 | pub fn forget(mut self) { |
357 | self.shutdown_on_drop = false; |
358 | drop(self); |
359 | } |
360 | |
361 | /// Waits for any of the requested ready states. |
362 | /// |
363 | /// This function is usually paired with [`try_write()`]. It can be used instead |
364 | /// of [`writable()`] to check the returned ready set for [`Ready::WRITABLE`] |
365 | /// and [`Ready::WRITE_CLOSED`] events. |
366 | /// |
367 | /// The function may complete without the socket being ready. This is a |
368 | /// false-positive and attempting an operation will return with |
369 | /// `io::ErrorKind::WouldBlock`. The function can also return with an empty |
370 | /// [`Ready`] set, so you should always check the returned value and possibly |
371 | /// wait again if the requested states are not set. |
372 | /// |
373 | /// This function is equivalent to [`TcpStream::ready`]. |
374 | /// |
375 | /// [`try_write()`]: Self::try_write |
376 | /// [`writable()`]: Self::writable |
377 | /// |
378 | /// # Cancel safety |
379 | /// |
380 | /// This method is cancel safe. Once a readiness event occurs, the method |
381 | /// will continue to return immediately until the readiness event is |
382 | /// consumed by an attempt to read or write that fails with `WouldBlock` or |
383 | /// `Poll::Pending`. |
384 | pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { |
385 | self.inner.ready(interest).await |
386 | } |
387 | |
388 | /// Waits for the socket to become writable. |
389 | /// |
390 | /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually |
391 | /// paired with `try_write()`. |
392 | /// |
393 | /// # Cancel safety |
394 | /// |
395 | /// This method is cancel safe. Once a readiness event occurs, the method |
396 | /// will continue to return immediately until the readiness event is |
397 | /// consumed by an attempt to write that fails with `WouldBlock` or |
398 | /// `Poll::Pending`. |
399 | pub async fn writable(&self) -> io::Result<()> { |
400 | self.inner.writable().await |
401 | } |
402 | |
403 | /// Tries to write a buffer to the stream, returning how many bytes were |
404 | /// written. |
405 | /// |
406 | /// The function will attempt to write the entire contents of `buf`, but |
407 | /// only part of the buffer may be written. |
408 | /// |
409 | /// This function is usually paired with `writable()`. |
410 | /// |
411 | /// # Return |
412 | /// |
413 | /// If data is successfully written, `Ok(n)` is returned, where `n` is the |
414 | /// number of bytes written. If the stream is not ready to write data, |
415 | /// `Err(io::ErrorKind::WouldBlock)` is returned. |
416 | pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { |
417 | self.inner.try_write(buf) |
418 | } |
419 | |
420 | /// Tries to write several buffers to the stream, returning how many bytes |
421 | /// were written. |
422 | /// |
423 | /// Data is written from each buffer in order, with the final buffer read |
424 | /// from possible being only partially consumed. This method behaves |
425 | /// equivalently to a single call to [`try_write()`] with concatenated |
426 | /// buffers. |
427 | /// |
428 | /// This function is usually paired with `writable()`. |
429 | /// |
430 | /// [`try_write()`]: Self::try_write() |
431 | /// |
432 | /// # Return |
433 | /// |
434 | /// If data is successfully written, `Ok(n)` is returned, where `n` is the |
435 | /// number of bytes written. If the stream is not ready to write data, |
436 | /// `Err(io::ErrorKind::WouldBlock)` is returned. |
437 | pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> { |
438 | self.inner.try_write_vectored(bufs) |
439 | } |
440 | |
441 | /// Returns the remote address that this stream is connected to. |
442 | pub fn peer_addr(&self) -> io::Result<SocketAddr> { |
443 | self.inner.peer_addr() |
444 | } |
445 | |
446 | /// Returns the local address that this stream is bound to. |
447 | pub fn local_addr(&self) -> io::Result<SocketAddr> { |
448 | self.inner.local_addr() |
449 | } |
450 | } |
451 | |
452 | impl Drop for OwnedWriteHalf { |
453 | fn drop(&mut self) { |
454 | if self.shutdown_on_drop { |
455 | let _ = self.inner.shutdown_std(how:Shutdown::Write); |
456 | } |
457 | } |
458 | } |
459 | |
460 | impl AsyncWrite for OwnedWriteHalf { |
461 | fn poll_write( |
462 | self: Pin<&mut Self>, |
463 | cx: &mut Context<'_>, |
464 | buf: &[u8], |
465 | ) -> Poll<io::Result<usize>> { |
466 | self.inner.poll_write_priv(cx, buf) |
467 | } |
468 | |
469 | fn poll_write_vectored( |
470 | self: Pin<&mut Self>, |
471 | cx: &mut Context<'_>, |
472 | bufs: &[io::IoSlice<'_>], |
473 | ) -> Poll<io::Result<usize>> { |
474 | self.inner.poll_write_vectored_priv(cx, bufs) |
475 | } |
476 | |
477 | fn is_write_vectored(&self) -> bool { |
478 | self.inner.is_write_vectored() |
479 | } |
480 | |
481 | #[inline ] |
482 | fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { |
483 | // tcp flush is a no-op |
484 | Poll::Ready(Ok(())) |
485 | } |
486 | |
487 | // `poll_shutdown` on a write half shutdowns the stream in the "write" direction. |
488 | fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> { |
489 | let res = self.inner.shutdown_std(Shutdown::Write); |
490 | if res.is_ok() { |
491 | Pin::into_inner(self).shutdown_on_drop = false; |
492 | } |
493 | res.into() |
494 | } |
495 | } |
496 | |
497 | impl AsRef<TcpStream> for OwnedReadHalf { |
498 | fn as_ref(&self) -> &TcpStream { |
499 | &self.inner |
500 | } |
501 | } |
502 | |
503 | impl AsRef<TcpStream> for OwnedWriteHalf { |
504 | fn as_ref(&self) -> &TcpStream { |
505 | &self.inner |
506 | } |
507 | } |
508 | |