1 | use crate::task::{waker_ref, ArcWake}; |
2 | use futures_core::future::{FusedFuture, Future}; |
3 | use futures_core::task::{Context, Poll, Waker}; |
4 | use slab::Slab; |
5 | use std::cell::UnsafeCell; |
6 | use std::fmt; |
7 | use std::hash::Hasher; |
8 | use std::pin::Pin; |
9 | use std::ptr; |
10 | use std::sync::atomic::AtomicUsize; |
11 | use std::sync::atomic::Ordering::{Acquire, SeqCst}; |
12 | use std::sync::{Arc, Mutex, Weak}; |
13 | |
14 | /// Future for the [`shared`](super::FutureExt::shared) method. |
15 | #[must_use = "futures do nothing unless you `.await` or poll them" ] |
16 | pub struct Shared<Fut: Future> { |
17 | inner: Option<Arc<Inner<Fut>>>, |
18 | waker_key: usize, |
19 | } |
20 | |
21 | struct Inner<Fut: Future> { |
22 | future_or_output: UnsafeCell<FutureOrOutput<Fut>>, |
23 | notifier: Arc<Notifier>, |
24 | } |
25 | |
26 | struct Notifier { |
27 | state: AtomicUsize, |
28 | wakers: Mutex<Option<Slab<Option<Waker>>>>, |
29 | } |
30 | |
31 | /// A weak reference to a [`Shared`] that can be upgraded much like an `Arc`. |
32 | pub struct WeakShared<Fut: Future>(Weak<Inner<Fut>>); |
33 | |
34 | impl<Fut: Future> Clone for WeakShared<Fut> { |
35 | fn clone(&self) -> Self { |
36 | Self(self.0.clone()) |
37 | } |
38 | } |
39 | |
40 | // The future itself is polled behind the `Arc`, so it won't be moved |
41 | // when `Shared` is moved. |
42 | impl<Fut: Future> Unpin for Shared<Fut> {} |
43 | |
44 | impl<Fut: Future> fmt::Debug for Shared<Fut> { |
45 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
46 | f.debug_struct("Shared" ) |
47 | .field("inner" , &self.inner) |
48 | .field("waker_key" , &self.waker_key) |
49 | .finish() |
50 | } |
51 | } |
52 | |
53 | impl<Fut: Future> fmt::Debug for Inner<Fut> { |
54 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
55 | f.debug_struct("Inner" ).finish() |
56 | } |
57 | } |
58 | |
59 | impl<Fut: Future> fmt::Debug for WeakShared<Fut> { |
60 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
61 | f.debug_struct("WeakShared" ).finish() |
62 | } |
63 | } |
64 | |
65 | enum FutureOrOutput<Fut: Future> { |
66 | Future(Fut), |
67 | Output(Fut::Output), |
68 | } |
69 | |
70 | unsafe impl<Fut> Send for Inner<Fut> |
71 | where |
72 | Fut: Future + Send, |
73 | Fut::Output: Send + Sync, |
74 | { |
75 | } |
76 | |
77 | unsafe impl<Fut> Sync for Inner<Fut> |
78 | where |
79 | Fut: Future + Send, |
80 | Fut::Output: Send + Sync, |
81 | { |
82 | } |
83 | |
84 | const IDLE: usize = 0; |
85 | const POLLING: usize = 1; |
86 | const COMPLETE: usize = 2; |
87 | const POISONED: usize = 3; |
88 | |
89 | const NULL_WAKER_KEY: usize = usize::max_value(); |
90 | |
91 | impl<Fut: Future> Shared<Fut> { |
92 | pub(super) fn new(future: Fut) -> Self { |
93 | let inner = Inner { |
94 | future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)), |
95 | notifier: Arc::new(Notifier { |
96 | state: AtomicUsize::new(IDLE), |
97 | wakers: Mutex::new(Some(Slab::new())), |
98 | }), |
99 | }; |
100 | |
101 | Self { inner: Some(Arc::new(inner)), waker_key: NULL_WAKER_KEY } |
102 | } |
103 | } |
104 | |
105 | impl<Fut> Shared<Fut> |
106 | where |
107 | Fut: Future, |
108 | { |
109 | /// Returns [`Some`] containing a reference to this [`Shared`]'s output if |
110 | /// it has already been computed by a clone or [`None`] if it hasn't been |
111 | /// computed yet or this [`Shared`] already returned its output from |
112 | /// [`poll`](Future::poll). |
113 | pub fn peek(&self) -> Option<&Fut::Output> { |
114 | if let Some(inner) = self.inner.as_ref() { |
115 | match inner.notifier.state.load(SeqCst) { |
116 | COMPLETE => unsafe { return Some(inner.output()) }, |
117 | POISONED => panic!("inner future panicked during poll" ), |
118 | _ => {} |
119 | } |
120 | } |
121 | None |
122 | } |
123 | |
124 | /// Creates a new [`WeakShared`] for this [`Shared`]. |
125 | /// |
126 | /// Returns [`None`] if it has already been polled to completion. |
127 | pub fn downgrade(&self) -> Option<WeakShared<Fut>> { |
128 | if let Some(inner) = self.inner.as_ref() { |
129 | return Some(WeakShared(Arc::downgrade(inner))); |
130 | } |
131 | None |
132 | } |
133 | |
134 | /// Gets the number of strong pointers to this allocation. |
135 | /// |
136 | /// Returns [`None`] if it has already been polled to completion. |
137 | /// |
138 | /// # Safety |
139 | /// |
140 | /// This method by itself is safe, but using it correctly requires extra care. Another thread |
141 | /// can change the strong count at any time, including potentially between calling this method |
142 | /// and acting on the result. |
143 | #[allow (clippy::unnecessary_safety_doc)] |
144 | pub fn strong_count(&self) -> Option<usize> { |
145 | self.inner.as_ref().map(|arc| Arc::strong_count(arc)) |
146 | } |
147 | |
148 | /// Gets the number of weak pointers to this allocation. |
149 | /// |
150 | /// Returns [`None`] if it has already been polled to completion. |
151 | /// |
152 | /// # Safety |
153 | /// |
154 | /// This method by itself is safe, but using it correctly requires extra care. Another thread |
155 | /// can change the weak count at any time, including potentially between calling this method |
156 | /// and acting on the result. |
157 | #[allow (clippy::unnecessary_safety_doc)] |
158 | pub fn weak_count(&self) -> Option<usize> { |
159 | self.inner.as_ref().map(|arc| Arc::weak_count(arc)) |
160 | } |
161 | |
162 | /// Hashes the internal state of this `Shared` in a way that's compatible with `ptr_eq`. |
163 | pub fn ptr_hash<H: Hasher>(&self, state: &mut H) { |
164 | match self.inner.as_ref() { |
165 | Some(arc) => { |
166 | state.write_u8(1); |
167 | ptr::hash(Arc::as_ptr(arc), state); |
168 | } |
169 | None => { |
170 | state.write_u8(0); |
171 | } |
172 | } |
173 | } |
174 | |
175 | /// Returns `true` if the two `Shared`s point to the same future (in a vein similar to |
176 | /// `Arc::ptr_eq`). |
177 | /// |
178 | /// Returns `false` if either `Shared` has terminated. |
179 | pub fn ptr_eq(&self, rhs: &Self) -> bool { |
180 | let lhs = match self.inner.as_ref() { |
181 | Some(lhs) => lhs, |
182 | None => return false, |
183 | }; |
184 | let rhs = match rhs.inner.as_ref() { |
185 | Some(rhs) => rhs, |
186 | None => return false, |
187 | }; |
188 | Arc::ptr_eq(lhs, rhs) |
189 | } |
190 | } |
191 | |
192 | impl<Fut> Inner<Fut> |
193 | where |
194 | Fut: Future, |
195 | { |
196 | /// Safety: callers must first ensure that `self.inner.state` |
197 | /// is `COMPLETE` |
198 | unsafe fn output(&self) -> &Fut::Output { |
199 | match &*self.future_or_output.get() { |
200 | FutureOrOutput::Output(ref item) => item, |
201 | FutureOrOutput::Future(_) => unreachable!(), |
202 | } |
203 | } |
204 | } |
205 | |
206 | impl<Fut> Inner<Fut> |
207 | where |
208 | Fut: Future, |
209 | Fut::Output: Clone, |
210 | { |
211 | /// Registers the current task to receive a wakeup when we are awoken. |
212 | fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) { |
213 | let mut wakers_guard = self.notifier.wakers.lock().unwrap(); |
214 | |
215 | let wakers = match wakers_guard.as_mut() { |
216 | Some(wakers) => wakers, |
217 | None => return, |
218 | }; |
219 | |
220 | let new_waker = cx.waker(); |
221 | |
222 | if *waker_key == NULL_WAKER_KEY { |
223 | *waker_key = wakers.insert(Some(new_waker.clone())); |
224 | } else { |
225 | match wakers[*waker_key] { |
226 | Some(ref old_waker) if new_waker.will_wake(old_waker) => {} |
227 | // Could use clone_from here, but Waker doesn't specialize it. |
228 | ref mut slot => *slot = Some(new_waker.clone()), |
229 | } |
230 | } |
231 | debug_assert!(*waker_key != NULL_WAKER_KEY); |
232 | } |
233 | |
234 | /// Safety: callers must first ensure that `inner.state` |
235 | /// is `COMPLETE` |
236 | unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output { |
237 | match Arc::try_unwrap(self) { |
238 | Ok(inner) => match inner.future_or_output.into_inner() { |
239 | FutureOrOutput::Output(item) => item, |
240 | FutureOrOutput::Future(_) => unreachable!(), |
241 | }, |
242 | Err(inner) => inner.output().clone(), |
243 | } |
244 | } |
245 | } |
246 | |
247 | impl<Fut> FusedFuture for Shared<Fut> |
248 | where |
249 | Fut: Future, |
250 | Fut::Output: Clone, |
251 | { |
252 | fn is_terminated(&self) -> bool { |
253 | self.inner.is_none() |
254 | } |
255 | } |
256 | |
257 | impl<Fut> Future for Shared<Fut> |
258 | where |
259 | Fut: Future, |
260 | Fut::Output: Clone, |
261 | { |
262 | type Output = Fut::Output; |
263 | |
264 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
265 | let this = &mut *self; |
266 | |
267 | let inner = this.inner.take().expect("Shared future polled again after completion" ); |
268 | |
269 | // Fast path for when the wrapped future has already completed |
270 | if inner.notifier.state.load(Acquire) == COMPLETE { |
271 | // Safety: We're in the COMPLETE state |
272 | return unsafe { Poll::Ready(inner.take_or_clone_output()) }; |
273 | } |
274 | |
275 | inner.record_waker(&mut this.waker_key, cx); |
276 | |
277 | match inner |
278 | .notifier |
279 | .state |
280 | .compare_exchange(IDLE, POLLING, SeqCst, SeqCst) |
281 | .unwrap_or_else(|x| x) |
282 | { |
283 | IDLE => { |
284 | // Lock acquired, fall through |
285 | } |
286 | POLLING => { |
287 | // Another task is currently polling, at this point we just want |
288 | // to ensure that the waker for this task is registered |
289 | this.inner = Some(inner); |
290 | return Poll::Pending; |
291 | } |
292 | COMPLETE => { |
293 | // Safety: We're in the COMPLETE state |
294 | return unsafe { Poll::Ready(inner.take_or_clone_output()) }; |
295 | } |
296 | POISONED => panic!("inner future panicked during poll" ), |
297 | _ => unreachable!(), |
298 | } |
299 | |
300 | let waker = waker_ref(&inner.notifier); |
301 | let mut cx = Context::from_waker(&waker); |
302 | |
303 | struct Reset<'a> { |
304 | state: &'a AtomicUsize, |
305 | did_not_panic: bool, |
306 | } |
307 | |
308 | impl Drop for Reset<'_> { |
309 | fn drop(&mut self) { |
310 | if !self.did_not_panic { |
311 | self.state.store(POISONED, SeqCst); |
312 | } |
313 | } |
314 | } |
315 | |
316 | let mut reset = Reset { state: &inner.notifier.state, did_not_panic: false }; |
317 | |
318 | let output = { |
319 | let future = unsafe { |
320 | match &mut *inner.future_or_output.get() { |
321 | FutureOrOutput::Future(fut) => Pin::new_unchecked(fut), |
322 | _ => unreachable!(), |
323 | } |
324 | }; |
325 | |
326 | let poll_result = future.poll(&mut cx); |
327 | reset.did_not_panic = true; |
328 | |
329 | match poll_result { |
330 | Poll::Pending => { |
331 | if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok() |
332 | { |
333 | // Success |
334 | drop(reset); |
335 | this.inner = Some(inner); |
336 | return Poll::Pending; |
337 | } else { |
338 | unreachable!() |
339 | } |
340 | } |
341 | Poll::Ready(output) => output, |
342 | } |
343 | }; |
344 | |
345 | unsafe { |
346 | *inner.future_or_output.get() = FutureOrOutput::Output(output); |
347 | } |
348 | |
349 | inner.notifier.state.store(COMPLETE, SeqCst); |
350 | |
351 | // Wake all tasks and drop the slab |
352 | let mut wakers_guard = inner.notifier.wakers.lock().unwrap(); |
353 | let mut wakers = wakers_guard.take().unwrap(); |
354 | for waker in wakers.drain().flatten() { |
355 | waker.wake(); |
356 | } |
357 | |
358 | drop(reset); // Make borrow checker happy |
359 | drop(wakers_guard); |
360 | |
361 | // Safety: We're in the COMPLETE state |
362 | unsafe { Poll::Ready(inner.take_or_clone_output()) } |
363 | } |
364 | } |
365 | |
366 | impl<Fut> Clone for Shared<Fut> |
367 | where |
368 | Fut: Future, |
369 | { |
370 | fn clone(&self) -> Self { |
371 | Self { inner: self.inner.clone(), waker_key: NULL_WAKER_KEY } |
372 | } |
373 | } |
374 | |
375 | impl<Fut> Drop for Shared<Fut> |
376 | where |
377 | Fut: Future, |
378 | { |
379 | fn drop(&mut self) { |
380 | if self.waker_key != NULL_WAKER_KEY { |
381 | if let Some(ref inner) = self.inner { |
382 | if let Ok(mut wakers) = inner.notifier.wakers.lock() { |
383 | if let Some(wakers) = wakers.as_mut() { |
384 | wakers.remove(self.waker_key); |
385 | } |
386 | } |
387 | } |
388 | } |
389 | } |
390 | } |
391 | |
392 | impl ArcWake for Notifier { |
393 | fn wake_by_ref(arc_self: &Arc<Self>) { |
394 | let wakers = &mut *arc_self.wakers.lock().unwrap(); |
395 | if let Some(wakers) = wakers.as_mut() { |
396 | for (_key, opt_waker) in wakers { |
397 | if let Some(waker) = opt_waker.take() { |
398 | waker.wake(); |
399 | } |
400 | } |
401 | } |
402 | } |
403 | } |
404 | |
405 | impl<Fut: Future> WeakShared<Fut> { |
406 | /// Attempts to upgrade this [`WeakShared`] into a [`Shared`]. |
407 | /// |
408 | /// Returns [`None`] if all clones of the [`Shared`] have been dropped or polled |
409 | /// to completion. |
410 | pub fn upgrade(&self) -> Option<Shared<Fut>> { |
411 | Some(Shared { inner: Some(self.0.upgrade()?), waker_key: NULL_WAKER_KEY }) |
412 | } |
413 | } |
414 | |