stateroom_server/
lib.rs

1use crate::server::Event;
2use axum::{
3    extract::{ws::WebSocket, State, WebSocketUpgrade},
4    routing::get,
5    Router,
6};
7use server::ServerState;
8use stateroom::StateroomServiceFactory;
9use std::{
10    net::{IpAddr, SocketAddr},
11    sync::Arc,
12    time::Duration,
13};
14use tokio::{net::TcpListener, select};
15use tower_http::services::ServeDir;
16
17mod server;
18
19const DEFAULT_IP: &str = "0.0.0.0";
20
21#[derive(Debug)]
22pub struct Server {
23    /// The duration of time between server-initiated WebSocket heartbeats.
24    ///
25    /// Defaults to 30 seconds.
26    pub heartbeat_interval: Duration,
27
28    /// The minimum amount of time between client heartbeats before a connection is dropped.
29    ///
30    /// Defaults to 5 minutes.
31    pub heartbeat_timeout: Duration,
32
33    /// The port to run the server on. Defaults to 8080.
34    pub port: u16,
35
36    /// The IP to listen on. Defaults to 0.0.0.0.
37    pub ip: String,
38
39    /// A local filesystem path to serve static files from, or None (default).
40    pub static_path: Option<String>,
41
42    /// A local filesystem path to serve from /client, or None (default).
43    pub client_path: Option<String>,
44}
45
46impl Default for Server {
47    fn default() -> Self {
48        Server {
49            heartbeat_interval: Duration::from_secs(30),
50            heartbeat_timeout: Duration::from_secs(300),
51            port: 8080,
52            ip: DEFAULT_IP.to_string(),
53            static_path: None,
54            client_path: None,
55        }
56    }
57}
58
59impl Server {
60    #[must_use]
61    pub fn new() -> Self {
62        Server::default()
63    }
64
65    #[must_use]
66    pub fn with_static_path(mut self, static_path: Option<String>) -> Self {
67        self.static_path = static_path;
68        self
69    }
70
71    #[must_use]
72    pub fn with_client_path(mut self, client_path: Option<String>) -> Self {
73        self.client_path = client_path;
74        self
75    }
76
77    #[must_use]
78    pub fn with_heartbeat_interval(mut self, duration_seconds: u64) -> Self {
79        self.heartbeat_interval = Duration::from_secs(duration_seconds);
80        self
81    }
82
83    #[must_use]
84    pub fn with_heartbeat_timeout(mut self, duration_seconds: u64) -> Self {
85        self.heartbeat_timeout = Duration::from_secs(duration_seconds);
86        self
87    }
88
89    #[must_use]
90    pub fn with_port(mut self, port: u16) -> Self {
91        self.port = port;
92        self
93    }
94
95    #[must_use]
96    pub fn with_ip(mut self, ip: String) -> Self {
97        self.ip = ip;
98        self
99    }
100
101    /// Start a server given a [StateroomService].
102    ///
103    /// This function blocks until the server is terminated. While it is running, the following
104    /// endpoints are available:
105    /// - `/` (GET): return HTTP 200 if the server is running (useful as a baseline status check)
106    /// - `/ws` (GET): initiate a WebSocket connection to the stateroom service.
107    pub async fn serve_async(self, factory: impl StateroomServiceFactory) -> std::io::Result<()> {
108        let server_state = Arc::new(ServerState::new(factory));
109
110        let mut app = Router::new()
111            .route("/ws", get(serve_websocket))
112            .with_state(server_state);
113
114        if let Some(static_path) = self.static_path {
115            app = app.nest_service("/", ServeDir::new(static_path));
116        }
117
118        if let Some(client_path) = self.client_path {
119            app = app.nest_service("/client", ServeDir::new(client_path));
120        }
121
122        let ip = self.ip.parse::<IpAddr>().unwrap();
123        let addr = SocketAddr::new(ip, self.port);
124        let listener = TcpListener::bind(&addr).await?;
125        axum::serve(listener, app).await?;
126
127        Ok(())
128    }
129
130    /// Start a server given a [StateroomService].
131    ///
132    /// This function blocks until the server is terminated. While it is running, the following
133    /// endpoints are available:
134    /// - `/` (GET): return HTTP 200 if the server is running (useful as a baseline status check)
135    /// - `/ws` (GET): initiate a WebSocket connection to the stateroom service.
136    pub fn serve(self, factory: impl StateroomServiceFactory) -> std::io::Result<()> {
137        tokio::runtime::Builder::new_multi_thread()
138            .enable_all()
139            .build()
140            .unwrap()
141            .block_on(async { self.serve_async(factory).await })
142    }
143}
144
145pub async fn serve_websocket(
146    ws: WebSocketUpgrade,
147    State(state): State<Arc<ServerState>>,
148) -> axum::response::Response {
149    ws.on_upgrade(move |socket| handle_socket(socket, state))
150}
151
152async fn handle_socket(mut socket: WebSocket, state: Arc<ServerState>) {
153    let (send, mut recv, client_id) = state.connect();
154
155    loop {
156        select! {
157            msg = recv.recv() => {
158                match msg {
159                    Some(msg) => socket.send(msg).await.unwrap(),
160                    None => break,
161                }
162            },
163            msg = socket.recv() => {
164                match msg {
165                    Some(Ok(msg)) => send.send(Event::Message { client: client_id, message: msg }).await.unwrap(),
166                    Some(Err(_)) => todo!("Error receiving message from client."),
167                    None => break,
168                }
169            }
170        }
171    }
172
173    state.remove(&client_id);
174}