1#![warn(rust_2018_idioms)]
2#![cfg(all(feature = "full", not(target_os = "wasi")))]
3#![cfg(tokio_unstable)]
4
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::{TcpListener, TcpStream};
7use tokio::runtime;
8use tokio::sync::oneshot;
9use tokio_test::{assert_err, assert_ok};
10
11use futures::future::poll_fn;
12use std::future::Future;
13use std::pin::Pin;
14use std::sync::atomic::AtomicUsize;
15use std::sync::atomic::Ordering::Relaxed;
16use std::sync::{mpsc, Arc, Mutex};
17use std::task::{Context, Poll, Waker};
18
19macro_rules! cfg_metrics {
20 ($($t:tt)*) => {
21 #[cfg(tokio_unstable)]
22 {
23 $( $t )*
24 }
25 }
26}
27
28#[test]
29fn single_thread() {
30 // No panic when starting a runtime w/ a single thread
31 let _ = runtime::Builder::new_multi_thread_alt()
32 .enable_all()
33 .worker_threads(1)
34 .build()
35 .unwrap();
36}
37
38#[test]
39fn many_oneshot_futures() {
40 // used for notifying the main thread
41 const NUM: usize = 1_000;
42
43 for _ in 0..5 {
44 let (tx, rx) = mpsc::channel();
45
46 let rt = rt();
47 let cnt = Arc::new(AtomicUsize::new(0));
48
49 for _ in 0..NUM {
50 let cnt = cnt.clone();
51 let tx = tx.clone();
52
53 rt.spawn(async move {
54 let num = cnt.fetch_add(1, Relaxed) + 1;
55
56 if num == NUM {
57 tx.send(()).unwrap();
58 }
59 });
60 }
61
62 rx.recv().unwrap();
63
64 // Wait for the pool to shutdown
65 drop(rt);
66 }
67}
68
69#[test]
70fn spawn_two() {
71 let rt = rt();
72
73 let out = rt.block_on(async {
74 let (tx, rx) = oneshot::channel();
75
76 tokio::spawn(async move {
77 tokio::spawn(async move {
78 tx.send("ZOMG").unwrap();
79 });
80 });
81
82 assert_ok!(rx.await)
83 });
84
85 assert_eq!(out, "ZOMG");
86
87 cfg_metrics! {
88 let metrics = rt.metrics();
89 drop(rt);
90 assert_eq!(1, metrics.remote_schedule_count());
91
92 let mut local = 0;
93 for i in 0..metrics.num_workers() {
94 local += metrics.worker_local_schedule_count(i);
95 }
96
97 assert_eq!(1, local);
98 }
99}
100
101#[test]
102fn many_multishot_futures() {
103 const CHAIN: usize = 200;
104 const CYCLES: usize = 5;
105 const TRACKS: usize = 50;
106
107 for _ in 0..50 {
108 let rt = rt();
109 let mut start_txs = Vec::with_capacity(TRACKS);
110 let mut final_rxs = Vec::with_capacity(TRACKS);
111
112 for _ in 0..TRACKS {
113 let (start_tx, mut chain_rx) = tokio::sync::mpsc::channel(10);
114
115 for _ in 0..CHAIN {
116 let (next_tx, next_rx) = tokio::sync::mpsc::channel(10);
117
118 // Forward all the messages
119 rt.spawn(async move {
120 while let Some(v) = chain_rx.recv().await {
121 next_tx.send(v).await.unwrap();
122 }
123 });
124
125 chain_rx = next_rx;
126 }
127
128 // This final task cycles if needed
129 let (final_tx, final_rx) = tokio::sync::mpsc::channel(10);
130 let cycle_tx = start_tx.clone();
131 let mut rem = CYCLES;
132
133 rt.spawn(async move {
134 for _ in 0..CYCLES {
135 let msg = chain_rx.recv().await.unwrap();
136
137 rem -= 1;
138
139 if rem == 0 {
140 final_tx.send(msg).await.unwrap();
141 } else {
142 cycle_tx.send(msg).await.unwrap();
143 }
144 }
145 });
146
147 start_txs.push(start_tx);
148 final_rxs.push(final_rx);
149 }
150
151 {
152 rt.block_on(async move {
153 for start_tx in start_txs {
154 start_tx.send("ping").await.unwrap();
155 }
156
157 for mut final_rx in final_rxs {
158 final_rx.recv().await.unwrap();
159 }
160 });
161 }
162 }
163}
164
165#[test]
166fn lifo_slot_budget() {
167 async fn my_fn() {
168 spawn_another();
169 }
170
171 fn spawn_another() {
172 tokio::spawn(my_fn());
173 }
174
175 let rt = runtime::Builder::new_multi_thread_alt()
176 .enable_all()
177 .worker_threads(1)
178 .build()
179 .unwrap();
180
181 let (send, recv) = oneshot::channel();
182
183 rt.spawn(async move {
184 tokio::spawn(my_fn());
185 let _ = send.send(());
186 });
187
188 let _ = rt.block_on(recv);
189}
190
191#[test]
192fn spawn_shutdown() {
193 let rt = rt();
194 let (tx, rx) = mpsc::channel();
195
196 rt.block_on(async {
197 tokio::spawn(client_server(tx.clone()));
198 });
199
200 // Use spawner
201 rt.spawn(client_server(tx));
202
203 assert_ok!(rx.recv());
204 assert_ok!(rx.recv());
205
206 drop(rt);
207 assert_err!(rx.try_recv());
208}
209
210async fn client_server(tx: mpsc::Sender<()>) {
211 let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
212
213 // Get the assigned address
214 let addr = assert_ok!(server.local_addr());
215
216 // Spawn the server
217 tokio::spawn(async move {
218 // Accept a socket
219 let (mut socket, _) = server.accept().await.unwrap();
220
221 // Write some data
222 socket.write_all(b"hello").await.unwrap();
223 });
224
225 let mut client = TcpStream::connect(&addr).await.unwrap();
226
227 let mut buf = vec![];
228 client.read_to_end(&mut buf).await.unwrap();
229
230 assert_eq!(buf, b"hello");
231 tx.send(()).unwrap();
232}
233
234#[test]
235fn drop_threadpool_drops_futures() {
236 for _ in 0..1_000 {
237 let num_inc = Arc::new(AtomicUsize::new(0));
238 let num_dec = Arc::new(AtomicUsize::new(0));
239 let num_drop = Arc::new(AtomicUsize::new(0));
240
241 struct Never(Arc<AtomicUsize>);
242
243 impl Future for Never {
244 type Output = ();
245
246 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
247 Poll::Pending
248 }
249 }
250
251 impl Drop for Never {
252 fn drop(&mut self) {
253 self.0.fetch_add(1, Relaxed);
254 }
255 }
256
257 let a = num_inc.clone();
258 let b = num_dec.clone();
259
260 let rt = runtime::Builder::new_multi_thread_alt()
261 .enable_all()
262 .on_thread_start(move || {
263 a.fetch_add(1, Relaxed);
264 })
265 .on_thread_stop(move || {
266 b.fetch_add(1, Relaxed);
267 })
268 .build()
269 .unwrap();
270
271 rt.spawn(Never(num_drop.clone()));
272
273 // Wait for the pool to shutdown
274 drop(rt);
275
276 // Assert that only a single thread was spawned.
277 let a = num_inc.load(Relaxed);
278 assert!(a >= 1);
279
280 // Assert that all threads shutdown
281 let b = num_dec.load(Relaxed);
282 assert_eq!(a, b);
283
284 // Assert that the future was dropped
285 let c = num_drop.load(Relaxed);
286 assert_eq!(c, 1);
287 }
288}
289
290#[test]
291fn start_stop_callbacks_called() {
292 use std::sync::atomic::{AtomicUsize, Ordering};
293
294 let after_start = Arc::new(AtomicUsize::new(0));
295 let before_stop = Arc::new(AtomicUsize::new(0));
296
297 let after_inner = after_start.clone();
298 let before_inner = before_stop.clone();
299 let rt = tokio::runtime::Builder::new_multi_thread_alt()
300 .enable_all()
301 .on_thread_start(move || {
302 after_inner.clone().fetch_add(1, Ordering::Relaxed);
303 })
304 .on_thread_stop(move || {
305 before_inner.clone().fetch_add(1, Ordering::Relaxed);
306 })
307 .build()
308 .unwrap();
309
310 let (tx, rx) = oneshot::channel();
311
312 rt.spawn(async move {
313 assert_ok!(tx.send(()));
314 });
315
316 assert_ok!(rt.block_on(rx));
317
318 drop(rt);
319
320 assert!(after_start.load(Ordering::Relaxed) > 0);
321 assert!(before_stop.load(Ordering::Relaxed) > 0);
322}
323
324#[test]
325fn blocking_task() {
326 // used for notifying the main thread
327 const NUM: usize = 1_000;
328
329 for _ in 0..10 {
330 let (tx, rx) = mpsc::channel();
331
332 let rt = rt();
333 let cnt = Arc::new(AtomicUsize::new(0));
334
335 // there are four workers in the pool
336 // so, if we run 4 blocking tasks, we know that handoff must have happened
337 let block = Arc::new(std::sync::Barrier::new(5));
338 for _ in 0..4 {
339 let block = block.clone();
340 rt.spawn(async move {
341 tokio::task::block_in_place(move || {
342 block.wait();
343 block.wait();
344 })
345 });
346 }
347 block.wait();
348
349 for _ in 0..NUM {
350 let cnt = cnt.clone();
351 let tx = tx.clone();
352
353 rt.spawn(async move {
354 let num = cnt.fetch_add(1, Relaxed) + 1;
355
356 if num == NUM {
357 tx.send(()).unwrap();
358 }
359 });
360 }
361
362 rx.recv().unwrap();
363
364 // Wait for the pool to shutdown
365 block.wait();
366 }
367}
368
369#[test]
370fn multi_threadpool() {
371 use tokio::sync::oneshot;
372
373 let rt1 = rt();
374 let rt2 = rt();
375
376 let (tx, rx) = oneshot::channel();
377 let (done_tx, done_rx) = mpsc::channel();
378
379 rt2.spawn(async move {
380 rx.await.unwrap();
381 done_tx.send(()).unwrap();
382 });
383
384 rt1.spawn(async move {
385 tx.send(()).unwrap();
386 });
387
388 done_rx.recv().unwrap();
389}
390
391// When `block_in_place` returns, it attempts to reclaim the yielded runtime
392// worker. In this case, the remainder of the task is on the runtime worker and
393// must take part in the cooperative task budgeting system.
394//
395// The test ensures that, when this happens, attempting to consume from a
396// channel yields occasionally even if there are values ready to receive.
397#[test]
398fn coop_and_block_in_place() {
399 let rt = tokio::runtime::Builder::new_multi_thread_alt()
400 // Setting max threads to 1 prevents another thread from claiming the
401 // runtime worker yielded as part of `block_in_place` and guarantees the
402 // same thread will reclaim the worker at the end of the
403 // `block_in_place` call.
404 .max_blocking_threads(1)
405 .build()
406 .unwrap();
407
408 rt.block_on(async move {
409 let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
410
411 // Fill the channel
412 for _ in 0..1024 {
413 tx.send(()).await.unwrap();
414 }
415
416 drop(tx);
417
418 tokio::spawn(async move {
419 // Block in place without doing anything
420 tokio::task::block_in_place(|| {});
421
422 // Receive all the values, this should trigger a `Pending` as the
423 // coop limit will be reached.
424 poll_fn(|cx| {
425 while let Poll::Ready(v) = {
426 tokio::pin! {
427 let fut = rx.recv();
428 }
429
430 Pin::new(&mut fut).poll(cx)
431 } {
432 if v.is_none() {
433 panic!("did not yield");
434 }
435 }
436
437 Poll::Ready(())
438 })
439 .await
440 })
441 .await
442 .unwrap();
443 });
444}
445
446#[test]
447fn yield_after_block_in_place() {
448 let rt = tokio::runtime::Builder::new_multi_thread_alt()
449 .worker_threads(1)
450 .build()
451 .unwrap();
452
453 rt.block_on(async {
454 tokio::spawn(async move {
455 // Block in place then enter a new runtime
456 tokio::task::block_in_place(|| {
457 let rt = tokio::runtime::Builder::new_current_thread()
458 .build()
459 .unwrap();
460
461 rt.block_on(async {});
462 });
463
464 // Yield, then complete
465 tokio::task::yield_now().await;
466 })
467 .await
468 .unwrap()
469 });
470}
471
472// Testing this does not panic
473#[test]
474fn max_blocking_threads() {
475 let _rt = tokio::runtime::Builder::new_multi_thread_alt()
476 .max_blocking_threads(1)
477 .build()
478 .unwrap();
479}
480
481#[test]
482#[should_panic]
483fn max_blocking_threads_set_to_zero() {
484 let _rt = tokio::runtime::Builder::new_multi_thread_alt()
485 .max_blocking_threads(0)
486 .build()
487 .unwrap();
488}
489
490#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
491async fn hang_on_shutdown() {
492 let (sync_tx, sync_rx) = std::sync::mpsc::channel::<()>();
493 tokio::spawn(async move {
494 tokio::task::block_in_place(|| sync_rx.recv().ok());
495 });
496
497 tokio::spawn(async {
498 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
499 drop(sync_tx);
500 });
501 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
502}
503
504/// Demonstrates tokio-rs/tokio#3869
505#[test]
506fn wake_during_shutdown() {
507 struct Shared {
508 waker: Option<Waker>,
509 }
510
511 struct MyFuture {
512 shared: Arc<Mutex<Shared>>,
513 put_waker: bool,
514 }
515
516 impl MyFuture {
517 fn new() -> (Self, Self) {
518 let shared = Arc::new(Mutex::new(Shared { waker: None }));
519 let f1 = MyFuture {
520 shared: shared.clone(),
521 put_waker: true,
522 };
523 let f2 = MyFuture {
524 shared,
525 put_waker: false,
526 };
527 (f1, f2)
528 }
529 }
530
531 impl Future for MyFuture {
532 type Output = ();
533
534 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
535 let me = Pin::into_inner(self);
536 let mut lock = me.shared.lock().unwrap();
537 if me.put_waker {
538 lock.waker = Some(cx.waker().clone());
539 }
540 Poll::Pending
541 }
542 }
543
544 impl Drop for MyFuture {
545 fn drop(&mut self) {
546 let mut lock = self.shared.lock().unwrap();
547 if !self.put_waker {
548 lock.waker.take().unwrap().wake();
549 }
550 drop(lock);
551 }
552 }
553
554 let rt = tokio::runtime::Builder::new_multi_thread_alt()
555 .worker_threads(1)
556 .enable_all()
557 .build()
558 .unwrap();
559
560 let (f1, f2) = MyFuture::new();
561
562 rt.spawn(f1);
563 rt.spawn(f2);
564
565 rt.block_on(async { tokio::time::sleep(tokio::time::Duration::from_millis(20)).await });
566}
567
568#[should_panic]
569#[tokio::test]
570async fn test_block_in_place1() {
571 tokio::task::block_in_place(|| {});
572}
573
574#[tokio::test(flavor = "multi_thread")]
575async fn test_block_in_place2() {
576 tokio::task::block_in_place(|| {});
577}
578
579#[should_panic]
580#[tokio::main(flavor = "current_thread")]
581#[test]
582async fn test_block_in_place3() {
583 tokio::task::block_in_place(|| {});
584}
585
586#[tokio::main]
587#[test]
588async fn test_block_in_place4() {
589 tokio::task::block_in_place(|| {});
590}
591
592// Testing the tuning logic is tricky as it is inherently timing based, and more
593// of a heuristic than an exact behavior. This test checks that the interval
594// changes over time based on load factors. There are no assertions, completion
595// is sufficient. If there is a regression, this test will hang. In theory, we
596// could add limits, but that would be likely to fail on CI.
597#[test]
598#[cfg(not(tokio_no_tuning_tests))]
599fn test_tuning() {
600 use std::sync::atomic::AtomicBool;
601 use std::time::Duration;
602
603 let rt = runtime::Builder::new_multi_thread_alt()
604 .worker_threads(1)
605 .build()
606 .unwrap();
607
608 fn iter(flag: Arc<AtomicBool>, counter: Arc<AtomicUsize>, stall: bool) {
609 if flag.load(Relaxed) {
610 if stall {
611 std::thread::sleep(Duration::from_micros(5));
612 }
613
614 counter.fetch_add(1, Relaxed);
615 tokio::spawn(async move { iter(flag, counter, stall) });
616 }
617 }
618
619 let flag = Arc::new(AtomicBool::new(true));
620 let counter = Arc::new(AtomicUsize::new(61));
621 let interval = Arc::new(AtomicUsize::new(61));
622
623 {
624 let flag = flag.clone();
625 let counter = counter.clone();
626 rt.spawn(async move { iter(flag, counter, true) });
627 }
628
629 // Now, hammer the injection queue until the interval drops.
630 let mut n = 0;
631 loop {
632 let curr = interval.load(Relaxed);
633
634 if curr <= 8 {
635 n += 1;
636 } else {
637 n = 0;
638 }
639
640 // Make sure we get a few good rounds. Jitter in the tuning could result
641 // in one "good" value without being representative of reaching a good
642 // state.
643 if n == 3 {
644 break;
645 }
646
647 if Arc::strong_count(&interval) < 5_000 {
648 let counter = counter.clone();
649 let interval = interval.clone();
650
651 rt.spawn(async move {
652 let prev = counter.swap(0, Relaxed);
653 interval.store(prev, Relaxed);
654 });
655
656 std::thread::yield_now();
657 }
658 }
659
660 flag.store(false, Relaxed);
661
662 let w = Arc::downgrade(&interval);
663 drop(interval);
664
665 while w.strong_count() > 0 {
666 std::thread::sleep(Duration::from_micros(500));
667 }
668
669 // Now, run it again with a faster task
670 let flag = Arc::new(AtomicBool::new(true));
671 // Set it high, we know it shouldn't ever really be this high
672 let counter = Arc::new(AtomicUsize::new(10_000));
673 let interval = Arc::new(AtomicUsize::new(10_000));
674
675 {
676 let flag = flag.clone();
677 let counter = counter.clone();
678 rt.spawn(async move { iter(flag, counter, false) });
679 }
680
681 // Now, hammer the injection queue until the interval reaches the expected range.
682 let mut n = 0;
683 loop {
684 let curr = interval.load(Relaxed);
685
686 if curr <= 1_000 && curr > 32 {
687 n += 1;
688 } else {
689 n = 0;
690 }
691
692 if n == 3 {
693 break;
694 }
695
696 if Arc::strong_count(&interval) <= 5_000 {
697 let counter = counter.clone();
698 let interval = interval.clone();
699
700 rt.spawn(async move {
701 let prev = counter.swap(0, Relaxed);
702 interval.store(prev, Relaxed);
703 });
704 }
705
706 std::thread::yield_now();
707 }
708
709 flag.store(false, Relaxed);
710}
711
712fn rt() -> runtime::Runtime {
713 runtime::Builder::new_multi_thread_alt()
714 .enable_all()
715 .build()
716 .unwrap()
717}
718