1 | //! Split a single value implementing `AsyncRead + AsyncWrite` into separate |
2 | //! `AsyncRead` and `AsyncWrite` handles. |
3 | //! |
4 | //! To restore this read/write object from its `split::ReadHalf` and |
5 | //! `split::WriteHalf` use `unsplit`. |
6 | |
7 | use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; |
8 | |
9 | use std::cell::UnsafeCell; |
10 | use std::fmt; |
11 | use std::io; |
12 | use std::pin::Pin; |
13 | use std::sync::atomic::AtomicBool; |
14 | use std::sync::atomic::Ordering::{Acquire, Release}; |
15 | use std::sync::Arc; |
16 | use std::task::{Context, Poll}; |
17 | |
18 | cfg_io_util! { |
19 | /// The readable half of a value returned from [`split`](split()). |
20 | pub struct ReadHalf<T> { |
21 | inner: Arc<Inner<T>>, |
22 | } |
23 | |
24 | /// The writable half of a value returned from [`split`](split()). |
25 | pub struct WriteHalf<T> { |
26 | inner: Arc<Inner<T>>, |
27 | } |
28 | |
29 | /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate |
30 | /// `AsyncRead` and `AsyncWrite` handles. |
31 | /// |
32 | /// To restore this read/write object from its `ReadHalf` and |
33 | /// `WriteHalf` use [`unsplit`](ReadHalf::unsplit()). |
34 | pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) |
35 | where |
36 | T: AsyncRead + AsyncWrite, |
37 | { |
38 | let inner = Arc::new(Inner { |
39 | locked: AtomicBool::new(false), |
40 | stream: UnsafeCell::new(stream), |
41 | }); |
42 | |
43 | let rd = ReadHalf { |
44 | inner: inner.clone(), |
45 | }; |
46 | |
47 | let wr = WriteHalf { inner }; |
48 | |
49 | (rd, wr) |
50 | } |
51 | } |
52 | |
53 | struct Inner<T> { |
54 | locked: AtomicBool, |
55 | stream: UnsafeCell<T>, |
56 | } |
57 | |
58 | struct Guard<'a, T> { |
59 | inner: &'a Inner<T>, |
60 | } |
61 | |
62 | impl<T> ReadHalf<T> { |
63 | /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same |
64 | /// stream. |
65 | pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool { |
66 | other.is_pair_of(self) |
67 | } |
68 | |
69 | /// Reunites with a previously split `WriteHalf`. |
70 | /// |
71 | /// # Panics |
72 | /// |
73 | /// If this `ReadHalf` and the given `WriteHalf` do not originate from the |
74 | /// same `split` operation this method will panic. |
75 | /// This can be checked ahead of time by comparing the stream ID |
76 | /// of the two halves. |
77 | #[track_caller ] |
78 | pub fn unsplit(self, wr: WriteHalf<T>) -> T |
79 | where |
80 | T: Unpin, |
81 | { |
82 | if self.is_pair_of(&wr) { |
83 | drop(wr); |
84 | |
85 | let inner = Arc::try_unwrap(self.inner) |
86 | .ok() |
87 | .expect("`Arc::try_unwrap` failed" ); |
88 | |
89 | inner.stream.into_inner() |
90 | } else { |
91 | panic!("Unrelated `split::Write` passed to `split::Read::unsplit`." ) |
92 | } |
93 | } |
94 | } |
95 | |
96 | impl<T> WriteHalf<T> { |
97 | /// Checks if this `WriteHalf` and some `ReadHalf` were split from the same |
98 | /// stream. |
99 | pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool { |
100 | Arc::ptr_eq(&self.inner, &other.inner) |
101 | } |
102 | } |
103 | |
104 | impl<T: AsyncRead> AsyncRead for ReadHalf<T> { |
105 | fn poll_read( |
106 | self: Pin<&mut Self>, |
107 | cx: &mut Context<'_>, |
108 | buf: &mut ReadBuf<'_>, |
109 | ) -> Poll<io::Result<()>> { |
110 | let mut inner: Guard<'_, T> = ready!(self.inner.poll_lock(cx)); |
111 | inner.stream_pin().poll_read(cx, buf) |
112 | } |
113 | } |
114 | |
115 | impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { |
116 | fn poll_write( |
117 | self: Pin<&mut Self>, |
118 | cx: &mut Context<'_>, |
119 | buf: &[u8], |
120 | ) -> Poll<Result<usize, io::Error>> { |
121 | let mut inner: Guard<'_, T> = ready!(self.inner.poll_lock(cx)); |
122 | inner.stream_pin().poll_write(cx, buf) |
123 | } |
124 | |
125 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
126 | let mut inner: Guard<'_, T> = ready!(self.inner.poll_lock(cx)); |
127 | inner.stream_pin().poll_flush(cx) |
128 | } |
129 | |
130 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { |
131 | let mut inner: Guard<'_, T> = ready!(self.inner.poll_lock(cx)); |
132 | inner.stream_pin().poll_shutdown(cx) |
133 | } |
134 | } |
135 | |
136 | impl<T> Inner<T> { |
137 | fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>> { |
138 | if self |
139 | .locked |
140 | .compare_exchange(current:false, new:true, success:Acquire, failure:Acquire) |
141 | .is_ok() |
142 | { |
143 | Poll::Ready(Guard { inner: self }) |
144 | } else { |
145 | // Spin... but investigate a better strategy |
146 | |
147 | std::thread::yield_now(); |
148 | cx.waker().wake_by_ref(); |
149 | |
150 | Poll::Pending |
151 | } |
152 | } |
153 | } |
154 | |
155 | impl<T> Guard<'_, T> { |
156 | fn stream_pin(&mut self) -> Pin<&mut T> { |
157 | // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual |
158 | // exclusion. |
159 | unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) } |
160 | } |
161 | } |
162 | |
163 | impl<T> Drop for Guard<'_, T> { |
164 | fn drop(&mut self) { |
165 | self.inner.locked.store(val:false, order:Release); |
166 | } |
167 | } |
168 | |
169 | unsafe impl<T: Send> Send for ReadHalf<T> {} |
170 | unsafe impl<T: Send> Send for WriteHalf<T> {} |
171 | unsafe impl<T: Sync> Sync for ReadHalf<T> {} |
172 | unsafe impl<T: Sync> Sync for WriteHalf<T> {} |
173 | |
174 | impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> { |
175 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
176 | fmt.debug_struct(name:"split::ReadHalf" ).finish() |
177 | } |
178 | } |
179 | |
180 | impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> { |
181 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
182 | fmt.debug_struct(name:"split::WriteHalf" ).finish() |
183 | } |
184 | } |
185 | |