| 1 | use crate::task::AtomicWaker; |
| 2 | use alloc::sync::Arc; |
| 3 | use core::fmt; |
| 4 | use core::pin::Pin; |
| 5 | use core::sync::atomic::{AtomicBool, Ordering}; |
| 6 | use futures_core::future::Future; |
| 7 | use futures_core::task::{Context, Poll}; |
| 8 | use futures_core::Stream; |
| 9 | use pin_project_lite::pin_project; |
| 10 | |
| 11 | pin_project! { |
| 12 | /// A future/stream which can be remotely short-circuited using an `AbortHandle`. |
| 13 | #[derive (Debug, Clone)] |
| 14 | #[must_use = "futures/streams do nothing unless you poll them" ] |
| 15 | pub struct Abortable<T> { |
| 16 | #[pin] |
| 17 | task: T, |
| 18 | inner: Arc<AbortInner>, |
| 19 | } |
| 20 | } |
| 21 | |
| 22 | impl<T> Abortable<T> { |
| 23 | /// Creates a new `Abortable` future/stream using an existing `AbortRegistration`. |
| 24 | /// `AbortRegistration`s can be acquired through `AbortHandle::new`. |
| 25 | /// |
| 26 | /// When `abort` is called on the handle tied to `reg` or if `abort` has |
| 27 | /// already been called, the future/stream will complete immediately without making |
| 28 | /// any further progress. |
| 29 | /// |
| 30 | /// # Examples: |
| 31 | /// |
| 32 | /// Usage with futures: |
| 33 | /// |
| 34 | /// ``` |
| 35 | /// # futures::executor::block_on(async { |
| 36 | /// use futures::future::{Abortable, AbortHandle, Aborted}; |
| 37 | /// |
| 38 | /// let (abort_handle, abort_registration) = AbortHandle::new_pair(); |
| 39 | /// let future = Abortable::new(async { 2 }, abort_registration); |
| 40 | /// abort_handle.abort(); |
| 41 | /// assert_eq!(future.await, Err(Aborted)); |
| 42 | /// # }); |
| 43 | /// ``` |
| 44 | /// |
| 45 | /// Usage with streams: |
| 46 | /// |
| 47 | /// ``` |
| 48 | /// # futures::executor::block_on(async { |
| 49 | /// # use futures::future::{Abortable, AbortHandle}; |
| 50 | /// # use futures::stream::{self, StreamExt}; |
| 51 | /// |
| 52 | /// let (abort_handle, abort_registration) = AbortHandle::new_pair(); |
| 53 | /// let mut stream = Abortable::new(stream::iter(vec![1, 2, 3]), abort_registration); |
| 54 | /// abort_handle.abort(); |
| 55 | /// assert_eq!(stream.next().await, None); |
| 56 | /// # }); |
| 57 | /// ``` |
| 58 | pub fn new(task: T, reg: AbortRegistration) -> Self { |
| 59 | Self { task, inner: reg.inner } |
| 60 | } |
| 61 | |
| 62 | /// Checks whether the task has been aborted. Note that all this |
| 63 | /// method indicates is whether [`AbortHandle::abort`] was *called*. |
| 64 | /// This means that it will return `true` even if: |
| 65 | /// * `abort` was called after the task had completed. |
| 66 | /// * `abort` was called while the task was being polled - the task may still be running and |
| 67 | /// will not be stopped until `poll` returns. |
| 68 | pub fn is_aborted(&self) -> bool { |
| 69 | self.inner.aborted.load(Ordering::Relaxed) |
| 70 | } |
| 71 | } |
| 72 | |
| 73 | /// A registration handle for an `Abortable` task. |
| 74 | /// Values of this type can be acquired from `AbortHandle::new` and are used |
| 75 | /// in calls to `Abortable::new`. |
| 76 | #[derive (Debug)] |
| 77 | pub struct AbortRegistration { |
| 78 | pub(crate) inner: Arc<AbortInner>, |
| 79 | } |
| 80 | |
| 81 | impl AbortRegistration { |
| 82 | /// Create an [`AbortHandle`] from the given [`AbortRegistration`]. |
| 83 | /// |
| 84 | /// The created [`AbortHandle`] is functionally the same as any other |
| 85 | /// [`AbortHandle`]s that are associated with the same [`AbortRegistration`], |
| 86 | /// such as the one created by [`AbortHandle::new_pair`]. |
| 87 | pub fn handle(&self) -> AbortHandle { |
| 88 | AbortHandle { inner: self.inner.clone() } |
| 89 | } |
| 90 | } |
| 91 | |
| 92 | /// A handle to an `Abortable` task. |
| 93 | #[derive (Debug, Clone)] |
| 94 | pub struct AbortHandle { |
| 95 | inner: Arc<AbortInner>, |
| 96 | } |
| 97 | |
| 98 | impl AbortHandle { |
| 99 | /// Creates an (`AbortHandle`, `AbortRegistration`) pair which can be used |
| 100 | /// to abort a running future or stream. |
| 101 | /// |
| 102 | /// This function is usually paired with a call to [`Abortable::new`]. |
| 103 | pub fn new_pair() -> (Self, AbortRegistration) { |
| 104 | let inner: Arc = |
| 105 | Arc::new(data:AbortInner { waker: AtomicWaker::new(), aborted: AtomicBool::new(false) }); |
| 106 | |
| 107 | (Self { inner: inner.clone() }, AbortRegistration { inner }) |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | // Inner type storing the waker to awaken and a bool indicating that it |
| 112 | // should be aborted. |
| 113 | #[derive (Debug)] |
| 114 | pub(crate) struct AbortInner { |
| 115 | pub(crate) waker: AtomicWaker, |
| 116 | pub(crate) aborted: AtomicBool, |
| 117 | } |
| 118 | |
| 119 | /// Indicator that the `Abortable` task was aborted. |
| 120 | #[derive (Copy, Clone, Debug, Eq, PartialEq)] |
| 121 | pub struct Aborted; |
| 122 | |
| 123 | impl fmt::Display for Aborted { |
| 124 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 125 | write!(f, "`Abortable` future has been aborted" ) |
| 126 | } |
| 127 | } |
| 128 | |
| 129 | #[cfg (feature = "std" )] |
| 130 | impl std::error::Error for Aborted {} |
| 131 | |
| 132 | impl<T> Abortable<T> { |
| 133 | fn try_poll<I>( |
| 134 | mut self: Pin<&mut Self>, |
| 135 | cx: &mut Context<'_>, |
| 136 | poll: impl Fn(Pin<&mut T>, &mut Context<'_>) -> Poll<I>, |
| 137 | ) -> Poll<Result<I, Aborted>> { |
| 138 | // Check if the task has been aborted |
| 139 | if self.is_aborted() { |
| 140 | return Poll::Ready(Err(Aborted)); |
| 141 | } |
| 142 | |
| 143 | // attempt to complete the task |
| 144 | if let Poll::Ready(x) = poll(self.as_mut().project().task, cx) { |
| 145 | return Poll::Ready(Ok(x)); |
| 146 | } |
| 147 | |
| 148 | // Register to receive a wakeup if the task is aborted in the future |
| 149 | self.inner.waker.register(cx.waker()); |
| 150 | |
| 151 | // Check to see if the task was aborted between the first check and |
| 152 | // registration. |
| 153 | // Checking with `is_aborted` which uses `Relaxed` is sufficient because |
| 154 | // `register` introduces an `AcqRel` barrier. |
| 155 | if self.is_aborted() { |
| 156 | return Poll::Ready(Err(Aborted)); |
| 157 | } |
| 158 | |
| 159 | Poll::Pending |
| 160 | } |
| 161 | } |
| 162 | |
| 163 | impl<Fut> Future for Abortable<Fut> |
| 164 | where |
| 165 | Fut: Future, |
| 166 | { |
| 167 | type Output = Result<Fut::Output, Aborted>; |
| 168 | |
| 169 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 170 | self.try_poll(cx, |fut: Pin<&mut Fut>, cx: &mut Context<'_>| fut.poll(cx)) |
| 171 | } |
| 172 | } |
| 173 | |
| 174 | impl<St> Stream for Abortable<St> |
| 175 | where |
| 176 | St: Stream, |
| 177 | { |
| 178 | type Item = St::Item; |
| 179 | |
| 180 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 181 | self.try_poll(cx, |stream: Pin<&mut St>, cx: &mut Context<'_>| stream.poll_next(cx)).map(Result::ok).map(Option::flatten) |
| 182 | } |
| 183 | } |
| 184 | |
| 185 | impl AbortHandle { |
| 186 | /// Abort the `Abortable` stream/future associated with this handle. |
| 187 | /// |
| 188 | /// Notifies the Abortable task associated with this handle that it |
| 189 | /// should abort. Note that if the task is currently being polled on |
| 190 | /// another thread, it will not immediately stop running. Instead, it will |
| 191 | /// continue to run until its poll method returns. |
| 192 | pub fn abort(&self) { |
| 193 | self.inner.aborted.store(true, Ordering::Relaxed); |
| 194 | self.inner.waker.wake(); |
| 195 | } |
| 196 | |
| 197 | /// Checks whether [`AbortHandle::abort`] was *called* on any associated |
| 198 | /// [`AbortHandle`]s, which includes all the [`AbortHandle`]s linked with |
| 199 | /// the same [`AbortRegistration`]. This means that it will return `true` |
| 200 | /// even if: |
| 201 | /// * `abort` was called after the task had completed. |
| 202 | /// * `abort` was called while the task was being polled - the task may still be running and |
| 203 | /// will not be stopped until `poll` returns. |
| 204 | /// |
| 205 | /// This operation has a Relaxed ordering. |
| 206 | pub fn is_aborted(&self) -> bool { |
| 207 | self.inner.aborted.load(Ordering::Relaxed) |
| 208 | } |
| 209 | } |
| 210 | |