1//! Synchronization primitive allowing multiple threads to synchronize the
2//! beginning of some computation.
3//!
4//! Implementation adapted from the 'Barrier' type of the standard library. See:
5//! <https://doc.rust-lang.org/std/sync/struct.Barrier.html>
6//!
7//! Copyright 2014 The Rust Project Developers. See the COPYRIGHT
8//! file at the top-level directory of this distribution and at
9//! <http://rust-lang.org/COPYRIGHT>.
10//!
11//! Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
12//! <http://www.apache.org/licenses/LICENSE-2.0>> or the MIT license
13//! <LICENSE-MIT or <http://opensource.org/licenses/MIT>>, at your
14//! option. This file may not be copied, modified, or distributed
15//! except according to those terms.
16
17use crate::{mutex::Mutex, RelaxStrategy, Spin};
18
19/// A primitive that synchronizes the execution of multiple threads.
20///
21/// # Example
22///
23/// ```
24/// use spin;
25/// use std::sync::Arc;
26/// use std::thread;
27///
28/// let mut handles = Vec::with_capacity(10);
29/// let barrier = Arc::new(spin::Barrier::new(10));
30/// for _ in 0..10 {
31/// let c = barrier.clone();
32/// // The same messages will be printed together.
33/// // You will NOT see any interleaving.
34/// handles.push(thread::spawn(move|| {
35/// println!("before wait");
36/// c.wait();
37/// println!("after wait");
38/// }));
39/// }
40/// // Wait for other threads to finish.
41/// for handle in handles {
42/// handle.join().unwrap();
43/// }
44/// ```
45pub struct Barrier<R = Spin> {
46 lock: Mutex<BarrierState, R>,
47 num_threads: usize,
48}
49
50// The inner state of a double barrier
51struct BarrierState {
52 count: usize,
53 generation_id: usize,
54}
55
56/// A `BarrierWaitResult` is returned by [`wait`] when all threads in the [`Barrier`]
57/// have rendezvoused.
58///
59/// [`wait`]: struct.Barrier.html#method.wait
60/// [`Barrier`]: struct.Barrier.html
61///
62/// # Examples
63///
64/// ```
65/// use spin;
66///
67/// let barrier = spin::Barrier::new(1);
68/// let barrier_wait_result = barrier.wait();
69/// ```
70pub struct BarrierWaitResult(bool);
71
72impl<R: RelaxStrategy> Barrier<R> {
73 /// Blocks the current thread until all threads have rendezvoused here.
74 ///
75 /// Barriers are re-usable after all threads have rendezvoused once, and can
76 /// be used continuously.
77 ///
78 /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
79 /// returns `true` from [`is_leader`] when returning from this function, and
80 /// all other threads will receive a result that will return `false` from
81 /// [`is_leader`].
82 ///
83 /// [`BarrierWaitResult`]: struct.BarrierWaitResult.html
84 /// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader
85 ///
86 /// # Examples
87 ///
88 /// ```
89 /// use spin;
90 /// use std::sync::Arc;
91 /// use std::thread;
92 ///
93 /// let mut handles = Vec::with_capacity(10);
94 /// let barrier = Arc::new(spin::Barrier::new(10));
95 /// for _ in 0..10 {
96 /// let c = barrier.clone();
97 /// // The same messages will be printed together.
98 /// // You will NOT see any interleaving.
99 /// handles.push(thread::spawn(move|| {
100 /// println!("before wait");
101 /// c.wait();
102 /// println!("after wait");
103 /// }));
104 /// }
105 /// // Wait for other threads to finish.
106 /// for handle in handles {
107 /// handle.join().unwrap();
108 /// }
109 /// ```
110 pub fn wait(&self) -> BarrierWaitResult {
111 let mut lock = self.lock.lock();
112 lock.count += 1;
113
114 if lock.count < self.num_threads {
115 // not the leader
116 let local_gen = lock.generation_id;
117
118 while local_gen == lock.generation_id && lock.count < self.num_threads {
119 drop(lock);
120 R::relax();
121 lock = self.lock.lock();
122 }
123 BarrierWaitResult(false)
124 } else {
125 // this thread is the leader,
126 // and is responsible for incrementing the generation
127 lock.count = 0;
128 lock.generation_id = lock.generation_id.wrapping_add(1);
129 BarrierWaitResult(true)
130 }
131 }
132}
133
134impl<R> Barrier<R> {
135 /// Creates a new barrier that can block a given number of threads.
136 ///
137 /// A barrier will block `n`-1 threads which call [`wait`] and then wake up
138 /// all threads at once when the `n`th thread calls [`wait`]. A Barrier created
139 /// with n = 0 will behave identically to one created with n = 1.
140 ///
141 /// [`wait`]: #method.wait
142 ///
143 /// # Examples
144 ///
145 /// ```
146 /// use spin;
147 ///
148 /// let barrier = spin::Barrier::new(10);
149 /// ```
150 pub const fn new(n: usize) -> Self {
151 Self {
152 lock: Mutex::new(BarrierState {
153 count: 0,
154 generation_id: 0,
155 }),
156 num_threads: n,
157 }
158 }
159}
160
161impl BarrierWaitResult {
162 /// Returns whether this thread from [`wait`] is the "leader thread".
163 ///
164 /// Only one thread will have `true` returned from their result, all other
165 /// threads will have `false` returned.
166 ///
167 /// [`wait`]: struct.Barrier.html#method.wait
168 ///
169 /// # Examples
170 ///
171 /// ```
172 /// use spin;
173 ///
174 /// let barrier = spin::Barrier::new(1);
175 /// let barrier_wait_result = barrier.wait();
176 /// println!("{:?}", barrier_wait_result.is_leader());
177 /// ```
178 pub fn is_leader(&self) -> bool {
179 self.0
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use std::prelude::v1::*;
186
187 use std::sync::mpsc::{channel, TryRecvError};
188 use std::sync::Arc;
189 use std::thread;
190
191 type Barrier = super::Barrier;
192
193 fn use_barrier(n: usize, barrier: Arc<Barrier>) {
194 let (tx, rx) = channel();
195
196 let mut ts = Vec::new();
197 for _ in 0..n - 1 {
198 let c = barrier.clone();
199 let tx = tx.clone();
200 ts.push(thread::spawn(move || {
201 tx.send(c.wait().is_leader()).unwrap();
202 }));
203 }
204
205 // At this point, all spawned threads should be blocked,
206 // so we shouldn't get anything from the port
207 assert!(match rx.try_recv() {
208 Err(TryRecvError::Empty) => true,
209 _ => false,
210 });
211
212 let mut leader_found = barrier.wait().is_leader();
213
214 // Now, the barrier is cleared and we should get data.
215 for _ in 0..n - 1 {
216 if rx.recv().unwrap() {
217 assert!(!leader_found);
218 leader_found = true;
219 }
220 }
221 assert!(leader_found);
222
223 for t in ts {
224 t.join().unwrap();
225 }
226 }
227
228 #[test]
229 fn test_barrier() {
230 const N: usize = 10;
231
232 let barrier = Arc::new(Barrier::new(N));
233
234 use_barrier(N, barrier.clone());
235
236 // use barrier twice to ensure it is reusable
237 use_barrier(N, barrier.clone());
238 }
239}
240