1use crate::task::{waker_ref, ArcWake};
2use futures_core::future::{FusedFuture, Future};
3use futures_core::task::{Context, Poll, Waker};
4use slab::Slab;
5use std::cell::UnsafeCell;
6use std::fmt;
7use std::hash::Hasher;
8use std::pin::Pin;
9use std::ptr;
10use std::sync::atomic::AtomicUsize;
11use std::sync::atomic::Ordering::{Acquire, SeqCst};
12use std::sync::{Arc, Mutex, Weak};
13
14/// Future for the [`shared`](super::FutureExt::shared) method.
15#[must_use = "futures do nothing unless you `.await` or poll them"]
16pub struct Shared<Fut: Future> {
17 inner: Option<Arc<Inner<Fut>>>,
18 waker_key: usize,
19}
20
21struct Inner<Fut: Future> {
22 future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
23 notifier: Arc<Notifier>,
24}
25
26struct Notifier {
27 state: AtomicUsize,
28 wakers: Mutex<Option<Slab<Option<Waker>>>>,
29}
30
31/// A weak reference to a [`Shared`] that can be upgraded much like an `Arc`.
32pub struct WeakShared<Fut: Future>(Weak<Inner<Fut>>);
33
34impl<Fut: Future> Clone for WeakShared<Fut> {
35 fn clone(&self) -> Self {
36 Self(self.0.clone())
37 }
38}
39
40// The future itself is polled behind the `Arc`, so it won't be moved
41// when `Shared` is moved.
42impl<Fut: Future> Unpin for Shared<Fut> {}
43
44impl<Fut: Future> fmt::Debug for Shared<Fut> {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.debug_struct("Shared")
47 .field("inner", &self.inner)
48 .field("waker_key", &self.waker_key)
49 .finish()
50 }
51}
52
53impl<Fut: Future> fmt::Debug for Inner<Fut> {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 f.debug_struct("Inner").finish()
56 }
57}
58
59impl<Fut: Future> fmt::Debug for WeakShared<Fut> {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.debug_struct("WeakShared").finish()
62 }
63}
64
65enum FutureOrOutput<Fut: Future> {
66 Future(Fut),
67 Output(Fut::Output),
68}
69
70unsafe impl<Fut> Send for Inner<Fut>
71where
72 Fut: Future + Send,
73 Fut::Output: Send + Sync,
74{
75}
76
77unsafe impl<Fut> Sync for Inner<Fut>
78where
79 Fut: Future + Send,
80 Fut::Output: Send + Sync,
81{
82}
83
84const IDLE: usize = 0;
85const POLLING: usize = 1;
86const COMPLETE: usize = 2;
87const POISONED: usize = 3;
88
89const NULL_WAKER_KEY: usize = usize::max_value();
90
91impl<Fut: Future> Shared<Fut> {
92 pub(super) fn new(future: Fut) -> Self {
93 let inner = Inner {
94 future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
95 notifier: Arc::new(Notifier {
96 state: AtomicUsize::new(IDLE),
97 wakers: Mutex::new(Some(Slab::new())),
98 }),
99 };
100
101 Self { inner: Some(Arc::new(inner)), waker_key: NULL_WAKER_KEY }
102 }
103}
104
105impl<Fut> Shared<Fut>
106where
107 Fut: Future,
108{
109 /// Returns [`Some`] containing a reference to this [`Shared`]'s output if
110 /// it has already been computed by a clone or [`None`] if it hasn't been
111 /// computed yet or this [`Shared`] already returned its output from
112 /// [`poll`](Future::poll).
113 pub fn peek(&self) -> Option<&Fut::Output> {
114 if let Some(inner) = self.inner.as_ref() {
115 match inner.notifier.state.load(SeqCst) {
116 COMPLETE => unsafe { return Some(inner.output()) },
117 POISONED => panic!("inner future panicked during poll"),
118 _ => {}
119 }
120 }
121 None
122 }
123
124 /// Creates a new [`WeakShared`] for this [`Shared`].
125 ///
126 /// Returns [`None`] if it has already been polled to completion.
127 pub fn downgrade(&self) -> Option<WeakShared<Fut>> {
128 if let Some(inner) = self.inner.as_ref() {
129 return Some(WeakShared(Arc::downgrade(inner)));
130 }
131 None
132 }
133
134 /// Gets the number of strong pointers to this allocation.
135 ///
136 /// Returns [`None`] if it has already been polled to completion.
137 ///
138 /// # Safety
139 ///
140 /// This method by itself is safe, but using it correctly requires extra care. Another thread
141 /// can change the strong count at any time, including potentially between calling this method
142 /// and acting on the result.
143 #[allow(clippy::unnecessary_safety_doc)]
144 pub fn strong_count(&self) -> Option<usize> {
145 self.inner.as_ref().map(|arc| Arc::strong_count(arc))
146 }
147
148 /// Gets the number of weak pointers to this allocation.
149 ///
150 /// Returns [`None`] if it has already been polled to completion.
151 ///
152 /// # Safety
153 ///
154 /// This method by itself is safe, but using it correctly requires extra care. Another thread
155 /// can change the weak count at any time, including potentially between calling this method
156 /// and acting on the result.
157 #[allow(clippy::unnecessary_safety_doc)]
158 pub fn weak_count(&self) -> Option<usize> {
159 self.inner.as_ref().map(|arc| Arc::weak_count(arc))
160 }
161
162 /// Hashes the internal state of this `Shared` in a way that's compatible with `ptr_eq`.
163 pub fn ptr_hash<H: Hasher>(&self, state: &mut H) {
164 match self.inner.as_ref() {
165 Some(arc) => {
166 state.write_u8(1);
167 ptr::hash(Arc::as_ptr(arc), state);
168 }
169 None => {
170 state.write_u8(0);
171 }
172 }
173 }
174
175 /// Returns `true` if the two `Shared`s point to the same future (in a vein similar to
176 /// `Arc::ptr_eq`).
177 ///
178 /// Returns `false` if either `Shared` has terminated.
179 pub fn ptr_eq(&self, rhs: &Self) -> bool {
180 let lhs = match self.inner.as_ref() {
181 Some(lhs) => lhs,
182 None => return false,
183 };
184 let rhs = match rhs.inner.as_ref() {
185 Some(rhs) => rhs,
186 None => return false,
187 };
188 Arc::ptr_eq(lhs, rhs)
189 }
190}
191
192impl<Fut> Inner<Fut>
193where
194 Fut: Future,
195{
196 /// Safety: callers must first ensure that `self.inner.state`
197 /// is `COMPLETE`
198 unsafe fn output(&self) -> &Fut::Output {
199 match &*self.future_or_output.get() {
200 FutureOrOutput::Output(ref item) => item,
201 FutureOrOutput::Future(_) => unreachable!(),
202 }
203 }
204}
205
206impl<Fut> Inner<Fut>
207where
208 Fut: Future,
209 Fut::Output: Clone,
210{
211 /// Registers the current task to receive a wakeup when we are awoken.
212 fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
213 let mut wakers_guard = self.notifier.wakers.lock().unwrap();
214
215 let wakers = match wakers_guard.as_mut() {
216 Some(wakers) => wakers,
217 None => return,
218 };
219
220 let new_waker = cx.waker();
221
222 if *waker_key == NULL_WAKER_KEY {
223 *waker_key = wakers.insert(Some(new_waker.clone()));
224 } else {
225 match wakers[*waker_key] {
226 Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
227 // Could use clone_from here, but Waker doesn't specialize it.
228 ref mut slot => *slot = Some(new_waker.clone()),
229 }
230 }
231 debug_assert!(*waker_key != NULL_WAKER_KEY);
232 }
233
234 /// Safety: callers must first ensure that `inner.state`
235 /// is `COMPLETE`
236 unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output {
237 match Arc::try_unwrap(self) {
238 Ok(inner) => match inner.future_or_output.into_inner() {
239 FutureOrOutput::Output(item) => item,
240 FutureOrOutput::Future(_) => unreachable!(),
241 },
242 Err(inner) => inner.output().clone(),
243 }
244 }
245}
246
247impl<Fut> FusedFuture for Shared<Fut>
248where
249 Fut: Future,
250 Fut::Output: Clone,
251{
252 fn is_terminated(&self) -> bool {
253 self.inner.is_none()
254 }
255}
256
257impl<Fut> Future for Shared<Fut>
258where
259 Fut: Future,
260 Fut::Output: Clone,
261{
262 type Output = Fut::Output;
263
264 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265 let this = &mut *self;
266
267 let inner = this.inner.take().expect("Shared future polled again after completion");
268
269 // Fast path for when the wrapped future has already completed
270 if inner.notifier.state.load(Acquire) == COMPLETE {
271 // Safety: We're in the COMPLETE state
272 return unsafe { Poll::Ready(inner.take_or_clone_output()) };
273 }
274
275 inner.record_waker(&mut this.waker_key, cx);
276
277 match inner
278 .notifier
279 .state
280 .compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
281 .unwrap_or_else(|x| x)
282 {
283 IDLE => {
284 // Lock acquired, fall through
285 }
286 POLLING => {
287 // Another task is currently polling, at this point we just want
288 // to ensure that the waker for this task is registered
289 this.inner = Some(inner);
290 return Poll::Pending;
291 }
292 COMPLETE => {
293 // Safety: We're in the COMPLETE state
294 return unsafe { Poll::Ready(inner.take_or_clone_output()) };
295 }
296 POISONED => panic!("inner future panicked during poll"),
297 _ => unreachable!(),
298 }
299
300 let waker = waker_ref(&inner.notifier);
301 let mut cx = Context::from_waker(&waker);
302
303 struct Reset<'a> {
304 state: &'a AtomicUsize,
305 did_not_panic: bool,
306 }
307
308 impl Drop for Reset<'_> {
309 fn drop(&mut self) {
310 if !self.did_not_panic {
311 self.state.store(POISONED, SeqCst);
312 }
313 }
314 }
315
316 let mut reset = Reset { state: &inner.notifier.state, did_not_panic: false };
317
318 let output = {
319 let future = unsafe {
320 match &mut *inner.future_or_output.get() {
321 FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
322 _ => unreachable!(),
323 }
324 };
325
326 let poll_result = future.poll(&mut cx);
327 reset.did_not_panic = true;
328
329 match poll_result {
330 Poll::Pending => {
331 if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok()
332 {
333 // Success
334 drop(reset);
335 this.inner = Some(inner);
336 return Poll::Pending;
337 } else {
338 unreachable!()
339 }
340 }
341 Poll::Ready(output) => output,
342 }
343 };
344
345 unsafe {
346 *inner.future_or_output.get() = FutureOrOutput::Output(output);
347 }
348
349 inner.notifier.state.store(COMPLETE, SeqCst);
350
351 // Wake all tasks and drop the slab
352 let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
353 let mut wakers = wakers_guard.take().unwrap();
354 for waker in wakers.drain().flatten() {
355 waker.wake();
356 }
357
358 drop(reset); // Make borrow checker happy
359 drop(wakers_guard);
360
361 // Safety: We're in the COMPLETE state
362 unsafe { Poll::Ready(inner.take_or_clone_output()) }
363 }
364}
365
366impl<Fut> Clone for Shared<Fut>
367where
368 Fut: Future,
369{
370 fn clone(&self) -> Self {
371 Self { inner: self.inner.clone(), waker_key: NULL_WAKER_KEY }
372 }
373}
374
375impl<Fut> Drop for Shared<Fut>
376where
377 Fut: Future,
378{
379 fn drop(&mut self) {
380 if self.waker_key != NULL_WAKER_KEY {
381 if let Some(ref inner) = self.inner {
382 if let Ok(mut wakers) = inner.notifier.wakers.lock() {
383 if let Some(wakers) = wakers.as_mut() {
384 wakers.remove(self.waker_key);
385 }
386 }
387 }
388 }
389 }
390}
391
392impl ArcWake for Notifier {
393 fn wake_by_ref(arc_self: &Arc<Self>) {
394 let wakers = &mut *arc_self.wakers.lock().unwrap();
395 if let Some(wakers) = wakers.as_mut() {
396 for (_key, opt_waker) in wakers {
397 if let Some(waker) = opt_waker.take() {
398 waker.wake();
399 }
400 }
401 }
402 }
403}
404
405impl<Fut: Future> WeakShared<Fut> {
406 /// Attempts to upgrade this [`WeakShared`] into a [`Shared`].
407 ///
408 /// Returns [`None`] if all clones of the [`Shared`] have been dropped or polled
409 /// to completion.
410 pub fn upgrade(&self) -> Option<Shared<Fut>> {
411 Some(Shared { inner: Some(self.0.upgrade()?), waker_key: NULL_WAKER_KEY })
412 }
413}
414