1use core::fmt;
2use core::mem;
3use core::pin::Pin;
4use core::sync::atomic::{AtomicUsize, Ordering};
5use core::task::Poll;
6
7use alloc::sync::Arc;
8
9use event_listener::{Event, EventListener};
10use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
11
12/// A counter for limiting the number of concurrent operations.
13#[derive(Debug)]
14pub struct Semaphore {
15 count: AtomicUsize,
16 event: Event,
17}
18
19impl Semaphore {
20 /// Creates a new semaphore with a limit of `n` concurrent operations.
21 ///
22 /// # Examples
23 ///
24 /// ```
25 /// use async_lock::Semaphore;
26 ///
27 /// let s = Semaphore::new(5);
28 /// ```
29 pub const fn new(n: usize) -> Semaphore {
30 Semaphore {
31 count: AtomicUsize::new(n),
32 event: Event::new(),
33 }
34 }
35
36 /// Attempts to get a permit for a concurrent operation.
37 ///
38 /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, a
39 /// guard is returned that releases the mutex when dropped.
40 ///
41 /// # Examples
42 ///
43 /// ```
44 /// use async_lock::Semaphore;
45 ///
46 /// let s = Semaphore::new(2);
47 ///
48 /// let g1 = s.try_acquire().unwrap();
49 /// let g2 = s.try_acquire().unwrap();
50 ///
51 /// assert!(s.try_acquire().is_none());
52 /// drop(g2);
53 /// assert!(s.try_acquire().is_some());
54 /// ```
55 pub fn try_acquire(&self) -> Option<SemaphoreGuard<'_>> {
56 let mut count = self.count.load(Ordering::Acquire);
57 loop {
58 if count == 0 {
59 return None;
60 }
61
62 match self.count.compare_exchange_weak(
63 count,
64 count - 1,
65 Ordering::AcqRel,
66 Ordering::Acquire,
67 ) {
68 Ok(_) => return Some(SemaphoreGuard(self)),
69 Err(c) => count = c,
70 }
71 }
72 }
73
74 /// Waits for a permit for a concurrent operation.
75 ///
76 /// Returns a guard that releases the permit when dropped.
77 ///
78 /// # Examples
79 ///
80 /// ```
81 /// # futures_lite::future::block_on(async {
82 /// use async_lock::Semaphore;
83 ///
84 /// let s = Semaphore::new(2);
85 /// let guard = s.acquire().await;
86 /// # });
87 /// ```
88 pub fn acquire(&self) -> Acquire<'_> {
89 Acquire::_new(AcquireInner {
90 semaphore: self,
91 listener: EventListener::new(),
92 })
93 }
94
95 /// Waits for a permit for a concurrent operation.
96 ///
97 /// Returns a guard that releases the permit when dropped.
98 ///
99 /// # Blocking
100 ///
101 /// Rather than using asynchronous waiting, like the [`acquire`][Semaphore::acquire] method,
102 /// this method will block the current thread until the permit is acquired.
103 ///
104 /// This method should not be used in an asynchronous context. It is intended to be
105 /// used in a way that a semaphore can be used in both asynchronous and synchronous contexts.
106 /// Calling this method in an asynchronous context may result in a deadlock.
107 ///
108 /// # Examples
109 ///
110 /// ```
111 /// use async_lock::Semaphore;
112 ///
113 /// let s = Semaphore::new(2);
114 /// let guard = s.acquire_blocking();
115 /// ```
116 #[cfg(all(feature = "std", not(target_family = "wasm")))]
117 #[inline]
118 pub fn acquire_blocking(&self) -> SemaphoreGuard<'_> {
119 self.acquire().wait()
120 }
121
122 /// Attempts to get an owned permit for a concurrent operation.
123 ///
124 /// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, an
125 /// owned guard is returned that releases the mutex when dropped.
126 ///
127 /// # Examples
128 ///
129 /// ```
130 /// use async_lock::Semaphore;
131 /// use std::sync::Arc;
132 ///
133 /// let s = Arc::new(Semaphore::new(2));
134 ///
135 /// let g1 = s.try_acquire_arc().unwrap();
136 /// let g2 = s.try_acquire_arc().unwrap();
137 ///
138 /// assert!(s.try_acquire_arc().is_none());
139 /// drop(g2);
140 /// assert!(s.try_acquire_arc().is_some());
141 /// ```
142 pub fn try_acquire_arc(self: &Arc<Self>) -> Option<SemaphoreGuardArc> {
143 let mut count = self.count.load(Ordering::Acquire);
144 loop {
145 if count == 0 {
146 return None;
147 }
148
149 match self.count.compare_exchange_weak(
150 count,
151 count - 1,
152 Ordering::AcqRel,
153 Ordering::Acquire,
154 ) {
155 Ok(_) => return Some(SemaphoreGuardArc(Some(self.clone()))),
156 Err(c) => count = c,
157 }
158 }
159 }
160
161 /// Waits for an owned permit for a concurrent operation.
162 ///
163 /// Returns a guard that releases the permit when dropped.
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// # futures_lite::future::block_on(async {
169 /// use async_lock::Semaphore;
170 /// use std::sync::Arc;
171 ///
172 /// let s = Arc::new(Semaphore::new(2));
173 /// let guard = s.acquire_arc().await;
174 /// # });
175 /// ```
176 pub fn acquire_arc(self: &Arc<Self>) -> AcquireArc {
177 AcquireArc::_new(AcquireArcInner {
178 semaphore: self.clone(),
179 listener: EventListener::new(),
180 })
181 }
182
183 /// Waits for an owned permit for a concurrent operation.
184 ///
185 /// Returns a guard that releases the permit when dropped.
186 ///
187 /// # Blocking
188 ///
189 /// Rather than using asynchronous waiting, like the [`acquire_arc`][Semaphore::acquire_arc] method,
190 /// this method will block the current thread until the permit is acquired.
191 ///
192 /// This method should not be used in an asynchronous context. It is intended to be
193 /// used in a way that a semaphore can be used in both asynchronous and synchronous contexts.
194 /// Calling this method in an asynchronous context may result in a deadlock.
195 ///
196 /// # Examples
197 ///
198 /// ```
199 /// use std::sync::Arc;
200 /// use async_lock::Semaphore;
201 ///
202 /// let s = Arc::new(Semaphore::new(2));
203 /// let guard = s.acquire_arc_blocking();
204 /// ```
205 #[cfg(all(feature = "std", not(target_family = "wasm")))]
206 #[inline]
207 pub fn acquire_arc_blocking(self: &Arc<Self>) -> SemaphoreGuardArc {
208 self.acquire_arc().wait()
209 }
210
211 /// Adds `n` additional permits to the semaphore.
212 ///
213 /// # Examples
214 ///
215 /// ```
216 /// use async_lock::Semaphore;
217 ///
218 /// # futures_lite::future::block_on(async {
219 /// let s = Semaphore::new(1);
220 ///
221 /// let _guard = s.acquire().await;
222 /// assert!(s.try_acquire().is_none());
223 ///
224 /// s.add_permits(2);
225 ///
226 /// let _guard = s.acquire().await;
227 /// let _guard = s.acquire().await;
228 /// # });
229 /// ```
230 pub fn add_permits(&self, n: usize) {
231 self.count.fetch_add(n, Ordering::AcqRel);
232 self.event.notify(n);
233 }
234}
235
236easy_wrapper! {
237 /// The future returned by [`Semaphore::acquire`].
238 pub struct Acquire<'a>(AcquireInner<'a> => SemaphoreGuard<'a>);
239 #[cfg(all(feature = "std", not(target_family = "wasm")))]
240 pub(crate) wait();
241}
242
243pin_project_lite::pin_project! {
244 struct AcquireInner<'a> {
245 // The semaphore being acquired.
246 semaphore: &'a Semaphore,
247
248 // The listener waiting on the semaphore.
249 #[pin]
250 listener: EventListener,
251 }
252}
253
254impl fmt::Debug for Acquire<'_> {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 f.write_str(data:"Acquire { .. }")
257 }
258}
259
260impl<'a> EventListenerFuture for AcquireInner<'a> {
261 type Output = SemaphoreGuard<'a>;
262
263 fn poll_with_strategy<'x, S: Strategy<'x>>(
264 self: Pin<&mut Self>,
265 strategy: &mut S,
266 cx: &mut S::Context,
267 ) -> Poll<Self::Output> {
268 let mut this = self.project();
269
270 loop {
271 match this.semaphore.try_acquire() {
272 Some(guard) => return Poll::Ready(guard),
273 None => {
274 // Wait on the listener.
275 if !this.listener.is_listening() {
276 this.listener.as_mut().listen(&this.semaphore.event);
277 } else {
278 ready!(strategy.poll(this.listener.as_mut(), cx));
279 }
280 }
281 }
282 }
283 }
284}
285
286easy_wrapper! {
287 /// The future returned by [`Semaphore::acquire_arc`].
288 pub struct AcquireArc(AcquireArcInner => SemaphoreGuardArc);
289 #[cfg(all(feature = "std", not(target_family = "wasm")))]
290 pub(crate) wait();
291}
292
293pin_project_lite::pin_project! {
294 struct AcquireArcInner {
295 // The semaphore being acquired.
296 semaphore: Arc<Semaphore>,
297
298 // The listener waiting on the semaphore.
299 #[pin]
300 listener: EventListener,
301 }
302}
303
304impl fmt::Debug for AcquireArc {
305 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
306 f.write_str(data:"AcquireArc { .. }")
307 }
308}
309
310impl EventListenerFuture for AcquireArcInner {
311 type Output = SemaphoreGuardArc;
312
313 fn poll_with_strategy<'x, S: Strategy<'x>>(
314 self: Pin<&mut Self>,
315 strategy: &mut S,
316 cx: &mut S::Context,
317 ) -> Poll<Self::Output> {
318 let mut this = self.project();
319
320 loop {
321 match this.semaphore.try_acquire_arc() {
322 Some(guard) => return Poll::Ready(guard),
323 None => {
324 // Wait on the listener.
325 if !this.listener.is_listening() {
326 this.listener.as_mut().listen(&this.semaphore.event);
327 } else {
328 ready!(strategy.poll(this.listener.as_mut(), cx));
329 }
330 }
331 }
332 }
333 }
334}
335
336/// A guard that releases the acquired permit.
337#[clippy::has_significant_drop]
338#[derive(Debug)]
339pub struct SemaphoreGuard<'a>(&'a Semaphore);
340
341impl SemaphoreGuard<'_> {
342 /// Drops the guard _without_ releasing the acquired permit.
343 #[inline]
344 pub fn forget(self) {
345 mem::forget(self);
346 }
347}
348
349impl Drop for SemaphoreGuard<'_> {
350 fn drop(&mut self) {
351 self.0.count.fetch_add(val:1, order:Ordering::AcqRel);
352 self.0.event.notify(1);
353 }
354}
355
356/// An owned guard that releases the acquired permit.
357#[clippy::has_significant_drop]
358#[derive(Debug)]
359pub struct SemaphoreGuardArc(Option<Arc<Semaphore>>);
360
361impl SemaphoreGuardArc {
362 /// Drops the guard _without_ releasing the acquired permit.
363 /// (Will still decrement the `Arc` reference count.)
364 #[inline]
365 pub fn forget(mut self) {
366 // Drop the inner `Arc` in order to decrement the reference count.
367 // FIXME: get rid of the `Option` once RFC 3466 or equivalent becomes available.
368 drop(self.0.take());
369 mem::forget(self);
370 }
371}
372
373impl Drop for SemaphoreGuardArc {
374 fn drop(&mut self) {
375 let opt: Arc = self.0.take().unwrap();
376 opt.count.fetch_add(val:1, order:Ordering::AcqRel);
377 opt.event.notify(1);
378 }
379}
380