| 1 | #![warn (rust_2018_idioms)] |
| 2 | #![cfg (all(feature = "full" , not(target_os = "wasi" )))] // Wasi doesn't support bind |
| 3 | |
| 4 | use tokio::net::{TcpListener, TcpStream}; |
| 5 | use tokio::sync::{mpsc, oneshot}; |
| 6 | use tokio_test::assert_ok; |
| 7 | |
| 8 | use std::io; |
| 9 | use std::net::{IpAddr, SocketAddr}; |
| 10 | |
| 11 | macro_rules! test_accept { |
| 12 | ($(($ident:ident, $target:expr),)*) => { |
| 13 | $( |
| 14 | #[tokio::test] |
| 15 | async fn $ident() { |
| 16 | let listener = assert_ok!(TcpListener::bind($target).await); |
| 17 | let addr = listener.local_addr().unwrap(); |
| 18 | |
| 19 | let (tx, rx) = oneshot::channel(); |
| 20 | |
| 21 | tokio::spawn(async move { |
| 22 | let (socket, _) = assert_ok!(listener.accept().await); |
| 23 | assert_ok!(tx.send(socket)); |
| 24 | }); |
| 25 | |
| 26 | let cli = assert_ok!(TcpStream::connect(&addr).await); |
| 27 | let srv = assert_ok!(rx.await); |
| 28 | |
| 29 | assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap()); |
| 30 | } |
| 31 | )* |
| 32 | } |
| 33 | } |
| 34 | |
| 35 | test_accept! { |
| 36 | (ip_str, "127.0.0.1:0" ), |
| 37 | (host_str, "localhost:0" ), |
| 38 | (socket_addr, "127.0.0.1:0" .parse::<SocketAddr>().unwrap()), |
| 39 | (str_port_tuple, ("127.0.0.1" , 0)), |
| 40 | (ip_port_tuple, ("127.0.0.1" .parse::<IpAddr>().unwrap(), 0)), |
| 41 | } |
| 42 | |
| 43 | use std::pin::Pin; |
| 44 | use std::sync::{ |
| 45 | atomic::{AtomicUsize, Ordering::SeqCst}, |
| 46 | Arc, |
| 47 | }; |
| 48 | use std::task::{Context, Poll}; |
| 49 | use tokio_stream::{Stream, StreamExt}; |
| 50 | |
| 51 | struct TrackPolls<'a> { |
| 52 | npolls: Arc<AtomicUsize>, |
| 53 | listener: &'a mut TcpListener, |
| 54 | } |
| 55 | |
| 56 | impl<'a> Stream for TrackPolls<'a> { |
| 57 | type Item = io::Result<(TcpStream, SocketAddr)>; |
| 58 | |
| 59 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 60 | self.npolls.fetch_add(1, SeqCst); |
| 61 | self.listener.poll_accept(cx).map(Some) |
| 62 | } |
| 63 | } |
| 64 | |
| 65 | #[tokio::test ] |
| 66 | async fn no_extra_poll() { |
| 67 | let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0" ).await); |
| 68 | let addr = listener.local_addr().unwrap(); |
| 69 | |
| 70 | let (tx, rx) = oneshot::channel(); |
| 71 | let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel(); |
| 72 | |
| 73 | tokio::spawn(async move { |
| 74 | let mut incoming = TrackPolls { |
| 75 | npolls: Arc::new(AtomicUsize::new(0)), |
| 76 | listener: &mut listener, |
| 77 | }; |
| 78 | assert_ok!(tx.send(Arc::clone(&incoming.npolls))); |
| 79 | while incoming.next().await.is_some() { |
| 80 | accepted_tx.send(()).unwrap(); |
| 81 | } |
| 82 | }); |
| 83 | |
| 84 | let npolls = assert_ok!(rx.await); |
| 85 | tokio::task::yield_now().await; |
| 86 | |
| 87 | // should have been polled exactly once: the initial poll |
| 88 | assert_eq!(npolls.load(SeqCst), 1); |
| 89 | |
| 90 | let _ = assert_ok!(TcpStream::connect(&addr).await); |
| 91 | accepted_rx.recv().await.unwrap(); |
| 92 | |
| 93 | // should have been polled twice more: once to yield Some(), then once to yield Pending |
| 94 | assert_eq!(npolls.load(SeqCst), 1 + 2); |
| 95 | } |
| 96 | |
| 97 | #[tokio::test ] |
| 98 | async fn accept_many() { |
| 99 | use futures::future::poll_fn; |
| 100 | use std::future::Future; |
| 101 | use std::sync::atomic::AtomicBool; |
| 102 | |
| 103 | const N: usize = 50; |
| 104 | |
| 105 | let listener = assert_ok!(TcpListener::bind("127.0.0.1:0" ).await); |
| 106 | let listener = Arc::new(listener); |
| 107 | let addr = listener.local_addr().unwrap(); |
| 108 | let connected = Arc::new(AtomicBool::new(false)); |
| 109 | |
| 110 | let (pending_tx, mut pending_rx) = mpsc::unbounded_channel(); |
| 111 | let (notified_tx, mut notified_rx) = mpsc::unbounded_channel(); |
| 112 | |
| 113 | for _ in 0..N { |
| 114 | let listener = listener.clone(); |
| 115 | let connected = connected.clone(); |
| 116 | let pending_tx = pending_tx.clone(); |
| 117 | let notified_tx = notified_tx.clone(); |
| 118 | |
| 119 | tokio::spawn(async move { |
| 120 | let accept = listener.accept(); |
| 121 | tokio::pin!(accept); |
| 122 | |
| 123 | let mut polled = false; |
| 124 | |
| 125 | poll_fn(|cx| { |
| 126 | if !polled { |
| 127 | polled = true; |
| 128 | assert!(Pin::new(&mut accept).poll(cx).is_pending()); |
| 129 | pending_tx.send(()).unwrap(); |
| 130 | Poll::Pending |
| 131 | } else if connected.load(SeqCst) { |
| 132 | notified_tx.send(()).unwrap(); |
| 133 | Poll::Ready(()) |
| 134 | } else { |
| 135 | Poll::Pending |
| 136 | } |
| 137 | }) |
| 138 | .await; |
| 139 | |
| 140 | pending_tx.send(()).unwrap(); |
| 141 | }); |
| 142 | } |
| 143 | |
| 144 | // Wait for all tasks to have polled at least once |
| 145 | for _ in 0..N { |
| 146 | pending_rx.recv().await.unwrap(); |
| 147 | } |
| 148 | |
| 149 | // Establish a TCP connection |
| 150 | connected.store(true, SeqCst); |
| 151 | let _sock = TcpStream::connect(addr).await.unwrap(); |
| 152 | |
| 153 | // Wait for all notifications |
| 154 | for _ in 0..N { |
| 155 | notified_rx.recv().await.unwrap(); |
| 156 | } |
| 157 | } |
| 158 | |