| 1 | #![warn (rust_2018_idioms)] |
| 2 | #![cfg (not(target_os = "wasi" ))] // Wasi doesn't support threads |
| 3 | |
| 4 | use std::rc::Rc; |
| 5 | use std::sync::Arc; |
| 6 | use tokio::sync::Barrier; |
| 7 | use tokio_util::task; |
| 8 | |
| 9 | /// Simple test of running a !Send future via spawn_pinned |
| 10 | #[tokio::test ] |
| 11 | async fn can_spawn_not_send_future() { |
| 12 | let pool = task::LocalPoolHandle::new(1); |
| 13 | |
| 14 | let output = pool |
| 15 | .spawn_pinned(|| { |
| 16 | // Rc is !Send + !Sync |
| 17 | let local_data = Rc::new("test" ); |
| 18 | |
| 19 | // This future holds an Rc, so it is !Send |
| 20 | async move { local_data.to_string() } |
| 21 | }) |
| 22 | .await |
| 23 | .unwrap(); |
| 24 | |
| 25 | assert_eq!(output, "test" ); |
| 26 | } |
| 27 | |
| 28 | /// Dropping the join handle still lets the task execute |
| 29 | #[test] |
| 30 | fn can_drop_future_and_still_get_output() { |
| 31 | let pool = task::LocalPoolHandle::new(1); |
| 32 | let (sender, receiver) = std::sync::mpsc::channel(); |
| 33 | |
| 34 | let _ = pool.spawn_pinned(move || { |
| 35 | // Rc is !Send + !Sync |
| 36 | let local_data = Rc::new("test" ); |
| 37 | |
| 38 | // This future holds an Rc, so it is !Send |
| 39 | async move { |
| 40 | let _ = sender.send(local_data.to_string()); |
| 41 | } |
| 42 | }); |
| 43 | |
| 44 | assert_eq!(receiver.recv(), Ok("test" .to_string())); |
| 45 | } |
| 46 | |
| 47 | #[test] |
| 48 | #[should_panic (expected = "assertion failed: pool_size > 0" )] |
| 49 | fn cannot_create_zero_sized_pool() { |
| 50 | let _pool = task::LocalPoolHandle::new(0); |
| 51 | } |
| 52 | |
| 53 | /// We should be able to spawn multiple futures onto the pool at the same time. |
| 54 | #[tokio::test ] |
| 55 | async fn can_spawn_multiple_futures() { |
| 56 | let pool = task::LocalPoolHandle::new(2); |
| 57 | |
| 58 | let join_handle1 = pool.spawn_pinned(|| { |
| 59 | let local_data = Rc::new("test1" ); |
| 60 | async move { local_data.to_string() } |
| 61 | }); |
| 62 | let join_handle2 = pool.spawn_pinned(|| { |
| 63 | let local_data = Rc::new("test2" ); |
| 64 | async move { local_data.to_string() } |
| 65 | }); |
| 66 | |
| 67 | assert_eq!(join_handle1.await.unwrap(), "test1" ); |
| 68 | assert_eq!(join_handle2.await.unwrap(), "test2" ); |
| 69 | } |
| 70 | |
| 71 | /// A panic in the spawned task causes the join handle to return an error. |
| 72 | /// But, you can continue to spawn tasks. |
| 73 | #[tokio::test ] |
| 74 | async fn task_panic_propagates() { |
| 75 | let pool = task::LocalPoolHandle::new(1); |
| 76 | |
| 77 | let join_handle = pool.spawn_pinned(|| async { |
| 78 | panic!("Test panic" ); |
| 79 | }); |
| 80 | |
| 81 | let result = join_handle.await; |
| 82 | assert!(result.is_err()); |
| 83 | let error = result.unwrap_err(); |
| 84 | assert!(error.is_panic()); |
| 85 | let panic_str = error.into_panic().downcast::<&'static str>().unwrap(); |
| 86 | assert_eq!(*panic_str, "Test panic" ); |
| 87 | |
| 88 | // Trying again with a "safe" task still works |
| 89 | let join_handle = pool.spawn_pinned(|| async { "test" }); |
| 90 | let result = join_handle.await; |
| 91 | assert!(result.is_ok()); |
| 92 | assert_eq!(result.unwrap(), "test" ); |
| 93 | } |
| 94 | |
| 95 | /// A panic during task creation causes the join handle to return an error. |
| 96 | /// But, you can continue to spawn tasks. |
| 97 | #[tokio::test ] |
| 98 | async fn callback_panic_does_not_kill_worker() { |
| 99 | let pool = task::LocalPoolHandle::new(1); |
| 100 | |
| 101 | let join_handle = pool.spawn_pinned(|| { |
| 102 | panic!("Test panic" ); |
| 103 | #[allow (unreachable_code)] |
| 104 | async {} |
| 105 | }); |
| 106 | |
| 107 | let result = join_handle.await; |
| 108 | assert!(result.is_err()); |
| 109 | let error = result.unwrap_err(); |
| 110 | assert!(error.is_panic()); |
| 111 | let panic_str = error.into_panic().downcast::<&'static str>().unwrap(); |
| 112 | assert_eq!(*panic_str, "Test panic" ); |
| 113 | |
| 114 | // Trying again with a "safe" callback works |
| 115 | let join_handle = pool.spawn_pinned(|| async { "test" }); |
| 116 | let result = join_handle.await; |
| 117 | assert!(result.is_ok()); |
| 118 | assert_eq!(result.unwrap(), "test" ); |
| 119 | } |
| 120 | |
| 121 | /// Canceling the task via the returned join handle cancels the spawned task |
| 122 | /// (which has a different, internal join handle). |
| 123 | #[tokio::test ] |
| 124 | async fn task_cancellation_propagates() { |
| 125 | let pool = task::LocalPoolHandle::new(1); |
| 126 | let notify_dropped = Arc::new(()); |
| 127 | let weak_notify_dropped = Arc::downgrade(¬ify_dropped); |
| 128 | |
| 129 | let (start_sender, start_receiver) = tokio::sync::oneshot::channel(); |
| 130 | let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>(); |
| 131 | let join_handle = pool.spawn_pinned(|| async move { |
| 132 | let _drop_sender = drop_sender; |
| 133 | // Move the Arc into the task |
| 134 | let _notify_dropped = notify_dropped; |
| 135 | let _ = start_sender.send(()); |
| 136 | |
| 137 | // Keep the task running until it gets aborted |
| 138 | futures::future::pending::<()>().await; |
| 139 | }); |
| 140 | |
| 141 | // Wait for the task to start |
| 142 | let _ = start_receiver.await; |
| 143 | |
| 144 | join_handle.abort(); |
| 145 | |
| 146 | // Wait for the inner task to abort, dropping the sender. |
| 147 | // The top level join handle aborts quicker than the inner task (the abort |
| 148 | // needs to propagate and get processed on the worker thread), so we can't |
| 149 | // just await the top level join handle. |
| 150 | let _ = drop_receiver.await; |
| 151 | |
| 152 | // Check that the Arc has been dropped. This verifies that the inner task |
| 153 | // was canceled as well. |
| 154 | assert!(weak_notify_dropped.upgrade().is_none()); |
| 155 | } |
| 156 | |
| 157 | /// Tasks should be given to the least burdened worker. When spawning two tasks |
| 158 | /// on a pool with two empty workers the tasks should be spawned on separate |
| 159 | /// workers. |
| 160 | #[tokio::test ] |
| 161 | async fn tasks_are_balanced() { |
| 162 | let pool = task::LocalPoolHandle::new(2); |
| 163 | |
| 164 | // Spawn a task so one thread has a task count of 1 |
| 165 | let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel(); |
| 166 | let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel(); |
| 167 | let join_handle1 = pool.spawn_pinned(|| async move { |
| 168 | let _ = start_sender1.send(()); |
| 169 | let _ = end_receiver1.await; |
| 170 | std::thread::current().id() |
| 171 | }); |
| 172 | |
| 173 | // Wait for the first task to start up |
| 174 | let _ = start_receiver1.await; |
| 175 | |
| 176 | // This task should be spawned on the other thread |
| 177 | let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel(); |
| 178 | let join_handle2 = pool.spawn_pinned(|| async move { |
| 179 | let _ = start_sender2.send(()); |
| 180 | std::thread::current().id() |
| 181 | }); |
| 182 | |
| 183 | // Wait for the second task to start up |
| 184 | let _ = start_receiver2.await; |
| 185 | |
| 186 | // Allow the first task to end |
| 187 | let _ = end_sender1.send(()); |
| 188 | |
| 189 | let thread_id1 = join_handle1.await.unwrap(); |
| 190 | let thread_id2 = join_handle2.await.unwrap(); |
| 191 | |
| 192 | // Since the first task was active when the second task spawned, they should |
| 193 | // be on separate workers/threads. |
| 194 | assert_ne!(thread_id1, thread_id2); |
| 195 | } |
| 196 | |
| 197 | #[tokio::test ] |
| 198 | async fn spawn_by_idx() { |
| 199 | let pool = task::LocalPoolHandle::new(3); |
| 200 | let barrier = Arc::new(Barrier::new(4)); |
| 201 | let barrier1 = barrier.clone(); |
| 202 | let barrier2 = barrier.clone(); |
| 203 | let barrier3 = barrier.clone(); |
| 204 | |
| 205 | let handle1 = pool.spawn_pinned_by_idx( |
| 206 | || async move { |
| 207 | barrier1.wait().await; |
| 208 | std::thread::current().id() |
| 209 | }, |
| 210 | 0, |
| 211 | ); |
| 212 | let _ = pool.spawn_pinned_by_idx( |
| 213 | || async move { |
| 214 | barrier2.wait().await; |
| 215 | std::thread::current().id() |
| 216 | }, |
| 217 | 0, |
| 218 | ); |
| 219 | let handle2 = pool.spawn_pinned_by_idx( |
| 220 | || async move { |
| 221 | barrier3.wait().await; |
| 222 | std::thread::current().id() |
| 223 | }, |
| 224 | 1, |
| 225 | ); |
| 226 | |
| 227 | let loads = pool.get_task_loads_for_each_worker(); |
| 228 | barrier.wait().await; |
| 229 | assert_eq!(loads[0], 2); |
| 230 | assert_eq!(loads[1], 1); |
| 231 | assert_eq!(loads[2], 0); |
| 232 | |
| 233 | let thread_id1 = handle1.await.unwrap(); |
| 234 | let thread_id2 = handle2.await.unwrap(); |
| 235 | |
| 236 | assert_ne!(thread_id1, thread_id2); |
| 237 | } |
| 238 | |