1// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2// file at the top-level directory of this distribution and at
3// http://rust-lang.org/COPYRIGHT.
4//
5// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8// option. This file may not be copied, modified, or distributed
9// except according to those terms.
10
11//! A thread pool used to execute functions in parallel.
12//!
13//! Spawns a specified number of worker threads and replenishes the pool if any worker threads
14//! panic.
15//!
16//! # Examples
17//!
18//! ## Synchronized with a channel
19//!
20//! Every thread sends one message over the channel, which then is collected with the `take()`.
21//!
22//! ```
23//! use threadpool::ThreadPool;
24//! use std::sync::mpsc::channel;
25//!
26//! let n_workers = 4;
27//! let n_jobs = 8;
28//! let pool = ThreadPool::new(n_workers);
29//!
30//! let (tx, rx) = channel();
31//! for _ in 0..n_jobs {
32//! let tx = tx.clone();
33//! pool.execute(move|| {
34//! tx.send(1).expect("channel will be there waiting for the pool");
35//! });
36//! }
37//!
38//! assert_eq!(rx.iter().take(n_jobs).fold(0, |a, b| a + b), 8);
39//! ```
40//!
41//! ## Synchronized with a barrier
42//!
43//! Keep in mind, if a barrier synchronizes more jobs than you have workers in the pool,
44//! you will end up with a [deadlock](https://en.wikipedia.org/wiki/Deadlock)
45//! at the barrier which is [not considered unsafe](
46//! https://doc.rust-lang.org/reference/behavior-not-considered-unsafe.html).
47//!
48//! ```
49//! use threadpool::ThreadPool;
50//! use std::sync::{Arc, Barrier};
51//! use std::sync::atomic::{AtomicUsize, Ordering};
52//!
53//! // create at least as many workers as jobs or you will deadlock yourself
54//! let n_workers = 42;
55//! let n_jobs = 23;
56//! let pool = ThreadPool::new(n_workers);
57//! let an_atomic = Arc::new(AtomicUsize::new(0));
58//!
59//! assert!(n_jobs <= n_workers, "too many jobs, will deadlock");
60//!
61//! // create a barrier that waits for all jobs plus the starter thread
62//! let barrier = Arc::new(Barrier::new(n_jobs + 1));
63//! for _ in 0..n_jobs {
64//! let barrier = barrier.clone();
65//! let an_atomic = an_atomic.clone();
66//!
67//! pool.execute(move|| {
68//! // do the heavy work
69//! an_atomic.fetch_add(1, Ordering::Relaxed);
70//!
71//! // then wait for the other threads
72//! barrier.wait();
73//! });
74//! }
75//!
76//! // wait for the threads to finish the work
77//! barrier.wait();
78//! assert_eq!(an_atomic.load(Ordering::SeqCst), /* n_jobs = */ 23);
79//! ```
80
81extern crate num_cpus;
82
83use std::fmt;
84use std::sync::atomic::{AtomicUsize, Ordering};
85use std::sync::mpsc::{channel, Receiver, Sender};
86use std::sync::{Arc, Condvar, Mutex};
87use std::thread;
88
89trait FnBox {
90 fn call_box(self: Box<Self>);
91}
92
93impl<F: FnOnce()> FnBox for F {
94 fn call_box(self: Box<F>) {
95 (*self)()
96 }
97}
98
99type Thunk<'a> = Box<FnBox + Send + 'a>;
100
101struct Sentinel<'a> {
102 shared_data: &'a Arc<ThreadPoolSharedData>,
103 active: bool,
104}
105
106impl<'a> Sentinel<'a> {
107 fn new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a> {
108 Sentinel {
109 shared_data: shared_data,
110 active: true,
111 }
112 }
113
114 /// Cancel and destroy this sentinel.
115 fn cancel(mut self) {
116 self.active = false;
117 }
118}
119
120impl<'a> Drop for Sentinel<'a> {
121 fn drop(&mut self) {
122 if self.active {
123 self.shared_data.active_count.fetch_sub(val:1, order:Ordering::SeqCst);
124 if thread::panicking() {
125 self.shared_data.panic_count.fetch_add(val:1, order:Ordering::SeqCst);
126 }
127 self.shared_data.no_work_notify_all();
128 spawn_in_pool(self.shared_data.clone())
129 }
130 }
131}
132
133/// [`ThreadPool`] factory, which can be used in order to configure the properties of the
134/// [`ThreadPool`].
135///
136/// The three configuration options available:
137///
138/// * `num_threads`: maximum number of threads that will be alive at any given moment by the built
139/// [`ThreadPool`]
140/// * `thread_name`: thread name for each of the threads spawned by the built [`ThreadPool`]
141/// * `thread_stack_size`: stack size (in bytes) for each of the threads spawned by the built
142/// [`ThreadPool`]
143///
144/// [`ThreadPool`]: struct.ThreadPool.html
145///
146/// # Examples
147///
148/// Build a [`ThreadPool`] that uses a maximum of eight threads simultaneously and each thread has
149/// a 8 MB stack size:
150///
151/// ```
152/// let pool = threadpool::Builder::new()
153/// .num_threads(8)
154/// .thread_stack_size(8_000_000)
155/// .build();
156/// ```
157#[derive(Clone, Default)]
158pub struct Builder {
159 num_threads: Option<usize>,
160 thread_name: Option<String>,
161 thread_stack_size: Option<usize>,
162}
163
164impl Builder {
165 /// Initiate a new [`Builder`].
166 ///
167 /// [`Builder`]: struct.Builder.html
168 ///
169 /// # Examples
170 ///
171 /// ```
172 /// let builder = threadpool::Builder::new();
173 /// ```
174 pub fn new() -> Builder {
175 Builder {
176 num_threads: None,
177 thread_name: None,
178 thread_stack_size: None,
179 }
180 }
181
182 /// Set the maximum number of worker-threads that will be alive at any given moment by the built
183 /// [`ThreadPool`]. If not specified, defaults the number of threads to the number of CPUs.
184 ///
185 /// [`ThreadPool`]: struct.ThreadPool.html
186 ///
187 /// # Panics
188 ///
189 /// This method will panic if `num_threads` is 0.
190 ///
191 /// # Examples
192 ///
193 /// No more than eight threads will be alive simultaneously for this pool:
194 ///
195 /// ```
196 /// use std::thread;
197 ///
198 /// let pool = threadpool::Builder::new()
199 /// .num_threads(8)
200 /// .build();
201 ///
202 /// for _ in 0..100 {
203 /// pool.execute(|| {
204 /// println!("Hello from a worker thread!")
205 /// })
206 /// }
207 /// ```
208 pub fn num_threads(mut self, num_threads: usize) -> Builder {
209 assert!(num_threads > 0);
210 self.num_threads = Some(num_threads);
211 self
212 }
213
214 /// Set the thread name for each of the threads spawned by the built [`ThreadPool`]. If not
215 /// specified, threads spawned by the thread pool will be unnamed.
216 ///
217 /// [`ThreadPool`]: struct.ThreadPool.html
218 ///
219 /// # Examples
220 ///
221 /// Each thread spawned by this pool will have the name "foo":
222 ///
223 /// ```
224 /// use std::thread;
225 ///
226 /// let pool = threadpool::Builder::new()
227 /// .thread_name("foo".into())
228 /// .build();
229 ///
230 /// for _ in 0..100 {
231 /// pool.execute(|| {
232 /// assert_eq!(thread::current().name(), Some("foo"));
233 /// })
234 /// }
235 /// ```
236 pub fn thread_name(mut self, name: String) -> Builder {
237 self.thread_name = Some(name);
238 self
239 }
240
241 /// Set the stack size (in bytes) for each of the threads spawned by the built [`ThreadPool`].
242 /// If not specified, threads spawned by the threadpool will have a stack size [as specified in
243 /// the `std::thread` documentation][thread].
244 ///
245 /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size
246 /// [`ThreadPool`]: struct.ThreadPool.html
247 ///
248 /// # Examples
249 ///
250 /// Each thread spawned by this pool will have a 4 MB stack:
251 ///
252 /// ```
253 /// let pool = threadpool::Builder::new()
254 /// .thread_stack_size(4_000_000)
255 /// .build();
256 ///
257 /// for _ in 0..100 {
258 /// pool.execute(|| {
259 /// println!("This thread has a 4 MB stack size!");
260 /// })
261 /// }
262 /// ```
263 pub fn thread_stack_size(mut self, size: usize) -> Builder {
264 self.thread_stack_size = Some(size);
265 self
266 }
267
268 /// Finalize the [`Builder`] and build the [`ThreadPool`].
269 ///
270 /// [`Builder`]: struct.Builder.html
271 /// [`ThreadPool`]: struct.ThreadPool.html
272 ///
273 /// # Examples
274 ///
275 /// ```
276 /// let pool = threadpool::Builder::new()
277 /// .num_threads(8)
278 /// .thread_stack_size(4_000_000)
279 /// .build();
280 /// ```
281 pub fn build(self) -> ThreadPool {
282 let (tx, rx) = channel::<Thunk<'static>>();
283
284 let num_threads = self.num_threads.unwrap_or_else(num_cpus::get);
285
286 let shared_data = Arc::new(ThreadPoolSharedData {
287 name: self.thread_name,
288 job_receiver: Mutex::new(rx),
289 empty_condvar: Condvar::new(),
290 empty_trigger: Mutex::new(()),
291 join_generation: AtomicUsize::new(0),
292 queued_count: AtomicUsize::new(0),
293 active_count: AtomicUsize::new(0),
294 max_thread_count: AtomicUsize::new(num_threads),
295 panic_count: AtomicUsize::new(0),
296 stack_size: self.thread_stack_size,
297 });
298
299 // Threadpool threads
300 for _ in 0..num_threads {
301 spawn_in_pool(shared_data.clone());
302 }
303
304 ThreadPool {
305 jobs: tx,
306 shared_data: shared_data,
307 }
308 }
309}
310
311struct ThreadPoolSharedData {
312 name: Option<String>,
313 job_receiver: Mutex<Receiver<Thunk<'static>>>,
314 empty_trigger: Mutex<()>,
315 empty_condvar: Condvar,
316 join_generation: AtomicUsize,
317 queued_count: AtomicUsize,
318 active_count: AtomicUsize,
319 max_thread_count: AtomicUsize,
320 panic_count: AtomicUsize,
321 stack_size: Option<usize>,
322}
323
324impl ThreadPoolSharedData {
325 fn has_work(&self) -> bool {
326 self.queued_count.load(order:Ordering::SeqCst) > 0 || self.active_count.load(order:Ordering::SeqCst) > 0
327 }
328
329 /// Notify all observers joining this pool if there is no more work to do.
330 fn no_work_notify_all(&self) {
331 if !self.has_work() {
332 *self
333 .empty_trigger
334 .lock()
335 .expect(msg:"Unable to notify all joining threads");
336 self.empty_condvar.notify_all();
337 }
338 }
339}
340
341/// Abstraction of a thread pool for basic parallelism.
342pub struct ThreadPool {
343 // How the threadpool communicates with subthreads.
344 //
345 // This is the only such Sender, so when it is dropped all subthreads will
346 // quit.
347 jobs: Sender<Thunk<'static>>,
348 shared_data: Arc<ThreadPoolSharedData>,
349}
350
351impl ThreadPool {
352 /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
353 ///
354 /// # Panics
355 ///
356 /// This function will panic if `num_threads` is 0.
357 ///
358 /// # Examples
359 ///
360 /// Create a new thread pool capable of executing four jobs concurrently:
361 ///
362 /// ```
363 /// use threadpool::ThreadPool;
364 ///
365 /// let pool = ThreadPool::new(4);
366 /// ```
367 pub fn new(num_threads: usize) -> ThreadPool {
368 Builder::new().num_threads(num_threads).build()
369 }
370
371 /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
372 /// Each thread will have the [name][thread name] `name`.
373 ///
374 /// # Panics
375 ///
376 /// This function will panic if `num_threads` is 0.
377 ///
378 /// # Examples
379 ///
380 /// ```rust
381 /// use std::thread;
382 /// use threadpool::ThreadPool;
383 ///
384 /// let pool = ThreadPool::with_name("worker".into(), 2);
385 /// for _ in 0..2 {
386 /// pool.execute(|| {
387 /// assert_eq!(
388 /// thread::current().name(),
389 /// Some("worker")
390 /// );
391 /// });
392 /// }
393 /// pool.join();
394 /// ```
395 ///
396 /// [thread name]: https://doc.rust-lang.org/std/thread/struct.Thread.html#method.name
397 pub fn with_name(name: String, num_threads: usize) -> ThreadPool {
398 Builder::new()
399 .num_threads(num_threads)
400 .thread_name(name)
401 .build()
402 }
403
404 /// **Deprecated: Use [`ThreadPool::with_name`](#method.with_name)**
405 #[inline(always)]
406 #[deprecated(since = "1.4.0", note = "use ThreadPool::with_name")]
407 pub fn new_with_name(name: String, num_threads: usize) -> ThreadPool {
408 Self::with_name(name, num_threads)
409 }
410
411 /// Executes the function `job` on a thread in the pool.
412 ///
413 /// # Examples
414 ///
415 /// Execute four jobs on a thread pool that can run two jobs concurrently:
416 ///
417 /// ```
418 /// use threadpool::ThreadPool;
419 ///
420 /// let pool = ThreadPool::new(2);
421 /// pool.execute(|| println!("hello"));
422 /// pool.execute(|| println!("world"));
423 /// pool.execute(|| println!("foo"));
424 /// pool.execute(|| println!("bar"));
425 /// pool.join();
426 /// ```
427 pub fn execute<F>(&self, job: F)
428 where
429 F: FnOnce() + Send + 'static,
430 {
431 self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
432 self.jobs
433 .send(Box::new(job))
434 .expect("ThreadPool::execute unable to send job into queue.");
435 }
436
437 /// Returns the number of jobs waiting to executed in the pool.
438 ///
439 /// # Examples
440 ///
441 /// ```
442 /// use threadpool::ThreadPool;
443 /// use std::time::Duration;
444 /// use std::thread::sleep;
445 ///
446 /// let pool = ThreadPool::new(2);
447 /// for _ in 0..10 {
448 /// pool.execute(|| {
449 /// sleep(Duration::from_secs(100));
450 /// });
451 /// }
452 ///
453 /// sleep(Duration::from_secs(1)); // wait for threads to start
454 /// assert_eq!(8, pool.queued_count());
455 /// ```
456 pub fn queued_count(&self) -> usize {
457 self.shared_data.queued_count.load(Ordering::Relaxed)
458 }
459
460 /// Returns the number of currently active threads.
461 ///
462 /// # Examples
463 ///
464 /// ```
465 /// use threadpool::ThreadPool;
466 /// use std::time::Duration;
467 /// use std::thread::sleep;
468 ///
469 /// let pool = ThreadPool::new(4);
470 /// for _ in 0..10 {
471 /// pool.execute(move || {
472 /// sleep(Duration::from_secs(100));
473 /// });
474 /// }
475 ///
476 /// sleep(Duration::from_secs(1)); // wait for threads to start
477 /// assert_eq!(4, pool.active_count());
478 /// ```
479 pub fn active_count(&self) -> usize {
480 self.shared_data.active_count.load(Ordering::SeqCst)
481 }
482
483 /// Returns the maximum number of threads the pool will execute concurrently.
484 ///
485 /// # Examples
486 ///
487 /// ```
488 /// use threadpool::ThreadPool;
489 ///
490 /// let mut pool = ThreadPool::new(4);
491 /// assert_eq!(4, pool.max_count());
492 ///
493 /// pool.set_num_threads(8);
494 /// assert_eq!(8, pool.max_count());
495 /// ```
496 pub fn max_count(&self) -> usize {
497 self.shared_data.max_thread_count.load(Ordering::Relaxed)
498 }
499
500 /// Returns the number of panicked threads over the lifetime of the pool.
501 ///
502 /// # Examples
503 ///
504 /// ```
505 /// use threadpool::ThreadPool;
506 ///
507 /// let pool = ThreadPool::new(4);
508 /// for n in 0..10 {
509 /// pool.execute(move || {
510 /// // simulate a panic
511 /// if n % 2 == 0 {
512 /// panic!()
513 /// }
514 /// });
515 /// }
516 /// pool.join();
517 ///
518 /// assert_eq!(5, pool.panic_count());
519 /// ```
520 pub fn panic_count(&self) -> usize {
521 self.shared_data.panic_count.load(Ordering::Relaxed)
522 }
523
524 /// **Deprecated: Use [`ThreadPool::set_num_threads`](#method.set_num_threads)**
525 #[deprecated(since = "1.3.0", note = "use ThreadPool::set_num_threads")]
526 pub fn set_threads(&mut self, num_threads: usize) {
527 self.set_num_threads(num_threads)
528 }
529
530 /// Sets the number of worker-threads to use as `num_threads`.
531 /// Can be used to change the threadpool size during runtime.
532 /// Will not abort already running or waiting threads.
533 ///
534 /// # Panics
535 ///
536 /// This function will panic if `num_threads` is 0.
537 ///
538 /// # Examples
539 ///
540 /// ```
541 /// use threadpool::ThreadPool;
542 /// use std::time::Duration;
543 /// use std::thread::sleep;
544 ///
545 /// let mut pool = ThreadPool::new(4);
546 /// for _ in 0..10 {
547 /// pool.execute(move || {
548 /// sleep(Duration::from_secs(100));
549 /// });
550 /// }
551 ///
552 /// sleep(Duration::from_secs(1)); // wait for threads to start
553 /// assert_eq!(4, pool.active_count());
554 /// assert_eq!(6, pool.queued_count());
555 ///
556 /// // Increase thread capacity of the pool
557 /// pool.set_num_threads(8);
558 ///
559 /// sleep(Duration::from_secs(1)); // wait for new threads to start
560 /// assert_eq!(8, pool.active_count());
561 /// assert_eq!(2, pool.queued_count());
562 ///
563 /// // Decrease thread capacity of the pool
564 /// // No active threads are killed
565 /// pool.set_num_threads(4);
566 ///
567 /// assert_eq!(8, pool.active_count());
568 /// assert_eq!(2, pool.queued_count());
569 /// ```
570 pub fn set_num_threads(&mut self, num_threads: usize) {
571 assert!(num_threads >= 1);
572 let prev_num_threads = self
573 .shared_data
574 .max_thread_count
575 .swap(num_threads, Ordering::Release);
576 if let Some(num_spawn) = num_threads.checked_sub(prev_num_threads) {
577 // Spawn new threads
578 for _ in 0..num_spawn {
579 spawn_in_pool(self.shared_data.clone());
580 }
581 }
582 }
583
584 /// Block the current thread until all jobs in the pool have been executed.
585 ///
586 /// Calling `join` on an empty pool will cause an immediate return.
587 /// `join` may be called from multiple threads concurrently.
588 /// A `join` is an atomic point in time. All threads joining before the join
589 /// event will exit together even if the pool is processing new jobs by the
590 /// time they get scheduled.
591 ///
592 /// Calling `join` from a thread within the pool will cause a deadlock. This
593 /// behavior is considered safe.
594 ///
595 /// # Examples
596 ///
597 /// ```
598 /// use threadpool::ThreadPool;
599 /// use std::sync::Arc;
600 /// use std::sync::atomic::{AtomicUsize, Ordering};
601 ///
602 /// let pool = ThreadPool::new(8);
603 /// let test_count = Arc::new(AtomicUsize::new(0));
604 ///
605 /// for _ in 0..42 {
606 /// let test_count = test_count.clone();
607 /// pool.execute(move || {
608 /// test_count.fetch_add(1, Ordering::Relaxed);
609 /// });
610 /// }
611 ///
612 /// pool.join();
613 /// assert_eq!(42, test_count.load(Ordering::Relaxed));
614 /// ```
615 pub fn join(&self) {
616 // fast path requires no mutex
617 if self.shared_data.has_work() == false {
618 return ();
619 }
620
621 let generation = self.shared_data.join_generation.load(Ordering::SeqCst);
622 let mut lock = self.shared_data.empty_trigger.lock().unwrap();
623
624 while generation == self.shared_data.join_generation.load(Ordering::Relaxed)
625 && self.shared_data.has_work()
626 {
627 lock = self.shared_data.empty_condvar.wait(lock).unwrap();
628 }
629
630 // increase generation if we are the first thread to come out of the loop
631 self.shared_data.join_generation.compare_and_swap(
632 generation,
633 generation.wrapping_add(1),
634 Ordering::SeqCst,
635 );
636 }
637}
638
639impl Clone for ThreadPool {
640 /// Cloning a pool will create a new handle to the pool.
641 /// The behavior is similar to [Arc](https://doc.rust-lang.org/stable/std/sync/struct.Arc.html).
642 ///
643 /// We could for example submit jobs from multiple threads concurrently.
644 ///
645 /// ```
646 /// use threadpool::ThreadPool;
647 /// use std::thread;
648 /// use std::sync::mpsc::channel;
649 ///
650 /// let pool = ThreadPool::with_name("clone example".into(), 2);
651 ///
652 /// let results = (0..2)
653 /// .map(|i| {
654 /// let pool = pool.clone();
655 /// thread::spawn(move || {
656 /// let (tx, rx) = channel();
657 /// for i in 1..12 {
658 /// let tx = tx.clone();
659 /// pool.execute(move || {
660 /// tx.send(i).expect("channel will be waiting");
661 /// });
662 /// }
663 /// drop(tx);
664 /// if i == 0 {
665 /// rx.iter().fold(0, |accumulator, element| accumulator + element)
666 /// } else {
667 /// rx.iter().fold(1, |accumulator, element| accumulator * element)
668 /// }
669 /// })
670 /// })
671 /// .map(|join_handle| join_handle.join().expect("collect results from threads"))
672 /// .collect::<Vec<usize>>();
673 ///
674 /// assert_eq!(vec![66, 39916800], results);
675 /// ```
676 fn clone(&self) -> ThreadPool {
677 ThreadPool {
678 jobs: self.jobs.clone(),
679 shared_data: self.shared_data.clone(),
680 }
681 }
682}
683
684/// Create a thread pool with one thread per CPU.
685/// On machines with hyperthreading,
686/// this will create one thread per hyperthread.
687impl Default for ThreadPool {
688 fn default() -> Self {
689 ThreadPool::new(num_threads:num_cpus::get())
690 }
691}
692
693impl fmt::Debug for ThreadPool {
694 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
695 f&mut DebugStruct<'_, '_>.debug_struct("ThreadPool")
696 .field("name", &self.shared_data.name)
697 .field("queued_count", &self.queued_count())
698 .field("active_count", &self.active_count())
699 .field(name:"max_count", &self.max_count())
700 .finish()
701 }
702}
703
704impl PartialEq for ThreadPool {
705 /// Check if you are working with the same pool
706 ///
707 /// ```
708 /// use threadpool::ThreadPool;
709 ///
710 /// let a = ThreadPool::new(2);
711 /// let b = ThreadPool::new(2);
712 ///
713 /// assert_eq!(a, a);
714 /// assert_eq!(b, b);
715 ///
716 /// # // TODO: change this to assert_ne in the future
717 /// assert!(a != b);
718 /// assert!(b != a);
719 /// ```
720 fn eq(&self, other: &ThreadPool) -> bool {
721 let a: &ThreadPoolSharedData = &*self.shared_data;
722 let b: &ThreadPoolSharedData = &*other.shared_data;
723 a as *const ThreadPoolSharedData == b as *const ThreadPoolSharedData
724 // with rust 1.17 and late:
725 // Arc::ptr_eq(&self.shared_data, &other.shared_data)
726 }
727}
728impl Eq for ThreadPool {}
729
730fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>) {
731 let mut builder = thread::Builder::new();
732 if let Some(ref name) = shared_data.name {
733 builder = builder.name(name.clone());
734 }
735 if let Some(ref stack_size) = shared_data.stack_size {
736 builder = builder.stack_size(stack_size.to_owned());
737 }
738 builder
739 .spawn(move || {
740 // Will spawn a new thread on panic unless it is cancelled.
741 let sentinel = Sentinel::new(&shared_data);
742
743 loop {
744 // Shutdown this thread if the pool has become smaller
745 let thread_counter_val = shared_data.active_count.load(Ordering::Acquire);
746 let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed);
747 if thread_counter_val >= max_thread_count_val {
748 break;
749 }
750 let message = {
751 // Only lock jobs for the time it takes
752 // to get a job, not run it.
753 let lock = shared_data
754 .job_receiver
755 .lock()
756 .expect("Worker thread unable to lock job_receiver");
757 lock.recv()
758 };
759
760 let job = match message {
761 Ok(job) => job,
762 // The ThreadPool was dropped.
763 Err(..) => break,
764 };
765 // Do not allow IR around the job execution
766 shared_data.active_count.fetch_add(1, Ordering::SeqCst);
767 shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);
768
769 job.call_box();
770
771 shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
772 shared_data.no_work_notify_all();
773 }
774
775 sentinel.cancel();
776 })
777 .unwrap();
778}
779
780#[cfg(test)]
781mod test {
782 use super::{Builder, ThreadPool};
783 use std::sync::atomic::{AtomicUsize, Ordering};
784 use std::sync::mpsc::{channel, sync_channel};
785 use std::sync::{Arc, Barrier};
786 use std::thread::{self, sleep};
787 use std::time::Duration;
788
789 const TEST_TASKS: usize = 4;
790
791 #[test]
792 fn test_set_num_threads_increasing() {
793 let new_thread_amount = TEST_TASKS + 8;
794 let mut pool = ThreadPool::new(TEST_TASKS);
795 for _ in 0..TEST_TASKS {
796 pool.execute(move || sleep(Duration::from_secs(23)));
797 }
798 sleep(Duration::from_secs(1));
799 assert_eq!(pool.active_count(), TEST_TASKS);
800
801 pool.set_num_threads(new_thread_amount);
802
803 for _ in 0..(new_thread_amount - TEST_TASKS) {
804 pool.execute(move || sleep(Duration::from_secs(23)));
805 }
806 sleep(Duration::from_secs(1));
807 assert_eq!(pool.active_count(), new_thread_amount);
808
809 pool.join();
810 }
811
812 #[test]
813 fn test_set_num_threads_decreasing() {
814 let new_thread_amount = 2;
815 let mut pool = ThreadPool::new(TEST_TASKS);
816 for _ in 0..TEST_TASKS {
817 pool.execute(move || {
818 assert_eq!(1, 1);
819 });
820 }
821 pool.set_num_threads(new_thread_amount);
822 for _ in 0..new_thread_amount {
823 pool.execute(move || sleep(Duration::from_secs(23)));
824 }
825 sleep(Duration::from_secs(1));
826 assert_eq!(pool.active_count(), new_thread_amount);
827
828 pool.join();
829 }
830
831 #[test]
832 fn test_active_count() {
833 let pool = ThreadPool::new(TEST_TASKS);
834 for _ in 0..2 * TEST_TASKS {
835 pool.execute(move || loop {
836 sleep(Duration::from_secs(10))
837 });
838 }
839 sleep(Duration::from_secs(1));
840 let active_count = pool.active_count();
841 assert_eq!(active_count, TEST_TASKS);
842 let initialized_count = pool.max_count();
843 assert_eq!(initialized_count, TEST_TASKS);
844 }
845
846 #[test]
847 fn test_works() {
848 let pool = ThreadPool::new(TEST_TASKS);
849
850 let (tx, rx) = channel();
851 for _ in 0..TEST_TASKS {
852 let tx = tx.clone();
853 pool.execute(move || {
854 tx.send(1).unwrap();
855 });
856 }
857
858 assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
859 }
860
861 #[test]
862 #[should_panic]
863 fn test_zero_tasks_panic() {
864 ThreadPool::new(0);
865 }
866
867 #[test]
868 fn test_recovery_from_subtask_panic() {
869 let pool = ThreadPool::new(TEST_TASKS);
870
871 // Panic all the existing threads.
872 for _ in 0..TEST_TASKS {
873 pool.execute(move || panic!("Ignore this panic, it must!"));
874 }
875 pool.join();
876
877 assert_eq!(pool.panic_count(), TEST_TASKS);
878
879 // Ensure new threads were spawned to compensate.
880 let (tx, rx) = channel();
881 for _ in 0..TEST_TASKS {
882 let tx = tx.clone();
883 pool.execute(move || {
884 tx.send(1).unwrap();
885 });
886 }
887
888 assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
889 }
890
891 #[test]
892 fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
893 let pool = ThreadPool::new(TEST_TASKS);
894 let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
895
896 // Panic all the existing threads in a bit.
897 for _ in 0..TEST_TASKS {
898 let waiter = waiter.clone();
899 pool.execute(move || {
900 waiter.wait();
901 panic!("Ignore this panic, it should!");
902 });
903 }
904
905 drop(pool);
906
907 // Kick off the failure.
908 waiter.wait();
909 }
910
911 #[test]
912 fn test_massive_task_creation() {
913 let test_tasks = 4_200_000;
914
915 let pool = ThreadPool::new(TEST_TASKS);
916 let b0 = Arc::new(Barrier::new(TEST_TASKS + 1));
917 let b1 = Arc::new(Barrier::new(TEST_TASKS + 1));
918
919 let (tx, rx) = channel();
920
921 for i in 0..test_tasks {
922 let tx = tx.clone();
923 let (b0, b1) = (b0.clone(), b1.clone());
924
925 pool.execute(move || {
926 // Wait until the pool has been filled once.
927 if i < TEST_TASKS {
928 b0.wait();
929 // wait so the pool can be measured
930 b1.wait();
931 }
932
933 tx.send(1).is_ok();
934 });
935 }
936
937 b0.wait();
938 assert_eq!(pool.active_count(), TEST_TASKS);
939 b1.wait();
940
941 assert_eq!(rx.iter().take(test_tasks).fold(0, |a, b| a + b), test_tasks);
942 pool.join();
943
944 let atomic_active_count = pool.active_count();
945 assert!(
946 atomic_active_count == 0,
947 "atomic_active_count: {}",
948 atomic_active_count
949 );
950 }
951
952 #[test]
953 fn test_shrink() {
954 let test_tasks_begin = TEST_TASKS + 2;
955
956 let mut pool = ThreadPool::new(test_tasks_begin);
957 let b0 = Arc::new(Barrier::new(test_tasks_begin + 1));
958 let b1 = Arc::new(Barrier::new(test_tasks_begin + 1));
959
960 for _ in 0..test_tasks_begin {
961 let (b0, b1) = (b0.clone(), b1.clone());
962 pool.execute(move || {
963 b0.wait();
964 b1.wait();
965 });
966 }
967
968 let b2 = Arc::new(Barrier::new(TEST_TASKS + 1));
969 let b3 = Arc::new(Barrier::new(TEST_TASKS + 1));
970
971 for _ in 0..TEST_TASKS {
972 let (b2, b3) = (b2.clone(), b3.clone());
973 pool.execute(move || {
974 b2.wait();
975 b3.wait();
976 });
977 }
978
979 b0.wait();
980 pool.set_num_threads(TEST_TASKS);
981
982 assert_eq!(pool.active_count(), test_tasks_begin);
983 b1.wait();
984
985 b2.wait();
986 assert_eq!(pool.active_count(), TEST_TASKS);
987 b3.wait();
988 }
989
990 #[test]
991 fn test_name() {
992 let name = "test";
993 let mut pool = ThreadPool::with_name(name.to_owned(), 2);
994 let (tx, rx) = sync_channel(0);
995
996 // initial thread should share the name "test"
997 for _ in 0..2 {
998 let tx = tx.clone();
999 pool.execute(move || {
1000 let name = thread::current().name().unwrap().to_owned();
1001 tx.send(name).unwrap();
1002 });
1003 }
1004
1005 // new spawn thread should share the name "test" too.
1006 pool.set_num_threads(3);
1007 let tx_clone = tx.clone();
1008 pool.execute(move || {
1009 let name = thread::current().name().unwrap().to_owned();
1010 tx_clone.send(name).unwrap();
1011 panic!();
1012 });
1013
1014 // recover thread should share the name "test" too.
1015 pool.execute(move || {
1016 let name = thread::current().name().unwrap().to_owned();
1017 tx.send(name).unwrap();
1018 });
1019
1020 for thread_name in rx.iter().take(4) {
1021 assert_eq!(name, thread_name);
1022 }
1023 }
1024
1025 #[test]
1026 fn test_debug() {
1027 let pool = ThreadPool::new(4);
1028 let debug = format!("{:?}", pool);
1029 assert_eq!(
1030 debug,
1031 "ThreadPool { name: None, queued_count: 0, active_count: 0, max_count: 4 }"
1032 );
1033
1034 let pool = ThreadPool::with_name("hello".into(), 4);
1035 let debug = format!("{:?}", pool);
1036 assert_eq!(
1037 debug,
1038 "ThreadPool { name: Some(\"hello\"), queued_count: 0, active_count: 0, max_count: 4 }"
1039 );
1040
1041 let pool = ThreadPool::new(4);
1042 pool.execute(move || sleep(Duration::from_secs(5)));
1043 sleep(Duration::from_secs(1));
1044 let debug = format!("{:?}", pool);
1045 assert_eq!(
1046 debug,
1047 "ThreadPool { name: None, queued_count: 0, active_count: 1, max_count: 4 }"
1048 );
1049 }
1050
1051 #[test]
1052 fn test_repeate_join() {
1053 let pool = ThreadPool::with_name("repeate join test".into(), 8);
1054 let test_count = Arc::new(AtomicUsize::new(0));
1055
1056 for _ in 0..42 {
1057 let test_count = test_count.clone();
1058 pool.execute(move || {
1059 sleep(Duration::from_secs(2));
1060 test_count.fetch_add(1, Ordering::Release);
1061 });
1062 }
1063
1064 println!("{:?}", pool);
1065 pool.join();
1066 assert_eq!(42, test_count.load(Ordering::Acquire));
1067
1068 for _ in 0..42 {
1069 let test_count = test_count.clone();
1070 pool.execute(move || {
1071 sleep(Duration::from_secs(2));
1072 test_count.fetch_add(1, Ordering::Relaxed);
1073 });
1074 }
1075 pool.join();
1076 assert_eq!(84, test_count.load(Ordering::Relaxed));
1077 }
1078
1079 #[test]
1080 fn test_multi_join() {
1081 use std::sync::mpsc::TryRecvError::*;
1082
1083 // Toggle the following lines to debug the deadlock
1084 fn error(_s: String) {
1085 //use ::std::io::Write;
1086 //let stderr = ::std::io::stderr();
1087 //let mut stderr = stderr.lock();
1088 //stderr.write(&_s.as_bytes()).is_ok();
1089 }
1090
1091 let pool0 = ThreadPool::with_name("multi join pool0".into(), 4);
1092 let pool1 = ThreadPool::with_name("multi join pool1".into(), 4);
1093 let (tx, rx) = channel();
1094
1095 for i in 0..8 {
1096 let pool1 = pool1.clone();
1097 let pool0_ = pool0.clone();
1098 let tx = tx.clone();
1099 pool0.execute(move || {
1100 pool1.execute(move || {
1101 error(format!("p1: {} -=- {:?}\n", i, pool0_));
1102 pool0_.join();
1103 error(format!("p1: send({})\n", i));
1104 tx.send(i).expect("send i from pool1 -> main");
1105 });
1106 error(format!("p0: {}\n", i));
1107 });
1108 }
1109 drop(tx);
1110
1111 assert_eq!(rx.try_recv(), Err(Empty));
1112 error(format!("{:?}\n{:?}\n", pool0, pool1));
1113 pool0.join();
1114 error(format!("pool0.join() complete =-= {:?}", pool1));
1115 pool1.join();
1116 error("pool1.join() complete\n".into());
1117 assert_eq!(
1118 rx.iter().fold(0, |acc, i| acc + i),
1119 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7
1120 );
1121 }
1122
1123 #[test]
1124 fn test_empty_pool() {
1125 // Joining an empty pool must return imminently
1126 let pool = ThreadPool::new(4);
1127
1128 pool.join();
1129
1130 assert!(true);
1131 }
1132
1133 #[test]
1134 fn test_no_fun_or_joy() {
1135 // What happens when you keep adding jobs after a join
1136
1137 fn sleepy_function() {
1138 sleep(Duration::from_secs(6));
1139 }
1140
1141 let pool = ThreadPool::with_name("no fun or joy".into(), 8);
1142
1143 pool.execute(sleepy_function);
1144
1145 let p_t = pool.clone();
1146 thread::spawn(move || {
1147 (0..23).map(|_| p_t.execute(sleepy_function)).count();
1148 });
1149
1150 pool.join();
1151 }
1152
1153 #[test]
1154 fn test_clone() {
1155 let pool = ThreadPool::with_name("clone example".into(), 2);
1156
1157 // This batch of jobs will occupy the pool for some time
1158 for _ in 0..6 {
1159 pool.execute(move || {
1160 sleep(Duration::from_secs(2));
1161 });
1162 }
1163
1164 // The following jobs will be inserted into the pool in a random fashion
1165 let t0 = {
1166 let pool = pool.clone();
1167 thread::spawn(move || {
1168 // wait for the first batch of tasks to finish
1169 pool.join();
1170
1171 let (tx, rx) = channel();
1172 for i in 0..42 {
1173 let tx = tx.clone();
1174 pool.execute(move || {
1175 tx.send(i).expect("channel will be waiting");
1176 });
1177 }
1178 drop(tx);
1179 rx.iter()
1180 .fold(0, |accumulator, element| accumulator + element)
1181 })
1182 };
1183 let t1 = {
1184 let pool = pool.clone();
1185 thread::spawn(move || {
1186 // wait for the first batch of tasks to finish
1187 pool.join();
1188
1189 let (tx, rx) = channel();
1190 for i in 1..12 {
1191 let tx = tx.clone();
1192 pool.execute(move || {
1193 tx.send(i).expect("channel will be waiting");
1194 });
1195 }
1196 drop(tx);
1197 rx.iter()
1198 .fold(1, |accumulator, element| accumulator * element)
1199 })
1200 };
1201
1202 assert_eq!(
1203 861,
1204 t0.join()
1205 .expect("thread 0 will return after calculating additions",)
1206 );
1207 assert_eq!(
1208 39916800,
1209 t1.join()
1210 .expect("thread 1 will return after calculating multiplications",)
1211 );
1212 }
1213
1214 #[test]
1215 fn test_sync_shared_data() {
1216 fn assert_sync<T: Sync>() {}
1217 assert_sync::<super::ThreadPoolSharedData>();
1218 }
1219
1220 #[test]
1221 fn test_send_shared_data() {
1222 fn assert_send<T: Send>() {}
1223 assert_send::<super::ThreadPoolSharedData>();
1224 }
1225
1226 #[test]
1227 fn test_send() {
1228 fn assert_send<T: Send>() {}
1229 assert_send::<ThreadPool>();
1230 }
1231
1232 #[test]
1233 fn test_cloned_eq() {
1234 let a = ThreadPool::new(2);
1235
1236 assert_eq!(a, a.clone());
1237 }
1238
1239 #[test]
1240 /// The scenario is joining threads should not be stuck once their wave
1241 /// of joins has completed. So once one thread joining on a pool has
1242 /// succeded other threads joining on the same pool must get out even if
1243 /// the thread is used for other jobs while the first group is finishing
1244 /// their join
1245 ///
1246 /// In this example this means the waiting threads will exit the join in
1247 /// groups of four because the waiter pool has four workers.
1248 fn test_join_wavesurfer() {
1249 let n_cycles = 4;
1250 let n_workers = 4;
1251 let (tx, rx) = channel();
1252 let builder = Builder::new()
1253 .num_threads(n_workers)
1254 .thread_name("join wavesurfer".into());
1255 let p_waiter = builder.clone().build();
1256 let p_clock = builder.build();
1257
1258 let barrier = Arc::new(Barrier::new(3));
1259 let wave_clock = Arc::new(AtomicUsize::new(0));
1260 let clock_thread = {
1261 let barrier = barrier.clone();
1262 let wave_clock = wave_clock.clone();
1263 thread::spawn(move || {
1264 barrier.wait();
1265 for wave_num in 0..n_cycles {
1266 wave_clock.store(wave_num, Ordering::SeqCst);
1267 sleep(Duration::from_secs(1));
1268 }
1269 })
1270 };
1271
1272 {
1273 let barrier = barrier.clone();
1274 p_clock.execute(move || {
1275 barrier.wait();
1276 // this sleep is for stabilisation on weaker platforms
1277 sleep(Duration::from_millis(100));
1278 });
1279 }
1280
1281 // prepare three waves of jobs
1282 for i in 0..3 * n_workers {
1283 let p_clock = p_clock.clone();
1284 let tx = tx.clone();
1285 let wave_clock = wave_clock.clone();
1286 p_waiter.execute(move || {
1287 let now = wave_clock.load(Ordering::SeqCst);
1288 p_clock.join();
1289 // submit jobs for the second wave
1290 p_clock.execute(|| sleep(Duration::from_secs(1)));
1291 let clock = wave_clock.load(Ordering::SeqCst);
1292 tx.send((now, clock, i)).unwrap();
1293 });
1294 }
1295 println!("all scheduled at {}", wave_clock.load(Ordering::SeqCst));
1296 barrier.wait();
1297
1298 p_clock.join();
1299 //p_waiter.join();
1300
1301 drop(tx);
1302 let mut hist = vec![0; n_cycles];
1303 let mut data = vec![];
1304 for (now, after, i) in rx.iter() {
1305 let mut dur = after - now;
1306 if dur >= n_cycles - 1 {
1307 dur = n_cycles - 1;
1308 }
1309 hist[dur] += 1;
1310
1311 data.push((now, after, i));
1312 }
1313 for (i, n) in hist.iter().enumerate() {
1314 println!(
1315 "\t{}: {} {}",
1316 i,
1317 n,
1318 &*(0..*n).fold("".to_owned(), |s, _| s + "*")
1319 );
1320 }
1321 assert!(data.iter().all(|&(cycle, stop, i)| if i < n_workers {
1322 cycle == stop
1323 } else {
1324 cycle < stop
1325 }));
1326
1327 clock_thread.join().unwrap();
1328 }
1329}
1330