1 | #![warn (rust_2018_idioms)] |
2 | #![cfg (not(target_os = "wasi" ))] // Wasi doesn't support threads |
3 | |
4 | use std::rc::Rc; |
5 | use std::sync::Arc; |
6 | use tokio::sync::Barrier; |
7 | use tokio_util::task; |
8 | |
9 | /// Simple test of running a !Send future via spawn_pinned |
10 | #[tokio::test ] |
11 | async fn can_spawn_not_send_future() { |
12 | let pool = task::LocalPoolHandle::new(1); |
13 | |
14 | let output = pool |
15 | .spawn_pinned(|| { |
16 | // Rc is !Send + !Sync |
17 | let local_data = Rc::new("test" ); |
18 | |
19 | // This future holds an Rc, so it is !Send |
20 | async move { local_data.to_string() } |
21 | }) |
22 | .await |
23 | .unwrap(); |
24 | |
25 | assert_eq!(output, "test" ); |
26 | } |
27 | |
28 | /// Dropping the join handle still lets the task execute |
29 | #[test] |
30 | fn can_drop_future_and_still_get_output() { |
31 | let pool = task::LocalPoolHandle::new(1); |
32 | let (sender, receiver) = std::sync::mpsc::channel(); |
33 | |
34 | let _ = pool.spawn_pinned(move || { |
35 | // Rc is !Send + !Sync |
36 | let local_data = Rc::new("test" ); |
37 | |
38 | // This future holds an Rc, so it is !Send |
39 | async move { |
40 | let _ = sender.send(local_data.to_string()); |
41 | } |
42 | }); |
43 | |
44 | assert_eq!(receiver.recv(), Ok("test" .to_string())); |
45 | } |
46 | |
47 | #[test] |
48 | #[should_panic (expected = "assertion failed: pool_size > 0" )] |
49 | fn cannot_create_zero_sized_pool() { |
50 | let _pool = task::LocalPoolHandle::new(0); |
51 | } |
52 | |
53 | /// We should be able to spawn multiple futures onto the pool at the same time. |
54 | #[tokio::test ] |
55 | async fn can_spawn_multiple_futures() { |
56 | let pool = task::LocalPoolHandle::new(2); |
57 | |
58 | let join_handle1 = pool.spawn_pinned(|| { |
59 | let local_data = Rc::new("test1" ); |
60 | async move { local_data.to_string() } |
61 | }); |
62 | let join_handle2 = pool.spawn_pinned(|| { |
63 | let local_data = Rc::new("test2" ); |
64 | async move { local_data.to_string() } |
65 | }); |
66 | |
67 | assert_eq!(join_handle1.await.unwrap(), "test1" ); |
68 | assert_eq!(join_handle2.await.unwrap(), "test2" ); |
69 | } |
70 | |
71 | /// A panic in the spawned task causes the join handle to return an error. |
72 | /// But, you can continue to spawn tasks. |
73 | #[tokio::test ] |
74 | async fn task_panic_propagates() { |
75 | let pool = task::LocalPoolHandle::new(1); |
76 | |
77 | let join_handle = pool.spawn_pinned(|| async { |
78 | panic!("Test panic" ); |
79 | }); |
80 | |
81 | let result = join_handle.await; |
82 | assert!(result.is_err()); |
83 | let error = result.unwrap_err(); |
84 | assert!(error.is_panic()); |
85 | let panic_str = error.into_panic().downcast::<&'static str>().unwrap(); |
86 | assert_eq!(*panic_str, "Test panic" ); |
87 | |
88 | // Trying again with a "safe" task still works |
89 | let join_handle = pool.spawn_pinned(|| async { "test" }); |
90 | let result = join_handle.await; |
91 | assert!(result.is_ok()); |
92 | assert_eq!(result.unwrap(), "test" ); |
93 | } |
94 | |
95 | /// A panic during task creation causes the join handle to return an error. |
96 | /// But, you can continue to spawn tasks. |
97 | #[tokio::test ] |
98 | async fn callback_panic_does_not_kill_worker() { |
99 | let pool = task::LocalPoolHandle::new(1); |
100 | |
101 | let join_handle = pool.spawn_pinned(|| { |
102 | panic!("Test panic" ); |
103 | #[allow (unreachable_code)] |
104 | async {} |
105 | }); |
106 | |
107 | let result = join_handle.await; |
108 | assert!(result.is_err()); |
109 | let error = result.unwrap_err(); |
110 | assert!(error.is_panic()); |
111 | let panic_str = error.into_panic().downcast::<&'static str>().unwrap(); |
112 | assert_eq!(*panic_str, "Test panic" ); |
113 | |
114 | // Trying again with a "safe" callback works |
115 | let join_handle = pool.spawn_pinned(|| async { "test" }); |
116 | let result = join_handle.await; |
117 | assert!(result.is_ok()); |
118 | assert_eq!(result.unwrap(), "test" ); |
119 | } |
120 | |
121 | /// Canceling the task via the returned join handle cancels the spawned task |
122 | /// (which has a different, internal join handle). |
123 | #[tokio::test ] |
124 | async fn task_cancellation_propagates() { |
125 | let pool = task::LocalPoolHandle::new(1); |
126 | let notify_dropped = Arc::new(()); |
127 | let weak_notify_dropped = Arc::downgrade(¬ify_dropped); |
128 | |
129 | let (start_sender, start_receiver) = tokio::sync::oneshot::channel(); |
130 | let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>(); |
131 | let join_handle = pool.spawn_pinned(|| async move { |
132 | let _drop_sender = drop_sender; |
133 | // Move the Arc into the task |
134 | let _notify_dropped = notify_dropped; |
135 | let _ = start_sender.send(()); |
136 | |
137 | // Keep the task running until it gets aborted |
138 | futures::future::pending::<()>().await; |
139 | }); |
140 | |
141 | // Wait for the task to start |
142 | let _ = start_receiver.await; |
143 | |
144 | join_handle.abort(); |
145 | |
146 | // Wait for the inner task to abort, dropping the sender. |
147 | // The top level join handle aborts quicker than the inner task (the abort |
148 | // needs to propagate and get processed on the worker thread), so we can't |
149 | // just await the top level join handle. |
150 | let _ = drop_receiver.await; |
151 | |
152 | // Check that the Arc has been dropped. This verifies that the inner task |
153 | // was canceled as well. |
154 | assert!(weak_notify_dropped.upgrade().is_none()); |
155 | } |
156 | |
157 | /// Tasks should be given to the least burdened worker. When spawning two tasks |
158 | /// on a pool with two empty workers the tasks should be spawned on separate |
159 | /// workers. |
160 | #[tokio::test ] |
161 | async fn tasks_are_balanced() { |
162 | let pool = task::LocalPoolHandle::new(2); |
163 | |
164 | // Spawn a task so one thread has a task count of 1 |
165 | let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel(); |
166 | let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel(); |
167 | let join_handle1 = pool.spawn_pinned(|| async move { |
168 | let _ = start_sender1.send(()); |
169 | let _ = end_receiver1.await; |
170 | std::thread::current().id() |
171 | }); |
172 | |
173 | // Wait for the first task to start up |
174 | let _ = start_receiver1.await; |
175 | |
176 | // This task should be spawned on the other thread |
177 | let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel(); |
178 | let join_handle2 = pool.spawn_pinned(|| async move { |
179 | let _ = start_sender2.send(()); |
180 | std::thread::current().id() |
181 | }); |
182 | |
183 | // Wait for the second task to start up |
184 | let _ = start_receiver2.await; |
185 | |
186 | // Allow the first task to end |
187 | let _ = end_sender1.send(()); |
188 | |
189 | let thread_id1 = join_handle1.await.unwrap(); |
190 | let thread_id2 = join_handle2.await.unwrap(); |
191 | |
192 | // Since the first task was active when the second task spawned, they should |
193 | // be on separate workers/threads. |
194 | assert_ne!(thread_id1, thread_id2); |
195 | } |
196 | |
197 | #[tokio::test ] |
198 | async fn spawn_by_idx() { |
199 | let pool = task::LocalPoolHandle::new(3); |
200 | let barrier = Arc::new(Barrier::new(4)); |
201 | let barrier1 = barrier.clone(); |
202 | let barrier2 = barrier.clone(); |
203 | let barrier3 = barrier.clone(); |
204 | |
205 | let handle1 = pool.spawn_pinned_by_idx( |
206 | || async move { |
207 | barrier1.wait().await; |
208 | std::thread::current().id() |
209 | }, |
210 | 0, |
211 | ); |
212 | let _ = pool.spawn_pinned_by_idx( |
213 | || async move { |
214 | barrier2.wait().await; |
215 | std::thread::current().id() |
216 | }, |
217 | 0, |
218 | ); |
219 | let handle2 = pool.spawn_pinned_by_idx( |
220 | || async move { |
221 | barrier3.wait().await; |
222 | std::thread::current().id() |
223 | }, |
224 | 1, |
225 | ); |
226 | |
227 | let loads = pool.get_task_loads_for_each_worker(); |
228 | barrier.wait().await; |
229 | assert_eq!(loads[0], 2); |
230 | assert_eq!(loads[1], 1); |
231 | assert_eq!(loads[2], 0); |
232 | |
233 | let thread_id1 = handle1.await.unwrap(); |
234 | let thread_id2 = handle2.await.unwrap(); |
235 | |
236 | assert_ne!(thread_id1, thread_id2); |
237 | } |
238 | |