xcb_rust_connection/
connection.rs

1use alloc::collections::VecDeque;
2use alloc::string::{String, ToString};
3use alloc::vec::Vec;
4use alloc::{format, vec};
5use core::time::Duration;
6use rusl::error::Errno;
7use rusl::platform::{PollEvents, PollFd, TimeSpec};
8use rusl::string::unix_str::{UnixStr, UnixString};
9
10use smallmap::{Map, Set};
11use tiny_std::fs::read;
12use tiny_std::io::{Read, Write};
13use tiny_std::net::UnixStream;
14use tiny_std::time::MonotonicInstant;
15use tiny_std::unix::fd::{AsRawFd, OwnedFd, RawFd};
16
17use xcb_rust_protocol::con::{SocketIo, XcbBuffer, XcbBuffers, XcbState};
18use xcb_rust_protocol::connection::bigreq::enable;
19use xcb_rust_protocol::connection::xproto::{
20    change_property, get_input_focus, list_extensions, query_extension,
21};
22use xcb_rust_protocol::cookie::VoidCookie;
23use xcb_rust_protocol::proto::find_extension;
24use xcb_rust_protocol::proto::xc_misc::GetXIDRangeReply;
25use xcb_rust_protocol::proto::xproto::{Atom, ListExtensionsReply, PropModeEnum, Setup, Window};
26use xcb_rust_protocol::util::{
27    parse_error, ExtensionInfoProvider, ExtensionInformation, VariableLengthFromBytes,
28    XcbErrorHandler,
29};
30use xcb_rust_protocol::{Error, XcbConnection, XcbEnv};
31
32use crate::helpers::basic_info_provider::BasicExtensionInfoProvider;
33use crate::helpers::connect::{get_setup_length, parse_setup, setup_request, Connect};
34use crate::helpers::id_allocator::IdAllocator;
35use crate::helpers::parse_display::ParsedDisplay;
36use crate::helpers::xauth::Family;
37use crate::{ConnectError, ConnectionError};
38
39#[derive(Debug)]
40pub struct XcbEventState {
41    setup: Setup,
42    seq_count: SeqCount,
43    event_cache: VecDeque<Vec<u8>>,
44    reply_cache: Map<u16, Vec<u8>>,
45    keep_seqs: Set<u16>,
46    id_allocator: IdAllocator,
47    max_request_length: usize,
48    pub extensions: BasicExtensionInfoProvider,
49}
50
51impl XcbEventState {
52    #[must_use]
53    pub fn new(setup: Setup) -> Self {
54        Self {
55            max_request_length: setup.maximum_request_length as usize, // It's the length in 32bit words
56            id_allocator: IdAllocator::new(setup.resource_id_base, setup.resource_id_mask).unwrap(),
57            setup,
58            seq_count: SeqCount::new(),
59            event_cache: VecDeque::new(),
60            reply_cache: Map::new(),
61            keep_seqs: Set::new(),
62            extensions: BasicExtensionInfoProvider::default(),
63        }
64    }
65
66    #[inline]
67    pub(crate) fn extension_information(&self, name: &'static str) -> Option<ExtensionInformation> {
68        self.extensions.get_by_name(name)
69    }
70
71    #[cfg(feature = "debug")]
72    pub fn clear_cache<IO: SocketIo>(&mut self, io: &mut IO) -> Result<(), ConnectionError> {
73        if self.keep_seqs.is_empty() && self.reply_cache.is_empty() {
74            return Ok(());
75        }
76        if !self.keep_seqs.is_empty() {
77            let _ = get_input_focus(io, self, false)?.reply(io, self)?;
78        }
79        for (seq, _) in self.keep_seqs.iter() {
80            crate::debug!("Dropped voidcookie {seq}");
81        }
82        for (seq, reply) in self.reply_cache.iter() {
83            if reply[0] == ERROR {
84                let err = parse_error(reply, &self.extensions)?;
85                crate::debug!("Dropped error on seq {seq}! {:?}", err);
86            } else {
87                crate::debug!("Dropped reply on seq {seq}!");
88            }
89        }
90        crate::debug!("Panicking because of leak!");
91        panic!("Leaked replies;")
92    }
93}
94
95#[derive(Copy, Clone, Debug)]
96struct SeqCount {
97    cur: u16,
98    seen: u16,
99}
100
101impl SeqCount {
102    fn new() -> Self {
103        Self { cur: 1, seen: 1 }
104    }
105
106    #[inline]
107    // A strictly less than here is kind of dubious as sequences wrap.
108    // However, this is only used to potentially skip a sync so it doesn't really matter
109    // since it only has false negatives.
110    fn sequence_has_been_seen(self, seq: u16) -> bool {
111        seq < self.seen
112    }
113
114    #[inline]
115    fn get_and_increment(&mut self) -> u16 {
116        let last = self.cur;
117        self.cur = self.cur.overflowing_add(1).0;
118        last
119    }
120
121    #[inline]
122    // Events are sequential so this shouldn't be callable out of order messing with
123    // sequence has been seen logic
124    fn record_seen(&mut self, seq: u16) {
125        self.seen = seq;
126    }
127}
128
129pub fn find_socket_path(
130    dpy_name: Option<&str>,
131) -> Result<(UnixString, ParsedDisplay), ConnectError> {
132    let parsed_display = crate::helpers::parse_display::parse_display(dpy_name)
133        .ok_or(ConnectError::DisplayParsingError)?;
134    let screen: usize = parsed_display.screen.into();
135    if let Some(path) = parsed_display.connect_instruction() {
136        Ok((path, parsed_display))
137    } else {
138        Err(ConnectError::DisplayParsingError)
139    }
140}
141
142pub fn setup<IO: SocketIo>(
143    io: &mut IO,
144    xcb_env: XcbEnv,
145    dpy: ParsedDisplay,
146) -> Result<XcbEventState, ConnectError> {
147    let family = Family::LOCAL;
148    let host = tiny_std::unix::host_name::host_name().unwrap_or_else(|_| "localhost".to_string());
149    let setup_req = setup_request(xcb_env, family, host.as_bytes(), dpy.display)?;
150    io.use_write_buffer(|buf| {
151        buf[..setup_req.len()].copy_from_slice(&setup_req);
152        Ok::<usize, ConnectError>(setup_req.len())
153    })?;
154    let mut read_bytes = 0;
155    while read_bytes < 8 {
156        io.use_read_buffer(|buf| {
157            read_bytes = buf.len();
158            Ok::<usize, ConnectError>(0)
159        })?;
160        if read_bytes < 8 {
161            io.block_for_more_data().unwrap();
162        }
163    }
164    let mut required_length = 0;
165    io.use_read_buffer(|buf| {
166        required_length = get_setup_length(buf);
167        Ok::<usize, ConnectError>(0)
168    })?;
169    while read_bytes < required_length {
170        io.use_read_buffer(|buf| {
171            read_bytes = buf.len();
172            Ok::<usize, ConnectError>(0)
173        })?;
174        if read_bytes < required_length {
175            io.block_for_more_data().unwrap();
176        }
177    }
178    let mut setup = None;
179    io.use_read_buffer(|buf| {
180        setup = Some(parse_setup(buf)?);
181        Ok::<usize, ConnectError>(required_length)
182    })?;
183    let setup = setup.unwrap();
184
185    // resolve the setup
186    crate::debug!("Setup completed.");
187
188    // Check that we got a valid screen number
189    if dpy.screen >= setup.roots.len() as u16 {
190        return Err(ConnectError::InvalidScreen);
191    }
192    let mut state = XcbEventState::new(setup);
193    init_extensions(io, &mut state).map_err(|e| {
194        crate::debug!("Error init exts {e}");
195        ConnectError::UnknownError
196    })?;
197    check_for_big_req(io, &mut state).map_err(|e| {
198        crate::debug!("Error check big_req {e}");
199        ConnectError::UnknownError
200    })?;
201    Ok(state)
202}
203
204// Preload all extensions immediately
205fn init_extensions<IO: SocketIo>(
206    io: &mut IO,
207    state: &mut XcbEventState,
208) -> Result<(), ConnectionError> {
209    let listed = list_extensions(io, state, false)?;
210    let r = state.block_for_reply(io, listed.seq)?;
211    let (reply, offset) = ListExtensionsReply::from_bytes(&r)?;
212    let mut extensions = vec![];
213    for name in reply.names {
214        let cookie = query_extension(io, state, &name.name, false)?;
215        extensions.push((name.name, cookie));
216    }
217    crate::debug!("Pushed all {} ext requests", extensions.len());
218    for (name, cookie) in extensions {
219        let response = cookie.reply(io, state)?;
220        let name = String::from_utf8(name).map_err(|e| {
221            crate::debug!("Failed string convert {e}");
222            ConnectionError::UnsupportedExtension(format!("Failed to convert extension name {e}"))
223        })?;
224        if let Some(ext) = find_extension(&name) {
225            crate::debug!("Registered extension: '{ext}'");
226            if response.present == 1 {
227                state.extensions.extensions.push((
228                    ext,
229                    ExtensionInformation {
230                        major_opcode: response.major_opcode,
231                        first_event: response.first_event,
232                        first_error: response.first_error,
233                    },
234                ));
235            }
236        }
237    }
238    Ok(())
239}
240
241fn check_for_big_req<IO: SocketIo>(
242    io: &mut IO,
243    state: &mut XcbEventState,
244) -> Result<(), ConnectionError> {
245    if state
246        .extension_information(xcb_rust_protocol::proto::bigreq::EXTENSION_NAME)
247        .is_some()
248    {
249        let reply = enable(io, state, false)?.reply(io, state)?;
250        state.max_request_length = reply.maximum_request_length as usize;
251        crate::debug!(
252            "Got max_request_length = {} words from bigreq",
253            state.max_request_length
254        );
255    }
256
257    Ok(())
258}
259pub fn change_property8<IO: SocketIo, XS: XcbState>(
260    io: &mut IO,
261    state: &mut XS,
262    mode: PropModeEnum,
263    window: Window,
264    property: Atom,
265    type_: Atom,
266    data: &[u8],
267    forget: bool,
268) -> Result<VoidCookie, ConnectionError> {
269    Ok(change_property(
270        io,
271        state,
272        mode,
273        window,
274        property,
275        type_,
276        8,
277        data.len().try_into().expect("`data` has too many elements"),
278        data,
279        forget,
280    )?)
281}
282
283/// Change a property on a window with format 16.
284pub fn change_property16<IO: SocketIo, XS: XcbState>(
285    io: &mut IO,
286    state: &mut XS,
287    mode: PropModeEnum,
288    window: Window,
289    property: Atom,
290    type_: Atom,
291    data: &[u16],
292    forget: bool,
293) -> Result<VoidCookie, ConnectionError> {
294    let mut data_u8 = Vec::with_capacity(data.len() * 2);
295    for item in data {
296        data_u8.extend(item.to_ne_bytes());
297    }
298    Ok(change_property(
299        io,
300        state,
301        mode,
302        window,
303        property,
304        type_,
305        16,
306        data.len().try_into().expect("`data` has too many elements"),
307        &data_u8,
308        forget,
309    )?)
310}
311
312/// Change a property on a window with format 32.
313pub fn change_property32<IO: SocketIo, XS: XcbState>(
314    io: &mut IO,
315    state: &mut XS,
316    mode: PropModeEnum,
317    window: Window,
318    property: Atom,
319    type_: Atom,
320    data: &[u32],
321    forget: bool,
322) -> Result<VoidCookie, ConnectionError> {
323    let mut data_u8 = Vec::with_capacity(data.len() * 4);
324    for item in data {
325        data_u8.extend(item.to_ne_bytes());
326    }
327    Ok(change_property(
328        io,
329        state,
330        mode,
331        window,
332        property,
333        type_,
334        32,
335        data.len().try_into().expect("`data` has too many elements"),
336        &data_u8,
337        forget,
338    )?)
339}
340
341pub fn try_drain<IO: SocketIo>(
342    io: &mut IO,
343    state: &mut XcbEventState,
344) -> Result<Vec<Vec<u8>>, ConnectionError> {
345    let mut events = vec![];
346    while let Some(next) = state.event_cache.pop_front() {
347        events.push(next);
348    }
349    for rr in do_drain(io) {
350        match rr {
351            ReadResult::Event(e) => {
352                events.push(e);
353            }
354            ReadResult::Reply(got_seq, buf) => {
355                if state.keep_seqs.remove(&got_seq).is_some() {
356                    state.reply_cache.insert(got_seq, buf);
357                }
358                state.seq_count.record_seen(got_seq);
359            }
360            ReadResult::Error(got_seq, buf) => {
361                crate::debug!("Got err {:?}", parse_error(&buf, &state.extensions)?);
362                if state.keep_seqs.remove(&got_seq).is_some() {
363                    state.reply_cache.insert(got_seq, buf);
364                }
365                state.seq_count.record_seen(got_seq);
366            }
367        }
368    }
369
370    Ok(events)
371}
372
373impl XcbState for XcbEventState {
374    #[inline]
375    fn major_opcode(&self, extension_name: &'static str) -> Option<u8> {
376        self.extension_information(extension_name)
377            .map(|ei| ei.major_opcode)
378    }
379
380    #[inline]
381    fn next_seq(&mut self) -> u16 {
382        self.seq_count.get_and_increment()
383    }
384
385    #[inline]
386    fn keep_and_return_next_seq(&mut self) -> u16 {
387        let next = self.seq_count.get_and_increment();
388        self.keep_seqs.insert(next, ());
389        next
390    }
391
392    #[inline]
393    fn max_request_size(&self) -> usize {
394        self.max_request_length
395    }
396
397    #[inline]
398    fn setup(&self) -> &Setup {
399        &self.setup
400    }
401
402    #[inline]
403    fn generate_id<IO: SocketIo>(&mut self, io: &mut IO) -> Result<u32, Error> {
404        if let Some(id) = self.id_allocator.generate_id() {
405            Ok(id)
406        } else if self
407            .extension_information(xcb_rust_protocol::proto::xc_misc::EXTENSION_NAME)
408            .is_none()
409        {
410            // IDs are exhausted and XC-MISC is not available
411            Err(Error::Connection("Ids exhausted and xc-misc not available"))
412        } else {
413            let range = xcb_rust_protocol::connection::xc_misc::get_x_i_d_range(io, self, false)?
414                .reply(io, self)?;
415
416            self.id_allocator
417                .update_xid_range(&range)
418                .map_err(|_| Error::Connection("Ids exhausted on server"))?;
419            self.id_allocator
420                .generate_id()
421                .ok_or(Error::Connection("Ids exhausted"))
422        }
423    }
424
425    #[inline]
426    fn block_for_reply<IO: SocketIo>(&mut self, io: &mut IO, seq: u16) -> Result<Vec<u8>, Error> {
427        if let Some(reply) = self.reply_cache.remove(&seq) {
428            Ok(reply)
429        } else {
430            let mut target = None;
431            self.keep_seqs.remove(&seq);
432            while target.is_none() {
433                for rr in do_drain(io) {
434                    match rr {
435                        ReadResult::Event(e) => {
436                            self.event_cache.push_back(e);
437                        }
438                        ReadResult::Reply(got_seq, buf) => {
439                            if got_seq == seq {
440                                target = Some(buf);
441                            } else if self.keep_seqs.remove(&got_seq).is_some() {
442                                self.reply_cache.insert(got_seq, buf);
443                            }
444                            self.seq_count.record_seen(got_seq);
445                        }
446                        ReadResult::Error(got_seq, buf) => {
447                            crate::debug!("Got err {:?}", parse_error(&buf, &self.extensions)?);
448                            if got_seq == seq {
449                                target = Some(buf);
450                            } else if self.keep_seqs.remove(&got_seq).is_some() {
451                                self.reply_cache.insert(got_seq, buf);
452                            }
453                            self.seq_count.record_seen(got_seq);
454                        }
455                    }
456                }
457                if target.is_some() {
458                    continue;
459                }
460                crate::debug!("No drain in current buffer, try read.");
461                for rr in read_next(io).map_err(|e| {
462                    crate::debug!("Failed to read next {e}");
463                    Error::Connection("Failed to read next")
464                })? {
465                    match rr {
466                        ReadResult::Event(e) => {
467                            self.event_cache.push_back(e);
468                        }
469                        ReadResult::Reply(got_seq, buf) => {
470                            if got_seq == seq {
471                                target = Some(buf);
472                            } else if self.keep_seqs.remove(&got_seq).is_some() {
473                                self.reply_cache.insert(got_seq, buf);
474                            }
475                            self.seq_count.record_seen(got_seq);
476                        }
477                        ReadResult::Error(got_seq, buf) => {
478                            crate::debug!("Got err {:?}", parse_error(&buf, &self.extensions)?);
479                            if got_seq == seq {
480                                target = Some(buf);
481                            } else if self.keep_seqs.remove(&got_seq).is_some() {
482                                self.reply_cache.insert(got_seq, buf);
483                            }
484                            self.seq_count.record_seen(got_seq);
485                        }
486                    }
487                }
488            }
489            Ok(unsafe { target.unwrap_unchecked() })
490        }
491    }
492
493    #[inline]
494    fn block_check_err<IO: SocketIo>(&mut self, io: &mut IO, seq: u16) -> Result<(), Error> {
495        if !self.seq_count.sequence_has_been_seen(seq) {
496            get_input_focus(io, self, false)?.reply(io, self)?;
497        }
498        if let Some(err) = self.reply_cache.remove(&seq) {
499            Err(Error::XcbError(parse_error(&err, &self.extensions)?))
500        } else {
501            self.keep_seqs.remove(&seq);
502            Ok(())
503        }
504    }
505
506    #[inline]
507    fn forget(&mut self, seq: u16) {
508        self.keep_seqs.remove(&seq);
509        self.reply_cache.remove(&seq);
510    }
511}
512
513#[inline]
514fn read_next<IO: SocketIo>(io: &mut IO) -> Result<Vec<ReadResult>, ConnectionError> {
515    io.block_for_more_data().map_err(ConnectionError::Io)?;
516    Ok(do_drain(io))
517}
518
519#[inline]
520fn do_drain<IO: SocketIo>(io: &mut IO) -> Vec<ReadResult> {
521    let mut read_results = vec![];
522    io.use_read_buffer(|read_buf| {
523        let mut offset = 0;
524        while let Some((new_offset, rr)) = drain_next(read_buf, offset) {
525            read_results.push(rr);
526            offset = new_offset;
527        }
528        Ok::<usize, ()>(offset)
529    });
530
531    read_results
532}
533
534#[allow(clippy::match_on_vec_items)]
535#[inline]
536fn drain_next(in_buffer: &[u8], offset: usize) -> Option<(usize, ReadResult)> {
537    let has_length_field = match in_buffer.get(offset) {
538        Some(&REPLY) => true,
539        Some(x) if x & 0x7f == xcb_rust_protocol::proto::xproto::GE_GENERIC_EVENT => true,
540        _ => false,
541    };
542    let additional_length = if has_length_field {
543        if let Some(length_field) = in_buffer.get(offset + 4..offset + 8) {
544            let length_field = u32::from_ne_bytes(length_field.try_into().unwrap());
545            let length_field = usize::try_from(length_field).unwrap();
546            debug_assert!(length_field <= usize::MAX / 4);
547            4 * length_field
548        } else {
549            0
550        }
551    } else {
552        0
553    };
554    // All packets are at least 32 bytes
555    let packet_length = 32 + additional_length;
556    if in_buffer[offset..].len() < packet_length {
557        // Need more data
558        None
559    } else {
560        // Got at least one full packet
561        let end_at = offset + packet_length;
562        let slice = &in_buffer[offset..end_at];
563        let read_result = match in_buffer[offset] {
564            ERROR => ReadResult::Error(
565                parse_seq(&in_buffer[offset..]),
566                in_buffer[offset..end_at].to_vec(),
567            ),
568            REPLY => ReadResult::Reply(
569                parse_seq(&in_buffer[offset..]),
570                in_buffer[offset..end_at].to_vec(),
571            ),
572            _ => ReadResult::Event(in_buffer[offset..end_at].to_vec()),
573        };
574        Some((end_at, read_result))
575    }
576}
577
578const ERROR: u8 = 0;
579const REPLY: u8 = 1;
580
581enum ReadResult {
582    Event(Vec<u8>),
583    Reply(u16, Vec<u8>),
584    Error(u16, Vec<u8>),
585}
586
587#[inline]
588fn parse_seq(raw_reply: &[u8]) -> u16 {
589    // The seq is at the same byte offset for both replies and errors
590    u16::from_ne_bytes(raw_reply[2..4].try_into().unwrap())
591}