1 | use crate::job::{ArcJob, StackJob}; |
2 | use crate::latch::{CountLatch, LatchRef}; |
3 | use crate::registry::{Registry, WorkerThread}; |
4 | use std::fmt; |
5 | use std::marker::PhantomData; |
6 | use std::sync::Arc; |
7 | |
8 | mod test; |
9 | |
10 | /// Executes `op` within every thread in the current threadpool. If this is |
11 | /// called from a non-Rayon thread, it will execute in the global threadpool. |
12 | /// Any attempts to use `join`, `scope`, or parallel iterators will then operate |
13 | /// within that threadpool. When the call has completed on each thread, returns |
14 | /// a vector containing all of their return values. |
15 | /// |
16 | /// For more information, see the [`ThreadPool::broadcast()`][m] method. |
17 | /// |
18 | /// [m]: struct.ThreadPool.html#method.broadcast |
19 | pub fn broadcast<OP, R>(op: OP) -> Vec<R> |
20 | where |
21 | OP: Fn(BroadcastContext<'_>) -> R + Sync, |
22 | R: Send, |
23 | { |
24 | // We assert that current registry has not terminated. |
25 | unsafe { broadcast_in(op, &Registry::current()) } |
26 | } |
27 | |
28 | /// Spawns an asynchronous task on every thread in this thread-pool. This task |
29 | /// will run in the implicit, global scope, which means that it may outlast the |
30 | /// current stack frame -- therefore, it cannot capture any references onto the |
31 | /// stack (you will likely need a `move` closure). |
32 | /// |
33 | /// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method. |
34 | /// |
35 | /// [m]: struct.ThreadPool.html#method.spawn_broadcast |
36 | pub fn spawn_broadcast<OP>(op: OP) |
37 | where |
38 | OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, |
39 | { |
40 | // We assert that current registry has not terminated. |
41 | unsafe { spawn_broadcast_in(op, &Registry::current()) } |
42 | } |
43 | |
44 | /// Provides context to a closure called by `broadcast`. |
45 | pub struct BroadcastContext<'a> { |
46 | worker: &'a WorkerThread, |
47 | |
48 | /// Make sure to prevent auto-traits like `Send` and `Sync`. |
49 | _marker: PhantomData<&'a mut dyn Fn()>, |
50 | } |
51 | |
52 | impl<'a> BroadcastContext<'a> { |
53 | pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R { |
54 | let worker_thread = WorkerThread::current(); |
55 | assert!(!worker_thread.is_null()); |
56 | f(BroadcastContext { |
57 | worker: unsafe { &*worker_thread }, |
58 | _marker: PhantomData, |
59 | }) |
60 | } |
61 | |
62 | /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`). |
63 | #[inline ] |
64 | pub fn index(&self) -> usize { |
65 | self.worker.index() |
66 | } |
67 | |
68 | /// The number of threads receiving the broadcast in the thread pool. |
69 | /// |
70 | /// # Future compatibility note |
71 | /// |
72 | /// Future versions of Rayon might vary the number of threads over time, but |
73 | /// this method will always return the number of threads which are actually |
74 | /// receiving your particular `broadcast` call. |
75 | #[inline ] |
76 | pub fn num_threads(&self) -> usize { |
77 | self.worker.registry().num_threads() |
78 | } |
79 | } |
80 | |
81 | impl<'a> fmt::Debug for BroadcastContext<'a> { |
82 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { |
83 | fmt.debug_struct("BroadcastContext" ) |
84 | .field("index" , &self.index()) |
85 | .field("num_threads" , &self.num_threads()) |
86 | .field("pool_id" , &self.worker.registry().id()) |
87 | .finish() |
88 | } |
89 | } |
90 | |
91 | /// Execute `op` on every thread in the pool. It will be executed on each |
92 | /// thread when they have nothing else to do locally, before they try to |
93 | /// steal work from other threads. This function will not return until all |
94 | /// threads have completed the `op`. |
95 | /// |
96 | /// Unsafe because `registry` must not yet have terminated. |
97 | pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R> |
98 | where |
99 | OP: Fn(BroadcastContext<'_>) -> R + Sync, |
100 | R: Send, |
101 | { |
102 | let f = move |injected: bool| { |
103 | debug_assert!(injected); |
104 | BroadcastContext::with(&op) |
105 | }; |
106 | |
107 | let n_threads = registry.num_threads(); |
108 | let current_thread = WorkerThread::current().as_ref(); |
109 | let latch = CountLatch::with_count(n_threads, current_thread); |
110 | let jobs: Vec<_> = (0..n_threads) |
111 | .map(|_| StackJob::new(&f, LatchRef::new(&latch))) |
112 | .collect(); |
113 | let job_refs = jobs.iter().map(|job| job.as_job_ref()); |
114 | |
115 | registry.inject_broadcast(job_refs); |
116 | |
117 | // Wait for all jobs to complete, then collect the results, maybe propagating a panic. |
118 | latch.wait(current_thread); |
119 | jobs.into_iter().map(|job| job.into_result()).collect() |
120 | } |
121 | |
122 | /// Execute `op` on every thread in the pool. It will be executed on each |
123 | /// thread when they have nothing else to do locally, before they try to |
124 | /// steal work from other threads. This function returns immediately after |
125 | /// injecting the jobs. |
126 | /// |
127 | /// Unsafe because `registry` must not yet have terminated. |
128 | pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>) |
129 | where |
130 | OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, |
131 | { |
132 | let job = ArcJob::new({ |
133 | let registry = Arc::clone(registry); |
134 | move || { |
135 | registry.catch_unwind(|| BroadcastContext::with(&op)); |
136 | registry.terminate(); // (*) permit registry to terminate now |
137 | } |
138 | }); |
139 | |
140 | let n_threads = registry.num_threads(); |
141 | let job_refs = (0..n_threads).map(|_| { |
142 | // Ensure that registry cannot terminate until this job has executed |
143 | // on each thread. This ref is decremented at the (*) above. |
144 | registry.increment_terminate_count(); |
145 | |
146 | ArcJob::as_static_job_ref(&job) |
147 | }); |
148 | |
149 | registry.inject_broadcast(job_refs); |
150 | } |
151 | |