1 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
2 | |
3 | use bytes::{Buf, BufMut}; |
4 | use futures_core::ready; |
5 | use std::io::{self, IoSlice}; |
6 | use std::mem::MaybeUninit; |
7 | use std::pin::Pin; |
8 | use std::task::{Context, Poll}; |
9 | |
10 | /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. |
11 | /// |
12 | /// [`BufMut`]: bytes::Buf |
13 | /// |
14 | /// # Example |
15 | /// |
16 | /// ``` |
17 | /// use bytes::{Bytes, BytesMut}; |
18 | /// use tokio_stream as stream; |
19 | /// use tokio::io::Result; |
20 | /// use tokio_util::io::{StreamReader, poll_read_buf}; |
21 | /// use futures::future::poll_fn; |
22 | /// use std::pin::Pin; |
23 | /// # #[tokio::main] |
24 | /// # async fn main() -> std::io::Result<()> { |
25 | /// |
26 | /// // Create a reader from an iterator. This particular reader will always be |
27 | /// // ready. |
28 | /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); |
29 | /// |
30 | /// let mut buf = BytesMut::new(); |
31 | /// let mut reads = 0; |
32 | /// |
33 | /// loop { |
34 | /// reads += 1; |
35 | /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; |
36 | /// |
37 | /// if n == 0 { |
38 | /// break; |
39 | /// } |
40 | /// } |
41 | /// |
42 | /// // one or more reads might be necessary. |
43 | /// assert!(reads >= 1); |
44 | /// assert_eq!(&buf[..], &[0, 1, 2, 3]); |
45 | /// # Ok(()) |
46 | /// # } |
47 | /// ``` |
48 | #[cfg_attr (not(feature = "io" ), allow(unreachable_pub))] |
49 | pub fn poll_read_buf<T: AsyncRead, B: BufMut>( |
50 | io: Pin<&mut T>, |
51 | cx: &mut Context<'_>, |
52 | buf: &mut B, |
53 | ) -> Poll<io::Result<usize>> { |
54 | if !buf.has_remaining_mut() { |
55 | return Poll::Ready(Ok(0)); |
56 | } |
57 | |
58 | let n = { |
59 | let dst = buf.chunk_mut(); |
60 | |
61 | // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a |
62 | // transparent wrapper around `[MaybeUninit<u8>]`. |
63 | let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) }; |
64 | let mut buf = ReadBuf::uninit(dst); |
65 | let ptr = buf.filled().as_ptr(); |
66 | ready!(io.poll_read(cx, &mut buf)?); |
67 | |
68 | // Ensure the pointer does not change from under us |
69 | assert_eq!(ptr, buf.filled().as_ptr()); |
70 | buf.filled().len() |
71 | }; |
72 | |
73 | // Safety: This is guaranteed to be the number of initialized (and read) |
74 | // bytes due to the invariants provided by `ReadBuf::filled`. |
75 | unsafe { |
76 | buf.advance_mut(n); |
77 | } |
78 | |
79 | Poll::Ready(Ok(n)) |
80 | } |
81 | |
82 | /// Try to write data from an implementer of the [`Buf`] trait to an |
83 | /// [`AsyncWrite`], advancing the buffer's internal cursor. |
84 | /// |
85 | /// This function will use [vectored writes] when the [`AsyncWrite`] supports |
86 | /// vectored writes. |
87 | /// |
88 | /// # Examples |
89 | /// |
90 | /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements |
91 | /// [`Buf`]: |
92 | /// |
93 | /// ```no_run |
94 | /// use tokio_util::io::poll_write_buf; |
95 | /// use tokio::io; |
96 | /// use tokio::fs::File; |
97 | /// |
98 | /// use bytes::Buf; |
99 | /// use std::io::Cursor; |
100 | /// use std::pin::Pin; |
101 | /// use futures::future::poll_fn; |
102 | /// |
103 | /// #[tokio::main] |
104 | /// async fn main() -> io::Result<()> { |
105 | /// let mut file = File::create("foo.txt" ).await?; |
106 | /// let mut buf = Cursor::new(b"data to write" ); |
107 | /// |
108 | /// // Loop until the entire contents of the buffer are written to |
109 | /// // the file. |
110 | /// while buf.has_remaining() { |
111 | /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; |
112 | /// } |
113 | /// |
114 | /// Ok(()) |
115 | /// } |
116 | /// ``` |
117 | /// |
118 | /// [`Buf`]: bytes::Buf |
119 | /// [`AsyncWrite`]: tokio::io::AsyncWrite |
120 | /// [`File`]: tokio::fs::File |
121 | /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored |
122 | #[cfg_attr (not(feature = "io" ), allow(unreachable_pub))] |
123 | pub fn poll_write_buf<T: AsyncWrite, B: Buf>( |
124 | io: Pin<&mut T>, |
125 | cx: &mut Context<'_>, |
126 | buf: &mut B, |
127 | ) -> Poll<io::Result<usize>> { |
128 | const MAX_BUFS: usize = 64; |
129 | |
130 | if !buf.has_remaining() { |
131 | return Poll::Ready(Ok(0)); |
132 | } |
133 | |
134 | let n: usize = if io.is_write_vectored() { |
135 | let mut slices: [IoSlice<'_>; 64] = [IoSlice::new(&[]); MAX_BUFS]; |
136 | let cnt: usize = buf.chunks_vectored(&mut slices); |
137 | ready!(io.poll_write_vectored(cx, &slices[..cnt]))? |
138 | } else { |
139 | ready!(io.poll_write(cx, buf.chunk()))? |
140 | }; |
141 | |
142 | buf.advance(cnt:n); |
143 | |
144 | Poll::Ready(Ok(n)) |
145 | } |
146 | |