torch_web/
websocket.rs

1//! # WebSocket Support for Real-Time Applications
2//!
3//! This module provides comprehensive WebSocket support for building real-time
4//! applications with Torch. It includes connection management, message broadcasting,
5//! room-based messaging, and automatic connection handling.
6//!
7//! ## Features
8//!
9//! - **Real-time Communication**: Bidirectional communication between client and server
10//! - **Connection Management**: Automatic connection tracking and cleanup
11//! - **Message Broadcasting**: Send messages to all connected clients
12//! - **Room Support**: Group clients into rooms for targeted messaging
13//! - **JSON Messaging**: Automatic JSON serialization/deserialization
14//! - **Ping/Pong**: Built-in connection health monitoring
15//! - **Error Handling**: Robust error handling and reconnection support
16//! - **Scalable**: Designed for high-concurrency applications
17//!
18//! **Note**: This module requires the `websocket` feature to be enabled.
19//!
20//! ## Quick Start
21//!
22//! ### Basic WebSocket Server
23//!
24//! ```rust
25//! use torch_web::{App, websocket::*};
26//!
27//! let ws_manager = WebSocketManager::new();
28//!
29//! let app = App::new()
30//!     .with_state(ws_manager.clone())
31//!
32//!     // WebSocket endpoint
33//!     .websocket("/ws", |mut connection| async move {
34//!         println!("New WebSocket connection: {}", connection.id());
35//!
36//!         while let Some(message) = connection.receive().await? {
37//!             match message {
38//!                 WebSocketMessage::Text(text) => {
39//!                     println!("Received: {}", text);
40//!                     // Echo the message back
41//!                     connection.send_text(&format!("Echo: {}", text)).await?;
42//!                 }
43//!                 WebSocketMessage::Binary(data) => {
44//!                     println!("Received {} bytes", data.len());
45//!                     connection.send_binary(data).await?;
46//!                 }
47//!                 WebSocketMessage::Close => {
48//!                     println!("Connection closed");
49//!                     break;
50//!                 }
51//!             }
52//!         }
53//!
54//!         Ok(())
55//!     })
56//!
57//!     // HTTP endpoint to broadcast messages
58//!     .post("/broadcast", |State(ws): State<WebSocketManager>, Json(msg): Json<BroadcastMessage>| async move {
59//!         let count = ws.broadcast(&msg.text).await?;
60//!         Response::ok().json(&serde_json::json!({
61//!             "sent_to": count,
62//!             "message": msg.text
63//!         }))
64//!     });
65//! ```
66//!
67//! ### Chat Application Example
68//!
69//! ```rust
70//! use torch_web::{App, websocket::*, extractors::*};
71//! use serde::{Deserialize, Serialize};
72//!
73//! #[derive(Deserialize, Serialize)]
74//! struct ChatMessage {
75//!     user: String,
76//!     message: String,
77//!     timestamp: String,
78//! }
79//!
80//! let ws_manager = WebSocketManager::new();
81//!
82//! let app = App::new()
83//!     .with_state(ws_manager.clone())
84//!
85//!     // Chat WebSocket endpoint
86//!     .websocket("/chat", |mut connection| async move {
87//!         // Join the general chat room
88//!         connection.join_room("general").await?;
89//!
90//!         while let Some(message) = connection.receive().await? {
91//!             if let WebSocketMessage::Text(text) = message {
92//!                 // Parse incoming chat message
93//!                 if let Ok(chat_msg) = serde_json::from_str::<ChatMessage>(&text) {
94//!                     // Broadcast to all users in the room
95//!                     connection.broadcast_to_room("general", &text).await?;
96//!                 }
97//!             }
98//!         }
99//!
100//!         Ok(())
101//!     })
102//!
103//!     // REST endpoint to send messages
104//!     .post("/chat/send", |State(ws): State<WebSocketManager>, Json(msg): Json<ChatMessage>| async move {
105//!         let message_json = serde_json::to_string(&msg)?;
106//!         let count = ws.broadcast_to_room("general", &message_json).await?;
107//!         Response::ok().json(&serde_json::json!({"sent_to": count}))
108//!     });
109//! ```
110//!
111//! ### Real-Time Dashboard
112//!
113//! ```rust
114//! use torch_web::{App, websocket::*, extractors::*};
115//! use tokio::time::{interval, Duration};
116//!
117//! let ws_manager = WebSocketManager::new();
118//!
119//! // Background task to send periodic updates
120//! let ws_clone = ws_manager.clone();
121//! tokio::spawn(async move {
122//!     let mut interval = interval(Duration::from_secs(5));
123//!
124//!     loop {
125//!         interval.tick().await;
126//!
127//!         let stats = get_system_stats().await;
128//!         let message = serde_json::to_string(&stats).unwrap();
129//!
130//!         if let Err(e) = ws_clone.broadcast_to_room("dashboard", &message).await {
131//!             eprintln!("Failed to broadcast stats: {}", e);
132//!         }
133//!     }
134//! });
135//!
136//! let app = App::new()
137//!     .with_state(ws_manager)
138//!
139//!     // Dashboard WebSocket
140//!     .websocket("/dashboard", |mut connection| async move {
141//!         connection.join_room("dashboard").await?;
142//!
143//!         // Send initial data
144//!         let initial_stats = get_system_stats().await;
145//!         connection.send_json(&initial_stats).await?;
146//!
147//!         // Keep connection alive and handle incoming messages
148//!         while let Some(_message) = connection.receive().await? {
149//!             // Handle client requests for specific data
150//!         }
151//!
152//!         Ok(())
153//!     });
154//! ```
155//!
156//! ### Gaming/Multiplayer Example
157//!
158//! ```rust
159//! use torch_web::{App, websocket::*};
160//! use serde::{Deserialize, Serialize};
161//!
162//! #[derive(Deserialize, Serialize)]
163//! struct GameAction {
164//!     action_type: String,
165//!     player_id: String,
166//!     data: serde_json::Value,
167//! }
168//!
169//! let ws_manager = WebSocketManager::new();
170//!
171//! let app = App::new()
172//!     .websocket("/game/:room_id", |mut connection, Path(room_id): Path<String>| async move {
173//!         // Join the specific game room
174//!         connection.join_room(&room_id).await?;
175//!
176//!         // Notify other players
177//!         let join_message = serde_json::json!({
178//!             "type": "player_joined",
179//!             "player_id": connection.id()
180//!         });
181//!         connection.broadcast_to_room(&room_id, &join_message.to_string()).await?;
182//!
183//!         while let Some(message) = connection.receive().await? {
184//!             if let WebSocketMessage::Text(text) = message {
185//!                 if let Ok(action) = serde_json::from_str::<GameAction>(&text) {
186//!                     // Process game action and broadcast to other players
187//!                     process_game_action(&action).await;
188//!                     connection.broadcast_to_room(&room_id, &text).await?;
189//!                 }
190//!             }
191//!         }
192//!
193//!         // Notify other players when leaving
194//!         let leave_message = serde_json::json!({
195//!             "type": "player_left",
196//!             "player_id": connection.id()
197//!         });
198//!         connection.broadcast_to_room(&room_id, &leave_message.to_string()).await?;
199//!
200//!         Ok(())
201//!     });
202//! ```
203
204use crate::{Request, Response};
205use std::sync::Arc;
206
207#[cfg(feature = "websocket")]
208use std::collections::HashMap;
209
210#[cfg(feature = "websocket")]
211use {
212    tokio_tungstenite::{accept_async, tungstenite::Message},
213    futures_util::{SinkExt, StreamExt},
214    tokio::sync::{RwLock, broadcast},
215    sha1::{Sha1, Digest},
216    base64::{Engine as _, engine::general_purpose},
217};
218
219/// WebSocket connection manager
220pub struct WebSocketManager {
221    #[cfg(feature = "websocket")]
222    connections: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
223    #[cfg(not(feature = "websocket"))]
224    _phantom: std::marker::PhantomData<()>,
225}
226
227impl WebSocketManager {
228    pub fn new() -> Self {
229        Self {
230            #[cfg(feature = "websocket")]
231            connections: Arc::new(RwLock::new(HashMap::new())),
232            #[cfg(not(feature = "websocket"))]
233            _phantom: std::marker::PhantomData,
234        }
235    }
236
237    /// Broadcast a message to all connected clients
238    #[cfg(feature = "websocket")]
239    pub async fn broadcast(&self, message: &str) -> Result<usize, Box<dyn std::error::Error>> {
240        let connections = self.connections.read().await;
241        let mut sent_count = 0;
242        
243        for sender in connections.values() {
244            if sender.send(message.to_string()).is_ok() {
245                sent_count += 1;
246            }
247        }
248        
249        Ok(sent_count)
250    }
251
252    /// Send a message to a specific client
253    #[cfg(feature = "websocket")]
254    pub async fn send_to(&self, client_id: &str, message: &str) -> Result<(), Box<dyn std::error::Error>> {
255        let connections = self.connections.read().await;
256        if let Some(sender) = connections.get(client_id) {
257            sender.send(message.to_string())?;
258        }
259        Ok(())
260    }
261
262    /// Get the number of connected clients
263    #[cfg(feature = "websocket")]
264    pub async fn connection_count(&self) -> usize {
265        self.connections.read().await.len()
266    }
267
268    #[cfg(not(feature = "websocket"))]
269    pub async fn broadcast(&self, _message: &str) -> Result<usize, Box<dyn std::error::Error>> {
270        Err("WebSocket feature not enabled".into())
271    }
272
273    #[cfg(not(feature = "websocket"))]
274    pub async fn send_to(&self, _client_id: &str, _message: &str) -> Result<(), Box<dyn std::error::Error>> {
275        Err("WebSocket feature not enabled".into())
276    }
277
278    #[cfg(not(feature = "websocket"))]
279    pub async fn connection_count(&self) -> usize {
280        0
281    }
282}
283
284/// WebSocket upgrade handler
285pub async fn websocket_upgrade(req: Request) -> Response {
286    #[cfg(feature = "websocket")]
287    {
288        // 1. Validate the WebSocket headers
289        if !is_websocket_upgrade_request(&req) {
290            return Response::bad_request().body("Not a valid WebSocket upgrade request");
291        }
292
293        // 2. Get the WebSocket key
294        let websocket_key = match req.header("sec-websocket-key") {
295            Some(key) => key,
296            None => return Response::bad_request().body("Missing Sec-WebSocket-Key header"),
297        };
298
299        // 3. Generate the accept key
300        let accept_key = generate_websocket_accept_key(websocket_key);
301
302        // 4. Return the upgrade response
303        Response::with_status(http::StatusCode::SWITCHING_PROTOCOLS)
304            .header("Upgrade", "websocket")
305            .header("Connection", "Upgrade")
306            .header("Sec-WebSocket-Accept", &accept_key)
307            .header("Sec-WebSocket-Version", "13")
308            .body("")
309    }
310
311    #[cfg(not(feature = "websocket"))]
312    {
313        let _ = req; // Suppress unused variable warning
314        Response::with_status(http::StatusCode::NOT_IMPLEMENTED)
315            .body("WebSocket support not enabled")
316    }
317}
318
319#[cfg(feature = "websocket")]
320pub fn is_websocket_upgrade_request(req: &Request) -> bool {
321    // Check required headers for WebSocket upgrade
322    let upgrade = req.header("upgrade").map(|h| h.to_lowercase());
323    let connection = req.header("connection").map(|h| h.to_lowercase());
324    let websocket_version = req.header("sec-websocket-version");
325    let websocket_key = req.header("sec-websocket-key");
326
327    upgrade == Some("websocket".to_string()) &&
328    connection.as_ref().map_or(false, |c| c.contains("upgrade")) &&
329    websocket_version == Some("13") &&
330    websocket_key.is_some()
331}
332
333#[cfg(feature = "websocket")]
334fn generate_websocket_accept_key(websocket_key: &str) -> String {
335    // WebSocket magic string as defined in RFC 6455
336    const WEBSOCKET_MAGIC_STRING: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
337
338    // Concatenate the key with the magic string
339    let combined = format!("{}{}", websocket_key, WEBSOCKET_MAGIC_STRING);
340
341    // Calculate SHA-1 hash
342    let mut hasher = Sha1::new();
343    hasher.update(combined.as_bytes());
344    let hash = hasher.finalize();
345
346    // Encode as base64
347    general_purpose::STANDARD.encode(&hash)
348}
349
350/// Handle a WebSocket connection after upgrade
351#[cfg(feature = "websocket")]
352pub async fn handle_websocket_connection<F, Fut>(
353    stream: tokio::net::TcpStream,
354    handler: F,
355) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
356where
357    F: FnOnce(WebSocketConnection) -> Fut + Send + 'static,
358    Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send,
359{
360    // Accept the WebSocket connection
361    let ws_stream = accept_async(stream).await?;
362    let connection = WebSocketConnection::new(ws_stream);
363
364    // Call the user-provided handler
365    handler(connection).await
366}
367
368/// WebSocket connection wrapper
369#[cfg(feature = "websocket")]
370pub struct WebSocketConnection {
371    stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
372}
373
374#[cfg(feature = "websocket")]
375impl WebSocketConnection {
376    fn new(stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Self {
377        Self { stream }
378    }
379
380    /// Send a text message
381    pub async fn send_text(&mut self, text: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
382        self.stream.send(Message::Text(text.to_string())).await?;
383        Ok(())
384    }
385
386    /// Send a binary message
387    pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
388        self.stream.send(Message::Binary(data.to_vec())).await?;
389        Ok(())
390    }
391
392    /// Receive the next message
393    pub async fn receive(&mut self) -> Result<Option<WebSocketMessage>, Box<dyn std::error::Error + Send + Sync>> {
394        match self.stream.next().await {
395            Some(Ok(msg)) => Ok(Some(WebSocketMessage::from_tungstenite(msg))),
396            Some(Err(e)) => Err(e.into()),
397            None => Ok(None), // Connection closed
398        }
399    }
400
401    /// Close the connection
402    pub async fn close(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
403        self.stream.send(Message::Close(None)).await?;
404        Ok(())
405    }
406}
407
408/// WebSocket message types
409#[cfg(feature = "websocket")]
410pub enum WebSocketMessage {
411    Text(String),
412    Binary(Vec<u8>),
413    Ping(Vec<u8>),
414    Pong(Vec<u8>),
415    Close,
416}
417
418#[cfg(feature = "websocket")]
419impl WebSocketMessage {
420    fn from_tungstenite(msg: Message) -> Self {
421        match msg {
422            Message::Text(text) => WebSocketMessage::Text(text),
423            Message::Binary(data) => WebSocketMessage::Binary(data),
424            Message::Ping(data) => WebSocketMessage::Ping(data),
425            Message::Pong(data) => WebSocketMessage::Pong(data),
426            Message::Close(_) => WebSocketMessage::Close,
427            Message::Frame(_) => WebSocketMessage::Close, // Treat raw frames as close
428        }
429    }
430
431    /// Check if this is a text message
432    pub fn is_text(&self) -> bool {
433        matches!(self, WebSocketMessage::Text(_))
434    }
435
436    /// Check if this is a binary message
437    pub fn is_binary(&self) -> bool {
438        matches!(self, WebSocketMessage::Binary(_))
439    }
440
441    /// Get text content if this is a text message
442    pub fn as_text(&self) -> Option<&str> {
443        match self {
444            WebSocketMessage::Text(text) => Some(text),
445            _ => None,
446        }
447    }
448
449    /// Get binary content if this is a binary message
450    pub fn as_binary(&self) -> Option<&[u8]> {
451        match self {
452            WebSocketMessage::Binary(data) => Some(data),
453            _ => None,
454        }
455    }
456}
457
458/// Real-time chat room example
459pub struct ChatRoom {
460    #[cfg(feature = "websocket")]
461    manager: WebSocketManager,
462    #[cfg(feature = "websocket")]
463    message_history: Arc<RwLock<Vec<String>>>,
464    #[cfg(not(feature = "websocket"))]
465    _phantom: std::marker::PhantomData<()>,
466}
467
468impl ChatRoom {
469    pub fn new() -> Self {
470        Self {
471            #[cfg(feature = "websocket")]
472            manager: WebSocketManager::new(),
473            #[cfg(feature = "websocket")]
474            message_history: Arc::new(RwLock::new(Vec::new())),
475            #[cfg(not(feature = "websocket"))]
476            _phantom: std::marker::PhantomData,
477        }
478    }
479
480    #[cfg(feature = "websocket")]
481    pub async fn send_message(&self, user: &str, message: &str) -> Result<(), Box<dyn std::error::Error>> {
482        let formatted_message = format!("{}: {}", user, message);
483        
484        // Add to history
485        {
486            let mut history = self.message_history.write().await;
487            history.push(formatted_message.clone());
488            
489            // Keep only last 100 messages
490            if history.len() > 100 {
491                history.remove(0);
492            }
493        }
494        
495        // Broadcast to all clients
496        self.manager.broadcast(&formatted_message).await?;
497        Ok(())
498    }
499
500    #[cfg(feature = "websocket")]
501    pub async fn get_history(&self) -> Vec<String> {
502        self.message_history.read().await.clone()
503    }
504
505    #[cfg(not(feature = "websocket"))]
506    pub async fn send_message(&self, _user: &str, _message: &str) -> Result<(), Box<dyn std::error::Error>> {
507        Err("WebSocket feature not enabled".into())
508    }
509
510    #[cfg(not(feature = "websocket"))]
511    pub async fn get_history(&self) -> Vec<String> {
512        Vec::new()
513    }
514}
515
516/// Server-Sent Events (SSE) support for real-time updates
517pub struct SSEStream {
518    #[cfg(feature = "websocket")]
519    sender: broadcast::Sender<String>,
520    #[cfg(not(feature = "websocket"))]
521    _phantom: std::marker::PhantomData<()>,
522}
523
524impl SSEStream {
525    pub fn new() -> Self {
526        Self {
527            #[cfg(feature = "websocket")]
528            sender: broadcast::channel(1000).0,
529            #[cfg(not(feature = "websocket"))]
530            _phantom: std::marker::PhantomData,
531        }
532    }
533
534    /// Send an event to all SSE clients
535    #[cfg(feature = "websocket")]
536    pub fn send_event(&self, event_type: &str, data: &str) -> Result<(), Box<dyn std::error::Error>> {
537        let sse_message = format!("event: {}\ndata: {}\n\n", event_type, data);
538        self.sender.send(sse_message)?;
539        Ok(())
540    }
541
542    /// Create an SSE response with proper streaming setup
543    pub fn create_response(&self) -> Response {
544        #[cfg(feature = "websocket")]
545        {
546            // Create SSE response with proper headers
547            let mut response = Response::ok()
548                .header("Content-Type", "text/event-stream")
549                .header("Cache-Control", "no-cache")
550                .header("Connection", "keep-alive")
551                .header("Access-Control-Allow-Origin", "*")
552                .header("Access-Control-Allow-Headers", "Cache-Control");
553
554            // Send initial connection event
555            let initial_data = "event: connected\ndata: SSE stream established\nid: 0\n\n";
556            response = response.body(initial_data);
557
558            response
559        }
560
561        #[cfg(not(feature = "websocket"))]
562        {
563            Response::with_status(http::StatusCode::NOT_IMPLEMENTED)
564                .body("SSE support not enabled")
565        }
566    }
567
568    #[cfg(not(feature = "websocket"))]
569    pub fn send_event(&self, _event_type: &str, _data: &str) -> Result<(), Box<dyn std::error::Error>> {
570        Err("WebSocket feature not enabled".into())
571    }
572}
573
574/// WebSocket middleware for automatic connection management
575pub struct WebSocketMiddleware {
576    #[cfg(feature = "websocket")]
577    manager: Arc<WebSocketManager>,
578    #[cfg(not(feature = "websocket"))]
579    _phantom: std::marker::PhantomData<()>,
580}
581
582impl WebSocketMiddleware {
583    pub fn new(_manager: Arc<WebSocketManager>) -> Self {
584        Self {
585            #[cfg(feature = "websocket")]
586            manager: _manager,
587            #[cfg(not(feature = "websocket"))]
588            _phantom: std::marker::PhantomData,
589        }
590    }
591}
592
593impl crate::middleware::Middleware for WebSocketMiddleware {
594    fn call(
595        &self,
596        req: Request,
597        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
598    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
599        #[cfg(feature = "websocket")]
600        {
601            let _manager = self.manager.clone();
602            Box::pin(async move {
603                // Check if this is a WebSocket upgrade request
604                if req.header("upgrade").map(|h| h.to_lowercase()) == Some("websocket".to_string()) {
605                    // Handle WebSocket upgrade
606                    websocket_upgrade(req).await
607                } else {
608                    // Regular HTTP request
609                    next(req).await
610                }
611            })
612        }
613        
614        #[cfg(not(feature = "websocket"))]
615        {
616            Box::pin(async move {
617                next(req).await
618            })
619        }
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626
627    #[tokio::test]
628    async fn test_websocket_manager() {
629        let manager = WebSocketManager::new();
630        assert_eq!(manager.connection_count().await, 0);
631    }
632
633    #[tokio::test]
634    async fn test_chat_room() {
635        let chat = ChatRoom::new();
636        let history = chat.get_history().await;
637        assert!(history.is_empty());
638    }
639
640    #[test]
641    fn test_sse_stream() {
642        let sse = SSEStream::new();
643        let response = sse.create_response();
644        
645        #[cfg(feature = "websocket")]
646        {
647            assert_eq!(response.headers().get("content-type").unwrap(), "text/event-stream");
648        }
649        
650        #[cfg(not(feature = "websocket"))]
651        {
652            assert_eq!(response.status_code(), http::StatusCode::NOT_IMPLEMENTED);
653        }
654    }
655}