1 | #![cfg (feature = "full" )] |
2 | #![warn (rust_2018_idioms)] |
3 | #![cfg (unix)] |
4 | |
5 | use std::io; |
6 | use std::task::Poll; |
7 | |
8 | use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest}; |
9 | use tokio::net::{UnixListener, UnixStream}; |
10 | use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task}; |
11 | |
12 | use futures::future::{poll_fn, try_join}; |
13 | |
14 | #[tokio::test ] |
15 | async fn accept_read_write() -> std::io::Result<()> { |
16 | let dir = tempfile::Builder::new() |
17 | .prefix("tokio-uds-tests" ) |
18 | .tempdir() |
19 | .unwrap(); |
20 | let sock_path = dir.path().join("connect.sock" ); |
21 | |
22 | let listener = UnixListener::bind(&sock_path)?; |
23 | |
24 | let accept = listener.accept(); |
25 | let connect = UnixStream::connect(&sock_path); |
26 | let ((mut server, _), mut client) = try_join(accept, connect).await?; |
27 | |
28 | // Write to the client. |
29 | client.write_all(b"hello" ).await?; |
30 | drop(client); |
31 | |
32 | // Read from the server. |
33 | let mut buf = vec![]; |
34 | server.read_to_end(&mut buf).await?; |
35 | assert_eq!(&buf, b"hello" ); |
36 | let len = server.read(&mut buf).await?; |
37 | assert_eq!(len, 0); |
38 | Ok(()) |
39 | } |
40 | |
41 | #[tokio::test ] |
42 | async fn shutdown() -> std::io::Result<()> { |
43 | let dir = tempfile::Builder::new() |
44 | .prefix("tokio-uds-tests" ) |
45 | .tempdir() |
46 | .unwrap(); |
47 | let sock_path = dir.path().join("connect.sock" ); |
48 | |
49 | let listener = UnixListener::bind(&sock_path)?; |
50 | |
51 | let accept = listener.accept(); |
52 | let connect = UnixStream::connect(&sock_path); |
53 | let ((mut server, _), mut client) = try_join(accept, connect).await?; |
54 | |
55 | // Shut down the client |
56 | AsyncWriteExt::shutdown(&mut client).await?; |
57 | // Read from the server should return 0 to indicate the channel has been closed. |
58 | let mut buf = [0u8; 1]; |
59 | let n = server.read(&mut buf).await?; |
60 | assert_eq!(n, 0); |
61 | Ok(()) |
62 | } |
63 | |
64 | #[tokio::test ] |
65 | async fn try_read_write() -> std::io::Result<()> { |
66 | let msg = b"hello world" ; |
67 | |
68 | let dir = tempfile::tempdir()?; |
69 | let bind_path = dir.path().join("bind.sock" ); |
70 | |
71 | // Create listener |
72 | let listener = UnixListener::bind(&bind_path)?; |
73 | |
74 | // Create socket pair |
75 | let client = UnixStream::connect(&bind_path).await?; |
76 | |
77 | let (server, _) = listener.accept().await?; |
78 | let mut written = msg.to_vec(); |
79 | |
80 | // Track the server receiving data |
81 | let mut readable = task::spawn(server.readable()); |
82 | assert_pending!(readable.poll()); |
83 | |
84 | // Write data. |
85 | client.writable().await?; |
86 | assert_eq!(msg.len(), client.try_write(msg)?); |
87 | |
88 | // The task should be notified |
89 | while !readable.is_woken() { |
90 | tokio::task::yield_now().await; |
91 | } |
92 | |
93 | // Fill the write buffer using non-vectored I/O |
94 | loop { |
95 | // Still ready |
96 | let mut writable = task::spawn(client.writable()); |
97 | assert_ready_ok!(writable.poll()); |
98 | |
99 | match client.try_write(msg) { |
100 | Ok(n) => written.extend(&msg[..n]), |
101 | Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { |
102 | break; |
103 | } |
104 | Err(e) => panic!("error = {:?}" , e), |
105 | } |
106 | } |
107 | |
108 | { |
109 | // Write buffer full |
110 | let mut writable = task::spawn(client.writable()); |
111 | assert_pending!(writable.poll()); |
112 | |
113 | // Drain the socket from the server end using non-vectored I/O |
114 | let mut read = vec![0; written.len()]; |
115 | let mut i = 0; |
116 | |
117 | while i < read.len() { |
118 | server.readable().await?; |
119 | |
120 | match server.try_read(&mut read[i..]) { |
121 | Ok(n) => i += n, |
122 | Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, |
123 | Err(e) => panic!("error = {:?}" , e), |
124 | } |
125 | } |
126 | |
127 | assert_eq!(read, written); |
128 | } |
129 | |
130 | written.clear(); |
131 | client.writable().await.unwrap(); |
132 | |
133 | // Fill the write buffer using vectored I/O |
134 | let msg_bufs: Vec<_> = msg.chunks(3).map(io::IoSlice::new).collect(); |
135 | loop { |
136 | // Still ready |
137 | let mut writable = task::spawn(client.writable()); |
138 | assert_ready_ok!(writable.poll()); |
139 | |
140 | match client.try_write_vectored(&msg_bufs) { |
141 | Ok(n) => written.extend(&msg[..n]), |
142 | Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { |
143 | break; |
144 | } |
145 | Err(e) => panic!("error = {:?}" , e), |
146 | } |
147 | } |
148 | |
149 | { |
150 | // Write buffer full |
151 | let mut writable = task::spawn(client.writable()); |
152 | assert_pending!(writable.poll()); |
153 | |
154 | // Drain the socket from the server end using vectored I/O |
155 | let mut read = vec![0; written.len()]; |
156 | let mut i = 0; |
157 | |
158 | while i < read.len() { |
159 | server.readable().await?; |
160 | |
161 | let mut bufs: Vec<_> = read[i..] |
162 | .chunks_mut(0x10000) |
163 | .map(io::IoSliceMut::new) |
164 | .collect(); |
165 | match server.try_read_vectored(&mut bufs) { |
166 | Ok(n) => i += n, |
167 | Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, |
168 | Err(e) => panic!("error = {:?}" , e), |
169 | } |
170 | } |
171 | |
172 | assert_eq!(read, written); |
173 | } |
174 | |
175 | // Now, we listen for shutdown |
176 | drop(client); |
177 | |
178 | loop { |
179 | let ready = server.ready(Interest::READABLE).await?; |
180 | |
181 | if ready.is_read_closed() { |
182 | break; |
183 | } else { |
184 | tokio::task::yield_now().await; |
185 | } |
186 | } |
187 | |
188 | Ok(()) |
189 | } |
190 | |
191 | async fn create_pair() -> (UnixStream, UnixStream) { |
192 | let dir = assert_ok!(tempfile::tempdir()); |
193 | let bind_path = dir.path().join("bind.sock" ); |
194 | |
195 | let listener = assert_ok!(UnixListener::bind(&bind_path)); |
196 | |
197 | let accept = listener.accept(); |
198 | let connect = UnixStream::connect(&bind_path); |
199 | let ((server, _), client) = assert_ok!(try_join(accept, connect).await); |
200 | |
201 | (client, server) |
202 | } |
203 | |
204 | macro_rules! assert_readable_by_polling { |
205 | ($stream:expr) => { |
206 | assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await); |
207 | }; |
208 | } |
209 | |
210 | macro_rules! assert_not_readable_by_polling { |
211 | ($stream:expr) => { |
212 | poll_fn(|cx| { |
213 | assert_pending!($stream.poll_read_ready(cx)); |
214 | Poll::Ready(()) |
215 | }) |
216 | .await; |
217 | }; |
218 | } |
219 | |
220 | macro_rules! assert_writable_by_polling { |
221 | ($stream:expr) => { |
222 | assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await); |
223 | }; |
224 | } |
225 | |
226 | macro_rules! assert_not_writable_by_polling { |
227 | ($stream:expr) => { |
228 | poll_fn(|cx| { |
229 | assert_pending!($stream.poll_write_ready(cx)); |
230 | Poll::Ready(()) |
231 | }) |
232 | .await; |
233 | }; |
234 | } |
235 | |
236 | #[tokio::test ] |
237 | async fn poll_read_ready() { |
238 | let (mut client, mut server) = create_pair().await; |
239 | |
240 | // Initial state - not readable. |
241 | assert_not_readable_by_polling!(server); |
242 | |
243 | // There is data in the buffer - readable. |
244 | assert_ok!(client.write_all(b"ping" ).await); |
245 | assert_readable_by_polling!(server); |
246 | |
247 | // Readable until calls to `poll_read` return `Poll::Pending`. |
248 | let mut buf = [0u8; 4]; |
249 | assert_ok!(server.read_exact(&mut buf).await); |
250 | assert_readable_by_polling!(server); |
251 | read_until_pending(&mut server); |
252 | assert_not_readable_by_polling!(server); |
253 | |
254 | // Detect the client disconnect. |
255 | drop(client); |
256 | assert_readable_by_polling!(server); |
257 | } |
258 | |
259 | #[tokio::test ] |
260 | async fn poll_write_ready() { |
261 | let (mut client, server) = create_pair().await; |
262 | |
263 | // Initial state - writable. |
264 | assert_writable_by_polling!(client); |
265 | |
266 | // No space to write - not writable. |
267 | write_until_pending(&mut client); |
268 | assert_not_writable_by_polling!(client); |
269 | |
270 | // Detect the server disconnect. |
271 | drop(server); |
272 | assert_writable_by_polling!(client); |
273 | } |
274 | |
275 | fn read_until_pending(stream: &mut UnixStream) { |
276 | let mut buf = vec![0u8; 1024 * 1024]; |
277 | loop { |
278 | match stream.try_read(&mut buf) { |
279 | Ok(_) => (), |
280 | Err(err) => { |
281 | assert_eq!(err.kind(), io::ErrorKind::WouldBlock); |
282 | break; |
283 | } |
284 | } |
285 | } |
286 | } |
287 | |
288 | fn write_until_pending(stream: &mut UnixStream) { |
289 | let buf = vec![0u8; 1024 * 1024]; |
290 | loop { |
291 | match stream.try_write(&buf) { |
292 | Ok(_) => (), |
293 | Err(err) => { |
294 | assert_eq!(err.kind(), io::ErrorKind::WouldBlock); |
295 | break; |
296 | } |
297 | } |
298 | } |
299 | } |
300 | |
301 | #[tokio::test ] |
302 | async fn try_read_buf() -> std::io::Result<()> { |
303 | let msg = b"hello world" ; |
304 | |
305 | let dir = tempfile::tempdir()?; |
306 | let bind_path = dir.path().join("bind.sock" ); |
307 | |
308 | // Create listener |
309 | let listener = UnixListener::bind(&bind_path)?; |
310 | |
311 | // Create socket pair |
312 | let client = UnixStream::connect(&bind_path).await?; |
313 | |
314 | let (server, _) = listener.accept().await?; |
315 | let mut written = msg.to_vec(); |
316 | |
317 | // Track the server receiving data |
318 | let mut readable = task::spawn(server.readable()); |
319 | assert_pending!(readable.poll()); |
320 | |
321 | // Write data. |
322 | client.writable().await?; |
323 | assert_eq!(msg.len(), client.try_write(msg)?); |
324 | |
325 | // The task should be notified |
326 | while !readable.is_woken() { |
327 | tokio::task::yield_now().await; |
328 | } |
329 | |
330 | // Fill the write buffer |
331 | loop { |
332 | // Still ready |
333 | let mut writable = task::spawn(client.writable()); |
334 | assert_ready_ok!(writable.poll()); |
335 | |
336 | match client.try_write(msg) { |
337 | Ok(n) => written.extend(&msg[..n]), |
338 | Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { |
339 | break; |
340 | } |
341 | Err(e) => panic!("error = {:?}" , e), |
342 | } |
343 | } |
344 | |
345 | { |
346 | // Write buffer full |
347 | let mut writable = task::spawn(client.writable()); |
348 | assert_pending!(writable.poll()); |
349 | |
350 | // Drain the socket from the server end |
351 | let mut read = Vec::with_capacity(written.len()); |
352 | let mut i = 0; |
353 | |
354 | while i < read.capacity() { |
355 | server.readable().await?; |
356 | |
357 | match server.try_read_buf(&mut read) { |
358 | Ok(n) => i += n, |
359 | Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, |
360 | Err(e) => panic!("error = {:?}" , e), |
361 | } |
362 | } |
363 | |
364 | assert_eq!(read, written); |
365 | } |
366 | |
367 | // Now, we listen for shutdown |
368 | drop(client); |
369 | |
370 | loop { |
371 | let ready = server.ready(Interest::READABLE).await?; |
372 | |
373 | if ready.is_read_closed() { |
374 | break; |
375 | } else { |
376 | tokio::task::yield_now().await; |
377 | } |
378 | } |
379 | |
380 | Ok(()) |
381 | } |
382 | |
383 | // https://github.com/tokio-rs/tokio/issues/3879 |
384 | #[tokio::test ] |
385 | #[cfg (not(target_os = "macos" ))] |
386 | async fn epollhup() -> io::Result<()> { |
387 | let dir = tempfile::Builder::new() |
388 | .prefix("tokio-uds-tests" ) |
389 | .tempdir() |
390 | .unwrap(); |
391 | let sock_path = dir.path().join("connect.sock" ); |
392 | |
393 | let listener = UnixListener::bind(&sock_path)?; |
394 | let connect = UnixStream::connect(&sock_path); |
395 | tokio::pin!(connect); |
396 | |
397 | // Poll `connect` once. |
398 | poll_fn(|cx| { |
399 | use std::future::Future; |
400 | |
401 | assert_pending!(connect.as_mut().poll(cx)); |
402 | Poll::Ready(()) |
403 | }) |
404 | .await; |
405 | |
406 | drop(listener); |
407 | |
408 | let err = connect.await.unwrap_err(); |
409 | assert_eq!(err.kind(), io::ErrorKind::ConnectionReset); |
410 | Ok(()) |
411 | } |
412 | |