1 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
2 | |
3 | use bytes::{Buf, BufMut}; |
4 | use std::io::{self, IoSlice}; |
5 | use std::pin::Pin; |
6 | use std::task::{ready, Context, Poll}; |
7 | |
8 | /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. |
9 | /// |
10 | /// [`BufMut`]: bytes::Buf |
11 | /// |
12 | /// # Example |
13 | /// |
14 | /// ``` |
15 | /// use bytes::{Bytes, BytesMut}; |
16 | /// use tokio_stream as stream; |
17 | /// use tokio::io::Result; |
18 | /// use tokio_util::io::{StreamReader, poll_read_buf}; |
19 | /// use std::future::poll_fn; |
20 | /// use std::pin::Pin; |
21 | /// # #[tokio::main] |
22 | /// # async fn main() -> std::io::Result<()> { |
23 | /// |
24 | /// // Create a reader from an iterator. This particular reader will always be |
25 | /// // ready. |
26 | /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); |
27 | /// |
28 | /// let mut buf = BytesMut::new(); |
29 | /// let mut reads = 0; |
30 | /// |
31 | /// loop { |
32 | /// reads += 1; |
33 | /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; |
34 | /// |
35 | /// if n == 0 { |
36 | /// break; |
37 | /// } |
38 | /// } |
39 | /// |
40 | /// // one or more reads might be necessary. |
41 | /// assert!(reads >= 1); |
42 | /// assert_eq!(&buf[..], &[0, 1, 2, 3]); |
43 | /// # Ok(()) |
44 | /// # } |
45 | /// ``` |
46 | #[cfg_attr (not(feature = "io" ), allow(unreachable_pub))] |
47 | pub fn poll_read_buf<T: AsyncRead + ?Sized, B: BufMut>( |
48 | io: Pin<&mut T>, |
49 | cx: &mut Context<'_>, |
50 | buf: &mut B, |
51 | ) -> Poll<io::Result<usize>> { |
52 | if !buf.has_remaining_mut() { |
53 | return Poll::Ready(Ok(0)); |
54 | } |
55 | |
56 | let n = { |
57 | let dst = buf.chunk_mut(); |
58 | |
59 | // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a |
60 | // transparent wrapper around `[MaybeUninit<u8>]`. |
61 | let dst = unsafe { dst.as_uninit_slice_mut() }; |
62 | let mut buf = ReadBuf::uninit(dst); |
63 | let ptr = buf.filled().as_ptr(); |
64 | ready!(io.poll_read(cx, &mut buf)?); |
65 | |
66 | // Ensure the pointer does not change from under us |
67 | assert_eq!(ptr, buf.filled().as_ptr()); |
68 | buf.filled().len() |
69 | }; |
70 | |
71 | // Safety: This is guaranteed to be the number of initialized (and read) |
72 | // bytes due to the invariants provided by `ReadBuf::filled`. |
73 | unsafe { |
74 | buf.advance_mut(n); |
75 | } |
76 | |
77 | Poll::Ready(Ok(n)) |
78 | } |
79 | |
80 | /// Try to write data from an implementer of the [`Buf`] trait to an |
81 | /// [`AsyncWrite`], advancing the buffer's internal cursor. |
82 | /// |
83 | /// This function will use [vectored writes] when the [`AsyncWrite`] supports |
84 | /// vectored writes. |
85 | /// |
86 | /// # Examples |
87 | /// |
88 | /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements |
89 | /// [`Buf`]: |
90 | /// |
91 | /// ```no_run |
92 | /// use tokio_util::io::poll_write_buf; |
93 | /// use tokio::io; |
94 | /// use tokio::fs::File; |
95 | /// |
96 | /// use bytes::Buf; |
97 | /// use std::future::poll_fn; |
98 | /// use std::io::Cursor; |
99 | /// use std::pin::Pin; |
100 | /// |
101 | /// #[tokio::main] |
102 | /// async fn main() -> io::Result<()> { |
103 | /// let mut file = File::create("foo.txt" ).await?; |
104 | /// let mut buf = Cursor::new(b"data to write" ); |
105 | /// |
106 | /// // Loop until the entire contents of the buffer are written to |
107 | /// // the file. |
108 | /// while buf.has_remaining() { |
109 | /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; |
110 | /// } |
111 | /// |
112 | /// Ok(()) |
113 | /// } |
114 | /// ``` |
115 | /// |
116 | /// [`Buf`]: bytes::Buf |
117 | /// [`AsyncWrite`]: tokio::io::AsyncWrite |
118 | /// [`File`]: tokio::fs::File |
119 | /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored |
120 | #[cfg_attr (not(feature = "io" ), allow(unreachable_pub))] |
121 | pub fn poll_write_buf<T: AsyncWrite + ?Sized, B: Buf>( |
122 | io: Pin<&mut T>, |
123 | cx: &mut Context<'_>, |
124 | buf: &mut B, |
125 | ) -> Poll<io::Result<usize>> { |
126 | const MAX_BUFS: usize = 64; |
127 | |
128 | if !buf.has_remaining() { |
129 | return Poll::Ready(Ok(0)); |
130 | } |
131 | |
132 | let n: usize = if io.is_write_vectored() { |
133 | let mut slices: [IoSlice<'_>; 64] = [IoSlice::new(&[]); MAX_BUFS]; |
134 | let cnt: usize = buf.chunks_vectored(&mut slices); |
135 | ready!(io.poll_write_vectored(cx, &slices[..cnt]))? |
136 | } else { |
137 | ready!(io.poll_write(cx, buf.chunk()))? |
138 | }; |
139 | |
140 | buf.advance(cnt:n); |
141 | |
142 | Poll::Ready(Ok(n)) |
143 | } |
144 | |