1use super::{EnterRuntime, CONTEXT};
2
3use crate::loom::thread::AccessError;
4use crate::util::markers::NotSendOrSync;
5
6use std::marker::PhantomData;
7use std::time::Duration;
8
9/// Guard tracking that a caller has entered a blocking region.
10#[must_use]
11pub(crate) struct BlockingRegionGuard {
12 _p: PhantomData<NotSendOrSync>,
13}
14
15pub(crate) struct DisallowBlockInPlaceGuard(bool);
16
17pub(crate) fn try_enter_blocking_region() -> Option<BlockingRegionGuard> {
18 CONTEXT
19 .try_with(|c| {
20 if c.runtime.get().is_entered() {
21 None
22 } else {
23 Some(BlockingRegionGuard::new())
24 }
25 // If accessing the thread-local fails, the thread is terminating
26 // and thread-locals are being destroyed. Because we don't know if
27 // we are currently in a runtime or not, we default to being
28 // permissive.
29 })
30 .unwrap_or_else(|_| Some(BlockingRegionGuard::new()))
31}
32
33/// Disallows blocking in the current runtime context until the guard is dropped.
34pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard {
35 let reset = CONTEXT.with(|c| {
36 if let EnterRuntime::Entered {
37 allow_block_in_place: true,
38 } = c.runtime.get()
39 {
40 c.runtime.set(EnterRuntime::Entered {
41 allow_block_in_place: false,
42 });
43 true
44 } else {
45 false
46 }
47 });
48
49 DisallowBlockInPlaceGuard(reset)
50}
51
52impl BlockingRegionGuard {
53 pub(super) fn new() -> BlockingRegionGuard {
54 BlockingRegionGuard { _p: PhantomData }
55 }
56
57 /// Blocks the thread on the specified future, returning the value with
58 /// which that future completes.
59 pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, AccessError>
60 where
61 F: std::future::Future,
62 {
63 use crate::runtime::park::CachedParkThread;
64
65 let mut park = CachedParkThread::new();
66 park.block_on(f)
67 }
68
69 /// Blocks the thread on the specified future for **at most** `timeout`
70 ///
71 /// If the future completes before `timeout`, the result is returned. If
72 /// `timeout` elapses, then `Err` is returned.
73 pub(crate) fn block_on_timeout<F>(&mut self, f: F, timeout: Duration) -> Result<F::Output, ()>
74 where
75 F: std::future::Future,
76 {
77 use crate::runtime::park::CachedParkThread;
78 use std::task::Context;
79 use std::task::Poll::Ready;
80 use std::time::Instant;
81
82 let mut park = CachedParkThread::new();
83 let waker = park.waker().map_err(|_| ())?;
84 let mut cx = Context::from_waker(&waker);
85
86 pin!(f);
87 let when = Instant::now() + timeout;
88
89 loop {
90 if let Ready(v) = crate::runtime::coop::budget(|| f.as_mut().poll(&mut cx)) {
91 return Ok(v);
92 }
93
94 let now = Instant::now();
95
96 if now >= when {
97 return Err(());
98 }
99
100 park.park_timeout(when - now);
101 }
102 }
103}
104
105impl Drop for DisallowBlockInPlaceGuard {
106 fn drop(&mut self) {
107 if self.0 {
108 // XXX: Do we want some kind of assertion here, or is "best effort" okay?
109 CONTEXT.with(|c| {
110 if let EnterRuntime::Entered {
111 allow_block_in_place: false,
112 } = c.runtime.get()
113 {
114 c.runtime.set(EnterRuntime::Entered {
115 allow_block_in_place: true,
116 });
117 }
118 });
119 }
120 }
121}
122