1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
use futures::StreamExt;
use serde::{de::DeserializeOwned, Serialize};

use crate::{
    common::{self},
    WebsocketConnection,
};
use common::should_terminate;
use log::{debug, error, trace};
use std::fmt::Debug;
use tokio::sync::mpsc::Sender;
use tokio::{net::TcpStream, select, signal::ctrl_c, sync::mpsc};
use tungstenite::error::Error;
use tungstenite::Message;

pub async fn spawn_server<
    Request: DeserializeOwned + Send + Sync + 'static,
    Response: Serialize + Debug + Send + Sync + 'static,
    F: Fn(&Response) -> bool + Send + Sync + 'static,
>(
    stream: TcpStream,
    is_close: F,
) -> anyhow::Result<(mpsc::Receiver<Request>, mpsc::Sender<Response>)> {
    let _addr = stream.peer_addr()?;
    let mut websocket = async_tungstenite::tokio::accept_async(stream).await?;

    let (mut tx_request, rx_request) = mpsc::channel::<Request>(4);
    let (tx_response, mut rx_response) = mpsc::channel::<Response>(4);

    tokio::spawn(async move {
        loop {
            select!(
                request = websocket.next() => {
                    if let None = request {
                        break;
                    }

                    trace!("request received: {:?}", &request);

                    let request = request.unwrap();
                    if let Err(e) = request {
                        match e {
                            Error::ConnectionClosed | Error::AlreadyClosed | Error::Protocol(_)=> {
                                break;
                            },
                            _ => {
                                error!("request error: {}", e);
                                continue;
                            }
                        }
                    }

                    let request = request.unwrap();
                    if should_terminate(&request) {
                        break;
                    }

                    server_process_request(&mut websocket, request, &mut tx_request).await;
                },
                response = rx_response.recv() => {
                    if !response.is_some()  {
                        common::send_close(&mut websocket).await;
                        break;
                    }

                    let response = response.unwrap();

                    if is_close(&response) {
                        common::send_close(&mut websocket).await;
                        continue;
                    }

                    debug!("send message: {:?}", &response);
                    common::send_message(&mut websocket, response).await;
                },
                _ = ctrl_c() => {
                    common::send_close(&mut websocket).await;
                },
            );
        }

        debug!("server loop terminated");
    });

    Ok((rx_request, tx_response))
}

async fn server_process_request<Request: DeserializeOwned>(
    _websocket: &mut WebsocketConnection,
    response: tungstenite::Message,
    target: &mut Sender<Request>,
) {
    if let Message::Close(_) = response {
        return;
    }

    let decoded = bincode::deserialize(response.into_data().as_slice());

    if let Err(e) = decoded {
        error!("failed to decode response: {}", e);
        return;
    }

    if let Err(e) = target.send(decoded.unwrap()).await {
        error!("failed to queue response: {}", e);
    }
}