1use futures_core::task::{Context, Poll};
2use futures_io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, IoSlice, IoSliceMut, SeekFrom};
3use std::pin::Pin;
4use std::{fmt, io};
5
6/// A simple wrapper type which allows types which implement only
7/// implement `std::io::Read` or `std::io::Write`
8/// to be used in contexts which expect an `AsyncRead` or `AsyncWrite`.
9///
10/// If these types issue an error with the kind `io::ErrorKind::WouldBlock`,
11/// it is expected that they will notify the current task on readiness.
12/// Synchronous `std` types should not issue errors of this kind and
13/// are safe to use in this context. However, using these types with
14/// `AllowStdIo` will cause the event loop to block, so they should be used
15/// with care.
16#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
17pub struct AllowStdIo<T>(T);
18
19impl<T> Unpin for AllowStdIo<T> {}
20
21macro_rules! try_with_interrupt {
22 ($e:expr) => {
23 loop {
24 match $e {
25 Ok(e) => {
26 break e;
27 }
28 Err(ref e) if e.kind() == ::std::io::ErrorKind::Interrupted => {
29 continue;
30 }
31 Err(e) => {
32 return Poll::Ready(Err(e));
33 }
34 }
35 }
36 };
37}
38
39impl<T> AllowStdIo<T> {
40 /// Creates a new `AllowStdIo` from an existing IO object.
41 pub fn new(io: T) -> Self {
42 Self(io)
43 }
44
45 /// Returns a reference to the contained IO object.
46 pub fn get_ref(&self) -> &T {
47 &self.0
48 }
49
50 /// Returns a mutable reference to the contained IO object.
51 pub fn get_mut(&mut self) -> &mut T {
52 &mut self.0
53 }
54
55 /// Consumes self and returns the contained IO object.
56 pub fn into_inner(self) -> T {
57 self.0
58 }
59}
60
61impl<T> io::Write for AllowStdIo<T>
62where
63 T: io::Write,
64{
65 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
66 self.0.write(buf)
67 }
68 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
69 self.0.write_vectored(bufs)
70 }
71 fn flush(&mut self) -> io::Result<()> {
72 self.0.flush()
73 }
74 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
75 self.0.write_all(buf)
76 }
77 fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
78 self.0.write_fmt(fmt)
79 }
80}
81
82impl<T> AsyncWrite for AllowStdIo<T>
83where
84 T: io::Write,
85{
86 fn poll_write(
87 mut self: Pin<&mut Self>,
88 _: &mut Context<'_>,
89 buf: &[u8],
90 ) -> Poll<io::Result<usize>> {
91 Poll::Ready(Ok(try_with_interrupt!(self.0.write(buf))))
92 }
93
94 fn poll_write_vectored(
95 mut self: Pin<&mut Self>,
96 _: &mut Context<'_>,
97 bufs: &[IoSlice<'_>],
98 ) -> Poll<io::Result<usize>> {
99 Poll::Ready(Ok(try_with_interrupt!(self.0.write_vectored(bufs))))
100 }
101
102 fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
103 try_with_interrupt!(self.0.flush());
104 Poll::Ready(Ok(()))
105 }
106
107 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
108 self.poll_flush(cx)
109 }
110}
111
112impl<T> io::Read for AllowStdIo<T>
113where
114 T: io::Read,
115{
116 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
117 self.0.read(buf)
118 }
119 fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
120 self.0.read_vectored(bufs)
121 }
122 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
123 self.0.read_to_end(buf)
124 }
125 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
126 self.0.read_to_string(buf)
127 }
128 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
129 self.0.read_exact(buf)
130 }
131}
132
133impl<T> AsyncRead for AllowStdIo<T>
134where
135 T: io::Read,
136{
137 fn poll_read(
138 mut self: Pin<&mut Self>,
139 _: &mut Context<'_>,
140 buf: &mut [u8],
141 ) -> Poll<io::Result<usize>> {
142 Poll::Ready(Ok(try_with_interrupt!(self.0.read(buf))))
143 }
144
145 fn poll_read_vectored(
146 mut self: Pin<&mut Self>,
147 _: &mut Context<'_>,
148 bufs: &mut [IoSliceMut<'_>],
149 ) -> Poll<io::Result<usize>> {
150 Poll::Ready(Ok(try_with_interrupt!(self.0.read_vectored(bufs))))
151 }
152}
153
154impl<T> io::Seek for AllowStdIo<T>
155where
156 T: io::Seek,
157{
158 fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
159 self.0.seek(pos)
160 }
161}
162
163impl<T> AsyncSeek for AllowStdIo<T>
164where
165 T: io::Seek,
166{
167 fn poll_seek(
168 mut self: Pin<&mut Self>,
169 _: &mut Context<'_>,
170 pos: SeekFrom,
171 ) -> Poll<io::Result<u64>> {
172 Poll::Ready(Ok(try_with_interrupt!(self.0.seek(pos))))
173 }
174}
175
176impl<T> io::BufRead for AllowStdIo<T>
177where
178 T: io::BufRead,
179{
180 fn fill_buf(&mut self) -> io::Result<&[u8]> {
181 self.0.fill_buf()
182 }
183 fn consume(&mut self, amt: usize) {
184 self.0.consume(amt)
185 }
186}
187
188impl<T> AsyncBufRead for AllowStdIo<T>
189where
190 T: io::BufRead,
191{
192 fn poll_fill_buf(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
193 let this: *mut Self = &mut *self as *mut _;
194 Poll::Ready(Ok(try_with_interrupt!(unsafe { &mut *this }.0.fill_buf())))
195 }
196
197 fn consume(mut self: Pin<&mut Self>, amt: usize) {
198 self.0.consume(amt)
199 }
200}
201