1//! Futures task based helpers to easily test futures and manually written futures.
2//!
3//! The [`Spawn`] type is used as a mock task harness that allows you to poll futures
4//! without needing to setup pinning or context. Any future can be polled but if the
5//! future requires the tokio async context you will need to ensure that you poll the
6//! [`Spawn`] within a tokio context, this means that as long as you are inside the
7//! runtime it will work and you can poll it via [`Spawn`].
8//!
9//! [`Spawn`] also supports [`Stream`] to call `poll_next` without pinning
10//! or context.
11//!
12//! In addition to circumventing the need for pinning and context, [`Spawn`] also tracks
13//! the amount of times the future/task was woken. This can be useful to track if some
14//! leaf future notified the root task correctly.
15//!
16//! # Example
17//!
18//! ```
19//! use tokio_test::task;
20//!
21//! let fut = async {};
22//!
23//! let mut task = task::spawn(fut);
24//!
25//! assert!(task.poll().is_ready(), "Task was not ready!");
26//! ```
27
28#![allow(clippy::mutex_atomic)]
29
30use std::future::Future;
31use std::mem;
32use std::ops;
33use std::pin::Pin;
34use std::sync::{Arc, Condvar, Mutex};
35use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
36
37use tokio_stream::Stream;
38
39/// Spawn a future into a [`Spawn`] which wraps the future in a mocked executor.
40///
41/// This can be used to spawn a [`Future`] or a [`Stream`].
42///
43/// For more information, check the module docs.
44pub fn spawn<T>(task: T) -> Spawn<T> {
45 Spawn {
46 task: MockTask::new(),
47 future: Box::pin(task),
48 }
49}
50
51/// Future spawned on a mock task that can be used to poll the future or stream
52/// without needing pinning or context types.
53#[derive(Debug)]
54pub struct Spawn<T> {
55 task: MockTask,
56 future: Pin<Box<T>>,
57}
58
59#[derive(Debug, Clone)]
60struct MockTask {
61 waker: Arc<ThreadWaker>,
62}
63
64#[derive(Debug)]
65struct ThreadWaker {
66 state: Mutex<usize>,
67 condvar: Condvar,
68}
69
70const IDLE: usize = 0;
71const WAKE: usize = 1;
72const SLEEP: usize = 2;
73
74impl<T> Spawn<T> {
75 /// Consumes `self` returning the inner value
76 pub fn into_inner(self) -> T
77 where
78 T: Unpin,
79 {
80 *Pin::into_inner(self.future)
81 }
82
83 /// Returns `true` if the inner future has received a wake notification
84 /// since the last call to `enter`.
85 pub fn is_woken(&self) -> bool {
86 self.task.is_woken()
87 }
88
89 /// Returns the number of references to the task waker
90 ///
91 /// The task itself holds a reference. The return value will never be zero.
92 pub fn waker_ref_count(&self) -> usize {
93 self.task.waker_ref_count()
94 }
95
96 /// Enter the task context
97 pub fn enter<F, R>(&mut self, f: F) -> R
98 where
99 F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
100 {
101 let fut = self.future.as_mut();
102 self.task.enter(|cx| f(cx, fut))
103 }
104}
105
106impl<T: Unpin> ops::Deref for Spawn<T> {
107 type Target = T;
108
109 fn deref(&self) -> &T {
110 &self.future
111 }
112}
113
114impl<T: Unpin> ops::DerefMut for Spawn<T> {
115 fn deref_mut(&mut self) -> &mut T {
116 &mut self.future
117 }
118}
119
120impl<T: Future> Spawn<T> {
121 /// If `T` is a [`Future`] then poll it. This will handle pinning and the context
122 /// type for the future.
123 pub fn poll(&mut self) -> Poll<T::Output> {
124 let fut = self.future.as_mut();
125 self.task.enter(|cx| fut.poll(cx))
126 }
127}
128
129impl<T: Stream> Spawn<T> {
130 /// If `T` is a [`Stream`] then `poll_next` it. This will handle pinning and the context
131 /// type for the stream.
132 pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
133 let stream = self.future.as_mut();
134 self.task.enter(|cx| stream.poll_next(cx))
135 }
136}
137
138impl<T: Future> Future for Spawn<T> {
139 type Output = T::Output;
140
141 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142 self.future.as_mut().poll(cx)
143 }
144}
145
146impl<T: Stream> Stream for Spawn<T> {
147 type Item = T::Item;
148
149 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150 self.future.as_mut().poll_next(cx)
151 }
152}
153
154impl MockTask {
155 /// Creates new mock task
156 fn new() -> Self {
157 MockTask {
158 waker: Arc::new(ThreadWaker::new()),
159 }
160 }
161
162 /// Runs a closure from the context of the task.
163 ///
164 /// Any wake notifications resulting from the execution of the closure are
165 /// tracked.
166 fn enter<F, R>(&mut self, f: F) -> R
167 where
168 F: FnOnce(&mut Context<'_>) -> R,
169 {
170 self.waker.clear();
171 let waker = self.waker();
172 let mut cx = Context::from_waker(&waker);
173
174 f(&mut cx)
175 }
176
177 /// Returns `true` if the inner future has received a wake notification
178 /// since the last call to `enter`.
179 fn is_woken(&self) -> bool {
180 self.waker.is_woken()
181 }
182
183 /// Returns the number of references to the task waker
184 ///
185 /// The task itself holds a reference. The return value will never be zero.
186 fn waker_ref_count(&self) -> usize {
187 Arc::strong_count(&self.waker)
188 }
189
190 fn waker(&self) -> Waker {
191 unsafe {
192 let raw = to_raw(self.waker.clone());
193 Waker::from_raw(raw)
194 }
195 }
196}
197
198impl Default for MockTask {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl ThreadWaker {
205 fn new() -> Self {
206 ThreadWaker {
207 state: Mutex::new(IDLE),
208 condvar: Condvar::new(),
209 }
210 }
211
212 /// Clears any previously received wakes, avoiding potential spurious
213 /// wake notifications. This should only be called immediately before running the
214 /// task.
215 fn clear(&self) {
216 *self.state.lock().unwrap() = IDLE;
217 }
218
219 fn is_woken(&self) -> bool {
220 match *self.state.lock().unwrap() {
221 IDLE => false,
222 WAKE => true,
223 _ => unreachable!(),
224 }
225 }
226
227 fn wake(&self) {
228 // First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
229 let mut state = self.state.lock().unwrap();
230 let prev = *state;
231
232 if prev == WAKE {
233 return;
234 }
235
236 *state = WAKE;
237
238 if prev == IDLE {
239 return;
240 }
241
242 // The other half is sleeping, so we wake it up.
243 assert_eq!(prev, SLEEP);
244 self.condvar.notify_one();
245 }
246}
247
248static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
249
250unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
251 RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
252}
253
254unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
255 Arc::from_raw(raw as *const ThreadWaker)
256}
257
258unsafe fn clone(raw: *const ()) -> RawWaker {
259 let waker = from_raw(raw);
260
261 // Increment the ref count
262 mem::forget(waker.clone());
263
264 to_raw(waker)
265}
266
267unsafe fn wake(raw: *const ()) {
268 let waker = from_raw(raw);
269 waker.wake();
270}
271
272unsafe fn wake_by_ref(raw: *const ()) {
273 let waker = from_raw(raw);
274 waker.wake();
275
276 // We don't actually own a reference to the unparker
277 mem::forget(waker);
278}
279
280unsafe fn drop_waker(raw: *const ()) {
281 let _ = from_raw(raw);
282}
283