1use futures_core::{ready, Stream};
2use std::fmt;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError};
7
8use super::ReusableBoxFuture;
9
10/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method.
11///
12/// [`Semaphore`]: tokio::sync::Semaphore
13pub struct PollSemaphore {
14 semaphore: Arc<Semaphore>,
15 permit_fut: Option<(
16 u32, // The number of permits requested.
17 ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>,
18 )>,
19}
20
21impl PollSemaphore {
22 /// Create a new `PollSemaphore`.
23 pub fn new(semaphore: Arc<Semaphore>) -> Self {
24 Self {
25 semaphore,
26 permit_fut: None,
27 }
28 }
29
30 /// Closes the semaphore.
31 pub fn close(&self) {
32 self.semaphore.close();
33 }
34
35 /// Obtain a clone of the inner semaphore.
36 pub fn clone_inner(&self) -> Arc<Semaphore> {
37 self.semaphore.clone()
38 }
39
40 /// Get back the inner semaphore.
41 pub fn into_inner(self) -> Arc<Semaphore> {
42 self.semaphore
43 }
44
45 /// Poll to acquire a permit from the semaphore.
46 ///
47 /// This can return the following values:
48 ///
49 /// - `Poll::Pending` if a permit is not currently available.
50 /// - `Poll::Ready(Some(permit))` if a permit was acquired.
51 /// - `Poll::Ready(None)` if the semaphore has been closed.
52 ///
53 /// When this method returns `Poll::Pending`, the current task is scheduled
54 /// to receive a wakeup when a permit becomes available, or when the
55 /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
56 /// the `Waker` from the `Context` passed to the most recent call is
57 /// scheduled to receive a wakeup.
58 pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
59 self.poll_acquire_many(cx, 1)
60 }
61
62 /// Poll to acquire many permits from the semaphore.
63 ///
64 /// This can return the following values:
65 ///
66 /// - `Poll::Pending` if a permit is not currently available.
67 /// - `Poll::Ready(Some(permit))` if a permit was acquired.
68 /// - `Poll::Ready(None)` if the semaphore has been closed.
69 ///
70 /// When this method returns `Poll::Pending`, the current task is scheduled
71 /// to receive a wakeup when the permits become available, or when the
72 /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
73 /// the `Waker` from the `Context` passed to the most recent call is
74 /// scheduled to receive a wakeup.
75 pub fn poll_acquire_many(
76 &mut self,
77 cx: &mut Context<'_>,
78 permits: u32,
79 ) -> Poll<Option<OwnedSemaphorePermit>> {
80 let permit_future = match self.permit_fut.as_mut() {
81 Some((prev_permits, fut)) if *prev_permits == permits => fut,
82 Some((old_permits, fut_box)) => {
83 // We're requesting a different number of permits, so replace the future
84 // and record the new amount.
85 let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
86 fut_box.set(fut);
87 *old_permits = permits;
88 fut_box
89 }
90 None => {
91 // avoid allocations completely if we can grab a permit immediately
92 match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) {
93 Ok(permit) => return Poll::Ready(Some(permit)),
94 Err(TryAcquireError::Closed) => return Poll::Ready(None),
95 Err(TryAcquireError::NoPermits) => {}
96 }
97
98 let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
99 &mut self
100 .permit_fut
101 .get_or_insert((permits, ReusableBoxFuture::new(next_fut)))
102 .1
103 }
104 };
105
106 let result = ready!(permit_future.poll(cx));
107
108 // Assume we'll request the same amount of permits in a subsequent call.
109 let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
110 permit_future.set(next_fut);
111
112 match result {
113 Ok(permit) => Poll::Ready(Some(permit)),
114 Err(_closed) => {
115 self.permit_fut = None;
116 Poll::Ready(None)
117 }
118 }
119 }
120
121 /// Returns the current number of available permits.
122 ///
123 /// This is equivalent to the [`Semaphore::available_permits`] method on the
124 /// `tokio::sync::Semaphore` type.
125 ///
126 /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits
127 pub fn available_permits(&self) -> usize {
128 self.semaphore.available_permits()
129 }
130
131 /// Adds `n` new permits to the semaphore.
132 ///
133 /// The maximum number of permits is [`Semaphore::MAX_PERMITS`], and this function
134 /// will panic if the limit is exceeded.
135 ///
136 /// This is equivalent to the [`Semaphore::add_permits`] method on the
137 /// `tokio::sync::Semaphore` type.
138 ///
139 /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits
140 pub fn add_permits(&self, n: usize) {
141 self.semaphore.add_permits(n);
142 }
143}
144
145impl Stream for PollSemaphore {
146 type Item = OwnedSemaphorePermit;
147
148 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
149 Pin::into_inner(self).poll_acquire(cx)
150 }
151}
152
153impl Clone for PollSemaphore {
154 fn clone(&self) -> PollSemaphore {
155 PollSemaphore::new(self.clone_inner())
156 }
157}
158
159impl fmt::Debug for PollSemaphore {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 f.debug_struct("PollSemaphore")
162 .field("semaphore", &self.semaphore)
163 .finish()
164 }
165}
166
167impl AsRef<Semaphore> for PollSemaphore {
168 fn as_ref(&self) -> &Semaphore {
169 &self.semaphore
170 }
171}
172