1 | use futures_core::future::{FusedFuture, Future}; |
2 | use futures_core::stream::{FusedStream, Stream}; |
3 | use futures_core::task::{Context, Poll}; |
4 | use futures_io::{ |
5 | self as io, AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, IoSlice, IoSliceMut, SeekFrom, |
6 | }; |
7 | use futures_sink::Sink; |
8 | use pin_project::{pin_project , pinned_drop }; |
9 | use std::pin::Pin; |
10 | use std::thread::panicking; |
11 | |
12 | /// Combinator that asserts that the underlying type is not moved after being polled. |
13 | /// |
14 | /// See the `assert_unmoved` methods on: |
15 | /// * [`FutureTestExt`](crate::future::FutureTestExt::assert_unmoved) |
16 | /// * [`StreamTestExt`](crate::stream::StreamTestExt::assert_unmoved) |
17 | /// * [`SinkTestExt`](crate::sink::SinkTestExt::assert_unmoved_sink) |
18 | /// * [`AsyncReadTestExt`](crate::io::AsyncReadTestExt::assert_unmoved) |
19 | /// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::assert_unmoved_write) |
20 | #[pin_project (PinnedDrop, !Unpin)] |
21 | #[derive(Debug, Clone)] |
22 | #[must_use = "futures do nothing unless you `.await` or poll them" ] |
23 | pub struct AssertUnmoved<T> { |
24 | #[pin] |
25 | inner: T, |
26 | this_addr: usize, |
27 | } |
28 | |
29 | impl<T> AssertUnmoved<T> { |
30 | pub(crate) fn new(inner: T) -> Self { |
31 | Self { inner, this_addr: 0 } |
32 | } |
33 | |
34 | fn poll_with<'a, U>(mut self: Pin<&'a mut Self>, f: impl FnOnce(Pin<&'a mut T>) -> U) -> U { |
35 | let cur_this = &*self as *const Self as usize; |
36 | if self.this_addr == 0 { |
37 | // First time being polled |
38 | *self.as_mut().project().this_addr = cur_this; |
39 | } else { |
40 | assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved between poll calls" ); |
41 | } |
42 | f(self.project().inner) |
43 | } |
44 | } |
45 | |
46 | impl<Fut: Future> Future for AssertUnmoved<Fut> { |
47 | type Output = Fut::Output; |
48 | |
49 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
50 | self.poll_with(|f| f.poll(cx)) |
51 | } |
52 | } |
53 | |
54 | impl<Fut: FusedFuture> FusedFuture for AssertUnmoved<Fut> { |
55 | fn is_terminated(&self) -> bool { |
56 | self.inner.is_terminated() |
57 | } |
58 | } |
59 | |
60 | impl<St: Stream> Stream for AssertUnmoved<St> { |
61 | type Item = St::Item; |
62 | |
63 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
64 | self.poll_with(|s| s.poll_next(cx)) |
65 | } |
66 | } |
67 | |
68 | impl<St: FusedStream> FusedStream for AssertUnmoved<St> { |
69 | fn is_terminated(&self) -> bool { |
70 | self.inner.is_terminated() |
71 | } |
72 | } |
73 | |
74 | impl<Si: Sink<Item>, Item> Sink<Item> for AssertUnmoved<Si> { |
75 | type Error = Si::Error; |
76 | |
77 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
78 | self.poll_with(|s| s.poll_ready(cx)) |
79 | } |
80 | |
81 | fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { |
82 | self.poll_with(|s| s.start_send(item)) |
83 | } |
84 | |
85 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
86 | self.poll_with(|s| s.poll_flush(cx)) |
87 | } |
88 | |
89 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
90 | self.poll_with(|s| s.poll_close(cx)) |
91 | } |
92 | } |
93 | |
94 | impl<R: AsyncRead> AsyncRead for AssertUnmoved<R> { |
95 | fn poll_read( |
96 | self: Pin<&mut Self>, |
97 | cx: &mut Context<'_>, |
98 | buf: &mut [u8], |
99 | ) -> Poll<io::Result<usize>> { |
100 | self.poll_with(|r| r.poll_read(cx, buf)) |
101 | } |
102 | |
103 | fn poll_read_vectored( |
104 | self: Pin<&mut Self>, |
105 | cx: &mut Context<'_>, |
106 | bufs: &mut [IoSliceMut<'_>], |
107 | ) -> Poll<io::Result<usize>> { |
108 | self.poll_with(|r| r.poll_read_vectored(cx, bufs)) |
109 | } |
110 | } |
111 | |
112 | impl<W: AsyncWrite> AsyncWrite for AssertUnmoved<W> { |
113 | fn poll_write( |
114 | self: Pin<&mut Self>, |
115 | cx: &mut Context<'_>, |
116 | buf: &[u8], |
117 | ) -> Poll<io::Result<usize>> { |
118 | self.poll_with(|w| w.poll_write(cx, buf)) |
119 | } |
120 | |
121 | fn poll_write_vectored( |
122 | self: Pin<&mut Self>, |
123 | cx: &mut Context<'_>, |
124 | bufs: &[IoSlice<'_>], |
125 | ) -> Poll<io::Result<usize>> { |
126 | self.poll_with(|w| w.poll_write_vectored(cx, bufs)) |
127 | } |
128 | |
129 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
130 | self.poll_with(|w| w.poll_flush(cx)) |
131 | } |
132 | |
133 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
134 | self.poll_with(|w| w.poll_close(cx)) |
135 | } |
136 | } |
137 | |
138 | impl<S: AsyncSeek> AsyncSeek for AssertUnmoved<S> { |
139 | fn poll_seek( |
140 | self: Pin<&mut Self>, |
141 | cx: &mut Context<'_>, |
142 | pos: SeekFrom, |
143 | ) -> Poll<io::Result<u64>> { |
144 | self.poll_with(|s| s.poll_seek(cx, pos)) |
145 | } |
146 | } |
147 | |
148 | impl<R: AsyncBufRead> AsyncBufRead for AssertUnmoved<R> { |
149 | fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { |
150 | self.poll_with(|r| r.poll_fill_buf(cx)) |
151 | } |
152 | |
153 | fn consume(self: Pin<&mut Self>, amt: usize) { |
154 | self.poll_with(|r| r.consume(amt)) |
155 | } |
156 | } |
157 | |
158 | #[pinned_drop ] |
159 | impl<T> PinnedDrop for AssertUnmoved<T> { |
160 | fn drop(self: Pin<&mut Self>) { |
161 | // If the thread is panicking then we can't panic again as that will |
162 | // cause the process to be aborted. |
163 | if !panicking() && self.this_addr != 0 { |
164 | let cur_this = &*self as *const Self as usize; |
165 | assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved before drop" ); |
166 | } |
167 | } |
168 | } |
169 | |
170 | #[cfg (test)] |
171 | mod tests { |
172 | use futures_core::future::Future; |
173 | use futures_core::task::{Context, Poll}; |
174 | use futures_util::future::pending; |
175 | use futures_util::task::noop_waker; |
176 | use std::pin::Pin; |
177 | |
178 | use super::AssertUnmoved; |
179 | |
180 | #[test] |
181 | fn assert_send_sync() { |
182 | fn assert<T: Send + Sync>() {} |
183 | assert::<AssertUnmoved<()>>(); |
184 | } |
185 | |
186 | #[test] |
187 | fn dont_panic_when_not_polled() { |
188 | // This shouldn't panic. |
189 | let future = AssertUnmoved::new(pending::<()>()); |
190 | drop(future); |
191 | } |
192 | |
193 | #[test] |
194 | #[should_panic (expected = "AssertUnmoved moved between poll calls" )] |
195 | fn dont_double_panic() { |
196 | // This test should only panic, not abort the process. |
197 | let waker = noop_waker(); |
198 | let mut cx = Context::from_waker(&waker); |
199 | |
200 | // First we allocate the future on the stack and poll it. |
201 | let mut future = AssertUnmoved::new(pending::<()>()); |
202 | let pinned_future = unsafe { Pin::new_unchecked(&mut future) }; |
203 | assert_eq!(pinned_future.poll(&mut cx), Poll::Pending); |
204 | |
205 | // Next we move it back to the heap and poll it again. This second call |
206 | // should panic (as the future is moved), but we shouldn't panic again |
207 | // whilst dropping `AssertUnmoved`. |
208 | let mut future = Box::new(future); |
209 | let pinned_boxed_future = unsafe { Pin::new_unchecked(&mut *future) }; |
210 | assert_eq!(pinned_boxed_future.poll(&mut cx), Poll::Pending); |
211 | } |
212 | } |
213 | |