1 | use crate::job::{JobFifo, JobRef, StackJob}; |
2 | use crate::latch::{AsCoreLatch, CoreLatch, Latch, LatchRef, LockLatch, OnceLatch, SpinLatch}; |
3 | use crate::sleep::Sleep; |
4 | use crate::unwind; |
5 | use crate::{ |
6 | ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder, |
7 | Yield, |
8 | }; |
9 | use crossbeam_deque::{Injector, Steal, Stealer, Worker}; |
10 | use std::cell::Cell; |
11 | use std::collections::hash_map::DefaultHasher; |
12 | use std::fmt; |
13 | use std::hash::Hasher; |
14 | use std::io; |
15 | use std::mem; |
16 | use std::ptr; |
17 | use std::sync::atomic::{AtomicUsize, Ordering}; |
18 | use std::sync::{Arc, Mutex, Once}; |
19 | use std::thread; |
20 | use std::usize; |
21 | |
22 | /// Thread builder used for customization via |
23 | /// [`ThreadPoolBuilder::spawn_handler`](struct.ThreadPoolBuilder.html#method.spawn_handler). |
24 | pub struct ThreadBuilder { |
25 | name: Option<String>, |
26 | stack_size: Option<usize>, |
27 | worker: Worker<JobRef>, |
28 | stealer: Stealer<JobRef>, |
29 | registry: Arc<Registry>, |
30 | index: usize, |
31 | } |
32 | |
33 | impl ThreadBuilder { |
34 | /// Gets the index of this thread in the pool, within `0..num_threads`. |
35 | pub fn index(&self) -> usize { |
36 | self.index |
37 | } |
38 | |
39 | /// Gets the string that was specified by `ThreadPoolBuilder::name()`. |
40 | pub fn name(&self) -> Option<&str> { |
41 | self.name.as_deref() |
42 | } |
43 | |
44 | /// Gets the value that was specified by `ThreadPoolBuilder::stack_size()`. |
45 | pub fn stack_size(&self) -> Option<usize> { |
46 | self.stack_size |
47 | } |
48 | |
49 | /// Executes the main loop for this thread. This will not return until the |
50 | /// thread pool is dropped. |
51 | pub fn run(self) { |
52 | unsafe { main_loop(self) } |
53 | } |
54 | } |
55 | |
56 | impl fmt::Debug for ThreadBuilder { |
57 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
58 | f.debug_struct("ThreadBuilder" ) |
59 | .field("pool" , &self.registry.id()) |
60 | .field("index" , &self.index) |
61 | .field("name" , &self.name) |
62 | .field("stack_size" , &self.stack_size) |
63 | .finish() |
64 | } |
65 | } |
66 | |
67 | /// Generalized trait for spawning a thread in the `Registry`. |
68 | /// |
69 | /// This trait is pub-in-private -- E0445 forces us to make it public, |
70 | /// but we don't actually want to expose these details in the API. |
71 | pub trait ThreadSpawn { |
72 | private_decl! {} |
73 | |
74 | /// Spawn a thread with the `ThreadBuilder` parameters, and then |
75 | /// call `ThreadBuilder::run()`. |
76 | fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>; |
77 | } |
78 | |
79 | /// Spawns a thread in the "normal" way with `std::thread::Builder`. |
80 | /// |
81 | /// This type is pub-in-private -- E0445 forces us to make it public, |
82 | /// but we don't actually want to expose these details in the API. |
83 | #[derive(Debug, Default)] |
84 | pub struct DefaultSpawn; |
85 | |
86 | impl ThreadSpawn for DefaultSpawn { |
87 | private_impl! {} |
88 | |
89 | fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> { |
90 | let mut b = thread::Builder::new(); |
91 | if let Some(name) = thread.name() { |
92 | b = b.name(name.to_owned()); |
93 | } |
94 | if let Some(stack_size) = thread.stack_size() { |
95 | b = b.stack_size(stack_size); |
96 | } |
97 | b.spawn(|| thread.run())?; |
98 | Ok(()) |
99 | } |
100 | } |
101 | |
102 | /// Spawns a thread with a user's custom callback. |
103 | /// |
104 | /// This type is pub-in-private -- E0445 forces us to make it public, |
105 | /// but we don't actually want to expose these details in the API. |
106 | #[derive(Debug)] |
107 | pub struct CustomSpawn<F>(F); |
108 | |
109 | impl<F> CustomSpawn<F> |
110 | where |
111 | F: FnMut(ThreadBuilder) -> io::Result<()>, |
112 | { |
113 | pub(super) fn new(spawn: F) -> Self { |
114 | CustomSpawn(spawn) |
115 | } |
116 | } |
117 | |
118 | impl<F> ThreadSpawn for CustomSpawn<F> |
119 | where |
120 | F: FnMut(ThreadBuilder) -> io::Result<()>, |
121 | { |
122 | private_impl! {} |
123 | |
124 | #[inline ] |
125 | fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> { |
126 | (self.0)(thread) |
127 | } |
128 | } |
129 | |
130 | pub(super) struct Registry { |
131 | thread_infos: Vec<ThreadInfo>, |
132 | sleep: Sleep, |
133 | injected_jobs: Injector<JobRef>, |
134 | broadcasts: Mutex<Vec<Worker<JobRef>>>, |
135 | panic_handler: Option<Box<PanicHandler>>, |
136 | start_handler: Option<Box<StartHandler>>, |
137 | exit_handler: Option<Box<ExitHandler>>, |
138 | |
139 | // When this latch reaches 0, it means that all work on this |
140 | // registry must be complete. This is ensured in the following ways: |
141 | // |
142 | // - if this is the global registry, there is a ref-count that never |
143 | // gets released. |
144 | // - if this is a user-created thread-pool, then so long as the thread-pool |
145 | // exists, it holds a reference. |
146 | // - when we inject a "blocking job" into the registry with `ThreadPool::install()`, |
147 | // no adjustment is needed; the `ThreadPool` holds the reference, and since we won't |
148 | // return until the blocking job is complete, that ref will continue to be held. |
149 | // - when `join()` or `scope()` is invoked, similarly, no adjustments are needed. |
150 | // These are always owned by some other job (e.g., one injected by `ThreadPool::install()`) |
151 | // and that job will keep the pool alive. |
152 | terminate_count: AtomicUsize, |
153 | } |
154 | |
155 | /// //////////////////////////////////////////////////////////////////////// |
156 | /// Initialization |
157 | |
158 | static mut THE_REGISTRY: Option<Arc<Registry>> = None; |
159 | static THE_REGISTRY_SET: Once = Once::new(); |
160 | |
161 | /// Starts the worker threads (if that has not already happened). If |
162 | /// initialization has not already occurred, use the default |
163 | /// configuration. |
164 | pub(super) fn global_registry() -> &'static Arc<Registry> { |
165 | set_global_registry(default_global_registry) |
166 | .or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) }) |
167 | .expect("The global thread pool has not been initialized." ) |
168 | } |
169 | |
170 | /// Starts the worker threads (if that has not already happened) with |
171 | /// the given builder. |
172 | pub(super) fn init_global_registry<S>( |
173 | builder: ThreadPoolBuilder<S>, |
174 | ) -> Result<&'static Arc<Registry>, ThreadPoolBuildError> |
175 | where |
176 | S: ThreadSpawn, |
177 | { |
178 | set_global_registry(|| Registry::new(builder)) |
179 | } |
180 | |
181 | /// Starts the worker threads (if that has not already happened) |
182 | /// by creating a registry with the given callback. |
183 | fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError> |
184 | where |
185 | F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>, |
186 | { |
187 | let mut result = Err(ThreadPoolBuildError::new( |
188 | ErrorKind::GlobalPoolAlreadyInitialized, |
189 | )); |
190 | |
191 | THE_REGISTRY_SET.call_once(|| { |
192 | result = registry() |
193 | .map(|registry: Arc<Registry>| unsafe { &*THE_REGISTRY.get_or_insert(registry) }) |
194 | }); |
195 | |
196 | result |
197 | } |
198 | |
199 | fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> { |
200 | let result = Registry::new(ThreadPoolBuilder::new()); |
201 | |
202 | // If we're running in an environment that doesn't support threads at all, we can fall back to |
203 | // using the current thread alone. This is crude, and probably won't work for non-blocking |
204 | // calls like `spawn` or `broadcast_spawn`, but a lot of stuff does work fine. |
205 | // |
206 | // Notably, this allows current WebAssembly targets to work even though their threading support |
207 | // is stubbed out, and we won't have to change anything if they do add real threading. |
208 | let unsupported = matches!(&result, Err(e) if e.is_unsupported()); |
209 | if unsupported && WorkerThread::current().is_null() { |
210 | let builder = ThreadPoolBuilder::new().num_threads(1).use_current_thread(); |
211 | let fallback_result = Registry::new(builder); |
212 | if fallback_result.is_ok() { |
213 | return fallback_result; |
214 | } |
215 | } |
216 | |
217 | result |
218 | } |
219 | |
220 | struct Terminator<'a>(&'a Arc<Registry>); |
221 | |
222 | impl<'a> Drop for Terminator<'a> { |
223 | fn drop(&mut self) { |
224 | self.0.terminate() |
225 | } |
226 | } |
227 | |
228 | impl Registry { |
229 | pub(super) fn new<S>( |
230 | mut builder: ThreadPoolBuilder<S>, |
231 | ) -> Result<Arc<Self>, ThreadPoolBuildError> |
232 | where |
233 | S: ThreadSpawn, |
234 | { |
235 | // Soft-limit the number of threads that we can actually support. |
236 | let n_threads = Ord::min(builder.get_num_threads(), crate::max_num_threads()); |
237 | |
238 | let breadth_first = builder.get_breadth_first(); |
239 | |
240 | let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads) |
241 | .map(|_| { |
242 | let worker = if breadth_first { |
243 | Worker::new_fifo() |
244 | } else { |
245 | Worker::new_lifo() |
246 | }; |
247 | |
248 | let stealer = worker.stealer(); |
249 | (worker, stealer) |
250 | }) |
251 | .unzip(); |
252 | |
253 | let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads) |
254 | .map(|_| { |
255 | let worker = Worker::new_fifo(); |
256 | let stealer = worker.stealer(); |
257 | (worker, stealer) |
258 | }) |
259 | .unzip(); |
260 | |
261 | let registry = Arc::new(Registry { |
262 | thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(), |
263 | sleep: Sleep::new(n_threads), |
264 | injected_jobs: Injector::new(), |
265 | broadcasts: Mutex::new(broadcasts), |
266 | terminate_count: AtomicUsize::new(1), |
267 | panic_handler: builder.take_panic_handler(), |
268 | start_handler: builder.take_start_handler(), |
269 | exit_handler: builder.take_exit_handler(), |
270 | }); |
271 | |
272 | // If we return early or panic, make sure to terminate existing threads. |
273 | let t1000 = Terminator(®istry); |
274 | |
275 | for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() { |
276 | let thread = ThreadBuilder { |
277 | name: builder.get_thread_name(index), |
278 | stack_size: builder.get_stack_size(), |
279 | registry: Arc::clone(®istry), |
280 | worker, |
281 | stealer, |
282 | index, |
283 | }; |
284 | |
285 | if index == 0 && builder.use_current_thread { |
286 | if !WorkerThread::current().is_null() { |
287 | return Err(ThreadPoolBuildError::new( |
288 | ErrorKind::CurrentThreadAlreadyInPool, |
289 | )); |
290 | } |
291 | // Rather than starting a new thread, we're just taking over the current thread |
292 | // *without* running the main loop, so we can still return from here. |
293 | // The WorkerThread is leaked, but we never shutdown the global pool anyway. |
294 | let worker_thread = Box::into_raw(Box::new(WorkerThread::from(thread))); |
295 | |
296 | unsafe { |
297 | WorkerThread::set_current(worker_thread); |
298 | Latch::set(®istry.thread_infos[index].primed); |
299 | } |
300 | continue; |
301 | } |
302 | |
303 | if let Err(e) = builder.get_spawn_handler().spawn(thread) { |
304 | return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e))); |
305 | } |
306 | } |
307 | |
308 | // Returning normally now, without termination. |
309 | mem::forget(t1000); |
310 | |
311 | Ok(registry) |
312 | } |
313 | |
314 | pub(super) fn current() -> Arc<Registry> { |
315 | unsafe { |
316 | let worker_thread = WorkerThread::current(); |
317 | let registry = if worker_thread.is_null() { |
318 | global_registry() |
319 | } else { |
320 | &(*worker_thread).registry |
321 | }; |
322 | Arc::clone(registry) |
323 | } |
324 | } |
325 | |
326 | /// Returns the number of threads in the current registry. This |
327 | /// is better than `Registry::current().num_threads()` because it |
328 | /// avoids incrementing the `Arc`. |
329 | pub(super) fn current_num_threads() -> usize { |
330 | unsafe { |
331 | let worker_thread = WorkerThread::current(); |
332 | if worker_thread.is_null() { |
333 | global_registry().num_threads() |
334 | } else { |
335 | (*worker_thread).registry.num_threads() |
336 | } |
337 | } |
338 | } |
339 | |
340 | /// Returns the current `WorkerThread` if it's part of this `Registry`. |
341 | pub(super) fn current_thread(&self) -> Option<&WorkerThread> { |
342 | unsafe { |
343 | let worker = WorkerThread::current().as_ref()?; |
344 | if worker.registry().id() == self.id() { |
345 | Some(worker) |
346 | } else { |
347 | None |
348 | } |
349 | } |
350 | } |
351 | |
352 | /// Returns an opaque identifier for this registry. |
353 | pub(super) fn id(&self) -> RegistryId { |
354 | // We can rely on `self` not to change since we only ever create |
355 | // registries that are boxed up in an `Arc` (see `new()` above). |
356 | RegistryId { |
357 | addr: self as *const Self as usize, |
358 | } |
359 | } |
360 | |
361 | pub(super) fn num_threads(&self) -> usize { |
362 | self.thread_infos.len() |
363 | } |
364 | |
365 | pub(super) fn catch_unwind(&self, f: impl FnOnce()) { |
366 | if let Err(err) = unwind::halt_unwinding(f) { |
367 | // If there is no handler, or if that handler itself panics, then we abort. |
368 | let abort_guard = unwind::AbortIfPanic; |
369 | if let Some(ref handler) = self.panic_handler { |
370 | handler(err); |
371 | mem::forget(abort_guard); |
372 | } |
373 | } |
374 | } |
375 | |
376 | /// Waits for the worker threads to get up and running. This is |
377 | /// meant to be used for benchmarking purposes, primarily, so that |
378 | /// you can get more consistent numbers by having everything |
379 | /// "ready to go". |
380 | pub(super) fn wait_until_primed(&self) { |
381 | for info in &self.thread_infos { |
382 | info.primed.wait(); |
383 | } |
384 | } |
385 | |
386 | /// Waits for the worker threads to stop. This is used for testing |
387 | /// -- so we can check that termination actually works. |
388 | #[cfg (test)] |
389 | pub(super) fn wait_until_stopped(&self) { |
390 | for info in &self.thread_infos { |
391 | info.stopped.wait(); |
392 | } |
393 | } |
394 | |
395 | /// //////////////////////////////////////////////////////////////////////// |
396 | /// MAIN LOOP |
397 | /// |
398 | /// So long as all of the worker threads are hanging out in their |
399 | /// top-level loop, there is no work to be done. |
400 | |
401 | /// Push a job into the given `registry`. If we are running on a |
402 | /// worker thread for the registry, this will push onto the |
403 | /// deque. Else, it will inject from the outside (which is slower). |
404 | pub(super) fn inject_or_push(&self, job_ref: JobRef) { |
405 | let worker_thread = WorkerThread::current(); |
406 | unsafe { |
407 | if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() { |
408 | (*worker_thread).push(job_ref); |
409 | } else { |
410 | self.inject(job_ref); |
411 | } |
412 | } |
413 | } |
414 | |
415 | /// Push a job into the "external jobs" queue; it will be taken by |
416 | /// whatever worker has nothing to do. Use this if you know that |
417 | /// you are not on a worker of this registry. |
418 | pub(super) fn inject(&self, injected_job: JobRef) { |
419 | // It should not be possible for `state.terminate` to be true |
420 | // here. It is only set to true when the user creates (and |
421 | // drops) a `ThreadPool`; and, in that case, they cannot be |
422 | // calling `inject()` later, since they dropped their |
423 | // `ThreadPool`. |
424 | debug_assert_ne!( |
425 | self.terminate_count.load(Ordering::Acquire), |
426 | 0, |
427 | "inject() sees state.terminate as true" |
428 | ); |
429 | |
430 | let queue_was_empty = self.injected_jobs.is_empty(); |
431 | |
432 | self.injected_jobs.push(injected_job); |
433 | self.sleep.new_injected_jobs(1, queue_was_empty); |
434 | } |
435 | |
436 | fn has_injected_job(&self) -> bool { |
437 | !self.injected_jobs.is_empty() |
438 | } |
439 | |
440 | fn pop_injected_job(&self) -> Option<JobRef> { |
441 | loop { |
442 | match self.injected_jobs.steal() { |
443 | Steal::Success(job) => return Some(job), |
444 | Steal::Empty => return None, |
445 | Steal::Retry => {} |
446 | } |
447 | } |
448 | } |
449 | |
450 | /// Push a job into each thread's own "external jobs" queue; it will be |
451 | /// executed only on that thread, when it has nothing else to do locally, |
452 | /// before it tries to steal other work. |
453 | /// |
454 | /// **Panics** if not given exactly as many jobs as there are threads. |
455 | pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>) { |
456 | assert_eq!(self.num_threads(), injected_jobs.len()); |
457 | { |
458 | let broadcasts = self.broadcasts.lock().unwrap(); |
459 | |
460 | // It should not be possible for `state.terminate` to be true |
461 | // here. It is only set to true when the user creates (and |
462 | // drops) a `ThreadPool`; and, in that case, they cannot be |
463 | // calling `inject_broadcast()` later, since they dropped their |
464 | // `ThreadPool`. |
465 | debug_assert_ne!( |
466 | self.terminate_count.load(Ordering::Acquire), |
467 | 0, |
468 | "inject_broadcast() sees state.terminate as true" |
469 | ); |
470 | |
471 | assert_eq!(broadcasts.len(), injected_jobs.len()); |
472 | for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) { |
473 | worker.push(job_ref); |
474 | } |
475 | } |
476 | for i in 0..self.num_threads() { |
477 | self.sleep.notify_worker_latch_is_set(i); |
478 | } |
479 | } |
480 | |
481 | /// If already in a worker-thread of this registry, just execute `op`. |
482 | /// Otherwise, inject `op` in this thread-pool. Either way, block until `op` |
483 | /// completes and return its return value. If `op` panics, that panic will |
484 | /// be propagated as well. The second argument indicates `true` if injection |
485 | /// was performed, `false` if executed directly. |
486 | pub(super) fn in_worker<OP, R>(&self, op: OP) -> R |
487 | where |
488 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
489 | R: Send, |
490 | { |
491 | unsafe { |
492 | let worker_thread = WorkerThread::current(); |
493 | if worker_thread.is_null() { |
494 | self.in_worker_cold(op) |
495 | } else if (*worker_thread).registry().id() != self.id() { |
496 | self.in_worker_cross(&*worker_thread, op) |
497 | } else { |
498 | // Perfectly valid to give them a `&T`: this is the |
499 | // current thread, so we know the data structure won't be |
500 | // invalidated until we return. |
501 | op(&*worker_thread, false) |
502 | } |
503 | } |
504 | } |
505 | |
506 | #[cold ] |
507 | unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R |
508 | where |
509 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
510 | R: Send, |
511 | { |
512 | thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new()); |
513 | |
514 | LOCK_LATCH.with(|l| { |
515 | // This thread isn't a member of *any* thread pool, so just block. |
516 | debug_assert!(WorkerThread::current().is_null()); |
517 | let job = StackJob::new( |
518 | |injected| { |
519 | let worker_thread = WorkerThread::current(); |
520 | assert!(injected && !worker_thread.is_null()); |
521 | op(&*worker_thread, true) |
522 | }, |
523 | LatchRef::new(l), |
524 | ); |
525 | self.inject(job.as_job_ref()); |
526 | job.latch.wait_and_reset(); // Make sure we can use the same latch again next time. |
527 | |
528 | job.into_result() |
529 | }) |
530 | } |
531 | |
532 | #[cold ] |
533 | unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R |
534 | where |
535 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
536 | R: Send, |
537 | { |
538 | // This thread is a member of a different pool, so let it process |
539 | // other work while waiting for this `op` to complete. |
540 | debug_assert!(current_thread.registry().id() != self.id()); |
541 | let latch = SpinLatch::cross(current_thread); |
542 | let job = StackJob::new( |
543 | |injected| { |
544 | let worker_thread = WorkerThread::current(); |
545 | assert!(injected && !worker_thread.is_null()); |
546 | op(&*worker_thread, true) |
547 | }, |
548 | latch, |
549 | ); |
550 | self.inject(job.as_job_ref()); |
551 | current_thread.wait_until(&job.latch); |
552 | job.into_result() |
553 | } |
554 | |
555 | /// Increments the terminate counter. This increment should be |
556 | /// balanced by a call to `terminate`, which will decrement. This |
557 | /// is used when spawning asynchronous work, which needs to |
558 | /// prevent the registry from terminating so long as it is active. |
559 | /// |
560 | /// Note that blocking functions such as `join` and `scope` do not |
561 | /// need to concern themselves with this fn; their context is |
562 | /// responsible for ensuring the current thread-pool will not |
563 | /// terminate until they return. |
564 | /// |
565 | /// The global thread-pool always has an outstanding reference |
566 | /// (the initial one). Custom thread-pools have one outstanding |
567 | /// reference that is dropped when the `ThreadPool` is dropped: |
568 | /// since installing the thread-pool blocks until any joins/scopes |
569 | /// complete, this ensures that joins/scopes are covered. |
570 | /// |
571 | /// The exception is `::spawn()`, which can create a job outside |
572 | /// of any blocking scope. In that case, the job itself holds a |
573 | /// terminate count and is responsible for invoking `terminate()` |
574 | /// when finished. |
575 | pub(super) fn increment_terminate_count(&self) { |
576 | let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel); |
577 | debug_assert!(previous != 0, "registry ref count incremented from zero" ); |
578 | assert!( |
579 | previous != std::usize::MAX, |
580 | "overflow in registry ref count" |
581 | ); |
582 | } |
583 | |
584 | /// Signals that the thread-pool which owns this registry has been |
585 | /// dropped. The worker threads will gradually terminate, once any |
586 | /// extant work is completed. |
587 | pub(super) fn terminate(&self) { |
588 | if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 { |
589 | for (i, thread_info) in self.thread_infos.iter().enumerate() { |
590 | unsafe { OnceLatch::set_and_tickle_one(&thread_info.terminate, self, i) }; |
591 | } |
592 | } |
593 | } |
594 | |
595 | /// Notify the worker that the latch they are sleeping on has been "set". |
596 | pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) { |
597 | self.sleep.notify_worker_latch_is_set(target_worker_index); |
598 | } |
599 | } |
600 | |
601 | #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] |
602 | pub(super) struct RegistryId { |
603 | addr: usize, |
604 | } |
605 | |
606 | struct ThreadInfo { |
607 | /// Latch set once thread has started and we are entering into the |
608 | /// main loop. Used to wait for worker threads to become primed, |
609 | /// primarily of interest for benchmarking. |
610 | primed: LockLatch, |
611 | |
612 | /// Latch is set once worker thread has completed. Used to wait |
613 | /// until workers have stopped; only used for tests. |
614 | stopped: LockLatch, |
615 | |
616 | /// The latch used to signal that terminated has been requested. |
617 | /// This latch is *set* by the `terminate` method on the |
618 | /// `Registry`, once the registry's main "terminate" counter |
619 | /// reaches zero. |
620 | terminate: OnceLatch, |
621 | |
622 | /// the "stealer" half of the worker's deque |
623 | stealer: Stealer<JobRef>, |
624 | } |
625 | |
626 | impl ThreadInfo { |
627 | fn new(stealer: Stealer<JobRef>) -> ThreadInfo { |
628 | ThreadInfo { |
629 | primed: LockLatch::new(), |
630 | stopped: LockLatch::new(), |
631 | terminate: OnceLatch::new(), |
632 | stealer, |
633 | } |
634 | } |
635 | } |
636 | |
637 | /// //////////////////////////////////////////////////////////////////////// |
638 | /// WorkerThread identifiers |
639 | |
640 | pub(super) struct WorkerThread { |
641 | /// the "worker" half of our local deque |
642 | worker: Worker<JobRef>, |
643 | |
644 | /// the "stealer" half of the worker's broadcast deque |
645 | stealer: Stealer<JobRef>, |
646 | |
647 | /// local queue used for `spawn_fifo` indirection |
648 | fifo: JobFifo, |
649 | |
650 | index: usize, |
651 | |
652 | /// A weak random number generator. |
653 | rng: XorShift64Star, |
654 | |
655 | registry: Arc<Registry>, |
656 | } |
657 | |
658 | // This is a bit sketchy, but basically: the WorkerThread is |
659 | // allocated on the stack of the worker on entry and stored into this |
660 | // thread local variable. So it will remain valid at least until the |
661 | // worker is fully unwound. Using an unsafe pointer avoids the need |
662 | // for a RefCell<T> etc. |
663 | thread_local! { |
664 | static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null()) }; |
665 | } |
666 | |
667 | impl From<ThreadBuilder> for WorkerThread { |
668 | fn from(thread: ThreadBuilder) -> Self { |
669 | Self { |
670 | worker: thread.worker, |
671 | stealer: thread.stealer, |
672 | fifo: JobFifo::new(), |
673 | index: thread.index, |
674 | rng: XorShift64Star::new(), |
675 | registry: thread.registry, |
676 | } |
677 | } |
678 | } |
679 | |
680 | impl Drop for WorkerThread { |
681 | fn drop(&mut self) { |
682 | // Undo `set_current` |
683 | WORKER_THREAD_STATE.with(|t| { |
684 | assert!(t.get().eq(&(self as *const _))); |
685 | t.set(ptr::null()); |
686 | }); |
687 | } |
688 | } |
689 | |
690 | impl WorkerThread { |
691 | /// Gets the `WorkerThread` index for the current thread; returns |
692 | /// NULL if this is not a worker thread. This pointer is valid |
693 | /// anywhere on the current thread. |
694 | #[inline ] |
695 | pub(super) fn current() -> *const WorkerThread { |
696 | WORKER_THREAD_STATE.with(Cell::get) |
697 | } |
698 | |
699 | /// Sets `self` as the worker thread index for the current thread. |
700 | /// This is done during worker thread startup. |
701 | unsafe fn set_current(thread: *const WorkerThread) { |
702 | WORKER_THREAD_STATE.with(|t| { |
703 | assert!(t.get().is_null()); |
704 | t.set(thread); |
705 | }); |
706 | } |
707 | |
708 | /// Returns the registry that owns this worker thread. |
709 | #[inline ] |
710 | pub(super) fn registry(&self) -> &Arc<Registry> { |
711 | &self.registry |
712 | } |
713 | |
714 | /// Our index amongst the worker threads (ranges from `0..self.num_threads()`). |
715 | #[inline ] |
716 | pub(super) fn index(&self) -> usize { |
717 | self.index |
718 | } |
719 | |
720 | #[inline ] |
721 | pub(super) unsafe fn push(&self, job: JobRef) { |
722 | let queue_was_empty = self.worker.is_empty(); |
723 | self.worker.push(job); |
724 | self.registry.sleep.new_internal_jobs(1, queue_was_empty); |
725 | } |
726 | |
727 | #[inline ] |
728 | pub(super) unsafe fn push_fifo(&self, job: JobRef) { |
729 | self.push(self.fifo.push(job)); |
730 | } |
731 | |
732 | #[inline ] |
733 | pub(super) fn local_deque_is_empty(&self) -> bool { |
734 | self.worker.is_empty() |
735 | } |
736 | |
737 | /// Attempts to obtain a "local" job -- typically this means |
738 | /// popping from the top of the stack, though if we are configured |
739 | /// for breadth-first execution, it would mean dequeuing from the |
740 | /// bottom. |
741 | #[inline ] |
742 | pub(super) fn take_local_job(&self) -> Option<JobRef> { |
743 | let popped_job = self.worker.pop(); |
744 | |
745 | if popped_job.is_some() { |
746 | return popped_job; |
747 | } |
748 | |
749 | loop { |
750 | match self.stealer.steal() { |
751 | Steal::Success(job) => return Some(job), |
752 | Steal::Empty => return None, |
753 | Steal::Retry => {} |
754 | } |
755 | } |
756 | } |
757 | |
758 | fn has_injected_job(&self) -> bool { |
759 | !self.stealer.is_empty() || self.registry.has_injected_job() |
760 | } |
761 | |
762 | /// Wait until the latch is set. Try to keep busy by popping and |
763 | /// stealing tasks as necessary. |
764 | #[inline ] |
765 | pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) { |
766 | let latch = latch.as_core_latch(); |
767 | if !latch.probe() { |
768 | self.wait_until_cold(latch); |
769 | } |
770 | } |
771 | |
772 | #[cold ] |
773 | unsafe fn wait_until_cold(&self, latch: &CoreLatch) { |
774 | // the code below should swallow all panics and hence never |
775 | // unwind; but if something does wrong, we want to abort, |
776 | // because otherwise other code in rayon may assume that the |
777 | // latch has been signaled, and that can lead to random memory |
778 | // accesses, which would be *very bad* |
779 | let abort_guard = unwind::AbortIfPanic; |
780 | |
781 | 'outer: while !latch.probe() { |
782 | // Check for local work *before* we start marking ourself idle, |
783 | // especially to avoid modifying shared sleep state. |
784 | if let Some(job) = self.take_local_job() { |
785 | self.execute(job); |
786 | continue; |
787 | } |
788 | |
789 | let mut idle_state = self.registry.sleep.start_looking(self.index); |
790 | while !latch.probe() { |
791 | if let Some(job) = self.find_work() { |
792 | self.registry.sleep.work_found(); |
793 | self.execute(job); |
794 | // The job might have injected local work, so go back to the outer loop. |
795 | continue 'outer; |
796 | } else { |
797 | self.registry |
798 | .sleep |
799 | .no_work_found(&mut idle_state, latch, || self.has_injected_job()) |
800 | } |
801 | } |
802 | |
803 | // If we were sleepy, we are not anymore. We "found work" -- |
804 | // whatever the surrounding thread was doing before it had to wait. |
805 | self.registry.sleep.work_found(); |
806 | break; |
807 | } |
808 | |
809 | mem::forget(abort_guard); // successful execution, do not abort |
810 | } |
811 | |
812 | unsafe fn wait_until_out_of_work(&self) { |
813 | debug_assert_eq!(self as *const _, WorkerThread::current()); |
814 | let registry = &*self.registry; |
815 | let index = self.index; |
816 | |
817 | self.wait_until(®istry.thread_infos[index].terminate); |
818 | |
819 | // Should not be any work left in our queue. |
820 | debug_assert!(self.take_local_job().is_none()); |
821 | |
822 | // Let registry know we are done |
823 | Latch::set(®istry.thread_infos[index].stopped); |
824 | } |
825 | |
826 | fn find_work(&self) -> Option<JobRef> { |
827 | // Try to find some work to do. We give preference first |
828 | // to things in our local deque, then in other workers |
829 | // deques, and finally to injected jobs from the |
830 | // outside. The idea is to finish what we started before |
831 | // we take on something new. |
832 | self.take_local_job() |
833 | .or_else(|| self.steal()) |
834 | .or_else(|| self.registry.pop_injected_job()) |
835 | } |
836 | |
837 | pub(super) fn yield_now(&self) -> Yield { |
838 | match self.find_work() { |
839 | Some(job) => unsafe { |
840 | self.execute(job); |
841 | Yield::Executed |
842 | }, |
843 | None => Yield::Idle, |
844 | } |
845 | } |
846 | |
847 | pub(super) fn yield_local(&self) -> Yield { |
848 | match self.take_local_job() { |
849 | Some(job) => unsafe { |
850 | self.execute(job); |
851 | Yield::Executed |
852 | }, |
853 | None => Yield::Idle, |
854 | } |
855 | } |
856 | |
857 | #[inline ] |
858 | pub(super) unsafe fn execute(&self, job: JobRef) { |
859 | job.execute(); |
860 | } |
861 | |
862 | /// Try to steal a single job and return it. |
863 | /// |
864 | /// This should only be done as a last resort, when there is no |
865 | /// local work to do. |
866 | fn steal(&self) -> Option<JobRef> { |
867 | // we only steal when we don't have any work to do locally |
868 | debug_assert!(self.local_deque_is_empty()); |
869 | |
870 | // otherwise, try to steal |
871 | let thread_infos = &self.registry.thread_infos.as_slice(); |
872 | let num_threads = thread_infos.len(); |
873 | if num_threads <= 1 { |
874 | return None; |
875 | } |
876 | |
877 | loop { |
878 | let mut retry = false; |
879 | let start = self.rng.next_usize(num_threads); |
880 | let job = (start..num_threads) |
881 | .chain(0..start) |
882 | .filter(move |&i| i != self.index) |
883 | .find_map(|victim_index| { |
884 | let victim = &thread_infos[victim_index]; |
885 | match victim.stealer.steal() { |
886 | Steal::Success(job) => Some(job), |
887 | Steal::Empty => None, |
888 | Steal::Retry => { |
889 | retry = true; |
890 | None |
891 | } |
892 | } |
893 | }); |
894 | if job.is_some() || !retry { |
895 | return job; |
896 | } |
897 | } |
898 | } |
899 | } |
900 | |
901 | /// //////////////////////////////////////////////////////////////////////// |
902 | |
903 | unsafe fn main_loop(thread: ThreadBuilder) { |
904 | let worker_thread = &WorkerThread::from(thread); |
905 | WorkerThread::set_current(worker_thread); |
906 | let registry = &*worker_thread.registry; |
907 | let index = worker_thread.index; |
908 | |
909 | // let registry know we are ready to do work |
910 | Latch::set(®istry.thread_infos[index].primed); |
911 | |
912 | // Worker threads should not panic. If they do, just abort, as the |
913 | // internal state of the threadpool is corrupted. Note that if |
914 | // **user code** panics, we should catch that and redirect. |
915 | let abort_guard = unwind::AbortIfPanic; |
916 | |
917 | // Inform a user callback that we started a thread. |
918 | if let Some(ref handler) = registry.start_handler { |
919 | registry.catch_unwind(|| handler(index)); |
920 | } |
921 | |
922 | worker_thread.wait_until_out_of_work(); |
923 | |
924 | // Normal termination, do not abort. |
925 | mem::forget(abort_guard); |
926 | |
927 | // Inform a user callback that we exited a thread. |
928 | if let Some(ref handler) = registry.exit_handler { |
929 | registry.catch_unwind(|| handler(index)); |
930 | // We're already exiting the thread, there's nothing else to do. |
931 | } |
932 | } |
933 | |
934 | /// If already in a worker-thread, just execute `op`. Otherwise, |
935 | /// execute `op` in the default thread-pool. Either way, block until |
936 | /// `op` completes and return its return value. If `op` panics, that |
937 | /// panic will be propagated as well. The second argument indicates |
938 | /// `true` if injection was performed, `false` if executed directly. |
939 | pub(super) fn in_worker<OP, R>(op: OP) -> R |
940 | where |
941 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
942 | R: Send, |
943 | { |
944 | unsafe { |
945 | let owner_thread = WorkerThread::current(); |
946 | if !owner_thread.is_null() { |
947 | // Perfectly valid to give them a `&T`: this is the |
948 | // current thread, so we know the data structure won't be |
949 | // invalidated until we return. |
950 | op(&*owner_thread, false) |
951 | } else { |
952 | global_registry().in_worker(op) |
953 | } |
954 | } |
955 | } |
956 | |
957 | /// [xorshift*] is a fast pseudorandom number generator which will |
958 | /// even tolerate weak seeding, as long as it's not zero. |
959 | /// |
960 | /// [xorshift*]: https://en.wikipedia.org/wiki/Xorshift#xorshift* |
961 | struct XorShift64Star { |
962 | state: Cell<u64>, |
963 | } |
964 | |
965 | impl XorShift64Star { |
966 | fn new() -> Self { |
967 | // Any non-zero seed will do -- this uses the hash of a global counter. |
968 | let mut seed = 0; |
969 | while seed == 0 { |
970 | let mut hasher = DefaultHasher::new(); |
971 | static COUNTER: AtomicUsize = AtomicUsize::new(0); |
972 | hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed)); |
973 | seed = hasher.finish(); |
974 | } |
975 | |
976 | XorShift64Star { |
977 | state: Cell::new(seed), |
978 | } |
979 | } |
980 | |
981 | fn next(&self) -> u64 { |
982 | let mut x = self.state.get(); |
983 | debug_assert_ne!(x, 0); |
984 | x ^= x >> 12; |
985 | x ^= x << 25; |
986 | x ^= x >> 27; |
987 | self.state.set(x); |
988 | x.wrapping_mul(0x2545_f491_4f6c_dd1d) |
989 | } |
990 | |
991 | /// Return a value from `0..n`. |
992 | fn next_usize(&self, n: usize) -> usize { |
993 | (self.next() % n as u64) as usize |
994 | } |
995 | } |
996 | |