1 | use crate::primitive::sync::{Arc, Condvar, Mutex}; |
2 | use std::fmt; |
3 | |
4 | /// Enables threads to synchronize the beginning or end of some computation. |
5 | /// |
6 | /// # Wait groups vs barriers |
7 | /// |
8 | /// `WaitGroup` is very similar to [`Barrier`], but there are a few differences: |
9 | /// |
10 | /// * [`Barrier`] needs to know the number of threads at construction, while `WaitGroup` is cloned to |
11 | /// register more threads. |
12 | /// |
13 | /// * A [`Barrier`] can be reused even after all threads have synchronized, while a `WaitGroup` |
14 | /// synchronizes threads only once. |
15 | /// |
16 | /// * All threads wait for others to reach the [`Barrier`]. With `WaitGroup`, each thread can choose |
17 | /// to either wait for other threads or to continue without blocking. |
18 | /// |
19 | /// # Examples |
20 | /// |
21 | /// ``` |
22 | /// use crossbeam_utils::sync::WaitGroup; |
23 | /// use std::thread; |
24 | /// |
25 | /// // Create a new wait group. |
26 | /// let wg = WaitGroup::new(); |
27 | /// |
28 | /// for _ in 0..4 { |
29 | /// // Create another reference to the wait group. |
30 | /// let wg = wg.clone(); |
31 | /// |
32 | /// thread::spawn(move || { |
33 | /// // Do some work. |
34 | /// |
35 | /// // Drop the reference to the wait group. |
36 | /// drop(wg); |
37 | /// }); |
38 | /// } |
39 | /// |
40 | /// // Block until all threads have finished their work. |
41 | /// wg.wait(); |
42 | /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371 |
43 | /// ``` |
44 | /// |
45 | /// [`Barrier`]: std::sync::Barrier |
46 | pub struct WaitGroup { |
47 | inner: Arc<Inner>, |
48 | } |
49 | |
50 | /// Inner state of a `WaitGroup`. |
51 | struct Inner { |
52 | cvar: Condvar, |
53 | count: Mutex<usize>, |
54 | } |
55 | |
56 | impl Default for WaitGroup { |
57 | fn default() -> Self { |
58 | Self { |
59 | inner: Arc::new(Inner { |
60 | cvar: Condvar::new(), |
61 | count: Mutex::new(1), |
62 | }), |
63 | } |
64 | } |
65 | } |
66 | |
67 | impl WaitGroup { |
68 | /// Creates a new wait group and returns the single reference to it. |
69 | /// |
70 | /// # Examples |
71 | /// |
72 | /// ``` |
73 | /// use crossbeam_utils::sync::WaitGroup; |
74 | /// |
75 | /// let wg = WaitGroup::new(); |
76 | /// ``` |
77 | pub fn new() -> Self { |
78 | Self::default() |
79 | } |
80 | |
81 | /// Drops this reference and waits until all other references are dropped. |
82 | /// |
83 | /// # Examples |
84 | /// |
85 | /// ``` |
86 | /// use crossbeam_utils::sync::WaitGroup; |
87 | /// use std::thread; |
88 | /// |
89 | /// let wg = WaitGroup::new(); |
90 | /// |
91 | /// thread::spawn({ |
92 | /// let wg = wg.clone(); |
93 | /// move || { |
94 | /// // Block until both threads have reached `wait()`. |
95 | /// wg.wait(); |
96 | /// } |
97 | /// }); |
98 | /// |
99 | /// // Block until both threads have reached `wait()`. |
100 | /// wg.wait(); |
101 | /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371 |
102 | /// ``` |
103 | pub fn wait(self) { |
104 | if *self.inner.count.lock().unwrap() == 1 { |
105 | return; |
106 | } |
107 | |
108 | let inner = self.inner.clone(); |
109 | drop(self); |
110 | |
111 | let mut count = inner.count.lock().unwrap(); |
112 | while *count > 0 { |
113 | count = inner.cvar.wait(count).unwrap(); |
114 | } |
115 | } |
116 | } |
117 | |
118 | impl Drop for WaitGroup { |
119 | fn drop(&mut self) { |
120 | let mut count = self.inner.count.lock().unwrap(); |
121 | *count -= 1; |
122 | |
123 | if *count == 0 { |
124 | self.inner.cvar.notify_all(); |
125 | } |
126 | } |
127 | } |
128 | |
129 | impl Clone for WaitGroup { |
130 | fn clone(&self) -> WaitGroup { |
131 | let mut count = self.inner.count.lock().unwrap(); |
132 | *count += 1; |
133 | |
134 | WaitGroup { |
135 | inner: self.inner.clone(), |
136 | } |
137 | } |
138 | } |
139 | |
140 | impl fmt::Debug for WaitGroup { |
141 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
142 | let count: &usize = &*self.inner.count.lock().unwrap(); |
143 | f.debug_struct("WaitGroup" ).field("count" , count).finish() |
144 | } |
145 | } |
146 | |