1use {
2 crate::{
3 baseline::Baseline,
4 endpoint::Endpoint,
5 object::{Object, ObjectPrivate},
6 poll::{self, Poller},
7 protocols::wayland::wl_display::WlDisplay,
8 state::{EndpointWithClient, Pollable, State, StateError, StateErrorKind},
9 utils::env::{WAYLAND_DISPLAY, WAYLAND_SOCKET, WL_PROXY_DEBUG, XDG_RUNTIME_DIR},
10 },
11 std::{
12 cell::{Cell, RefCell},
13 collections::HashMap,
14 env::{remove_var, var, var_os},
15 os::{
16 fd::{AsFd, FromRawFd, OwnedFd},
17 unix::ffi::OsStrExt,
18 },
19 rc::Rc,
20 str::FromStr,
21 },
22 uapi::c::{self, sockaddr_un},
23};
24
25pub struct StateBuilder {
29 baseline: Baseline,
30 server: Option<Server>,
31 log: bool,
32 log_prefix: String,
33}
34
35enum Server {
36 None,
37 Fd(Rc<OwnedFd>),
38 DisplayName(String),
39}
40
41impl StateBuilder {
42 pub(super) fn new(baseline: Baseline) -> Self {
43 Self {
44 baseline,
45 server: Default::default(),
46 log: var(WL_PROXY_DEBUG).as_deref() == Ok("1"),
47 log_prefix: Default::default(),
48 }
49 }
50
51 pub fn build(self) -> Result<Rc<State>, StateError> {
62 let server_fd = 'fd: {
63 let display_name = match self.server {
64 None => None,
65 Some(Server::None) => break 'fd None,
66 Some(Server::Fd(fd)) => break 'fd Some(fd),
67 Some(Server::DisplayName(n)) => Some(n),
68 };
69 if display_name.is_none()
70 && let Some(wayland_socket) = var_os(WAYLAND_SOCKET)
71 {
72 let fd = str::from_utf8(wayland_socket.as_bytes())
73 .ok()
74 .and_then(|s| i32::from_str(s).ok())
75 .ok_or(StateErrorKind::WaylandSocketNotNumber)?;
76 let flags = uapi::fcntl_getfd(fd)
77 .map_err(|e| StateErrorKind::WaylandSocketGetFd(e.into()))?;
78 uapi::fcntl_setfd(fd, flags | c::FD_CLOEXEC)
79 .map_err(|e| StateErrorKind::WaylandSocketSetFd(e.into()))?;
80 let fd = unsafe {
82 remove_var(WAYLAND_SOCKET);
83 Rc::new(OwnedFd::from_raw_fd(fd))
84 };
85 break 'fd Some(fd);
86 }
87 let mut name = match display_name {
88 Some(n) => n,
89 _ => var(WAYLAND_DISPLAY)
90 .ok()
91 .ok_or(StateErrorKind::WaylandDisplay)?,
92 };
93 if name.is_empty() {
94 return Err(StateErrorKind::WaylandDisplayEmpty.into());
95 }
96 if !name.starts_with("/") {
97 let Ok(xrd) = var(XDG_RUNTIME_DIR) else {
98 return Err(StateErrorKind::XrdNotSet.into());
99 };
100 name = format!("{xrd}/{name}");
101 }
102 let mut addr = sockaddr_un {
103 sun_family: c::AF_UNIX as _,
104 sun_path: [0; 108],
105 };
106 if name.len() > addr.sun_path.len() - 1 {
107 return Err(StateErrorKind::SocketPathTooLong.into());
108 }
109 let sun_path = uapi::as_bytes_mut(&mut addr.sun_path[..]);
110 sun_path[..name.len()].copy_from_slice(name.as_bytes());
111 sun_path[name.len()] = 0;
112 let socket = uapi::socket(c::AF_UNIX, c::SOCK_STREAM | c::SOCK_CLOEXEC, 0)
113 .map_err(|e| StateErrorKind::CreateSocket(e.into()))?;
114 uapi::connect(socket.raw(), &addr)
115 .map_err(|e| StateErrorKind::Connect(name.to_string(), e.into()))?;
116 Some(Rc::new(socket.into()))
117 };
118 const SERVER_ENDPOINT_ID: u64 = 0;
119 let mut endpoints = HashMap::new();
120 let mut server = None;
121 if let Some(server_fd) = &server_fd {
122 let s = Endpoint::new(SERVER_ENDPOINT_ID, server_fd);
123 s.idl.acquire();
124 s.idl.acquire();
125 endpoints.insert(
126 SERVER_ENDPOINT_ID,
127 Pollable::Endpoint(EndpointWithClient {
128 endpoint: s.clone(),
129 client: None,
130 }),
131 );
132 server = Some(s);
133 }
134 let poller = Poller::new().map_err(StateErrorKind::PollError)?;
135 #[cfg(feature = "logging")]
136 let log_prefix = {
137 use {crate::utils::env::WL_PROXY_PREFIX, isnt::std_1::string::IsntStringExt};
138 let mut log_prefix = String::new();
139 if let Ok(prefix) = var(WL_PROXY_PREFIX) {
140 log_prefix = prefix;
141 }
142 if self.log_prefix.is_not_empty() {
143 if log_prefix.is_not_empty() {
144 log_prefix.push_str(" ");
145 }
146 log_prefix.push_str(&self.log_prefix);
147 }
148 if log_prefix.is_not_empty() {
149 log_prefix = format!("{{{}}} ", log_prefix);
150 }
151 log_prefix
152 };
153 let state = Rc::new(State {
154 baseline: self.baseline,
155 poller,
156 next_pollable_id: Cell::new(SERVER_ENDPOINT_ID + 1),
157 server,
158 destroyed: Default::default(),
159 handler: Default::default(),
160 pollables: RefCell::new(endpoints),
161 acceptable_acceptors: Default::default(),
162 has_acceptable_acceptors: Default::default(),
163 clients_to_kill: Default::default(),
164 has_clients_to_kill: Default::default(),
165 readable_endpoints: Default::default(),
166 has_readable_endpoints: Default::default(),
167 flushable_endpoints: Default::default(),
168 has_flushable_endpoints: Default::default(),
169 interest_update_endpoints: Default::default(),
170 has_interest_update_endpoints: Default::default(),
171 interest_update_acceptors: Default::default(),
172 has_interest_update_acceptors: Default::default(),
173 all_objects: Default::default(),
174 next_object_id: Default::default(),
175 #[cfg(feature = "logging")]
176 log: self.log,
177 #[cfg(feature = "logging")]
178 log_prefix,
179 #[cfg(feature = "logging")]
180 log_writer: RefCell::new(std::io::BufWriter::with_capacity(
181 1024,
182 uapi::Fd::new(c::STDERR_FILENO),
183 )),
184 global_lock_held: Default::default(),
185 object_stash: Default::default(),
186 forward_to_client: Cell::new(true),
187 forward_to_server: Cell::new(true),
188 });
189 if let Some(server) = &state.server {
190 state.change_interest(server, |i| i | poll::READABLE);
191 state
192 .poller
193 .register(0, server.socket.as_fd())
194 .map_err(StateErrorKind::PollError)?;
195 let display = WlDisplay::new(&state, 1);
196 display
197 .core()
198 .set_server_id_unchecked(1, display.clone())
199 .unwrap();
200 }
201 Ok(state)
202 }
203
204 pub fn without_server(mut self) -> Self {
206 self.server = Some(Server::None);
207 self
208 }
209
210 pub fn with_server_fd(mut self, fd: &Rc<OwnedFd>) -> Self {
212 self.server = Some(Server::Fd(fd.clone()));
213 self
214 }
215
216 pub fn with_server_display_name(mut self, name: &str) -> Self {
218 self.server = Some(Server::DisplayName(name.to_owned()));
219 self
220 }
221
222 pub fn with_logging(mut self, log: bool) -> Self {
227 self.log = log;
228 self
229 }
230
231 pub fn with_log_prefix(mut self, prefix: &str) -> Self {
233 self.log_prefix = prefix.to_string();
234 self
235 }
236}