1use {
4 crate::{
5 acceptor::{Acceptor, AcceptorError},
6 baseline::Baseline,
7 client::Client,
8 endpoint::{Endpoint, EndpointError},
9 handler::HandlerHolder,
10 object::{Object, ObjectCoreApi, ObjectPrivate},
11 poll::{self, PollError, PollEvent, Poller},
12 protocols::wayland::wl_display::WlDisplay,
13 trans::{FlushResult, TransError},
14 utils::{
15 env::{WAYLAND_DISPLAY, WAYLAND_SOCKET, XDG_RUNTIME_DIR},
16 stack::Stack,
17 stash::Stash,
18 },
19 },
20 error_reporter::Report,
21 run_on_drop::on_drop,
22 std::{
23 cell::{Cell, RefCell},
24 collections::HashMap,
25 io::{self, pipe},
26 os::fd::{AsFd, OwnedFd},
27 rc::{Rc, Weak},
28 sync::{
29 Arc,
30 atomic::{AtomicBool, Ordering::Acquire},
31 },
32 time::Duration,
33 },
34 thiserror::Error,
35 uapi::c,
36};
37pub use {
38 builder::StateBuilder,
39 destructor::{Destructor, RemoteDestructor},
40};
41
42mod builder;
43mod destructor;
44#[cfg(test)]
45mod tests;
46
47#[derive(Debug, Error)]
49#[error(transparent)]
50pub struct StateError(#[from] StateErrorKind);
51
52#[derive(Debug, Error)]
53enum StateErrorKind {
54 #[error("the state has already been destroyed")]
55 Destroyed,
56 #[error("the state has been destroyed by a remote destructor")]
57 RemoteDestroyed,
58 #[error("cannot perform recursive call into the state")]
59 RecursiveCall,
60 #[error("the server hung up the connection")]
61 ServerHangup,
62 #[error("could not write to the server socket")]
63 WriteToServer(#[source] EndpointError),
64 #[error("could not dispatch server events")]
65 DispatchEvents(#[source] EndpointError),
66 #[error("could not create a socket pair")]
67 Socketpair(#[source] io::Error),
68 #[error(transparent)]
69 CreateAcceptor(AcceptorError),
70 #[error("could not accept a new connection")]
71 AcceptConnection(AcceptorError),
72 #[error("could not create a pipe")]
73 CreatePipe(#[source] io::Error),
74 #[error("could not read {} environment variable", WAYLAND_DISPLAY)]
75 WaylandDisplay,
76 #[error("the display name is empty")]
77 WaylandDisplayEmpty,
78 #[error("{} is not set", XDG_RUNTIME_DIR)]
79 XrdNotSet,
80 #[error("the socket path is too long")]
81 SocketPathTooLong,
82 #[error("could not create a socket")]
83 CreateSocket(#[source] io::Error),
84 #[error("could not connect to {0}")]
85 Connect(String, #[source] io::Error),
86 #[error("{} does not contain a valid number", WAYLAND_SOCKET)]
87 WaylandSocketNotNumber,
88 #[error("F_GETFD failed on {}", WAYLAND_SOCKET)]
89 WaylandSocketGetFd(#[source] io::Error),
90 #[error("F_SETFD failed on {}", WAYLAND_SOCKET)]
91 WaylandSocketSetFd(#[source] io::Error),
92 #[error(transparent)]
93 PollError(PollError),
94}
95
96pub struct State {
149 pub(crate) baseline: Baseline,
150 poller: Poller,
151 next_pollable_id: Cell<u64>,
152 pub(crate) server: Option<Rc<Endpoint>>,
153 pub(crate) destroyed: Cell<bool>,
154 handler: HandlerHolder<dyn StateHandler>,
155 pollables: RefCell<HashMap<u64, Pollable>>,
156 acceptable_acceptors: Stack<Rc<Acceptor>>,
157 has_acceptable_acceptors: Cell<bool>,
158 clients_to_kill: Stack<Rc<Client>>,
159 has_clients_to_kill: Cell<bool>,
160 readable_endpoints: Stack<EndpointWithClient>,
161 has_readable_endpoints: Cell<bool>,
162 flushable_endpoints: Stack<EndpointWithClient>,
163 has_flushable_endpoints: Cell<bool>,
164 interest_update_endpoints: Stack<Rc<Endpoint>>,
165 has_interest_update_endpoints: Cell<bool>,
166 interest_update_acceptors: Stack<Rc<Acceptor>>,
167 has_interest_update_acceptors: Cell<bool>,
168 pub(crate) all_objects: RefCell<HashMap<u64, Weak<dyn Object>>>,
169 pub(crate) next_object_id: Cell<u64>,
170 #[cfg(feature = "logging")]
171 pub(crate) log: bool,
172 #[cfg(feature = "logging")]
173 pub(crate) log_prefix: String,
174 #[cfg(feature = "logging")]
175 log_writer: RefCell<io::BufWriter<uapi::Fd>>,
176 global_lock_held: Cell<bool>,
177 pub(crate) object_stash: Stash<Rc<dyn Object>>,
178 pub(crate) forward_to_client: Cell<bool>,
179 pub(crate) forward_to_server: Cell<bool>,
180}
181
182pub trait StateHandler: 'static {
184 fn new_client(&mut self, client: &Rc<Client>) {
189 let _ = client;
190 }
191}
192
193enum Pollable {
194 Endpoint(EndpointWithClient),
195 Acceptor(Rc<Acceptor>),
196 Destructor(OwnedFd, Arc<AtomicBool>),
197}
198
199#[derive(Clone)]
200struct EndpointWithClient {
201 endpoint: Rc<Endpoint>,
202 client: Option<Rc<Client>>,
203}
204
205pub(crate) struct HandlerLock<'a> {
206 state: &'a State,
207}
208
209impl State {
210 pub(crate) fn remove_endpoint(&self, endpoint: &Endpoint) {
211 self.pollables.borrow_mut().remove(&endpoint.id);
212 self.poller.unregister(endpoint.socket.as_fd());
213 endpoint.unregistered.set(true);
214 }
215
216 fn acquire_handler_lock(&self) -> Result<HandlerLock<'_>, StateErrorKind> {
217 if self.global_lock_held.replace(true) {
218 return Err(StateErrorKind::RecursiveCall);
219 }
220 Ok(HandlerLock { state: self })
221 }
222
223 fn flush_locked(&self, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
224 let mut did_work = false;
225 did_work |= self.perform_writes(lock)?;
226 did_work |= self.kill_clients();
227 self.update_interests()?;
228 Ok(did_work)
229 }
230
231 pub(crate) fn handle_delete_id(&self, server: &Endpoint, id: u32) {
232 let object = server.objects.borrow_mut().remove(&id).unwrap();
233 let core = object.core();
234 core.server_obj_id.take();
235 server.idl.release(id);
236 if let Err((e, object)) = object.delete_id() {
237 log::warn!(
238 "Could not handle a wl_display.delete_id message: {}",
239 Report::new(e),
240 );
241 let _ = object.core().try_delete_id();
242 }
243 }
244
245 fn perform_writes(&self, _: &HandlerLock<'_>) -> Result<bool, StateError> {
246 if !self.has_flushable_endpoints.get() {
247 return Ok(false);
248 }
249 while let Some(ewc) = self.flushable_endpoints.pop() {
250 let res = match ewc.endpoint.flush() {
251 Ok(r) => r,
252 Err(e) => {
253 let is_closed = matches!(e, EndpointError::Flush(TransError::Closed));
254 if let Some(client) = &ewc.client {
255 if !is_closed {
256 log::warn!(
257 "Could not write to client#{}: {}",
258 client.endpoint.id,
259 Report::new(e),
260 );
261 }
262 self.add_client_to_kill(client);
263 } else {
264 if is_closed {
265 return Err(StateErrorKind::ServerHangup.into());
266 }
267 return Err(StateErrorKind::WriteToServer(e).into());
268 }
269 continue;
270 }
271 };
272 match res {
273 FlushResult::Done => {
274 ewc.endpoint.flush_queued.set(false);
275 self.change_interest(&ewc.endpoint, |i| i & !poll::WRITABLE);
276 }
277 FlushResult::Blocked => {
278 self.change_interest(&ewc.endpoint, |i| i | poll::WRITABLE);
279 }
280 }
281 }
282 self.has_flushable_endpoints.set(false);
283 Ok(true)
284 }
285
286 fn accept_connections(self: &Rc<Self>, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
287 if !self.has_acceptable_acceptors.get() {
288 return Ok(false);
289 }
290 self.check_destroyed()?;
291 while let Some(acceptor) = self.acceptable_acceptors.pop() {
292 self.interest_update_acceptors.push(acceptor.clone());
293 self.has_interest_update_acceptors.set(true);
294 const MAX_ACCEPT_PER_ITERATION: usize = 10;
295 for _ in 0..MAX_ACCEPT_PER_ITERATION {
296 let socket = acceptor
297 .accept()
298 .map_err(StateErrorKind::AcceptConnection)?;
299 let Some(socket) = socket else {
300 break;
301 };
302 self.create_client(Some(lock), &Rc::new(socket))?;
303 }
304 }
305 self.has_acceptable_acceptors.set(false);
306 Ok(true)
307 }
308
309 fn read_messages(&self, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
310 if !self.has_readable_endpoints.get() {
311 return Ok(false);
312 }
313 while let Some(ewc) = self.readable_endpoints.pop() {
314 let res = ewc.endpoint.read_messages(lock, ewc.client.as_ref());
315 if let Err(e) = res {
316 if let Some(client) = &ewc.client {
317 log::error!("Could not handle client message: {}", Report::new(e));
318 self.add_client_to_kill(client);
319 } else {
320 return Err(StateErrorKind::DispatchEvents(e).into());
321 }
322 }
323 self.change_interest(&ewc.endpoint, |i| i | poll::READABLE);
324 }
325 self.has_readable_endpoints.set(false);
326 Ok(true)
327 }
328
329 fn change_interest(&self, endpoint: &Rc<Endpoint>, f: impl FnOnce(u32) -> u32) {
330 if self.destroyed.get() {
331 return;
332 }
333 let old = endpoint.desired_interest.get();
334 let new = f(old);
335 endpoint.desired_interest.set(new);
336 if old != new
337 && endpoint.current_interest.get() != new
338 && !endpoint.interest_update_queued.replace(true)
339 {
340 self.interest_update_endpoints.push(endpoint.clone());
341 self.has_interest_update_endpoints.set(true);
342 }
343 }
344
345 pub(crate) fn add_flushable_endpoint(
346 &self,
347 endpoint: &Rc<Endpoint>,
348 client: Option<&Rc<Client>>,
349 ) {
350 if self.destroyed.get() {
351 return;
352 }
353 self.flushable_endpoints.push(EndpointWithClient {
354 endpoint: endpoint.clone(),
355 client: client.cloned(),
356 });
357 self.has_flushable_endpoints.set(true);
358 }
359
360 fn wait_for_work(&self, _: &HandlerLock<'_>, mut timeout: c::c_int) -> Result<(), StateError> {
361 self.check_destroyed()?;
362 let mut events = [PollEvent::default(); poll::MAX_EVENTS];
363 let pollables = &mut *self.pollables.borrow_mut();
364 loop {
365 let n = self
366 .poller
367 .read_events(timeout, &mut events)
368 .map_err(StateErrorKind::PollError)?;
369 if n == 0 {
370 return Ok(());
371 }
372 timeout = 0;
373 for event in &events[0..n] {
374 let id = event.u64;
375 let Some(pollable) = pollables.get(&id) else {
376 continue;
377 };
378 match pollable {
379 Pollable::Endpoint(ewc) => {
380 let events = event.events;
381 if events & poll::ERROR != 0 {
382 if let Some(client) = &ewc.client {
383 self.add_client_to_kill(client);
384 } else {
385 return Err(StateErrorKind::ServerHangup.into());
386 }
387 continue;
388 }
389 ewc.endpoint.current_interest.set(0);
390 self.change_interest(&ewc.endpoint, |i| i & !events);
391 if events & poll::READABLE != 0 {
392 self.readable_endpoints.push(ewc.clone());
393 self.has_readable_endpoints.set(true);
394 }
395 if events & poll::WRITABLE != 0 {
396 self.flushable_endpoints.push(ewc.clone());
397 self.has_flushable_endpoints.set(true);
398 }
399 }
400 Pollable::Acceptor(a) => {
401 self.acceptable_acceptors.push(a.clone());
402 self.has_acceptable_acceptors.set(true);
403 }
404 Pollable::Destructor(fd, destroy) => {
405 let destroy = destroy.load(Acquire);
406 self.poller.unregister(fd.as_fd());
407 pollables.remove(&id);
408 if destroy {
409 return Err(StateErrorKind::RemoteDestroyed.into());
410 }
411 }
412 }
413 }
414 }
415 }
416
417 fn add_client_to_kill(&self, client: &Rc<Client>) {
418 self.clients_to_kill.push(client.clone());
419 self.has_clients_to_kill.set(true);
420 }
421
422 fn kill_clients(&self) -> bool {
423 if !self.has_clients_to_kill.get() {
424 return false;
425 }
426 while let Some(client) = self.clients_to_kill.pop() {
427 if let Some(handler) = client.handler.borrow_mut().take() {
428 handler.disconnected();
429 }
430 client.disconnect();
431 }
432 self.has_clients_to_kill.set(false);
433 true
434 }
435
436 fn create_pollable_id(&self) -> u64 {
437 let id = self.next_pollable_id.get();
438 self.next_pollable_id.set(id + 1);
439 id
440 }
441
442 fn update_interests(&self) -> Result<(), StateError> {
443 if self.has_interest_update_endpoints.get() {
444 while let Some(endpoint) = self.interest_update_endpoints.pop() {
445 endpoint.interest_update_queued.set(false);
446 let desired = endpoint.desired_interest.get();
447 if desired == endpoint.current_interest.get() {
448 continue;
449 }
450 if endpoint.unregistered.get() {
451 continue;
452 }
453 self.poller
454 .update_interests(endpoint.id, endpoint.socket.as_fd(), desired)
455 .map_err(StateErrorKind::PollError)?;
456 endpoint.current_interest.set(desired);
457 }
458 self.has_interest_update_endpoints.set(false);
459 }
460 if self.has_interest_update_acceptors.get() {
461 while let Some(acceptor) = self.interest_update_acceptors.pop() {
462 self.poller
463 .update_interests(acceptor.id, acceptor.socket.as_fd(), poll::READABLE)
464 .map_err(StateErrorKind::PollError)?;
465 }
466 self.has_interest_update_acceptors.set(false);
467 }
468 Ok(())
469 }
470
471 fn check_destroyed(&self) -> Result<(), StateError> {
472 if self.destroyed.get() {
473 return Err(StateErrorKind::Destroyed.into());
474 }
475 Ok(())
476 }
477
478 #[cfg(feature = "logging")]
479 #[cold]
480 pub(crate) fn log(&self, args: std::fmt::Arguments<'_>) {
481 use std::io::Write;
482 let writer = &mut *self.log_writer.borrow_mut();
483 let _ = writer.write_fmt(args);
484 let _ = writer.flush();
485 }
486}
487
488impl State {
490 pub fn builder(baseline: Baseline) -> StateBuilder {
492 StateBuilder::new(baseline)
493 }
494}
495
496impl State {
498 pub fn dispatch_blocking(self: &Rc<Self>) -> Result<bool, StateError> {
502 self.dispatch(None)
503 }
504
505 pub fn dispatch_available(self: &Rc<Self>) -> Result<bool, StateError> {
509 self.dispatch(Some(Duration::from_secs(0)))
510 }
511
512 pub fn dispatch(self: &Rc<Self>, timeout: Option<Duration>) -> Result<bool, StateError> {
528 let mut did_work = false;
529 let lock = self.acquire_handler_lock()?;
530 let timeout = timeout
531 .and_then(|t| t.as_millis().try_into().ok())
532 .unwrap_or(-1);
533 let destroy_on_error = on_drop(|| self.destroy());
534 if timeout != 0 {
535 did_work |= self.flush_locked(&lock)?;
536 }
537 self.wait_for_work(&lock, timeout)?;
538 did_work |= self.accept_connections(&lock)?;
539 did_work |= self.read_messages(&lock)?;
540 did_work |= self.flush_locked(&lock)?;
541 destroy_on_error.forget();
542 Ok(did_work)
543 }
544}
545
546impl State {
547 pub fn poll_fd(&self) -> &Rc<OwnedFd> {
554 self.poller.fd()
555 }
556
557 pub fn before_poll(&self) -> Result<(), StateError> {
576 let lock = self.acquire_handler_lock()?;
577 let destroy_on_error = on_drop(|| self.destroy());
578 self.flush_locked(&lock)?;
579 destroy_on_error.forget();
580 Ok(())
581 }
582}
583
584impl State {
586 pub fn create_object<P>(self: &Rc<Self>, version: u32) -> Rc<P>
601 where
602 P: Object,
603 {
604 P::new(self, version)
605 }
606
607 pub fn display(self: &Rc<Self>) -> Rc<WlDisplay> {
609 let display = WlDisplay::new(self, 1);
610 if self.server.is_some() {
611 display.core().server_obj_id.set(Some(1));
612 }
613 display
614 }
615
616 pub fn set_default_forward_to_client(&self, enabled: bool) {
621 self.forward_to_client.set(enabled);
622 }
623
624 pub fn set_default_forward_to_server(&self, enabled: bool) {
629 self.forward_to_server.set(enabled);
630 }
631}
632
633impl State {
635 pub fn connect(self: &Rc<Self>) -> Result<(Rc<Client>, OwnedFd), StateError> {
643 let (server_fd, client_fd) = uapi::socketpair(
644 c::AF_UNIX,
645 c::SOCK_STREAM | c::SOCK_NONBLOCK | c::SOCK_CLOEXEC,
646 0,
647 )
648 .map_err(|e| StateErrorKind::Socketpair(e.into()))?;
649 let client = self.create_client(None, &Rc::new(server_fd.into()))?;
650 Ok((client, client_fd.into()))
651 }
652
653 pub fn add_client(self: &Rc<Self>, socket: &Rc<OwnedFd>) -> Result<Rc<Client>, StateError> {
661 self.create_client(None, socket)
662 }
663
664 pub fn create_acceptor(&self, max_tries: u32) -> Result<Rc<Acceptor>, StateError> {
672 self.check_destroyed()?;
673 let id = self.create_pollable_id();
674 let acceptor =
675 Acceptor::create(id, max_tries, true).map_err(StateErrorKind::CreateAcceptor)?;
676 self.poller
677 .register(id, acceptor.socket.as_fd())
678 .map_err(StateErrorKind::PollError)?;
679 self.update_interests()?;
680 self.interest_update_acceptors.push(acceptor.clone());
681 self.has_interest_update_acceptors.set(true);
682 self.pollables
683 .borrow_mut()
684 .insert(id, Pollable::Acceptor(acceptor.clone()));
685 Ok(acceptor)
686 }
687
688 fn create_client(
689 self: &Rc<Self>,
690 lock: Option<&HandlerLock<'_>>,
691 socket: &Rc<OwnedFd>,
692 ) -> Result<Rc<Client>, StateError> {
693 self.check_destroyed()?;
694 let id = self.create_pollable_id();
695 self.poller
696 .register(id, socket.as_fd())
697 .map_err(StateErrorKind::PollError)?;
698 let endpoint = Endpoint::new(id, socket);
699 self.change_interest(&endpoint, |i| i | poll::READABLE);
700 self.update_interests()?;
701 let client = Rc::new(Client {
702 state: self.clone(),
703 endpoint: endpoint.clone(),
704 display: self.display(),
705 destroyed: Cell::new(false),
706 handler: Default::default(),
707 });
708 client
709 .display
710 .core()
711 .set_client_id(&client, 1, client.display.clone())
712 .unwrap();
713 self.pollables.borrow_mut().insert(
714 id,
715 Pollable::Endpoint(EndpointWithClient {
716 endpoint,
717 client: Some(client.clone()),
718 }),
719 );
720 if lock.is_some()
721 && let Some(handler) = &mut *self.handler.borrow_mut()
722 {
723 handler.new_client(&client);
724 }
725 Ok(client)
726 }
727}
728
729impl State {
734 pub fn unset_handler(&self) {
736 self.handler.set(None);
737 }
738
739 pub fn set_handler(&self, handler: impl StateHandler) {
741 self.set_boxed_handler(Box::new(handler))
742 }
743
744 pub fn set_boxed_handler(&self, handler: Box<dyn StateHandler>) {
746 if self.destroyed.get() {
747 return;
748 }
749 self.handler.set(Some(handler));
750 }
751}
752
753impl State {
755 pub fn is_not_destroyed(&self) -> bool {
759 !self.is_destroyed()
760 }
761
762 pub fn is_destroyed(&self) -> bool {
786 self.destroyed.get()
787 }
788
789 pub fn destroy(&self) {
794 if self.destroyed.replace(true) {
795 return;
796 }
797 let objects = &mut *self.object_stash.borrow();
798 for pollable in self.pollables.borrow().values() {
799 let fd = match pollable {
800 Pollable::Endpoint(ewc) => {
801 if let Some(c) = &ewc.client {
802 c.destroyed.set(true);
803 }
804 objects.extend(ewc.endpoint.objects.borrow_mut().drain().map(|v| v.1));
805 &ewc.endpoint.socket
806 }
807 Pollable::Acceptor(a) => &a.socket,
808 Pollable::Destructor(fd, _) => fd,
809 };
810 self.poller.unregister(fd.as_fd());
811 }
812 objects.clear();
813 for object in self.all_objects.borrow().values() {
814 if let Some(object) = object.upgrade() {
815 objects.push(object);
816 }
817 }
818 for object in objects {
819 object.unset_handler();
820 object.core().client.take();
821 }
822 self.handler.set(None);
823 self.pollables.borrow_mut().clear();
824 self.acceptable_acceptors.take();
825 self.clients_to_kill.take();
826 self.readable_endpoints.take();
827 self.flushable_endpoints.take();
828 self.interest_update_endpoints.take();
829 self.interest_update_acceptors.take();
830 self.all_objects.borrow_mut().clear();
831 let _ = self.create_remote_destructor();
833 }
834
835 pub fn create_destructor(self: &Rc<Self>) -> Destructor {
847 Destructor {
848 state: self.clone(),
849 enabled: Cell::new(true),
850 }
851 }
852
853 pub fn create_remote_destructor(&self) -> Result<RemoteDestructor, StateError> {
859 let (r, w) = pipe().map_err(StateErrorKind::CreatePipe)?;
860 let r: OwnedFd = r.into();
861 let id = self.create_pollable_id();
862 self.poller
863 .register(id, r.as_fd())
864 .map_err(StateErrorKind::PollError)?;
865 let destroy = Arc::new(AtomicBool::new(false));
866 self.pollables
867 .borrow_mut()
868 .insert(id, Pollable::Destructor(r, destroy.clone()));
869 Ok(RemoteDestructor {
870 destroy,
871 _fd: w.into(),
872 enabled: AtomicBool::new(true),
873 })
874 }
875}
876
877impl StateError {
878 pub fn is_destroyed(&self) -> bool {
882 matches!(self.0, StateErrorKind::Destroyed)
883 }
884}
885
886impl Drop for HandlerLock<'_> {
887 fn drop(&mut self) {
888 self.state.global_lock_held.set(false);
889 }
890}