1 | use pin_project_lite::pin_project; |
2 | use std::{ |
3 | pin::Pin, |
4 | task::{Context, Poll}, |
5 | }; |
6 | |
7 | pin_project! { |
8 | /// Extends an underlying [`tokio`] I/O with [`hyper`] I/O implementations. |
9 | /// |
10 | /// This implements [`Read`] and [`Write`] given an inner type that implements [`AsyncRead`] |
11 | /// and [`AsyncWrite`], respectively. |
12 | #[derive (Debug)] |
13 | pub struct WithHyperIo<I> { |
14 | #[pin] |
15 | inner: I, |
16 | } |
17 | } |
18 | |
19 | // ==== impl WithHyperIo ===== |
20 | |
21 | impl<I> WithHyperIo<I> { |
22 | /// Wraps the inner I/O in an [`WithHyperIo<I>`] |
23 | pub fn new(inner: I) -> Self { |
24 | Self { inner } |
25 | } |
26 | |
27 | /// Returns a reference to the inner type. |
28 | pub fn inner(&self) -> &I { |
29 | &self.inner |
30 | } |
31 | |
32 | /// Returns a mutable reference to the inner type. |
33 | pub fn inner_mut(&mut self) -> &mut I { |
34 | &mut self.inner |
35 | } |
36 | |
37 | /// Consumes this wrapper and returns the inner type. |
38 | pub fn into_inner(self) -> I { |
39 | self.inner |
40 | } |
41 | } |
42 | |
43 | /// [`WithHyperIo<I>`] is [`Read`] if `I` is [`AsyncRead`]. |
44 | /// |
45 | /// [`AsyncRead`]: tokio::io::AsyncRead |
46 | /// [`Read`]: hyper::rt::Read |
47 | impl<I> hyper::rt::Read for WithHyperIo<I> |
48 | where |
49 | I: tokio::io::AsyncRead, |
50 | { |
51 | fn poll_read( |
52 | self: Pin<&mut Self>, |
53 | cx: &mut Context<'_>, |
54 | mut buf: hyper::rt::ReadBufCursor<'_>, |
55 | ) -> Poll<Result<(), std::io::Error>> { |
56 | let n: usize = unsafe { |
57 | let mut tbuf: ReadBuf<'_> = tokio::io::ReadBuf::uninit(buf.as_mut()); |
58 | match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { |
59 | Poll::Ready(Ok(())) => tbuf.filled().len(), |
60 | other: Poll> => return other, |
61 | } |
62 | }; |
63 | |
64 | unsafe { |
65 | buf.advance(n); |
66 | } |
67 | Poll::Ready(Ok(())) |
68 | } |
69 | } |
70 | |
71 | /// [`WithHyperIo<I>`] is [`Write`] if `I` is [`AsyncWrite`]. |
72 | /// |
73 | /// [`AsyncWrite`]: tokio::io::AsyncWrite |
74 | /// [`Write`]: hyper::rt::Write |
75 | impl<I> hyper::rt::Write for WithHyperIo<I> |
76 | where |
77 | I: tokio::io::AsyncWrite, |
78 | { |
79 | fn poll_write( |
80 | self: Pin<&mut Self>, |
81 | cx: &mut Context<'_>, |
82 | buf: &[u8], |
83 | ) -> Poll<Result<usize, std::io::Error>> { |
84 | tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) |
85 | } |
86 | |
87 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { |
88 | tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) |
89 | } |
90 | |
91 | fn poll_shutdown( |
92 | self: Pin<&mut Self>, |
93 | cx: &mut Context<'_>, |
94 | ) -> Poll<Result<(), std::io::Error>> { |
95 | tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) |
96 | } |
97 | |
98 | fn is_write_vectored(&self) -> bool { |
99 | tokio::io::AsyncWrite::is_write_vectored(&self.inner) |
100 | } |
101 | |
102 | fn poll_write_vectored( |
103 | self: Pin<&mut Self>, |
104 | cx: &mut Context<'_>, |
105 | bufs: &[std::io::IoSlice<'_>], |
106 | ) -> Poll<Result<usize, std::io::Error>> { |
107 | tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) |
108 | } |
109 | } |
110 | |
111 | /// [`WithHyperIo<I>`] exposes its inner `I`'s [`AsyncRead`] implementation. |
112 | /// |
113 | /// [`AsyncRead`]: tokio::io::AsyncRead |
114 | impl<I> tokio::io::AsyncRead for WithHyperIo<I> |
115 | where |
116 | I: tokio::io::AsyncRead, |
117 | { |
118 | #[inline ] |
119 | fn poll_read( |
120 | self: Pin<&mut Self>, |
121 | cx: &mut Context<'_>, |
122 | buf: &mut tokio::io::ReadBuf<'_>, |
123 | ) -> Poll<Result<(), std::io::Error>> { |
124 | self.project().inner.poll_read(cx, buf) |
125 | } |
126 | } |
127 | |
128 | /// [`WithHyperIo<I>`] exposes its inner `I`'s [`AsyncWrite`] implementation. |
129 | /// |
130 | /// [`AsyncWrite`]: tokio::io::AsyncWrite |
131 | impl<I> tokio::io::AsyncWrite for WithHyperIo<I> |
132 | where |
133 | I: tokio::io::AsyncWrite, |
134 | { |
135 | #[inline ] |
136 | fn poll_write( |
137 | self: Pin<&mut Self>, |
138 | cx: &mut Context<'_>, |
139 | buf: &[u8], |
140 | ) -> Poll<Result<usize, std::io::Error>> { |
141 | self.project().inner.poll_write(cx, buf) |
142 | } |
143 | |
144 | #[inline ] |
145 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { |
146 | self.project().inner.poll_flush(cx) |
147 | } |
148 | |
149 | #[inline ] |
150 | fn poll_shutdown( |
151 | self: Pin<&mut Self>, |
152 | cx: &mut Context<'_>, |
153 | ) -> Poll<Result<(), std::io::Error>> { |
154 | self.project().inner.poll_shutdown(cx) |
155 | } |
156 | |
157 | #[inline ] |
158 | fn is_write_vectored(&self) -> bool { |
159 | self.inner.is_write_vectored() |
160 | } |
161 | |
162 | #[inline ] |
163 | fn poll_write_vectored( |
164 | self: Pin<&mut Self>, |
165 | cx: &mut Context<'_>, |
166 | bufs: &[std::io::IoSlice<'_>], |
167 | ) -> Poll<Result<usize, std::io::Error>> { |
168 | self.project().inner.poll_write_vectored(cx, bufs) |
169 | } |
170 | } |
171 | |