1 | //! Compatibility between the `tokio::io` and `futures-io` versions of the |
2 | //! `AsyncRead` and `AsyncWrite` traits. |
3 | use futures_core::ready; |
4 | use pin_project_lite::pin_project; |
5 | use std::io; |
6 | use std::pin::Pin; |
7 | use std::task::{Context, Poll}; |
8 | |
9 | pin_project! { |
10 | /// A compatibility layer that allows conversion between the |
11 | /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits. |
12 | #[derive(Copy, Clone, Debug)] |
13 | pub struct Compat<T> { |
14 | #[pin] |
15 | inner: T, |
16 | seek_pos: Option<io::SeekFrom>, |
17 | } |
18 | } |
19 | |
20 | /// Extension trait that allows converting a type implementing |
21 | /// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`. |
22 | pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead { |
23 | /// Wraps `self` with a compatibility layer that implements |
24 | /// `tokio_io::AsyncRead`. |
25 | fn compat(self) -> Compat<Self> |
26 | where |
27 | Self: Sized, |
28 | { |
29 | Compat::new(self) |
30 | } |
31 | } |
32 | |
33 | impl<T: futures_io::AsyncRead> FuturesAsyncReadCompatExt for T {} |
34 | |
35 | /// Extension trait that allows converting a type implementing |
36 | /// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`. |
37 | pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite { |
38 | /// Wraps `self` with a compatibility layer that implements |
39 | /// `tokio::io::AsyncWrite`. |
40 | fn compat_write(self) -> Compat<Self> |
41 | where |
42 | Self: Sized, |
43 | { |
44 | Compat::new(self) |
45 | } |
46 | } |
47 | |
48 | impl<T: futures_io::AsyncWrite> FuturesAsyncWriteCompatExt for T {} |
49 | |
50 | /// Extension trait that allows converting a type implementing |
51 | /// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`. |
52 | pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead { |
53 | /// Wraps `self` with a compatibility layer that implements |
54 | /// `futures_io::AsyncRead`. |
55 | fn compat(self) -> Compat<Self> |
56 | where |
57 | Self: Sized, |
58 | { |
59 | Compat::new(self) |
60 | } |
61 | } |
62 | |
63 | impl<T: tokio::io::AsyncRead> TokioAsyncReadCompatExt for T {} |
64 | |
65 | /// Extension trait that allows converting a type implementing |
66 | /// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`. |
67 | pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite { |
68 | /// Wraps `self` with a compatibility layer that implements |
69 | /// `futures_io::AsyncWrite`. |
70 | fn compat_write(self) -> Compat<Self> |
71 | where |
72 | Self: Sized, |
73 | { |
74 | Compat::new(self) |
75 | } |
76 | } |
77 | |
78 | impl<T: tokio::io::AsyncWrite> TokioAsyncWriteCompatExt for T {} |
79 | |
80 | // === impl Compat === |
81 | |
82 | impl<T> Compat<T> { |
83 | fn new(inner: T) -> Self { |
84 | Self { |
85 | inner, |
86 | seek_pos: None, |
87 | } |
88 | } |
89 | |
90 | /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object |
91 | /// contained within. |
92 | pub fn get_ref(&self) -> &T { |
93 | &self.inner |
94 | } |
95 | |
96 | /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object |
97 | /// contained within. |
98 | pub fn get_mut(&mut self) -> &mut T { |
99 | &mut self.inner |
100 | } |
101 | |
102 | /// Returns the wrapped item. |
103 | pub fn into_inner(self) -> T { |
104 | self.inner |
105 | } |
106 | } |
107 | |
108 | impl<T> tokio::io::AsyncRead for Compat<T> |
109 | where |
110 | T: futures_io::AsyncRead, |
111 | { |
112 | fn poll_read( |
113 | self: Pin<&mut Self>, |
114 | cx: &mut Context<'_>, |
115 | buf: &mut tokio::io::ReadBuf<'_>, |
116 | ) -> Poll<io::Result<()>> { |
117 | // We can't trust the inner type to not peak at the bytes, |
118 | // so we must defensively initialize the buffer. |
119 | let slice = buf.initialize_unfilled(); |
120 | let n = ready!(futures_io::AsyncRead::poll_read( |
121 | self.project().inner, |
122 | cx, |
123 | slice |
124 | ))?; |
125 | buf.advance(n); |
126 | Poll::Ready(Ok(())) |
127 | } |
128 | } |
129 | |
130 | impl<T> futures_io::AsyncRead for Compat<T> |
131 | where |
132 | T: tokio::io::AsyncRead, |
133 | { |
134 | fn poll_read( |
135 | self: Pin<&mut Self>, |
136 | cx: &mut Context<'_>, |
137 | slice: &mut [u8], |
138 | ) -> Poll<io::Result<usize>> { |
139 | let mut buf = tokio::io::ReadBuf::new(slice); |
140 | ready!(tokio::io::AsyncRead::poll_read( |
141 | self.project().inner, |
142 | cx, |
143 | &mut buf |
144 | ))?; |
145 | Poll::Ready(Ok(buf.filled().len())) |
146 | } |
147 | } |
148 | |
149 | impl<T> tokio::io::AsyncBufRead for Compat<T> |
150 | where |
151 | T: futures_io::AsyncBufRead, |
152 | { |
153 | fn poll_fill_buf<'a>( |
154 | self: Pin<&'a mut Self>, |
155 | cx: &mut Context<'_>, |
156 | ) -> Poll<io::Result<&'a [u8]>> { |
157 | futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) |
158 | } |
159 | |
160 | fn consume(self: Pin<&mut Self>, amt: usize) { |
161 | futures_io::AsyncBufRead::consume(self.project().inner, amt) |
162 | } |
163 | } |
164 | |
165 | impl<T> futures_io::AsyncBufRead for Compat<T> |
166 | where |
167 | T: tokio::io::AsyncBufRead, |
168 | { |
169 | fn poll_fill_buf<'a>( |
170 | self: Pin<&'a mut Self>, |
171 | cx: &mut Context<'_>, |
172 | ) -> Poll<io::Result<&'a [u8]>> { |
173 | tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx) |
174 | } |
175 | |
176 | fn consume(self: Pin<&mut Self>, amt: usize) { |
177 | tokio::io::AsyncBufRead::consume(self.project().inner, amt) |
178 | } |
179 | } |
180 | |
181 | impl<T> tokio::io::AsyncWrite for Compat<T> |
182 | where |
183 | T: futures_io::AsyncWrite, |
184 | { |
185 | fn poll_write( |
186 | self: Pin<&mut Self>, |
187 | cx: &mut Context<'_>, |
188 | buf: &[u8], |
189 | ) -> Poll<io::Result<usize>> { |
190 | futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf) |
191 | } |
192 | |
193 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
194 | futures_io::AsyncWrite::poll_flush(self.project().inner, cx) |
195 | } |
196 | |
197 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
198 | futures_io::AsyncWrite::poll_close(self.project().inner, cx) |
199 | } |
200 | } |
201 | |
202 | impl<T> futures_io::AsyncWrite for Compat<T> |
203 | where |
204 | T: tokio::io::AsyncWrite, |
205 | { |
206 | fn poll_write( |
207 | self: Pin<&mut Self>, |
208 | cx: &mut Context<'_>, |
209 | buf: &[u8], |
210 | ) -> Poll<io::Result<usize>> { |
211 | tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) |
212 | } |
213 | |
214 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
215 | tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) |
216 | } |
217 | |
218 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
219 | tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) |
220 | } |
221 | } |
222 | |
223 | impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> { |
224 | fn poll_seek( |
225 | mut self: Pin<&mut Self>, |
226 | cx: &mut Context<'_>, |
227 | pos: io::SeekFrom, |
228 | ) -> Poll<io::Result<u64>> { |
229 | if self.seek_pos != Some(pos) { |
230 | // Ensure previous seeks have finished before starting a new one |
231 | ready!(self.as_mut().project().inner.poll_complete(cx))?; |
232 | self.as_mut().project().inner.start_seek(pos)?; |
233 | *self.as_mut().project().seek_pos = Some(pos); |
234 | } |
235 | let res = ready!(self.as_mut().project().inner.poll_complete(cx)); |
236 | *self.as_mut().project().seek_pos = None; |
237 | Poll::Ready(res) |
238 | } |
239 | } |
240 | |
241 | impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> { |
242 | fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> { |
243 | *self.as_mut().project().seek_pos = Some(pos); |
244 | Ok(()) |
245 | } |
246 | |
247 | fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { |
248 | let pos = match self.seek_pos { |
249 | None => { |
250 | // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek. |
251 | // We don't have to guarantee that the value returned by |
252 | // poll_complete called without start_seek is correct, |
253 | // so we'll return 0. |
254 | return Poll::Ready(Ok(0)); |
255 | } |
256 | Some(pos) => pos, |
257 | }; |
258 | let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos)); |
259 | *self.as_mut().project().seek_pos = None; |
260 | Poll::Ready(res) |
261 | } |
262 | } |
263 | |
264 | #[cfg (unix)] |
265 | impl<T: std::os::unix::io::AsRawFd> std::os::unix::io::AsRawFd for Compat<T> { |
266 | fn as_raw_fd(&self) -> std::os::unix::io::RawFd { |
267 | self.inner.as_raw_fd() |
268 | } |
269 | } |
270 | |
271 | #[cfg (windows)] |
272 | impl<T: std::os::windows::io::AsRawHandle> std::os::windows::io::AsRawHandle for Compat<T> { |
273 | fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { |
274 | self.inner.as_raw_handle() |
275 | } |
276 | } |
277 | |