1use futures_sink::Sink;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use std::{fmt, mem};
5use tokio::sync::mpsc::OwnedPermit;
6use tokio::sync::mpsc::Sender;
7
8use super::ReusableBoxFuture;
9
10/// Error returned by the `PollSender` when the channel is closed.
11#[derive(Debug)]
12pub struct PollSendError<T>(Option<T>);
13
14impl<T> PollSendError<T> {
15 /// Consumes the stored value, if any.
16 ///
17 /// If this error was encountered when calling `start_send`/`send_item`, this will be the item
18 /// that the caller attempted to send. Otherwise, it will be `None`.
19 pub fn into_inner(self) -> Option<T> {
20 self.0
21 }
22}
23
24impl<T> fmt::Display for PollSendError<T> {
25 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
26 write!(fmt, "channel closed")
27 }
28}
29
30impl<T: fmt::Debug> std::error::Error for PollSendError<T> {}
31
32#[derive(Debug)]
33enum State<T> {
34 Idle(Sender<T>),
35 Acquiring,
36 ReadyToSend(OwnedPermit<T>),
37 Closed,
38}
39
40/// A wrapper around [`mpsc::Sender`] that can be polled.
41///
42/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
43#[derive(Debug)]
44pub struct PollSender<T> {
45 sender: Option<Sender<T>>,
46 state: State<T>,
47 acquire: ReusableBoxFuture<'static, Result<OwnedPermit<T>, PollSendError<T>>>,
48}
49
50// Creates a future for acquiring a permit from the underlying channel. This is used to ensure
51// there's capacity for a send to complete.
52//
53// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to
54// ReusableBoxFuture has the same underlying type, and hence the same size and alignment.
55async fn make_acquire_future<T>(
56 data: Option<Sender<T>>,
57) -> Result<OwnedPermit<T>, PollSendError<T>> {
58 match data {
59 Some(sender: Sender) => sender
60 .reserve_owned()
61 .await
62 .map_err(|_| PollSendError(None)),
63 None => unreachable!("this future should not be pollable in this state"),
64 }
65}
66
67impl<T: Send + 'static> PollSender<T> {
68 /// Creates a new `PollSender`.
69 pub fn new(sender: Sender<T>) -> Self {
70 Self {
71 sender: Some(sender.clone()),
72 state: State::Idle(sender),
73 acquire: ReusableBoxFuture::new(make_acquire_future(None)),
74 }
75 }
76
77 fn take_state(&mut self) -> State<T> {
78 mem::replace(&mut self.state, State::Closed)
79 }
80
81 /// Attempts to prepare the sender to receive a value.
82 ///
83 /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to
84 /// `send_item`.
85 ///
86 /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value,
87 /// by reserving a slot in the channel for the item to be sent. If this method returns
88 /// `Poll::Pending`, the current task is registered to be notified (via
89 /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again.
90 ///
91 /// # Errors
92 ///
93 /// If the channel is closed, an error will be returned. This is a permanent state.
94 pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
95 loop {
96 let (result, next_state) = match self.take_state() {
97 State::Idle(sender) => {
98 // Start trying to acquire a permit to reserve a slot for our send, and
99 // immediately loop back around to poll it the first time.
100 self.acquire.set(make_acquire_future(Some(sender)));
101 (None, State::Acquiring)
102 }
103 State::Acquiring => match self.acquire.poll(cx) {
104 // Channel has capacity.
105 Poll::Ready(Ok(permit)) => {
106 (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit))
107 }
108 // Channel is closed.
109 Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed),
110 // Channel doesn't have capacity yet, so we need to wait.
111 Poll::Pending => (Some(Poll::Pending), State::Acquiring),
112 },
113 // We're closed, either by choice or because the underlying sender was closed.
114 s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s),
115 // We're already ready to send an item.
116 s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s),
117 };
118
119 self.state = next_state;
120 if let Some(result) = result {
121 return result;
122 }
123 }
124 }
125
126 /// Sends an item to the channel.
127 ///
128 /// Before calling `send_item`, `poll_reserve` must be called with a successful return
129 /// value of `Poll::Ready(Ok(()))`.
130 ///
131 /// # Errors
132 ///
133 /// If the channel is closed, an error will be returned. This is a permanent state.
134 ///
135 /// # Panics
136 ///
137 /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method
138 /// will panic.
139 pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> {
140 let (result, next_state) = match self.take_state() {
141 State::Idle(_) | State::Acquiring => {
142 panic!("`send_item` called without first calling `poll_reserve`")
143 }
144 // We have a permit to send our item, so go ahead, which gets us our sender back.
145 State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))),
146 // We're closed, either by choice or because the underlying sender was closed.
147 State::Closed => (Err(PollSendError(Some(value))), State::Closed),
148 };
149
150 // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`.
151 self.state = if self.sender.is_some() {
152 next_state
153 } else {
154 State::Closed
155 };
156 result
157 }
158
159 /// Checks whether this sender is been closed.
160 ///
161 /// The underlying channel that this sender was wrapping may still be open.
162 pub fn is_closed(&self) -> bool {
163 matches!(self.state, State::Closed) || self.sender.is_none()
164 }
165
166 /// Gets a reference to the `Sender` of the underlying channel.
167 ///
168 /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender
169 /// was wrapping may still be open.
170 pub fn get_ref(&self) -> Option<&Sender<T>> {
171 self.sender.as_ref()
172 }
173
174 /// Closes this sender.
175 ///
176 /// No more messages will be able to be sent from this sender, but the underlying channel will
177 /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel.
178 ///
179 /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made
180 /// to `send_item` in order to consume the reserved slot. After that, no further sends will be
181 /// possible. If you do not intend to send another item, you can release the reserved slot back
182 /// to the underlying sender by calling [`abort_send`].
183 ///
184 /// [`abort_send`]: crate::sync::PollSender::abort_send
185 /// [`Receiver`]: tokio::sync::mpsc::Receiver
186 pub fn close(&mut self) {
187 // Mark ourselves officially closed by dropping our main sender.
188 self.sender = None;
189
190 // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly
191 // transition to the closed state. Otherwise, leave the existing permit in place for the
192 // caller if they want to complete the send.
193 match self.state {
194 State::Idle(_) => self.state = State::Closed,
195 State::Acquiring => {
196 self.acquire.set(make_acquire_future(None));
197 self.state = State::Closed;
198 }
199 _ => {}
200 }
201 }
202
203 /// Aborts the current in-progress send, if any.
204 ///
205 /// Returns `true` if a send was aborted. If the sender was closed prior to calling
206 /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be
207 /// ready to attempt another send.
208 pub fn abort_send(&mut self) -> bool {
209 // We may have been closed in the meantime, after a call to `poll_reserve` already
210 // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the
211 // closed state when we actually abort a send, rather than resetting ourselves back to idle.
212
213 let (result, next_state) = match self.take_state() {
214 // We're currently trying to reserve a slot to send into.
215 State::Acquiring => {
216 // Replacing the future drops the in-flight one.
217 self.acquire.set(make_acquire_future(None));
218
219 // If we haven't closed yet, we have to clone our stored sender since we have no way
220 // to get it back from the acquire future we just dropped.
221 let state = match self.sender.clone() {
222 Some(sender) => State::Idle(sender),
223 None => State::Closed,
224 };
225 (true, state)
226 }
227 // We got the permit. If we haven't closed yet, get the sender back.
228 State::ReadyToSend(permit) => {
229 let state = if self.sender.is_some() {
230 State::Idle(permit.release())
231 } else {
232 State::Closed
233 };
234 (true, state)
235 }
236 s => (false, s),
237 };
238
239 self.state = next_state;
240 result
241 }
242}
243
244impl<T> Clone for PollSender<T> {
245 /// Clones this `PollSender`.
246 ///
247 /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`.
248 fn clone(&self) -> PollSender<T> {
249 let (sender: Option>, state: State) = match self.sender.clone() {
250 Some(sender: Sender) => (Some(sender.clone()), State::Idle(sender)),
251 None => (None, State::Closed),
252 };
253
254 Self {
255 sender,
256 state,
257 // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not
258 // compatible with the transitive bounds required by `Sender<T>`.
259 acquire: ReusableBoxFuture::new(future:async { unreachable!() }),
260 }
261 }
262}
263
264impl<T: Send + 'static> Sink<T> for PollSender<T> {
265 type Error = PollSendError<T>;
266
267 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
268 Pin::into_inner(self).poll_reserve(cx)
269 }
270
271 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
272 Poll::Ready(Ok(()))
273 }
274
275 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
276 Pin::into_inner(self).send_item(item)
277 }
278
279 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
280 Pin::into_inner(self).close();
281 Poll::Ready(Ok(()))
282 }
283}
284