1use futures_core::ready;
2use pin_project_lite::pin_project;
3use std::io::{IoSlice, Result};
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9pin_project! {
10 /// An adapter that lets you inspect the data that's being read.
11 ///
12 /// This is useful for things like hashing data as it's read in.
13 pub struct InspectReader<R, F> {
14 #[pin]
15 reader: R,
16 f: F,
17 }
18}
19
20impl<R, F> InspectReader<R, F> {
21 /// Create a new InspectReader, wrapping `reader` and calling `f` for the
22 /// new data supplied by each read call.
23 ///
24 /// The closure will only be called with an empty slice if the inner reader
25 /// returns without reading data into the buffer. This happens at EOF, or if
26 /// `poll_read` is called with a zero-size buffer.
27 pub fn new(reader: R, f: F) -> InspectReader<R, F>
28 where
29 R: AsyncRead,
30 F: FnMut(&[u8]),
31 {
32 InspectReader { reader, f }
33 }
34
35 /// Consumes the `InspectReader`, returning the wrapped reader
36 pub fn into_inner(self) -> R {
37 self.reader
38 }
39}
40
41impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
42 fn poll_read(
43 self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 buf: &mut ReadBuf<'_>,
46 ) -> Poll<Result<()>> {
47 let me = self.project();
48 let filled_length = buf.filled().len();
49 ready!(me.reader.poll_read(cx, buf))?;
50 (me.f)(&buf.filled()[filled_length..]);
51 Poll::Ready(Ok(()))
52 }
53}
54
55impl<R: AsyncWrite, F> AsyncWrite for InspectReader<R, F> {
56 fn poll_write(
57 self: Pin<&mut Self>,
58 cx: &mut Context<'_>,
59 buf: &[u8],
60 ) -> Poll<std::result::Result<usize, std::io::Error>> {
61 self.project().reader.poll_write(cx, buf)
62 }
63
64 fn poll_flush(
65 self: Pin<&mut Self>,
66 cx: &mut Context<'_>,
67 ) -> Poll<std::result::Result<(), std::io::Error>> {
68 self.project().reader.poll_flush(cx)
69 }
70
71 fn poll_shutdown(
72 self: Pin<&mut Self>,
73 cx: &mut Context<'_>,
74 ) -> Poll<std::result::Result<(), std::io::Error>> {
75 self.project().reader.poll_shutdown(cx)
76 }
77
78 fn poll_write_vectored(
79 self: Pin<&mut Self>,
80 cx: &mut Context<'_>,
81 bufs: &[IoSlice<'_>],
82 ) -> Poll<Result<usize>> {
83 self.project().reader.poll_write_vectored(cx, bufs)
84 }
85
86 fn is_write_vectored(&self) -> bool {
87 self.reader.is_write_vectored()
88 }
89}
90
91pin_project! {
92 /// An adapter that lets you inspect the data that's being written.
93 ///
94 /// This is useful for things like hashing data as it's written out.
95 pub struct InspectWriter<W, F> {
96 #[pin]
97 writer: W,
98 f: F,
99 }
100}
101
102impl<W, F> InspectWriter<W, F> {
103 /// Create a new InspectWriter, wrapping `write` and calling `f` for the
104 /// data successfully written by each write call.
105 ///
106 /// The closure `f` will never be called with an empty slice. A vectored
107 /// write can result in multiple calls to `f` - at most one call to `f` per
108 /// buffer supplied to `poll_write_vectored`.
109 pub fn new(writer: W, f: F) -> InspectWriter<W, F>
110 where
111 W: AsyncWrite,
112 F: FnMut(&[u8]),
113 {
114 InspectWriter { writer, f }
115 }
116
117 /// Consumes the `InspectWriter`, returning the wrapped writer
118 pub fn into_inner(self) -> W {
119 self.writer
120 }
121}
122
123impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> {
124 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
125 let me = self.project();
126 let res = me.writer.poll_write(cx, buf);
127 if let Poll::Ready(Ok(count)) = res {
128 if count != 0 {
129 (me.f)(&buf[..count]);
130 }
131 }
132 res
133 }
134
135 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
136 let me = self.project();
137 me.writer.poll_flush(cx)
138 }
139
140 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
141 let me = self.project();
142 me.writer.poll_shutdown(cx)
143 }
144
145 fn poll_write_vectored(
146 self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 bufs: &[IoSlice<'_>],
149 ) -> Poll<Result<usize>> {
150 let me = self.project();
151 let res = me.writer.poll_write_vectored(cx, bufs);
152 if let Poll::Ready(Ok(mut count)) = res {
153 for buf in bufs {
154 if count == 0 {
155 break;
156 }
157 let size = count.min(buf.len());
158 if size != 0 {
159 (me.f)(&buf[..size]);
160 count -= size;
161 }
162 }
163 }
164 res
165 }
166
167 fn is_write_vectored(&self) -> bool {
168 self.writer.is_write_vectored()
169 }
170}
171
172impl<W: AsyncRead, F> AsyncRead for InspectWriter<W, F> {
173 fn poll_read(
174 self: Pin<&mut Self>,
175 cx: &mut Context<'_>,
176 buf: &mut ReadBuf<'_>,
177 ) -> Poll<std::io::Result<()>> {
178 self.project().writer.poll_read(cx, buf)
179 }
180}
181