1 | #![warn (rust_2018_idioms)] |
2 | #![cfg (all(feature = "full" , not(target_os = "wasi" )))] // Wasi doesn't support bind |
3 | |
4 | use tokio::net::{TcpListener, TcpStream}; |
5 | use tokio::sync::{mpsc, oneshot}; |
6 | use tokio_test::assert_ok; |
7 | |
8 | use std::io; |
9 | use std::net::{IpAddr, SocketAddr}; |
10 | |
11 | macro_rules! test_accept { |
12 | ($(($ident:ident, $target:expr),)*) => { |
13 | $( |
14 | #[tokio::test] |
15 | async fn $ident() { |
16 | let listener = assert_ok!(TcpListener::bind($target).await); |
17 | let addr = listener.local_addr().unwrap(); |
18 | |
19 | let (tx, rx) = oneshot::channel(); |
20 | |
21 | tokio::spawn(async move { |
22 | let (socket, _) = assert_ok!(listener.accept().await); |
23 | assert_ok!(tx.send(socket)); |
24 | }); |
25 | |
26 | let cli = assert_ok!(TcpStream::connect(&addr).await); |
27 | let srv = assert_ok!(rx.await); |
28 | |
29 | assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap()); |
30 | } |
31 | )* |
32 | } |
33 | } |
34 | |
35 | test_accept! { |
36 | (ip_str, "127.0.0.1:0" ), |
37 | (host_str, "localhost:0" ), |
38 | (socket_addr, "127.0.0.1:0" .parse::<SocketAddr>().unwrap()), |
39 | (str_port_tuple, ("127.0.0.1" , 0)), |
40 | (ip_port_tuple, ("127.0.0.1" .parse::<IpAddr>().unwrap(), 0)), |
41 | } |
42 | |
43 | use std::pin::Pin; |
44 | use std::sync::{ |
45 | atomic::{AtomicUsize, Ordering::SeqCst}, |
46 | Arc, |
47 | }; |
48 | use std::task::{Context, Poll}; |
49 | use tokio_stream::{Stream, StreamExt}; |
50 | |
51 | struct TrackPolls<'a> { |
52 | npolls: Arc<AtomicUsize>, |
53 | listener: &'a mut TcpListener, |
54 | } |
55 | |
56 | impl<'a> Stream for TrackPolls<'a> { |
57 | type Item = io::Result<(TcpStream, SocketAddr)>; |
58 | |
59 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
60 | self.npolls.fetch_add(1, SeqCst); |
61 | self.listener.poll_accept(cx).map(Some) |
62 | } |
63 | } |
64 | |
65 | #[tokio::test ] |
66 | async fn no_extra_poll() { |
67 | let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0" ).await); |
68 | let addr = listener.local_addr().unwrap(); |
69 | |
70 | let (tx, rx) = oneshot::channel(); |
71 | let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel(); |
72 | |
73 | tokio::spawn(async move { |
74 | let mut incoming = TrackPolls { |
75 | npolls: Arc::new(AtomicUsize::new(0)), |
76 | listener: &mut listener, |
77 | }; |
78 | assert_ok!(tx.send(Arc::clone(&incoming.npolls))); |
79 | while incoming.next().await.is_some() { |
80 | accepted_tx.send(()).unwrap(); |
81 | } |
82 | }); |
83 | |
84 | let npolls = assert_ok!(rx.await); |
85 | tokio::task::yield_now().await; |
86 | |
87 | // should have been polled exactly once: the initial poll |
88 | assert_eq!(npolls.load(SeqCst), 1); |
89 | |
90 | let _ = assert_ok!(TcpStream::connect(&addr).await); |
91 | accepted_rx.recv().await.unwrap(); |
92 | |
93 | // should have been polled twice more: once to yield Some(), then once to yield Pending |
94 | assert_eq!(npolls.load(SeqCst), 1 + 2); |
95 | } |
96 | |
97 | #[tokio::test ] |
98 | async fn accept_many() { |
99 | use futures::future::poll_fn; |
100 | use std::future::Future; |
101 | use std::sync::atomic::AtomicBool; |
102 | |
103 | const N: usize = 50; |
104 | |
105 | let listener = assert_ok!(TcpListener::bind("127.0.0.1:0" ).await); |
106 | let listener = Arc::new(listener); |
107 | let addr = listener.local_addr().unwrap(); |
108 | let connected = Arc::new(AtomicBool::new(false)); |
109 | |
110 | let (pending_tx, mut pending_rx) = mpsc::unbounded_channel(); |
111 | let (notified_tx, mut notified_rx) = mpsc::unbounded_channel(); |
112 | |
113 | for _ in 0..N { |
114 | let listener = listener.clone(); |
115 | let connected = connected.clone(); |
116 | let pending_tx = pending_tx.clone(); |
117 | let notified_tx = notified_tx.clone(); |
118 | |
119 | tokio::spawn(async move { |
120 | let accept = listener.accept(); |
121 | tokio::pin!(accept); |
122 | |
123 | let mut polled = false; |
124 | |
125 | poll_fn(|cx| { |
126 | if !polled { |
127 | polled = true; |
128 | assert!(Pin::new(&mut accept).poll(cx).is_pending()); |
129 | pending_tx.send(()).unwrap(); |
130 | Poll::Pending |
131 | } else if connected.load(SeqCst) { |
132 | notified_tx.send(()).unwrap(); |
133 | Poll::Ready(()) |
134 | } else { |
135 | Poll::Pending |
136 | } |
137 | }) |
138 | .await; |
139 | |
140 | pending_tx.send(()).unwrap(); |
141 | }); |
142 | } |
143 | |
144 | // Wait for all tasks to have polled at least once |
145 | for _ in 0..N { |
146 | pending_rx.recv().await.unwrap(); |
147 | } |
148 | |
149 | // Establish a TCP connection |
150 | connected.store(true, SeqCst); |
151 | let _sock = TcpStream::connect(addr).await.unwrap(); |
152 | |
153 | // Wait for all notifications |
154 | for _ in 0..N { |
155 | notified_rx.recv().await.unwrap(); |
156 | } |
157 | } |
158 | |