Skip to main content

haystack_core/graph/
shared.rs

1// SharedGraph — thread-safe wrapper around EntityGraph using parking_lot RwLock.
2
3use parking_lot::RwLock;
4use std::sync::Arc;
5
6use crate::data::{HDict, HGrid};
7use crate::ontology::ValidationIssue;
8
9use super::entity_graph::{EntityGraph, GraphError};
10
11/// Thread-safe, clonable handle to an `EntityGraph`.
12///
13/// Uses `parking_lot::RwLock` for reader-writer locking, allowing
14/// concurrent reads with exclusive writes. Cloning shares the
15/// underlying graph (via `Arc`).
16pub struct SharedGraph {
17    inner: Arc<RwLock<EntityGraph>>,
18}
19
20impl SharedGraph {
21    /// Wrap an `EntityGraph` in a thread-safe handle.
22    pub fn new(graph: EntityGraph) -> Self {
23        Self {
24            inner: Arc::new(RwLock::new(graph)),
25        }
26    }
27
28    /// Execute a closure with shared (read) access to the graph.
29    pub fn read<F, R>(&self, f: F) -> R
30    where
31        F: FnOnce(&EntityGraph) -> R,
32    {
33        let guard = self.inner.read();
34        f(&guard)
35    }
36
37    /// Execute a closure with exclusive (write) access to the graph.
38    pub fn write<F, R>(&self, f: F) -> R
39    where
40        F: FnOnce(&mut EntityGraph) -> R,
41    {
42        let mut guard = self.inner.write();
43        f(&mut guard)
44    }
45
46    // ── Convenience methods ──
47
48    /// Add an entity. See [`EntityGraph::add`].
49    pub fn add(&self, entity: HDict) -> Result<String, GraphError> {
50        self.write(|g| g.add(entity))
51    }
52
53    /// Get an entity by ref value.
54    ///
55    /// Returns an owned clone because the read lock is released before
56    /// the caller uses the value.
57    pub fn get(&self, ref_val: &str) -> Option<HDict> {
58        self.read(|g| g.get(ref_val).cloned())
59    }
60
61    /// Update an entity. See [`EntityGraph::update`].
62    pub fn update(&self, ref_val: &str, changes: HDict) -> Result<(), GraphError> {
63        self.write(|g| g.update(ref_val, changes))
64    }
65
66    /// Remove an entity. See [`EntityGraph::remove`].
67    pub fn remove(&self, ref_val: &str) -> Result<HDict, GraphError> {
68        self.write(|g| g.remove(ref_val))
69    }
70
71    /// Run a filter expression and return a grid.
72    pub fn read_filter(&self, filter_expr: &str, limit: usize) -> Result<HGrid, GraphError> {
73        self.read(|g| g.read(filter_expr, limit))
74    }
75
76    /// Number of entities.
77    pub fn len(&self) -> usize {
78        self.read(|g| g.len())
79    }
80
81    /// Returns `true` if the graph has no entities.
82    pub fn is_empty(&self) -> bool {
83        self.read(|g| g.is_empty())
84    }
85
86    /// Return all entities as owned clones.
87    pub fn all_entities(&self) -> Vec<HDict> {
88        self.read(|g| g.all().into_iter().cloned().collect())
89    }
90
91    /// Check if an entity with the given ref value exists.
92    pub fn contains(&self, ref_val: &str) -> bool {
93        self.read(|g| g.contains(ref_val))
94    }
95
96    /// Current graph version.
97    pub fn version(&self) -> u64 {
98        self.read(|g| g.version())
99    }
100
101    /// Run a filter and return matching entity dicts (cloned).
102    pub fn read_all(&self, filter_expr: &str, limit: usize) -> Result<Vec<HDict>, GraphError> {
103        self.read(|g| {
104            g.read_all(filter_expr, limit)
105                .map(|refs| refs.into_iter().cloned().collect())
106        })
107    }
108
109    /// Get ref values that the given entity points to.
110    pub fn refs_from(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
111        self.read(|g| g.refs_from(ref_val, ref_type))
112    }
113
114    /// Get ref values of entities that point to the given entity.
115    pub fn refs_to(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
116        self.read(|g| g.refs_to(ref_val, ref_type))
117    }
118
119    /// Get changelog entries since a given version.
120    pub fn changes_since(&self, version: u64) -> Vec<super::changelog::GraphDiff> {
121        self.read(|g| g.changes_since(version).to_vec())
122    }
123
124    /// Find all entities that structurally fit a spec/type name.
125    ///
126    /// Returns owned clones. See [`EntityGraph::entities_fitting`].
127    pub fn entities_fitting(&self, spec_name: &str) -> Vec<HDict> {
128        self.read(|g| g.entities_fitting(spec_name).into_iter().cloned().collect())
129    }
130
131    /// Validate all entities against the namespace and check for dangling refs.
132    ///
133    /// See [`EntityGraph::validate`].
134    pub fn validate(&self) -> Vec<ValidationIssue> {
135        self.read(|g| g.validate())
136    }
137}
138
139impl Default for SharedGraph {
140    fn default() -> Self {
141        Self::new(EntityGraph::new())
142    }
143}
144
145impl Clone for SharedGraph {
146    fn clone(&self) -> Self {
147        Self {
148            inner: Arc::clone(&self.inner),
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::kinds::{HRef, Kind};
157
158    fn make_site(id: &str) -> HDict {
159        let mut d = HDict::new();
160        d.set("id", Kind::Ref(HRef::from_val(id)));
161        d.set("site", Kind::Marker);
162        d.set("dis", Kind::Str(format!("Site {id}")));
163        d
164    }
165
166    #[test]
167    fn thread_safe_add_get() {
168        let sg = SharedGraph::new(EntityGraph::new());
169        sg.add(make_site("site-1")).unwrap();
170
171        let entity = sg.get("site-1").unwrap();
172        assert!(entity.has("site"));
173    }
174
175    #[test]
176    fn concurrent_read_access() {
177        let sg = SharedGraph::new(EntityGraph::new());
178        sg.add(make_site("site-1")).unwrap();
179
180        // Multiple reads at the "same time" via clone.
181        let sg2 = sg.clone();
182
183        let entity1 = sg.get("site-1");
184        let entity2 = sg2.get("site-1");
185        assert!(entity1.is_some());
186        assert!(entity2.is_some());
187    }
188
189    #[test]
190    fn clone_shares_state() {
191        let sg = SharedGraph::new(EntityGraph::new());
192        let sg2 = sg.clone();
193
194        sg.add(make_site("site-1")).unwrap();
195
196        // sg2 should see the entity added via sg.
197        assert!(sg2.get("site-1").is_some());
198        assert_eq!(sg2.len(), 1);
199    }
200
201    #[test]
202    fn convenience_methods() {
203        let sg = SharedGraph::new(EntityGraph::new());
204        assert!(sg.is_empty());
205        assert_eq!(sg.version(), 0);
206
207        sg.add(make_site("site-1")).unwrap();
208        assert_eq!(sg.len(), 1);
209        assert_eq!(sg.version(), 1);
210
211        let mut changes = HDict::new();
212        changes.set("dis", Kind::Str("Updated".into()));
213        sg.update("site-1", changes).unwrap();
214        assert_eq!(sg.version(), 2);
215
216        let grid = sg.read_filter("site", 0).unwrap();
217        assert_eq!(grid.len(), 1);
218
219        sg.remove("site-1").unwrap();
220        assert!(sg.is_empty());
221    }
222
223    #[test]
224    fn concurrent_writes_from_threads() {
225        use std::thread;
226
227        let sg = SharedGraph::new(EntityGraph::new());
228        let mut handles = Vec::new();
229
230        for i in 0..10 {
231            let sg_clone = sg.clone();
232            handles.push(thread::spawn(move || {
233                let id = format!("site-{i}");
234                sg_clone.add(make_site(&id)).unwrap();
235            }));
236        }
237
238        for h in handles {
239            h.join().unwrap();
240        }
241
242        assert_eq!(sg.len(), 10);
243    }
244
245    #[test]
246    fn contains_check() {
247        let sg = SharedGraph::new(EntityGraph::new());
248        sg.add(make_site("site-1")).unwrap();
249        assert!(sg.contains("site-1"));
250        assert!(!sg.contains("site-2"));
251    }
252
253    #[test]
254    fn default_creates_empty() {
255        let sg = SharedGraph::default();
256        assert!(sg.is_empty());
257        assert_eq!(sg.len(), 0);
258        assert_eq!(sg.version(), 0);
259    }
260
261    #[test]
262    fn read_all_filter() {
263        let sg = SharedGraph::new(EntityGraph::new());
264        sg.add(make_site("site-1")).unwrap();
265        sg.add(make_site("site-2")).unwrap();
266
267        let mut equip = HDict::new();
268        equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
269        equip.set("equip", Kind::Marker);
270        equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
271        sg.add(equip).unwrap();
272
273        let results = sg.read_all("site", 0).unwrap();
274        assert_eq!(results.len(), 2);
275    }
276
277    #[test]
278    fn concurrent_reads_from_threads() {
279        use std::thread;
280
281        let sg = SharedGraph::new(EntityGraph::new());
282        for i in 0..20 {
283            sg.add(make_site(&format!("site-{i}"))).unwrap();
284        }
285
286        let mut handles = Vec::new();
287        for _ in 0..8 {
288            let sg_clone = sg.clone();
289            handles.push(thread::spawn(move || {
290                assert_eq!(sg_clone.len(), 20);
291                for i in 0..20 {
292                    assert!(sg_clone.contains(&format!("site-{i}")));
293                }
294            }));
295        }
296
297        for h in handles {
298            h.join().unwrap();
299        }
300    }
301
302    #[test]
303    fn concurrent_read_write_mix() {
304        use std::thread;
305
306        let sg = SharedGraph::new(EntityGraph::new());
307        // Pre-populate
308        for i in 0..5 {
309            sg.add(make_site(&format!("site-{i}"))).unwrap();
310        }
311
312        let mut handles = Vec::new();
313
314        // Writer thread: add more entities
315        let sg_writer = sg.clone();
316        handles.push(thread::spawn(move || {
317            for i in 5..15 {
318                sg_writer.add(make_site(&format!("site-{i}"))).unwrap();
319            }
320        }));
321
322        // Reader threads: read existing entities
323        for _ in 0..4 {
324            let sg_reader = sg.clone();
325            handles.push(thread::spawn(move || {
326                // Just verify no panics and consistent reads
327                let _len = sg_reader.len();
328                for i in 0..5 {
329                    let _entity = sg_reader.get(&format!("site-{i}"));
330                }
331            }));
332        }
333
334        for h in handles {
335            h.join().unwrap();
336        }
337
338        assert_eq!(sg.len(), 15);
339    }
340
341    #[test]
342    fn version_tracking_across_operations() {
343        let sg = SharedGraph::new(EntityGraph::new());
344        assert_eq!(sg.version(), 0);
345
346        sg.add(make_site("site-1")).unwrap();
347        assert_eq!(sg.version(), 1);
348
349        let mut changes = HDict::new();
350        changes.set("dis", Kind::Str("Updated".into()));
351        sg.update("site-1", changes).unwrap();
352        assert_eq!(sg.version(), 2);
353
354        sg.remove("site-1").unwrap();
355        assert_eq!(sg.version(), 3);
356    }
357
358    #[test]
359    fn refs_from_and_to() {
360        let sg = SharedGraph::new(EntityGraph::new());
361        sg.add(make_site("site-1")).unwrap();
362
363        let mut equip = HDict::new();
364        equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
365        equip.set("equip", Kind::Marker);
366        equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
367        sg.add(equip).unwrap();
368
369        let targets = sg.refs_from("equip-1", None);
370        assert_eq!(targets, vec!["site-1".to_string()]);
371
372        let sources = sg.refs_to("site-1", None);
373        assert_eq!(sources.len(), 1);
374    }
375
376    #[test]
377    fn changes_since_through_shared() {
378        let sg = SharedGraph::new(EntityGraph::new());
379        sg.add(make_site("site-1")).unwrap();
380        sg.add(make_site("site-2")).unwrap();
381
382        let changes = sg.changes_since(0);
383        assert_eq!(changes.len(), 2);
384
385        let changes = sg.changes_since(1);
386        assert_eq!(changes.len(), 1);
387        assert_eq!(changes[0].ref_val, "site-2");
388    }
389}