1use crate::{Request, Response};
4use std::sync::Arc;
5
6#[cfg(feature = "websocket")]
7use std::collections::HashMap;
8
9#[cfg(feature = "websocket")]
10use {
11 tokio_tungstenite::{accept_async, tungstenite::Message},
12 futures_util::{SinkExt, StreamExt},
13 tokio::sync::{RwLock, broadcast},
14 sha1::{Sha1, Digest},
15 base64::{Engine as _, engine::general_purpose},
16};
17
18pub struct WebSocketManager {
20 #[cfg(feature = "websocket")]
21 connections: Arc<RwLock<HashMap<String, broadcast::Sender<String>>>>,
22 #[cfg(not(feature = "websocket"))]
23 _phantom: std::marker::PhantomData<()>,
24}
25
26impl WebSocketManager {
27 pub fn new() -> Self {
28 Self {
29 #[cfg(feature = "websocket")]
30 connections: Arc::new(RwLock::new(HashMap::new())),
31 #[cfg(not(feature = "websocket"))]
32 _phantom: std::marker::PhantomData,
33 }
34 }
35
36 #[cfg(feature = "websocket")]
38 pub async fn broadcast(&self, message: &str) -> Result<usize, Box<dyn std::error::Error>> {
39 let connections = self.connections.read().await;
40 let mut sent_count = 0;
41
42 for sender in connections.values() {
43 if sender.send(message.to_string()).is_ok() {
44 sent_count += 1;
45 }
46 }
47
48 Ok(sent_count)
49 }
50
51 #[cfg(feature = "websocket")]
53 pub async fn send_to(&self, client_id: &str, message: &str) -> Result<(), Box<dyn std::error::Error>> {
54 let connections = self.connections.read().await;
55 if let Some(sender) = connections.get(client_id) {
56 sender.send(message.to_string())?;
57 }
58 Ok(())
59 }
60
61 #[cfg(feature = "websocket")]
63 pub async fn connection_count(&self) -> usize {
64 self.connections.read().await.len()
65 }
66
67 #[cfg(not(feature = "websocket"))]
68 pub async fn broadcast(&self, _message: &str) -> Result<usize, Box<dyn std::error::Error>> {
69 Err("WebSocket feature not enabled".into())
70 }
71
72 #[cfg(not(feature = "websocket"))]
73 pub async fn send_to(&self, _client_id: &str, _message: &str) -> Result<(), Box<dyn std::error::Error>> {
74 Err("WebSocket feature not enabled".into())
75 }
76
77 #[cfg(not(feature = "websocket"))]
78 pub async fn connection_count(&self) -> usize {
79 0
80 }
81}
82
83pub async fn websocket_upgrade(req: Request) -> Response {
85 #[cfg(feature = "websocket")]
86 {
87 if !is_websocket_upgrade_request(&req) {
89 return Response::bad_request().body("Not a valid WebSocket upgrade request");
90 }
91
92 let websocket_key = match req.header("sec-websocket-key") {
94 Some(key) => key,
95 None => return Response::bad_request().body("Missing Sec-WebSocket-Key header"),
96 };
97
98 let accept_key = generate_websocket_accept_key(websocket_key);
100
101 Response::with_status(http::StatusCode::SWITCHING_PROTOCOLS)
103 .header("Upgrade", "websocket")
104 .header("Connection", "Upgrade")
105 .header("Sec-WebSocket-Accept", &accept_key)
106 .header("Sec-WebSocket-Version", "13")
107 .body("")
108 }
109
110 #[cfg(not(feature = "websocket"))]
111 {
112 let _ = req; Response::with_status(http::StatusCode::NOT_IMPLEMENTED)
114 .body("WebSocket support not enabled")
115 }
116}
117
118#[cfg(feature = "websocket")]
119pub fn is_websocket_upgrade_request(req: &Request) -> bool {
120 let upgrade = req.header("upgrade").map(|h| h.to_lowercase());
122 let connection = req.header("connection").map(|h| h.to_lowercase());
123 let websocket_version = req.header("sec-websocket-version");
124 let websocket_key = req.header("sec-websocket-key");
125
126 upgrade == Some("websocket".to_string()) &&
127 connection.as_ref().map_or(false, |c| c.contains("upgrade")) &&
128 websocket_version == Some("13") &&
129 websocket_key.is_some()
130}
131
132#[cfg(feature = "websocket")]
133fn generate_websocket_accept_key(websocket_key: &str) -> String {
134 const WEBSOCKET_MAGIC_STRING: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
136
137 let combined = format!("{}{}", websocket_key, WEBSOCKET_MAGIC_STRING);
139
140 let mut hasher = Sha1::new();
142 hasher.update(combined.as_bytes());
143 let hash = hasher.finalize();
144
145 general_purpose::STANDARD.encode(&hash)
147}
148
149#[cfg(feature = "websocket")]
151pub async fn handle_websocket_connection<F, Fut>(
152 stream: tokio::net::TcpStream,
153 handler: F,
154) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
155where
156 F: FnOnce(WebSocketConnection) -> Fut + Send + 'static,
157 Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send,
158{
159 let ws_stream = accept_async(stream).await?;
161 let connection = WebSocketConnection::new(ws_stream);
162
163 handler(connection).await
165}
166
167#[cfg(feature = "websocket")]
169pub struct WebSocketConnection {
170 stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
171}
172
173#[cfg(feature = "websocket")]
174impl WebSocketConnection {
175 fn new(stream: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Self {
176 Self { stream }
177 }
178
179 pub async fn send_text(&mut self, text: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
181 self.stream.send(Message::Text(text.to_string())).await?;
182 Ok(())
183 }
184
185 pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
187 self.stream.send(Message::Binary(data.to_vec())).await?;
188 Ok(())
189 }
190
191 pub async fn receive(&mut self) -> Result<Option<WebSocketMessage>, Box<dyn std::error::Error + Send + Sync>> {
193 match self.stream.next().await {
194 Some(Ok(msg)) => Ok(Some(WebSocketMessage::from_tungstenite(msg))),
195 Some(Err(e)) => Err(e.into()),
196 None => Ok(None), }
198 }
199
200 pub async fn close(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
202 self.stream.send(Message::Close(None)).await?;
203 Ok(())
204 }
205}
206
207#[cfg(feature = "websocket")]
209pub enum WebSocketMessage {
210 Text(String),
211 Binary(Vec<u8>),
212 Ping(Vec<u8>),
213 Pong(Vec<u8>),
214 Close,
215}
216
217#[cfg(feature = "websocket")]
218impl WebSocketMessage {
219 fn from_tungstenite(msg: Message) -> Self {
220 match msg {
221 Message::Text(text) => WebSocketMessage::Text(text),
222 Message::Binary(data) => WebSocketMessage::Binary(data),
223 Message::Ping(data) => WebSocketMessage::Ping(data),
224 Message::Pong(data) => WebSocketMessage::Pong(data),
225 Message::Close(_) => WebSocketMessage::Close,
226 Message::Frame(_) => WebSocketMessage::Close, }
228 }
229
230 pub fn is_text(&self) -> bool {
232 matches!(self, WebSocketMessage::Text(_))
233 }
234
235 pub fn is_binary(&self) -> bool {
237 matches!(self, WebSocketMessage::Binary(_))
238 }
239
240 pub fn as_text(&self) -> Option<&str> {
242 match self {
243 WebSocketMessage::Text(text) => Some(text),
244 _ => None,
245 }
246 }
247
248 pub fn as_binary(&self) -> Option<&[u8]> {
250 match self {
251 WebSocketMessage::Binary(data) => Some(data),
252 _ => None,
253 }
254 }
255}
256
257pub struct ChatRoom {
259 #[cfg(feature = "websocket")]
260 manager: WebSocketManager,
261 #[cfg(feature = "websocket")]
262 message_history: Arc<RwLock<Vec<String>>>,
263 #[cfg(not(feature = "websocket"))]
264 _phantom: std::marker::PhantomData<()>,
265}
266
267impl ChatRoom {
268 pub fn new() -> Self {
269 Self {
270 #[cfg(feature = "websocket")]
271 manager: WebSocketManager::new(),
272 #[cfg(feature = "websocket")]
273 message_history: Arc::new(RwLock::new(Vec::new())),
274 #[cfg(not(feature = "websocket"))]
275 _phantom: std::marker::PhantomData,
276 }
277 }
278
279 #[cfg(feature = "websocket")]
280 pub async fn send_message(&self, user: &str, message: &str) -> Result<(), Box<dyn std::error::Error>> {
281 let formatted_message = format!("{}: {}", user, message);
282
283 {
285 let mut history = self.message_history.write().await;
286 history.push(formatted_message.clone());
287
288 if history.len() > 100 {
290 history.remove(0);
291 }
292 }
293
294 self.manager.broadcast(&formatted_message).await?;
296 Ok(())
297 }
298
299 #[cfg(feature = "websocket")]
300 pub async fn get_history(&self) -> Vec<String> {
301 self.message_history.read().await.clone()
302 }
303
304 #[cfg(not(feature = "websocket"))]
305 pub async fn send_message(&self, _user: &str, _message: &str) -> Result<(), Box<dyn std::error::Error>> {
306 Err("WebSocket feature not enabled".into())
307 }
308
309 #[cfg(not(feature = "websocket"))]
310 pub async fn get_history(&self) -> Vec<String> {
311 Vec::new()
312 }
313}
314
315pub struct SSEStream {
317 #[cfg(feature = "websocket")]
318 sender: broadcast::Sender<String>,
319 #[cfg(not(feature = "websocket"))]
320 _phantom: std::marker::PhantomData<()>,
321}
322
323impl SSEStream {
324 pub fn new() -> Self {
325 Self {
326 #[cfg(feature = "websocket")]
327 sender: broadcast::channel(1000).0,
328 #[cfg(not(feature = "websocket"))]
329 _phantom: std::marker::PhantomData,
330 }
331 }
332
333 #[cfg(feature = "websocket")]
335 pub fn send_event(&self, event_type: &str, data: &str) -> Result<(), Box<dyn std::error::Error>> {
336 let sse_message = format!("event: {}\ndata: {}\n\n", event_type, data);
337 self.sender.send(sse_message)?;
338 Ok(())
339 }
340
341 pub fn create_response(&self) -> Response {
343 #[cfg(feature = "websocket")]
344 {
345 let mut response = Response::ok()
347 .header("Content-Type", "text/event-stream")
348 .header("Cache-Control", "no-cache")
349 .header("Connection", "keep-alive")
350 .header("Access-Control-Allow-Origin", "*")
351 .header("Access-Control-Allow-Headers", "Cache-Control");
352
353 let initial_data = "event: connected\ndata: SSE stream established\nid: 0\n\n";
355 response = response.body(initial_data);
356
357 response
358 }
359
360 #[cfg(not(feature = "websocket"))]
361 {
362 Response::with_status(http::StatusCode::NOT_IMPLEMENTED)
363 .body("SSE support not enabled")
364 }
365 }
366
367 #[cfg(not(feature = "websocket"))]
368 pub fn send_event(&self, _event_type: &str, _data: &str) -> Result<(), Box<dyn std::error::Error>> {
369 Err("WebSocket feature not enabled".into())
370 }
371}
372
373pub struct WebSocketMiddleware {
375 #[cfg(feature = "websocket")]
376 manager: Arc<WebSocketManager>,
377 #[cfg(not(feature = "websocket"))]
378 _phantom: std::marker::PhantomData<()>,
379}
380
381impl WebSocketMiddleware {
382 pub fn new(_manager: Arc<WebSocketManager>) -> Self {
383 Self {
384 #[cfg(feature = "websocket")]
385 manager: _manager,
386 #[cfg(not(feature = "websocket"))]
387 _phantom: std::marker::PhantomData,
388 }
389 }
390}
391
392impl crate::middleware::Middleware for WebSocketMiddleware {
393 fn call(
394 &self,
395 req: Request,
396 next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
397 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
398 #[cfg(feature = "websocket")]
399 {
400 let _manager = self.manager.clone();
401 Box::pin(async move {
402 if req.header("upgrade").map(|h| h.to_lowercase()) == Some("websocket".to_string()) {
404 websocket_upgrade(req).await
406 } else {
407 next(req).await
409 }
410 })
411 }
412
413 #[cfg(not(feature = "websocket"))]
414 {
415 Box::pin(async move {
416 next(req).await
417 })
418 }
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[tokio::test]
427 async fn test_websocket_manager() {
428 let manager = WebSocketManager::new();
429 assert_eq!(manager.connection_count().await, 0);
430 }
431
432 #[tokio::test]
433 async fn test_chat_room() {
434 let chat = ChatRoom::new();
435 let history = chat.get_history().await;
436 assert!(history.is_empty());
437 }
438
439 #[test]
440 fn test_sse_stream() {
441 let sse = SSEStream::new();
442 let response = sse.create_response();
443
444 #[cfg(feature = "websocket")]
445 {
446 assert_eq!(response.headers().get("content-type").unwrap(), "text/event-stream");
447 }
448
449 #[cfg(not(feature = "websocket"))]
450 {
451 assert_eq!(response.status_code(), http::StatusCode::NOT_IMPLEMENTED);
452 }
453 }
454}