Skip to main content

haystack_server/
ws.rs

1//! WebSocket handler and watch subscription manager.
2//!
3//! This module provides two major components:
4//!
5//! 1. **`WatchManager`** — a thread-safe subscription registry that manages
6//!    watch lifecycles (subscribe, poll, unsubscribe, add/remove IDs). Each
7//!    watch tracks a set of entity IDs and the graph version at last poll,
8//!    enabling efficient change detection.
9//!
10//! 2. **`ws_handler`** — an Actix-Web WebSocket upgrade endpoint (`GET /api/ws`)
11//!    that handles Haystack watch operations over JSON messages. Supports
12//!    server-initiated ping/pong liveness, deflate compression for large
13//!    payloads, and automatic server-push of graph changes to watching clients.
14
15use std::collections::{HashMap, HashSet};
16use std::time::{Duration, Instant};
17
18use actix_web::{HttpMessage, HttpRequest, HttpResponse, web};
19use parking_lot::RwLock;
20use serde_json::{Map, Value};
21use uuid::Uuid;
22
23use haystack_core::codecs::json::v3 as json_v3;
24use haystack_core::data::HDict;
25use haystack_core::graph::SharedGraph;
26
27use crate::state::AppState;
28
29// ---------------------------------------------------------------------------
30// Tuning constants
31// ---------------------------------------------------------------------------
32
33const MAX_WATCHES: usize = 100;
34const MAX_ENTITY_IDS_PER_WATCH: usize = 1_000;
35/// Maximum total entity IDs a single user can watch across all watches.
36const MAX_TOTAL_WATCHED_IDS: usize = 5_000;
37/// Maximum watches a single user can hold at once.
38const MAX_WATCHES_PER_USER: usize = 20;
39
40/// Maximum entries in the per-connection encode cache.
41const MAX_ENCODE_CACHE_ENTRIES: usize = 50_000;
42
43/// Server-initiated ping interval for liveness detection.
44const PING_INTERVAL: Duration = Duration::from_secs(30);
45
46/// If no pong is received within this duration after a ping, the connection
47/// is considered dead and will be closed.
48const PONG_TIMEOUT: Duration = Duration::from_secs(10);
49
50/// mpsc channel capacity for outbound messages.
51const CHANNEL_CAPACITY: usize = 64;
52
53/// Number of consecutive `try_send` failures before closing a slow client.
54const MAX_SEND_FAILURES: u32 = 3;
55
56/// Minimum payload size (bytes) to consider compressing with deflate.
57const COMPRESSION_THRESHOLD: usize = 512;
58
59// ---------------------------------------------------------------------------
60// WebSocket message types
61// ---------------------------------------------------------------------------
62
63/// Incoming JSON message from a WebSocket client.
64#[derive(serde::Deserialize, Debug)]
65struct WsRequest {
66    op: String,
67    #[serde(rename = "reqId")]
68    req_id: Option<String>,
69    #[serde(rename = "watchDis")]
70    #[allow(dead_code)]
71    watch_dis: Option<String>,
72    #[serde(rename = "watchId")]
73    watch_id: Option<String>,
74    ids: Option<Vec<String>>,
75}
76
77/// Outgoing JSON message sent to a WebSocket client.
78#[derive(serde::Serialize, Debug)]
79struct WsResponse {
80    #[serde(rename = "reqId", skip_serializing_if = "Option::is_none")]
81    req_id: Option<String>,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    error: Option<String>,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    rows: Option<Vec<Value>>,
86    #[serde(rename = "watchId", skip_serializing_if = "Option::is_none")]
87    watch_id: Option<String>,
88}
89
90impl WsResponse {
91    /// Build an error response, preserving the request ID for correlation.
92    fn error(req_id: Option<String>, msg: impl Into<String>) -> Self {
93        Self {
94            req_id,
95            error: Some(msg.into()),
96            rows: None,
97            watch_id: None,
98        }
99    }
100
101    /// Build a success response with rows and an optional watch ID.
102    fn ok(req_id: Option<String>, rows: Vec<Value>, watch_id: Option<String>) -> Self {
103        Self {
104            req_id,
105            error: None,
106            rows: Some(rows),
107            watch_id,
108        }
109    }
110}
111
112// ---------------------------------------------------------------------------
113// Entity encoding helper
114// ---------------------------------------------------------------------------
115
116/// Encode an `HDict` entity as a JSON object using the Haystack JSON v3
117/// encoding for individual tag values.
118fn encode_entity(entity: &HDict) -> Value {
119    let mut m = Map::new();
120    let mut keys: Vec<&String> = entity.tags().keys().collect();
121    keys.sort();
122    for k in keys {
123        let v = &entity.tags()[k];
124        if let Ok(encoded) = json_v3::encode_kind(v) {
125            m.insert(k.clone(), encoded);
126        }
127    }
128    Value::Object(m)
129}
130
131// ---------------------------------------------------------------------------
132// WebSocket op dispatch
133// ---------------------------------------------------------------------------
134
135/// Handle a parsed `WsRequest` by dispatching to the appropriate watch op.
136///
137/// Returns the serialized JSON response string.
138fn handle_ws_request(req: &WsRequest, username: &str, state: &AppState) -> String {
139    let resp = match req.op.as_str() {
140        "watchSub" => handle_watch_sub(req, username, state),
141        "watchPoll" => handle_watch_poll(req, username, state),
142        "watchUnsub" => handle_watch_unsub(req, username, state),
143        other => WsResponse::error(req.req_id.clone(), format!("unknown op: {other}")),
144    };
145    // Serialization of WsResponse should never fail in practice.
146    serde_json::to_string(&resp).unwrap_or_else(|e| {
147        let fallback = WsResponse::error(req.req_id.clone(), format!("serialization error: {e}"));
148        serde_json::to_string(&fallback).unwrap()
149    })
150}
151
152/// Handle the `watchSub` op: create a new watch and return the initial
153/// state of the subscribed entities.
154fn handle_watch_sub(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
155    let ids = match &req.ids {
156        Some(ids) if !ids.is_empty() => ids.clone(),
157        _ => {
158            return WsResponse::error(
159                req.req_id.clone(),
160                "watchSub requires non-empty 'ids' array",
161            );
162        }
163    };
164
165    // Strip leading '@' from ref strings if present.
166    let ids: Vec<String> = ids
167        .into_iter()
168        .map(|id| id.strip_prefix('@').unwrap_or(&id).to_string())
169        .collect();
170
171    let graph_version = state.graph.version();
172    let watch_id = match state
173        .watches
174        .subscribe(username, ids.clone(), graph_version)
175    {
176        Ok(wid) => wid,
177        Err(e) => return WsResponse::error(req.req_id.clone(), e),
178    };
179
180    // Resolve initial entity values.
181    let rows: Vec<Value> = ids
182        .iter()
183        .filter_map(|id| state.graph.get(id).map(|e| encode_entity(&e)))
184        .collect();
185
186    WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
187}
188
189/// Handle the `watchPoll` op: poll an existing watch for changes.
190fn handle_watch_poll(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
191    let watch_id = match &req.watch_id {
192        Some(wid) => wid.clone(),
193        None => {
194            return WsResponse::error(req.req_id.clone(), "watchPoll requires 'watchId'");
195        }
196    };
197
198    match state.watches.poll(&watch_id, username, &state.graph) {
199        Some(changed) => {
200            let rows: Vec<Value> = changed.iter().map(encode_entity).collect();
201            WsResponse::ok(req.req_id.clone(), rows, Some(watch_id))
202        }
203        None => WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}")),
204    }
205}
206
207/// Handle the `watchUnsub` op: remove a watch or specific IDs from it.
208fn handle_watch_unsub(req: &WsRequest, username: &str, state: &AppState) -> WsResponse {
209    let watch_id = match &req.watch_id {
210        Some(wid) => wid.clone(),
211        None => {
212            return WsResponse::error(req.req_id.clone(), "watchUnsub requires 'watchId'");
213        }
214    };
215
216    // If specific IDs are provided, remove only those; otherwise remove the
217    // entire watch.
218    if let Some(ids) = &req.ids
219        && !ids.is_empty()
220    {
221        let clean: Vec<String> = ids
222            .iter()
223            .map(|id| id.strip_prefix('@').unwrap_or(id).to_string())
224            .collect();
225        if !state.watches.remove_ids(&watch_id, username, &clean) {
226            return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
227        }
228        return WsResponse::ok(req.req_id.clone(), vec![], Some(watch_id));
229    }
230
231    if !state.watches.unsubscribe(&watch_id, username) {
232        return WsResponse::error(req.req_id.clone(), format!("watch not found: {watch_id}"));
233    }
234    WsResponse::ok(req.req_id.clone(), vec![], None)
235}
236
237/// A single watch subscription.
238struct Watch {
239    /// Entity IDs being watched (HashSet for O(1) membership tests).
240    entity_ids: HashSet<String>,
241    /// Graph version at last poll.
242    last_version: u64,
243    /// Username of the watch owner.
244    owner: String,
245}
246
247/// Manages watch subscriptions for change polling.
248///
249/// Watches are keyed by a UUID watch ID and owned by a specific user.
250/// The manager enforces global and per-user watch limits, per-watch entity
251/// ID limits, and provides an entity encoding cache for efficient
252/// WebSocket server-push serialization.
253pub struct WatchManager {
254    watches: RwLock<HashMap<String, Watch>>,
255    /// Cached entity encodings keyed by (ref_val, version) for watch poll.
256    encode_cache: RwLock<HashMap<(String, u64), Value>>,
257    /// Graph version at which the encode cache was last validated.
258    cache_version: RwLock<u64>,
259}
260
261impl WatchManager {
262    /// Create a new empty WatchManager.
263    pub fn new() -> Self {
264        Self {
265            watches: RwLock::new(HashMap::new()),
266            encode_cache: RwLock::new(HashMap::new()),
267            cache_version: RwLock::new(0),
268        }
269    }
270
271    /// Subscribe to changes on a set of entity IDs.
272    ///
273    /// Returns the watch ID, or an error if a growth cap would be exceeded.
274    pub fn subscribe(
275        &self,
276        username: &str,
277        ids: Vec<String>,
278        graph_version: u64,
279    ) -> Result<String, String> {
280        let mut watches = self.watches.write();
281        if watches.len() >= MAX_WATCHES {
282            return Err("maximum number of watches reached".to_string());
283        }
284        let user_count = watches.values().filter(|w| w.owner == username).count();
285        if user_count >= MAX_WATCHES_PER_USER {
286            return Err(format!(
287                "user '{}' has reached the maximum of {} watches",
288                username, MAX_WATCHES_PER_USER
289            ));
290        }
291        if ids.len() > MAX_ENTITY_IDS_PER_WATCH {
292            return Err(format!(
293                "too many entity IDs (max {})",
294                MAX_ENTITY_IDS_PER_WATCH
295            ));
296        }
297        // Check total watched IDs across all watches for this user.
298        let user_total: usize = watches
299            .values()
300            .filter(|w| w.owner == username)
301            .map(|w| w.entity_ids.len())
302            .sum();
303        if user_total + ids.len() > MAX_TOTAL_WATCHED_IDS {
304            return Err(format!(
305                "user '{}' would exceed the maximum of {} total watched IDs",
306                username, MAX_TOTAL_WATCHED_IDS
307            ));
308        }
309        let watch_id = Uuid::new_v4().to_string();
310        let watch = Watch {
311            entity_ids: ids.into_iter().collect(),
312            last_version: graph_version,
313            owner: username.to_string(),
314        };
315        watches.insert(watch_id.clone(), watch);
316        Ok(watch_id)
317    }
318
319    /// Poll for changes since the last poll.
320    ///
321    /// Returns the current state of watched entities that have changed,
322    /// or `None` if the watch ID is not found or the caller is not the owner.
323    pub fn poll(&self, watch_id: &str, username: &str, graph: &SharedGraph) -> Option<Vec<HDict>> {
324        // Acquire the write lock only long enough to read watch state and
325        // update last_version.  Graph reads happen outside the lock to
326        // avoid holding it during potentially expensive I/O.
327        let (entity_ids, last_version) = {
328            let mut watches = self.watches.write();
329            let watch = watches.get_mut(watch_id)?;
330            if watch.owner != username {
331                return None;
332            }
333
334            let current_version = graph.version();
335            if current_version == watch.last_version {
336                // No changes
337                return Some(Vec::new());
338            }
339
340            let ids = watch.entity_ids.clone();
341            let last = watch.last_version;
342            watch.last_version = current_version;
343            (ids, last)
344        }; // write lock released here
345
346        // Graph reads happen without the WatchManager write lock held.
347        let changes = match graph.changes_since(last_version) {
348            Ok(c) => c,
349            Err(_gap) => {
350                // Subscriber fell behind — treat all watched entities as changed.
351                return Some(entity_ids.iter().filter_map(|id| graph.get(id)).collect());
352            }
353        };
354        let changed_refs: HashSet<&str> = changes.iter().map(|d| d.ref_val.as_str()).collect();
355
356        Some(
357            entity_ids
358                .iter()
359                .filter(|id| changed_refs.contains(id.as_str()))
360                .filter_map(|id| graph.get(id))
361                .collect(),
362        )
363    }
364
365    /// Unsubscribe a watch by ID.
366    ///
367    /// Returns `true` if the watch existed, was owned by `username`, and was removed.
368    pub fn unsubscribe(&self, watch_id: &str, username: &str) -> bool {
369        let mut watches = self.watches.write();
370        match watches.get(watch_id) {
371            Some(watch) if watch.owner == username => {
372                watches.remove(watch_id);
373                true
374            }
375            _ => false,
376        }
377    }
378
379    /// Add entity IDs to an existing watch.
380    ///
381    /// Returns `true` if the watch exists, is owned by `username`, and
382    /// the addition would not exceed the per-watch entity ID limit.
383    pub fn add_ids(&self, watch_id: &str, username: &str, ids: Vec<String>) -> bool {
384        let mut watches = self.watches.write();
385
386        // Pre-check ownership and compute total before taking a mutable ref.
387        let (owner_ok, per_watch_ok, user_total) = match watches.get(watch_id) {
388            Some(watch) => (
389                watch.owner == username,
390                watch.entity_ids.len() + ids.len() <= MAX_ENTITY_IDS_PER_WATCH,
391                watches
392                    .values()
393                    .filter(|w| w.owner == username)
394                    .map(|w| w.entity_ids.len())
395                    .sum::<usize>(),
396            ),
397            None => return false,
398        };
399
400        if !owner_ok || !per_watch_ok {
401            return false;
402        }
403        if user_total + ids.len() > MAX_TOTAL_WATCHED_IDS {
404            return false;
405        }
406
407        if let Some(watch) = watches.get_mut(watch_id) {
408            watch.entity_ids.extend(ids);
409        }
410        true
411    }
412
413    /// Remove specific entity IDs from an existing watch.
414    ///
415    /// Returns `true` if the watch exists and is owned by `username`.
416    /// If all IDs are removed, the watch remains but with an empty entity set.
417    pub fn remove_ids(&self, watch_id: &str, username: &str, ids: &[String]) -> bool {
418        let mut watches = self.watches.write();
419        if let Some(watch) = watches.get_mut(watch_id) {
420            if watch.owner != username {
421                return false;
422            }
423            for id in ids {
424                watch.entity_ids.remove(id);
425            }
426            true
427        } else {
428            false
429        }
430    }
431
432    /// Return the list of entity IDs for a given watch.
433    ///
434    /// Returns `None` if the watch does not exist.
435    pub fn get_ids(&self, watch_id: &str) -> Option<Vec<String>> {
436        let watches = self.watches.read();
437        watches
438            .get(watch_id)
439            .map(|w| w.entity_ids.iter().cloned().collect())
440    }
441
442    /// Return the number of active watches.
443    pub fn len(&self) -> usize {
444        self.watches.read().len()
445    }
446
447    /// Return whether there are no active watches.
448    pub fn is_empty(&self) -> bool {
449        self.watches.read().is_empty()
450    }
451
452    /// Encode an entity using the cache. Returns cached value if the entity
453    /// version hasn't changed; otherwise encodes and caches the result.
454    pub fn encode_cached(&self, ref_val: &str, graph_version: u64, entity: &HDict) -> Value {
455        // Invalidate entire cache when graph version advances beyond what we've seen.
456        {
457            let mut cv = self.cache_version.write();
458            if graph_version > *cv {
459                self.encode_cache.write().clear();
460                *cv = graph_version;
461            }
462        }
463
464        let key = (ref_val.to_string(), graph_version);
465        if let Some(cached) = self.encode_cache.read().get(&key) {
466            return cached.clone();
467        }
468
469        let encoded = encode_entity(entity);
470        let mut cache = self.encode_cache.write();
471        cache.insert(key, encoded.clone());
472        if cache.len() > MAX_ENCODE_CACHE_ENTRIES {
473            // Evict oldest quarter of entries
474            let to_remove = cache.len() / 4;
475            let keys: Vec<_> = cache.keys().take(to_remove).cloned().collect();
476            for k in keys {
477                cache.remove(&k);
478            }
479        }
480        encoded
481    }
482
483    /// Get the IDs of all entities watched by any watch, for server-push
484    /// change detection.
485    pub fn all_watched_ids(&self) -> HashSet<String> {
486        let watches = self.watches.read();
487        watches
488            .values()
489            .flat_map(|w| w.entity_ids.iter().cloned())
490            .collect()
491    }
492
493    /// Find watches that contain any of the given changed ref_vals.
494    /// Returns (watch_id, owner, changed_entity_ids) tuples.
495    pub fn watches_affected_by(
496        &self,
497        changed_refs: &HashSet<&str>,
498    ) -> Vec<(String, String, Vec<String>)> {
499        let watches = self.watches.read();
500        let mut affected = Vec::new();
501        for (wid, watch) in watches.iter() {
502            let matched: Vec<String> = watch
503                .entity_ids
504                .iter()
505                .filter(|id| changed_refs.contains(id.as_str()))
506                .cloned()
507                .collect();
508            if !matched.is_empty() {
509                affected.push((wid.clone(), watch.owner.clone(), matched));
510            }
511        }
512        affected
513    }
514}
515
516impl Default for WatchManager {
517    fn default() -> Self {
518        Self::new()
519    }
520}
521
522// ---------------------------------------------------------------------------
523// Compression helpers (application-level deflate)
524// ---------------------------------------------------------------------------
525
526/// Compress a response string with deflate if it exceeds the threshold.
527/// Returns the original text if compression doesn't save space.
528fn maybe_compress_response(text: &str) -> WsPayload {
529    if text.len() < COMPRESSION_THRESHOLD {
530        return WsPayload::Text(text.to_string());
531    }
532    let compressed = compress_deflate(text.as_bytes());
533    if compressed.len() < text.len() {
534        WsPayload::Binary(compressed)
535    } else {
536        WsPayload::Text(text.to_string())
537    }
538}
539
540fn compress_deflate(data: &[u8]) -> Vec<u8> {
541    use flate2::Compression;
542    use flate2::write::DeflateEncoder;
543    use std::io::Write;
544
545    let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast());
546    let _ = encoder.write_all(data);
547    encoder.finish().unwrap_or_else(|_| data.to_vec())
548}
549
550/// Maximum decompressed payload size (10 MB) to prevent zip bomb attacks.
551const MAX_DECOMPRESSED_SIZE: u64 = 10 * 1024 * 1024;
552
553fn decompress_deflate(data: &[u8]) -> Result<String, std::io::Error> {
554    use flate2::read::DeflateDecoder;
555    use std::io::Read;
556
557    let decoder = DeflateDecoder::new(data);
558    let mut limited = decoder.take(MAX_DECOMPRESSED_SIZE);
559    let mut output = String::new();
560    limited.read_to_string(&mut output)?;
561    Ok(output)
562}
563
564enum WsPayload {
565    Text(String),
566    Binary(Vec<u8>),
567}
568
569// ---------------------------------------------------------------------------
570// WebSocket handler
571// ---------------------------------------------------------------------------
572
573/// WebSocket upgrade handler for `/api/ws`.
574///
575/// Upgrades the HTTP connection to a WebSocket and handles Haystack
576/// watch operations (watchSub, watchPoll, watchUnsub) over JSON
577/// messages.  Each client request may include a `reqId` field which
578/// is echoed back in the response for correlation.
579///
580/// Features:
581/// - Server-initiated ping every [`PING_INTERVAL`] for liveness detection
582/// - Backpressure: slow clients are disconnected after [`MAX_SEND_FAILURES`]
583/// - Deflate compression for large responses (binary frames)
584/// - Server-push: graph changes are pushed to watching clients automatically
585pub async fn ws_handler(
586    req: HttpRequest,
587    stream: web::Payload,
588    state: web::Data<AppState>,
589) -> actix_web::Result<HttpResponse> {
590    // Require authenticated user when auth is enabled
591    let username = if state.auth.is_enabled() {
592        match req.extensions().get::<crate::auth::AuthUser>() {
593            Some(u) => u.username.clone(),
594            None => {
595                return Err(crate::error::HaystackError::new(
596                    "authentication required for WebSocket connections",
597                    actix_web::http::StatusCode::UNAUTHORIZED,
598                )
599                .into());
600            }
601        }
602    } else {
603        req.extensions()
604            .get::<crate::auth::AuthUser>()
605            .map(|u| u.username.clone())
606            .unwrap_or_else(|| "anonymous".to_string())
607    };
608
609    let (response, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
610
611    actix_rt::spawn(async move {
612        use actix_ws::Message;
613        use tokio::sync::mpsc;
614
615        let (tx, mut rx) = mpsc::channel::<WsPayload>(CHANNEL_CAPACITY);
616
617        // Spawn a task to forward messages from the channel to the WS session.
618        let mut session_clone = session.clone();
619        actix_rt::spawn(async move {
620            while let Some(payload) = rx.recv().await {
621                let result = match payload {
622                    WsPayload::Text(text) => session_clone.text(text).await,
623                    WsPayload::Binary(data) => session_clone.binary(data).await,
624                };
625                if result.is_err() {
626                    break;
627                }
628            }
629        });
630
631        // Track connection liveness.
632        let mut last_activity = Instant::now();
633        let mut ping_interval = tokio::time::interval(PING_INTERVAL);
634        ping_interval.tick().await; // consume the immediate first tick
635        let mut awaiting_pong = false;
636        let mut send_failures: u32 = 0;
637
638        // Track graph version for server-push change detection.
639        let mut last_push_version = state.graph.version();
640
641        // Server-push check interval (faster than ping but slower than a busy loop).
642        let mut push_interval = tokio::time::interval(Duration::from_millis(500));
643        push_interval.tick().await;
644
645        use futures_util::StreamExt as _;
646
647        loop {
648            tokio::select! {
649                // ── Incoming WS messages ──
650                msg = msg_stream.next() => {
651                    let Some(Ok(msg)) = msg else { break };
652                    last_activity = Instant::now();
653                    awaiting_pong = false;
654
655                    match msg {
656                        Message::Text(text) => {
657                            let response_text = match serde_json::from_str::<WsRequest>(&text) {
658                                Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
659                                Err(e) => {
660                                    let err = WsResponse::error(None, format!("invalid request: {e}"));
661                                    serde_json::to_string(&err).unwrap()
662                                }
663                            };
664                            let payload = maybe_compress_response(&response_text);
665                            if tx.try_send(payload).is_err() {
666                                send_failures += 1;
667                                if send_failures >= MAX_SEND_FAILURES {
668                                    log::warn!("closing slow WS client ({})", username);
669                                    break;
670                                }
671                            } else {
672                                send_failures = 0;
673                            }
674                        }
675                        Message::Binary(data) => {
676                            // Compressed request from client.
677                            if let Ok(text) = decompress_deflate(&data) {
678                                let response_text = match serde_json::from_str::<WsRequest>(&text) {
679                                    Ok(ws_req) => handle_ws_request(&ws_req, &username, &state),
680                                    Err(e) => {
681                                        let err = WsResponse::error(None, format!("invalid request: {e}"));
682                                        serde_json::to_string(&err).unwrap()
683                                    }
684                                };
685                                let payload = maybe_compress_response(&response_text);
686                                let _ = tx.try_send(payload);
687                            }
688                        }
689                        Message::Ping(bytes) => {
690                            let _ = session.pong(&bytes).await;
691                        }
692                        Message::Pong(_) => {
693                            awaiting_pong = false;
694                        }
695                        Message::Close(_) => {
696                            break;
697                        }
698                        _ => {}
699                    }
700                }
701
702                // ── Server-initiated ping for liveness ──
703                _ = ping_interval.tick() => {
704                    if awaiting_pong && last_activity.elapsed() > PONG_TIMEOUT {
705                        log::info!("closing stale WS connection ({}): no pong", username);
706                        break;
707                    }
708                    let _ = session.ping(b"").await;
709                    awaiting_pong = true;
710                }
711
712                // ── Server-push: check for graph changes ──
713                _ = push_interval.tick() => {
714                    let current_version = state.graph.version();
715                    if current_version > last_push_version {
716                        let changes = match state.graph.changes_since(last_push_version) {
717                            Ok(c) => c,
718                            Err(_gap) => {
719                                // Subscriber fell behind — skip to current version.
720                                last_push_version = current_version;
721                                continue;
722                            }
723                        };
724                        let changed_refs: HashSet<&str> =
725                            changes.iter().map(|d| d.ref_val.as_str()).collect();
726
727                        let affected = state.watches.watches_affected_by(&changed_refs);
728                        for (watch_id, owner, changed_ids) in &affected {
729                            if owner != &username {
730                                continue;
731                            }
732                            let rows: Vec<Value> = changed_ids
733                                .iter()
734                                .filter_map(|id| {
735                                    let entity = state.graph.get(id)?;
736                                    Some(state.watches.encode_cached(id, current_version, &entity))
737                                })
738                                .collect();
739                            if !rows.is_empty() {
740                                let push_msg = serde_json::json!({
741                                    "type": "push",
742                                    "watchId": watch_id,
743                                    "rows": rows,
744                                });
745                                if let Ok(text) = serde_json::to_string(&push_msg) {
746                                    let payload = maybe_compress_response(&text);
747                                    let _ = tx.try_send(payload);
748                                }
749                            }
750                        }
751                        last_push_version = current_version;
752                    }
753                }
754            }
755        }
756
757        let _ = session.close(None).await;
758    });
759
760    Ok(response)
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766    use haystack_core::graph::{EntityGraph, SharedGraph};
767    use haystack_core::kinds::{HRef, Kind};
768
769    fn make_graph_with_entity(id: &str) -> SharedGraph {
770        let graph = SharedGraph::new(EntityGraph::new());
771        let mut entity = HDict::new();
772        entity.set("id", Kind::Ref(HRef::from_val(id)));
773        entity.set("site", Kind::Marker);
774        entity.set("dis", Kind::Str(format!("Site {id}")));
775        graph.add(entity).unwrap();
776        graph
777    }
778
779    #[test]
780    fn subscribe_returns_watch_id() {
781        let wm = WatchManager::new();
782        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
783        assert!(!watch_id.is_empty());
784    }
785
786    #[test]
787    fn poll_no_changes() {
788        let graph = make_graph_with_entity("site-1");
789        let wm = WatchManager::new();
790        let version = graph.version();
791        let watch_id = wm
792            .subscribe("admin", vec!["site-1".into()], version)
793            .unwrap();
794
795        let changes = wm.poll(&watch_id, "admin", &graph).unwrap();
796        assert!(changes.is_empty());
797    }
798
799    #[test]
800    fn poll_with_changes() {
801        let graph = make_graph_with_entity("site-1");
802        let wm = WatchManager::new();
803        let version = graph.version();
804        let watch_id = wm
805            .subscribe("admin", vec!["site-1".into()], version)
806            .unwrap();
807
808        // Modify the entity
809        let mut changes = HDict::new();
810        changes.set("dis", Kind::Str("Updated".into()));
811        graph.update("site-1", changes).unwrap();
812
813        let result = wm.poll(&watch_id, "admin", &graph).unwrap();
814        assert_eq!(result.len(), 1);
815    }
816
817    #[test]
818    fn poll_unknown_watch() {
819        let graph = make_graph_with_entity("site-1");
820        let wm = WatchManager::new();
821        assert!(wm.poll("unknown", "admin", &graph).is_none());
822    }
823
824    #[test]
825    fn poll_wrong_owner() {
826        let graph = make_graph_with_entity("site-1");
827        let wm = WatchManager::new();
828        let version = graph.version();
829        let watch_id = wm
830            .subscribe("admin", vec!["site-1".into()], version)
831            .unwrap();
832
833        // Different user cannot poll the watch
834        assert!(wm.poll(&watch_id, "other-user", &graph).is_none());
835    }
836
837    #[test]
838    fn unsubscribe_removes_watch() {
839        let wm = WatchManager::new();
840        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
841        assert!(wm.unsubscribe(&watch_id, "admin"));
842        assert!(!wm.unsubscribe(&watch_id, "admin")); // already removed
843    }
844
845    #[test]
846    fn unsubscribe_wrong_owner() {
847        let wm = WatchManager::new();
848        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
849        // Different user cannot unsubscribe
850        assert!(!wm.unsubscribe(&watch_id, "other-user"));
851        // Original owner can still unsubscribe
852        assert!(wm.unsubscribe(&watch_id, "admin"));
853    }
854
855    #[test]
856    fn remove_ids_selective() {
857        let wm = WatchManager::new();
858        let watch_id = wm
859            .subscribe(
860                "admin",
861                vec!["site-1".into(), "site-2".into(), "site-3".into()],
862                0,
863            )
864            .unwrap();
865
866        // Remove only site-2
867        assert!(wm.remove_ids(&watch_id, "admin", &["site-2".into()]));
868
869        let remaining = wm.get_ids(&watch_id).unwrap();
870        assert_eq!(remaining.len(), 2);
871        assert!(remaining.contains(&"site-1".to_string()));
872        assert!(remaining.contains(&"site-3".to_string()));
873        assert!(!remaining.contains(&"site-2".to_string()));
874    }
875
876    #[test]
877    fn remove_ids_nonexistent_watch() {
878        let wm = WatchManager::new();
879        assert!(!wm.remove_ids("no-such-watch", "admin", &["site-1".into()]));
880    }
881
882    #[test]
883    fn remove_ids_wrong_owner() {
884        let wm = WatchManager::new();
885        let watch_id = wm
886            .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
887            .unwrap();
888
889        // Different user cannot remove IDs
890        assert!(!wm.remove_ids(&watch_id, "other-user", &["site-1".into()]));
891
892        // IDs remain unchanged
893        let remaining = wm.get_ids(&watch_id).unwrap();
894        assert_eq!(remaining.len(), 2);
895    }
896
897    #[test]
898    fn remove_ids_leaves_watch_alive() {
899        let wm = WatchManager::new();
900        let watch_id = wm
901            .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
902            .unwrap();
903
904        // Remove all IDs selectively — watch still exists with empty entity set
905        assert!(wm.remove_ids(&watch_id, "admin", &["site-1".into(), "site-2".into()]));
906
907        let remaining = wm.get_ids(&watch_id).unwrap();
908        assert!(remaining.is_empty());
909
910        // The watch itself still exists (unsubscribe should succeed)
911        assert!(wm.unsubscribe(&watch_id, "admin"));
912    }
913
914    #[test]
915    fn unsubscribe_full_removal() {
916        let wm = WatchManager::new();
917        let watch_id = wm
918            .subscribe("admin", vec!["site-1".into(), "site-2".into()], 0)
919            .unwrap();
920
921        // Full unsubscribe removes the watch entirely
922        assert!(wm.unsubscribe(&watch_id, "admin"));
923
924        // Watch no longer exists — get_ids returns None
925        assert!(wm.get_ids(&watch_id).is_none());
926
927        // Second unsubscribe returns false
928        assert!(!wm.unsubscribe(&watch_id, "admin"));
929    }
930
931    #[test]
932    fn add_ids_ownership_check() {
933        let wm = WatchManager::new();
934        let watch_id = wm.subscribe("admin", vec!["site-1".into()], 0).unwrap();
935
936        // Different user cannot add IDs
937        assert!(!wm.add_ids(&watch_id, "other-user", vec!["site-2".into()]));
938
939        // Original owner can add IDs
940        assert!(wm.add_ids(&watch_id, "admin", vec!["site-2".into()]));
941
942        let ids = wm.get_ids(&watch_id).unwrap();
943        assert_eq!(ids.len(), 2);
944        assert!(ids.contains(&"site-1".to_string()));
945        assert!(ids.contains(&"site-2".to_string()));
946    }
947
948    #[test]
949    fn get_ids_returns_none_for_unknown_watch() {
950        let wm = WatchManager::new();
951        assert!(wm.get_ids("nonexistent").is_none());
952    }
953
954    // -----------------------------------------------------------------------
955    // WebSocket message format tests
956    // -----------------------------------------------------------------------
957
958    #[test]
959    fn ws_request_deserialization() {
960        let json = r#"{
961            "op": "watchSub",
962            "reqId": "abc-123",
963            "watchDis": "my-watch",
964            "ids": ["@ref1", "@ref2"]
965        }"#;
966        let req: WsRequest = serde_json::from_str(json).unwrap();
967        assert_eq!(req.op, "watchSub");
968        assert_eq!(req.req_id.as_deref(), Some("abc-123"));
969        assert_eq!(req.watch_dis.as_deref(), Some("my-watch"));
970        assert!(req.watch_id.is_none());
971        let ids = req.ids.unwrap();
972        assert_eq!(ids, vec!["@ref1", "@ref2"]);
973    }
974
975    #[test]
976    fn ws_request_deserialization_minimal() {
977        // Only `op` is required; all other fields are optional.
978        let json = r#"{"op": "watchPoll", "watchId": "w-1"}"#;
979        let req: WsRequest = serde_json::from_str(json).unwrap();
980        assert_eq!(req.op, "watchPoll");
981        assert!(req.req_id.is_none());
982        assert!(req.watch_dis.is_none());
983        assert_eq!(req.watch_id.as_deref(), Some("w-1"));
984        assert!(req.ids.is_none());
985    }
986
987    #[test]
988    fn ws_response_serialization() {
989        let resp = WsResponse::ok(
990            Some("r-1".into()),
991            vec![serde_json::json!({"id": "r:site-1"})],
992            Some("w-1".into()),
993        );
994        let json = serde_json::to_value(&resp).unwrap();
995        assert_eq!(json["reqId"], "r-1");
996        assert_eq!(json["watchId"], "w-1");
997        assert!(json["rows"].is_array());
998        assert_eq!(json["rows"][0]["id"], "r:site-1");
999        // `error` should be absent (not null) when None
1000        assert!(json.get("error").is_none());
1001    }
1002
1003    #[test]
1004    fn ws_response_omits_none_fields() {
1005        let resp = WsResponse::ok(None, vec![], None);
1006        let json = serde_json::to_value(&resp).unwrap();
1007        // reqId, error, and watchId should all be absent
1008        assert!(json.get("reqId").is_none());
1009        assert!(json.get("error").is_none());
1010        assert!(json.get("watchId").is_none());
1011        // rows is present (empty array)
1012        assert!(json["rows"].is_array());
1013    }
1014
1015    #[test]
1016    fn ws_response_includes_req_id() {
1017        let resp = WsResponse::error(Some("req-42".into()), "something went wrong");
1018        let json = serde_json::to_value(&resp).unwrap();
1019        assert_eq!(json["reqId"], "req-42");
1020        assert_eq!(json["error"], "something went wrong");
1021        // rows and watchId should be absent on error
1022        assert!(json.get("rows").is_none());
1023        assert!(json.get("watchId").is_none());
1024    }
1025
1026    #[test]
1027    fn ws_error_response_format() {
1028        let resp = WsResponse::error(None, "bad request");
1029        let json = serde_json::to_value(&resp).unwrap();
1030        assert_eq!(json["error"], "bad request");
1031        // reqId should be absent when not provided
1032        assert!(json.get("reqId").is_none());
1033        // rows and watchId should be absent
1034        assert!(json.get("rows").is_none());
1035        assert!(json.get("watchId").is_none());
1036    }
1037}