1use async_lock::{Mutex, MutexGuard, RwLock};
4use futures_lite::future;
5use tracing::Instrument as _;
6
7use std::convert::Infallible;
8use std::future::Future;
9use std::io;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use crate::connection::{Connection, Fut, RequestConnection};
14use crate::{Cookie, CookieWithFds, VoidCookie};
15
16use x11rb_protocol::connection::{Connection as ProtoConnection, PollReply, ReplyFdKind};
17use x11rb_protocol::id_allocator::IdAllocator;
18use x11rb_protocol::protocol::bigreq::EnableReply;
19use x11rb_protocol::protocol::xproto::{Setup, QUERY_EXTENSION_REQUEST};
20use x11rb_protocol::x11_utils::{ExtensionInformation, TryParse, TryParseFd, X11Error};
21use x11rb_protocol::xauth::get_auth;
22use x11rb_protocol::{DiscardMode, RawFdContainer, SequenceNumber};
23
24use x11rb::connection::{BufWithFds, ReplyOrError};
25use x11rb::errors::{ConnectError, ConnectionError, ParseError, ReplyOrIdError};
26
27mod extensions;
28mod nb_connect;
29mod shared_state;
30mod stream;
31mod write_buffer;
32
33pub use stream::{DefaultStream, Stream, StreamAdaptor, StreamBase};
34use write_buffer::{WriteBuffer, WriteBufferGuard};
35
36#[derive(Debug)]
38pub struct RustConnection<S = DefaultStream> {
39 shared: Arc<shared_state::SharedState<S>>,
41
42 write_buffer: WriteBuffer,
46
47 setup: Setup,
49
50 max_request_bytes: Mutex<MaxRequestBytes>,
52
53 id_allocator: Mutex<IdAllocator>,
55
56 extensions: RwLock<extensions::Extensions>,
58}
59
60#[derive(Debug, PartialEq, Eq)]
62enum MaxRequestBytes {
63 Unknown,
65
66 Known(usize),
68
69 Requested(Option<SequenceNumber>),
71}
72
73impl RustConnection {
74 pub async fn connect(
80 display_name: Option<&str>,
81 ) -> Result<
82 (
83 Self,
84 usize,
85 impl Future<Output = Result<Infallible, ConnectionError>> + Send,
86 ),
87 ConnectError,
88 > {
89 let addrs = x11rb_protocol::parse_display::parse_display(display_name)?;
91
92 let (stream, screen, (family, address)) = nb_connect::connect(&addrs).await?;
94
95 let stream = StreamAdaptor::new(stream)?;
97
98 let (auth_name, auth_data) = blocking::unblock(move || {
100 get_auth(family, &address, addrs.display)
101 .unwrap_or(None)
102 .unwrap_or_else(|| (Vec::new(), Vec::new()))
103 })
104 .await;
105 tracing::trace!("Picked authentication via auth mechanism {:?}", auth_name);
106
107 let (conn, drive) =
108 RustConnection::connect_to_stream_with_auth_info(stream, screen, auth_name, auth_data)
109 .await?;
110 Ok((conn, screen, drive))
111 }
112}
113
114impl<S: Stream + Send + Sync> RustConnection<S> {
115 pub async fn connect_to_stream(
121 stream: S,
122 screen: usize,
123 ) -> Result<
124 (
125 Self,
126 impl Future<Output = Result<Infallible, ConnectionError>> + Send,
127 ),
128 ConnectError,
129 > {
130 Self::connect_to_stream_with_auth_info(stream, screen, Vec::new(), Vec::new()).await
131 }
132
133 pub async fn connect_to_stream_with_auth_info(
139 stream: S,
140 screen: usize,
141 auth_name: Vec<u8>,
142 auth_data: Vec<u8>,
143 ) -> Result<
144 (
145 Self,
146 impl Future<Output = Result<Infallible, ConnectionError>> + Send,
147 ),
148 ConnectError,
149 > {
150 let (mut connect, setup_request) =
152 x11rb_protocol::connect::Connect::with_authorization(auth_name, auth_data);
153
154 let mut fds = Vec::new();
156 let mut nwritten = 0;
157
158 tracing::trace!(
159 "Writing connection setup with {} bytes",
160 setup_request.len()
161 );
162 while nwritten < setup_request.len() {
163 nwritten += write_with(&stream, |stream| {
164 match stream.write(&setup_request[nwritten..], &mut fds) {
165 Ok(0) => Err(io::ErrorKind::WriteZero.into()),
166 res => res,
167 }
168 })
169 .await?;
170 }
171
172 loop {
174 tracing::trace!(
175 "Reading connection setup with at least {} bytes remaining",
176 connect.buffer().len()
177 );
178 let adv = match stream.read(connect.buffer(), &mut fds) {
179 Err(e) if e.kind() == io::ErrorKind::WouldBlock => 0,
180 Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof).into()),
181 Ok(n) => n,
182 Err(e) => return Err(e.into()),
183 };
184 tracing::trace!("Read {} bytes", adv);
185
186 if connect.advance(adv) {
188 break;
189 }
190
191 stream.readable().await?;
193 }
194
195 let setup = connect.into_setup()?;
197
198 if setup.roots.len() <= screen {
200 return Err(ConnectError::InvalidScreen);
201 }
202
203 Self::for_connected_stream(stream, setup)
204 }
205
206 pub fn for_connected_stream(
212 stream: S,
213 setup: Setup,
214 ) -> Result<
215 (
216 Self,
217 impl Future<Output = Result<Infallible, ConnectionError>> + Send,
218 ),
219 ConnectError,
220 > {
221 let id_allocator = IdAllocator::new(setup.resource_id_base, setup.resource_id_mask)?;
222 let shared = Arc::new(shared_state::SharedState::new(stream));
223
224 let drive = {
226 let shared = shared.clone();
227 let break_on_drop = shared_state::BreakOnDrop(shared.clone());
231 async move { shared.drive(break_on_drop).await }
232 };
233
234 Ok((
235 RustConnection {
236 shared,
237 write_buffer: Default::default(),
238 setup,
239 max_request_bytes: Mutex::new(MaxRequestBytes::Unknown),
240 id_allocator: Mutex::new(id_allocator),
241 extensions: Default::default(),
242 },
243 drive,
244 ))
245 }
246
247 async fn send_request(
249 &self,
250 bufs: &[io::IoSlice<'_>],
251 mut fds: Vec<RawFdContainer>,
252 kind: ReplyFdKind,
253 ) -> Result<SequenceNumber, ConnectionError>
254 where
255 S: Send + Sync,
256 {
257 async {
258 {
259 const LEVEL: tracing::Level = tracing::Level::DEBUG;
260 if tracing::event_enabled!(LEVEL) {
261 let major_opcode = bufs[0][0];
264 if major_opcode == QUERY_EXTENSION_REQUEST {
265 tracing::event!(LEVEL, "Sending QueryExtension request");
266 } else {
267 let extensions = self.extensions.read().await;
268 tracing::event!(LEVEL, "Sending {} request", x11rb_protocol::protocol::get_request_name(&*extensions, major_opcode, bufs[0][1]));
269 }
270 }
271 }
272
273 let mut storage = Default::default();
275 let bufs = compute_length_field(self, bufs, &mut storage).await?;
276
277 let mut buffer = self.write_buffer.lock().await?;
279
280 loop {
281 let seq = {
282 let mut inner = self.shared.lock_connection();
283 inner.send_request(kind)
284 };
285
286 match seq {
288 Some(seq) => {
289 buffer = self.write_all_vectored(buffer, bufs, &mut fds).await?;
291 buffer.unlock();
292 return Ok(seq);
293 }
294
295 None => {
296 tracing::trace!("Syncing with the X11 server since there are too many outstanding void requests");
298 buffer = self.send_sync(buffer).await?;
299 }
300 }
301 }
302 }.instrument(tracing::debug_span!("send_request")).await
303 }
304
305 async fn send_sync<'a>(
307 &'a self,
308 buffer: WriteBufferGuard<'a>,
309 ) -> Result<WriteBufferGuard<'a>, ConnectionError> {
310 let length = 1u16.to_ne_bytes();
311 let request = [
312 x11rb_protocol::protocol::xproto::GET_INPUT_FOCUS_REQUEST,
313 0,
314 length[0],
315 length[1],
316 ];
317
318 {
320 let mut inner = self.shared.lock_connection();
321 let seq = inner
322 .send_request(ReplyFdKind::ReplyWithoutFDs)
323 .expect("This request should not be blocked by syncs");
324 inner.discard_reply(seq, DiscardMode::DiscardReplyAndError);
325 };
326
327 let iov = &[io::IoSlice::new(&request)];
329 let mut fds = Vec::new();
330 self.write_all_vectored(buffer, iov, &mut fds).await
331 }
332
333 async fn write_all_vectored<'a>(
335 &'a self,
336 mut write_buffer: WriteBufferGuard<'a>,
337 bufs: &[io::IoSlice<'_>],
338 fds: &mut Vec<RawFdContainer>,
339 ) -> Result<WriteBufferGuard<'a>, ConnectionError> {
340 write_buffer
341 .write_all_vectored(&self.shared.stream, bufs, fds)
342 .await?;
343 Ok(write_buffer)
344 }
345
346 async fn flush_impl<'a>(
348 &'a self,
349 mut buffer: WriteBufferGuard<'a>,
350 ) -> Result<WriteBufferGuard<'a>, ConnectionError> {
351 buffer.flush(&self.shared.stream).await?;
352 Ok(buffer)
353 }
354
355 async fn prefetch_len_impl(&self) -> Result<MutexGuard<'_, MaxRequestBytes>, ConnectionError>
357 where
358 S: Send + Sync,
359 {
360 let mut mrl = self.max_request_bytes.lock().await;
361
362 if *mrl == MaxRequestBytes::Unknown {
364 tracing::debug!("Prefetching maximum request length");
365 let cookie = crate::protocol::bigreq::enable(self)
366 .await
367 .map(|cookie| {
368 let seq = cookie.sequence_number();
369 std::mem::forget(cookie);
370 seq
371 })
372 .ok();
373
374 *mrl = MaxRequestBytes::Requested(cookie);
375 }
376
377 Ok(mrl)
378 }
379
380 async fn wait_for_reply_with_fds_impl(
382 &self,
383 sequence: SequenceNumber,
384 ) -> Result<ReplyOrError<BufWithFds<Vec<u8>>, Vec<u8>>, ConnectionError> {
385 self.flush_impl(self.write_buffer.lock().await?)
387 .await?
388 .unlock();
389
390 let get_reply = |inner: &mut ProtoConnection| {
391 if let Some(reply) = inner.poll_for_reply_or_error(sequence) {
392 if reply.0[0] == 0 {
393 tracing::trace!("Got an error");
394 Some(Ok(ReplyOrError::Error(reply.0)))
395 } else {
396 tracing::trace!("Got a reply");
397 Some(Ok(ReplyOrError::Reply(reply)))
398 }
399 } else {
400 None
401 }
402 };
403
404 self.shared.wait_for_incoming(get_reply).await?
405 }
406}
407
408impl<S: Stream + Send + Sync> RequestConnection for RustConnection<S> {
409 type Buf = Vec<u8>;
410
411 fn send_request_with_reply<'this, 'bufs, 'sl, 're, 'future, R>(
412 &'this self,
413 bufs: &'bufs [io::IoSlice<'sl>],
414 fds: Vec<RawFdContainer>,
415 ) -> Fut<'future, Cookie<'this, Self, R>, ConnectionError>
416 where
417 'this: 'future,
418 'bufs: 'future,
419 'sl: 'future,
420 're: 'future,
421 R: TryParse + Send + 're,
422 {
423 Box::pin(async move {
424 let seq = self
425 .send_request(bufs, fds, ReplyFdKind::ReplyWithoutFDs)
426 .await?;
427
428 Ok(Cookie::new(self, seq))
429 })
430 }
431
432 fn send_request_with_reply_with_fds<'this, 'bufs, 'sl, 're, 'future, R>(
433 &'this self,
434 bufs: &'bufs [io::IoSlice<'sl>],
435 fds: Vec<RawFdContainer>,
436 ) -> Fut<'future, CookieWithFds<'this, Self, R>, ConnectionError>
437 where
438 'this: 'future,
439 'bufs: 'future,
440 'sl: 'future,
441 're: 'future,
442 R: TryParseFd + Send + 're,
443 {
444 Box::pin(async move {
445 let seq = self
446 .send_request(bufs, fds, ReplyFdKind::ReplyWithFDs)
447 .await?;
448
449 Ok(CookieWithFds::new(self, seq))
450 })
451 }
452
453 fn send_request_without_reply<'this, 'bufs, 'sl, 'future>(
454 &'this self,
455 bufs: &'bufs [io::IoSlice<'sl>],
456 fds: Vec<RawFdContainer>,
457 ) -> Fut<'future, VoidCookie<'this, Self>, ConnectionError>
458 where
459 'this: 'future,
460 'bufs: 'future,
461 'sl: 'future,
462 {
463 Box::pin(async move {
464 let seq = self.send_request(bufs, fds, ReplyFdKind::NoReply).await?;
465
466 Ok(VoidCookie::new(self, seq))
467 })
468 }
469
470 fn discard_reply(
471 &self,
472 sequence: SequenceNumber,
473 _kind: x11rb::connection::RequestKind,
474 mode: DiscardMode,
475 ) {
476 tracing::debug!(
477 "Discarding reply to request {} in mode {:?}",
478 sequence,
479 mode
480 );
481 self.shared.lock_connection().discard_reply(sequence, mode)
482 }
483
484 fn prefetch_extension_information(&self, name: &'static str) -> Fut<'_, (), ConnectionError> {
485 Box::pin(async move {
486 let mut cache = self.extensions.write().await;
487 cache.prefetch(self, name).await
488 })
489 }
490
491 fn extension_information(
492 &self,
493 name: &'static str,
494 ) -> Fut<'_, Option<ExtensionInformation>, ConnectionError> {
495 Box::pin(async move {
496 let mut cache = self.extensions.write().await;
497 cache.information(self, name).await
498 })
499 }
500
501 fn wait_for_reply_or_raw_error(
502 &self,
503 sequence: SequenceNumber,
504 ) -> Fut<'_, ReplyOrError<Self::Buf>, ConnectionError> {
505 Box::pin(
506 async move {
507 match self.wait_for_reply_with_fds_impl(sequence).await? {
508 ReplyOrError::Reply((buf, _)) => Ok(ReplyOrError::Reply(buf)),
509 ReplyOrError::Error(buf) => Ok(ReplyOrError::Error(buf)),
510 }
511 }
512 .instrument(tracing::info_span!("wait_for_reply_or_raw_error", sequence)),
513 )
514 }
515
516 fn wait_for_reply(
517 &self,
518 sequence: SequenceNumber,
519 ) -> Fut<'_, Option<Self::Buf>, ConnectionError> {
520 Box::pin(
521 async move {
522 self.flush_impl(self.write_buffer.lock().await?)
524 .await?
525 .unlock();
526
527 let get_reply = |inner: &mut ProtoConnection| match inner.poll_for_reply(sequence) {
528 PollReply::TryAgain => None,
529 PollReply::Reply(reply) => Some(Ok(Some(reply))),
530 PollReply::NoReply => Some(Ok(None)),
531 };
532
533 self.shared.wait_for_incoming(get_reply).await?
535 }
536 .instrument(tracing::info_span!("wait_for_reply", sequence)),
537 )
538 }
539
540 fn wait_for_reply_with_fds_raw(
541 &self,
542 sequence: SequenceNumber,
543 ) -> Fut<'_, ReplyOrError<BufWithFds<Self::Buf>, Self::Buf>, ConnectionError> {
544 Box::pin(
545 self.wait_for_reply_with_fds_impl(sequence)
546 .instrument(tracing::info_span!("wait_for_reply_with_fds_raw", sequence)),
547 )
548 }
549
550 fn check_for_raw_error(
551 &self,
552 sequence: SequenceNumber,
553 ) -> Fut<'_, Option<Self::Buf>, ConnectionError> {
554 Box::pin(
555 async move {
556 let mut write_buffer = self.write_buffer.lock().await?;
557 if self
558 .shared
559 .lock_connection()
560 .prepare_check_for_reply_or_error(sequence)
561 {
562 tracing::trace!("Inserting sync with the X11 server");
563 write_buffer = self.send_sync(write_buffer).await?;
564
565 assert!(!self
566 .shared
567 .lock_connection()
568 .prepare_check_for_reply_or_error(sequence));
569 }
570
571 self.flush_impl(write_buffer).await?.unlock();
573
574 let get_result = |inner: &mut ProtoConnection| match inner
575 .poll_check_for_reply_or_error(sequence)
576 {
577 PollReply::TryAgain => None,
578 PollReply::NoReply => Some(Ok(None)),
579 PollReply::Reply(buffer) => Some(Ok(Some(buffer))),
580 };
581
582 self.shared.wait_for_incoming(get_result).await?
583 }
584 .instrument(tracing::info_span!("check_for_raw_error", sequence)),
585 )
586 }
587
588 fn prefetch_maximum_request_bytes(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
589 Box::pin(async move {
590 let _guard = self
591 .prefetch_len_impl()
592 .await
593 .expect("Failed to prefetch maximum request bytes");
594 })
595 }
596
597 fn maximum_request_bytes(&self) -> Pin<Box<dyn Future<Output = usize> + Send + '_>> {
598 Box::pin(
599 async move {
600 let mut mrl = self
601 .prefetch_len_impl()
602 .await
603 .expect("Failed to prefetch maximum request bytes");
604
605 match *mrl {
607 MaxRequestBytes::Known(len) => len,
608 MaxRequestBytes::Unknown => unreachable!("We are in the Some branch"),
609 MaxRequestBytes::Requested(cookie) => {
610 let cookie = match cookie {
611 Some(cookie) => cookie,
612 None => {
613 return usize::from(self.setup().maximum_request_length)
615 .saturating_mul(4);
616 }
617 };
618
619 let cookie = Cookie::<'_, _, EnableReply>::new(self, cookie);
621
622 let reply = cookie.reply().await.expect("Failed to get reply");
623
624 let total = reply
626 .maximum_request_length
627 .try_into()
628 .ok()
629 .and_then(|x: usize| x.checked_mul(4))
630 .unwrap_or(std::usize::MAX);
631
632 *mrl = MaxRequestBytes::Known(total);
633 tracing::debug!("Maximum request length is {} bytes", total);
634 total
635 }
636 }
637 }
638 .instrument(tracing::info_span!("maximum_request_bytes")),
639 )
640 }
641
642 fn parse_error(&self, error: &[u8]) -> Result<X11Error, ParseError> {
643 let extensions = future::block_on(self.extensions.read());
644 X11Error::try_parse(error, &*extensions)
645 }
646
647 fn parse_event(&self, event: &[u8]) -> Result<x11rb::protocol::Event, ParseError> {
648 let extensions = future::block_on(self.extensions.read());
649 x11rb::protocol::Event::parse(event, &*extensions)
650 }
651}
652
653impl<S: Stream + Send + Sync> Connection for RustConnection<S> {
654 fn wait_for_raw_event_with_sequence(
655 &self,
656 ) -> Fut<'_, x11rb_protocol::RawEventAndSeqNumber<Self::Buf>, ConnectionError> {
657 Box::pin(
658 async move {
659 let get_event = |inner: &mut ProtoConnection| inner.poll_for_event_with_sequence();
660
661 Ok(self.shared.wait_for_incoming(get_event).await?)
662 }
663 .instrument(tracing::info_span!("wait_for_raw_event_with_sequence")),
664 )
665 }
666
667 fn poll_for_raw_event_with_sequence(
668 &self,
669 ) -> Result<Option<x11rb_protocol::RawEventAndSeqNumber<Self::Buf>>, ConnectionError> {
670 Ok(self.shared.lock_connection().poll_for_event_with_sequence())
671 }
672
673 fn flush(&self) -> Fut<'_, (), ConnectionError> {
674 Box::pin(async move {
675 self.flush_impl(self.write_buffer.lock().await?)
676 .await?
677 .unlock();
678
679 Ok(())
680 })
681 }
682
683 fn setup(&self) -> &Setup {
684 &self.setup
685 }
686
687 fn generate_id(&self) -> Fut<'_, u32, ReplyOrIdError> {
688 Box::pin(
689 async move {
690 use crate::protocol::xc_misc;
691
692 let mut id_allocator = self.id_allocator.lock().await;
693
694 if let Some(id) = id_allocator.generate_id() {
696 return Ok(id);
697 }
698
699 if self
701 .extension_information(xc_misc::X11_EXTENSION_NAME)
702 .await?
703 .is_some()
704 {
705 tracing::info!("XIDs are exhausted; fetching free range via XC-MISC");
706
707 id_allocator
709 .update_xid_range(&xc_misc::get_xid_range(self).await?.reply().await?)?;
710
711 return id_allocator
713 .generate_id()
714 .ok_or(ReplyOrIdError::IdsExhausted);
715 } else {
716 tracing::error!("XIDs are exhausted and XC-MISC extension is not available");
717 }
718
719 Err(ReplyOrIdError::IdsExhausted)
721 }
722 .instrument(tracing::info_span!("generate_id")),
723 )
724 }
725}
726
727async fn compute_length_field<'b>(
729 conn: &impl RequestConnection,
730 request_buffers: &'b [io::IoSlice<'b>],
731 storage: &'b mut (Vec<io::IoSlice<'b>>, [u8; 8]),
732) -> Result<&'b [io::IoSlice<'b>], ConnectionError> {
733 let length: usize = request_buffers.iter().map(|buf| buf.len()).sum();
735 assert_eq!(
736 length % 4,
737 0,
738 "The length of X11 requests must be a multiple of 4, got {}",
739 length
740 );
741 let wire_length = length / 4;
742
743 let first_buf = &request_buffers[0];
744
745 if let Ok(wire_length) = u16::try_from(wire_length) {
747 let length_field = u16::from_ne_bytes([first_buf[2], first_buf[3]]);
749 assert_eq!(
750 wire_length, length_field,
751 "Length field contains incorrect value"
752 );
753 return Ok(request_buffers);
754 }
755
756 if length > conn.maximum_request_bytes().await {
758 return Err(ConnectionError::MaximumRequestLengthExceeded);
759 }
760
761 let wire_length: u32 = wire_length
763 .checked_add(1)
764 .ok_or(ConnectionError::MaximumRequestLengthExceeded)?
765 .try_into()
766 .expect("X11 request larger than 2^34 bytes?!?");
767 let wire_length = wire_length.to_ne_bytes();
768
769 storage.1.copy_from_slice(&[
773 first_buf[0],
775 first_buf[1],
776 0,
778 0,
779 wire_length[0],
781 wire_length[1],
782 wire_length[2],
783 wire_length[3],
784 ]);
785 storage.0.push(io::IoSlice::new(&storage.1));
786
787 storage.0.push(io::IoSlice::new(&first_buf[4..]));
789
790 storage.0.extend(
792 request_buffers[1..]
793 .iter()
794 .map(std::ops::Deref::deref)
795 .map(io::IoSlice::new),
796 );
797
798 Ok(&storage.0[..])
799}
800
801async fn write_with<'a, S: StreamBase<'a>, R, F>(stream: &'a S, mut f: F) -> Result<R, io::Error>
802where
803 F: FnMut(&'a S) -> Result<R, io::Error>,
804{
805 loop {
806 match f(stream) {
807 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
808 stream.writable().await?;
810 }
811
812 res => return res,
813 }
814 }
815}