1 | use std::{ |
2 | num::NonZeroUsize, |
3 | panic::AssertUnwindSafe, |
4 | ptr::NonNull, |
5 | sync::{ |
6 | atomic::{AtomicUsize, Ordering}, |
7 | mpsc, Mutex, PoisonError, |
8 | }, |
9 | thread::Thread, |
10 | }; |
11 | |
12 | use crate::util::{defer, sync::SyncWrap}; |
13 | |
14 | /// Single shared thread pool for running benchmarks on. |
15 | pub(crate) static BENCH_POOL: ThreadPool = ThreadPool::new(); |
16 | |
17 | /// Reusable threads for broadcasting tasks. |
18 | /// |
19 | /// This thread pool runs only a single task at a time, since only one benchmark |
20 | /// should run at a time. Invoking `broadcast` from two threads will cause one |
21 | /// thread to wait for the other to finish. |
22 | /// |
23 | /// # How It Works |
24 | /// |
25 | /// Upon calling `broadcast`: |
26 | /// |
27 | /// 1. The main thread creates a `Task`, which is a pointer to a `TaskShared` |
28 | /// pinned on the stack. `TaskShared` stores the function to run, along with |
29 | /// other fields for coordinating threads. |
30 | /// |
31 | /// 2. New threads are spawned if the requested amount is not available. Each |
32 | /// receives tasks over an associated channel. |
33 | /// |
34 | /// 3. The main thread sends the `Task` over the channels to the requested |
35 | /// amount of threads. Upon receiving the task, each auxiliary thread will |
36 | /// execute it and then decrement the task's reference count. |
37 | /// |
38 | /// 4. The main thread executes the `Task` like auxiliary threads. It then waits |
39 | /// until the reference count is 0 before returning. |
40 | pub(crate) struct ThreadPool { |
41 | threads: Mutex<Vec<mpsc::SyncSender<Task>>>, |
42 | } |
43 | |
44 | impl ThreadPool { |
45 | const fn new() -> Self { |
46 | Self { threads: Mutex::new(Vec::new()) } |
47 | } |
48 | |
49 | /// Performs the given task and pushes the results into a `vec`. |
50 | #[inline ] |
51 | pub fn par_extend<T, F>(&self, vec: &mut Vec<Option<T>>, aux_threads: usize, task: F) |
52 | where |
53 | F: Sync + Fn(usize) -> T, |
54 | T: Sync + Send, |
55 | { |
56 | unsafe { |
57 | let old_len = vec.len(); |
58 | let additional = aux_threads + 1; |
59 | |
60 | vec.reserve_exact(additional); |
61 | vec.spare_capacity_mut().iter_mut().for_each(|val| { |
62 | val.write(None); |
63 | }); |
64 | vec.set_len(old_len + additional); |
65 | |
66 | let ptr = SyncWrap::new(vec.as_mut_ptr().add(old_len)); |
67 | |
68 | self.broadcast(aux_threads, move |index| { |
69 | ptr.add(index).write(Some(task(index))); |
70 | }); |
71 | } |
72 | } |
73 | |
74 | /// Performs the given task across the current thread and auxiliary worker |
75 | /// threads. |
76 | /// |
77 | /// This function returns once all threads complete the task. |
78 | #[inline ] |
79 | pub fn broadcast<F>(&self, aux_threads: usize, task: F) |
80 | where |
81 | F: Sync + Fn(usize), |
82 | { |
83 | // SAFETY: The `TaskShared` instance is guaranteed to be accessible to |
84 | // all threads until this function returns, because this thread waits |
85 | // until `TaskShared.ref_count` is 0 before continuing. |
86 | unsafe { |
87 | let task = TaskShared::new(aux_threads, task); |
88 | let task = Task { shared: NonNull::from(&task).cast() }; |
89 | |
90 | self.broadcast_task(aux_threads, task); |
91 | } |
92 | } |
93 | |
94 | /// Type-erased monomorphized implementation for `broadcast`. |
95 | unsafe fn broadcast_task(&self, aux_threads: usize, task: Task) { |
96 | // Send task to auxiliary threads. |
97 | if aux_threads > 0 { |
98 | let threads = &mut *self.threads.lock().unwrap_or_else(PoisonError::into_inner); |
99 | |
100 | // Spawn more threads if necessary. |
101 | if let Some(additional) = NonZeroUsize::new(aux_threads.saturating_sub(threads.len())) { |
102 | spawn(additional, threads); |
103 | } |
104 | |
105 | for thread in &threads[..aux_threads] { |
106 | thread.send(task).unwrap(); |
107 | } |
108 | } |
109 | |
110 | // Run the task on the main thread. |
111 | let main_result = std::panic::catch_unwind(AssertUnwindSafe(|| task.run(0))); |
112 | |
113 | // Wait for other threads to finish writing their results. |
114 | // |
115 | // SAFETY: The acquire memory ordering ensures that all writes performed |
116 | // by the task on other threads will become visible to this thread after |
117 | // returning from `broadcast`. |
118 | while task.shared.as_ref().ref_count.load(Ordering::Acquire) > 0 { |
119 | std::thread::park(); |
120 | } |
121 | |
122 | // Don't drop our result until other threads finish, in case the panic |
123 | // error's drop handler itself also panics. |
124 | drop(main_result); |
125 | } |
126 | |
127 | pub fn drop_threads(&self) { |
128 | *self.threads.lock().unwrap_or_else(PoisonError::into_inner) = Default::default(); |
129 | } |
130 | |
131 | #[cfg (test)] |
132 | fn aux_thread_count(&self) -> usize { |
133 | self.threads.lock().unwrap_or_else(PoisonError::into_inner).len() |
134 | } |
135 | } |
136 | |
137 | /// Type-erased function and metadata. |
138 | #[derive (Clone, Copy)] |
139 | struct Task { |
140 | shared: NonNull<TaskShared<()>>, |
141 | } |
142 | |
143 | unsafe impl Send for Task {} |
144 | unsafe impl Sync for Task {} |
145 | |
146 | impl Task { |
147 | /// Runs this task on behalf of `thread_id`. |
148 | /// |
149 | /// # Safety |
150 | /// |
151 | /// The caller must ensure: |
152 | /// |
153 | /// - This task has not outlived the `TaskShared` it came from, or else |
154 | /// there will be a use-after-free. |
155 | /// |
156 | /// - `thread_id` is within the number of `broadcast` threads requested, so |
157 | /// that it can be used to index input or output buffers. |
158 | #[inline ] |
159 | unsafe fn run(&self, thread_id: usize) { |
160 | let shared_ptr: *mut TaskShared<()> = self.shared.as_ptr(); |
161 | let shared: &TaskShared<()> = &*shared_ptr; |
162 | |
163 | (shared.task_fn_ptr)(shared_ptr.cast(), thread_id); |
164 | } |
165 | } |
166 | |
167 | /// Data stored on the main thread that gets shared with auxiliary threads. |
168 | /// |
169 | /// # Memory Layout |
170 | /// |
171 | /// Since the benchmark may have thrashed the cache, this type's fields are |
172 | /// ordered by usage order. This type is also placed on its own cache line. |
173 | #[repr (C)] |
174 | struct TaskShared<F> { |
175 | /// Once an auxiliary thread sets `ref_count` to 0, it should notify the |
176 | /// main thread to wake up. |
177 | main_thread: Thread, |
178 | |
179 | /// The number of auxiliary threads executing the task. |
180 | /// |
181 | /// Once this is 0, the main thread can read any results the task produced. |
182 | ref_count: AtomicUsize, |
183 | |
184 | /// Performs `*result = Some(task_fn(thread))`. |
185 | task_fn_ptr: unsafe fn(task: *const TaskShared<()>, thread: usize), |
186 | |
187 | /// Stores the closure state of the provided task. |
188 | /// |
189 | /// This must be stored as the last field so that all other fields are in |
190 | /// the same place regardless of this field's type. |
191 | task_fn: F, |
192 | } |
193 | |
194 | impl<F> TaskShared<F> { |
195 | #[inline ] |
196 | fn new(aux_threads: usize, task_fn: F) -> Self |
197 | where |
198 | F: Sync + Fn(usize), |
199 | { |
200 | unsafe fn call<F>(task: *const TaskShared<()>, thread: usize) |
201 | where |
202 | F: Fn(usize), |
203 | { |
204 | let task_fn: &F = &(*task.cast::<TaskShared<F>>()).task_fn; |
205 | |
206 | task_fn(thread); |
207 | } |
208 | |
209 | Self { |
210 | main_thread: std::thread::current(), |
211 | ref_count: AtomicUsize::new(aux_threads), |
212 | task_fn_ptr: call::<F>, |
213 | task_fn, |
214 | } |
215 | } |
216 | } |
217 | |
218 | /// Spawns N additional threads and appends their channels to the list. |
219 | /// |
220 | /// Threads are given names in the form of `divan-$INDEX`. |
221 | #[cold ] |
222 | fn spawn(additional: NonZeroUsize, threads: &mut Vec<mpsc::SyncSender<Task>>) { |
223 | let next_thread_id = threads.len() + 1; |
224 | |
225 | threads.extend((next_thread_id..(next_thread_id + additional.get())).map(|thread_id| { |
226 | // Create single-task channel. Unless another benchmark is running, the |
227 | // current thread will be immediately unblocked after the auxiliary |
228 | // thread accepts the task. |
229 | // |
230 | // This uses a rendezvous channel (capacity 0) instead of other standard |
231 | // library channels because it reduces memory usage by many kilobytes. |
232 | let (sender, receiver) = mpsc::sync_channel::<Task>(0); |
233 | |
234 | let work = move || { |
235 | // Abort the process if the caught panic error itself panics when |
236 | // dropped. |
237 | let panic_guard = defer(|| std::process::abort()); |
238 | |
239 | while let Ok(task) = receiver.recv() { |
240 | // Run the task on this auxiliary thread. |
241 | // |
242 | // SAFETY: The task is valid until `ref_count == 0`. |
243 | let result = |
244 | std::panic::catch_unwind(AssertUnwindSafe(|| unsafe { task.run(thread_id) })); |
245 | |
246 | // Decrement the `ref_count` count to notify the main thread |
247 | // that we finished our work. |
248 | // |
249 | // SAFETY: This release operation makes writes within the task |
250 | // become visible to the main thread. |
251 | unsafe { |
252 | // Clone the main thread's handle for unparking because the |
253 | // `TaskShared` will be invalidated when `ref_count` is 0. |
254 | let main_thread = task.shared.as_ref().main_thread.clone(); |
255 | |
256 | if task.shared.as_ref().ref_count.fetch_sub(1, Ordering::Release) == 1 { |
257 | main_thread.unpark(); |
258 | } |
259 | } |
260 | |
261 | // Don't drop our result until after notifying the main thread, |
262 | // in case the panic error's drop handler itself also panics. |
263 | drop(result); |
264 | } |
265 | |
266 | std::mem::forget(panic_guard); |
267 | }; |
268 | |
269 | std::thread::Builder::new() |
270 | .name(format!("divan- {thread_id}" )) |
271 | .spawn(work) |
272 | .expect("failed to spawn thread" ); |
273 | |
274 | sender |
275 | })); |
276 | } |
277 | |
278 | #[cfg (test)] |
279 | mod tests { |
280 | use super::*; |
281 | |
282 | /// Make every thread write its ID to a buffer and then check that the |
283 | /// buffer contains all IDs. |
284 | #[test ] |
285 | fn extend() { |
286 | static TEST_POOL: ThreadPool = ThreadPool::new(); |
287 | |
288 | fn test(aux_threads: usize, final_aux_threads: usize) { |
289 | let total_threads = aux_threads + 1; |
290 | |
291 | let mut results = Vec::new(); |
292 | let expected = (0..total_threads).map(Some).collect::<Vec<_>>(); |
293 | |
294 | TEST_POOL.par_extend(&mut results, aux_threads, |index| index); |
295 | |
296 | assert_eq!(results, expected); |
297 | assert_eq!(TEST_POOL.aux_thread_count(), final_aux_threads); |
298 | } |
299 | |
300 | test (0, 0); |
301 | test (1, 1); |
302 | test (2, 2); |
303 | test (3, 3); |
304 | test (4, 4); |
305 | test (8, 8); |
306 | |
307 | // Decreasing auxiliary threads on later calls should still leave |
308 | // previously spawned threads running. |
309 | test (4, 8); |
310 | test (0, 8); |
311 | |
312 | // Silence Miri about leaking threads. |
313 | TEST_POOL.drop_threads(); |
314 | } |
315 | |
316 | /// Execute a task that takes longer on all other threads than the main |
317 | /// thread. |
318 | #[test ] |
319 | fn broadcast_sleep() { |
320 | use std::time::Duration; |
321 | |
322 | static TEST_POOL: ThreadPool = ThreadPool::new(); |
323 | |
324 | TEST_POOL.broadcast(10, |thread_id| { |
325 | if thread_id > 0 { |
326 | std::thread::sleep(Duration::from_millis(10)); |
327 | } |
328 | }); |
329 | |
330 | // Silence Miri about leaking threads. |
331 | TEST_POOL.drop_threads(); |
332 | } |
333 | |
334 | /// Checks that thread ID 0 refers to the main thread. |
335 | #[test ] |
336 | fn broadcast_thread_id() { |
337 | static TEST_POOL: ThreadPool = ThreadPool::new(); |
338 | |
339 | let main_thread = std::thread::current().id(); |
340 | |
341 | TEST_POOL.broadcast(10, |thread_id| { |
342 | let is_main = main_thread == std::thread::current().id(); |
343 | assert_eq!(is_main, thread_id == 0); |
344 | }); |
345 | |
346 | // Silence Miri about leaking threads. |
347 | TEST_POOL.drop_threads(); |
348 | } |
349 | } |
350 | |
351 | #[cfg (feature = "internal_benches" )] |
352 | mod benches { |
353 | use super::*; |
354 | |
355 | fn aux_thread_counts() -> impl Iterator<Item = usize> { |
356 | let mut available_parallelism = std::thread::available_parallelism().ok().map(|n| n.get()); |
357 | |
358 | let range = 0..=16; |
359 | |
360 | if let Some(n) = available_parallelism { |
361 | if range.contains(&n) { |
362 | available_parallelism = None; |
363 | } |
364 | } |
365 | |
366 | range.chain(available_parallelism) |
367 | } |
368 | |
369 | /// Benchmarks repeatedly using `ThreadPool` for the same number of threads |
370 | /// on every run. |
371 | #[crate::bench (crate = crate, args = aux_thread_counts())] |
372 | fn broadcast(bencher: crate::Bencher, aux_threads: usize) { |
373 | let pool = ThreadPool::new(); |
374 | let benched = move || pool.broadcast(aux_threads, crate::black_box_drop); |
375 | |
376 | // Warmup to spawn threads. |
377 | benched(); |
378 | |
379 | bencher.bench(benched); |
380 | } |
381 | |
382 | /// Benchmarks using `ThreadPool` once. |
383 | #[crate::bench (crate = crate, args = aux_thread_counts(), sample_size = 1)] |
384 | fn broadcast_once(bencher: crate::Bencher, aux_threads: usize) { |
385 | bencher |
386 | .with_inputs(ThreadPool::new) |
387 | .bench_refs(|pool| pool.broadcast(aux_threads, crate::black_box_drop)); |
388 | } |
389 | } |
390 | |