1 | #[cfg (feature = "http2" )] |
2 | use std::future::Future; |
3 | use std::marker::Unpin; |
4 | #[cfg (feature = "http2" )] |
5 | use std::pin::Pin; |
6 | use std::task::{Context, Poll}; |
7 | |
8 | use futures_util::FutureExt; |
9 | use tokio::sync::{mpsc, oneshot}; |
10 | |
11 | pub(crate) type RetryPromise<T, U> = oneshot::Receiver<Result<U, (crate::Error, Option<T>)>>; |
12 | pub(crate) type Promise<T> = oneshot::Receiver<Result<T, crate::Error>>; |
13 | |
14 | pub(crate) fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) { |
15 | let (tx: UnboundedSender>, rx: UnboundedReceiver>) = mpsc::unbounded_channel(); |
16 | let (giver: Giver, taker: Taker) = want::new(); |
17 | let tx: Sender = Sender { |
18 | buffered_once: false, |
19 | giver, |
20 | inner: tx, |
21 | }; |
22 | let rx: Receiver = Receiver { inner: rx, taker }; |
23 | (tx, rx) |
24 | } |
25 | |
26 | /// A bounded sender of requests and callbacks for when responses are ready. |
27 | /// |
28 | /// While the inner sender is unbounded, the Giver is used to determine |
29 | /// if the Receiver is ready for another request. |
30 | pub(crate) struct Sender<T, U> { |
31 | /// One message is always allowed, even if the Receiver hasn't asked |
32 | /// for it yet. This boolean keeps track of whether we've sent one |
33 | /// without notice. |
34 | buffered_once: bool, |
35 | /// The Giver helps watch that the the Receiver side has been polled |
36 | /// when the queue is empty. This helps us know when a request and |
37 | /// response have been fully processed, and a connection is ready |
38 | /// for more. |
39 | giver: want::Giver, |
40 | /// Actually bounded by the Giver, plus `buffered_once`. |
41 | inner: mpsc::UnboundedSender<Envelope<T, U>>, |
42 | } |
43 | |
44 | /// An unbounded version. |
45 | /// |
46 | /// Cannot poll the Giver, but can still use it to determine if the Receiver |
47 | /// has been dropped. However, this version can be cloned. |
48 | #[cfg (feature = "http2" )] |
49 | pub(crate) struct UnboundedSender<T, U> { |
50 | /// Only used for `is_closed`, since mpsc::UnboundedSender cannot be checked. |
51 | giver: want::SharedGiver, |
52 | inner: mpsc::UnboundedSender<Envelope<T, U>>, |
53 | } |
54 | |
55 | impl<T, U> Sender<T, U> { |
56 | pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> { |
57 | self.giver |
58 | .poll_want(cx) |
59 | .map_err(|_| crate::Error::new_closed()) |
60 | } |
61 | |
62 | pub(crate) fn is_ready(&self) -> bool { |
63 | self.giver.is_wanting() |
64 | } |
65 | |
66 | pub(crate) fn is_closed(&self) -> bool { |
67 | self.giver.is_canceled() |
68 | } |
69 | |
70 | fn can_send(&mut self) -> bool { |
71 | if self.giver.give() || !self.buffered_once { |
72 | // If the receiver is ready *now*, then of course we can send. |
73 | // |
74 | // If the receiver isn't ready yet, but we don't have anything |
75 | // in the channel yet, then allow one message. |
76 | self.buffered_once = true; |
77 | true |
78 | } else { |
79 | false |
80 | } |
81 | } |
82 | |
83 | pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> { |
84 | if !self.can_send() { |
85 | return Err(val); |
86 | } |
87 | let (tx, rx) = oneshot::channel(); |
88 | self.inner |
89 | .send(Envelope(Some((val, Callback::Retry(Some(tx)))))) |
90 | .map(move |_| rx) |
91 | .map_err(|mut e| (e.0).0.take().expect("envelope not dropped" ).0) |
92 | } |
93 | |
94 | pub(crate) fn send(&mut self, val: T) -> Result<Promise<U>, T> { |
95 | if !self.can_send() { |
96 | return Err(val); |
97 | } |
98 | let (tx, rx) = oneshot::channel(); |
99 | self.inner |
100 | .send(Envelope(Some((val, Callback::NoRetry(Some(tx)))))) |
101 | .map(move |_| rx) |
102 | .map_err(|mut e| (e.0).0.take().expect("envelope not dropped" ).0) |
103 | } |
104 | |
105 | #[cfg (feature = "http2" )] |
106 | pub(crate) fn unbound(self) -> UnboundedSender<T, U> { |
107 | UnboundedSender { |
108 | giver: self.giver.shared(), |
109 | inner: self.inner, |
110 | } |
111 | } |
112 | } |
113 | |
114 | #[cfg (feature = "http2" )] |
115 | impl<T, U> UnboundedSender<T, U> { |
116 | pub(crate) fn is_ready(&self) -> bool { |
117 | !self.giver.is_canceled() |
118 | } |
119 | |
120 | pub(crate) fn is_closed(&self) -> bool { |
121 | self.giver.is_canceled() |
122 | } |
123 | |
124 | pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> { |
125 | let (tx, rx) = oneshot::channel(); |
126 | self.inner |
127 | .send(Envelope(Some((val, Callback::Retry(Some(tx)))))) |
128 | .map(move |_| rx) |
129 | .map_err(|mut e| (e.0).0.take().expect("envelope not dropped" ).0) |
130 | } |
131 | |
132 | #[cfg (all(feature = "backports" , feature = "http2" ))] |
133 | pub(crate) fn send(&mut self, val: T) -> Result<Promise<U>, T> { |
134 | let (tx, rx) = oneshot::channel(); |
135 | self.inner |
136 | .send(Envelope(Some((val, Callback::NoRetry(Some(tx)))))) |
137 | .map(move |_| rx) |
138 | .map_err(|mut e| (e.0).0.take().expect("envelope not dropped" ).0) |
139 | } |
140 | } |
141 | |
142 | #[cfg (feature = "http2" )] |
143 | impl<T, U> Clone for UnboundedSender<T, U> { |
144 | fn clone(&self) -> Self { |
145 | UnboundedSender { |
146 | giver: self.giver.clone(), |
147 | inner: self.inner.clone(), |
148 | } |
149 | } |
150 | } |
151 | |
152 | pub(crate) struct Receiver<T, U> { |
153 | inner: mpsc::UnboundedReceiver<Envelope<T, U>>, |
154 | taker: want::Taker, |
155 | } |
156 | |
157 | impl<T, U> Receiver<T, U> { |
158 | pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<(T, Callback<T, U>)>> { |
159 | match self.inner.poll_recv(cx) { |
160 | Poll::Ready(item) => { |
161 | Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped" ))) |
162 | } |
163 | Poll::Pending => { |
164 | self.taker.want(); |
165 | Poll::Pending |
166 | } |
167 | } |
168 | } |
169 | |
170 | #[cfg (feature = "http1" )] |
171 | pub(crate) fn close(&mut self) { |
172 | self.taker.cancel(); |
173 | self.inner.close(); |
174 | } |
175 | |
176 | #[cfg (feature = "http1" )] |
177 | pub(crate) fn try_recv(&mut self) -> Option<(T, Callback<T, U>)> { |
178 | match self.inner.recv().now_or_never() { |
179 | Some(Some(mut env)) => env.0.take(), |
180 | _ => None, |
181 | } |
182 | } |
183 | } |
184 | |
185 | impl<T, U> Drop for Receiver<T, U> { |
186 | fn drop(&mut self) { |
187 | // Notify the giver about the closure first, before dropping |
188 | // the mpsc::Receiver. |
189 | self.taker.cancel(); |
190 | } |
191 | } |
192 | |
193 | struct Envelope<T, U>(Option<(T, Callback<T, U>)>); |
194 | |
195 | impl<T, U> Drop for Envelope<T, U> { |
196 | fn drop(&mut self) { |
197 | if let Some((val: T, cb: Callback)) = self.0.take() { |
198 | cb.send(val:Err(( |
199 | crate::Error::new_canceled().with(cause:"connection closed" ), |
200 | Some(val), |
201 | ))); |
202 | } |
203 | } |
204 | } |
205 | |
206 | pub(crate) enum Callback<T, U> { |
207 | Retry(Option<oneshot::Sender<Result<U, (crate::Error, Option<T>)>>>), |
208 | NoRetry(Option<oneshot::Sender<Result<U, crate::Error>>>), |
209 | } |
210 | |
211 | impl<T, U> Drop for Callback<T, U> { |
212 | fn drop(&mut self) { |
213 | // FIXME(nox): What errors do we want here? |
214 | let error: Error = crate::Error::new_user_dispatch_gone().with(cause:if std::thread::panicking() { |
215 | "user code panicked" |
216 | } else { |
217 | "runtime dropped the dispatch task" |
218 | }); |
219 | |
220 | match self { |
221 | Callback::Retry(tx: &mut Option>>) => { |
222 | if let Some(tx: Sender>) = tx.take() { |
223 | let _ = tx.send(Err((error, None))); |
224 | } |
225 | } |
226 | Callback::NoRetry(tx: &mut Option>>) => { |
227 | if let Some(tx: Sender>) = tx.take() { |
228 | let _ = tx.send(Err(error)); |
229 | } |
230 | } |
231 | } |
232 | } |
233 | } |
234 | |
235 | impl<T, U> Callback<T, U> { |
236 | #[cfg (feature = "http2" )] |
237 | pub(crate) fn is_canceled(&self) -> bool { |
238 | match *self { |
239 | Callback::Retry(Some(ref tx)) => tx.is_closed(), |
240 | Callback::NoRetry(Some(ref tx)) => tx.is_closed(), |
241 | _ => unreachable!(), |
242 | } |
243 | } |
244 | |
245 | pub(crate) fn poll_canceled(&mut self, cx: &mut Context<'_>) -> Poll<()> { |
246 | match *self { |
247 | Callback::Retry(Some(ref mut tx)) => tx.poll_closed(cx), |
248 | Callback::NoRetry(Some(ref mut tx)) => tx.poll_closed(cx), |
249 | _ => unreachable!(), |
250 | } |
251 | } |
252 | |
253 | pub(crate) fn send(mut self, val: Result<U, (crate::Error, Option<T>)>) { |
254 | match self { |
255 | Callback::Retry(ref mut tx) => { |
256 | let _ = tx.take().unwrap().send(val); |
257 | } |
258 | Callback::NoRetry(ref mut tx) => { |
259 | let _ = tx.take().unwrap().send(val.map_err(|e| e.0)); |
260 | } |
261 | } |
262 | } |
263 | |
264 | #[cfg (feature = "http2" )] |
265 | pub(crate) async fn send_when( |
266 | self, |
267 | mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin, |
268 | ) { |
269 | use futures_util::future; |
270 | use tracing::trace; |
271 | |
272 | let mut cb = Some(self); |
273 | |
274 | // "select" on this callback being canceled, and the future completing |
275 | future::poll_fn(move |cx| { |
276 | match Pin::new(&mut when).poll(cx) { |
277 | Poll::Ready(Ok(res)) => { |
278 | cb.take().expect("polled after complete" ).send(Ok(res)); |
279 | Poll::Ready(()) |
280 | } |
281 | Poll::Pending => { |
282 | // check if the callback is canceled |
283 | ready!(cb.as_mut().unwrap().poll_canceled(cx)); |
284 | trace!("send_when canceled" ); |
285 | Poll::Ready(()) |
286 | } |
287 | Poll::Ready(Err(err)) => { |
288 | cb.take().expect("polled after complete" ).send(Err(err)); |
289 | Poll::Ready(()) |
290 | } |
291 | } |
292 | }) |
293 | .await |
294 | } |
295 | } |
296 | |
297 | #[cfg (test)] |
298 | mod tests { |
299 | #[cfg (feature = "nightly" )] |
300 | extern crate test; |
301 | |
302 | use std::future::Future; |
303 | use std::pin::Pin; |
304 | use std::task::{Context, Poll}; |
305 | |
306 | use super::{channel, Callback, Receiver}; |
307 | |
308 | #[derive (Debug)] |
309 | struct Custom(i32); |
310 | |
311 | impl<T, U> Future for Receiver<T, U> { |
312 | type Output = Option<(T, Callback<T, U>)>; |
313 | |
314 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
315 | self.poll_recv(cx) |
316 | } |
317 | } |
318 | |
319 | /// Helper to check if the future is ready after polling once. |
320 | struct PollOnce<'a, F>(&'a mut F); |
321 | |
322 | impl<F, T> Future for PollOnce<'_, F> |
323 | where |
324 | F: Future<Output = T> + Unpin, |
325 | { |
326 | type Output = Option<()>; |
327 | |
328 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
329 | match Pin::new(&mut self.0).poll(cx) { |
330 | Poll::Ready(_) => Poll::Ready(Some(())), |
331 | Poll::Pending => Poll::Ready(None), |
332 | } |
333 | } |
334 | } |
335 | |
336 | #[tokio::test ] |
337 | async fn drop_receiver_sends_cancel_errors() { |
338 | let _ = pretty_env_logger::try_init(); |
339 | |
340 | let (mut tx, mut rx) = channel::<Custom, ()>(); |
341 | |
342 | // must poll once for try_send to succeed |
343 | assert!(PollOnce(&mut rx).await.is_none(), "rx empty" ); |
344 | |
345 | let promise = tx.try_send(Custom(43)).unwrap(); |
346 | drop(rx); |
347 | |
348 | let fulfilled = promise.await; |
349 | let err = fulfilled |
350 | .expect("fulfilled" ) |
351 | .expect_err("promise should error" ); |
352 | match (err.0.kind(), err.1) { |
353 | (&crate::error::Kind::Canceled, Some(_)) => (), |
354 | e => panic!("expected Error::Cancel(_), found {:?}" , e), |
355 | } |
356 | } |
357 | |
358 | #[tokio::test ] |
359 | async fn sender_checks_for_want_on_send() { |
360 | let (mut tx, mut rx) = channel::<Custom, ()>(); |
361 | |
362 | // one is allowed to buffer, second is rejected |
363 | let _ = tx.try_send(Custom(1)).expect("1 buffered" ); |
364 | tx.try_send(Custom(2)).expect_err("2 not ready" ); |
365 | |
366 | assert!(PollOnce(&mut rx).await.is_some(), "rx once" ); |
367 | |
368 | // Even though 1 has been popped, only 1 could be buffered for the |
369 | // lifetime of the channel. |
370 | tx.try_send(Custom(2)).expect_err("2 still not ready" ); |
371 | |
372 | assert!(PollOnce(&mut rx).await.is_none(), "rx empty" ); |
373 | |
374 | let _ = tx.try_send(Custom(2)).expect("2 ready" ); |
375 | } |
376 | |
377 | #[cfg (feature = "http2" )] |
378 | #[test ] |
379 | fn unbounded_sender_doesnt_bound_on_want() { |
380 | let (tx, rx) = channel::<Custom, ()>(); |
381 | let mut tx = tx.unbound(); |
382 | |
383 | let _ = tx.try_send(Custom(1)).unwrap(); |
384 | let _ = tx.try_send(Custom(2)).unwrap(); |
385 | let _ = tx.try_send(Custom(3)).unwrap(); |
386 | |
387 | drop(rx); |
388 | |
389 | let _ = tx.try_send(Custom(4)).unwrap_err(); |
390 | } |
391 | |
392 | #[cfg (feature = "nightly" )] |
393 | #[bench ] |
394 | fn giver_queue_throughput(b: &mut test::Bencher) { |
395 | use crate::{Body, Request, Response}; |
396 | |
397 | let rt = tokio::runtime::Builder::new_current_thread() |
398 | .enable_all() |
399 | .build() |
400 | .unwrap(); |
401 | let (mut tx, mut rx) = channel::<Request<Body>, Response<Body>>(); |
402 | |
403 | b.iter(move || { |
404 | let _ = tx.send(Request::default()).unwrap(); |
405 | rt.block_on(async { |
406 | loop { |
407 | let poll_once = PollOnce(&mut rx); |
408 | let opt = poll_once.await; |
409 | if opt.is_none() { |
410 | break; |
411 | } |
412 | } |
413 | }); |
414 | }) |
415 | } |
416 | |
417 | #[cfg (feature = "nightly" )] |
418 | #[bench ] |
419 | fn giver_queue_not_ready(b: &mut test::Bencher) { |
420 | let rt = tokio::runtime::Builder::new_current_thread() |
421 | .enable_all() |
422 | .build() |
423 | .unwrap(); |
424 | let (_tx, mut rx) = channel::<i32, ()>(); |
425 | b.iter(move || { |
426 | rt.block_on(async { |
427 | let poll_once = PollOnce(&mut rx); |
428 | assert!(poll_once.await.is_none()); |
429 | }); |
430 | }) |
431 | } |
432 | |
433 | #[cfg (feature = "nightly" )] |
434 | #[bench ] |
435 | fn giver_queue_cancel(b: &mut test::Bencher) { |
436 | let (_tx, mut rx) = channel::<i32, ()>(); |
437 | |
438 | b.iter(move || { |
439 | rx.taker.cancel(); |
440 | }) |
441 | } |
442 | } |
443 | |