1//! Wayland socket manipulation
2
3use std::collections::VecDeque;
4use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
5use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
6use std::os::unix::net::UnixStream;
7use std::slice;
8
9use rustix::io::retry_on_intr;
10use rustix::net::{
11 recvmsg, send, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags,
12 SendAncillaryBuffer, SendAncillaryMessage, SendFlags,
13};
14
15use crate::protocol::{ArgumentType, Message};
16
17use super::wire::{parse_message, write_to_buffers, MessageParseError, MessageWriteError};
18
19/// Maximum number of FD that can be sent in a single socket message
20pub const MAX_FDS_OUT: usize = 28;
21/// Maximum number of bytes that can be sent in a single socket message
22pub const MAX_BYTES_OUT: usize = 4096;
23
24/*
25 * Socket
26 */
27
28/// A wayland socket
29#[derive(Debug)]
30pub struct Socket {
31 stream: UnixStream,
32}
33
34impl Socket {
35 /// Send a single message to the socket
36 ///
37 /// A single socket message can contain several wayland messages
38 ///
39 /// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes`
40 /// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
41 /// end may lose some data.
42 pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> {
43 #[cfg(not(target_os = "macos"))]
44 let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
45 #[cfg(target_os = "macos")]
46 let flags = SendFlags::DONTWAIT;
47
48 if !fds.is_empty() {
49 let iov = [IoSlice::new(bytes)];
50 let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(fds.len()))];
51 let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
52 let fds =
53 unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) };
54 cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds));
55 Ok(retry_on_intr(|| sendmsg(self, &iov, &mut cmsg_buffer, flags))?)
56 } else {
57 Ok(retry_on_intr(|| send(self, bytes, flags))?)
58 }
59 }
60
61 /// Receive a single message from the socket
62 ///
63 /// Return the number of bytes received and the number of Fds received.
64 ///
65 /// Errors with `WouldBlock` is no message is available.
66 ///
67 /// A single socket message can contain several wayland messages.
68 ///
69 /// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds`
70 /// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
71 /// be lost.
72 pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> {
73 #[cfg(not(target_os = "macos"))]
74 let flags = RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC;
75 #[cfg(target_os = "macos")]
76 let flags = RecvFlags::DONTWAIT;
77
78 let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
79 let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
80 let mut iov = [IoSliceMut::new(buffer)];
81 let msg = retry_on_intr(|| recvmsg(&self.stream, &mut iov[..], &mut cmsg_buffer, flags))?;
82
83 let received_fds = cmsg_buffer
84 .drain()
85 .filter_map(|cmsg| match cmsg {
86 RecvAncillaryMessage::ScmRights(fds) => Some(fds),
87 _ => None,
88 })
89 .flatten();
90 fds.extend(received_fds);
91 #[cfg(target_os = "macos")]
92 for fd in fds.iter() {
93 if let Ok(flags) = rustix::io::fcntl_getfd(fd) {
94 let _ = rustix::io::fcntl_setfd(fd, flags | rustix::io::FdFlags::CLOEXEC);
95 }
96 }
97 Ok(msg.bytes)
98 }
99}
100
101impl From<UnixStream> for Socket {
102 fn from(stream: UnixStream) -> Self {
103 // macOS doesn't have MSG_NOSIGNAL, but has SO_NOSIGPIPE instead
104 #[cfg(target_os = "macos")]
105 let _ = rustix::net::sockopt::set_socket_nosigpipe(&stream, true);
106 Self { stream }
107 }
108}
109
110impl AsFd for Socket {
111 fn as_fd(&self) -> BorrowedFd<'_> {
112 self.stream.as_fd()
113 }
114}
115
116impl AsRawFd for Socket {
117 fn as_raw_fd(&self) -> RawFd {
118 self.stream.as_raw_fd()
119 }
120}
121
122/*
123 * BufferedSocket
124 */
125
126/// An adapter around a raw Socket that directly handles buffering and
127/// conversion from/to wayland messages
128#[derive(Debug)]
129pub struct BufferedSocket {
130 socket: Socket,
131 in_data: Buffer<u8>,
132 in_fds: VecDeque<OwnedFd>,
133 out_data: Buffer<u8>,
134 out_fds: Vec<OwnedFd>,
135}
136
137impl BufferedSocket {
138 /// Wrap a Socket into a Buffered Socket
139 pub fn new(socket: Socket) -> Self {
140 Self {
141 socket,
142 in_data: Buffer::new(2 * MAX_BYTES_OUT), // Incoming buffers are twice as big in order to be
143 in_fds: VecDeque::new(), // able to store leftover data if needed
144 out_data: Buffer::new(MAX_BYTES_OUT),
145 out_fds: Vec::new(),
146 }
147 }
148
149 /// Flush the contents of the outgoing buffer into the socket
150 pub fn flush(&mut self) -> IoResult<()> {
151 let written = {
152 let bytes = self.out_data.get_contents();
153 if bytes.is_empty() {
154 return Ok(());
155 }
156 self.socket.send_msg(bytes, &self.out_fds)?
157 };
158 self.out_data.offset(written);
159 self.out_data.move_to_front();
160 self.out_fds.clear();
161 Ok(())
162 }
163
164 // internal method
165 //
166 // attempts to write a message in the internal out buffers,
167 // returns true if successful
168 //
169 // if false is returned, it means there is not enough space
170 // in the buffer
171 fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> {
172 match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) {
173 Ok(bytes_out) => {
174 self.out_data.advance(bytes_out);
175 Ok(true)
176 }
177 Err(MessageWriteError::BufferTooSmall) => Ok(false),
178 Err(MessageWriteError::DupFdFailed(e)) => Err(e),
179 }
180 }
181
182 /// Write a message to the outgoing buffer
183 ///
184 /// This method may flush the internal buffer if necessary (if it is full).
185 ///
186 /// If the message is too big to fit in the buffer, the error `Error::Sys(E2BIG)`
187 /// will be returned.
188 pub fn write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<()> {
189 if !self.attempt_write_message(msg)? {
190 // the attempt failed, there is not enough space in the buffer
191 // we need to flush it
192 if let Err(e) = self.flush() {
193 if e.kind() != ErrorKind::WouldBlock {
194 return Err(e);
195 }
196 }
197 if !self.attempt_write_message(msg)? {
198 // If this fails again, this means the message is too big
199 // to be transmitted at all
200 return Err(rustix::io::Errno::TOOBIG.into());
201 }
202 }
203 Ok(())
204 }
205
206 /// Try to fill the incoming buffers of this socket, to prepare
207 /// a new round of parsing.
208 pub fn fill_incoming_buffers(&mut self) -> IoResult<()> {
209 // reorganize the buffers
210 self.in_data.move_to_front();
211 // receive a message
212 let in_bytes = {
213 let bytes = self.in_data.get_writable_storage();
214 self.socket.rcv_msg(bytes, &mut self.in_fds)?
215 };
216 if in_bytes == 0 {
217 // the other end of the socket was closed
218 return Err(rustix::io::Errno::PIPE.into());
219 }
220 // advance the storage
221 self.in_data.advance(in_bytes);
222 Ok(())
223 }
224
225 /// Read and deserialize a single message from the incoming buffers socket
226 ///
227 /// This method requires one closure that given an object id and an opcode,
228 /// must provide the signature of the associated request/event, in the form of
229 /// a `&'static [ArgumentType]`.
230 pub fn read_one_message<F>(
231 &mut self,
232 mut signature: F,
233 ) -> Result<Message<u32, OwnedFd>, MessageParseError>
234 where
235 F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
236 {
237 let (msg, read_data) = {
238 let data = self.in_data.get_contents();
239 if data.len() < 2 * 4 {
240 return Err(MessageParseError::MissingData);
241 }
242 let object_id = u32::from_ne_bytes([data[0], data[1], data[2], data[3]]);
243 let word_2 = u32::from_ne_bytes([data[4], data[5], data[6], data[7]]);
244 let opcode = (word_2 & 0x0000_FFFF) as u16;
245 if let Some(sig) = signature(object_id, opcode) {
246 match parse_message(data, sig, &mut self.in_fds) {
247 Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()),
248 Err(e) => return Err(e),
249 }
250 } else {
251 // no signature found ?
252 return Err(MessageParseError::Malformed);
253 }
254 };
255
256 self.in_data.offset(read_data);
257
258 Ok(msg)
259 }
260}
261
262impl AsRawFd for BufferedSocket {
263 fn as_raw_fd(&self) -> RawFd {
264 self.socket.as_raw_fd()
265 }
266}
267
268impl AsFd for BufferedSocket {
269 fn as_fd(&self) -> BorrowedFd<'_> {
270 self.socket.as_fd()
271 }
272}
273
274/*
275 * Buffer
276 */
277#[derive(Debug)]
278struct Buffer<T: Copy> {
279 storage: Vec<T>,
280 occupied: usize,
281 offset: usize,
282}
283
284impl<T: Copy + Default> Buffer<T> {
285 fn new(size: usize) -> Self {
286 Self { storage: vec![T::default(); size], occupied: 0, offset: 0 }
287 }
288
289 /// Advance the internal counter of occupied space
290 fn advance(&mut self, bytes: usize) {
291 self.occupied += bytes;
292 }
293
294 /// Advance the read offset of current occupied space
295 fn offset(&mut self, bytes: usize) {
296 self.offset += bytes;
297 }
298
299 /// Clears the contents of the buffer
300 ///
301 /// This only sets the counter of occupied space back to zero,
302 /// allowing previous content to be overwritten.
303 #[allow(unused)]
304 fn clear(&mut self) {
305 self.occupied = 0;
306 self.offset = 0;
307 }
308
309 /// Get the current contents of the occupied space of the buffer
310 fn get_contents(&self) -> &[T] {
311 &self.storage[(self.offset)..(self.occupied)]
312 }
313
314 /// Get mutable access to the unoccupied space of the buffer
315 fn get_writable_storage(&mut self) -> &mut [T] {
316 &mut self.storage[(self.occupied)..]
317 }
318
319 /// Move the unread contents of the buffer to the front, to ensure
320 /// maximal write space availability
321 fn move_to_front(&mut self) {
322 if self.occupied > self.offset {
323 self.storage.copy_within((self.offset)..(self.occupied), 0)
324 }
325 self.occupied -= self.offset;
326 self.offset = 0;
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::protocol::{AllowNull, Argument, ArgumentType, Message};
334
335 use std::ffi::CString;
336 use std::os::unix::io::BorrowedFd;
337 use std::os::unix::prelude::IntoRawFd;
338
339 use smallvec::smallvec;
340
341 fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool {
342 let stat1 = rustix::fs::fstat(a).unwrap();
343 let stat2 = rustix::fs::fstat(b).unwrap();
344 stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
345 }
346
347 // check if two messages are equal
348 //
349 // if arguments contain FDs, check that the fd point to
350 // the same file, rather than are the same number.
351 fn assert_eq_msgs<Fd: AsRawFd + std::fmt::Debug>(
352 msg1: &Message<u32, Fd>,
353 msg2: &Message<u32, Fd>,
354 ) {
355 assert_eq!(msg1.sender_id, msg2.sender_id);
356 assert_eq!(msg1.opcode, msg2.opcode);
357 assert_eq!(msg1.args.len(), msg2.args.len());
358 for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
359 if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) {
360 let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) };
361 let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) };
362 assert!(same_file(fd1, fd2));
363 } else {
364 assert_eq!(arg1, arg2);
365 }
366 }
367 }
368
369 #[test]
370 fn write_read_cycle() {
371 let msg = Message {
372 sender_id: 42,
373 opcode: 7,
374 args: smallvec![
375 Argument::Uint(3),
376 Argument::Fixed(-89),
377 Argument::Str(Some(Box::new(CString::new(&b"I like trains!"[..]).unwrap()))),
378 Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
379 Argument::Object(88),
380 Argument::NewId(56),
381 Argument::Int(-25),
382 ],
383 };
384
385 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
386 let mut client = BufferedSocket::new(Socket::from(client));
387 let mut server = BufferedSocket::new(Socket::from(server));
388
389 client.write_message(&msg).unwrap();
390 client.flush().unwrap();
391
392 static SIGNATURE: &[ArgumentType] = &[
393 ArgumentType::Uint,
394 ArgumentType::Fixed,
395 ArgumentType::Str(AllowNull::No),
396 ArgumentType::Array,
397 ArgumentType::Object(AllowNull::No),
398 ArgumentType::NewId,
399 ArgumentType::Int,
400 ];
401
402 server.fill_incoming_buffers().unwrap();
403
404 let ret_msg =
405 server
406 .read_one_message(|sender_id, opcode| {
407 if sender_id == 42 && opcode == 7 {
408 Some(SIGNATURE)
409 } else {
410 None
411 }
412 })
413 .unwrap();
414
415 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
416 }
417
418 #[test]
419 fn write_read_cycle_fd() {
420 let msg = Message {
421 sender_id: 42,
422 opcode: 7,
423 args: smallvec![
424 Argument::Fd(1), // stdin
425 Argument::Fd(0), // stdout
426 ],
427 };
428
429 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
430 let mut client = BufferedSocket::new(Socket::from(client));
431 let mut server = BufferedSocket::new(Socket::from(server));
432
433 client.write_message(&msg).unwrap();
434 client.flush().unwrap();
435
436 static SIGNATURE: &[ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
437
438 server.fill_incoming_buffers().unwrap();
439
440 let ret_msg =
441 server
442 .read_one_message(|sender_id, opcode| {
443 if sender_id == 42 && opcode == 7 {
444 Some(SIGNATURE)
445 } else {
446 None
447 }
448 })
449 .unwrap();
450 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
451 }
452
453 #[test]
454 fn write_read_cycle_multiple() {
455 let messages = vec![
456 Message {
457 sender_id: 42,
458 opcode: 0,
459 args: smallvec![
460 Argument::Int(42),
461 Argument::Str(Some(Box::new(CString::new(&b"I like trains"[..]).unwrap()))),
462 ],
463 },
464 Message {
465 sender_id: 42,
466 opcode: 1,
467 args: smallvec![
468 Argument::Fd(1), // stdin
469 Argument::Fd(0), // stdout
470 ],
471 },
472 Message {
473 sender_id: 42,
474 opcode: 2,
475 args: smallvec![
476 Argument::Uint(3),
477 Argument::Fd(2), // stderr
478 ],
479 },
480 ];
481
482 static SIGNATURES: &[&[ArgumentType]] = &[
483 &[ArgumentType::Int, ArgumentType::Str(AllowNull::No)],
484 &[ArgumentType::Fd, ArgumentType::Fd],
485 &[ArgumentType::Uint, ArgumentType::Fd],
486 ];
487
488 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
489 let mut client = BufferedSocket::new(Socket::from(client));
490 let mut server = BufferedSocket::new(Socket::from(server));
491
492 for msg in &messages {
493 client.write_message(msg).unwrap();
494 }
495 client.flush().unwrap();
496
497 server.fill_incoming_buffers().unwrap();
498
499 let mut recv_msgs = Vec::new();
500 while let Ok(message) = server.read_one_message(|sender_id, opcode| {
501 if sender_id == 42 {
502 Some(SIGNATURES[opcode as usize])
503 } else {
504 None
505 }
506 }) {
507 recv_msgs.push(message);
508 }
509 assert_eq!(recv_msgs.len(), 3);
510 for (msg1, msg2) in messages.into_iter().zip(recv_msgs.into_iter()) {
511 assert_eq_msgs(&msg1.map_fd(|fd| fd.as_raw_fd()), &msg2.map_fd(IntoRawFd::into_raw_fd));
512 }
513 }
514
515 #[test]
516 fn parse_with_string_len_multiple_of_4() {
517 let msg = Message {
518 sender_id: 2,
519 opcode: 0,
520 args: smallvec![
521 Argument::Uint(18),
522 Argument::Str(Some(Box::new(CString::new(&b"wl_shell"[..]).unwrap()))),
523 Argument::Uint(1),
524 ],
525 };
526
527 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
528 let mut client = BufferedSocket::new(Socket::from(client));
529 let mut server = BufferedSocket::new(Socket::from(server));
530
531 client.write_message(&msg).unwrap();
532 client.flush().unwrap();
533
534 static SIGNATURE: &[ArgumentType] =
535 &[ArgumentType::Uint, ArgumentType::Str(AllowNull::No), ArgumentType::Uint];
536
537 server.fill_incoming_buffers().unwrap();
538
539 let ret_msg =
540 server
541 .read_one_message(|sender_id, opcode| {
542 if sender_id == 2 && opcode == 0 {
543 Some(SIGNATURE)
544 } else {
545 None
546 }
547 })
548 .unwrap();
549
550 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
551 }
552}
553