1 | use crate::job::{JobFifo, JobRef, StackJob}; |
2 | use crate::latch::{AsCoreLatch, CoreLatch, Latch, LatchRef, LockLatch, OnceLatch, SpinLatch}; |
3 | use crate::sleep::Sleep; |
4 | use crate::sync::Mutex; |
5 | use crate::unwind; |
6 | use crate::{ |
7 | ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder, |
8 | Yield, |
9 | }; |
10 | use crossbeam_deque::{Injector, Steal, Stealer, Worker}; |
11 | use std::cell::Cell; |
12 | use std::collections::hash_map::DefaultHasher; |
13 | use std::fmt; |
14 | use std::hash::Hasher; |
15 | use std::io; |
16 | use std::mem; |
17 | use std::ptr; |
18 | use std::sync::atomic::{AtomicUsize, Ordering}; |
19 | use std::sync::{Arc, Once}; |
20 | use std::thread; |
21 | use std::usize; |
22 | |
23 | /// Thread builder used for customization via |
24 | /// [`ThreadPoolBuilder::spawn_handler`](struct.ThreadPoolBuilder.html#method.spawn_handler). |
25 | pub struct ThreadBuilder { |
26 | name: Option<String>, |
27 | stack_size: Option<usize>, |
28 | worker: Worker<JobRef>, |
29 | stealer: Stealer<JobRef>, |
30 | registry: Arc<Registry>, |
31 | index: usize, |
32 | } |
33 | |
34 | impl ThreadBuilder { |
35 | /// Gets the index of this thread in the pool, within `0..num_threads`. |
36 | pub fn index(&self) -> usize { |
37 | self.index |
38 | } |
39 | |
40 | /// Gets the string that was specified by `ThreadPoolBuilder::name()`. |
41 | pub fn name(&self) -> Option<&str> { |
42 | self.name.as_deref() |
43 | } |
44 | |
45 | /// Gets the value that was specified by `ThreadPoolBuilder::stack_size()`. |
46 | pub fn stack_size(&self) -> Option<usize> { |
47 | self.stack_size |
48 | } |
49 | |
50 | /// Executes the main loop for this thread. This will not return until the |
51 | /// thread pool is dropped. |
52 | pub fn run(self) { |
53 | unsafe { main_loop(self) } |
54 | } |
55 | } |
56 | |
57 | impl fmt::Debug for ThreadBuilder { |
58 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
59 | f&mut DebugStruct<'_, '_>.debug_struct("ThreadBuilder" ) |
60 | .field("pool" , &self.registry.id()) |
61 | .field("index" , &self.index) |
62 | .field("name" , &self.name) |
63 | .field(name:"stack_size" , &self.stack_size) |
64 | .finish() |
65 | } |
66 | } |
67 | |
68 | /// Generalized trait for spawning a thread in the `Registry`. |
69 | /// |
70 | /// This trait is pub-in-private -- E0445 forces us to make it public, |
71 | /// but we don't actually want to expose these details in the API. |
72 | pub trait ThreadSpawn { |
73 | private_decl! {} |
74 | |
75 | /// Spawn a thread with the `ThreadBuilder` parameters, and then |
76 | /// call `ThreadBuilder::run()`. |
77 | fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>; |
78 | } |
79 | |
80 | /// Spawns a thread in the "normal" way with `std::thread::Builder`. |
81 | /// |
82 | /// This type is pub-in-private -- E0445 forces us to make it public, |
83 | /// but we don't actually want to expose these details in the API. |
84 | #[derive (Debug, Default)] |
85 | pub struct DefaultSpawn; |
86 | |
87 | impl ThreadSpawn for DefaultSpawn { |
88 | private_impl! {} |
89 | |
90 | fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> { |
91 | let mut b: Builder = thread::Builder::new(); |
92 | if let Some(name: &str) = thread.name() { |
93 | b = b.name(name.to_owned()); |
94 | } |
95 | if let Some(stack_size: usize) = thread.stack_size() { |
96 | b = b.stack_size(stack_size); |
97 | } |
98 | b.spawn(|| thread.run())?; |
99 | Ok(()) |
100 | } |
101 | } |
102 | |
103 | /// Spawns a thread with a user's custom callback. |
104 | /// |
105 | /// This type is pub-in-private -- E0445 forces us to make it public, |
106 | /// but we don't actually want to expose these details in the API. |
107 | #[derive (Debug)] |
108 | pub struct CustomSpawn<F>(F); |
109 | |
110 | impl<F> CustomSpawn<F> |
111 | where |
112 | F: FnMut(ThreadBuilder) -> io::Result<()>, |
113 | { |
114 | pub(super) fn new(spawn: F) -> Self { |
115 | CustomSpawn(spawn) |
116 | } |
117 | } |
118 | |
119 | impl<F> ThreadSpawn for CustomSpawn<F> |
120 | where |
121 | F: FnMut(ThreadBuilder) -> io::Result<()>, |
122 | { |
123 | private_impl! {} |
124 | |
125 | #[inline ] |
126 | fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> { |
127 | (self.0)(thread) |
128 | } |
129 | } |
130 | |
131 | pub(super) struct Registry { |
132 | thread_infos: Vec<ThreadInfo>, |
133 | sleep: Sleep, |
134 | injected_jobs: Injector<JobRef>, |
135 | broadcasts: Mutex<Vec<Worker<JobRef>>>, |
136 | panic_handler: Option<Box<PanicHandler>>, |
137 | start_handler: Option<Box<StartHandler>>, |
138 | exit_handler: Option<Box<ExitHandler>>, |
139 | |
140 | // When this latch reaches 0, it means that all work on this |
141 | // registry must be complete. This is ensured in the following ways: |
142 | // |
143 | // - if this is the global registry, there is a ref-count that never |
144 | // gets released. |
145 | // - if this is a user-created thread-pool, then so long as the thread-pool |
146 | // exists, it holds a reference. |
147 | // - when we inject a "blocking job" into the registry with `ThreadPool::install()`, |
148 | // no adjustment is needed; the `ThreadPool` holds the reference, and since we won't |
149 | // return until the blocking job is complete, that ref will continue to be held. |
150 | // - when `join()` or `scope()` is invoked, similarly, no adjustments are needed. |
151 | // These are always owned by some other job (e.g., one injected by `ThreadPool::install()`) |
152 | // and that job will keep the pool alive. |
153 | terminate_count: AtomicUsize, |
154 | } |
155 | |
156 | /// //////////////////////////////////////////////////////////////////////// |
157 | /// Initialization |
158 | |
159 | static mut THE_REGISTRY: Option<Arc<Registry>> = None; |
160 | static THE_REGISTRY_SET: Once = Once::new(); |
161 | |
162 | /// Starts the worker threads (if that has not already happened). If |
163 | /// initialization has not already occurred, use the default |
164 | /// configuration. |
165 | pub(super) fn global_registry() -> &'static Arc<Registry> { |
166 | set_global_registry(default_global_registry) |
167 | .or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) }) |
168 | .expect(msg:"The global thread pool has not been initialized." ) |
169 | } |
170 | |
171 | /// Starts the worker threads (if that has not already happened) with |
172 | /// the given builder. |
173 | pub(super) fn init_global_registry<S>( |
174 | builder: ThreadPoolBuilder<S>, |
175 | ) -> Result<&'static Arc<Registry>, ThreadPoolBuildError> |
176 | where |
177 | S: ThreadSpawn, |
178 | { |
179 | set_global_registry(|| Registry::new(builder)) |
180 | } |
181 | |
182 | /// Starts the worker threads (if that has not already happened) |
183 | /// by creating a registry with the given callback. |
184 | fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError> |
185 | where |
186 | F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>, |
187 | { |
188 | let mut result: Result<&Arc, ThreadPoolBuildError> = Err(ThreadPoolBuildError::new( |
189 | kind:ErrorKind::GlobalPoolAlreadyInitialized, |
190 | )); |
191 | |
192 | THE_REGISTRY_SET.call_once(|| { |
193 | result = registry() |
194 | .map(|registry: Arc<Registry>| unsafe { &*THE_REGISTRY.get_or_insert(registry) }) |
195 | }); |
196 | |
197 | result |
198 | } |
199 | |
200 | fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> { |
201 | let result: Result, ThreadPoolBuildError> = Registry::new(builder:ThreadPoolBuilder::new()); |
202 | |
203 | // If we're running in an environment that doesn't support threads at all, we can fall back to |
204 | // using the current thread alone. This is crude, and probably won't work for non-blocking |
205 | // calls like `spawn` or `broadcast_spawn`, but a lot of stuff does work fine. |
206 | // |
207 | // Notably, this allows current WebAssembly targets to work even though their threading support |
208 | // is stubbed out, and we won't have to change anything if they do add real threading. |
209 | let unsupported: bool = matches!(&result, Err(e) if e.is_unsupported()); |
210 | if unsupported && WorkerThread::current().is_null() { |
211 | let builder: ThreadPoolBuilder = ThreadPoolBuilder::new().num_threads(1).use_current_thread(); |
212 | let fallback_result: Result, ThreadPoolBuildError> = Registry::new(builder); |
213 | if fallback_result.is_ok() { |
214 | return fallback_result; |
215 | } |
216 | } |
217 | |
218 | result |
219 | } |
220 | |
221 | struct Terminator<'a>(&'a Arc<Registry>); |
222 | |
223 | impl<'a> Drop for Terminator<'a> { |
224 | fn drop(&mut self) { |
225 | self.0.terminate() |
226 | } |
227 | } |
228 | |
229 | impl Registry { |
230 | pub(super) fn new<S>( |
231 | mut builder: ThreadPoolBuilder<S>, |
232 | ) -> Result<Arc<Self>, ThreadPoolBuildError> |
233 | where |
234 | S: ThreadSpawn, |
235 | { |
236 | // Soft-limit the number of threads that we can actually support. |
237 | let n_threads = Ord::min(builder.get_num_threads(), crate::max_num_threads()); |
238 | |
239 | let breadth_first = builder.get_breadth_first(); |
240 | |
241 | let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads) |
242 | .map(|_| { |
243 | let worker = if breadth_first { |
244 | Worker::new_fifo() |
245 | } else { |
246 | Worker::new_lifo() |
247 | }; |
248 | |
249 | let stealer = worker.stealer(); |
250 | (worker, stealer) |
251 | }) |
252 | .unzip(); |
253 | |
254 | let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads) |
255 | .map(|_| { |
256 | let worker = Worker::new_fifo(); |
257 | let stealer = worker.stealer(); |
258 | (worker, stealer) |
259 | }) |
260 | .unzip(); |
261 | |
262 | let registry = Arc::new(Registry { |
263 | thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(), |
264 | sleep: Sleep::new(n_threads), |
265 | injected_jobs: Injector::new(), |
266 | broadcasts: Mutex::new(broadcasts), |
267 | terminate_count: AtomicUsize::new(1), |
268 | panic_handler: builder.take_panic_handler(), |
269 | start_handler: builder.take_start_handler(), |
270 | exit_handler: builder.take_exit_handler(), |
271 | }); |
272 | |
273 | // If we return early or panic, make sure to terminate existing threads. |
274 | let t1000 = Terminator(®istry); |
275 | |
276 | for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() { |
277 | let thread = ThreadBuilder { |
278 | name: builder.get_thread_name(index), |
279 | stack_size: builder.get_stack_size(), |
280 | registry: Arc::clone(®istry), |
281 | worker, |
282 | stealer, |
283 | index, |
284 | }; |
285 | |
286 | if index == 0 && builder.use_current_thread { |
287 | if !WorkerThread::current().is_null() { |
288 | return Err(ThreadPoolBuildError::new( |
289 | ErrorKind::CurrentThreadAlreadyInPool, |
290 | )); |
291 | } |
292 | // Rather than starting a new thread, we're just taking over the current thread |
293 | // *without* running the main loop, so we can still return from here. |
294 | // The WorkerThread is leaked, but we never shutdown the global pool anyway. |
295 | let worker_thread = Box::into_raw(Box::new(WorkerThread::from(thread))); |
296 | |
297 | unsafe { |
298 | WorkerThread::set_current(worker_thread); |
299 | Latch::set(®istry.thread_infos[index].primed); |
300 | } |
301 | continue; |
302 | } |
303 | |
304 | if let Err(e) = builder.get_spawn_handler().spawn(thread) { |
305 | return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e))); |
306 | } |
307 | } |
308 | |
309 | // Returning normally now, without termination. |
310 | mem::forget(t1000); |
311 | |
312 | Ok(registry) |
313 | } |
314 | |
315 | pub(super) fn current() -> Arc<Registry> { |
316 | unsafe { |
317 | let worker_thread = WorkerThread::current(); |
318 | let registry = if worker_thread.is_null() { |
319 | global_registry() |
320 | } else { |
321 | &(*worker_thread).registry |
322 | }; |
323 | Arc::clone(registry) |
324 | } |
325 | } |
326 | |
327 | /// Returns the number of threads in the current registry. This |
328 | /// is better than `Registry::current().num_threads()` because it |
329 | /// avoids incrementing the `Arc`. |
330 | pub(super) fn current_num_threads() -> usize { |
331 | unsafe { |
332 | let worker_thread = WorkerThread::current(); |
333 | if worker_thread.is_null() { |
334 | global_registry().num_threads() |
335 | } else { |
336 | (*worker_thread).registry.num_threads() |
337 | } |
338 | } |
339 | } |
340 | |
341 | /// Returns the current `WorkerThread` if it's part of this `Registry`. |
342 | pub(super) fn current_thread(&self) -> Option<&WorkerThread> { |
343 | unsafe { |
344 | let worker = WorkerThread::current().as_ref()?; |
345 | if worker.registry().id() == self.id() { |
346 | Some(worker) |
347 | } else { |
348 | None |
349 | } |
350 | } |
351 | } |
352 | |
353 | /// Returns an opaque identifier for this registry. |
354 | pub(super) fn id(&self) -> RegistryId { |
355 | // We can rely on `self` not to change since we only ever create |
356 | // registries that are boxed up in an `Arc` (see `new()` above). |
357 | RegistryId { |
358 | addr: self as *const Self as usize, |
359 | } |
360 | } |
361 | |
362 | pub(super) fn num_threads(&self) -> usize { |
363 | self.thread_infos.len() |
364 | } |
365 | |
366 | pub(super) fn catch_unwind(&self, f: impl FnOnce()) { |
367 | if let Err(err) = unwind::halt_unwinding(f) { |
368 | // If there is no handler, or if that handler itself panics, then we abort. |
369 | let abort_guard = unwind::AbortIfPanic; |
370 | if let Some(ref handler) = self.panic_handler { |
371 | handler(err); |
372 | mem::forget(abort_guard); |
373 | } |
374 | } |
375 | } |
376 | |
377 | /// Waits for the worker threads to get up and running. This is |
378 | /// meant to be used for benchmarking purposes, primarily, so that |
379 | /// you can get more consistent numbers by having everything |
380 | /// "ready to go". |
381 | pub(super) fn wait_until_primed(&self) { |
382 | for info in &self.thread_infos { |
383 | info.primed.wait(); |
384 | } |
385 | } |
386 | |
387 | /// Waits for the worker threads to stop. This is used for testing |
388 | /// -- so we can check that termination actually works. |
389 | #[cfg (test)] |
390 | pub(super) fn wait_until_stopped(&self) { |
391 | for info in &self.thread_infos { |
392 | info.stopped.wait(); |
393 | } |
394 | } |
395 | |
396 | /// //////////////////////////////////////////////////////////////////////// |
397 | /// MAIN LOOP |
398 | /// |
399 | /// So long as all of the worker threads are hanging out in their |
400 | /// top-level loop, there is no work to be done. |
401 | |
402 | /// Push a job into the given `registry`. If we are running on a |
403 | /// worker thread for the registry, this will push onto the |
404 | /// deque. Else, it will inject from the outside (which is slower). |
405 | pub(super) fn inject_or_push(&self, job_ref: JobRef) { |
406 | let worker_thread = WorkerThread::current(); |
407 | unsafe { |
408 | if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() { |
409 | (*worker_thread).push(job_ref); |
410 | } else { |
411 | self.inject(job_ref); |
412 | } |
413 | } |
414 | } |
415 | |
416 | /// Push a job into the "external jobs" queue; it will be taken by |
417 | /// whatever worker has nothing to do. Use this if you know that |
418 | /// you are not on a worker of this registry. |
419 | pub(super) fn inject(&self, injected_job: JobRef) { |
420 | // It should not be possible for `state.terminate` to be true |
421 | // here. It is only set to true when the user creates (and |
422 | // drops) a `ThreadPool`; and, in that case, they cannot be |
423 | // calling `inject()` later, since they dropped their |
424 | // `ThreadPool`. |
425 | debug_assert_ne!( |
426 | self.terminate_count.load(Ordering::Acquire), |
427 | 0, |
428 | "inject() sees state.terminate as true" |
429 | ); |
430 | |
431 | let queue_was_empty = self.injected_jobs.is_empty(); |
432 | |
433 | self.injected_jobs.push(injected_job); |
434 | self.sleep.new_injected_jobs(1, queue_was_empty); |
435 | } |
436 | |
437 | fn has_injected_job(&self) -> bool { |
438 | !self.injected_jobs.is_empty() |
439 | } |
440 | |
441 | fn pop_injected_job(&self) -> Option<JobRef> { |
442 | loop { |
443 | match self.injected_jobs.steal() { |
444 | Steal::Success(job) => return Some(job), |
445 | Steal::Empty => return None, |
446 | Steal::Retry => {} |
447 | } |
448 | } |
449 | } |
450 | |
451 | /// Push a job into each thread's own "external jobs" queue; it will be |
452 | /// executed only on that thread, when it has nothing else to do locally, |
453 | /// before it tries to steal other work. |
454 | /// |
455 | /// **Panics** if not given exactly as many jobs as there are threads. |
456 | pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>) { |
457 | assert_eq!(self.num_threads(), injected_jobs.len()); |
458 | { |
459 | let broadcasts = self.broadcasts.lock().unwrap(); |
460 | |
461 | // It should not be possible for `state.terminate` to be true |
462 | // here. It is only set to true when the user creates (and |
463 | // drops) a `ThreadPool`; and, in that case, they cannot be |
464 | // calling `inject_broadcast()` later, since they dropped their |
465 | // `ThreadPool`. |
466 | debug_assert_ne!( |
467 | self.terminate_count.load(Ordering::Acquire), |
468 | 0, |
469 | "inject_broadcast() sees state.terminate as true" |
470 | ); |
471 | |
472 | assert_eq!(broadcasts.len(), injected_jobs.len()); |
473 | for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) { |
474 | worker.push(job_ref); |
475 | } |
476 | } |
477 | for i in 0..self.num_threads() { |
478 | self.sleep.notify_worker_latch_is_set(i); |
479 | } |
480 | } |
481 | |
482 | /// If already in a worker-thread of this registry, just execute `op`. |
483 | /// Otherwise, inject `op` in this thread-pool. Either way, block until `op` |
484 | /// completes and return its return value. If `op` panics, that panic will |
485 | /// be propagated as well. The second argument indicates `true` if injection |
486 | /// was performed, `false` if executed directly. |
487 | pub(super) fn in_worker<OP, R>(&self, op: OP) -> R |
488 | where |
489 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
490 | R: Send, |
491 | { |
492 | unsafe { |
493 | let worker_thread = WorkerThread::current(); |
494 | if worker_thread.is_null() { |
495 | self.in_worker_cold(op) |
496 | } else if (*worker_thread).registry().id() != self.id() { |
497 | self.in_worker_cross(&*worker_thread, op) |
498 | } else { |
499 | // Perfectly valid to give them a `&T`: this is the |
500 | // current thread, so we know the data structure won't be |
501 | // invalidated until we return. |
502 | op(&*worker_thread, false) |
503 | } |
504 | } |
505 | } |
506 | |
507 | #[cold ] |
508 | unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R |
509 | where |
510 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
511 | R: Send, |
512 | { |
513 | thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new()); |
514 | |
515 | LOCK_LATCH.with(|l| { |
516 | // This thread isn't a member of *any* thread pool, so just block. |
517 | debug_assert!(WorkerThread::current().is_null()); |
518 | let job = StackJob::new( |
519 | |injected| { |
520 | let worker_thread = WorkerThread::current(); |
521 | assert!(injected && !worker_thread.is_null()); |
522 | op(&*worker_thread, true) |
523 | }, |
524 | LatchRef::new(l), |
525 | ); |
526 | self.inject(job.as_job_ref()); |
527 | job.latch.wait_and_reset(); // Make sure we can use the same latch again next time. |
528 | |
529 | job.into_result() |
530 | }) |
531 | } |
532 | |
533 | #[cold ] |
534 | unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R |
535 | where |
536 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
537 | R: Send, |
538 | { |
539 | // This thread is a member of a different pool, so let it process |
540 | // other work while waiting for this `op` to complete. |
541 | debug_assert!(current_thread.registry().id() != self.id()); |
542 | let latch = SpinLatch::cross(current_thread); |
543 | let job = StackJob::new( |
544 | |injected| { |
545 | let worker_thread = WorkerThread::current(); |
546 | assert!(injected && !worker_thread.is_null()); |
547 | op(&*worker_thread, true) |
548 | }, |
549 | latch, |
550 | ); |
551 | self.inject(job.as_job_ref()); |
552 | current_thread.wait_until(&job.latch); |
553 | job.into_result() |
554 | } |
555 | |
556 | /// Increments the terminate counter. This increment should be |
557 | /// balanced by a call to `terminate`, which will decrement. This |
558 | /// is used when spawning asynchronous work, which needs to |
559 | /// prevent the registry from terminating so long as it is active. |
560 | /// |
561 | /// Note that blocking functions such as `join` and `scope` do not |
562 | /// need to concern themselves with this fn; their context is |
563 | /// responsible for ensuring the current thread-pool will not |
564 | /// terminate until they return. |
565 | /// |
566 | /// The global thread-pool always has an outstanding reference |
567 | /// (the initial one). Custom thread-pools have one outstanding |
568 | /// reference that is dropped when the `ThreadPool` is dropped: |
569 | /// since installing the thread-pool blocks until any joins/scopes |
570 | /// complete, this ensures that joins/scopes are covered. |
571 | /// |
572 | /// The exception is `::spawn()`, which can create a job outside |
573 | /// of any blocking scope. In that case, the job itself holds a |
574 | /// terminate count and is responsible for invoking `terminate()` |
575 | /// when finished. |
576 | pub(super) fn increment_terminate_count(&self) { |
577 | let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel); |
578 | debug_assert!(previous != 0, "registry ref count incremented from zero" ); |
579 | assert!( |
580 | previous != std::usize::MAX, |
581 | "overflow in registry ref count" |
582 | ); |
583 | } |
584 | |
585 | /// Signals that the thread-pool which owns this registry has been |
586 | /// dropped. The worker threads will gradually terminate, once any |
587 | /// extant work is completed. |
588 | pub(super) fn terminate(&self) { |
589 | if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 { |
590 | for (i, thread_info) in self.thread_infos.iter().enumerate() { |
591 | unsafe { OnceLatch::set_and_tickle_one(&thread_info.terminate, self, i) }; |
592 | } |
593 | } |
594 | } |
595 | |
596 | /// Notify the worker that the latch they are sleeping on has been "set". |
597 | pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) { |
598 | self.sleep.notify_worker_latch_is_set(target_worker_index); |
599 | } |
600 | } |
601 | |
602 | #[derive (Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] |
603 | pub(super) struct RegistryId { |
604 | addr: usize, |
605 | } |
606 | |
607 | struct ThreadInfo { |
608 | /// Latch set once thread has started and we are entering into the |
609 | /// main loop. Used to wait for worker threads to become primed, |
610 | /// primarily of interest for benchmarking. |
611 | primed: LockLatch, |
612 | |
613 | /// Latch is set once worker thread has completed. Used to wait |
614 | /// until workers have stopped; only used for tests. |
615 | stopped: LockLatch, |
616 | |
617 | /// The latch used to signal that terminated has been requested. |
618 | /// This latch is *set* by the `terminate` method on the |
619 | /// `Registry`, once the registry's main "terminate" counter |
620 | /// reaches zero. |
621 | terminate: OnceLatch, |
622 | |
623 | /// the "stealer" half of the worker's deque |
624 | stealer: Stealer<JobRef>, |
625 | } |
626 | |
627 | impl ThreadInfo { |
628 | fn new(stealer: Stealer<JobRef>) -> ThreadInfo { |
629 | ThreadInfo { |
630 | primed: LockLatch::new(), |
631 | stopped: LockLatch::new(), |
632 | terminate: OnceLatch::new(), |
633 | stealer, |
634 | } |
635 | } |
636 | } |
637 | |
638 | /// //////////////////////////////////////////////////////////////////////// |
639 | /// WorkerThread identifiers |
640 | |
641 | pub(super) struct WorkerThread { |
642 | /// the "worker" half of our local deque |
643 | worker: Worker<JobRef>, |
644 | |
645 | /// the "stealer" half of the worker's broadcast deque |
646 | stealer: Stealer<JobRef>, |
647 | |
648 | /// local queue used for `spawn_fifo` indirection |
649 | fifo: JobFifo, |
650 | |
651 | index: usize, |
652 | |
653 | /// A weak random number generator. |
654 | rng: XorShift64Star, |
655 | |
656 | registry: Arc<Registry>, |
657 | } |
658 | |
659 | // This is a bit sketchy, but basically: the WorkerThread is |
660 | // allocated on the stack of the worker on entry and stored into this |
661 | // thread local variable. So it will remain valid at least until the |
662 | // worker is fully unwound. Using an unsafe pointer avoids the need |
663 | // for a RefCell<T> etc. |
664 | thread_local! { |
665 | static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null()) }; |
666 | } |
667 | |
668 | impl From<ThreadBuilder> for WorkerThread { |
669 | fn from(thread: ThreadBuilder) -> Self { |
670 | Self { |
671 | worker: thread.worker, |
672 | stealer: thread.stealer, |
673 | fifo: JobFifo::new(), |
674 | index: thread.index, |
675 | rng: XorShift64Star::new(), |
676 | registry: thread.registry, |
677 | } |
678 | } |
679 | } |
680 | |
681 | impl Drop for WorkerThread { |
682 | fn drop(&mut self) { |
683 | // Undo `set_current` |
684 | WORKER_THREAD_STATE.with(|t: &Cell<*const WorkerThread>| { |
685 | assert!(t.get().eq(&(self as *const _))); |
686 | t.set(val:ptr::null()); |
687 | }); |
688 | } |
689 | } |
690 | |
691 | impl WorkerThread { |
692 | /// Gets the `WorkerThread` index for the current thread; returns |
693 | /// NULL if this is not a worker thread. This pointer is valid |
694 | /// anywhere on the current thread. |
695 | #[inline ] |
696 | pub(super) fn current() -> *const WorkerThread { |
697 | WORKER_THREAD_STATE.with(Cell::get) |
698 | } |
699 | |
700 | /// Sets `self` as the worker thread index for the current thread. |
701 | /// This is done during worker thread startup. |
702 | unsafe fn set_current(thread: *const WorkerThread) { |
703 | WORKER_THREAD_STATE.with(|t| { |
704 | assert!(t.get().is_null()); |
705 | t.set(thread); |
706 | }); |
707 | } |
708 | |
709 | /// Returns the registry that owns this worker thread. |
710 | #[inline ] |
711 | pub(super) fn registry(&self) -> &Arc<Registry> { |
712 | &self.registry |
713 | } |
714 | |
715 | /// Our index amongst the worker threads (ranges from `0..self.num_threads()`). |
716 | #[inline ] |
717 | pub(super) fn index(&self) -> usize { |
718 | self.index |
719 | } |
720 | |
721 | #[inline ] |
722 | pub(super) unsafe fn push(&self, job: JobRef) { |
723 | let queue_was_empty = self.worker.is_empty(); |
724 | self.worker.push(job); |
725 | self.registry.sleep.new_internal_jobs(1, queue_was_empty); |
726 | } |
727 | |
728 | #[inline ] |
729 | pub(super) unsafe fn push_fifo(&self, job: JobRef) { |
730 | self.push(self.fifo.push(job)); |
731 | } |
732 | |
733 | #[inline ] |
734 | pub(super) fn local_deque_is_empty(&self) -> bool { |
735 | self.worker.is_empty() |
736 | } |
737 | |
738 | /// Attempts to obtain a "local" job -- typically this means |
739 | /// popping from the top of the stack, though if we are configured |
740 | /// for breadth-first execution, it would mean dequeuing from the |
741 | /// bottom. |
742 | #[inline ] |
743 | pub(super) fn take_local_job(&self) -> Option<JobRef> { |
744 | let popped_job = self.worker.pop(); |
745 | |
746 | if popped_job.is_some() { |
747 | return popped_job; |
748 | } |
749 | |
750 | loop { |
751 | match self.stealer.steal() { |
752 | Steal::Success(job) => return Some(job), |
753 | Steal::Empty => return None, |
754 | Steal::Retry => {} |
755 | } |
756 | } |
757 | } |
758 | |
759 | fn has_injected_job(&self) -> bool { |
760 | !self.stealer.is_empty() || self.registry.has_injected_job() |
761 | } |
762 | |
763 | /// Wait until the latch is set. Try to keep busy by popping and |
764 | /// stealing tasks as necessary. |
765 | #[inline ] |
766 | pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) { |
767 | let latch = latch.as_core_latch(); |
768 | if !latch.probe() { |
769 | self.wait_until_cold(latch); |
770 | } |
771 | } |
772 | |
773 | #[cold ] |
774 | unsafe fn wait_until_cold(&self, latch: &CoreLatch) { |
775 | // the code below should swallow all panics and hence never |
776 | // unwind; but if something does wrong, we want to abort, |
777 | // because otherwise other code in rayon may assume that the |
778 | // latch has been signaled, and that can lead to random memory |
779 | // accesses, which would be *very bad* |
780 | let abort_guard = unwind::AbortIfPanic; |
781 | |
782 | 'outer: while !latch.probe() { |
783 | // Check for local work *before* we start marking ourself idle, |
784 | // especially to avoid modifying shared sleep state. |
785 | if let Some(job) = self.take_local_job() { |
786 | self.execute(job); |
787 | continue; |
788 | } |
789 | |
790 | let mut idle_state = self.registry.sleep.start_looking(self.index); |
791 | while !latch.probe() { |
792 | if let Some(job) = self.find_work() { |
793 | self.registry.sleep.work_found(); |
794 | self.execute(job); |
795 | // The job might have injected local work, so go back to the outer loop. |
796 | continue 'outer; |
797 | } else { |
798 | self.registry |
799 | .sleep |
800 | .no_work_found(&mut idle_state, latch, || self.has_injected_job()) |
801 | } |
802 | } |
803 | |
804 | // If we were sleepy, we are not anymore. We "found work" -- |
805 | // whatever the surrounding thread was doing before it had to wait. |
806 | self.registry.sleep.work_found(); |
807 | break; |
808 | } |
809 | |
810 | mem::forget(abort_guard); // successful execution, do not abort |
811 | } |
812 | |
813 | unsafe fn wait_until_out_of_work(&self) { |
814 | debug_assert_eq!(self as *const _, WorkerThread::current()); |
815 | let registry = &*self.registry; |
816 | let index = self.index; |
817 | |
818 | self.wait_until(®istry.thread_infos[index].terminate); |
819 | |
820 | // Should not be any work left in our queue. |
821 | debug_assert!(self.take_local_job().is_none()); |
822 | |
823 | // Let registry know we are done |
824 | Latch::set(®istry.thread_infos[index].stopped); |
825 | } |
826 | |
827 | fn find_work(&self) -> Option<JobRef> { |
828 | // Try to find some work to do. We give preference first |
829 | // to things in our local deque, then in other workers |
830 | // deques, and finally to injected jobs from the |
831 | // outside. The idea is to finish what we started before |
832 | // we take on something new. |
833 | self.take_local_job() |
834 | .or_else(|| self.steal()) |
835 | .or_else(|| self.registry.pop_injected_job()) |
836 | } |
837 | |
838 | pub(super) fn yield_now(&self) -> Yield { |
839 | match self.find_work() { |
840 | Some(job) => unsafe { |
841 | self.execute(job); |
842 | Yield::Executed |
843 | }, |
844 | None => Yield::Idle, |
845 | } |
846 | } |
847 | |
848 | pub(super) fn yield_local(&self) -> Yield { |
849 | match self.take_local_job() { |
850 | Some(job) => unsafe { |
851 | self.execute(job); |
852 | Yield::Executed |
853 | }, |
854 | None => Yield::Idle, |
855 | } |
856 | } |
857 | |
858 | #[inline ] |
859 | pub(super) unsafe fn execute(&self, job: JobRef) { |
860 | job.execute(); |
861 | } |
862 | |
863 | /// Try to steal a single job and return it. |
864 | /// |
865 | /// This should only be done as a last resort, when there is no |
866 | /// local work to do. |
867 | fn steal(&self) -> Option<JobRef> { |
868 | // we only steal when we don't have any work to do locally |
869 | debug_assert!(self.local_deque_is_empty()); |
870 | |
871 | // otherwise, try to steal |
872 | let thread_infos = &self.registry.thread_infos.as_slice(); |
873 | let num_threads = thread_infos.len(); |
874 | if num_threads <= 1 { |
875 | return None; |
876 | } |
877 | |
878 | loop { |
879 | let mut retry = false; |
880 | let start = self.rng.next_usize(num_threads); |
881 | let job = (start..num_threads) |
882 | .chain(0..start) |
883 | .filter(move |&i| i != self.index) |
884 | .find_map(|victim_index| { |
885 | let victim = &thread_infos[victim_index]; |
886 | match victim.stealer.steal() { |
887 | Steal::Success(job) => Some(job), |
888 | Steal::Empty => None, |
889 | Steal::Retry => { |
890 | retry = true; |
891 | None |
892 | } |
893 | } |
894 | }); |
895 | if job.is_some() || !retry { |
896 | return job; |
897 | } |
898 | } |
899 | } |
900 | } |
901 | |
902 | /// //////////////////////////////////////////////////////////////////////// |
903 | |
904 | unsafe fn main_loop(thread: ThreadBuilder) { |
905 | let worker_thread = &WorkerThread::from(thread); |
906 | WorkerThread::set_current(worker_thread); |
907 | let registry = &*worker_thread.registry; |
908 | let index = worker_thread.index; |
909 | |
910 | // let registry know we are ready to do work |
911 | Latch::set(®istry.thread_infos[index].primed); |
912 | |
913 | // Worker threads should not panic. If they do, just abort, as the |
914 | // internal state of the threadpool is corrupted. Note that if |
915 | // **user code** panics, we should catch that and redirect. |
916 | let abort_guard = unwind::AbortIfPanic; |
917 | |
918 | // Inform a user callback that we started a thread. |
919 | if let Some(ref handler) = registry.start_handler { |
920 | registry.catch_unwind(|| handler(index)); |
921 | } |
922 | |
923 | worker_thread.wait_until_out_of_work(); |
924 | |
925 | // Normal termination, do not abort. |
926 | mem::forget(abort_guard); |
927 | |
928 | // Inform a user callback that we exited a thread. |
929 | if let Some(ref handler) = registry.exit_handler { |
930 | registry.catch_unwind(|| handler(index)); |
931 | // We're already exiting the thread, there's nothing else to do. |
932 | } |
933 | } |
934 | |
935 | /// If already in a worker-thread, just execute `op`. Otherwise, |
936 | /// execute `op` in the default thread-pool. Either way, block until |
937 | /// `op` completes and return its return value. If `op` panics, that |
938 | /// panic will be propagated as well. The second argument indicates |
939 | /// `true` if injection was performed, `false` if executed directly. |
940 | pub(super) fn in_worker<OP, R>(op: OP) -> R |
941 | where |
942 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
943 | R: Send, |
944 | { |
945 | unsafe { |
946 | let owner_thread: *const WorkerThread = WorkerThread::current(); |
947 | if !owner_thread.is_null() { |
948 | // Perfectly valid to give them a `&T`: this is the |
949 | // current thread, so we know the data structure won't be |
950 | // invalidated until we return. |
951 | op(&*owner_thread, false) |
952 | } else { |
953 | global_registry().in_worker(op) |
954 | } |
955 | } |
956 | } |
957 | |
958 | /// [xorshift*] is a fast pseudorandom number generator which will |
959 | /// even tolerate weak seeding, as long as it's not zero. |
960 | /// |
961 | /// [xorshift*]: https://en.wikipedia.org/wiki/Xorshift#xorshift* |
962 | struct XorShift64Star { |
963 | state: Cell<u64>, |
964 | } |
965 | |
966 | impl XorShift64Star { |
967 | fn new() -> Self { |
968 | // Any non-zero seed will do -- this uses the hash of a global counter. |
969 | let mut seed = 0; |
970 | while seed == 0 { |
971 | let mut hasher = DefaultHasher::new(); |
972 | static COUNTER: AtomicUsize = AtomicUsize::new(0); |
973 | hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed)); |
974 | seed = hasher.finish(); |
975 | } |
976 | |
977 | XorShift64Star { |
978 | state: Cell::new(seed), |
979 | } |
980 | } |
981 | |
982 | fn next(&self) -> u64 { |
983 | let mut x = self.state.get(); |
984 | debug_assert_ne!(x, 0); |
985 | x ^= x >> 12; |
986 | x ^= x << 25; |
987 | x ^= x >> 27; |
988 | self.state.set(x); |
989 | x.wrapping_mul(0x2545_f491_4f6c_dd1d) |
990 | } |
991 | |
992 | /// Return a value from `0..n`. |
993 | fn next_usize(&self, n: usize) -> usize { |
994 | (self.next() % n as u64) as usize |
995 | } |
996 | } |
997 | |