1 | use futures_core::{ready, Stream}; |
2 | use std::fmt; |
3 | use std::pin::Pin; |
4 | use std::sync::Arc; |
5 | use std::task::{Context, Poll}; |
6 | use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError}; |
7 | |
8 | use super::ReusableBoxFuture; |
9 | |
10 | /// A wrapper around [`Semaphore`] that provides a `poll_acquire` method. |
11 | /// |
12 | /// [`Semaphore`]: tokio::sync::Semaphore |
13 | pub struct PollSemaphore { |
14 | semaphore: Arc<Semaphore>, |
15 | permit_fut: Option<( |
16 | u32, // The number of permits requested. |
17 | ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>, |
18 | )>, |
19 | } |
20 | |
21 | impl PollSemaphore { |
22 | /// Create a new `PollSemaphore`. |
23 | pub fn new(semaphore: Arc<Semaphore>) -> Self { |
24 | Self { |
25 | semaphore, |
26 | permit_fut: None, |
27 | } |
28 | } |
29 | |
30 | /// Closes the semaphore. |
31 | pub fn close(&self) { |
32 | self.semaphore.close(); |
33 | } |
34 | |
35 | /// Obtain a clone of the inner semaphore. |
36 | pub fn clone_inner(&self) -> Arc<Semaphore> { |
37 | self.semaphore.clone() |
38 | } |
39 | |
40 | /// Get back the inner semaphore. |
41 | pub fn into_inner(self) -> Arc<Semaphore> { |
42 | self.semaphore |
43 | } |
44 | |
45 | /// Poll to acquire a permit from the semaphore. |
46 | /// |
47 | /// This can return the following values: |
48 | /// |
49 | /// - `Poll::Pending` if a permit is not currently available. |
50 | /// - `Poll::Ready(Some(permit))` if a permit was acquired. |
51 | /// - `Poll::Ready(None)` if the semaphore has been closed. |
52 | /// |
53 | /// When this method returns `Poll::Pending`, the current task is scheduled |
54 | /// to receive a wakeup when a permit becomes available, or when the |
55 | /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only |
56 | /// the `Waker` from the `Context` passed to the most recent call is |
57 | /// scheduled to receive a wakeup. |
58 | pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> { |
59 | self.poll_acquire_many(cx, 1) |
60 | } |
61 | |
62 | /// Poll to acquire many permits from the semaphore. |
63 | /// |
64 | /// This can return the following values: |
65 | /// |
66 | /// - `Poll::Pending` if a permit is not currently available. |
67 | /// - `Poll::Ready(Some(permit))` if a permit was acquired. |
68 | /// - `Poll::Ready(None)` if the semaphore has been closed. |
69 | /// |
70 | /// When this method returns `Poll::Pending`, the current task is scheduled |
71 | /// to receive a wakeup when the permits become available, or when the |
72 | /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only |
73 | /// the `Waker` from the `Context` passed to the most recent call is |
74 | /// scheduled to receive a wakeup. |
75 | pub fn poll_acquire_many( |
76 | &mut self, |
77 | cx: &mut Context<'_>, |
78 | permits: u32, |
79 | ) -> Poll<Option<OwnedSemaphorePermit>> { |
80 | let permit_future = match self.permit_fut.as_mut() { |
81 | Some((prev_permits, fut)) if *prev_permits == permits => fut, |
82 | Some((old_permits, fut_box)) => { |
83 | // We're requesting a different number of permits, so replace the future |
84 | // and record the new amount. |
85 | let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); |
86 | fut_box.set(fut); |
87 | *old_permits = permits; |
88 | fut_box |
89 | } |
90 | None => { |
91 | // avoid allocations completely if we can grab a permit immediately |
92 | match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) { |
93 | Ok(permit) => return Poll::Ready(Some(permit)), |
94 | Err(TryAcquireError::Closed) => return Poll::Ready(None), |
95 | Err(TryAcquireError::NoPermits) => {} |
96 | } |
97 | |
98 | let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); |
99 | &mut self |
100 | .permit_fut |
101 | .get_or_insert((permits, ReusableBoxFuture::new(next_fut))) |
102 | .1 |
103 | } |
104 | }; |
105 | |
106 | let result = ready!(permit_future.poll(cx)); |
107 | |
108 | // Assume we'll request the same amount of permits in a subsequent call. |
109 | let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); |
110 | permit_future.set(next_fut); |
111 | |
112 | match result { |
113 | Ok(permit) => Poll::Ready(Some(permit)), |
114 | Err(_closed) => { |
115 | self.permit_fut = None; |
116 | Poll::Ready(None) |
117 | } |
118 | } |
119 | } |
120 | |
121 | /// Returns the current number of available permits. |
122 | /// |
123 | /// This is equivalent to the [`Semaphore::available_permits`] method on the |
124 | /// `tokio::sync::Semaphore` type. |
125 | /// |
126 | /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits |
127 | pub fn available_permits(&self) -> usize { |
128 | self.semaphore.available_permits() |
129 | } |
130 | |
131 | /// Adds `n` new permits to the semaphore. |
132 | /// |
133 | /// The maximum number of permits is [`Semaphore::MAX_PERMITS`], and this function |
134 | /// will panic if the limit is exceeded. |
135 | /// |
136 | /// This is equivalent to the [`Semaphore::add_permits`] method on the |
137 | /// `tokio::sync::Semaphore` type. |
138 | /// |
139 | /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits |
140 | pub fn add_permits(&self, n: usize) { |
141 | self.semaphore.add_permits(n); |
142 | } |
143 | } |
144 | |
145 | impl Stream for PollSemaphore { |
146 | type Item = OwnedSemaphorePermit; |
147 | |
148 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> { |
149 | Pin::into_inner(self).poll_acquire(cx) |
150 | } |
151 | } |
152 | |
153 | impl Clone for PollSemaphore { |
154 | fn clone(&self) -> PollSemaphore { |
155 | PollSemaphore::new(self.clone_inner()) |
156 | } |
157 | } |
158 | |
159 | impl fmt::Debug for PollSemaphore { |
160 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
161 | f.debug_struct("PollSemaphore" ) |
162 | .field("semaphore" , &self.semaphore) |
163 | .finish() |
164 | } |
165 | } |
166 | |
167 | impl AsRef<Semaphore> for PollSemaphore { |
168 | fn as_ref(&self) -> &Semaphore { |
169 | &self.semaphore |
170 | } |
171 | } |
172 | |