1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
//! The state of the connection that is shared with the reading future

use event_listener::Event;
use futures_lite::future;
use std::convert::Infallible;
use std::io;
use std::mem;
use std::sync::{
    atomic::{AtomicBool, Ordering},
    Arc, Mutex as StdMutex, MutexGuard as StdMutexGuard,
};
use x11rb::errors::ConnectionError;
use x11rb_protocol::connection::Connection as ProtoConnection;
use x11rb_protocol::packet_reader::PacketReader as ProtoPacketReader;
use x11rb_protocol::RawFdContainer;

use super::Stream;

/// State shared between the `RustConnection` and the future polling for new packets.
#[derive(Debug)]
pub(super) struct SharedState<S> {
    /// The underlying connection manager.
    ///
    /// This is never held across an `.await` point, so it's fine to use a standard library mutex.
    inner: StdMutex<ProtoConnection>,

    /// The stream for communicating with the X11 server.
    pub(super) stream: S,

    /// Listener for when new data is available on the stream.
    new_input: Event,

    /// Flag that indicates that the future for drive() was dropped and we no longer read input.
    driver_dropped: AtomicBool,
}

impl<S: Stream> SharedState<S> {
    pub(super) fn new(stream: S) -> Self {
        Self {
            inner: Default::default(),
            stream,
            new_input: Event::new(),
            driver_dropped: AtomicBool::new(false),
        }
    }

    /// Lock the inner connection and return a mutex guard for it.
    pub(super) fn lock_connection(&self) -> StdMutexGuard<'_, ProtoConnection> {
        self.inner.lock().unwrap()
    }

    /// Wait for an incoming packet.
    ///
    /// The given function get_reply should check whether the needed package was already received
    /// and put into the inner connection. It should return `None` if nothing is present yet and
    /// new incoming X11 packets should be awaited.
    pub(super) async fn wait_for_incoming<R, F>(&self, mut get_reply: F) -> Result<R, io::Error>
    where
        F: FnMut(&mut ProtoConnection) -> Option<R>,
    {
        loop {
            // See if we can find the reply in the connection.
            if let Some(reply) = get_reply(&mut self.lock_connection()) {
                return Ok(reply);
            }

            // Register a listener for the reply.
            let listener = self.new_input.listen();

            // Maybe a packet was delivered while we were registering the listener.
            if let Some(reply) = get_reply(&mut self.lock_connection()) {
                return Ok(reply);
            }

            // Maybe the future from drive() was dropped?
            // We only check this down here and not before the listener since this is unlikely
            if self.driver_dropped.load(Ordering::SeqCst) {
                return Err(io::Error::new(
                    io::ErrorKind::Other,
                    "Driving future was dropped",
                ));
            }

            // Wait for the next packet.
            listener.await;
        }
    }

    /// Read incoming packets from the stream and put them into the inner connection.
    pub(super) async fn drive(
        &self,
        _break_on_drop: BreakOnDrop<S>,
    ) -> Result<Infallible, ConnectionError> {
        let mut packet_reader = PacketReader {
            read_buffer: vec![0; 4096].into_boxed_slice(),
            inner: ProtoPacketReader::new(),
        };
        let mut fds = vec![];
        let mut packets = vec![];

        loop {
            for _ in 0..50 {
                // Try to read packets from the stream.
                packet_reader.try_read_packets(&self.stream, &mut packets, &mut fds)?;
                let packet_count = packets.len();

                // Now, actually enqueue the packets.
                {
                    let mut inner = self.inner.lock().unwrap();
                    inner.enqueue_fds(mem::take(&mut fds));
                    packets
                        .drain(..)
                        .for_each(|packet| inner.enqueue_packet(packet));
                }

                if packet_count > 0 {
                    // Notify any listeners that there is new data.
                    self.new_input.notify_additional(std::usize::MAX);
                } else {
                    // Wait for more data.
                    self.stream.readable().await?;
                }
            }

            // In the case of a large influx of packets, don't starve other tasks.
            future::yield_now().await;
        }
    }
}

#[derive(Debug)]
struct PacketReader {
    /// The read buffer to store incoming bytes in.
    read_buffer: Box<[u8]>,

    /// The inner reader that breaks these bytes into packets.
    inner: ProtoPacketReader,
}

impl PacketReader {
    /// Try to read packets from the stream.
    fn try_read_packets(
        &mut self,
        stream: &impl Stream,
        out_packets: &mut Vec<Vec<u8>>,
        fd_storage: &mut Vec<RawFdContainer>,
    ) -> io::Result<()> {
        let original_length = out_packets.len();
        loop {
            // If the necessary packet size is larger than our buffer, just fill straight
            // into the buffer.
            if self.inner.remaining_capacity() >= self.read_buffer.len() {
                tracing::trace!(
                    "Trying to read large packet with {} bytes remaining",
                    self.inner.remaining_capacity()
                );
                match stream.read(self.inner.buffer(), fd_storage) {
                    Ok(0) => {
                        tracing::error!("Large read returned zero");
                        return Err(io::Error::new(
                            io::ErrorKind::UnexpectedEof,
                            "The X11 server closed the connection",
                        ));
                    }
                    Ok(n) => {
                        tracing::trace!("Read {} bytes directly into large packet", n);
                        if let Some(packet) = self.inner.advance(n) {
                            out_packets.push(packet);
                        }
                    }
                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break,
                    Err(e) => return Err(e),
                }
            } else {
                // read into our buffer
                let nread = match stream.read(&mut self.read_buffer, fd_storage) {
                    Ok(0) => {
                        tracing::error!("Buffer read returned zero");
                        return Err(io::Error::new(
                            io::ErrorKind::UnexpectedEof,
                            "The X11 server closed the connection",
                        ));
                    }
                    Ok(n) => n,
                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break,
                    Err(e) => return Err(e),
                };
                tracing::trace!("Read {} bytes into read buffer", nread);

                // begin reading that data into packets
                let mut src = &self.read_buffer[..nread];
                while !src.is_empty() {
                    let dest = self.inner.buffer();
                    let amt_to_read = std::cmp::min(src.len(), dest.len());

                    dest[..amt_to_read].copy_from_slice(&src[..amt_to_read]);

                    // reborrow src
                    src = &src[amt_to_read..];

                    // advance by the given amount
                    if let Some(packet) = self.inner.advance(amt_to_read) {
                        out_packets.push(packet);
                    }
                }
            }
        }
        tracing::trace!(
            "Read {} complete packet(s)",
            out_packets.len() - original_length
        );

        Ok(())
    }
}

#[derive(Debug)]
pub(super) struct BreakOnDrop<S>(pub(super) Arc<SharedState<S>>);

impl<S> Drop for BreakOnDrop<S> {
    fn drop(&mut self) {
        // Mark the connection as broken
        self.0.driver_dropped.store(true, Ordering::SeqCst);

        // Wake up everyone that might be waiting
        self.0.new_input.notify_additional(std::usize::MAX);
    }
}