1 | use crate::stream_ext::Fuse; |
2 | use crate::Stream; |
3 | use tokio::time::{Instant, Sleep}; |
4 | |
5 | use core::future::Future; |
6 | use core::pin::Pin; |
7 | use core::task::{Context, Poll}; |
8 | use pin_project_lite::pin_project; |
9 | use std::fmt; |
10 | use std::time::Duration; |
11 | |
12 | pin_project! { |
13 | /// Stream returned by the [`timeout`](super::StreamExt::timeout) method. |
14 | #[must_use = "streams do nothing unless polled" ] |
15 | #[derive(Debug)] |
16 | pub struct Timeout<S> { |
17 | #[pin] |
18 | stream: Fuse<S>, |
19 | #[pin] |
20 | deadline: Sleep, |
21 | duration: Duration, |
22 | poll_deadline: bool, |
23 | } |
24 | } |
25 | |
26 | /// Error returned by `Timeout` and `TimeoutRepeating`. |
27 | #[derive(Debug, PartialEq, Eq)] |
28 | pub struct Elapsed(()); |
29 | |
30 | impl<S: Stream> Timeout<S> { |
31 | pub(super) fn new(stream: S, duration: Duration) -> Self { |
32 | let next = Instant::now() + duration; |
33 | let deadline = tokio::time::sleep_until(next); |
34 | |
35 | Timeout { |
36 | stream: Fuse::new(stream), |
37 | deadline, |
38 | duration, |
39 | poll_deadline: true, |
40 | } |
41 | } |
42 | } |
43 | |
44 | impl<S: Stream> Stream for Timeout<S> { |
45 | type Item = Result<S::Item, Elapsed>; |
46 | |
47 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
48 | let me = self.project(); |
49 | |
50 | match me.stream.poll_next(cx) { |
51 | Poll::Ready(v) => { |
52 | if v.is_some() { |
53 | let next = Instant::now() + *me.duration; |
54 | me.deadline.reset(next); |
55 | *me.poll_deadline = true; |
56 | } |
57 | return Poll::Ready(v.map(Ok)); |
58 | } |
59 | Poll::Pending => {} |
60 | }; |
61 | |
62 | if *me.poll_deadline { |
63 | ready!(me.deadline.poll(cx)); |
64 | *me.poll_deadline = false; |
65 | return Poll::Ready(Some(Err(Elapsed::new()))); |
66 | } |
67 | |
68 | Poll::Pending |
69 | } |
70 | |
71 | fn size_hint(&self) -> (usize, Option<usize>) { |
72 | let (lower, upper) = self.stream.size_hint(); |
73 | |
74 | // The timeout stream may insert an error before and after each message |
75 | // from the underlying stream, but no more than one error between each |
76 | // message. Hence the upper bound is computed as 2x+1. |
77 | |
78 | // Using a helper function to enable use of question mark operator. |
79 | fn twice_plus_one(value: Option<usize>) -> Option<usize> { |
80 | value?.checked_mul(2)?.checked_add(1) |
81 | } |
82 | |
83 | (lower, twice_plus_one(upper)) |
84 | } |
85 | } |
86 | |
87 | // ===== impl Elapsed ===== |
88 | |
89 | impl Elapsed { |
90 | pub(crate) fn new() -> Self { |
91 | Elapsed(()) |
92 | } |
93 | } |
94 | |
95 | impl fmt::Display for Elapsed { |
96 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
97 | "deadline has elapsed" .fmt(fmt) |
98 | } |
99 | } |
100 | |
101 | impl std::error::Error for Elapsed {} |
102 | |
103 | impl From<Elapsed> for std::io::Error { |
104 | fn from(_err: Elapsed) -> std::io::Error { |
105 | std::io::ErrorKind::TimedOut.into() |
106 | } |
107 | } |
108 | |