1#![warn(rust_2018_idioms)]
2#![cfg(feature = "full")]
3
4use tokio::io::{AsyncWrite, AsyncWriteExt};
5use tokio_test::{assert_err, assert_ok};
6
7use bytes::{Buf, Bytes, BytesMut};
8use std::cmp;
9use std::io;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13#[tokio::test]
14async fn write_all_buf() {
15 struct Wr {
16 buf: BytesMut,
17 cnt: usize,
18 }
19
20 impl AsyncWrite for Wr {
21 fn poll_write(
22 mut self: Pin<&mut Self>,
23 _cx: &mut Context<'_>,
24 buf: &[u8],
25 ) -> Poll<io::Result<usize>> {
26 let n = cmp::min(4, buf.len());
27 dbg!(buf);
28 let buf = &buf[0..n];
29
30 self.cnt += 1;
31 self.buf.extend(buf);
32 Ok(buf.len()).into()
33 }
34
35 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
36 Ok(()).into()
37 }
38
39 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
40 Ok(()).into()
41 }
42 }
43
44 let mut wr = Wr {
45 buf: BytesMut::with_capacity(64),
46 cnt: 0,
47 };
48
49 let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));
50
51 assert_ok!(wr.write_all_buf(&mut buf).await);
52 assert_eq!(wr.buf, b"helloworld"[..]);
53 // expect 4 writes, [hell],[o],[worl],[d]
54 assert_eq!(wr.cnt, 4);
55 assert!(!buf.has_remaining());
56}
57
58#[tokio::test]
59async fn write_buf_err() {
60 /// Error out after writing the first 4 bytes
61 struct Wr {
62 cnt: usize,
63 }
64
65 impl AsyncWrite for Wr {
66 fn poll_write(
67 mut self: Pin<&mut Self>,
68 _cx: &mut Context<'_>,
69 _buf: &[u8],
70 ) -> Poll<io::Result<usize>> {
71 self.cnt += 1;
72 if self.cnt == 2 {
73 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "whoops")));
74 }
75 Poll::Ready(Ok(4))
76 }
77
78 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
79 Ok(()).into()
80 }
81
82 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
83 Ok(()).into()
84 }
85 }
86
87 let mut wr = Wr { cnt: 0 };
88
89 let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));
90
91 assert_err!(wr.write_all_buf(&mut buf).await);
92 assert_eq!(
93 buf.copy_to_bytes(buf.remaining()),
94 Bytes::from_static(b"oworld")
95 );
96}
97