1use std::collections::VecDeque;
4use std::env;
5use std::fmt;
6use std::io;
7use std::num::NonZeroU32;
8use std::os::fd::{AsRawFd, FromRawFd, RawFd};
9use std::os::unix::net::UnixStream;
10use std::path::PathBuf;
11
12use crate::debug_message::DebugMessage;
13use crate::global::BindError;
14use crate::global::GlobalExt;
15use crate::global::VersionBounds;
16use crate::object::{Object, ObjectManager, Proxy};
17use crate::protocol::wl_registry::GlobalArgs;
18use crate::protocol::*;
19use crate::EventCtx;
20
21use wayrs_core::transport::{BufferedSocket, PeekHeaderError, RecvMessageError, SendMessageError};
22use wayrs_core::{ArgType, ArgValue, Interface, IoMode, Message, MessageBuffersPool, ObjectId};
23
24#[cfg(feature = "tokio")]
25use tokio::io::unix::AsyncFd;
26
27#[derive(Debug)]
29pub enum ConnectError {
30 NotEnoughEnvVars,
32 Io(io::Error),
34}
35
36impl std::error::Error for ConnectError {}
37
38impl fmt::Display for ConnectError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 Self::NotEnoughEnvVars => {
42 f.write_str("both $XDG_RUNTIME_DIR and $WAYLAND_DISPLAY must be set")
43 }
44 Self::Io(error) => error.fmt(f),
45 }
46 }
47}
48
49impl From<io::Error> for ConnectError {
50 fn from(value: io::Error) -> Self {
51 Self::Io(value)
52 }
53}
54
55pub struct Connection<D> {
62 #[cfg(feature = "tokio")]
63 async_fd: Option<AsyncFd<RawFd>>,
64
65 socket: BufferedSocket<UnixStream>,
66 msg_buffers_pool: MessageBuffersPool,
67
68 object_mgr: ObjectManager<D>,
69
70 event_queue: VecDeque<QueuedEvent>,
71 requests_queue: VecDeque<Message>,
72 break_dispatch: bool,
73
74 registry: WlRegistry,
75 globals: Vec<GlobalArgs>,
76
77 registry_cbs: Option<Vec<RegistryCb<D>>>,
79
80 debug: bool,
81}
82
83enum QueuedEvent {
84 DeleteId(ObjectId),
85 RegistryEvent(wl_registry::Event),
86 Message(Message),
87}
88
89pub(crate) type GenericCallback<D> =
90 Box<dyn FnMut(&mut Connection<D>, &mut D, Object, Message) + Send>;
91
92type RegistryCb<D> = Box<dyn FnMut(&mut Connection<D>, &mut D, &wl_registry::Event) + Send>;
93
94impl<D> AsRawFd for Connection<D> {
95 fn as_raw_fd(&self) -> RawFd {
96 self.socket.as_raw_fd()
97 }
98}
99
100impl<D> Connection<D> {
101 pub fn connect() -> Result<Self, ConnectError> {
118 if let Some(fd) = env::var("WAYLAND_SOCKET")
119 .ok()
120 .and_then(|fd| fd.parse::<RawFd>().ok())
121 {
122 let stream = unsafe { UnixStream::from_raw_fd(fd) };
123 return Ok(Self::connect_with_unix_stream(stream));
124 }
125
126 let runtime_dir = env::var_os("XDG_RUNTIME_DIR").ok_or(ConnectError::NotEnoughEnvVars)?;
127 let wayland_disp = env::var_os("WAYLAND_DISPLAY").ok_or(ConnectError::NotEnoughEnvVars)?;
128
129 let mut path = PathBuf::new();
130 path.push(runtime_dir);
131 path.push(wayland_disp);
132
133 Ok(Self::connect_with_unix_stream(UnixStream::connect(path)?))
134 }
135
136 fn connect_with_unix_stream(stream: UnixStream) -> Self {
137 let mut this = Self {
138 #[cfg(feature = "tokio")]
139 async_fd: None,
140
141 socket: BufferedSocket::from(stream),
142 msg_buffers_pool: MessageBuffersPool::default(),
143
144 object_mgr: ObjectManager::new(),
145
146 event_queue: VecDeque::with_capacity(32),
147 requests_queue: VecDeque::with_capacity(32),
148 break_dispatch: false,
149
150 registry: WlRegistry::new(ObjectId::MAX_CLIENT, 1), globals: Vec::new(),
152 registry_cbs: Some(Vec::new()),
153
154 debug: std::env::var_os("WAYLAND_DEBUG").is_some(),
155 };
156
157 this.registry = WlDisplay::INSTANCE.get_registry(&mut this);
158
159 this
160 }
161
162 #[deprecated = "use blocking_roundtrip() + bind_singleton() instead"]
166 pub fn connect_and_collect_globals() -> Result<(Self, Vec<GlobalArgs>), ConnectError> {
167 let mut this = Self::connect()?;
168 this.blocking_roundtrip()?;
169 let globals = this.globals.clone();
170 this.event_queue.clear();
171 Ok((this, globals))
172 }
173
174 #[cfg(feature = "tokio")]
176 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
177 #[deprecated = "use async_roundtrip() + bind_singleton() instead"]
178 pub async fn async_connect_and_collect_globals() -> Result<(Self, Vec<GlobalArgs>), ConnectError>
179 {
180 let mut this = Self::connect()?;
181 this.async_roundtrip().await?;
182 let globals = this.globals.clone();
183 this.event_queue.clear();
184 Ok((this, globals))
185 }
186
187 #[must_use]
192 pub fn registry(&self) -> WlRegistry {
193 self.registry
194 }
195
196 #[must_use]
203 pub fn globals(&self) -> &[GlobalArgs] {
204 &self.globals
205 }
206
207 pub fn bind_singleton<P: Proxy>(
220 &mut self,
221 version: impl VersionBounds,
222 ) -> Result<P, BindError> {
223 assert!(version.upper() <= P::INTERFACE.version);
224
225 let i = self
226 .globals
227 .iter()
228 .position(|g| g.is::<P>())
229 .ok_or(BindError::GlobalNotFound(P::INTERFACE.name))?;
230
231 if self.globals[i].version < version.lower() {
232 return Err(BindError::UnsupportedVersion {
233 actual: self.globals[i].version,
234 min: version.lower(),
235 });
236 }
237
238 let name = self.globals[i].name;
239 let version = u32::min(version.upper(), self.globals[i].version);
240 Ok(self.registry.bind(self, name, version))
241 }
242
243 pub fn bind_singleton_with_cb<P: Proxy, F: FnMut(EventCtx<D, P>) + Send + 'static>(
245 &mut self,
246 version: impl VersionBounds,
247 cb: F,
248 ) -> Result<P, BindError> {
249 assert!(version.upper() <= P::INTERFACE.version);
250
251 let i = self
252 .globals
253 .iter()
254 .position(|g| g.is::<P>())
255 .ok_or(BindError::GlobalNotFound(P::INTERFACE.name))?;
256
257 if self.globals[i].version < version.lower() {
258 return Err(BindError::UnsupportedVersion {
259 actual: self.globals[i].version,
260 min: version.lower(),
261 });
262 }
263
264 let name = self.globals[i].name;
265 let version = u32::min(version.upper(), self.globals[i].version);
266 Ok(self.registry.bind_with_cb(self, name, version, cb))
267 }
268
269 pub fn add_registry_cb<
278 F: FnMut(&mut Connection<D>, &mut D, &wl_registry::Event) + Send + 'static,
279 >(
280 &mut self,
281 cb: F,
282 ) {
283 self.registry_cbs
284 .as_mut()
285 .expect("add_registry_cb called from registry callback")
286 .push(Box::new(cb));
287 }
288
289 pub fn set_callback_for<P: Proxy, F: FnMut(EventCtx<D, P>) + Send + 'static>(
302 &mut self,
303 proxy: P,
304 cb: F,
305 ) {
306 assert_ne!(
307 P::INTERFACE,
308 WlRegistry::INTERFACE,
309 "attempt to set a callback for wl_registry"
310 );
311
312 let obj = self
313 .object_mgr
314 .get_object_mut(proxy.id())
315 .expect("attempt to set a callback for non-existing object");
316
317 assert_eq!(obj.object, proxy.id(), "object mismatch");
318 assert!(obj.is_alive, "attempt to set a callback for dead object");
319
320 obj.cb = Some(Self::make_generic_cb(cb));
321 }
322
323 #[must_use]
327 #[deprecated = "this function is error-prone and best avoided"]
328 pub fn clear_callbacks<D2>(self) -> Connection<D2> {
329 Connection {
330 #[cfg(feature = "tokio")]
331 async_fd: self.async_fd,
332 socket: self.socket,
333 msg_buffers_pool: self.msg_buffers_pool,
334 object_mgr: self.object_mgr.clear_callbacks(),
335 event_queue: self.event_queue,
336 requests_queue: self.requests_queue,
337 break_dispatch: self.break_dispatch,
338 registry: self.registry,
339 globals: self.globals,
340 registry_cbs: Some(Vec::new()),
341 debug: self.debug,
342 }
343 }
344
345 pub fn blocking_roundtrip(&mut self) -> io::Result<()> {
350 let sync_cb = WlDisplay::INSTANCE.sync(self);
351 self.flush(IoMode::Blocking)?;
352
353 loop {
354 match self.recv_event(IoMode::Blocking)? {
355 QueuedEvent::Message(m) if m.header.object_id == sync_cb => break,
356 other => self.event_queue.push_back(other),
357 }
358 }
359
360 Ok(())
361 }
362
363 #[cfg(feature = "tokio")]
365 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
366 pub async fn async_roundtrip(&mut self) -> io::Result<()> {
367 let sync_cb = WlDisplay::INSTANCE.sync(self);
368 self.async_flush().await?;
369
370 loop {
371 match self.async_recv_event().await? {
372 QueuedEvent::Message(m) if m.header.object_id == sync_cb => break,
373 other => self.event_queue.push_back(other),
374 }
375 }
376
377 Ok(())
378 }
379
380 #[doc(hidden)]
381 pub fn alloc_msg_args(&mut self) -> Vec<ArgValue> {
382 self.msg_buffers_pool.get_args()
383 }
384
385 #[doc(hidden)]
386 pub fn send_request(&mut self, iface: &'static Interface, request: Message) {
387 let obj = self
388 .object_mgr
389 .get_object_mut(request.header.object_id)
390 .expect("attempt to send request for non-existing object");
391 assert!(obj.is_alive, "attempt to send request for dead object");
392
393 if self.debug {
394 eprintln!(
395 "[wayrs] -> {:?}",
396 DebugMessage::new(&request, false, obj.object)
397 );
398 }
399
400 if iface.requests[request.header.opcode as usize].is_destructor {
402 obj.is_alive = false;
403 }
404
405 self.requests_queue.push_back(request);
407 }
408
409 fn recv_event(&mut self, mode: IoMode) -> io::Result<QueuedEvent> {
410 let header = self
411 .socket
412 .peek_message_header(mode)
413 .map_err(|err| match err {
414 PeekHeaderError::Io(io) => io,
415 other => io::Error::new(io::ErrorKind::InvalidData, other),
416 })?;
417
418 let obj = self
419 .object_mgr
420 .get_object_mut(header.object_id)
421 .expect("received event for non-existing object");
422 let object = obj.object;
423 let signature = object
424 .interface
425 .events
426 .get(header.opcode as usize)
427 .expect("incorrect opcode")
428 .signature;
429
430 let event = self
431 .socket
432 .recv_message(header, signature, &mut self.msg_buffers_pool, mode)
433 .map_err(|err| match err {
434 RecvMessageError::Io(io) => io,
435 other => io::Error::new(io::ErrorKind::InvalidData, other),
436 })?;
437 if self.debug {
438 eprintln!("[wayrs] {:?}", DebugMessage::new(&event, true, object));
439 }
440
441 if event.header.object_id == ObjectId::DISPLAY {
442 match WlDisplay::parse_event(event, 1, &mut self.msg_buffers_pool).unwrap() {
443 wl_display::Event::Error(err) => {
444 return Err(io::Error::other(format!(
446 "Error in object {} (code({})): {}",
447 err.object_id.0,
448 err.code,
449 err.message.to_string_lossy(),
450 )));
451 }
452 wl_display::Event::DeleteId(id) => {
453 return Ok(QueuedEvent::DeleteId(ObjectId(
454 NonZeroU32::new(id).ok_or_else(|| {
455 io::Error::new(
456 io::ErrorKind::InvalidData,
457 "wl_display.delete_id with null id",
458 )
459 })?,
460 )));
461 }
462 };
463 }
464
465 if event.header.object_id == self.registry {
466 let event = WlRegistry::parse_event(event, 1, &mut self.msg_buffers_pool).unwrap();
467 match &event {
468 wl_registry::Event::Global(global) => {
469 self.globals.push(global.clone());
470 }
471 wl_registry::Event::GlobalRemove(name) => {
472 if let Some(i) = self.globals.iter().position(|g| g.name == *name) {
473 self.globals.swap_remove(i);
474 }
475 }
476 }
477 return Ok(QueuedEvent::RegistryEvent(event));
478 }
479
480 let signature = object
482 .interface
483 .events
484 .get(header.opcode as usize)
485 .expect("incorrect opcode")
486 .signature;
487 for (arg, arg_ty) in event.args.iter().zip(signature) {
488 match arg {
489 ArgValue::NewId(id) => {
490 let ArgType::NewId(interface) = arg_ty else {
491 unreachable!()
492 };
493 self.object_mgr.register_server_object(Object {
494 id: *id,
495 interface,
496 version: object.version,
497 });
498 }
499 ArgValue::AnyNewId(_, _, _) => unimplemented!(),
500 _ => (),
501 }
502 }
503
504 Ok(QueuedEvent::Message(event))
505 }
506
507 #[cfg(feature = "tokio")]
508 async fn async_recv_event(&mut self) -> io::Result<QueuedEvent> {
509 let mut async_fd = match self.async_fd.take() {
510 Some(fd) => fd,
511 None => AsyncFd::new(self.as_raw_fd())?,
512 };
513
514 loop {
515 let mut fd_guard = async_fd.readable_mut().await?;
516 match self.recv_event(IoMode::NonBlocking) {
517 Err(e) if e.kind() == io::ErrorKind::WouldBlock => fd_guard.clear_ready(),
518 result => {
519 self.async_fd = Some(async_fd);
520 return result;
521 }
522 }
523 }
524 }
525
526 pub fn recv_events(&mut self, mut mode: IoMode) -> io::Result<()> {
537 let mut at_least_one = false;
538
539 loop {
540 let msg = match self.recv_event(mode) {
541 Ok(msg) => msg,
542 Err(e) if e.kind() == io::ErrorKind::WouldBlock && at_least_one => return Ok(()),
543 Err(e) => return Err(e),
544 };
545
546 at_least_one = true;
547 mode = IoMode::NonBlocking;
548 self.event_queue.push_back(msg);
549 }
550 }
551
552 #[cfg(feature = "tokio")]
554 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
555 pub async fn async_recv_events(&mut self) -> io::Result<()> {
556 let msg = self.async_recv_event().await?;
557 self.event_queue.push_back(msg);
558
559 loop {
560 match self.recv_event(IoMode::NonBlocking) {
561 Ok(msg) => self.event_queue.push_back(msg),
562 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(()),
563 Err(e) => return Err(e),
564 };
565 }
566 }
567
568 pub fn flush(&mut self, mode: IoMode) -> io::Result<()> {
570 while let Some(msg) = self.requests_queue.pop_front() {
572 if let Err(SendMessageError { msg, err }) =
573 self.socket
574 .write_message(msg, &mut self.msg_buffers_pool, mode)
575 {
576 self.requests_queue.push_front(msg);
577 return Err(err);
578 }
579 }
580
581 self.socket.flush(mode)
583 }
584
585 #[cfg(feature = "tokio")]
587 #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
588 pub async fn async_flush(&mut self) -> io::Result<()> {
589 match self.flush(IoMode::NonBlocking) {
591 Err(e) if e.kind() == io::ErrorKind::WouldBlock => (),
592 result => return result,
593 }
594
595 let mut async_fd = match self.async_fd.take() {
596 Some(fd) => fd,
597 None => AsyncFd::new(self.as_raw_fd())?,
598 };
599
600 loop {
601 let mut fd_guard = async_fd.writable_mut().await?;
602 match self.flush(IoMode::NonBlocking) {
603 Err(e) if e.kind() == io::ErrorKind::WouldBlock => fd_guard.clear_ready(),
604 result => {
605 self.async_fd = Some(async_fd);
606 return result;
607 }
608 }
609 }
610 }
611
612 pub fn dispatch_events(&mut self, state: &mut D) {
618 self.break_dispatch = false;
619
620 while let Some(event) = self.event_queue.pop_front() {
621 match event {
622 QueuedEvent::DeleteId(id) => self.object_mgr.delete_client_object(id),
623 QueuedEvent::RegistryEvent(event) => {
624 let mut registry_cbs = self
625 .registry_cbs
626 .take()
627 .expect("dispatch_events called from registry callback");
628
629 for cb in &mut registry_cbs {
630 cb(self, state, &event);
631 }
632
633 self.registry_cbs = Some(registry_cbs);
634
635 if self.break_dispatch {
636 break;
637 }
638 }
639 QueuedEvent::Message(event) => {
640 let object = match self.object_mgr.get_object_mut(event.header.object_id) {
641 Some(obj) if obj.is_alive => obj,
642 _ => continue, };
644
645 let mut object_cb = object.cb.take();
647 let object = object.object;
648 let opcode = event.header.opcode;
649
650 if let Some(cb) = &mut object_cb {
651 cb(self, state, object, event);
652 }
653
654 let object = self.object_mgr.get_object_mut(object.id).unwrap();
655
656 if object.object.interface.events[opcode as usize].is_destructor {
658 object.is_alive = false;
659 }
660
661 if object.is_alive && object.cb.is_none() {
663 object.cb = object_cb;
664 }
665
666 if self.break_dispatch {
667 break;
668 }
669 }
670 }
671 }
672 }
673
674 pub fn break_dispatch_loop(&mut self) {
679 self.break_dispatch = true;
680 }
681
682 #[doc(hidden)]
684 pub fn allocate_new_object<P: Proxy>(&mut self, version: u32) -> P {
685 let id = self
686 .object_mgr
687 .alloc_client_object(P::INTERFACE, version)
688 .object
689 .id;
690 P::new(id, version)
691 }
692
693 #[doc(hidden)]
696 pub fn allocate_new_object_with_cb<P: Proxy, F: FnMut(EventCtx<D, P>) + Send + 'static>(
697 &mut self,
698 version: u32,
699 cb: F,
700 ) -> P {
701 let state = self.object_mgr.alloc_client_object(P::INTERFACE, version);
702 state.cb = Some(Self::make_generic_cb(cb));
703 P::new(state.object.id, version)
704 }
705
706 fn make_generic_cb<P: Proxy, F: FnMut(EventCtx<D, P>) + Send + 'static>(
707 mut cb: F,
708 ) -> GenericCallback<D> {
709 Box::new(move |conn, state, object, event| {
711 let proxy: P = object.try_into().unwrap();
712 let event = P::parse_event(event, object.version, &mut conn.msg_buffers_pool).unwrap();
713 let ctx = EventCtx {
714 conn,
715 state,
716 proxy,
717 event,
718 };
719 cb(ctx);
720 })
721 }
722}
723
724#[cfg(test)]
725mod tests {
726 use super::*;
727
728 fn assert_send<T: Send>() {}
729
730 #[test]
731 fn send() {
732 assert_send::<Connection<()>>();
733 }
734}