| 1 | use std::{ |
| 2 | num::NonZeroUsize, |
| 3 | panic::AssertUnwindSafe, |
| 4 | ptr::NonNull, |
| 5 | sync::{ |
| 6 | atomic::{AtomicUsize, Ordering}, |
| 7 | mpsc, Mutex, PoisonError, |
| 8 | }, |
| 9 | thread::Thread, |
| 10 | }; |
| 11 | |
| 12 | use crate::util::{defer, sync::SyncWrap}; |
| 13 | |
| 14 | /// Single shared thread pool for running benchmarks on. |
| 15 | pub(crate) static BENCH_POOL: ThreadPool = ThreadPool::new(); |
| 16 | |
| 17 | /// Reusable threads for broadcasting tasks. |
| 18 | /// |
| 19 | /// This thread pool runs only a single task at a time, since only one benchmark |
| 20 | /// should run at a time. Invoking `broadcast` from two threads will cause one |
| 21 | /// thread to wait for the other to finish. |
| 22 | /// |
| 23 | /// # How It Works |
| 24 | /// |
| 25 | /// Upon calling `broadcast`: |
| 26 | /// |
| 27 | /// 1. The main thread creates a `Task`, which is a pointer to a `TaskShared` |
| 28 | /// pinned on the stack. `TaskShared` stores the function to run, along with |
| 29 | /// other fields for coordinating threads. |
| 30 | /// |
| 31 | /// 2. New threads are spawned if the requested amount is not available. Each |
| 32 | /// receives tasks over an associated channel. |
| 33 | /// |
| 34 | /// 3. The main thread sends the `Task` over the channels to the requested |
| 35 | /// amount of threads. Upon receiving the task, each auxiliary thread will |
| 36 | /// execute it and then decrement the task's reference count. |
| 37 | /// |
| 38 | /// 4. The main thread executes the `Task` like auxiliary threads. It then waits |
| 39 | /// until the reference count is 0 before returning. |
| 40 | pub(crate) struct ThreadPool { |
| 41 | threads: Mutex<Vec<mpsc::SyncSender<Task>>>, |
| 42 | } |
| 43 | |
| 44 | impl ThreadPool { |
| 45 | const fn new() -> Self { |
| 46 | Self { threads: Mutex::new(Vec::new()) } |
| 47 | } |
| 48 | |
| 49 | /// Performs the given task and pushes the results into a `vec`. |
| 50 | #[inline ] |
| 51 | pub fn par_extend<T, F>(&self, vec: &mut Vec<Option<T>>, aux_threads: usize, task: F) |
| 52 | where |
| 53 | F: Sync + Fn(usize) -> T, |
| 54 | T: Sync + Send, |
| 55 | { |
| 56 | unsafe { |
| 57 | let old_len = vec.len(); |
| 58 | let additional = aux_threads + 1; |
| 59 | |
| 60 | vec.reserve_exact(additional); |
| 61 | vec.spare_capacity_mut().iter_mut().for_each(|val| { |
| 62 | val.write(None); |
| 63 | }); |
| 64 | vec.set_len(old_len + additional); |
| 65 | |
| 66 | let ptr = SyncWrap::new(vec.as_mut_ptr().add(old_len)); |
| 67 | |
| 68 | self.broadcast(aux_threads, move |index| { |
| 69 | ptr.add(index).write(Some(task(index))); |
| 70 | }); |
| 71 | } |
| 72 | } |
| 73 | |
| 74 | /// Performs the given task across the current thread and auxiliary worker |
| 75 | /// threads. |
| 76 | /// |
| 77 | /// This function returns once all threads complete the task. |
| 78 | #[inline ] |
| 79 | pub fn broadcast<F>(&self, aux_threads: usize, task: F) |
| 80 | where |
| 81 | F: Sync + Fn(usize), |
| 82 | { |
| 83 | // SAFETY: The `TaskShared` instance is guaranteed to be accessible to |
| 84 | // all threads until this function returns, because this thread waits |
| 85 | // until `TaskShared.ref_count` is 0 before continuing. |
| 86 | unsafe { |
| 87 | let task = TaskShared::new(aux_threads, task); |
| 88 | let task = Task { shared: NonNull::from(&task).cast() }; |
| 89 | |
| 90 | self.broadcast_task(aux_threads, task); |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | /// Type-erased monomorphized implementation for `broadcast`. |
| 95 | unsafe fn broadcast_task(&self, aux_threads: usize, task: Task) { |
| 96 | // Send task to auxiliary threads. |
| 97 | if aux_threads > 0 { |
| 98 | let threads = &mut *self.threads.lock().unwrap_or_else(PoisonError::into_inner); |
| 99 | |
| 100 | // Spawn more threads if necessary. |
| 101 | if let Some(additional) = NonZeroUsize::new(aux_threads.saturating_sub(threads.len())) { |
| 102 | spawn(additional, threads); |
| 103 | } |
| 104 | |
| 105 | for thread in &threads[..aux_threads] { |
| 106 | thread.send(task).unwrap(); |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | // Run the task on the main thread. |
| 111 | let main_result = std::panic::catch_unwind(AssertUnwindSafe(|| task.run(0))); |
| 112 | |
| 113 | // Wait for other threads to finish writing their results. |
| 114 | // |
| 115 | // SAFETY: The acquire memory ordering ensures that all writes performed |
| 116 | // by the task on other threads will become visible to this thread after |
| 117 | // returning from `broadcast`. |
| 118 | while task.shared.as_ref().ref_count.load(Ordering::Acquire) > 0 { |
| 119 | std::thread::park(); |
| 120 | } |
| 121 | |
| 122 | // Don't drop our result until other threads finish, in case the panic |
| 123 | // error's drop handler itself also panics. |
| 124 | drop(main_result); |
| 125 | } |
| 126 | |
| 127 | pub fn drop_threads(&self) { |
| 128 | *self.threads.lock().unwrap_or_else(PoisonError::into_inner) = Default::default(); |
| 129 | } |
| 130 | |
| 131 | #[cfg (test)] |
| 132 | fn aux_thread_count(&self) -> usize { |
| 133 | self.threads.lock().unwrap_or_else(PoisonError::into_inner).len() |
| 134 | } |
| 135 | } |
| 136 | |
| 137 | /// Type-erased function and metadata. |
| 138 | #[derive (Clone, Copy)] |
| 139 | struct Task { |
| 140 | shared: NonNull<TaskShared<()>>, |
| 141 | } |
| 142 | |
| 143 | unsafe impl Send for Task {} |
| 144 | unsafe impl Sync for Task {} |
| 145 | |
| 146 | impl Task { |
| 147 | /// Runs this task on behalf of `thread_id`. |
| 148 | /// |
| 149 | /// # Safety |
| 150 | /// |
| 151 | /// The caller must ensure: |
| 152 | /// |
| 153 | /// - This task has not outlived the `TaskShared` it came from, or else |
| 154 | /// there will be a use-after-free. |
| 155 | /// |
| 156 | /// - `thread_id` is within the number of `broadcast` threads requested, so |
| 157 | /// that it can be used to index input or output buffers. |
| 158 | #[inline ] |
| 159 | unsafe fn run(&self, thread_id: usize) { |
| 160 | let shared_ptr: *mut TaskShared<()> = self.shared.as_ptr(); |
| 161 | let shared: &TaskShared<()> = &*shared_ptr; |
| 162 | |
| 163 | (shared.task_fn_ptr)(shared_ptr.cast(), thread_id); |
| 164 | } |
| 165 | } |
| 166 | |
| 167 | /// Data stored on the main thread that gets shared with auxiliary threads. |
| 168 | /// |
| 169 | /// # Memory Layout |
| 170 | /// |
| 171 | /// Since the benchmark may have thrashed the cache, this type's fields are |
| 172 | /// ordered by usage order. This type is also placed on its own cache line. |
| 173 | #[repr (C)] |
| 174 | struct TaskShared<F> { |
| 175 | /// Once an auxiliary thread sets `ref_count` to 0, it should notify the |
| 176 | /// main thread to wake up. |
| 177 | main_thread: Thread, |
| 178 | |
| 179 | /// The number of auxiliary threads executing the task. |
| 180 | /// |
| 181 | /// Once this is 0, the main thread can read any results the task produced. |
| 182 | ref_count: AtomicUsize, |
| 183 | |
| 184 | /// Performs `*result = Some(task_fn(thread))`. |
| 185 | task_fn_ptr: unsafe fn(task: *const TaskShared<()>, thread: usize), |
| 186 | |
| 187 | /// Stores the closure state of the provided task. |
| 188 | /// |
| 189 | /// This must be stored as the last field so that all other fields are in |
| 190 | /// the same place regardless of this field's type. |
| 191 | task_fn: F, |
| 192 | } |
| 193 | |
| 194 | impl<F> TaskShared<F> { |
| 195 | #[inline ] |
| 196 | fn new(aux_threads: usize, task_fn: F) -> Self |
| 197 | where |
| 198 | F: Sync + Fn(usize), |
| 199 | { |
| 200 | unsafe fn call<F>(task: *const TaskShared<()>, thread: usize) |
| 201 | where |
| 202 | F: Fn(usize), |
| 203 | { |
| 204 | let task_fn: &F = &(*task.cast::<TaskShared<F>>()).task_fn; |
| 205 | |
| 206 | task_fn(thread); |
| 207 | } |
| 208 | |
| 209 | Self { |
| 210 | main_thread: std::thread::current(), |
| 211 | ref_count: AtomicUsize::new(aux_threads), |
| 212 | task_fn_ptr: call::<F>, |
| 213 | task_fn, |
| 214 | } |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | /// Spawns N additional threads and appends their channels to the list. |
| 219 | /// |
| 220 | /// Threads are given names in the form of `divan-$INDEX`. |
| 221 | #[cold ] |
| 222 | fn spawn(additional: NonZeroUsize, threads: &mut Vec<mpsc::SyncSender<Task>>) { |
| 223 | let next_thread_id = threads.len() + 1; |
| 224 | |
| 225 | threads.extend((next_thread_id..(next_thread_id + additional.get())).map(|thread_id| { |
| 226 | // Create single-task channel. Unless another benchmark is running, the |
| 227 | // current thread will be immediately unblocked after the auxiliary |
| 228 | // thread accepts the task. |
| 229 | // |
| 230 | // This uses a rendezvous channel (capacity 0) instead of other standard |
| 231 | // library channels because it reduces memory usage by many kilobytes. |
| 232 | let (sender, receiver) = mpsc::sync_channel::<Task>(0); |
| 233 | |
| 234 | let work = move || { |
| 235 | // Abort the process if the caught panic error itself panics when |
| 236 | // dropped. |
| 237 | let panic_guard = defer(|| std::process::abort()); |
| 238 | |
| 239 | while let Ok(task) = receiver.recv() { |
| 240 | // Run the task on this auxiliary thread. |
| 241 | // |
| 242 | // SAFETY: The task is valid until `ref_count == 0`. |
| 243 | let result = |
| 244 | std::panic::catch_unwind(AssertUnwindSafe(|| unsafe { task.run(thread_id) })); |
| 245 | |
| 246 | // Decrement the `ref_count` count to notify the main thread |
| 247 | // that we finished our work. |
| 248 | // |
| 249 | // SAFETY: This release operation makes writes within the task |
| 250 | // become visible to the main thread. |
| 251 | unsafe { |
| 252 | // Clone the main thread's handle for unparking because the |
| 253 | // `TaskShared` will be invalidated when `ref_count` is 0. |
| 254 | let main_thread = task.shared.as_ref().main_thread.clone(); |
| 255 | |
| 256 | if task.shared.as_ref().ref_count.fetch_sub(1, Ordering::Release) == 1 { |
| 257 | main_thread.unpark(); |
| 258 | } |
| 259 | } |
| 260 | |
| 261 | // Don't drop our result until after notifying the main thread, |
| 262 | // in case the panic error's drop handler itself also panics. |
| 263 | drop(result); |
| 264 | } |
| 265 | |
| 266 | std::mem::forget(panic_guard); |
| 267 | }; |
| 268 | |
| 269 | std::thread::Builder::new() |
| 270 | .name(format!("divan- {thread_id}" )) |
| 271 | .spawn(work) |
| 272 | .expect("failed to spawn thread" ); |
| 273 | |
| 274 | sender |
| 275 | })); |
| 276 | } |
| 277 | |
| 278 | #[cfg (test)] |
| 279 | mod tests { |
| 280 | use super::*; |
| 281 | |
| 282 | /// Make every thread write its ID to a buffer and then check that the |
| 283 | /// buffer contains all IDs. |
| 284 | #[test ] |
| 285 | fn extend() { |
| 286 | static TEST_POOL: ThreadPool = ThreadPool::new(); |
| 287 | |
| 288 | fn test(aux_threads: usize, final_aux_threads: usize) { |
| 289 | let total_threads = aux_threads + 1; |
| 290 | |
| 291 | let mut results = Vec::new(); |
| 292 | let expected = (0..total_threads).map(Some).collect::<Vec<_>>(); |
| 293 | |
| 294 | TEST_POOL.par_extend(&mut results, aux_threads, |index| index); |
| 295 | |
| 296 | assert_eq!(results, expected); |
| 297 | assert_eq!(TEST_POOL.aux_thread_count(), final_aux_threads); |
| 298 | } |
| 299 | |
| 300 | test (0, 0); |
| 301 | test (1, 1); |
| 302 | test (2, 2); |
| 303 | test (3, 3); |
| 304 | test (4, 4); |
| 305 | test (8, 8); |
| 306 | |
| 307 | // Decreasing auxiliary threads on later calls should still leave |
| 308 | // previously spawned threads running. |
| 309 | test (4, 8); |
| 310 | test (0, 8); |
| 311 | |
| 312 | // Silence Miri about leaking threads. |
| 313 | TEST_POOL.drop_threads(); |
| 314 | } |
| 315 | |
| 316 | /// Execute a task that takes longer on all other threads than the main |
| 317 | /// thread. |
| 318 | #[test ] |
| 319 | fn broadcast_sleep() { |
| 320 | use std::time::Duration; |
| 321 | |
| 322 | static TEST_POOL: ThreadPool = ThreadPool::new(); |
| 323 | |
| 324 | TEST_POOL.broadcast(10, |thread_id| { |
| 325 | if thread_id > 0 { |
| 326 | std::thread::sleep(Duration::from_millis(10)); |
| 327 | } |
| 328 | }); |
| 329 | |
| 330 | // Silence Miri about leaking threads. |
| 331 | TEST_POOL.drop_threads(); |
| 332 | } |
| 333 | |
| 334 | /// Checks that thread ID 0 refers to the main thread. |
| 335 | #[test ] |
| 336 | fn broadcast_thread_id() { |
| 337 | static TEST_POOL: ThreadPool = ThreadPool::new(); |
| 338 | |
| 339 | let main_thread = std::thread::current().id(); |
| 340 | |
| 341 | TEST_POOL.broadcast(10, |thread_id| { |
| 342 | let is_main = main_thread == std::thread::current().id(); |
| 343 | assert_eq!(is_main, thread_id == 0); |
| 344 | }); |
| 345 | |
| 346 | // Silence Miri about leaking threads. |
| 347 | TEST_POOL.drop_threads(); |
| 348 | } |
| 349 | } |
| 350 | |
| 351 | #[cfg (feature = "internal_benches" )] |
| 352 | mod benches { |
| 353 | use super::*; |
| 354 | |
| 355 | fn aux_thread_counts() -> impl Iterator<Item = usize> { |
| 356 | let mut available_parallelism = std::thread::available_parallelism().ok().map(|n| n.get()); |
| 357 | |
| 358 | let range = 0..=16; |
| 359 | |
| 360 | if let Some(n) = available_parallelism { |
| 361 | if range.contains(&n) { |
| 362 | available_parallelism = None; |
| 363 | } |
| 364 | } |
| 365 | |
| 366 | range.chain(available_parallelism) |
| 367 | } |
| 368 | |
| 369 | /// Benchmarks repeatedly using `ThreadPool` for the same number of threads |
| 370 | /// on every run. |
| 371 | #[crate::bench (crate = crate, args = aux_thread_counts())] |
| 372 | fn broadcast(bencher: crate::Bencher, aux_threads: usize) { |
| 373 | let pool = ThreadPool::new(); |
| 374 | let benched = move || pool.broadcast(aux_threads, crate::black_box_drop); |
| 375 | |
| 376 | // Warmup to spawn threads. |
| 377 | benched(); |
| 378 | |
| 379 | bencher.bench(benched); |
| 380 | } |
| 381 | |
| 382 | /// Benchmarks using `ThreadPool` once. |
| 383 | #[crate::bench (crate = crate, args = aux_thread_counts(), sample_size = 1)] |
| 384 | fn broadcast_once(bencher: crate::Bencher, aux_threads: usize) { |
| 385 | bencher |
| 386 | .with_inputs(ThreadPool::new) |
| 387 | .bench_refs(|pool| pool.broadcast(aux_threads, crate::black_box_drop)); |
| 388 | } |
| 389 | } |
| 390 | |