1 | use futures::future::poll_fn; |
2 | use std::{ |
3 | io::IoSlice, |
4 | pin::Pin, |
5 | task::{Context, Poll}, |
6 | }; |
7 | use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; |
8 | use tokio_util::io::{InspectReader, InspectWriter}; |
9 | |
10 | /// An AsyncRead implementation that works byte-by-byte, to catch out callers |
11 | /// who don't allow for `buf` being part-filled before the call |
12 | struct SmallReader { |
13 | contents: Vec<u8>, |
14 | } |
15 | |
16 | impl Unpin for SmallReader {} |
17 | |
18 | impl AsyncRead for SmallReader { |
19 | fn poll_read( |
20 | mut self: Pin<&mut Self>, |
21 | _cx: &mut Context<'_>, |
22 | buf: &mut ReadBuf<'_>, |
23 | ) -> Poll<std::io::Result<()>> { |
24 | if let Some(byte) = self.contents.pop() { |
25 | buf.put_slice(&[byte]) |
26 | } |
27 | Poll::Ready(Ok(())) |
28 | } |
29 | } |
30 | |
31 | #[tokio::test ] |
32 | async fn read_tee() { |
33 | let contents = b"This could be really long, you know" .to_vec(); |
34 | let reader = SmallReader { |
35 | contents: contents.clone(), |
36 | }; |
37 | let mut altout: Vec<u8> = Vec::new(); |
38 | let mut teeout = Vec::new(); |
39 | { |
40 | let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes)); |
41 | tee.read_to_end(&mut teeout).await.unwrap(); |
42 | } |
43 | assert_eq!(teeout, altout); |
44 | assert_eq!(altout.len(), contents.len()); |
45 | } |
46 | |
47 | /// An AsyncWrite implementation that works byte-by-byte for poll_write, and |
48 | /// that reads the whole of the first buffer plus one byte from the second in |
49 | /// poll_write_vectored. |
50 | /// |
51 | /// This is designed to catch bugs in handling partially written buffers |
52 | #[derive(Debug)] |
53 | struct SmallWriter { |
54 | contents: Vec<u8>, |
55 | } |
56 | |
57 | impl Unpin for SmallWriter {} |
58 | |
59 | impl AsyncWrite for SmallWriter { |
60 | fn poll_write( |
61 | mut self: Pin<&mut Self>, |
62 | _cx: &mut Context<'_>, |
63 | buf: &[u8], |
64 | ) -> Poll<Result<usize, std::io::Error>> { |
65 | // Just write one byte at a time |
66 | if buf.is_empty() { |
67 | return Poll::Ready(Ok(0)); |
68 | } |
69 | self.contents.push(buf[0]); |
70 | Poll::Ready(Ok(1)) |
71 | } |
72 | |
73 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { |
74 | Poll::Ready(Ok(())) |
75 | } |
76 | |
77 | fn poll_shutdown( |
78 | self: Pin<&mut Self>, |
79 | _cx: &mut Context<'_>, |
80 | ) -> Poll<Result<(), std::io::Error>> { |
81 | Poll::Ready(Ok(())) |
82 | } |
83 | |
84 | fn poll_write_vectored( |
85 | mut self: Pin<&mut Self>, |
86 | _cx: &mut Context<'_>, |
87 | bufs: &[IoSlice<'_>], |
88 | ) -> Poll<Result<usize, std::io::Error>> { |
89 | // Write all of the first buffer, then one byte from the second buffer |
90 | // This should trip up anything that doesn't correctly handle multiple |
91 | // buffers. |
92 | if bufs.is_empty() { |
93 | return Poll::Ready(Ok(0)); |
94 | } |
95 | let mut written_len = bufs[0].len(); |
96 | self.contents.extend_from_slice(&bufs[0]); |
97 | |
98 | if bufs.len() > 1 { |
99 | let buf = bufs[1]; |
100 | if !buf.is_empty() { |
101 | written_len += 1; |
102 | self.contents.push(buf[0]); |
103 | } |
104 | } |
105 | Poll::Ready(Ok(written_len)) |
106 | } |
107 | |
108 | fn is_write_vectored(&self) -> bool { |
109 | true |
110 | } |
111 | } |
112 | |
113 | #[tokio::test ] |
114 | async fn write_tee() { |
115 | let mut altout: Vec<u8> = Vec::new(); |
116 | let mut writeout = SmallWriter { |
117 | contents: Vec::new(), |
118 | }; |
119 | { |
120 | let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes)); |
121 | tee.write_all(b"A testing string, very testing" ) |
122 | .await |
123 | .unwrap(); |
124 | } |
125 | assert_eq!(altout, writeout.contents); |
126 | } |
127 | |
128 | // This is inefficient, but works well enough for test use. |
129 | // If you want something similar for real code, you'll want to avoid all the |
130 | // fun of manipulating `bufs` - ideally, by the time you read this, |
131 | // IoSlice::advance_slices will be stable, and you can use that. |
132 | async fn write_all_vectored<W: AsyncWrite + Unpin>( |
133 | mut writer: W, |
134 | mut bufs: Vec<Vec<u8>>, |
135 | ) -> Result<usize, std::io::Error> { |
136 | let mut res = 0; |
137 | while !bufs.is_empty() { |
138 | let mut written = poll_fn(|cx| { |
139 | let bufs: Vec<IoSlice> = bufs.iter().map(|v| IoSlice::new(v)).collect(); |
140 | Pin::new(&mut writer).poll_write_vectored(cx, &bufs) |
141 | }) |
142 | .await?; |
143 | res += written; |
144 | while written > 0 { |
145 | let buf_len = bufs[0].len(); |
146 | if buf_len <= written { |
147 | bufs.remove(0); |
148 | written -= buf_len; |
149 | } else { |
150 | let buf = &mut bufs[0]; |
151 | let drain_len = written.min(buf.len()); |
152 | buf.drain(..drain_len); |
153 | written -= drain_len; |
154 | } |
155 | } |
156 | } |
157 | Ok(res) |
158 | } |
159 | |
160 | #[tokio::test ] |
161 | async fn write_tee_vectored() { |
162 | let mut altout: Vec<u8> = Vec::new(); |
163 | let mut writeout = SmallWriter { |
164 | contents: Vec::new(), |
165 | }; |
166 | let original = b"A very long string split up" ; |
167 | let bufs: Vec<Vec<u8>> = original |
168 | .split(|b| b.is_ascii_whitespace()) |
169 | .map(Vec::from) |
170 | .collect(); |
171 | assert!(bufs.len() > 1); |
172 | let expected: Vec<u8> = { |
173 | let mut out = Vec::new(); |
174 | for item in &bufs { |
175 | out.extend_from_slice(item) |
176 | } |
177 | out |
178 | }; |
179 | { |
180 | let mut bufcount = 0; |
181 | let tee = InspectWriter::new(&mut writeout, |bytes| { |
182 | bufcount += 1; |
183 | altout.extend(bytes) |
184 | }); |
185 | |
186 | assert!(tee.is_write_vectored()); |
187 | |
188 | write_all_vectored(tee, bufs.clone()).await.unwrap(); |
189 | |
190 | assert!(bufcount >= bufs.len()); |
191 | } |
192 | assert_eq!(altout, writeout.contents); |
193 | assert_eq!(writeout.contents, expected); |
194 | } |
195 | |