1//! Thread pool for blocking operations
2
3use crate::loom::sync::{Arc, Condvar, Mutex};
4use crate::loom::thread;
5use crate::runtime::blocking::schedule::BlockingSchedule;
6use crate::runtime::blocking::{shutdown, BlockingTask};
7use crate::runtime::builder::ThreadNameFn;
8use crate::runtime::task::{self, JoinHandle};
9use crate::runtime::{Builder, Callback, Handle};
10
11use std::collections::{HashMap, VecDeque};
12use std::fmt;
13use std::io;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::time::Duration;
16
17pub(crate) struct BlockingPool {
18 spawner: Spawner,
19 shutdown_rx: shutdown::Receiver,
20}
21
22#[derive(Clone)]
23pub(crate) struct Spawner {
24 inner: Arc<Inner>,
25}
26
27#[derive(Default)]
28pub(crate) struct SpawnerMetrics {
29 num_threads: AtomicUsize,
30 num_idle_threads: AtomicUsize,
31 queue_depth: AtomicUsize,
32}
33
34impl SpawnerMetrics {
35 fn num_threads(&self) -> usize {
36 self.num_threads.load(Ordering::Relaxed)
37 }
38
39 fn num_idle_threads(&self) -> usize {
40 self.num_idle_threads.load(Ordering::Relaxed)
41 }
42
43 cfg_metrics! {
44 fn queue_depth(&self) -> usize {
45 self.queue_depth.load(Ordering::Relaxed)
46 }
47 }
48
49 fn inc_num_threads(&self) {
50 self.num_threads.fetch_add(1, Ordering::Relaxed);
51 }
52
53 fn dec_num_threads(&self) {
54 self.num_threads.fetch_sub(1, Ordering::Relaxed);
55 }
56
57 fn inc_num_idle_threads(&self) {
58 self.num_idle_threads.fetch_add(1, Ordering::Relaxed);
59 }
60
61 fn dec_num_idle_threads(&self) -> usize {
62 self.num_idle_threads.fetch_sub(1, Ordering::Relaxed)
63 }
64
65 fn inc_queue_depth(&self) {
66 self.queue_depth.fetch_add(1, Ordering::Relaxed);
67 }
68
69 fn dec_queue_depth(&self) {
70 self.queue_depth.fetch_sub(1, Ordering::Relaxed);
71 }
72}
73
74struct Inner {
75 /// State shared between worker threads.
76 shared: Mutex<Shared>,
77
78 /// Pool threads wait on this.
79 condvar: Condvar,
80
81 /// Spawned threads use this name.
82 thread_name: ThreadNameFn,
83
84 /// Spawned thread stack size.
85 stack_size: Option<usize>,
86
87 /// Call after a thread starts.
88 after_start: Option<Callback>,
89
90 /// Call before a thread stops.
91 before_stop: Option<Callback>,
92
93 // Maximum number of threads.
94 thread_cap: usize,
95
96 // Customizable wait timeout.
97 keep_alive: Duration,
98
99 // Metrics about the pool.
100 metrics: SpawnerMetrics,
101}
102
103struct Shared {
104 queue: VecDeque<Task>,
105 num_notify: u32,
106 shutdown: bool,
107 shutdown_tx: Option<shutdown::Sender>,
108 /// Prior to shutdown, we clean up JoinHandles by having each timed-out
109 /// thread join on the previous timed-out thread. This is not strictly
110 /// necessary but helps avoid Valgrind false positives, see
111 /// <https://github.com/tokio-rs/tokio/commit/646fbae76535e397ef79dbcaacb945d4c829f666>
112 /// for more information.
113 last_exiting_thread: Option<thread::JoinHandle<()>>,
114 /// This holds the JoinHandles for all running threads; on shutdown, the thread
115 /// calling shutdown handles joining on these.
116 worker_threads: HashMap<usize, thread::JoinHandle<()>>,
117 /// This is a counter used to iterate worker_threads in a consistent order (for loom's
118 /// benefit).
119 worker_thread_index: usize,
120}
121
122pub(crate) struct Task {
123 task: task::UnownedTask<BlockingSchedule>,
124 mandatory: Mandatory,
125}
126
127#[derive(PartialEq, Eq)]
128pub(crate) enum Mandatory {
129 #[cfg_attr(not(fs), allow(dead_code))]
130 Mandatory,
131 NonMandatory,
132}
133
134pub(crate) enum SpawnError {
135 /// Pool is shutting down and the task was not scheduled
136 ShuttingDown,
137 /// There are no worker threads available to take the task
138 /// and the OS failed to spawn a new one
139 NoThreads(io::Error),
140}
141
142impl From<SpawnError> for io::Error {
143 fn from(e: SpawnError) -> Self {
144 match e {
145 SpawnError::ShuttingDown => {
146 io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
147 }
148 SpawnError::NoThreads(e) => e,
149 }
150 }
151}
152
153impl Task {
154 pub(crate) fn new(task: task::UnownedTask<BlockingSchedule>, mandatory: Mandatory) -> Task {
155 Task { task, mandatory }
156 }
157
158 fn run(self) {
159 self.task.run();
160 }
161
162 fn shutdown_or_run_if_mandatory(self) {
163 match self.mandatory {
164 Mandatory::NonMandatory => self.task.shutdown(),
165 Mandatory::Mandatory => self.task.run(),
166 }
167 }
168}
169
170const KEEP_ALIVE: Duration = Duration::from_secs(10);
171
172/// Runs the provided function on an executor dedicated to blocking operations.
173/// Tasks will be scheduled as non-mandatory, meaning they may not get executed
174/// in case of runtime shutdown.
175#[track_caller]
176#[cfg_attr(target_os = "wasi", allow(dead_code))]
177pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
178where
179 F: FnOnce() -> R + Send + 'static,
180 R: Send + 'static,
181{
182 let rt = Handle::current();
183 rt.spawn_blocking(func)
184}
185
186cfg_fs! {
187 #[cfg_attr(any(
188 all(loom, not(test)), // the function is covered by loom tests
189 test
190 ), allow(dead_code))]
191 /// Runs the provided function on an executor dedicated to blocking
192 /// operations. Tasks will be scheduled as mandatory, meaning they are
193 /// guaranteed to run unless a shutdown is already taking place. In case a
194 /// shutdown is already taking place, `None` will be returned.
195 pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>>
196 where
197 F: FnOnce() -> R + Send + 'static,
198 R: Send + 'static,
199 {
200 let rt = Handle::current();
201 rt.inner.blocking_spawner().spawn_mandatory_blocking(&rt, func)
202 }
203}
204
205// ===== impl BlockingPool =====
206
207impl BlockingPool {
208 pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool {
209 let (shutdown_tx, shutdown_rx) = shutdown::channel();
210 let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE);
211
212 BlockingPool {
213 spawner: Spawner {
214 inner: Arc::new(Inner {
215 shared: Mutex::new(Shared {
216 queue: VecDeque::new(),
217 num_notify: 0,
218 shutdown: false,
219 shutdown_tx: Some(shutdown_tx),
220 last_exiting_thread: None,
221 worker_threads: HashMap::new(),
222 worker_thread_index: 0,
223 }),
224 condvar: Condvar::new(),
225 thread_name: builder.thread_name.clone(),
226 stack_size: builder.thread_stack_size,
227 after_start: builder.after_start.clone(),
228 before_stop: builder.before_stop.clone(),
229 thread_cap,
230 keep_alive,
231 metrics: SpawnerMetrics::default(),
232 }),
233 },
234 shutdown_rx,
235 }
236 }
237
238 pub(crate) fn spawner(&self) -> &Spawner {
239 &self.spawner
240 }
241
242 pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
243 let mut shared = self.spawner.inner.shared.lock();
244
245 // The function can be called multiple times. First, by explicitly
246 // calling `shutdown` then by the drop handler calling `shutdown`. This
247 // prevents shutting down twice.
248 if shared.shutdown {
249 return;
250 }
251
252 shared.shutdown = true;
253 shared.shutdown_tx = None;
254 self.spawner.inner.condvar.notify_all();
255
256 let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread);
257 let workers = std::mem::take(&mut shared.worker_threads);
258
259 drop(shared);
260
261 if self.shutdown_rx.wait(timeout) {
262 let _ = last_exited_thread.map(thread::JoinHandle::join);
263
264 // Loom requires that execution be deterministic, so sort by thread ID before joining.
265 // (HashMaps use a randomly-seeded hash function, so the order is nondeterministic)
266 let mut workers: Vec<(usize, thread::JoinHandle<()>)> = workers.into_iter().collect();
267 workers.sort_by_key(|(id, _)| *id);
268
269 for (_id, handle) in workers {
270 let _ = handle.join();
271 }
272 }
273 }
274}
275
276impl Drop for BlockingPool {
277 fn drop(&mut self) {
278 self.shutdown(None);
279 }
280}
281
282impl fmt::Debug for BlockingPool {
283 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
284 fmt.debug_struct("BlockingPool").finish()
285 }
286}
287
288// ===== impl Spawner =====
289
290impl Spawner {
291 #[track_caller]
292 pub(crate) fn spawn_blocking<F, R>(&self, rt: &Handle, func: F) -> JoinHandle<R>
293 where
294 F: FnOnce() -> R + Send + 'static,
295 R: Send + 'static,
296 {
297 let (join_handle, spawn_result) =
298 if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
299 self.spawn_blocking_inner(Box::new(func), Mandatory::NonMandatory, None, rt)
300 } else {
301 self.spawn_blocking_inner(func, Mandatory::NonMandatory, None, rt)
302 };
303
304 match spawn_result {
305 Ok(()) => join_handle,
306 // Compat: do not panic here, return the join_handle even though it will never resolve
307 Err(SpawnError::ShuttingDown) => join_handle,
308 Err(SpawnError::NoThreads(e)) => {
309 panic!("OS can't spawn worker thread: {}", e)
310 }
311 }
312 }
313
314 cfg_fs! {
315 #[track_caller]
316 #[cfg_attr(any(
317 all(loom, not(test)), // the function is covered by loom tests
318 test
319 ), allow(dead_code))]
320 pub(crate) fn spawn_mandatory_blocking<F, R>(&self, rt: &Handle, func: F) -> Option<JoinHandle<R>>
321 where
322 F: FnOnce() -> R + Send + 'static,
323 R: Send + 'static,
324 {
325 let (join_handle, spawn_result) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
326 self.spawn_blocking_inner(
327 Box::new(func),
328 Mandatory::Mandatory,
329 None,
330 rt,
331 )
332 } else {
333 self.spawn_blocking_inner(
334 func,
335 Mandatory::Mandatory,
336 None,
337 rt,
338 )
339 };
340
341 if spawn_result.is_ok() {
342 Some(join_handle)
343 } else {
344 None
345 }
346 }
347 }
348
349 #[track_caller]
350 pub(crate) fn spawn_blocking_inner<F, R>(
351 &self,
352 func: F,
353 is_mandatory: Mandatory,
354 name: Option<&str>,
355 rt: &Handle,
356 ) -> (JoinHandle<R>, Result<(), SpawnError>)
357 where
358 F: FnOnce() -> R + Send + 'static,
359 R: Send + 'static,
360 {
361 let fut = BlockingTask::new(func);
362 let id = task::Id::next();
363 #[cfg(all(tokio_unstable, feature = "tracing"))]
364 let fut = {
365 use tracing::Instrument;
366 let location = std::panic::Location::caller();
367 let span = tracing::trace_span!(
368 target: "tokio::task::blocking",
369 "runtime.spawn",
370 kind = %"blocking",
371 task.name = %name.unwrap_or_default(),
372 task.id = id.as_u64(),
373 "fn" = %std::any::type_name::<F>(),
374 loc.file = location.file(),
375 loc.line = location.line(),
376 loc.col = location.column(),
377 );
378 fut.instrument(span)
379 };
380
381 #[cfg(not(all(tokio_unstable, feature = "tracing")))]
382 let _ = name;
383
384 let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id);
385
386 let spawned = self.spawn_task(Task::new(task, is_mandatory), rt);
387 (handle, spawned)
388 }
389
390 fn spawn_task(&self, task: Task, rt: &Handle) -> Result<(), SpawnError> {
391 let mut shared = self.inner.shared.lock();
392
393 if shared.shutdown {
394 // Shutdown the task: it's fine to shutdown this task (even if
395 // mandatory) because it was scheduled after the shutdown of the
396 // runtime began.
397 task.task.shutdown();
398
399 // no need to even push this task; it would never get picked up
400 return Err(SpawnError::ShuttingDown);
401 }
402
403 shared.queue.push_back(task);
404 self.inner.metrics.inc_queue_depth();
405
406 if self.inner.metrics.num_idle_threads() == 0 {
407 // No threads are able to process the task.
408
409 if self.inner.metrics.num_threads() == self.inner.thread_cap {
410 // At max number of threads
411 } else {
412 assert!(shared.shutdown_tx.is_some());
413 let shutdown_tx = shared.shutdown_tx.clone();
414
415 if let Some(shutdown_tx) = shutdown_tx {
416 let id = shared.worker_thread_index;
417
418 match self.spawn_thread(shutdown_tx, rt, id) {
419 Ok(handle) => {
420 self.inner.metrics.inc_num_threads();
421 shared.worker_thread_index += 1;
422 shared.worker_threads.insert(id, handle);
423 }
424 Err(ref e)
425 if is_temporary_os_thread_error(e)
426 && self.inner.metrics.num_threads() > 0 =>
427 {
428 // OS temporarily failed to spawn a new thread.
429 // The task will be picked up eventually by a currently
430 // busy thread.
431 }
432 Err(e) => {
433 // The OS refused to spawn the thread and there is no thread
434 // to pick up the task that has just been pushed to the queue.
435 return Err(SpawnError::NoThreads(e));
436 }
437 }
438 }
439 }
440 } else {
441 // Notify an idle worker thread. The notification counter
442 // is used to count the needed amount of notifications
443 // exactly. Thread libraries may generate spurious
444 // wakeups, this counter is used to keep us in a
445 // consistent state.
446 self.inner.metrics.dec_num_idle_threads();
447 shared.num_notify += 1;
448 self.inner.condvar.notify_one();
449 }
450
451 Ok(())
452 }
453
454 fn spawn_thread(
455 &self,
456 shutdown_tx: shutdown::Sender,
457 rt: &Handle,
458 id: usize,
459 ) -> std::io::Result<thread::JoinHandle<()>> {
460 let mut builder = thread::Builder::new().name((self.inner.thread_name)());
461
462 if let Some(stack_size) = self.inner.stack_size {
463 builder = builder.stack_size(stack_size);
464 }
465
466 let rt = rt.clone();
467
468 builder.spawn(move || {
469 // Only the reference should be moved into the closure
470 let _enter = rt.enter();
471 rt.inner.blocking_spawner().inner.run(id);
472 drop(shutdown_tx);
473 })
474 }
475}
476
477cfg_metrics! {
478 impl Spawner {
479 pub(crate) fn num_threads(&self) -> usize {
480 self.inner.metrics.num_threads()
481 }
482
483 pub(crate) fn num_idle_threads(&self) -> usize {
484 self.inner.metrics.num_idle_threads()
485 }
486
487 pub(crate) fn queue_depth(&self) -> usize {
488 self.inner.metrics.queue_depth()
489 }
490 }
491}
492
493// Tells whether the error when spawning a thread is temporary.
494#[inline]
495fn is_temporary_os_thread_error(error: &std::io::Error) -> bool {
496 matches!(error.kind(), std::io::ErrorKind::WouldBlock)
497}
498
499impl Inner {
500 fn run(&self, worker_thread_id: usize) {
501 if let Some(f) = &self.after_start {
502 f();
503 }
504
505 let mut shared = self.shared.lock();
506 let mut join_on_thread = None;
507
508 'main: loop {
509 // BUSY
510 while let Some(task) = shared.queue.pop_front() {
511 self.metrics.dec_queue_depth();
512 drop(shared);
513 task.run();
514
515 shared = self.shared.lock();
516 }
517
518 // IDLE
519 self.metrics.inc_num_idle_threads();
520
521 while !shared.shutdown {
522 let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap();
523
524 shared = lock_result.0;
525 let timeout_result = lock_result.1;
526
527 if shared.num_notify != 0 {
528 // We have received a legitimate wakeup,
529 // acknowledge it by decrementing the counter
530 // and transition to the BUSY state.
531 shared.num_notify -= 1;
532 break;
533 }
534
535 // Even if the condvar "timed out", if the pool is entering the
536 // shutdown phase, we want to perform the cleanup logic.
537 if !shared.shutdown && timeout_result.timed_out() {
538 // We'll join the prior timed-out thread's JoinHandle after dropping the lock.
539 // This isn't done when shutting down, because the thread calling shutdown will
540 // handle joining everything.
541 let my_handle = shared.worker_threads.remove(&worker_thread_id);
542 join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle);
543
544 break 'main;
545 }
546
547 // Spurious wakeup detected, go back to sleep.
548 }
549
550 if shared.shutdown {
551 // Drain the queue
552 while let Some(task) = shared.queue.pop_front() {
553 self.metrics.dec_queue_depth();
554 drop(shared);
555
556 task.shutdown_or_run_if_mandatory();
557
558 shared = self.shared.lock();
559 }
560
561 // Work was produced, and we "took" it (by decrementing num_notify).
562 // This means that num_idle was decremented once for our wakeup.
563 // But, since we are exiting, we need to "undo" that, as we'll stay idle.
564 self.metrics.inc_num_idle_threads();
565 // NOTE: Technically we should also do num_notify++ and notify again,
566 // but since we're shutting down anyway, that won't be necessary.
567 break;
568 }
569 }
570
571 // Thread exit
572 self.metrics.dec_num_threads();
573
574 // num_idle should now be tracked exactly, panic
575 // with a descriptive message if it is not the
576 // case.
577 let prev_idle = self.metrics.dec_num_idle_threads();
578 assert!(
579 prev_idle >= self.metrics.num_idle_threads(),
580 "num_idle_threads underflowed on thread exit"
581 );
582
583 if shared.shutdown && self.metrics.num_threads() == 0 {
584 self.condvar.notify_one();
585 }
586
587 drop(shared);
588
589 if let Some(f) = &self.before_stop {
590 f();
591 }
592
593 if let Some(handle) = join_on_thread {
594 let _ = handle.join();
595 }
596 }
597}
598
599impl fmt::Debug for Spawner {
600 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
601 fmt.debug_struct("blocking::Spawner").finish()
602 }
603}
604