1use std::fs::File;
2use std::io::BufReader;
3use std::net::SocketAddr;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{FutureExt, SinkExt, StreamExt};
9use log::{debug, error, info, warn};
10use rustls::pki_types::{CertificateDer, PrivateKeyDer};
11use tokio::net::{TcpListener, TcpStream};
12use tokio::sync::mpsc;
13use tokio_rustls::TlsAcceptor;
14use tokio_tungstenite::tungstenite::Message;
15use tokio_tungstenite::accept_async;
16
17#[derive(Debug, Clone)]
19pub enum WsMessage {
20 Text(String),
22 Binary(Vec<u8>),
24 Close,
26}
27
28impl From<Message> for WsMessage {
29 fn from(msg: Message) -> Self {
30 match msg {
31 Message::Text(text) => WsMessage::Text(text),
32 Message::Binary(data) => WsMessage::Binary(data),
33 Message::Close(_) => WsMessage::Close,
34 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {
35 WsMessage::Text("".to_string())
37 }
38 }
39 }
40}
41
42impl From<WsMessage> for Message {
43 fn from(msg: WsMessage) -> Self {
44 match msg {
45 WsMessage::Text(text) => Message::Text(text),
46 WsMessage::Binary(data) => Message::Binary(data),
47 WsMessage::Close => Message::Close(None),
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Hash)]
54pub struct ClientId(pub String);
55
56#[derive(Debug, Clone)]
58pub struct WsServerConfig {
59 pub addr: String,
61 pub cert_path: PathBuf,
63 pub key_path: PathBuf,
65 pub ca_cert_path: PathBuf,
67 pub max_connections: usize,
69 pub connection_timeout: u64,
71 pub client_cert_required: bool,
73}
74
75impl Default for WsServerConfig {
76 fn default() -> Self {
77 Self {
78 addr: "127.0.0.1:9000".to_string(),
79 cert_path: PathBuf::from("./crate_cert/a_cert.pem"),
80 key_path: PathBuf::from("./crate_cert/a_key.pem"),
81 ca_cert_path: PathBuf::from("./crate_cert/ca_cert.pem"),
82 max_connections: 1000,
83 connection_timeout: 30,
84 client_cert_required: true,
85 }
86 }
87}
88
89pub trait ServerHandler: Send + Sync + 'static {
93 fn on_connect(&self, client_id: ClientId, addr: SocketAddr);
95
96 fn on_disconnect(&self, client_id: ClientId);
98
99 fn on_message(&self, client_id: ClientId, message: WsMessage) -> Option<WsMessage>;
103
104 fn on_error(&self, client_id: Option<ClientId>, error: String);
106}
107
108struct ClientConnection {
110 client_id: ClientId,
111 tx: mpsc::Sender<WsMessage>,
112}
113
114pub struct WsServer {
116 config: WsServerConfig,
117 handler: Arc<dyn ServerHandler>,
118 tls_acceptor: TlsAcceptor,
119 clients: Arc<tokio::sync::Mutex<Vec<ClientConnection>>>,
120}
121
122impl WsServer {
123 pub fn new(config: WsServerConfig, handler: impl ServerHandler) -> Result<Self, String> {
134 let tls_acceptor = Self::create_tls_acceptor(&config)
135 .map_err(|e| format!("Failed to create TLS acceptor: {}", e))?;
136
137 Ok(Self {
138 config,
139 handler: Arc::new(handler),
140 tls_acceptor,
141 clients: Arc::new(tokio::sync::Mutex::new(Vec::new())),
142 })
143 }
144
145 pub async fn start(&self) -> Result<(), String> {
151 let listener = TcpListener::bind(&self.config.addr)
152 .await
153 .map_err(|e| format!("Failed to bind to address {}: {}", self.config.addr, e))?;
154
155 info!("WebSocket server started on {}", self.config.addr);
156
157 loop {
158 match listener.accept().await {
159 Ok((stream, addr)) => {
160 debug!("New TCP connection from: {}", addr);
161
162 let acceptor = self.tls_acceptor.clone();
164 let handler = self.handler.clone();
165 let clients = self.clients.clone();
166 let connection_timeout = Duration::from_secs(self.config.connection_timeout);
167
168 let client_id = ClientId(format!("client-{}", uuid_simple()));
170 let client_id_clone = client_id.clone();
171
172 tokio::spawn(async move {
173 if let Err(e) = Self::handle_connection(
174 stream,
175 addr,
176 acceptor,
177 handler,
178 clients,
179 client_id_clone,
180 connection_timeout
181 ).await {
182 error!("Connection error for {}: {}", addr, e);
183 }
184 });
185 }
186 Err(e) => {
187 error!("Failed to accept connection: {}", e);
188 }
189 }
190
191 let client_count = self.clients.lock().await.len();
193 if client_count >= self.config.max_connections {
194 warn!("Maximum connections reached: {}", client_count);
195
196 tokio::time::sleep(Duration::from_millis(100)).await;
198 }
199 }
200 }
201
202 pub async fn broadcast(&self, message: WsMessage) -> Result<usize, String> {
212 let clients = self.clients.lock().await;
213 let mut sent_count = 0;
214
215 for client in clients.iter() {
216 if client.tx.send(message.clone()).await.is_ok() {
217 sent_count += 1;
218 }
219 }
220
221 Ok(sent_count)
222 }
223
224 pub async fn send_to_client(&self, client_id: &ClientId, message: WsMessage) -> Result<(), String> {
235 let clients = self.clients.lock().await;
236
237 for client in clients.iter() {
238 if client.client_id == *client_id {
239 return client.tx.send(message)
240 .await
241 .map_err(|_| format!("Failed to send message to client {}", client_id.0));
242 }
243 }
244
245 Err(format!("Client not found: {}", client_id.0))
246 }
247
248 pub async fn client_count(&self) -> usize {
254 self.clients.lock().await.len()
255 }
256
257 pub async fn client_list(&self) -> Vec<ClientId> {
263 let clients = self.clients.lock().await;
264 clients.iter().map(|c| c.client_id.clone()).collect()
265 }
266
267 async fn handle_connection(
269 stream: TcpStream,
270 addr: SocketAddr,
271 acceptor: TlsAcceptor,
272 handler: Arc<dyn ServerHandler>,
273 clients: Arc<tokio::sync::Mutex<Vec<ClientConnection>>>,
274 client_id: ClientId,
275 connection_timeout: Duration,
276 ) -> Result<(), String> {
277 let tls_handshake = tokio::time::timeout(
279 connection_timeout,
280 acceptor.accept(stream),
281 ).await
282 .map_err(|_| format!("TLS handshake timed out after {} seconds", connection_timeout.as_secs()))?
283 .map_err(|e| format!("TLS handshake failed: {}", e))?;
284
285 debug!("TLS handshake successful for {}", addr);
286
287 let ws_stream = tokio::time::timeout(
289 connection_timeout,
290 accept_async(tls_handshake),
291 ).await
292 .map_err(|_| format!("WebSocket handshake timed out after {} seconds", connection_timeout.as_secs()))?
293 .map_err(|e| format!("WebSocket handshake failed: {}", e))?;
294
295 debug!("WebSocket handshake successful for {}", addr);
296
297 let (tx, mut rx) = mpsc::channel::<WsMessage>(100);
299
300 {
302 let mut clients_lock = clients.lock().await;
303 clients_lock.push(ClientConnection {
304 client_id: client_id.clone(),
305 tx: tx.clone(),
306 });
307
308 info!("Client connected: {} from {}", client_id.0, addr);
309 }
310
311 handler.on_connect(client_id.clone(), addr);
313
314 let (ws_sender, ws_receiver) = ws_stream.split();
316
317 let mut send_task = {
319 let mut ws_sender = ws_sender;
320 let client_id_for_send = client_id.clone();
321 let handler_for_send = handler.clone();
322
323 async move {
324 while let Some(msg) = rx.recv().await {
325 match ws_sender.send(msg.into()).await {
326 Ok(_) => {
327 debug!("Message sent to client {}", client_id_for_send.0);
328 }
329 Err(e) => {
330 let error_msg = format!("Failed to send message: {}", e);
331 handler_for_send.on_error(Some(client_id_for_send.clone()), error_msg);
332 break;
333 }
334 }
335 }
336
337 let _ = ws_sender.close().await;
339
340 debug!("Send task completed for client {}", client_id_for_send.0);
341 }.boxed()
342 };
343
344 let mut receive_task = {
346 let mut ws_receiver = ws_receiver;
347 let handler_for_recv = handler.clone();
348 let client_id_for_recv = client_id.clone();
349 let tx_for_recv = tx.clone();
350
351 async move {
352 while let Some(result) = ws_receiver.next().await {
353 match result {
354 Ok(msg) => {
355 if msg.is_close() {
356 debug!("Client {} requested close", client_id_for_recv.0);
357 break;
358 }
359
360 let ws_msg = WsMessage::from(msg);
362
363 if let Some(response) = handler_for_recv.on_message(client_id_for_recv.clone(), ws_msg) {
365 if tx_for_recv.send(response).await.is_err() {
367 break;
368 }
369 }
370 }
371 Err(e) => {
372 let error_msg = format!("Error receiving message: {}", e);
373 handler_for_recv.on_error(Some(client_id_for_recv.clone()), error_msg);
374 break;
375 }
376 }
377 }
378
379 debug!("Receive task completed for client {}", client_id_for_recv.0);
380 }.boxed()
381 };
382
383 tokio::select! {
385 _ = &mut send_task => {},
386 _ = &mut receive_task => {},
387 }
388
389 Self::remove_client(clients, client_id.clone()).await;
391 handler.on_disconnect(client_id.clone());
392
393 info!("Client disconnected: {} from {}", client_id.0, addr);
394 Ok(())
395 }
396
397 async fn remove_client(
399 clients: Arc<tokio::sync::Mutex<Vec<ClientConnection>>>,
400 client_id: ClientId,
401 ) {
402 let mut clients_lock = clients.lock().await;
403 if let Some(pos) = clients_lock.iter().position(|c| c.client_id == client_id) {
404 clients_lock.remove(pos);
405 }
406 }
407
408 fn create_tls_acceptor(config: &WsServerConfig) -> Result<TlsAcceptor, Box<dyn std::error::Error>> {
410 info!("Loading certificates and keys...");
412 let certs = load_certs(&config.cert_path)?;
413 let key = load_private_key(&config.key_path)?;
414 let ca_certs = load_certs(&config.ca_cert_path)?;
415
416 let mut root_cert_store = rustls::RootCertStore::empty();
418 for cert in ca_certs {
419 root_cert_store.add(cert)?;
420 }
421
422 let server_config = if config.client_cert_required {
424 let client_verifier = rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store))
426 .build()?;
427
428 rustls::ServerConfig::builder()
430 .with_client_cert_verifier(client_verifier)
431 .with_single_cert(certs, key)?
432 } else {
433 rustls::ServerConfig::builder()
435 .with_no_client_auth()
436 .with_single_cert(certs, key)?
437 };
438
439 Ok(TlsAcceptor::from(Arc::new(server_config)))
441 }
442}
443
444fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error>> {
454 let file = File::open(path)?;
455 let mut reader = BufReader::new(file);
456 let mut certs = Vec::new();
457
458 for cert_result in rustls_pemfile::certs(&mut reader) {
459 let cert = cert_result?;
460 certs.push(cert);
461 }
462
463 if certs.is_empty() {
464 return Err(format!("No certificates found in {}", path.display()).into());
465 }
466
467 Ok(certs)
468}
469
470fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error>> {
480 let file = File::open(path)?;
481 let mut reader = BufReader::new(file);
482
483 let mut pkcs8_keys = Vec::new();
485 for key_result in rustls_pemfile::pkcs8_private_keys(&mut reader) {
486 pkcs8_keys.push(key_result?);
487 }
488
489 if !pkcs8_keys.is_empty() {
490 return Ok(PrivateKeyDer::Pkcs8(pkcs8_keys.remove(0)));
491 }
492
493 reader = BufReader::new(File::open(path)?);
495
496 let mut rsa_keys = Vec::new();
498 for key_result in rustls_pemfile::rsa_private_keys(&mut reader) {
499 rsa_keys.push(key_result?);
500 }
501
502 if !rsa_keys.is_empty() {
503 return Ok(PrivateKeyDer::Pkcs1(rsa_keys.remove(0)));
504 }
505
506 reader = BufReader::new(File::open(path)?);
508
509 let mut ec_keys = Vec::new();
511 for key_result in rustls_pemfile::ec_private_keys(&mut reader) {
512 ec_keys.push(key_result?);
513 }
514
515 if !ec_keys.is_empty() {
516 return Ok(PrivateKeyDer::Sec1(ec_keys.remove(0)));
517 }
518
519 Err(format!("No private keys found in {}", path.display()).into())
520}
521
522fn uuid_simple() -> String {
524 use std::time::{SystemTime, UNIX_EPOCH};
525 let now = SystemTime::now()
526 .duration_since(UNIX_EPOCH)
527 .unwrap_or_default();
528
529 format!(
530 "{:x}{:x}",
531 now.as_secs(),
532 now.subsec_nanos()
533 )
534}