1 | //! Split a single value implementing `AsyncRead + AsyncWrite` into separate |
2 | //! `AsyncRead` and `AsyncWrite` handles. |
3 | //! |
4 | //! To restore this read/write object from its `split::ReadHalf` and |
5 | //! `split::WriteHalf` use `unsplit`. |
6 | |
7 | use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; |
8 | |
9 | use std::cell::UnsafeCell; |
10 | use std::fmt; |
11 | use std::io; |
12 | use std::pin::Pin; |
13 | use std::sync::atomic::AtomicBool; |
14 | use std::sync::atomic::Ordering::{Acquire, Release}; |
15 | use std::sync::Arc; |
16 | use std::task::{Context, Poll}; |
17 | |
18 | cfg_io_util! { |
19 | /// The readable half of a value returned from [`split`](split()). |
20 | pub struct ReadHalf<T> { |
21 | inner: Arc<Inner<T>>, |
22 | } |
23 | |
24 | /// The writable half of a value returned from [`split`](split()). |
25 | pub struct WriteHalf<T> { |
26 | inner: Arc<Inner<T>>, |
27 | } |
28 | |
29 | /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate |
30 | /// `AsyncRead` and `AsyncWrite` handles. |
31 | /// |
32 | /// To restore this read/write object from its `ReadHalf` and |
33 | /// `WriteHalf` use [`unsplit`](ReadHalf::unsplit()). |
34 | pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) |
35 | where |
36 | T: AsyncRead + AsyncWrite, |
37 | { |
38 | let is_write_vectored = stream.is_write_vectored(); |
39 | |
40 | let inner = Arc::new(Inner { |
41 | locked: AtomicBool::new(false), |
42 | stream: UnsafeCell::new(stream), |
43 | is_write_vectored, |
44 | }); |
45 | |
46 | let rd = ReadHalf { |
47 | inner: inner.clone(), |
48 | }; |
49 | |
50 | let wr = WriteHalf { inner }; |
51 | |
52 | (rd, wr) |
53 | } |
54 | } |
55 | |
56 | struct Inner<T> { |
57 | locked: AtomicBool, |
58 | stream: UnsafeCell<T>, |
59 | is_write_vectored: bool, |
60 | } |
61 | |
62 | struct Guard<'a, T> { |
63 | inner: &'a Inner<T>, |
64 | } |
65 | |
66 | impl<T> ReadHalf<T> { |
67 | /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same |
68 | /// stream. |
69 | pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool { |
70 | other.is_pair_of(self) |
71 | } |
72 | |
73 | /// Reunites with a previously split `WriteHalf`. |
74 | /// |
75 | /// # Panics |
76 | /// |
77 | /// If this `ReadHalf` and the given `WriteHalf` do not originate from the |
78 | /// same `split` operation this method will panic. |
79 | /// This can be checked ahead of time by comparing the stream ID |
80 | /// of the two halves. |
81 | #[track_caller ] |
82 | pub fn unsplit(self, wr: WriteHalf<T>) -> T |
83 | where |
84 | T: Unpin, |
85 | { |
86 | if self.is_pair_of(&wr) { |
87 | drop(wr); |
88 | |
89 | let inner = Arc::try_unwrap(self.inner) |
90 | .ok() |
91 | .expect("`Arc::try_unwrap` failed" ); |
92 | |
93 | inner.stream.into_inner() |
94 | } else { |
95 | panic!("Unrelated `split::Write` passed to `split::Read::unsplit`." ) |
96 | } |
97 | } |
98 | } |
99 | |
100 | impl<T> WriteHalf<T> { |
101 | /// Checks if this `WriteHalf` and some `ReadHalf` were split from the same |
102 | /// stream. |
103 | pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool { |
104 | Arc::ptr_eq(&self.inner, &other.inner) |
105 | } |
106 | } |
107 | |
108 | impl<T: AsyncRead> AsyncRead for ReadHalf<T> { |
109 | fn poll_read( |
110 | self: Pin<&mut Self>, |
111 | cx: &mut Context<'_>, |
112 | buf: &mut ReadBuf<'_>, |
113 | ) -> Poll<io::Result<()>> { |
114 | let mut inner = ready!(self.inner.poll_lock(cx)); |
115 | inner.stream_pin().poll_read(cx, buf) |
116 | } |
117 | } |
118 | |
119 | impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { |
120 | fn poll_write( |
121 | self: Pin<&mut Self>, |
122 | cx: &mut Context<'_>, |
123 | buf: &[u8], |
124 | ) -> Poll<Result<usize, io::Error>> { |
125 | let mut inner = ready!(self.inner.poll_lock(cx)); |
126 | inner.stream_pin().poll_write(cx, buf) |
127 | } |
128 | |
129 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
130 | let mut inner = ready!(self.inner.poll_lock(cx)); |
131 | inner.stream_pin().poll_flush(cx) |
132 | } |
133 | |
134 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
135 | let mut inner = ready!(self.inner.poll_lock(cx)); |
136 | inner.stream_pin().poll_shutdown(cx) |
137 | } |
138 | |
139 | fn poll_write_vectored( |
140 | self: Pin<&mut Self>, |
141 | cx: &mut Context<'_>, |
142 | bufs: &[io::IoSlice<'_>], |
143 | ) -> Poll<Result<usize, io::Error>> { |
144 | let mut inner = ready!(self.inner.poll_lock(cx)); |
145 | inner.stream_pin().poll_write_vectored(cx, bufs) |
146 | } |
147 | |
148 | fn is_write_vectored(&self) -> bool { |
149 | self.inner.is_write_vectored |
150 | } |
151 | } |
152 | |
153 | impl<T> Inner<T> { |
154 | fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>> { |
155 | if self |
156 | .locked |
157 | .compare_exchange(false, true, Acquire, Acquire) |
158 | .is_ok() |
159 | { |
160 | Poll::Ready(Guard { inner: self }) |
161 | } else { |
162 | // Spin... but investigate a better strategy |
163 | |
164 | std::thread::yield_now(); |
165 | cx.waker().wake_by_ref(); |
166 | |
167 | Poll::Pending |
168 | } |
169 | } |
170 | } |
171 | |
172 | impl<T> Guard<'_, T> { |
173 | fn stream_pin(&mut self) -> Pin<&mut T> { |
174 | // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual |
175 | // exclusion. |
176 | unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) } |
177 | } |
178 | } |
179 | |
180 | impl<T> Drop for Guard<'_, T> { |
181 | fn drop(&mut self) { |
182 | self.inner.locked.store(false, Release); |
183 | } |
184 | } |
185 | |
186 | unsafe impl<T: Send> Send for ReadHalf<T> {} |
187 | unsafe impl<T: Send> Send for WriteHalf<T> {} |
188 | unsafe impl<T: Sync> Sync for ReadHalf<T> {} |
189 | unsafe impl<T: Sync> Sync for WriteHalf<T> {} |
190 | |
191 | impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> { |
192 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
193 | fmt.debug_struct("split::ReadHalf" ).finish() |
194 | } |
195 | } |
196 | |
197 | impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> { |
198 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
199 | fmt.debug_struct("split::WriteHalf" ).finish() |
200 | } |
201 | } |
202 | |