1 | use futures_sink::Sink; |
2 | use std::pin::Pin; |
3 | use std::task::{Context, Poll}; |
4 | use std::{fmt, mem}; |
5 | use tokio::sync::mpsc::OwnedPermit; |
6 | use tokio::sync::mpsc::Sender; |
7 | |
8 | use super::ReusableBoxFuture; |
9 | |
10 | /// Error returned by the `PollSender` when the channel is closed. |
11 | #[derive (Debug)] |
12 | pub struct PollSendError<T>(Option<T>); |
13 | |
14 | impl<T> PollSendError<T> { |
15 | /// Consumes the stored value, if any. |
16 | /// |
17 | /// If this error was encountered when calling `start_send`/`send_item`, this will be the item |
18 | /// that the caller attempted to send. Otherwise, it will be `None`. |
19 | pub fn into_inner(self) -> Option<T> { |
20 | self.0 |
21 | } |
22 | } |
23 | |
24 | impl<T> fmt::Display for PollSendError<T> { |
25 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
26 | write!(fmt, "channel closed" ) |
27 | } |
28 | } |
29 | |
30 | impl<T: fmt::Debug> std::error::Error for PollSendError<T> {} |
31 | |
32 | #[derive (Debug)] |
33 | enum State<T> { |
34 | Idle(Sender<T>), |
35 | Acquiring, |
36 | ReadyToSend(OwnedPermit<T>), |
37 | Closed, |
38 | } |
39 | |
40 | /// A wrapper around [`mpsc::Sender`] that can be polled. |
41 | /// |
42 | /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender |
43 | #[derive (Debug)] |
44 | pub struct PollSender<T> { |
45 | sender: Option<Sender<T>>, |
46 | state: State<T>, |
47 | acquire: ReusableBoxFuture<'static, Result<OwnedPermit<T>, PollSendError<T>>>, |
48 | } |
49 | |
50 | // Creates a future for acquiring a permit from the underlying channel. This is used to ensure |
51 | // there's capacity for a send to complete. |
52 | // |
53 | // By reusing the same async fn for both `Some` and `None`, we make sure every future passed to |
54 | // ReusableBoxFuture has the same underlying type, and hence the same size and alignment. |
55 | async fn make_acquire_future<T>( |
56 | data: Option<Sender<T>>, |
57 | ) -> Result<OwnedPermit<T>, PollSendError<T>> { |
58 | match data { |
59 | Some(sender: Sender) => sender |
60 | .reserve_owned() |
61 | .await |
62 | .map_err(|_| PollSendError(None)), |
63 | None => unreachable!("this future should not be pollable in this state" ), |
64 | } |
65 | } |
66 | |
67 | impl<T: Send + 'static> PollSender<T> { |
68 | /// Creates a new `PollSender`. |
69 | pub fn new(sender: Sender<T>) -> Self { |
70 | Self { |
71 | sender: Some(sender.clone()), |
72 | state: State::Idle(sender), |
73 | acquire: ReusableBoxFuture::new(make_acquire_future(None)), |
74 | } |
75 | } |
76 | |
77 | fn take_state(&mut self) -> State<T> { |
78 | mem::replace(&mut self.state, State::Closed) |
79 | } |
80 | |
81 | /// Attempts to prepare the sender to receive a value. |
82 | /// |
83 | /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to |
84 | /// `send_item`. |
85 | /// |
86 | /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, |
87 | /// by reserving a slot in the channel for the item to be sent. If this method returns |
88 | /// `Poll::Pending`, the current task is registered to be notified (via |
89 | /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. |
90 | /// |
91 | /// # Errors |
92 | /// |
93 | /// If the channel is closed, an error will be returned. This is a permanent state. |
94 | pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> { |
95 | loop { |
96 | let (result, next_state) = match self.take_state() { |
97 | State::Idle(sender) => { |
98 | // Start trying to acquire a permit to reserve a slot for our send, and |
99 | // immediately loop back around to poll it the first time. |
100 | self.acquire.set(make_acquire_future(Some(sender))); |
101 | (None, State::Acquiring) |
102 | } |
103 | State::Acquiring => match self.acquire.poll(cx) { |
104 | // Channel has capacity. |
105 | Poll::Ready(Ok(permit)) => { |
106 | (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) |
107 | } |
108 | // Channel is closed. |
109 | Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), |
110 | // Channel doesn't have capacity yet, so we need to wait. |
111 | Poll::Pending => (Some(Poll::Pending), State::Acquiring), |
112 | }, |
113 | // We're closed, either by choice or because the underlying sender was closed. |
114 | s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), |
115 | // We're already ready to send an item. |
116 | s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), |
117 | }; |
118 | |
119 | self.state = next_state; |
120 | if let Some(result) = result { |
121 | return result; |
122 | } |
123 | } |
124 | } |
125 | |
126 | /// Sends an item to the channel. |
127 | /// |
128 | /// Before calling `send_item`, `poll_reserve` must be called with a successful return |
129 | /// value of `Poll::Ready(Ok(()))`. |
130 | /// |
131 | /// # Errors |
132 | /// |
133 | /// If the channel is closed, an error will be returned. This is a permanent state. |
134 | /// |
135 | /// # Panics |
136 | /// |
137 | /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method |
138 | /// will panic. |
139 | pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> { |
140 | let (result, next_state) = match self.take_state() { |
141 | State::Idle(_) | State::Acquiring => { |
142 | panic!("`send_item` called without first calling `poll_reserve`" ) |
143 | } |
144 | // We have a permit to send our item, so go ahead, which gets us our sender back. |
145 | State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), |
146 | // We're closed, either by choice or because the underlying sender was closed. |
147 | State::Closed => (Err(PollSendError(Some(value))), State::Closed), |
148 | }; |
149 | |
150 | // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. |
151 | self.state = if self.sender.is_some() { |
152 | next_state |
153 | } else { |
154 | State::Closed |
155 | }; |
156 | result |
157 | } |
158 | |
159 | /// Checks whether this sender is been closed. |
160 | /// |
161 | /// The underlying channel that this sender was wrapping may still be open. |
162 | pub fn is_closed(&self) -> bool { |
163 | matches!(self.state, State::Closed) || self.sender.is_none() |
164 | } |
165 | |
166 | /// Gets a reference to the `Sender` of the underlying channel. |
167 | /// |
168 | /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender |
169 | /// was wrapping may still be open. |
170 | pub fn get_ref(&self) -> Option<&Sender<T>> { |
171 | self.sender.as_ref() |
172 | } |
173 | |
174 | /// Closes this sender. |
175 | /// |
176 | /// No more messages will be able to be sent from this sender, but the underlying channel will |
177 | /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. |
178 | /// |
179 | /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made |
180 | /// to `send_item` in order to consume the reserved slot. After that, no further sends will be |
181 | /// possible. If you do not intend to send another item, you can release the reserved slot back |
182 | /// to the underlying sender by calling [`abort_send`]. |
183 | /// |
184 | /// [`abort_send`]: crate::sync::PollSender::abort_send |
185 | /// [`Receiver`]: tokio::sync::mpsc::Receiver |
186 | pub fn close(&mut self) { |
187 | // Mark ourselves officially closed by dropping our main sender. |
188 | self.sender = None; |
189 | |
190 | // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly |
191 | // transition to the closed state. Otherwise, leave the existing permit in place for the |
192 | // caller if they want to complete the send. |
193 | match self.state { |
194 | State::Idle(_) => self.state = State::Closed, |
195 | State::Acquiring => { |
196 | self.acquire.set(make_acquire_future(None)); |
197 | self.state = State::Closed; |
198 | } |
199 | _ => {} |
200 | } |
201 | } |
202 | |
203 | /// Aborts the current in-progress send, if any. |
204 | /// |
205 | /// Returns `true` if a send was aborted. If the sender was closed prior to calling |
206 | /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be |
207 | /// ready to attempt another send. |
208 | pub fn abort_send(&mut self) -> bool { |
209 | // We may have been closed in the meantime, after a call to `poll_reserve` already |
210 | // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the |
211 | // closed state when we actually abort a send, rather than resetting ourselves back to idle. |
212 | |
213 | let (result, next_state) = match self.take_state() { |
214 | // We're currently trying to reserve a slot to send into. |
215 | State::Acquiring => { |
216 | // Replacing the future drops the in-flight one. |
217 | self.acquire.set(make_acquire_future(None)); |
218 | |
219 | // If we haven't closed yet, we have to clone our stored sender since we have no way |
220 | // to get it back from the acquire future we just dropped. |
221 | let state = match self.sender.clone() { |
222 | Some(sender) => State::Idle(sender), |
223 | None => State::Closed, |
224 | }; |
225 | (true, state) |
226 | } |
227 | // We got the permit. If we haven't closed yet, get the sender back. |
228 | State::ReadyToSend(permit) => { |
229 | let state = if self.sender.is_some() { |
230 | State::Idle(permit.release()) |
231 | } else { |
232 | State::Closed |
233 | }; |
234 | (true, state) |
235 | } |
236 | s => (false, s), |
237 | }; |
238 | |
239 | self.state = next_state; |
240 | result |
241 | } |
242 | } |
243 | |
244 | impl<T> Clone for PollSender<T> { |
245 | /// Clones this `PollSender`. |
246 | /// |
247 | /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. |
248 | fn clone(&self) -> PollSender<T> { |
249 | let (sender: Option>, state: State) = match self.sender.clone() { |
250 | Some(sender: Sender) => (Some(sender.clone()), State::Idle(sender)), |
251 | None => (None, State::Closed), |
252 | }; |
253 | |
254 | Self { |
255 | sender, |
256 | state, |
257 | // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not |
258 | // compatible with the transitive bounds required by `Sender<T>`. |
259 | acquire: ReusableBoxFuture::new(future:async { unreachable!() }), |
260 | } |
261 | } |
262 | } |
263 | |
264 | impl<T: Send + 'static> Sink<T> for PollSender<T> { |
265 | type Error = PollSendError<T>; |
266 | |
267 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
268 | Pin::into_inner(self).poll_reserve(cx) |
269 | } |
270 | |
271 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
272 | Poll::Ready(Ok(())) |
273 | } |
274 | |
275 | fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { |
276 | Pin::into_inner(self).send_item(item) |
277 | } |
278 | |
279 | fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
280 | Pin::into_inner(self).close(); |
281 | Poll::Ready(Ok(())) |
282 | } |
283 | } |
284 | |