1 | #![warn (rust_2018_idioms)] |
2 | #![cfg (all(feature = "rt" , tokio_unstable))] |
3 | |
4 | use tokio::sync::oneshot; |
5 | use tokio::time::Duration; |
6 | use tokio_util::task::JoinMap; |
7 | |
8 | use futures::future::FutureExt; |
9 | |
10 | fn rt() -> tokio::runtime::Runtime { |
11 | tokio::runtime::Builder::new_current_thread() |
12 | .build() |
13 | .unwrap() |
14 | } |
15 | |
16 | #[tokio::test (start_paused = true)] |
17 | async fn test_with_sleep() { |
18 | let mut map = JoinMap::new(); |
19 | |
20 | for i in 0..10 { |
21 | map.spawn(i, async move { i }); |
22 | assert_eq!(map.len(), 1 + i); |
23 | } |
24 | map.detach_all(); |
25 | assert_eq!(map.len(), 0); |
26 | |
27 | assert!(matches!(map.join_next().await, None)); |
28 | |
29 | for i in 0..10 { |
30 | map.spawn(i, async move { |
31 | tokio::time::sleep(Duration::from_secs(i as u64)).await; |
32 | i |
33 | }); |
34 | assert_eq!(map.len(), 1 + i); |
35 | } |
36 | |
37 | let mut seen = [false; 10]; |
38 | while let Some((k, res)) = map.join_next().await { |
39 | seen[k] = true; |
40 | assert_eq!(res.expect("task should have completed successfully" ), k); |
41 | } |
42 | |
43 | for was_seen in &seen { |
44 | assert!(was_seen); |
45 | } |
46 | assert!(matches!(map.join_next().await, None)); |
47 | |
48 | // Do it again. |
49 | for i in 0..10 { |
50 | map.spawn(i, async move { |
51 | tokio::time::sleep(Duration::from_secs(i as u64)).await; |
52 | i |
53 | }); |
54 | } |
55 | |
56 | let mut seen = [false; 10]; |
57 | while let Some((k, res)) = map.join_next().await { |
58 | seen[k] = true; |
59 | assert_eq!(res.expect("task should have completed successfully" ), k); |
60 | } |
61 | |
62 | for was_seen in &seen { |
63 | assert!(was_seen); |
64 | } |
65 | assert!(matches!(map.join_next().await, None)); |
66 | } |
67 | |
68 | #[tokio::test ] |
69 | async fn test_abort_on_drop() { |
70 | let mut map = JoinMap::new(); |
71 | |
72 | let mut recvs = Vec::new(); |
73 | |
74 | for i in 0..16 { |
75 | let (send, recv) = oneshot::channel::<()>(); |
76 | recvs.push(recv); |
77 | |
78 | map.spawn(i, async { |
79 | // This task will never complete on its own. |
80 | futures::future::pending::<()>().await; |
81 | drop(send); |
82 | }); |
83 | } |
84 | |
85 | drop(map); |
86 | |
87 | for recv in recvs { |
88 | // The task is aborted soon and we will receive an error. |
89 | assert!(recv.await.is_err()); |
90 | } |
91 | } |
92 | |
93 | #[tokio::test ] |
94 | async fn alternating() { |
95 | let mut map = JoinMap::new(); |
96 | |
97 | assert_eq!(map.len(), 0); |
98 | map.spawn(1, async {}); |
99 | assert_eq!(map.len(), 1); |
100 | map.spawn(2, async {}); |
101 | assert_eq!(map.len(), 2); |
102 | |
103 | for i in 0..16 { |
104 | let (_, res) = map.join_next().await.unwrap(); |
105 | assert!(res.is_ok()); |
106 | assert_eq!(map.len(), 1); |
107 | map.spawn(i, async {}); |
108 | assert_eq!(map.len(), 2); |
109 | } |
110 | } |
111 | |
112 | #[tokio::test ] |
113 | async fn test_keys() { |
114 | use std::collections::HashSet; |
115 | |
116 | let mut map = JoinMap::new(); |
117 | |
118 | assert_eq!(map.len(), 0); |
119 | map.spawn(1, async {}); |
120 | assert_eq!(map.len(), 1); |
121 | map.spawn(2, async {}); |
122 | assert_eq!(map.len(), 2); |
123 | |
124 | let keys = map.keys().collect::<HashSet<&u32>>(); |
125 | assert!(keys.contains(&1)); |
126 | assert!(keys.contains(&2)); |
127 | |
128 | let _ = map.join_next().await.unwrap(); |
129 | let _ = map.join_next().await.unwrap(); |
130 | |
131 | assert_eq!(map.len(), 0); |
132 | let keys = map.keys().collect::<HashSet<&u32>>(); |
133 | assert!(keys.is_empty()); |
134 | } |
135 | |
136 | #[tokio::test (start_paused = true)] |
137 | async fn abort_by_key() { |
138 | let mut map = JoinMap::new(); |
139 | let mut num_canceled = 0; |
140 | let mut num_completed = 0; |
141 | for i in 0..16 { |
142 | map.spawn(i, async move { |
143 | tokio::time::sleep(Duration::from_secs(i as u64)).await; |
144 | }); |
145 | } |
146 | |
147 | for i in 0..16 { |
148 | if i % 2 != 0 { |
149 | // abort odd-numbered tasks. |
150 | map.abort(&i); |
151 | } |
152 | } |
153 | |
154 | while let Some((key, res)) = map.join_next().await { |
155 | match res { |
156 | Ok(()) => { |
157 | num_completed += 1; |
158 | assert_eq!(key % 2, 0); |
159 | assert!(!map.contains_key(&key)); |
160 | } |
161 | Err(e) => { |
162 | num_canceled += 1; |
163 | assert!(e.is_cancelled()); |
164 | assert_ne!(key % 2, 0); |
165 | assert!(!map.contains_key(&key)); |
166 | } |
167 | } |
168 | } |
169 | |
170 | assert_eq!(num_canceled, 8); |
171 | assert_eq!(num_completed, 8); |
172 | } |
173 | |
174 | #[tokio::test (start_paused = true)] |
175 | async fn abort_by_predicate() { |
176 | let mut map = JoinMap::new(); |
177 | let mut num_canceled = 0; |
178 | let mut num_completed = 0; |
179 | for i in 0..16 { |
180 | map.spawn(i, async move { |
181 | tokio::time::sleep(Duration::from_secs(i as u64)).await; |
182 | }); |
183 | } |
184 | |
185 | // abort odd-numbered tasks. |
186 | map.abort_matching(|key| key % 2 != 0); |
187 | |
188 | while let Some((key, res)) = map.join_next().await { |
189 | match res { |
190 | Ok(()) => { |
191 | num_completed += 1; |
192 | assert_eq!(key % 2, 0); |
193 | assert!(!map.contains_key(&key)); |
194 | } |
195 | Err(e) => { |
196 | num_canceled += 1; |
197 | assert!(e.is_cancelled()); |
198 | assert_ne!(key % 2, 0); |
199 | assert!(!map.contains_key(&key)); |
200 | } |
201 | } |
202 | } |
203 | |
204 | assert_eq!(num_canceled, 8); |
205 | assert_eq!(num_completed, 8); |
206 | } |
207 | |
208 | #[test] |
209 | fn runtime_gone() { |
210 | let mut map = JoinMap::new(); |
211 | { |
212 | let rt = rt(); |
213 | map.spawn_on("key" , async { 1 }, rt.handle()); |
214 | drop(rt); |
215 | } |
216 | |
217 | let (key, res) = rt().block_on(map.join_next()).unwrap(); |
218 | assert_eq!(key, "key" ); |
219 | assert!(res.unwrap_err().is_cancelled()); |
220 | } |
221 | |
222 | // This ensures that `join_next` works correctly when the coop budget is |
223 | // exhausted. |
224 | #[tokio::test (flavor = "current_thread" )] |
225 | async fn join_map_coop() { |
226 | // Large enough to trigger coop. |
227 | const TASK_NUM: u32 = 1000; |
228 | |
229 | static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0); |
230 | |
231 | let mut map = JoinMap::new(); |
232 | |
233 | for i in 0..TASK_NUM { |
234 | map.spawn(i, async move { |
235 | SEM.add_permits(1); |
236 | i |
237 | }); |
238 | } |
239 | |
240 | // Wait for all tasks to complete. |
241 | // |
242 | // Since this is a `current_thread` runtime, there's no race condition |
243 | // between the last permit being added and the task completing. |
244 | let _ = SEM.acquire_many(TASK_NUM).await.unwrap(); |
245 | |
246 | let mut count = 0; |
247 | let mut coop_count = 0; |
248 | loop { |
249 | match map.join_next().now_or_never() { |
250 | Some(Some((key, Ok(i)))) => assert_eq!(key, i), |
251 | Some(Some((key, Err(err)))) => panic!("failed[{}]: {}" , key, err), |
252 | None => { |
253 | coop_count += 1; |
254 | tokio::task::yield_now().await; |
255 | continue; |
256 | } |
257 | Some(None) => break, |
258 | } |
259 | |
260 | count += 1; |
261 | } |
262 | assert!(coop_count >= 1); |
263 | assert_eq!(count, TASK_NUM); |
264 | } |
265 | |
266 | #[tokio::test (start_paused = true)] |
267 | async fn abort_all() { |
268 | let mut map: JoinMap<usize, ()> = JoinMap::new(); |
269 | |
270 | for i in 0..5 { |
271 | map.spawn(i, futures::future::pending()); |
272 | } |
273 | for i in 5..10 { |
274 | map.spawn(i, async { |
275 | tokio::time::sleep(Duration::from_secs(1)).await; |
276 | }); |
277 | } |
278 | |
279 | // The join map will now have 5 pending tasks and 5 ready tasks. |
280 | tokio::time::sleep(Duration::from_secs(2)).await; |
281 | |
282 | map.abort_all(); |
283 | assert_eq!(map.len(), 10); |
284 | |
285 | let mut count = 0; |
286 | let mut seen = [false; 10]; |
287 | while let Some((k, res)) = map.join_next().await { |
288 | seen[k] = true; |
289 | if let Err(err) = res { |
290 | assert!(err.is_cancelled()); |
291 | } |
292 | count += 1; |
293 | } |
294 | assert_eq!(count, 10); |
295 | assert_eq!(map.len(), 0); |
296 | for was_seen in &seen { |
297 | assert!(was_seen); |
298 | } |
299 | } |
300 | |