Skip to main content

wl_proxy/
state.rs

1//! The proxy state.
2
3use {
4    crate::{
5        acceptor::{Acceptor, AcceptorError},
6        baseline::Baseline,
7        client::Client,
8        endpoint::{Endpoint, EndpointError},
9        handler::HandlerHolder,
10        object::{Object, ObjectCoreApi, ObjectErrorKind, ObjectPrivate},
11        poll::{self, PollError, PollEvent, Poller},
12        protocols::wayland::wl_display::WlDisplay,
13        trans::{FlushResult, TransError},
14        utils::{
15            env::{WAYLAND_DISPLAY, WAYLAND_SOCKET, XDG_RUNTIME_DIR},
16            stack::Stack,
17            stash::Stash,
18        },
19    },
20    error_reporter::Report,
21    run_on_drop::on_drop,
22    std::{
23        cell::{Cell, RefCell},
24        collections::HashMap,
25        io::{self, pipe},
26        os::fd::{AsFd, AsRawFd, OwnedFd},
27        rc::{Rc, Weak},
28        sync::{
29            Arc,
30            atomic::{AtomicBool, Ordering::Acquire},
31        },
32        time::Duration,
33    },
34    thiserror::Error,
35    uapi::c,
36};
37pub use {
38    builder::StateBuilder,
39    destructor::{Destructor, RemoteDestructor},
40};
41
42mod builder;
43mod destructor;
44#[cfg(test)]
45mod tests;
46
47/// An error emitted by a [`State`].
48#[derive(Debug, Error)]
49#[error(transparent)]
50pub struct StateError(#[from] StateErrorKind);
51
52#[derive(Debug, Error)]
53enum StateErrorKind {
54    #[error("the state has already been destroyed")]
55    Destroyed,
56    #[error("the state has been destroyed by a remote destructor")]
57    RemoteDestroyed,
58    #[error("cannot perform recursive call into the state")]
59    RecursiveCall,
60    #[error("the server hung up the connection")]
61    ServerHangup,
62    #[error("could not write to the server socket")]
63    WriteToServer(#[source] EndpointError),
64    #[error("could not dispatch server events")]
65    DispatchEvents(#[source] EndpointError),
66    #[error("could not create a socket pair")]
67    Socketpair(#[source] io::Error),
68    #[error(transparent)]
69    CreateAcceptor(AcceptorError),
70    #[error("could not accept a new connection")]
71    AcceptConnection(AcceptorError),
72    #[error("could not create a pipe")]
73    CreatePipe(#[source] io::Error),
74    #[error("could not read {} environment variable", WAYLAND_DISPLAY)]
75    WaylandDisplay,
76    #[error("the display name is empty")]
77    WaylandDisplayEmpty,
78    #[error("{} is not set", XDG_RUNTIME_DIR)]
79    XrdNotSet,
80    #[error("the socket path is too long")]
81    SocketPathTooLong,
82    #[error("could not create a socket")]
83    CreateSocket(#[source] io::Error),
84    #[error("could not connect to {0}")]
85    Connect(String, #[source] io::Error),
86    #[error("{} does not contain a valid number", WAYLAND_SOCKET)]
87    WaylandSocketNotNumber,
88    #[error("F_GETFD failed on {}", WAYLAND_SOCKET)]
89    WaylandSocketGetFd(#[source] io::Error),
90    #[error("F_SETFD failed on {}", WAYLAND_SOCKET)]
91    WaylandSocketSetFd(#[source] io::Error),
92    #[error(transparent)]
93    PollError(PollError),
94    #[error("Could not create an eventfd")]
95    CreateEventfd(#[source] io::Error),
96}
97
98/// The proxy state.
99///
100/// This type represents a connection to a server and any number of clients connected to
101/// this proxy.
102///
103/// This type can be constructed by using a [`StateBuilder`].
104///
105/// # Example
106///
107/// ```
108/// # use std::rc::Rc;
109/// # use wl_proxy::baseline::Baseline;
110/// # use wl_proxy::client::{Client, ClientHandler};
111/// # use wl_proxy::protocols::wayland::wl_display::{WlDisplay, WlDisplayHandler};
112/// # use wl_proxy::protocols::wayland::wl_registry::WlRegistry;
113/// # use wl_proxy::state::{State, StateBuilder, StateHandler};
114/// # fn f() {
115/// let state = State::builder(Baseline::ALL_OF_THEM).build().unwrap();
116/// let acceptor = state.create_acceptor(1000).unwrap();
117/// eprintln!("{}", acceptor.display());
118/// loop {
119///     state.dispatch_blocking().unwrap();
120/// }
121///
122/// struct StateHandlerImpl;
123///
124/// impl StateHandler for StateHandlerImpl {
125///     fn new_client(&mut self, client: &Rc<Client>) {
126///         eprintln!("Client connected");
127///         client.set_handler(ClientHandlerImpl);
128///         client.display().set_handler(DisplayHandler);
129///     }
130/// }
131///
132/// struct ClientHandlerImpl;
133///
134/// impl ClientHandler for ClientHandlerImpl {
135///     fn disconnected(self: Box<Self>) {
136///         eprintln!("Client disconnected");
137///     }
138/// }
139///
140/// struct DisplayHandler;
141///
142/// impl WlDisplayHandler for DisplayHandler {
143///     fn handle_get_registry(&mut self, slf: &Rc<WlDisplay>, registry: &Rc<WlRegistry>) {
144///         eprintln!("get_registry called");
145///         let _ = slf.send_get_registry(registry);
146///     }
147/// }
148/// # }
149/// ```
150pub struct State {
151    pub(crate) baseline: Baseline,
152    poller: Poller,
153    next_pollable_id: Cell<u64>,
154    pub(crate) server: Option<Rc<Endpoint>>,
155    pub(crate) destroyed: Cell<bool>,
156    handler: HandlerHolder<dyn StateHandler>,
157    pollables: RefCell<HashMap<u64, Pollable>>,
158    acceptable_acceptors: Stack<Rc<Acceptor>>,
159    has_acceptable_acceptors: Cell<bool>,
160    clients_to_kill: Stack<Rc<Client>>,
161    has_clients_to_kill: Cell<bool>,
162    readable_endpoints: Stack<EndpointWithClient>,
163    has_readable_endpoints: Cell<bool>,
164    flushable_endpoints: Stack<EndpointWithClient>,
165    has_flushable_endpoints: Cell<bool>,
166    interest_update_endpoints: Stack<Rc<Endpoint>>,
167    has_interest_update_endpoints: Cell<bool>,
168    interest_update_acceptors: Stack<Rc<Acceptor>>,
169    has_interest_update_acceptors: Cell<bool>,
170    pub(crate) all_objects: RefCell<HashMap<u64, Weak<dyn Object>>>,
171    pub(crate) next_object_id: Cell<u64>,
172    #[cfg(feature = "logging")]
173    pub(crate) log: bool,
174    #[cfg(feature = "logging")]
175    pub(crate) log_prefix: String,
176    #[cfg(feature = "logging")]
177    log_writer: RefCell<io::BufWriter<uapi::Fd>>,
178    global_lock_held: Cell<bool>,
179    pub(crate) object_stash: Stash<Rc<dyn Object>>,
180    pub(crate) forward_to_client: Cell<bool>,
181    pub(crate) forward_to_server: Cell<bool>,
182    unsuspend_fd: OwnedFd,
183    unsuspend_requests: Stack<EndpointWithClient>,
184    has_unsuspend_requests: Cell<bool>,
185    unsuspend_triggered: Cell<bool>,
186}
187
188/// A handler for events emitted by a [`State`].
189pub trait StateHandler: 'static {
190    /// A new client has connected.
191    ///
192    /// This event is not emitted if the connection is created explicitly via
193    /// [`State::connect`] or [`State::add_client`].
194    fn new_client(&mut self, client: &Rc<Client>) {
195        let _ = client;
196    }
197
198    /// The server has sent a wl_display.error event.
199    ///
200    /// Such errors are fatal.
201    ///
202    /// The object can be `None` if the error is sent on an object that has already been
203    /// deleted.
204    fn display_error(
205        self: Box<Self>,
206        object: Option<&Rc<dyn Object>>,
207        server_id: u32,
208        error: u32,
209        msg: &str,
210    ) {
211        let _ = object;
212        let _ = server_id;
213        let _ = error;
214        let _ = msg;
215    }
216}
217
218enum Pollable {
219    Endpoint(EndpointWithClient),
220    Acceptor(Rc<Acceptor>),
221    Destructor(OwnedFd, Arc<AtomicBool>),
222    Unsuspend,
223}
224
225#[derive(Clone)]
226struct EndpointWithClient {
227    endpoint: Rc<Endpoint>,
228    client: Option<Rc<Client>>,
229}
230
231pub(crate) struct HandlerLock<'a> {
232    state: &'a State,
233}
234
235impl State {
236    pub(crate) fn remove_endpoint(&self, endpoint: &Endpoint) {
237        self.pollables.borrow_mut().remove(&endpoint.id);
238        self.poller.unregister(endpoint.socket.as_fd());
239        endpoint.unregistered.set(true);
240    }
241
242    fn acquire_handler_lock(&self) -> Result<HandlerLock<'_>, StateErrorKind> {
243        if self.global_lock_held.replace(true) {
244            return Err(StateErrorKind::RecursiveCall);
245        }
246        Ok(HandlerLock { state: self })
247    }
248
249    fn flush_locked(&self, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
250        let mut did_work = false;
251        did_work |= self.perform_writes(lock)?;
252        did_work |= self.kill_clients();
253        self.update_interests()?;
254        Ok(did_work)
255    }
256
257    pub(crate) fn handle_delete_id(&self, server: &Endpoint, id: u32) {
258        let object = server.objects.borrow_mut().remove(&id).unwrap();
259        let core = object.core();
260        core.server_obj_id.take();
261        server.idl.release(id);
262        if let Err((e, object)) = object.delete_id() {
263            log::warn!(
264                "Could not handle a wl_display.delete_id message: {}",
265                Report::new(e),
266            );
267            let _ = object.core().try_delete_id();
268        }
269    }
270
271    fn perform_writes(&self, _: &HandlerLock<'_>) -> Result<bool, StateError> {
272        if !self.has_flushable_endpoints.get() {
273            return Ok(false);
274        }
275        while let Some(ewc) = self.flushable_endpoints.pop() {
276            let res = match ewc.endpoint.flush() {
277                Ok(r) => r,
278                Err(e) => {
279                    let is_closed = matches!(e, EndpointError::Flush(TransError::Closed));
280                    if let Some(client) = &ewc.client {
281                        if !is_closed {
282                            log::warn!(
283                                "Could not write to client#{}: {}",
284                                client.endpoint.id,
285                                Report::new(e),
286                            );
287                        }
288                        self.add_client_to_kill(client);
289                    } else {
290                        if is_closed {
291                            return Err(StateErrorKind::ServerHangup.into());
292                        }
293                        return Err(StateErrorKind::WriteToServer(e).into());
294                    }
295                    continue;
296                }
297            };
298            match res {
299                FlushResult::Done => {
300                    ewc.endpoint.flush_queued.set(false);
301                    self.change_interest(&ewc.endpoint, |i| i & !poll::WRITABLE);
302                }
303                FlushResult::Blocked => {
304                    self.change_interest(&ewc.endpoint, |i| i | poll::WRITABLE);
305                }
306            }
307        }
308        self.has_flushable_endpoints.set(false);
309        Ok(true)
310    }
311
312    fn unsuspend_endpoints(self: &Rc<Self>, _lock: &HandlerLock<'_>) -> Result<(), StateError> {
313        if !self.has_unsuspend_requests.get() {
314            return Ok(());
315        }
316        self.check_destroyed()?;
317        while let Some(ewc) = self.unsuspend_requests.pop() {
318            ewc.endpoint.unsuspend_queued.set(false);
319            if ewc.endpoint.desired_suspended.get() {
320                continue;
321            }
322            ewc.endpoint.suspended.set(false);
323            self.readable_endpoints.push(ewc);
324            self.has_readable_endpoints.set(true);
325        }
326        self.has_unsuspend_requests.set(false);
327        Ok(())
328    }
329
330    fn accept_connections(self: &Rc<Self>, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
331        if !self.has_acceptable_acceptors.get() {
332            return Ok(false);
333        }
334        self.check_destroyed()?;
335        while let Some(acceptor) = self.acceptable_acceptors.pop() {
336            self.interest_update_acceptors.push(acceptor.clone());
337            self.has_interest_update_acceptors.set(true);
338            const MAX_ACCEPT_PER_ITERATION: usize = 10;
339            for _ in 0..MAX_ACCEPT_PER_ITERATION {
340                let socket = acceptor
341                    .accept()
342                    .map_err(StateErrorKind::AcceptConnection)?;
343                let Some(socket) = socket else {
344                    break;
345                };
346                self.create_client(Some(lock), &Rc::new(socket))?;
347            }
348        }
349        self.has_acceptable_acceptors.set(false);
350        Ok(true)
351    }
352
353    fn read_messages(&self, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
354        if !self.has_readable_endpoints.get() {
355            return Ok(false);
356        }
357        while let Some(ewc) = self.readable_endpoints.pop() {
358            let res = ewc.endpoint.read_messages(lock, ewc.client.as_ref());
359            if let Err(e) = res {
360                if let Some(client) = &ewc.client {
361                    log::error!("Could not handle client message: {}", Report::new(e));
362                    self.add_client_to_kill(client);
363                } else {
364                    if let EndpointError::HandleMessage(msg) = &e
365                        && let ObjectErrorKind::ServerError(object, server_id, error, msg) =
366                            &msg.source.0
367                        && let Some(handler) = self.handler.borrow_mut().take()
368                    {
369                        handler.display_error(object.as_ref(), *server_id, *error, &msg.0)
370                    }
371                    return Err(StateErrorKind::DispatchEvents(e).into());
372                }
373            }
374            if !ewc.endpoint.suspended.get() {
375                self.change_interest(&ewc.endpoint, |i| i | poll::READABLE);
376            }
377        }
378        self.has_readable_endpoints.set(false);
379        Ok(true)
380    }
381
382    pub(crate) fn set_endpoint_suspended(
383        &self,
384        endpoint: &Rc<Endpoint>,
385        client: Option<&Rc<Client>>,
386        suspended: bool,
387    ) {
388        if self.destroyed.get() {
389            return;
390        }
391        if suspended {
392            endpoint.suspended.set(true);
393            endpoint.desired_suspended.set(true);
394            return;
395        }
396        endpoint.desired_suspended.set(false);
397        if endpoint.unsuspend_queued.get() {
398            return;
399        }
400        if !self.unsuspend_triggered.get() {
401            if let Err(e) = uapi::eventfd_write(self.unsuspend_fd.as_raw_fd(), 1) {
402                log::error!(
403                    "Could not write to eventfd: {}",
404                    Report::new(io::Error::from(e)),
405                );
406                self.destroy();
407                return;
408            }
409            self.unsuspend_triggered.set(true);
410        }
411        self.unsuspend_requests.push(EndpointWithClient {
412            endpoint: endpoint.clone(),
413            client: client.cloned(),
414        });
415        endpoint.unsuspend_queued.set(true);
416    }
417
418    fn change_interest(&self, endpoint: &Rc<Endpoint>, f: impl FnOnce(u32) -> u32) {
419        if self.destroyed.get() {
420            return;
421        }
422        let old = endpoint.desired_interest.get();
423        let new = f(old);
424        endpoint.desired_interest.set(new);
425        if old != new
426            && endpoint.current_interest.get() != new
427            && !endpoint.interest_update_queued.replace(true)
428        {
429            self.interest_update_endpoints.push(endpoint.clone());
430            self.has_interest_update_endpoints.set(true);
431        }
432    }
433
434    pub(crate) fn add_flushable_endpoint(
435        &self,
436        endpoint: &Rc<Endpoint>,
437        client: Option<&Rc<Client>>,
438    ) {
439        if self.destroyed.get() {
440            return;
441        }
442        self.flushable_endpoints.push(EndpointWithClient {
443            endpoint: endpoint.clone(),
444            client: client.cloned(),
445        });
446        self.has_flushable_endpoints.set(true);
447    }
448
449    fn wait_for_work(&self, _: &HandlerLock<'_>, mut timeout: c::c_int) -> Result<(), StateError> {
450        self.check_destroyed()?;
451        let mut events = [PollEvent::default(); poll::MAX_EVENTS];
452        let pollables = &mut *self.pollables.borrow_mut();
453        loop {
454            let n = self
455                .poller
456                .read_events(timeout, &mut events)
457                .map_err(StateErrorKind::PollError)?;
458            if n == 0 {
459                return Ok(());
460            }
461            timeout = 0;
462            for event in &events[0..n] {
463                let id = event.u64;
464                let Some(pollable) = pollables.get(&id) else {
465                    continue;
466                };
467                match pollable {
468                    Pollable::Endpoint(ewc) => {
469                        let events = event.events;
470                        if events & poll::ERROR != 0 {
471                            if let Some(client) = &ewc.client {
472                                self.add_client_to_kill(client);
473                            } else {
474                                return Err(StateErrorKind::ServerHangup.into());
475                            }
476                            continue;
477                        }
478                        ewc.endpoint.current_interest.set(0);
479                        self.change_interest(&ewc.endpoint, |i| i & !events);
480                        if events & poll::READABLE != 0 {
481                            self.readable_endpoints.push(ewc.clone());
482                            self.has_readable_endpoints.set(true);
483                        }
484                        if events & poll::WRITABLE != 0 {
485                            self.flushable_endpoints.push(ewc.clone());
486                            self.has_flushable_endpoints.set(true);
487                        }
488                    }
489                    Pollable::Acceptor(a) => {
490                        self.acceptable_acceptors.push(a.clone());
491                        self.has_acceptable_acceptors.set(true);
492                    }
493                    Pollable::Destructor(fd, destroy) => {
494                        let destroy = destroy.load(Acquire);
495                        self.poller.unregister(fd.as_fd());
496                        pollables.remove(&id);
497                        if destroy {
498                            return Err(StateErrorKind::RemoteDestroyed.into());
499                        }
500                    }
501                    Pollable::Unsuspend => {
502                        self.has_unsuspend_requests.set(true);
503                        self.unsuspend_triggered.set(false);
504                    }
505                }
506            }
507        }
508    }
509
510    fn add_client_to_kill(&self, client: &Rc<Client>) {
511        self.clients_to_kill.push(client.clone());
512        self.has_clients_to_kill.set(true);
513    }
514
515    fn kill_clients(&self) -> bool {
516        if !self.has_clients_to_kill.get() {
517            return false;
518        }
519        while let Some(client) = self.clients_to_kill.pop() {
520            if let Some(handler) = client.handler.borrow_mut().take() {
521                handler.disconnected();
522            }
523            client.disconnect();
524        }
525        self.has_clients_to_kill.set(false);
526        true
527    }
528
529    fn create_pollable_id(&self) -> u64 {
530        let id = self.next_pollable_id.get();
531        self.next_pollable_id.set(id + 1);
532        id
533    }
534
535    fn update_interests(&self) -> Result<(), StateError> {
536        if self.has_interest_update_endpoints.get() {
537            while let Some(endpoint) = self.interest_update_endpoints.pop() {
538                endpoint.interest_update_queued.set(false);
539                let desired = endpoint.desired_interest.get();
540                if desired == endpoint.current_interest.get() {
541                    continue;
542                }
543                if endpoint.unregistered.get() {
544                    continue;
545                }
546                self.poller
547                    .update_interests(endpoint.id, endpoint.socket.as_fd(), desired)
548                    .map_err(StateErrorKind::PollError)?;
549                endpoint.current_interest.set(desired);
550            }
551            self.has_interest_update_endpoints.set(false);
552        }
553        if self.has_interest_update_acceptors.get() {
554            while let Some(acceptor) = self.interest_update_acceptors.pop() {
555                self.poller
556                    .update_interests(acceptor.id, acceptor.socket.as_fd(), poll::READABLE)
557                    .map_err(StateErrorKind::PollError)?;
558            }
559            self.has_interest_update_acceptors.set(false);
560        }
561        Ok(())
562    }
563
564    fn check_destroyed(&self) -> Result<(), StateError> {
565        if self.destroyed.get() {
566            return Err(StateErrorKind::Destroyed.into());
567        }
568        Ok(())
569    }
570
571    #[cfg(feature = "logging")]
572    #[cold]
573    pub(crate) fn log(&self, args: std::fmt::Arguments<'_>) {
574        use std::io::Write;
575        let writer = &mut *self.log_writer.borrow_mut();
576        let _ = writer.write_fmt(args);
577        let _ = writer.flush();
578    }
579}
580
581/// These functions can be used to create a new state.
582impl State {
583    /// Creates a new [`StateBuilder`].
584    pub fn builder(baseline: Baseline) -> StateBuilder {
585        StateBuilder::new(baseline)
586    }
587}
588
589/// These functions can be used to dispatch and flush messages.
590impl State {
591    /// Performs a blocking dispatch.
592    ///
593    /// This is a shorthand for `self.dispatch(None)`.
594    pub fn dispatch_blocking(self: &Rc<Self>) -> Result<bool, StateError> {
595        self.dispatch(None)
596    }
597
598    /// Performs a non-blocking dispatch.
599    ///
600    /// This is a shorthand for `self.dispatch(Some(Duration::from_secs(0))`.
601    pub fn dispatch_available(self: &Rc<Self>) -> Result<bool, StateError> {
602        self.dispatch(Some(Duration::from_secs(0)))
603    }
604
605    /// Performs a dispatch.
606    ///
607    /// The timeout determines how long this function will wait for new events. If the
608    /// timeout is `None`, then it will wait indefinitely. If the timeout is `0`, then
609    /// it will only process currently available events.
610    ///
611    /// If the timeout is not `0`, then outgoing messages will be flushed before waiting.
612    ///
613    /// Outgoing messages will be flushed immediately before this function returns.
614    ///
615    /// The return value indicates if any work was performed.
616    ///
617    /// This function is not reentrant. It should not be called from within a callback.
618    /// Trying to do so will cause it to return an error immediately and the state will
619    /// be otherwise unchanged.
620    pub fn dispatch(self: &Rc<Self>, timeout: Option<Duration>) -> Result<bool, StateError> {
621        let mut did_work = false;
622        let lock = self.acquire_handler_lock()?;
623        let timeout = timeout
624            .and_then(|t| t.as_millis().try_into().ok())
625            .unwrap_or(-1);
626        let destroy_on_error = on_drop(|| self.destroy());
627        if timeout != 0 {
628            did_work |= self.flush_locked(&lock)?;
629        }
630        self.wait_for_work(&lock, timeout)?;
631        self.unsuspend_endpoints(&lock)?;
632        did_work |= self.accept_connections(&lock)?;
633        did_work |= self.read_messages(&lock)?;
634        did_work |= self.flush_locked(&lock)?;
635        destroy_on_error.forget();
636        Ok(did_work)
637    }
638
639    /// Suspends or unsuspends dispatching messages from the server.
640    ///
641    /// See also [`Client::set_suspended`].
642    pub fn set_suspended(&self, suspended: bool) {
643        if let Some(endpoint) = &self.server {
644            self.set_endpoint_suspended(endpoint, None, suspended);
645        }
646    }
647}
648
649impl State {
650    /// Returns a file descriptor that can be used with epoll or similar.
651    ///
652    /// If this file descriptor becomes readable, the state should be dispatched.
653    /// [`Self::before_poll`] should be used before going to sleep.
654    ///
655    /// This function always returns the same file descriptor.
656    pub fn poll_fd(&self) -> &Rc<OwnedFd> {
657        self.poller.fd()
658    }
659
660    /// Prepares the state for an external poll operation.
661    ///
662    /// If [`Self::poll_fd`] is used, this function should be called immediately before
663    /// going to sleep. Otherwise, outgoing messages might not be flushed.
664    ///
665    /// ```
666    /// # use std::os::fd::OwnedFd;
667    /// # use std::rc::Rc;
668    /// # use wl_proxy::state::State;
669    /// # fn poll(fd: &OwnedFd) { }
670    /// # fn f(state: &Rc<State>) {
671    /// loop {
672    ///     state.before_poll().unwrap();
673    ///     poll(state.poll_fd());
674    ///     state.dispatch_available().unwrap();
675    /// }
676    /// # }
677    /// ```
678    pub fn before_poll(&self) -> Result<(), StateError> {
679        let lock = self.acquire_handler_lock()?;
680        let destroy_on_error = on_drop(|| self.destroy());
681        self.flush_locked(&lock)?;
682        destroy_on_error.forget();
683        Ok(())
684    }
685}
686
687/// These functions can be used to manipulate objects.
688impl State {
689    /// Creates a new object.
690    ///
691    /// The new object is not associated with a client ID or a server ID. It can become
692    /// associated with a client ID by sending an event with a `new_id` parameter. It can
693    /// become associated with a server ID by sending a request with a `new_id` parameter.
694    ///
695    /// The object can only be associated with one client at a time. The association with
696    /// a client is removed when the object is used in a destructor event.
697    ///
698    /// This function does not enforce that the version is less than or equal to the
699    /// maximum version supported by this crate. Using a version that exceeds tha maximum
700    /// supported version can cause a protocol error if the client sends a request that is
701    /// not available in the maximum supported protocol version or if the server sends an
702    /// event that is not available in the maximum supported protocol version.
703    pub fn create_object<P>(self: &Rc<Self>, version: u32) -> Rc<P>
704    where
705        P: Object,
706    {
707        P::new(self, version)
708    }
709
710    /// Returns a wl_display object.
711    pub fn display(self: &Rc<Self>) -> Rc<WlDisplay> {
712        let display = WlDisplay::new(self, 1);
713        if self.server.is_some() {
714            display.core().server_obj_id.set(Some(1));
715        }
716        display
717    }
718
719    /// Changes the default forward-to-client setting.
720    ///
721    /// This affects objects created after this call. See
722    /// [`ObjectCoreApi::set_forward_to_client`].
723    pub fn set_default_forward_to_client(&self, enabled: bool) {
724        self.forward_to_client.set(enabled);
725    }
726
727    /// Changes the default forward-to-server setting.
728    ///
729    /// This affects objects created after this call. See
730    /// [`ObjectCoreApi::set_forward_to_server`].
731    pub fn set_default_forward_to_server(&self, enabled: bool) {
732        self.forward_to_server.set(enabled);
733    }
734}
735
736/// These functions can be used to manage sockets associated with this state.
737impl State {
738    /// Creates a new connection to this proxy.
739    ///
740    /// The returned file descriptor is the client end of the connection and can be used
741    /// with a function such as `wl_display_connect_to_fd` or with the `WAYLAND_SOCKET`
742    /// environment variable.
743    ///
744    /// The [`StateHandler::new_client`] callback will not be invoked.
745    pub fn connect(self: &Rc<Self>) -> Result<(Rc<Client>, OwnedFd), StateError> {
746        let (server_fd, client_fd) = uapi::socketpair(
747            c::AF_UNIX,
748            c::SOCK_STREAM | c::SOCK_NONBLOCK | c::SOCK_CLOEXEC,
749            0,
750        )
751        .map_err(|e| StateErrorKind::Socketpair(e.into()))?;
752        let client = self.create_client(None, &Rc::new(server_fd.into()))?;
753        Ok((client, client_fd.into()))
754    }
755
756    /// Creates a new connection to this proxy from an existing socket.
757    ///
758    /// The file descriptor should be the server end of the connection. It can be created
759    /// with a function such as `socketpair` or by accepting a connection from a
760    /// file-system socket.
761    ///
762    /// The [`StateHandler::new_client`] callback will not be invoked.
763    pub fn add_client(self: &Rc<Self>, socket: &Rc<OwnedFd>) -> Result<Rc<Client>, StateError> {
764        self.create_client(None, socket)
765    }
766
767    /// Creates a new file-system acceptor and starts listening for connections.
768    ///
769    /// See [`Acceptor::new`] for the meaning of the `max_tries` parameter.
770    ///
771    /// Calling [`State::dispatch`] will automatically accept connections from this
772    /// acceptor. The [`StateHandler::new_client`] callback will be invoked when this
773    /// happens.
774    pub fn create_acceptor(&self, max_tries: u32) -> Result<Rc<Acceptor>, StateError> {
775        self.check_destroyed()?;
776        let id = self.create_pollable_id();
777        let acceptor =
778            Acceptor::create(id, max_tries, true).map_err(StateErrorKind::CreateAcceptor)?;
779        self.poller
780            .register(id, acceptor.socket.as_fd())
781            .map_err(StateErrorKind::PollError)?;
782        self.update_interests()?;
783        self.interest_update_acceptors.push(acceptor.clone());
784        self.has_interest_update_acceptors.set(true);
785        self.pollables
786            .borrow_mut()
787            .insert(id, Pollable::Acceptor(acceptor.clone()));
788        Ok(acceptor)
789    }
790
791    fn create_client(
792        self: &Rc<Self>,
793        lock: Option<&HandlerLock<'_>>,
794        socket: &Rc<OwnedFd>,
795    ) -> Result<Rc<Client>, StateError> {
796        self.check_destroyed()?;
797        let id = self.create_pollable_id();
798        self.poller
799            .register(id, socket.as_fd())
800            .map_err(StateErrorKind::PollError)?;
801        let endpoint = Endpoint::new(id, socket);
802        self.change_interest(&endpoint, |i| i | poll::READABLE);
803        self.update_interests()?;
804        let client = Rc::new(Client {
805            state: self.clone(),
806            endpoint: endpoint.clone(),
807            display: self.display(),
808            destroyed: Cell::new(false),
809            handler: Default::default(),
810        });
811        client
812            .display
813            .core()
814            .set_client_id(&client, 1, client.display.clone())
815            .unwrap();
816        self.pollables.borrow_mut().insert(
817            id,
818            Pollable::Endpoint(EndpointWithClient {
819                endpoint,
820                client: Some(client.clone()),
821            }),
822        );
823        if lock.is_some()
824            && let Some(handler) = &mut *self.handler.borrow_mut()
825        {
826            handler.new_client(&client);
827        }
828        Ok(client)
829    }
830}
831
832/// These functions can be used to manipulate the [`StateHandler`] of this state.
833///
834/// These functions can be called at any time, even from within a handler callback. In
835/// that case, the handler is replaced as soon as the callback returns.
836impl State {
837    /// Unsets the handler.
838    pub fn unset_handler(&self) {
839        self.handler.set(None);
840    }
841
842    /// Sets a new handler.
843    pub fn set_handler(&self, handler: impl StateHandler) {
844        self.set_boxed_handler(Box::new(handler))
845    }
846
847    /// Sets a new, already boxed handler.
848    pub fn set_boxed_handler(&self, handler: Box<dyn StateHandler>) {
849        if self.destroyed.get() {
850            return;
851        }
852        self.handler.set(Some(handler));
853    }
854}
855
856/// These functions can be used to check the state status and to destroy the state.
857impl State {
858    /// Returns whether this state is not destroyed.
859    ///
860    /// This is the same as `!self.is_destroyed()`.
861    pub fn is_not_destroyed(&self) -> bool {
862        !self.is_destroyed()
863    }
864
865    /// Returns whether the state is destroyed.
866    ///
867    /// If the state is destroyed, most functions that can return an error will return an
868    /// error saying that the state is already destroyed.
869    ///
870    /// This function or [`Self::is_not_destroyed`] should be used before dispatching the
871    /// state.
872    ///
873    /// # Example
874    ///
875    /// ```
876    /// # use std::rc::Rc;
877    /// # use error_reporter::Report;
878    /// # use wl_proxy::state::State;
879    /// #
880    /// # fn f(state: &Rc<State>) {
881    /// while state.is_not_destroyed() {
882    ///     if let Err(e) = state.dispatch_blocking() {
883    ///         log::error!("Could not dispatch the state: {}", Report::new(e));
884    ///     }
885    /// }
886    /// # }
887    /// ```
888    pub fn is_destroyed(&self) -> bool {
889        self.destroyed.get()
890    }
891
892    /// Destroys this state.
893    ///
894    /// This function unsets all handlers and destroys all clients. You should drop the
895    /// state after calling this function.
896    pub fn destroy(&self) {
897        if self.destroyed.replace(true) {
898            return;
899        }
900        let objects = &mut *self.object_stash.borrow();
901        for pollable in self.pollables.borrow().values() {
902            let fd = match pollable {
903                Pollable::Endpoint(ewc) => {
904                    if let Some(c) = &ewc.client {
905                        c.destroyed.set(true);
906                    }
907                    objects.extend(ewc.endpoint.objects.borrow_mut().drain().map(|v| v.1));
908                    &ewc.endpoint.socket
909                }
910                Pollable::Acceptor(a) => &a.socket,
911                Pollable::Destructor(fd, _) => fd,
912                Pollable::Unsuspend => &self.unsuspend_fd,
913            };
914            self.poller.unregister(fd.as_fd());
915        }
916        objects.clear();
917        for object in self.all_objects.borrow().values() {
918            if let Some(object) = object.upgrade() {
919                objects.push(object);
920            }
921        }
922        for object in objects {
923            object.unset_handler();
924            object.core().client.take();
925        }
926        self.handler.set(None);
927        self.pollables.borrow_mut().clear();
928        self.acceptable_acceptors.take();
929        self.clients_to_kill.take();
930        self.readable_endpoints.take();
931        self.flushable_endpoints.take();
932        self.interest_update_endpoints.take();
933        self.interest_update_acceptors.take();
934        self.unsuspend_requests.take();
935        self.all_objects.borrow_mut().clear();
936        // Ensure that the poll fd stays permanently readable.
937        let _ = self.create_remote_destructor();
938    }
939
940    /// Creates a RAII destructor for this state.
941    ///
942    /// Dropping the destructor will automatically call [`State::destroy`] unless you
943    /// first call [`Destructor::disable`].
944    ///
945    /// State objects contain reference cycles that must be cleared manually to release
946    /// the associated resources. Dropping the [`State`] is usually not sufficient to do
947    /// this. Instead, [`State::destroy`] must be called manually. This function can be
948    /// used to accomplish this in an application that otherwise relies on RAII semantics.
949    ///
950    /// Ensure that the destructor is itself not part of a reference cycle.
951    pub fn create_destructor(self: &Rc<Self>) -> Destructor {
952        Destructor {
953            state: self.clone(),
954            enabled: Cell::new(true),
955        }
956    }
957
958    /// Creates a `Sync+Send` RAII destructor for this state.
959    ///
960    /// This function is similar to [`State::create_destructor`] but the returned
961    /// destructor implements `Sync+Send`. This destructor can therefore be used to
962    /// destroy states running in a different thread.
963    pub fn create_remote_destructor(&self) -> Result<RemoteDestructor, StateError> {
964        let (r, w) = pipe().map_err(StateErrorKind::CreatePipe)?;
965        let r: OwnedFd = r.into();
966        let id = self.create_pollable_id();
967        self.poller
968            .register(id, r.as_fd())
969            .map_err(StateErrorKind::PollError)?;
970        let destroy = Arc::new(AtomicBool::new(false));
971        self.pollables
972            .borrow_mut()
973            .insert(id, Pollable::Destructor(r, destroy.clone()));
974        Ok(RemoteDestructor {
975            destroy,
976            _fd: w.into(),
977            enabled: AtomicBool::new(true),
978        })
979    }
980}
981
982impl StateError {
983    /// Returns whether this error was emitted because the state is already destroyed.
984    ///
985    /// This can be used to determine the severity of emitted log messages.
986    pub fn is_destroyed(&self) -> bool {
987        matches!(self.0, StateErrorKind::Destroyed)
988    }
989}
990
991impl Drop for HandlerLock<'_> {
992    fn drop(&mut self) {
993        self.state.global_lock_held.set(false);
994    }
995}