1#![warn(rust_2018_idioms)]
2
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio::io::{AsyncRead, ReadBuf};
6use tokio_stream::StreamExt;
7
8/// produces at most `remaining` zeros, that returns error.
9/// each time it reads at most 31 byte.
10struct Reader {
11 remaining: usize,
12}
13
14impl AsyncRead for Reader {
15 fn poll_read(
16 self: Pin<&mut Self>,
17 _cx: &mut Context<'_>,
18 buf: &mut ReadBuf<'_>,
19 ) -> Poll<std::io::Result<()>> {
20 let this = Pin::into_inner(self);
21 assert_ne!(buf.remaining(), 0);
22 if this.remaining > 0 {
23 let n = std::cmp::min(this.remaining, buf.remaining());
24 let n = std::cmp::min(n, 31);
25 for x in &mut buf.initialize_unfilled_to(n)[..n] {
26 *x = 0;
27 }
28 buf.advance(n);
29 this.remaining -= n;
30 Poll::Ready(Ok(()))
31 } else {
32 Poll::Ready(Err(std::io::Error::from_raw_os_error(22)))
33 }
34 }
35}
36
37#[tokio::test]
38async fn correct_behavior_on_errors() {
39 let reader = Reader { remaining: 8000 };
40 let mut stream = tokio_util::io::ReaderStream::new(reader);
41 let mut zeros_received = 0;
42 let mut had_error = false;
43 loop {
44 let item = stream.next().await.unwrap();
45 println!("{:?}", item);
46 match item {
47 Ok(bytes) => {
48 let bytes = &*bytes;
49 for byte in bytes {
50 assert_eq!(*byte, 0);
51 zeros_received += 1;
52 }
53 }
54 Err(_) => {
55 assert!(!had_error);
56 had_error = true;
57 break;
58 }
59 }
60 }
61
62 assert!(had_error);
63 assert_eq!(zeros_received, 8000);
64 assert!(stream.next().await.is_none());
65}
66