| 1 | #![cfg (test)] |
| 2 | |
| 3 | use crate::{ThreadPoolBuildError, ThreadPoolBuilder}; |
| 4 | use std::sync::atomic::{AtomicUsize, Ordering}; |
| 5 | use std::sync::{Arc, Barrier}; |
| 6 | |
| 7 | #[test] |
| 8 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 9 | fn worker_thread_index() { |
| 10 | let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap(); |
| 11 | assert_eq!(pool.current_num_threads(), 22); |
| 12 | assert_eq!(pool.current_thread_index(), None); |
| 13 | let index = pool.install(|| pool.current_thread_index().unwrap()); |
| 14 | assert!(index < 22); |
| 15 | } |
| 16 | |
| 17 | #[test] |
| 18 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 19 | fn start_callback_called() { |
| 20 | let n_threads = 16; |
| 21 | let n_called = Arc::new(AtomicUsize::new(0)); |
| 22 | // Wait for all the threads in the pool plus the one running tests. |
| 23 | let barrier = Arc::new(Barrier::new(n_threads + 1)); |
| 24 | |
| 25 | let b = Arc::clone(&barrier); |
| 26 | let nc = Arc::clone(&n_called); |
| 27 | let start_handler = move |_| { |
| 28 | nc.fetch_add(1, Ordering::SeqCst); |
| 29 | b.wait(); |
| 30 | }; |
| 31 | |
| 32 | let conf = ThreadPoolBuilder::new() |
| 33 | .num_threads(n_threads) |
| 34 | .start_handler(start_handler); |
| 35 | let _ = conf.build().unwrap(); |
| 36 | |
| 37 | // Wait for all the threads to have been scheduled to run. |
| 38 | barrier.wait(); |
| 39 | |
| 40 | // The handler must have been called on every started thread. |
| 41 | assert_eq!(n_called.load(Ordering::SeqCst), n_threads); |
| 42 | } |
| 43 | |
| 44 | #[test] |
| 45 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 46 | fn exit_callback_called() { |
| 47 | let n_threads = 16; |
| 48 | let n_called = Arc::new(AtomicUsize::new(0)); |
| 49 | // Wait for all the threads in the pool plus the one running tests. |
| 50 | let barrier = Arc::new(Barrier::new(n_threads + 1)); |
| 51 | |
| 52 | let b = Arc::clone(&barrier); |
| 53 | let nc = Arc::clone(&n_called); |
| 54 | let exit_handler = move |_| { |
| 55 | nc.fetch_add(1, Ordering::SeqCst); |
| 56 | b.wait(); |
| 57 | }; |
| 58 | |
| 59 | let conf = ThreadPoolBuilder::new() |
| 60 | .num_threads(n_threads) |
| 61 | .exit_handler(exit_handler); |
| 62 | { |
| 63 | let _ = conf.build().unwrap(); |
| 64 | // Drop the pool so it stops the running threads. |
| 65 | } |
| 66 | |
| 67 | // Wait for all the threads to have been scheduled to run. |
| 68 | barrier.wait(); |
| 69 | |
| 70 | // The handler must have been called on every exiting thread. |
| 71 | assert_eq!(n_called.load(Ordering::SeqCst), n_threads); |
| 72 | } |
| 73 | |
| 74 | #[test] |
| 75 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 76 | fn handler_panics_handled_correctly() { |
| 77 | let n_threads = 16; |
| 78 | let n_called = Arc::new(AtomicUsize::new(0)); |
| 79 | // Wait for all the threads in the pool plus the one running tests. |
| 80 | let start_barrier = Arc::new(Barrier::new(n_threads + 1)); |
| 81 | let exit_barrier = Arc::new(Barrier::new(n_threads + 1)); |
| 82 | |
| 83 | let start_handler = move |_| { |
| 84 | panic!("ensure panic handler is called when starting" ); |
| 85 | }; |
| 86 | let exit_handler = move |_| { |
| 87 | panic!("ensure panic handler is called when exiting" ); |
| 88 | }; |
| 89 | |
| 90 | let sb = Arc::clone(&start_barrier); |
| 91 | let eb = Arc::clone(&exit_barrier); |
| 92 | let nc = Arc::clone(&n_called); |
| 93 | let panic_handler = move |_| { |
| 94 | let val = nc.fetch_add(1, Ordering::SeqCst); |
| 95 | if val < n_threads { |
| 96 | sb.wait(); |
| 97 | } else { |
| 98 | eb.wait(); |
| 99 | } |
| 100 | }; |
| 101 | |
| 102 | let conf = ThreadPoolBuilder::new() |
| 103 | .num_threads(n_threads) |
| 104 | .start_handler(start_handler) |
| 105 | .exit_handler(exit_handler) |
| 106 | .panic_handler(panic_handler); |
| 107 | { |
| 108 | let _ = conf.build().unwrap(); |
| 109 | |
| 110 | // Wait for all the threads to start, panic in the start handler, |
| 111 | // and been taken care of by the panic handler. |
| 112 | start_barrier.wait(); |
| 113 | |
| 114 | // Drop the pool so it stops the running threads. |
| 115 | } |
| 116 | |
| 117 | // Wait for all the threads to exit, panic in the exit handler, |
| 118 | // and been taken care of by the panic handler. |
| 119 | exit_barrier.wait(); |
| 120 | |
| 121 | // The panic handler must have been called twice on every thread. |
| 122 | assert_eq!(n_called.load(Ordering::SeqCst), 2 * n_threads); |
| 123 | } |
| 124 | |
| 125 | #[test] |
| 126 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 127 | fn check_config_build() { |
| 128 | let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap(); |
| 129 | assert_eq!(pool.current_num_threads(), 22); |
| 130 | } |
| 131 | |
| 132 | /// Helper used by check_error_send_sync to ensure ThreadPoolBuildError is Send + Sync |
| 133 | fn _send_sync<T: Send + Sync>() {} |
| 134 | |
| 135 | #[test] |
| 136 | fn check_error_send_sync() { |
| 137 | _send_sync::<ThreadPoolBuildError>(); |
| 138 | } |
| 139 | |
| 140 | #[allow (deprecated)] |
| 141 | #[test] |
| 142 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 143 | fn configuration() { |
| 144 | let start_handler = move |_| {}; |
| 145 | let exit_handler = move |_| {}; |
| 146 | let panic_handler = move |_| {}; |
| 147 | let thread_name = move |i| format!("thread_name_{}" , i); |
| 148 | |
| 149 | // Ensure we can call all public methods on Configuration |
| 150 | crate::Configuration::new() |
| 151 | .thread_name(thread_name) |
| 152 | .num_threads(5) |
| 153 | .panic_handler(panic_handler) |
| 154 | .stack_size(4e6 as usize) |
| 155 | .breadth_first() |
| 156 | .start_handler(start_handler) |
| 157 | .exit_handler(exit_handler) |
| 158 | .build() |
| 159 | .unwrap(); |
| 160 | } |
| 161 | |
| 162 | #[test] |
| 163 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 164 | fn default_pool() { |
| 165 | ThreadPoolBuilder::default().build().unwrap(); |
| 166 | } |
| 167 | |
| 168 | /// Test that custom spawned threads get their `WorkerThread` cleared once |
| 169 | /// the pool is done with them, allowing them to be used with rayon again |
| 170 | /// later. e.g. WebAssembly want to have their own pool of available threads. |
| 171 | #[test] |
| 172 | #[cfg_attr (any(target_os = "emscripten" , target_family = "wasm" ), ignore)] |
| 173 | fn cleared_current_thread() -> Result<(), ThreadPoolBuildError> { |
| 174 | let n_threads = 5; |
| 175 | let mut handles = vec![]; |
| 176 | let pool = ThreadPoolBuilder::new() |
| 177 | .num_threads(n_threads) |
| 178 | .spawn_handler(|thread| { |
| 179 | let handle = std::thread::spawn(move || { |
| 180 | thread.run(); |
| 181 | |
| 182 | // Afterward, the current thread shouldn't be set anymore. |
| 183 | assert_eq!(crate::current_thread_index(), None); |
| 184 | }); |
| 185 | handles.push(handle); |
| 186 | Ok(()) |
| 187 | }) |
| 188 | .build()?; |
| 189 | assert_eq!(handles.len(), n_threads); |
| 190 | |
| 191 | pool.install(|| assert!(crate::current_thread_index().is_some())); |
| 192 | drop(pool); |
| 193 | |
| 194 | // Wait for all threads to make their assertions and exit |
| 195 | for handle in handles { |
| 196 | handle.join().unwrap(); |
| 197 | } |
| 198 | |
| 199 | Ok(()) |
| 200 | } |
| 201 | |