1use crate::job::{ArcJob, StackJob};
2use crate::latch::{CountLatch, LatchRef};
3use crate::registry::{Registry, WorkerThread};
4use std::fmt;
5use std::marker::PhantomData;
6use std::sync::Arc;
7
8mod test;
9
10/// Executes `op` within every thread in the current threadpool. If this is
11/// called from a non-Rayon thread, it will execute in the global threadpool.
12/// Any attempts to use `join`, `scope`, or parallel iterators will then operate
13/// within that threadpool. When the call has completed on each thread, returns
14/// a vector containing all of their return values.
15///
16/// For more information, see the [`ThreadPool::broadcast()`][m] method.
17///
18/// [m]: struct.ThreadPool.html#method.broadcast
19pub fn broadcast<OP, R>(op: OP) -> Vec<R>
20where
21 OP: Fn(BroadcastContext<'_>) -> R + Sync,
22 R: Send,
23{
24 // We assert that current registry has not terminated.
25 unsafe { broadcast_in(op, &Registry::current()) }
26}
27
28/// Spawns an asynchronous task on every thread in this thread-pool. This task
29/// will run in the implicit, global scope, which means that it may outlast the
30/// current stack frame -- therefore, it cannot capture any references onto the
31/// stack (you will likely need a `move` closure).
32///
33/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method.
34///
35/// [m]: struct.ThreadPool.html#method.spawn_broadcast
36pub fn spawn_broadcast<OP>(op: OP)
37where
38 OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
39{
40 // We assert that current registry has not terminated.
41 unsafe { spawn_broadcast_in(op, &Registry::current()) }
42}
43
44/// Provides context to a closure called by `broadcast`.
45pub struct BroadcastContext<'a> {
46 worker: &'a WorkerThread,
47
48 /// Make sure to prevent auto-traits like `Send` and `Sync`.
49 _marker: PhantomData<&'a mut dyn Fn()>,
50}
51
52impl<'a> BroadcastContext<'a> {
53 pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
54 let worker_thread = WorkerThread::current();
55 assert!(!worker_thread.is_null());
56 f(BroadcastContext {
57 worker: unsafe { &*worker_thread },
58 _marker: PhantomData,
59 })
60 }
61
62 /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
63 #[inline]
64 pub fn index(&self) -> usize {
65 self.worker.index()
66 }
67
68 /// The number of threads receiving the broadcast in the thread pool.
69 ///
70 /// # Future compatibility note
71 ///
72 /// Future versions of Rayon might vary the number of threads over time, but
73 /// this method will always return the number of threads which are actually
74 /// receiving your particular `broadcast` call.
75 #[inline]
76 pub fn num_threads(&self) -> usize {
77 self.worker.registry().num_threads()
78 }
79}
80
81impl<'a> fmt::Debug for BroadcastContext<'a> {
82 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
83 fmt.debug_struct("BroadcastContext")
84 .field("index", &self.index())
85 .field("num_threads", &self.num_threads())
86 .field("pool_id", &self.worker.registry().id())
87 .finish()
88 }
89}
90
91/// Execute `op` on every thread in the pool. It will be executed on each
92/// thread when they have nothing else to do locally, before they try to
93/// steal work from other threads. This function will not return until all
94/// threads have completed the `op`.
95///
96/// Unsafe because `registry` must not yet have terminated.
97pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
98where
99 OP: Fn(BroadcastContext<'_>) -> R + Sync,
100 R: Send,
101{
102 let f = move |injected: bool| {
103 debug_assert!(injected);
104 BroadcastContext::with(&op)
105 };
106
107 let n_threads = registry.num_threads();
108 let current_thread = WorkerThread::current().as_ref();
109 let latch = CountLatch::with_count(n_threads, current_thread);
110 let jobs: Vec<_> = (0..n_threads)
111 .map(|_| StackJob::new(&f, LatchRef::new(&latch)))
112 .collect();
113 let job_refs = jobs.iter().map(|job| job.as_job_ref());
114
115 registry.inject_broadcast(job_refs);
116
117 // Wait for all jobs to complete, then collect the results, maybe propagating a panic.
118 latch.wait(current_thread);
119 jobs.into_iter().map(|job| job.into_result()).collect()
120}
121
122/// Execute `op` on every thread in the pool. It will be executed on each
123/// thread when they have nothing else to do locally, before they try to
124/// steal work from other threads. This function returns immediately after
125/// injecting the jobs.
126///
127/// Unsafe because `registry` must not yet have terminated.
128pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
129where
130 OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
131{
132 let job = ArcJob::new({
133 let registry = Arc::clone(registry);
134 move || {
135 registry.catch_unwind(|| BroadcastContext::with(&op));
136 registry.terminate(); // (*) permit registry to terminate now
137 }
138 });
139
140 let n_threads = registry.num_threads();
141 let job_refs = (0..n_threads).map(|_| {
142 // Ensure that registry cannot terminate until this job has executed
143 // on each thread. This ref is decremented at the (*) above.
144 registry.increment_terminate_count();
145
146 ArcJob::as_static_job_ref(&job)
147 });
148
149 registry.inject_broadcast(job_refs);
150}
151