1use {
2 crate::{
3 baseline::Baseline,
4 endpoint::Endpoint,
5 object::{Object, ObjectPrivate},
6 poll::{self, Poller},
7 protocols::wayland::wl_display::WlDisplay,
8 state::{EndpointWithClient, Pollable, State, StateError, StateErrorKind},
9 utils::env::{WAYLAND_DISPLAY, WAYLAND_SOCKET, WL_PROXY_DEBUG, XDG_RUNTIME_DIR},
10 },
11 linearize::Linearize,
12 std::{
13 cell::{Cell, RefCell},
14 collections::HashMap,
15 env::{remove_var, var, var_os},
16 os::{
17 fd::{AsFd, FromRawFd, OwnedFd},
18 unix::ffi::OsStrExt,
19 },
20 rc::Rc,
21 str::FromStr,
22 },
23 uapi::c::{self, sockaddr_un},
24};
25
26pub struct StateBuilder {
30 baseline: Baseline,
31 server: Option<Server>,
32 log: bool,
33 log_prefix: String,
34}
35
36enum Server {
37 None,
38 Fd(Rc<OwnedFd>),
39 DisplayName(String),
40}
41
42#[derive(Copy, Clone, Linearize)]
43pub(crate) enum StaticPollableIds {
44 Server,
45 Unsuspend,
46}
47
48impl StateBuilder {
49 pub(super) fn new(baseline: Baseline) -> Self {
50 Self {
51 baseline,
52 server: Default::default(),
53 log: var(WL_PROXY_DEBUG).as_deref() == Ok("1"),
54 log_prefix: Default::default(),
55 }
56 }
57
58 pub fn build(self) -> Result<Rc<State>, StateError> {
69 let server_fd = 'fd: {
70 let display_name = match self.server {
71 None => None,
72 Some(Server::None) => break 'fd None,
73 Some(Server::Fd(fd)) => break 'fd Some(fd),
74 Some(Server::DisplayName(n)) => Some(n),
75 };
76 if display_name.is_none()
77 && let Some(wayland_socket) = var_os(WAYLAND_SOCKET)
78 {
79 let fd = str::from_utf8(wayland_socket.as_bytes())
80 .ok()
81 .and_then(|s| i32::from_str(s).ok())
82 .ok_or(StateErrorKind::WaylandSocketNotNumber)?;
83 let flags = uapi::fcntl_getfd(fd)
84 .map_err(|e| StateErrorKind::WaylandSocketGetFd(e.into()))?;
85 uapi::fcntl_setfd(fd, flags | c::FD_CLOEXEC)
86 .map_err(|e| StateErrorKind::WaylandSocketSetFd(e.into()))?;
87 let fd = unsafe {
89 remove_var(WAYLAND_SOCKET);
90 Rc::new(OwnedFd::from_raw_fd(fd))
91 };
92 break 'fd Some(fd);
93 }
94 let mut name = match display_name {
95 Some(n) => n,
96 _ => var(WAYLAND_DISPLAY)
97 .ok()
98 .ok_or(StateErrorKind::WaylandDisplay)?,
99 };
100 if name.is_empty() {
101 return Err(StateErrorKind::WaylandDisplayEmpty.into());
102 }
103 if !name.starts_with("/") {
104 let Ok(xrd) = var(XDG_RUNTIME_DIR) else {
105 return Err(StateErrorKind::XrdNotSet.into());
106 };
107 name = format!("{xrd}/{name}");
108 }
109 let mut addr = sockaddr_un {
110 sun_family: c::AF_UNIX as _,
111 sun_path: [0; 108],
112 };
113 if name.len() > addr.sun_path.len() - 1 {
114 return Err(StateErrorKind::SocketPathTooLong.into());
115 }
116 let sun_path = uapi::as_bytes_mut(&mut addr.sun_path[..]);
117 sun_path[..name.len()].copy_from_slice(name.as_bytes());
118 sun_path[name.len()] = 0;
119 let socket = uapi::socket(c::AF_UNIX, c::SOCK_STREAM | c::SOCK_CLOEXEC, 0)
120 .map_err(|e| StateErrorKind::CreateSocket(e.into()))?;
121 uapi::connect(socket.raw(), &addr)
122 .map_err(|e| StateErrorKind::Connect(name.to_string(), e.into()))?;
123 Some(Rc::new(socket.into()))
124 };
125 let mut endpoints = HashMap::new();
126 let mut server = None;
127 if let Some(server_fd) = &server_fd {
128 let s = Endpoint::new(StaticPollableIds::Server as u64, server_fd);
129 s.idl.acquire();
130 s.idl.acquire();
131 endpoints.insert(
132 StaticPollableIds::Server as u64,
133 Pollable::Endpoint(EndpointWithClient {
134 endpoint: s.clone(),
135 client: None,
136 }),
137 );
138 server = Some(s);
139 }
140 let unsuspend_fd = uapi::eventfd(0, c::EFD_CLOEXEC | c::EFD_NONBLOCK)
141 .map(Into::into)
142 .map_err(|e| StateErrorKind::CreateEventfd(e.into()))?;
143 endpoints.insert(StaticPollableIds::Unsuspend as u64, Pollable::Unsuspend);
144 let poller = Poller::new().map_err(StateErrorKind::PollError)?;
145 #[cfg(feature = "logging")]
146 let log_prefix = {
147 use {crate::utils::env::WL_PROXY_PREFIX, isnt::std_1::string::IsntStringExt};
148 let mut log_prefix = String::new();
149 if let Ok(prefix) = var(WL_PROXY_PREFIX) {
150 log_prefix = prefix;
151 }
152 if self.log_prefix.is_not_empty() {
153 if log_prefix.is_not_empty() {
154 log_prefix.push_str(" ");
155 }
156 log_prefix.push_str(&self.log_prefix);
157 }
158 if log_prefix.is_not_empty() {
159 log_prefix = format!("{{{}}} ", log_prefix);
160 }
161 log_prefix
162 };
163 let state = Rc::new(State {
164 baseline: self.baseline,
165 poller,
166 next_pollable_id: Cell::new(StaticPollableIds::LENGTH as u64),
167 server,
168 destroyed: Default::default(),
169 handler: Default::default(),
170 pollables: RefCell::new(endpoints),
171 acceptable_acceptors: Default::default(),
172 has_acceptable_acceptors: Default::default(),
173 clients_to_kill: Default::default(),
174 has_clients_to_kill: Default::default(),
175 readable_endpoints: Default::default(),
176 has_readable_endpoints: Default::default(),
177 flushable_endpoints: Default::default(),
178 has_flushable_endpoints: Default::default(),
179 interest_update_endpoints: Default::default(),
180 has_interest_update_endpoints: Default::default(),
181 interest_update_acceptors: Default::default(),
182 has_interest_update_acceptors: Default::default(),
183 all_objects: Default::default(),
184 next_object_id: Cell::new(1),
185 #[cfg(feature = "logging")]
186 log: self.log,
187 #[cfg(feature = "logging")]
188 log_prefix,
189 #[cfg(feature = "logging")]
190 log_writer: RefCell::new(std::io::BufWriter::with_capacity(
191 1024,
192 uapi::Fd::new(c::STDERR_FILENO),
193 )),
194 global_lock_held: Default::default(),
195 object_stash: Default::default(),
196 forward_to_client: Cell::new(true),
197 forward_to_server: Cell::new(true),
198 unsuspend_fd,
199 unsuspend_requests: Default::default(),
200 has_unsuspend_requests: Default::default(),
201 unsuspend_triggered: Default::default(),
202 });
203 if let Some(server) = &state.server {
204 state.change_interest(server, |i| i | poll::READABLE);
205 state
206 .poller
207 .register(server.id, server.socket.as_fd())
208 .map_err(StateErrorKind::PollError)?;
209 let display = WlDisplay::new(&state, 1);
210 display
211 .core()
212 .set_server_id_unchecked(1, display.clone())
213 .unwrap();
214 }
215 state
216 .poller
217 .register_edge_triggered(
218 StaticPollableIds::Unsuspend as u64,
219 state.unsuspend_fd.as_fd(),
220 poll::READABLE,
221 )
222 .map_err(StateErrorKind::PollError)?;
223 Ok(state)
224 }
225
226 pub fn without_server(mut self) -> Self {
228 self.server = Some(Server::None);
229 self
230 }
231
232 pub fn with_server_fd(mut self, fd: &Rc<OwnedFd>) -> Self {
234 self.server = Some(Server::Fd(fd.clone()));
235 self
236 }
237
238 pub fn with_server_display_name(mut self, name: &str) -> Self {
240 self.server = Some(Server::DisplayName(name.to_owned()));
241 self
242 }
243
244 pub fn with_logging(mut self, log: bool) -> Self {
249 self.log = log;
250 self
251 }
252
253 pub fn with_log_prefix(mut self, prefix: &str) -> Self {
255 self.log_prefix = prefix.to_string();
256 self
257 }
258}