Skip to main content

haystack_core/graph/
shared.rs

1// SharedGraph — thread-safe wrapper around EntityGraph using parking_lot RwLock.
2
3use parking_lot::RwLock;
4use std::sync::Arc;
5
6use crate::data::{HDict, HGrid};
7use crate::ontology::ValidationIssue;
8
9use super::entity_graph::{EntityGraph, GraphError};
10
11/// Thread-safe, clonable handle to an `EntityGraph`.
12///
13/// Uses `parking_lot::RwLock` for reader-writer locking, allowing
14/// concurrent reads with exclusive writes. Cloning shares the
15/// underlying graph (via `Arc`).
16pub struct SharedGraph {
17    inner: Arc<RwLock<EntityGraph>>,
18}
19
20impl SharedGraph {
21    /// Wrap an `EntityGraph` in a thread-safe handle.
22    pub fn new(graph: EntityGraph) -> Self {
23        Self {
24            inner: Arc::new(RwLock::new(graph)),
25        }
26    }
27
28    /// Execute a closure with shared (read) access to the graph.
29    pub fn read<F, R>(&self, f: F) -> R
30    where
31        F: FnOnce(&EntityGraph) -> R,
32    {
33        let guard = self.inner.read();
34        f(&guard)
35    }
36
37    /// Execute a closure with exclusive (write) access to the graph.
38    pub fn write<F, R>(&self, f: F) -> R
39    where
40        F: FnOnce(&mut EntityGraph) -> R,
41    {
42        let mut guard = self.inner.write();
43        f(&mut guard)
44    }
45
46    // ── Convenience methods ──
47
48    /// Add an entity. See [`EntityGraph::add`].
49    pub fn add(&self, entity: HDict) -> Result<String, GraphError> {
50        self.write(|g| g.add(entity))
51    }
52
53    /// Get an entity by ref value.
54    ///
55    /// Returns an owned clone because the read lock is released before
56    /// the caller uses the value.
57    pub fn get(&self, ref_val: &str) -> Option<HDict> {
58        self.read(|g| g.get(ref_val).cloned())
59    }
60
61    /// Update an entity. See [`EntityGraph::update`].
62    pub fn update(&self, ref_val: &str, changes: HDict) -> Result<(), GraphError> {
63        self.write(|g| g.update(ref_val, changes))
64    }
65
66    /// Remove an entity. See [`EntityGraph::remove`].
67    pub fn remove(&self, ref_val: &str) -> Result<HDict, GraphError> {
68        self.write(|g| g.remove(ref_val))
69    }
70
71    /// Run a filter expression and return a grid.
72    pub fn read_filter(&self, filter_expr: &str, limit: usize) -> Result<HGrid, GraphError> {
73        self.read(|g| g.read(filter_expr, limit))
74    }
75
76    /// Number of entities.
77    pub fn len(&self) -> usize {
78        self.read(|g| g.len())
79    }
80
81    /// Returns `true` if the graph has no entities.
82    pub fn is_empty(&self) -> bool {
83        self.read(|g| g.is_empty())
84    }
85
86    /// Check if an entity with the given ref value exists.
87    pub fn contains(&self, ref_val: &str) -> bool {
88        self.read(|g| g.contains(ref_val))
89    }
90
91    /// Current graph version.
92    pub fn version(&self) -> u64 {
93        self.read(|g| g.version())
94    }
95
96    /// Run a filter and return matching entity dicts (cloned).
97    pub fn read_all(&self, filter_expr: &str, limit: usize) -> Result<Vec<HDict>, GraphError> {
98        self.read(|g| {
99            g.read_all(filter_expr, limit)
100                .map(|refs| refs.into_iter().cloned().collect())
101        })
102    }
103
104    /// Get ref values that the given entity points to.
105    pub fn refs_from(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
106        self.read(|g| g.refs_from(ref_val, ref_type))
107    }
108
109    /// Get ref values of entities that point to the given entity.
110    pub fn refs_to(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
111        self.read(|g| g.refs_to(ref_val, ref_type))
112    }
113
114    /// Get changelog entries since a given version.
115    pub fn changes_since(&self, version: u64) -> Vec<super::changelog::GraphDiff> {
116        self.read(|g| g.changes_since(version).to_vec())
117    }
118
119    /// Find all entities that structurally fit a spec/type name.
120    ///
121    /// Returns owned clones. See [`EntityGraph::entities_fitting`].
122    pub fn entities_fitting(&self, spec_name: &str) -> Vec<HDict> {
123        self.read(|g| g.entities_fitting(spec_name).into_iter().cloned().collect())
124    }
125
126    /// Validate all entities against the namespace and check for dangling refs.
127    ///
128    /// See [`EntityGraph::validate`].
129    pub fn validate(&self) -> Vec<ValidationIssue> {
130        self.read(|g| g.validate())
131    }
132}
133
134impl Default for SharedGraph {
135    fn default() -> Self {
136        Self::new(EntityGraph::new())
137    }
138}
139
140impl Clone for SharedGraph {
141    fn clone(&self) -> Self {
142        Self {
143            inner: Arc::clone(&self.inner),
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::kinds::{HRef, Kind};
152
153    fn make_site(id: &str) -> HDict {
154        let mut d = HDict::new();
155        d.set("id", Kind::Ref(HRef::from_val(id)));
156        d.set("site", Kind::Marker);
157        d.set("dis", Kind::Str(format!("Site {id}")));
158        d
159    }
160
161    #[test]
162    fn thread_safe_add_get() {
163        let sg = SharedGraph::new(EntityGraph::new());
164        sg.add(make_site("site-1")).unwrap();
165
166        let entity = sg.get("site-1").unwrap();
167        assert!(entity.has("site"));
168    }
169
170    #[test]
171    fn concurrent_read_access() {
172        let sg = SharedGraph::new(EntityGraph::new());
173        sg.add(make_site("site-1")).unwrap();
174
175        // Multiple reads at the "same time" via clone.
176        let sg2 = sg.clone();
177
178        let entity1 = sg.get("site-1");
179        let entity2 = sg2.get("site-1");
180        assert!(entity1.is_some());
181        assert!(entity2.is_some());
182    }
183
184    #[test]
185    fn clone_shares_state() {
186        let sg = SharedGraph::new(EntityGraph::new());
187        let sg2 = sg.clone();
188
189        sg.add(make_site("site-1")).unwrap();
190
191        // sg2 should see the entity added via sg.
192        assert!(sg2.get("site-1").is_some());
193        assert_eq!(sg2.len(), 1);
194    }
195
196    #[test]
197    fn convenience_methods() {
198        let sg = SharedGraph::new(EntityGraph::new());
199        assert!(sg.is_empty());
200        assert_eq!(sg.version(), 0);
201
202        sg.add(make_site("site-1")).unwrap();
203        assert_eq!(sg.len(), 1);
204        assert_eq!(sg.version(), 1);
205
206        let mut changes = HDict::new();
207        changes.set("dis", Kind::Str("Updated".into()));
208        sg.update("site-1", changes).unwrap();
209        assert_eq!(sg.version(), 2);
210
211        let grid = sg.read_filter("site", 0).unwrap();
212        assert_eq!(grid.len(), 1);
213
214        sg.remove("site-1").unwrap();
215        assert!(sg.is_empty());
216    }
217
218    #[test]
219    fn concurrent_writes_from_threads() {
220        use std::thread;
221
222        let sg = SharedGraph::new(EntityGraph::new());
223        let mut handles = Vec::new();
224
225        for i in 0..10 {
226            let sg_clone = sg.clone();
227            handles.push(thread::spawn(move || {
228                let id = format!("site-{i}");
229                sg_clone.add(make_site(&id)).unwrap();
230            }));
231        }
232
233        for h in handles {
234            h.join().unwrap();
235        }
236
237        assert_eq!(sg.len(), 10);
238    }
239
240    #[test]
241    fn contains_check() {
242        let sg = SharedGraph::new(EntityGraph::new());
243        sg.add(make_site("site-1")).unwrap();
244        assert!(sg.contains("site-1"));
245        assert!(!sg.contains("site-2"));
246    }
247
248    #[test]
249    fn default_creates_empty() {
250        let sg = SharedGraph::default();
251        assert!(sg.is_empty());
252        assert_eq!(sg.len(), 0);
253        assert_eq!(sg.version(), 0);
254    }
255
256    #[test]
257    fn read_all_filter() {
258        let sg = SharedGraph::new(EntityGraph::new());
259        sg.add(make_site("site-1")).unwrap();
260        sg.add(make_site("site-2")).unwrap();
261
262        let mut equip = HDict::new();
263        equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
264        equip.set("equip", Kind::Marker);
265        equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
266        sg.add(equip).unwrap();
267
268        let results = sg.read_all("site", 0).unwrap();
269        assert_eq!(results.len(), 2);
270    }
271
272    #[test]
273    fn concurrent_reads_from_threads() {
274        use std::thread;
275
276        let sg = SharedGraph::new(EntityGraph::new());
277        for i in 0..20 {
278            sg.add(make_site(&format!("site-{i}"))).unwrap();
279        }
280
281        let mut handles = Vec::new();
282        for _ in 0..8 {
283            let sg_clone = sg.clone();
284            handles.push(thread::spawn(move || {
285                assert_eq!(sg_clone.len(), 20);
286                for i in 0..20 {
287                    assert!(sg_clone.contains(&format!("site-{i}")));
288                }
289            }));
290        }
291
292        for h in handles {
293            h.join().unwrap();
294        }
295    }
296
297    #[test]
298    fn concurrent_read_write_mix() {
299        use std::thread;
300
301        let sg = SharedGraph::new(EntityGraph::new());
302        // Pre-populate
303        for i in 0..5 {
304            sg.add(make_site(&format!("site-{i}"))).unwrap();
305        }
306
307        let mut handles = Vec::new();
308
309        // Writer thread: add more entities
310        let sg_writer = sg.clone();
311        handles.push(thread::spawn(move || {
312            for i in 5..15 {
313                sg_writer.add(make_site(&format!("site-{i}"))).unwrap();
314            }
315        }));
316
317        // Reader threads: read existing entities
318        for _ in 0..4 {
319            let sg_reader = sg.clone();
320            handles.push(thread::spawn(move || {
321                // Just verify no panics and consistent reads
322                let _len = sg_reader.len();
323                for i in 0..5 {
324                    let _entity = sg_reader.get(&format!("site-{i}"));
325                }
326            }));
327        }
328
329        for h in handles {
330            h.join().unwrap();
331        }
332
333        assert_eq!(sg.len(), 15);
334    }
335
336    #[test]
337    fn version_tracking_across_operations() {
338        let sg = SharedGraph::new(EntityGraph::new());
339        assert_eq!(sg.version(), 0);
340
341        sg.add(make_site("site-1")).unwrap();
342        assert_eq!(sg.version(), 1);
343
344        let mut changes = HDict::new();
345        changes.set("dis", Kind::Str("Updated".into()));
346        sg.update("site-1", changes).unwrap();
347        assert_eq!(sg.version(), 2);
348
349        sg.remove("site-1").unwrap();
350        assert_eq!(sg.version(), 3);
351    }
352
353    #[test]
354    fn refs_from_and_to() {
355        let sg = SharedGraph::new(EntityGraph::new());
356        sg.add(make_site("site-1")).unwrap();
357
358        let mut equip = HDict::new();
359        equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
360        equip.set("equip", Kind::Marker);
361        equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
362        sg.add(equip).unwrap();
363
364        let targets = sg.refs_from("equip-1", None);
365        assert_eq!(targets, vec!["site-1".to_string()]);
366
367        let sources = sg.refs_to("site-1", None);
368        assert_eq!(sources.len(), 1);
369    }
370
371    #[test]
372    fn changes_since_through_shared() {
373        let sg = SharedGraph::new(EntityGraph::new());
374        sg.add(make_site("site-1")).unwrap();
375        sg.add(make_site("site-2")).unwrap();
376
377        let changes = sg.changes_since(0);
378        assert_eq!(changes.len(), 2);
379
380        let changes = sg.changes_since(1);
381        assert_eq!(changes.len(), 1);
382        assert_eq!(changes[0].ref_val, "site-2");
383    }
384}