1 | use super::{EnterRuntime, CONTEXT}; |
2 | |
3 | use crate::loom::thread::AccessError; |
4 | use crate::util::markers::NotSendOrSync; |
5 | |
6 | use std::marker::PhantomData; |
7 | use std::time::Duration; |
8 | |
9 | /// Guard tracking that a caller has entered a blocking region. |
10 | #[must_use ] |
11 | pub(crate) struct BlockingRegionGuard { |
12 | _p: PhantomData<NotSendOrSync>, |
13 | } |
14 | |
15 | pub(crate) struct DisallowBlockInPlaceGuard(bool); |
16 | |
17 | pub(crate) fn try_enter_blocking_region() -> Option<BlockingRegionGuard> { |
18 | CONTEXT |
19 | .try_with(|c| { |
20 | if c.runtime.get().is_entered() { |
21 | None |
22 | } else { |
23 | Some(BlockingRegionGuard::new()) |
24 | } |
25 | // If accessing the thread-local fails, the thread is terminating |
26 | // and thread-locals are being destroyed. Because we don't know if |
27 | // we are currently in a runtime or not, we default to being |
28 | // permissive. |
29 | }) |
30 | .unwrap_or_else(|_| Some(BlockingRegionGuard::new())) |
31 | } |
32 | |
33 | /// Disallows blocking in the current runtime context until the guard is dropped. |
34 | pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard { |
35 | let reset = CONTEXT.with(|c| { |
36 | if let EnterRuntime::Entered { |
37 | allow_block_in_place: true, |
38 | } = c.runtime.get() |
39 | { |
40 | c.runtime.set(EnterRuntime::Entered { |
41 | allow_block_in_place: false, |
42 | }); |
43 | true |
44 | } else { |
45 | false |
46 | } |
47 | }); |
48 | |
49 | DisallowBlockInPlaceGuard(reset) |
50 | } |
51 | |
52 | impl BlockingRegionGuard { |
53 | pub(super) fn new() -> BlockingRegionGuard { |
54 | BlockingRegionGuard { _p: PhantomData } |
55 | } |
56 | |
57 | /// Blocks the thread on the specified future, returning the value with |
58 | /// which that future completes. |
59 | pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, AccessError> |
60 | where |
61 | F: std::future::Future, |
62 | { |
63 | use crate::runtime::park::CachedParkThread; |
64 | |
65 | let mut park = CachedParkThread::new(); |
66 | park.block_on(f) |
67 | } |
68 | |
69 | /// Blocks the thread on the specified future for **at most** `timeout` |
70 | /// |
71 | /// If the future completes before `timeout`, the result is returned. If |
72 | /// `timeout` elapses, then `Err` is returned. |
73 | pub(crate) fn block_on_timeout<F>(&mut self, f: F, timeout: Duration) -> Result<F::Output, ()> |
74 | where |
75 | F: std::future::Future, |
76 | { |
77 | use crate::runtime::park::CachedParkThread; |
78 | use std::task::Context; |
79 | use std::task::Poll::Ready; |
80 | use std::time::Instant; |
81 | |
82 | let mut park = CachedParkThread::new(); |
83 | let waker = park.waker().map_err(|_| ())?; |
84 | let mut cx = Context::from_waker(&waker); |
85 | |
86 | pin!(f); |
87 | let when = Instant::now() + timeout; |
88 | |
89 | loop { |
90 | if let Ready(v) = crate::runtime::coop::budget(|| f.as_mut().poll(&mut cx)) { |
91 | return Ok(v); |
92 | } |
93 | |
94 | let now = Instant::now(); |
95 | |
96 | if now >= when { |
97 | return Err(()); |
98 | } |
99 | |
100 | park.park_timeout(when - now); |
101 | } |
102 | } |
103 | } |
104 | |
105 | impl Drop for DisallowBlockInPlaceGuard { |
106 | fn drop(&mut self) { |
107 | if self.0 { |
108 | // XXX: Do we want some kind of assertion here, or is "best effort" okay? |
109 | CONTEXT.with(|c| { |
110 | if let EnterRuntime::Entered { |
111 | allow_block_in_place: false, |
112 | } = c.runtime.get() |
113 | { |
114 | c.runtime.set(EnterRuntime::Entered { |
115 | allow_block_in_place: true, |
116 | }); |
117 | } |
118 | }); |
119 | } |
120 | } |
121 | } |
122 | |