1 | #![warn (rust_2018_idioms)] |
2 | #![cfg (feature = "full" )] |
3 | |
4 | use tokio::io::{AsyncWrite, AsyncWriteExt}; |
5 | use tokio_test::{assert_err, assert_ok}; |
6 | |
7 | use bytes::{Buf, Bytes, BytesMut}; |
8 | use std::cmp; |
9 | use std::io; |
10 | use std::pin::Pin; |
11 | use std::task::{Context, Poll}; |
12 | |
13 | #[tokio::test ] |
14 | async fn write_all_buf() { |
15 | struct Wr { |
16 | buf: BytesMut, |
17 | cnt: usize, |
18 | } |
19 | |
20 | impl AsyncWrite for Wr { |
21 | fn poll_write( |
22 | mut self: Pin<&mut Self>, |
23 | _cx: &mut Context<'_>, |
24 | buf: &[u8], |
25 | ) -> Poll<io::Result<usize>> { |
26 | let n = cmp::min(4, buf.len()); |
27 | dbg!(buf); |
28 | let buf = &buf[0..n]; |
29 | |
30 | self.cnt += 1; |
31 | self.buf.extend(buf); |
32 | Ok(buf.len()).into() |
33 | } |
34 | |
35 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
36 | Ok(()).into() |
37 | } |
38 | |
39 | fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
40 | Ok(()).into() |
41 | } |
42 | } |
43 | |
44 | let mut wr = Wr { |
45 | buf: BytesMut::with_capacity(64), |
46 | cnt: 0, |
47 | }; |
48 | |
49 | let mut buf = Bytes::from_static(b"hello" ).chain(Bytes::from_static(b"world" )); |
50 | |
51 | assert_ok!(wr.write_all_buf(&mut buf).await); |
52 | assert_eq!(wr.buf, b"helloworld" [..]); |
53 | // expect 4 writes, [hell],[o],[worl],[d] |
54 | assert_eq!(wr.cnt, 4); |
55 | assert!(!buf.has_remaining()); |
56 | } |
57 | |
58 | #[tokio::test ] |
59 | async fn write_buf_err() { |
60 | /// Error out after writing the first 4 bytes |
61 | struct Wr { |
62 | cnt: usize, |
63 | } |
64 | |
65 | impl AsyncWrite for Wr { |
66 | fn poll_write( |
67 | mut self: Pin<&mut Self>, |
68 | _cx: &mut Context<'_>, |
69 | _buf: &[u8], |
70 | ) -> Poll<io::Result<usize>> { |
71 | self.cnt += 1; |
72 | if self.cnt == 2 { |
73 | return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "whoops" ))); |
74 | } |
75 | Poll::Ready(Ok(4)) |
76 | } |
77 | |
78 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
79 | Ok(()).into() |
80 | } |
81 | |
82 | fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
83 | Ok(()).into() |
84 | } |
85 | } |
86 | |
87 | let mut wr = Wr { cnt: 0 }; |
88 | |
89 | let mut buf = Bytes::from_static(b"hello" ).chain(Bytes::from_static(b"world" )); |
90 | |
91 | assert_err!(wr.write_all_buf(&mut buf).await); |
92 | assert_eq!( |
93 | buf.copy_to_bytes(buf.remaining()), |
94 | Bytes::from_static(b"oworld" ) |
95 | ); |
96 | } |
97 | |