1use std::future::Future;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::thread::{self, Thread};
5use std::time::Duration;
6
7use tokio::time::Instant;
8
9pub(crate) fn timeout<F, I, E>(fut: F, timeout: Option<Duration>) -> Result<I, Waited<E>>
10where
11 F: Future<Output = Result<I, E>>,
12{
13 enter();
14
15 let deadline = timeout.map(|d| {
16 log::trace!("wait at most {:?}", d);
17 Instant::now() + d
18 });
19
20 let thread = ThreadWaker(thread::current());
21 // Arc shouldn't be necessary, since `Thread` is reference counted internally,
22 // but let's just stay safe for now.
23 let waker = futures_util::task::waker(Arc::new(thread));
24 let mut cx = Context::from_waker(&waker);
25
26 futures_util::pin_mut!(fut);
27
28 loop {
29 match fut.as_mut().poll(&mut cx) {
30 Poll::Ready(Ok(val)) => return Ok(val),
31 Poll::Ready(Err(err)) => return Err(Waited::Inner(err)),
32 Poll::Pending => (), // fallthrough
33 }
34
35 if let Some(deadline) = deadline {
36 let now = Instant::now();
37 if now >= deadline {
38 log::trace!("wait timeout exceeded");
39 return Err(Waited::TimedOut(crate::error::TimedOut));
40 }
41
42 log::trace!(
43 "({:?}) park timeout {:?}",
44 thread::current().id(),
45 deadline - now
46 );
47 thread::park_timeout(deadline - now);
48 } else {
49 log::trace!("({:?}) park without timeout", thread::current().id());
50 thread::park();
51 }
52 }
53}
54
55#[derive(Debug)]
56pub(crate) enum Waited<E> {
57 TimedOut(crate::error::TimedOut),
58 Inner(E),
59}
60
61struct ThreadWaker(Thread);
62
63impl futures_util::task::ArcWake for ThreadWaker {
64 fn wake_by_ref(arc_self: &Arc<Self>) {
65 arc_self.0.unpark();
66 }
67}
68
69fn enter() {
70 // Check we aren't already in a runtime
71 #[cfg(debug_assertions)]
72 {
73 let _enter: EnterGuard<'_> = tokioRuntime::runtime::Builder::new_current_thread()
74 .build()
75 .expect(msg:"build shell runtime")
76 .enter();
77 }
78}
79