1use crate::{Request, Response};
205use std::sync::Arc;
206
207#[cfg(feature = "websocket")]
208use std::collections::HashMap;
209
210#[cfg(feature = "websocket")]
211use {
212 tokio_tungstenite::{accept_async, tungstenite::Message},
213 futures_util::{SinkExt, StreamExt},
214 tokio::sync::{RwLock, broadcast},
215 sha1::{Sha1, Digest},
216 base64::{Engine as _, engine::general_purpose},
217};
218
219pub struct WebSocketManager {
221 #[cfg(feature = "websocket")]
222 connections: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
223 #[cfg(not(feature = "websocket"))]
224 _phantom: std::marker::PhantomData<()>,
225}
226
227impl WebSocketManager {
228 pub fn new() -> Self {
229 Self {
230 #[cfg(feature = "websocket")]
231 connections: Arc::new(RwLock::new(HashMap::new())),
232 #[cfg(not(feature = "websocket"))]
233 _phantom: std::marker::PhantomData,
234 }
235 }
236
237 #[cfg(feature = "websocket")]
239 pub async fn broadcast(&self, message: &str) -> Result<usize, Box<dyn std::error::Error>> {
240 let connections = self.connections.read().await;
241 let mut sent_count = 0;
242
243 for sender in connections.values() {
244 if sender.send(message.to_string()).is_ok() {
245 sent_count += 1;
246 }
247 }
248
249 Ok(sent_count)
250 }
251
252 #[cfg(feature = "websocket")]
254 pub async fn send_to(&self, client_id: &str, message: &str) -> Result<(), Box<dyn std::error::Error>> {
255 let connections = self.connections.read().await;
256 if let Some(sender) = connections.get(client_id) {
257 sender.send(message.to_string())?;
258 }
259 Ok(())
260 }
261
262 #[cfg(feature = "websocket")]
264 pub async fn connection_count(&self) -> usize {
265 self.connections.read().await.len()
266 }
267
268 #[cfg(not(feature = "websocket"))]
269 pub async fn broadcast(&self, _message: &str) -> Result<usize, Box<dyn std::error::Error>> {
270 Err("WebSocket feature not enabled".into())
271 }
272
273 #[cfg(not(feature = "websocket"))]
274 pub async fn send_to(&self, _client_id: &str, _message: &str) -> Result<(), Box<dyn std::error::Error>> {
275 Err("WebSocket feature not enabled".into())
276 }
277
278 #[cfg(not(feature = "websocket"))]
279 pub async fn connection_count(&self) -> usize {
280 0
281 }
282}
283
284pub async fn websocket_upgrade(req: Request) -> Response {
286 #[cfg(feature = "websocket")]
287 {
288 if !is_websocket_upgrade_request(&req) {
290 return Response::bad_request().body("Not a valid WebSocket upgrade request");
291 }
292
293 let websocket_key = match req.header("sec-websocket-key") {
295 Some(key) => key,
296 None => return Response::bad_request().body("Missing Sec-WebSocket-Key header"),
297 };
298
299 let accept_key = generate_websocket_accept_key(websocket_key);
301
302 Response::with_status(http::StatusCode::SWITCHING_PROTOCOLS)
304 .header("Upgrade", "websocket")
305 .header("Connection", "Upgrade")
306 .header("Sec-WebSocket-Accept", &accept_key)
307 .header("Sec-WebSocket-Version", "13")
308 .body("")
309 }
310
311 #[cfg(not(feature = "websocket"))]
312 {
313 let _ = req; Response::with_status(http::StatusCode::NOT_IMPLEMENTED)
315 .body("WebSocket support not enabled")
316 }
317}
318
319#[cfg(feature = "websocket")]
320pub fn is_websocket_upgrade_request(req: &Request) -> bool {
321 let upgrade = req.header("upgrade").map(|h| h.to_lowercase());
323 let connection = req.header("connection").map(|h| h.to_lowercase());
324 let websocket_version = req.header("sec-websocket-version");
325 let websocket_key = req.header("sec-websocket-key");
326
327 upgrade == Some("websocket".to_string()) &&
328 connection.as_ref().map_or(false, |c| c.contains("upgrade")) &&
329 websocket_version == Some("13") &&
330 websocket_key.is_some()
331}
332
333#[cfg(feature = "websocket")]
334fn generate_websocket_accept_key(websocket_key: &str) -> String {
335 const WEBSOCKET_MAGIC_STRING: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
337
338 let combined = format!("{}{}", websocket_key, WEBSOCKET_MAGIC_STRING);
340
341 let mut hasher = Sha1::new();
343 hasher.update(combined.as_bytes());
344 let hash = hasher.finalize();
345
346 general_purpose::STANDARD.encode(&hash)
348}
349
350#[cfg(feature = "websocket")]
352pub async fn handle_websocket_connection<F, Fut>(
353 stream: tokio::net::TcpStream,
354 handler: F,
355) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
356where
357 F: FnOnce(WebSocketConnection) -> Fut + Send + 'static,
358 Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send,
359{
360 let ws_stream = accept_async(stream).await?;
362 let connection = WebSocketConnection::new(ws_stream);
363
364 handler(connection).await
366}
367
368#[cfg(feature = "websocket")]
370pub struct WebSocketConnection {
371 stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
372}
373
374#[cfg(feature = "websocket")]
375impl WebSocketConnection {
376 fn new(stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Self {
377 Self { stream }
378 }
379
380 pub async fn send_text(&mut self, text: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
382 self.stream.send(Message::Text(text.to_string())).await?;
383 Ok(())
384 }
385
386 pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
388 self.stream.send(Message::Binary(data.to_vec())).await?;
389 Ok(())
390 }
391
392 pub async fn receive(&mut self) -> Result<Option<WebSocketMessage>, Box<dyn std::error::Error + Send + Sync>> {
394 match self.stream.next().await {
395 Some(Ok(msg)) => Ok(Some(WebSocketMessage::from_tungstenite(msg))),
396 Some(Err(e)) => Err(e.into()),
397 None => Ok(None), }
399 }
400
401 pub async fn close(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
403 self.stream.send(Message::Close(None)).await?;
404 Ok(())
405 }
406}
407
408#[cfg(feature = "websocket")]
410pub enum WebSocketMessage {
411 Text(String),
412 Binary(Vec<u8>),
413 Ping(Vec<u8>),
414 Pong(Vec<u8>),
415 Close,
416}
417
418#[cfg(feature = "websocket")]
419impl WebSocketMessage {
420 fn from_tungstenite(msg: Message) -> Self {
421 match msg {
422 Message::Text(text) => WebSocketMessage::Text(text),
423 Message::Binary(data) => WebSocketMessage::Binary(data),
424 Message::Ping(data) => WebSocketMessage::Ping(data),
425 Message::Pong(data) => WebSocketMessage::Pong(data),
426 Message::Close(_) => WebSocketMessage::Close,
427 Message::Frame(_) => WebSocketMessage::Close, }
429 }
430
431 pub fn is_text(&self) -> bool {
433 matches!(self, WebSocketMessage::Text(_))
434 }
435
436 pub fn is_binary(&self) -> bool {
438 matches!(self, WebSocketMessage::Binary(_))
439 }
440
441 pub fn as_text(&self) -> Option<&str> {
443 match self {
444 WebSocketMessage::Text(text) => Some(text),
445 _ => None,
446 }
447 }
448
449 pub fn as_binary(&self) -> Option<&[u8]> {
451 match self {
452 WebSocketMessage::Binary(data) => Some(data),
453 _ => None,
454 }
455 }
456}
457
458pub struct ChatRoom {
460 #[cfg(feature = "websocket")]
461 manager: WebSocketManager,
462 #[cfg(feature = "websocket")]
463 message_history: Arc<RwLock<Vec<String>>>,
464 #[cfg(not(feature = "websocket"))]
465 _phantom: std::marker::PhantomData<()>,
466}
467
468impl ChatRoom {
469 pub fn new() -> Self {
470 Self {
471 #[cfg(feature = "websocket")]
472 manager: WebSocketManager::new(),
473 #[cfg(feature = "websocket")]
474 message_history: Arc::new(RwLock::new(Vec::new())),
475 #[cfg(not(feature = "websocket"))]
476 _phantom: std::marker::PhantomData,
477 }
478 }
479
480 #[cfg(feature = "websocket")]
481 pub async fn send_message(&self, user: &str, message: &str) -> Result<(), Box<dyn std::error::Error>> {
482 let formatted_message = format!("{}: {}", user, message);
483
484 {
486 let mut history = self.message_history.write().await;
487 history.push(formatted_message.clone());
488
489 if history.len() > 100 {
491 history.remove(0);
492 }
493 }
494
495 self.manager.broadcast(&formatted_message).await?;
497 Ok(())
498 }
499
500 #[cfg(feature = "websocket")]
501 pub async fn get_history(&self) -> Vec<String> {
502 self.message_history.read().await.clone()
503 }
504
505 #[cfg(not(feature = "websocket"))]
506 pub async fn send_message(&self, _user: &str, _message: &str) -> Result<(), Box<dyn std::error::Error>> {
507 Err("WebSocket feature not enabled".into())
508 }
509
510 #[cfg(not(feature = "websocket"))]
511 pub async fn get_history(&self) -> Vec<String> {
512 Vec::new()
513 }
514}
515
516pub struct SSEStream {
518 #[cfg(feature = "websocket")]
519 sender: broadcast::Sender<String>,
520 #[cfg(not(feature = "websocket"))]
521 _phantom: std::marker::PhantomData<()>,
522}
523
524impl SSEStream {
525 pub fn new() -> Self {
526 Self {
527 #[cfg(feature = "websocket")]
528 sender: broadcast::channel(1000).0,
529 #[cfg(not(feature = "websocket"))]
530 _phantom: std::marker::PhantomData,
531 }
532 }
533
534 #[cfg(feature = "websocket")]
536 pub fn send_event(&self, event_type: &str, data: &str) -> Result<(), Box<dyn std::error::Error>> {
537 let sse_message = format!("event: {}\ndata: {}\n\n", event_type, data);
538 self.sender.send(sse_message)?;
539 Ok(())
540 }
541
542 pub fn create_response(&self) -> Response {
544 #[cfg(feature = "websocket")]
545 {
546 let mut response = Response::ok()
548 .header("Content-Type", "text/event-stream")
549 .header("Cache-Control", "no-cache")
550 .header("Connection", "keep-alive")
551 .header("Access-Control-Allow-Origin", "*")
552 .header("Access-Control-Allow-Headers", "Cache-Control");
553
554 let initial_data = "event: connected\ndata: SSE stream established\nid: 0\n\n";
556 response = response.body(initial_data);
557
558 response
559 }
560
561 #[cfg(not(feature = "websocket"))]
562 {
563 Response::with_status(http::StatusCode::NOT_IMPLEMENTED)
564 .body("SSE support not enabled")
565 }
566 }
567
568 #[cfg(not(feature = "websocket"))]
569 pub fn send_event(&self, _event_type: &str, _data: &str) -> Result<(), Box<dyn std::error::Error>> {
570 Err("WebSocket feature not enabled".into())
571 }
572}
573
574pub struct WebSocketMiddleware {
576 #[cfg(feature = "websocket")]
577 manager: Arc<WebSocketManager>,
578 #[cfg(not(feature = "websocket"))]
579 _phantom: std::marker::PhantomData<()>,
580}
581
582impl WebSocketMiddleware {
583 pub fn new(_manager: Arc<WebSocketManager>) -> Self {
584 Self {
585 #[cfg(feature = "websocket")]
586 manager: _manager,
587 #[cfg(not(feature = "websocket"))]
588 _phantom: std::marker::PhantomData,
589 }
590 }
591}
592
593impl crate::middleware::Middleware for WebSocketMiddleware {
594 fn call(
595 &self,
596 req: Request,
597 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
598 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
599 #[cfg(feature = "websocket")]
600 {
601 let _manager = self.manager.clone();
602 Box::pin(async move {
603 if req.header("upgrade").map(|h| h.to_lowercase()) == Some("websocket".to_string()) {
605 websocket_upgrade(req).await
607 } else {
608 next(req).await
610 }
611 })
612 }
613
614 #[cfg(not(feature = "websocket"))]
615 {
616 Box::pin(async move {
617 next(req).await
618 })
619 }
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626
627 #[tokio::test]
628 async fn test_websocket_manager() {
629 let manager = WebSocketManager::new();
630 assert_eq!(manager.connection_count().await, 0);
631 }
632
633 #[tokio::test]
634 async fn test_chat_room() {
635 let chat = ChatRoom::new();
636 let history = chat.get_history().await;
637 assert!(history.is_empty());
638 }
639
640 #[test]
641 fn test_sse_stream() {
642 let sse = SSEStream::new();
643 let response = sse.create_response();
644
645 #[cfg(feature = "websocket")]
646 {
647 assert_eq!(response.headers().get("content-type").unwrap(), "text/event-stream");
648 }
649
650 #[cfg(not(feature = "websocket"))]
651 {
652 assert_eq!(response.status_code(), http::StatusCode::NOT_IMPLEMENTED);
653 }
654 }
655}