1#[cfg(not(feature = "tokio"))]
2use async_io::Async;
3#[cfg(not(feature = "tokio"))]
4use futures_core::ready;
5#[cfg(unix)]
6use std::io::{IoSlice, IoSliceMut};
7#[cfg(feature = "tokio")]
8use std::pin::Pin;
9use std::{
10 io,
11 task::{Context, Poll},
12};
13#[cfg(not(feature = "tokio"))]
14use std::{
15 io::{Read, Write},
16 net::TcpStream,
17};
18
19#[cfg(all(windows, not(feature = "tokio")))]
20use uds_windows::UnixStream;
21
22#[cfg(unix)]
23use nix::{
24 cmsg_space,
25 sys::socket::{recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, UnixAddr},
26};
27#[cfg(unix)]
28use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
29
30#[cfg(all(unix, not(feature = "tokio")))]
31use std::os::unix::net::UnixStream;
32
33#[cfg(unix)]
34use crate::{utils::FDS_MAX, OwnedFd};
35
36#[cfg(unix)]
37fn fd_recvmsg(fd: RawFd, buffer: &mut [u8]) -> io::Result<(usize, Vec<OwnedFd>)> {
38 let mut iov = [IoSliceMut::new(buffer)];
39 let mut cmsgspace = cmsg_space!([RawFd; FDS_MAX]);
40
41 let msg = recvmsg::<UnixAddr>(fd, &mut iov, Some(&mut cmsgspace), MsgFlags::empty())?;
42 if msg.bytes == 0 {
43 return Err(io::Error::new(
44 io::ErrorKind::BrokenPipe,
45 "failed to read from socket",
46 ));
47 }
48 let mut fds = vec![];
49 for cmsg in msg.cmsgs() {
50 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
51 if let ControlMessageOwned::ScmCreds(_) = cmsg {
52 continue;
53 }
54 if let ControlMessageOwned::ScmRights(fd) = cmsg {
55 fds.extend(fd.iter().map(|&f| unsafe { OwnedFd::from_raw_fd(f) }));
56 } else {
57 return Err(io::Error::new(
58 io::ErrorKind::InvalidData,
59 "unexpected CMSG kind",
60 ));
61 }
62 }
63 Ok((msg.bytes, fds))
64}
65
66#[cfg(unix)]
67fn fd_sendmsg(fd: RawFd, buffer: &[u8], fds: &[RawFd]) -> io::Result<usize> {
68 let cmsg: Vec> = if !fds.is_empty() {
69 vec![ControlMessage::ScmRights(fds)]
70 } else {
71 vec![]
72 };
73 let iov: [IoSlice<'_>; 1] = [IoSlice::new(buf:buffer)];
74 match sendmsg::<UnixAddr>(fd, &iov, &cmsg, flags:MsgFlags::empty(), addr:None) {
75 // can it really happen?
76 Ok(0) => Err(io::Error::new(
77 kind:io::ErrorKind::WriteZero,
78 error:"failed to write to buffer",
79 )),
80 Ok(n: usize) => Ok(n),
81 Err(e: Errno) => Err(e.into()),
82 }
83}
84
85#[cfg(unix)]
86fn get_unix_pid(fd: &impl AsRawFd) -> io::Result<Option<u32>> {
87 #[cfg(any(target_os = "android", target_os = "linux"))]
88 {
89 use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
90
91 let fd = fd.as_raw_fd();
92 getsockopt(fd, PeerCredentials)
93 .map(|creds| Some(creds.pid() as _))
94 .map_err(|e| e.into())
95 }
96
97 #[cfg(any(
98 target_os = "macos",
99 target_os = "ios",
100 target_os = "freebsd",
101 target_os = "dragonfly",
102 target_os = "openbsd",
103 target_os = "netbsd"
104 ))]
105 {
106 let _ = fd;
107 // FIXME
108 Ok(None)
109 }
110}
111
112#[cfg(unix)]
113fn get_unix_uid(fd: &impl AsRawFd) -> io::Result<Option<u32>> {
114 let fd = fd.as_raw_fd();
115
116 #[cfg(any(target_os = "android", target_os = "linux"))]
117 {
118 use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
119
120 getsockopt(fd, PeerCredentials)
121 .map(|creds| Some(creds.uid()))
122 .map_err(|e| e.into())
123 }
124
125 #[cfg(any(
126 target_os = "macos",
127 target_os = "ios",
128 target_os = "freebsd",
129 target_os = "dragonfly",
130 target_os = "openbsd",
131 target_os = "netbsd"
132 ))]
133 {
134 nix::unistd::getpeereid(fd)
135 .map(|(uid, _)| Some(uid.into()))
136 .map_err(|e| e.into())
137 }
138}
139
140// Send 0 byte as a separate SCM_CREDS message.
141#[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
142fn send_zero_byte(fd: &impl AsRawFd) -> io::Result<usize> {
143 let fd = fd.as_raw_fd();
144 let iov = [std::io::IoSlice::new(b"\0")];
145 sendmsg::<()>(
146 fd,
147 &iov,
148 &[ControlMessage::ScmCreds],
149 MsgFlags::empty(),
150 None,
151 )
152 .map_err(|e| e.into())
153}
154
155#[cfg(unix)]
156type PollRecvmsg = Poll<io::Result<(usize, Vec<OwnedFd>)>>;
157
158#[cfg(not(unix))]
159type PollRecvmsg = Poll<io::Result<usize>>;
160
161/// Trait representing some transport layer over which the DBus protocol can be used
162///
163/// The crate provides implementations for `async_io` and `tokio`'s `UnixStream` wrappers if you
164/// enable the corresponding crate features (`async_io` is enabled by default).
165///
166/// You can implement it manually to integrate with other runtimes or other dbus transports. Feel
167/// free to submit pull requests to add support for more runtimes to zbus itself so rust's orphan
168/// rules don't force the use of a wrapper struct (and to avoid duplicating the work across many
169/// projects).
170pub trait Socket: std::fmt::Debug + Send + Sync {
171 /// Supports passing file descriptors.
172 fn can_pass_unix_fd(&self) -> bool {
173 true
174 }
175
176 /// Attempt to receive a message from the socket.
177 ///
178 /// On success, returns the number of bytes read as well as a `Vec` containing
179 /// any associated file descriptors.
180 fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg;
181
182 /// Attempt to send a message on the socket
183 ///
184 /// On success, return the number of bytes written. There may be a partial write, in
185 /// which case the caller is responsible of sending the remaining data by calling this
186 /// method again until everything is written or it returns an error of kind `WouldBlock`.
187 ///
188 /// If at least one byte has been written, then all the provided file descriptors will
189 /// have been sent as well, and should not be provided again in subsequent calls.
190 ///
191 /// If the underlying transport does not support transmitting file descriptors, this
192 /// will return `Err(ErrorKind::InvalidInput)`.
193 fn poll_sendmsg(
194 &mut self,
195 cx: &mut Context<'_>,
196 buffer: &[u8],
197 #[cfg(unix)] fds: &[RawFd],
198 ) -> Poll<io::Result<usize>>;
199
200 /// Close the socket.
201 ///
202 /// After this call, it is valid for all reading and writing operations to fail.
203 fn close(&self) -> io::Result<()>;
204
205 /// Return the peer PID.
206 fn peer_pid(&self) -> io::Result<Option<u32>> {
207 Ok(None)
208 }
209
210 /// Return the peer process SID, if any.
211 #[cfg(windows)]
212 fn peer_sid(&self) -> Option<String> {
213 None
214 }
215
216 /// Return the User ID, if any.
217 #[cfg(unix)]
218 fn uid(&self) -> io::Result<Option<u32>> {
219 Ok(None)
220 }
221
222 /// The dbus daemon on `freebsd` and `dragonfly` currently requires sending the zero byte
223 /// as a separate message with SCM_CREDS, as part of the `EXTERNAL` authentication on unix
224 /// sockets. This method is used by the authentication machinery in zbus to send this
225 /// zero byte. Socket implementations based on unix sockets should implement this method.
226 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
227 fn send_zero_byte(&self) -> io::Result<Option<usize>> {
228 Ok(None)
229 }
230}
231
232impl Socket for Box<dyn Socket> {
233 fn can_pass_unix_fd(&self) -> bool {
234 (**self).can_pass_unix_fd()
235 }
236
237 fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
238 (**self).poll_recvmsg(cx, buf)
239 }
240
241 fn poll_sendmsg(
242 &mut self,
243 cx: &mut Context<'_>,
244 buffer: &[u8],
245 #[cfg(unix)] fds: &[RawFd],
246 ) -> Poll<io::Result<usize>> {
247 (**self).poll_sendmsg(
248 cx,
249 buffer,
250 #[cfg(unix)]
251 fds,
252 )
253 }
254
255 fn close(&self) -> io::Result<()> {
256 (**self).close()
257 }
258
259 fn peer_pid(&self) -> io::Result<Option<u32>> {
260 (**self).peer_pid()
261 }
262
263 #[cfg(windows)]
264 fn peer_sid(&self) -> Option<String> {
265 (&**self).peer_sid()
266 }
267
268 #[cfg(unix)]
269 fn uid(&self) -> io::Result<Option<u32>> {
270 (**self).uid()
271 }
272
273 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
274 fn send_zero_byte(&self) -> io::Result<Option<usize>> {
275 (**self).send_zero_byte()
276 }
277}
278
279#[cfg(all(unix, not(feature = "tokio")))]
280impl Socket for Async<UnixStream> {
281 fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
282 let (len, fds) = loop {
283 match fd_recvmsg(self.as_raw_fd(), buf) {
284 Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
285 Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.poll_readable(cx) {
286 Poll::Pending => return Poll::Pending,
287 Poll::Ready(res) => res?,
288 },
289 v => break v?,
290 }
291 };
292 Poll::Ready(Ok((len, fds)))
293 }
294
295 fn poll_sendmsg(
296 &mut self,
297 cx: &mut Context<'_>,
298 buffer: &[u8],
299 #[cfg(unix)] fds: &[RawFd],
300 ) -> Poll<io::Result<usize>> {
301 loop {
302 match fd_sendmsg(
303 self.as_raw_fd(),
304 buffer,
305 #[cfg(unix)]
306 fds,
307 ) {
308 Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
309 Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.poll_writable(cx) {
310 Poll::Pending => return Poll::Pending,
311 Poll::Ready(res) => res?,
312 },
313 v => return Poll::Ready(v),
314 }
315 }
316 }
317
318 fn close(&self) -> io::Result<()> {
319 self.get_ref().shutdown(std::net::Shutdown::Both)
320 }
321
322 fn peer_pid(&self) -> io::Result<Option<u32>> {
323 get_unix_pid(self)
324 }
325
326 #[cfg(unix)]
327 fn uid(&self) -> io::Result<Option<u32>> {
328 get_unix_uid(self)
329 }
330
331 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
332 fn send_zero_byte(&self) -> io::Result<Option<usize>> {
333 send_zero_byte(self).map(Some)
334 }
335}
336
337#[cfg(all(unix, feature = "tokio"))]
338impl Socket for tokio::net::UnixStream {
339 fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
340 loop {
341 match self.try_io(tokio::io::Interest::READABLE, || {
342 fd_recvmsg(self.as_raw_fd(), buf)
343 }) {
344 Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
345 Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.poll_read_ready(cx) {
346 Poll::Pending => return Poll::Pending,
347 Poll::Ready(res) => res?,
348 },
349 v => return Poll::Ready(v),
350 }
351 }
352 }
353
354 fn poll_sendmsg(
355 &mut self,
356 cx: &mut Context<'_>,
357 buffer: &[u8],
358 #[cfg(unix)] fds: &[RawFd],
359 ) -> Poll<io::Result<usize>> {
360 loop {
361 match self.try_io(tokio::io::Interest::WRITABLE, || {
362 fd_sendmsg(
363 self.as_raw_fd(),
364 buffer,
365 #[cfg(unix)]
366 fds,
367 )
368 }) {
369 Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
370 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
371 match self.poll_write_ready(cx) {
372 Poll::Pending => return Poll::Pending,
373 Poll::Ready(res) => res?,
374 }
375 }
376 v => return Poll::Ready(v),
377 }
378 }
379 }
380
381 fn close(&self) -> io::Result<()> {
382 // FIXME: This should call `tokio::net::UnixStream::poll_shutdown` but this method is not
383 // async-friendly. At the next API break, we should fix this.
384 Ok(())
385 }
386
387 fn peer_pid(&self) -> io::Result<Option<u32>> {
388 get_unix_pid(self)
389 }
390
391 #[cfg(unix)]
392 fn uid(&self) -> io::Result<Option<u32>> {
393 get_unix_uid(self)
394 }
395
396 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
397 fn send_zero_byte(&self) -> io::Result<Option<usize>> {
398 send_zero_byte(self).map(Some)
399 }
400}
401
402#[cfg(all(windows, not(feature = "tokio")))]
403impl Socket for Async<UnixStream> {
404 fn can_pass_unix_fd(&self) -> bool {
405 false
406 }
407
408 fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
409 loop {
410 match (&mut *self).get_mut().read(buf) {
411 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
412 Err(e) => return Poll::Ready(Err(e)),
413 Ok(len) => {
414 let ret = len;
415 return Poll::Ready(Ok(ret));
416 }
417 }
418 ready!(self.poll_readable(cx))?;
419 }
420 }
421
422 fn poll_sendmsg(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
423 loop {
424 match (&mut *self).get_mut().write(buf) {
425 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
426 res => return Poll::Ready(res),
427 }
428 ready!(self.poll_writable(cx))?;
429 }
430 }
431
432 fn close(&self) -> io::Result<()> {
433 self.get_ref().shutdown(std::net::Shutdown::Both)
434 }
435
436 #[cfg(windows)]
437 fn peer_sid(&self) -> Option<String> {
438 use crate::win32::ProcessToken;
439
440 if let Ok(Some(pid)) = self.peer_pid() {
441 if let Ok(process_token) =
442 ProcessToken::open(if pid != 0 { Some(pid as _) } else { None })
443 {
444 return process_token.sid().ok();
445 }
446 }
447
448 None
449 }
450
451 fn peer_pid(&self) -> io::Result<Option<u32>> {
452 #[cfg(windows)]
453 {
454 use crate::win32::unix_stream_get_peer_pid;
455
456 Ok(Some(unix_stream_get_peer_pid(&self.get_ref())? as _))
457 }
458
459 #[cfg(unix)]
460 get_unix_pid(self)
461 }
462
463 #[cfg(unix)]
464 fn uid(&self) -> io::Result<Option<u32>> {
465 get_unix_uid(self)
466 }
467
468 #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))]
469 fn send_zero_byte(&self) -> io::Result<Option<usize>> {
470 send_zero_byte(self).map(Some)
471 }
472}
473
474#[cfg(not(feature = "tokio"))]
475impl Socket for Async<TcpStream> {
476 fn can_pass_unix_fd(&self) -> bool {
477 false
478 }
479
480 fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
481 #[cfg(unix)]
482 let fds = vec![];
483
484 loop {
485 match (*self).get_mut().read(buf) {
486 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
487 Err(e) => return Poll::Ready(Err(e)),
488 Ok(len) => {
489 #[cfg(unix)]
490 let ret = (len, fds);
491 #[cfg(not(unix))]
492 let ret = len;
493 return Poll::Ready(Ok(ret));
494 }
495 }
496 ready!(self.poll_readable(cx))?;
497 }
498 }
499
500 fn poll_sendmsg(
501 &mut self,
502 cx: &mut Context<'_>,
503 buf: &[u8],
504 #[cfg(unix)] fds: &[RawFd],
505 ) -> Poll<io::Result<usize>> {
506 #[cfg(unix)]
507 if !fds.is_empty() {
508 return Poll::Ready(Err(io::Error::new(
509 io::ErrorKind::InvalidInput,
510 "fds cannot be sent with a tcp stream",
511 )));
512 }
513
514 loop {
515 match (*self).get_mut().write(buf) {
516 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
517 res => return Poll::Ready(res),
518 }
519 ready!(self.poll_writable(cx))?;
520 }
521 }
522
523 fn close(&self) -> io::Result<()> {
524 self.get_ref().shutdown(std::net::Shutdown::Both)
525 }
526
527 #[cfg(windows)]
528 fn peer_sid(&self) -> Option<String> {
529 use crate::win32::{tcp_stream_get_peer_pid, ProcessToken};
530
531 if let Ok(pid) = tcp_stream_get_peer_pid(&self.get_ref()) {
532 if let Ok(process_token) = ProcessToken::open(if pid != 0 { Some(pid) } else { None }) {
533 return process_token.sid().ok();
534 }
535 }
536
537 None
538 }
539}
540
541#[cfg(feature = "tokio")]
542impl Socket for tokio::net::TcpStream {
543 fn can_pass_unix_fd(&self) -> bool {
544 false
545 }
546
547 fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
548 use tokio::io::{AsyncRead, ReadBuf};
549
550 let mut read_buf = ReadBuf::new(buf);
551 Pin::new(self).poll_read(cx, &mut read_buf).map(|res| {
552 res.map(|_| {
553 let ret = read_buf.filled().len();
554 #[cfg(unix)]
555 let ret = (ret, vec![]);
556
557 ret
558 })
559 })
560 }
561
562 fn poll_sendmsg(
563 &mut self,
564 cx: &mut Context<'_>,
565 buf: &[u8],
566 #[cfg(unix)] fds: &[RawFd],
567 ) -> Poll<io::Result<usize>> {
568 use tokio::io::AsyncWrite;
569
570 #[cfg(unix)]
571 if !fds.is_empty() {
572 return Poll::Ready(Err(io::Error::new(
573 io::ErrorKind::InvalidInput,
574 "fds cannot be sent with a tcp stream",
575 )));
576 }
577
578 Pin::new(self).poll_write(cx, buf)
579 }
580
581 fn close(&self) -> io::Result<()> {
582 // FIXME: This should call `tokio::net::TcpStream::poll_shutdown` but this method is not
583 // async-friendly. At the next API break, we should fix this.
584 Ok(())
585 }
586
587 #[cfg(windows)]
588 fn peer_sid(&self) -> Option<String> {
589 use crate::win32::{socket_addr_get_pid, ProcessToken};
590
591 let peer_addr = match self.peer_addr() {
592 Ok(addr) => addr,
593 Err(_) => return None,
594 };
595
596 if let Ok(pid) = socket_addr_get_pid(&peer_addr) {
597 if let Ok(process_token) = ProcessToken::open(if pid != 0 { Some(pid) } else { None }) {
598 return process_token.sid().ok();
599 }
600 }
601
602 None
603 }
604}
605
606#[cfg(all(feature = "vsock", not(feature = "tokio")))]
607impl Socket for Async<vsock::VsockStream> {
608 fn can_pass_unix_fd(&self) -> bool {
609 false
610 }
611
612 fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
613 #[cfg(unix)]
614 let fds = vec![];
615
616 loop {
617 match (*self).get_mut().read(buf) {
618 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
619 Err(e) => return Poll::Ready(Err(e)),
620 Ok(len) => {
621 #[cfg(unix)]
622 let ret = (len, fds);
623 #[cfg(not(unix))]
624 let ret = len;
625 return Poll::Ready(Ok(ret));
626 }
627 }
628 ready!(self.poll_readable(cx))?;
629 }
630 }
631
632 fn poll_sendmsg(
633 &mut self,
634 cx: &mut Context<'_>,
635 buf: &[u8],
636 #[cfg(unix)] fds: &[RawFd],
637 ) -> Poll<io::Result<usize>> {
638 #[cfg(unix)]
639 if !fds.is_empty() {
640 return Poll::Ready(Err(io::Error::new(
641 io::ErrorKind::InvalidInput,
642 "fds cannot be sent with a tcp stream",
643 )));
644 }
645
646 loop {
647 match (*self).get_mut().write(buf) {
648 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
649 res => return Poll::Ready(res),
650 }
651 ready!(self.poll_writable(cx))?;
652 }
653 }
654
655 fn close(&self) -> io::Result<()> {
656 self.get_ref().shutdown(std::net::Shutdown::Both)
657 }
658}
659
660#[cfg(feature = "tokio-vsock")]
661impl Socket for tokio_vsock::VsockStream {
662 fn can_pass_unix_fd(&self) -> bool {
663 false
664 }
665
666 fn poll_recvmsg(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> PollRecvmsg {
667 use tokio::io::{AsyncRead, ReadBuf};
668
669 let mut read_buf = ReadBuf::new(buf);
670 Pin::new(self).poll_read(cx, &mut read_buf).map(|res| {
671 res.map(|_| {
672 let ret = read_buf.filled().len();
673 #[cfg(unix)]
674 let ret = (ret, vec![]);
675
676 ret
677 })
678 })
679 }
680
681 fn poll_sendmsg(
682 &mut self,
683 cx: &mut Context<'_>,
684 buf: &[u8],
685 #[cfg(unix)] fds: &[RawFd],
686 ) -> Poll<io::Result<usize>> {
687 use tokio::io::AsyncWrite;
688
689 #[cfg(unix)]
690 if !fds.is_empty() {
691 return Poll::Ready(Err(io::Error::new(
692 io::ErrorKind::InvalidInput,
693 "fds cannot be sent with a tcp stream",
694 )));
695 }
696
697 Pin::new(self).poll_write(cx, buf)
698 }
699
700 fn close(&self) -> io::Result<()> {
701 self.shutdown(std::net::Shutdown::Both)
702 }
703}
704