1use std::collections::{HashMap, HashSet};
16use std::time::{Duration, Instant};
17
18use actix_web::{HttpMessage, HttpRequest, HttpResponse, web};
19use parking_lot::RwLock;
20use serde_json::{Map, Value};
21use uuid::Uuid;
22
23use haystack_core::codecs::json::v3 as json_v3;
24use haystack_core::data::HDict;
25use haystack_core::graph::SharedGraph;
26
27use crate::state::AppState;
28
29const MAX_WATCHES: usize = 100;
34const MAX_ENTITY_IDS_PER_WATCH: usize = 1_000;
35const MAX_TOTAL_WATCHED_IDS: usize = 5_000;
37const MAX_WATCHES_PER_USER: usize = 20;
39
40const MAX_ENCODE_CACHE_ENTRIES: usize = 50_000;
42
43const PING_INTERVAL: Duration = Duration::from_secs(30);
45
46const PONG_TIMEOUT: Duration = Duration::from_secs(10);
49
50const CHANNEL_CAPACITY: usize = 64;
52
53const MAX_SEND_FAILURES: u32 = 3;
55
56const COMPRESSION_THRESHOLD: usize = 512;
58
59#[derive(serde::Deserialize, Debug)]
65struct WsRequest {
66 op: String,
67 #[serde(rename = "reqId")]
68 req_id: Option<String>,
69 #[serde(rename = "watchDis")]
70 #[allow(dead_code)]
71 watch_dis: Option<String>,
72 #[serde(rename = "watchId")]
73 watch_id: Option<String>,
74 ids: Option<Vec<String>>,
75}
76
77#[derive(serde::Serialize, Debug)]
79struct WsResponse {
80 #[serde(rename = "reqId", skip_serializing_if = "Option::is_none")]
81 req_id: Option<String>,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 error: Option<String>,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 rows: Option<Vec<Value>>,
86 #[serde(rename = "watchId", skip_serializing_if = "Option::is_none")]
87 watch_id: Option<String>,
88}
89
90impl WsResponse {
91 fn error(req_id: Option<String>, msg: impl Into<String>) -> Self {
93 Self {
94 req_id,
95 error: Some(msg.into()),
96 rows: None,
97 watch_id: None,
98 }
99 }
100
101 fn ok(req_id: Option<String>, rows: Vec<Value>, watch_id: Option<String>) -> Self {
103 Self {
104 req_id,
105 error: None,
106 rows: Some(rows),
107 watch_id,
108 }
109 }
110}
111
112fn encode_entity(entity: &HDict) -> Value {
119 let mut m = Map::new();
120 let mut keys: Vec<&String> = entity.tags().keys().collect();
121 keys.sort();
122 for k in keys {
123 let v = &entity.tags()[k];
124 if let Ok(encoded) = json_v3::encode_kind(v) {
125 m.insert(k.clone(), encoded);
126 }
127 }
128 Value::Object(m)
129}
130
131fn handle_ws_request(req: &WsRequest, username: &str, state: &AppState) -> String {
139 let resp = match req.op.as_str() {
140 "watchSub" => handle_watch_sub(req, username, state),
141 "watchPoll" => handle_watch_poll(req, username, state),
142 "watchUnsub" => handle_watch_unsub(req, username, state),
143 other => WsResponse::error(req.req_id.clone(), format!("unknown op: {other}")),
144 };
145 serde_json::to_string(&resp).unwrap_or_else(|e| {
147 let fallback = WsResponse::error(req.req_id.clone(), format!("serialization error: {e}"));
148 serde_json::to_string(&fallback).unwrap()
149 })
150}
151
152fn handle_watch_sub(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
155 let ids = match &req.ids {
156 Some(ids) if !ids.is_empty() => ids.clone(),
157 _ => {
158 return WsResponse::error(
159 req.req_id.clone(),
160 "watchSub requires non-empty 'ids' array",
161 );
162 }
163 };
164
165 let ids: Vec<String> = ids
167 .into_iter()
168 .map(|id| id.strip_prefix('@').unwrap_or(&id).to_string())
169 .collect();
170
171 let graph_version = state.graph.version();
172 let watch_id = match state
173 .watches
174 .subscribe(username, ids.clone(), graph_version)
175 {
176 Ok(wid) => wid,
177 Err(e) => return WsResponse::error(req.req_id.clone(), e),
178 };
179
180 let rows: Vec<Value> = ids
182 .iter()
183 .filter_map(|id| state.graph.get(id).map(|e| encode_entity(&e)))
184 .collect();
185
186 WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
187}
188
189fn handle_watch_poll(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
191 let watch_id = match &req.watch_id {
192 Some(wid) => wid.clone(),
193 None => {
194 return WsResponse::error(req.req_id.clone(), "watchPoll requires 'watchId'");
195 }
196 };
197
198 match state.watches.poll(&watch_id, username, &state.graph) {
199 Some(changed) => {
200 let rows: Vec<Value> = changed.iter().map(encode_entity).collect();
201 WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
202 }
203 None => WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}")),
204 }
205}
206
207fn handle_watch_unsub(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
209 let watch_id = match &req.watch_id {
210 Some(wid) => wid.clone(),
211 None => {
212 return WsResponse::error(req.req_id.clone(), "watchUnsub requires 'watchId'");
213 }
214 };
215
216 if let Some(ids) = &req.ids
219 && !ids.is_empty()
220 {
221 let clean: Vec<String> = ids
222 .iter()
223 .map(|id| id.strip_prefix('@').unwrap_or(id).to_string())
224 .collect();
225 if !state.watches.remove_ids(&watch_id, username, &clean) {
226 return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
227 }
228 return WsResponse::ok(req.req_id.clone(), vec![], Some(watch_id));
229 }
230
231 if !state.watches.unsubscribe(&watch_id, username) {
232 return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
233 }
234 WsResponse::ok(req.req_id.clone(), vec![], None)
235}
236
237struct Watch {
239 entity_ids: HashSet<String>,
241 last_version: u64,
243 owner: String,
245}
246
247pub struct WatchManager {
254 watches: RwLock<HashMap<String, Watch>>,
255 encode_cache: RwLock<HashMap<(String, u64), Value>>,
257 cache_version: RwLock<u64>,
259}
260
261impl WatchManager {
262 pub fn new() -> Self {
264 Self {
265 watches: RwLock::new(HashMap::new()),
266 encode_cache: RwLock::new(HashMap::new()),
267 cache_version: RwLock::new(0),
268 }
269 }
270
271 pub fn subscribe(
275 &self,
276 username: &str,
277 ids: Vec<String>,
278 graph_version: u64,
279 ) -> Result<String, String> {
280 let mut watches = self.watches.write();
281 if watches.len() >= MAX_WATCHES {
282 return Err("maximum number of watches reached".to_string());
283 }
284 let user_count = watches.values().filter(|w| w.owner == username).count();
285 if user_count >= MAX_WATCHES_PER_USER {
286 return Err(format!(
287 "user '{}' has reached the maximum of {} watches",
288 username, MAX_WATCHES_PER_USER
289 ));
290 }
291 if ids.len() > MAX_ENTITY_IDS_PER_WATCH {
292 return Err(format!(
293 "too many entity IDs (max {})",
294 MAX_ENTITY_IDS_PER_WATCH
295 ));
296 }
297 let user_total: usize = watches
299 .values()
300 .filter(|w| w.owner == username)
301 .map(|w| w.entity_ids.len())
302 .sum();
303 if user_total + ids.len() > MAX_TOTAL_WATCHED_IDS {
304 return Err(format!(
305 "user '{}' would exceed the maximum of {} total watched IDs",
306 username, MAX_TOTAL_WATCHED_IDS
307 ));
308 }
309 let watch_id = Uuid::new_v4().to_string();
310 let watch = Watch {
311 entity_ids: ids.into_iter().collect(),
312 last_version: graph_version,
313 owner: username.to_string(),
314 };
315 watches.insert(watch_id.clone(), watch);
316 Ok(watch_id)
317 }
318
319 pub fn poll(&self, watch_id: &str, username: &str, graph: &SharedGraph) -> Option<Vec<HDict>> {
324 let (entity_ids, last_version) = {
328 let mut watches = self.watches.write();
329 let watch = watches.get_mut(watch_id)?;
330 if watch.owner != username {
331 return None;
332 }
333
334 let current_version = graph.version();
335 if current_version == watch.last_version {
336 return Some(Vec::new());
338 }
339
340 let ids = watch.entity_ids.clone();
341 let last = watch.last_version;
342 watch.last_version = current_version;
343 (ids, last)
344 }; let changes = match graph.changes_since(last_version) {
348 Ok(c) => c,
349 Err(_gap) => {
350 return Some(entity_ids.iter().filter_map(|id| graph.get(id)).collect());
352 }
353 };
354 let changed_refs: HashSet<&str> = changes.iter().map(|d| d.ref_val.as_str()).collect();
355
356 Some(
357 entity_ids
358 .iter()
359 .filter(|id| changed_refs.contains(id.as_str()))
360 .filter_map(|id| graph.get(id))
361 .collect(),
362 )
363 }
364
365 pub fn unsubscribe(&self, watch_id: &str, username: &str) -> bool {
369 let mut watches = self.watches.write();
370 match watches.get(watch_id) {
371 Some(watch) if watch.owner == username => {
372 watches.remove(watch_id);
373 true
374 }
375 _ => false,
376 }
377 }
378
379 pub fn add_ids(&self, watch_id: &str, username: &str, ids: Vec<String>) -> bool {
384 let mut watches = self.watches.write();
385
386 let (owner_ok, per_watch_ok, user_total) = match watches.get(watch_id) {
388 Some(watch) => (
389 watch.owner == username,
390 watch.entity_ids.len() + ids.len() <= MAX_ENTITY_IDS_PER_WATCH,
391 watches
392 .values()
393 .filter(|w| w.owner == username)
394 .map(|w| w.entity_ids.len())
395 .sum::<usize>(),
396 ),
397 None => return false,
398 };
399
400 if !owner_ok || !per_watch_ok {
401 return false;
402 }
403 if user_total + ids.len() > MAX_TOTAL_WATCHED_IDS {
404 return false;
405 }
406
407 if let Some(watch) = watches.get_mut(watch_id) {
408 watch.entity_ids.extend(ids);
409 }
410 true
411 }
412
413 pub fn remove_ids(&self, watch_id: &str, username: &str, ids: &[String]) -> bool {
418 let mut watches = self.watches.write();
419 if let Some(watch) = watches.get_mut(watch_id) {
420 if watch.owner != username {
421 return false;
422 }
423 for id in ids {
424 watch.entity_ids.remove(id);
425 }
426 true
427 } else {
428 false
429 }
430 }
431
432 pub fn get_ids(&self, watch_id: &str) -> Option<Vec<String>> {
436 let watches = self.watches.read();
437 watches
438 .get(watch_id)
439 .map(|w| w.entity_ids.iter().cloned().collect())
440 }
441
442 pub fn len(&self) -> usize {
444 self.watches.read().len()
445 }
446
447 pub fn is_empty(&self) -> bool {
449 self.watches.read().is_empty()
450 }
451
452 pub fn encode_cached(&self, ref_val: &str, graph_version: u64, entity: &HDict) -> Value {
455 {
457 let mut cv = self.cache_version.write();
458 if graph_version > *cv {
459 self.encode_cache.write().clear();
460 *cv = graph_version;
461 }
462 }
463
464 let key = (ref_val.to_string(), graph_version);
465 if let Some(cached) = self.encode_cache.read().get(&key) {
466 return cached.clone();
467 }
468
469 let encoded = encode_entity(entity);
470 let mut cache = self.encode_cache.write();
471 cache.insert(key, encoded.clone());
472 if cache.len() > MAX_ENCODE_CACHE_ENTRIES {
473 let to_remove = cache.len() / 4;
475 let keys: Vec<_> = cache.keys().take(to_remove).cloned().collect();
476 for k in keys {
477 cache.remove(&k);
478 }
479 }
480 encoded
481 }
482
483 pub fn all_watched_ids(&self) -> HashSet<String> {
486 let watches = self.watches.read();
487 watches
488 .values()
489 .flat_map(|w| w.entity_ids.iter().cloned())
490 .collect()
491 }
492
493 pub fn watches_affected_by(
496 &self,
497 changed_refs: &HashSet<&str>,
498 ) -> Vec<(String, String, Vec<String>)> {
499 let watches = self.watches.read();
500 let mut affected = Vec::new();
501 for (wid, watch) in watches.iter() {
502 let matched: Vec<String> = watch
503 .entity_ids
504 .iter()
505 .filter(|id| changed_refs.contains(id.as_str()))
506 .cloned()
507 .collect();
508 if !matched.is_empty() {
509 affected.push((wid.clone(), watch.owner.clone(), matched));
510 }
511 }
512 affected
513 }
514}
515
516impl Default for WatchManager {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522fn maybe_compress_response(text: &str) -> WsPayload {
529 if text.len() < COMPRESSION_THRESHOLD {
530 return WsPayload::Text(text.to_string());
531 }
532 let compressed = compress_deflate(text.as_bytes());
533 if compressed.len() < text.len() {
534 WsPayload::Binary(compressed)
535 } else {
536 WsPayload::Text(text.to_string())
537 }
538}
539
540fn compress_deflate(data: &[u8]) -> Vec<u8> {
541 use flate2::Compression;
542 use flate2::write::DeflateEncoder;
543 use std::io::Write;
544
545 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
546 let _ = encoder.write_all(data);
547 encoder.finish().unwrap_or_else(|_| data.to_vec())
548}
549
550const MAX_DECOMPRESSED_SIZE: u64 = 10 * 1024 * 1024;
552
553fn decompress_deflate(data: &[u8]) -> Result<String, std::io::Error> {
554 use flate2::read::DeflateDecoder;
555 use std::io::Read;
556
557 let decoder = DeflateDecoder::new(data);
558 let mut limited = decoder.take(MAX_DECOMPRESSED_SIZE);
559 let mut output = String::new();
560 limited.read_to_string(&mut output)?;
561 Ok(output)
562}
563
564enum WsPayload {
565 Text(String),
566 Binary(Vec<u8>),
567}
568
569pub async fn ws_handler(
586 req: HttpRequest,
587 stream: web::Payload,
588 state: web::Data<AppState>,
589) -> actix_web::Result<HttpResponse> {
590 let username = if state.auth.is_enabled() {
592 match req.extensions().get::<crate::auth::AuthUser>() {
593 Some(u) => u.username.clone(),
594 None => {
595 return Err(crate::error::HaystackError::new(
596 "authentication required for WebSocket connections",
597 actix_web::http::StatusCode::UNAUTHORIZED,
598 )
599 .into());
600 }
601 }
602 } else {
603 req.extensions()
604 .get::<crate::auth::AuthUser>()
605 .map(|u| u.username.clone())
606 .unwrap_or_else(|| "anonymous".to_string())
607 };
608
609 let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
610
611 actix_rt::spawn(async move {
612 use actix_ws::Message;
613 use tokio::sync::mpsc;
614
615 let (tx, mut rx) = mpsc::channel::<WsPayload>(CHANNEL_CAPACITY);
616
617 let mut session_clone = session.clone();
619 actix_rt::spawn(async move {
620 while let Some(payload) = rx.recv().await {
621 let result = match payload {
622 WsPayload::Text(text) => session_clone.text(text).await,
623 WsPayload::Binary(data) => session_clone.binary(data).await,
624 };
625 if result.is_err() {
626 break;
627 }
628 }
629 });
630
631 let mut last_activity = Instant::now();
633 let mut ping_interval = tokio::time::interval(PING_INTERVAL);
634 ping_interval.tick().await; let mut awaiting_pong = false;
636 let mut send_failures: u32 = 0;
637
638 let mut last_push_version = state.graph.version();
640
641 let mut push_interval = tokio::time::interval(Duration::from_millis(500));
643 push_interval.tick().await;
644
645 use futures_util::StreamExt as _;
646
647 loop {
648 tokio::select! {
649 msg = msg_stream.next() => {
651 let Some(Ok(msg)) = msg else { break };
652 last_activity = Instant::now();
653 awaiting_pong = false;
654
655 match msg {
656 Message::Text(text) => {
657 let response_text = match serde_json::from_str::<WsRequest>(&text) {
658 Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
659 Err(e) => {
660 let err = WsResponse::error(None, format!("invalid request: {e}"));
661 serde_json::to_string(&err).unwrap()
662 }
663 };
664 let payload = maybe_compress_response(&response_text);
665 if tx.try_send(payload).is_err() {
666 send_failures += 1;
667 if send_failures >= MAX_SEND_FAILURES {
668 log::warn!("closing slow WS client ({})", username);
669 break;
670 }
671 } else {
672 send_failures = 0;
673 }
674 }
675 Message::Binary(data) => {
676 if let Ok(text) = decompress_deflate(&data) {
678 let response_text = match serde_json::from_str::<WsRequest>(&text) {
679 Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
680 Err(e) => {
681 let err = WsResponse::error(None, format!("invalid request: {e}"));
682 serde_json::to_string(&err).unwrap()
683 }
684 };
685 let payload = maybe_compress_response(&response_text);
686 let _ = tx.try_send(payload);
687 }
688 }
689 Message::Ping(bytes) => {
690 let _ = session.pong(&bytes).await;
691 }
692 Message::Pong(_) => {
693 awaiting_pong = false;
694 }
695 Message::Close(_) => {
696 break;
697 }
698 _ => {}
699 }
700 }
701
702 _ = ping_interval.tick() => {
704 if awaiting_pong && last_activity.elapsed() > PONG_TIMEOUT {
705 log::info!("closing stale WS connection ({}): no pong", username);
706 break;
707 }
708 let _ = session.ping(b"").await;
709 awaiting_pong = true;
710 }
711
712 _ = push_interval.tick() => {
714 let current_version = state.graph.version();
715 if current_version > last_push_version {
716 let changes = match state.graph.changes_since(last_push_version) {
717 Ok(c) => c,
718 Err(_gap) => {
719 last_push_version = current_version;
721 continue;
722 }
723 };
724 let changed_refs: HashSet<&str> =
725 changes.iter().map(|d| d.ref_val.as_str()).collect();
726
727 let affected = state.watches.watches_affected_by(&changed_refs);
728 for (watch_id, owner, changed_ids) in &affected {
729 if owner != &username {
730 continue;
731 }
732 let rows: Vec<Value> = changed_ids
733 .iter()
734 .filter_map(|id| {
735 let entity = state.graph.get(id)?;
736 Some(state.watches.encode_cached(id, current_version, &entity))
737 })
738 .collect();
739 if !rows.is_empty() {
740 let push_msg = serde_json::json!({
741 "type": "push",
742 "watchId": watch_id,
743 "rows": rows,
744 });
745 if let Ok(text) = serde_json::to_string(&push_msg) {
746 let payload = maybe_compress_response(&text);
747 let _ = tx.try_send(payload);
748 }
749 }
750 }
751 last_push_version = current_version;
752 }
753 }
754 }
755 }
756
757 let _ = session.close(None).await;
758 });
759
760 Ok(response)
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766 use haystack_core::graph::{EntityGraph, SharedGraph};
767 use haystack_core::kinds::{HRef, Kind};
768
769 fn make_graph_with_entity(id: &str) -> SharedGraph {
770 let graph = SharedGraph::new(EntityGraph::new());
771 let mut entity = HDict::new();
772 entity.set("id", Kind::Ref(HRef::from_val(id)));
773 entity.set("site", Kind::Marker);
774 entity.set("dis", Kind::Str(format!("Site {id}")));
775 graph.add(entity).unwrap();
776 graph
777 }
778
779 #[test]
780 fn subscribe_returns_watch_id() {
781 let wm = WatchManager::new();
782 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
783 assert!(!watch_id.is_empty());
784 }
785
786 #[test]
787 fn poll_no_changes() {
788 let graph = make_graph_with_entity("site-1");
789 let wm = WatchManager::new();
790 let version = graph.version();
791 let watch_id = wm
792 .subscribe("admin", vec!["site-1".into()], version)
793 .unwrap();
794
795 let changes = wm.poll(&watch_id, "admin", &graph).unwrap();
796 assert!(changes.is_empty());
797 }
798
799 #[test]
800 fn poll_with_changes() {
801 let graph = make_graph_with_entity("site-1");
802 let wm = WatchManager::new();
803 let version = graph.version();
804 let watch_id = wm
805 .subscribe("admin", vec!["site-1".into()], version)
806 .unwrap();
807
808 let mut changes = HDict::new();
810 changes.set("dis", Kind::Str("Updated".into()));
811 graph.update("site-1", changes).unwrap();
812
813 let result = wm.poll(&watch_id, "admin", &graph).unwrap();
814 assert_eq!(result.len(), 1);
815 }
816
817 #[test]
818 fn poll_unknown_watch() {
819 let graph = make_graph_with_entity("site-1");
820 let wm = WatchManager::new();
821 assert!(wm.poll("unknown", "admin", &graph).is_none());
822 }
823
824 #[test]
825 fn poll_wrong_owner() {
826 let graph = make_graph_with_entity("site-1");
827 let wm = WatchManager::new();
828 let version = graph.version();
829 let watch_id = wm
830 .subscribe("admin", vec!["site-1".into()], version)
831 .unwrap();
832
833 assert!(wm.poll(&watch_id, "other-user", &graph).is_none());
835 }
836
837 #[test]
838 fn unsubscribe_removes_watch() {
839 let wm = WatchManager::new();
840 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
841 assert!(wm.unsubscribe(&watch_id, "admin"));
842 assert!(!wm.unsubscribe(&watch_id, "admin")); }
844
845 #[test]
846 fn unsubscribe_wrong_owner() {
847 let wm = WatchManager::new();
848 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
849 assert!(!wm.unsubscribe(&watch_id, "other-user"));
851 assert!(wm.unsubscribe(&watch_id, "admin"));
853 }
854
855 #[test]
856 fn remove_ids_selective() {
857 let wm = WatchManager::new();
858 let watch_id = wm
859 .subscribe(
860 "admin",
861 vec!["site-1".into(), "site-2".into(), "site-3".into()],
862 0,
863 )
864 .unwrap();
865
866 assert!(wm.remove_ids(&watch_id, "admin", &["site-2".into()]));
868
869 let remaining = wm.get_ids(&watch_id).unwrap();
870 assert_eq!(remaining.len(), 2);
871 assert!(remaining.contains(&"site-1".to_string()));
872 assert!(remaining.contains(&"site-3".to_string()));
873 assert!(!remaining.contains(&"site-2".to_string()));
874 }
875
876 #[test]
877 fn remove_ids_nonexistent_watch() {
878 let wm = WatchManager::new();
879 assert!(!wm.remove_ids("no-such-watch", "admin", &["site-1".into()]));
880 }
881
882 #[test]
883 fn remove_ids_wrong_owner() {
884 let wm = WatchManager::new();
885 let watch_id = wm
886 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
887 .unwrap();
888
889 assert!(!wm.remove_ids(&watch_id, "other-user", &["site-1".into()]));
891
892 let remaining = wm.get_ids(&watch_id).unwrap();
894 assert_eq!(remaining.len(), 2);
895 }
896
897 #[test]
898 fn remove_ids_leaves_watch_alive() {
899 let wm = WatchManager::new();
900 let watch_id = wm
901 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
902 .unwrap();
903
904 assert!(wm.remove_ids(&watch_id, "admin", &["site-1".into(), "site-2".into()]));
906
907 let remaining = wm.get_ids(&watch_id).unwrap();
908 assert!(remaining.is_empty());
909
910 assert!(wm.unsubscribe(&watch_id, "admin"));
912 }
913
914 #[test]
915 fn unsubscribe_full_removal() {
916 let wm = WatchManager::new();
917 let watch_id = wm
918 .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
919 .unwrap();
920
921 assert!(wm.unsubscribe(&watch_id, "admin"));
923
924 assert!(wm.get_ids(&watch_id).is_none());
926
927 assert!(!wm.unsubscribe(&watch_id, "admin"));
929 }
930
931 #[test]
932 fn add_ids_ownership_check() {
933 let wm = WatchManager::new();
934 let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
935
936 assert!(!wm.add_ids(&watch_id, "other-user", vec!["site-2".into()]));
938
939 assert!(wm.add_ids(&watch_id, "admin", vec!["site-2".into()]));
941
942 let ids = wm.get_ids(&watch_id).unwrap();
943 assert_eq!(ids.len(), 2);
944 assert!(ids.contains(&"site-1".to_string()));
945 assert!(ids.contains(&"site-2".to_string()));
946 }
947
948 #[test]
949 fn get_ids_returns_none_for_unknown_watch() {
950 let wm = WatchManager::new();
951 assert!(wm.get_ids("nonexistent").is_none());
952 }
953
954 #[test]
959 fn ws_request_deserialization() {
960 let json = r#"{
961 "op": "watchSub",
962 "reqId": "abc-123",
963 "watchDis": "my-watch",
964 "ids": ["@ref1", "@ref2"]
965 }"#;
966 let req: WsRequest = serde_json::from_str(json).unwrap();
967 assert_eq!(req.op, "watchSub");
968 assert_eq!(req.req_id.as_deref(), Some("abc-123"));
969 assert_eq!(req.watch_dis.as_deref(), Some("my-watch"));
970 assert!(req.watch_id.is_none());
971 let ids = req.ids.unwrap();
972 assert_eq!(ids, vec!["@ref1", "@ref2"]);
973 }
974
975 #[test]
976 fn ws_request_deserialization_minimal() {
977 let json = r#"{"op": "watchPoll", "watchId": "w-1"}"#;
979 let req: WsRequest = serde_json::from_str(json).unwrap();
980 assert_eq!(req.op, "watchPoll");
981 assert!(req.req_id.is_none());
982 assert!(req.watch_dis.is_none());
983 assert_eq!(req.watch_id.as_deref(), Some("w-1"));
984 assert!(req.ids.is_none());
985 }
986
987 #[test]
988 fn ws_response_serialization() {
989 let resp = WsResponse::ok(
990 Some("r-1".into()),
991 vec![serde_json::json!({"id": "r:site-1"})],
992 Some("w-1".into()),
993 );
994 let json = serde_json::to_value(&resp).unwrap();
995 assert_eq!(json["reqId"], "r-1");
996 assert_eq!(json["watchId"], "w-1");
997 assert!(json["rows"].is_array());
998 assert_eq!(json["rows"][0]["id"], "r:site-1");
999 assert!(json.get("error").is_none());
1001 }
1002
1003 #[test]
1004 fn ws_response_omits_none_fields() {
1005 let resp = WsResponse::ok(None, vec![], None);
1006 let json = serde_json::to_value(&resp).unwrap();
1007 assert!(json.get("reqId").is_none());
1009 assert!(json.get("error").is_none());
1010 assert!(json.get("watchId").is_none());
1011 assert!(json["rows"].is_array());
1013 }
1014
1015 #[test]
1016 fn ws_response_includes_req_id() {
1017 let resp = WsResponse::error(Some("req-42".into()), "something went wrong");
1018 let json = serde_json::to_value(&resp).unwrap();
1019 assert_eq!(json["reqId"], "req-42");
1020 assert_eq!(json["error"], "something went wrong");
1021 assert!(json.get("rows").is_none());
1023 assert!(json.get("watchId").is_none());
1024 }
1025
1026 #[test]
1027 fn ws_error_response_format() {
1028 let resp = WsResponse::error(None, "bad request");
1029 let json = serde_json::to_value(&resp).unwrap();
1030 assert_eq!(json["error"], "bad request");
1031 assert!(json.get("reqId").is_none());
1033 assert!(json.get("rows").is_none());
1035 assert!(json.get("watchId").is_none());
1036 }
1037}