1use crate::server::Event;
2use axum::{
3 extract::{ws::WebSocket, State, WebSocketUpgrade},
4 routing::get,
5 Router,
6};
7use server::ServerState;
8use stateroom::StateroomServiceFactory;
9use std::{
10 net::{IpAddr, SocketAddr},
11 sync::Arc,
12 time::Duration,
13};
14use tokio::{net::TcpListener, select};
15use tower_http::services::ServeDir;
16
17mod server;
18
19const DEFAULT_IP: &str = "0.0.0.0";
20
21#[derive(Debug)]
22pub struct Server {
23 pub heartbeat_interval: Duration,
27
28 pub heartbeat_timeout: Duration,
32
33 pub port: u16,
35
36 pub ip: String,
38
39 pub static_path: Option<String>,
41
42 pub client_path: Option<String>,
44}
45
46impl Default for Server {
47 fn default() -> Self {
48 Server {
49 heartbeat_interval: Duration::from_secs(30),
50 heartbeat_timeout: Duration::from_secs(300),
51 port: 8080,
52 ip: DEFAULT_IP.to_string(),
53 static_path: None,
54 client_path: None,
55 }
56 }
57}
58
59impl Server {
60 #[must_use]
61 pub fn new() -> Self {
62 Server::default()
63 }
64
65 #[must_use]
66 pub fn with_static_path(mut self, static_path: Option<String>) -> Self {
67 self.static_path = static_path;
68 self
69 }
70
71 #[must_use]
72 pub fn with_client_path(mut self, client_path: Option<String>) -> Self {
73 self.client_path = client_path;
74 self
75 }
76
77 #[must_use]
78 pub fn with_heartbeat_interval(mut self, duration_seconds: u64) -> Self {
79 self.heartbeat_interval = Duration::from_secs(duration_seconds);
80 self
81 }
82
83 #[must_use]
84 pub fn with_heartbeat_timeout(mut self, duration_seconds: u64) -> Self {
85 self.heartbeat_timeout = Duration::from_secs(duration_seconds);
86 self
87 }
88
89 #[must_use]
90 pub fn with_port(mut self, port: u16) -> Self {
91 self.port = port;
92 self
93 }
94
95 #[must_use]
96 pub fn with_ip(mut self, ip: String) -> Self {
97 self.ip = ip;
98 self
99 }
100
101 pub async fn serve_async(self, factory: impl StateroomServiceFactory) -> std::io::Result<()> {
108 let server_state = Arc::new(ServerState::new(factory));
109
110 let mut app = Router::new()
111 .route("/ws", get(serve_websocket))
112 .with_state(server_state);
113
114 if let Some(static_path) = self.static_path {
115 app = app.nest_service("/", ServeDir::new(static_path));
116 }
117
118 if let Some(client_path) = self.client_path {
119 app = app.nest_service("/client", ServeDir::new(client_path));
120 }
121
122 let ip = self.ip.parse::<IpAddr>().unwrap();
123 let addr = SocketAddr::new(ip, self.port);
124 let listener = TcpListener::bind(&addr).await?;
125 axum::serve(listener, app).await?;
126
127 Ok(())
128 }
129
130 pub fn serve(self, factory: impl StateroomServiceFactory) -> std::io::Result<()> {
137 tokio::runtime::Builder::new_multi_thread()
138 .enable_all()
139 .build()
140 .unwrap()
141 .block_on(async { self.serve_async(factory).await })
142 }
143}
144
145pub async fn serve_websocket(
146 ws: WebSocketUpgrade,
147 State(state): State<Arc<ServerState>>,
148) -> axum::response::Response {
149 ws.on_upgrade(move |socket| handle_socket(socket, state))
150}
151
152async fn handle_socket(mut socket: WebSocket, state: Arc<ServerState>) {
153 let (send, mut recv, client_id) = state.connect();
154
155 loop {
156 select! {
157 msg = recv.recv() => {
158 match msg {
159 Some(msg) => socket.send(msg).await.unwrap(),
160 None => break,
161 }
162 },
163 msg = socket.recv() => {
164 match msg {
165 Some(Ok(msg)) => send.send(Event::Message { client: client_id, message: msg }).await.unwrap(),
166 Some(Err(_)) => todo!("Error receiving message from client."),
167 None => break,
168 }
169 }
170 }
171 }
172
173 state.remove(&client_id);
174}