1use rustix::fd::{AsFd, BorrowedFd};
2use std::io::{IoSlice, Result};
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::io::{AsRawFd, IntoRawFd, OwnedFd, RawFd};
6#[cfg(unix)]
7use std::os::unix::net::UnixStream;
8#[cfg(windows)]
9use std::os::windows::io::{
10 AsRawSocket, AsSocket, BorrowedSocket, IntoRawSocket, OwnedSocket, RawSocket,
11};
12
13use crate::utils::RawFdContainer;
14use x11rb_protocol::parse_display::ConnectAddress;
15use x11rb_protocol::xauth::Family;
16
17/// The kind of operation that one want to poll for.
18#[derive(Debug, Clone, Copy)]
19pub enum PollMode {
20 /// Check if the stream is readable, i.e. there is pending data to be read.
21 Readable,
22
23 /// Check if the stream is writable, i.e. some data could be successfully written to it.
24 Writable,
25
26 /// Check for both readability and writability.
27 ReadAndWritable,
28}
29
30impl PollMode {
31 /// Does this poll mode include readability?
32 pub fn readable(self) -> bool {
33 match self {
34 PollMode::Readable | PollMode::ReadAndWritable => true,
35 PollMode::Writable => false,
36 }
37 }
38
39 /// Does this poll mode include writability?
40 pub fn writable(self) -> bool {
41 match self {
42 PollMode::Writable | PollMode::ReadAndWritable => true,
43 PollMode::Readable => false,
44 }
45 }
46}
47
48/// A trait used to implement the raw communication with the X11 server.
49///
50/// None of the functions of this trait shall return [`std::io::ErrorKind::Interrupted`].
51/// If a system call fails with this error, the implementation should try again.
52pub trait Stream {
53 /// Waits for level-triggered read and/or write events on the stream.
54 ///
55 /// This function does not return what caused it to complete the poll.
56 /// Instead, callers should try to read or write and check for
57 /// [`std::io::ErrorKind::WouldBlock`].
58 ///
59 /// This function is allowed to spuriously return even if the stream
60 /// is neither readable nor writable. However, it shall not do it
61 /// continuously, which would cause a 100% CPU usage.
62 ///
63 /// # Multithreading
64 ///
65 /// If `Self` is `Send + Sync` and `poll` is used concurrently from more than
66 /// one thread, all threads should wake when the stream becomes readable (when
67 /// `read` is `true`) or writable (when `write` is `true`).
68 fn poll(&self, mode: PollMode) -> Result<()>;
69
70 /// Read some bytes and FDs from this reader without blocking, returning how many bytes
71 /// were read.
72 ///
73 /// This function works like [`std::io::Read::read`], but also supports the reception of file
74 /// descriptors. Any received file descriptors are appended to the given `fd_storage`.
75 /// Whereas implementation of [`std::io::Read::read`] are allowed to block or not to block,
76 /// this method shall never block and return `ErrorKind::WouldBlock` if needed.
77 ///
78 /// This function does not guarantee that all file descriptors were sent together with the data
79 /// with which they are received. However, file descriptors may not be received later than the
80 /// data that was sent at the same time. Instead, file descriptors may only be received
81 /// earlier.
82 ///
83 /// # Multithreading
84 ///
85 /// If `Self` is `Send + Sync` and `read` is used concurrently from more than one thread:
86 ///
87 /// * Both the data and the file descriptors shall be read in order, but possibly
88 /// interleaved across threads.
89 /// * Neither the data nor the file descriptors shall be duplicated.
90 /// * The returned value shall always be the actual number of bytes read into `buf`.
91 fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize>;
92
93 /// Write a buffer and some FDs into this writer without blocking, returning how many
94 /// bytes were written.
95 ///
96 /// This function works like [`std::io::Write::write`], but also supports sending file
97 /// descriptors. The `fds` argument contains the file descriptors to send. The order of file
98 /// descriptors is maintained. Whereas implementation of [`std::io::Write::write`] are
99 /// allowed to block or not to block, this function must never block and return
100 /// `ErrorKind::WouldBlock` if needed.
101 ///
102 /// This function does not guarantee that all file descriptors are sent together with the data.
103 /// Any file descriptors that were sent are removed from the beginning of the given `Vec`.
104 ///
105 /// There is no guarantee that the given file descriptors are received together with the given
106 /// data. File descriptors might be received earlier than their corresponding data. It is not
107 /// allowed for file descriptors to be received later than the bytes that were sent at the same
108 /// time.
109 ///
110 /// # Multithreading
111 ///
112 /// If `Self` is `Send + Sync` and `write` is used concurrently from more than one thread:
113 ///
114 /// * Both the data and the file descriptors shall be written in order, but possibly
115 /// interleaved across threads.
116 /// * Neither the data nor the file descriptors shall be duplicated.
117 /// * The returned value shall always be the actual number of bytes written from `buf`.
118 fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize>;
119
120 /// Like `write`, except that it writes from a slice of buffers. Like `write`, this
121 /// method must never block.
122 ///
123 /// This method must behave as a call to `write` with the buffers concatenated would.
124 ///
125 /// The default implementation calls `write` with the first nonempty buffer provided.
126 ///
127 /// # Multithreading
128 ///
129 /// Same as `write`.
130 fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
131 for buf in bufs {
132 if !buf.is_empty() {
133 return self.write(buf, fds);
134 }
135 }
136 Ok(0)
137 }
138}
139
140/// A wrapper around a `TcpStream` or `UnixStream`.
141///
142/// Use by default in `RustConnection` as stream.
143#[derive(Debug)]
144pub struct DefaultStream {
145 inner: DefaultStreamInner,
146}
147
148#[cfg(unix)]
149type DefaultStreamInner = RawFdContainer;
150
151#[cfg(not(unix))]
152type DefaultStreamInner = TcpStream;
153
154/// The address of a peer in a format suitable for xauth.
155///
156/// These values can be directly given to [`x11rb_protocol::xauth::get_auth`].
157type PeerAddr = (Family, Vec<u8>);
158
159impl DefaultStream {
160 /// Try to connect to the X11 server described by the given arguments.
161 pub fn connect(addr: &ConnectAddress<'_>) -> Result<(Self, PeerAddr)> {
162 match addr {
163 ConnectAddress::Hostname(host, port) => {
164 // connect over TCP
165 let stream = TcpStream::connect((*host, *port))?;
166 Self::from_tcp_stream(stream)
167 }
168 #[cfg(unix)]
169 ConnectAddress::Socket(path) => {
170 // Try abstract unix socket first. If that fails, fall back to normal unix socket
171 #[cfg(any(target_os = "linux", target_os = "android"))]
172 if let Ok(stream) = connect_abstract_unix_stream(path.as_bytes()) {
173 // TODO: Does it make sense to add a constructor similar to from_unix_stream()?
174 // If this is done: Move the set_nonblocking() from
175 // connect_abstract_unix_stream() to that new function.
176 let stream = DefaultStream { inner: stream };
177 return Ok((stream, peer_addr::local()));
178 }
179
180 // connect over Unix domain socket
181 let stream = UnixStream::connect(path)?;
182 Self::from_unix_stream(stream)
183 }
184 #[cfg(not(unix))]
185 ConnectAddress::Socket(_) => {
186 // Unix domain sockets are not supported on Windows
187 Err(std::io::Error::new(
188 std::io::ErrorKind::Other,
189 "Unix domain sockets are not supported on Windows",
190 ))
191 }
192 _ => Err(std::io::Error::new(
193 std::io::ErrorKind::Other,
194 "The given address family is not implemented",
195 )),
196 }
197 }
198
199 /// Creates a new `Stream` from an already connected `TcpStream`.
200 ///
201 /// The stream will be set in non-blocking mode.
202 ///
203 /// This returns the peer address in a format suitable for [`x11rb_protocol::xauth::get_auth`].
204 pub fn from_tcp_stream(stream: TcpStream) -> Result<(Self, PeerAddr)> {
205 let peer_addr = peer_addr::tcp(&stream.peer_addr()?);
206 stream.set_nonblocking(true)?;
207 let result = Self {
208 inner: stream.into(),
209 };
210 Ok((result, peer_addr))
211 }
212
213 /// Creates a new `Stream` from an already connected `UnixStream`.
214 ///
215 /// The stream will be set in non-blocking mode.
216 ///
217 /// This returns the peer address in a format suitable for [`x11rb_protocol::xauth::get_auth`].
218 #[cfg(unix)]
219 pub fn from_unix_stream(stream: UnixStream) -> Result<(Self, PeerAddr)> {
220 stream.set_nonblocking(true)?;
221 let result = Self {
222 inner: stream.into(),
223 };
224 Ok((result, peer_addr::local()))
225 }
226
227 fn as_fd(&self) -> BorrowedFd<'_> {
228 self.inner.as_fd()
229 }
230}
231
232#[cfg(unix)]
233impl AsRawFd for DefaultStream {
234 fn as_raw_fd(&self) -> RawFd {
235 self.inner.as_raw_fd()
236 }
237}
238
239#[cfg(unix)]
240impl AsFd for DefaultStream {
241 fn as_fd(&self) -> BorrowedFd<'_> {
242 self.inner.as_fd()
243 }
244}
245
246#[cfg(unix)]
247impl IntoRawFd for DefaultStream {
248 fn into_raw_fd(self) -> RawFd {
249 self.inner.into_raw_fd()
250 }
251}
252
253#[cfg(unix)]
254impl From<DefaultStream> for OwnedFd {
255 fn from(stream: DefaultStream) -> Self {
256 stream.inner
257 }
258}
259
260#[cfg(windows)]
261impl AsRawSocket for DefaultStream {
262 fn as_raw_socket(&self) -> RawSocket {
263 self.inner.as_raw_socket()
264 }
265}
266
267#[cfg(windows)]
268impl AsSocket for DefaultStream {
269 fn as_socket(&self) -> BorrowedSocket<'_> {
270 self.inner.as_socket()
271 }
272}
273
274#[cfg(windows)]
275impl IntoRawSocket for DefaultStream {
276 fn into_raw_socket(self) -> RawSocket {
277 self.inner.into_raw_socket()
278 }
279}
280
281#[cfg(windows)]
282impl From<DefaultStream> for OwnedSocket {
283 fn from(stream: DefaultStream) -> Self {
284 stream.inner.into()
285 }
286}
287
288#[cfg(unix)]
289fn do_write(
290 stream: &DefaultStream,
291 bufs: &[IoSlice<'_>],
292 fds: &mut Vec<RawFdContainer>,
293) -> Result<usize> {
294 use rustix::io::Errno;
295 use rustix::net::{sendmsg, SendAncillaryBuffer, SendAncillaryMessage, SendFlags};
296
297 fn sendmsg_wrapper(
298 fd: BorrowedFd<'_>,
299 iov: &[IoSlice<'_>],
300 cmsgs: &mut SendAncillaryBuffer<'_, '_, '_>,
301 flags: SendFlags,
302 ) -> Result<usize> {
303 loop {
304 match sendmsg(fd, iov, cmsgs, flags) {
305 Ok(n) => return Ok(n),
306 // try again
307 Err(Errno::INTR) => {}
308 Err(e) => return Err(e.into()),
309 }
310 }
311 }
312
313 let fd = stream.as_fd();
314
315 let res = if !fds.is_empty() {
316 let fds = fds.iter().map(|fd| fd.as_fd()).collect::<Vec<_>>();
317 let rights = SendAncillaryMessage::ScmRights(&fds);
318
319 let mut cmsg_space = vec![0u8; rights.size()];
320 let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
321 assert!(cmsg_buffer.push(rights));
322
323 sendmsg_wrapper(fd, bufs, &mut cmsg_buffer, SendFlags::empty())?
324 } else {
325 sendmsg_wrapper(fd, bufs, &mut Default::default(), SendFlags::empty())?
326 };
327
328 // We successfully sent all FDs
329 fds.clear();
330
331 Ok(res)
332}
333
334impl Stream for DefaultStream {
335 fn poll(&self, mode: PollMode) -> Result<()> {
336 use rustix::event::{poll, PollFd, PollFlags};
337 use rustix::io::Errno;
338
339 let mut poll_flags = PollFlags::empty();
340 if mode.readable() {
341 poll_flags |= PollFlags::IN;
342 }
343 if mode.writable() {
344 poll_flags |= PollFlags::OUT;
345 }
346 let fd = self.as_fd();
347 let mut poll_fds = [PollFd::from_borrowed_fd(fd, poll_flags)];
348 loop {
349 match poll(&mut poll_fds, -1) {
350 Ok(_) => break,
351 Err(Errno::INTR) => {}
352 Err(e) => return Err(e.into()),
353 }
354 }
355 // Let the errors (POLLERR) be handled when trying to read or write.
356 Ok(())
357 }
358
359 fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize> {
360 #[cfg(unix)]
361 {
362 use rustix::io::Errno;
363 use rustix::net::{recvmsg, RecvAncillaryBuffer, RecvAncillaryMessage};
364 use std::io::IoSliceMut;
365
366 // 1024 bytes on the stack should be enough for more file descriptors than the X server will ever
367 // send, as well as the header for the ancillary data. If you can find a case where this can
368 // overflow with an actual production X11 server, I'll buy you a steak dinner.
369 let mut cmsg = [0u8; 1024];
370 let mut iov = [IoSliceMut::new(buf)];
371 let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg);
372
373 let fd = self.as_fd();
374 let msg = loop {
375 match recvmsg(fd, &mut iov, &mut cmsg_buffer, recvmsg::flags()) {
376 Ok(msg) => break msg,
377 // try again
378 Err(Errno::INTR) => {}
379 Err(e) => return Err(e.into()),
380 }
381 };
382
383 let fds_received = cmsg_buffer
384 .drain()
385 .filter_map(|cmsg| match cmsg {
386 RecvAncillaryMessage::ScmRights(r) => Some(r),
387 _ => None,
388 })
389 .flatten();
390
391 let mut cloexec_error = Ok(());
392 fd_storage.extend(recvmsg::after_recvmsg(fds_received, &mut cloexec_error));
393 cloexec_error?;
394
395 Ok(msg.bytes)
396 }
397 #[cfg(not(unix))]
398 {
399 use std::io::Read;
400 // No FDs are read, so nothing needs to be done with fd_storage
401 let _ = fd_storage;
402 loop {
403 // Use `impl Read for &TcpStream` to avoid needing a mutable `TcpStream`.
404 match (&mut &self.inner).read(buf) {
405 Ok(n) => return Ok(n),
406 // try again
407 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
408 Err(e) => return Err(e),
409 }
410 }
411 }
412 }
413
414 fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
415 #[cfg(unix)]
416 {
417 do_write(self, &[IoSlice::new(buf)], fds)
418 }
419 #[cfg(not(unix))]
420 {
421 use std::io::{Error, ErrorKind, Write};
422 if !fds.is_empty() {
423 return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
424 }
425 loop {
426 // Use `impl Write for &TcpStream` to avoid needing a mutable `TcpStream`.
427 match (&mut &self.inner).write(buf) {
428 Ok(n) => return Ok(n),
429 // try again
430 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
431 Err(e) => return Err(e),
432 }
433 }
434 }
435 }
436
437 fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
438 #[cfg(unix)]
439 {
440 do_write(self, bufs, fds)
441 }
442 #[cfg(not(unix))]
443 {
444 use std::io::{Error, ErrorKind, Write};
445 if !fds.is_empty() {
446 return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
447 }
448 loop {
449 // Use `impl Write for &TcpStream` to avoid needing a mutable `TcpStream`.
450 match (&mut &self.inner).write_vectored(bufs) {
451 Ok(n) => return Ok(n),
452 // try again
453 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
454 Err(e) => return Err(e),
455 }
456 }
457 }
458 }
459}
460
461#[cfg(any(target_os = "linux", target_os = "android"))]
462fn connect_abstract_unix_stream(
463 path: &[u8],
464) -> std::result::Result<RawFdContainer, rustix::io::Errno> {
465 use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags};
466 use rustix::net::{
467 connect_unix, socket_with, AddressFamily, SocketAddrUnix, SocketFlags, SocketType,
468 };
469
470 let socket: OwnedFd = socket_with(
471 domain:AddressFamily::UNIX,
472 type_:SocketType::STREAM,
473 flags:SocketFlags::CLOEXEC,
474 protocol:None,
475 )?;
476
477 connect_unix(&socket, &SocketAddrUnix::new_abstract_name(path)?)?;
478
479 // Make the FD non-blocking
480 fcntl_setfl(&socket, flags:fcntl_getfl(&socket)? | OFlags::NONBLOCK)?;
481
482 Ok(socket)
483}
484
485/// Helper code to make sure that received FDs are marked as CLOEXEC
486#[cfg(any(
487 target_os = "android",
488 target_os = "dragonfly",
489 target_os = "freebsd",
490 target_os = "linux",
491 target_os = "netbsd",
492 target_os = "openbsd"
493))]
494mod recvmsg {
495 use super::RawFdContainer;
496 use rustix::net::RecvFlags;
497
498 pub(crate) fn flags() -> RecvFlags {
499 RecvFlags::CMSG_CLOEXEC
500 }
501
502 pub(crate) fn after_recvmsg<'a>(
503 fds: impl Iterator<Item = RawFdContainer> + 'a,
504 _cloexec_error: &'a mut Result<(), rustix::io::Errno>,
505 ) -> impl Iterator<Item = RawFdContainer> + 'a {
506 fds
507 }
508}
509
510/// Helper code to make sure that received FDs are marked as CLOEXEC
511#[cfg(all(
512 unix,
513 not(any(
514 target_os = "android",
515 target_os = "dragonfly",
516 target_os = "freebsd",
517 target_os = "linux",
518 target_os = "netbsd",
519 target_os = "openbsd"
520 ))
521))]
522mod recvmsg {
523 use super::RawFdContainer;
524 use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags};
525 use rustix::net::RecvFlags;
526
527 pub(crate) fn flags() -> RecvFlags {
528 RecvFlags::empty()
529 }
530
531 pub(crate) fn after_recvmsg<'a>(
532 fds: impl Iterator<Item = RawFdContainer> + 'a,
533 cloexec_error: &'a mut rustix::io::Result<()>,
534 ) -> impl Iterator<Item = RawFdContainer> + 'a {
535 fds.map(move |fd| {
536 if let Err(e) =
537 fcntl_getfd(&fd).and_then(|flags| fcntl_setfd(&fd, flags | FdFlags::CLOEXEC))
538 {
539 *cloexec_error = Err(e);
540 }
541 fd
542 })
543 }
544}
545
546mod peer_addr {
547 use super::{Family, PeerAddr};
548 use std::net::{Ipv4Addr, SocketAddr};
549
550 // Get xauth information representing a local connection
551 pub(super) fn local() -> PeerAddr {
552 let hostname = crate::hostname()
553 .to_str()
554 .map_or_else(Vec::new, |s| s.as_bytes().to_vec());
555 (Family::LOCAL, hostname)
556 }
557
558 // Get xauth information representing a TCP connection to the given address
559 pub(super) fn tcp(addr: &SocketAddr) -> PeerAddr {
560 let ip = match addr {
561 SocketAddr::V4(addr) => *addr.ip(),
562 SocketAddr::V6(addr) => {
563 let ip = addr.ip();
564 if ip.is_loopback() {
565 // This is a local connection.
566 // Use LOCALHOST to cause a fall-through in the code below.
567 Ipv4Addr::LOCALHOST
568 } else if let Some(ip) = ip.to_ipv4() {
569 // Let the ipv4 code below handle this
570 ip
571 } else {
572 // Okay, this is really a v6 address
573 return (Family::INTERNET6, ip.octets().to_vec());
574 }
575 }
576 };
577
578 // Handle the v4 address
579 if ip.is_loopback() {
580 local()
581 } else {
582 (Family::INTERNET, ip.octets().to_vec())
583 }
584 }
585}
586