1 | //! Futures task based helpers to easily test futures and manually written futures. |
2 | //! |
3 | //! The [`Spawn`] type is used as a mock task harness that allows you to poll futures |
4 | //! without needing to setup pinning or context. Any future can be polled but if the |
5 | //! future requires the tokio async context you will need to ensure that you poll the |
6 | //! [`Spawn`] within a tokio context, this means that as long as you are inside the |
7 | //! runtime it will work and you can poll it via [`Spawn`]. |
8 | //! |
9 | //! [`Spawn`] also supports [`Stream`] to call `poll_next` without pinning |
10 | //! or context. |
11 | //! |
12 | //! In addition to circumventing the need for pinning and context, [`Spawn`] also tracks |
13 | //! the amount of times the future/task was woken. This can be useful to track if some |
14 | //! leaf future notified the root task correctly. |
15 | //! |
16 | //! # Example |
17 | //! |
18 | //! ``` |
19 | //! use tokio_test::task; |
20 | //! |
21 | //! let fut = async {}; |
22 | //! |
23 | //! let mut task = task::spawn(fut); |
24 | //! |
25 | //! assert!(task.poll().is_ready(), "Task was not ready!" ); |
26 | //! ``` |
27 | |
28 | #![allow (clippy::mutex_atomic)] |
29 | |
30 | use std::future::Future; |
31 | use std::mem; |
32 | use std::ops; |
33 | use std::pin::Pin; |
34 | use std::sync::{Arc, Condvar, Mutex}; |
35 | use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; |
36 | |
37 | use tokio_stream::Stream; |
38 | |
39 | /// Spawn a future into a [`Spawn`] which wraps the future in a mocked executor. |
40 | /// |
41 | /// This can be used to spawn a [`Future`] or a [`Stream`]. |
42 | /// |
43 | /// For more information, check the module docs. |
44 | pub fn spawn<T>(task: T) -> Spawn<T> { |
45 | Spawn { |
46 | task: MockTask::new(), |
47 | future: Box::pin(task), |
48 | } |
49 | } |
50 | |
51 | /// Future spawned on a mock task that can be used to poll the future or stream |
52 | /// without needing pinning or context types. |
53 | #[derive(Debug)] |
54 | pub struct Spawn<T> { |
55 | task: MockTask, |
56 | future: Pin<Box<T>>, |
57 | } |
58 | |
59 | #[derive(Debug, Clone)] |
60 | struct MockTask { |
61 | waker: Arc<ThreadWaker>, |
62 | } |
63 | |
64 | #[derive(Debug)] |
65 | struct ThreadWaker { |
66 | state: Mutex<usize>, |
67 | condvar: Condvar, |
68 | } |
69 | |
70 | const IDLE: usize = 0; |
71 | const WAKE: usize = 1; |
72 | const SLEEP: usize = 2; |
73 | |
74 | impl<T> Spawn<T> { |
75 | /// Consumes `self` returning the inner value |
76 | pub fn into_inner(self) -> T |
77 | where |
78 | T: Unpin, |
79 | { |
80 | *Pin::into_inner(self.future) |
81 | } |
82 | |
83 | /// Returns `true` if the inner future has received a wake notification |
84 | /// since the last call to `enter`. |
85 | pub fn is_woken(&self) -> bool { |
86 | self.task.is_woken() |
87 | } |
88 | |
89 | /// Returns the number of references to the task waker |
90 | /// |
91 | /// The task itself holds a reference. The return value will never be zero. |
92 | pub fn waker_ref_count(&self) -> usize { |
93 | self.task.waker_ref_count() |
94 | } |
95 | |
96 | /// Enter the task context |
97 | pub fn enter<F, R>(&mut self, f: F) -> R |
98 | where |
99 | F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R, |
100 | { |
101 | let fut = self.future.as_mut(); |
102 | self.task.enter(|cx| f(cx, fut)) |
103 | } |
104 | } |
105 | |
106 | impl<T: Unpin> ops::Deref for Spawn<T> { |
107 | type Target = T; |
108 | |
109 | fn deref(&self) -> &T { |
110 | &self.future |
111 | } |
112 | } |
113 | |
114 | impl<T: Unpin> ops::DerefMut for Spawn<T> { |
115 | fn deref_mut(&mut self) -> &mut T { |
116 | &mut self.future |
117 | } |
118 | } |
119 | |
120 | impl<T: Future> Spawn<T> { |
121 | /// If `T` is a [`Future`] then poll it. This will handle pinning and the context |
122 | /// type for the future. |
123 | pub fn poll(&mut self) -> Poll<T::Output> { |
124 | let fut = self.future.as_mut(); |
125 | self.task.enter(|cx| fut.poll(cx)) |
126 | } |
127 | } |
128 | |
129 | impl<T: Stream> Spawn<T> { |
130 | /// If `T` is a [`Stream`] then `poll_next` it. This will handle pinning and the context |
131 | /// type for the stream. |
132 | pub fn poll_next(&mut self) -> Poll<Option<T::Item>> { |
133 | let stream = self.future.as_mut(); |
134 | self.task.enter(|cx| stream.poll_next(cx)) |
135 | } |
136 | } |
137 | |
138 | impl<T: Future> Future for Spawn<T> { |
139 | type Output = T::Output; |
140 | |
141 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
142 | self.future.as_mut().poll(cx) |
143 | } |
144 | } |
145 | |
146 | impl<T: Stream> Stream for Spawn<T> { |
147 | type Item = T::Item; |
148 | |
149 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
150 | self.future.as_mut().poll_next(cx) |
151 | } |
152 | } |
153 | |
154 | impl MockTask { |
155 | /// Creates new mock task |
156 | fn new() -> Self { |
157 | MockTask { |
158 | waker: Arc::new(ThreadWaker::new()), |
159 | } |
160 | } |
161 | |
162 | /// Runs a closure from the context of the task. |
163 | /// |
164 | /// Any wake notifications resulting from the execution of the closure are |
165 | /// tracked. |
166 | fn enter<F, R>(&mut self, f: F) -> R |
167 | where |
168 | F: FnOnce(&mut Context<'_>) -> R, |
169 | { |
170 | self.waker.clear(); |
171 | let waker = self.waker(); |
172 | let mut cx = Context::from_waker(&waker); |
173 | |
174 | f(&mut cx) |
175 | } |
176 | |
177 | /// Returns `true` if the inner future has received a wake notification |
178 | /// since the last call to `enter`. |
179 | fn is_woken(&self) -> bool { |
180 | self.waker.is_woken() |
181 | } |
182 | |
183 | /// Returns the number of references to the task waker |
184 | /// |
185 | /// The task itself holds a reference. The return value will never be zero. |
186 | fn waker_ref_count(&self) -> usize { |
187 | Arc::strong_count(&self.waker) |
188 | } |
189 | |
190 | fn waker(&self) -> Waker { |
191 | unsafe { |
192 | let raw = to_raw(self.waker.clone()); |
193 | Waker::from_raw(raw) |
194 | } |
195 | } |
196 | } |
197 | |
198 | impl Default for MockTask { |
199 | fn default() -> Self { |
200 | Self::new() |
201 | } |
202 | } |
203 | |
204 | impl ThreadWaker { |
205 | fn new() -> Self { |
206 | ThreadWaker { |
207 | state: Mutex::new(IDLE), |
208 | condvar: Condvar::new(), |
209 | } |
210 | } |
211 | |
212 | /// Clears any previously received wakes, avoiding potential spurious |
213 | /// wake notifications. This should only be called immediately before running the |
214 | /// task. |
215 | fn clear(&self) { |
216 | *self.state.lock().unwrap() = IDLE; |
217 | } |
218 | |
219 | fn is_woken(&self) -> bool { |
220 | match *self.state.lock().unwrap() { |
221 | IDLE => false, |
222 | WAKE => true, |
223 | _ => unreachable!(), |
224 | } |
225 | } |
226 | |
227 | fn wake(&self) { |
228 | // First, try transitioning from IDLE -> NOTIFY, this does not require a lock. |
229 | let mut state = self.state.lock().unwrap(); |
230 | let prev = *state; |
231 | |
232 | if prev == WAKE { |
233 | return; |
234 | } |
235 | |
236 | *state = WAKE; |
237 | |
238 | if prev == IDLE { |
239 | return; |
240 | } |
241 | |
242 | // The other half is sleeping, so we wake it up. |
243 | assert_eq!(prev, SLEEP); |
244 | self.condvar.notify_one(); |
245 | } |
246 | } |
247 | |
248 | static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker); |
249 | |
250 | unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker { |
251 | RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE) |
252 | } |
253 | |
254 | unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> { |
255 | Arc::from_raw(raw as *const ThreadWaker) |
256 | } |
257 | |
258 | unsafe fn clone(raw: *const ()) -> RawWaker { |
259 | let waker = from_raw(raw); |
260 | |
261 | // Increment the ref count |
262 | mem::forget(waker.clone()); |
263 | |
264 | to_raw(waker) |
265 | } |
266 | |
267 | unsafe fn wake(raw: *const ()) { |
268 | let waker = from_raw(raw); |
269 | waker.wake(); |
270 | } |
271 | |
272 | unsafe fn wake_by_ref(raw: *const ()) { |
273 | let waker = from_raw(raw); |
274 | waker.wake(); |
275 | |
276 | // We don't actually own a reference to the unparker |
277 | mem::forget(waker); |
278 | } |
279 | |
280 | unsafe fn drop_waker(raw: *const ()) { |
281 | let _ = from_raw(raw); |
282 | } |
283 | |