| 1 | #![warn (rust_2018_idioms)] |
| 2 | #![cfg (all(feature = "rt" , tokio_unstable))] |
| 3 | |
| 4 | use tokio::sync::oneshot; |
| 5 | use tokio::time::Duration; |
| 6 | use tokio_util::task::JoinMap; |
| 7 | |
| 8 | use futures::future::FutureExt; |
| 9 | |
| 10 | fn rt() -> tokio::runtime::Runtime { |
| 11 | tokio::runtime::Builder::new_current_thread() |
| 12 | .build() |
| 13 | .unwrap() |
| 14 | } |
| 15 | |
| 16 | #[tokio::test (start_paused = true)] |
| 17 | async fn test_with_sleep() { |
| 18 | let mut map = JoinMap::new(); |
| 19 | |
| 20 | for i in 0..10 { |
| 21 | map.spawn(i, async move { i }); |
| 22 | assert_eq!(map.len(), 1 + i); |
| 23 | } |
| 24 | map.detach_all(); |
| 25 | assert_eq!(map.len(), 0); |
| 26 | |
| 27 | assert!(matches!(map.join_next().await, None)); |
| 28 | |
| 29 | for i in 0..10 { |
| 30 | map.spawn(i, async move { |
| 31 | tokio::time::sleep(Duration::from_secs(i as u64)).await; |
| 32 | i |
| 33 | }); |
| 34 | assert_eq!(map.len(), 1 + i); |
| 35 | } |
| 36 | |
| 37 | let mut seen = [false; 10]; |
| 38 | while let Some((k, res)) = map.join_next().await { |
| 39 | seen[k] = true; |
| 40 | assert_eq!(res.expect("task should have completed successfully" ), k); |
| 41 | } |
| 42 | |
| 43 | for was_seen in &seen { |
| 44 | assert!(was_seen); |
| 45 | } |
| 46 | assert!(matches!(map.join_next().await, None)); |
| 47 | |
| 48 | // Do it again. |
| 49 | for i in 0..10 { |
| 50 | map.spawn(i, async move { |
| 51 | tokio::time::sleep(Duration::from_secs(i as u64)).await; |
| 52 | i |
| 53 | }); |
| 54 | } |
| 55 | |
| 56 | let mut seen = [false; 10]; |
| 57 | while let Some((k, res)) = map.join_next().await { |
| 58 | seen[k] = true; |
| 59 | assert_eq!(res.expect("task should have completed successfully" ), k); |
| 60 | } |
| 61 | |
| 62 | for was_seen in &seen { |
| 63 | assert!(was_seen); |
| 64 | } |
| 65 | assert!(matches!(map.join_next().await, None)); |
| 66 | } |
| 67 | |
| 68 | #[tokio::test ] |
| 69 | async fn test_abort_on_drop() { |
| 70 | let mut map = JoinMap::new(); |
| 71 | |
| 72 | let mut recvs = Vec::new(); |
| 73 | |
| 74 | for i in 0..16 { |
| 75 | let (send, recv) = oneshot::channel::<()>(); |
| 76 | recvs.push(recv); |
| 77 | |
| 78 | map.spawn(i, async { |
| 79 | // This task will never complete on its own. |
| 80 | futures::future::pending::<()>().await; |
| 81 | drop(send); |
| 82 | }); |
| 83 | } |
| 84 | |
| 85 | drop(map); |
| 86 | |
| 87 | for recv in recvs { |
| 88 | // The task is aborted soon and we will receive an error. |
| 89 | assert!(recv.await.is_err()); |
| 90 | } |
| 91 | } |
| 92 | |
| 93 | #[tokio::test ] |
| 94 | async fn alternating() { |
| 95 | let mut map = JoinMap::new(); |
| 96 | |
| 97 | assert_eq!(map.len(), 0); |
| 98 | map.spawn(1, async {}); |
| 99 | assert_eq!(map.len(), 1); |
| 100 | map.spawn(2, async {}); |
| 101 | assert_eq!(map.len(), 2); |
| 102 | |
| 103 | for i in 0..16 { |
| 104 | let (_, res) = map.join_next().await.unwrap(); |
| 105 | assert!(res.is_ok()); |
| 106 | assert_eq!(map.len(), 1); |
| 107 | map.spawn(i, async {}); |
| 108 | assert_eq!(map.len(), 2); |
| 109 | } |
| 110 | } |
| 111 | |
| 112 | #[tokio::test ] |
| 113 | async fn test_keys() { |
| 114 | use std::collections::HashSet; |
| 115 | |
| 116 | let mut map = JoinMap::new(); |
| 117 | |
| 118 | assert_eq!(map.len(), 0); |
| 119 | map.spawn(1, async {}); |
| 120 | assert_eq!(map.len(), 1); |
| 121 | map.spawn(2, async {}); |
| 122 | assert_eq!(map.len(), 2); |
| 123 | |
| 124 | let keys = map.keys().collect::<HashSet<&u32>>(); |
| 125 | assert!(keys.contains(&1)); |
| 126 | assert!(keys.contains(&2)); |
| 127 | |
| 128 | let _ = map.join_next().await.unwrap(); |
| 129 | let _ = map.join_next().await.unwrap(); |
| 130 | |
| 131 | assert_eq!(map.len(), 0); |
| 132 | let keys = map.keys().collect::<HashSet<&u32>>(); |
| 133 | assert!(keys.is_empty()); |
| 134 | } |
| 135 | |
| 136 | #[tokio::test (start_paused = true)] |
| 137 | async fn abort_by_key() { |
| 138 | let mut map = JoinMap::new(); |
| 139 | let mut num_canceled = 0; |
| 140 | let mut num_completed = 0; |
| 141 | for i in 0..16 { |
| 142 | map.spawn(i, async move { |
| 143 | tokio::time::sleep(Duration::from_secs(i as u64)).await; |
| 144 | }); |
| 145 | } |
| 146 | |
| 147 | for i in 0..16 { |
| 148 | if i % 2 != 0 { |
| 149 | // abort odd-numbered tasks. |
| 150 | map.abort(&i); |
| 151 | } |
| 152 | } |
| 153 | |
| 154 | while let Some((key, res)) = map.join_next().await { |
| 155 | match res { |
| 156 | Ok(()) => { |
| 157 | num_completed += 1; |
| 158 | assert_eq!(key % 2, 0); |
| 159 | assert!(!map.contains_key(&key)); |
| 160 | } |
| 161 | Err(e) => { |
| 162 | num_canceled += 1; |
| 163 | assert!(e.is_cancelled()); |
| 164 | assert_ne!(key % 2, 0); |
| 165 | assert!(!map.contains_key(&key)); |
| 166 | } |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | assert_eq!(num_canceled, 8); |
| 171 | assert_eq!(num_completed, 8); |
| 172 | } |
| 173 | |
| 174 | #[tokio::test (start_paused = true)] |
| 175 | async fn abort_by_predicate() { |
| 176 | let mut map = JoinMap::new(); |
| 177 | let mut num_canceled = 0; |
| 178 | let mut num_completed = 0; |
| 179 | for i in 0..16 { |
| 180 | map.spawn(i, async move { |
| 181 | tokio::time::sleep(Duration::from_secs(i as u64)).await; |
| 182 | }); |
| 183 | } |
| 184 | |
| 185 | // abort odd-numbered tasks. |
| 186 | map.abort_matching(|key| key % 2 != 0); |
| 187 | |
| 188 | while let Some((key, res)) = map.join_next().await { |
| 189 | match res { |
| 190 | Ok(()) => { |
| 191 | num_completed += 1; |
| 192 | assert_eq!(key % 2, 0); |
| 193 | assert!(!map.contains_key(&key)); |
| 194 | } |
| 195 | Err(e) => { |
| 196 | num_canceled += 1; |
| 197 | assert!(e.is_cancelled()); |
| 198 | assert_ne!(key % 2, 0); |
| 199 | assert!(!map.contains_key(&key)); |
| 200 | } |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | assert_eq!(num_canceled, 8); |
| 205 | assert_eq!(num_completed, 8); |
| 206 | } |
| 207 | |
| 208 | #[test] |
| 209 | fn runtime_gone() { |
| 210 | let mut map = JoinMap::new(); |
| 211 | { |
| 212 | let rt = rt(); |
| 213 | map.spawn_on("key" , async { 1 }, rt.handle()); |
| 214 | drop(rt); |
| 215 | } |
| 216 | |
| 217 | let (key, res) = rt().block_on(map.join_next()).unwrap(); |
| 218 | assert_eq!(key, "key" ); |
| 219 | assert!(res.unwrap_err().is_cancelled()); |
| 220 | } |
| 221 | |
| 222 | // This ensures that `join_next` works correctly when the coop budget is |
| 223 | // exhausted. |
| 224 | #[tokio::test (flavor = "current_thread" )] |
| 225 | async fn join_map_coop() { |
| 226 | // Large enough to trigger coop. |
| 227 | const TASK_NUM: u32 = 1000; |
| 228 | |
| 229 | static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0); |
| 230 | |
| 231 | let mut map = JoinMap::new(); |
| 232 | |
| 233 | for i in 0..TASK_NUM { |
| 234 | map.spawn(i, async move { |
| 235 | SEM.add_permits(1); |
| 236 | i |
| 237 | }); |
| 238 | } |
| 239 | |
| 240 | // Wait for all tasks to complete. |
| 241 | // |
| 242 | // Since this is a `current_thread` runtime, there's no race condition |
| 243 | // between the last permit being added and the task completing. |
| 244 | let _ = SEM.acquire_many(TASK_NUM).await.unwrap(); |
| 245 | |
| 246 | let mut count = 0; |
| 247 | let mut coop_count = 0; |
| 248 | loop { |
| 249 | match map.join_next().now_or_never() { |
| 250 | Some(Some((key, Ok(i)))) => assert_eq!(key, i), |
| 251 | Some(Some((key, Err(err)))) => panic!("failed[{}]: {}" , key, err), |
| 252 | None => { |
| 253 | coop_count += 1; |
| 254 | tokio::task::yield_now().await; |
| 255 | continue; |
| 256 | } |
| 257 | Some(None) => break, |
| 258 | } |
| 259 | |
| 260 | count += 1; |
| 261 | } |
| 262 | assert!(coop_count >= 1); |
| 263 | assert_eq!(count, TASK_NUM); |
| 264 | } |
| 265 | |
| 266 | #[tokio::test (start_paused = true)] |
| 267 | async fn abort_all() { |
| 268 | let mut map: JoinMap<usize, ()> = JoinMap::new(); |
| 269 | |
| 270 | for i in 0..5 { |
| 271 | map.spawn(i, futures::future::pending()); |
| 272 | } |
| 273 | for i in 5..10 { |
| 274 | map.spawn(i, async { |
| 275 | tokio::time::sleep(Duration::from_secs(1)).await; |
| 276 | }); |
| 277 | } |
| 278 | |
| 279 | // The join map will now have 5 pending tasks and 5 ready tasks. |
| 280 | tokio::time::sleep(Duration::from_secs(2)).await; |
| 281 | |
| 282 | map.abort_all(); |
| 283 | assert_eq!(map.len(), 10); |
| 284 | |
| 285 | let mut count = 0; |
| 286 | let mut seen = [false; 10]; |
| 287 | while let Some((k, res)) = map.join_next().await { |
| 288 | seen[k] = true; |
| 289 | if let Err(err) = res { |
| 290 | assert!(err.is_cancelled()); |
| 291 | } |
| 292 | count += 1; |
| 293 | } |
| 294 | assert_eq!(count, 10); |
| 295 | assert_eq!(map.len(), 0); |
| 296 | for was_seen in &seen { |
| 297 | assert!(was_seen); |
| 298 | } |
| 299 | } |
| 300 | |