Skip to main content

haystack_core/graph/
shared.rs

1// SharedGraph — thread-safe wrapper around EntityGraph using parking_lot RwLock.
2
3use parking_lot::RwLock;
4use std::sync::Arc;
5
6use crate::data::{HDict, HGrid};
7use crate::ontology::ValidationIssue;
8
9use super::entity_graph::{EntityGraph, GraphError};
10
11/// Thread-safe, clonable handle to an `EntityGraph`.
12///
13/// Uses `parking_lot::RwLock` for reader-writer locking, allowing
14/// concurrent reads with exclusive writes. Cloning shares the
15/// underlying graph (via `Arc`).
16pub struct SharedGraph {
17    inner: Arc<RwLock<EntityGraph>>,
18}
19
20impl SharedGraph {
21    /// Wrap an `EntityGraph` in a thread-safe handle.
22    pub fn new(graph: EntityGraph) -> Self {
23        Self {
24            inner: Arc::new(RwLock::new(graph)),
25        }
26    }
27
28    /// Execute a closure with shared (read) access to the graph.
29    pub fn read<F, R>(&self, f: F) -> R
30    where
31        F: FnOnce(&EntityGraph) -> R,
32    {
33        let guard = self.inner.read();
34        f(&guard)
35    }
36
37    /// Execute a closure with exclusive (write) access to the graph.
38    pub fn write<F, R>(&self, f: F) -> R
39    where
40        F: FnOnce(&mut EntityGraph) -> R,
41    {
42        let mut guard = self.inner.write();
43        f(&mut guard)
44    }
45
46    // ── Convenience methods ──
47
48    /// Add an entity. See [`EntityGraph::add`].
49    pub fn add(&self, entity: HDict) -> Result<String, GraphError> {
50        self.write(|g| g.add(entity))
51    }
52
53    /// Get an entity by ref value.
54    ///
55    /// Returns an owned clone because the read lock is released before
56    /// the caller uses the value.
57    pub fn get(&self, ref_val: &str) -> Option<HDict> {
58        self.read(|g| g.get(ref_val).cloned())
59    }
60
61    /// Update an entity. See [`EntityGraph::update`].
62    pub fn update(&self, ref_val: &str, changes: HDict) -> Result<(), GraphError> {
63        self.write(|g| g.update(ref_val, changes))
64    }
65
66    /// Remove an entity. See [`EntityGraph::remove`].
67    pub fn remove(&self, ref_val: &str) -> Result<HDict, GraphError> {
68        self.write(|g| g.remove(ref_val))
69    }
70
71    /// Run a filter expression and return a grid.
72    pub fn read_filter(&self, filter_expr: &str, limit: usize) -> Result<HGrid, GraphError> {
73        self.read(|g| g.read(filter_expr, limit))
74    }
75
76    /// Number of entities.
77    pub fn len(&self) -> usize {
78        self.read(|g| g.len())
79    }
80
81    /// Returns `true` if the graph has no entities.
82    pub fn is_empty(&self) -> bool {
83        self.read(|g| g.is_empty())
84    }
85
86    /// Return all entities as owned clones.
87    pub fn all_entities(&self) -> Vec<HDict> {
88        self.read(|g| g.all().into_iter().cloned().collect())
89    }
90
91    /// Check if an entity with the given ref value exists.
92    pub fn contains(&self, ref_val: &str) -> bool {
93        self.read(|g| g.contains(ref_val))
94    }
95
96    /// Current graph version.
97    pub fn version(&self) -> u64 {
98        self.read(|g| g.version())
99    }
100
101    /// Run a filter and return matching entity dicts (cloned).
102    pub fn read_all(&self, filter_expr: &str, limit: usize) -> Result<Vec<HDict>, GraphError> {
103        self.read(|g| {
104            g.read_all(filter_expr, limit)
105                .map(|refs| refs.into_iter().cloned().collect())
106        })
107    }
108
109    /// Get ref values that the given entity points to.
110    pub fn refs_from(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
111        self.read(|g| g.refs_from(ref_val, ref_type))
112    }
113
114    /// Get ref values of entities that point to the given entity.
115    pub fn refs_to(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
116        self.read(|g| g.refs_to(ref_val, ref_type))
117    }
118
119    /// Get changelog entries since a given version.
120    pub fn changes_since(&self, version: u64) -> Vec<super::changelog::GraphDiff> {
121        self.read(|g| g.changes_since(version).into_iter().cloned().collect())
122    }
123
124    /// Find all entities that structurally fit a spec/type name.
125    ///
126    /// Returns owned clones. See [`EntityGraph::entities_fitting`].
127    pub fn entities_fitting(&self, spec_name: &str) -> Vec<HDict> {
128        self.read(|g| g.entities_fitting(spec_name).into_iter().cloned().collect())
129    }
130
131    /// Validate all entities against the namespace and check for dangling refs.
132    ///
133    /// See [`EntityGraph::validate`].
134    pub fn validate(&self) -> Vec<ValidationIssue> {
135        self.read(|g| g.validate())
136    }
137
138    /// Return all edges as `(source_ref, ref_tag, target_ref)` tuples.
139    pub fn all_edges(&self) -> Vec<(String, String, String)> {
140        self.read(|g| g.all_edges())
141    }
142
143    /// BFS neighborhood: entities and edges within `hops` of `ref_val`.
144    pub fn neighbors(
145        &self,
146        ref_val: &str,
147        hops: usize,
148        ref_types: Option<&[&str]>,
149    ) -> (Vec<HDict>, Vec<(String, String, String)>) {
150        self.read(|g| {
151            let (entities, edges) = g.neighbors(ref_val, hops, ref_types);
152            (entities.into_iter().cloned().collect(), edges)
153        })
154    }
155
156    /// BFS shortest path from `from` to `to`.
157    pub fn shortest_path(&self, from: &str, to: &str) -> Vec<String> {
158        self.read(|g| g.shortest_path(from, to))
159    }
160
161    /// Subtree rooted at `root` up to `max_depth` levels.
162    ///
163    /// Returns entities with their depth from root.
164    pub fn subtree(&self, root: &str, max_depth: usize) -> Vec<(HDict, usize)> {
165        self.read(|g| {
166            g.subtree(root, max_depth)
167                .into_iter()
168                .map(|(e, d)| (e.clone(), d))
169                .collect()
170        })
171    }
172}
173
174impl Default for SharedGraph {
175    fn default() -> Self {
176        Self::new(EntityGraph::new())
177    }
178}
179
180impl Clone for SharedGraph {
181    fn clone(&self) -> Self {
182        Self {
183            inner: Arc::clone(&self.inner),
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::kinds::{HRef, Kind};
192
193    fn make_site(id: &str) -> HDict {
194        let mut d = HDict::new();
195        d.set("id", Kind::Ref(HRef::from_val(id)));
196        d.set("site", Kind::Marker);
197        d.set("dis", Kind::Str(format!("Site {id}")));
198        d
199    }
200
201    #[test]
202    fn thread_safe_add_get() {
203        let sg = SharedGraph::new(EntityGraph::new());
204        sg.add(make_site("site-1")).unwrap();
205
206        let entity = sg.get("site-1").unwrap();
207        assert!(entity.has("site"));
208    }
209
210    #[test]
211    fn concurrent_read_access() {
212        let sg = SharedGraph::new(EntityGraph::new());
213        sg.add(make_site("site-1")).unwrap();
214
215        // Multiple reads at the "same time" via clone.
216        let sg2 = sg.clone();
217
218        let entity1 = sg.get("site-1");
219        let entity2 = sg2.get("site-1");
220        assert!(entity1.is_some());
221        assert!(entity2.is_some());
222    }
223
224    #[test]
225    fn clone_shares_state() {
226        let sg = SharedGraph::new(EntityGraph::new());
227        let sg2 = sg.clone();
228
229        sg.add(make_site("site-1")).unwrap();
230
231        // sg2 should see the entity added via sg.
232        assert!(sg2.get("site-1").is_some());
233        assert_eq!(sg2.len(), 1);
234    }
235
236    #[test]
237    fn convenience_methods() {
238        let sg = SharedGraph::new(EntityGraph::new());
239        assert!(sg.is_empty());
240        assert_eq!(sg.version(), 0);
241
242        sg.add(make_site("site-1")).unwrap();
243        assert_eq!(sg.len(), 1);
244        assert_eq!(sg.version(), 1);
245
246        let mut changes = HDict::new();
247        changes.set("dis", Kind::Str("Updated".into()));
248        sg.update("site-1", changes).unwrap();
249        assert_eq!(sg.version(), 2);
250
251        let grid = sg.read_filter("site", 0).unwrap();
252        assert_eq!(grid.len(), 1);
253
254        sg.remove("site-1").unwrap();
255        assert!(sg.is_empty());
256    }
257
258    #[test]
259    fn concurrent_writes_from_threads() {
260        use std::thread;
261
262        let sg = SharedGraph::new(EntityGraph::new());
263        let mut handles = Vec::new();
264
265        for i in 0..10 {
266            let sg_clone = sg.clone();
267            handles.push(thread::spawn(move || {
268                let id = format!("site-{i}");
269                sg_clone.add(make_site(&id)).unwrap();
270            }));
271        }
272
273        for h in handles {
274            h.join().unwrap();
275        }
276
277        assert_eq!(sg.len(), 10);
278    }
279
280    #[test]
281    fn contains_check() {
282        let sg = SharedGraph::new(EntityGraph::new());
283        sg.add(make_site("site-1")).unwrap();
284        assert!(sg.contains("site-1"));
285        assert!(!sg.contains("site-2"));
286    }
287
288    #[test]
289    fn default_creates_empty() {
290        let sg = SharedGraph::default();
291        assert!(sg.is_empty());
292        assert_eq!(sg.len(), 0);
293        assert_eq!(sg.version(), 0);
294    }
295
296    #[test]
297    fn read_all_filter() {
298        let sg = SharedGraph::new(EntityGraph::new());
299        sg.add(make_site("site-1")).unwrap();
300        sg.add(make_site("site-2")).unwrap();
301
302        let mut equip = HDict::new();
303        equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
304        equip.set("equip", Kind::Marker);
305        equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
306        sg.add(equip).unwrap();
307
308        let results = sg.read_all("site", 0).unwrap();
309        assert_eq!(results.len(), 2);
310    }
311
312    #[test]
313    fn concurrent_reads_from_threads() {
314        use std::thread;
315
316        let sg = SharedGraph::new(EntityGraph::new());
317        for i in 0..20 {
318            sg.add(make_site(&format!("site-{i}"))).unwrap();
319        }
320
321        let mut handles = Vec::new();
322        for _ in 0..8 {
323            let sg_clone = sg.clone();
324            handles.push(thread::spawn(move || {
325                assert_eq!(sg_clone.len(), 20);
326                for i in 0..20 {
327                    assert!(sg_clone.contains(&format!("site-{i}")));
328                }
329            }));
330        }
331
332        for h in handles {
333            h.join().unwrap();
334        }
335    }
336
337    #[test]
338    fn concurrent_read_write_mix() {
339        use std::thread;
340
341        let sg = SharedGraph::new(EntityGraph::new());
342        // Pre-populate
343        for i in 0..5 {
344            sg.add(make_site(&format!("site-{i}"))).unwrap();
345        }
346
347        let mut handles = Vec::new();
348
349        // Writer thread: add more entities
350        let sg_writer = sg.clone();
351        handles.push(thread::spawn(move || {
352            for i in 5..15 {
353                sg_writer.add(make_site(&format!("site-{i}"))).unwrap();
354            }
355        }));
356
357        // Reader threads: read existing entities
358        for _ in 0..4 {
359            let sg_reader = sg.clone();
360            handles.push(thread::spawn(move || {
361                // Just verify no panics and consistent reads
362                let _len = sg_reader.len();
363                for i in 0..5 {
364                    let _entity = sg_reader.get(&format!("site-{i}"));
365                }
366            }));
367        }
368
369        for h in handles {
370            h.join().unwrap();
371        }
372
373        assert_eq!(sg.len(), 15);
374    }
375
376    #[test]
377    fn version_tracking_across_operations() {
378        let sg = SharedGraph::new(EntityGraph::new());
379        assert_eq!(sg.version(), 0);
380
381        sg.add(make_site("site-1")).unwrap();
382        assert_eq!(sg.version(), 1);
383
384        let mut changes = HDict::new();
385        changes.set("dis", Kind::Str("Updated".into()));
386        sg.update("site-1", changes).unwrap();
387        assert_eq!(sg.version(), 2);
388
389        sg.remove("site-1").unwrap();
390        assert_eq!(sg.version(), 3);
391    }
392
393    #[test]
394    fn refs_from_and_to() {
395        let sg = SharedGraph::new(EntityGraph::new());
396        sg.add(make_site("site-1")).unwrap();
397
398        let mut equip = HDict::new();
399        equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
400        equip.set("equip", Kind::Marker);
401        equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
402        sg.add(equip).unwrap();
403
404        let targets = sg.refs_from("equip-1", None);
405        assert_eq!(targets, vec!["site-1".to_string()]);
406
407        let sources = sg.refs_to("site-1", None);
408        assert_eq!(sources.len(), 1);
409    }
410
411    #[test]
412    fn changes_since_through_shared() {
413        let sg = SharedGraph::new(EntityGraph::new());
414        sg.add(make_site("site-1")).unwrap();
415        sg.add(make_site("site-2")).unwrap();
416
417        let changes = sg.changes_since(0);
418        assert_eq!(changes.len(), 2);
419
420        let changes = sg.changes_since(1);
421        assert_eq!(changes.len(), 1);
422        assert_eq!(changes[0].ref_val, "site-2");
423    }
424}