1use std::marker::Unpin;
2use std::{cmp, io};
3
4use bytes::{Buf, Bytes};
5use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
6
7use crate::common::{task, Pin, Poll};
8
9/// Combine a buffer with an IO, rewinding reads to use the buffer.
10#[derive(Debug)]
11pub(crate) struct Rewind<T> {
12 pre: Option<Bytes>,
13 inner: T,
14}
15
16impl<T> Rewind<T> {
17 #[cfg(any(all(feature = "http2", feature = "server"), test))]
18 pub(crate) fn new(io: T) -> Self {
19 Rewind {
20 pre: None,
21 inner: io,
22 }
23 }
24
25 pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
26 Rewind {
27 pre: Some(buf),
28 inner: io,
29 }
30 }
31
32 #[cfg(any(all(feature = "http1", feature = "http2", feature = "server"), test))]
33 pub(crate) fn rewind(&mut self, bs: Bytes) {
34 debug_assert!(self.pre.is_none());
35 self.pre = Some(bs);
36 }
37
38 pub(crate) fn into_inner(self) -> (T, Bytes) {
39 (self.inner, self.pre.unwrap_or_else(Bytes::new))
40 }
41
42 // pub(crate) fn get_mut(&mut self) -> &mut T {
43 // &mut self.inner
44 // }
45}
46
47impl<T> AsyncRead for Rewind<T>
48where
49 T: AsyncRead + Unpin,
50{
51 fn poll_read(
52 mut self: Pin<&mut Self>,
53 cx: &mut task::Context<'_>,
54 buf: &mut ReadBuf<'_>,
55 ) -> Poll<io::Result<()>> {
56 if let Some(mut prefix: Bytes) = self.pre.take() {
57 // If there are no remaining bytes, let the bytes get dropped.
58 if !prefix.is_empty() {
59 let copy_len: usize = cmp::min(v1:prefix.len(), v2:buf.remaining());
60 // TODO: There should be a way to do following two lines cleaner...
61 buf.put_slice(&prefix[..copy_len]);
62 prefix.advance(cnt:copy_len);
63 // Put back what's left
64 if !prefix.is_empty() {
65 self.pre = Some(prefix);
66 }
67
68 return Poll::Ready(Ok(()));
69 }
70 }
71 Pin::new(&mut self.inner).poll_read(cx, buf)
72 }
73}
74
75impl<T> AsyncWrite for Rewind<T>
76where
77 T: AsyncWrite + Unpin,
78{
79 fn poll_write(
80 mut self: Pin<&mut Self>,
81 cx: &mut task::Context<'_>,
82 buf: &[u8],
83 ) -> Poll<io::Result<usize>> {
84 Pin::new(&mut self.inner).poll_write(cx, buf)
85 }
86
87 fn poll_write_vectored(
88 mut self: Pin<&mut Self>,
89 cx: &mut task::Context<'_>,
90 bufs: &[io::IoSlice<'_>],
91 ) -> Poll<io::Result<usize>> {
92 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
93 }
94
95 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
96 Pin::new(&mut self.inner).poll_flush(cx)
97 }
98
99 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
100 Pin::new(&mut self.inner).poll_shutdown(cx)
101 }
102
103 fn is_write_vectored(&self) -> bool {
104 self.inner.is_write_vectored()
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 // FIXME: re-implement tests with `async/await`, this import should
111 // trigger a warning to remind us
112 use super::Rewind;
113 use bytes::Bytes;
114 use tokio::io::AsyncReadExt;
115
116 #[tokio::test]
117 async fn partial_rewind() {
118 let underlying = [104, 101, 108, 108, 111];
119
120 let mock = tokio_test::io::Builder::new().read(&underlying).build();
121
122 let mut stream = Rewind::new(mock);
123
124 // Read off some bytes, ensure we filled o1
125 let mut buf = [0; 2];
126 stream.read_exact(&mut buf).await.expect("read1");
127
128 // Rewind the stream so that it is as if we never read in the first place.
129 stream.rewind(Bytes::copy_from_slice(&buf[..]));
130
131 let mut buf = [0; 5];
132 stream.read_exact(&mut buf).await.expect("read1");
133
134 // At this point we should have read everything that was in the MockStream
135 assert_eq!(&buf, &underlying);
136 }
137
138 #[tokio::test]
139 async fn full_rewind() {
140 let underlying = [104, 101, 108, 108, 111];
141
142 let mock = tokio_test::io::Builder::new().read(&underlying).build();
143
144 let mut stream = Rewind::new(mock);
145
146 let mut buf = [0; 5];
147 stream.read_exact(&mut buf).await.expect("read1");
148
149 // Rewind the stream so that it is as if we never read in the first place.
150 stream.rewind(Bytes::copy_from_slice(&buf[..]));
151
152 let mut buf = [0; 5];
153 stream.read_exact(&mut buf).await.expect("read1");
154 }
155}
156