1 | use futures_sink::Sink; |
2 | use std::pin::Pin; |
3 | use std::task::{Context, Poll}; |
4 | use std::{fmt, mem}; |
5 | use tokio::sync::mpsc::OwnedPermit; |
6 | use tokio::sync::mpsc::Sender; |
7 | |
8 | use super::ReusableBoxFuture; |
9 | |
10 | /// Error returned by the `PollSender` when the channel is closed. |
11 | #[derive(Debug)] |
12 | pub struct PollSendError<T>(Option<T>); |
13 | |
14 | impl<T> PollSendError<T> { |
15 | /// Consumes the stored value, if any. |
16 | /// |
17 | /// If this error was encountered when calling `start_send`/`send_item`, this will be the item |
18 | /// that the caller attempted to send. Otherwise, it will be `None`. |
19 | pub fn into_inner(self) -> Option<T> { |
20 | self.0 |
21 | } |
22 | } |
23 | |
24 | impl<T> fmt::Display for PollSendError<T> { |
25 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
26 | write!(fmt, "channel closed" ) |
27 | } |
28 | } |
29 | |
30 | impl<T: fmt::Debug> std::error::Error for PollSendError<T> {} |
31 | |
32 | #[derive(Debug)] |
33 | enum State<T> { |
34 | Idle(Sender<T>), |
35 | Acquiring, |
36 | ReadyToSend(OwnedPermit<T>), |
37 | Closed, |
38 | } |
39 | |
40 | /// A wrapper around [`mpsc::Sender`] that can be polled. |
41 | /// |
42 | /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender |
43 | #[derive(Debug)] |
44 | pub struct PollSender<T> { |
45 | sender: Option<Sender<T>>, |
46 | state: State<T>, |
47 | acquire: PollSenderFuture<T>, |
48 | } |
49 | |
50 | // Creates a future for acquiring a permit from the underlying channel. This is used to ensure |
51 | // there's capacity for a send to complete. |
52 | // |
53 | // By reusing the same async fn for both `Some` and `None`, we make sure every future passed to |
54 | // ReusableBoxFuture has the same underlying type, and hence the same size and alignment. |
55 | async fn make_acquire_future<T>( |
56 | data: Option<Sender<T>>, |
57 | ) -> Result<OwnedPermit<T>, PollSendError<T>> { |
58 | match data { |
59 | Some(sender) => sender |
60 | .reserve_owned() |
61 | .await |
62 | .map_err(|_| PollSendError(None)), |
63 | None => unreachable!("this future should not be pollable in this state" ), |
64 | } |
65 | } |
66 | |
67 | type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result<OwnedPermit<T>, PollSendError<T>>>; |
68 | |
69 | #[derive(Debug)] |
70 | // TODO: This should be replace with a type_alias_impl_trait to eliminate `'static` and all the transmutes |
71 | struct PollSenderFuture<T>(InnerFuture<'static, T>); |
72 | |
73 | impl<T> PollSenderFuture<T> { |
74 | /// Create with an empty inner future with no `Send` bound. |
75 | fn empty() -> Self { |
76 | // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not |
77 | // compatible with the transitive bounds required by `Sender<T>`. |
78 | Self(ReusableBoxFuture::new(async { unreachable!() })) |
79 | } |
80 | } |
81 | |
82 | impl<T: Send> PollSenderFuture<T> { |
83 | /// Create with an empty inner future. |
84 | fn new() -> Self { |
85 | let v = InnerFuture::new(make_acquire_future(None)); |
86 | // This is safe because `make_acquire_future(None)` is actually `'static` |
87 | Self(unsafe { mem::transmute::<InnerFuture<'_, T>, InnerFuture<'static, T>>(v) }) |
88 | } |
89 | |
90 | /// Poll the inner future. |
91 | fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>> { |
92 | self.0.poll(cx) |
93 | } |
94 | |
95 | /// Replace the inner future. |
96 | fn set(&mut self, sender: Option<Sender<T>>) { |
97 | let inner: *mut InnerFuture<'static, T> = &mut self.0; |
98 | let inner: *mut InnerFuture<'_, T> = inner.cast(); |
99 | // SAFETY: The `make_acquire_future(sender)` future must not exist after the type `T` |
100 | // becomes invalid, and this casts away the type-level lifetime check for that. However, the |
101 | // inner future is never moved out of this `PollSenderFuture<T>`, so the future will not |
102 | // live longer than the `PollSenderFuture<T>` lives. A `PollSenderFuture<T>` is guaranteed |
103 | // to not exist after the type `T` becomes invalid, because it is annotated with a `T`, so |
104 | // this is ok. |
105 | let inner = unsafe { &mut *inner }; |
106 | inner.set(make_acquire_future(sender)); |
107 | } |
108 | } |
109 | |
110 | impl<T: Send> PollSender<T> { |
111 | /// Creates a new `PollSender`. |
112 | pub fn new(sender: Sender<T>) -> Self { |
113 | Self { |
114 | sender: Some(sender.clone()), |
115 | state: State::Idle(sender), |
116 | acquire: PollSenderFuture::new(), |
117 | } |
118 | } |
119 | |
120 | fn take_state(&mut self) -> State<T> { |
121 | mem::replace(&mut self.state, State::Closed) |
122 | } |
123 | |
124 | /// Attempts to prepare the sender to receive a value. |
125 | /// |
126 | /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to |
127 | /// `send_item`. |
128 | /// |
129 | /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, |
130 | /// by reserving a slot in the channel for the item to be sent. If this method returns |
131 | /// `Poll::Pending`, the current task is registered to be notified (via |
132 | /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. |
133 | /// |
134 | /// # Errors |
135 | /// |
136 | /// If the channel is closed, an error will be returned. This is a permanent state. |
137 | pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> { |
138 | loop { |
139 | let (result, next_state) = match self.take_state() { |
140 | State::Idle(sender) => { |
141 | // Start trying to acquire a permit to reserve a slot for our send, and |
142 | // immediately loop back around to poll it the first time. |
143 | self.acquire.set(Some(sender)); |
144 | (None, State::Acquiring) |
145 | } |
146 | State::Acquiring => match self.acquire.poll(cx) { |
147 | // Channel has capacity. |
148 | Poll::Ready(Ok(permit)) => { |
149 | (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) |
150 | } |
151 | // Channel is closed. |
152 | Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), |
153 | // Channel doesn't have capacity yet, so we need to wait. |
154 | Poll::Pending => (Some(Poll::Pending), State::Acquiring), |
155 | }, |
156 | // We're closed, either by choice or because the underlying sender was closed. |
157 | s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), |
158 | // We're already ready to send an item. |
159 | s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), |
160 | }; |
161 | |
162 | self.state = next_state; |
163 | if let Some(result) = result { |
164 | return result; |
165 | } |
166 | } |
167 | } |
168 | |
169 | /// Sends an item to the channel. |
170 | /// |
171 | /// Before calling `send_item`, `poll_reserve` must be called with a successful return |
172 | /// value of `Poll::Ready(Ok(()))`. |
173 | /// |
174 | /// # Errors |
175 | /// |
176 | /// If the channel is closed, an error will be returned. This is a permanent state. |
177 | /// |
178 | /// # Panics |
179 | /// |
180 | /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method |
181 | /// will panic. |
182 | #[track_caller ] |
183 | pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> { |
184 | let (result, next_state) = match self.take_state() { |
185 | State::Idle(_) | State::Acquiring => { |
186 | panic!("`send_item` called without first calling `poll_reserve`" ) |
187 | } |
188 | // We have a permit to send our item, so go ahead, which gets us our sender back. |
189 | State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), |
190 | // We're closed, either by choice or because the underlying sender was closed. |
191 | State::Closed => (Err(PollSendError(Some(value))), State::Closed), |
192 | }; |
193 | |
194 | // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. |
195 | self.state = if self.sender.is_some() { |
196 | next_state |
197 | } else { |
198 | State::Closed |
199 | }; |
200 | result |
201 | } |
202 | |
203 | /// Checks whether this sender is been closed. |
204 | /// |
205 | /// The underlying channel that this sender was wrapping may still be open. |
206 | pub fn is_closed(&self) -> bool { |
207 | matches!(self.state, State::Closed) || self.sender.is_none() |
208 | } |
209 | |
210 | /// Gets a reference to the `Sender` of the underlying channel. |
211 | /// |
212 | /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender |
213 | /// was wrapping may still be open. |
214 | pub fn get_ref(&self) -> Option<&Sender<T>> { |
215 | self.sender.as_ref() |
216 | } |
217 | |
218 | /// Closes this sender. |
219 | /// |
220 | /// No more messages will be able to be sent from this sender, but the underlying channel will |
221 | /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. |
222 | /// |
223 | /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made |
224 | /// to `send_item` in order to consume the reserved slot. After that, no further sends will be |
225 | /// possible. If you do not intend to send another item, you can release the reserved slot back |
226 | /// to the underlying sender by calling [`abort_send`]. |
227 | /// |
228 | /// [`abort_send`]: crate::sync::PollSender::abort_send |
229 | /// [`Receiver`]: tokio::sync::mpsc::Receiver |
230 | pub fn close(&mut self) { |
231 | // Mark ourselves officially closed by dropping our main sender. |
232 | self.sender = None; |
233 | |
234 | // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly |
235 | // transition to the closed state. Otherwise, leave the existing permit in place for the |
236 | // caller if they want to complete the send. |
237 | match self.state { |
238 | State::Idle(_) => self.state = State::Closed, |
239 | State::Acquiring => { |
240 | self.acquire.set(None); |
241 | self.state = State::Closed; |
242 | } |
243 | _ => {} |
244 | } |
245 | } |
246 | |
247 | /// Aborts the current in-progress send, if any. |
248 | /// |
249 | /// Returns `true` if a send was aborted. If the sender was closed prior to calling |
250 | /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be |
251 | /// ready to attempt another send. |
252 | pub fn abort_send(&mut self) -> bool { |
253 | // We may have been closed in the meantime, after a call to `poll_reserve` already |
254 | // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the |
255 | // closed state when we actually abort a send, rather than resetting ourselves back to idle. |
256 | |
257 | let (result, next_state) = match self.take_state() { |
258 | // We're currently trying to reserve a slot to send into. |
259 | State::Acquiring => { |
260 | // Replacing the future drops the in-flight one. |
261 | self.acquire.set(None); |
262 | |
263 | // If we haven't closed yet, we have to clone our stored sender since we have no way |
264 | // to get it back from the acquire future we just dropped. |
265 | let state = match self.sender.clone() { |
266 | Some(sender) => State::Idle(sender), |
267 | None => State::Closed, |
268 | }; |
269 | (true, state) |
270 | } |
271 | // We got the permit. If we haven't closed yet, get the sender back. |
272 | State::ReadyToSend(permit) => { |
273 | let state = if self.sender.is_some() { |
274 | State::Idle(permit.release()) |
275 | } else { |
276 | State::Closed |
277 | }; |
278 | (true, state) |
279 | } |
280 | s => (false, s), |
281 | }; |
282 | |
283 | self.state = next_state; |
284 | result |
285 | } |
286 | } |
287 | |
288 | impl<T> Clone for PollSender<T> { |
289 | /// Clones this `PollSender`. |
290 | /// |
291 | /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. |
292 | fn clone(&self) -> PollSender<T> { |
293 | let (sender, state) = match self.sender.clone() { |
294 | Some(sender) => (Some(sender.clone()), State::Idle(sender)), |
295 | None => (None, State::Closed), |
296 | }; |
297 | |
298 | Self { |
299 | sender, |
300 | state, |
301 | acquire: PollSenderFuture::empty(), |
302 | } |
303 | } |
304 | } |
305 | |
306 | impl<T: Send + 'static> Sink<T> for PollSender<T> { |
307 | type Error = PollSendError<T>; |
308 | |
309 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
310 | Pin::into_inner(self).poll_reserve(cx) |
311 | } |
312 | |
313 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
314 | Poll::Ready(Ok(())) |
315 | } |
316 | |
317 | fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { |
318 | Pin::into_inner(self).send_item(item) |
319 | } |
320 | |
321 | fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
322 | Pin::into_inner(self).close(); |
323 | Poll::Ready(Ok(())) |
324 | } |
325 | } |
326 | |