spikard_http/
websocket.rs

1//! WebSocket support for Spikard
2//!
3//! Provides WebSocket connection handling with message validation and routing.
4
5use axum::{
6    extract::{
7        State,
8        ws::{Message, WebSocket, WebSocketUpgrade},
9    },
10    response::IntoResponse,
11};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, error, info, warn};
15
16fn trace_ws(message: &str) {
17    if std::env::var("SPIKARD_WS_TRACE").ok().as_deref() == Some("1") {
18        eprintln!("[spikard-ws] {message}");
19    }
20}
21
22/// WebSocket message handler trait
23///
24/// Implement this trait to create custom WebSocket message handlers for your application.
25/// The handler processes JSON messages received from WebSocket clients and can optionally
26/// send responses back.
27///
28/// # Implementing the Trait
29///
30/// You must implement the `handle_message` method. The `on_connect` and `on_disconnect`
31/// methods are optional and provide lifecycle hooks.
32///
33/// # Example
34///
35/// ```ignore
36/// use spikard_http::websocket::WebSocketHandler;
37/// use serde_json::{json, Value};
38///
39/// struct EchoHandler;
40///
41/// #[async_trait]
42/// impl WebSocketHandler for EchoHandler {
43///     async fn handle_message(&self, message: Value) -> Option<Value> {
44///         // Echo the message back to the client
45///         Some(message)
46///     }
47///
48///     async fn on_connect(&self) {
49///         println!("Client connected");
50///     }
51///
52///     async fn on_disconnect(&self) {
53///         println!("Client disconnected");
54///     }
55/// }
56/// ```
57pub trait WebSocketHandler: Send + Sync {
58    /// Handle incoming WebSocket message
59    ///
60    /// Called whenever a text message is received from a WebSocket client.
61    /// Messages are automatically parsed as JSON.
62    ///
63    /// # Arguments
64    /// * `message` - JSON value received from the client
65    ///
66    /// # Returns
67    /// * `Some(value)` - JSON value to send back to the client
68    /// * `None` - No response to send
69    fn handle_message(&self, message: Value) -> impl std::future::Future<Output = Option<Value>> + Send;
70
71    /// Called when a client connects to the WebSocket
72    ///
73    /// Optional lifecycle hook invoked when a new WebSocket connection is established.
74    /// Default implementation does nothing.
75    fn on_connect(&self) -> impl std::future::Future<Output = ()> + Send {
76        async {}
77    }
78
79    /// Called when a client disconnects from the WebSocket
80    ///
81    /// Optional lifecycle hook invoked when a WebSocket connection is closed
82    /// (either by the client or due to an error). Default implementation does nothing.
83    fn on_disconnect(&self) -> impl std::future::Future<Output = ()> + Send {
84        async {}
85    }
86}
87
88/// WebSocket state shared across connections
89///
90/// Contains the message handler and optional JSON schemas for validating
91/// incoming and outgoing messages. This state is shared among all connections
92/// to the same WebSocket endpoint.
93#[derive(Debug)]
94pub struct WebSocketState<H: WebSocketHandler> {
95    /// The message handler implementation
96    handler: Arc<H>,
97    /// Optional JSON Schema for validating incoming messages
98    message_schema: Option<Arc<jsonschema::Validator>>,
99    /// Optional JSON Schema for validating outgoing responses
100    response_schema: Option<Arc<jsonschema::Validator>>,
101}
102
103impl<H: WebSocketHandler> Clone for WebSocketState<H> {
104    fn clone(&self) -> Self {
105        Self {
106            handler: Arc::clone(&self.handler),
107            message_schema: self.message_schema.clone(),
108            response_schema: self.response_schema.clone(),
109        }
110    }
111}
112
113impl<H: WebSocketHandler + 'static> WebSocketState<H> {
114    /// Create new WebSocket state with a handler
115    ///
116    /// Creates a new state without message or response validation schemas.
117    /// Messages and responses are not validated.
118    ///
119    /// # Arguments
120    /// * `handler` - The message handler implementation
121    ///
122    /// # Example
123    ///
124    /// ```ignore
125    /// let state = WebSocketState::new(MyHandler);
126    /// ```
127    pub fn new(handler: H) -> Self {
128        Self {
129            handler: Arc::new(handler),
130            message_schema: None,
131            response_schema: None,
132        }
133    }
134
135    /// Create new WebSocket state with a handler and optional validation schemas
136    ///
137    /// Creates a new state with optional JSON schemas for validating incoming messages
138    /// and outgoing responses. If a schema is provided and validation fails, the message
139    /// or response is rejected.
140    ///
141    /// # Arguments
142    /// * `handler` - The message handler implementation
143    /// * `message_schema` - Optional JSON schema for validating client messages
144    /// * `response_schema` - Optional JSON schema for validating handler responses
145    ///
146    /// # Returns
147    /// * `Ok(state)` - Successfully created state
148    /// * `Err(msg)` - Invalid schema provided
149    ///
150    /// # Example
151    ///
152    /// ```ignore
153    /// use serde_json::json;
154    ///
155    /// let message_schema = json!({
156    ///     "type": "object",
157    ///     "properties": {
158    ///         "type": {"type": "string"},
159    ///         "data": {"type": "string"}
160    ///     }
161    /// });
162    ///
163    /// let state = WebSocketState::with_schemas(
164    ///     MyHandler,
165    ///     Some(message_schema),
166    ///     None,
167    /// )?;
168    /// ```
169    pub fn with_schemas(
170        handler: H,
171        message_schema: Option<serde_json::Value>,
172        response_schema: Option<serde_json::Value>,
173    ) -> Result<Self, String> {
174        let message_validator = if let Some(schema) = message_schema {
175            Some(Arc::new(
176                jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
177            ))
178        } else {
179            None
180        };
181
182        let response_validator = if let Some(schema) = response_schema {
183            Some(Arc::new(
184                jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
185            ))
186        } else {
187            None
188        };
189
190        Ok(Self {
191            handler: Arc::new(handler),
192            message_schema: message_validator,
193            response_schema: response_validator,
194        })
195    }
196}
197
198/// WebSocket upgrade handler
199///
200/// This is the main entry point for WebSocket connections. Use this as an Axum route
201/// handler by passing it to an Axum router's `.route()` method with `get()`.
202///
203/// # Arguments
204/// * `ws` - WebSocket upgrade from Axum
205/// * `State(state)` - Application state containing the handler and optional schemas
206///
207/// # Returns
208/// An Axum response that upgrades the connection to WebSocket
209///
210/// # Example
211///
212/// ```ignore
213/// use axum::{Router, routing::get, extract::State};
214///
215/// let state = WebSocketState::new(MyHandler);
216/// let router = Router::new()
217///     .route("/ws", get(websocket_handler::<MyHandler>))
218///     .with_state(state);
219/// ```
220pub async fn websocket_handler<H: WebSocketHandler + 'static>(
221    ws: WebSocketUpgrade,
222    State(state): State<WebSocketState<H>>,
223) -> impl IntoResponse {
224    ws.on_upgrade(move |socket| handle_socket(socket, state))
225}
226
227/// Handle an individual WebSocket connection
228async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSocketState<H>) {
229    info!("WebSocket client connected");
230    trace_ws("socket:connected");
231
232    state.handler.on_connect().await;
233    trace_ws("socket:on_connect:done");
234
235    while let Some(msg) = socket.recv().await {
236        match msg {
237            Ok(Message::Text(text)) => {
238                debug!("Received text message: {}", text);
239                trace_ws(&format!("recv:text len={}", text.len()));
240
241                match serde_json::from_str::<Value>(&text) {
242                    Ok(json_msg) => {
243                        trace_ws("recv:text:json-ok");
244                        if let Some(validator) = &state.message_schema
245                            && !validator.is_valid(&json_msg)
246                        {
247                            error!("Message validation failed");
248                            trace_ws("recv:text:validation-failed");
249                            let error_response = serde_json::json!({
250                                "error": "Message validation failed"
251                            });
252                            if let Ok(error_text) = serde_json::to_string(&error_response) {
253                                trace_ws(&format!("send:validation-error len={}", error_text.len()));
254                                let _ = socket.send(Message::Text(error_text.into())).await;
255                            }
256                            continue;
257                        }
258
259                        if let Some(response) = state.handler.handle_message(json_msg).await {
260                            trace_ws("handler:response:some");
261                            if let Some(validator) = &state.response_schema
262                                && !validator.is_valid(&response)
263                            {
264                                error!("Response validation failed");
265                                trace_ws("send:response:validation-failed");
266                                continue;
267                            }
268
269                            let response_text = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
270                            let response_len = response_text.len();
271
272                            if let Err(e) = socket.send(Message::Text(response_text.into())).await {
273                                error!("Failed to send response: {}", e);
274                                trace_ws("send:response:error");
275                                break;
276                            }
277                            trace_ws(&format!("send:response len={}", response_len));
278                        } else {
279                            trace_ws("handler:response:none");
280                        }
281                    }
282                    Err(e) => {
283                        warn!("Failed to parse JSON message: {}", e);
284                        trace_ws("recv:text:json-error");
285                        let error_msg = serde_json::json!({
286                            "type": "error",
287                            "message": "Invalid JSON"
288                        });
289                        let error_text = serde_json::to_string(&error_msg).unwrap_or_else(|_| "{}".to_string());
290                        trace_ws(&format!("send:json-error len={}", error_text.len()));
291                        let _ = socket.send(Message::Text(error_text.into())).await;
292                    }
293                }
294            }
295            Ok(Message::Binary(data)) => {
296                debug!("Received binary message: {} bytes", data.len());
297                trace_ws(&format!("recv:binary len={}", data.len()));
298                if let Err(e) = socket.send(Message::Binary(data)).await {
299                    error!("Failed to send binary response: {}", e);
300                    trace_ws("send:binary:error");
301                    break;
302                }
303                trace_ws("send:binary:ok");
304            }
305            Ok(Message::Ping(data)) => {
306                debug!("Received ping");
307                trace_ws(&format!("recv:ping len={}", data.len()));
308                if let Err(e) = socket.send(Message::Pong(data)).await {
309                    error!("Failed to send pong: {}", e);
310                    trace_ws("send:pong:error");
311                    break;
312                }
313                trace_ws("send:pong:ok");
314            }
315            Ok(Message::Pong(_)) => {
316                debug!("Received pong");
317                trace_ws("recv:pong");
318            }
319            Ok(Message::Close(_)) => {
320                info!("Client closed connection");
321                trace_ws("recv:close");
322                break;
323            }
324            Err(e) => {
325                error!("WebSocket error: {}", e);
326                trace_ws(&format!("recv:error {}", e));
327                break;
328            }
329        }
330    }
331
332    state.handler.on_disconnect().await;
333    trace_ws("socket:on_disconnect:done");
334    info!("WebSocket client disconnected");
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use std::sync::Mutex;
341    use std::sync::atomic::{AtomicUsize, Ordering};
342
343    #[derive(Debug)]
344    struct EchoHandler;
345
346    impl WebSocketHandler for EchoHandler {
347        async fn handle_message(&self, message: Value) -> Option<Value> {
348            Some(message)
349        }
350    }
351
352    #[derive(Debug)]
353    struct TrackingHandler {
354        connect_count: Arc<AtomicUsize>,
355        disconnect_count: Arc<AtomicUsize>,
356        message_count: Arc<AtomicUsize>,
357        messages: Arc<Mutex<Vec<Value>>>,
358    }
359
360    impl TrackingHandler {
361        fn new() -> Self {
362            Self {
363                connect_count: Arc::new(AtomicUsize::new(0)),
364                disconnect_count: Arc::new(AtomicUsize::new(0)),
365                message_count: Arc::new(AtomicUsize::new(0)),
366                messages: Arc::new(Mutex::new(Vec::new())),
367            }
368        }
369    }
370
371    impl WebSocketHandler for TrackingHandler {
372        async fn handle_message(&self, message: Value) -> Option<Value> {
373            self.message_count.fetch_add(1, Ordering::SeqCst);
374            self.messages.lock().unwrap().push(message.clone());
375            Some(message)
376        }
377
378        async fn on_connect(&self) {
379            self.connect_count.fetch_add(1, Ordering::SeqCst);
380        }
381
382        async fn on_disconnect(&self) {
383            self.disconnect_count.fetch_add(1, Ordering::SeqCst);
384        }
385    }
386
387    #[derive(Debug)]
388    struct SelectiveHandler;
389
390    impl WebSocketHandler for SelectiveHandler {
391        async fn handle_message(&self, message: Value) -> Option<Value> {
392            if message.get("respond").is_some_and(|v| v.as_bool().unwrap_or(false)) {
393                Some(serde_json::json!({"response": "acknowledged"}))
394            } else {
395                None
396            }
397        }
398    }
399
400    #[derive(Debug)]
401    struct TransformHandler;
402
403    impl WebSocketHandler for TransformHandler {
404        async fn handle_message(&self, message: Value) -> Option<Value> {
405            message.as_object().map_or(None, |obj| {
406                let mut resp = obj.clone();
407                resp.insert("processed".to_string(), Value::Bool(true));
408                Some(Value::Object(resp))
409            })
410        }
411    }
412
413    #[test]
414    fn test_websocket_state_creation() {
415        let handler: EchoHandler = EchoHandler;
416        let state: WebSocketState<EchoHandler> = WebSocketState::new(handler);
417        let cloned: WebSocketState<EchoHandler> = state.clone();
418        assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
419    }
420
421    #[test]
422    fn test_websocket_state_with_valid_schema() {
423        let handler: EchoHandler = EchoHandler;
424        let schema: serde_json::Value = serde_json::json!({
425            "type": "object",
426            "properties": {
427                "type": {"type": "string"}
428            }
429        });
430
431        let result: Result<WebSocketState<EchoHandler>, String> =
432            WebSocketState::with_schemas(handler, Some(schema), None);
433        assert!(result.is_ok());
434    }
435
436    #[test]
437    fn test_websocket_state_with_invalid_schema() {
438        let handler: EchoHandler = EchoHandler;
439        let invalid_schema: serde_json::Value = serde_json::json!({
440            "type": "not_a_real_type",
441            "invalid": "schema"
442        });
443
444        let result: Result<WebSocketState<EchoHandler>, String> =
445            WebSocketState::with_schemas(handler, Some(invalid_schema), None);
446        assert!(result.is_err());
447        if let Err(error_msg) = result {
448            assert!(error_msg.contains("Invalid message schema"));
449        }
450    }
451
452    #[test]
453    fn test_websocket_state_with_both_schemas() {
454        let handler: EchoHandler = EchoHandler;
455        let message_schema: serde_json::Value = serde_json::json!({
456            "type": "object",
457            "properties": {"action": {"type": "string"}}
458        });
459        let response_schema: serde_json::Value = serde_json::json!({
460            "type": "object",
461            "properties": {"result": {"type": "string"}}
462        });
463
464        let result: Result<WebSocketState<EchoHandler>, String> =
465            WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema));
466        assert!(result.is_ok());
467        let state: WebSocketState<EchoHandler> = result.unwrap();
468        assert!(state.message_schema.is_some());
469        assert!(state.response_schema.is_some());
470    }
471
472    #[test]
473    fn test_websocket_state_cloning_preserves_schemas() {
474        let handler: EchoHandler = EchoHandler;
475        let schema: serde_json::Value = serde_json::json!({
476            "type": "object",
477            "properties": {"id": {"type": "integer"}}
478        });
479
480        let state: WebSocketState<EchoHandler> = WebSocketState::with_schemas(handler, Some(schema), None).unwrap();
481        let cloned: WebSocketState<EchoHandler> = state.clone();
482
483        assert!(cloned.message_schema.is_some());
484        assert!(cloned.response_schema.is_none());
485        assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
486    }
487
488    #[tokio::test]
489    async fn test_tracking_handler_lifecycle() {
490        let handler: TrackingHandler = TrackingHandler::new();
491        handler.on_connect().await;
492        assert_eq!(handler.connect_count.load(Ordering::SeqCst), 1);
493
494        let msg: Value = serde_json::json!({"test": "data"});
495        let _response: Option<Value> = handler.handle_message(msg).await;
496        assert_eq!(handler.message_count.load(Ordering::SeqCst), 1);
497
498        handler.on_disconnect().await;
499        assert_eq!(handler.disconnect_count.load(Ordering::SeqCst), 1);
500    }
501
502    #[tokio::test]
503    async fn test_selective_handler_responds_conditionally() {
504        let handler: SelectiveHandler = SelectiveHandler;
505
506        let respond_msg: Value = serde_json::json!({"respond": true});
507        let response1: Option<Value> = handler.handle_message(respond_msg).await;
508        assert!(response1.is_some());
509        assert_eq!(response1.unwrap(), serde_json::json!({"response": "acknowledged"}));
510
511        let no_respond_msg: Value = serde_json::json!({"respond": false});
512        let response2: Option<Value> = handler.handle_message(no_respond_msg).await;
513        assert!(response2.is_none());
514    }
515
516    #[tokio::test]
517    async fn test_transform_handler_modifies_message() {
518        let handler: TransformHandler = TransformHandler;
519        let original: Value = serde_json::json!({"name": "test"});
520        let transformed: Option<Value> = handler.handle_message(original).await;
521
522        assert!(transformed.is_some());
523        let resp: Value = transformed.unwrap();
524        assert_eq!(resp.get("name").unwrap(), "test");
525        assert_eq!(resp.get("processed").unwrap(), true);
526    }
527
528    #[tokio::test]
529    async fn test_echo_handler_preserves_json_types() {
530        let handler: EchoHandler = EchoHandler;
531
532        let messages: Vec<Value> = vec![
533            serde_json::json!({"string": "value"}),
534            serde_json::json!({"number": 42}),
535            serde_json::json!({"float": 3.14}),
536            serde_json::json!({"bool": true}),
537            serde_json::json!({"null": null}),
538            serde_json::json!({"array": [1, 2, 3]}),
539        ];
540
541        for msg in messages {
542            let response: Option<Value> = handler.handle_message(msg.clone()).await;
543            assert!(response.is_some());
544            assert_eq!(response.unwrap(), msg);
545        }
546    }
547
548    #[tokio::test]
549    async fn test_tracking_handler_accumulates_messages() {
550        let handler: TrackingHandler = TrackingHandler::new();
551
552        let messages: Vec<Value> = vec![
553            serde_json::json!({"id": 1}),
554            serde_json::json!({"id": 2}),
555            serde_json::json!({"id": 3}),
556        ];
557
558        for msg in messages {
559            let _: Option<Value> = handler.handle_message(msg).await;
560        }
561
562        assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
563        let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
564        assert_eq!(stored.len(), 3);
565        assert_eq!(stored[0].get("id").unwrap(), 1);
566        assert_eq!(stored[1].get("id").unwrap(), 2);
567        assert_eq!(stored[2].get("id").unwrap(), 3);
568    }
569
570    #[tokio::test]
571    async fn test_echo_handler_with_nested_json() {
572        let handler: EchoHandler = EchoHandler;
573        let nested: Value = serde_json::json!({
574            "level1": {
575                "level2": {
576                    "level3": {
577                        "value": "deeply nested"
578                    }
579                }
580            }
581        });
582
583        let response: Option<Value> = handler.handle_message(nested.clone()).await;
584        assert!(response.is_some());
585        assert_eq!(response.unwrap(), nested);
586    }
587
588    #[tokio::test]
589    async fn test_echo_handler_with_large_array() {
590        let handler: EchoHandler = EchoHandler;
591        let large_array: Value = serde_json::json!({
592            "items": (0..1000).collect::<Vec<i32>>()
593        });
594
595        let response: Option<Value> = handler.handle_message(large_array.clone()).await;
596        assert!(response.is_some());
597        assert_eq!(response.unwrap(), large_array);
598    }
599
600    #[tokio::test]
601    async fn test_echo_handler_with_unicode() {
602        let handler: EchoHandler = EchoHandler;
603        let unicode_msg: Value = serde_json::json!({
604            "emoji": "🚀",
605            "chinese": "你好",
606            "arabic": "مرحبا",
607            "mixed": "Hello 世界 🌍"
608        });
609
610        let response: Option<Value> = handler.handle_message(unicode_msg.clone()).await;
611        assert!(response.is_some());
612        assert_eq!(response.unwrap(), unicode_msg);
613    }
614
615    #[test]
616    fn test_websocket_state_schemas_are_independent() {
617        let handler: EchoHandler = EchoHandler;
618        let message_schema: serde_json::Value = serde_json::json!({"type": "object"});
619        let response_schema: serde_json::Value = serde_json::json!({"type": "array"});
620
621        let state: WebSocketState<EchoHandler> =
622            WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
623
624        let cloned: WebSocketState<EchoHandler> = state.clone();
625
626        assert!(state.message_schema.is_some());
627        assert!(state.response_schema.is_some());
628        assert!(cloned.message_schema.is_some());
629        assert!(cloned.response_schema.is_some());
630    }
631
632    #[test]
633    fn test_message_schema_validation_with_required_field() {
634        let handler: EchoHandler = EchoHandler;
635        let message_schema: serde_json::Value = serde_json::json!({
636            "type": "object",
637            "properties": {"type": {"type": "string"}},
638            "required": ["type"]
639        });
640
641        let state: WebSocketState<EchoHandler> =
642            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
643
644        assert!(state.message_schema.is_some());
645        assert!(state.response_schema.is_none());
646
647        let valid_msg: Value = serde_json::json!({"type": "test"});
648        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
649        assert!(validator.is_valid(&valid_msg));
650
651        let invalid_msg: Value = serde_json::json!({"other": "field"});
652        assert!(!validator.is_valid(&invalid_msg));
653    }
654
655    #[test]
656    fn test_response_schema_validation_with_required_field() {
657        let handler: EchoHandler = EchoHandler;
658        let response_schema: serde_json::Value = serde_json::json!({
659            "type": "object",
660            "properties": {"status": {"type": "string"}},
661            "required": ["status"]
662        });
663
664        let state: WebSocketState<EchoHandler> =
665            WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
666
667        assert!(state.message_schema.is_none());
668        assert!(state.response_schema.is_some());
669
670        let valid_response: Value = serde_json::json!({"status": "ok"});
671        let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
672        assert!(validator.is_valid(&valid_response));
673
674        let invalid_response: Value = serde_json::json!({"other": "field"});
675        assert!(!validator.is_valid(&invalid_response));
676    }
677
678    #[test]
679    fn test_invalid_message_schema_returns_error() {
680        let handler: EchoHandler = EchoHandler;
681        let invalid_schema: serde_json::Value = serde_json::json!({
682            "type": "invalid_type_value",
683            "properties": {}
684        });
685
686        let result: Result<WebSocketState<EchoHandler>, String> =
687            WebSocketState::with_schemas(handler, Some(invalid_schema), None);
688
689        assert!(result.is_err());
690        match result {
691            Err(error_msg) => assert!(error_msg.contains("Invalid message schema")),
692            Ok(_) => panic!("Expected error but got Ok"),
693        }
694    }
695
696    #[test]
697    fn test_invalid_response_schema_returns_error() {
698        let handler: EchoHandler = EchoHandler;
699        let invalid_schema: serde_json::Value = serde_json::json!({
700            "type": "definitely_not_valid"
701        });
702
703        let result: Result<WebSocketState<EchoHandler>, String> =
704            WebSocketState::with_schemas(handler, None, Some(invalid_schema));
705
706        assert!(result.is_err());
707        match result {
708            Err(error_msg) => assert!(error_msg.contains("Invalid response schema")),
709            Ok(_) => panic!("Expected error but got Ok"),
710        }
711    }
712
713    #[tokio::test]
714    async fn test_handler_returning_none_response() {
715        let handler: SelectiveHandler = SelectiveHandler;
716
717        let no_response_msg: Value = serde_json::json!({"respond": false});
718        let result: Option<Value> = handler.handle_message(no_response_msg).await;
719
720        assert!(result.is_none());
721    }
722
723    #[tokio::test]
724    async fn test_handler_with_complex_schema_validation() {
725        let handler: EchoHandler = EchoHandler;
726        let message_schema: serde_json::Value = serde_json::json!({
727            "type": "object",
728            "properties": {
729                "user": {
730                    "type": "object",
731                    "properties": {
732                        "id": {"type": "integer"},
733                        "name": {"type": "string"}
734                    },
735                    "required": ["id", "name"]
736                },
737                "action": {"type": "string"}
738            },
739            "required": ["user", "action"]
740        });
741
742        let state: WebSocketState<EchoHandler> =
743            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
744
745        let valid_msg: Value = serde_json::json!({
746            "user": {"id": 123, "name": "Alice"},
747            "action": "create"
748        });
749        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
750        assert!(validator.is_valid(&valid_msg));
751
752        let invalid_msg: Value = serde_json::json!({
753            "user": {"id": "not_an_int", "name": "Bob"},
754            "action": "create"
755        });
756        assert!(!validator.is_valid(&invalid_msg));
757    }
758
759    #[tokio::test]
760    async fn test_tracking_handler_with_multiple_message_types() {
761        let handler: TrackingHandler = TrackingHandler::new();
762
763        let messages: Vec<Value> = vec![
764            serde_json::json!({"type": "text", "content": "hello"}),
765            serde_json::json!({"type": "image", "url": "http://example.com/image.png"}),
766            serde_json::json!({"type": "video", "duration": 120}),
767        ];
768
769        for msg in messages {
770            let _: Option<Value> = handler.handle_message(msg).await;
771        }
772
773        assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
774        let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
775        assert_eq!(stored.len(), 3);
776        assert_eq!(stored[0].get("type").unwrap(), "text");
777        assert_eq!(stored[1].get("type").unwrap(), "image");
778        assert_eq!(stored[2].get("type").unwrap(), "video");
779    }
780
781    #[tokio::test]
782    async fn test_selective_handler_with_explicit_false() {
783        let handler: SelectiveHandler = SelectiveHandler;
784
785        let msg: Value = serde_json::json!({"respond": false, "data": "test"});
786        let response: Option<Value> = handler.handle_message(msg).await;
787
788        assert!(response.is_none());
789    }
790
791    #[tokio::test]
792    async fn test_selective_handler_without_respond_field() {
793        let handler: SelectiveHandler = SelectiveHandler;
794
795        let msg: Value = serde_json::json!({"data": "test"});
796        let response: Option<Value> = handler.handle_message(msg).await;
797
798        assert!(response.is_none());
799    }
800
801    #[tokio::test]
802    async fn test_transform_handler_with_empty_object() {
803        let handler: TransformHandler = TransformHandler;
804        let original: Value = serde_json::json!({});
805        let transformed: Option<Value> = handler.handle_message(original).await;
806
807        assert!(transformed.is_some());
808        let resp: Value = transformed.unwrap();
809        assert_eq!(resp.get("processed").unwrap(), true);
810        assert_eq!(resp.as_object().unwrap().len(), 1);
811    }
812
813    #[tokio::test]
814    async fn test_transform_handler_preserves_all_fields() {
815        let handler: TransformHandler = TransformHandler;
816        let original: Value = serde_json::json!({
817            "field1": "value1",
818            "field2": 42,
819            "field3": true,
820            "nested": {"key": "value"}
821        });
822        let transformed: Option<Value> = handler.handle_message(original.clone()).await;
823
824        assert!(transformed.is_some());
825        let resp: Value = transformed.unwrap();
826        assert_eq!(resp.get("field1").unwrap(), "value1");
827        assert_eq!(resp.get("field2").unwrap(), 42);
828        assert_eq!(resp.get("field3").unwrap(), true);
829        assert_eq!(resp.get("nested").unwrap(), &serde_json::json!({"key": "value"}));
830        assert_eq!(resp.get("processed").unwrap(), true);
831    }
832
833    #[tokio::test]
834    async fn test_transform_handler_with_non_object_input() {
835        let handler: TransformHandler = TransformHandler;
836
837        let array: Value = serde_json::json!([1, 2, 3]);
838        let response1: Option<Value> = handler.handle_message(array).await;
839        assert!(response1.is_none());
840
841        let string: Value = serde_json::json!("not an object");
842        let response2: Option<Value> = handler.handle_message(string).await;
843        assert!(response2.is_none());
844
845        let number: Value = serde_json::json!(42);
846        let response3: Option<Value> = handler.handle_message(number).await;
847        assert!(response3.is_none());
848    }
849
850    /// Test message validation failure with schema constraint
851    #[test]
852    fn test_message_schema_rejects_wrong_type() {
853        let handler: EchoHandler = EchoHandler;
854        let message_schema: serde_json::Value = serde_json::json!({
855            "type": "object",
856            "properties": {"id": {"type": "integer"}},
857            "required": ["id"]
858        });
859
860        let state: WebSocketState<EchoHandler> =
861            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
862
863        let invalid_msg: Value = serde_json::json!({"id": "not_an_integer"});
864        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
865        assert!(!validator.is_valid(&invalid_msg));
866    }
867
868    /// Test response schema validation failure
869    #[test]
870    fn test_response_schema_rejects_invalid_type() {
871        let handler: EchoHandler = EchoHandler;
872        let response_schema: serde_json::Value = serde_json::json!({
873            "type": "object",
874            "properties": {"count": {"type": "integer"}},
875            "required": ["count"]
876        });
877
878        let state: WebSocketState<EchoHandler> =
879            WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
880
881        let invalid_response: Value = serde_json::json!([1, 2, 3]);
882        let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
883        assert!(!validator.is_valid(&invalid_response));
884    }
885
886    /// Test message with multiple required fields missing
887    #[test]
888    fn test_message_missing_multiple_required_fields() {
889        let handler: EchoHandler = EchoHandler;
890        let message_schema: serde_json::Value = serde_json::json!({
891            "type": "object",
892            "properties": {
893                "user_id": {"type": "integer"},
894                "action": {"type": "string"},
895                "timestamp": {"type": "string"}
896            },
897            "required": ["user_id", "action", "timestamp"]
898        });
899
900        let state: WebSocketState<EchoHandler> =
901            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
902
903        let invalid_msg: Value = serde_json::json!({"other": "value"});
904        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
905        assert!(!validator.is_valid(&invalid_msg));
906
907        let partial_msg: Value = serde_json::json!({"user_id": 123});
908        assert!(!validator.is_valid(&partial_msg));
909    }
910
911    /// Test deeply nested schema validation with required nested properties
912    #[test]
913    fn test_deeply_nested_schema_validation_failure() {
914        let handler: EchoHandler = EchoHandler;
915        let message_schema: serde_json::Value = serde_json::json!({
916            "type": "object",
917            "properties": {
918                "metadata": {
919                    "type": "object",
920                    "properties": {
921                        "request": {
922                            "type": "object",
923                            "properties": {
924                                "id": {"type": "string"}
925                            },
926                            "required": ["id"]
927                        }
928                    },
929                    "required": ["request"]
930                }
931            },
932            "required": ["metadata"]
933        });
934
935        let state: WebSocketState<EchoHandler> =
936            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
937
938        let invalid_msg: Value = serde_json::json!({
939            "metadata": {
940                "request": {}
941            }
942        });
943        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
944        assert!(!validator.is_valid(&invalid_msg));
945    }
946
947    /// Test array property validation with items constraint
948    #[test]
949    fn test_array_property_type_validation() {
950        let handler: EchoHandler = EchoHandler;
951        let message_schema: serde_json::Value = serde_json::json!({
952            "type": "object",
953            "properties": {
954                "ids": {
955                    "type": "array",
956                    "items": {"type": "integer"}
957                }
958            }
959        });
960
961        let state: WebSocketState<EchoHandler> =
962            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
963
964        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
965
966        let valid_msg: Value = serde_json::json!({"ids": [1, 2, 3]});
967        assert!(validator.is_valid(&valid_msg));
968
969        let invalid_msg: Value = serde_json::json!({"ids": [1, "two", 3]});
970        assert!(!validator.is_valid(&invalid_msg));
971
972        let invalid_msg2: Value = serde_json::json!({"ids": "not_an_array"});
973        assert!(!validator.is_valid(&invalid_msg2));
974    }
975
976    /// Test enum/const property validation
977    #[test]
978    fn test_enum_property_validation() {
979        let handler: EchoHandler = EchoHandler;
980        let message_schema: serde_json::Value = serde_json::json!({
981            "type": "object",
982            "properties": {
983                "status": {
984                    "type": "string",
985                    "enum": ["pending", "active", "completed"]
986                }
987            },
988            "required": ["status"]
989        });
990
991        let state: WebSocketState<EchoHandler> =
992            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
993
994        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
995
996        let valid_msg: Value = serde_json::json!({"status": "active"});
997        assert!(validator.is_valid(&valid_msg));
998
999        let invalid_msg: Value = serde_json::json!({"status": "unknown"});
1000        assert!(!validator.is_valid(&invalid_msg));
1001    }
1002
1003    /// Test minimum/maximum constraints on numbers
1004    #[test]
1005    fn test_number_range_validation() {
1006        let handler: EchoHandler = EchoHandler;
1007        let message_schema: serde_json::Value = serde_json::json!({
1008            "type": "object",
1009            "properties": {
1010                "age": {
1011                    "type": "integer",
1012                    "minimum": 0,
1013                    "maximum": 150
1014                }
1015            },
1016            "required": ["age"]
1017        });
1018
1019        let state: WebSocketState<EchoHandler> =
1020            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1021
1022        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1023
1024        let valid_msg: Value = serde_json::json!({"age": 25});
1025        assert!(validator.is_valid(&valid_msg));
1026
1027        let invalid_msg: Value = serde_json::json!({"age": -1});
1028        assert!(!validator.is_valid(&invalid_msg));
1029
1030        let invalid_msg2: Value = serde_json::json!({"age": 200});
1031        assert!(!validator.is_valid(&invalid_msg2));
1032    }
1033
1034    /// Test string length constraints
1035    #[test]
1036    fn test_string_length_validation() {
1037        let handler: EchoHandler = EchoHandler;
1038        let message_schema: serde_json::Value = serde_json::json!({
1039            "type": "object",
1040            "properties": {
1041                "username": {
1042                    "type": "string",
1043                    "minLength": 3,
1044                    "maxLength": 20
1045                }
1046            },
1047            "required": ["username"]
1048        });
1049
1050        let state: WebSocketState<EchoHandler> =
1051            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1052
1053        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1054
1055        let valid_msg: Value = serde_json::json!({"username": "alice"});
1056        assert!(validator.is_valid(&valid_msg));
1057
1058        let invalid_msg: Value = serde_json::json!({"username": "ab"});
1059        assert!(!validator.is_valid(&invalid_msg));
1060
1061        let invalid_msg2: Value =
1062            serde_json::json!({"username": "this_is_a_very_long_username_over_twenty_characters"});
1063        assert!(!validator.is_valid(&invalid_msg2));
1064    }
1065
1066    /// Test pattern (regex) validation
1067    #[test]
1068    fn test_pattern_validation() {
1069        let handler: EchoHandler = EchoHandler;
1070        let message_schema: serde_json::Value = serde_json::json!({
1071            "type": "object",
1072            "properties": {
1073                "email": {
1074                    "type": "string",
1075                    "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
1076                }
1077            },
1078            "required": ["email"]
1079        });
1080
1081        let state: WebSocketState<EchoHandler> =
1082            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1083
1084        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1085
1086        let valid_msg: Value = serde_json::json!({"email": "user@example.com"});
1087        assert!(validator.is_valid(&valid_msg));
1088
1089        let invalid_msg: Value = serde_json::json!({"email": "user@example"});
1090        assert!(!validator.is_valid(&invalid_msg));
1091
1092        let invalid_msg2: Value = serde_json::json!({"email": "userexample.com"});
1093        assert!(!validator.is_valid(&invalid_msg2));
1094    }
1095
1096    /// Test additionalProperties constraint
1097    #[test]
1098    fn test_additional_properties_validation() {
1099        let handler: EchoHandler = EchoHandler;
1100        let message_schema: serde_json::Value = serde_json::json!({
1101            "type": "object",
1102            "properties": {
1103                "name": {"type": "string"}
1104            },
1105            "additionalProperties": false
1106        });
1107
1108        let state: WebSocketState<EchoHandler> =
1109            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1110
1111        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1112
1113        let valid_msg: Value = serde_json::json!({"name": "Alice"});
1114        assert!(validator.is_valid(&valid_msg));
1115
1116        let invalid_msg: Value = serde_json::json!({"name": "Bob", "age": 30});
1117        assert!(!validator.is_valid(&invalid_msg));
1118    }
1119
1120    /// Test oneOf constraint (mutually exclusive properties)
1121    #[test]
1122    fn test_one_of_constraint() {
1123        let handler: EchoHandler = EchoHandler;
1124        let message_schema: serde_json::Value = serde_json::json!({
1125            "type": "object",
1126            "oneOf": [
1127                {
1128                    "properties": {"type": {"const": "text"}},
1129                    "required": ["type"]
1130                },
1131                {
1132                    "properties": {"type": {"const": "number"}},
1133                    "required": ["type"]
1134                }
1135            ]
1136        });
1137
1138        let state: WebSocketState<EchoHandler> =
1139            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1140
1141        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1142
1143        let valid_msg: Value = serde_json::json!({"type": "text"});
1144        assert!(validator.is_valid(&valid_msg));
1145
1146        let invalid_msg: Value = serde_json::json!({"type": "unknown"});
1147        assert!(!validator.is_valid(&invalid_msg));
1148    }
1149
1150    /// Test anyOf constraint (at least one match)
1151    #[test]
1152    fn test_any_of_constraint() {
1153        let handler: EchoHandler = EchoHandler;
1154        let message_schema: serde_json::Value = serde_json::json!({
1155            "type": "object",
1156            "properties": {
1157                "value": {"type": ["string", "integer"]}
1158            },
1159            "required": ["value"]
1160        });
1161
1162        let state: WebSocketState<EchoHandler> =
1163            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1164
1165        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1166
1167        let msg1: Value = serde_json::json!({"value": "text"});
1168        assert!(validator.is_valid(&msg1));
1169
1170        let msg2: Value = serde_json::json!({"value": 42});
1171        assert!(validator.is_valid(&msg2));
1172
1173        let invalid_msg: Value = serde_json::json!({"value": true});
1174        assert!(!validator.is_valid(&invalid_msg));
1175    }
1176
1177    /// Test response validation with complex constraints
1178    #[test]
1179    fn test_response_schema_with_multiple_constraints() {
1180        let handler: EchoHandler = EchoHandler;
1181        let response_schema: serde_json::Value = serde_json::json!({
1182            "type": "object",
1183            "properties": {
1184                "success": {"type": "boolean"},
1185                "data": {
1186                    "type": "object",
1187                    "properties": {
1188                        "items": {
1189                            "type": "array",
1190                            "items": {"type": "object"},
1191                            "minItems": 1
1192                        }
1193                    },
1194                    "required": ["items"]
1195                }
1196            },
1197            "required": ["success", "data"]
1198        });
1199
1200        let state: WebSocketState<EchoHandler> =
1201            WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
1202
1203        let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1204
1205        let valid_response: Value = serde_json::json!({
1206            "success": true,
1207            "data": {
1208                "items": [{"id": 1}]
1209            }
1210        });
1211        assert!(validator.is_valid(&valid_response));
1212
1213        let invalid_response: Value = serde_json::json!({
1214            "success": true,
1215            "data": {
1216                "items": []
1217            }
1218        });
1219        assert!(!validator.is_valid(&invalid_response));
1220
1221        let invalid_response2: Value = serde_json::json!({
1222            "success": true
1223        });
1224        assert!(!validator.is_valid(&invalid_response2));
1225    }
1226
1227    /// Test null type validation
1228    #[test]
1229    fn test_null_value_validation() {
1230        let handler: EchoHandler = EchoHandler;
1231        let message_schema: serde_json::Value = serde_json::json!({
1232            "type": "object",
1233            "properties": {
1234                "optional_field": {"type": ["string", "null"]},
1235                "required_field": {"type": "string"}
1236            },
1237            "required": ["required_field"]
1238        });
1239
1240        let state: WebSocketState<EchoHandler> =
1241            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1242
1243        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1244
1245        let msg1: Value = serde_json::json!({
1246            "optional_field": null,
1247            "required_field": "value"
1248        });
1249        assert!(validator.is_valid(&msg1));
1250
1251        let msg2: Value = serde_json::json!({"required_field": "value"});
1252        assert!(validator.is_valid(&msg2));
1253
1254        let invalid_msg: Value = serde_json::json!({"required_field": null});
1255        assert!(!validator.is_valid(&invalid_msg));
1256    }
1257
1258    /// Test schema with default values (they don't change validation)
1259    #[test]
1260    fn test_schema_with_defaults_still_validates() {
1261        let handler: EchoHandler = EchoHandler;
1262        let message_schema: serde_json::Value = serde_json::json!({
1263            "type": "object",
1264            "properties": {
1265                "status": {
1266                    "type": "string",
1267                    "default": "pending"
1268                }
1269            }
1270        });
1271
1272        let state: WebSocketState<EchoHandler> =
1273            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1274
1275        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1276
1277        let msg: Value = serde_json::json!({});
1278        assert!(validator.is_valid(&msg));
1279    }
1280
1281    /// Test both message and response schema validation together
1282    #[test]
1283    fn test_both_schemas_validate_independently() {
1284        let handler: EchoHandler = EchoHandler;
1285        let message_schema: serde_json::Value = serde_json::json!({
1286            "type": "object",
1287            "properties": {"action": {"type": "string"}},
1288            "required": ["action"]
1289        });
1290        let response_schema: serde_json::Value = serde_json::json!({
1291            "type": "object",
1292            "properties": {"result": {"type": "string"}},
1293            "required": ["result"]
1294        });
1295
1296        let state: WebSocketState<EchoHandler> =
1297            WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
1298
1299        let msg_validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1300        let resp_validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1301
1302        let valid_msg: Value = serde_json::json!({"action": "test"});
1303        let invalid_response: Value = serde_json::json!({"data": "oops"});
1304
1305        assert!(msg_validator.is_valid(&valid_msg));
1306        assert!(!resp_validator.is_valid(&invalid_response));
1307
1308        let invalid_msg: Value = serde_json::json!({"data": "oops"});
1309        let valid_response: Value = serde_json::json!({"result": "ok"});
1310
1311        assert!(!msg_validator.is_valid(&invalid_msg));
1312        assert!(resp_validator.is_valid(&valid_response));
1313    }
1314
1315    /// Test validation with very long/large payload
1316    #[test]
1317    fn test_validation_with_large_payload() {
1318        let handler: EchoHandler = EchoHandler;
1319        let message_schema: serde_json::Value = serde_json::json!({
1320            "type": "object",
1321            "properties": {
1322                "items": {
1323                    "type": "array",
1324                    "items": {"type": "integer"}
1325                }
1326            },
1327            "required": ["items"]
1328        });
1329
1330        let state: WebSocketState<EchoHandler> =
1331            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1332
1333        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1334
1335        let mut items = Vec::new();
1336        for i in 0..10_000 {
1337            items.push(i);
1338        }
1339        let large_msg: Value = serde_json::json!({"items": items});
1340
1341        assert!(validator.is_valid(&large_msg));
1342    }
1343
1344    /// Test validation error doesn't panic with invalid schema combinations
1345    #[test]
1346    fn test_mutually_exclusive_schema_properties() {
1347        let handler: EchoHandler = EchoHandler;
1348
1349        let message_schema: serde_json::Value = serde_json::json!({
1350            "allOf": [
1351                {
1352                    "type": "object",
1353                    "properties": {"a": {"type": "string"}},
1354                    "required": ["a"]
1355                },
1356                {
1357                    "type": "object",
1358                    "properties": {"b": {"type": "integer"}},
1359                    "required": ["b"]
1360                }
1361            ]
1362        });
1363
1364        let state: WebSocketState<EchoHandler> =
1365            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1366
1367        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1368
1369        let valid_msg: Value = serde_json::json!({"a": "text", "b": 42});
1370        assert!(validator.is_valid(&valid_msg));
1371
1372        let invalid_msg: Value = serde_json::json!({"a": "text"});
1373        assert!(!validator.is_valid(&invalid_msg));
1374    }
1375}