1#![warn(rust_2018_idioms)]
2#![cfg(all(feature = "rt", tokio_unstable))]
3
4use tokio::sync::oneshot;
5use tokio::time::Duration;
6use tokio_util::task::JoinMap;
7
8use futures::future::FutureExt;
9
10fn rt() -> tokio::runtime::Runtime {
11 tokio::runtime::Builder::new_current_thread()
12 .build()
13 .unwrap()
14}
15
16#[tokio::test(start_paused = true)]
17async fn test_with_sleep() {
18 let mut map = JoinMap::new();
19
20 for i in 0..10 {
21 map.spawn(i, async move { i });
22 assert_eq!(map.len(), 1 + i);
23 }
24 map.detach_all();
25 assert_eq!(map.len(), 0);
26
27 assert!(matches!(map.join_next().await, None));
28
29 for i in 0..10 {
30 map.spawn(i, async move {
31 tokio::time::sleep(Duration::from_secs(i as u64)).await;
32 i
33 });
34 assert_eq!(map.len(), 1 + i);
35 }
36
37 let mut seen = [false; 10];
38 while let Some((k, res)) = map.join_next().await {
39 seen[k] = true;
40 assert_eq!(res.expect("task should have completed successfully"), k);
41 }
42
43 for was_seen in &seen {
44 assert!(was_seen);
45 }
46 assert!(matches!(map.join_next().await, None));
47
48 // Do it again.
49 for i in 0..10 {
50 map.spawn(i, async move {
51 tokio::time::sleep(Duration::from_secs(i as u64)).await;
52 i
53 });
54 }
55
56 let mut seen = [false; 10];
57 while let Some((k, res)) = map.join_next().await {
58 seen[k] = true;
59 assert_eq!(res.expect("task should have completed successfully"), k);
60 }
61
62 for was_seen in &seen {
63 assert!(was_seen);
64 }
65 assert!(matches!(map.join_next().await, None));
66}
67
68#[tokio::test]
69async fn test_abort_on_drop() {
70 let mut map = JoinMap::new();
71
72 let mut recvs = Vec::new();
73
74 for i in 0..16 {
75 let (send, recv) = oneshot::channel::<()>();
76 recvs.push(recv);
77
78 map.spawn(i, async {
79 // This task will never complete on its own.
80 futures::future::pending::<()>().await;
81 drop(send);
82 });
83 }
84
85 drop(map);
86
87 for recv in recvs {
88 // The task is aborted soon and we will receive an error.
89 assert!(recv.await.is_err());
90 }
91}
92
93#[tokio::test]
94async fn alternating() {
95 let mut map = JoinMap::new();
96
97 assert_eq!(map.len(), 0);
98 map.spawn(1, async {});
99 assert_eq!(map.len(), 1);
100 map.spawn(2, async {});
101 assert_eq!(map.len(), 2);
102
103 for i in 0..16 {
104 let (_, res) = map.join_next().await.unwrap();
105 assert!(res.is_ok());
106 assert_eq!(map.len(), 1);
107 map.spawn(i, async {});
108 assert_eq!(map.len(), 2);
109 }
110}
111
112#[tokio::test]
113async fn test_keys() {
114 use std::collections::HashSet;
115
116 let mut map = JoinMap::new();
117
118 assert_eq!(map.len(), 0);
119 map.spawn(1, async {});
120 assert_eq!(map.len(), 1);
121 map.spawn(2, async {});
122 assert_eq!(map.len(), 2);
123
124 let keys = map.keys().collect::<HashSet<&u32>>();
125 assert!(keys.contains(&1));
126 assert!(keys.contains(&2));
127
128 let _ = map.join_next().await.unwrap();
129 let _ = map.join_next().await.unwrap();
130
131 assert_eq!(map.len(), 0);
132 let keys = map.keys().collect::<HashSet<&u32>>();
133 assert!(keys.is_empty());
134}
135
136#[tokio::test(start_paused = true)]
137async fn abort_by_key() {
138 let mut map = JoinMap::new();
139 let mut num_canceled = 0;
140 let mut num_completed = 0;
141 for i in 0..16 {
142 map.spawn(i, async move {
143 tokio::time::sleep(Duration::from_secs(i as u64)).await;
144 });
145 }
146
147 for i in 0..16 {
148 if i % 2 != 0 {
149 // abort odd-numbered tasks.
150 map.abort(&i);
151 }
152 }
153
154 while let Some((key, res)) = map.join_next().await {
155 match res {
156 Ok(()) => {
157 num_completed += 1;
158 assert_eq!(key % 2, 0);
159 assert!(!map.contains_key(&key));
160 }
161 Err(e) => {
162 num_canceled += 1;
163 assert!(e.is_cancelled());
164 assert_ne!(key % 2, 0);
165 assert!(!map.contains_key(&key));
166 }
167 }
168 }
169
170 assert_eq!(num_canceled, 8);
171 assert_eq!(num_completed, 8);
172}
173
174#[tokio::test(start_paused = true)]
175async fn abort_by_predicate() {
176 let mut map = JoinMap::new();
177 let mut num_canceled = 0;
178 let mut num_completed = 0;
179 for i in 0..16 {
180 map.spawn(i, async move {
181 tokio::time::sleep(Duration::from_secs(i as u64)).await;
182 });
183 }
184
185 // abort odd-numbered tasks.
186 map.abort_matching(|key| key % 2 != 0);
187
188 while let Some((key, res)) = map.join_next().await {
189 match res {
190 Ok(()) => {
191 num_completed += 1;
192 assert_eq!(key % 2, 0);
193 assert!(!map.contains_key(&key));
194 }
195 Err(e) => {
196 num_canceled += 1;
197 assert!(e.is_cancelled());
198 assert_ne!(key % 2, 0);
199 assert!(!map.contains_key(&key));
200 }
201 }
202 }
203
204 assert_eq!(num_canceled, 8);
205 assert_eq!(num_completed, 8);
206}
207
208#[test]
209fn runtime_gone() {
210 let mut map = JoinMap::new();
211 {
212 let rt = rt();
213 map.spawn_on("key", async { 1 }, rt.handle());
214 drop(rt);
215 }
216
217 let (key, res) = rt().block_on(map.join_next()).unwrap();
218 assert_eq!(key, "key");
219 assert!(res.unwrap_err().is_cancelled());
220}
221
222// This ensures that `join_next` works correctly when the coop budget is
223// exhausted.
224#[tokio::test(flavor = "current_thread")]
225async fn join_map_coop() {
226 // Large enough to trigger coop.
227 const TASK_NUM: u32 = 1000;
228
229 static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0);
230
231 let mut map = JoinMap::new();
232
233 for i in 0..TASK_NUM {
234 map.spawn(i, async move {
235 SEM.add_permits(1);
236 i
237 });
238 }
239
240 // Wait for all tasks to complete.
241 //
242 // Since this is a `current_thread` runtime, there's no race condition
243 // between the last permit being added and the task completing.
244 let _ = SEM.acquire_many(TASK_NUM).await.unwrap();
245
246 let mut count = 0;
247 let mut coop_count = 0;
248 loop {
249 match map.join_next().now_or_never() {
250 Some(Some((key, Ok(i)))) => assert_eq!(key, i),
251 Some(Some((key, Err(err)))) => panic!("failed[{}]: {}", key, err),
252 None => {
253 coop_count += 1;
254 tokio::task::yield_now().await;
255 continue;
256 }
257 Some(None) => break,
258 }
259
260 count += 1;
261 }
262 assert!(coop_count >= 1);
263 assert_eq!(count, TASK_NUM);
264}
265
266#[tokio::test(start_paused = true)]
267async fn abort_all() {
268 let mut map: JoinMap<usize, ()> = JoinMap::new();
269
270 for i in 0..5 {
271 map.spawn(i, futures::future::pending());
272 }
273 for i in 5..10 {
274 map.spawn(i, async {
275 tokio::time::sleep(Duration::from_secs(1)).await;
276 });
277 }
278
279 // The join map will now have 5 pending tasks and 5 ready tasks.
280 tokio::time::sleep(Duration::from_secs(2)).await;
281
282 map.abort_all();
283 assert_eq!(map.len(), 10);
284
285 let mut count = 0;
286 let mut seen = [false; 10];
287 while let Some((k, res)) = map.join_next().await {
288 seen[k] = true;
289 if let Err(err) = res {
290 assert!(err.is_cancelled());
291 }
292 count += 1;
293 }
294 assert_eq!(count, 10);
295 assert_eq!(map.len(), 0);
296 for was_seen in &seen {
297 assert!(was_seen);
298 }
299}
300