1 | use futures_core::stream::{FusedStream, Stream}; |
2 | use futures_core::task::{Context, Poll}; |
3 | use pin_project_lite::pin_project; |
4 | use std::any::Any; |
5 | use std::panic::{catch_unwind, AssertUnwindSafe, UnwindSafe}; |
6 | use std::pin::Pin; |
7 | |
8 | pin_project! { |
9 | /// Stream for the [`catch_unwind`](super::StreamExt::catch_unwind) method. |
10 | #[derive(Debug)] |
11 | #[must_use = "streams do nothing unless polled" ] |
12 | pub struct CatchUnwind<St> { |
13 | #[pin] |
14 | stream: St, |
15 | caught_unwind: bool, |
16 | } |
17 | } |
18 | |
19 | impl<St: Stream + UnwindSafe> CatchUnwind<St> { |
20 | pub(super) fn new(stream: St) -> Self { |
21 | Self { stream, caught_unwind: false } |
22 | } |
23 | |
24 | delegate_access_inner!(stream, St, ()); |
25 | } |
26 | |
27 | impl<St: Stream + UnwindSafe> Stream for CatchUnwind<St> { |
28 | type Item = Result<St::Item, Box<dyn Any + Send>>; |
29 | |
30 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
31 | let mut this = self.project(); |
32 | |
33 | if *this.caught_unwind { |
34 | Poll::Ready(None) |
35 | } else { |
36 | let res = catch_unwind(AssertUnwindSafe(|| this.stream.as_mut().poll_next(cx))); |
37 | |
38 | match res { |
39 | Ok(poll) => poll.map(|opt| opt.map(Ok)), |
40 | Err(e) => { |
41 | *this.caught_unwind = true; |
42 | Poll::Ready(Some(Err(e))) |
43 | } |
44 | } |
45 | } |
46 | } |
47 | |
48 | fn size_hint(&self) -> (usize, Option<usize>) { |
49 | if self.caught_unwind { |
50 | (0, Some(0)) |
51 | } else { |
52 | self.stream.size_hint() |
53 | } |
54 | } |
55 | } |
56 | |
57 | impl<St: FusedStream + UnwindSafe> FusedStream for CatchUnwind<St> { |
58 | fn is_terminated(&self) -> bool { |
59 | self.caught_unwind || self.stream.is_terminated() |
60 | } |
61 | } |
62 | |