1use std::{
2 num::NonZeroUsize,
3 panic::AssertUnwindSafe,
4 ptr::NonNull,
5 sync::{
6 atomic::{AtomicUsize, Ordering},
7 mpsc, Mutex, PoisonError,
8 },
9 thread::Thread,
10};
11
12use crate::util::{defer, sync::SyncWrap};
13
14/// Single shared thread pool for running benchmarks on.
15pub(crate) static BENCH_POOL: ThreadPool = ThreadPool::new();
16
17/// Reusable threads for broadcasting tasks.
18///
19/// This thread pool runs only a single task at a time, since only one benchmark
20/// should run at a time. Invoking `broadcast` from two threads will cause one
21/// thread to wait for the other to finish.
22///
23/// # How It Works
24///
25/// Upon calling `broadcast`:
26///
27/// 1. The main thread creates a `Task`, which is a pointer to a `TaskShared`
28/// pinned on the stack. `TaskShared` stores the function to run, along with
29/// other fields for coordinating threads.
30///
31/// 2. New threads are spawned if the requested amount is not available. Each
32/// receives tasks over an associated channel.
33///
34/// 3. The main thread sends the `Task` over the channels to the requested
35/// amount of threads. Upon receiving the task, each auxiliary thread will
36/// execute it and then decrement the task's reference count.
37///
38/// 4. The main thread executes the `Task` like auxiliary threads. It then waits
39/// until the reference count is 0 before returning.
40pub(crate) struct ThreadPool {
41 threads: Mutex<Vec<mpsc::SyncSender<Task>>>,
42}
43
44impl ThreadPool {
45 const fn new() -> Self {
46 Self { threads: Mutex::new(Vec::new()) }
47 }
48
49 /// Performs the given task and pushes the results into a `vec`.
50 #[inline]
51 pub fn par_extend<T, F>(&self, vec: &mut Vec<Option<T>>, aux_threads: usize, task: F)
52 where
53 F: Sync + Fn(usize) -> T,
54 T: Sync + Send,
55 {
56 unsafe {
57 let old_len = vec.len();
58 let additional = aux_threads + 1;
59
60 vec.reserve_exact(additional);
61 vec.spare_capacity_mut().iter_mut().for_each(|val| {
62 val.write(None);
63 });
64 vec.set_len(old_len + additional);
65
66 let ptr = SyncWrap::new(vec.as_mut_ptr().add(old_len));
67
68 self.broadcast(aux_threads, move |index| {
69 ptr.add(index).write(Some(task(index)));
70 });
71 }
72 }
73
74 /// Performs the given task across the current thread and auxiliary worker
75 /// threads.
76 ///
77 /// This function returns once all threads complete the task.
78 #[inline]
79 pub fn broadcast<F>(&self, aux_threads: usize, task: F)
80 where
81 F: Sync + Fn(usize),
82 {
83 // SAFETY: The `TaskShared` instance is guaranteed to be accessible to
84 // all threads until this function returns, because this thread waits
85 // until `TaskShared.ref_count` is 0 before continuing.
86 unsafe {
87 let task = TaskShared::new(aux_threads, task);
88 let task = Task { shared: NonNull::from(&task).cast() };
89
90 self.broadcast_task(aux_threads, task);
91 }
92 }
93
94 /// Type-erased monomorphized implementation for `broadcast`.
95 unsafe fn broadcast_task(&self, aux_threads: usize, task: Task) {
96 // Send task to auxiliary threads.
97 if aux_threads > 0 {
98 let threads = &mut *self.threads.lock().unwrap_or_else(PoisonError::into_inner);
99
100 // Spawn more threads if necessary.
101 if let Some(additional) = NonZeroUsize::new(aux_threads.saturating_sub(threads.len())) {
102 spawn(additional, threads);
103 }
104
105 for thread in &threads[..aux_threads] {
106 thread.send(task).unwrap();
107 }
108 }
109
110 // Run the task on the main thread.
111 let main_result = std::panic::catch_unwind(AssertUnwindSafe(|| task.run(0)));
112
113 // Wait for other threads to finish writing their results.
114 //
115 // SAFETY: The acquire memory ordering ensures that all writes performed
116 // by the task on other threads will become visible to this thread after
117 // returning from `broadcast`.
118 while task.shared.as_ref().ref_count.load(Ordering::Acquire) > 0 {
119 std::thread::park();
120 }
121
122 // Don't drop our result until other threads finish, in case the panic
123 // error's drop handler itself also panics.
124 drop(main_result);
125 }
126
127 pub fn drop_threads(&self) {
128 *self.threads.lock().unwrap_or_else(PoisonError::into_inner) = Default::default();
129 }
130
131 #[cfg(test)]
132 fn aux_thread_count(&self) -> usize {
133 self.threads.lock().unwrap_or_else(PoisonError::into_inner).len()
134 }
135}
136
137/// Type-erased function and metadata.
138#[derive(Clone, Copy)]
139struct Task {
140 shared: NonNull<TaskShared<()>>,
141}
142
143unsafe impl Send for Task {}
144unsafe impl Sync for Task {}
145
146impl Task {
147 /// Runs this task on behalf of `thread_id`.
148 ///
149 /// # Safety
150 ///
151 /// The caller must ensure:
152 ///
153 /// - This task has not outlived the `TaskShared` it came from, or else
154 /// there will be a use-after-free.
155 ///
156 /// - `thread_id` is within the number of `broadcast` threads requested, so
157 /// that it can be used to index input or output buffers.
158 #[inline]
159 unsafe fn run(&self, thread_id: usize) {
160 let shared_ptr: *mut TaskShared<()> = self.shared.as_ptr();
161 let shared: &TaskShared<()> = &*shared_ptr;
162
163 (shared.task_fn_ptr)(shared_ptr.cast(), thread_id);
164 }
165}
166
167/// Data stored on the main thread that gets shared with auxiliary threads.
168///
169/// # Memory Layout
170///
171/// Since the benchmark may have thrashed the cache, this type's fields are
172/// ordered by usage order. This type is also placed on its own cache line.
173#[repr(C)]
174struct TaskShared<F> {
175 /// Once an auxiliary thread sets `ref_count` to 0, it should notify the
176 /// main thread to wake up.
177 main_thread: Thread,
178
179 /// The number of auxiliary threads executing the task.
180 ///
181 /// Once this is 0, the main thread can read any results the task produced.
182 ref_count: AtomicUsize,
183
184 /// Performs `*result = Some(task_fn(thread))`.
185 task_fn_ptr: unsafe fn(task: *const TaskShared<()>, thread: usize),
186
187 /// Stores the closure state of the provided task.
188 ///
189 /// This must be stored as the last field so that all other fields are in
190 /// the same place regardless of this field's type.
191 task_fn: F,
192}
193
194impl<F> TaskShared<F> {
195 #[inline]
196 fn new(aux_threads: usize, task_fn: F) -> Self
197 where
198 F: Sync + Fn(usize),
199 {
200 unsafe fn call<F>(task: *const TaskShared<()>, thread: usize)
201 where
202 F: Fn(usize),
203 {
204 let task_fn: &F = &(*task.cast::<TaskShared<F>>()).task_fn;
205
206 task_fn(thread);
207 }
208
209 Self {
210 main_thread: std::thread::current(),
211 ref_count: AtomicUsize::new(aux_threads),
212 task_fn_ptr: call::<F>,
213 task_fn,
214 }
215 }
216}
217
218/// Spawns N additional threads and appends their channels to the list.
219///
220/// Threads are given names in the form of `divan-$INDEX`.
221#[cold]
222fn spawn(additional: NonZeroUsize, threads: &mut Vec<mpsc::SyncSender<Task>>) {
223 let next_thread_id = threads.len() + 1;
224
225 threads.extend((next_thread_id..(next_thread_id + additional.get())).map(|thread_id| {
226 // Create single-task channel. Unless another benchmark is running, the
227 // current thread will be immediately unblocked after the auxiliary
228 // thread accepts the task.
229 //
230 // This uses a rendezvous channel (capacity 0) instead of other standard
231 // library channels because it reduces memory usage by many kilobytes.
232 let (sender, receiver) = mpsc::sync_channel::<Task>(0);
233
234 let work = move || {
235 // Abort the process if the caught panic error itself panics when
236 // dropped.
237 let panic_guard = defer(|| std::process::abort());
238
239 while let Ok(task) = receiver.recv() {
240 // Run the task on this auxiliary thread.
241 //
242 // SAFETY: The task is valid until `ref_count == 0`.
243 let result =
244 std::panic::catch_unwind(AssertUnwindSafe(|| unsafe { task.run(thread_id) }));
245
246 // Decrement the `ref_count` count to notify the main thread
247 // that we finished our work.
248 //
249 // SAFETY: This release operation makes writes within the task
250 // become visible to the main thread.
251 unsafe {
252 // Clone the main thread's handle for unparking because the
253 // `TaskShared` will be invalidated when `ref_count` is 0.
254 let main_thread = task.shared.as_ref().main_thread.clone();
255
256 if task.shared.as_ref().ref_count.fetch_sub(1, Ordering::Release) == 1 {
257 main_thread.unpark();
258 }
259 }
260
261 // Don't drop our result until after notifying the main thread,
262 // in case the panic error's drop handler itself also panics.
263 drop(result);
264 }
265
266 std::mem::forget(panic_guard);
267 };
268
269 std::thread::Builder::new()
270 .name(format!("divan-{thread_id}"))
271 .spawn(work)
272 .expect("failed to spawn thread");
273
274 sender
275 }));
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 /// Make every thread write its ID to a buffer and then check that the
283 /// buffer contains all IDs.
284 #[test]
285 fn extend() {
286 static TEST_POOL: ThreadPool = ThreadPool::new();
287
288 fn test(aux_threads: usize, final_aux_threads: usize) {
289 let total_threads = aux_threads + 1;
290
291 let mut results = Vec::new();
292 let expected = (0..total_threads).map(Some).collect::<Vec<_>>();
293
294 TEST_POOL.par_extend(&mut results, aux_threads, |index| index);
295
296 assert_eq!(results, expected);
297 assert_eq!(TEST_POOL.aux_thread_count(), final_aux_threads);
298 }
299
300 test(0, 0);
301 test(1, 1);
302 test(2, 2);
303 test(3, 3);
304 test(4, 4);
305 test(8, 8);
306
307 // Decreasing auxiliary threads on later calls should still leave
308 // previously spawned threads running.
309 test(4, 8);
310 test(0, 8);
311
312 // Silence Miri about leaking threads.
313 TEST_POOL.drop_threads();
314 }
315
316 /// Execute a task that takes longer on all other threads than the main
317 /// thread.
318 #[test]
319 fn broadcast_sleep() {
320 use std::time::Duration;
321
322 static TEST_POOL: ThreadPool = ThreadPool::new();
323
324 TEST_POOL.broadcast(10, |thread_id| {
325 if thread_id > 0 {
326 std::thread::sleep(Duration::from_millis(10));
327 }
328 });
329
330 // Silence Miri about leaking threads.
331 TEST_POOL.drop_threads();
332 }
333
334 /// Checks that thread ID 0 refers to the main thread.
335 #[test]
336 fn broadcast_thread_id() {
337 static TEST_POOL: ThreadPool = ThreadPool::new();
338
339 let main_thread = std::thread::current().id();
340
341 TEST_POOL.broadcast(10, |thread_id| {
342 let is_main = main_thread == std::thread::current().id();
343 assert_eq!(is_main, thread_id == 0);
344 });
345
346 // Silence Miri about leaking threads.
347 TEST_POOL.drop_threads();
348 }
349}
350
351#[cfg(feature = "internal_benches")]
352mod benches {
353 use super::*;
354
355 fn aux_thread_counts() -> impl Iterator<Item = usize> {
356 let mut available_parallelism = std::thread::available_parallelism().ok().map(|n| n.get());
357
358 let range = 0..=16;
359
360 if let Some(n) = available_parallelism {
361 if range.contains(&n) {
362 available_parallelism = None;
363 }
364 }
365
366 range.chain(available_parallelism)
367 }
368
369 /// Benchmarks repeatedly using `ThreadPool` for the same number of threads
370 /// on every run.
371 #[crate::bench(crate = crate, args = aux_thread_counts())]
372 fn broadcast(bencher: crate::Bencher, aux_threads: usize) {
373 let pool = ThreadPool::new();
374 let benched = move || pool.broadcast(aux_threads, crate::black_box_drop);
375
376 // Warmup to spawn threads.
377 benched();
378
379 bencher.bench(benched);
380 }
381
382 /// Benchmarks using `ThreadPool` once.
383 #[crate::bench(crate = crate, args = aux_thread_counts(), sample_size = 1)]
384 fn broadcast_once(bencher: crate::Bencher, aux_threads: usize) {
385 bencher
386 .with_inputs(ThreadPool::new)
387 .bench_refs(|pool| pool.broadcast(aux_threads, crate::black_box_drop));
388 }
389}
390