1use std::{
2 ffi::CString,
3 os::unix::io::OwnedFd,
4 os::unix::{io::RawFd, net::UnixStream},
5 sync::{Arc, Mutex, Weak},
6};
7
8use crate::{
9 protocol::{same_interface, Interface, Message, ObjectInfo, ANONYMOUS_INTERFACE},
10 types::server::{DisconnectReason, GlobalInfo, InvalidId},
11};
12
13use super::{
14 client::ClientStore, registry::Registry, ClientData, ClientId, Credentials, GlobalHandler,
15 InnerClientId, InnerGlobalId, InnerObjectId, ObjectData, ObjectId,
16};
17
18pub(crate) type PendingDestructor<D> = (Arc<dyn ObjectData<D>>, InnerClientId, InnerObjectId);
19
20#[derive(Debug)]
21pub struct State<D: 'static> {
22 pub(crate) clients: ClientStore<D>,
23 pub(crate) registry: Registry<D>,
24 pub(crate) pending_destructors: Vec<PendingDestructor<D>>,
25 pub(crate) poll_fd: OwnedFd,
26}
27
28impl<D> State<D> {
29 pub(crate) fn new(poll_fd: OwnedFd) -> Self {
30 let debug =
31 matches!(std::env::var_os("WAYLAND_DEBUG"), Some(str) if str == "1" || str == "server");
32 Self {
33 clients: ClientStore::new(debug),
34 registry: Registry::new(),
35 pending_destructors: Vec::new(),
36 poll_fd,
37 }
38 }
39
40 pub(crate) fn cleanup<'a>(&mut self) -> impl FnOnce(&super::Handle, &mut D) + 'a {
41 let dead_clients = self.clients.cleanup(&mut self.pending_destructors);
42 self.registry.cleanup(&dead_clients);
43 // return a closure that will do the cleanup once invoked
44 let pending_destructors = std::mem::take(&mut self.pending_destructors);
45 move |handle, data| {
46 for (object_data, client_id, object_id) in pending_destructors {
47 object_data.clone().destroyed(
48 handle,
49 data,
50 ClientId { id: client_id },
51 ObjectId { id: object_id },
52 );
53 }
54 }
55 }
56
57 pub(crate) fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
58 if let Some(ClientId { id: client }) = client {
59 match self.clients.get_client_mut(client) {
60 Ok(client) => client.flush(),
61 Err(InvalidId) => Ok(()),
62 }
63 } else {
64 for client in self.clients.clients_mut() {
65 let _ = client.flush();
66 }
67 Ok(())
68 }
69 }
70}
71
72#[derive(Clone)]
73pub struct InnerHandle {
74 pub(crate) state: Arc<Mutex<dyn ErasedState + Send>>,
75}
76
77impl std::fmt::Debug for InnerHandle {
78 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 fmt.debug_struct(name:"InnerHandle[rs]").finish_non_exhaustive()
80 }
81}
82
83#[derive(Clone)]
84pub struct WeakInnerHandle {
85 pub(crate) state: Weak<Mutex<dyn ErasedState + Send>>,
86}
87
88impl std::fmt::Debug for WeakInnerHandle {
89 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 fmt.debug_struct(name:"WeakInnerHandle[rs]").finish_non_exhaustive()
91 }
92}
93
94impl WeakInnerHandle {
95 pub fn upgrade(&self) -> Option<InnerHandle> {
96 self.state.upgrade().map(|state: Arc>| InnerHandle { state })
97 }
98}
99
100impl InnerHandle {
101 pub fn downgrade(&self) -> WeakInnerHandle {
102 WeakInnerHandle { state: Arc::downgrade(&self.state) }
103 }
104
105 pub fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
106 self.state.lock().unwrap().object_info(id)
107 }
108
109 pub fn insert_client(
110 &self,
111 stream: UnixStream,
112 data: Arc<dyn ClientData>,
113 ) -> std::io::Result<InnerClientId> {
114 self.state.lock().unwrap().insert_client(stream, data)
115 }
116
117 pub fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
118 self.state.lock().unwrap().get_client(id)
119 }
120
121 pub fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
122 self.state.lock().unwrap().get_client_data(id)
123 }
124
125 pub fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
126 self.state.lock().unwrap().get_client_credentials(id)
127 }
128
129 pub fn with_all_clients(&self, mut f: impl FnMut(ClientId)) {
130 self.state.lock().unwrap().with_all_clients(&mut f)
131 }
132
133 pub fn with_all_objects_for(
134 &self,
135 client_id: InnerClientId,
136 mut f: impl FnMut(ObjectId),
137 ) -> Result<(), InvalidId> {
138 self.state.lock().unwrap().with_all_objects_for(client_id, &mut f)
139 }
140
141 pub fn object_for_protocol_id(
142 &self,
143 client_id: InnerClientId,
144 interface: &'static Interface,
145 protocol_id: u32,
146 ) -> Result<ObjectId, InvalidId> {
147 self.state.lock().unwrap().object_for_protocol_id(client_id, interface, protocol_id)
148 }
149
150 pub fn create_object<D: 'static>(
151 &self,
152 client_id: InnerClientId,
153 interface: &'static Interface,
154 version: u32,
155 data: Arc<dyn ObjectData<D>>,
156 ) -> Result<ObjectId, InvalidId> {
157 let mut state = self.state.lock().unwrap();
158 let state = (&mut *state as &mut dyn ErasedState)
159 .downcast_mut::<State<D>>()
160 .expect("Wrong type parameter passed to Handle::create_object().");
161 let client = state.clients.get_client_mut(client_id)?;
162 Ok(ObjectId { id: client.create_object(interface, version, data) })
163 }
164
165 pub fn null_id() -> ObjectId {
166 ObjectId {
167 id: InnerObjectId {
168 id: 0,
169 serial: 0,
170 client_id: InnerClientId { id: 0, serial: 0 },
171 interface: &ANONYMOUS_INTERFACE,
172 },
173 }
174 }
175
176 pub fn send_event(&self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
177 self.state.lock().unwrap().send_event(msg)
178 }
179
180 pub fn get_object_data<D: 'static>(
181 &self,
182 id: InnerObjectId,
183 ) -> Result<Arc<dyn ObjectData<D>>, InvalidId> {
184 let mut state = self.state.lock().unwrap();
185 let state = (&mut *state as &mut dyn ErasedState)
186 .downcast_mut::<State<D>>()
187 .expect("Wrong type parameter passed to Handle::get_object_data().");
188 state.clients.get_client(id.client_id.clone())?.get_object_data(id)
189 }
190
191 pub fn get_object_data_any(
192 &self,
193 id: InnerObjectId,
194 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
195 self.state.lock().unwrap().get_object_data_any(id)
196 }
197
198 pub fn set_object_data<D: 'static>(
199 &self,
200 id: InnerObjectId,
201 data: Arc<dyn ObjectData<D>>,
202 ) -> Result<(), InvalidId> {
203 let mut state = self.state.lock().unwrap();
204 let state = (&mut *state as &mut dyn ErasedState)
205 .downcast_mut::<State<D>>()
206 .expect("Wrong type parameter passed to Handle::set_object_data().");
207 state.clients.get_client_mut(id.client_id.clone())?.set_object_data(id, data)
208 }
209
210 pub fn post_error(&self, object_id: InnerObjectId, error_code: u32, message: CString) {
211 self.state.lock().unwrap().post_error(object_id, error_code, message)
212 }
213
214 pub fn kill_client(&self, client_id: InnerClientId, reason: DisconnectReason) {
215 self.state.lock().unwrap().kill_client(client_id, reason)
216 }
217
218 pub fn create_global<D: 'static>(
219 &self,
220 interface: &'static Interface,
221 version: u32,
222 handler: Arc<dyn GlobalHandler<D>>,
223 ) -> InnerGlobalId {
224 let mut state = self.state.lock().unwrap();
225 let state = (&mut *state as &mut dyn ErasedState)
226 .downcast_mut::<State<D>>()
227 .expect("Wrong type parameter passed to Handle::create_global().");
228 state.registry.create_global(interface, version, handler, &mut state.clients)
229 }
230
231 pub fn disable_global<D: 'static>(&self, id: InnerGlobalId) {
232 let mut state = self.state.lock().unwrap();
233 let state = (&mut *state as &mut dyn ErasedState)
234 .downcast_mut::<State<D>>()
235 .expect("Wrong type parameter passed to Handle::create_global().");
236
237 state.registry.disable_global(id, &mut state.clients)
238 }
239
240 pub fn remove_global<D: 'static>(&self, id: InnerGlobalId) {
241 let mut state = self.state.lock().unwrap();
242 let state = (&mut *state as &mut dyn ErasedState)
243 .downcast_mut::<State<D>>()
244 .expect("Wrong type parameter passed to Handle::create_global().");
245
246 state.registry.remove_global(id, &mut state.clients)
247 }
248
249 pub fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
250 self.state.lock().unwrap().global_info(id)
251 }
252
253 pub fn get_global_handler<D: 'static>(
254 &self,
255 id: InnerGlobalId,
256 ) -> Result<Arc<dyn GlobalHandler<D>>, InvalidId> {
257 let mut state = self.state.lock().unwrap();
258 let state = (&mut *state as &mut dyn ErasedState)
259 .downcast_mut::<State<D>>()
260 .expect("Wrong type parameter passed to Handle::get_global_handler().");
261 state.registry.get_handler(id)
262 }
263
264 pub fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
265 self.state.lock().unwrap().flush(client)
266 }
267}
268
269pub(crate) trait ErasedState: downcast_rs::Downcast {
270 fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId>;
271 fn insert_client(
272 &mut self,
273 stream: UnixStream,
274 data: Arc<dyn ClientData>,
275 ) -> std::io::Result<InnerClientId>;
276 fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId>;
277 fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId>;
278 fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId>;
279 fn with_all_clients(&self, f: &mut dyn FnMut(ClientId));
280 fn with_all_objects_for(
281 &self,
282 client_id: InnerClientId,
283 f: &mut dyn FnMut(ObjectId),
284 ) -> Result<(), InvalidId>;
285 fn object_for_protocol_id(
286 &self,
287 client_id: InnerClientId,
288 interface: &'static Interface,
289 protocol_id: u32,
290 ) -> Result<ObjectId, InvalidId>;
291 fn get_object_data_any(
292 &self,
293 id: InnerObjectId,
294 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId>;
295 fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId>;
296 fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString);
297 fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason);
298 fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId>;
299 fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()>;
300}
301
302downcast_rs::impl_downcast!(ErasedState);
303
304impl<D> ErasedState for State<D> {
305 fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
306 self.clients.get_client(id.client_id.clone())?.object_info(id)
307 }
308
309 fn insert_client(
310 &mut self,
311 stream: UnixStream,
312 data: Arc<dyn ClientData>,
313 ) -> std::io::Result<InnerClientId> {
314 let id = self.clients.create_client(stream, data);
315 let client = self.clients.get_client(id.clone()).unwrap();
316
317 // register the client to the internal epoll
318 #[cfg(any(target_os = "linux", target_os = "android"))]
319 let ret = {
320 use rustix::event::epoll;
321 epoll::add(
322 &self.poll_fd,
323 client,
324 epoll::EventData::new_u64(id.as_u64()),
325 epoll::EventFlags::IN,
326 )
327 };
328
329 #[cfg(any(
330 target_os = "dragonfly",
331 target_os = "freebsd",
332 target_os = "netbsd",
333 target_os = "openbsd",
334 target_os = "macos"
335 ))]
336 let ret = {
337 use rustix::event::kqueue::*;
338 use std::os::unix::io::{AsFd, AsRawFd};
339
340 let evt = Event::new(
341 EventFilter::Read(client.as_fd().as_raw_fd()),
342 EventFlags::ADD | EventFlags::RECEIPT,
343 id.as_u64() as isize,
344 );
345
346 let mut events = Vec::new();
347 unsafe { kevent(&self.poll_fd, &[evt], &mut events, None).map(|_| ()) }
348 };
349
350 match ret {
351 Ok(()) => Ok(id),
352 Err(e) => {
353 self.kill_client(id, DisconnectReason::ConnectionClosed);
354 Err(e.into())
355 }
356 }
357 }
358
359 fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
360 if self.clients.get_client(id.client_id.clone()).is_ok() {
361 Ok(ClientId { id: id.client_id })
362 } else {
363 Err(InvalidId)
364 }
365 }
366
367 fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
368 let client = self.clients.get_client(id)?;
369 Ok(client.data.clone())
370 }
371
372 fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
373 let client = self.clients.get_client(id)?;
374 Ok(client.get_credentials())
375 }
376
377 fn with_all_clients(&self, f: &mut dyn FnMut(ClientId)) {
378 for client in self.clients.all_clients_id() {
379 f(client)
380 }
381 }
382
383 fn with_all_objects_for(
384 &self,
385 client_id: InnerClientId,
386 f: &mut dyn FnMut(ObjectId),
387 ) -> Result<(), InvalidId> {
388 let client = self.clients.get_client(client_id)?;
389 for object in client.all_objects() {
390 f(object)
391 }
392 Ok(())
393 }
394
395 fn object_for_protocol_id(
396 &self,
397 client_id: InnerClientId,
398 interface: &'static Interface,
399 protocol_id: u32,
400 ) -> Result<ObjectId, InvalidId> {
401 let client = self.clients.get_client(client_id)?;
402 let object = client.object_for_protocol_id(protocol_id)?;
403 if same_interface(interface, object.interface) {
404 Ok(ObjectId { id: object })
405 } else {
406 Err(InvalidId)
407 }
408 }
409
410 fn get_object_data_any(
411 &self,
412 id: InnerObjectId,
413 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
414 self.clients
415 .get_client(id.client_id.clone())?
416 .get_object_data(id)
417 .map(|arc| arc.into_any_arc())
418 }
419
420 fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
421 self.clients
422 .get_client_mut(msg.sender_id.id.client_id.clone())?
423 .send_event(msg, Some(&mut self.pending_destructors))
424 }
425
426 fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString) {
427 if let Ok(client) = self.clients.get_client_mut(object_id.client_id.clone()) {
428 client.post_error(object_id, error_code, message)
429 }
430 }
431
432 fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason) {
433 if let Ok(client) = self.clients.get_client_mut(client_id) {
434 client.kill(reason)
435 }
436 }
437 fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
438 self.registry.get_info(id)
439 }
440
441 fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
442 self.flush(client)
443 }
444}
445