x11rb_async/rust_connection/
shared_state.rs1use event_listener::Event;
4use futures_lite::future;
5use std::convert::Infallible;
6use std::io;
7use std::mem;
8use std::sync::{
9 atomic::{AtomicBool, Ordering},
10 Arc, Mutex as StdMutex, MutexGuard as StdMutexGuard,
11};
12use x11rb::errors::ConnectionError;
13use x11rb_protocol::connection::Connection as ProtoConnection;
14use x11rb_protocol::packet_reader::PacketReader as ProtoPacketReader;
15use x11rb_protocol::RawFdContainer;
16
17use super::Stream;
18
19#[derive(Debug)]
21pub(super) struct SharedState<S> {
22 inner: StdMutex<ProtoConnection>,
26
27 pub(super) stream: S,
29
30 new_input: Event,
32
33 driver_dropped: AtomicBool,
35}
36
37impl<S: Stream> SharedState<S> {
38 pub(super) fn new(stream: S) -> Self {
39 Self {
40 inner: Default::default(),
41 stream,
42 new_input: Event::new(),
43 driver_dropped: AtomicBool::new(false),
44 }
45 }
46
47 pub(super) fn lock_connection(&self) -> StdMutexGuard<'_, ProtoConnection> {
49 self.inner.lock().unwrap()
50 }
51
52 pub(super) async fn wait_for_incoming<R, F>(&self, mut get_reply: F) -> Result<R, io::Error>
58 where
59 F: FnMut(&mut ProtoConnection) -> Option<R>,
60 {
61 loop {
62 if let Some(reply) = get_reply(&mut self.lock_connection()) {
64 return Ok(reply);
65 }
66
67 let listener = self.new_input.listen();
69
70 if let Some(reply) = get_reply(&mut self.lock_connection()) {
72 return Ok(reply);
73 }
74
75 if self.driver_dropped.load(Ordering::SeqCst) {
78 return Err(io::Error::new(
79 io::ErrorKind::Other,
80 "Driving future was dropped",
81 ));
82 }
83
84 listener.await;
86 }
87 }
88
89 pub(super) async fn drive(
91 &self,
92 _break_on_drop: BreakOnDrop<S>,
93 ) -> Result<Infallible, ConnectionError> {
94 let mut packet_reader = PacketReader {
95 read_buffer: vec![0; 4096].into_boxed_slice(),
96 inner: ProtoPacketReader::new(),
97 };
98 let mut fds = vec![];
99 let mut packets = vec![];
100
101 loop {
102 for _ in 0..50 {
103 packet_reader.try_read_packets(&self.stream, &mut packets, &mut fds)?;
105 let packet_count = packets.len();
106
107 {
109 let mut inner = self.inner.lock().unwrap();
110 inner.enqueue_fds(mem::take(&mut fds));
111 packets
112 .drain(..)
113 .for_each(|packet| inner.enqueue_packet(packet));
114 }
115
116 if packet_count > 0 {
117 let _num_notified = self.new_input.notify_additional(usize::MAX);
119 } else {
120 self.stream.readable().await?;
122 }
123 }
124
125 future::yield_now().await;
127 }
128 }
129}
130
131#[derive(Debug)]
132struct PacketReader {
133 read_buffer: Box<[u8]>,
135
136 inner: ProtoPacketReader,
138}
139
140impl PacketReader {
141 fn try_read_packets(
143 &mut self,
144 stream: &impl Stream,
145 out_packets: &mut Vec<Vec<u8>>,
146 fd_storage: &mut Vec<RawFdContainer>,
147 ) -> io::Result<()> {
148 let original_length = out_packets.len();
149 loop {
150 if self.inner.remaining_capacity() >= self.read_buffer.len() {
153 tracing::trace!(
154 "Trying to read large packet with {} bytes remaining",
155 self.inner.remaining_capacity()
156 );
157 match stream.read(self.inner.buffer(), fd_storage) {
158 Ok(0) => {
159 tracing::error!("Large read returned zero");
160 return Err(io::Error::new(
161 io::ErrorKind::UnexpectedEof,
162 "The X11 server closed the connection",
163 ));
164 }
165 Ok(n) => {
166 tracing::trace!("Read {} bytes directly into large packet", n);
167 if let Some(packet) = self.inner.advance(n) {
168 out_packets.push(packet);
169 }
170 }
171 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break,
172 Err(e) => return Err(e),
173 }
174 } else {
175 let nread = match stream.read(&mut self.read_buffer, fd_storage) {
177 Ok(0) => {
178 tracing::error!("Buffer read returned zero");
179 return Err(io::Error::new(
180 io::ErrorKind::UnexpectedEof,
181 "The X11 server closed the connection",
182 ));
183 }
184 Ok(n) => n,
185 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break,
186 Err(e) => return Err(e),
187 };
188 tracing::trace!("Read {} bytes into read buffer", nread);
189
190 let mut src = &self.read_buffer[..nread];
192 while !src.is_empty() {
193 let dest = self.inner.buffer();
194 let amt_to_read = std::cmp::min(src.len(), dest.len());
195
196 dest[..amt_to_read].copy_from_slice(&src[..amt_to_read]);
197
198 src = &src[amt_to_read..];
200
201 if let Some(packet) = self.inner.advance(amt_to_read) {
203 out_packets.push(packet);
204 }
205 }
206 }
207 }
208 tracing::trace!(
209 "Read {} complete packet(s)",
210 out_packets.len() - original_length
211 );
212
213 Ok(())
214 }
215}
216
217#[derive(Debug)]
218pub(super) struct BreakOnDrop<S>(pub(super) Arc<SharedState<S>>);
219
220impl<S> Drop for BreakOnDrop<S> {
221 fn drop(&mut self) {
222 self.0.driver_dropped.store(true, Ordering::SeqCst);
224
225 let _num_notified = self.0.new_input.notify_additional(usize::MAX);
227 }
228}