1use futures_core::future::{FusedFuture, Future};
2use futures_core::stream::{FusedStream, Stream};
3use futures_io::{
4 self as io, AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, IoSlice, IoSliceMut, SeekFrom,
5};
6use futures_sink::Sink;
7use pin_project::pin_project;
8use std::{
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13/// Wrapper that interleaves [`Poll::Pending`] in calls to poll.
14///
15/// See the `interleave_pending` methods on:
16/// * [`FutureTestExt`](crate::future::FutureTestExt::interleave_pending)
17/// * [`StreamTestExt`](crate::stream::StreamTestExt::interleave_pending)
18/// * [`SinkTestExt`](crate::sink::SinkTestExt::interleave_pending_sink)
19/// * [`AsyncReadTestExt`](crate::io::AsyncReadTestExt::interleave_pending)
20/// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::interleave_pending_write)
21#[pin_project]
22#[derive(Debug)]
23pub struct InterleavePending<T> {
24 #[pin]
25 inner: T,
26 pended: bool,
27}
28
29impl<T> InterleavePending<T> {
30 pub(crate) fn new(inner: T) -> Self {
31 Self { inner, pended: false }
32 }
33
34 /// Acquires a reference to the underlying I/O object that this adaptor is
35 /// wrapping.
36 pub fn get_ref(&self) -> &T {
37 &self.inner
38 }
39
40 /// Acquires a mutable reference to the underlying I/O object that this
41 /// adaptor is wrapping.
42 pub fn get_mut(&mut self) -> &mut T {
43 &mut self.inner
44 }
45
46 /// Acquires a pinned mutable reference to the underlying I/O object that
47 /// this adaptor is wrapping.
48 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
49 self.project().inner
50 }
51
52 /// Consumes this adaptor returning the underlying I/O object.
53 pub fn into_inner(self) -> T {
54 self.inner
55 }
56
57 fn poll_with<'a, U>(
58 self: Pin<&'a mut Self>,
59 cx: &mut Context<'_>,
60 f: impl FnOnce(Pin<&'a mut T>, &mut Context<'_>) -> Poll<U>,
61 ) -> Poll<U> {
62 let this = self.project();
63 if *this.pended {
64 let next = f(this.inner, cx);
65 if next.is_ready() {
66 *this.pended = false;
67 }
68 next
69 } else {
70 cx.waker().wake_by_ref();
71 *this.pended = true;
72 Poll::Pending
73 }
74 }
75}
76
77impl<Fut: Future> Future for InterleavePending<Fut> {
78 type Output = Fut::Output;
79
80 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
81 self.poll_with(cx, Fut::poll)
82 }
83}
84
85impl<Fut: FusedFuture> FusedFuture for InterleavePending<Fut> {
86 fn is_terminated(&self) -> bool {
87 self.inner.is_terminated()
88 }
89}
90
91impl<St: Stream> Stream for InterleavePending<St> {
92 type Item = St::Item;
93
94 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95 self.poll_with(cx, St::poll_next)
96 }
97
98 fn size_hint(&self) -> (usize, Option<usize>) {
99 self.inner.size_hint()
100 }
101}
102
103impl<St: FusedStream> FusedStream for InterleavePending<St> {
104 fn is_terminated(&self) -> bool {
105 self.inner.is_terminated()
106 }
107}
108
109impl<Si: Sink<Item>, Item> Sink<Item> for InterleavePending<Si> {
110 type Error = Si::Error;
111
112 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
113 self.poll_with(cx, Si::poll_ready)
114 }
115
116 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
117 self.project().inner.start_send(item)
118 }
119
120 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121 self.poll_with(cx, Si::poll_flush)
122 }
123
124 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125 self.poll_with(cx, Si::poll_close)
126 }
127}
128
129impl<R: AsyncRead> AsyncRead for InterleavePending<R> {
130 fn poll_read(
131 self: Pin<&mut Self>,
132 cx: &mut Context<'_>,
133 buf: &mut [u8],
134 ) -> Poll<io::Result<usize>> {
135 self.poll_with(cx, |r, cx| r.poll_read(cx, buf))
136 }
137
138 fn poll_read_vectored(
139 self: Pin<&mut Self>,
140 cx: &mut Context<'_>,
141 bufs: &mut [IoSliceMut<'_>],
142 ) -> Poll<io::Result<usize>> {
143 self.poll_with(cx, |r, cx| r.poll_read_vectored(cx, bufs))
144 }
145}
146
147impl<W: AsyncWrite> AsyncWrite for InterleavePending<W> {
148 fn poll_write(
149 self: Pin<&mut Self>,
150 cx: &mut Context<'_>,
151 buf: &[u8],
152 ) -> Poll<io::Result<usize>> {
153 self.poll_with(cx, |w, cx| w.poll_write(cx, buf))
154 }
155
156 fn poll_write_vectored(
157 self: Pin<&mut Self>,
158 cx: &mut Context<'_>,
159 bufs: &[IoSlice<'_>],
160 ) -> Poll<io::Result<usize>> {
161 self.poll_with(cx, |w, cx| w.poll_write_vectored(cx, bufs))
162 }
163
164 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
165 self.poll_with(cx, W::poll_flush)
166 }
167
168 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
169 self.poll_with(cx, W::poll_close)
170 }
171}
172
173impl<S: AsyncSeek> AsyncSeek for InterleavePending<S> {
174 fn poll_seek(
175 self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 pos: SeekFrom,
178 ) -> Poll<io::Result<u64>> {
179 self.poll_with(cx, |s, cx| s.poll_seek(cx, pos))
180 }
181}
182
183impl<R: AsyncBufRead> AsyncBufRead for InterleavePending<R> {
184 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
185 self.poll_with(cx, R::poll_fill_buf)
186 }
187
188 fn consume(self: Pin<&mut Self>, amount: usize) {
189 self.project().inner.consume(amount)
190 }
191}
192