Skip to main content

vnc/client/
connection.rs

1use futures::TryStreamExt;
2use tokio_stream::wrappers::ReceiverStream;
3
4use std::{future::Future, sync::Arc, vec};
5use tokio::{
6    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
7    sync::{
8        mpsc::{
9            channel,
10            error::{TryRecvError, TrySendError},
11            Receiver, Sender,
12        },
13        oneshot, Mutex,
14    },
15};
16use tokio_util::compat::*;
17use tracing::*;
18
19use crate::{codec, PixelFormat, Rect, VncEncoding, VncError, VncEvent, X11Event};
20const CHANNEL_SIZE: usize = 4096;
21
22#[cfg(not(target_arch = "wasm32"))]
23use tokio::spawn;
24#[cfg(target_arch = "wasm32")]
25use wasm_bindgen_futures::spawn_local as spawn;
26
27use super::messages::{ClientMsg, ServerMsg};
28
29struct ImageRect {
30    rect: Rect,
31    encoding: VncEncoding,
32}
33
34impl From<[u8; 12]> for ImageRect {
35    fn from(buf: [u8; 12]) -> Self {
36        Self {
37            rect: Rect {
38                x: ((buf[0] as u16) << 8) | buf[1] as u16,
39                y: ((buf[2] as u16) << 8) | buf[3] as u16,
40                width: ((buf[4] as u16) << 8) | buf[5] as u16,
41                height: ((buf[6] as u16) << 8) | buf[7] as u16,
42            },
43            encoding: (((buf[8] as u32) << 24)
44                | ((buf[9] as u32) << 16)
45                | ((buf[10] as u32) << 8)
46                | (buf[11] as u32))
47                .into(),
48        }
49    }
50}
51
52impl ImageRect {
53    async fn read<S>(reader: &mut S) -> Result<Self, VncError>
54    where
55        S: AsyncRead + Unpin,
56    {
57        let mut rect_buf = [0_u8; 12];
58        reader.read_exact(&mut rect_buf).await?;
59        Ok(rect_buf.into())
60    }
61}
62
63struct VncInner {
64    name: String,
65    screen: (u16, u16),
66    input_ch: Sender<ClientMsg>,
67    output_ch: Receiver<VncEvent>,
68    decoding_stop: Option<oneshot::Sender<()>>,
69    net_conn_stop: Option<oneshot::Sender<()>>,
70    closed: bool,
71}
72
73/// The instance of a connected vnc client
74///
75impl VncInner {
76    async fn new<S>(
77        mut stream: S,
78        shared: bool,
79        mut pixel_format: Option<PixelFormat>,
80        encodings: Vec<VncEncoding>,
81    ) -> Result<Self, VncError>
82    where
83        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
84    {
85        let (conn_ch_tx, conn_ch_rx) = channel(CHANNEL_SIZE);
86        let (input_ch_tx, input_ch_rx) = channel(CHANNEL_SIZE);
87        let (output_ch_tx, output_ch_rx) = channel(CHANNEL_SIZE);
88        let (decoding_stop_tx, decoding_stop_rx) = oneshot::channel();
89        let (net_conn_stop_tx, net_conn_stop_rx) = oneshot::channel();
90
91        trace!("client init msg");
92        send_client_init(&mut stream, shared).await?;
93
94        trace!("server init msg");
95        let (name, (width, height)) =
96            read_server_init(&mut stream, &mut pixel_format, &|e| async {
97                output_ch_tx.send(e).await?;
98                Ok(())
99            })
100            .await?;
101
102        trace!("client encodings: {:?}", encodings);
103        send_client_encoding(&mut stream, encodings).await?;
104
105        trace!("Require the first frame");
106        input_ch_tx
107            .send(ClientMsg::FramebufferUpdateRequest(
108                Rect {
109                    x: 0,
110                    y: 0,
111                    width,
112                    height,
113                },
114                0,
115            ))
116            .await?;
117
118        // start the decoding thread
119        spawn(async move {
120            trace!("Decoding thread starts");
121            let mut conn_ch_rx = {
122                let conn_ch_rx = ReceiverStream::new(conn_ch_rx).into_async_read();
123                FuturesAsyncReadCompatExt::compat(conn_ch_rx)
124            };
125
126            let output_func = |e| async {
127                output_ch_tx.send(e).await?;
128                Ok(())
129            };
130
131            let pf = pixel_format.as_ref().unwrap();
132            if let Err(e) =
133                asycn_vnc_read_loop(&mut conn_ch_rx, pf, &output_func, decoding_stop_rx).await
134            {
135                if let VncError::IoError(e) = e {
136                    if let std::io::ErrorKind::UnexpectedEof = e.kind() {
137                        // this should be a normal case when the network connection disconnects
138                        // and we just send an EOF over the inner bridge between the process thread and the decode thread
139                        // do nothing here
140                    } else {
141                        error!("Error occurs during the decoding {:?}", e);
142                        let _ = output_func(VncEvent::Error(e.to_string())).await;
143                    }
144                } else {
145                    error!("Error occurs during the decoding {:?}", e);
146                    let _ = output_func(VncEvent::Error(e.to_string())).await;
147                }
148            }
149            trace!("Decoding thread stops");
150        });
151
152        // start the traffic process thread
153        spawn(async move {
154            trace!("Net Connection thread starts");
155            let _ =
156                async_connection_process_loop(stream, input_ch_rx, conn_ch_tx, net_conn_stop_rx)
157                    .await;
158            trace!("Net Connection thread stops");
159        });
160
161        info!("VNC Client {name} starts");
162        Ok(Self {
163            name,
164            screen: (width, height),
165            input_ch: input_ch_tx,
166            output_ch: output_ch_rx,
167            decoding_stop: Some(decoding_stop_tx),
168            net_conn_stop: Some(net_conn_stop_tx),
169            closed: false,
170        })
171    }
172
173    async fn input(&mut self, event: X11Event) -> Result<(), VncError> {
174        if self.closed {
175            Err(VncError::ClientNotRunning)
176        } else {
177            let msg = match event {
178                X11Event::Refresh => ClientMsg::FramebufferUpdateRequest(
179                    Rect {
180                        x: 0,
181                        y: 0,
182                        width: self.screen.0,
183                        height: self.screen.1,
184                    },
185                    1,
186                ),
187                X11Event::FullRefresh => ClientMsg::FramebufferUpdateRequest(
188                    Rect {
189                        x: 0,
190                        y: 0,
191                        width: self.screen.0,
192                        height: self.screen.1,
193                    },
194                    0, // non-incremental: server sends entire framebuffer
195                ),
196                X11Event::KeyEvent(key) => ClientMsg::KeyEvent(key.keycode, key.down),
197                X11Event::PointerEvent(mouse) => {
198                    ClientMsg::PointerEvent(mouse.position_x, mouse.position_y, mouse.bottons)
199                }
200                X11Event::CopyText(text) => ClientMsg::ClientCutText(text),
201            };
202            self.input_ch.send(msg).await?;
203            Ok(())
204        }
205    }
206
207    async fn recv_event(&mut self) -> Result<VncEvent, VncError> {
208        if self.closed {
209            Err(VncError::ClientNotRunning)
210        } else {
211            match self.output_ch.recv().await {
212                Some(e) => Ok(e),
213                None => {
214                    self.closed = true;
215                    Err(VncError::ClientNotRunning)
216                }
217            }
218        }
219    }
220
221    async fn poll_event(&mut self) -> Result<Option<VncEvent>, VncError> {
222        if self.closed {
223            Err(VncError::ClientNotRunning)
224        } else {
225            match self.output_ch.try_recv() {
226                Err(TryRecvError::Disconnected) => {
227                    self.closed = true;
228                    Err(VncError::ClientNotRunning)
229                }
230                Err(TryRecvError::Empty) => Ok(None),
231                Ok(e) => Ok(Some(e)),
232            }
233            // Ok(self.output_ch.recv().await)
234        }
235    }
236
237    /// Stop the VNC engine and release resources
238    ///
239    fn close(&mut self) -> Result<(), VncError> {
240        if self.net_conn_stop.is_some() {
241            let net_conn_stop: oneshot::Sender<()> = self.net_conn_stop.take().unwrap();
242            let _ = net_conn_stop.send(());
243        }
244        if self.decoding_stop.is_some() {
245            let decoding_stop = self.decoding_stop.take().unwrap();
246            let _ = decoding_stop.send(());
247        }
248        self.closed = true;
249        Ok(())
250    }
251}
252
253impl Drop for VncInner {
254    fn drop(&mut self) {
255        info!("VNC Client {} stops", self.name);
256        let _ = self.close();
257    }
258}
259
260pub struct VncClient {
261    inner: Arc<Mutex<VncInner>>,
262}
263
264impl VncClient {
265    pub(super) async fn new<S>(
266        stream: S,
267        shared: bool,
268        pixel_format: Option<PixelFormat>,
269        encodings: Vec<VncEncoding>,
270    ) -> Result<Self, VncError>
271    where
272        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
273    {
274        Ok(Self {
275            inner: Arc::new(Mutex::new(
276                VncInner::new(stream, shared, pixel_format, encodings).await?,
277            )),
278        })
279    }
280
281    /// Input a `X11Event` from the frontend
282    ///
283    pub async fn input(&self, event: X11Event) -> Result<(), VncError> {
284        self.inner.lock().await.input(event).await
285    }
286
287    /// Receive a `VncEvent` from the engine
288    /// This function will block until a `VncEvent` is received
289    ///
290    pub async fn recv_event(&self) -> Result<VncEvent, VncError> {
291        self.inner.lock().await.recv_event().await
292    }
293
294    /// polling `VncEvent` from the engine and give it to the client
295    ///
296    pub async fn poll_event(&self) -> Result<Option<VncEvent>, VncError> {
297        self.inner.lock().await.poll_event().await
298    }
299
300    /// Stop the VNC engine and release resources
301    ///
302    pub async fn close(&self) -> Result<(), VncError> {
303        self.inner.lock().await.close()
304    }
305}
306
307impl Clone for VncClient {
308    fn clone(&self) -> Self {
309        Self {
310            inner: self.inner.clone(),
311        }
312    }
313}
314
315async fn send_client_init<S>(stream: &mut S, shared: bool) -> Result<(), VncError>
316where
317    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
318{
319    trace!("Send shared flag: {}", shared);
320    stream.write_u8(shared as u8).await?;
321    Ok(())
322}
323
324async fn read_server_init<S, F, Fut>(
325    stream: &mut S,
326    pf: &mut Option<PixelFormat>,
327    output_func: &F,
328) -> Result<(String, (u16, u16)), VncError>
329where
330    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
331    F: Fn(VncEvent) -> Fut,
332    Fut: Future<Output = Result<(), VncError>>,
333{
334    // +--------------+--------------+------------------------------+
335    // | No. of bytes | Type [Value] | Description                  |
336    // +--------------+--------------+------------------------------+
337    // | 2            | U16          | framebuffer-width in pixels  |
338    // | 2            | U16          | framebuffer-height in pixels |
339    // | 16           | PIXEL_FORMAT | server-pixel-format          |
340    // | 4            | U32          | name-length                  |
341    // | name-length  | U8 array     | name-string                  |
342    // +--------------+--------------+------------------------------+
343
344    let screen_width = stream.read_u16().await?;
345    let screen_height = stream.read_u16().await?;
346    let mut send_our_pf = false;
347
348    output_func(VncEvent::SetResolution(
349        (screen_width, screen_height).into(),
350    ))
351    .await?;
352
353    let pixel_format = PixelFormat::read(stream).await?;
354    if pf.is_none() {
355        output_func(VncEvent::SetPixelFormat(pixel_format)).await?;
356        let _ = pf.insert(pixel_format);
357    } else {
358        send_our_pf = true;
359    }
360
361    let name_len = stream.read_u32().await?;
362    let mut name_buf = vec![0_u8; name_len as usize];
363    stream.read_exact(&mut name_buf).await?;
364    let name = String::from_utf8_lossy(&name_buf).into_owned();
365
366    if send_our_pf {
367        trace!("Send customized pixel format {:#?}", pf);
368        ClientMsg::SetPixelFormat(*pf.as_ref().unwrap())
369            .write(stream)
370            .await?;
371    }
372    Ok((name, (screen_width, screen_height)))
373}
374
375async fn send_client_encoding<S>(
376    stream: &mut S,
377    encodings: Vec<VncEncoding>,
378) -> Result<(), VncError>
379where
380    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
381{
382    ClientMsg::SetEncodings(encodings).write(stream).await?;
383    Ok(())
384}
385
386async fn asycn_vnc_read_loop<S, F, Fut>(
387    stream: &mut S,
388    pf: &PixelFormat,
389    output_func: &F,
390    mut stop_ch: oneshot::Receiver<()>,
391) -> Result<(), VncError>
392where
393    S: AsyncRead + Unpin,
394    F: Fn(VncEvent) -> Fut,
395    Fut: Future<Output = Result<(), VncError>>,
396{
397    let mut raw_decoder = codec::RawDecoder::new();
398    let mut zrle_decoder = codec::ZrleDecoder::new();
399    let mut tight_decoder = codec::TightDecoder::new();
400    let mut trle_decoder = codec::TrleDecoder::new();
401    let mut cursor = codec::CursorDecoder::new();
402
403    // main decoding loop
404    while let Err(oneshot::error::TryRecvError::Empty) = stop_ch.try_recv() {
405        let server_msg = ServerMsg::read(stream).await?;
406        trace!("Server message got: {:?}", server_msg);
407        match server_msg {
408            ServerMsg::FramebufferUpdate(rect_num) => {
409                for _ in 0..rect_num {
410                    let rect = ImageRect::read(stream).await?;
411                    // trace!("Encoding: {:?}", rect.encoding);
412
413                    match rect.encoding {
414                        VncEncoding::Raw => {
415                            raw_decoder
416                                .decode(pf, &rect.rect, stream, output_func)
417                                .await?;
418                        }
419                        VncEncoding::CopyRect => {
420                            let source_x = stream.read_u16().await?;
421                            let source_y = stream.read_u16().await?;
422                            let mut src_rect = rect.rect;
423                            src_rect.x = source_x;
424                            src_rect.y = source_y;
425                            output_func(VncEvent::Copy(rect.rect, src_rect)).await?;
426                        }
427                        VncEncoding::Tight => {
428                            tight_decoder
429                                .decode(pf, &rect.rect, stream, output_func)
430                                .await?;
431                        }
432                        VncEncoding::Trle => {
433                            trle_decoder
434                                .decode(pf, &rect.rect, stream, output_func)
435                                .await?;
436                        }
437                        VncEncoding::Zrle => {
438                            zrle_decoder
439                                .decode(pf, &rect.rect, stream, output_func)
440                                .await?;
441                        }
442                        VncEncoding::CursorPseudo => {
443                            cursor.decode(pf, &rect.rect, stream, output_func).await?;
444                        }
445                        VncEncoding::DesktopSizePseudo => {
446                            output_func(VncEvent::SetResolution(
447                                (rect.rect.width, rect.rect.height).into(),
448                            ))
449                            .await?;
450                        }
451                        VncEncoding::LastRectPseudo => {
452                            break;
453                        }
454                    }
455                }
456            }
457            // SetColorMapEntries,
458            ServerMsg::Bell => {
459                output_func(VncEvent::Bell).await?;
460            }
461            ServerMsg::ServerCutText(text) => {
462                output_func(VncEvent::Text(text)).await?;
463            }
464        }
465    }
466    Ok(())
467}
468
469async fn async_connection_process_loop<S>(
470    mut stream: S,
471    mut input_ch: Receiver<ClientMsg>,
472    conn_ch: Sender<std::io::Result<Vec<u8>>>,
473    mut stop_ch: oneshot::Receiver<()>,
474) -> Result<(), VncError>
475where
476    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
477{
478    let mut buffer = [0; 65535];
479    let mut pending = 0;
480
481    // main traffic loop
482    loop {
483        if pending > 0 {
484            match conn_ch.try_send(Ok(buffer[0..pending].to_owned())) {
485                Err(TrySendError::Full(_message)) => (),
486                Err(TrySendError::Closed(_message)) => break,
487                Ok(()) => pending = 0,
488            }
489        }
490
491        tokio::select! {
492            _ = &mut stop_ch => break,
493            result = stream.read(&mut buffer), if pending == 0 => {
494                match result {
495                    Ok(nread) => {
496                        if nread > 0 {
497                            match conn_ch.try_send(Ok(buffer[0..nread].to_owned())) {
498                                Err(TrySendError::Full(_message)) => pending = nread,
499                                Err(TrySendError::Closed(_message)) => break,
500                                Ok(()) => ()
501                            }
502                        } else {
503                            // According to the tokio's Doc
504                            // https://docs.rs/tokio/latest/tokio/io/trait.AsyncRead.html
505                            // if nread == 0, then EOF is reached
506                            trace!("Net Connection EOF detected");
507                            break;
508                        }
509                    }
510                    Err(e) => {
511                        error!("{}", e.to_string());
512                        break;
513                    }
514                }
515            }
516            Some(msg) = input_ch.recv() => {
517                msg.write(&mut stream).await?;
518            }
519        }
520    }
521
522    // notify the decoding thread
523    let _ = conn_ch
524        .send(Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof)))
525        .await;
526
527    Ok(())
528}