1//! Compatibility between the `tokio::io` and `futures-io` versions of the
2//! `AsyncRead` and `AsyncWrite` traits.
3use futures_core::ready;
4use pin_project_lite::pin_project;
5use std::io;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10 /// A compatibility layer that allows conversion between the
11 /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits.
12 #[derive(Copy, Clone, Debug)]
13 pub struct Compat<T> {
14 #[pin]
15 inner: T,
16 seek_pos: Option<io::SeekFrom>,
17 }
18}
19
20/// Extension trait that allows converting a type implementing
21/// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`.
22pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead {
23 /// Wraps `self` with a compatibility layer that implements
24 /// `tokio_io::AsyncRead`.
25 fn compat(self) -> Compat<Self>
26 where
27 Self: Sized,
28 {
29 Compat::new(self)
30 }
31}
32
33impl<T: futures_io::AsyncRead> FuturesAsyncReadCompatExt for T {}
34
35/// Extension trait that allows converting a type implementing
36/// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`.
37pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite {
38 /// Wraps `self` with a compatibility layer that implements
39 /// `tokio::io::AsyncWrite`.
40 fn compat_write(self) -> Compat<Self>
41 where
42 Self: Sized,
43 {
44 Compat::new(self)
45 }
46}
47
48impl<T: futures_io::AsyncWrite> FuturesAsyncWriteCompatExt for T {}
49
50/// Extension trait that allows converting a type implementing
51/// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`.
52pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead {
53 /// Wraps `self` with a compatibility layer that implements
54 /// `futures_io::AsyncRead`.
55 fn compat(self) -> Compat<Self>
56 where
57 Self: Sized,
58 {
59 Compat::new(self)
60 }
61}
62
63impl<T: tokio::io::AsyncRead> TokioAsyncReadCompatExt for T {}
64
65/// Extension trait that allows converting a type implementing
66/// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`.
67pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite {
68 /// Wraps `self` with a compatibility layer that implements
69 /// `futures_io::AsyncWrite`.
70 fn compat_write(self) -> Compat<Self>
71 where
72 Self: Sized,
73 {
74 Compat::new(self)
75 }
76}
77
78impl<T: tokio::io::AsyncWrite> TokioAsyncWriteCompatExt for T {}
79
80// === impl Compat ===
81
82impl<T> Compat<T> {
83 fn new(inner: T) -> Self {
84 Self {
85 inner,
86 seek_pos: None,
87 }
88 }
89
90 /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
91 /// contained within.
92 pub fn get_ref(&self) -> &T {
93 &self.inner
94 }
95
96 /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
97 /// contained within.
98 pub fn get_mut(&mut self) -> &mut T {
99 &mut self.inner
100 }
101
102 /// Returns the wrapped item.
103 pub fn into_inner(self) -> T {
104 self.inner
105 }
106}
107
108impl<T> tokio::io::AsyncRead for Compat<T>
109where
110 T: futures_io::AsyncRead,
111{
112 fn poll_read(
113 self: Pin<&mut Self>,
114 cx: &mut Context<'_>,
115 buf: &mut tokio::io::ReadBuf<'_>,
116 ) -> Poll<io::Result<()>> {
117 // We can't trust the inner type to not peak at the bytes,
118 // so we must defensively initialize the buffer.
119 let slice = buf.initialize_unfilled();
120 let n = ready!(futures_io::AsyncRead::poll_read(
121 self.project().inner,
122 cx,
123 slice
124 ))?;
125 buf.advance(n);
126 Poll::Ready(Ok(()))
127 }
128}
129
130impl<T> futures_io::AsyncRead for Compat<T>
131where
132 T: tokio::io::AsyncRead,
133{
134 fn poll_read(
135 self: Pin<&mut Self>,
136 cx: &mut Context<'_>,
137 slice: &mut [u8],
138 ) -> Poll<io::Result<usize>> {
139 let mut buf = tokio::io::ReadBuf::new(slice);
140 ready!(tokio::io::AsyncRead::poll_read(
141 self.project().inner,
142 cx,
143 &mut buf
144 ))?;
145 Poll::Ready(Ok(buf.filled().len()))
146 }
147}
148
149impl<T> tokio::io::AsyncBufRead for Compat<T>
150where
151 T: futures_io::AsyncBufRead,
152{
153 fn poll_fill_buf<'a>(
154 self: Pin<&'a mut Self>,
155 cx: &mut Context<'_>,
156 ) -> Poll<io::Result<&'a [u8]>> {
157 futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
158 }
159
160 fn consume(self: Pin<&mut Self>, amt: usize) {
161 futures_io::AsyncBufRead::consume(self.project().inner, amt)
162 }
163}
164
165impl<T> futures_io::AsyncBufRead for Compat<T>
166where
167 T: tokio::io::AsyncBufRead,
168{
169 fn poll_fill_buf<'a>(
170 self: Pin<&'a mut Self>,
171 cx: &mut Context<'_>,
172 ) -> Poll<io::Result<&'a [u8]>> {
173 tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
174 }
175
176 fn consume(self: Pin<&mut Self>, amt: usize) {
177 tokio::io::AsyncBufRead::consume(self.project().inner, amt)
178 }
179}
180
181impl<T> tokio::io::AsyncWrite for Compat<T>
182where
183 T: futures_io::AsyncWrite,
184{
185 fn poll_write(
186 self: Pin<&mut Self>,
187 cx: &mut Context<'_>,
188 buf: &[u8],
189 ) -> Poll<io::Result<usize>> {
190 futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf)
191 }
192
193 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194 futures_io::AsyncWrite::poll_flush(self.project().inner, cx)
195 }
196
197 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198 futures_io::AsyncWrite::poll_close(self.project().inner, cx)
199 }
200}
201
202impl<T> futures_io::AsyncWrite for Compat<T>
203where
204 T: tokio::io::AsyncWrite,
205{
206 fn poll_write(
207 self: Pin<&mut Self>,
208 cx: &mut Context<'_>,
209 buf: &[u8],
210 ) -> Poll<io::Result<usize>> {
211 tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
212 }
213
214 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
215 tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
216 }
217
218 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
219 tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
220 }
221}
222
223impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> {
224 fn poll_seek(
225 mut self: Pin<&mut Self>,
226 cx: &mut Context<'_>,
227 pos: io::SeekFrom,
228 ) -> Poll<io::Result<u64>> {
229 if self.seek_pos != Some(pos) {
230 // Ensure previous seeks have finished before starting a new one
231 ready!(self.as_mut().project().inner.poll_complete(cx))?;
232 self.as_mut().project().inner.start_seek(pos)?;
233 *self.as_mut().project().seek_pos = Some(pos);
234 }
235 let res = ready!(self.as_mut().project().inner.poll_complete(cx));
236 *self.as_mut().project().seek_pos = None;
237 Poll::Ready(res)
238 }
239}
240
241impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> {
242 fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
243 *self.as_mut().project().seek_pos = Some(pos);
244 Ok(())
245 }
246
247 fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
248 let pos = match self.seek_pos {
249 None => {
250 // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek.
251 // We don't have to guarantee that the value returned by
252 // poll_complete called without start_seek is correct,
253 // so we'll return 0.
254 return Poll::Ready(Ok(0));
255 }
256 Some(pos) => pos,
257 };
258 let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos));
259 *self.as_mut().project().seek_pos = None;
260 Poll::Ready(res)
261 }
262}
263
264#[cfg(unix)]
265impl<T: std::os::unix::io::AsRawFd> std::os::unix::io::AsRawFd for Compat<T> {
266 fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
267 self.inner.as_raw_fd()
268 }
269}
270
271#[cfg(windows)]
272impl<T: std::os::windows::io::AsRawHandle> std::os::windows::io::AsRawHandle for Compat<T> {
273 fn as_raw_handle(&self) -> std::os::windows::io::RawHandle {
274 self.inner.as_raw_handle()
275 }
276}
277