Skip to main content

haystack_core/graph/
shared.rs

1// SharedGraph — thread-safe wrapper around EntityGraph using parking_lot RwLock.
2
3use parking_lot::RwLock;
4use std::sync::Arc;
5use tokio::sync::broadcast;
6
7use crate::data::{HDict, HGrid};
8use crate::ontology::ValidationIssue;
9
10use super::entity_graph::{EntityGraph, GraphError, HierarchyNode};
11
12/// Default broadcast channel capacity.
13const BROADCAST_CAPACITY: usize = 256;
14
15/// Thread-safe, clonable handle to an `EntityGraph`.
16///
17/// Uses `parking_lot::RwLock` for reader-writer locking, allowing
18/// concurrent reads with exclusive writes. Cloning shares the
19/// underlying graph (via `Arc`).
20///
21/// Write operations automatically send the new graph version on an
22/// internal broadcast channel. Call [`subscribe`](SharedGraph::subscribe)
23/// to get a receiver.
24pub struct SharedGraph {
25    inner: Arc<RwLock<EntityGraph>>,
26    tx: broadcast::Sender<u64>,
27}
28
29impl SharedGraph {
30    /// Wrap an `EntityGraph` in a thread-safe handle.
31    pub fn new(graph: EntityGraph) -> Self {
32        let (tx, _) = broadcast::channel(BROADCAST_CAPACITY);
33        Self {
34            inner: Arc::new(RwLock::new(graph)),
35            tx,
36        }
37    }
38
39    /// Subscribe to graph change notifications.
40    ///
41    /// Returns a receiver that yields the new graph version after each
42    /// write operation (add, update, remove).
43    pub fn subscribe(&self) -> broadcast::Receiver<u64> {
44        self.tx.subscribe()
45    }
46
47    /// Subscribe and atomically capture the current version under the read lock.
48    ///
49    /// This avoids a TOCTOU race where a write could occur between
50    /// subscribing and reading the version.
51    pub fn subscribe_with_version(&self) -> (broadcast::Receiver<u64>, u64) {
52        let guard = self.inner.read();
53        let version = guard.version();
54        let rx = self.tx.subscribe();
55        (rx, version)
56    }
57
58    /// Number of active subscribers.
59    pub fn subscriber_count(&self) -> usize {
60        self.tx.receiver_count()
61    }
62
63    /// Execute a closure with shared (read) access to the graph.
64    pub fn read<F, R>(&self, f: F) -> R
65    where
66        F: FnOnce(&EntityGraph) -> R,
67    {
68        let guard = self.inner.read();
69        f(&guard)
70    }
71
72    /// Execute a closure with exclusive (write) access to the graph.
73    pub fn write<F, R>(&self, f: F) -> R
74    where
75        F: FnOnce(&mut EntityGraph) -> R,
76    {
77        let mut guard = self.inner.write();
78        f(&mut guard)
79    }
80
81    /// Execute a write closure and broadcast the new version if it changed.
82    fn write_and_notify<F, R>(&self, f: F) -> R
83    where
84        F: FnOnce(&mut EntityGraph) -> R,
85    {
86        let (result, version) = {
87            let mut guard = self.inner.write();
88            let v_before = guard.version();
89            let result = f(&mut guard);
90            let v_after = guard.version();
91            (
92                result,
93                if v_after != v_before {
94                    Some(v_after)
95                } else {
96                    None
97                },
98            )
99        };
100        // Send outside the lock to avoid holding it during broadcast.
101        if let Some(v) = version {
102            let _ = self.tx.send(v);
103        }
104        result
105    }
106
107    // ── Convenience methods ──
108
109    /// Add an entity. See [`EntityGraph::add`].
110    pub fn add(&self, entity: HDict) -> Result<String, GraphError> {
111        self.write_and_notify(|g| g.add(entity))
112    }
113
114    /// Get an entity by ref value.
115    ///
116    /// Returns an owned clone because the read lock is released before
117    /// the caller uses the value.
118    pub fn get(&self, ref_val: &str) -> Option<HDict> {
119        self.read(|g| g.get(ref_val).cloned())
120    }
121
122    /// Update an entity. See [`EntityGraph::update`].
123    pub fn update(&self, ref_val: &str, changes: HDict) -> Result<(), GraphError> {
124        self.write_and_notify(|g| g.update(ref_val, changes))
125    }
126
127    /// Remove an entity. See [`EntityGraph::remove`].
128    pub fn remove(&self, ref_val: &str) -> Result<HDict, GraphError> {
129        self.write_and_notify(|g| g.remove(ref_val))
130    }
131
132    /// Run a filter expression and return a grid.
133    pub fn read_filter(&self, filter_expr: &str, limit: usize) -> Result<HGrid, GraphError> {
134        self.read(|g| g.read(filter_expr, limit))
135    }
136
137    /// Number of entities.
138    pub fn len(&self) -> usize {
139        self.read(|g| g.len())
140    }
141
142    /// Returns `true` if the graph has no entities.
143    pub fn is_empty(&self) -> bool {
144        self.read(|g| g.is_empty())
145    }
146
147    /// Return all entities as owned clones.
148    pub fn all_entities(&self) -> Vec<HDict> {
149        self.read(|g| g.all().into_iter().cloned().collect())
150    }
151
152    /// Check if an entity with the given ref value exists.
153    pub fn contains(&self, ref_val: &str) -> bool {
154        self.read(|g| g.contains(ref_val))
155    }
156
157    /// Current graph version.
158    pub fn version(&self) -> u64 {
159        self.read(|g| g.version())
160    }
161
162    /// Run a filter and return matching entity dicts (cloned).
163    pub fn read_all(&self, filter_expr: &str, limit: usize) -> Result<Vec<HDict>, GraphError> {
164        self.read(|g| {
165            g.read_all(filter_expr, limit)
166                .map(|refs| refs.into_iter().cloned().collect())
167        })
168    }
169
170    /// Get ref values that the given entity points to.
171    pub fn refs_from(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
172        self.read(|g| g.refs_from(ref_val, ref_type))
173    }
174
175    /// Get ref values of entities that point to the given entity.
176    pub fn refs_to(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
177        self.read(|g| g.refs_to(ref_val, ref_type))
178    }
179
180    /// Get changelog entries since a given version.
181    ///
182    /// Returns `Err(ChangelogGap)` if the requested version has been evicted.
183    pub fn changes_since(
184        &self,
185        version: u64,
186    ) -> Result<Vec<super::changelog::GraphDiff>, super::changelog::ChangelogGap> {
187        self.read(|g| {
188            g.changes_since(version)
189                .map(|refs| refs.into_iter().cloned().collect())
190        })
191    }
192
193    /// Validate all entities against the namespace and check for dangling refs.
194    ///
195    /// See [`EntityGraph::validate`].
196    pub fn validate(&self) -> Vec<ValidationIssue> {
197        self.read(|g| g.validate())
198    }
199
200    /// Walk a chain of ref tags. See [`EntityGraph::ref_chain`].
201    pub fn ref_chain(&self, ref_val: &str, ref_tags: &[&str]) -> Vec<HDict> {
202        self.read(|g| {
203            g.ref_chain(ref_val, ref_tags)
204                .into_iter()
205                .cloned()
206                .collect()
207        })
208    }
209
210    /// Resolve the site for any entity. See [`EntityGraph::site_for`].
211    pub fn site_for(&self, ref_val: &str) -> Option<HDict> {
212        self.read(|g| g.site_for(ref_val).cloned())
213    }
214
215    /// All direct children of an entity. See [`EntityGraph::children`].
216    pub fn children(&self, ref_val: &str) -> Vec<HDict> {
217        self.read(|g| g.children(ref_val).into_iter().cloned().collect())
218    }
219
220    /// All points for an equip, optionally filtered. See [`EntityGraph::equip_points`].
221    pub fn equip_points(
222        &self,
223        equip_ref: &str,
224        filter: Option<&str>,
225    ) -> Result<Vec<HDict>, GraphError> {
226        self.read(|g| {
227            g.equip_points(equip_ref, filter)
228                .map(|v| v.into_iter().cloned().collect())
229        })
230    }
231
232    /// Build a hierarchy tree. See [`EntityGraph::hierarchy_tree`].
233    pub fn hierarchy_tree(&self, root: &str, max_depth: usize) -> Option<HierarchyNode> {
234        self.read(|g| g.hierarchy_tree(root, max_depth))
235    }
236
237    /// Classify an entity. See [`EntityGraph::classify`].
238    pub fn classify(&self, ref_val: &str) -> Option<String> {
239        self.read(|g| g.classify(ref_val))
240    }
241}
242
243impl Default for SharedGraph {
244    fn default() -> Self {
245        Self::new(EntityGraph::new())
246    }
247}
248
249impl Clone for SharedGraph {
250    fn clone(&self) -> Self {
251        Self {
252            inner: Arc::clone(&self.inner),
253            tx: self.tx.clone(),
254        }
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::kinds::{HRef, Kind};
262
263    fn make_site(id: &str) -> HDict {
264        let mut d = HDict::new();
265        d.set("id", Kind::Ref(HRef::from_val(id)));
266        d.set("site", Kind::Marker);
267        d.set("dis", Kind::Str(format!("Site {id}")));
268        d
269    }
270
271    #[test]
272    fn thread_safe_add_get() {
273        let sg = SharedGraph::new(EntityGraph::new());
274        sg.add(make_site("site-1")).unwrap();
275
276        let entity = sg.get("site-1").unwrap();
277        assert!(entity.has("site"));
278    }
279
280    #[test]
281    fn concurrent_read_access() {
282        let sg = SharedGraph::new(EntityGraph::new());
283        sg.add(make_site("site-1")).unwrap();
284
285        // Multiple reads at the "same time" via clone.
286        let sg2 = sg.clone();
287
288        let entity1 = sg.get("site-1");
289        let entity2 = sg2.get("site-1");
290        assert!(entity1.is_some());
291        assert!(entity2.is_some());
292    }
293
294    #[test]
295    fn clone_shares_state() {
296        let sg = SharedGraph::new(EntityGraph::new());
297        let sg2 = sg.clone();
298
299        sg.add(make_site("site-1")).unwrap();
300
301        // sg2 should see the entity added via sg.
302        assert!(sg2.get("site-1").is_some());
303        assert_eq!(sg2.len(), 1);
304    }
305
306    #[test]
307    fn convenience_methods() {
308        let sg = SharedGraph::new(EntityGraph::new());
309        assert!(sg.is_empty());
310        assert_eq!(sg.version(), 0);
311
312        sg.add(make_site("site-1")).unwrap();
313        assert_eq!(sg.len(), 1);
314        assert_eq!(sg.version(), 1);
315
316        let mut changes = HDict::new();
317        changes.set("dis", Kind::Str("Updated".into()));
318        sg.update("site-1", changes).unwrap();
319        assert_eq!(sg.version(), 2);
320
321        let grid = sg.read_filter("site", 0).unwrap();
322        assert_eq!(grid.len(), 1);
323
324        sg.remove("site-1").unwrap();
325        assert!(sg.is_empty());
326    }
327
328    #[test]
329    fn concurrent_writes_from_threads() {
330        use std::thread;
331
332        let sg = SharedGraph::new(EntityGraph::new());
333        let mut handles = Vec::new();
334
335        for i in 0..10 {
336            let sg_clone = sg.clone();
337            handles.push(thread::spawn(move || {
338                let id = format!("site-{i}");
339                sg_clone.add(make_site(&id)).unwrap();
340            }));
341        }
342
343        for h in handles {
344            h.join().unwrap();
345        }
346
347        assert_eq!(sg.len(), 10);
348    }
349
350    #[test]
351    fn contains_check() {
352        let sg = SharedGraph::new(EntityGraph::new());
353        sg.add(make_site("site-1")).unwrap();
354        assert!(sg.contains("site-1"));
355        assert!(!sg.contains("site-2"));
356    }
357
358    #[test]
359    fn default_creates_empty() {
360        let sg = SharedGraph::default();
361        assert!(sg.is_empty());
362        assert_eq!(sg.len(), 0);
363        assert_eq!(sg.version(), 0);
364    }
365
366    #[test]
367    fn read_all_filter() {
368        let sg = SharedGraph::new(EntityGraph::new());
369        sg.add(make_site("site-1")).unwrap();
370        sg.add(make_site("site-2")).unwrap();
371
372        let mut equip = HDict::new();
373        equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
374        equip.set("equip", Kind::Marker);
375        equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
376        sg.add(equip).unwrap();
377
378        let results = sg.read_all("site", 0).unwrap();
379        assert_eq!(results.len(), 2);
380    }
381
382    #[test]
383    fn concurrent_reads_from_threads() {
384        use std::thread;
385
386        let sg = SharedGraph::new(EntityGraph::new());
387        for i in 0..20 {
388            sg.add(make_site(&format!("site-{i}"))).unwrap();
389        }
390
391        let mut handles = Vec::new();
392        for _ in 0..8 {
393            let sg_clone = sg.clone();
394            handles.push(thread::spawn(move || {
395                assert_eq!(sg_clone.len(), 20);
396                for i in 0..20 {
397                    assert!(sg_clone.contains(&format!("site-{i}")));
398                }
399            }));
400        }
401
402        for h in handles {
403            h.join().unwrap();
404        }
405    }
406
407    #[test]
408    fn concurrent_read_write_mix() {
409        use std::thread;
410
411        let sg = SharedGraph::new(EntityGraph::new());
412        // Pre-populate
413        for i in 0..5 {
414            sg.add(make_site(&format!("site-{i}"))).unwrap();
415        }
416
417        let mut handles = Vec::new();
418
419        // Writer thread: add more entities
420        let sg_writer = sg.clone();
421        handles.push(thread::spawn(move || {
422            for i in 5..15 {
423                sg_writer.add(make_site(&format!("site-{i}"))).unwrap();
424            }
425        }));
426
427        // Reader threads: read existing entities
428        for _ in 0..4 {
429            let sg_reader = sg.clone();
430            handles.push(thread::spawn(move || {
431                // Just verify no panics and consistent reads
432                let _len = sg_reader.len();
433                for i in 0..5 {
434                    let _entity = sg_reader.get(&format!("site-{i}"));
435                }
436            }));
437        }
438
439        for h in handles {
440            h.join().unwrap();
441        }
442
443        assert_eq!(sg.len(), 15);
444    }
445
446    #[test]
447    fn version_tracking_across_operations() {
448        let sg = SharedGraph::new(EntityGraph::new());
449        assert_eq!(sg.version(), 0);
450
451        sg.add(make_site("site-1")).unwrap();
452        assert_eq!(sg.version(), 1);
453
454        let mut changes = HDict::new();
455        changes.set("dis", Kind::Str("Updated".into()));
456        sg.update("site-1", changes).unwrap();
457        assert_eq!(sg.version(), 2);
458
459        sg.remove("site-1").unwrap();
460        assert_eq!(sg.version(), 3);
461    }
462
463    #[test]
464    fn refs_from_and_to() {
465        let sg = SharedGraph::new(EntityGraph::new());
466        sg.add(make_site("site-1")).unwrap();
467
468        let mut equip = HDict::new();
469        equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
470        equip.set("equip", Kind::Marker);
471        equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
472        sg.add(equip).unwrap();
473
474        let targets = sg.refs_from("equip-1", None);
475        assert_eq!(targets, vec!["site-1".to_string()]);
476
477        let sources = sg.refs_to("site-1", None);
478        assert_eq!(sources.len(), 1);
479    }
480
481    #[test]
482    fn changes_since_through_shared() {
483        let sg = SharedGraph::new(EntityGraph::new());
484        sg.add(make_site("site-1")).unwrap();
485        sg.add(make_site("site-2")).unwrap();
486
487        let changes = sg.changes_since(0).unwrap();
488        assert_eq!(changes.len(), 2);
489
490        let changes = sg.changes_since(1).unwrap();
491        assert_eq!(changes.len(), 1);
492        assert_eq!(changes[0].ref_val, "site-2");
493    }
494
495    #[test]
496    fn subscribe_receives_versions() {
497        let sg = SharedGraph::new(EntityGraph::new());
498        let mut rx = sg.subscribe();
499        assert_eq!(sg.subscriber_count(), 1);
500
501        sg.add(make_site("site-1")).unwrap();
502        sg.add(make_site("site-2")).unwrap();
503
504        // Receiver should have the two versions.
505        assert_eq!(rx.try_recv().unwrap(), 1);
506        assert_eq!(rx.try_recv().unwrap(), 2);
507        assert!(rx.try_recv().is_err()); // no more
508    }
509
510    #[test]
511    fn broadcast_on_update_and_remove() {
512        let sg = SharedGraph::new(EntityGraph::new());
513        sg.add(make_site("site-1")).unwrap();
514
515        let mut rx = sg.subscribe();
516
517        let mut changes = HDict::new();
518        changes.set("dis", Kind::Str("Updated".into()));
519        sg.update("site-1", changes).unwrap();
520        sg.remove("site-1").unwrap();
521
522        assert_eq!(rx.try_recv().unwrap(), 2); // update
523        assert_eq!(rx.try_recv().unwrap(), 3); // remove
524    }
525
526    #[test]
527    fn no_subscribers_does_not_panic() {
528        let sg = SharedGraph::new(EntityGraph::new());
529        // No subscribers — write should still succeed.
530        sg.add(make_site("site-1")).unwrap();
531        assert_eq!(sg.len(), 1);
532    }
533}