1use futures_core::future::{FusedFuture, Future};
2use futures_core::stream::{FusedStream, Stream};
3use futures_core::task::{Context, Poll};
4use futures_io::{
5 self as io, AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, IoSlice, IoSliceMut, SeekFrom,
6};
7use futures_sink::Sink;
8use pin_project::{pin_project, pinned_drop};
9use std::pin::Pin;
10use std::thread::panicking;
11
12/// Combinator that asserts that the underlying type is not moved after being polled.
13///
14/// See the `assert_unmoved` methods on:
15/// * [`FutureTestExt`](crate::future::FutureTestExt::assert_unmoved)
16/// * [`StreamTestExt`](crate::stream::StreamTestExt::assert_unmoved)
17/// * [`SinkTestExt`](crate::sink::SinkTestExt::assert_unmoved_sink)
18/// * [`AsyncReadTestExt`](crate::io::AsyncReadTestExt::assert_unmoved)
19/// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::assert_unmoved_write)
20#[pin_project(PinnedDrop, !Unpin)]
21#[derive(Debug, Clone)]
22#[must_use = "futures do nothing unless you `.await` or poll them"]
23pub struct AssertUnmoved<T> {
24 #[pin]
25 inner: T,
26 this_addr: usize,
27}
28
29impl<T> AssertUnmoved<T> {
30 pub(crate) fn new(inner: T) -> Self {
31 Self { inner, this_addr: 0 }
32 }
33
34 fn poll_with<'a, U>(mut self: Pin<&'a mut Self>, f: impl FnOnce(Pin<&'a mut T>) -> U) -> U {
35 let cur_this = &*self as *const Self as usize;
36 if self.this_addr == 0 {
37 // First time being polled
38 *self.as_mut().project().this_addr = cur_this;
39 } else {
40 assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved between poll calls");
41 }
42 f(self.project().inner)
43 }
44}
45
46impl<Fut: Future> Future for AssertUnmoved<Fut> {
47 type Output = Fut::Output;
48
49 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
50 self.poll_with(|f| f.poll(cx))
51 }
52}
53
54impl<Fut: FusedFuture> FusedFuture for AssertUnmoved<Fut> {
55 fn is_terminated(&self) -> bool {
56 self.inner.is_terminated()
57 }
58}
59
60impl<St: Stream> Stream for AssertUnmoved<St> {
61 type Item = St::Item;
62
63 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
64 self.poll_with(|s| s.poll_next(cx))
65 }
66}
67
68impl<St: FusedStream> FusedStream for AssertUnmoved<St> {
69 fn is_terminated(&self) -> bool {
70 self.inner.is_terminated()
71 }
72}
73
74impl<Si: Sink<Item>, Item> Sink<Item> for AssertUnmoved<Si> {
75 type Error = Si::Error;
76
77 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78 self.poll_with(|s| s.poll_ready(cx))
79 }
80
81 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
82 self.poll_with(|s| s.start_send(item))
83 }
84
85 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 self.poll_with(|s| s.poll_flush(cx))
87 }
88
89 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90 self.poll_with(|s| s.poll_close(cx))
91 }
92}
93
94impl<R: AsyncRead> AsyncRead for AssertUnmoved<R> {
95 fn poll_read(
96 self: Pin<&mut Self>,
97 cx: &mut Context<'_>,
98 buf: &mut [u8],
99 ) -> Poll<io::Result<usize>> {
100 self.poll_with(|r| r.poll_read(cx, buf))
101 }
102
103 fn poll_read_vectored(
104 self: Pin<&mut Self>,
105 cx: &mut Context<'_>,
106 bufs: &mut [IoSliceMut<'_>],
107 ) -> Poll<io::Result<usize>> {
108 self.poll_with(|r| r.poll_read_vectored(cx, bufs))
109 }
110}
111
112impl<W: AsyncWrite> AsyncWrite for AssertUnmoved<W> {
113 fn poll_write(
114 self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 buf: &[u8],
117 ) -> Poll<io::Result<usize>> {
118 self.poll_with(|w| w.poll_write(cx, buf))
119 }
120
121 fn poll_write_vectored(
122 self: Pin<&mut Self>,
123 cx: &mut Context<'_>,
124 bufs: &[IoSlice<'_>],
125 ) -> Poll<io::Result<usize>> {
126 self.poll_with(|w| w.poll_write_vectored(cx, bufs))
127 }
128
129 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
130 self.poll_with(|w| w.poll_flush(cx))
131 }
132
133 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
134 self.poll_with(|w| w.poll_close(cx))
135 }
136}
137
138impl<S: AsyncSeek> AsyncSeek for AssertUnmoved<S> {
139 fn poll_seek(
140 self: Pin<&mut Self>,
141 cx: &mut Context<'_>,
142 pos: SeekFrom,
143 ) -> Poll<io::Result<u64>> {
144 self.poll_with(|s| s.poll_seek(cx, pos))
145 }
146}
147
148impl<R: AsyncBufRead> AsyncBufRead for AssertUnmoved<R> {
149 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
150 self.poll_with(|r| r.poll_fill_buf(cx))
151 }
152
153 fn consume(self: Pin<&mut Self>, amt: usize) {
154 self.poll_with(|r| r.consume(amt))
155 }
156}
157
158#[pinned_drop]
159impl<T> PinnedDrop for AssertUnmoved<T> {
160 fn drop(self: Pin<&mut Self>) {
161 // If the thread is panicking then we can't panic again as that will
162 // cause the process to be aborted.
163 if !panicking() && self.this_addr != 0 {
164 let cur_this = &*self as *const Self as usize;
165 assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved before drop");
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use futures_core::future::Future;
173 use futures_core::task::{Context, Poll};
174 use futures_util::future::pending;
175 use futures_util::task::noop_waker;
176 use std::pin::Pin;
177
178 use super::AssertUnmoved;
179
180 #[test]
181 fn assert_send_sync() {
182 fn assert<T: Send + Sync>() {}
183 assert::<AssertUnmoved<()>>();
184 }
185
186 #[test]
187 fn dont_panic_when_not_polled() {
188 // This shouldn't panic.
189 let future = AssertUnmoved::new(pending::<()>());
190 drop(future);
191 }
192
193 #[test]
194 #[should_panic(expected = "AssertUnmoved moved between poll calls")]
195 fn dont_double_panic() {
196 // This test should only panic, not abort the process.
197 let waker = noop_waker();
198 let mut cx = Context::from_waker(&waker);
199
200 // First we allocate the future on the stack and poll it.
201 let mut future = AssertUnmoved::new(pending::<()>());
202 let pinned_future = unsafe { Pin::new_unchecked(&mut future) };
203 assert_eq!(pinned_future.poll(&mut cx), Poll::Pending);
204
205 // Next we move it back to the heap and poll it again. This second call
206 // should panic (as the future is moved), but we shouldn't panic again
207 // whilst dropping `AssertUnmoved`.
208 let mut future = Box::new(future);
209 let pinned_boxed_future = unsafe { Pin::new_unchecked(&mut *future) };
210 assert_eq!(pinned_boxed_future.poll(&mut cx), Poll::Pending);
211 }
212}
213