1 | use core::any::Any; |
2 | use core::pin::Pin; |
3 | use std::panic::{catch_unwind, AssertUnwindSafe, UnwindSafe}; |
4 | |
5 | use futures_core::future::Future; |
6 | use futures_core::task::{Context, Poll}; |
7 | use pin_project_lite::pin_project; |
8 | |
9 | pin_project! { |
10 | /// Future for the [`catch_unwind`](super::FutureExt::catch_unwind) method. |
11 | #[derive(Debug)] |
12 | #[must_use = "futures do nothing unless you `.await` or poll them" ] |
13 | pub struct CatchUnwind<Fut> { |
14 | #[pin] |
15 | future: Fut, |
16 | } |
17 | } |
18 | |
19 | impl<Fut> CatchUnwind<Fut> |
20 | where |
21 | Fut: Future + UnwindSafe, |
22 | { |
23 | pub(super) fn new(future: Fut) -> Self { |
24 | Self { future } |
25 | } |
26 | } |
27 | |
28 | impl<Fut> Future for CatchUnwind<Fut> |
29 | where |
30 | Fut: Future + UnwindSafe, |
31 | { |
32 | type Output = Result<Fut::Output, Box<dyn Any + Send>>; |
33 | |
34 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
35 | let f = self.project().future; |
36 | catch_unwind(AssertUnwindSafe(|| f.poll(cx)))?.map(Ok) |
37 | } |
38 | } |
39 | |