| 1 | use crate::job::{JobFifo, JobRef, StackJob}; |
| 2 | use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LatchRef, LockLatch, SpinLatch}; |
| 3 | use crate::log::Event::*; |
| 4 | use crate::log::Logger; |
| 5 | use crate::sleep::Sleep; |
| 6 | use crate::tlv::Tlv; |
| 7 | use crate::unwind; |
| 8 | use crate::{ |
| 9 | AcquireThreadHandler, DeadlockHandler, ErrorKind, ExitHandler, PanicHandler, |
| 10 | ReleaseThreadHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder, Yield, |
| 11 | }; |
| 12 | use crossbeam_deque::{Injector, Steal, Stealer, Worker}; |
| 13 | use std::cell::Cell; |
| 14 | use std::collections::hash_map::DefaultHasher; |
| 15 | use std::fmt; |
| 16 | use std::hash::Hasher; |
| 17 | use std::io; |
| 18 | use std::mem; |
| 19 | use std::ptr; |
| 20 | use std::sync::atomic::{AtomicUsize, Ordering}; |
| 21 | use std::sync::{Arc, Mutex, Once}; |
| 22 | use std::thread; |
| 23 | use std::usize; |
| 24 | |
| 25 | /// Thread builder used for customization via |
| 26 | /// [`ThreadPoolBuilder::spawn_handler`](struct.ThreadPoolBuilder.html#method.spawn_handler). |
| 27 | pub struct ThreadBuilder { |
| 28 | name: Option<String>, |
| 29 | stack_size: Option<usize>, |
| 30 | worker: Worker<JobRef>, |
| 31 | stealer: Stealer<JobRef>, |
| 32 | registry: Arc<Registry>, |
| 33 | index: usize, |
| 34 | } |
| 35 | |
| 36 | impl ThreadBuilder { |
| 37 | /// Gets the index of this thread in the pool, within `0..num_threads`. |
| 38 | pub fn index(&self) -> usize { |
| 39 | self.index |
| 40 | } |
| 41 | |
| 42 | /// Gets the string that was specified by `ThreadPoolBuilder::name()`. |
| 43 | pub fn name(&self) -> Option<&str> { |
| 44 | self.name.as_deref() |
| 45 | } |
| 46 | |
| 47 | /// Gets the value that was specified by `ThreadPoolBuilder::stack_size()`. |
| 48 | pub fn stack_size(&self) -> Option<usize> { |
| 49 | self.stack_size |
| 50 | } |
| 51 | |
| 52 | /// Executes the main loop for this thread. This will not return until the |
| 53 | /// thread pool is dropped. |
| 54 | pub fn run(self) { |
| 55 | unsafe { main_loop(self) } |
| 56 | } |
| 57 | } |
| 58 | |
| 59 | impl fmt::Debug for ThreadBuilder { |
| 60 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 61 | f&mut DebugStruct<'_, '_>.debug_struct("ThreadBuilder" ) |
| 62 | .field("pool" , &self.registry.id()) |
| 63 | .field("index" , &self.index) |
| 64 | .field("name" , &self.name) |
| 65 | .field(name:"stack_size" , &self.stack_size) |
| 66 | .finish() |
| 67 | } |
| 68 | } |
| 69 | |
| 70 | /// Generalized trait for spawning a thread in the `Registry`. |
| 71 | /// |
| 72 | /// This trait is pub-in-private -- E0445 forces us to make it public, |
| 73 | /// but we don't actually want to expose these details in the API. |
| 74 | pub trait ThreadSpawn { |
| 75 | private_decl! {} |
| 76 | |
| 77 | /// Spawn a thread with the `ThreadBuilder` parameters, and then |
| 78 | /// call `ThreadBuilder::run()`. |
| 79 | fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>; |
| 80 | } |
| 81 | |
| 82 | /// Spawns a thread in the "normal" way with `std::thread::Builder`. |
| 83 | /// |
| 84 | /// This type is pub-in-private -- E0445 forces us to make it public, |
| 85 | /// but we don't actually want to expose these details in the API. |
| 86 | #[derive (Debug, Default)] |
| 87 | pub struct DefaultSpawn; |
| 88 | |
| 89 | impl ThreadSpawn for DefaultSpawn { |
| 90 | private_impl! {} |
| 91 | |
| 92 | fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> { |
| 93 | let mut b: Builder = thread::Builder::new(); |
| 94 | if let Some(name: &str) = thread.name() { |
| 95 | b = b.name(name.to_owned()); |
| 96 | } |
| 97 | if let Some(stack_size: usize) = thread.stack_size() { |
| 98 | b = b.stack_size(stack_size); |
| 99 | } |
| 100 | b.spawn(|| thread.run())?; |
| 101 | Ok(()) |
| 102 | } |
| 103 | } |
| 104 | |
| 105 | /// Spawns a thread with a user's custom callback. |
| 106 | /// |
| 107 | /// This type is pub-in-private -- E0445 forces us to make it public, |
| 108 | /// but we don't actually want to expose these details in the API. |
| 109 | #[derive (Debug)] |
| 110 | pub struct CustomSpawn<F>(F); |
| 111 | |
| 112 | impl<F> CustomSpawn<F> |
| 113 | where |
| 114 | F: FnMut(ThreadBuilder) -> io::Result<()>, |
| 115 | { |
| 116 | pub(super) fn new(spawn: F) -> Self { |
| 117 | CustomSpawn(spawn) |
| 118 | } |
| 119 | } |
| 120 | |
| 121 | impl<F> ThreadSpawn for CustomSpawn<F> |
| 122 | where |
| 123 | F: FnMut(ThreadBuilder) -> io::Result<()>, |
| 124 | { |
| 125 | private_impl! {} |
| 126 | |
| 127 | #[inline ] |
| 128 | fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> { |
| 129 | (self.0)(thread) |
| 130 | } |
| 131 | } |
| 132 | |
| 133 | pub struct Registry { |
| 134 | logger: Logger, |
| 135 | thread_infos: Vec<ThreadInfo>, |
| 136 | sleep: Sleep, |
| 137 | injected_jobs: Injector<JobRef>, |
| 138 | broadcasts: Mutex<Vec<Worker<JobRef>>>, |
| 139 | panic_handler: Option<Box<PanicHandler>>, |
| 140 | pub(crate) deadlock_handler: Option<Box<DeadlockHandler>>, |
| 141 | start_handler: Option<Box<StartHandler>>, |
| 142 | exit_handler: Option<Box<ExitHandler>>, |
| 143 | pub(crate) acquire_thread_handler: Option<Box<AcquireThreadHandler>>, |
| 144 | pub(crate) release_thread_handler: Option<Box<ReleaseThreadHandler>>, |
| 145 | |
| 146 | // When this latch reaches 0, it means that all work on this |
| 147 | // registry must be complete. This is ensured in the following ways: |
| 148 | // |
| 149 | // - if this is the global registry, there is a ref-count that never |
| 150 | // gets released. |
| 151 | // - if this is a user-created thread-pool, then so long as the thread-pool |
| 152 | // exists, it holds a reference. |
| 153 | // - when we inject a "blocking job" into the registry with `ThreadPool::install()`, |
| 154 | // no adjustment is needed; the `ThreadPool` holds the reference, and since we won't |
| 155 | // return until the blocking job is complete, that ref will continue to be held. |
| 156 | // - when `join()` or `scope()` is invoked, similarly, no adjustments are needed. |
| 157 | // These are always owned by some other job (e.g., one injected by `ThreadPool::install()`) |
| 158 | // and that job will keep the pool alive. |
| 159 | terminate_count: AtomicUsize, |
| 160 | } |
| 161 | |
| 162 | /// //////////////////////////////////////////////////////////////////////// |
| 163 | /// Initialization |
| 164 | |
| 165 | static mut THE_REGISTRY: Option<Arc<Registry>> = None; |
| 166 | static THE_REGISTRY_SET: Once = Once::new(); |
| 167 | |
| 168 | /// Starts the worker threads (if that has not already happened). If |
| 169 | /// initialization has not already occurred, use the default |
| 170 | /// configuration. |
| 171 | pub(super) fn global_registry() -> &'static Arc<Registry> { |
| 172 | set_global_registry(default_global_registry) |
| 173 | .or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) }) |
| 174 | .expect(msg:"The global thread pool has not been initialized." ) |
| 175 | } |
| 176 | |
| 177 | /// Starts the worker threads (if that has not already happened) with |
| 178 | /// the given builder. |
| 179 | pub(super) fn init_global_registry<S>( |
| 180 | builder: ThreadPoolBuilder<S>, |
| 181 | ) -> Result<&'static Arc<Registry>, ThreadPoolBuildError> |
| 182 | where |
| 183 | S: ThreadSpawn, |
| 184 | { |
| 185 | set_global_registry(|| Registry::new(builder)) |
| 186 | } |
| 187 | |
| 188 | /// Starts the worker threads (if that has not already happened) |
| 189 | /// by creating a registry with the given callback. |
| 190 | fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError> |
| 191 | where |
| 192 | F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>, |
| 193 | { |
| 194 | let mut result: Result<&Arc, ThreadPoolBuildError> = Err(ThreadPoolBuildError::new( |
| 195 | kind:ErrorKind::GlobalPoolAlreadyInitialized, |
| 196 | )); |
| 197 | |
| 198 | THE_REGISTRY_SET.call_once(|| { |
| 199 | result = registry() |
| 200 | .map(|registry: Arc<Registry>| unsafe { &*THE_REGISTRY.get_or_insert(registry) }) |
| 201 | }); |
| 202 | |
| 203 | result |
| 204 | } |
| 205 | |
| 206 | fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> { |
| 207 | let result = Registry::new(ThreadPoolBuilder::new()); |
| 208 | |
| 209 | // If we're running in an environment that doesn't support threads at all, we can fall back to |
| 210 | // using the current thread alone. This is crude, and probably won't work for non-blocking |
| 211 | // calls like `spawn` or `broadcast_spawn`, but a lot of stuff does work fine. |
| 212 | // |
| 213 | // Notably, this allows current WebAssembly targets to work even though their threading support |
| 214 | // is stubbed out, and we won't have to change anything if they do add real threading. |
| 215 | let unsupported = matches!(&result, Err(e) if e.is_unsupported()); |
| 216 | if unsupported && WorkerThread::current().is_null() { |
| 217 | let builder = ThreadPoolBuilder::new() |
| 218 | .num_threads(1) |
| 219 | .spawn_handler(|thread| { |
| 220 | // Rather than starting a new thread, we're just taking over the current thread |
| 221 | // *without* running the main loop, so we can still return from here. |
| 222 | // The WorkerThread is leaked, but we never shutdown the global pool anyway. |
| 223 | let worker_thread = Box::leak(Box::new(WorkerThread::from(thread))); |
| 224 | let registry = &*worker_thread.registry; |
| 225 | let index = worker_thread.index; |
| 226 | |
| 227 | unsafe { |
| 228 | WorkerThread::set_current(worker_thread); |
| 229 | |
| 230 | // let registry know we are ready to do work |
| 231 | Latch::set(®istry.thread_infos[index].primed); |
| 232 | } |
| 233 | |
| 234 | Ok(()) |
| 235 | }); |
| 236 | |
| 237 | let fallback_result = Registry::new(builder); |
| 238 | if fallback_result.is_ok() { |
| 239 | return fallback_result; |
| 240 | } |
| 241 | } |
| 242 | |
| 243 | result |
| 244 | } |
| 245 | |
| 246 | struct Terminator<'a>(&'a Arc<Registry>); |
| 247 | |
| 248 | impl<'a> Drop for Terminator<'a> { |
| 249 | fn drop(&mut self) { |
| 250 | self.0.terminate() |
| 251 | } |
| 252 | } |
| 253 | |
| 254 | impl Registry { |
| 255 | pub(super) fn new<S>( |
| 256 | mut builder: ThreadPoolBuilder<S>, |
| 257 | ) -> Result<Arc<Self>, ThreadPoolBuildError> |
| 258 | where |
| 259 | S: ThreadSpawn, |
| 260 | { |
| 261 | // Soft-limit the number of threads that we can actually support. |
| 262 | let n_threads = Ord::min(builder.get_num_threads(), crate::max_num_threads()); |
| 263 | |
| 264 | let breadth_first = builder.get_breadth_first(); |
| 265 | |
| 266 | let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads) |
| 267 | .map(|_| { |
| 268 | let worker = if breadth_first { |
| 269 | Worker::new_fifo() |
| 270 | } else { |
| 271 | Worker::new_lifo() |
| 272 | }; |
| 273 | |
| 274 | let stealer = worker.stealer(); |
| 275 | (worker, stealer) |
| 276 | }) |
| 277 | .unzip(); |
| 278 | |
| 279 | let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads) |
| 280 | .map(|_| { |
| 281 | let worker = Worker::new_fifo(); |
| 282 | let stealer = worker.stealer(); |
| 283 | (worker, stealer) |
| 284 | }) |
| 285 | .unzip(); |
| 286 | |
| 287 | let logger = Logger::new(n_threads); |
| 288 | let registry = Arc::new(Registry { |
| 289 | logger: logger.clone(), |
| 290 | thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(), |
| 291 | sleep: Sleep::new(logger, n_threads), |
| 292 | injected_jobs: Injector::new(), |
| 293 | broadcasts: Mutex::new(broadcasts), |
| 294 | terminate_count: AtomicUsize::new(1), |
| 295 | panic_handler: builder.take_panic_handler(), |
| 296 | deadlock_handler: builder.take_deadlock_handler(), |
| 297 | start_handler: builder.take_start_handler(), |
| 298 | exit_handler: builder.take_exit_handler(), |
| 299 | acquire_thread_handler: builder.take_acquire_thread_handler(), |
| 300 | release_thread_handler: builder.take_release_thread_handler(), |
| 301 | }); |
| 302 | |
| 303 | // If we return early or panic, make sure to terminate existing threads. |
| 304 | let t1000 = Terminator(®istry); |
| 305 | |
| 306 | for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() { |
| 307 | let thread = ThreadBuilder { |
| 308 | name: builder.get_thread_name(index), |
| 309 | stack_size: builder.get_stack_size(), |
| 310 | registry: Arc::clone(®istry), |
| 311 | worker, |
| 312 | stealer, |
| 313 | index, |
| 314 | }; |
| 315 | if let Err(e) = builder.get_spawn_handler().spawn(thread) { |
| 316 | return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e))); |
| 317 | } |
| 318 | } |
| 319 | |
| 320 | // Returning normally now, without termination. |
| 321 | mem::forget(t1000); |
| 322 | |
| 323 | Ok(registry) |
| 324 | } |
| 325 | |
| 326 | pub fn current() -> Arc<Registry> { |
| 327 | unsafe { |
| 328 | let worker_thread = WorkerThread::current(); |
| 329 | let registry = if worker_thread.is_null() { |
| 330 | global_registry() |
| 331 | } else { |
| 332 | &(*worker_thread).registry |
| 333 | }; |
| 334 | Arc::clone(registry) |
| 335 | } |
| 336 | } |
| 337 | |
| 338 | /// Returns the number of threads in the current registry. This |
| 339 | /// is better than `Registry::current().num_threads()` because it |
| 340 | /// avoids incrementing the `Arc`. |
| 341 | pub(super) fn current_num_threads() -> usize { |
| 342 | unsafe { |
| 343 | let worker_thread = WorkerThread::current(); |
| 344 | if worker_thread.is_null() { |
| 345 | global_registry().num_threads() |
| 346 | } else { |
| 347 | (*worker_thread).registry.num_threads() |
| 348 | } |
| 349 | } |
| 350 | } |
| 351 | |
| 352 | /// Returns the current `WorkerThread` if it's part of this `Registry`. |
| 353 | pub(super) fn current_thread(&self) -> Option<&WorkerThread> { |
| 354 | unsafe { |
| 355 | let worker = WorkerThread::current().as_ref()?; |
| 356 | if worker.registry().id() == self.id() { |
| 357 | Some(worker) |
| 358 | } else { |
| 359 | None |
| 360 | } |
| 361 | } |
| 362 | } |
| 363 | |
| 364 | /// Returns an opaque identifier for this registry. |
| 365 | pub(super) fn id(&self) -> RegistryId { |
| 366 | // We can rely on `self` not to change since we only ever create |
| 367 | // registries that are boxed up in an `Arc` (see `new()` above). |
| 368 | RegistryId { |
| 369 | addr: self as *const Self as usize, |
| 370 | } |
| 371 | } |
| 372 | |
| 373 | #[inline ] |
| 374 | pub(super) fn log(&self, event: impl FnOnce() -> crate::log::Event) { |
| 375 | self.logger.log(event) |
| 376 | } |
| 377 | |
| 378 | pub(super) fn num_threads(&self) -> usize { |
| 379 | self.thread_infos.len() |
| 380 | } |
| 381 | |
| 382 | pub(super) fn catch_unwind(&self, f: impl FnOnce()) { |
| 383 | if let Err(err) = unwind::halt_unwinding(f) { |
| 384 | // If there is no handler, or if that handler itself panics, then we abort. |
| 385 | let abort_guard = unwind::AbortIfPanic; |
| 386 | if let Some(ref handler) = self.panic_handler { |
| 387 | handler(err); |
| 388 | mem::forget(abort_guard); |
| 389 | } |
| 390 | } |
| 391 | } |
| 392 | |
| 393 | /// Waits for the worker threads to get up and running. This is |
| 394 | /// meant to be used for benchmarking purposes, primarily, so that |
| 395 | /// you can get more consistent numbers by having everything |
| 396 | /// "ready to go". |
| 397 | pub(super) fn wait_until_primed(&self) { |
| 398 | for info in &self.thread_infos { |
| 399 | info.primed.wait(); |
| 400 | } |
| 401 | } |
| 402 | |
| 403 | /// Waits for the worker threads to stop. This is used for testing |
| 404 | /// -- so we can check that termination actually works. |
| 405 | pub(super) fn wait_until_stopped(&self) { |
| 406 | self.release_thread(); |
| 407 | for info in &self.thread_infos { |
| 408 | info.stopped.wait(); |
| 409 | } |
| 410 | self.acquire_thread(); |
| 411 | } |
| 412 | |
| 413 | pub(crate) fn acquire_thread(&self) { |
| 414 | if let Some(ref acquire_thread_handler) = self.acquire_thread_handler { |
| 415 | acquire_thread_handler(); |
| 416 | } |
| 417 | } |
| 418 | |
| 419 | pub(crate) fn release_thread(&self) { |
| 420 | if let Some(ref release_thread_handler) = self.release_thread_handler { |
| 421 | release_thread_handler(); |
| 422 | } |
| 423 | } |
| 424 | |
| 425 | /// //////////////////////////////////////////////////////////////////////// |
| 426 | /// MAIN LOOP |
| 427 | /// |
| 428 | /// So long as all of the worker threads are hanging out in their |
| 429 | /// top-level loop, there is no work to be done. |
| 430 | |
| 431 | /// Push a job into the given `registry`. If we are running on a |
| 432 | /// worker thread for the registry, this will push onto the |
| 433 | /// deque. Else, it will inject from the outside (which is slower). |
| 434 | pub(super) fn inject_or_push(&self, job_ref: JobRef) { |
| 435 | let worker_thread = WorkerThread::current(); |
| 436 | unsafe { |
| 437 | if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() { |
| 438 | (*worker_thread).push(job_ref); |
| 439 | } else { |
| 440 | self.inject(job_ref); |
| 441 | } |
| 442 | } |
| 443 | } |
| 444 | |
| 445 | /// Push a job into the "external jobs" queue; it will be taken by |
| 446 | /// whatever worker has nothing to do. Use this if you know that |
| 447 | /// you are not on a worker of this registry. |
| 448 | pub(super) fn inject(&self, injected_job: JobRef) { |
| 449 | self.log(|| JobsInjected { count: 1 }); |
| 450 | |
| 451 | // It should not be possible for `state.terminate` to be true |
| 452 | // here. It is only set to true when the user creates (and |
| 453 | // drops) a `ThreadPool`; and, in that case, they cannot be |
| 454 | // calling `inject()` later, since they dropped their |
| 455 | // `ThreadPool`. |
| 456 | debug_assert_ne!( |
| 457 | self.terminate_count.load(Ordering::Acquire), |
| 458 | 0, |
| 459 | "inject() sees state.terminate as true" |
| 460 | ); |
| 461 | |
| 462 | let queue_was_empty = self.injected_jobs.is_empty(); |
| 463 | |
| 464 | self.injected_jobs.push(injected_job); |
| 465 | self.sleep.new_injected_jobs(usize::MAX, 1, queue_was_empty); |
| 466 | } |
| 467 | |
| 468 | pub(crate) fn has_injected_job(&self) -> bool { |
| 469 | !self.injected_jobs.is_empty() |
| 470 | } |
| 471 | |
| 472 | fn pop_injected_job(&self, worker_index: usize) -> Option<JobRef> { |
| 473 | loop { |
| 474 | match self.injected_jobs.steal() { |
| 475 | Steal::Success(job) => { |
| 476 | self.log(|| JobUninjected { |
| 477 | worker: worker_index, |
| 478 | }); |
| 479 | return Some(job); |
| 480 | } |
| 481 | Steal::Empty => return None, |
| 482 | Steal::Retry => {} |
| 483 | } |
| 484 | } |
| 485 | } |
| 486 | |
| 487 | /// Push a job into each thread's own "external jobs" queue; it will be |
| 488 | /// executed only on that thread, when it has nothing else to do locally, |
| 489 | /// before it tries to steal other work. |
| 490 | /// |
| 491 | /// **Panics** if not given exactly as many jobs as there are threads. |
| 492 | pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>) { |
| 493 | assert_eq!(self.num_threads(), injected_jobs.len()); |
| 494 | self.log(|| JobBroadcast { |
| 495 | count: self.num_threads(), |
| 496 | }); |
| 497 | { |
| 498 | let broadcasts = self.broadcasts.lock().unwrap(); |
| 499 | |
| 500 | // It should not be possible for `state.terminate` to be true |
| 501 | // here. It is only set to true when the user creates (and |
| 502 | // drops) a `ThreadPool`; and, in that case, they cannot be |
| 503 | // calling `inject_broadcast()` later, since they dropped their |
| 504 | // `ThreadPool`. |
| 505 | debug_assert_ne!( |
| 506 | self.terminate_count.load(Ordering::Acquire), |
| 507 | 0, |
| 508 | "inject_broadcast() sees state.terminate as true" |
| 509 | ); |
| 510 | |
| 511 | assert_eq!(broadcasts.len(), injected_jobs.len()); |
| 512 | for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) { |
| 513 | worker.push(job_ref); |
| 514 | } |
| 515 | } |
| 516 | for i in 0..self.num_threads() { |
| 517 | self.sleep.notify_worker_latch_is_set(i); |
| 518 | } |
| 519 | } |
| 520 | |
| 521 | /// If already in a worker-thread of this registry, just execute `op`. |
| 522 | /// Otherwise, inject `op` in this thread-pool. Either way, block until `op` |
| 523 | /// completes and return its return value. If `op` panics, that panic will |
| 524 | /// be propagated as well. The second argument indicates `true` if injection |
| 525 | /// was performed, `false` if executed directly. |
| 526 | pub(super) fn in_worker<OP, R>(&self, op: OP) -> R |
| 527 | where |
| 528 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
| 529 | R: Send, |
| 530 | { |
| 531 | unsafe { |
| 532 | let worker_thread = WorkerThread::current(); |
| 533 | if worker_thread.is_null() { |
| 534 | self.in_worker_cold(op) |
| 535 | } else if (*worker_thread).registry().id() != self.id() { |
| 536 | self.in_worker_cross(&*worker_thread, op) |
| 537 | } else { |
| 538 | // Perfectly valid to give them a `&T`: this is the |
| 539 | // current thread, so we know the data structure won't be |
| 540 | // invalidated until we return. |
| 541 | op(&*worker_thread, false) |
| 542 | } |
| 543 | } |
| 544 | } |
| 545 | |
| 546 | #[cold ] |
| 547 | unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R |
| 548 | where |
| 549 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
| 550 | R: Send, |
| 551 | { |
| 552 | thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new()); |
| 553 | |
| 554 | LOCK_LATCH.with(|l| { |
| 555 | // This thread isn't a member of *any* thread pool, so just block. |
| 556 | debug_assert!(WorkerThread::current().is_null()); |
| 557 | let job = StackJob::new( |
| 558 | Tlv::null(), |
| 559 | |injected| { |
| 560 | let worker_thread = WorkerThread::current(); |
| 561 | assert!(injected && !worker_thread.is_null()); |
| 562 | op(&*worker_thread, true) |
| 563 | }, |
| 564 | LatchRef::new(l), |
| 565 | ); |
| 566 | self.inject(job.as_job_ref()); |
| 567 | self.release_thread(); |
| 568 | job.latch.wait_and_reset(); // Make sure we can use the same latch again next time. |
| 569 | self.acquire_thread(); |
| 570 | |
| 571 | // flush accumulated logs as we exit the thread |
| 572 | self.logger.log(|| Flush); |
| 573 | |
| 574 | job.into_result() |
| 575 | }) |
| 576 | } |
| 577 | |
| 578 | #[cold ] |
| 579 | unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R |
| 580 | where |
| 581 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
| 582 | R: Send, |
| 583 | { |
| 584 | // This thread is a member of a different pool, so let it process |
| 585 | // other work while waiting for this `op` to complete. |
| 586 | debug_assert!(current_thread.registry().id() != self.id()); |
| 587 | let latch = SpinLatch::cross(current_thread); |
| 588 | let job = StackJob::new( |
| 589 | Tlv::null(), |
| 590 | |injected| { |
| 591 | let worker_thread = WorkerThread::current(); |
| 592 | assert!(injected && !worker_thread.is_null()); |
| 593 | op(&*worker_thread, true) |
| 594 | }, |
| 595 | latch, |
| 596 | ); |
| 597 | self.inject(job.as_job_ref()); |
| 598 | current_thread.wait_until(&job.latch); |
| 599 | job.into_result() |
| 600 | } |
| 601 | |
| 602 | /// Increments the terminate counter. This increment should be |
| 603 | /// balanced by a call to `terminate`, which will decrement. This |
| 604 | /// is used when spawning asynchronous work, which needs to |
| 605 | /// prevent the registry from terminating so long as it is active. |
| 606 | /// |
| 607 | /// Note that blocking functions such as `join` and `scope` do not |
| 608 | /// need to concern themselves with this fn; their context is |
| 609 | /// responsible for ensuring the current thread-pool will not |
| 610 | /// terminate until they return. |
| 611 | /// |
| 612 | /// The global thread-pool always has an outstanding reference |
| 613 | /// (the initial one). Custom thread-pools have one outstanding |
| 614 | /// reference that is dropped when the `ThreadPool` is dropped: |
| 615 | /// since installing the thread-pool blocks until any joins/scopes |
| 616 | /// complete, this ensures that joins/scopes are covered. |
| 617 | /// |
| 618 | /// The exception is `::spawn()`, which can create a job outside |
| 619 | /// of any blocking scope. In that case, the job itself holds a |
| 620 | /// terminate count and is responsible for invoking `terminate()` |
| 621 | /// when finished. |
| 622 | pub(super) fn increment_terminate_count(&self) { |
| 623 | let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel); |
| 624 | debug_assert!(previous != 0, "registry ref count incremented from zero" ); |
| 625 | assert!( |
| 626 | previous != std::usize::MAX, |
| 627 | "overflow in registry ref count" |
| 628 | ); |
| 629 | } |
| 630 | |
| 631 | /// Signals that the thread-pool which owns this registry has been |
| 632 | /// dropped. The worker threads will gradually terminate, once any |
| 633 | /// extant work is completed. |
| 634 | pub(super) fn terminate(&self) { |
| 635 | if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 { |
| 636 | for (i, thread_info) in self.thread_infos.iter().enumerate() { |
| 637 | unsafe { CountLatch::set_and_tickle_one(&thread_info.terminate, self, i) }; |
| 638 | } |
| 639 | } |
| 640 | } |
| 641 | |
| 642 | /// Notify the worker that the latch they are sleeping on has been "set". |
| 643 | pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) { |
| 644 | self.sleep.notify_worker_latch_is_set(target_worker_index); |
| 645 | } |
| 646 | } |
| 647 | |
| 648 | /// Mark a Rayon worker thread as blocked. This triggers the deadlock handler |
| 649 | /// if no other worker thread is active |
| 650 | #[inline ] |
| 651 | pub fn mark_blocked() { |
| 652 | let worker_thread: *const WorkerThread = WorkerThread::current(); |
| 653 | assert!(!worker_thread.is_null()); |
| 654 | unsafe { |
| 655 | let registry: &Arc = &(*worker_thread).registry; |
| 656 | registry.sleep.mark_blocked(®istry.deadlock_handler) |
| 657 | } |
| 658 | } |
| 659 | |
| 660 | /// Mark a previously blocked Rayon worker thread as unblocked |
| 661 | #[inline ] |
| 662 | pub fn mark_unblocked(registry: &Registry) { |
| 663 | registry.sleep.mark_unblocked() |
| 664 | } |
| 665 | |
| 666 | #[derive (Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] |
| 667 | pub(super) struct RegistryId { |
| 668 | addr: usize, |
| 669 | } |
| 670 | |
| 671 | struct ThreadInfo { |
| 672 | /// Latch set once thread has started and we are entering into the |
| 673 | /// main loop. Used to wait for worker threads to become primed, |
| 674 | /// primarily of interest for benchmarking. |
| 675 | primed: LockLatch, |
| 676 | |
| 677 | /// Latch is set once worker thread has completed. Used to wait |
| 678 | /// until workers have stopped; only used for tests. |
| 679 | stopped: LockLatch, |
| 680 | |
| 681 | /// The latch used to signal that terminated has been requested. |
| 682 | /// This latch is *set* by the `terminate` method on the |
| 683 | /// `Registry`, once the registry's main "terminate" counter |
| 684 | /// reaches zero. |
| 685 | /// |
| 686 | /// NB. We use a `CountLatch` here because it has no lifetimes and is |
| 687 | /// meant for async use, but the count never gets higher than one. |
| 688 | terminate: CountLatch, |
| 689 | |
| 690 | /// the "stealer" half of the worker's deque |
| 691 | stealer: Stealer<JobRef>, |
| 692 | } |
| 693 | |
| 694 | impl ThreadInfo { |
| 695 | fn new(stealer: Stealer<JobRef>) -> ThreadInfo { |
| 696 | ThreadInfo { |
| 697 | primed: LockLatch::new(), |
| 698 | stopped: LockLatch::new(), |
| 699 | terminate: CountLatch::new(), |
| 700 | stealer, |
| 701 | } |
| 702 | } |
| 703 | } |
| 704 | |
| 705 | /// //////////////////////////////////////////////////////////////////////// |
| 706 | /// WorkerThread identifiers |
| 707 | |
| 708 | pub(super) struct WorkerThread { |
| 709 | /// the "worker" half of our local deque |
| 710 | worker: Worker<JobRef>, |
| 711 | |
| 712 | /// the "stealer" half of the worker's broadcast deque |
| 713 | stealer: Stealer<JobRef>, |
| 714 | |
| 715 | /// local queue used for `spawn_fifo` indirection |
| 716 | fifo: JobFifo, |
| 717 | |
| 718 | pub(crate) index: usize, |
| 719 | |
| 720 | /// A weak random number generator. |
| 721 | rng: XorShift64Star, |
| 722 | |
| 723 | pub(crate) registry: Arc<Registry>, |
| 724 | } |
| 725 | |
| 726 | // This is a bit sketchy, but basically: the WorkerThread is |
| 727 | // allocated on the stack of the worker on entry and stored into this |
| 728 | // thread local variable. So it will remain valid at least until the |
| 729 | // worker is fully unwound. Using an unsafe pointer avoids the need |
| 730 | // for a RefCell<T> etc. |
| 731 | thread_local! { |
| 732 | static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null()) }; |
| 733 | } |
| 734 | |
| 735 | impl From<ThreadBuilder> for WorkerThread { |
| 736 | fn from(thread: ThreadBuilder) -> Self { |
| 737 | Self { |
| 738 | worker: thread.worker, |
| 739 | stealer: thread.stealer, |
| 740 | fifo: JobFifo::new(), |
| 741 | index: thread.index, |
| 742 | rng: XorShift64Star::new(), |
| 743 | registry: thread.registry, |
| 744 | } |
| 745 | } |
| 746 | } |
| 747 | |
| 748 | impl Drop for WorkerThread { |
| 749 | fn drop(&mut self) { |
| 750 | // Undo `set_current` |
| 751 | WORKER_THREAD_STATE.with(|t: &Cell<*const WorkerThread>| { |
| 752 | assert!(t.get().eq(&(self as *const _))); |
| 753 | t.set(val:ptr::null()); |
| 754 | }); |
| 755 | } |
| 756 | } |
| 757 | |
| 758 | impl WorkerThread { |
| 759 | /// Gets the `WorkerThread` index for the current thread; returns |
| 760 | /// NULL if this is not a worker thread. This pointer is valid |
| 761 | /// anywhere on the current thread. |
| 762 | #[inline ] |
| 763 | pub(super) fn current() -> *const WorkerThread { |
| 764 | WORKER_THREAD_STATE.with(Cell::get) |
| 765 | } |
| 766 | |
| 767 | /// Sets `self` as the worker thread index for the current thread. |
| 768 | /// This is done during worker thread startup. |
| 769 | unsafe fn set_current(thread: *const WorkerThread) { |
| 770 | WORKER_THREAD_STATE.with(|t| { |
| 771 | assert!(t.get().is_null()); |
| 772 | t.set(thread); |
| 773 | }); |
| 774 | } |
| 775 | |
| 776 | /// Returns the registry that owns this worker thread. |
| 777 | #[inline ] |
| 778 | pub(super) fn registry(&self) -> &Arc<Registry> { |
| 779 | &self.registry |
| 780 | } |
| 781 | |
| 782 | #[inline ] |
| 783 | pub(super) fn log(&self, event: impl FnOnce() -> crate::log::Event) { |
| 784 | self.registry.logger.log(event) |
| 785 | } |
| 786 | |
| 787 | /// Our index amongst the worker threads (ranges from `0..self.num_threads()`). |
| 788 | #[inline ] |
| 789 | pub(super) fn index(&self) -> usize { |
| 790 | self.index |
| 791 | } |
| 792 | |
| 793 | #[inline ] |
| 794 | pub(super) unsafe fn push(&self, job: JobRef) { |
| 795 | self.log(|| JobPushed { worker: self.index }); |
| 796 | let queue_was_empty = self.worker.is_empty(); |
| 797 | self.worker.push(job); |
| 798 | self.registry |
| 799 | .sleep |
| 800 | .new_internal_jobs(self.index, 1, queue_was_empty); |
| 801 | } |
| 802 | |
| 803 | #[inline ] |
| 804 | pub(super) unsafe fn push_fifo(&self, job: JobRef) { |
| 805 | self.push(self.fifo.push(job)); |
| 806 | } |
| 807 | |
| 808 | #[inline ] |
| 809 | pub(super) fn local_deque_is_empty(&self) -> bool { |
| 810 | self.worker.is_empty() |
| 811 | } |
| 812 | |
| 813 | /// Attempts to obtain a "local" job -- typically this means |
| 814 | /// popping from the top of the stack, though if we are configured |
| 815 | /// for breadth-first execution, it would mean dequeuing from the |
| 816 | /// bottom. |
| 817 | #[inline ] |
| 818 | pub(super) fn take_local_job(&self) -> Option<JobRef> { |
| 819 | let popped_job = self.worker.pop(); |
| 820 | |
| 821 | if popped_job.is_some() { |
| 822 | self.log(|| JobPopped { worker: self.index }); |
| 823 | return popped_job; |
| 824 | } |
| 825 | |
| 826 | loop { |
| 827 | match self.stealer.steal() { |
| 828 | Steal::Success(job) => return Some(job), |
| 829 | Steal::Empty => return None, |
| 830 | Steal::Retry => {} |
| 831 | } |
| 832 | } |
| 833 | } |
| 834 | |
| 835 | pub(super) fn has_injected_job(&self) -> bool { |
| 836 | !self.stealer.is_empty() || self.registry.has_injected_job() |
| 837 | } |
| 838 | |
| 839 | /// Wait until the latch is set. Try to keep busy by popping and |
| 840 | /// stealing tasks as necessary. |
| 841 | #[inline ] |
| 842 | pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) { |
| 843 | let latch = latch.as_core_latch(); |
| 844 | if !latch.probe() { |
| 845 | self.wait_until_cold(latch); |
| 846 | } |
| 847 | } |
| 848 | |
| 849 | #[cold ] |
| 850 | unsafe fn wait_until_cold(&self, latch: &CoreLatch) { |
| 851 | // the code below should swallow all panics and hence never |
| 852 | // unwind; but if something does wrong, we want to abort, |
| 853 | // because otherwise other code in rayon may assume that the |
| 854 | // latch has been signaled, and that can lead to random memory |
| 855 | // accesses, which would be *very bad* |
| 856 | let abort_guard = unwind::AbortIfPanic; |
| 857 | |
| 858 | let mut idle_state = self.registry.sleep.start_looking(self.index, latch); |
| 859 | while !latch.probe() { |
| 860 | if let Some(job) = self.find_work() { |
| 861 | self.registry.sleep.work_found(idle_state); |
| 862 | self.execute(job); |
| 863 | idle_state = self.registry.sleep.start_looking(self.index, latch); |
| 864 | } else { |
| 865 | self.registry |
| 866 | .sleep |
| 867 | .no_work_found(&mut idle_state, latch, &self) |
| 868 | } |
| 869 | } |
| 870 | |
| 871 | // If we were sleepy, we are not anymore. We "found work" -- |
| 872 | // whatever the surrounding thread was doing before it had to |
| 873 | // wait. |
| 874 | self.registry.sleep.work_found(idle_state); |
| 875 | |
| 876 | self.log(|| ThreadSawLatchSet { |
| 877 | worker: self.index, |
| 878 | latch_addr: latch.addr(), |
| 879 | }); |
| 880 | mem::forget(abort_guard); // successful execution, do not abort |
| 881 | } |
| 882 | |
| 883 | fn find_work(&self) -> Option<JobRef> { |
| 884 | // Try to find some work to do. We give preference first |
| 885 | // to things in our local deque, then in other workers |
| 886 | // deques, and finally to injected jobs from the |
| 887 | // outside. The idea is to finish what we started before |
| 888 | // we take on something new. |
| 889 | self.take_local_job() |
| 890 | .or_else(|| self.steal()) |
| 891 | .or_else(|| self.registry.pop_injected_job(self.index)) |
| 892 | } |
| 893 | |
| 894 | pub(super) fn yield_now(&self) -> Yield { |
| 895 | match self.find_work() { |
| 896 | Some(job) => unsafe { |
| 897 | self.execute(job); |
| 898 | Yield::Executed |
| 899 | }, |
| 900 | None => Yield::Idle, |
| 901 | } |
| 902 | } |
| 903 | |
| 904 | pub(super) fn yield_local(&self) -> Yield { |
| 905 | match self.take_local_job() { |
| 906 | Some(job) => unsafe { |
| 907 | self.execute(job); |
| 908 | Yield::Executed |
| 909 | }, |
| 910 | None => Yield::Idle, |
| 911 | } |
| 912 | } |
| 913 | |
| 914 | #[inline ] |
| 915 | pub(super) unsafe fn execute(&self, job: JobRef) { |
| 916 | job.execute(); |
| 917 | } |
| 918 | |
| 919 | /// Try to steal a single job and return it. |
| 920 | /// |
| 921 | /// This should only be done as a last resort, when there is no |
| 922 | /// local work to do. |
| 923 | fn steal(&self) -> Option<JobRef> { |
| 924 | // we only steal when we don't have any work to do locally |
| 925 | debug_assert!(self.local_deque_is_empty()); |
| 926 | |
| 927 | // otherwise, try to steal |
| 928 | let thread_infos = &self.registry.thread_infos.as_slice(); |
| 929 | let num_threads = thread_infos.len(); |
| 930 | if num_threads <= 1 { |
| 931 | return None; |
| 932 | } |
| 933 | |
| 934 | loop { |
| 935 | let mut retry = false; |
| 936 | let start = self.rng.next_usize(num_threads); |
| 937 | let job = (start..num_threads) |
| 938 | .chain(0..start) |
| 939 | .filter(move |&i| i != self.index) |
| 940 | .find_map(|victim_index| { |
| 941 | let victim = &thread_infos[victim_index]; |
| 942 | match victim.stealer.steal() { |
| 943 | Steal::Success(job) => { |
| 944 | self.log(|| JobStolen { |
| 945 | worker: self.index, |
| 946 | victim: victim_index, |
| 947 | }); |
| 948 | Some(job) |
| 949 | } |
| 950 | Steal::Empty => None, |
| 951 | Steal::Retry => { |
| 952 | retry = true; |
| 953 | None |
| 954 | } |
| 955 | } |
| 956 | }); |
| 957 | if job.is_some() || !retry { |
| 958 | return job; |
| 959 | } |
| 960 | } |
| 961 | } |
| 962 | } |
| 963 | |
| 964 | /// //////////////////////////////////////////////////////////////////////// |
| 965 | |
| 966 | unsafe fn main_loop(thread: ThreadBuilder) { |
| 967 | let worker_thread = &WorkerThread::from(thread); |
| 968 | WorkerThread::set_current(worker_thread); |
| 969 | let registry = &*worker_thread.registry; |
| 970 | let index = worker_thread.index; |
| 971 | |
| 972 | // let registry know we are ready to do work |
| 973 | Latch::set(®istry.thread_infos[index].primed); |
| 974 | |
| 975 | // Worker threads should not panic. If they do, just abort, as the |
| 976 | // internal state of the threadpool is corrupted. Note that if |
| 977 | // **user code** panics, we should catch that and redirect. |
| 978 | let abort_guard = unwind::AbortIfPanic; |
| 979 | |
| 980 | // Inform a user callback that we started a thread. |
| 981 | if let Some(ref handler) = registry.start_handler { |
| 982 | registry.catch_unwind(|| handler(index)); |
| 983 | } |
| 984 | |
| 985 | let my_terminate_latch = ®istry.thread_infos[index].terminate; |
| 986 | worker_thread.log(|| ThreadStart { |
| 987 | worker: index, |
| 988 | terminate_addr: my_terminate_latch.as_core_latch().addr(), |
| 989 | }); |
| 990 | registry.acquire_thread(); |
| 991 | worker_thread.wait_until(my_terminate_latch); |
| 992 | |
| 993 | // Should not be any work left in our queue. |
| 994 | debug_assert!(worker_thread.take_local_job().is_none()); |
| 995 | |
| 996 | // let registry know we are done |
| 997 | Latch::set(®istry.thread_infos[index].stopped); |
| 998 | |
| 999 | // Normal termination, do not abort. |
| 1000 | mem::forget(abort_guard); |
| 1001 | |
| 1002 | worker_thread.log(|| ThreadTerminate { worker: index }); |
| 1003 | |
| 1004 | // Inform a user callback that we exited a thread. |
| 1005 | if let Some(ref handler) = registry.exit_handler { |
| 1006 | registry.catch_unwind(|| handler(index)); |
| 1007 | // We're already exiting the thread, there's nothing else to do. |
| 1008 | } |
| 1009 | |
| 1010 | registry.release_thread(); |
| 1011 | } |
| 1012 | |
| 1013 | /// If already in a worker-thread, just execute `op`. Otherwise, |
| 1014 | /// execute `op` in the default thread-pool. Either way, block until |
| 1015 | /// `op` completes and return its return value. If `op` panics, that |
| 1016 | /// panic will be propagated as well. The second argument indicates |
| 1017 | /// `true` if injection was performed, `false` if executed directly. |
| 1018 | pub(super) fn in_worker<OP, R>(op: OP) -> R |
| 1019 | where |
| 1020 | OP: FnOnce(&WorkerThread, bool) -> R + Send, |
| 1021 | R: Send, |
| 1022 | { |
| 1023 | unsafe { |
| 1024 | let owner_thread: *const WorkerThread = WorkerThread::current(); |
| 1025 | if !owner_thread.is_null() { |
| 1026 | // Perfectly valid to give them a `&T`: this is the |
| 1027 | // current thread, so we know the data structure won't be |
| 1028 | // invalidated until we return. |
| 1029 | op(&*owner_thread, false) |
| 1030 | } else { |
| 1031 | global_registry().in_worker(op) |
| 1032 | } |
| 1033 | } |
| 1034 | } |
| 1035 | |
| 1036 | /// [xorshift*] is a fast pseudorandom number generator which will |
| 1037 | /// even tolerate weak seeding, as long as it's not zero. |
| 1038 | /// |
| 1039 | /// [xorshift*]: https://en.wikipedia.org/wiki/Xorshift#xorshift* |
| 1040 | struct XorShift64Star { |
| 1041 | state: Cell<u64>, |
| 1042 | } |
| 1043 | |
| 1044 | impl XorShift64Star { |
| 1045 | fn new() -> Self { |
| 1046 | // Any non-zero seed will do -- this uses the hash of a global counter. |
| 1047 | let mut seed = 0; |
| 1048 | while seed == 0 { |
| 1049 | let mut hasher = DefaultHasher::new(); |
| 1050 | static COUNTER: AtomicUsize = AtomicUsize::new(0); |
| 1051 | hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed)); |
| 1052 | seed = hasher.finish(); |
| 1053 | } |
| 1054 | |
| 1055 | XorShift64Star { |
| 1056 | state: Cell::new(seed), |
| 1057 | } |
| 1058 | } |
| 1059 | |
| 1060 | fn next(&self) -> u64 { |
| 1061 | let mut x = self.state.get(); |
| 1062 | debug_assert_ne!(x, 0); |
| 1063 | x ^= x >> 12; |
| 1064 | x ^= x << 25; |
| 1065 | x ^= x >> 27; |
| 1066 | self.state.set(x); |
| 1067 | x.wrapping_mul(0x2545_f491_4f6c_dd1d) |
| 1068 | } |
| 1069 | |
| 1070 | /// Return a value from `0..n`. |
| 1071 | fn next_usize(&self, n: usize) -> usize { |
| 1072 | (self.next() % n as u64) as usize |
| 1073 | } |
| 1074 | } |
| 1075 | |