Skip to main content

haystack_core/graph/
subscriber.rs

1// GraphSubscriber — async helper for consuming graph change notifications.
2
3use tokio::sync::broadcast;
4
5use super::changelog::{ChangelogGap, GraphDiff};
6use super::shared::SharedGraph;
7
8/// Async helper that pairs a [`SharedGraph`] with a broadcast receiver
9/// to yield batches of [`GraphDiff`] entries whenever the graph changes.
10///
11/// # Example
12///
13/// ```ignore
14/// let subscriber = GraphSubscriber::new(graph.clone());
15/// loop {
16///     match subscriber.next_batch().await {
17///         Ok(diffs) => { /* process diffs */ }
18///         Err(gap) => { /* full resync needed */ }
19///     }
20/// }
21/// ```
22pub struct GraphSubscriber {
23    graph: SharedGraph,
24    rx: broadcast::Receiver<u64>,
25    last_version: u64,
26}
27
28impl GraphSubscriber {
29    /// Create a new subscriber starting from the graph's current version.
30    pub fn new(graph: SharedGraph) -> Self {
31        let (rx, last_version) = graph.subscribe_with_version();
32        Self {
33            graph,
34            rx,
35            last_version,
36        }
37    }
38
39    /// Create a subscriber starting from a specific version.
40    ///
41    /// Useful for resuming after a reconnect.
42    pub fn from_version(graph: SharedGraph, version: u64) -> Self {
43        let rx = graph.subscribe();
44        Self {
45            graph,
46            rx,
47            last_version: version,
48        }
49    }
50
51    /// Wait for the next batch of changes and return them.
52    ///
53    /// Blocks (async) until at least one write occurs, then returns all
54    /// diffs since the last consumed version. Returns `Err(ChangelogGap)`
55    /// if the subscriber has fallen too far behind.
56    pub async fn next_batch(&mut self) -> Result<Vec<GraphDiff>, ChangelogGap> {
57        // Wait for at least one version notification.
58        // Coalesce: drain any additional pending notifications.
59        let mut _latest = match self.rx.recv().await {
60            Ok(v) => v,
61            Err(broadcast::error::RecvError::Lagged(_)) => {
62                // Missed messages — still try to get diffs from changelog.
63                self.graph.version()
64            }
65            Err(broadcast::error::RecvError::Closed) => {
66                // Channel closed — return current state.
67                return Ok(Vec::new());
68            }
69        };
70
71        // Drain any buffered notifications to coalesce into one batch.
72        while let Ok(v) = self.rx.try_recv() {
73            _latest = v;
74        }
75
76        let diffs = self.graph.changes_since(self.last_version)?;
77        if let Some(last) = diffs.last() {
78            self.last_version = last.version;
79        }
80        Ok(diffs)
81    }
82
83    /// The last version this subscriber has consumed.
84    pub fn version(&self) -> u64 {
85        self.last_version
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::data::HDict;
93    use crate::graph::EntityGraph;
94    use crate::kinds::{HRef, Kind};
95
96    fn make_site(id: &str) -> HDict {
97        let mut d = HDict::new();
98        d.set("id", Kind::Ref(HRef::from_val(id)));
99        d.set("site", Kind::Marker);
100        d
101    }
102
103    #[tokio::test]
104    async fn subscriber_receives_diffs() {
105        let sg = SharedGraph::new(EntityGraph::new());
106        let mut sub = GraphSubscriber::new(sg.clone());
107        assert_eq!(sub.version(), 0);
108
109        sg.add(make_site("site-1")).unwrap();
110
111        let diffs = sub.next_batch().await.unwrap();
112        assert_eq!(diffs.len(), 1);
113        assert_eq!(diffs[0].ref_val, "site-1");
114        assert_eq!(sub.version(), 1);
115    }
116
117    #[tokio::test]
118    async fn subscriber_coalesces_batches() {
119        let sg = SharedGraph::new(EntityGraph::new());
120        let mut sub = GraphSubscriber::new(sg.clone());
121
122        // Add multiple entities before subscriber reads.
123        sg.add(make_site("site-1")).unwrap();
124        sg.add(make_site("site-2")).unwrap();
125        sg.add(make_site("site-3")).unwrap();
126
127        // Give broadcast a moment to buffer.
128        tokio::task::yield_now().await;
129
130        let diffs = sub.next_batch().await.unwrap();
131        assert_eq!(diffs.len(), 3);
132        assert_eq!(sub.version(), 3);
133    }
134
135    #[tokio::test]
136    async fn subscriber_from_version() {
137        let sg = SharedGraph::new(EntityGraph::new());
138        sg.add(make_site("site-1")).unwrap();
139        sg.add(make_site("site-2")).unwrap();
140
141        // Start from version 1, should only get v2 onwards.
142        let mut sub = GraphSubscriber::from_version(sg.clone(), 1);
143
144        sg.add(make_site("site-3")).unwrap();
145
146        let diffs = sub.next_batch().await.unwrap();
147        assert_eq!(diffs.len(), 2); // site-2 (v2) and site-3 (v3)
148        assert_eq!(sub.version(), 3);
149    }
150}