1use futures::future::poll_fn;
2use std::{
3 io::IoSlice,
4 pin::Pin,
5 task::{Context, Poll},
6};
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
8use tokio_util::io::{InspectReader, InspectWriter};
9
10/// An AsyncRead implementation that works byte-by-byte, to catch out callers
11/// who don't allow for `buf` being part-filled before the call
12struct SmallReader {
13 contents: Vec<u8>,
14}
15
16impl Unpin for SmallReader {}
17
18impl AsyncRead for SmallReader {
19 fn poll_read(
20 mut self: Pin<&mut Self>,
21 _cx: &mut Context<'_>,
22 buf: &mut ReadBuf<'_>,
23 ) -> Poll<std::io::Result<()>> {
24 if let Some(byte) = self.contents.pop() {
25 buf.put_slice(&[byte])
26 }
27 Poll::Ready(Ok(()))
28 }
29}
30
31#[tokio::test]
32async fn read_tee() {
33 let contents = b"This could be really long, you know".to_vec();
34 let reader = SmallReader {
35 contents: contents.clone(),
36 };
37 let mut altout: Vec<u8> = Vec::new();
38 let mut teeout = Vec::new();
39 {
40 let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes));
41 tee.read_to_end(&mut teeout).await.unwrap();
42 }
43 assert_eq!(teeout, altout);
44 assert_eq!(altout.len(), contents.len());
45}
46
47/// An AsyncWrite implementation that works byte-by-byte for poll_write, and
48/// that reads the whole of the first buffer plus one byte from the second in
49/// poll_write_vectored.
50///
51/// This is designed to catch bugs in handling partially written buffers
52#[derive(Debug)]
53struct SmallWriter {
54 contents: Vec<u8>,
55}
56
57impl Unpin for SmallWriter {}
58
59impl AsyncWrite for SmallWriter {
60 fn poll_write(
61 mut self: Pin<&mut Self>,
62 _cx: &mut Context<'_>,
63 buf: &[u8],
64 ) -> Poll<Result<usize, std::io::Error>> {
65 // Just write one byte at a time
66 if buf.is_empty() {
67 return Poll::Ready(Ok(0));
68 }
69 self.contents.push(buf[0]);
70 Poll::Ready(Ok(1))
71 }
72
73 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
74 Poll::Ready(Ok(()))
75 }
76
77 fn poll_shutdown(
78 self: Pin<&mut Self>,
79 _cx: &mut Context<'_>,
80 ) -> Poll<Result<(), std::io::Error>> {
81 Poll::Ready(Ok(()))
82 }
83
84 fn poll_write_vectored(
85 mut self: Pin<&mut Self>,
86 _cx: &mut Context<'_>,
87 bufs: &[IoSlice<'_>],
88 ) -> Poll<Result<usize, std::io::Error>> {
89 // Write all of the first buffer, then one byte from the second buffer
90 // This should trip up anything that doesn't correctly handle multiple
91 // buffers.
92 if bufs.is_empty() {
93 return Poll::Ready(Ok(0));
94 }
95 let mut written_len = bufs[0].len();
96 self.contents.extend_from_slice(&bufs[0]);
97
98 if bufs.len() > 1 {
99 let buf = bufs[1];
100 if !buf.is_empty() {
101 written_len += 1;
102 self.contents.push(buf[0]);
103 }
104 }
105 Poll::Ready(Ok(written_len))
106 }
107
108 fn is_write_vectored(&self) -> bool {
109 true
110 }
111}
112
113#[tokio::test]
114async fn write_tee() {
115 let mut altout: Vec<u8> = Vec::new();
116 let mut writeout = SmallWriter {
117 contents: Vec::new(),
118 };
119 {
120 let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes));
121 tee.write_all(b"A testing string, very testing")
122 .await
123 .unwrap();
124 }
125 assert_eq!(altout, writeout.contents);
126}
127
128// This is inefficient, but works well enough for test use.
129// If you want something similar for real code, you'll want to avoid all the
130// fun of manipulating `bufs` - ideally, by the time you read this,
131// IoSlice::advance_slices will be stable, and you can use that.
132async fn write_all_vectored<W: AsyncWrite + Unpin>(
133 mut writer: W,
134 mut bufs: Vec<Vec<u8>>,
135) -> Result<usize, std::io::Error> {
136 let mut res = 0;
137 while !bufs.is_empty() {
138 let mut written = poll_fn(|cx| {
139 let bufs: Vec<IoSlice> = bufs.iter().map(|v| IoSlice::new(v)).collect();
140 Pin::new(&mut writer).poll_write_vectored(cx, &bufs)
141 })
142 .await?;
143 res += written;
144 while written > 0 {
145 let buf_len = bufs[0].len();
146 if buf_len <= written {
147 bufs.remove(0);
148 written -= buf_len;
149 } else {
150 let buf = &mut bufs[0];
151 let drain_len = written.min(buf.len());
152 buf.drain(..drain_len);
153 written -= drain_len;
154 }
155 }
156 }
157 Ok(res)
158}
159
160#[tokio::test]
161async fn write_tee_vectored() {
162 let mut altout: Vec<u8> = Vec::new();
163 let mut writeout = SmallWriter {
164 contents: Vec::new(),
165 };
166 let original = b"A very long string split up";
167 let bufs: Vec<Vec<u8>> = original
168 .split(|b| b.is_ascii_whitespace())
169 .map(Vec::from)
170 .collect();
171 assert!(bufs.len() > 1);
172 let expected: Vec<u8> = {
173 let mut out = Vec::new();
174 for item in &bufs {
175 out.extend_from_slice(item)
176 }
177 out
178 };
179 {
180 let mut bufcount = 0;
181 let tee = InspectWriter::new(&mut writeout, |bytes| {
182 bufcount += 1;
183 altout.extend(bytes)
184 });
185
186 assert!(tee.is_write_vectored());
187
188 write_all_vectored(tee, bufs.clone()).await.unwrap();
189
190 assert!(bufcount >= bufs.len());
191 }
192 assert_eq!(altout, writeout.contents);
193 assert_eq!(writeout.contents, expected);
194}
195