1#![warn(missing_docs)]
9
10use async_trait::async_trait;
11use futures_util::{SinkExt, StreamExt};
12use serde::{Serialize, de::DeserializeOwned};
13use std::{collections::HashMap, fmt, net::SocketAddr, sync::Arc, time::Duration};
14use tokio::sync::{RwLock, broadcast, mpsc};
15use tokio_tungstenite::tungstenite::protocol::Message as WsMessage;
16
17#[derive(Debug)]
19pub enum WebSocketError {
20 ConnectionFailed(String),
22
23 ConnectionClosed,
25
26 SendFailed(String),
28
29 ReceiveFailed(String),
31
32 SerializationFailed(String),
34
35 DeserializationFailed(String),
37
38 RoomNotFound(String),
40
41 ConnectionNotFound(String),
43
44 MaxConnectionsExceeded(u32),
46
47 Timeout(String),
49
50 Internal(String),
52}
53
54impl fmt::Display for WebSocketError {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 match self {
57 WebSocketError::ConnectionFailed(msg) => write!(f, "WebSocket connection failed: {}", msg),
58 WebSocketError::ConnectionClosed => write!(f, "WebSocket connection closed"),
59 WebSocketError::SendFailed(msg) => write!(f, "Failed to send message: {}", msg),
60 WebSocketError::ReceiveFailed(msg) => write!(f, "Failed to receive message: {}", msg),
61 WebSocketError::SerializationFailed(msg) => write!(f, "Serialization failed: {}", msg),
62 WebSocketError::DeserializationFailed(msg) => write!(f, "Deserialization failed: {}", msg),
63 WebSocketError::RoomNotFound(msg) => write!(f, "Room not found: {}", msg),
64 WebSocketError::ConnectionNotFound(msg) => write!(f, "Connection not found: {}", msg),
65 WebSocketError::MaxConnectionsExceeded(max) => write!(f, "Maximum connections exceeded: {}", max),
66 WebSocketError::Timeout(msg) => write!(f, "Operation timeout: {}", msg),
67 WebSocketError::Internal(msg) => write!(f, "WebSocket internal error: {}", msg),
68 }
69 }
70}
71
72impl std::error::Error for WebSocketError {}
73
74pub type WebSocketResult<T> = Result<T, WebSocketError>;
76
77pub type ConnectionId = String;
79
80pub type RoomId = String;
82
83#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum Message {
86 Text(String),
88 Binary(Vec<u8>),
90 Ping,
92 Pong,
94 Close,
96}
97
98impl Message {
99 pub fn text(content: impl Into<String>) -> Self {
101 Message::Text(content.into())
102 }
103
104 pub fn binary(data: impl Into<Vec<u8>>) -> Self {
106 Message::Binary(data.into())
107 }
108
109 pub fn is_text(&self) -> bool {
111 matches!(self, Message::Text(_))
112 }
113
114 pub fn is_binary(&self) -> bool {
116 matches!(self, Message::Binary(_))
117 }
118
119 pub fn as_text(&self) -> Option<&str> {
121 match self {
122 Message::Text(s) => Some(s),
123 _ => None,
124 }
125 }
126
127 pub fn as_binary(&self) -> Option<&[u8]> {
129 match self {
130 Message::Binary(data) => Some(data),
131 _ => None,
132 }
133 }
134}
135
136impl From<WsMessage> for Message {
137 fn from(msg: WsMessage) -> Self {
138 match msg {
139 WsMessage::Text(s) => Message::Text(s.to_string()),
140 WsMessage::Binary(data) => Message::Binary(data.to_vec()),
141 WsMessage::Ping(_) => Message::Ping,
142 WsMessage::Pong(_) => Message::Pong,
143 WsMessage::Close(_) => Message::Close,
144 _ => Message::Close,
145 }
146 }
147}
148
149impl From<Message> for WsMessage {
150 fn from(msg: Message) -> Self {
151 match msg {
152 Message::Text(s) => WsMessage::Text(s.into()),
153 Message::Binary(data) => WsMessage::Binary(data.into()),
154 Message::Ping => WsMessage::Ping(Vec::new().into()),
155 Message::Pong => WsMessage::Pong(Vec::new().into()),
156 Message::Close => WsMessage::Close(None),
157 }
158 }
159}
160
161#[derive(Debug, Clone)]
163pub struct Connection {
164 pub id: ConnectionId,
166 pub addr: SocketAddr,
168 pub connected_at: std::time::Instant,
170 pub metadata: HashMap<String, String>,
172 pub rooms: Vec<RoomId>,
174}
175
176impl Connection {
177 pub fn new(id: ConnectionId, addr: SocketAddr) -> Self {
179 Self { id, addr, connected_at: std::time::Instant::now(), metadata: HashMap::new(), rooms: Vec::new() }
180 }
181
182 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
184 self.metadata.insert(key.into(), value.into());
185 self
186 }
187
188 pub fn duration(&self) -> Duration {
190 self.connected_at.elapsed()
191 }
192}
193
194pub struct ConnectionManager {
196 connections: Arc<RwLock<HashMap<ConnectionId, Connection>>>,
197 max_connections: u32,
198}
199
200impl ConnectionManager {
201 pub fn new(max_connections: u32) -> Self {
203 Self { connections: Arc::new(RwLock::new(HashMap::new())), max_connections }
204 }
205
206 pub async fn add(&self, connection: Connection) -> WebSocketResult<()> {
208 let mut connections = self.connections.write().await;
209 if connections.len() >= self.max_connections as usize {
210 return Err(WebSocketError::MaxConnectionsExceeded(self.max_connections));
211 }
212 connections.insert(connection.id.clone(), connection);
213 Ok(())
214 }
215
216 pub async fn remove(&self, id: &str) -> Option<Connection> {
218 let mut connections = self.connections.write().await;
219 connections.remove(id)
220 }
221
222 pub async fn get(&self, id: &str) -> Option<Connection> {
224 let connections = self.connections.read().await;
225 connections.get(id).cloned()
226 }
227
228 pub async fn exists(&self, id: &str) -> bool {
230 let connections = self.connections.read().await;
231 connections.contains_key(id)
232 }
233
234 pub async fn count(&self) -> usize {
236 let connections = self.connections.read().await;
237 connections.len()
238 }
239
240 pub async fn all_ids(&self) -> Vec<ConnectionId> {
242 let connections = self.connections.read().await;
243 connections.keys().cloned().collect()
244 }
245
246 pub async fn join_room(&self, id: &str, room: &str) -> WebSocketResult<()> {
248 let mut connections = self.connections.write().await;
249 if let Some(conn) = connections.get_mut(id) {
250 if !conn.rooms.contains(&room.to_string()) {
251 conn.rooms.push(room.to_string());
252 }
253 return Ok(());
254 }
255 Err(WebSocketError::ConnectionNotFound(id.to_string()))
256 }
257
258 pub async fn leave_room(&self, id: &str, room: &str) -> WebSocketResult<()> {
260 let mut connections = self.connections.write().await;
261 if let Some(conn) = connections.get_mut(id) {
262 conn.rooms.retain(|r| r != room);
263 return Ok(());
264 }
265 Err(WebSocketError::ConnectionNotFound(id.to_string()))
266 }
267}
268
269pub struct RoomManager {
271 rooms: Arc<RwLock<HashMap<RoomId, Vec<ConnectionId>>>>,
272}
273
274impl RoomManager {
275 pub fn new() -> Self {
277 Self { rooms: Arc::new(RwLock::new(HashMap::new())) }
278 }
279
280 pub async fn create_room(&self, room_id: &str) {
282 let mut rooms = self.rooms.write().await;
283 rooms.entry(room_id.to_string()).or_insert_with(Vec::new);
284 }
285
286 pub async fn delete_room(&self, room_id: &str) -> Option<Vec<ConnectionId>> {
288 let mut rooms = self.rooms.write().await;
289 rooms.remove(room_id)
290 }
291
292 pub async fn join(&self, room_id: &str, connection_id: &str) {
294 let mut rooms = self.rooms.write().await;
295 let room = rooms.entry(room_id.to_string()).or_insert_with(Vec::new);
296 if !room.contains(&connection_id.to_string()) {
297 room.push(connection_id.to_string());
298 }
299 }
300
301 pub async fn leave(&self, room_id: &str, connection_id: &str) {
303 let mut rooms = self.rooms.write().await;
304 if let Some(room) = rooms.get_mut(room_id) {
305 room.retain(|id| id != connection_id);
306 if room.is_empty() {
307 rooms.remove(room_id);
308 }
309 }
310 }
311
312 pub async fn get_members(&self, room_id: &str) -> Vec<ConnectionId> {
314 let rooms = self.rooms.read().await;
315 rooms.get(room_id).cloned().unwrap_or_default()
316 }
317
318 pub async fn room_exists(&self, room_id: &str) -> bool {
320 let rooms = self.rooms.read().await;
321 rooms.contains_key(room_id)
322 }
323
324 pub async fn room_count(&self) -> usize {
326 let rooms = self.rooms.read().await;
327 rooms.len()
328 }
329
330 pub async fn member_count(&self, room_id: &str) -> usize {
332 let rooms = self.rooms.read().await;
333 rooms.get(room_id).map(|r| r.len()).unwrap_or(0)
334 }
335
336 pub async fn broadcast(&self, room_id: &str, sender: &Sender, message: &Message) -> WebSocketResult<Vec<ConnectionId>> {
338 let members = self.get_members(room_id).await;
339 let mut sent_to = Vec::new();
340 for conn_id in &members {
341 if sender.send_to(conn_id, message.clone()).await.is_ok() {
342 sent_to.push(conn_id.clone());
343 }
344 }
345 Ok(sent_to)
346 }
347}
348
349impl Default for RoomManager {
350 fn default() -> Self {
351 Self::new()
352 }
353}
354
355#[derive(Clone)]
357pub struct Sender {
358 senders: Arc<RwLock<HashMap<ConnectionId, mpsc::UnboundedSender<Message>>>>,
359}
360
361impl Sender {
362 pub fn new() -> Self {
364 Self { senders: Arc::new(RwLock::new(HashMap::new())) }
365 }
366
367 pub async fn register(&self, connection_id: ConnectionId, sender: mpsc::UnboundedSender<Message>) {
369 let mut senders = self.senders.write().await;
370 senders.insert(connection_id, sender);
371 }
372
373 pub async fn unregister(&self, connection_id: &str) {
375 let mut senders = self.senders.write().await;
376 senders.remove(connection_id);
377 }
378
379 pub async fn send_to(&self, connection_id: &str, message: Message) -> WebSocketResult<()> {
381 let senders = self.senders.read().await;
382 if let Some(sender) = senders.get(connection_id) {
383 sender.send(message).map_err(|e| WebSocketError::SendFailed(e.to_string()))?;
384 return Ok(());
385 }
386 Err(WebSocketError::ConnectionNotFound(connection_id.to_string()))
387 }
388
389 pub async fn broadcast(&self, message: Message) -> WebSocketResult<usize> {
391 let senders = self.senders.read().await;
392 let mut count = 0;
393 for sender in senders.values() {
394 if sender.send(message.clone()).is_ok() {
395 count += 1;
396 }
397 }
398 Ok(count)
399 }
400
401 pub async fn count(&self) -> usize {
403 let senders = self.senders.read().await;
404 senders.len()
405 }
406}
407
408impl Default for Sender {
409 fn default() -> Self {
410 Self::new()
411 }
412}
413
414#[derive(Debug, Clone)]
416pub struct ServerConfig {
417 pub host: String,
419 pub port: u16,
421 pub max_connections: u32,
423 pub heartbeat_interval: Duration,
425 pub connection_timeout: Duration,
427}
428
429impl Default for ServerConfig {
430 fn default() -> Self {
431 Self {
432 host: "0.0.0.0".to_string(),
433 port: 8080,
434 max_connections: 1000,
435 heartbeat_interval: Duration::from_secs(30),
436 connection_timeout: Duration::from_secs(60),
437 }
438 }
439}
440
441impl ServerConfig {
442 pub fn new() -> Self {
444 Self::default()
445 }
446
447 pub fn host(mut self, host: impl Into<String>) -> Self {
449 self.host = host.into();
450 self
451 }
452
453 pub fn port(mut self, port: u16) -> Self {
455 self.port = port;
456 self
457 }
458
459 pub fn max_connections(mut self, max: u32) -> Self {
461 self.max_connections = max;
462 self
463 }
464
465 pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
467 self.heartbeat_interval = interval;
468 self
469 }
470}
471
472#[async_trait]
474pub trait ClientHandler: Send + Sync {
475 async fn on_connect(&self, connection: &Connection) -> WebSocketResult<()>;
477
478 async fn on_message(&self, connection: &Connection, message: Message) -> WebSocketResult<()>;
480
481 async fn on_disconnect(&self, connection: &Connection);
483}
484
485pub struct DefaultClientHandler;
487
488#[async_trait]
489impl ClientHandler for DefaultClientHandler {
490 async fn on_connect(&self, _connection: &Connection) -> WebSocketResult<()> {
491 Ok(())
492 }
493
494 async fn on_message(&self, _connection: &Connection, _message: Message) -> WebSocketResult<()> {
495 Ok(())
496 }
497
498 async fn on_disconnect(&self, _connection: &Connection) {}
499}
500
501pub struct WebSocketServer {
503 config: ServerConfig,
504 connection_manager: Arc<ConnectionManager>,
505 room_manager: Arc<RoomManager>,
506 sender: Sender,
507 shutdown_tx: broadcast::Sender<()>,
508}
509
510impl WebSocketServer {
511 pub fn new(config: ServerConfig) -> Self {
513 let (shutdown_tx, _) = broadcast::channel(1);
514 Self {
515 config,
516 connection_manager: Arc::new(ConnectionManager::new(1000)),
517 room_manager: Arc::new(RoomManager::new()),
518 sender: Sender::new(),
519 shutdown_tx,
520 }
521 }
522
523 pub fn connection_manager(&self) -> &Arc<ConnectionManager> {
525 &self.connection_manager
526 }
527
528 pub fn room_manager(&self) -> &Arc<RoomManager> {
530 &self.room_manager
531 }
532
533 pub fn sender(&self) -> &Sender {
535 &self.sender
536 }
537
538 pub fn config(&self) -> &ServerConfig {
540 &self.config
541 }
542
543 pub async fn start<H: ClientHandler + 'static>(&self, handler: H) -> WebSocketResult<()> {
545 let addr = format!("{}:{}", self.config.host, self.config.port);
546 let listener =
547 tokio::net::TcpListener::bind(&addr).await.map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
548
549 tracing::info!("WebSocket server listening on {}", addr);
550
551 let mut shutdown_rx = self.shutdown_tx.subscribe();
552 let handler = Arc::new(handler);
553
554 loop {
555 tokio::select! {
556 accept_result = listener.accept() => {
557 match accept_result {
558 Ok((stream, addr)) => {
559 let connection_manager = self.connection_manager.clone();
560 let room_manager = self.room_manager.clone();
561 let sender = self.sender.clone();
562 let handler = handler.clone();
563 let config = self.config.clone();
564
565 tokio::spawn(async move {
566 if let Err(e) = Self::handle_connection(
567 stream,
568 addr,
569 connection_manager,
570 room_manager,
571 sender,
572 handler,
573 config,
574 ).await {
575 tracing::error!("Connection error: {}", e);
576 }
577 });
578 }
579 Err(e) => {
580 tracing::error!("Accept error: {}", e);
581 }
582 }
583 }
584 _ = shutdown_rx.recv() => {
585 tracing::info!("WebSocket server shutting down");
586 break;
587 }
588 }
589 }
590
591 Ok(())
592 }
593
594 async fn handle_connection<H: ClientHandler>(
595 stream: tokio::net::TcpStream,
596 addr: SocketAddr,
597 connection_manager: Arc<ConnectionManager>,
598 room_manager: Arc<RoomManager>,
599 sender: Sender,
600 handler: Arc<H>,
601 config: ServerConfig,
602 ) -> WebSocketResult<()> {
603 let ws_stream =
604 tokio_tungstenite::accept_async(stream).await.map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
605
606 let connection_id = uuid::Uuid::new_v4().to_string();
607 let connection = Connection::new(connection_id.clone(), addr);
608
609 if connection_manager.add(connection.clone()).await.is_err() {
610 return Err(WebSocketError::MaxConnectionsExceeded(config.max_connections));
611 }
612
613 handler.on_connect(&connection).await?;
614 tracing::info!("Client connected: {} from {}", connection_id, addr);
615
616 let (ws_sender, mut ws_receiver) = ws_stream.split();
617 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
618
619 sender.register(connection_id.clone(), tx).await;
620
621 let send_task = async move {
622 let mut ws_sender = ws_sender;
623 while let Some(msg) = rx.recv().await {
624 if ws_sender.send(msg.into()).await.is_err() {
625 break;
626 }
627 }
628 let _ = ws_sender.close().await;
629 };
630
631 let connection_manager_clone = connection_manager.clone();
632 let room_manager_clone = room_manager.clone();
633 let sender_clone = sender.clone();
634 let connection_id_clone = connection_id.clone();
635 let connection_clone = connection.clone();
636 let handler_clone = handler.clone();
637 let recv_task = async move {
638 while let Some(msg_result) = ws_receiver.next().await {
639 match msg_result {
640 Ok(ws_msg) => {
641 let msg: Message = ws_msg.into();
642 if matches!(msg, Message::Close) {
643 break;
644 }
645 if handler_clone.on_message(&connection_clone, msg).await.is_err() {
646 break;
647 }
648 }
649 Err(_) => break,
650 }
651 }
652 };
653
654 tokio::select! {
655 _ = send_task => {},
656 _ = recv_task => {},
657 }
658
659 for room_id in &connection.rooms {
660 room_manager_clone.leave(room_id, &connection_id_clone).await;
661 }
662
663 connection_manager_clone.remove(&connection_id_clone).await;
664 sender_clone.unregister(&connection_id_clone).await;
665 handler.on_disconnect(&connection).await;
666
667 tracing::info!("Client disconnected: {}", connection_id);
668
669 Ok(())
670 }
671
672 pub fn shutdown(&self) {
674 let _ = self.shutdown_tx.send(());
675 }
676
677 pub async fn broadcast(&self, message: Message) -> WebSocketResult<usize> {
679 self.sender.broadcast(message).await
680 }
681
682 pub async fn broadcast_to_room(&self, room_id: &str, message: Message) -> WebSocketResult<Vec<ConnectionId>> {
684 self.room_manager.broadcast(room_id, &self.sender, &message).await
685 }
686}
687
688#[derive(Debug, Clone)]
690pub struct ClientConfig {
691 pub url: String,
693 pub reconnect_interval: Duration,
695 pub heartbeat_interval: Duration,
697 pub connection_timeout: Duration,
699 pub max_reconnect_attempts: u32,
701}
702
703impl Default for ClientConfig {
704 fn default() -> Self {
705 Self {
706 url: "ws://127.0.0.1:8080".to_string(),
707 reconnect_interval: Duration::from_secs(5),
708 heartbeat_interval: Duration::from_secs(30),
709 connection_timeout: Duration::from_secs(10),
710 max_reconnect_attempts: 0,
711 }
712 }
713}
714
715impl ClientConfig {
716 pub fn new(url: impl Into<String>) -> Self {
718 Self { url: url.into(), ..Self::default() }
719 }
720
721 pub fn reconnect_interval(mut self, interval: Duration) -> Self {
723 self.reconnect_interval = interval;
724 self
725 }
726
727 pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
729 self.heartbeat_interval = interval;
730 self
731 }
732
733 pub fn max_reconnect_attempts(mut self, attempts: u32) -> Self {
735 self.max_reconnect_attempts = attempts;
736 self
737 }
738}
739
740pub struct WebSocketClient {
742 config: ClientConfig,
743 sender: mpsc::UnboundedSender<Message>,
744 receiver: mpsc::UnboundedReceiver<Message>,
745}
746
747impl WebSocketClient {
748 pub fn new(config: ClientConfig) -> Self {
750 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::<Message>();
751 let (incoming_tx, incoming_rx) = mpsc::unbounded_channel::<Message>();
752
753 let config_clone = config.clone();
754
755 tokio::spawn(async move {
756 let mut attempt = 0u32;
757 loop {
758 match tokio_tungstenite::connect_async(&config_clone.url).await {
759 Ok((ws_stream, _)) => {
760 tracing::info!("WebSocket client connected to {}", config_clone.url);
761 attempt = 0;
762
763 let (mut ws_sender, mut ws_receiver) = ws_stream.split();
764
765 let send_task = async {
766 while let Some(msg) = outgoing_rx.recv().await {
767 if ws_sender.send(msg.into()).await.is_err() {
768 break;
769 }
770 }
771 };
772
773 let recv_task = async {
774 while let Some(msg_result) = ws_receiver.next().await {
775 match msg_result {
776 Ok(ws_msg) => {
777 let msg: Message = ws_msg.into();
778 if matches!(msg, Message::Close) {
779 break;
780 }
781 if incoming_tx.send(msg).is_err() {
782 break;
783 }
784 }
785 Err(_) => break,
786 }
787 }
788 };
789
790 tokio::select! {
791 _ = send_task => {},
792 _ = recv_task => {},
793 }
794
795 tracing::warn!("WebSocket client disconnected, attempting to reconnect...");
796 }
797 Err(e) => {
798 tracing::error!("WebSocket connection failed: {}", e);
799 }
800 }
801
802 attempt += 1;
803 if config_clone.max_reconnect_attempts > 0 && attempt >= config_clone.max_reconnect_attempts {
804 tracing::error!("Max reconnect attempts reached, giving up");
805 break;
806 }
807
808 tokio::time::sleep(config_clone.reconnect_interval).await;
809 }
810 });
811
812 Self { config, sender: outgoing_tx, receiver: incoming_rx }
813 }
814
815 pub async fn send(&self, message: Message) -> WebSocketResult<()> {
817 self.sender.send(message).map_err(|e| WebSocketError::SendFailed(e.to_string()))
818 }
819
820 pub async fn send_text(&self, text: impl Into<String>) -> WebSocketResult<()> {
822 self.send(Message::text(text)).await
823 }
824
825 pub async fn send_binary(&self, data: impl Into<Vec<u8>>) -> WebSocketResult<()> {
827 self.send(Message::binary(data)).await
828 }
829
830 pub async fn send_json<T: Serialize + ?Sized>(&self, value: &T) -> WebSocketResult<()> {
832 let json = serde_json::to_string(value).map_err(|e| WebSocketError::SerializationFailed(e.to_string()))?;
833 self.send_text(json).await
834 }
835
836 pub async fn receive(&mut self) -> Option<Message> {
838 self.receiver.recv().await
839 }
840
841 pub async fn receive_json<T: DeserializeOwned>(&mut self) -> WebSocketResult<Option<T>> {
843 match self.receive().await {
844 Some(msg) => {
845 let text =
846 msg.as_text().ok_or_else(|| WebSocketError::DeserializationFailed("Expected text message".into()))?;
847 let value: T = serde_json::from_str(text).map_err(|e| WebSocketError::DeserializationFailed(e.to_string()))?;
848 Ok(Some(value))
849 }
850 None => Ok(None),
851 }
852 }
853
854 pub fn config(&self) -> &ClientConfig {
856 &self.config
857 }
858
859 pub async fn close(&self) -> WebSocketResult<()> {
861 self.send(Message::Close).await
862 }
863}
864
865pub fn websocket_server(config: ServerConfig) -> WebSocketServer {
867 WebSocketServer::new(config)
868}
869
870pub fn websocket_client(config: ClientConfig) -> WebSocketClient {
872 WebSocketClient::new(config)
873}