Skip to main content

haystack_core/graph/
shared.rs

1// SharedGraph — thread-safe wrapper around EntityGraph using parking_lot RwLock.
2
3use parking_lot::RwLock;
4use std::sync::Arc;
5use tokio::sync::broadcast;
6
7use crate::data::{HDict, HGrid};
8use crate::ontology::ValidationIssue;
9
10use super::entity_graph::{EntityGraph, GraphError, HierarchyNode};
11
12/// Default broadcast channel capacity.
13const BROADCAST_CAPACITY: usize = 256;
14
15/// Thread-safe, clonable handle to an `EntityGraph`.
16///
17/// Uses `parking_lot::RwLock` for reader-writer locking, allowing
18/// concurrent reads with exclusive writes. Cloning shares the
19/// underlying graph (via `Arc`).
20///
21/// Write operations automatically send the new graph version on an
22/// internal broadcast channel. Call [`subscribe`](SharedGraph::subscribe)
23/// to get a receiver.
24pub struct SharedGraph {
25    inner: Arc<RwLock<EntityGraph>>,
26    tx: broadcast::Sender<u64>,
27}
28
29impl SharedGraph {
30    /// Wrap an `EntityGraph` in a thread-safe handle.
31    pub fn new(graph: EntityGraph) -> Self {
32        let (tx, _) = broadcast::channel(BROADCAST_CAPACITY);
33        Self {
34            inner: Arc::new(RwLock::new(graph)),
35            tx,
36        }
37    }
38
39    /// Subscribe to graph change notifications.
40    ///
41    /// Returns a receiver that yields the new graph version after each
42    /// write operation (add, update, remove).
43    pub fn subscribe(&self) -> broadcast::Receiver<u64> {
44        self.tx.subscribe()
45    }
46
47    /// Number of active subscribers.
48    pub fn subscriber_count(&self) -> usize {
49        self.tx.receiver_count()
50    }
51
52    /// Execute a closure with shared (read) access to the graph.
53    pub fn read<F, R>(&self, f: F) -> R
54    where
55        F: FnOnce(&EntityGraph) -> R,
56    {
57        let guard = self.inner.read();
58        f(&guard)
59    }
60
61    /// Execute a closure with exclusive (write) access to the graph.
62    pub fn write<F, R>(&self, f: F) -> R
63    where
64        F: FnOnce(&mut EntityGraph) -> R,
65    {
66        let mut guard = self.inner.write();
67        f(&mut guard)
68    }
69
70    /// Execute a write closure and broadcast the new version if it changed.
71    fn write_and_notify<F, R>(&self, f: F) -> R
72    where
73        F: FnOnce(&mut EntityGraph) -> R,
74    {
75        let (result, version) = {
76            let mut guard = self.inner.write();
77            let v_before = guard.version();
78            let result = f(&mut guard);
79            let v_after = guard.version();
80            (
81                result,
82                if v_after != v_before {
83                    Some(v_after)
84                } else {
85                    None
86                },
87            )
88        };
89        // Send outside the lock to avoid holding it during broadcast.
90        if let Some(v) = version {
91            let _ = self.tx.send(v);
92        }
93        result
94    }
95
96    // ── Convenience methods ──
97
98    /// Add an entity. See [`EntityGraph::add`].
99    pub fn add(&self, entity: HDict) -> Result<String, GraphError> {
100        self.write_and_notify(|g| g.add(entity))
101    }
102
103    /// Get an entity by ref value.
104    ///
105    /// Returns an owned clone because the read lock is released before
106    /// the caller uses the value.
107    pub fn get(&self, ref_val: &str) -> Option<HDict> {
108        self.read(|g| g.get(ref_val).cloned())
109    }
110
111    /// Update an entity. See [`EntityGraph::update`].
112    pub fn update(&self, ref_val: &str, changes: HDict) -> Result<(), GraphError> {
113        self.write_and_notify(|g| g.update(ref_val, changes))
114    }
115
116    /// Remove an entity. See [`EntityGraph::remove`].
117    pub fn remove(&self, ref_val: &str) -> Result<HDict, GraphError> {
118        self.write_and_notify(|g| g.remove(ref_val))
119    }
120
121    /// Run a filter expression and return a grid.
122    pub fn read_filter(&self, filter_expr: &str, limit: usize) -> Result<HGrid, GraphError> {
123        self.read(|g| g.read(filter_expr, limit))
124    }
125
126    /// Number of entities.
127    pub fn len(&self) -> usize {
128        self.read(|g| g.len())
129    }
130
131    /// Returns `true` if the graph has no entities.
132    pub fn is_empty(&self) -> bool {
133        self.read(|g| g.is_empty())
134    }
135
136    /// Return all entities as owned clones.
137    pub fn all_entities(&self) -> Vec<HDict> {
138        self.read(|g| g.all().into_iter().cloned().collect())
139    }
140
141    /// Check if an entity with the given ref value exists.
142    pub fn contains(&self, ref_val: &str) -> bool {
143        self.read(|g| g.contains(ref_val))
144    }
145
146    /// Current graph version.
147    pub fn version(&self) -> u64 {
148        self.read(|g| g.version())
149    }
150
151    /// Run a filter and return matching entity dicts (cloned).
152    pub fn read_all(&self, filter_expr: &str, limit: usize) -> Result<Vec<HDict>, GraphError> {
153        self.read(|g| {
154            g.read_all(filter_expr, limit)
155                .map(|refs| refs.into_iter().cloned().collect())
156        })
157    }
158
159    /// Get ref values that the given entity points to.
160    pub fn refs_from(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
161        self.read(|g| g.refs_from(ref_val, ref_type))
162    }
163
164    /// Get ref values of entities that point to the given entity.
165    pub fn refs_to(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
166        self.read(|g| g.refs_to(ref_val, ref_type))
167    }
168
169    /// Get changelog entries since a given version.
170    ///
171    /// Returns `Err(ChangelogGap)` if the requested version has been evicted.
172    pub fn changes_since(
173        &self,
174        version: u64,
175    ) -> Result<Vec<super::changelog::GraphDiff>, super::changelog::ChangelogGap> {
176        self.read(|g| {
177            g.changes_since(version)
178                .map(|refs| refs.into_iter().cloned().collect())
179        })
180    }
181
182    /// Find all entities that structurally fit a spec/type name.
183    ///
184    /// Returns owned clones. See [`EntityGraph::entities_fitting`].
185    pub fn entities_fitting(&self, spec_name: &str) -> Vec<HDict> {
186        self.read(|g| g.entities_fitting(spec_name).into_iter().cloned().collect())
187    }
188
189    /// Validate all entities against the namespace and check for dangling refs.
190    ///
191    /// See [`EntityGraph::validate`].
192    pub fn validate(&self) -> Vec<ValidationIssue> {
193        self.read(|g| g.validate())
194    }
195
196    /// Return all edges as `(source_ref, ref_tag, target_ref)` tuples.
197    pub fn all_edges(&self) -> Vec<(String, String, String)> {
198        self.read(|g| g.all_edges())
199    }
200
201    /// BFS neighborhood: entities and edges within `hops` of `ref_val`.
202    pub fn neighbors(
203        &self,
204        ref_val: &str,
205        hops: usize,
206        ref_types: Option<&[&str]>,
207    ) -> (Vec<HDict>, Vec<(String, String, String)>) {
208        self.read(|g| {
209            let (entities, edges) = g.neighbors(ref_val, hops, ref_types);
210            (entities.into_iter().cloned().collect(), edges)
211        })
212    }
213
214    /// BFS shortest path from `from` to `to`.
215    pub fn shortest_path(&self, from: &str, to: &str) -> Vec<String> {
216        self.read(|g| g.shortest_path(from, to))
217    }
218
219    /// Subtree rooted at `root` up to `max_depth` levels.
220    ///
221    /// Returns entities with their depth from root.
222    pub fn subtree(&self, root: &str, max_depth: usize) -> Vec<(HDict, usize)> {
223        self.read(|g| {
224            g.subtree(root, max_depth)
225                .into_iter()
226                .map(|(e, d)| (e.clone(), d))
227                .collect()
228        })
229    }
230
231    /// Walk a chain of ref tags. See [`EntityGraph::ref_chain`].
232    pub fn ref_chain(&self, ref_val: &str, ref_tags: &[&str]) -> Vec<HDict> {
233        self.read(|g| {
234            g.ref_chain(ref_val, ref_tags)
235                .into_iter()
236                .cloned()
237                .collect()
238        })
239    }
240
241    /// Resolve the site for any entity. See [`EntityGraph::site_for`].
242    pub fn site_for(&self, ref_val: &str) -> Option<HDict> {
243        self.read(|g| g.site_for(ref_val).cloned())
244    }
245
246    /// All direct children of an entity. See [`EntityGraph::children`].
247    pub fn children(&self, ref_val: &str) -> Vec<HDict> {
248        self.read(|g| g.children(ref_val).into_iter().cloned().collect())
249    }
250
251    /// All points for an equip, optionally filtered. See [`EntityGraph::equip_points`].
252    pub fn equip_points(&self, equip_ref: &str, filter: Option<&str>) -> Vec<HDict> {
253        self.read(|g| {
254            g.equip_points(equip_ref, filter)
255                .into_iter()
256                .cloned()
257                .collect()
258        })
259    }
260
261    /// Build a hierarchy tree. See [`EntityGraph::hierarchy_tree`].
262    pub fn hierarchy_tree(&self, root: &str, max_depth: usize) -> Option<HierarchyNode> {
263        self.read(|g| g.hierarchy_tree(root, max_depth))
264    }
265
266    /// Classify an entity. See [`EntityGraph::classify`].
267    pub fn classify(&self, ref_val: &str) -> Option<String> {
268        self.read(|g| g.classify(ref_val))
269    }
270}
271
272impl Default for SharedGraph {
273    fn default() -> Self {
274        Self::new(EntityGraph::new())
275    }
276}
277
278impl Clone for SharedGraph {
279    fn clone(&self) -> Self {
280        Self {
281            inner: Arc::clone(&self.inner),
282            tx: self.tx.clone(),
283        }
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::kinds::{HRef, Kind};
291
292    fn make_site(id: &str) -> HDict {
293        let mut d = HDict::new();
294        d.set("id", Kind::Ref(HRef::from_val(id)));
295        d.set("site", Kind::Marker);
296        d.set("dis", Kind::Str(format!("Site {id}")));
297        d
298    }
299
300    #[test]
301    fn thread_safe_add_get() {
302        let sg = SharedGraph::new(EntityGraph::new());
303        sg.add(make_site("site-1")).unwrap();
304
305        let entity = sg.get("site-1").unwrap();
306        assert!(entity.has("site"));
307    }
308
309    #[test]
310    fn concurrent_read_access() {
311        let sg = SharedGraph::new(EntityGraph::new());
312        sg.add(make_site("site-1")).unwrap();
313
314        // Multiple reads at the "same time" via clone.
315        let sg2 = sg.clone();
316
317        let entity1 = sg.get("site-1");
318        let entity2 = sg2.get("site-1");
319        assert!(entity1.is_some());
320        assert!(entity2.is_some());
321    }
322
323    #[test]
324    fn clone_shares_state() {
325        let sg = SharedGraph::new(EntityGraph::new());
326        let sg2 = sg.clone();
327
328        sg.add(make_site("site-1")).unwrap();
329
330        // sg2 should see the entity added via sg.
331        assert!(sg2.get("site-1").is_some());
332        assert_eq!(sg2.len(), 1);
333    }
334
335    #[test]
336    fn convenience_methods() {
337        let sg = SharedGraph::new(EntityGraph::new());
338        assert!(sg.is_empty());
339        assert_eq!(sg.version(), 0);
340
341        sg.add(make_site("site-1")).unwrap();
342        assert_eq!(sg.len(), 1);
343        assert_eq!(sg.version(), 1);
344
345        let mut changes = HDict::new();
346        changes.set("dis", Kind::Str("Updated".into()));
347        sg.update("site-1", changes).unwrap();
348        assert_eq!(sg.version(), 2);
349
350        let grid = sg.read_filter("site", 0).unwrap();
351        assert_eq!(grid.len(), 1);
352
353        sg.remove("site-1").unwrap();
354        assert!(sg.is_empty());
355    }
356
357    #[test]
358    fn concurrent_writes_from_threads() {
359        use std::thread;
360
361        let sg = SharedGraph::new(EntityGraph::new());
362        let mut handles = Vec::new();
363
364        for i in 0..10 {
365            let sg_clone = sg.clone();
366            handles.push(thread::spawn(move || {
367                let id = format!("site-{i}");
368                sg_clone.add(make_site(&id)).unwrap();
369            }));
370        }
371
372        for h in handles {
373            h.join().unwrap();
374        }
375
376        assert_eq!(sg.len(), 10);
377    }
378
379    #[test]
380    fn contains_check() {
381        let sg = SharedGraph::new(EntityGraph::new());
382        sg.add(make_site("site-1")).unwrap();
383        assert!(sg.contains("site-1"));
384        assert!(!sg.contains("site-2"));
385    }
386
387    #[test]
388    fn default_creates_empty() {
389        let sg = SharedGraph::default();
390        assert!(sg.is_empty());
391        assert_eq!(sg.len(), 0);
392        assert_eq!(sg.version(), 0);
393    }
394
395    #[test]
396    fn read_all_filter() {
397        let sg = SharedGraph::new(EntityGraph::new());
398        sg.add(make_site("site-1")).unwrap();
399        sg.add(make_site("site-2")).unwrap();
400
401        let mut equip = HDict::new();
402        equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
403        equip.set("equip", Kind::Marker);
404        equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
405        sg.add(equip).unwrap();
406
407        let results = sg.read_all("site", 0).unwrap();
408        assert_eq!(results.len(), 2);
409    }
410
411    #[test]
412    fn concurrent_reads_from_threads() {
413        use std::thread;
414
415        let sg = SharedGraph::new(EntityGraph::new());
416        for i in 0..20 {
417            sg.add(make_site(&format!("site-{i}"))).unwrap();
418        }
419
420        let mut handles = Vec::new();
421        for _ in 0..8 {
422            let sg_clone = sg.clone();
423            handles.push(thread::spawn(move || {
424                assert_eq!(sg_clone.len(), 20);
425                for i in 0..20 {
426                    assert!(sg_clone.contains(&format!("site-{i}")));
427                }
428            }));
429        }
430
431        for h in handles {
432            h.join().unwrap();
433        }
434    }
435
436    #[test]
437    fn concurrent_read_write_mix() {
438        use std::thread;
439
440        let sg = SharedGraph::new(EntityGraph::new());
441        // Pre-populate
442        for i in 0..5 {
443            sg.add(make_site(&format!("site-{i}"))).unwrap();
444        }
445
446        let mut handles = Vec::new();
447
448        // Writer thread: add more entities
449        let sg_writer = sg.clone();
450        handles.push(thread::spawn(move || {
451            for i in 5..15 {
452                sg_writer.add(make_site(&format!("site-{i}"))).unwrap();
453            }
454        }));
455
456        // Reader threads: read existing entities
457        for _ in 0..4 {
458            let sg_reader = sg.clone();
459            handles.push(thread::spawn(move || {
460                // Just verify no panics and consistent reads
461                let _len = sg_reader.len();
462                for i in 0..5 {
463                    let _entity = sg_reader.get(&format!("site-{i}"));
464                }
465            }));
466        }
467
468        for h in handles {
469            h.join().unwrap();
470        }
471
472        assert_eq!(sg.len(), 15);
473    }
474
475    #[test]
476    fn version_tracking_across_operations() {
477        let sg = SharedGraph::new(EntityGraph::new());
478        assert_eq!(sg.version(), 0);
479
480        sg.add(make_site("site-1")).unwrap();
481        assert_eq!(sg.version(), 1);
482
483        let mut changes = HDict::new();
484        changes.set("dis", Kind::Str("Updated".into()));
485        sg.update("site-1", changes).unwrap();
486        assert_eq!(sg.version(), 2);
487
488        sg.remove("site-1").unwrap();
489        assert_eq!(sg.version(), 3);
490    }
491
492    #[test]
493    fn refs_from_and_to() {
494        let sg = SharedGraph::new(EntityGraph::new());
495        sg.add(make_site("site-1")).unwrap();
496
497        let mut equip = HDict::new();
498        equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
499        equip.set("equip", Kind::Marker);
500        equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
501        sg.add(equip).unwrap();
502
503        let targets = sg.refs_from("equip-1", None);
504        assert_eq!(targets, vec!["site-1".to_string()]);
505
506        let sources = sg.refs_to("site-1", None);
507        assert_eq!(sources.len(), 1);
508    }
509
510    #[test]
511    fn changes_since_through_shared() {
512        let sg = SharedGraph::new(EntityGraph::new());
513        sg.add(make_site("site-1")).unwrap();
514        sg.add(make_site("site-2")).unwrap();
515
516        let changes = sg.changes_since(0).unwrap();
517        assert_eq!(changes.len(), 2);
518
519        let changes = sg.changes_since(1).unwrap();
520        assert_eq!(changes.len(), 1);
521        assert_eq!(changes[0].ref_val, "site-2");
522    }
523
524    #[test]
525    fn subscribe_receives_versions() {
526        let sg = SharedGraph::new(EntityGraph::new());
527        let mut rx = sg.subscribe();
528        assert_eq!(sg.subscriber_count(), 1);
529
530        sg.add(make_site("site-1")).unwrap();
531        sg.add(make_site("site-2")).unwrap();
532
533        // Receiver should have the two versions.
534        assert_eq!(rx.try_recv().unwrap(), 1);
535        assert_eq!(rx.try_recv().unwrap(), 2);
536        assert!(rx.try_recv().is_err()); // no more
537    }
538
539    #[test]
540    fn broadcast_on_update_and_remove() {
541        let sg = SharedGraph::new(EntityGraph::new());
542        sg.add(make_site("site-1")).unwrap();
543
544        let mut rx = sg.subscribe();
545
546        let mut changes = HDict::new();
547        changes.set("dis", Kind::Str("Updated".into()));
548        sg.update("site-1", changes).unwrap();
549        sg.remove("site-1").unwrap();
550
551        assert_eq!(rx.try_recv().unwrap(), 2); // update
552        assert_eq!(rx.try_recv().unwrap(), 3); // remove
553    }
554
555    #[test]
556    fn no_subscribers_does_not_panic() {
557        let sg = SharedGraph::new(EntityGraph::new());
558        // No subscribers — write should still succeed.
559        sg.add(make_site("site-1")).unwrap();
560        assert_eq!(sg.len(), 1);
561    }
562}