| 1 | use super::{EnterRuntime, CONTEXT}; |
| 2 | |
| 3 | use crate::loom::thread::AccessError; |
| 4 | use crate::util::markers::NotSendOrSync; |
| 5 | |
| 6 | use std::marker::PhantomData; |
| 7 | use std::time::Duration; |
| 8 | |
| 9 | /// Guard tracking that a caller has entered a blocking region. |
| 10 | #[must_use ] |
| 11 | pub(crate) struct BlockingRegionGuard { |
| 12 | _p: PhantomData<NotSendOrSync>, |
| 13 | } |
| 14 | |
| 15 | pub(crate) struct DisallowBlockInPlaceGuard(bool); |
| 16 | |
| 17 | pub(crate) fn try_enter_blocking_region() -> Option<BlockingRegionGuard> { |
| 18 | CONTEXT |
| 19 | .try_with(|c| { |
| 20 | if c.runtime.get().is_entered() { |
| 21 | None |
| 22 | } else { |
| 23 | Some(BlockingRegionGuard::new()) |
| 24 | } |
| 25 | // If accessing the thread-local fails, the thread is terminating |
| 26 | // and thread-locals are being destroyed. Because we don't know if |
| 27 | // we are currently in a runtime or not, we default to being |
| 28 | // permissive. |
| 29 | }) |
| 30 | .unwrap_or_else(|_| Some(BlockingRegionGuard::new())) |
| 31 | } |
| 32 | |
| 33 | /// Disallows blocking in the current runtime context until the guard is dropped. |
| 34 | pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard { |
| 35 | let reset: bool = CONTEXT.with(|c: &Context| { |
| 36 | if let EnterRuntime::Entered { |
| 37 | allow_block_in_place: true, |
| 38 | } = c.runtime.get() |
| 39 | { |
| 40 | c.runtime.set(val:EnterRuntime::Entered { |
| 41 | allow_block_in_place: false, |
| 42 | }); |
| 43 | true |
| 44 | } else { |
| 45 | false |
| 46 | } |
| 47 | }); |
| 48 | |
| 49 | DisallowBlockInPlaceGuard(reset) |
| 50 | } |
| 51 | |
| 52 | impl BlockingRegionGuard { |
| 53 | pub(super) fn new() -> BlockingRegionGuard { |
| 54 | BlockingRegionGuard { _p: PhantomData } |
| 55 | } |
| 56 | |
| 57 | /// Blocks the thread on the specified future, returning the value with |
| 58 | /// which that future completes. |
| 59 | pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, AccessError> |
| 60 | where |
| 61 | F: std::future::Future, |
| 62 | { |
| 63 | use crate::runtime::park::CachedParkThread; |
| 64 | |
| 65 | let mut park = CachedParkThread::new(); |
| 66 | park.block_on(f) |
| 67 | } |
| 68 | |
| 69 | /// Blocks the thread on the specified future for **at most** `timeout` |
| 70 | /// |
| 71 | /// If the future completes before `timeout`, the result is returned. If |
| 72 | /// `timeout` elapses, then `Err` is returned. |
| 73 | pub(crate) fn block_on_timeout<F>(&mut self, f: F, timeout: Duration) -> Result<F::Output, ()> |
| 74 | where |
| 75 | F: std::future::Future, |
| 76 | { |
| 77 | use crate::runtime::park::CachedParkThread; |
| 78 | use std::task::Context; |
| 79 | use std::task::Poll::Ready; |
| 80 | use std::time::Instant; |
| 81 | |
| 82 | let mut park = CachedParkThread::new(); |
| 83 | let waker = park.waker().map_err(|_| ())?; |
| 84 | let mut cx = Context::from_waker(&waker); |
| 85 | |
| 86 | pin!(f); |
| 87 | let when = Instant::now() + timeout; |
| 88 | |
| 89 | loop { |
| 90 | if let Ready(v) = crate::task::coop::budget(|| f.as_mut().poll(&mut cx)) { |
| 91 | return Ok(v); |
| 92 | } |
| 93 | |
| 94 | let now = Instant::now(); |
| 95 | |
| 96 | if now >= when { |
| 97 | return Err(()); |
| 98 | } |
| 99 | |
| 100 | park.park_timeout(when - now); |
| 101 | } |
| 102 | } |
| 103 | } |
| 104 | |
| 105 | impl Drop for DisallowBlockInPlaceGuard { |
| 106 | fn drop(&mut self) { |
| 107 | if self.0 { |
| 108 | // XXX: Do we want some kind of assertion here, or is "best effort" okay? |
| 109 | CONTEXT.with(|c: &Context| { |
| 110 | if let EnterRuntime::Entered { |
| 111 | allow_block_in_place: false, |
| 112 | } = c.runtime.get() |
| 113 | { |
| 114 | c.runtime.set(val:EnterRuntime::Entered { |
| 115 | allow_block_in_place: true, |
| 116 | }); |
| 117 | } |
| 118 | }); |
| 119 | } |
| 120 | } |
| 121 | } |
| 122 | |