| 1 | use futures_core::future::{FusedFuture, Future}; |
| 2 | use futures_core::stream::{FusedStream, Stream}; |
| 3 | use futures_core::task::{Context, Poll}; |
| 4 | use futures_io::{ |
| 5 | self as io, AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, IoSlice, IoSliceMut, SeekFrom, |
| 6 | }; |
| 7 | use futures_sink::Sink; |
| 8 | use pin_project::{pin_project , pinned_drop }; |
| 9 | use std::pin::Pin; |
| 10 | use std::thread::panicking; |
| 11 | |
| 12 | /// Combinator that asserts that the underlying type is not moved after being polled. |
| 13 | /// |
| 14 | /// See the `assert_unmoved` methods on: |
| 15 | /// * [`FutureTestExt`](crate::future::FutureTestExt::assert_unmoved) |
| 16 | /// * [`StreamTestExt`](crate::stream::StreamTestExt::assert_unmoved) |
| 17 | /// * [`SinkTestExt`](crate::sink::SinkTestExt::assert_unmoved_sink) |
| 18 | /// * [`AsyncReadTestExt`](crate::io::AsyncReadTestExt::assert_unmoved) |
| 19 | /// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::assert_unmoved_write) |
| 20 | #[pin_project (PinnedDrop, !Unpin)] |
| 21 | #[derive(Debug, Clone)] |
| 22 | #[must_use = "futures do nothing unless you `.await` or poll them" ] |
| 23 | pub struct AssertUnmoved<T> { |
| 24 | #[pin] |
| 25 | inner: T, |
| 26 | this_addr: usize, |
| 27 | } |
| 28 | |
| 29 | impl<T> AssertUnmoved<T> { |
| 30 | pub(crate) fn new(inner: T) -> Self { |
| 31 | Self { inner, this_addr: 0 } |
| 32 | } |
| 33 | |
| 34 | fn poll_with<'a, U>(mut self: Pin<&'a mut Self>, f: impl FnOnce(Pin<&'a mut T>) -> U) -> U { |
| 35 | let cur_this = &*self as *const Self as usize; |
| 36 | if self.this_addr == 0 { |
| 37 | // First time being polled |
| 38 | *self.as_mut().project().this_addr = cur_this; |
| 39 | } else { |
| 40 | assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved between poll calls" ); |
| 41 | } |
| 42 | f(self.project().inner) |
| 43 | } |
| 44 | } |
| 45 | |
| 46 | impl<Fut: Future> Future for AssertUnmoved<Fut> { |
| 47 | type Output = Fut::Output; |
| 48 | |
| 49 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 50 | self.poll_with(|f| f.poll(cx)) |
| 51 | } |
| 52 | } |
| 53 | |
| 54 | impl<Fut: FusedFuture> FusedFuture for AssertUnmoved<Fut> { |
| 55 | fn is_terminated(&self) -> bool { |
| 56 | self.inner.is_terminated() |
| 57 | } |
| 58 | } |
| 59 | |
| 60 | impl<St: Stream> Stream for AssertUnmoved<St> { |
| 61 | type Item = St::Item; |
| 62 | |
| 63 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 64 | self.poll_with(|s| s.poll_next(cx)) |
| 65 | } |
| 66 | } |
| 67 | |
| 68 | impl<St: FusedStream> FusedStream for AssertUnmoved<St> { |
| 69 | fn is_terminated(&self) -> bool { |
| 70 | self.inner.is_terminated() |
| 71 | } |
| 72 | } |
| 73 | |
| 74 | impl<Si: Sink<Item>, Item> Sink<Item> for AssertUnmoved<Si> { |
| 75 | type Error = Si::Error; |
| 76 | |
| 77 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 78 | self.poll_with(|s| s.poll_ready(cx)) |
| 79 | } |
| 80 | |
| 81 | fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { |
| 82 | self.poll_with(|s| s.start_send(item)) |
| 83 | } |
| 84 | |
| 85 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 86 | self.poll_with(|s| s.poll_flush(cx)) |
| 87 | } |
| 88 | |
| 89 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 90 | self.poll_with(|s| s.poll_close(cx)) |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | impl<R: AsyncRead> AsyncRead for AssertUnmoved<R> { |
| 95 | fn poll_read( |
| 96 | self: Pin<&mut Self>, |
| 97 | cx: &mut Context<'_>, |
| 98 | buf: &mut [u8], |
| 99 | ) -> Poll<io::Result<usize>> { |
| 100 | self.poll_with(|r| r.poll_read(cx, buf)) |
| 101 | } |
| 102 | |
| 103 | fn poll_read_vectored( |
| 104 | self: Pin<&mut Self>, |
| 105 | cx: &mut Context<'_>, |
| 106 | bufs: &mut [IoSliceMut<'_>], |
| 107 | ) -> Poll<io::Result<usize>> { |
| 108 | self.poll_with(|r| r.poll_read_vectored(cx, bufs)) |
| 109 | } |
| 110 | } |
| 111 | |
| 112 | impl<W: AsyncWrite> AsyncWrite for AssertUnmoved<W> { |
| 113 | fn poll_write( |
| 114 | self: Pin<&mut Self>, |
| 115 | cx: &mut Context<'_>, |
| 116 | buf: &[u8], |
| 117 | ) -> Poll<io::Result<usize>> { |
| 118 | self.poll_with(|w| w.poll_write(cx, buf)) |
| 119 | } |
| 120 | |
| 121 | fn poll_write_vectored( |
| 122 | self: Pin<&mut Self>, |
| 123 | cx: &mut Context<'_>, |
| 124 | bufs: &[IoSlice<'_>], |
| 125 | ) -> Poll<io::Result<usize>> { |
| 126 | self.poll_with(|w| w.poll_write_vectored(cx, bufs)) |
| 127 | } |
| 128 | |
| 129 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 130 | self.poll_with(|w| w.poll_flush(cx)) |
| 131 | } |
| 132 | |
| 133 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 134 | self.poll_with(|w| w.poll_close(cx)) |
| 135 | } |
| 136 | } |
| 137 | |
| 138 | impl<S: AsyncSeek> AsyncSeek for AssertUnmoved<S> { |
| 139 | fn poll_seek( |
| 140 | self: Pin<&mut Self>, |
| 141 | cx: &mut Context<'_>, |
| 142 | pos: SeekFrom, |
| 143 | ) -> Poll<io::Result<u64>> { |
| 144 | self.poll_with(|s| s.poll_seek(cx, pos)) |
| 145 | } |
| 146 | } |
| 147 | |
| 148 | impl<R: AsyncBufRead> AsyncBufRead for AssertUnmoved<R> { |
| 149 | fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { |
| 150 | self.poll_with(|r| r.poll_fill_buf(cx)) |
| 151 | } |
| 152 | |
| 153 | fn consume(self: Pin<&mut Self>, amt: usize) { |
| 154 | self.poll_with(|r| r.consume(amt)) |
| 155 | } |
| 156 | } |
| 157 | |
| 158 | #[pinned_drop ] |
| 159 | impl<T> PinnedDrop for AssertUnmoved<T> { |
| 160 | fn drop(self: Pin<&mut Self>) { |
| 161 | // If the thread is panicking then we can't panic again as that will |
| 162 | // cause the process to be aborted. |
| 163 | if !panicking() && self.this_addr != 0 { |
| 164 | let cur_this = &*self as *const Self as usize; |
| 165 | assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved before drop" ); |
| 166 | } |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | #[cfg (test)] |
| 171 | mod tests { |
| 172 | use futures_core::future::Future; |
| 173 | use futures_core::task::{Context, Poll}; |
| 174 | use futures_util::future::pending; |
| 175 | use futures_util::task::noop_waker; |
| 176 | use std::pin::Pin; |
| 177 | |
| 178 | use super::AssertUnmoved; |
| 179 | |
| 180 | #[test] |
| 181 | fn assert_send_sync() { |
| 182 | fn assert<T: Send + Sync>() {} |
| 183 | assert::<AssertUnmoved<()>>(); |
| 184 | } |
| 185 | |
| 186 | #[test] |
| 187 | fn dont_panic_when_not_polled() { |
| 188 | // This shouldn't panic. |
| 189 | let future = AssertUnmoved::new(pending::<()>()); |
| 190 | drop(future); |
| 191 | } |
| 192 | |
| 193 | #[test] |
| 194 | #[should_panic (expected = "AssertUnmoved moved between poll calls" )] |
| 195 | fn dont_double_panic() { |
| 196 | // This test should only panic, not abort the process. |
| 197 | let waker = noop_waker(); |
| 198 | let mut cx = Context::from_waker(&waker); |
| 199 | |
| 200 | // First we allocate the future on the stack and poll it. |
| 201 | let mut future = AssertUnmoved::new(pending::<()>()); |
| 202 | let pinned_future = unsafe { Pin::new_unchecked(&mut future) }; |
| 203 | assert_eq!(pinned_future.poll(&mut cx), Poll::Pending); |
| 204 | |
| 205 | // Next we move it back to the heap and poll it again. This second call |
| 206 | // should panic (as the future is moved), but we shouldn't panic again |
| 207 | // whilst dropping `AssertUnmoved`. |
| 208 | let mut future = Box::new(future); |
| 209 | let pinned_boxed_future = unsafe { Pin::new_unchecked(&mut *future) }; |
| 210 | assert_eq!(pinned_boxed_future.poll(&mut cx), Poll::Pending); |
| 211 | } |
| 212 | } |
| 213 | |