1use crate::task::AtomicWaker;
2use alloc::sync::Arc;
3use core::fmt;
4use core::pin::Pin;
5use core::sync::atomic::{AtomicBool, Ordering};
6use futures_core::future::Future;
7use futures_core::task::{Context, Poll};
8use futures_core::Stream;
9use pin_project_lite::pin_project;
10
11pin_project! {
12 /// A future/stream which can be remotely short-circuited using an `AbortHandle`.
13 #[derive(Debug, Clone)]
14 #[must_use = "futures/streams do nothing unless you poll them"]
15 pub struct Abortable<T> {
16 #[pin]
17 task: T,
18 inner: Arc<AbortInner>,
19 }
20}
21
22impl<T> Abortable<T> {
23 /// Creates a new `Abortable` future/stream using an existing `AbortRegistration`.
24 /// `AbortRegistration`s can be acquired through `AbortHandle::new`.
25 ///
26 /// When `abort` is called on the handle tied to `reg` or if `abort` has
27 /// already been called, the future/stream will complete immediately without making
28 /// any further progress.
29 ///
30 /// # Examples:
31 ///
32 /// Usage with futures:
33 ///
34 /// ```
35 /// # futures::executor::block_on(async {
36 /// use futures::future::{Abortable, AbortHandle, Aborted};
37 ///
38 /// let (abort_handle, abort_registration) = AbortHandle::new_pair();
39 /// let future = Abortable::new(async { 2 }, abort_registration);
40 /// abort_handle.abort();
41 /// assert_eq!(future.await, Err(Aborted));
42 /// # });
43 /// ```
44 ///
45 /// Usage with streams:
46 ///
47 /// ```
48 /// # futures::executor::block_on(async {
49 /// # use futures::future::{Abortable, AbortHandle};
50 /// # use futures::stream::{self, StreamExt};
51 ///
52 /// let (abort_handle, abort_registration) = AbortHandle::new_pair();
53 /// let mut stream = Abortable::new(stream::iter(vec![1, 2, 3]), abort_registration);
54 /// abort_handle.abort();
55 /// assert_eq!(stream.next().await, None);
56 /// # });
57 /// ```
58 pub fn new(task: T, reg: AbortRegistration) -> Self {
59 Self { task, inner: reg.inner }
60 }
61
62 /// Checks whether the task has been aborted. Note that all this
63 /// method indicates is whether [`AbortHandle::abort`] was *called*.
64 /// This means that it will return `true` even if:
65 /// * `abort` was called after the task had completed.
66 /// * `abort` was called while the task was being polled - the task may still be running and
67 /// will not be stopped until `poll` returns.
68 pub fn is_aborted(&self) -> bool {
69 self.inner.aborted.load(Ordering::Relaxed)
70 }
71}
72
73/// A registration handle for an `Abortable` task.
74/// Values of this type can be acquired from `AbortHandle::new` and are used
75/// in calls to `Abortable::new`.
76#[derive(Debug)]
77pub struct AbortRegistration {
78 pub(crate) inner: Arc<AbortInner>,
79}
80
81impl AbortRegistration {
82 /// Create an [`AbortHandle`] from the given [`AbortRegistration`].
83 ///
84 /// The created [`AbortHandle`] is functionally the same as any other
85 /// [`AbortHandle`]s that are associated with the same [`AbortRegistration`],
86 /// such as the one created by [`AbortHandle::new_pair`].
87 pub fn handle(&self) -> AbortHandle {
88 AbortHandle { inner: self.inner.clone() }
89 }
90}
91
92/// A handle to an `Abortable` task.
93#[derive(Debug, Clone)]
94pub struct AbortHandle {
95 inner: Arc<AbortInner>,
96}
97
98impl AbortHandle {
99 /// Creates an (`AbortHandle`, `AbortRegistration`) pair which can be used
100 /// to abort a running future or stream.
101 ///
102 /// This function is usually paired with a call to [`Abortable::new`].
103 pub fn new_pair() -> (Self, AbortRegistration) {
104 let inner =
105 Arc::new(AbortInner { waker: AtomicWaker::new(), aborted: AtomicBool::new(false) });
106
107 (Self { inner: inner.clone() }, AbortRegistration { inner })
108 }
109}
110
111// Inner type storing the waker to awaken and a bool indicating that it
112// should be aborted.
113#[derive(Debug)]
114pub(crate) struct AbortInner {
115 pub(crate) waker: AtomicWaker,
116 pub(crate) aborted: AtomicBool,
117}
118
119/// Indicator that the `Abortable` task was aborted.
120#[derive(Copy, Clone, Debug, Eq, PartialEq)]
121pub struct Aborted;
122
123impl fmt::Display for Aborted {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 write!(f, "`Abortable` future has been aborted")
126 }
127}
128
129#[cfg(feature = "std")]
130impl std::error::Error for Aborted {}
131
132impl<T> Abortable<T> {
133 fn try_poll<I>(
134 mut self: Pin<&mut Self>,
135 cx: &mut Context<'_>,
136 poll: impl Fn(Pin<&mut T>, &mut Context<'_>) -> Poll<I>,
137 ) -> Poll<Result<I, Aborted>> {
138 // Check if the task has been aborted
139 if self.is_aborted() {
140 return Poll::Ready(Err(Aborted));
141 }
142
143 // attempt to complete the task
144 if let Poll::Ready(x) = poll(self.as_mut().project().task, cx) {
145 return Poll::Ready(Ok(x));
146 }
147
148 // Register to receive a wakeup if the task is aborted in the future
149 self.inner.waker.register(cx.waker());
150
151 // Check to see if the task was aborted between the first check and
152 // registration.
153 // Checking with `is_aborted` which uses `Relaxed` is sufficient because
154 // `register` introduces an `AcqRel` barrier.
155 if self.is_aborted() {
156 return Poll::Ready(Err(Aborted));
157 }
158
159 Poll::Pending
160 }
161}
162
163impl<Fut> Future for Abortable<Fut>
164where
165 Fut: Future,
166{
167 type Output = Result<Fut::Output, Aborted>;
168
169 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
170 self.try_poll(cx, |fut, cx| fut.poll(cx))
171 }
172}
173
174impl<St> Stream for Abortable<St>
175where
176 St: Stream,
177{
178 type Item = St::Item;
179
180 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
181 self.try_poll(cx, |stream, cx| stream.poll_next(cx)).map(Result::ok).map(Option::flatten)
182 }
183}
184
185impl AbortHandle {
186 /// Abort the `Abortable` stream/future associated with this handle.
187 ///
188 /// Notifies the Abortable task associated with this handle that it
189 /// should abort. Note that if the task is currently being polled on
190 /// another thread, it will not immediately stop running. Instead, it will
191 /// continue to run until its poll method returns.
192 pub fn abort(&self) {
193 self.inner.aborted.store(true, Ordering::Relaxed);
194 self.inner.waker.wake();
195 }
196
197 /// Checks whether [`AbortHandle::abort`] was *called* on any associated
198 /// [`AbortHandle`]s, which includes all the [`AbortHandle`]s linked with
199 /// the same [`AbortRegistration`]. This means that it will return `true`
200 /// even if:
201 /// * `abort` was called after the task had completed.
202 /// * `abort` was called while the task was being polled - the task may still be running and
203 /// will not be stopped until `poll` returns.
204 ///
205 /// This operation has a Relaxed ordering.
206 pub fn is_aborted(&self) -> bool {
207 self.inner.aborted.load(Ordering::Relaxed)
208 }
209}
210