1 | use super::{BlockingRegionGuard, SetCurrentGuard, CONTEXT}; |
2 | |
3 | use crate::runtime::scheduler; |
4 | use crate::util::rand::{FastRand, RngSeed}; |
5 | |
6 | use std::fmt; |
7 | |
8 | #[derive (Debug, Clone, Copy)] |
9 | #[must_use ] |
10 | pub(crate) enum EnterRuntime { |
11 | /// Currently in a runtime context. |
12 | #[cfg_attr (not(feature = "rt" ), allow(dead_code))] |
13 | Entered { allow_block_in_place: bool }, |
14 | |
15 | /// Not in a runtime context **or** a blocking region. |
16 | NotEntered, |
17 | } |
18 | |
19 | /// Guard tracking that a caller has entered a runtime context. |
20 | #[must_use ] |
21 | pub(crate) struct EnterRuntimeGuard { |
22 | /// Tracks that the current thread has entered a blocking function call. |
23 | pub(crate) blocking: BlockingRegionGuard, |
24 | |
25 | #[allow (dead_code)] // Only tracking the guard. |
26 | pub(crate) handle: SetCurrentGuard, |
27 | |
28 | // Tracks the previous random number generator seed |
29 | old_seed: RngSeed, |
30 | } |
31 | |
32 | /// Marks the current thread as being within the dynamic extent of an |
33 | /// executor. |
34 | #[track_caller ] |
35 | pub(crate) fn enter_runtime<F, R>(handle: &scheduler::Handle, allow_block_in_place: bool, f: F) -> R |
36 | where |
37 | F: FnOnce(&mut BlockingRegionGuard) -> R, |
38 | { |
39 | let maybe_guard = CONTEXT.with(|c| { |
40 | if c.runtime.get().is_entered() { |
41 | None |
42 | } else { |
43 | // Set the entered flag |
44 | c.runtime.set(EnterRuntime::Entered { |
45 | allow_block_in_place, |
46 | }); |
47 | |
48 | // Generate a new seed |
49 | let rng_seed = handle.seed_generator().next_seed(); |
50 | |
51 | // Swap the RNG seed |
52 | let mut rng = c.rng.get().unwrap_or_else(FastRand::new); |
53 | let old_seed = rng.replace_seed(rng_seed); |
54 | c.rng.set(Some(rng)); |
55 | |
56 | Some(EnterRuntimeGuard { |
57 | blocking: BlockingRegionGuard::new(), |
58 | handle: c.set_current(handle), |
59 | old_seed, |
60 | }) |
61 | } |
62 | }); |
63 | |
64 | if let Some(mut guard) = maybe_guard { |
65 | return f(&mut guard.blocking); |
66 | } |
67 | |
68 | panic!( |
69 | "Cannot start a runtime from within a runtime. This happens \ |
70 | because a function (like `block_on`) attempted to block the \ |
71 | current thread while the thread is being used to drive \ |
72 | asynchronous tasks." |
73 | ); |
74 | } |
75 | |
76 | impl fmt::Debug for EnterRuntimeGuard { |
77 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
78 | f.debug_struct(name:"Enter" ).finish() |
79 | } |
80 | } |
81 | |
82 | impl Drop for EnterRuntimeGuard { |
83 | fn drop(&mut self) { |
84 | CONTEXT.with(|c: &Context| { |
85 | assert!(c.runtime.get().is_entered()); |
86 | c.runtime.set(val:EnterRuntime::NotEntered); |
87 | // Replace the previous RNG seed |
88 | let mut rng: FastRand = c.rng.get().unwrap_or_else(FastRand::new); |
89 | rng.replace_seed(self.old_seed.clone()); |
90 | c.rng.set(val:Some(rng)); |
91 | }); |
92 | } |
93 | } |
94 | |
95 | impl EnterRuntime { |
96 | pub(crate) fn is_entered(self) -> bool { |
97 | matches!(self, EnterRuntime::Entered { .. }) |
98 | } |
99 | } |
100 | |