1#![warn(rust_2018_idioms)]
2#![cfg(all(feature = "full", not(target_os = "wasi")))]
3
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::{TcpListener, TcpStream};
6use tokio::runtime;
7use tokio::sync::oneshot;
8use tokio_test::{assert_err, assert_ok};
9
10use futures::future::poll_fn;
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::atomic::AtomicUsize;
14use std::sync::atomic::Ordering::Relaxed;
15use std::sync::{mpsc, Arc, Mutex};
16use std::task::{Context, Poll, Waker};
17
18macro_rules! cfg_metrics {
19 ($($t:tt)*) => {
20 #[cfg(tokio_unstable)]
21 {
22 $( $t )*
23 }
24 }
25}
26
27#[test]
28fn single_thread() {
29 // No panic when starting a runtime w/ a single thread
30 let _ = runtime::Builder::new_multi_thread()
31 .enable_all()
32 .worker_threads(1)
33 .build()
34 .unwrap();
35}
36
37#[test]
38fn many_oneshot_futures() {
39 // used for notifying the main thread
40 const NUM: usize = 1_000;
41
42 for _ in 0..5 {
43 let (tx, rx) = mpsc::channel();
44
45 let rt = rt();
46 let cnt = Arc::new(AtomicUsize::new(0));
47
48 for _ in 0..NUM {
49 let cnt = cnt.clone();
50 let tx = tx.clone();
51
52 rt.spawn(async move {
53 let num = cnt.fetch_add(1, Relaxed) + 1;
54
55 if num == NUM {
56 tx.send(()).unwrap();
57 }
58 });
59 }
60
61 rx.recv().unwrap();
62
63 // Wait for the pool to shutdown
64 drop(rt);
65 }
66}
67
68#[test]
69fn spawn_two() {
70 let rt = rt();
71
72 let out = rt.block_on(async {
73 let (tx, rx) = oneshot::channel();
74
75 tokio::spawn(async move {
76 tokio::spawn(async move {
77 tx.send("ZOMG").unwrap();
78 });
79 });
80
81 assert_ok!(rx.await)
82 });
83
84 assert_eq!(out, "ZOMG");
85
86 cfg_metrics! {
87 let metrics = rt.metrics();
88 drop(rt);
89 assert_eq!(1, metrics.remote_schedule_count());
90
91 let mut local = 0;
92 for i in 0..metrics.num_workers() {
93 local += metrics.worker_local_schedule_count(i);
94 }
95
96 assert_eq!(1, local);
97 }
98}
99
100#[test]
101fn many_multishot_futures() {
102 const CHAIN: usize = 200;
103 const CYCLES: usize = 5;
104 const TRACKS: usize = 50;
105
106 for _ in 0..50 {
107 let rt = rt();
108 let mut start_txs = Vec::with_capacity(TRACKS);
109 let mut final_rxs = Vec::with_capacity(TRACKS);
110
111 for _ in 0..TRACKS {
112 let (start_tx, mut chain_rx) = tokio::sync::mpsc::channel(10);
113
114 for _ in 0..CHAIN {
115 let (next_tx, next_rx) = tokio::sync::mpsc::channel(10);
116
117 // Forward all the messages
118 rt.spawn(async move {
119 while let Some(v) = chain_rx.recv().await {
120 next_tx.send(v).await.unwrap();
121 }
122 });
123
124 chain_rx = next_rx;
125 }
126
127 // This final task cycles if needed
128 let (final_tx, final_rx) = tokio::sync::mpsc::channel(10);
129 let cycle_tx = start_tx.clone();
130 let mut rem = CYCLES;
131
132 rt.spawn(async move {
133 for _ in 0..CYCLES {
134 let msg = chain_rx.recv().await.unwrap();
135
136 rem -= 1;
137
138 if rem == 0 {
139 final_tx.send(msg).await.unwrap();
140 } else {
141 cycle_tx.send(msg).await.unwrap();
142 }
143 }
144 });
145
146 start_txs.push(start_tx);
147 final_rxs.push(final_rx);
148 }
149
150 {
151 rt.block_on(async move {
152 for start_tx in start_txs {
153 start_tx.send("ping").await.unwrap();
154 }
155
156 for mut final_rx in final_rxs {
157 final_rx.recv().await.unwrap();
158 }
159 });
160 }
161 }
162}
163
164#[test]
165fn lifo_slot_budget() {
166 async fn my_fn() {
167 spawn_another();
168 }
169
170 fn spawn_another() {
171 tokio::spawn(my_fn());
172 }
173
174 let rt = runtime::Builder::new_multi_thread()
175 .enable_all()
176 .worker_threads(1)
177 .build()
178 .unwrap();
179
180 let (send, recv) = oneshot::channel();
181
182 rt.spawn(async move {
183 tokio::spawn(my_fn());
184 let _ = send.send(());
185 });
186
187 let _ = rt.block_on(recv);
188}
189
190#[test]
191fn spawn_shutdown() {
192 let rt = rt();
193 let (tx, rx) = mpsc::channel();
194
195 rt.block_on(async {
196 tokio::spawn(client_server(tx.clone()));
197 });
198
199 // Use spawner
200 rt.spawn(client_server(tx));
201
202 assert_ok!(rx.recv());
203 assert_ok!(rx.recv());
204
205 drop(rt);
206 assert_err!(rx.try_recv());
207}
208
209async fn client_server(tx: mpsc::Sender<()>) {
210 let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
211
212 // Get the assigned address
213 let addr = assert_ok!(server.local_addr());
214
215 // Spawn the server
216 tokio::spawn(async move {
217 // Accept a socket
218 let (mut socket, _) = server.accept().await.unwrap();
219
220 // Write some data
221 socket.write_all(b"hello").await.unwrap();
222 });
223
224 let mut client = TcpStream::connect(&addr).await.unwrap();
225
226 let mut buf = vec![];
227 client.read_to_end(&mut buf).await.unwrap();
228
229 assert_eq!(buf, b"hello");
230 tx.send(()).unwrap();
231}
232
233#[test]
234fn drop_threadpool_drops_futures() {
235 for _ in 0..1_000 {
236 let num_inc = Arc::new(AtomicUsize::new(0));
237 let num_dec = Arc::new(AtomicUsize::new(0));
238 let num_drop = Arc::new(AtomicUsize::new(0));
239
240 struct Never(Arc<AtomicUsize>);
241
242 impl Future for Never {
243 type Output = ();
244
245 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
246 Poll::Pending
247 }
248 }
249
250 impl Drop for Never {
251 fn drop(&mut self) {
252 self.0.fetch_add(1, Relaxed);
253 }
254 }
255
256 let a = num_inc.clone();
257 let b = num_dec.clone();
258
259 let rt = runtime::Builder::new_multi_thread()
260 .enable_all()
261 .on_thread_start(move || {
262 a.fetch_add(1, Relaxed);
263 })
264 .on_thread_stop(move || {
265 b.fetch_add(1, Relaxed);
266 })
267 .build()
268 .unwrap();
269
270 rt.spawn(Never(num_drop.clone()));
271
272 // Wait for the pool to shutdown
273 drop(rt);
274
275 // Assert that only a single thread was spawned.
276 let a = num_inc.load(Relaxed);
277 assert!(a >= 1);
278
279 // Assert that all threads shutdown
280 let b = num_dec.load(Relaxed);
281 assert_eq!(a, b);
282
283 // Assert that the future was dropped
284 let c = num_drop.load(Relaxed);
285 assert_eq!(c, 1);
286 }
287}
288
289#[test]
290fn start_stop_callbacks_called() {
291 use std::sync::atomic::{AtomicUsize, Ordering};
292
293 let after_start = Arc::new(AtomicUsize::new(0));
294 let before_stop = Arc::new(AtomicUsize::new(0));
295
296 let after_inner = after_start.clone();
297 let before_inner = before_stop.clone();
298 let rt = tokio::runtime::Builder::new_multi_thread()
299 .enable_all()
300 .on_thread_start(move || {
301 after_inner.clone().fetch_add(1, Ordering::Relaxed);
302 })
303 .on_thread_stop(move || {
304 before_inner.clone().fetch_add(1, Ordering::Relaxed);
305 })
306 .build()
307 .unwrap();
308
309 let (tx, rx) = oneshot::channel();
310
311 rt.spawn(async move {
312 assert_ok!(tx.send(()));
313 });
314
315 assert_ok!(rt.block_on(rx));
316
317 drop(rt);
318
319 assert!(after_start.load(Ordering::Relaxed) > 0);
320 assert!(before_stop.load(Ordering::Relaxed) > 0);
321}
322
323#[test]
324fn blocking() {
325 // used for notifying the main thread
326 const NUM: usize = 1_000;
327
328 for _ in 0..10 {
329 let (tx, rx) = mpsc::channel();
330
331 let rt = rt();
332 let cnt = Arc::new(AtomicUsize::new(0));
333
334 // there are four workers in the pool
335 // so, if we run 4 blocking tasks, we know that handoff must have happened
336 let block = Arc::new(std::sync::Barrier::new(5));
337 for _ in 0..4 {
338 let block = block.clone();
339 rt.spawn(async move {
340 tokio::task::block_in_place(move || {
341 block.wait();
342 block.wait();
343 })
344 });
345 }
346 block.wait();
347
348 for _ in 0..NUM {
349 let cnt = cnt.clone();
350 let tx = tx.clone();
351
352 rt.spawn(async move {
353 let num = cnt.fetch_add(1, Relaxed) + 1;
354
355 if num == NUM {
356 tx.send(()).unwrap();
357 }
358 });
359 }
360
361 rx.recv().unwrap();
362
363 // Wait for the pool to shutdown
364 block.wait();
365 }
366}
367
368#[test]
369fn multi_threadpool() {
370 use tokio::sync::oneshot;
371
372 let rt1 = rt();
373 let rt2 = rt();
374
375 let (tx, rx) = oneshot::channel();
376 let (done_tx, done_rx) = mpsc::channel();
377
378 rt2.spawn(async move {
379 rx.await.unwrap();
380 done_tx.send(()).unwrap();
381 });
382
383 rt1.spawn(async move {
384 tx.send(()).unwrap();
385 });
386
387 done_rx.recv().unwrap();
388}
389
390// When `block_in_place` returns, it attempts to reclaim the yielded runtime
391// worker. In this case, the remainder of the task is on the runtime worker and
392// must take part in the cooperative task budgeting system.
393//
394// The test ensures that, when this happens, attempting to consume from a
395// channel yields occasionally even if there are values ready to receive.
396#[test]
397fn coop_and_block_in_place() {
398 let rt = tokio::runtime::Builder::new_multi_thread()
399 // Setting max threads to 1 prevents another thread from claiming the
400 // runtime worker yielded as part of `block_in_place` and guarantees the
401 // same thread will reclaim the worker at the end of the
402 // `block_in_place` call.
403 .max_blocking_threads(1)
404 .build()
405 .unwrap();
406
407 rt.block_on(async move {
408 let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
409
410 // Fill the channel
411 for _ in 0..1024 {
412 tx.send(()).await.unwrap();
413 }
414
415 drop(tx);
416
417 tokio::spawn(async move {
418 // Block in place without doing anything
419 tokio::task::block_in_place(|| {});
420
421 // Receive all the values, this should trigger a `Pending` as the
422 // coop limit will be reached.
423 poll_fn(|cx| {
424 while let Poll::Ready(v) = {
425 tokio::pin! {
426 let fut = rx.recv();
427 }
428
429 Pin::new(&mut fut).poll(cx)
430 } {
431 if v.is_none() {
432 panic!("did not yield");
433 }
434 }
435
436 Poll::Ready(())
437 })
438 .await
439 })
440 .await
441 .unwrap();
442 });
443}
444
445#[test]
446fn yield_after_block_in_place() {
447 let rt = tokio::runtime::Builder::new_multi_thread()
448 .worker_threads(1)
449 .build()
450 .unwrap();
451
452 rt.block_on(async {
453 tokio::spawn(async move {
454 // Block in place then enter a new runtime
455 tokio::task::block_in_place(|| {
456 let rt = tokio::runtime::Builder::new_current_thread()
457 .build()
458 .unwrap();
459
460 rt.block_on(async {});
461 });
462
463 // Yield, then complete
464 tokio::task::yield_now().await;
465 })
466 .await
467 .unwrap()
468 });
469}
470
471// Testing this does not panic
472#[test]
473fn max_blocking_threads() {
474 let _rt = tokio::runtime::Builder::new_multi_thread()
475 .max_blocking_threads(1)
476 .build()
477 .unwrap();
478}
479
480#[test]
481#[should_panic]
482fn max_blocking_threads_set_to_zero() {
483 let _rt = tokio::runtime::Builder::new_multi_thread()
484 .max_blocking_threads(0)
485 .build()
486 .unwrap();
487}
488
489#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
490async fn hang_on_shutdown() {
491 let (sync_tx, sync_rx) = std::sync::mpsc::channel::<()>();
492 tokio::spawn(async move {
493 tokio::task::block_in_place(|| sync_rx.recv().ok());
494 });
495
496 tokio::spawn(async {
497 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
498 drop(sync_tx);
499 });
500 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
501}
502
503/// Demonstrates tokio-rs/tokio#3869
504#[test]
505fn wake_during_shutdown() {
506 struct Shared {
507 waker: Option<Waker>,
508 }
509
510 struct MyFuture {
511 shared: Arc<Mutex<Shared>>,
512 put_waker: bool,
513 }
514
515 impl MyFuture {
516 fn new() -> (Self, Self) {
517 let shared = Arc::new(Mutex::new(Shared { waker: None }));
518 let f1 = MyFuture {
519 shared: shared.clone(),
520 put_waker: true,
521 };
522 let f2 = MyFuture {
523 shared,
524 put_waker: false,
525 };
526 (f1, f2)
527 }
528 }
529
530 impl Future for MyFuture {
531 type Output = ();
532
533 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
534 let me = Pin::into_inner(self);
535 let mut lock = me.shared.lock().unwrap();
536 if me.put_waker {
537 lock.waker = Some(cx.waker().clone());
538 }
539 Poll::Pending
540 }
541 }
542
543 impl Drop for MyFuture {
544 fn drop(&mut self) {
545 let mut lock = self.shared.lock().unwrap();
546 if !self.put_waker {
547 lock.waker.take().unwrap().wake();
548 }
549 drop(lock);
550 }
551 }
552
553 let rt = tokio::runtime::Builder::new_multi_thread()
554 .worker_threads(1)
555 .enable_all()
556 .build()
557 .unwrap();
558
559 let (f1, f2) = MyFuture::new();
560
561 rt.spawn(f1);
562 rt.spawn(f2);
563
564 rt.block_on(async { tokio::time::sleep(tokio::time::Duration::from_millis(20)).await });
565}
566
567#[should_panic]
568#[tokio::test]
569async fn test_block_in_place1() {
570 tokio::task::block_in_place(|| {});
571}
572
573#[tokio::test(flavor = "multi_thread")]
574async fn test_block_in_place2() {
575 tokio::task::block_in_place(|| {});
576}
577
578#[should_panic]
579#[tokio::main(flavor = "current_thread")]
580#[test]
581async fn test_block_in_place3() {
582 tokio::task::block_in_place(|| {});
583}
584
585#[tokio::main]
586#[test]
587async fn test_block_in_place4() {
588 tokio::task::block_in_place(|| {});
589}
590
591// Repro for tokio-rs/tokio#5239
592#[test]
593fn test_nested_block_in_place_with_block_on_between() {
594 let rt = runtime::Builder::new_multi_thread()
595 .worker_threads(1)
596 // Needs to be more than 0
597 .max_blocking_threads(1)
598 .build()
599 .unwrap();
600
601 // Triggered by a race condition, so run a few times to make sure it is OK.
602 for _ in 0..100 {
603 let h = rt.handle().clone();
604
605 rt.block_on(async move {
606 tokio::spawn(async move {
607 tokio::task::block_in_place(|| {
608 h.block_on(async {
609 tokio::task::block_in_place(|| {});
610 });
611 })
612 })
613 .await
614 .unwrap()
615 });
616 }
617}
618
619// Testing the tuning logic is tricky as it is inherently timing based, and more
620// of a heuristic than an exact behavior. This test checks that the interval
621// changes over time based on load factors. There are no assertions, completion
622// is sufficient. If there is a regression, this test will hang. In theory, we
623// could add limits, but that would be likely to fail on CI.
624#[test]
625#[cfg(not(tokio_no_tuning_tests))]
626fn test_tuning() {
627 use std::sync::atomic::AtomicBool;
628 use std::time::Duration;
629
630 let rt = runtime::Builder::new_multi_thread()
631 .worker_threads(1)
632 .build()
633 .unwrap();
634
635 fn iter(flag: Arc<AtomicBool>, counter: Arc<AtomicUsize>, stall: bool) {
636 if flag.load(Relaxed) {
637 if stall {
638 std::thread::sleep(Duration::from_micros(5));
639 }
640
641 counter.fetch_add(1, Relaxed);
642 tokio::spawn(async move { iter(flag, counter, stall) });
643 }
644 }
645
646 let flag = Arc::new(AtomicBool::new(true));
647 let counter = Arc::new(AtomicUsize::new(61));
648 let interval = Arc::new(AtomicUsize::new(61));
649
650 {
651 let flag = flag.clone();
652 let counter = counter.clone();
653 rt.spawn(async move { iter(flag, counter, true) });
654 }
655
656 // Now, hammer the injection queue until the interval drops.
657 let mut n = 0;
658 loop {
659 let curr = interval.load(Relaxed);
660
661 if curr <= 8 {
662 n += 1;
663 } else {
664 n = 0;
665 }
666
667 // Make sure we get a few good rounds. Jitter in the tuning could result
668 // in one "good" value without being representative of reaching a good
669 // state.
670 if n == 3 {
671 break;
672 }
673
674 if Arc::strong_count(&interval) < 5_000 {
675 let counter = counter.clone();
676 let interval = interval.clone();
677
678 rt.spawn(async move {
679 let prev = counter.swap(0, Relaxed);
680 interval.store(prev, Relaxed);
681 });
682
683 std::thread::yield_now();
684 }
685 }
686
687 flag.store(false, Relaxed);
688
689 let w = Arc::downgrade(&interval);
690 drop(interval);
691
692 while w.strong_count() > 0 {
693 std::thread::sleep(Duration::from_micros(500));
694 }
695
696 // Now, run it again with a faster task
697 let flag = Arc::new(AtomicBool::new(true));
698 // Set it high, we know it shouldn't ever really be this high
699 let counter = Arc::new(AtomicUsize::new(10_000));
700 let interval = Arc::new(AtomicUsize::new(10_000));
701
702 {
703 let flag = flag.clone();
704 let counter = counter.clone();
705 rt.spawn(async move { iter(flag, counter, false) });
706 }
707
708 // Now, hammer the injection queue until the interval reaches the expected range.
709 let mut n = 0;
710 loop {
711 let curr = interval.load(Relaxed);
712
713 if curr <= 1_000 && curr > 32 {
714 n += 1;
715 } else {
716 n = 0;
717 }
718
719 if n == 3 {
720 break;
721 }
722
723 if Arc::strong_count(&interval) <= 5_000 {
724 let counter = counter.clone();
725 let interval = interval.clone();
726
727 rt.spawn(async move {
728 let prev = counter.swap(0, Relaxed);
729 interval.store(prev, Relaxed);
730 });
731 }
732
733 std::thread::yield_now();
734 }
735
736 flag.store(false, Relaxed);
737}
738
739fn rt() -> runtime::Runtime {
740 runtime::Runtime::new().unwrap()
741}
742
743#[cfg(tokio_unstable)]
744mod unstable {
745 use super::*;
746
747 #[test]
748 fn test_disable_lifo_slot() {
749 use std::sync::mpsc::{channel, RecvTimeoutError};
750
751 let rt = runtime::Builder::new_multi_thread()
752 .disable_lifo_slot()
753 .worker_threads(2)
754 .build()
755 .unwrap();
756
757 // Spawn a background thread to poke the runtime periodically.
758 //
759 // This is necessary because we may end up triggering the issue in:
760 // <https://github.com/tokio-rs/tokio/issues/4730>
761 //
762 // Spawning a task will wake up the second worker, which will then steal
763 // the task. However, the steal will fail if the task is in the LIFO
764 // slot, because the LIFO slot cannot be stolen.
765 //
766 // Note that this only happens rarely. Most of the time, this thread is
767 // not necessary.
768 let (kill_bg_thread, recv) = channel::<()>();
769 let handle = rt.handle().clone();
770 let bg_thread = std::thread::spawn(move || {
771 let one_sec = std::time::Duration::from_secs(1);
772 while recv.recv_timeout(one_sec) == Err(RecvTimeoutError::Timeout) {
773 handle.spawn(async {});
774 }
775 });
776
777 rt.block_on(async {
778 tokio::spawn(async {
779 // Spawn another task and block the thread until completion. If the LIFO slot
780 // is used then the test doesn't complete.
781 futures::executor::block_on(tokio::spawn(async {})).unwrap();
782 })
783 .await
784 .unwrap();
785 });
786
787 drop(kill_bg_thread);
788 bg_thread.join().unwrap();
789 }
790
791 #[test]
792 fn runtime_id_is_same() {
793 let rt = rt();
794
795 let handle1 = rt.handle();
796 let handle2 = rt.handle();
797
798 assert_eq!(handle1.id(), handle2.id());
799 }
800
801 #[test]
802 fn runtime_ids_different() {
803 let rt1 = rt();
804 let rt2 = rt();
805
806 assert_ne!(rt1.handle().id(), rt2.handle().id());
807 }
808}
809