wl_proxy/state/
builder.rs

1use {
2    crate::{
3        baseline::Baseline,
4        endpoint::Endpoint,
5        object::{Object, ObjectPrivate},
6        poll::{self, Poller},
7        protocols::wayland::wl_display::WlDisplay,
8        state::{EndpointWithClient, Pollable, State, StateError, StateErrorKind},
9        utils::env::{WAYLAND_DISPLAY, WAYLAND_SOCKET, WL_PROXY_DEBUG, XDG_RUNTIME_DIR},
10    },
11    std::{
12        cell::{Cell, RefCell},
13        collections::HashMap,
14        env::{remove_var, var, var_os},
15        os::{
16            fd::{AsFd, FromRawFd, OwnedFd},
17            unix::ffi::OsStrExt,
18        },
19        rc::Rc,
20        str::FromStr,
21    },
22    uapi::c::{self, sockaddr_un},
23};
24
25/// A builder for a [`State`].
26///
27/// This type can be constructed with [`State::builder`].
28pub struct StateBuilder {
29    baseline: Baseline,
30    server: Option<Server>,
31    log: bool,
32    log_prefix: String,
33}
34
35enum Server {
36    None,
37    Fd(Rc<OwnedFd>),
38    DisplayName(String),
39}
40
41impl StateBuilder {
42    pub(super) fn new(baseline: Baseline) -> Self {
43        Self {
44            baseline,
45            server: Default::default(),
46            log: var(WL_PROXY_DEBUG).as_deref() == Ok("1"),
47            log_prefix: Default::default(),
48        }
49    }
50
51    /// Builds the state.
52    ///
53    /// The server to connect to is chosen as follows:
54    ///
55    /// - If [`Self::with_server_fd`] was used, that FD is used.
56    /// - Otherwise, if [`Self::with_server_display_name`] was used, that display name is
57    ///   used.
58    /// - Otherwise, if the `WAYLAND_SOCKET` environment variable is set, that FD is used.
59    /// - Otherwise, the display name from the `WAYLAND_DISPLAY` environment variable is
60    ///   used.
61    pub fn build(self) -> Result<Rc<State>, StateError> {
62        let server_fd = 'fd: {
63            let display_name = match self.server {
64                None => None,
65                Some(Server::None) => break 'fd None,
66                Some(Server::Fd(fd)) => break 'fd Some(fd),
67                Some(Server::DisplayName(n)) => Some(n),
68            };
69            if display_name.is_none()
70                && let Some(wayland_socket) = var_os(WAYLAND_SOCKET)
71            {
72                let fd = str::from_utf8(wayland_socket.as_bytes())
73                    .ok()
74                    .and_then(|s| i32::from_str(s).ok())
75                    .ok_or(StateErrorKind::WaylandSocketNotNumber)?;
76                let flags = uapi::fcntl_getfd(fd)
77                    .map_err(|e| StateErrorKind::WaylandSocketGetFd(e.into()))?;
78                uapi::fcntl_setfd(fd, flags | c::FD_CLOEXEC)
79                    .map_err(|e| StateErrorKind::WaylandSocketSetFd(e.into()))?;
80                // SAFETY: This is unsound.
81                let fd = unsafe {
82                    remove_var(WAYLAND_SOCKET);
83                    Rc::new(OwnedFd::from_raw_fd(fd))
84                };
85                break 'fd Some(fd);
86            }
87            let mut name = match display_name {
88                Some(n) => n,
89                _ => var(WAYLAND_DISPLAY)
90                    .ok()
91                    .ok_or(StateErrorKind::WaylandDisplay)?,
92            };
93            if name.is_empty() {
94                return Err(StateErrorKind::WaylandDisplayEmpty.into());
95            }
96            if !name.starts_with("/") {
97                let Ok(xrd) = var(XDG_RUNTIME_DIR) else {
98                    return Err(StateErrorKind::XrdNotSet.into());
99                };
100                name = format!("{xrd}/{name}");
101            }
102            let mut addr = sockaddr_un {
103                sun_family: c::AF_UNIX as _,
104                sun_path: [0; 108],
105            };
106            if name.len() > addr.sun_path.len() - 1 {
107                return Err(StateErrorKind::SocketPathTooLong.into());
108            }
109            let sun_path = uapi::as_bytes_mut(&mut addr.sun_path[..]);
110            sun_path[..name.len()].copy_from_slice(name.as_bytes());
111            sun_path[name.len()] = 0;
112            let socket = uapi::socket(c::AF_UNIX, c::SOCK_STREAM | c::SOCK_CLOEXEC, 0)
113                .map_err(|e| StateErrorKind::CreateSocket(e.into()))?;
114            uapi::connect(socket.raw(), &addr)
115                .map_err(|e| StateErrorKind::Connect(name.to_string(), e.into()))?;
116            Some(Rc::new(socket.into()))
117        };
118        const SERVER_ENDPOINT_ID: u64 = 0;
119        let mut endpoints = HashMap::new();
120        let mut server = None;
121        if let Some(server_fd) = &server_fd {
122            let s = Endpoint::new(SERVER_ENDPOINT_ID, server_fd);
123            s.idl.acquire();
124            s.idl.acquire();
125            endpoints.insert(
126                SERVER_ENDPOINT_ID,
127                Pollable::Endpoint(EndpointWithClient {
128                    endpoint: s.clone(),
129                    client: None,
130                }),
131            );
132            server = Some(s);
133        }
134        let poller = Poller::new().map_err(StateErrorKind::PollError)?;
135        #[cfg(feature = "logging")]
136        let log_prefix = {
137            use {crate::utils::env::WL_PROXY_PREFIX, isnt::std_1::string::IsntStringExt};
138            let mut log_prefix = String::new();
139            if let Ok(prefix) = var(WL_PROXY_PREFIX) {
140                log_prefix = prefix;
141            }
142            if self.log_prefix.is_not_empty() {
143                if log_prefix.is_not_empty() {
144                    log_prefix.push_str(" ");
145                }
146                log_prefix.push_str(&self.log_prefix);
147            }
148            if log_prefix.is_not_empty() {
149                log_prefix = format!("{{{}}} ", log_prefix);
150            }
151            log_prefix
152        };
153        let state = Rc::new(State {
154            baseline: self.baseline,
155            poller,
156            next_pollable_id: Cell::new(SERVER_ENDPOINT_ID + 1),
157            server,
158            destroyed: Default::default(),
159            handler: Default::default(),
160            pollables: RefCell::new(endpoints),
161            acceptable_acceptors: Default::default(),
162            has_acceptable_acceptors: Default::default(),
163            clients_to_kill: Default::default(),
164            has_clients_to_kill: Default::default(),
165            readable_endpoints: Default::default(),
166            has_readable_endpoints: Default::default(),
167            flushable_endpoints: Default::default(),
168            has_flushable_endpoints: Default::default(),
169            interest_update_endpoints: Default::default(),
170            has_interest_update_endpoints: Default::default(),
171            interest_update_acceptors: Default::default(),
172            has_interest_update_acceptors: Default::default(),
173            all_objects: Default::default(),
174            next_object_id: Default::default(),
175            #[cfg(feature = "logging")]
176            log: self.log,
177            #[cfg(feature = "logging")]
178            log_prefix,
179            #[cfg(feature = "logging")]
180            log_writer: RefCell::new(std::io::BufWriter::with_capacity(
181                1024,
182                uapi::Fd::new(c::STDERR_FILENO),
183            )),
184            global_lock_held: Default::default(),
185            object_stash: Default::default(),
186            forward_to_client: Cell::new(true),
187            forward_to_server: Cell::new(true),
188        });
189        if let Some(server) = &state.server {
190            state.change_interest(server, |i| i | poll::READABLE);
191            state
192                .poller
193                .register(0, server.socket.as_fd())
194                .map_err(StateErrorKind::PollError)?;
195            let display = WlDisplay::new(&state, 1);
196            display
197                .core()
198                .set_server_id_unchecked(1, display.clone())
199                .unwrap();
200        }
201        Ok(state)
202    }
203
204    /// Constructs a state without a server.
205    pub fn without_server(mut self) -> Self {
206        self.server = Some(Server::None);
207        self
208    }
209
210    /// Sets the server file descriptor to connect to.
211    pub fn with_server_fd(mut self, fd: &Rc<OwnedFd>) -> Self {
212        self.server = Some(Server::Fd(fd.clone()));
213        self
214    }
215
216    /// Sets the server display name to connect to.
217    pub fn with_server_display_name(mut self, name: &str) -> Self {
218        self.server = Some(Server::DisplayName(name.to_owned()));
219        self
220    }
221
222    /// Enables or disables logging.
223    ///
224    /// If this function is not used, then logging is enabled if and only if the
225    /// `WL_PROXY_DEBUG` environment variable is set to `1`.
226    pub fn with_logging(mut self, log: bool) -> Self {
227        self.log = log;
228        self
229    }
230
231    /// Sets a log prefix for messages emitted by this state.
232    pub fn with_log_prefix(mut self, prefix: &str) -> Self {
233        self.log_prefix = prefix.to_string();
234        self
235    }
236}