1 | #[cfg (feature = "http2" )] |
2 | use std::future::Future; |
3 | |
4 | use futures_util::FutureExt; |
5 | use tokio::sync::{mpsc, oneshot}; |
6 | |
7 | #[cfg (feature = "http2" )] |
8 | use crate::common::Pin; |
9 | use crate::common::{task, Poll}; |
10 | |
11 | pub(crate) type RetryPromise<T, U> = oneshot::Receiver<Result<U, (crate::Error, Option<T>)>>; |
12 | pub(crate) type Promise<T> = oneshot::Receiver<Result<T, crate::Error>>; |
13 | |
14 | pub(crate) fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) { |
15 | let (tx: UnboundedSender>, rx: UnboundedReceiver>) = mpsc::unbounded_channel(); |
16 | let (giver: Giver, taker: Taker) = want::new(); |
17 | let tx: Sender = Sender { |
18 | buffered_once: false, |
19 | giver, |
20 | inner: tx, |
21 | }; |
22 | let rx: Receiver = Receiver { inner: rx, taker }; |
23 | (tx, rx) |
24 | } |
25 | |
26 | /// A bounded sender of requests and callbacks for when responses are ready. |
27 | /// |
28 | /// While the inner sender is unbounded, the Giver is used to determine |
29 | /// if the Receiver is ready for another request. |
30 | pub(crate) struct Sender<T, U> { |
31 | /// One message is always allowed, even if the Receiver hasn't asked |
32 | /// for it yet. This boolean keeps track of whether we've sent one |
33 | /// without notice. |
34 | buffered_once: bool, |
35 | /// The Giver helps watch that the the Receiver side has been polled |
36 | /// when the queue is empty. This helps us know when a request and |
37 | /// response have been fully processed, and a connection is ready |
38 | /// for more. |
39 | giver: want::Giver, |
40 | /// Actually bounded by the Giver, plus `buffered_once`. |
41 | inner: mpsc::UnboundedSender<Envelope<T, U>>, |
42 | } |
43 | |
44 | /// An unbounded version. |
45 | /// |
46 | /// Cannot poll the Giver, but can still use it to determine if the Receiver |
47 | /// has been dropped. However, this version can be cloned. |
48 | #[cfg (feature = "http2" )] |
49 | pub(crate) struct UnboundedSender<T, U> { |
50 | /// Only used for `is_closed`, since mpsc::UnboundedSender cannot be checked. |
51 | giver: want::SharedGiver, |
52 | inner: mpsc::UnboundedSender<Envelope<T, U>>, |
53 | } |
54 | |
55 | impl<T, U> Sender<T, U> { |
56 | pub(crate) fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> { |
57 | self.giver |
58 | .poll_want(cx) |
59 | .map_err(|_| crate::Error::new_closed()) |
60 | } |
61 | |
62 | pub(crate) fn is_ready(&self) -> bool { |
63 | self.giver.is_wanting() |
64 | } |
65 | |
66 | pub(crate) fn is_closed(&self) -> bool { |
67 | self.giver.is_canceled() |
68 | } |
69 | |
70 | fn can_send(&mut self) -> bool { |
71 | if self.giver.give() || !self.buffered_once { |
72 | // If the receiver is ready *now*, then of course we can send. |
73 | // |
74 | // If the receiver isn't ready yet, but we don't have anything |
75 | // in the channel yet, then allow one message. |
76 | self.buffered_once = true; |
77 | true |
78 | } else { |
79 | false |
80 | } |
81 | } |
82 | |
83 | pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> { |
84 | if !self.can_send() { |
85 | return Err(val); |
86 | } |
87 | let (tx, rx) = oneshot::channel(); |
88 | self.inner |
89 | .send(Envelope(Some((val, Callback::Retry(Some(tx)))))) |
90 | .map(move |_| rx) |
91 | .map_err(|mut e| (e.0).0.take().expect("envelope not dropped" ).0) |
92 | } |
93 | |
94 | pub(crate) fn send(&mut self, val: T) -> Result<Promise<U>, T> { |
95 | if !self.can_send() { |
96 | return Err(val); |
97 | } |
98 | let (tx, rx) = oneshot::channel(); |
99 | self.inner |
100 | .send(Envelope(Some((val, Callback::NoRetry(Some(tx)))))) |
101 | .map(move |_| rx) |
102 | .map_err(|mut e| (e.0).0.take().expect("envelope not dropped" ).0) |
103 | } |
104 | |
105 | #[cfg (feature = "http2" )] |
106 | pub(crate) fn unbound(self) -> UnboundedSender<T, U> { |
107 | UnboundedSender { |
108 | giver: self.giver.shared(), |
109 | inner: self.inner, |
110 | } |
111 | } |
112 | } |
113 | |
114 | #[cfg (feature = "http2" )] |
115 | impl<T, U> UnboundedSender<T, U> { |
116 | pub(crate) fn is_ready(&self) -> bool { |
117 | !self.giver.is_canceled() |
118 | } |
119 | |
120 | pub(crate) fn is_closed(&self) -> bool { |
121 | self.giver.is_canceled() |
122 | } |
123 | |
124 | pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> { |
125 | let (tx: Sender>, rx: Receiver>) = oneshot::channel(); |
126 | self.inner |
127 | .send(Envelope(Some((val, Callback::Retry(Some(tx)))))) |
128 | .map(move |_| rx) |
129 | .map_err(|mut e: SendError>| (e.0).0.take().expect(msg:"envelope not dropped" ).0) |
130 | } |
131 | } |
132 | |
133 | #[cfg (feature = "http2" )] |
134 | impl<T, U> Clone for UnboundedSender<T, U> { |
135 | fn clone(&self) -> Self { |
136 | UnboundedSender { |
137 | giver: self.giver.clone(), |
138 | inner: self.inner.clone(), |
139 | } |
140 | } |
141 | } |
142 | |
143 | pub(crate) struct Receiver<T, U> { |
144 | inner: mpsc::UnboundedReceiver<Envelope<T, U>>, |
145 | taker: want::Taker, |
146 | } |
147 | |
148 | impl<T, U> Receiver<T, U> { |
149 | pub(crate) fn poll_recv( |
150 | &mut self, |
151 | cx: &mut task::Context<'_>, |
152 | ) -> Poll<Option<(T, Callback<T, U>)>> { |
153 | match self.inner.poll_recv(cx) { |
154 | Poll::Ready(item) => { |
155 | Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped" ))) |
156 | } |
157 | Poll::Pending => { |
158 | self.taker.want(); |
159 | Poll::Pending |
160 | } |
161 | } |
162 | } |
163 | |
164 | #[cfg (feature = "http1" )] |
165 | pub(crate) fn close(&mut self) { |
166 | self.taker.cancel(); |
167 | self.inner.close(); |
168 | } |
169 | |
170 | #[cfg (feature = "http1" )] |
171 | pub(crate) fn try_recv(&mut self) -> Option<(T, Callback<T, U>)> { |
172 | match self.inner.recv().now_or_never() { |
173 | Some(Some(mut env)) => env.0.take(), |
174 | _ => None, |
175 | } |
176 | } |
177 | } |
178 | |
179 | impl<T, U> Drop for Receiver<T, U> { |
180 | fn drop(&mut self) { |
181 | // Notify the giver about the closure first, before dropping |
182 | // the mpsc::Receiver. |
183 | self.taker.cancel(); |
184 | } |
185 | } |
186 | |
187 | struct Envelope<T, U>(Option<(T, Callback<T, U>)>); |
188 | |
189 | impl<T, U> Drop for Envelope<T, U> { |
190 | fn drop(&mut self) { |
191 | if let Some((val: T, cb: Callback)) = self.0.take() { |
192 | cb.send(val:Err(( |
193 | crate::Error::new_canceled().with(cause:"connection closed" ), |
194 | Some(val), |
195 | ))); |
196 | } |
197 | } |
198 | } |
199 | |
200 | pub(crate) enum Callback<T, U> { |
201 | Retry(Option<oneshot::Sender<Result<U, (crate::Error, Option<T>)>>>), |
202 | NoRetry(Option<oneshot::Sender<Result<U, crate::Error>>>), |
203 | } |
204 | |
205 | impl<T, U> Drop for Callback<T, U> { |
206 | fn drop(&mut self) { |
207 | // FIXME(nox): What errors do we want here? |
208 | let error: Error = crate::Error::new_user_dispatch_gone().with(cause:if std::thread::panicking() { |
209 | "user code panicked" |
210 | } else { |
211 | "runtime dropped the dispatch task" |
212 | }); |
213 | |
214 | match self { |
215 | Callback::Retry(tx: &mut Option>>) => { |
216 | if let Some(tx: Sender>) = tx.take() { |
217 | let _ = tx.send(Err((error, None))); |
218 | } |
219 | } |
220 | Callback::NoRetry(tx: &mut Option>>) => { |
221 | if let Some(tx: Sender>) = tx.take() { |
222 | let _ = tx.send(Err(error)); |
223 | } |
224 | } |
225 | } |
226 | } |
227 | } |
228 | |
229 | impl<T, U> Callback<T, U> { |
230 | #[cfg (feature = "http2" )] |
231 | pub(crate) fn is_canceled(&self) -> bool { |
232 | match *self { |
233 | Callback::Retry(Some(ref tx)) => tx.is_closed(), |
234 | Callback::NoRetry(Some(ref tx)) => tx.is_closed(), |
235 | _ => unreachable!(), |
236 | } |
237 | } |
238 | |
239 | pub(crate) fn poll_canceled(&mut self, cx: &mut task::Context<'_>) -> Poll<()> { |
240 | match *self { |
241 | Callback::Retry(Some(ref mut tx)) => tx.poll_closed(cx), |
242 | Callback::NoRetry(Some(ref mut tx)) => tx.poll_closed(cx), |
243 | _ => unreachable!(), |
244 | } |
245 | } |
246 | |
247 | pub(crate) fn send(mut self, val: Result<U, (crate::Error, Option<T>)>) { |
248 | match self { |
249 | Callback::Retry(ref mut tx) => { |
250 | let _ = tx.take().unwrap().send(val); |
251 | } |
252 | Callback::NoRetry(ref mut tx) => { |
253 | let _ = tx.take().unwrap().send(val.map_err(|e| e.0)); |
254 | } |
255 | } |
256 | } |
257 | |
258 | #[cfg (feature = "http2" )] |
259 | pub(crate) async fn send_when( |
260 | self, |
261 | mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin, |
262 | ) { |
263 | use futures_util::future; |
264 | use tracing::trace; |
265 | |
266 | let mut cb = Some(self); |
267 | |
268 | // "select" on this callback being canceled, and the future completing |
269 | future::poll_fn(move |cx| { |
270 | match Pin::new(&mut when).poll(cx) { |
271 | Poll::Ready(Ok(res)) => { |
272 | cb.take().expect("polled after complete" ).send(Ok(res)); |
273 | Poll::Ready(()) |
274 | } |
275 | Poll::Pending => { |
276 | // check if the callback is canceled |
277 | ready!(cb.as_mut().unwrap().poll_canceled(cx)); |
278 | trace!("send_when canceled" ); |
279 | Poll::Ready(()) |
280 | } |
281 | Poll::Ready(Err(err)) => { |
282 | cb.take().expect("polled after complete" ).send(Err(err)); |
283 | Poll::Ready(()) |
284 | } |
285 | } |
286 | }) |
287 | .await |
288 | } |
289 | } |
290 | |
291 | #[cfg (test)] |
292 | mod tests { |
293 | #[cfg (feature = "nightly" )] |
294 | extern crate test; |
295 | |
296 | use std::future::Future; |
297 | use std::pin::Pin; |
298 | use std::task::{Context, Poll}; |
299 | |
300 | use super::{channel, Callback, Receiver}; |
301 | |
302 | #[derive (Debug)] |
303 | struct Custom(i32); |
304 | |
305 | impl<T, U> Future for Receiver<T, U> { |
306 | type Output = Option<(T, Callback<T, U>)>; |
307 | |
308 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
309 | self.poll_recv(cx) |
310 | } |
311 | } |
312 | |
313 | /// Helper to check if the future is ready after polling once. |
314 | struct PollOnce<'a, F>(&'a mut F); |
315 | |
316 | impl<F, T> Future for PollOnce<'_, F> |
317 | where |
318 | F: Future<Output = T> + Unpin, |
319 | { |
320 | type Output = Option<()>; |
321 | |
322 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
323 | match Pin::new(&mut self.0).poll(cx) { |
324 | Poll::Ready(_) => Poll::Ready(Some(())), |
325 | Poll::Pending => Poll::Ready(None), |
326 | } |
327 | } |
328 | } |
329 | |
330 | #[tokio::test] |
331 | async fn drop_receiver_sends_cancel_errors() { |
332 | let _ = pretty_env_logger::try_init(); |
333 | |
334 | let (mut tx, mut rx) = channel::<Custom, ()>(); |
335 | |
336 | // must poll once for try_send to succeed |
337 | assert!(PollOnce(&mut rx).await.is_none(), "rx empty" ); |
338 | |
339 | let promise = tx.try_send(Custom(43)).unwrap(); |
340 | drop(rx); |
341 | |
342 | let fulfilled = promise.await; |
343 | let err = fulfilled |
344 | .expect("fulfilled" ) |
345 | .expect_err("promise should error" ); |
346 | match (err.0.kind(), err.1) { |
347 | (&crate::error::Kind::Canceled, Some(_)) => (), |
348 | e => panic!("expected Error::Cancel(_), found {:?}" , e), |
349 | } |
350 | } |
351 | |
352 | #[tokio::test] |
353 | async fn sender_checks_for_want_on_send() { |
354 | let (mut tx, mut rx) = channel::<Custom, ()>(); |
355 | |
356 | // one is allowed to buffer, second is rejected |
357 | let _ = tx.try_send(Custom(1)).expect("1 buffered" ); |
358 | tx.try_send(Custom(2)).expect_err("2 not ready" ); |
359 | |
360 | assert!(PollOnce(&mut rx).await.is_some(), "rx once" ); |
361 | |
362 | // Even though 1 has been popped, only 1 could be buffered for the |
363 | // lifetime of the channel. |
364 | tx.try_send(Custom(2)).expect_err("2 still not ready" ); |
365 | |
366 | assert!(PollOnce(&mut rx).await.is_none(), "rx empty" ); |
367 | |
368 | let _ = tx.try_send(Custom(2)).expect("2 ready" ); |
369 | } |
370 | |
371 | #[cfg (feature = "http2" )] |
372 | #[test ] |
373 | fn unbounded_sender_doesnt_bound_on_want() { |
374 | let (tx, rx) = channel::<Custom, ()>(); |
375 | let mut tx = tx.unbound(); |
376 | |
377 | let _ = tx.try_send(Custom(1)).unwrap(); |
378 | let _ = tx.try_send(Custom(2)).unwrap(); |
379 | let _ = tx.try_send(Custom(3)).unwrap(); |
380 | |
381 | drop(rx); |
382 | |
383 | let _ = tx.try_send(Custom(4)).unwrap_err(); |
384 | } |
385 | |
386 | #[cfg (feature = "nightly" )] |
387 | #[bench ] |
388 | fn giver_queue_throughput(b: &mut test::Bencher) { |
389 | use crate::{Body, Request, Response}; |
390 | |
391 | let rt = tokio::runtime::Builder::new_current_thread() |
392 | .enable_all() |
393 | .build() |
394 | .unwrap(); |
395 | let (mut tx, mut rx) = channel::<Request<Body>, Response<Body>>(); |
396 | |
397 | b.iter(move || { |
398 | let _ = tx.send(Request::default()).unwrap(); |
399 | rt.block_on(async { |
400 | loop { |
401 | let poll_once = PollOnce(&mut rx); |
402 | let opt = poll_once.await; |
403 | if opt.is_none() { |
404 | break; |
405 | } |
406 | } |
407 | }); |
408 | }) |
409 | } |
410 | |
411 | #[cfg (feature = "nightly" )] |
412 | #[bench ] |
413 | fn giver_queue_not_ready(b: &mut test::Bencher) { |
414 | let rt = tokio::runtime::Builder::new_current_thread() |
415 | .enable_all() |
416 | .build() |
417 | .unwrap(); |
418 | let (_tx, mut rx) = channel::<i32, ()>(); |
419 | b.iter(move || { |
420 | rt.block_on(async { |
421 | let poll_once = PollOnce(&mut rx); |
422 | assert!(poll_once.await.is_none()); |
423 | }); |
424 | }) |
425 | } |
426 | |
427 | #[cfg (feature = "nightly" )] |
428 | #[bench ] |
429 | fn giver_queue_cancel(b: &mut test::Bencher) { |
430 | let (_tx, mut rx) = channel::<i32, ()>(); |
431 | |
432 | b.iter(move || { |
433 | rx.taker.cancel(); |
434 | }) |
435 | } |
436 | } |
437 | |