1#![cfg(feature = "full")]
2#![warn(rust_2018_idioms)]
3#![cfg(unix)]
4
5use std::io;
6use std::task::Poll;
7
8use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
9use tokio::net::{UnixListener, UnixStream};
10use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task};
11
12use futures::future::{poll_fn, try_join};
13
14#[tokio::test]
15async fn accept_read_write() -> std::io::Result<()> {
16 let dir = tempfile::Builder::new()
17 .prefix("tokio-uds-tests")
18 .tempdir()
19 .unwrap();
20 let sock_path = dir.path().join("connect.sock");
21
22 let listener = UnixListener::bind(&sock_path)?;
23
24 let accept = listener.accept();
25 let connect = UnixStream::connect(&sock_path);
26 let ((mut server, _), mut client) = try_join(accept, connect).await?;
27
28 // Write to the client.
29 client.write_all(b"hello").await?;
30 drop(client);
31
32 // Read from the server.
33 let mut buf = vec![];
34 server.read_to_end(&mut buf).await?;
35 assert_eq!(&buf, b"hello");
36 let len = server.read(&mut buf).await?;
37 assert_eq!(len, 0);
38 Ok(())
39}
40
41#[tokio::test]
42async fn shutdown() -> std::io::Result<()> {
43 let dir = tempfile::Builder::new()
44 .prefix("tokio-uds-tests")
45 .tempdir()
46 .unwrap();
47 let sock_path = dir.path().join("connect.sock");
48
49 let listener = UnixListener::bind(&sock_path)?;
50
51 let accept = listener.accept();
52 let connect = UnixStream::connect(&sock_path);
53 let ((mut server, _), mut client) = try_join(accept, connect).await?;
54
55 // Shut down the client
56 AsyncWriteExt::shutdown(&mut client).await?;
57 // Read from the server should return 0 to indicate the channel has been closed.
58 let mut buf = [0u8; 1];
59 let n = server.read(&mut buf).await?;
60 assert_eq!(n, 0);
61 Ok(())
62}
63
64#[tokio::test]
65async fn try_read_write() -> std::io::Result<()> {
66 let msg = b"hello world";
67
68 let dir = tempfile::tempdir()?;
69 let bind_path = dir.path().join("bind.sock");
70
71 // Create listener
72 let listener = UnixListener::bind(&bind_path)?;
73
74 // Create socket pair
75 let client = UnixStream::connect(&bind_path).await?;
76
77 let (server, _) = listener.accept().await?;
78 let mut written = msg.to_vec();
79
80 // Track the server receiving data
81 let mut readable = task::spawn(server.readable());
82 assert_pending!(readable.poll());
83
84 // Write data.
85 client.writable().await?;
86 assert_eq!(msg.len(), client.try_write(msg)?);
87
88 // The task should be notified
89 while !readable.is_woken() {
90 tokio::task::yield_now().await;
91 }
92
93 // Fill the write buffer using non-vectored I/O
94 loop {
95 // Still ready
96 let mut writable = task::spawn(client.writable());
97 assert_ready_ok!(writable.poll());
98
99 match client.try_write(msg) {
100 Ok(n) => written.extend(&msg[..n]),
101 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
102 break;
103 }
104 Err(e) => panic!("error = {:?}", e),
105 }
106 }
107
108 {
109 // Write buffer full
110 let mut writable = task::spawn(client.writable());
111 assert_pending!(writable.poll());
112
113 // Drain the socket from the server end using non-vectored I/O
114 let mut read = vec![0; written.len()];
115 let mut i = 0;
116
117 while i < read.len() {
118 server.readable().await?;
119
120 match server.try_read(&mut read[i..]) {
121 Ok(n) => i += n,
122 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
123 Err(e) => panic!("error = {:?}", e),
124 }
125 }
126
127 assert_eq!(read, written);
128 }
129
130 written.clear();
131 client.writable().await.unwrap();
132
133 // Fill the write buffer using vectored I/O
134 let msg_bufs: Vec<_> = msg.chunks(3).map(io::IoSlice::new).collect();
135 loop {
136 // Still ready
137 let mut writable = task::spawn(client.writable());
138 assert_ready_ok!(writable.poll());
139
140 match client.try_write_vectored(&msg_bufs) {
141 Ok(n) => written.extend(&msg[..n]),
142 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
143 break;
144 }
145 Err(e) => panic!("error = {:?}", e),
146 }
147 }
148
149 {
150 // Write buffer full
151 let mut writable = task::spawn(client.writable());
152 assert_pending!(writable.poll());
153
154 // Drain the socket from the server end using vectored I/O
155 let mut read = vec![0; written.len()];
156 let mut i = 0;
157
158 while i < read.len() {
159 server.readable().await?;
160
161 let mut bufs: Vec<_> = read[i..]
162 .chunks_mut(0x10000)
163 .map(io::IoSliceMut::new)
164 .collect();
165 match server.try_read_vectored(&mut bufs) {
166 Ok(n) => i += n,
167 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
168 Err(e) => panic!("error = {:?}", e),
169 }
170 }
171
172 assert_eq!(read, written);
173 }
174
175 // Now, we listen for shutdown
176 drop(client);
177
178 loop {
179 let ready = server.ready(Interest::READABLE).await?;
180
181 if ready.is_read_closed() {
182 break;
183 } else {
184 tokio::task::yield_now().await;
185 }
186 }
187
188 Ok(())
189}
190
191async fn create_pair() -> (UnixStream, UnixStream) {
192 let dir = assert_ok!(tempfile::tempdir());
193 let bind_path = dir.path().join("bind.sock");
194
195 let listener = assert_ok!(UnixListener::bind(&bind_path));
196
197 let accept = listener.accept();
198 let connect = UnixStream::connect(&bind_path);
199 let ((server, _), client) = assert_ok!(try_join(accept, connect).await);
200
201 (client, server)
202}
203
204macro_rules! assert_readable_by_polling {
205 ($stream:expr) => {
206 assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
207 };
208}
209
210macro_rules! assert_not_readable_by_polling {
211 ($stream:expr) => {
212 poll_fn(|cx| {
213 assert_pending!($stream.poll_read_ready(cx));
214 Poll::Ready(())
215 })
216 .await;
217 };
218}
219
220macro_rules! assert_writable_by_polling {
221 ($stream:expr) => {
222 assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
223 };
224}
225
226macro_rules! assert_not_writable_by_polling {
227 ($stream:expr) => {
228 poll_fn(|cx| {
229 assert_pending!($stream.poll_write_ready(cx));
230 Poll::Ready(())
231 })
232 .await;
233 };
234}
235
236#[tokio::test]
237async fn poll_read_ready() {
238 let (mut client, mut server) = create_pair().await;
239
240 // Initial state - not readable.
241 assert_not_readable_by_polling!(server);
242
243 // There is data in the buffer - readable.
244 assert_ok!(client.write_all(b"ping").await);
245 assert_readable_by_polling!(server);
246
247 // Readable until calls to `poll_read` return `Poll::Pending`.
248 let mut buf = [0u8; 4];
249 assert_ok!(server.read_exact(&mut buf).await);
250 assert_readable_by_polling!(server);
251 read_until_pending(&mut server);
252 assert_not_readable_by_polling!(server);
253
254 // Detect the client disconnect.
255 drop(client);
256 assert_readable_by_polling!(server);
257}
258
259#[tokio::test]
260async fn poll_write_ready() {
261 let (mut client, server) = create_pair().await;
262
263 // Initial state - writable.
264 assert_writable_by_polling!(client);
265
266 // No space to write - not writable.
267 write_until_pending(&mut client);
268 assert_not_writable_by_polling!(client);
269
270 // Detect the server disconnect.
271 drop(server);
272 assert_writable_by_polling!(client);
273}
274
275fn read_until_pending(stream: &mut UnixStream) {
276 let mut buf = vec![0u8; 1024 * 1024];
277 loop {
278 match stream.try_read(&mut buf) {
279 Ok(_) => (),
280 Err(err) => {
281 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
282 break;
283 }
284 }
285 }
286}
287
288fn write_until_pending(stream: &mut UnixStream) {
289 let buf = vec![0u8; 1024 * 1024];
290 loop {
291 match stream.try_write(&buf) {
292 Ok(_) => (),
293 Err(err) => {
294 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
295 break;
296 }
297 }
298 }
299}
300
301#[tokio::test]
302async fn try_read_buf() -> std::io::Result<()> {
303 let msg = b"hello world";
304
305 let dir = tempfile::tempdir()?;
306 let bind_path = dir.path().join("bind.sock");
307
308 // Create listener
309 let listener = UnixListener::bind(&bind_path)?;
310
311 // Create socket pair
312 let client = UnixStream::connect(&bind_path).await?;
313
314 let (server, _) = listener.accept().await?;
315 let mut written = msg.to_vec();
316
317 // Track the server receiving data
318 let mut readable = task::spawn(server.readable());
319 assert_pending!(readable.poll());
320
321 // Write data.
322 client.writable().await?;
323 assert_eq!(msg.len(), client.try_write(msg)?);
324
325 // The task should be notified
326 while !readable.is_woken() {
327 tokio::task::yield_now().await;
328 }
329
330 // Fill the write buffer
331 loop {
332 // Still ready
333 let mut writable = task::spawn(client.writable());
334 assert_ready_ok!(writable.poll());
335
336 match client.try_write(msg) {
337 Ok(n) => written.extend(&msg[..n]),
338 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
339 break;
340 }
341 Err(e) => panic!("error = {:?}", e),
342 }
343 }
344
345 {
346 // Write buffer full
347 let mut writable = task::spawn(client.writable());
348 assert_pending!(writable.poll());
349
350 // Drain the socket from the server end
351 let mut read = Vec::with_capacity(written.len());
352 let mut i = 0;
353
354 while i < read.capacity() {
355 server.readable().await?;
356
357 match server.try_read_buf(&mut read) {
358 Ok(n) => i += n,
359 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
360 Err(e) => panic!("error = {:?}", e),
361 }
362 }
363
364 assert_eq!(read, written);
365 }
366
367 // Now, we listen for shutdown
368 drop(client);
369
370 loop {
371 let ready = server.ready(Interest::READABLE).await?;
372
373 if ready.is_read_closed() {
374 break;
375 } else {
376 tokio::task::yield_now().await;
377 }
378 }
379
380 Ok(())
381}
382
383// https://github.com/tokio-rs/tokio/issues/3879
384#[tokio::test]
385#[cfg(not(target_os = "macos"))]
386async fn epollhup() -> io::Result<()> {
387 let dir = tempfile::Builder::new()
388 .prefix("tokio-uds-tests")
389 .tempdir()
390 .unwrap();
391 let sock_path = dir.path().join("connect.sock");
392
393 let listener = UnixListener::bind(&sock_path)?;
394 let connect = UnixStream::connect(&sock_path);
395 tokio::pin!(connect);
396
397 // Poll `connect` once.
398 poll_fn(|cx| {
399 use std::future::Future;
400
401 assert_pending!(connect.as_mut().poll(cx));
402 Poll::Ready(())
403 })
404 .await;
405
406 drop(listener);
407
408 let err = connect.await.unwrap_err();
409 assert_eq!(err.kind(), io::ErrorKind::ConnectionReset);
410 Ok(())
411}
412