1 | //! Wayland socket manipulation |
2 | |
3 | use std::collections::VecDeque; |
4 | use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult}; |
5 | use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd}; |
6 | use std::os::unix::net::UnixStream; |
7 | use std::slice; |
8 | |
9 | use rustix::io::retry_on_intr; |
10 | use rustix::net::{ |
11 | recvmsg, send, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, |
12 | SendAncillaryBuffer, SendAncillaryMessage, SendFlags, |
13 | }; |
14 | |
15 | use crate::protocol::{ArgumentType, Message}; |
16 | |
17 | use super::wire::{parse_message, write_to_buffers, MessageParseError, MessageWriteError}; |
18 | |
19 | /// Maximum number of FD that can be sent in a single socket message |
20 | pub const MAX_FDS_OUT: usize = 28; |
21 | /// Maximum number of bytes that can be sent in a single socket message |
22 | pub const MAX_BYTES_OUT: usize = 4096; |
23 | |
24 | /* |
25 | * Socket |
26 | */ |
27 | |
28 | /// A wayland socket |
29 | #[derive (Debug)] |
30 | pub struct Socket { |
31 | stream: UnixStream, |
32 | } |
33 | |
34 | impl Socket { |
35 | /// Send a single message to the socket |
36 | /// |
37 | /// A single socket message can contain several wayland messages |
38 | /// |
39 | /// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes` |
40 | /// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving |
41 | /// end may lose some data. |
42 | pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> { |
43 | #[cfg (not(target_os = "macos" ))] |
44 | let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL; |
45 | #[cfg (target_os = "macos" )] |
46 | let flags = SendFlags::DONTWAIT; |
47 | |
48 | if !fds.is_empty() { |
49 | let iov = [IoSlice::new(bytes)]; |
50 | let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(fds.len()))]; |
51 | let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space); |
52 | let fds = |
53 | unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) }; |
54 | cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds)); |
55 | Ok(retry_on_intr(|| sendmsg(self, &iov, &mut cmsg_buffer, flags))?) |
56 | } else { |
57 | Ok(retry_on_intr(|| send(self, bytes, flags))?) |
58 | } |
59 | } |
60 | |
61 | /// Receive a single message from the socket |
62 | /// |
63 | /// Return the number of bytes received and the number of Fds received. |
64 | /// |
65 | /// Errors with `WouldBlock` is no message is available. |
66 | /// |
67 | /// A single socket message can contain several wayland messages. |
68 | /// |
69 | /// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds` |
70 | /// slice `MAX_FDS_OUT` long, otherwise some data of the received message may |
71 | /// be lost. |
72 | pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> { |
73 | #[cfg (not(target_os = "macos" ))] |
74 | let flags = RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC; |
75 | #[cfg (target_os = "macos" )] |
76 | let flags = RecvFlags::DONTWAIT; |
77 | |
78 | let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))]; |
79 | let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space); |
80 | let mut iov = [IoSliceMut::new(buffer)]; |
81 | let msg = retry_on_intr(|| recvmsg(&self.stream, &mut iov[..], &mut cmsg_buffer, flags))?; |
82 | |
83 | let received_fds = cmsg_buffer |
84 | .drain() |
85 | .filter_map(|cmsg| match cmsg { |
86 | RecvAncillaryMessage::ScmRights(fds) => Some(fds), |
87 | _ => None, |
88 | }) |
89 | .flatten(); |
90 | fds.extend(received_fds); |
91 | #[cfg (target_os = "macos" )] |
92 | for fd in fds.iter() { |
93 | if let Ok(flags) = rustix::io::fcntl_getfd(fd) { |
94 | let _ = rustix::io::fcntl_setfd(fd, flags | rustix::io::FdFlags::CLOEXEC); |
95 | } |
96 | } |
97 | Ok(msg.bytes) |
98 | } |
99 | } |
100 | |
101 | impl From<UnixStream> for Socket { |
102 | fn from(stream: UnixStream) -> Self { |
103 | // macOS doesn't have MSG_NOSIGNAL, but has SO_NOSIGPIPE instead |
104 | #[cfg (target_os = "macos" )] |
105 | let _ = rustix::net::sockopt::set_socket_nosigpipe(&stream, true); |
106 | Self { stream } |
107 | } |
108 | } |
109 | |
110 | impl AsFd for Socket { |
111 | fn as_fd(&self) -> BorrowedFd<'_> { |
112 | self.stream.as_fd() |
113 | } |
114 | } |
115 | |
116 | impl AsRawFd for Socket { |
117 | fn as_raw_fd(&self) -> RawFd { |
118 | self.stream.as_raw_fd() |
119 | } |
120 | } |
121 | |
122 | /* |
123 | * BufferedSocket |
124 | */ |
125 | |
126 | /// An adapter around a raw Socket that directly handles buffering and |
127 | /// conversion from/to wayland messages |
128 | #[derive (Debug)] |
129 | pub struct BufferedSocket { |
130 | socket: Socket, |
131 | in_data: Buffer<u8>, |
132 | in_fds: VecDeque<OwnedFd>, |
133 | out_data: Buffer<u8>, |
134 | out_fds: Vec<OwnedFd>, |
135 | } |
136 | |
137 | impl BufferedSocket { |
138 | /// Wrap a Socket into a Buffered Socket |
139 | pub fn new(socket: Socket) -> Self { |
140 | Self { |
141 | socket, |
142 | in_data: Buffer::new(2 * MAX_BYTES_OUT), // Incoming buffers are twice as big in order to be |
143 | in_fds: VecDeque::new(), // able to store leftover data if needed |
144 | out_data: Buffer::new(MAX_BYTES_OUT), |
145 | out_fds: Vec::new(), |
146 | } |
147 | } |
148 | |
149 | /// Flush the contents of the outgoing buffer into the socket |
150 | pub fn flush(&mut self) -> IoResult<()> { |
151 | let written = { |
152 | let bytes = self.out_data.get_contents(); |
153 | if bytes.is_empty() { |
154 | return Ok(()); |
155 | } |
156 | self.socket.send_msg(bytes, &self.out_fds)? |
157 | }; |
158 | self.out_data.offset(written); |
159 | self.out_data.move_to_front(); |
160 | self.out_fds.clear(); |
161 | Ok(()) |
162 | } |
163 | |
164 | // internal method |
165 | // |
166 | // attempts to write a message in the internal out buffers, |
167 | // returns true if successful |
168 | // |
169 | // if false is returned, it means there is not enough space |
170 | // in the buffer |
171 | fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> { |
172 | match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) { |
173 | Ok(bytes_out) => { |
174 | self.out_data.advance(bytes_out); |
175 | Ok(true) |
176 | } |
177 | Err(MessageWriteError::BufferTooSmall) => Ok(false), |
178 | Err(MessageWriteError::DupFdFailed(e)) => Err(e), |
179 | } |
180 | } |
181 | |
182 | /// Write a message to the outgoing buffer |
183 | /// |
184 | /// This method may flush the internal buffer if necessary (if it is full). |
185 | /// |
186 | /// If the message is too big to fit in the buffer, the error `Error::Sys(E2BIG)` |
187 | /// will be returned. |
188 | pub fn write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<()> { |
189 | if !self.attempt_write_message(msg)? { |
190 | // the attempt failed, there is not enough space in the buffer |
191 | // we need to flush it |
192 | if let Err(e) = self.flush() { |
193 | if e.kind() != ErrorKind::WouldBlock { |
194 | return Err(e); |
195 | } |
196 | } |
197 | if !self.attempt_write_message(msg)? { |
198 | // If this fails again, this means the message is too big |
199 | // to be transmitted at all |
200 | return Err(rustix::io::Errno::TOOBIG.into()); |
201 | } |
202 | } |
203 | Ok(()) |
204 | } |
205 | |
206 | /// Try to fill the incoming buffers of this socket, to prepare |
207 | /// a new round of parsing. |
208 | pub fn fill_incoming_buffers(&mut self) -> IoResult<()> { |
209 | // reorganize the buffers |
210 | self.in_data.move_to_front(); |
211 | // receive a message |
212 | let in_bytes = { |
213 | let bytes = self.in_data.get_writable_storage(); |
214 | self.socket.rcv_msg(bytes, &mut self.in_fds)? |
215 | }; |
216 | if in_bytes == 0 { |
217 | // the other end of the socket was closed |
218 | return Err(rustix::io::Errno::PIPE.into()); |
219 | } |
220 | // advance the storage |
221 | self.in_data.advance(in_bytes); |
222 | Ok(()) |
223 | } |
224 | |
225 | /// Read and deserialize a single message from the incoming buffers socket |
226 | /// |
227 | /// This method requires one closure that given an object id and an opcode, |
228 | /// must provide the signature of the associated request/event, in the form of |
229 | /// a `&'static [ArgumentType]`. |
230 | pub fn read_one_message<F>( |
231 | &mut self, |
232 | mut signature: F, |
233 | ) -> Result<Message<u32, OwnedFd>, MessageParseError> |
234 | where |
235 | F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>, |
236 | { |
237 | let (msg, read_data) = { |
238 | let data = self.in_data.get_contents(); |
239 | if data.len() < 2 * 4 { |
240 | return Err(MessageParseError::MissingData); |
241 | } |
242 | let object_id = u32::from_ne_bytes([data[0], data[1], data[2], data[3]]); |
243 | let word_2 = u32::from_ne_bytes([data[4], data[5], data[6], data[7]]); |
244 | let opcode = (word_2 & 0x0000_FFFF) as u16; |
245 | if let Some(sig) = signature(object_id, opcode) { |
246 | match parse_message(data, sig, &mut self.in_fds) { |
247 | Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()), |
248 | Err(e) => return Err(e), |
249 | } |
250 | } else { |
251 | // no signature found ? |
252 | return Err(MessageParseError::Malformed); |
253 | } |
254 | }; |
255 | |
256 | self.in_data.offset(read_data); |
257 | |
258 | Ok(msg) |
259 | } |
260 | } |
261 | |
262 | impl AsRawFd for BufferedSocket { |
263 | fn as_raw_fd(&self) -> RawFd { |
264 | self.socket.as_raw_fd() |
265 | } |
266 | } |
267 | |
268 | impl AsFd for BufferedSocket { |
269 | fn as_fd(&self) -> BorrowedFd<'_> { |
270 | self.socket.as_fd() |
271 | } |
272 | } |
273 | |
274 | /* |
275 | * Buffer |
276 | */ |
277 | #[derive (Debug)] |
278 | struct Buffer<T: Copy> { |
279 | storage: Vec<T>, |
280 | occupied: usize, |
281 | offset: usize, |
282 | } |
283 | |
284 | impl<T: Copy + Default> Buffer<T> { |
285 | fn new(size: usize) -> Self { |
286 | Self { storage: vec![T::default(); size], occupied: 0, offset: 0 } |
287 | } |
288 | |
289 | /// Advance the internal counter of occupied space |
290 | fn advance(&mut self, bytes: usize) { |
291 | self.occupied += bytes; |
292 | } |
293 | |
294 | /// Advance the read offset of current occupied space |
295 | fn offset(&mut self, bytes: usize) { |
296 | self.offset += bytes; |
297 | } |
298 | |
299 | /// Clears the contents of the buffer |
300 | /// |
301 | /// This only sets the counter of occupied space back to zero, |
302 | /// allowing previous content to be overwritten. |
303 | #[allow (unused)] |
304 | fn clear(&mut self) { |
305 | self.occupied = 0; |
306 | self.offset = 0; |
307 | } |
308 | |
309 | /// Get the current contents of the occupied space of the buffer |
310 | fn get_contents(&self) -> &[T] { |
311 | &self.storage[(self.offset)..(self.occupied)] |
312 | } |
313 | |
314 | /// Get mutable access to the unoccupied space of the buffer |
315 | fn get_writable_storage(&mut self) -> &mut [T] { |
316 | &mut self.storage[(self.occupied)..] |
317 | } |
318 | |
319 | /// Move the unread contents of the buffer to the front, to ensure |
320 | /// maximal write space availability |
321 | fn move_to_front(&mut self) { |
322 | if self.occupied > self.offset { |
323 | self.storage.copy_within((self.offset)..(self.occupied), 0) |
324 | } |
325 | self.occupied -= self.offset; |
326 | self.offset = 0; |
327 | } |
328 | } |
329 | |
330 | #[cfg (test)] |
331 | mod tests { |
332 | use super::*; |
333 | use crate::protocol::{AllowNull, Argument, ArgumentType, Message}; |
334 | |
335 | use std::ffi::CString; |
336 | use std::os::unix::io::BorrowedFd; |
337 | use std::os::unix::prelude::IntoRawFd; |
338 | |
339 | use smallvec::smallvec; |
340 | |
341 | fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool { |
342 | let stat1 = rustix::fs::fstat(a).unwrap(); |
343 | let stat2 = rustix::fs::fstat(b).unwrap(); |
344 | stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino |
345 | } |
346 | |
347 | // check if two messages are equal |
348 | // |
349 | // if arguments contain FDs, check that the fd point to |
350 | // the same file, rather than are the same number. |
351 | fn assert_eq_msgs<Fd: AsRawFd + std::fmt::Debug>( |
352 | msg1: &Message<u32, Fd>, |
353 | msg2: &Message<u32, Fd>, |
354 | ) { |
355 | assert_eq!(msg1.sender_id, msg2.sender_id); |
356 | assert_eq!(msg1.opcode, msg2.opcode); |
357 | assert_eq!(msg1.args.len(), msg2.args.len()); |
358 | for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) { |
359 | if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) { |
360 | let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) }; |
361 | let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) }; |
362 | assert!(same_file(fd1, fd2)); |
363 | } else { |
364 | assert_eq!(arg1, arg2); |
365 | } |
366 | } |
367 | } |
368 | |
369 | #[test ] |
370 | fn write_read_cycle() { |
371 | let msg = Message { |
372 | sender_id: 42, |
373 | opcode: 7, |
374 | args: smallvec![ |
375 | Argument::Uint(3), |
376 | Argument::Fixed(-89), |
377 | Argument::Str(Some(Box::new(CString::new(&b"I like trains!" [..]).unwrap()))), |
378 | Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()), |
379 | Argument::Object(88), |
380 | Argument::NewId(56), |
381 | Argument::Int(-25), |
382 | ], |
383 | }; |
384 | |
385 | let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap(); |
386 | let mut client = BufferedSocket::new(Socket::from(client)); |
387 | let mut server = BufferedSocket::new(Socket::from(server)); |
388 | |
389 | client.write_message(&msg).unwrap(); |
390 | client.flush().unwrap(); |
391 | |
392 | static SIGNATURE: &[ArgumentType] = &[ |
393 | ArgumentType::Uint, |
394 | ArgumentType::Fixed, |
395 | ArgumentType::Str(AllowNull::No), |
396 | ArgumentType::Array, |
397 | ArgumentType::Object(AllowNull::No), |
398 | ArgumentType::NewId, |
399 | ArgumentType::Int, |
400 | ]; |
401 | |
402 | server.fill_incoming_buffers().unwrap(); |
403 | |
404 | let ret_msg = |
405 | server |
406 | .read_one_message(|sender_id, opcode| { |
407 | if sender_id == 42 && opcode == 7 { |
408 | Some(SIGNATURE) |
409 | } else { |
410 | None |
411 | } |
412 | }) |
413 | .unwrap(); |
414 | |
415 | assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd)); |
416 | } |
417 | |
418 | #[test ] |
419 | fn write_read_cycle_fd() { |
420 | let msg = Message { |
421 | sender_id: 42, |
422 | opcode: 7, |
423 | args: smallvec![ |
424 | Argument::Fd(1), // stdin |
425 | Argument::Fd(0), // stdout |
426 | ], |
427 | }; |
428 | |
429 | let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap(); |
430 | let mut client = BufferedSocket::new(Socket::from(client)); |
431 | let mut server = BufferedSocket::new(Socket::from(server)); |
432 | |
433 | client.write_message(&msg).unwrap(); |
434 | client.flush().unwrap(); |
435 | |
436 | static SIGNATURE: &[ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd]; |
437 | |
438 | server.fill_incoming_buffers().unwrap(); |
439 | |
440 | let ret_msg = |
441 | server |
442 | .read_one_message(|sender_id, opcode| { |
443 | if sender_id == 42 && opcode == 7 { |
444 | Some(SIGNATURE) |
445 | } else { |
446 | None |
447 | } |
448 | }) |
449 | .unwrap(); |
450 | assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd)); |
451 | } |
452 | |
453 | #[test ] |
454 | fn write_read_cycle_multiple() { |
455 | let messages = vec![ |
456 | Message { |
457 | sender_id: 42, |
458 | opcode: 0, |
459 | args: smallvec![ |
460 | Argument::Int(42), |
461 | Argument::Str(Some(Box::new(CString::new(&b"I like trains" [..]).unwrap()))), |
462 | ], |
463 | }, |
464 | Message { |
465 | sender_id: 42, |
466 | opcode: 1, |
467 | args: smallvec![ |
468 | Argument::Fd(1), // stdin |
469 | Argument::Fd(0), // stdout |
470 | ], |
471 | }, |
472 | Message { |
473 | sender_id: 42, |
474 | opcode: 2, |
475 | args: smallvec![ |
476 | Argument::Uint(3), |
477 | Argument::Fd(2), // stderr |
478 | ], |
479 | }, |
480 | ]; |
481 | |
482 | static SIGNATURES: &[&[ArgumentType]] = &[ |
483 | &[ArgumentType::Int, ArgumentType::Str(AllowNull::No)], |
484 | &[ArgumentType::Fd, ArgumentType::Fd], |
485 | &[ArgumentType::Uint, ArgumentType::Fd], |
486 | ]; |
487 | |
488 | let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap(); |
489 | let mut client = BufferedSocket::new(Socket::from(client)); |
490 | let mut server = BufferedSocket::new(Socket::from(server)); |
491 | |
492 | for msg in &messages { |
493 | client.write_message(msg).unwrap(); |
494 | } |
495 | client.flush().unwrap(); |
496 | |
497 | server.fill_incoming_buffers().unwrap(); |
498 | |
499 | let mut recv_msgs = Vec::new(); |
500 | while let Ok(message) = server.read_one_message(|sender_id, opcode| { |
501 | if sender_id == 42 { |
502 | Some(SIGNATURES[opcode as usize]) |
503 | } else { |
504 | None |
505 | } |
506 | }) { |
507 | recv_msgs.push(message); |
508 | } |
509 | assert_eq!(recv_msgs.len(), 3); |
510 | for (msg1, msg2) in messages.into_iter().zip(recv_msgs.into_iter()) { |
511 | assert_eq_msgs(&msg1.map_fd(|fd| fd.as_raw_fd()), &msg2.map_fd(IntoRawFd::into_raw_fd)); |
512 | } |
513 | } |
514 | |
515 | #[test ] |
516 | fn parse_with_string_len_multiple_of_4() { |
517 | let msg = Message { |
518 | sender_id: 2, |
519 | opcode: 0, |
520 | args: smallvec![ |
521 | Argument::Uint(18), |
522 | Argument::Str(Some(Box::new(CString::new(&b"wl_shell" [..]).unwrap()))), |
523 | Argument::Uint(1), |
524 | ], |
525 | }; |
526 | |
527 | let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap(); |
528 | let mut client = BufferedSocket::new(Socket::from(client)); |
529 | let mut server = BufferedSocket::new(Socket::from(server)); |
530 | |
531 | client.write_message(&msg).unwrap(); |
532 | client.flush().unwrap(); |
533 | |
534 | static SIGNATURE: &[ArgumentType] = |
535 | &[ArgumentType::Uint, ArgumentType::Str(AllowNull::No), ArgumentType::Uint]; |
536 | |
537 | server.fill_incoming_buffers().unwrap(); |
538 | |
539 | let ret_msg = |
540 | server |
541 | .read_one_message(|sender_id, opcode| { |
542 | if sender_id == 2 && opcode == 0 { |
543 | Some(SIGNATURE) |
544 | } else { |
545 | None |
546 | } |
547 | }) |
548 | .unwrap(); |
549 | |
550 | assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd)); |
551 | } |
552 | } |
553 | |