1 | use std::future::Future; |
2 | use std::sync::Arc; |
3 | use std::task::{Context, Poll}; |
4 | use std::thread::{self, Thread}; |
5 | use std::time::Duration; |
6 | |
7 | use tokio::time::Instant; |
8 | |
9 | pub(crate) fn timeout<F, I, E>(fut: F, timeout: Option<Duration>) -> Result<I, Waited<E>> |
10 | where |
11 | F: Future<Output = Result<I, E>>, |
12 | { |
13 | enter(); |
14 | |
15 | let deadline = timeout.map(|d| { |
16 | log::trace!("wait at most {:?}" , d); |
17 | Instant::now() + d |
18 | }); |
19 | |
20 | let thread = ThreadWaker(thread::current()); |
21 | // Arc shouldn't be necessary, since `Thread` is reference counted internally, |
22 | // but let's just stay safe for now. |
23 | let waker = futures_util::task::waker(Arc::new(thread)); |
24 | let mut cx = Context::from_waker(&waker); |
25 | |
26 | futures_util::pin_mut!(fut); |
27 | |
28 | loop { |
29 | match fut.as_mut().poll(&mut cx) { |
30 | Poll::Ready(Ok(val)) => return Ok(val), |
31 | Poll::Ready(Err(err)) => return Err(Waited::Inner(err)), |
32 | Poll::Pending => (), // fallthrough |
33 | } |
34 | |
35 | if let Some(deadline) = deadline { |
36 | let now = Instant::now(); |
37 | if now >= deadline { |
38 | log::trace!("wait timeout exceeded" ); |
39 | return Err(Waited::TimedOut(crate::error::TimedOut)); |
40 | } |
41 | |
42 | log::trace!( |
43 | "( {:?}) park timeout {:?}" , |
44 | thread::current().id(), |
45 | deadline - now |
46 | ); |
47 | thread::park_timeout(deadline - now); |
48 | } else { |
49 | log::trace!("( {:?}) park without timeout" , thread::current().id()); |
50 | thread::park(); |
51 | } |
52 | } |
53 | } |
54 | |
55 | #[derive (Debug)] |
56 | pub(crate) enum Waited<E> { |
57 | TimedOut(crate::error::TimedOut), |
58 | Inner(E), |
59 | } |
60 | |
61 | struct ThreadWaker(Thread); |
62 | |
63 | impl futures_util::task::ArcWake for ThreadWaker { |
64 | fn wake_by_ref(arc_self: &Arc<Self>) { |
65 | arc_self.0.unpark(); |
66 | } |
67 | } |
68 | |
69 | fn enter() { |
70 | // Check we aren't already in a runtime |
71 | #[cfg (debug_assertions)] |
72 | { |
73 | let _enter: EnterGuard<'_> = tokioRuntime::runtime::Builder::new_current_thread() |
74 | .build() |
75 | .expect(msg:"build shell runtime" ) |
76 | .enter(); |
77 | } |
78 | } |
79 | |