1#![cfg(test)]
2
3use crate::{ThreadPoolBuildError, ThreadPoolBuilder};
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{Arc, Barrier};
6
7#[test]
8#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
9fn worker_thread_index() {
10 let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap();
11 assert_eq!(pool.current_num_threads(), 22);
12 assert_eq!(pool.current_thread_index(), None);
13 let index = pool.install(|| pool.current_thread_index().unwrap());
14 assert!(index < 22);
15}
16
17#[test]
18#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
19fn start_callback_called() {
20 let n_threads = 16;
21 let n_called = Arc::new(AtomicUsize::new(0));
22 // Wait for all the threads in the pool plus the one running tests.
23 let barrier = Arc::new(Barrier::new(n_threads + 1));
24
25 let b = Arc::clone(&barrier);
26 let nc = Arc::clone(&n_called);
27 let start_handler = move |_| {
28 nc.fetch_add(1, Ordering::SeqCst);
29 b.wait();
30 };
31
32 let conf = ThreadPoolBuilder::new()
33 .num_threads(n_threads)
34 .start_handler(start_handler);
35 let _ = conf.build().unwrap();
36
37 // Wait for all the threads to have been scheduled to run.
38 barrier.wait();
39
40 // The handler must have been called on every started thread.
41 assert_eq!(n_called.load(Ordering::SeqCst), n_threads);
42}
43
44#[test]
45#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
46fn exit_callback_called() {
47 let n_threads = 16;
48 let n_called = Arc::new(AtomicUsize::new(0));
49 // Wait for all the threads in the pool plus the one running tests.
50 let barrier = Arc::new(Barrier::new(n_threads + 1));
51
52 let b = Arc::clone(&barrier);
53 let nc = Arc::clone(&n_called);
54 let exit_handler = move |_| {
55 nc.fetch_add(1, Ordering::SeqCst);
56 b.wait();
57 };
58
59 let conf = ThreadPoolBuilder::new()
60 .num_threads(n_threads)
61 .exit_handler(exit_handler);
62 {
63 let _ = conf.build().unwrap();
64 // Drop the pool so it stops the running threads.
65 }
66
67 // Wait for all the threads to have been scheduled to run.
68 barrier.wait();
69
70 // The handler must have been called on every exiting thread.
71 assert_eq!(n_called.load(Ordering::SeqCst), n_threads);
72}
73
74#[test]
75#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
76fn handler_panics_handled_correctly() {
77 let n_threads = 16;
78 let n_called = Arc::new(AtomicUsize::new(0));
79 // Wait for all the threads in the pool plus the one running tests.
80 let start_barrier = Arc::new(Barrier::new(n_threads + 1));
81 let exit_barrier = Arc::new(Barrier::new(n_threads + 1));
82
83 let start_handler = move |_| {
84 panic!("ensure panic handler is called when starting");
85 };
86 let exit_handler = move |_| {
87 panic!("ensure panic handler is called when exiting");
88 };
89
90 let sb = Arc::clone(&start_barrier);
91 let eb = Arc::clone(&exit_barrier);
92 let nc = Arc::clone(&n_called);
93 let panic_handler = move |_| {
94 let val = nc.fetch_add(1, Ordering::SeqCst);
95 if val < n_threads {
96 sb.wait();
97 } else {
98 eb.wait();
99 }
100 };
101
102 let conf = ThreadPoolBuilder::new()
103 .num_threads(n_threads)
104 .start_handler(start_handler)
105 .exit_handler(exit_handler)
106 .panic_handler(panic_handler);
107 {
108 let _ = conf.build().unwrap();
109
110 // Wait for all the threads to start, panic in the start handler,
111 // and been taken care of by the panic handler.
112 start_barrier.wait();
113
114 // Drop the pool so it stops the running threads.
115 }
116
117 // Wait for all the threads to exit, panic in the exit handler,
118 // and been taken care of by the panic handler.
119 exit_barrier.wait();
120
121 // The panic handler must have been called twice on every thread.
122 assert_eq!(n_called.load(Ordering::SeqCst), 2 * n_threads);
123}
124
125#[test]
126#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
127fn check_config_build() {
128 let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap();
129 assert_eq!(pool.current_num_threads(), 22);
130}
131
132/// Helper used by check_error_send_sync to ensure ThreadPoolBuildError is Send + Sync
133fn _send_sync<T: Send + Sync>() {}
134
135#[test]
136fn check_error_send_sync() {
137 _send_sync::<ThreadPoolBuildError>();
138}
139
140#[allow(deprecated)]
141#[test]
142#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
143fn configuration() {
144 let start_handler = move |_| {};
145 let exit_handler = move |_| {};
146 let panic_handler = move |_| {};
147 let thread_name = move |i| format!("thread_name_{}", i);
148
149 // Ensure we can call all public methods on Configuration
150 crate::Configuration::new()
151 .thread_name(thread_name)
152 .num_threads(5)
153 .panic_handler(panic_handler)
154 .stack_size(4e6 as usize)
155 .breadth_first()
156 .start_handler(start_handler)
157 .exit_handler(exit_handler)
158 .build()
159 .unwrap();
160}
161
162#[test]
163#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
164fn default_pool() {
165 ThreadPoolBuilder::default().build().unwrap();
166}
167
168/// Test that custom spawned threads get their `WorkerThread` cleared once
169/// the pool is done with them, allowing them to be used with rayon again
170/// later. e.g. WebAssembly want to have their own pool of available threads.
171#[test]
172#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
173fn cleared_current_thread() -> Result<(), ThreadPoolBuildError> {
174 let n_threads = 5;
175 let mut handles = vec![];
176 let pool = ThreadPoolBuilder::new()
177 .num_threads(n_threads)
178 .spawn_handler(|thread| {
179 let handle = std::thread::spawn(move || {
180 thread.run();
181
182 // Afterward, the current thread shouldn't be set anymore.
183 assert_eq!(crate::current_thread_index(), None);
184 });
185 handles.push(handle);
186 Ok(())
187 })
188 .build()?;
189 assert_eq!(handles.len(), n_threads);
190
191 pool.install(|| assert!(crate::current_thread_index().is_some()));
192 drop(pool);
193
194 // Wait for all threads to make their assertions and exit
195 for handle in handles {
196 handle.join().unwrap();
197 }
198
199 Ok(())
200}
201