1use super::{BlockingRegionGuard, SetCurrentGuard, CONTEXT};
2
3use crate::runtime::scheduler;
4use crate::util::rand::{FastRand, RngSeed};
5
6use std::fmt;
7
8#[derive(Debug, Clone, Copy)]
9#[must_use]
10pub(crate) enum EnterRuntime {
11 /// Currently in a runtime context.
12 #[cfg_attr(not(feature = "rt"), allow(dead_code))]
13 Entered { allow_block_in_place: bool },
14
15 /// Not in a runtime context **or** a blocking region.
16 NotEntered,
17}
18
19/// Guard tracking that a caller has entered a runtime context.
20#[must_use]
21pub(crate) struct EnterRuntimeGuard {
22 /// Tracks that the current thread has entered a blocking function call.
23 pub(crate) blocking: BlockingRegionGuard,
24
25 #[allow(dead_code)] // Only tracking the guard.
26 pub(crate) handle: SetCurrentGuard,
27
28 // Tracks the previous random number generator seed
29 old_seed: RngSeed,
30}
31
32/// Marks the current thread as being within the dynamic extent of an
33/// executor.
34#[track_caller]
35pub(crate) fn enter_runtime<F, R>(handle: &scheduler::Handle, allow_block_in_place: bool, f: F) -> R
36where
37 F: FnOnce(&mut BlockingRegionGuard) -> R,
38{
39 let maybe_guard = CONTEXT.with(|c| {
40 if c.runtime.get().is_entered() {
41 None
42 } else {
43 // Set the entered flag
44 c.runtime.set(EnterRuntime::Entered {
45 allow_block_in_place,
46 });
47
48 // Generate a new seed
49 let rng_seed = handle.seed_generator().next_seed();
50
51 // Swap the RNG seed
52 let mut rng = c.rng.get().unwrap_or_else(FastRand::new);
53 let old_seed = rng.replace_seed(rng_seed);
54 c.rng.set(Some(rng));
55
56 Some(EnterRuntimeGuard {
57 blocking: BlockingRegionGuard::new(),
58 handle: c.set_current(handle),
59 old_seed,
60 })
61 }
62 });
63
64 if let Some(mut guard) = maybe_guard {
65 return f(&mut guard.blocking);
66 }
67
68 panic!(
69 "Cannot start a runtime from within a runtime. This happens \
70 because a function (like `block_on`) attempted to block the \
71 current thread while the thread is being used to drive \
72 asynchronous tasks."
73 );
74}
75
76impl fmt::Debug for EnterRuntimeGuard {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 f.debug_struct(name:"Enter").finish()
79 }
80}
81
82impl Drop for EnterRuntimeGuard {
83 fn drop(&mut self) {
84 CONTEXT.with(|c: &Context| {
85 assert!(c.runtime.get().is_entered());
86 c.runtime.set(val:EnterRuntime::NotEntered);
87 // Replace the previous RNG seed
88 let mut rng: FastRand = c.rng.get().unwrap_or_else(FastRand::new);
89 rng.replace_seed(self.old_seed.clone());
90 c.rng.set(val:Some(rng));
91 });
92 }
93}
94
95impl EnterRuntime {
96 pub(crate) fn is_entered(self) -> bool {
97 matches!(self, EnterRuntime::Entered { .. })
98 }
99}
100