| 1 | use futures_io::{self as io, AsyncBufRead, AsyncRead, AsyncWrite}; |
| 2 | use pin_project::pin_project ; |
| 3 | use std::{ |
| 4 | cmp, |
| 5 | pin::Pin, |
| 6 | task::{Context, Poll}, |
| 7 | }; |
| 8 | |
| 9 | /// I/O wrapper that limits the number of bytes written or read on each call. |
| 10 | /// |
| 11 | /// See the [`limited`] and [`limited_write`] methods. |
| 12 | /// |
| 13 | /// [`limited`]: super::AsyncReadTestExt::limited |
| 14 | /// [`limited_write`]: super::AsyncWriteTestExt::limited_write |
| 15 | #[pin_project ] |
| 16 | #[derive(Debug)] |
| 17 | pub struct Limited<Io> { |
| 18 | #[pin] |
| 19 | io: Io, |
| 20 | limit: usize, |
| 21 | } |
| 22 | |
| 23 | impl<Io> Limited<Io> { |
| 24 | pub(crate) fn new(io: Io, limit: usize) -> Self { |
| 25 | Self { io, limit } |
| 26 | } |
| 27 | |
| 28 | /// Acquires a reference to the underlying I/O object that this adaptor is |
| 29 | /// wrapping. |
| 30 | pub fn get_ref(&self) -> &Io { |
| 31 | &self.io |
| 32 | } |
| 33 | |
| 34 | /// Acquires a mutable reference to the underlying I/O object that this |
| 35 | /// adaptor is wrapping. |
| 36 | pub fn get_mut(&mut self) -> &mut Io { |
| 37 | &mut self.io |
| 38 | } |
| 39 | |
| 40 | /// Acquires a pinned mutable reference to the underlying I/O object that |
| 41 | /// this adaptor is wrapping. |
| 42 | pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Io> { |
| 43 | self.project().io |
| 44 | } |
| 45 | |
| 46 | /// Consumes this adaptor returning the underlying I/O object. |
| 47 | pub fn into_inner(self) -> Io { |
| 48 | self.io |
| 49 | } |
| 50 | } |
| 51 | |
| 52 | impl<W: AsyncWrite> AsyncWrite for Limited<W> { |
| 53 | fn poll_write( |
| 54 | self: Pin<&mut Self>, |
| 55 | cx: &mut Context<'_>, |
| 56 | buf: &[u8], |
| 57 | ) -> Poll<io::Result<usize>> { |
| 58 | let this = self.project(); |
| 59 | this.io.poll_write(cx, &buf[..cmp::min(*this.limit, buf.len())]) |
| 60 | } |
| 61 | |
| 62 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 63 | self.project().io.poll_flush(cx) |
| 64 | } |
| 65 | |
| 66 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 67 | self.project().io.poll_close(cx) |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | impl<R: AsyncRead> AsyncRead for Limited<R> { |
| 72 | fn poll_read( |
| 73 | self: Pin<&mut Self>, |
| 74 | cx: &mut Context<'_>, |
| 75 | buf: &mut [u8], |
| 76 | ) -> Poll<io::Result<usize>> { |
| 77 | let this = self.project(); |
| 78 | let limit = cmp::min(*this.limit, buf.len()); |
| 79 | this.io.poll_read(cx, &mut buf[..limit]) |
| 80 | } |
| 81 | } |
| 82 | |
| 83 | impl<R: AsyncBufRead> AsyncBufRead for Limited<R> { |
| 84 | fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { |
| 85 | self.project().io.poll_fill_buf(cx) |
| 86 | } |
| 87 | |
| 88 | fn consume(self: Pin<&mut Self>, amount: usize) { |
| 89 | self.project().io.consume(amount) |
| 90 | } |
| 91 | } |
| 92 | |