Skip to main content

haystack_core/graph/
subscriber.rs

1// GraphSubscriber — async helper for consuming graph change notifications.
2
3use tokio::sync::broadcast;
4
5use super::changelog::{ChangelogGap, GraphDiff};
6use super::shared::SharedGraph;
7
8/// Async helper that pairs a [`SharedGraph`] with a broadcast receiver
9/// to yield batches of [`GraphDiff`] entries whenever the graph changes.
10///
11/// # Example
12///
13/// ```ignore
14/// let subscriber = GraphSubscriber::new(graph.clone());
15/// loop {
16///     match subscriber.next_batch().await {
17///         Ok(diffs) => { /* process diffs */ }
18///         Err(gap) => { /* full resync needed */ }
19///     }
20/// }
21/// ```
22pub struct GraphSubscriber {
23    graph: SharedGraph,
24    rx: broadcast::Receiver<u64>,
25    last_version: u64,
26}
27
28impl GraphSubscriber {
29    /// Create a new subscriber starting from the graph's current version.
30    pub fn new(graph: SharedGraph) -> Self {
31        let rx = graph.subscribe();
32        let last_version = graph.version();
33        Self {
34            graph,
35            rx,
36            last_version,
37        }
38    }
39
40    /// Create a subscriber starting from a specific version.
41    ///
42    /// Useful for resuming after a reconnect.
43    pub fn from_version(graph: SharedGraph, version: u64) -> Self {
44        let rx = graph.subscribe();
45        Self {
46            graph,
47            rx,
48            last_version: version,
49        }
50    }
51
52    /// Wait for the next batch of changes and return them.
53    ///
54    /// Blocks (async) until at least one write occurs, then returns all
55    /// diffs since the last consumed version. Returns `Err(ChangelogGap)`
56    /// if the subscriber has fallen too far behind.
57    pub async fn next_batch(&mut self) -> Result<Vec<GraphDiff>, ChangelogGap> {
58        // Wait for at least one version notification.
59        // Coalesce: drain any additional pending notifications.
60        let mut _latest = match self.rx.recv().await {
61            Ok(v) => v,
62            Err(broadcast::error::RecvError::Lagged(_)) => {
63                // Missed messages — still try to get diffs from changelog.
64                self.graph.version()
65            }
66            Err(broadcast::error::RecvError::Closed) => {
67                // Channel closed — return current state.
68                return Ok(Vec::new());
69            }
70        };
71
72        // Drain any buffered notifications to coalesce into one batch.
73        while let Ok(v) = self.rx.try_recv() {
74            _latest = v;
75        }
76
77        let diffs = self.graph.changes_since(self.last_version)?;
78        if let Some(last) = diffs.last() {
79            self.last_version = last.version;
80        }
81        Ok(diffs)
82    }
83
84    /// The last version this subscriber has consumed.
85    pub fn version(&self) -> u64 {
86        self.last_version
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::data::HDict;
94    use crate::graph::EntityGraph;
95    use crate::kinds::{HRef, Kind};
96
97    fn make_site(id: &str) -> HDict {
98        let mut d = HDict::new();
99        d.set("id", Kind::Ref(HRef::from_val(id)));
100        d.set("site", Kind::Marker);
101        d
102    }
103
104    #[tokio::test]
105    async fn subscriber_receives_diffs() {
106        let sg = SharedGraph::new(EntityGraph::new());
107        let mut sub = GraphSubscriber::new(sg.clone());
108        assert_eq!(sub.version(), 0);
109
110        sg.add(make_site("site-1")).unwrap();
111
112        let diffs = sub.next_batch().await.unwrap();
113        assert_eq!(diffs.len(), 1);
114        assert_eq!(diffs[0].ref_val, "site-1");
115        assert_eq!(sub.version(), 1);
116    }
117
118    #[tokio::test]
119    async fn subscriber_coalesces_batches() {
120        let sg = SharedGraph::new(EntityGraph::new());
121        let mut sub = GraphSubscriber::new(sg.clone());
122
123        // Add multiple entities before subscriber reads.
124        sg.add(make_site("site-1")).unwrap();
125        sg.add(make_site("site-2")).unwrap();
126        sg.add(make_site("site-3")).unwrap();
127
128        // Give broadcast a moment to buffer.
129        tokio::task::yield_now().await;
130
131        let diffs = sub.next_batch().await.unwrap();
132        assert_eq!(diffs.len(), 3);
133        assert_eq!(sub.version(), 3);
134    }
135
136    #[tokio::test]
137    async fn subscriber_from_version() {
138        let sg = SharedGraph::new(EntityGraph::new());
139        sg.add(make_site("site-1")).unwrap();
140        sg.add(make_site("site-2")).unwrap();
141
142        // Start from version 1, should only get v2 onwards.
143        let mut sub = GraphSubscriber::from_version(sg.clone(), 1);
144
145        sg.add(make_site("site-3")).unwrap();
146
147        let diffs = sub.next_batch().await.unwrap();
148        assert_eq!(diffs.len(), 2); // site-2 (v2) and site-3 (v3)
149        assert_eq!(sub.version(), 3);
150    }
151}