1 | use core::fmt; |
2 | use core::mem; |
3 | use core::pin::Pin; |
4 | use core::sync::atomic::{AtomicUsize, Ordering}; |
5 | use core::task::Poll; |
6 | |
7 | use alloc::sync::Arc; |
8 | |
9 | use event_listener::{Event, EventListener}; |
10 | use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy}; |
11 | |
12 | /// A counter for limiting the number of concurrent operations. |
13 | #[derive (Debug)] |
14 | pub struct Semaphore { |
15 | count: AtomicUsize, |
16 | event: Event, |
17 | } |
18 | |
19 | impl Semaphore { |
20 | /// Creates a new semaphore with a limit of `n` concurrent operations. |
21 | /// |
22 | /// # Examples |
23 | /// |
24 | /// ``` |
25 | /// use async_lock::Semaphore; |
26 | /// |
27 | /// let s = Semaphore::new(5); |
28 | /// ``` |
29 | pub const fn new(n: usize) -> Semaphore { |
30 | Semaphore { |
31 | count: AtomicUsize::new(n), |
32 | event: Event::new(), |
33 | } |
34 | } |
35 | |
36 | /// Attempts to get a permit for a concurrent operation. |
37 | /// |
38 | /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, a |
39 | /// guard is returned that releases the mutex when dropped. |
40 | /// |
41 | /// # Examples |
42 | /// |
43 | /// ``` |
44 | /// use async_lock::Semaphore; |
45 | /// |
46 | /// let s = Semaphore::new(2); |
47 | /// |
48 | /// let g1 = s.try_acquire().unwrap(); |
49 | /// let g2 = s.try_acquire().unwrap(); |
50 | /// |
51 | /// assert!(s.try_acquire().is_none()); |
52 | /// drop(g2); |
53 | /// assert!(s.try_acquire().is_some()); |
54 | /// ``` |
55 | pub fn try_acquire(&self) -> Option<SemaphoreGuard<'_>> { |
56 | let mut count = self.count.load(Ordering::Acquire); |
57 | loop { |
58 | if count == 0 { |
59 | return None; |
60 | } |
61 | |
62 | match self.count.compare_exchange_weak( |
63 | count, |
64 | count - 1, |
65 | Ordering::AcqRel, |
66 | Ordering::Acquire, |
67 | ) { |
68 | Ok(_) => return Some(SemaphoreGuard(self)), |
69 | Err(c) => count = c, |
70 | } |
71 | } |
72 | } |
73 | |
74 | /// Waits for a permit for a concurrent operation. |
75 | /// |
76 | /// Returns a guard that releases the permit when dropped. |
77 | /// |
78 | /// # Examples |
79 | /// |
80 | /// ``` |
81 | /// # futures_lite::future::block_on(async { |
82 | /// use async_lock::Semaphore; |
83 | /// |
84 | /// let s = Semaphore::new(2); |
85 | /// let guard = s.acquire().await; |
86 | /// # }); |
87 | /// ``` |
88 | pub fn acquire(&self) -> Acquire<'_> { |
89 | Acquire::_new(AcquireInner { |
90 | semaphore: self, |
91 | listener: EventListener::new(), |
92 | }) |
93 | } |
94 | |
95 | /// Waits for a permit for a concurrent operation. |
96 | /// |
97 | /// Returns a guard that releases the permit when dropped. |
98 | /// |
99 | /// # Blocking |
100 | /// |
101 | /// Rather than using asynchronous waiting, like the [`acquire`][Semaphore::acquire] method, |
102 | /// this method will block the current thread until the permit is acquired. |
103 | /// |
104 | /// This method should not be used in an asynchronous context. It is intended to be |
105 | /// used in a way that a semaphore can be used in both asynchronous and synchronous contexts. |
106 | /// Calling this method in an asynchronous context may result in a deadlock. |
107 | /// |
108 | /// # Examples |
109 | /// |
110 | /// ``` |
111 | /// use async_lock::Semaphore; |
112 | /// |
113 | /// let s = Semaphore::new(2); |
114 | /// let guard = s.acquire_blocking(); |
115 | /// ``` |
116 | #[cfg (all(feature = "std" , not(target_family = "wasm" )))] |
117 | #[inline ] |
118 | pub fn acquire_blocking(&self) -> SemaphoreGuard<'_> { |
119 | self.acquire().wait() |
120 | } |
121 | |
122 | /// Attempts to get an owned permit for a concurrent operation. |
123 | /// |
124 | /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, an |
125 | /// owned guard is returned that releases the mutex when dropped. |
126 | /// |
127 | /// # Examples |
128 | /// |
129 | /// ``` |
130 | /// use async_lock::Semaphore; |
131 | /// use std::sync::Arc; |
132 | /// |
133 | /// let s = Arc::new(Semaphore::new(2)); |
134 | /// |
135 | /// let g1 = s.try_acquire_arc().unwrap(); |
136 | /// let g2 = s.try_acquire_arc().unwrap(); |
137 | /// |
138 | /// assert!(s.try_acquire_arc().is_none()); |
139 | /// drop(g2); |
140 | /// assert!(s.try_acquire_arc().is_some()); |
141 | /// ``` |
142 | pub fn try_acquire_arc(self: &Arc<Self>) -> Option<SemaphoreGuardArc> { |
143 | let mut count = self.count.load(Ordering::Acquire); |
144 | loop { |
145 | if count == 0 { |
146 | return None; |
147 | } |
148 | |
149 | match self.count.compare_exchange_weak( |
150 | count, |
151 | count - 1, |
152 | Ordering::AcqRel, |
153 | Ordering::Acquire, |
154 | ) { |
155 | Ok(_) => return Some(SemaphoreGuardArc(Some(self.clone()))), |
156 | Err(c) => count = c, |
157 | } |
158 | } |
159 | } |
160 | |
161 | /// Waits for an owned permit for a concurrent operation. |
162 | /// |
163 | /// Returns a guard that releases the permit when dropped. |
164 | /// |
165 | /// # Examples |
166 | /// |
167 | /// ``` |
168 | /// # futures_lite::future::block_on(async { |
169 | /// use async_lock::Semaphore; |
170 | /// use std::sync::Arc; |
171 | /// |
172 | /// let s = Arc::new(Semaphore::new(2)); |
173 | /// let guard = s.acquire_arc().await; |
174 | /// # }); |
175 | /// ``` |
176 | pub fn acquire_arc(self: &Arc<Self>) -> AcquireArc { |
177 | AcquireArc::_new(AcquireArcInner { |
178 | semaphore: self.clone(), |
179 | listener: EventListener::new(), |
180 | }) |
181 | } |
182 | |
183 | /// Waits for an owned permit for a concurrent operation. |
184 | /// |
185 | /// Returns a guard that releases the permit when dropped. |
186 | /// |
187 | /// # Blocking |
188 | /// |
189 | /// Rather than using asynchronous waiting, like the [`acquire_arc`][Semaphore::acquire_arc] method, |
190 | /// this method will block the current thread until the permit is acquired. |
191 | /// |
192 | /// This method should not be used in an asynchronous context. It is intended to be |
193 | /// used in a way that a semaphore can be used in both asynchronous and synchronous contexts. |
194 | /// Calling this method in an asynchronous context may result in a deadlock. |
195 | /// |
196 | /// # Examples |
197 | /// |
198 | /// ``` |
199 | /// use std::sync::Arc; |
200 | /// use async_lock::Semaphore; |
201 | /// |
202 | /// let s = Arc::new(Semaphore::new(2)); |
203 | /// let guard = s.acquire_arc_blocking(); |
204 | /// ``` |
205 | #[cfg (all(feature = "std" , not(target_family = "wasm" )))] |
206 | #[inline ] |
207 | pub fn acquire_arc_blocking(self: &Arc<Self>) -> SemaphoreGuardArc { |
208 | self.acquire_arc().wait() |
209 | } |
210 | |
211 | /// Adds `n` additional permits to the semaphore. |
212 | /// |
213 | /// # Examples |
214 | /// |
215 | /// ``` |
216 | /// use async_lock::Semaphore; |
217 | /// |
218 | /// # futures_lite::future::block_on(async { |
219 | /// let s = Semaphore::new(1); |
220 | /// |
221 | /// let _guard = s.acquire().await; |
222 | /// assert!(s.try_acquire().is_none()); |
223 | /// |
224 | /// s.add_permits(2); |
225 | /// |
226 | /// let _guard = s.acquire().await; |
227 | /// let _guard = s.acquire().await; |
228 | /// # }); |
229 | /// ``` |
230 | pub fn add_permits(&self, n: usize) { |
231 | self.count.fetch_add(n, Ordering::AcqRel); |
232 | self.event.notify(n); |
233 | } |
234 | } |
235 | |
236 | easy_wrapper! { |
237 | /// The future returned by [`Semaphore::acquire`]. |
238 | pub struct Acquire<'a>(AcquireInner<'a> => SemaphoreGuard<'a>); |
239 | #[cfg (all(feature = "std" , not(target_family = "wasm" )))] |
240 | pub(crate) wait(); |
241 | } |
242 | |
243 | pin_project_lite::pin_project! { |
244 | struct AcquireInner<'a> { |
245 | // The semaphore being acquired. |
246 | semaphore: &'a Semaphore, |
247 | |
248 | // The listener waiting on the semaphore. |
249 | #[pin] |
250 | listener: EventListener, |
251 | } |
252 | } |
253 | |
254 | impl fmt::Debug for Acquire<'_> { |
255 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
256 | f.write_str(data:"Acquire { .. }" ) |
257 | } |
258 | } |
259 | |
260 | impl<'a> EventListenerFuture for AcquireInner<'a> { |
261 | type Output = SemaphoreGuard<'a>; |
262 | |
263 | fn poll_with_strategy<'x, S: Strategy<'x>>( |
264 | self: Pin<&mut Self>, |
265 | strategy: &mut S, |
266 | cx: &mut S::Context, |
267 | ) -> Poll<Self::Output> { |
268 | let mut this = self.project(); |
269 | |
270 | loop { |
271 | match this.semaphore.try_acquire() { |
272 | Some(guard) => return Poll::Ready(guard), |
273 | None => { |
274 | // Wait on the listener. |
275 | if !this.listener.is_listening() { |
276 | this.listener.as_mut().listen(&this.semaphore.event); |
277 | } else { |
278 | ready!(strategy.poll(this.listener.as_mut(), cx)); |
279 | } |
280 | } |
281 | } |
282 | } |
283 | } |
284 | } |
285 | |
286 | easy_wrapper! { |
287 | /// The future returned by [`Semaphore::acquire_arc`]. |
288 | pub struct AcquireArc(AcquireArcInner => SemaphoreGuardArc); |
289 | #[cfg (all(feature = "std" , not(target_family = "wasm" )))] |
290 | pub(crate) wait(); |
291 | } |
292 | |
293 | pin_project_lite::pin_project! { |
294 | struct AcquireArcInner { |
295 | // The semaphore being acquired. |
296 | semaphore: Arc<Semaphore>, |
297 | |
298 | // The listener waiting on the semaphore. |
299 | #[pin] |
300 | listener: EventListener, |
301 | } |
302 | } |
303 | |
304 | impl fmt::Debug for AcquireArc { |
305 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
306 | f.write_str(data:"AcquireArc { .. }" ) |
307 | } |
308 | } |
309 | |
310 | impl EventListenerFuture for AcquireArcInner { |
311 | type Output = SemaphoreGuardArc; |
312 | |
313 | fn poll_with_strategy<'x, S: Strategy<'x>>( |
314 | self: Pin<&mut Self>, |
315 | strategy: &mut S, |
316 | cx: &mut S::Context, |
317 | ) -> Poll<Self::Output> { |
318 | let mut this = self.project(); |
319 | |
320 | loop { |
321 | match this.semaphore.try_acquire_arc() { |
322 | Some(guard) => return Poll::Ready(guard), |
323 | None => { |
324 | // Wait on the listener. |
325 | if !this.listener.is_listening() { |
326 | this.listener.as_mut().listen(&this.semaphore.event); |
327 | } else { |
328 | ready!(strategy.poll(this.listener.as_mut(), cx)); |
329 | } |
330 | } |
331 | } |
332 | } |
333 | } |
334 | } |
335 | |
336 | /// A guard that releases the acquired permit. |
337 | #[clippy::has_significant_drop] |
338 | #[derive (Debug)] |
339 | pub struct SemaphoreGuard<'a>(&'a Semaphore); |
340 | |
341 | impl SemaphoreGuard<'_> { |
342 | /// Drops the guard _without_ releasing the acquired permit. |
343 | #[inline ] |
344 | pub fn forget(self) { |
345 | mem::forget(self); |
346 | } |
347 | } |
348 | |
349 | impl Drop for SemaphoreGuard<'_> { |
350 | fn drop(&mut self) { |
351 | self.0.count.fetch_add(val:1, order:Ordering::AcqRel); |
352 | self.0.event.notify(1); |
353 | } |
354 | } |
355 | |
356 | /// An owned guard that releases the acquired permit. |
357 | #[clippy::has_significant_drop] |
358 | #[derive (Debug)] |
359 | pub struct SemaphoreGuardArc(Option<Arc<Semaphore>>); |
360 | |
361 | impl SemaphoreGuardArc { |
362 | /// Drops the guard _without_ releasing the acquired permit. |
363 | /// (Will still decrement the `Arc` reference count.) |
364 | #[inline ] |
365 | pub fn forget(mut self) { |
366 | // Drop the inner `Arc` in order to decrement the reference count. |
367 | // FIXME: get rid of the `Option` once RFC 3466 or equivalent becomes available. |
368 | drop(self.0.take()); |
369 | mem::forget(self); |
370 | } |
371 | } |
372 | |
373 | impl Drop for SemaphoreGuardArc { |
374 | fn drop(&mut self) { |
375 | let opt: Arc = self.0.take().unwrap(); |
376 | opt.count.fetch_add(val:1, order:Ordering::AcqRel); |
377 | opt.event.notify(1); |
378 | } |
379 | } |
380 | |