1use super::{Context, CONTEXT};
2
3use crate::runtime::{scheduler, TryCurrentError};
4use crate::util::markers::SyncNotSend;
5
6use std::cell::{Cell, RefCell};
7use std::marker::PhantomData;
8
9#[derive(Debug)]
10#[must_use]
11pub(crate) struct SetCurrentGuard {
12 // The previous handle
13 prev: Option<scheduler::Handle>,
14
15 // The depth for this guard
16 depth: usize,
17
18 // Don't let the type move across threads.
19 _p: PhantomData<SyncNotSend>,
20}
21
22pub(super) struct HandleCell {
23 /// Current handle
24 handle: RefCell<Option<scheduler::Handle>>,
25
26 /// Tracks the number of nested calls to `try_set_current`.
27 depth: Cell<usize>,
28}
29
30/// Sets this [`Handle`] as the current active [`Handle`].
31///
32/// [`Handle`]: crate::runtime::scheduler::Handle
33pub(crate) fn try_set_current(handle: &scheduler::Handle) -> Option<SetCurrentGuard> {
34 CONTEXT.try_with(|ctx| ctx.set_current(handle)).ok()
35}
36
37pub(crate) fn with_current<F, R>(f: F) -> Result<R, TryCurrentError>
38where
39 F: FnOnce(&scheduler::Handle) -> R,
40{
41 match CONTEXT.try_with(|ctx| ctx.current.handle.borrow().as_ref().map(f)) {
42 Ok(Some(ret)) => Ok(ret),
43 Ok(None) => Err(TryCurrentError::new_no_context()),
44 Err(_access_error) => Err(TryCurrentError::new_thread_local_destroyed()),
45 }
46}
47
48impl Context {
49 pub(super) fn set_current(&self, handle: &scheduler::Handle) -> SetCurrentGuard {
50 let old_handle = self.current.handle.borrow_mut().replace(handle.clone());
51 let depth = self.current.depth.get();
52
53 assert!(depth != usize::MAX, "reached max `enter` depth");
54
55 let depth = depth + 1;
56 self.current.depth.set(depth);
57
58 SetCurrentGuard {
59 prev: old_handle,
60 depth,
61 _p: PhantomData,
62 }
63 }
64}
65
66impl HandleCell {
67 pub(super) const fn new() -> HandleCell {
68 HandleCell {
69 handle: RefCell::new(None),
70 depth: Cell::new(0),
71 }
72 }
73}
74
75impl Drop for SetCurrentGuard {
76 fn drop(&mut self) {
77 CONTEXT.with(|ctx| {
78 let depth = ctx.current.depth.get();
79
80 if depth != self.depth {
81 if !std::thread::panicking() {
82 panic!(
83 "`EnterGuard` values dropped out of order. Guards returned by \
84 `tokio::runtime::Handle::enter()` must be dropped in the reverse \
85 order as they were acquired."
86 );
87 } else {
88 // Just return... this will leave handles in a wonky state though...
89 return;
90 }
91 }
92
93 *ctx.current.handle.borrow_mut() = self.prev.take();
94 ctx.current.depth.set(depth - 1);
95 });
96 }
97}
98