1use axum::{
6 extract::{
7 State,
8 ws::{Message, WebSocket, WebSocketUpgrade},
9 },
10 response::IntoResponse,
11};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, error, info, warn};
15
16fn trace_ws(message: &str) {
17 if std::env::var("SPIKARD_WS_TRACE").ok().as_deref() == Some("1") {
18 eprintln!("[spikard-ws] {message}");
19 }
20}
21
22pub trait WebSocketHandler: Send + Sync {
58 fn handle_message(&self, message: Value) -> impl std::future::Future<Output = Option<Value>> + Send;
70
71 fn on_connect(&self) -> impl std::future::Future<Output = ()> + Send {
76 async {}
77 }
78
79 fn on_disconnect(&self) -> impl std::future::Future<Output = ()> + Send {
84 async {}
85 }
86}
87
88pub struct WebSocketState<H: WebSocketHandler> {
94 handler: Arc<H>,
96 handler_factory: Arc<dyn Fn() -> Result<Arc<H>, String> + Send + Sync>,
98 message_schema: Option<Arc<jsonschema::Validator>>,
100 response_schema: Option<Arc<jsonschema::Validator>>,
102}
103
104impl<H: WebSocketHandler> std::fmt::Debug for WebSocketState<H> {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 f.debug_struct("WebSocketState")
107 .field("message_schema", &self.message_schema.is_some())
108 .field("response_schema", &self.response_schema.is_some())
109 .finish()
110 }
111}
112
113impl<H: WebSocketHandler> Clone for WebSocketState<H> {
114 fn clone(&self) -> Self {
115 Self {
116 handler: Arc::clone(&self.handler),
117 handler_factory: Arc::clone(&self.handler_factory),
118 message_schema: self.message_schema.clone(),
119 response_schema: self.response_schema.clone(),
120 }
121 }
122}
123
124impl<H: WebSocketHandler + 'static> WebSocketState<H> {
125 pub fn new(handler: H) -> Self {
139 let handler = Arc::new(handler);
140 Self {
141 handler_factory: Arc::new({
142 let handler = Arc::clone(&handler);
143 move || Ok(Arc::clone(&handler))
144 }),
145 handler,
146 message_schema: None,
147 response_schema: None,
148 }
149 }
150
151 pub fn with_schemas(
186 handler: H,
187 message_schema: Option<serde_json::Value>,
188 response_schema: Option<serde_json::Value>,
189 ) -> Result<Self, String> {
190 let message_validator = if let Some(schema) = message_schema {
191 Some(Arc::new(
192 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
193 ))
194 } else {
195 None
196 };
197
198 let response_validator = if let Some(schema) = response_schema {
199 Some(Arc::new(
200 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
201 ))
202 } else {
203 None
204 };
205
206 let handler = Arc::new(handler);
207 Ok(Self {
208 handler_factory: Arc::new({
209 let handler = Arc::clone(&handler);
210 move || Ok(Arc::clone(&handler))
211 }),
212 handler,
213 message_schema: message_validator,
214 response_schema: response_validator,
215 })
216 }
217
218 pub fn with_factory<F>(
222 factory: F,
223 message_schema: Option<serde_json::Value>,
224 response_schema: Option<serde_json::Value>,
225 ) -> Result<Self, String>
226 where
227 F: Fn() -> Result<H, String> + Send + Sync + 'static,
228 {
229 let message_validator = if let Some(schema) = message_schema {
230 Some(Arc::new(
231 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
232 ))
233 } else {
234 None
235 };
236
237 let response_validator = if let Some(schema) = response_schema {
238 Some(Arc::new(
239 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
240 ))
241 } else {
242 None
243 };
244
245 let factory = Arc::new(factory);
246 let handler = factory()
247 .map(Arc::new)
248 .map_err(|e| format!("Failed to build WebSocket handler: {}", e))?;
249
250 Ok(Self {
251 handler_factory: Arc::new({
252 let factory = Arc::clone(&factory);
253 move || factory().map(Arc::new)
254 }),
255 handler,
256 message_schema: message_validator,
257 response_schema: response_validator,
258 })
259 }
260
261 pub async fn on_connect(&self) {
263 self.handler.on_connect().await;
264 }
265
266 pub async fn on_disconnect(&self) {
268 self.handler.on_disconnect().await;
269 }
270
271 pub async fn handle_message_validated(&self, message: Value) -> Result<Option<Value>, String> {
273 if let Some(validator) = &self.message_schema
274 && !validator.is_valid(&message)
275 {
276 return Err("Message validation failed".to_string());
277 }
278
279 let response = self.handler.handle_message(message).await;
280 if let Some(ref value) = response
281 && let Some(validator) = &self.response_schema
282 && !validator.is_valid(value)
283 {
284 return Ok(None);
285 }
286
287 Ok(response)
288 }
289}
290
291pub async fn websocket_handler<H: WebSocketHandler + 'static>(
314 ws: WebSocketUpgrade,
315 State(state): State<WebSocketState<H>>,
316) -> impl IntoResponse {
317 ws.on_upgrade(move |socket| handle_socket(socket, state))
318}
319
320async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSocketState<H>) {
322 info!("WebSocket client connected");
323 trace_ws("socket:connected");
324
325 let handler = match (state.handler_factory)() {
326 Ok(handler) => handler,
327 Err(err) => {
328 error!("Failed to create WebSocket handler: {}", err);
329 trace_ws("socket:handler-factory:error");
330 return;
331 }
332 };
333
334 handler.on_connect().await;
335 trace_ws("socket:on_connect:done");
336
337 while let Some(msg) = socket.recv().await {
338 match msg {
339 Ok(Message::Text(text)) => {
340 debug!("Received text message: {}", text);
341 trace_ws(&format!("recv:text len={}", text.len()));
342
343 match serde_json::from_str::<Value>(&text) {
344 Ok(json_msg) => {
345 trace_ws("recv:text:json-ok");
346 if let Some(validator) = &state.message_schema
347 && !validator.is_valid(&json_msg)
348 {
349 error!("Message validation failed");
350 trace_ws("recv:text:validation-failed");
351 let error_response = serde_json::json!({
352 "error": "Message validation failed"
353 });
354 if let Ok(error_text) = serde_json::to_string(&error_response) {
355 trace_ws(&format!("send:validation-error len={}", error_text.len()));
356 let _ = socket.send(Message::Text(error_text.into())).await;
357 }
358 continue;
359 }
360
361 if let Some(response) = handler.handle_message(json_msg).await {
362 trace_ws("handler:response:some");
363 if let Some(validator) = &state.response_schema
364 && !validator.is_valid(&response)
365 {
366 error!("Response validation failed");
367 trace_ws("send:response:validation-failed");
368 continue;
369 }
370
371 let response_text = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
372 let response_len = response_text.len();
373
374 if let Err(e) = socket.send(Message::Text(response_text.into())).await {
375 error!("Failed to send response: {}", e);
376 trace_ws("send:response:error");
377 break;
378 }
379 trace_ws(&format!("send:response len={}", response_len));
380 } else {
381 trace_ws("handler:response:none");
382 }
383 }
384 Err(e) => {
385 warn!("Failed to parse JSON message: {}", e);
386 trace_ws("recv:text:json-error");
387 let error_msg = serde_json::json!({
388 "type": "error",
389 "message": "Invalid JSON"
390 });
391 let error_text = serde_json::to_string(&error_msg).unwrap_or_else(|_| "{}".to_string());
392 trace_ws(&format!("send:json-error len={}", error_text.len()));
393 let _ = socket.send(Message::Text(error_text.into())).await;
394 }
395 }
396 }
397 Ok(Message::Binary(data)) => {
398 debug!("Received binary message: {} bytes", data.len());
399 trace_ws(&format!("recv:binary len={}", data.len()));
400 if let Err(e) = socket.send(Message::Binary(data)).await {
401 error!("Failed to send binary response: {}", e);
402 trace_ws("send:binary:error");
403 break;
404 }
405 trace_ws("send:binary:ok");
406 }
407 Ok(Message::Ping(data)) => {
408 debug!("Received ping");
409 trace_ws(&format!("recv:ping len={}", data.len()));
410 if let Err(e) = socket.send(Message::Pong(data)).await {
411 error!("Failed to send pong: {}", e);
412 trace_ws("send:pong:error");
413 break;
414 }
415 trace_ws("send:pong:ok");
416 }
417 Ok(Message::Pong(_)) => {
418 debug!("Received pong");
419 trace_ws("recv:pong");
420 }
421 Ok(Message::Close(close_frame)) => {
422 let code: u16 = close_frame.as_ref().map(|f| f.code).unwrap_or(1005);
423 let reason = close_frame.as_ref().map(|f| f.reason.as_str()).unwrap_or("");
424 info!("Client closed connection: code={} reason={:?}", code, reason);
425 trace_ws(&format!("recv:close code={} reason={:?}", code, reason));
426 break;
427 }
428 Err(e) => {
429 error!("WebSocket error: {}", e);
430 trace_ws(&format!("recv:error {}", e));
431 break;
432 }
433 }
434 }
435
436 handler.on_disconnect().await;
437 trace_ws("socket:on_disconnect:done");
438 info!("WebSocket client disconnected");
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use std::sync::Mutex;
445 use std::sync::atomic::{AtomicUsize, Ordering};
446
447 #[derive(Debug)]
448 struct EchoHandler;
449
450 impl WebSocketHandler for EchoHandler {
451 async fn handle_message(&self, message: Value) -> Option<Value> {
452 Some(message)
453 }
454 }
455
456 #[derive(Debug)]
457 struct TrackingHandler {
458 connect_count: Arc<AtomicUsize>,
459 disconnect_count: Arc<AtomicUsize>,
460 message_count: Arc<AtomicUsize>,
461 messages: Arc<Mutex<Vec<Value>>>,
462 }
463
464 impl TrackingHandler {
465 fn new() -> Self {
466 Self {
467 connect_count: Arc::new(AtomicUsize::new(0)),
468 disconnect_count: Arc::new(AtomicUsize::new(0)),
469 message_count: Arc::new(AtomicUsize::new(0)),
470 messages: Arc::new(Mutex::new(Vec::new())),
471 }
472 }
473 }
474
475 impl WebSocketHandler for TrackingHandler {
476 async fn handle_message(&self, message: Value) -> Option<Value> {
477 self.message_count.fetch_add(1, Ordering::SeqCst);
478 self.messages.lock().unwrap().push(message.clone());
479 Some(message)
480 }
481
482 async fn on_connect(&self) {
483 self.connect_count.fetch_add(1, Ordering::SeqCst);
484 }
485
486 async fn on_disconnect(&self) {
487 self.disconnect_count.fetch_add(1, Ordering::SeqCst);
488 }
489 }
490
491 #[derive(Debug)]
492 struct SelectiveHandler;
493
494 impl WebSocketHandler for SelectiveHandler {
495 async fn handle_message(&self, message: Value) -> Option<Value> {
496 if message.get("respond").is_some_and(|v| v.as_bool().unwrap_or(false)) {
497 Some(serde_json::json!({"response": "acknowledged"}))
498 } else {
499 None
500 }
501 }
502 }
503
504 #[derive(Debug)]
505 struct TransformHandler;
506
507 impl WebSocketHandler for TransformHandler {
508 async fn handle_message(&self, message: Value) -> Option<Value> {
509 message.as_object().map_or(None, |obj| {
510 let mut resp = obj.clone();
511 resp.insert("processed".to_string(), Value::Bool(true));
512 Some(Value::Object(resp))
513 })
514 }
515 }
516
517 #[test]
518 fn test_websocket_state_creation() {
519 let handler: EchoHandler = EchoHandler;
520 let state: WebSocketState<EchoHandler> = WebSocketState::new(handler);
521 let cloned: WebSocketState<EchoHandler> = state.clone();
522 assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
523 }
524
525 #[test]
526 fn test_websocket_state_with_valid_schema() {
527 let handler: EchoHandler = EchoHandler;
528 let schema: serde_json::Value = serde_json::json!({
529 "type": "object",
530 "properties": {
531 "type": {"type": "string"}
532 }
533 });
534
535 let result: Result<WebSocketState<EchoHandler>, String> =
536 WebSocketState::with_schemas(handler, Some(schema), None);
537 assert!(result.is_ok());
538 }
539
540 #[test]
541 fn test_websocket_state_with_invalid_schema() {
542 let handler: EchoHandler = EchoHandler;
543 let invalid_schema: serde_json::Value = serde_json::json!({
544 "type": "not_a_real_type",
545 "invalid": "schema"
546 });
547
548 let result: Result<WebSocketState<EchoHandler>, String> =
549 WebSocketState::with_schemas(handler, Some(invalid_schema), None);
550 assert!(result.is_err());
551 if let Err(error_msg) = result {
552 assert!(error_msg.contains("Invalid message schema"));
553 }
554 }
555
556 #[test]
557 fn test_websocket_state_with_both_schemas() {
558 let handler: EchoHandler = EchoHandler;
559 let message_schema: serde_json::Value = serde_json::json!({
560 "type": "object",
561 "properties": {"action": {"type": "string"}}
562 });
563 let response_schema: serde_json::Value = serde_json::json!({
564 "type": "object",
565 "properties": {"result": {"type": "string"}}
566 });
567
568 let result: Result<WebSocketState<EchoHandler>, String> =
569 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema));
570 assert!(result.is_ok());
571 let state: WebSocketState<EchoHandler> = result.unwrap();
572 assert!(state.message_schema.is_some());
573 assert!(state.response_schema.is_some());
574 }
575
576 #[test]
577 fn test_websocket_state_cloning_preserves_schemas() {
578 let handler: EchoHandler = EchoHandler;
579 let schema: serde_json::Value = serde_json::json!({
580 "type": "object",
581 "properties": {"id": {"type": "integer"}}
582 });
583
584 let state: WebSocketState<EchoHandler> = WebSocketState::with_schemas(handler, Some(schema), None).unwrap();
585 let cloned: WebSocketState<EchoHandler> = state.clone();
586
587 assert!(cloned.message_schema.is_some());
588 assert!(cloned.response_schema.is_none());
589 assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
590 }
591
592 #[tokio::test]
593 async fn test_tracking_handler_lifecycle() {
594 let handler: TrackingHandler = TrackingHandler::new();
595 handler.on_connect().await;
596 assert_eq!(handler.connect_count.load(Ordering::SeqCst), 1);
597
598 let msg: Value = serde_json::json!({"test": "data"});
599 let _response: Option<Value> = handler.handle_message(msg).await;
600 assert_eq!(handler.message_count.load(Ordering::SeqCst), 1);
601
602 handler.on_disconnect().await;
603 assert_eq!(handler.disconnect_count.load(Ordering::SeqCst), 1);
604 }
605
606 #[tokio::test]
607 async fn test_selective_handler_responds_conditionally() {
608 let handler: SelectiveHandler = SelectiveHandler;
609
610 let respond_msg: Value = serde_json::json!({"respond": true});
611 let response1: Option<Value> = handler.handle_message(respond_msg).await;
612 assert!(response1.is_some());
613 assert_eq!(response1.unwrap(), serde_json::json!({"response": "acknowledged"}));
614
615 let no_respond_msg: Value = serde_json::json!({"respond": false});
616 let response2: Option<Value> = handler.handle_message(no_respond_msg).await;
617 assert!(response2.is_none());
618 }
619
620 #[tokio::test]
621 async fn test_transform_handler_modifies_message() {
622 let handler: TransformHandler = TransformHandler;
623 let original: Value = serde_json::json!({"name": "test"});
624 let transformed: Option<Value> = handler.handle_message(original).await;
625
626 assert!(transformed.is_some());
627 let resp: Value = transformed.unwrap();
628 assert_eq!(resp.get("name").unwrap(), "test");
629 assert_eq!(resp.get("processed").unwrap(), true);
630 }
631
632 #[tokio::test]
633 async fn test_echo_handler_preserves_json_types() {
634 let handler: EchoHandler = EchoHandler;
635
636 let messages: Vec<Value> = vec![
637 serde_json::json!({"string": "value"}),
638 serde_json::json!({"number": 42}),
639 serde_json::json!({"float": 3.14}),
640 serde_json::json!({"bool": true}),
641 serde_json::json!({"null": null}),
642 serde_json::json!({"array": [1, 2, 3]}),
643 ];
644
645 for msg in messages {
646 let response: Option<Value> = handler.handle_message(msg.clone()).await;
647 assert!(response.is_some());
648 assert_eq!(response.unwrap(), msg);
649 }
650 }
651
652 #[tokio::test]
653 async fn test_tracking_handler_accumulates_messages() {
654 let handler: TrackingHandler = TrackingHandler::new();
655
656 let messages: Vec<Value> = vec![
657 serde_json::json!({"id": 1}),
658 serde_json::json!({"id": 2}),
659 serde_json::json!({"id": 3}),
660 ];
661
662 for msg in messages {
663 let _: Option<Value> = handler.handle_message(msg).await;
664 }
665
666 assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
667 let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
668 assert_eq!(stored.len(), 3);
669 assert_eq!(stored[0].get("id").unwrap(), 1);
670 assert_eq!(stored[1].get("id").unwrap(), 2);
671 assert_eq!(stored[2].get("id").unwrap(), 3);
672 }
673
674 #[tokio::test]
675 async fn test_echo_handler_with_nested_json() {
676 let handler: EchoHandler = EchoHandler;
677 let nested: Value = serde_json::json!({
678 "level1": {
679 "level2": {
680 "level3": {
681 "value": "deeply nested"
682 }
683 }
684 }
685 });
686
687 let response: Option<Value> = handler.handle_message(nested.clone()).await;
688 assert!(response.is_some());
689 assert_eq!(response.unwrap(), nested);
690 }
691
692 #[tokio::test]
693 async fn test_echo_handler_with_large_array() {
694 let handler: EchoHandler = EchoHandler;
695 let large_array: Value = serde_json::json!({
696 "items": (0..1000).collect::<Vec<i32>>()
697 });
698
699 let response: Option<Value> = handler.handle_message(large_array.clone()).await;
700 assert!(response.is_some());
701 assert_eq!(response.unwrap(), large_array);
702 }
703
704 #[tokio::test]
705 async fn test_echo_handler_with_unicode() {
706 let handler: EchoHandler = EchoHandler;
707 let unicode_msg: Value = serde_json::json!({
708 "emoji": "🚀",
709 "chinese": "你好",
710 "arabic": "مرحبا",
711 "mixed": "Hello 世界 🌍"
712 });
713
714 let response: Option<Value> = handler.handle_message(unicode_msg.clone()).await;
715 assert!(response.is_some());
716 assert_eq!(response.unwrap(), unicode_msg);
717 }
718
719 #[test]
720 fn test_websocket_state_schemas_are_independent() {
721 let handler: EchoHandler = EchoHandler;
722 let message_schema: serde_json::Value = serde_json::json!({"type": "object"});
723 let response_schema: serde_json::Value = serde_json::json!({"type": "array"});
724
725 let state: WebSocketState<EchoHandler> =
726 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
727
728 let cloned: WebSocketState<EchoHandler> = state.clone();
729
730 assert!(state.message_schema.is_some());
731 assert!(state.response_schema.is_some());
732 assert!(cloned.message_schema.is_some());
733 assert!(cloned.response_schema.is_some());
734 }
735
736 #[test]
737 fn test_message_schema_validation_with_required_field() {
738 let handler: EchoHandler = EchoHandler;
739 let message_schema: serde_json::Value = serde_json::json!({
740 "type": "object",
741 "properties": {"type": {"type": "string"}},
742 "required": ["type"]
743 });
744
745 let state: WebSocketState<EchoHandler> =
746 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
747
748 assert!(state.message_schema.is_some());
749 assert!(state.response_schema.is_none());
750
751 let valid_msg: Value = serde_json::json!({"type": "test"});
752 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
753 assert!(validator.is_valid(&valid_msg));
754
755 let invalid_msg: Value = serde_json::json!({"other": "field"});
756 assert!(!validator.is_valid(&invalid_msg));
757 }
758
759 #[test]
760 fn test_response_schema_validation_with_required_field() {
761 let handler: EchoHandler = EchoHandler;
762 let response_schema: serde_json::Value = serde_json::json!({
763 "type": "object",
764 "properties": {"status": {"type": "string"}},
765 "required": ["status"]
766 });
767
768 let state: WebSocketState<EchoHandler> =
769 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
770
771 assert!(state.message_schema.is_none());
772 assert!(state.response_schema.is_some());
773
774 let valid_response: Value = serde_json::json!({"status": "ok"});
775 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
776 assert!(validator.is_valid(&valid_response));
777
778 let invalid_response: Value = serde_json::json!({"other": "field"});
779 assert!(!validator.is_valid(&invalid_response));
780 }
781
782 #[test]
783 fn test_invalid_message_schema_returns_error() {
784 let handler: EchoHandler = EchoHandler;
785 let invalid_schema: serde_json::Value = serde_json::json!({
786 "type": "invalid_type_value",
787 "properties": {}
788 });
789
790 let result: Result<WebSocketState<EchoHandler>, String> =
791 WebSocketState::with_schemas(handler, Some(invalid_schema), None);
792
793 assert!(result.is_err());
794 match result {
795 Err(error_msg) => assert!(error_msg.contains("Invalid message schema")),
796 Ok(_) => panic!("Expected error but got Ok"),
797 }
798 }
799
800 #[test]
801 fn test_invalid_response_schema_returns_error() {
802 let handler: EchoHandler = EchoHandler;
803 let invalid_schema: serde_json::Value = serde_json::json!({
804 "type": "definitely_not_valid"
805 });
806
807 let result: Result<WebSocketState<EchoHandler>, String> =
808 WebSocketState::with_schemas(handler, None, Some(invalid_schema));
809
810 assert!(result.is_err());
811 match result {
812 Err(error_msg) => assert!(error_msg.contains("Invalid response schema")),
813 Ok(_) => panic!("Expected error but got Ok"),
814 }
815 }
816
817 #[tokio::test]
818 async fn test_handler_returning_none_response() {
819 let handler: SelectiveHandler = SelectiveHandler;
820
821 let no_response_msg: Value = serde_json::json!({"respond": false});
822 let result: Option<Value> = handler.handle_message(no_response_msg).await;
823
824 assert!(result.is_none());
825 }
826
827 #[tokio::test]
828 async fn test_handler_with_complex_schema_validation() {
829 let handler: EchoHandler = EchoHandler;
830 let message_schema: serde_json::Value = serde_json::json!({
831 "type": "object",
832 "properties": {
833 "user": {
834 "type": "object",
835 "properties": {
836 "id": {"type": "integer"},
837 "name": {"type": "string"}
838 },
839 "required": ["id", "name"]
840 },
841 "action": {"type": "string"}
842 },
843 "required": ["user", "action"]
844 });
845
846 let state: WebSocketState<EchoHandler> =
847 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
848
849 let valid_msg: Value = serde_json::json!({
850 "user": {"id": 123, "name": "Alice"},
851 "action": "create"
852 });
853 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
854 assert!(validator.is_valid(&valid_msg));
855
856 let invalid_msg: Value = serde_json::json!({
857 "user": {"id": "not_an_int", "name": "Bob"},
858 "action": "create"
859 });
860 assert!(!validator.is_valid(&invalid_msg));
861 }
862
863 #[tokio::test]
864 async fn test_tracking_handler_with_multiple_message_types() {
865 let handler: TrackingHandler = TrackingHandler::new();
866
867 let messages: Vec<Value> = vec![
868 serde_json::json!({"type": "text", "content": "hello"}),
869 serde_json::json!({"type": "image", "url": "http://example.com/image.png"}),
870 serde_json::json!({"type": "video", "duration": 120}),
871 ];
872
873 for msg in messages {
874 let _: Option<Value> = handler.handle_message(msg).await;
875 }
876
877 assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
878 let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
879 assert_eq!(stored.len(), 3);
880 assert_eq!(stored[0].get("type").unwrap(), "text");
881 assert_eq!(stored[1].get("type").unwrap(), "image");
882 assert_eq!(stored[2].get("type").unwrap(), "video");
883 }
884
885 #[tokio::test]
886 async fn test_selective_handler_with_explicit_false() {
887 let handler: SelectiveHandler = SelectiveHandler;
888
889 let msg: Value = serde_json::json!({"respond": false, "data": "test"});
890 let response: Option<Value> = handler.handle_message(msg).await;
891
892 assert!(response.is_none());
893 }
894
895 #[tokio::test]
896 async fn test_selective_handler_without_respond_field() {
897 let handler: SelectiveHandler = SelectiveHandler;
898
899 let msg: Value = serde_json::json!({"data": "test"});
900 let response: Option<Value> = handler.handle_message(msg).await;
901
902 assert!(response.is_none());
903 }
904
905 #[tokio::test]
906 async fn test_transform_handler_with_empty_object() {
907 let handler: TransformHandler = TransformHandler;
908 let original: Value = serde_json::json!({});
909 let transformed: Option<Value> = handler.handle_message(original).await;
910
911 assert!(transformed.is_some());
912 let resp: Value = transformed.unwrap();
913 assert_eq!(resp.get("processed").unwrap(), true);
914 assert_eq!(resp.as_object().unwrap().len(), 1);
915 }
916
917 #[tokio::test]
918 async fn test_transform_handler_preserves_all_fields() {
919 let handler: TransformHandler = TransformHandler;
920 let original: Value = serde_json::json!({
921 "field1": "value1",
922 "field2": 42,
923 "field3": true,
924 "nested": {"key": "value"}
925 });
926 let transformed: Option<Value> = handler.handle_message(original.clone()).await;
927
928 assert!(transformed.is_some());
929 let resp: Value = transformed.unwrap();
930 assert_eq!(resp.get("field1").unwrap(), "value1");
931 assert_eq!(resp.get("field2").unwrap(), 42);
932 assert_eq!(resp.get("field3").unwrap(), true);
933 assert_eq!(resp.get("nested").unwrap(), &serde_json::json!({"key": "value"}));
934 assert_eq!(resp.get("processed").unwrap(), true);
935 }
936
937 #[tokio::test]
938 async fn test_transform_handler_with_non_object_input() {
939 let handler: TransformHandler = TransformHandler;
940
941 let array: Value = serde_json::json!([1, 2, 3]);
942 let response1: Option<Value> = handler.handle_message(array).await;
943 assert!(response1.is_none());
944
945 let string: Value = serde_json::json!("not an object");
946 let response2: Option<Value> = handler.handle_message(string).await;
947 assert!(response2.is_none());
948
949 let number: Value = serde_json::json!(42);
950 let response3: Option<Value> = handler.handle_message(number).await;
951 assert!(response3.is_none());
952 }
953
954 #[test]
956 fn test_message_schema_rejects_wrong_type() {
957 let handler: EchoHandler = EchoHandler;
958 let message_schema: serde_json::Value = serde_json::json!({
959 "type": "object",
960 "properties": {"id": {"type": "integer"}},
961 "required": ["id"]
962 });
963
964 let state: WebSocketState<EchoHandler> =
965 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
966
967 let invalid_msg: Value = serde_json::json!({"id": "not_an_integer"});
968 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
969 assert!(!validator.is_valid(&invalid_msg));
970 }
971
972 #[test]
974 fn test_response_schema_rejects_invalid_type() {
975 let handler: EchoHandler = EchoHandler;
976 let response_schema: serde_json::Value = serde_json::json!({
977 "type": "object",
978 "properties": {"count": {"type": "integer"}},
979 "required": ["count"]
980 });
981
982 let state: WebSocketState<EchoHandler> =
983 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
984
985 let invalid_response: Value = serde_json::json!([1, 2, 3]);
986 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
987 assert!(!validator.is_valid(&invalid_response));
988 }
989
990 #[test]
992 fn test_message_missing_multiple_required_fields() {
993 let handler: EchoHandler = EchoHandler;
994 let message_schema: serde_json::Value = serde_json::json!({
995 "type": "object",
996 "properties": {
997 "user_id": {"type": "integer"},
998 "action": {"type": "string"},
999 "timestamp": {"type": "string"}
1000 },
1001 "required": ["user_id", "action", "timestamp"]
1002 });
1003
1004 let state: WebSocketState<EchoHandler> =
1005 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1006
1007 let invalid_msg: Value = serde_json::json!({"other": "value"});
1008 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1009 assert!(!validator.is_valid(&invalid_msg));
1010
1011 let partial_msg: Value = serde_json::json!({"user_id": 123});
1012 assert!(!validator.is_valid(&partial_msg));
1013 }
1014
1015 #[test]
1017 fn test_deeply_nested_schema_validation_failure() {
1018 let handler: EchoHandler = EchoHandler;
1019 let message_schema: serde_json::Value = serde_json::json!({
1020 "type": "object",
1021 "properties": {
1022 "metadata": {
1023 "type": "object",
1024 "properties": {
1025 "request": {
1026 "type": "object",
1027 "properties": {
1028 "id": {"type": "string"}
1029 },
1030 "required": ["id"]
1031 }
1032 },
1033 "required": ["request"]
1034 }
1035 },
1036 "required": ["metadata"]
1037 });
1038
1039 let state: WebSocketState<EchoHandler> =
1040 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1041
1042 let invalid_msg: Value = serde_json::json!({
1043 "metadata": {
1044 "request": {}
1045 }
1046 });
1047 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1048 assert!(!validator.is_valid(&invalid_msg));
1049 }
1050
1051 #[test]
1053 fn test_array_property_type_validation() {
1054 let handler: EchoHandler = EchoHandler;
1055 let message_schema: serde_json::Value = serde_json::json!({
1056 "type": "object",
1057 "properties": {
1058 "ids": {
1059 "type": "array",
1060 "items": {"type": "integer"}
1061 }
1062 }
1063 });
1064
1065 let state: WebSocketState<EchoHandler> =
1066 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1067
1068 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1069
1070 let valid_msg: Value = serde_json::json!({"ids": [1, 2, 3]});
1071 assert!(validator.is_valid(&valid_msg));
1072
1073 let invalid_msg: Value = serde_json::json!({"ids": [1, "two", 3]});
1074 assert!(!validator.is_valid(&invalid_msg));
1075
1076 let invalid_msg2: Value = serde_json::json!({"ids": "not_an_array"});
1077 assert!(!validator.is_valid(&invalid_msg2));
1078 }
1079
1080 #[test]
1082 fn test_enum_property_validation() {
1083 let handler: EchoHandler = EchoHandler;
1084 let message_schema: serde_json::Value = serde_json::json!({
1085 "type": "object",
1086 "properties": {
1087 "status": {
1088 "type": "string",
1089 "enum": ["pending", "active", "completed"]
1090 }
1091 },
1092 "required": ["status"]
1093 });
1094
1095 let state: WebSocketState<EchoHandler> =
1096 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1097
1098 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1099
1100 let valid_msg: Value = serde_json::json!({"status": "active"});
1101 assert!(validator.is_valid(&valid_msg));
1102
1103 let invalid_msg: Value = serde_json::json!({"status": "unknown"});
1104 assert!(!validator.is_valid(&invalid_msg));
1105 }
1106
1107 #[test]
1109 fn test_number_range_validation() {
1110 let handler: EchoHandler = EchoHandler;
1111 let message_schema: serde_json::Value = serde_json::json!({
1112 "type": "object",
1113 "properties": {
1114 "age": {
1115 "type": "integer",
1116 "minimum": 0,
1117 "maximum": 150
1118 }
1119 },
1120 "required": ["age"]
1121 });
1122
1123 let state: WebSocketState<EchoHandler> =
1124 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1125
1126 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1127
1128 let valid_msg: Value = serde_json::json!({"age": 25});
1129 assert!(validator.is_valid(&valid_msg));
1130
1131 let invalid_msg: Value = serde_json::json!({"age": -1});
1132 assert!(!validator.is_valid(&invalid_msg));
1133
1134 let invalid_msg2: Value = serde_json::json!({"age": 200});
1135 assert!(!validator.is_valid(&invalid_msg2));
1136 }
1137
1138 #[test]
1140 fn test_string_length_validation() {
1141 let handler: EchoHandler = EchoHandler;
1142 let message_schema: serde_json::Value = serde_json::json!({
1143 "type": "object",
1144 "properties": {
1145 "username": {
1146 "type": "string",
1147 "minLength": 3,
1148 "maxLength": 20
1149 }
1150 },
1151 "required": ["username"]
1152 });
1153
1154 let state: WebSocketState<EchoHandler> =
1155 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1156
1157 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1158
1159 let valid_msg: Value = serde_json::json!({"username": "alice"});
1160 assert!(validator.is_valid(&valid_msg));
1161
1162 let invalid_msg: Value = serde_json::json!({"username": "ab"});
1163 assert!(!validator.is_valid(&invalid_msg));
1164
1165 let invalid_msg2: Value =
1166 serde_json::json!({"username": "this_is_a_very_long_username_over_twenty_characters"});
1167 assert!(!validator.is_valid(&invalid_msg2));
1168 }
1169
1170 #[test]
1172 fn test_pattern_validation() {
1173 let handler: EchoHandler = EchoHandler;
1174 let message_schema: serde_json::Value = serde_json::json!({
1175 "type": "object",
1176 "properties": {
1177 "email": {
1178 "type": "string",
1179 "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
1180 }
1181 },
1182 "required": ["email"]
1183 });
1184
1185 let state: WebSocketState<EchoHandler> =
1186 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1187
1188 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1189
1190 let valid_msg: Value = serde_json::json!({"email": "user@example.com"});
1191 assert!(validator.is_valid(&valid_msg));
1192
1193 let invalid_msg: Value = serde_json::json!({"email": "user@example"});
1194 assert!(!validator.is_valid(&invalid_msg));
1195
1196 let invalid_msg2: Value = serde_json::json!({"email": "userexample.com"});
1197 assert!(!validator.is_valid(&invalid_msg2));
1198 }
1199
1200 #[test]
1202 fn test_additional_properties_validation() {
1203 let handler: EchoHandler = EchoHandler;
1204 let message_schema: serde_json::Value = serde_json::json!({
1205 "type": "object",
1206 "properties": {
1207 "name": {"type": "string"}
1208 },
1209 "additionalProperties": false
1210 });
1211
1212 let state: WebSocketState<EchoHandler> =
1213 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1214
1215 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1216
1217 let valid_msg: Value = serde_json::json!({"name": "Alice"});
1218 assert!(validator.is_valid(&valid_msg));
1219
1220 let invalid_msg: Value = serde_json::json!({"name": "Bob", "age": 30});
1221 assert!(!validator.is_valid(&invalid_msg));
1222 }
1223
1224 #[test]
1226 fn test_one_of_constraint() {
1227 let handler: EchoHandler = EchoHandler;
1228 let message_schema: serde_json::Value = serde_json::json!({
1229 "type": "object",
1230 "oneOf": [
1231 {
1232 "properties": {"type": {"const": "text"}},
1233 "required": ["type"]
1234 },
1235 {
1236 "properties": {"type": {"const": "number"}},
1237 "required": ["type"]
1238 }
1239 ]
1240 });
1241
1242 let state: WebSocketState<EchoHandler> =
1243 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1244
1245 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1246
1247 let valid_msg: Value = serde_json::json!({"type": "text"});
1248 assert!(validator.is_valid(&valid_msg));
1249
1250 let invalid_msg: Value = serde_json::json!({"type": "unknown"});
1251 assert!(!validator.is_valid(&invalid_msg));
1252 }
1253
1254 #[test]
1256 fn test_any_of_constraint() {
1257 let handler: EchoHandler = EchoHandler;
1258 let message_schema: serde_json::Value = serde_json::json!({
1259 "type": "object",
1260 "properties": {
1261 "value": {"type": ["string", "integer"]}
1262 },
1263 "required": ["value"]
1264 });
1265
1266 let state: WebSocketState<EchoHandler> =
1267 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1268
1269 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1270
1271 let msg1: Value = serde_json::json!({"value": "text"});
1272 assert!(validator.is_valid(&msg1));
1273
1274 let msg2: Value = serde_json::json!({"value": 42});
1275 assert!(validator.is_valid(&msg2));
1276
1277 let invalid_msg: Value = serde_json::json!({"value": true});
1278 assert!(!validator.is_valid(&invalid_msg));
1279 }
1280
1281 #[test]
1283 fn test_response_schema_with_multiple_constraints() {
1284 let handler: EchoHandler = EchoHandler;
1285 let response_schema: serde_json::Value = serde_json::json!({
1286 "type": "object",
1287 "properties": {
1288 "success": {"type": "boolean"},
1289 "data": {
1290 "type": "object",
1291 "properties": {
1292 "items": {
1293 "type": "array",
1294 "items": {"type": "object"},
1295 "minItems": 1
1296 }
1297 },
1298 "required": ["items"]
1299 }
1300 },
1301 "required": ["success", "data"]
1302 });
1303
1304 let state: WebSocketState<EchoHandler> =
1305 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
1306
1307 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1308
1309 let valid_response: Value = serde_json::json!({
1310 "success": true,
1311 "data": {
1312 "items": [{"id": 1}]
1313 }
1314 });
1315 assert!(validator.is_valid(&valid_response));
1316
1317 let invalid_response: Value = serde_json::json!({
1318 "success": true,
1319 "data": {
1320 "items": []
1321 }
1322 });
1323 assert!(!validator.is_valid(&invalid_response));
1324
1325 let invalid_response2: Value = serde_json::json!({
1326 "success": true
1327 });
1328 assert!(!validator.is_valid(&invalid_response2));
1329 }
1330
1331 #[test]
1333 fn test_null_value_validation() {
1334 let handler: EchoHandler = EchoHandler;
1335 let message_schema: serde_json::Value = serde_json::json!({
1336 "type": "object",
1337 "properties": {
1338 "optional_field": {"type": ["string", "null"]},
1339 "required_field": {"type": "string"}
1340 },
1341 "required": ["required_field"]
1342 });
1343
1344 let state: WebSocketState<EchoHandler> =
1345 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1346
1347 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1348
1349 let msg1: Value = serde_json::json!({
1350 "optional_field": null,
1351 "required_field": "value"
1352 });
1353 assert!(validator.is_valid(&msg1));
1354
1355 let msg2: Value = serde_json::json!({"required_field": "value"});
1356 assert!(validator.is_valid(&msg2));
1357
1358 let invalid_msg: Value = serde_json::json!({"required_field": null});
1359 assert!(!validator.is_valid(&invalid_msg));
1360 }
1361
1362 #[test]
1364 fn test_schema_with_defaults_still_validates() {
1365 let handler: EchoHandler = EchoHandler;
1366 let message_schema: serde_json::Value = serde_json::json!({
1367 "type": "object",
1368 "properties": {
1369 "status": {
1370 "type": "string",
1371 "default": "pending"
1372 }
1373 }
1374 });
1375
1376 let state: WebSocketState<EchoHandler> =
1377 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1378
1379 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1380
1381 let msg: Value = serde_json::json!({});
1382 assert!(validator.is_valid(&msg));
1383 }
1384
1385 #[test]
1387 fn test_both_schemas_validate_independently() {
1388 let handler: EchoHandler = EchoHandler;
1389 let message_schema: serde_json::Value = serde_json::json!({
1390 "type": "object",
1391 "properties": {"action": {"type": "string"}},
1392 "required": ["action"]
1393 });
1394 let response_schema: serde_json::Value = serde_json::json!({
1395 "type": "object",
1396 "properties": {"result": {"type": "string"}},
1397 "required": ["result"]
1398 });
1399
1400 let state: WebSocketState<EchoHandler> =
1401 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
1402
1403 let msg_validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1404 let resp_validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1405
1406 let valid_msg: Value = serde_json::json!({"action": "test"});
1407 let invalid_response: Value = serde_json::json!({"data": "oops"});
1408
1409 assert!(msg_validator.is_valid(&valid_msg));
1410 assert!(!resp_validator.is_valid(&invalid_response));
1411
1412 let invalid_msg: Value = serde_json::json!({"data": "oops"});
1413 let valid_response: Value = serde_json::json!({"result": "ok"});
1414
1415 assert!(!msg_validator.is_valid(&invalid_msg));
1416 assert!(resp_validator.is_valid(&valid_response));
1417 }
1418
1419 #[test]
1421 fn test_validation_with_large_payload() {
1422 let handler: EchoHandler = EchoHandler;
1423 let message_schema: serde_json::Value = serde_json::json!({
1424 "type": "object",
1425 "properties": {
1426 "items": {
1427 "type": "array",
1428 "items": {"type": "integer"}
1429 }
1430 },
1431 "required": ["items"]
1432 });
1433
1434 let state: WebSocketState<EchoHandler> =
1435 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1436
1437 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1438
1439 let mut items = Vec::new();
1440 for i in 0..10_000 {
1441 items.push(i);
1442 }
1443 let large_msg: Value = serde_json::json!({"items": items});
1444
1445 assert!(validator.is_valid(&large_msg));
1446 }
1447
1448 #[test]
1450 fn test_mutually_exclusive_schema_properties() {
1451 let handler: EchoHandler = EchoHandler;
1452
1453 let message_schema: serde_json::Value = serde_json::json!({
1454 "allOf": [
1455 {
1456 "type": "object",
1457 "properties": {"a": {"type": "string"}},
1458 "required": ["a"]
1459 },
1460 {
1461 "type": "object",
1462 "properties": {"b": {"type": "integer"}},
1463 "required": ["b"]
1464 }
1465 ]
1466 });
1467
1468 let state: WebSocketState<EchoHandler> =
1469 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1470
1471 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1472
1473 let valid_msg: Value = serde_json::json!({"a": "text", "b": 42});
1474 assert!(validator.is_valid(&valid_msg));
1475
1476 let invalid_msg: Value = serde_json::json!({"a": "text"});
1477 assert!(!validator.is_valid(&invalid_msg));
1478 }
1479}