1 | #![cfg (test)] |
2 | |
3 | use crate::{ThreadPoolBuildError, ThreadPoolBuilder}; |
4 | use std::sync::atomic::{AtomicUsize, Ordering}; |
5 | use std::sync::{Arc, Barrier}; |
6 | |
7 | #[test] |
8 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
9 | fn worker_thread_index() { |
10 | let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap(); |
11 | assert_eq!(pool.current_num_threads(), 22); |
12 | assert_eq!(pool.current_thread_index(), None); |
13 | let index = pool.install(|| pool.current_thread_index().unwrap()); |
14 | assert!(index < 22); |
15 | } |
16 | |
17 | #[test] |
18 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
19 | fn start_callback_called() { |
20 | let n_threads = 16; |
21 | let n_called = Arc::new(AtomicUsize::new(0)); |
22 | // Wait for all the threads in the pool plus the one running tests. |
23 | let barrier = Arc::new(Barrier::new(n_threads + 1)); |
24 | |
25 | let b = Arc::clone(&barrier); |
26 | let nc = Arc::clone(&n_called); |
27 | let start_handler = move |_| { |
28 | nc.fetch_add(1, Ordering::SeqCst); |
29 | b.wait(); |
30 | }; |
31 | |
32 | let conf = ThreadPoolBuilder::new() |
33 | .num_threads(n_threads) |
34 | .start_handler(start_handler); |
35 | let _ = conf.build().unwrap(); |
36 | |
37 | // Wait for all the threads to have been scheduled to run. |
38 | barrier.wait(); |
39 | |
40 | // The handler must have been called on every started thread. |
41 | assert_eq!(n_called.load(Ordering::SeqCst), n_threads); |
42 | } |
43 | |
44 | #[test] |
45 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
46 | fn exit_callback_called() { |
47 | let n_threads = 16; |
48 | let n_called = Arc::new(AtomicUsize::new(0)); |
49 | // Wait for all the threads in the pool plus the one running tests. |
50 | let barrier = Arc::new(Barrier::new(n_threads + 1)); |
51 | |
52 | let b = Arc::clone(&barrier); |
53 | let nc = Arc::clone(&n_called); |
54 | let exit_handler = move |_| { |
55 | nc.fetch_add(1, Ordering::SeqCst); |
56 | b.wait(); |
57 | }; |
58 | |
59 | let conf = ThreadPoolBuilder::new() |
60 | .num_threads(n_threads) |
61 | .exit_handler(exit_handler); |
62 | { |
63 | let _ = conf.build().unwrap(); |
64 | // Drop the pool so it stops the running threads. |
65 | } |
66 | |
67 | // Wait for all the threads to have been scheduled to run. |
68 | barrier.wait(); |
69 | |
70 | // The handler must have been called on every exiting thread. |
71 | assert_eq!(n_called.load(Ordering::SeqCst), n_threads); |
72 | } |
73 | |
74 | #[test] |
75 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
76 | fn handler_panics_handled_correctly() { |
77 | let n_threads = 16; |
78 | let n_called = Arc::new(AtomicUsize::new(0)); |
79 | // Wait for all the threads in the pool plus the one running tests. |
80 | let start_barrier = Arc::new(Barrier::new(n_threads + 1)); |
81 | let exit_barrier = Arc::new(Barrier::new(n_threads + 1)); |
82 | |
83 | let start_handler = move |_| { |
84 | panic!("ensure panic handler is called when starting" ); |
85 | }; |
86 | let exit_handler = move |_| { |
87 | panic!("ensure panic handler is called when exiting" ); |
88 | }; |
89 | |
90 | let sb = Arc::clone(&start_barrier); |
91 | let eb = Arc::clone(&exit_barrier); |
92 | let nc = Arc::clone(&n_called); |
93 | let panic_handler = move |_| { |
94 | let val = nc.fetch_add(1, Ordering::SeqCst); |
95 | if val < n_threads { |
96 | sb.wait(); |
97 | } else { |
98 | eb.wait(); |
99 | } |
100 | }; |
101 | |
102 | let conf = ThreadPoolBuilder::new() |
103 | .num_threads(n_threads) |
104 | .start_handler(start_handler) |
105 | .exit_handler(exit_handler) |
106 | .panic_handler(panic_handler); |
107 | { |
108 | let _ = conf.build().unwrap(); |
109 | |
110 | // Wait for all the threads to start, panic in the start handler, |
111 | // and been taken care of by the panic handler. |
112 | start_barrier.wait(); |
113 | |
114 | // Drop the pool so it stops the running threads. |
115 | } |
116 | |
117 | // Wait for all the threads to exit, panic in the exit handler, |
118 | // and been taken care of by the panic handler. |
119 | exit_barrier.wait(); |
120 | |
121 | // The panic handler must have been called twice on every thread. |
122 | assert_eq!(n_called.load(Ordering::SeqCst), 2 * n_threads); |
123 | } |
124 | |
125 | #[test] |
126 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
127 | fn check_config_build() { |
128 | let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap(); |
129 | assert_eq!(pool.current_num_threads(), 22); |
130 | } |
131 | |
132 | /// Helper used by check_error_send_sync to ensure ThreadPoolBuildError is Send + Sync |
133 | fn _send_sync<T: Send + Sync>() {} |
134 | |
135 | #[test] |
136 | fn check_error_send_sync() { |
137 | _send_sync::<ThreadPoolBuildError>(); |
138 | } |
139 | |
140 | #[allow (deprecated)] |
141 | #[test] |
142 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
143 | fn configuration() { |
144 | let start_handler = move |_| {}; |
145 | let exit_handler = move |_| {}; |
146 | let panic_handler = move |_| {}; |
147 | let thread_name = move |i| format!("thread_name_{}" , i); |
148 | |
149 | // Ensure we can call all public methods on Configuration |
150 | crate::Configuration::new() |
151 | .thread_name(thread_name) |
152 | .num_threads(5) |
153 | .panic_handler(panic_handler) |
154 | .stack_size(4e6 as usize) |
155 | .breadth_first() |
156 | .start_handler(start_handler) |
157 | .exit_handler(exit_handler) |
158 | .build() |
159 | .unwrap(); |
160 | } |
161 | |
162 | #[test] |
163 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
164 | fn default_pool() { |
165 | ThreadPoolBuilder::default().build().unwrap(); |
166 | } |
167 | |
168 | /// Test that custom spawned threads get their `WorkerThread` cleared once |
169 | /// the pool is done with them, allowing them to be used with rayon again |
170 | /// later. e.g. WebAssembly want to have their own pool of available threads. |
171 | #[test] |
172 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
173 | fn cleared_current_thread() -> Result<(), ThreadPoolBuildError> { |
174 | let n_threads = 5; |
175 | let mut handles = vec![]; |
176 | let pool = ThreadPoolBuilder::new() |
177 | .num_threads(n_threads) |
178 | .spawn_handler(|thread| { |
179 | let handle = std::thread::spawn(move || { |
180 | thread.run(); |
181 | |
182 | // Afterward, the current thread shouldn't be set anymore. |
183 | assert_eq!(crate::current_thread_index(), None); |
184 | }); |
185 | handles.push(handle); |
186 | Ok(()) |
187 | }) |
188 | .build()?; |
189 | assert_eq!(handles.len(), n_threads); |
190 | |
191 | pool.install(|| assert!(crate::current_thread_index().is_some())); |
192 | drop(pool); |
193 | |
194 | // Wait for all threads to make their assertions and exit |
195 | for handle in handles { |
196 | handle.join().unwrap(); |
197 | } |
198 | |
199 | Ok(()) |
200 | } |
201 | |