1use futures_core::{ready, Stream};
2use std::fmt;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError};
7
8use super::ReusableBoxFuture;
9
10/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method.
11///
12/// [`Semaphore`]: tokio::sync::Semaphore
13pub struct PollSemaphore {
14 semaphore: Arc<Semaphore>,
15 permit_fut: Option<ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>>,
16}
17
18impl PollSemaphore {
19 /// Create a new `PollSemaphore`.
20 pub fn new(semaphore: Arc<Semaphore>) -> Self {
21 Self {
22 semaphore,
23 permit_fut: None,
24 }
25 }
26
27 /// Closes the semaphore.
28 pub fn close(&self) {
29 self.semaphore.close()
30 }
31
32 /// Obtain a clone of the inner semaphore.
33 pub fn clone_inner(&self) -> Arc<Semaphore> {
34 self.semaphore.clone()
35 }
36
37 /// Get back the inner semaphore.
38 pub fn into_inner(self) -> Arc<Semaphore> {
39 self.semaphore
40 }
41
42 /// Poll to acquire a permit from the semaphore.
43 ///
44 /// This can return the following values:
45 ///
46 /// - `Poll::Pending` if a permit is not currently available.
47 /// - `Poll::Ready(Some(permit))` if a permit was acquired.
48 /// - `Poll::Ready(None)` if the semaphore has been closed.
49 ///
50 /// When this method returns `Poll::Pending`, the current task is scheduled
51 /// to receive a wakeup when a permit becomes available, or when the
52 /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
53 /// the `Waker` from the `Context` passed to the most recent call is
54 /// scheduled to receive a wakeup.
55 pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
56 let permit_future = match self.permit_fut.as_mut() {
57 Some(fut) => fut,
58 None => {
59 // avoid allocations completely if we can grab a permit immediately
60 match Arc::clone(&self.semaphore).try_acquire_owned() {
61 Ok(permit) => return Poll::Ready(Some(permit)),
62 Err(TryAcquireError::Closed) => return Poll::Ready(None),
63 Err(TryAcquireError::NoPermits) => {}
64 }
65
66 let next_fut = Arc::clone(&self.semaphore).acquire_owned();
67 self.permit_fut
68 .get_or_insert(ReusableBoxFuture::new(next_fut))
69 }
70 };
71
72 let result = ready!(permit_future.poll(cx));
73
74 let next_fut = Arc::clone(&self.semaphore).acquire_owned();
75 permit_future.set(next_fut);
76
77 match result {
78 Ok(permit) => Poll::Ready(Some(permit)),
79 Err(_closed) => {
80 self.permit_fut = None;
81 Poll::Ready(None)
82 }
83 }
84 }
85
86 /// Returns the current number of available permits.
87 ///
88 /// This is equivalent to the [`Semaphore::available_permits`] method on the
89 /// `tokio::sync::Semaphore` type.
90 ///
91 /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits
92 pub fn available_permits(&self) -> usize {
93 self.semaphore.available_permits()
94 }
95
96 /// Adds `n` new permits to the semaphore.
97 ///
98 /// The maximum number of permits is `usize::MAX >> 3`, and this function
99 /// will panic if the limit is exceeded.
100 ///
101 /// This is equivalent to the [`Semaphore::add_permits`] method on the
102 /// `tokio::sync::Semaphore` type.
103 ///
104 /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits
105 pub fn add_permits(&self, n: usize) {
106 self.semaphore.add_permits(n);
107 }
108}
109
110impl Stream for PollSemaphore {
111 type Item = OwnedSemaphorePermit;
112
113 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
114 Pin::into_inner(self).poll_acquire(cx)
115 }
116}
117
118impl Clone for PollSemaphore {
119 fn clone(&self) -> PollSemaphore {
120 PollSemaphore::new(self.clone_inner())
121 }
122}
123
124impl fmt::Debug for PollSemaphore {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f&mut DebugStruct<'_, '_>.debug_struct("PollSemaphore")
127 .field(name:"semaphore", &self.semaphore)
128 .finish()
129 }
130}
131
132impl AsRef<Semaphore> for PollSemaphore {
133 fn as_ref(&self) -> &Semaphore {
134 &*self.semaphore
135 }
136}
137