| 1 | use futures_sink::Sink; |
| 2 | use std::pin::Pin; |
| 3 | use std::task::{Context, Poll}; |
| 4 | use std::{fmt, mem}; |
| 5 | use tokio::sync::mpsc::OwnedPermit; |
| 6 | use tokio::sync::mpsc::Sender; |
| 7 | |
| 8 | use super::ReusableBoxFuture; |
| 9 | |
| 10 | /// Error returned by the `PollSender` when the channel is closed. |
| 11 | #[derive (Debug)] |
| 12 | pub struct PollSendError<T>(Option<T>); |
| 13 | |
| 14 | impl<T> PollSendError<T> { |
| 15 | /// Consumes the stored value, if any. |
| 16 | /// |
| 17 | /// If this error was encountered when calling `start_send`/`send_item`, this will be the item |
| 18 | /// that the caller attempted to send. Otherwise, it will be `None`. |
| 19 | pub fn into_inner(self) -> Option<T> { |
| 20 | self.0 |
| 21 | } |
| 22 | } |
| 23 | |
| 24 | impl<T> fmt::Display for PollSendError<T> { |
| 25 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 26 | write!(fmt, "channel closed" ) |
| 27 | } |
| 28 | } |
| 29 | |
| 30 | impl<T: fmt::Debug> std::error::Error for PollSendError<T> {} |
| 31 | |
| 32 | #[derive (Debug)] |
| 33 | enum State<T> { |
| 34 | Idle(Sender<T>), |
| 35 | Acquiring, |
| 36 | ReadyToSend(OwnedPermit<T>), |
| 37 | Closed, |
| 38 | } |
| 39 | |
| 40 | /// A wrapper around [`mpsc::Sender`] that can be polled. |
| 41 | /// |
| 42 | /// [`mpsc::Sender`]: tokio::sync::mpsc::Sender |
| 43 | #[derive (Debug)] |
| 44 | pub struct PollSender<T> { |
| 45 | sender: Option<Sender<T>>, |
| 46 | state: State<T>, |
| 47 | acquire: PollSenderFuture<T>, |
| 48 | } |
| 49 | |
| 50 | // Creates a future for acquiring a permit from the underlying channel. This is used to ensure |
| 51 | // there's capacity for a send to complete. |
| 52 | // |
| 53 | // By reusing the same async fn for both `Some` and `None`, we make sure every future passed to |
| 54 | // ReusableBoxFuture has the same underlying type, and hence the same size and alignment. |
| 55 | async fn make_acquire_future<T>( |
| 56 | data: Option<Sender<T>>, |
| 57 | ) -> Result<OwnedPermit<T>, PollSendError<T>> { |
| 58 | match data { |
| 59 | Some(sender: Sender) => sender |
| 60 | .reserve_owned() |
| 61 | .await |
| 62 | .map_err(|_| PollSendError(None)), |
| 63 | None => unreachable!("this future should not be pollable in this state" ), |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result<OwnedPermit<T>, PollSendError<T>>>; |
| 68 | |
| 69 | #[derive (Debug)] |
| 70 | // TODO: This should be replace with a type_alias_impl_trait to eliminate `'static` and all the transmutes |
| 71 | struct PollSenderFuture<T>(InnerFuture<'static, T>); |
| 72 | |
| 73 | impl<T> PollSenderFuture<T> { |
| 74 | /// Create with an empty inner future with no `Send` bound. |
| 75 | fn empty() -> Self { |
| 76 | // We don't use `make_acquire_future` here because our relaxed bounds on `T` are not |
| 77 | // compatible with the transitive bounds required by `Sender<T>`. |
| 78 | Self(ReusableBoxFuture::new(future:async { unreachable!() })) |
| 79 | } |
| 80 | } |
| 81 | |
| 82 | impl<T: Send> PollSenderFuture<T> { |
| 83 | /// Create with an empty inner future. |
| 84 | fn new() -> Self { |
| 85 | let v = InnerFuture::new(make_acquire_future(None)); |
| 86 | // This is safe because `make_acquire_future(None)` is actually `'static` |
| 87 | Self(unsafe { mem::transmute::<InnerFuture<'_, T>, InnerFuture<'static, T>>(v) }) |
| 88 | } |
| 89 | |
| 90 | /// Poll the inner future. |
| 91 | fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>> { |
| 92 | self.0.poll(cx) |
| 93 | } |
| 94 | |
| 95 | /// Replace the inner future. |
| 96 | fn set(&mut self, sender: Option<Sender<T>>) { |
| 97 | let inner: *mut InnerFuture<'static, T> = &mut self.0; |
| 98 | let inner: *mut InnerFuture<'_, T> = inner.cast(); |
| 99 | // SAFETY: The `make_acquire_future(sender)` future must not exist after the type `T` |
| 100 | // becomes invalid, and this casts away the type-level lifetime check for that. However, the |
| 101 | // inner future is never moved out of this `PollSenderFuture<T>`, so the future will not |
| 102 | // live longer than the `PollSenderFuture<T>` lives. A `PollSenderFuture<T>` is guaranteed |
| 103 | // to not exist after the type `T` becomes invalid, because it is annotated with a `T`, so |
| 104 | // this is ok. |
| 105 | let inner = unsafe { &mut *inner }; |
| 106 | inner.set(make_acquire_future(sender)); |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | impl<T: Send> PollSender<T> { |
| 111 | /// Creates a new `PollSender`. |
| 112 | pub fn new(sender: Sender<T>) -> Self { |
| 113 | Self { |
| 114 | sender: Some(sender.clone()), |
| 115 | state: State::Idle(sender), |
| 116 | acquire: PollSenderFuture::new(), |
| 117 | } |
| 118 | } |
| 119 | |
| 120 | fn take_state(&mut self) -> State<T> { |
| 121 | mem::replace(&mut self.state, State::Closed) |
| 122 | } |
| 123 | |
| 124 | /// Attempts to prepare the sender to receive a value. |
| 125 | /// |
| 126 | /// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to |
| 127 | /// `send_item`. |
| 128 | /// |
| 129 | /// This method returns `Poll::Ready` once the underlying channel is ready to receive a value, |
| 130 | /// by reserving a slot in the channel for the item to be sent. If this method returns |
| 131 | /// `Poll::Pending`, the current task is registered to be notified (via |
| 132 | /// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again. |
| 133 | /// |
| 134 | /// # Errors |
| 135 | /// |
| 136 | /// If the channel is closed, an error will be returned. This is a permanent state. |
| 137 | pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> { |
| 138 | loop { |
| 139 | let (result, next_state) = match self.take_state() { |
| 140 | State::Idle(sender) => { |
| 141 | // Start trying to acquire a permit to reserve a slot for our send, and |
| 142 | // immediately loop back around to poll it the first time. |
| 143 | self.acquire.set(Some(sender)); |
| 144 | (None, State::Acquiring) |
| 145 | } |
| 146 | State::Acquiring => match self.acquire.poll(cx) { |
| 147 | // Channel has capacity. |
| 148 | Poll::Ready(Ok(permit)) => { |
| 149 | (Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit)) |
| 150 | } |
| 151 | // Channel is closed. |
| 152 | Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed), |
| 153 | // Channel doesn't have capacity yet, so we need to wait. |
| 154 | Poll::Pending => (Some(Poll::Pending), State::Acquiring), |
| 155 | }, |
| 156 | // We're closed, either by choice or because the underlying sender was closed. |
| 157 | s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s), |
| 158 | // We're already ready to send an item. |
| 159 | s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s), |
| 160 | }; |
| 161 | |
| 162 | self.state = next_state; |
| 163 | if let Some(result) = result { |
| 164 | return result; |
| 165 | } |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | /// Sends an item to the channel. |
| 170 | /// |
| 171 | /// Before calling `send_item`, `poll_reserve` must be called with a successful return |
| 172 | /// value of `Poll::Ready(Ok(()))`. |
| 173 | /// |
| 174 | /// # Errors |
| 175 | /// |
| 176 | /// If the channel is closed, an error will be returned. This is a permanent state. |
| 177 | /// |
| 178 | /// # Panics |
| 179 | /// |
| 180 | /// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method |
| 181 | /// will panic. |
| 182 | #[track_caller ] |
| 183 | pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> { |
| 184 | let (result, next_state) = match self.take_state() { |
| 185 | State::Idle(_) | State::Acquiring => { |
| 186 | panic!("`send_item` called without first calling `poll_reserve`" ) |
| 187 | } |
| 188 | // We have a permit to send our item, so go ahead, which gets us our sender back. |
| 189 | State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))), |
| 190 | // We're closed, either by choice or because the underlying sender was closed. |
| 191 | State::Closed => (Err(PollSendError(Some(value))), State::Closed), |
| 192 | }; |
| 193 | |
| 194 | // Handle deferred closing if `close` was called between `poll_reserve` and `send_item`. |
| 195 | self.state = if self.sender.is_some() { |
| 196 | next_state |
| 197 | } else { |
| 198 | State::Closed |
| 199 | }; |
| 200 | result |
| 201 | } |
| 202 | |
| 203 | /// Checks whether this sender is been closed. |
| 204 | /// |
| 205 | /// The underlying channel that this sender was wrapping may still be open. |
| 206 | pub fn is_closed(&self) -> bool { |
| 207 | matches!(self.state, State::Closed) || self.sender.is_none() |
| 208 | } |
| 209 | |
| 210 | /// Gets a reference to the `Sender` of the underlying channel. |
| 211 | /// |
| 212 | /// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender |
| 213 | /// was wrapping may still be open. |
| 214 | pub fn get_ref(&self) -> Option<&Sender<T>> { |
| 215 | self.sender.as_ref() |
| 216 | } |
| 217 | |
| 218 | /// Closes this sender. |
| 219 | /// |
| 220 | /// No more messages will be able to be sent from this sender, but the underlying channel will |
| 221 | /// remain open until all senders have dropped, or until the [`Receiver`] closes the channel. |
| 222 | /// |
| 223 | /// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made |
| 224 | /// to `send_item` in order to consume the reserved slot. After that, no further sends will be |
| 225 | /// possible. If you do not intend to send another item, you can release the reserved slot back |
| 226 | /// to the underlying sender by calling [`abort_send`]. |
| 227 | /// |
| 228 | /// [`abort_send`]: crate::sync::PollSender::abort_send |
| 229 | /// [`Receiver`]: tokio::sync::mpsc::Receiver |
| 230 | pub fn close(&mut self) { |
| 231 | // Mark ourselves officially closed by dropping our main sender. |
| 232 | self.sender = None; |
| 233 | |
| 234 | // If we're already idle, closed, or we haven't yet reserved a slot, we can quickly |
| 235 | // transition to the closed state. Otherwise, leave the existing permit in place for the |
| 236 | // caller if they want to complete the send. |
| 237 | match self.state { |
| 238 | State::Idle(_) => self.state = State::Closed, |
| 239 | State::Acquiring => { |
| 240 | self.acquire.set(None); |
| 241 | self.state = State::Closed; |
| 242 | } |
| 243 | _ => {} |
| 244 | } |
| 245 | } |
| 246 | |
| 247 | /// Aborts the current in-progress send, if any. |
| 248 | /// |
| 249 | /// Returns `true` if a send was aborted. If the sender was closed prior to calling |
| 250 | /// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be |
| 251 | /// ready to attempt another send. |
| 252 | pub fn abort_send(&mut self) -> bool { |
| 253 | // We may have been closed in the meantime, after a call to `poll_reserve` already |
| 254 | // succeeded. We'll check if `self.sender` is `None` to see if we should transition to the |
| 255 | // closed state when we actually abort a send, rather than resetting ourselves back to idle. |
| 256 | |
| 257 | let (result, next_state) = match self.take_state() { |
| 258 | // We're currently trying to reserve a slot to send into. |
| 259 | State::Acquiring => { |
| 260 | // Replacing the future drops the in-flight one. |
| 261 | self.acquire.set(None); |
| 262 | |
| 263 | // If we haven't closed yet, we have to clone our stored sender since we have no way |
| 264 | // to get it back from the acquire future we just dropped. |
| 265 | let state = match self.sender.clone() { |
| 266 | Some(sender) => State::Idle(sender), |
| 267 | None => State::Closed, |
| 268 | }; |
| 269 | (true, state) |
| 270 | } |
| 271 | // We got the permit. If we haven't closed yet, get the sender back. |
| 272 | State::ReadyToSend(permit) => { |
| 273 | let state = if self.sender.is_some() { |
| 274 | State::Idle(permit.release()) |
| 275 | } else { |
| 276 | State::Closed |
| 277 | }; |
| 278 | (true, state) |
| 279 | } |
| 280 | s => (false, s), |
| 281 | }; |
| 282 | |
| 283 | self.state = next_state; |
| 284 | result |
| 285 | } |
| 286 | } |
| 287 | |
| 288 | impl<T> Clone for PollSender<T> { |
| 289 | /// Clones this `PollSender`. |
| 290 | /// |
| 291 | /// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`. |
| 292 | fn clone(&self) -> PollSender<T> { |
| 293 | let (sender: Option>, state: State) = match self.sender.clone() { |
| 294 | Some(sender: Sender) => (Some(sender.clone()), State::Idle(sender)), |
| 295 | None => (None, State::Closed), |
| 296 | }; |
| 297 | |
| 298 | Self { |
| 299 | sender, |
| 300 | state, |
| 301 | acquire: PollSenderFuture::empty(), |
| 302 | } |
| 303 | } |
| 304 | } |
| 305 | |
| 306 | impl<T: Send> Sink<T> for PollSender<T> { |
| 307 | type Error = PollSendError<T>; |
| 308 | |
| 309 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 310 | Pin::into_inner(self).poll_reserve(cx) |
| 311 | } |
| 312 | |
| 313 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 314 | Poll::Ready(Ok(())) |
| 315 | } |
| 316 | |
| 317 | fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { |
| 318 | Pin::into_inner(self).send_item(item) |
| 319 | } |
| 320 | |
| 321 | fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 322 | Pin::into_inner(self).close(); |
| 323 | Poll::Ready(Ok(())) |
| 324 | } |
| 325 | } |
| 326 | |