1use std::collections::{HashMap, HashSet};
16use std::time::{Duration, Instant};
17
18use actix_web::{HttpMessage, HttpRequest, HttpResponse, web};
19use parking_lot::RwLock;
20use serde_json::{Map, Value};
21use uuid::Uuid;
22
23use haystack_core::codecs::json::v3 as json_v3;
24use haystack_core::data::HDict;
25use haystack_core::graph::SharedGraph;
26
27use crate::state::AppState;
28
29const MAX_WATCHES: usize = 100;
34const MAX_ENTITY_IDS_PER_WATCH: usize = 10_000;
35const MAX_WATCHES_PER_USER: usize = 20;
37
38const PING_INTERVAL: Duration = Duration::from_secs(30);
40
41const PONG_TIMEOUT: Duration = Duration::from_secs(10);
44
45const CHANNEL_CAPACITY: usize = 64;
47
48const MAX_SEND_FAILURES: u32 = 3;
50
51const COMPRESSION_THRESHOLD: usize = 512;
53
54#[derive(serde::Deserialize, Debug)]
60struct WsRequest {
61 op: String,
62 #[serde(rename = "reqId")]
63 req_id: Option<String>,
64 #[serde(rename = "watchDis")]
65 #[allow(dead_code)]
66 watch_dis: Option<String>,
67 #[serde(rename = "watchId")]
68 watch_id: Option<String>,
69 ids: Option<Vec<String>>,
70}
71
72#[derive(serde::Serialize, Debug)]
74struct WsResponse {
75 #[serde(rename = "reqId", skip_serializing_if = "Option::is_none")]
76 req_id: Option<String>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 error: Option<String>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 rows: Option<Vec<Value>>,
81 #[serde(rename = "watchId", skip_serializing_if = "Option::is_none")]
82 watch_id: Option<String>,
83}
84
85impl WsResponse {
86 fn error(req_id: Option<String>, msg: impl Into<String>) -> Self {
88 Self {
89 req_id,
90 error: Some(msg.into()),
91 rows: None,
92 watch_id: None,
93 }
94 }
95
96 fn ok(req_id: Option<String>, rows: Vec<Value>, watch_id: Option<String>) -> Self {
98 Self {
99 req_id,
100 error: None,
101 rows: Some(rows),
102 watch_id,
103 }
104 }
105}
106
107fn encode_entity(entity: &HDict) -> Value {
114 let mut m = Map::new();
115 let mut keys: Vec<&String> = entity.tags().keys().collect();
116 keys.sort();
117 for k in keys {
118 let v = &entity.tags()[k];
119 if let Ok(encoded) = json_v3::encode_kind(v) {
120 m.insert(k.clone(), encoded);
121 }
122 }
123 Value::Object(m)
124}
125
126fn handle_ws_request(req: &WsRequest, username: &str, state: &AppState) -> String {
134 let resp = match req.op.as_str() {
135 "watchSub" => handle_watch_sub(req, username, state),
136 "watchPoll" => handle_watch_poll(req, username, state),
137 "watchUnsub" => handle_watch_unsub(req, username, state),
138 other => WsResponse::error(req.req_id.clone(), format!("unknown op: {other}")),
139 };
140 serde_json::to_string(&resp).unwrap_or_else(|e| {
142 let fallback = WsResponse::error(req.req_id.clone(), format!("serialization error: {e}"));
143 serde_json::to_string(&fallback).unwrap()
144 })
145}
146
147fn handle_watch_sub(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
150 let ids = match &req.ids {
151 Some(ids) if !ids.is_empty() => ids.clone(),
152 _ => {
153 return WsResponse::error(
154 req.req_id.clone(),
155 "watchSub requires non-empty 'ids' array",
156 );
157 }
158 };
159
160 let ids: Vec<String> = ids
162 .into_iter()
163 .map(|id| id.strip_prefix('@').unwrap_or(&id).to_string())
164 .collect();
165
166 let graph_version = state.graph.version();
167 let watch_id = match state
168 .watches
169 .subscribe(username, ids.clone(), graph_version)
170 {
171 Ok(wid) => wid,
172 Err(e) => return WsResponse::error(req.req_id.clone(), e),
173 };
174
175 let rows: Vec<Value> = ids
177 .iter()
178 .filter_map(|id| state.graph.get(id).map(|e| encode_entity(&e)))
179 .collect();
180
181 WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
182}
183
184fn handle_watch_poll(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
186 let watch_id = match &req.watch_id {
187 Some(wid) => wid.clone(),
188 None => {
189 return WsResponse::error(req.req_id.clone(), "watchPoll requires 'watchId'");
190 }
191 };
192
193 match state.watches.poll(&watch_id, username, &state.graph) {
194 Some(changed) => {
195 let rows: Vec<Value> = changed.iter().map(encode_entity).collect();
196 WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
197 }
198 None => WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}")),
199 }
200}
201
202fn handle_watch_unsub(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
204 let watch_id = match &req.watch_id {
205 Some(wid) => wid.clone(),
206 None => {
207 return WsResponse::error(req.req_id.clone(), "watchUnsub requires 'watchId'");
208 }
209 };
210
211 if let Some(ids) = &req.ids
214 && !ids.is_empty()
215 {
216 let clean: Vec<String> = ids
217 .iter()
218 .map(|id| id.strip_prefix('@').unwrap_or(id).to_string())
219 .collect();
220 if !state.watches.remove_ids(&watch_id, username, &clean) {
221 return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
222 }
223 return WsResponse::ok(req.req_id.clone(), vec![], Some(watch_id));
224 }
225
226 if !state.watches.unsubscribe(&watch_id, username) {
227 return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
228 }
229 WsResponse::ok(req.req_id.clone(), vec![], None)
230}
231
232struct Watch {
234 entity_ids: HashSet<String>,
236 last_version: u64,
238 owner: String,
240}
241
242pub struct WatchManager {
249 watches: RwLock<HashMap<String, Watch>>,
250 encode_cache: RwLock<HashMap<(String, u64), Value>>,
252 cache_version: RwLock<u64>,
254}
255
256impl WatchManager {
257 pub fn new() -> Self {
259 Self {
260 watches: RwLock::new(HashMap::new()),
261 encode_cache: RwLock::new(HashMap::new()),
262 cache_version: RwLock::new(0),
263 }
264 }
265
266 pub fn subscribe(
270 &self,
271 username: &str,
272 ids: Vec<String>,
273 graph_version: u64,
274 ) -> Result<String, String> {
275 let mut watches = self.watches.write();
276 if watches.len() >= MAX_WATCHES {
277 return Err("maximum number of watches reached".to_string());
278 }
279 let user_count = watches.values().filter(|w| w.owner == username).count();
280 if user_count >= MAX_WATCHES_PER_USER {
281 return Err(format!(
282 "user '{}' has reached the maximum of {} watches",
283 username, MAX_WATCHES_PER_USER
284 ));
285 }
286 if ids.len() > MAX_ENTITY_IDS_PER_WATCH {
287 return Err(format!(
288 "too many entity IDs (max {})",
289 MAX_ENTITY_IDS_PER_WATCH
290 ));
291 }
292 let watch_id = Uuid::new_v4().to_string();
293 let watch = Watch {
294 entity_ids: ids.into_iter().collect(),
295 last_version: graph_version,
296 owner: username.to_string(),
297 };
298 watches.insert(watch_id.clone(), watch);
299 Ok(watch_id)
300 }
301
302 pub fn poll(&self, watch_id: &str, username: &str, graph: &SharedGraph) -> Option<Vec<HDict>> {
307 let (entity_ids, last_version) = {
311 let mut watches = self.watches.write();
312 let watch = watches.get_mut(watch_id)?;
313 if watch.owner != username {
314 return None;
315 }
316
317 let current_version = graph.version();
318 if current_version == watch.last_version {
319 return Some(Vec::new());
321 }
322
323 let ids = watch.entity_ids.clone();
324 let last = watch.last_version;
325 watch.last_version = current_version;
326 (ids, last)
327 }; let changes = graph.changes_since(last_version);
331 let changed_refs: HashSet<&str> = changes.iter().map(|d| d.ref_val.as_str()).collect();
332
333 Some(
334 entity_ids
335 .iter()
336 .filter(|id| changed_refs.contains(id.as_str()))
337 .filter_map(|id| graph.get(id))
338 .collect(),
339 )
340 }
341
342 pub fn unsubscribe(&self, watch_id: &str, username: &str) -> bool {
346 let mut watches = self.watches.write();
347 match watches.get(watch_id) {
348 Some(watch) if watch.owner == username => {
349 watches.remove(watch_id);
350 true
351 }
352 _ => false,
353 }
354 }
355
356 pub fn add_ids(&self, watch_id: &str, username: &str, ids: Vec<String>) -> bool {
361 let mut watches = self.watches.write();
362 if let Some(watch) = watches.get_mut(watch_id) {
363 if watch.owner != username {
364 return false;
365 }
366 if watch.entity_ids.len() + ids.len() > MAX_ENTITY_IDS_PER_WATCH {
367 return false;
368 }
369 watch.entity_ids.extend(ids);
370 true
371 } else {
372 false
373 }
374 }
375
376 pub fn remove_ids(&self, watch_id: &str, username: &str, ids: &[String]) -> bool {
381 let mut watches = self.watches.write();
382 if let Some(watch) = watches.get_mut(watch_id) {
383 if watch.owner != username {
384 return false;
385 }
386 for id in ids {
387 watch.entity_ids.remove(id);
388 }
389 true
390 } else {
391 false
392 }
393 }
394
395 pub fn get_ids(&self, watch_id: &str) -> Option<Vec<String>> {
399 let watches = self.watches.read();
400 watches
401 .get(watch_id)
402 .map(|w| w.entity_ids.iter().cloned().collect())
403 }
404
405 pub fn len(&self) -> usize {
407 self.watches.read().len()
408 }
409
410 pub fn is_empty(&self) -> bool {
412 self.watches.read().is_empty()
413 }
414
415 pub fn encode_cached(&self, ref_val: &str, graph_version: u64, entity: &HDict) -> Value {
418 {
420 let mut cv = self.cache_version.write();
421 if graph_version > *cv {
422 self.encode_cache.write().clear();
423 *cv = graph_version;
424 }
425 }
426
427 let key = (ref_val.to_string(), graph_version);
428 if let Some(cached) = self.encode_cache.read().get(&key) {
429 return cached.clone();
430 }
431
432 let encoded = encode_entity(entity);
433 self.encode_cache.write().insert(key, encoded.clone());
434 encoded
435 }
436
437 pub fn all_watched_ids(&self) -> HashSet<String> {
440 let watches = self.watches.read();
441 watches
442 .values()
443 .flat_map(|w| w.entity_ids.iter().cloned())
444 .collect()
445 }
446
447 pub fn watches_affected_by(
450 &self,
451 changed_refs: &HashSet<&str>,
452 ) -> Vec<(String, String, Vec<String>)> {
453 let watches = self.watches.read();
454 let mut affected = Vec::new();
455 for (wid, watch) in watches.iter() {
456 let matched: Vec<String> = watch
457 .entity_ids
458 .iter()
459 .filter(|id| changed_refs.contains(id.as_str()))
460 .cloned()
461 .collect();
462 if !matched.is_empty() {
463 affected.push((wid.clone(), watch.owner.clone(), matched));
464 }
465 }
466 affected
467 }
468}
469
470impl Default for WatchManager {
471 fn default() -> Self {
472 Self::new()
473 }
474}
475
476fn maybe_compress_response(text: &str) -> WsPayload {
483 if text.len() < COMPRESSION_THRESHOLD {
484 return WsPayload::Text(text.to_string());
485 }
486 let compressed = compress_deflate(text.as_bytes());
487 if compressed.len() < text.len() {
488 WsPayload::Binary(compressed)
489 } else {
490 WsPayload::Text(text.to_string())
491 }
492}
493
494fn compress_deflate(data: &[u8]) -> Vec<u8> {
495 use flate2::Compression;
496 use flate2::write::DeflateEncoder;
497 use std::io::Write;
498
499 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
500 let _ = encoder.write_all(data);
501 encoder.finish().unwrap_or_else(|_| data.to_vec())
502}
503
504const MAX_DECOMPRESSED_SIZE: u64 = 10 * 1024 * 1024;
506
507fn decompress_deflate(data: &[u8]) -> Result<String, std::io::Error> {
508 use flate2::read::DeflateDecoder;
509 use std::io::Read;
510
511 let decoder = DeflateDecoder::new(data);
512 let mut limited = decoder.take(MAX_DECOMPRESSED_SIZE);
513 let mut output = String::new();
514 limited.read_to_string(&mut output)?;
515 Ok(output)
516}
517
518enum WsPayload {
519 Text(String),
520 Binary(Vec<u8>),
521}
522
523pub async fn ws_handler(
540 req: HttpRequest,
541 stream: web::Payload,
542 state: web::Data<AppState>,
543) -> actix_web::Result<HttpResponse> {
544 let username = if state.auth.is_enabled() {
546 match req.extensions().get::<crate::auth::AuthUser>() {
547 Some(u) => u.username.clone(),
548 None => {
549 return Err(crate::error::HaystackError::new(
550 "authentication required for WebSocket connections",
551 actix_web::http::StatusCode::UNAUTHORIZED,
552 )
553 .into());
554 }
555 }
556 } else {
557 req.extensions()
558 .get::<crate::auth::AuthUser>()
559 .map(|u| u.username.clone())
560 .unwrap_or_else(|| "anonymous".to_string())
561 };
562
563 let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
564
565 actix_rt::spawn(async move {
566 use actix_ws::Message;
567 use tokio::sync::mpsc;
568
569 let (tx, mut rx) = mpsc::channel::<WsPayload>(CHANNEL_CAPACITY);
570
571 let mut session_clone = session.clone();
573 actix_rt::spawn(async move {
574 while let Some(payload) = rx.recv().await {
575 let result = match payload {
576 WsPayload::Text(text) => session_clone.text(text).await,
577 WsPayload::Binary(data) => session_clone.binary(data).await,
578 };
579 if result.is_err() {
580 break;
581 }
582 }
583 });
584
585 let mut last_activity = Instant::now();
587 let mut ping_interval = tokio::time::interval(PING_INTERVAL);
588 ping_interval.tick().await; let mut awaiting_pong = false;
590 let mut send_failures: u32 = 0;
591
592 let mut last_push_version = state.graph.version();
594
595 let mut push_interval = tokio::time::interval(Duration::from_millis(500));
597 push_interval.tick().await;
598
599 use futures_util::StreamExt as _;
600
601 loop {
602 tokio::select! {
603 msg = msg_stream.next() => {
605 let Some(Ok(msg)) = msg else { break };
606 last_activity = Instant::now();
607 awaiting_pong = false;
608
609 match msg {
610 Message::Text(text) => {
611 let response_text = match serde_json::from_str::<WsRequest>(&text) {
612 Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
613 Err(e) => {
614 let err = WsResponse::error(None, format!("invalid request: {e}"));
615 serde_json::to_string(&err).unwrap()
616 }
617 };
618 let payload = maybe_compress_response(&response_text);
619 if tx.try_send(payload).is_err() {
620 send_failures += 1;
621 if send_failures >= MAX_SEND_FAILURES {
622 log::warn!("closing slow WS client ({})", username);
623 break;
624 }
625 } else {
626 send_failures = 0;
627 }
628 }
629 Message::Binary(data) => {
630 if let Ok(text) = decompress_deflate(&data) {
632 let response_text = match serde_json::from_str::<WsRequest>(&text) {
633 Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
634 Err(e) => {
635 let err = WsResponse::error(None, format!("invalid request: {e}"));
636 serde_json::to_string(&err).unwrap()
637 }
638 };
639 let payload = maybe_compress_response(&response_text);
640 let _ = tx.try_send(payload);
641 }
642 }
643 Message::Ping(bytes) => {
644 let _ = session.pong(&bytes).await;
645 }
646 Message::Pong(_) => {
647 awaiting_pong = false;
648 }
649 Message::Close(_) => {
650 break;
651 }
652 _ => {}
653 }
654 }
655
656 _ = ping_interval.tick() => {
658 if awaiting_pong && last_activity.elapsed() > PONG_TIMEOUT {
659 log::info!("closing stale WS connection ({}): no pong", username);
660 break;
661 }
662 let _ = session.ping(b"").await;
663 awaiting_pong = true;
664 }
665
666 _ = push_interval.tick() => {
668 let current_version = state.graph.version();
669 if current_version > last_push_version {
670 let changes = state.graph.changes_since(last_push_version);
671 let changed_refs: HashSet<&str> =
672 changes.iter().map(|d| d.ref_val.as_str()).collect();
673
674 let affected = state.watches.watches_affected_by(&changed_refs);
675 for (watch_id, owner, changed_ids) in &affected {
676 if owner != &username {
677 continue;
678 }
679 let rows: Vec<Value> = changed_ids
680 .iter()
681 .filter_map(|id| {
682 let entity = state.graph.get(id)?;
683 Some(state.watches.encode_cached(id, current_version, &entity))
684 })
685 .collect();
686 if !rows.is_empty() {
687 let push_msg = serde_json::json!({
688 "type": "push",
689 "watchId": watch_id,
690 "rows": rows,
691 });
692 if let Ok(text) = serde_json::to_string(&push_msg) {
693 let payload = maybe_compress_response(&text);
694 let _ = tx.try_send(payload);
695 }
696 }
697 }
698 last_push_version = current_version;
699 }
700 }
701 }
702 }
703
704 let _ = session.close(None).await;
705 });
706
707 Ok(response)
708}
709
710#[cfg(test)]
711mod tests {
712 use super::*;
713 use haystack_core::graph::{EntityGraph, SharedGraph};
714 use haystack_core::kinds::{HRef, Kind};
715
716 fn make_graph_with_entity(id: &str) -> SharedGraph {
717 let graph = SharedGraph::new(EntityGraph::new());
718 let mut entity = HDict::new();
719 entity.set("id", Kind::Ref(HRef::from_val(id)));
720 entity.set("site", Kind::Marker);
721 entity.set("dis", Kind::Str(format!("Site {id}")));
722 graph.add(entity).unwrap();
723 graph
724 }
725
726 #[test]
727 fn subscribe_returns_watch_id() {
728 let wm = WatchManager::new();
729 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
730 assert!(!watch_id.is_empty());
731 }
732
733 #[test]
734 fn poll_no_changes() {
735 let graph = make_graph_with_entity("site-1");
736 let wm = WatchManager::new();
737 let version = graph.version();
738 let watch_id = wm
739 .subscribe("admin", vec!["site-1".into()], version)
740 .unwrap();
741
742 let changes = wm.poll(&watch_id, "admin", &graph).unwrap();
743 assert!(changes.is_empty());
744 }
745
746 #[test]
747 fn poll_with_changes() {
748 let graph = make_graph_with_entity("site-1");
749 let wm = WatchManager::new();
750 let version = graph.version();
751 let watch_id = wm
752 .subscribe("admin", vec!["site-1".into()], version)
753 .unwrap();
754
755 let mut changes = HDict::new();
757 changes.set("dis", Kind::Str("Updated".into()));
758 graph.update("site-1", changes).unwrap();
759
760 let result = wm.poll(&watch_id, "admin", &graph).unwrap();
761 assert_eq!(result.len(), 1);
762 }
763
764 #[test]
765 fn poll_unknown_watch() {
766 let graph = make_graph_with_entity("site-1");
767 let wm = WatchManager::new();
768 assert!(wm.poll("unknown", "admin", &graph).is_none());
769 }
770
771 #[test]
772 fn poll_wrong_owner() {
773 let graph = make_graph_with_entity("site-1");
774 let wm = WatchManager::new();
775 let version = graph.version();
776 let watch_id = wm
777 .subscribe("admin", vec!["site-1".into()], version)
778 .unwrap();
779
780 assert!(wm.poll(&watch_id, "other-user", &graph).is_none());
782 }
783
784 #[test]
785 fn unsubscribe_removes_watch() {
786 let wm = WatchManager::new();
787 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
788 assert!(wm.unsubscribe(&watch_id, "admin"));
789 assert!(!wm.unsubscribe(&watch_id, "admin")); }
791
792 #[test]
793 fn unsubscribe_wrong_owner() {
794 let wm = WatchManager::new();
795 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
796 assert!(!wm.unsubscribe(&watch_id, "other-user"));
798 assert!(wm.unsubscribe(&watch_id, "admin"));
800 }
801
802 #[test]
803 fn remove_ids_selective() {
804 let wm = WatchManager::new();
805 let watch_id = wm
806 .subscribe(
807 "admin",
808 vec!["site-1".into(), "site-2".into(), "site-3".into()],
809 0,
810 )
811 .unwrap();
812
813 assert!(wm.remove_ids(&watch_id, "admin", &["site-2".into()]));
815
816 let remaining = wm.get_ids(&watch_id).unwrap();
817 assert_eq!(remaining.len(), 2);
818 assert!(remaining.contains(&"site-1".to_string()));
819 assert!(remaining.contains(&"site-3".to_string()));
820 assert!(!remaining.contains(&"site-2".to_string()));
821 }
822
823 #[test]
824 fn remove_ids_nonexistent_watch() {
825 let wm = WatchManager::new();
826 assert!(!wm.remove_ids("no-such-watch", "admin", &["site-1".into()]));
827 }
828
829 #[test]
830 fn remove_ids_wrong_owner() {
831 let wm = WatchManager::new();
832 let watch_id = wm
833 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
834 .unwrap();
835
836 assert!(!wm.remove_ids(&watch_id, "other-user", &["site-1".into()]));
838
839 let remaining = wm.get_ids(&watch_id).unwrap();
841 assert_eq!(remaining.len(), 2);
842 }
843
844 #[test]
845 fn remove_ids_leaves_watch_alive() {
846 let wm = WatchManager::new();
847 let watch_id = wm
848 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
849 .unwrap();
850
851 assert!(wm.remove_ids(&watch_id, "admin", &["site-1".into(), "site-2".into()]));
853
854 let remaining = wm.get_ids(&watch_id).unwrap();
855 assert!(remaining.is_empty());
856
857 assert!(wm.unsubscribe(&watch_id, "admin"));
859 }
860
861 #[test]
862 fn unsubscribe_full_removal() {
863 let wm = WatchManager::new();
864 let watch_id = wm
865 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
866 .unwrap();
867
868 assert!(wm.unsubscribe(&watch_id, "admin"));
870
871 assert!(wm.get_ids(&watch_id).is_none());
873
874 assert!(!wm.unsubscribe(&watch_id, "admin"));
876 }
877
878 #[test]
879 fn add_ids_ownership_check() {
880 let wm = WatchManager::new();
881 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
882
883 assert!(!wm.add_ids(&watch_id, "other-user", vec!["site-2".into()]));
885
886 assert!(wm.add_ids(&watch_id, "admin", vec!["site-2".into()]));
888
889 let ids = wm.get_ids(&watch_id).unwrap();
890 assert_eq!(ids.len(), 2);
891 assert!(ids.contains(&"site-1".to_string()));
892 assert!(ids.contains(&"site-2".to_string()));
893 }
894
895 #[test]
896 fn get_ids_returns_none_for_unknown_watch() {
897 let wm = WatchManager::new();
898 assert!(wm.get_ids("nonexistent").is_none());
899 }
900
901 #[test]
906 fn ws_request_deserialization() {
907 let json = r#"{
908 "op": "watchSub",
909 "reqId": "abc-123",
910 "watchDis": "my-watch",
911 "ids": ["@ref1", "@ref2"]
912 }"#;
913 let req: WsRequest = serde_json::from_str(json).unwrap();
914 assert_eq!(req.op, "watchSub");
915 assert_eq!(req.req_id.as_deref(), Some("abc-123"));
916 assert_eq!(req.watch_dis.as_deref(), Some("my-watch"));
917 assert!(req.watch_id.is_none());
918 let ids = req.ids.unwrap();
919 assert_eq!(ids, vec!["@ref1", "@ref2"]);
920 }
921
922 #[test]
923 fn ws_request_deserialization_minimal() {
924 let json = r#"{"op": "watchPoll", "watchId": "w-1"}"#;
926 let req: WsRequest = serde_json::from_str(json).unwrap();
927 assert_eq!(req.op, "watchPoll");
928 assert!(req.req_id.is_none());
929 assert!(req.watch_dis.is_none());
930 assert_eq!(req.watch_id.as_deref(), Some("w-1"));
931 assert!(req.ids.is_none());
932 }
933
934 #[test]
935 fn ws_response_serialization() {
936 let resp = WsResponse::ok(
937 Some("r-1".into()),
938 vec![serde_json::json!({"id": "r:site-1"})],
939 Some("w-1".into()),
940 );
941 let json = serde_json::to_value(&resp).unwrap();
942 assert_eq!(json["reqId"], "r-1");
943 assert_eq!(json["watchId"], "w-1");
944 assert!(json["rows"].is_array());
945 assert_eq!(json["rows"][0]["id"], "r:site-1");
946 assert!(json.get("error").is_none());
948 }
949
950 #[test]
951 fn ws_response_omits_none_fields() {
952 let resp = WsResponse::ok(None, vec![], None);
953 let json = serde_json::to_value(&resp).unwrap();
954 assert!(json.get("reqId").is_none());
956 assert!(json.get("error").is_none());
957 assert!(json.get("watchId").is_none());
958 assert!(json["rows"].is_array());
960 }
961
962 #[test]
963 fn ws_response_includes_req_id() {
964 let resp = WsResponse::error(Some("req-42".into()), "something went wrong");
965 let json = serde_json::to_value(&resp).unwrap();
966 assert_eq!(json["reqId"], "req-42");
967 assert_eq!(json["error"], "something went wrong");
968 assert!(json.get("rows").is_none());
970 assert!(json.get("watchId").is_none());
971 }
972
973 #[test]
974 fn ws_error_response_format() {
975 let resp = WsResponse::error(None, "bad request");
976 let json = serde_json::to_value(&resp).unwrap();
977 assert_eq!(json["error"], "bad request");
978 assert!(json.get("reqId").is_none());
980 assert!(json.get("rows").is_none());
982 assert!(json.get("watchId").is_none());
983 }
984}