torch_web/
websocket.rs

1//! WebSocket support for real-time applications
2
3use crate::{Request, Response};
4use std::sync::Arc;
5
6#[cfg(feature = "websocket")]
7use std::collections::HashMap;
8
9#[cfg(feature = "websocket")]
10use {
11    tokio_tungstenite::{accept_async, tungstenite::Message},
12    futures_util::{SinkExt, StreamExt},
13    tokio::sync::{RwLock, broadcast},
14    sha1::{Sha1, Digest},
15    base64::{Engine as _, engine::general_purpose},
16};
17
18/// WebSocket connection manager
19pub struct WebSocketManager {
20    #[cfg(feature = "websocket")]
21    connections: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
22    #[cfg(not(feature = "websocket"))]
23    _phantom: std::marker::PhantomData<()>,
24}
25
26impl WebSocketManager {
27    pub fn new() -> Self {
28        Self {
29            #[cfg(feature = "websocket")]
30            connections: Arc::new(RwLock::new(HashMap::new())),
31            #[cfg(not(feature = "websocket"))]
32            _phantom: std::marker::PhantomData,
33        }
34    }
35
36    /// Broadcast a message to all connected clients
37    #[cfg(feature = "websocket")]
38    pub async fn broadcast(&self, message: &str) -> Result<usize, Box<dyn std::error::Error>> {
39        let connections = self.connections.read().await;
40        let mut sent_count = 0;
41        
42        for sender in connections.values() {
43            if sender.send(message.to_string()).is_ok() {
44                sent_count += 1;
45            }
46        }
47        
48        Ok(sent_count)
49    }
50
51    /// Send a message to a specific client
52    #[cfg(feature = "websocket")]
53    pub async fn send_to(&self, client_id: &str, message: &str) -> Result<(), Box<dyn std::error::Error>> {
54        let connections = self.connections.read().await;
55        if let Some(sender) = connections.get(client_id) {
56            sender.send(message.to_string())?;
57        }
58        Ok(())
59    }
60
61    /// Get the number of connected clients
62    #[cfg(feature = "websocket")]
63    pub async fn connection_count(&self) -> usize {
64        self.connections.read().await.len()
65    }
66
67    #[cfg(not(feature = "websocket"))]
68    pub async fn broadcast(&self, _message: &str) -> Result<usize, Box<dyn std::error::Error>> {
69        Err("WebSocket feature not enabled".into())
70    }
71
72    #[cfg(not(feature = "websocket"))]
73    pub async fn send_to(&self, _client_id: &str, _message: &str) -> Result<(), Box<dyn std::error::Error>> {
74        Err("WebSocket feature not enabled".into())
75    }
76
77    #[cfg(not(feature = "websocket"))]
78    pub async fn connection_count(&self) -> usize {
79        0
80    }
81}
82
83/// WebSocket upgrade handler
84pub async fn websocket_upgrade(req: Request) -> Response {
85    #[cfg(feature = "websocket")]
86    {
87        // 1. Validate the WebSocket headers
88        if !is_websocket_upgrade_request(&req) {
89            return Response::bad_request().body("Not a valid WebSocket upgrade request");
90        }
91
92        // 2. Get the WebSocket key
93        let websocket_key = match req.header("sec-websocket-key") {
94            Some(key) => key,
95            None => return Response::bad_request().body("Missing Sec-WebSocket-Key header"),
96        };
97
98        // 3. Generate the accept key
99        let accept_key = generate_websocket_accept_key(websocket_key);
100
101        // 4. Return the upgrade response
102        Response::with_status(http::StatusCode::SWITCHING_PROTOCOLS)
103            .header("Upgrade", "websocket")
104            .header("Connection", "Upgrade")
105            .header("Sec-WebSocket-Accept", &accept_key)
106            .header("Sec-WebSocket-Version", "13")
107            .body("")
108    }
109
110    #[cfg(not(feature = "websocket"))]
111    {
112        let _ = req; // Suppress unused variable warning
113        Response::with_status(http::StatusCode::NOT_IMPLEMENTED)
114            .body("WebSocket support not enabled")
115    }
116}
117
118#[cfg(feature = "websocket")]
119pub fn is_websocket_upgrade_request(req: &Request) -> bool {
120    // Check required headers for WebSocket upgrade
121    let upgrade = req.header("upgrade").map(|h| h.to_lowercase());
122    let connection = req.header("connection").map(|h| h.to_lowercase());
123    let websocket_version = req.header("sec-websocket-version");
124    let websocket_key = req.header("sec-websocket-key");
125
126    upgrade == Some("websocket".to_string()) &&
127    connection.as_ref().map_or(false, |c| c.contains("upgrade")) &&
128    websocket_version == Some("13") &&
129    websocket_key.is_some()
130}
131
132#[cfg(feature = "websocket")]
133fn generate_websocket_accept_key(websocket_key: &str) -> String {
134    // WebSocket magic string as defined in RFC 6455
135    const WEBSOCKET_MAGIC_STRING: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
136
137    // Concatenate the key with the magic string
138    let combined = format!("{}{}", websocket_key, WEBSOCKET_MAGIC_STRING);
139
140    // Calculate SHA-1 hash
141    let mut hasher = Sha1::new();
142    hasher.update(combined.as_bytes());
143    let hash = hasher.finalize();
144
145    // Encode as base64
146    general_purpose::STANDARD.encode(&hash)
147}
148
149/// Handle a WebSocket connection after upgrade
150#[cfg(feature = "websocket")]
151pub async fn handle_websocket_connection<F, Fut>(
152    stream: tokio::net::TcpStream,
153    handler: F,
154) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
155where
156    F: FnOnce(WebSocketConnection) -> Fut + Send + 'static,
157    Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send,
158{
159    // Accept the WebSocket connection
160    let ws_stream = accept_async(stream).await?;
161    let connection = WebSocketConnection::new(ws_stream);
162
163    // Call the user-provided handler
164    handler(connection).await
165}
166
167/// WebSocket connection wrapper
168#[cfg(feature = "websocket")]
169pub struct WebSocketConnection {
170    stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
171}
172
173#[cfg(feature = "websocket")]
174impl WebSocketConnection {
175    fn new(stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Self {
176        Self { stream }
177    }
178
179    /// Send a text message
180    pub async fn send_text(&mut self, text: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
181        self.stream.send(Message::Text(text.to_string())).await?;
182        Ok(())
183    }
184
185    /// Send a binary message
186    pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
187        self.stream.send(Message::Binary(data.to_vec())).await?;
188        Ok(())
189    }
190
191    /// Receive the next message
192    pub async fn receive(&mut self) -> Result<Option<WebSocketMessage>, Box<dyn std::error::Error + Send + Sync>> {
193        match self.stream.next().await {
194            Some(Ok(msg)) => Ok(Some(WebSocketMessage::from_tungstenite(msg))),
195            Some(Err(e)) => Err(e.into()),
196            None => Ok(None), // Connection closed
197        }
198    }
199
200    /// Close the connection
201    pub async fn close(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
202        self.stream.send(Message::Close(None)).await?;
203        Ok(())
204    }
205}
206
207/// WebSocket message types
208#[cfg(feature = "websocket")]
209pub enum WebSocketMessage {
210    Text(String),
211    Binary(Vec<u8>),
212    Ping(Vec<u8>),
213    Pong(Vec<u8>),
214    Close,
215}
216
217#[cfg(feature = "websocket")]
218impl WebSocketMessage {
219    fn from_tungstenite(msg: Message) -> Self {
220        match msg {
221            Message::Text(text) => WebSocketMessage::Text(text),
222            Message::Binary(data) => WebSocketMessage::Binary(data),
223            Message::Ping(data) => WebSocketMessage::Ping(data),
224            Message::Pong(data) => WebSocketMessage::Pong(data),
225            Message::Close(_) => WebSocketMessage::Close,
226            Message::Frame(_) => WebSocketMessage::Close, // Treat raw frames as close
227        }
228    }
229
230    /// Check if this is a text message
231    pub fn is_text(&self) -> bool {
232        matches!(self, WebSocketMessage::Text(_))
233    }
234
235    /// Check if this is a binary message
236    pub fn is_binary(&self) -> bool {
237        matches!(self, WebSocketMessage::Binary(_))
238    }
239
240    /// Get text content if this is a text message
241    pub fn as_text(&self) -> Option<&str> {
242        match self {
243            WebSocketMessage::Text(text) => Some(text),
244            _ => None,
245        }
246    }
247
248    /// Get binary content if this is a binary message
249    pub fn as_binary(&self) -> Option<&[u8]> {
250        match self {
251            WebSocketMessage::Binary(data) => Some(data),
252            _ => None,
253        }
254    }
255}
256
257/// Real-time chat room example
258pub struct ChatRoom {
259    #[cfg(feature = "websocket")]
260    manager: WebSocketManager,
261    #[cfg(feature = "websocket")]
262    message_history: Arc<RwLock<Vec<String>>>,
263    #[cfg(not(feature = "websocket"))]
264    _phantom: std::marker::PhantomData<()>,
265}
266
267impl ChatRoom {
268    pub fn new() -> Self {
269        Self {
270            #[cfg(feature = "websocket")]
271            manager: WebSocketManager::new(),
272            #[cfg(feature = "websocket")]
273            message_history: Arc::new(RwLock::new(Vec::new())),
274            #[cfg(not(feature = "websocket"))]
275            _phantom: std::marker::PhantomData,
276        }
277    }
278
279    #[cfg(feature = "websocket")]
280    pub async fn send_message(&self, user: &str, message: &str) -> Result<(), Box<dyn std::error::Error>> {
281        let formatted_message = format!("{}: {}", user, message);
282        
283        // Add to history
284        {
285            let mut history = self.message_history.write().await;
286            history.push(formatted_message.clone());
287            
288            // Keep only last 100 messages
289            if history.len() > 100 {
290                history.remove(0);
291            }
292        }
293        
294        // Broadcast to all clients
295        self.manager.broadcast(&formatted_message).await?;
296        Ok(())
297    }
298
299    #[cfg(feature = "websocket")]
300    pub async fn get_history(&self) -> Vec<String> {
301        self.message_history.read().await.clone()
302    }
303
304    #[cfg(not(feature = "websocket"))]
305    pub async fn send_message(&self, _user: &str, _message: &str) -> Result<(), Box<dyn std::error::Error>> {
306        Err("WebSocket feature not enabled".into())
307    }
308
309    #[cfg(not(feature = "websocket"))]
310    pub async fn get_history(&self) -> Vec<String> {
311        Vec::new()
312    }
313}
314
315/// Server-Sent Events (SSE) support for real-time updates
316pub struct SSEStream {
317    #[cfg(feature = "websocket")]
318    sender: broadcast::Sender<String>,
319    #[cfg(not(feature = "websocket"))]
320    _phantom: std::marker::PhantomData<()>,
321}
322
323impl SSEStream {
324    pub fn new() -> Self {
325        Self {
326            #[cfg(feature = "websocket")]
327            sender: broadcast::channel(1000).0,
328            #[cfg(not(feature = "websocket"))]
329            _phantom: std::marker::PhantomData,
330        }
331    }
332
333    /// Send an event to all SSE clients
334    #[cfg(feature = "websocket")]
335    pub fn send_event(&self, event_type: &str, data: &str) -> Result<(), Box<dyn std::error::Error>> {
336        let sse_message = format!("event: {}\ndata: {}\n\n", event_type, data);
337        self.sender.send(sse_message)?;
338        Ok(())
339    }
340
341    /// Create an SSE response with proper streaming setup
342    pub fn create_response(&self) -> Response {
343        #[cfg(feature = "websocket")]
344        {
345            // Create SSE response with proper headers
346            let mut response = Response::ok()
347                .header("Content-Type", "text/event-stream")
348                .header("Cache-Control", "no-cache")
349                .header("Connection", "keep-alive")
350                .header("Access-Control-Allow-Origin", "*")
351                .header("Access-Control-Allow-Headers", "Cache-Control");
352
353            // Send initial connection event
354            let initial_data = "event: connected\ndata: SSE stream established\nid: 0\n\n";
355            response = response.body(initial_data);
356
357            response
358        }
359
360        #[cfg(not(feature = "websocket"))]
361        {
362            Response::with_status(http::StatusCode::NOT_IMPLEMENTED)
363                .body("SSE support not enabled")
364        }
365    }
366
367    #[cfg(not(feature = "websocket"))]
368    pub fn send_event(&self, _event_type: &str, _data: &str) -> Result<(), Box<dyn std::error::Error>> {
369        Err("WebSocket feature not enabled".into())
370    }
371}
372
373/// WebSocket middleware for automatic connection management
374pub struct WebSocketMiddleware {
375    #[cfg(feature = "websocket")]
376    manager: Arc<WebSocketManager>,
377    #[cfg(not(feature = "websocket"))]
378    _phantom: std::marker::PhantomData<()>,
379}
380
381impl WebSocketMiddleware {
382    pub fn new(_manager: Arc<WebSocketManager>) -> Self {
383        Self {
384            #[cfg(feature = "websocket")]
385            manager: _manager,
386            #[cfg(not(feature = "websocket"))]
387            _phantom: std::marker::PhantomData,
388        }
389    }
390}
391
392impl crate::middleware::Middleware for WebSocketMiddleware {
393    fn call(
394        &self,
395        req: Request,
396        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
397    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
398        #[cfg(feature = "websocket")]
399        {
400            let _manager = self.manager.clone();
401            Box::pin(async move {
402                // Check if this is a WebSocket upgrade request
403                if req.header("upgrade").map(|h| h.to_lowercase()) == Some("websocket".to_string()) {
404                    // Handle WebSocket upgrade
405                    websocket_upgrade(req).await
406                } else {
407                    // Regular HTTP request
408                    next(req).await
409                }
410            })
411        }
412        
413        #[cfg(not(feature = "websocket"))]
414        {
415            Box::pin(async move {
416                next(req).await
417            })
418        }
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[tokio::test]
427    async fn test_websocket_manager() {
428        let manager = WebSocketManager::new();
429        assert_eq!(manager.connection_count().await, 0);
430    }
431
432    #[tokio::test]
433    async fn test_chat_room() {
434        let chat = ChatRoom::new();
435        let history = chat.get_history().await;
436        assert!(history.is_empty());
437    }
438
439    #[test]
440    fn test_sse_stream() {
441        let sse = SSEStream::new();
442        let response = sse.create_response();
443        
444        #[cfg(feature = "websocket")]
445        {
446            assert_eq!(response.headers().get("content-type").unwrap(), "text/event-stream");
447        }
448        
449        #[cfg(not(feature = "websocket"))]
450        {
451            assert_eq!(response.status_code(), http::StatusCode::NOT_IMPLEMENTED);
452        }
453    }
454}