1use bytes::Buf;
2use futures_core::stream::Stream;
3use futures_sink::Sink;
4use std::io;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
8
9/// Convert a [`Stream`] of byte chunks into an [`AsyncRead`].
10///
11/// This type performs the inverse operation of [`ReaderStream`].
12///
13/// This type also implements the [`AsyncBufRead`] trait, so you can use it
14/// to read a `Stream` of byte chunks line-by-line. See the examples below.
15///
16/// # Example
17///
18/// ```
19/// use bytes::Bytes;
20/// use tokio::io::{AsyncReadExt, Result};
21/// use tokio_util::io::StreamReader;
22/// # #[tokio::main(flavor = "current_thread")]
23/// # async fn main() -> std::io::Result<()> {
24///
25/// // Create a stream from an iterator.
26/// let stream = tokio_stream::iter(vec![
27/// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
28/// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
29/// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])),
30/// ]);
31///
32/// // Convert it to an AsyncRead.
33/// let mut read = StreamReader::new(stream);
34///
35/// // Read five bytes from the stream.
36/// let mut buf = [0; 5];
37/// read.read_exact(&mut buf).await?;
38/// assert_eq!(buf, [0, 1, 2, 3, 4]);
39///
40/// // Read the rest of the current chunk.
41/// assert_eq!(read.read(&mut buf).await?, 3);
42/// assert_eq!(&buf[..3], [5, 6, 7]);
43///
44/// // Read the next chunk.
45/// assert_eq!(read.read(&mut buf).await?, 4);
46/// assert_eq!(&buf[..4], [8, 9, 10, 11]);
47///
48/// // We have now reached the end.
49/// assert_eq!(read.read(&mut buf).await?, 0);
50///
51/// # Ok(())
52/// # }
53/// ```
54///
55/// If the stream produces errors which are not [`std::io::Error`],
56/// the errors can be converted using [`StreamExt`] to map each
57/// element.
58///
59/// ```
60/// use bytes::Bytes;
61/// use tokio::io::AsyncReadExt;
62/// use tokio_util::io::StreamReader;
63/// use tokio_stream::StreamExt;
64/// # #[tokio::main(flavor = "current_thread")]
65/// # async fn main() -> std::io::Result<()> {
66///
67/// // Create a stream from an iterator, including an error.
68/// let stream = tokio_stream::iter(vec![
69/// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
70/// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
71/// Result::Err("Something bad happened!")
72/// ]);
73///
74/// // Use StreamExt to map the stream and error to a std::io::Error
75/// let stream = stream.map(|result| result.map_err(|err| {
76/// std::io::Error::new(std::io::ErrorKind::Other, err)
77/// }));
78///
79/// // Convert it to an AsyncRead.
80/// let mut read = StreamReader::new(stream);
81///
82/// // Read five bytes from the stream.
83/// let mut buf = [0; 5];
84/// read.read_exact(&mut buf).await?;
85/// assert_eq!(buf, [0, 1, 2, 3, 4]);
86///
87/// // Read the rest of the current chunk.
88/// assert_eq!(read.read(&mut buf).await?, 3);
89/// assert_eq!(&buf[..3], [5, 6, 7]);
90///
91/// // Reading the next chunk will produce an error
92/// let error = read.read(&mut buf).await.unwrap_err();
93/// assert_eq!(error.kind(), std::io::ErrorKind::Other);
94/// assert_eq!(error.into_inner().unwrap().to_string(), "Something bad happened!");
95///
96/// // We have now reached the end.
97/// assert_eq!(read.read(&mut buf).await?, 0);
98///
99/// # Ok(())
100/// # }
101/// ```
102///
103/// Using the [`AsyncBufRead`] impl, you can read a `Stream` of byte chunks
104/// line-by-line. Note that you will usually also need to convert the error
105/// type when doing this. See the second example for an explanation of how
106/// to do this.
107///
108/// ```
109/// use tokio::io::{Result, AsyncBufReadExt};
110/// use tokio_util::io::StreamReader;
111/// # #[tokio::main(flavor = "current_thread")]
112/// # async fn main() -> std::io::Result<()> {
113///
114/// // Create a stream of byte chunks.
115/// let stream = tokio_stream::iter(vec![
116/// Result::Ok(b"The first line.\n".as_slice()),
117/// Result::Ok(b"The second line.".as_slice()),
118/// Result::Ok(b"\nThe third".as_slice()),
119/// Result::Ok(b" line.\nThe fourth line.\nThe fifth line.\n".as_slice()),
120/// ]);
121///
122/// // Convert it to an AsyncRead.
123/// let mut read = StreamReader::new(stream);
124///
125/// // Loop through the lines from the `StreamReader`.
126/// let mut line = String::new();
127/// let mut lines = Vec::new();
128/// loop {
129/// line.clear();
130/// let len = read.read_line(&mut line).await?;
131/// if len == 0 { break; }
132/// lines.push(line.clone());
133/// }
134///
135/// // Verify that we got the lines we expected.
136/// assert_eq!(
137/// lines,
138/// vec![
139/// "The first line.\n",
140/// "The second line.\n",
141/// "The third line.\n",
142/// "The fourth line.\n",
143/// "The fifth line.\n",
144/// ]
145/// );
146/// # Ok(())
147/// # }
148/// ```
149///
150/// [`AsyncRead`]: tokio::io::AsyncRead
151/// [`AsyncBufRead`]: tokio::io::AsyncBufRead
152/// [`Stream`]: futures_core::Stream
153/// [`ReaderStream`]: crate::io::ReaderStream
154/// [`StreamExt`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html
155#[derive(Debug)]
156pub struct StreamReader<S, B> {
157 // This field is pinned.
158 inner: S,
159 // This field is not pinned.
160 chunk: Option<B>,
161}
162
163impl<S, B, E> StreamReader<S, B>
164where
165 S: Stream<Item = Result<B, E>>,
166 B: Buf,
167 E: Into<std::io::Error>,
168{
169 /// Convert a stream of byte chunks into an [`AsyncRead`].
170 ///
171 /// The item should be a [`Result`] with the ok variant being something that
172 /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error
173 /// should be convertible into an [io error].
174 ///
175 /// [`Result`]: std::result::Result
176 /// [`Buf`]: bytes::Buf
177 /// [io error]: std::io::Error
178 pub fn new(stream: S) -> Self {
179 Self {
180 inner: stream,
181 chunk: None,
182 }
183 }
184
185 /// Do we have a chunk and is it non-empty?
186 fn has_chunk(&self) -> bool {
187 if let Some(ref chunk) = self.chunk {
188 chunk.remaining() > 0
189 } else {
190 false
191 }
192 }
193
194 /// Consumes this `StreamReader`, returning a Tuple consisting
195 /// of the underlying stream and an Option of the internal buffer,
196 /// which is Some in case the buffer contains elements.
197 pub fn into_inner_with_chunk(self) -> (S, Option<B>) {
198 if self.has_chunk() {
199 (self.inner, self.chunk)
200 } else {
201 (self.inner, None)
202 }
203 }
204}
205
206impl<S, B> StreamReader<S, B> {
207 /// Gets a reference to the underlying stream.
208 ///
209 /// It is inadvisable to directly read from the underlying stream.
210 pub fn get_ref(&self) -> &S {
211 &self.inner
212 }
213
214 /// Gets a mutable reference to the underlying stream.
215 ///
216 /// It is inadvisable to directly read from the underlying stream.
217 pub fn get_mut(&mut self) -> &mut S {
218 &mut self.inner
219 }
220
221 /// Gets a pinned mutable reference to the underlying stream.
222 ///
223 /// It is inadvisable to directly read from the underlying stream.
224 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
225 self.project().inner
226 }
227
228 /// Consumes this `BufWriter`, returning the underlying stream.
229 ///
230 /// Note that any leftover data in the internal buffer is lost.
231 /// If you additionally want access to the internal buffer use
232 /// [`into_inner_with_chunk`].
233 ///
234 /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk
235 pub fn into_inner(self) -> S {
236 self.inner
237 }
238}
239
240impl<S, B, E> AsyncRead for StreamReader<S, B>
241where
242 S: Stream<Item = Result<B, E>>,
243 B: Buf,
244 E: Into<std::io::Error>,
245{
246 fn poll_read(
247 mut self: Pin<&mut Self>,
248 cx: &mut Context<'_>,
249 buf: &mut ReadBuf<'_>,
250 ) -> Poll<io::Result<()>> {
251 if buf.remaining() == 0 {
252 return Poll::Ready(Ok(()));
253 }
254
255 let inner_buf = match self.as_mut().poll_fill_buf(cx) {
256 Poll::Ready(Ok(buf)) => buf,
257 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
258 Poll::Pending => return Poll::Pending,
259 };
260 let len = std::cmp::min(inner_buf.len(), buf.remaining());
261 buf.put_slice(&inner_buf[..len]);
262
263 self.consume(len);
264 Poll::Ready(Ok(()))
265 }
266}
267
268impl<S, B, E> AsyncBufRead for StreamReader<S, B>
269where
270 S: Stream<Item = Result<B, E>>,
271 B: Buf,
272 E: Into<std::io::Error>,
273{
274 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
275 loop {
276 if self.as_mut().has_chunk() {
277 // This unwrap is very sad, but it can't be avoided.
278 let buf = self.project().chunk.as_ref().unwrap().chunk();
279 return Poll::Ready(Ok(buf));
280 } else {
281 match self.as_mut().project().inner.poll_next(cx) {
282 Poll::Ready(Some(Ok(chunk))) => {
283 // Go around the loop in case the chunk is empty.
284 *self.as_mut().project().chunk = Some(chunk);
285 }
286 Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
287 Poll::Ready(None) => return Poll::Ready(Ok(&[])),
288 Poll::Pending => return Poll::Pending,
289 }
290 }
291 }
292 }
293 fn consume(self: Pin<&mut Self>, amt: usize) {
294 if amt > 0 {
295 self.project()
296 .chunk
297 .as_mut()
298 .expect("No chunk present")
299 .advance(amt);
300 }
301 }
302}
303
304// The code below is a manual expansion of the code that pin-project-lite would
305// generate. This is done because pin-project-lite fails by hitting the recursion
306// limit on this struct. (Every line of documentation is handled recursively by
307// the macro.)
308
309impl<S: Unpin, B> Unpin for StreamReader<S, B> {}
310
311struct StreamReaderProject<'a, S, B> {
312 inner: Pin<&'a mut S>,
313 chunk: &'a mut Option<B>,
314}
315
316impl<S, B> StreamReader<S, B> {
317 #[inline]
318 fn project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B> {
319 // SAFETY: We define that only `inner` should be pinned when `Self` is
320 // and have an appropriate `impl Unpin` for this.
321 let me = unsafe { Pin::into_inner_unchecked(self) };
322 StreamReaderProject {
323 inner: unsafe { Pin::new_unchecked(&mut me.inner) },
324 chunk: &mut me.chunk,
325 }
326 }
327}
328
329impl<S: Sink<T, Error = E>, E, T> Sink<T> for StreamReader<S, E> {
330 type Error = E;
331 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
332 self.project().inner.poll_ready(cx)
333 }
334
335 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
336 self.project().inner.start_send(item)
337 }
338
339 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
340 self.project().inner.poll_flush(cx)
341 }
342
343 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
344 self.project().inner.poll_close(cx)
345 }
346}
347