1 | use rustix::fd::{AsFd, BorrowedFd}; |
2 | use std::io::{IoSlice, Result}; |
3 | use std::net::TcpStream; |
4 | #[cfg (unix)] |
5 | use std::os::unix::io::{AsRawFd, IntoRawFd, OwnedFd, RawFd}; |
6 | #[cfg (unix)] |
7 | use std::os::unix::net::UnixStream; |
8 | #[cfg (windows)] |
9 | use std::os::windows::io::{ |
10 | AsRawSocket, AsSocket, BorrowedSocket, IntoRawSocket, OwnedSocket, RawSocket, |
11 | }; |
12 | |
13 | use crate::utils::RawFdContainer; |
14 | use x11rb_protocol::parse_display::ConnectAddress; |
15 | use x11rb_protocol::xauth::Family; |
16 | |
17 | /// The kind of operation that one want to poll for. |
18 | #[derive (Debug, Clone, Copy)] |
19 | pub enum PollMode { |
20 | /// Check if the stream is readable, i.e. there is pending data to be read. |
21 | Readable, |
22 | |
23 | /// Check if the stream is writable, i.e. some data could be successfully written to it. |
24 | Writable, |
25 | |
26 | /// Check for both readability and writability. |
27 | ReadAndWritable, |
28 | } |
29 | |
30 | impl PollMode { |
31 | /// Does this poll mode include readability? |
32 | pub fn readable(self) -> bool { |
33 | match self { |
34 | PollMode::Readable | PollMode::ReadAndWritable => true, |
35 | PollMode::Writable => false, |
36 | } |
37 | } |
38 | |
39 | /// Does this poll mode include writability? |
40 | pub fn writable(self) -> bool { |
41 | match self { |
42 | PollMode::Writable | PollMode::ReadAndWritable => true, |
43 | PollMode::Readable => false, |
44 | } |
45 | } |
46 | } |
47 | |
48 | /// A trait used to implement the raw communication with the X11 server. |
49 | /// |
50 | /// None of the functions of this trait shall return [`std::io::ErrorKind::Interrupted`]. |
51 | /// If a system call fails with this error, the implementation should try again. |
52 | pub trait Stream { |
53 | /// Waits for level-triggered read and/or write events on the stream. |
54 | /// |
55 | /// This function does not return what caused it to complete the poll. |
56 | /// Instead, callers should try to read or write and check for |
57 | /// [`std::io::ErrorKind::WouldBlock`]. |
58 | /// |
59 | /// This function is allowed to spuriously return even if the stream |
60 | /// is neither readable nor writable. However, it shall not do it |
61 | /// continuously, which would cause a 100% CPU usage. |
62 | /// |
63 | /// # Multithreading |
64 | /// |
65 | /// If `Self` is `Send + Sync` and `poll` is used concurrently from more than |
66 | /// one thread, all threads should wake when the stream becomes readable (when |
67 | /// `read` is `true`) or writable (when `write` is `true`). |
68 | fn poll(&self, mode: PollMode) -> Result<()>; |
69 | |
70 | /// Read some bytes and FDs from this reader without blocking, returning how many bytes |
71 | /// were read. |
72 | /// |
73 | /// This function works like [`std::io::Read::read`], but also supports the reception of file |
74 | /// descriptors. Any received file descriptors are appended to the given `fd_storage`. |
75 | /// Whereas implementation of [`std::io::Read::read`] are allowed to block or not to block, |
76 | /// this method shall never block and return `ErrorKind::WouldBlock` if needed. |
77 | /// |
78 | /// This function does not guarantee that all file descriptors were sent together with the data |
79 | /// with which they are received. However, file descriptors may not be received later than the |
80 | /// data that was sent at the same time. Instead, file descriptors may only be received |
81 | /// earlier. |
82 | /// |
83 | /// # Multithreading |
84 | /// |
85 | /// If `Self` is `Send + Sync` and `read` is used concurrently from more than one thread: |
86 | /// |
87 | /// * Both the data and the file descriptors shall be read in order, but possibly |
88 | /// interleaved across threads. |
89 | /// * Neither the data nor the file descriptors shall be duplicated. |
90 | /// * The returned value shall always be the actual number of bytes read into `buf`. |
91 | fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize>; |
92 | |
93 | /// Write a buffer and some FDs into this writer without blocking, returning how many |
94 | /// bytes were written. |
95 | /// |
96 | /// This function works like [`std::io::Write::write`], but also supports sending file |
97 | /// descriptors. The `fds` argument contains the file descriptors to send. The order of file |
98 | /// descriptors is maintained. Whereas implementation of [`std::io::Write::write`] are |
99 | /// allowed to block or not to block, this function must never block and return |
100 | /// `ErrorKind::WouldBlock` if needed. |
101 | /// |
102 | /// This function does not guarantee that all file descriptors are sent together with the data. |
103 | /// Any file descriptors that were sent are removed from the beginning of the given `Vec`. |
104 | /// |
105 | /// There is no guarantee that the given file descriptors are received together with the given |
106 | /// data. File descriptors might be received earlier than their corresponding data. It is not |
107 | /// allowed for file descriptors to be received later than the bytes that were sent at the same |
108 | /// time. |
109 | /// |
110 | /// # Multithreading |
111 | /// |
112 | /// If `Self` is `Send + Sync` and `write` is used concurrently from more than one thread: |
113 | /// |
114 | /// * Both the data and the file descriptors shall be written in order, but possibly |
115 | /// interleaved across threads. |
116 | /// * Neither the data nor the file descriptors shall be duplicated. |
117 | /// * The returned value shall always be the actual number of bytes written from `buf`. |
118 | fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize>; |
119 | |
120 | /// Like `write`, except that it writes from a slice of buffers. Like `write`, this |
121 | /// method must never block. |
122 | /// |
123 | /// This method must behave as a call to `write` with the buffers concatenated would. |
124 | /// |
125 | /// The default implementation calls `write` with the first nonempty buffer provided. |
126 | /// |
127 | /// # Multithreading |
128 | /// |
129 | /// Same as `write`. |
130 | fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> { |
131 | for buf in bufs { |
132 | if !buf.is_empty() { |
133 | return self.write(buf, fds); |
134 | } |
135 | } |
136 | Ok(0) |
137 | } |
138 | } |
139 | |
140 | /// A wrapper around a `TcpStream` or `UnixStream`. |
141 | /// |
142 | /// Use by default in `RustConnection` as stream. |
143 | #[derive (Debug)] |
144 | pub struct DefaultStream { |
145 | inner: DefaultStreamInner, |
146 | } |
147 | |
148 | #[cfg (unix)] |
149 | type DefaultStreamInner = RawFdContainer; |
150 | |
151 | #[cfg (not(unix))] |
152 | type DefaultStreamInner = TcpStream; |
153 | |
154 | /// The address of a peer in a format suitable for xauth. |
155 | /// |
156 | /// These values can be directly given to [`x11rb_protocol::xauth::get_auth`]. |
157 | type PeerAddr = (Family, Vec<u8>); |
158 | |
159 | impl DefaultStream { |
160 | /// Try to connect to the X11 server described by the given arguments. |
161 | pub fn connect(addr: &ConnectAddress<'_>) -> Result<(Self, PeerAddr)> { |
162 | match addr { |
163 | ConnectAddress::Hostname(host, port) => { |
164 | // connect over TCP |
165 | let stream = TcpStream::connect((*host, *port))?; |
166 | Self::from_tcp_stream(stream) |
167 | } |
168 | #[cfg (unix)] |
169 | ConnectAddress::Socket(path) => { |
170 | // Try abstract unix socket first. If that fails, fall back to normal unix socket |
171 | #[cfg (any(target_os = "linux" , target_os = "android" ))] |
172 | if let Ok(stream) = connect_abstract_unix_stream(path.as_bytes()) { |
173 | // TODO: Does it make sense to add a constructor similar to from_unix_stream()? |
174 | // If this is done: Move the set_nonblocking() from |
175 | // connect_abstract_unix_stream() to that new function. |
176 | let stream = DefaultStream { inner: stream }; |
177 | return Ok((stream, peer_addr::local())); |
178 | } |
179 | |
180 | // connect over Unix domain socket |
181 | let stream = UnixStream::connect(path)?; |
182 | Self::from_unix_stream(stream) |
183 | } |
184 | #[cfg (not(unix))] |
185 | ConnectAddress::Socket(_) => { |
186 | // Unix domain sockets are not supported on Windows |
187 | Err(std::io::Error::new( |
188 | std::io::ErrorKind::Other, |
189 | "Unix domain sockets are not supported on Windows" , |
190 | )) |
191 | } |
192 | _ => Err(std::io::Error::new( |
193 | std::io::ErrorKind::Other, |
194 | "The given address family is not implemented" , |
195 | )), |
196 | } |
197 | } |
198 | |
199 | /// Creates a new `Stream` from an already connected `TcpStream`. |
200 | /// |
201 | /// The stream will be set in non-blocking mode. |
202 | /// |
203 | /// This returns the peer address in a format suitable for [`x11rb_protocol::xauth::get_auth`]. |
204 | pub fn from_tcp_stream(stream: TcpStream) -> Result<(Self, PeerAddr)> { |
205 | let peer_addr = peer_addr::tcp(&stream.peer_addr()?); |
206 | stream.set_nonblocking(true)?; |
207 | let result = Self { |
208 | inner: stream.into(), |
209 | }; |
210 | Ok((result, peer_addr)) |
211 | } |
212 | |
213 | /// Creates a new `Stream` from an already connected `UnixStream`. |
214 | /// |
215 | /// The stream will be set in non-blocking mode. |
216 | /// |
217 | /// This returns the peer address in a format suitable for [`x11rb_protocol::xauth::get_auth`]. |
218 | #[cfg (unix)] |
219 | pub fn from_unix_stream(stream: UnixStream) -> Result<(Self, PeerAddr)> { |
220 | stream.set_nonblocking(true)?; |
221 | let result = Self { |
222 | inner: stream.into(), |
223 | }; |
224 | Ok((result, peer_addr::local())) |
225 | } |
226 | |
227 | fn as_fd(&self) -> BorrowedFd<'_> { |
228 | self.inner.as_fd() |
229 | } |
230 | } |
231 | |
232 | #[cfg (unix)] |
233 | impl AsRawFd for DefaultStream { |
234 | fn as_raw_fd(&self) -> RawFd { |
235 | self.inner.as_raw_fd() |
236 | } |
237 | } |
238 | |
239 | #[cfg (unix)] |
240 | impl AsFd for DefaultStream { |
241 | fn as_fd(&self) -> BorrowedFd<'_> { |
242 | self.inner.as_fd() |
243 | } |
244 | } |
245 | |
246 | #[cfg (unix)] |
247 | impl IntoRawFd for DefaultStream { |
248 | fn into_raw_fd(self) -> RawFd { |
249 | self.inner.into_raw_fd() |
250 | } |
251 | } |
252 | |
253 | #[cfg (unix)] |
254 | impl From<DefaultStream> for OwnedFd { |
255 | fn from(stream: DefaultStream) -> Self { |
256 | stream.inner |
257 | } |
258 | } |
259 | |
260 | #[cfg (windows)] |
261 | impl AsRawSocket for DefaultStream { |
262 | fn as_raw_socket(&self) -> RawSocket { |
263 | self.inner.as_raw_socket() |
264 | } |
265 | } |
266 | |
267 | #[cfg (windows)] |
268 | impl AsSocket for DefaultStream { |
269 | fn as_socket(&self) -> BorrowedSocket<'_> { |
270 | self.inner.as_socket() |
271 | } |
272 | } |
273 | |
274 | #[cfg (windows)] |
275 | impl IntoRawSocket for DefaultStream { |
276 | fn into_raw_socket(self) -> RawSocket { |
277 | self.inner.into_raw_socket() |
278 | } |
279 | } |
280 | |
281 | #[cfg (windows)] |
282 | impl From<DefaultStream> for OwnedSocket { |
283 | fn from(stream: DefaultStream) -> Self { |
284 | stream.inner.into() |
285 | } |
286 | } |
287 | |
288 | #[cfg (unix)] |
289 | fn do_write( |
290 | stream: &DefaultStream, |
291 | bufs: &[IoSlice<'_>], |
292 | fds: &mut Vec<RawFdContainer>, |
293 | ) -> Result<usize> { |
294 | use rustix::io::Errno; |
295 | use rustix::net::{sendmsg, SendAncillaryBuffer, SendAncillaryMessage, SendFlags}; |
296 | |
297 | fn sendmsg_wrapper( |
298 | fd: BorrowedFd<'_>, |
299 | iov: &[IoSlice<'_>], |
300 | cmsgs: &mut SendAncillaryBuffer<'_, '_, '_>, |
301 | flags: SendFlags, |
302 | ) -> Result<usize> { |
303 | loop { |
304 | match sendmsg(fd, iov, cmsgs, flags) { |
305 | Ok(n) => return Ok(n), |
306 | // try again |
307 | Err(Errno::INTR) => {} |
308 | Err(e) => return Err(e.into()), |
309 | } |
310 | } |
311 | } |
312 | |
313 | let fd = stream.as_fd(); |
314 | |
315 | let res = if !fds.is_empty() { |
316 | let fds = fds.iter().map(|fd| fd.as_fd()).collect::<Vec<_>>(); |
317 | let rights = SendAncillaryMessage::ScmRights(&fds); |
318 | |
319 | let mut cmsg_space = vec![0u8; rights.size()]; |
320 | let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space); |
321 | assert!(cmsg_buffer.push(rights)); |
322 | |
323 | sendmsg_wrapper(fd, bufs, &mut cmsg_buffer, SendFlags::empty())? |
324 | } else { |
325 | sendmsg_wrapper(fd, bufs, &mut Default::default(), SendFlags::empty())? |
326 | }; |
327 | |
328 | // We successfully sent all FDs |
329 | fds.clear(); |
330 | |
331 | Ok(res) |
332 | } |
333 | |
334 | impl Stream for DefaultStream { |
335 | fn poll(&self, mode: PollMode) -> Result<()> { |
336 | use rustix::event::{poll, PollFd, PollFlags}; |
337 | use rustix::io::Errno; |
338 | |
339 | let mut poll_flags = PollFlags::empty(); |
340 | if mode.readable() { |
341 | poll_flags |= PollFlags::IN; |
342 | } |
343 | if mode.writable() { |
344 | poll_flags |= PollFlags::OUT; |
345 | } |
346 | let fd = self.as_fd(); |
347 | let mut poll_fds = [PollFd::from_borrowed_fd(fd, poll_flags)]; |
348 | loop { |
349 | match poll(&mut poll_fds, -1) { |
350 | Ok(_) => break, |
351 | Err(Errno::INTR) => {} |
352 | Err(e) => return Err(e.into()), |
353 | } |
354 | } |
355 | // Let the errors (POLLERR) be handled when trying to read or write. |
356 | Ok(()) |
357 | } |
358 | |
359 | fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize> { |
360 | #[cfg (unix)] |
361 | { |
362 | use rustix::io::Errno; |
363 | use rustix::net::{recvmsg, RecvAncillaryBuffer, RecvAncillaryMessage}; |
364 | use std::io::IoSliceMut; |
365 | |
366 | // 1024 bytes on the stack should be enough for more file descriptors than the X server will ever |
367 | // send, as well as the header for the ancillary data. If you can find a case where this can |
368 | // overflow with an actual production X11 server, I'll buy you a steak dinner. |
369 | let mut cmsg = [0u8; 1024]; |
370 | let mut iov = [IoSliceMut::new(buf)]; |
371 | let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg); |
372 | |
373 | let fd = self.as_fd(); |
374 | let msg = loop { |
375 | match recvmsg(fd, &mut iov, &mut cmsg_buffer, recvmsg::flags()) { |
376 | Ok(msg) => break msg, |
377 | // try again |
378 | Err(Errno::INTR) => {} |
379 | Err(e) => return Err(e.into()), |
380 | } |
381 | }; |
382 | |
383 | let fds_received = cmsg_buffer |
384 | .drain() |
385 | .filter_map(|cmsg| match cmsg { |
386 | RecvAncillaryMessage::ScmRights(r) => Some(r), |
387 | _ => None, |
388 | }) |
389 | .flatten(); |
390 | |
391 | let mut cloexec_error = Ok(()); |
392 | fd_storage.extend(recvmsg::after_recvmsg(fds_received, &mut cloexec_error)); |
393 | cloexec_error?; |
394 | |
395 | Ok(msg.bytes) |
396 | } |
397 | #[cfg (not(unix))] |
398 | { |
399 | use std::io::Read; |
400 | // No FDs are read, so nothing needs to be done with fd_storage |
401 | let _ = fd_storage; |
402 | loop { |
403 | // Use `impl Read for &TcpStream` to avoid needing a mutable `TcpStream`. |
404 | match (&mut &self.inner).read(buf) { |
405 | Ok(n) => return Ok(n), |
406 | // try again |
407 | Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {} |
408 | Err(e) => return Err(e), |
409 | } |
410 | } |
411 | } |
412 | } |
413 | |
414 | fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize> { |
415 | #[cfg (unix)] |
416 | { |
417 | do_write(self, &[IoSlice::new(buf)], fds) |
418 | } |
419 | #[cfg (not(unix))] |
420 | { |
421 | use std::io::{Error, ErrorKind, Write}; |
422 | if !fds.is_empty() { |
423 | return Err(Error::new(ErrorKind::Other, "FD passing is unsupported" )); |
424 | } |
425 | loop { |
426 | // Use `impl Write for &TcpStream` to avoid needing a mutable `TcpStream`. |
427 | match (&mut &self.inner).write(buf) { |
428 | Ok(n) => return Ok(n), |
429 | // try again |
430 | Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {} |
431 | Err(e) => return Err(e), |
432 | } |
433 | } |
434 | } |
435 | } |
436 | |
437 | fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> { |
438 | #[cfg (unix)] |
439 | { |
440 | do_write(self, bufs, fds) |
441 | } |
442 | #[cfg (not(unix))] |
443 | { |
444 | use std::io::{Error, ErrorKind, Write}; |
445 | if !fds.is_empty() { |
446 | return Err(Error::new(ErrorKind::Other, "FD passing is unsupported" )); |
447 | } |
448 | loop { |
449 | // Use `impl Write for &TcpStream` to avoid needing a mutable `TcpStream`. |
450 | match (&mut &self.inner).write_vectored(bufs) { |
451 | Ok(n) => return Ok(n), |
452 | // try again |
453 | Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {} |
454 | Err(e) => return Err(e), |
455 | } |
456 | } |
457 | } |
458 | } |
459 | } |
460 | |
461 | #[cfg (any(target_os = "linux" , target_os = "android" ))] |
462 | fn connect_abstract_unix_stream( |
463 | path: &[u8], |
464 | ) -> std::result::Result<RawFdContainer, rustix::io::Errno> { |
465 | use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags}; |
466 | use rustix::net::{ |
467 | connect_unix, socket_with, AddressFamily, SocketAddrUnix, SocketFlags, SocketType, |
468 | }; |
469 | |
470 | let socket: OwnedFd = socket_with( |
471 | domain:AddressFamily::UNIX, |
472 | type_:SocketType::STREAM, |
473 | flags:SocketFlags::CLOEXEC, |
474 | protocol:None, |
475 | )?; |
476 | |
477 | connect_unix(&socket, &SocketAddrUnix::new_abstract_name(path)?)?; |
478 | |
479 | // Make the FD non-blocking |
480 | fcntl_setfl(&socket, flags:fcntl_getfl(&socket)? | OFlags::NONBLOCK)?; |
481 | |
482 | Ok(socket) |
483 | } |
484 | |
485 | /// Helper code to make sure that received FDs are marked as CLOEXEC |
486 | #[cfg (any( |
487 | target_os = "android" , |
488 | target_os = "dragonfly" , |
489 | target_os = "freebsd" , |
490 | target_os = "linux" , |
491 | target_os = "netbsd" , |
492 | target_os = "openbsd" |
493 | ))] |
494 | mod recvmsg { |
495 | use super::RawFdContainer; |
496 | use rustix::net::RecvFlags; |
497 | |
498 | pub(crate) fn flags() -> RecvFlags { |
499 | RecvFlags::CMSG_CLOEXEC |
500 | } |
501 | |
502 | pub(crate) fn after_recvmsg<'a>( |
503 | fds: impl Iterator<Item = RawFdContainer> + 'a, |
504 | _cloexec_error: &'a mut Result<(), rustix::io::Errno>, |
505 | ) -> impl Iterator<Item = RawFdContainer> + 'a { |
506 | fds |
507 | } |
508 | } |
509 | |
510 | /// Helper code to make sure that received FDs are marked as CLOEXEC |
511 | #[cfg (all( |
512 | unix, |
513 | not(any( |
514 | target_os = "android" , |
515 | target_os = "dragonfly" , |
516 | target_os = "freebsd" , |
517 | target_os = "linux" , |
518 | target_os = "netbsd" , |
519 | target_os = "openbsd" |
520 | )) |
521 | ))] |
522 | mod recvmsg { |
523 | use super::RawFdContainer; |
524 | use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags}; |
525 | use rustix::net::RecvFlags; |
526 | |
527 | pub(crate) fn flags() -> RecvFlags { |
528 | RecvFlags::empty() |
529 | } |
530 | |
531 | pub(crate) fn after_recvmsg<'a>( |
532 | fds: impl Iterator<Item = RawFdContainer> + 'a, |
533 | cloexec_error: &'a mut rustix::io::Result<()>, |
534 | ) -> impl Iterator<Item = RawFdContainer> + 'a { |
535 | fds.map(move |fd| { |
536 | if let Err(e) = |
537 | fcntl_getfd(&fd).and_then(|flags| fcntl_setfd(&fd, flags | FdFlags::CLOEXEC)) |
538 | { |
539 | *cloexec_error = Err(e); |
540 | } |
541 | fd |
542 | }) |
543 | } |
544 | } |
545 | |
546 | mod peer_addr { |
547 | use super::{Family, PeerAddr}; |
548 | use std::net::{Ipv4Addr, SocketAddr}; |
549 | |
550 | // Get xauth information representing a local connection |
551 | pub(super) fn local() -> PeerAddr { |
552 | let hostname = crate::hostname() |
553 | .to_str() |
554 | .map_or_else(Vec::new, |s| s.as_bytes().to_vec()); |
555 | (Family::LOCAL, hostname) |
556 | } |
557 | |
558 | // Get xauth information representing a TCP connection to the given address |
559 | pub(super) fn tcp(addr: &SocketAddr) -> PeerAddr { |
560 | let ip = match addr { |
561 | SocketAddr::V4(addr) => *addr.ip(), |
562 | SocketAddr::V6(addr) => { |
563 | let ip = addr.ip(); |
564 | if ip.is_loopback() { |
565 | // This is a local connection. |
566 | // Use LOCALHOST to cause a fall-through in the code below. |
567 | Ipv4Addr::LOCALHOST |
568 | } else if let Some(ip) = ip.to_ipv4() { |
569 | // Let the ipv4 code below handle this |
570 | ip |
571 | } else { |
572 | // Okay, this is really a v6 address |
573 | return (Family::INTERNET6, ip.octets().to_vec()); |
574 | } |
575 | } |
576 | }; |
577 | |
578 | // Handle the v4 address |
579 | if ip.is_loopback() { |
580 | local() |
581 | } else { |
582 | (Family::INTERNET, ip.octets().to_vec()) |
583 | } |
584 | } |
585 | } |
586 | |