1use parking_lot::RwLock;
4use std::sync::Arc;
5use tokio::sync::broadcast;
6
7use crate::data::{HDict, HGrid};
8use crate::ontology::ValidationIssue;
9
10use super::entity_graph::{EntityGraph, GraphError, HierarchyNode};
11
12const BROADCAST_CAPACITY: usize = 256;
14
15pub struct SharedGraph {
25 inner: Arc<RwLock<EntityGraph>>,
26 tx: broadcast::Sender<u64>,
27}
28
29impl SharedGraph {
30 pub fn new(graph: EntityGraph) -> Self {
32 let (tx, _) = broadcast::channel(BROADCAST_CAPACITY);
33 Self {
34 inner: Arc::new(RwLock::new(graph)),
35 tx,
36 }
37 }
38
39 pub fn subscribe(&self) -> broadcast::Receiver<u64> {
44 self.tx.subscribe()
45 }
46
47 pub fn subscriber_count(&self) -> usize {
49 self.tx.receiver_count()
50 }
51
52 pub fn read<F, R>(&self, f: F) -> R
54 where
55 F: FnOnce(&EntityGraph) -> R,
56 {
57 let guard = self.inner.read();
58 f(&guard)
59 }
60
61 pub fn write<F, R>(&self, f: F) -> R
63 where
64 F: FnOnce(&mut EntityGraph) -> R,
65 {
66 let mut guard = self.inner.write();
67 f(&mut guard)
68 }
69
70 fn write_and_notify<F, R>(&self, f: F) -> R
72 where
73 F: FnOnce(&mut EntityGraph) -> R,
74 {
75 let (result, version) = {
76 let mut guard = self.inner.write();
77 let v_before = guard.version();
78 let result = f(&mut guard);
79 let v_after = guard.version();
80 (
81 result,
82 if v_after != v_before {
83 Some(v_after)
84 } else {
85 None
86 },
87 )
88 };
89 if let Some(v) = version {
91 let _ = self.tx.send(v);
92 }
93 result
94 }
95
96 pub fn add(&self, entity: HDict) -> Result<String, GraphError> {
100 self.write_and_notify(|g| g.add(entity))
101 }
102
103 pub fn get(&self, ref_val: &str) -> Option<HDict> {
108 self.read(|g| g.get(ref_val).cloned())
109 }
110
111 pub fn update(&self, ref_val: &str, changes: HDict) -> Result<(), GraphError> {
113 self.write_and_notify(|g| g.update(ref_val, changes))
114 }
115
116 pub fn remove(&self, ref_val: &str) -> Result<HDict, GraphError> {
118 self.write_and_notify(|g| g.remove(ref_val))
119 }
120
121 pub fn read_filter(&self, filter_expr: &str, limit: usize) -> Result<HGrid, GraphError> {
123 self.read(|g| g.read(filter_expr, limit))
124 }
125
126 pub fn len(&self) -> usize {
128 self.read(|g| g.len())
129 }
130
131 pub fn is_empty(&self) -> bool {
133 self.read(|g| g.is_empty())
134 }
135
136 pub fn all_entities(&self) -> Vec<HDict> {
138 self.read(|g| g.all().into_iter().cloned().collect())
139 }
140
141 pub fn contains(&self, ref_val: &str) -> bool {
143 self.read(|g| g.contains(ref_val))
144 }
145
146 pub fn version(&self) -> u64 {
148 self.read(|g| g.version())
149 }
150
151 pub fn read_all(&self, filter_expr: &str, limit: usize) -> Result<Vec<HDict>, GraphError> {
153 self.read(|g| {
154 g.read_all(filter_expr, limit)
155 .map(|refs| refs.into_iter().cloned().collect())
156 })
157 }
158
159 pub fn refs_from(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
161 self.read(|g| g.refs_from(ref_val, ref_type))
162 }
163
164 pub fn refs_to(&self, ref_val: &str, ref_type: Option<&str>) -> Vec<String> {
166 self.read(|g| g.refs_to(ref_val, ref_type))
167 }
168
169 pub fn changes_since(
173 &self,
174 version: u64,
175 ) -> Result<Vec<super::changelog::GraphDiff>, super::changelog::ChangelogGap> {
176 self.read(|g| {
177 g.changes_since(version)
178 .map(|refs| refs.into_iter().cloned().collect())
179 })
180 }
181
182 pub fn entities_fitting(&self, spec_name: &str) -> Vec<HDict> {
186 self.read(|g| g.entities_fitting(spec_name).into_iter().cloned().collect())
187 }
188
189 pub fn validate(&self) -> Vec<ValidationIssue> {
193 self.read(|g| g.validate())
194 }
195
196 pub fn all_edges(&self) -> Vec<(String, String, String)> {
198 self.read(|g| g.all_edges())
199 }
200
201 pub fn neighbors(
203 &self,
204 ref_val: &str,
205 hops: usize,
206 ref_types: Option<&[&str]>,
207 ) -> (Vec<HDict>, Vec<(String, String, String)>) {
208 self.read(|g| {
209 let (entities, edges) = g.neighbors(ref_val, hops, ref_types);
210 (entities.into_iter().cloned().collect(), edges)
211 })
212 }
213
214 pub fn shortest_path(&self, from: &str, to: &str) -> Vec<String> {
216 self.read(|g| g.shortest_path(from, to))
217 }
218
219 pub fn subtree(&self, root: &str, max_depth: usize) -> Vec<(HDict, usize)> {
223 self.read(|g| {
224 g.subtree(root, max_depth)
225 .into_iter()
226 .map(|(e, d)| (e.clone(), d))
227 .collect()
228 })
229 }
230
231 pub fn ref_chain(&self, ref_val: &str, ref_tags: &[&str]) -> Vec<HDict> {
233 self.read(|g| {
234 g.ref_chain(ref_val, ref_tags)
235 .into_iter()
236 .cloned()
237 .collect()
238 })
239 }
240
241 pub fn site_for(&self, ref_val: &str) -> Option<HDict> {
243 self.read(|g| g.site_for(ref_val).cloned())
244 }
245
246 pub fn children(&self, ref_val: &str) -> Vec<HDict> {
248 self.read(|g| g.children(ref_val).into_iter().cloned().collect())
249 }
250
251 pub fn equip_points(&self, equip_ref: &str, filter: Option<&str>) -> Vec<HDict> {
253 self.read(|g| {
254 g.equip_points(equip_ref, filter)
255 .into_iter()
256 .cloned()
257 .collect()
258 })
259 }
260
261 pub fn hierarchy_tree(&self, root: &str, max_depth: usize) -> Option<HierarchyNode> {
263 self.read(|g| g.hierarchy_tree(root, max_depth))
264 }
265
266 pub fn classify(&self, ref_val: &str) -> Option<String> {
268 self.read(|g| g.classify(ref_val))
269 }
270}
271
272impl Default for SharedGraph {
273 fn default() -> Self {
274 Self::new(EntityGraph::new())
275 }
276}
277
278impl Clone for SharedGraph {
279 fn clone(&self) -> Self {
280 Self {
281 inner: Arc::clone(&self.inner),
282 tx: self.tx.clone(),
283 }
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::kinds::{HRef, Kind};
291
292 fn make_site(id: &str) -> HDict {
293 let mut d = HDict::new();
294 d.set("id", Kind::Ref(HRef::from_val(id)));
295 d.set("site", Kind::Marker);
296 d.set("dis", Kind::Str(format!("Site {id}")));
297 d
298 }
299
300 #[test]
301 fn thread_safe_add_get() {
302 let sg = SharedGraph::new(EntityGraph::new());
303 sg.add(make_site("site-1")).unwrap();
304
305 let entity = sg.get("site-1").unwrap();
306 assert!(entity.has("site"));
307 }
308
309 #[test]
310 fn concurrent_read_access() {
311 let sg = SharedGraph::new(EntityGraph::new());
312 sg.add(make_site("site-1")).unwrap();
313
314 let sg2 = sg.clone();
316
317 let entity1 = sg.get("site-1");
318 let entity2 = sg2.get("site-1");
319 assert!(entity1.is_some());
320 assert!(entity2.is_some());
321 }
322
323 #[test]
324 fn clone_shares_state() {
325 let sg = SharedGraph::new(EntityGraph::new());
326 let sg2 = sg.clone();
327
328 sg.add(make_site("site-1")).unwrap();
329
330 assert!(sg2.get("site-1").is_some());
332 assert_eq!(sg2.len(), 1);
333 }
334
335 #[test]
336 fn convenience_methods() {
337 let sg = SharedGraph::new(EntityGraph::new());
338 assert!(sg.is_empty());
339 assert_eq!(sg.version(), 0);
340
341 sg.add(make_site("site-1")).unwrap();
342 assert_eq!(sg.len(), 1);
343 assert_eq!(sg.version(), 1);
344
345 let mut changes = HDict::new();
346 changes.set("dis", Kind::Str("Updated".into()));
347 sg.update("site-1", changes).unwrap();
348 assert_eq!(sg.version(), 2);
349
350 let grid = sg.read_filter("site", 0).unwrap();
351 assert_eq!(grid.len(), 1);
352
353 sg.remove("site-1").unwrap();
354 assert!(sg.is_empty());
355 }
356
357 #[test]
358 fn concurrent_writes_from_threads() {
359 use std::thread;
360
361 let sg = SharedGraph::new(EntityGraph::new());
362 let mut handles = Vec::new();
363
364 for i in 0..10 {
365 let sg_clone = sg.clone();
366 handles.push(thread::spawn(move || {
367 let id = format!("site-{i}");
368 sg_clone.add(make_site(&id)).unwrap();
369 }));
370 }
371
372 for h in handles {
373 h.join().unwrap();
374 }
375
376 assert_eq!(sg.len(), 10);
377 }
378
379 #[test]
380 fn contains_check() {
381 let sg = SharedGraph::new(EntityGraph::new());
382 sg.add(make_site("site-1")).unwrap();
383 assert!(sg.contains("site-1"));
384 assert!(!sg.contains("site-2"));
385 }
386
387 #[test]
388 fn default_creates_empty() {
389 let sg = SharedGraph::default();
390 assert!(sg.is_empty());
391 assert_eq!(sg.len(), 0);
392 assert_eq!(sg.version(), 0);
393 }
394
395 #[test]
396 fn read_all_filter() {
397 let sg = SharedGraph::new(EntityGraph::new());
398 sg.add(make_site("site-1")).unwrap();
399 sg.add(make_site("site-2")).unwrap();
400
401 let mut equip = HDict::new();
402 equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
403 equip.set("equip", Kind::Marker);
404 equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
405 sg.add(equip).unwrap();
406
407 let results = sg.read_all("site", 0).unwrap();
408 assert_eq!(results.len(), 2);
409 }
410
411 #[test]
412 fn concurrent_reads_from_threads() {
413 use std::thread;
414
415 let sg = SharedGraph::new(EntityGraph::new());
416 for i in 0..20 {
417 sg.add(make_site(&format!("site-{i}"))).unwrap();
418 }
419
420 let mut handles = Vec::new();
421 for _ in 0..8 {
422 let sg_clone = sg.clone();
423 handles.push(thread::spawn(move || {
424 assert_eq!(sg_clone.len(), 20);
425 for i in 0..20 {
426 assert!(sg_clone.contains(&format!("site-{i}")));
427 }
428 }));
429 }
430
431 for h in handles {
432 h.join().unwrap();
433 }
434 }
435
436 #[test]
437 fn concurrent_read_write_mix() {
438 use std::thread;
439
440 let sg = SharedGraph::new(EntityGraph::new());
441 for i in 0..5 {
443 sg.add(make_site(&format!("site-{i}"))).unwrap();
444 }
445
446 let mut handles = Vec::new();
447
448 let sg_writer = sg.clone();
450 handles.push(thread::spawn(move || {
451 for i in 5..15 {
452 sg_writer.add(make_site(&format!("site-{i}"))).unwrap();
453 }
454 }));
455
456 for _ in 0..4 {
458 let sg_reader = sg.clone();
459 handles.push(thread::spawn(move || {
460 let _len = sg_reader.len();
462 for i in 0..5 {
463 let _entity = sg_reader.get(&format!("site-{i}"));
464 }
465 }));
466 }
467
468 for h in handles {
469 h.join().unwrap();
470 }
471
472 assert_eq!(sg.len(), 15);
473 }
474
475 #[test]
476 fn version_tracking_across_operations() {
477 let sg = SharedGraph::new(EntityGraph::new());
478 assert_eq!(sg.version(), 0);
479
480 sg.add(make_site("site-1")).unwrap();
481 assert_eq!(sg.version(), 1);
482
483 let mut changes = HDict::new();
484 changes.set("dis", Kind::Str("Updated".into()));
485 sg.update("site-1", changes).unwrap();
486 assert_eq!(sg.version(), 2);
487
488 sg.remove("site-1").unwrap();
489 assert_eq!(sg.version(), 3);
490 }
491
492 #[test]
493 fn refs_from_and_to() {
494 let sg = SharedGraph::new(EntityGraph::new());
495 sg.add(make_site("site-1")).unwrap();
496
497 let mut equip = HDict::new();
498 equip.set("id", Kind::Ref(HRef::from_val("equip-1")));
499 equip.set("equip", Kind::Marker);
500 equip.set("siteRef", Kind::Ref(HRef::from_val("site-1")));
501 sg.add(equip).unwrap();
502
503 let targets = sg.refs_from("equip-1", None);
504 assert_eq!(targets, vec!["site-1".to_string()]);
505
506 let sources = sg.refs_to("site-1", None);
507 assert_eq!(sources.len(), 1);
508 }
509
510 #[test]
511 fn changes_since_through_shared() {
512 let sg = SharedGraph::new(EntityGraph::new());
513 sg.add(make_site("site-1")).unwrap();
514 sg.add(make_site("site-2")).unwrap();
515
516 let changes = sg.changes_since(0).unwrap();
517 assert_eq!(changes.len(), 2);
518
519 let changes = sg.changes_since(1).unwrap();
520 assert_eq!(changes.len(), 1);
521 assert_eq!(changes[0].ref_val, "site-2");
522 }
523
524 #[test]
525 fn subscribe_receives_versions() {
526 let sg = SharedGraph::new(EntityGraph::new());
527 let mut rx = sg.subscribe();
528 assert_eq!(sg.subscriber_count(), 1);
529
530 sg.add(make_site("site-1")).unwrap();
531 sg.add(make_site("site-2")).unwrap();
532
533 assert_eq!(rx.try_recv().unwrap(), 1);
535 assert_eq!(rx.try_recv().unwrap(), 2);
536 assert!(rx.try_recv().is_err()); }
538
539 #[test]
540 fn broadcast_on_update_and_remove() {
541 let sg = SharedGraph::new(EntityGraph::new());
542 sg.add(make_site("site-1")).unwrap();
543
544 let mut rx = sg.subscribe();
545
546 let mut changes = HDict::new();
547 changes.set("dis", Kind::Str("Updated".into()));
548 sg.update("site-1", changes).unwrap();
549 sg.remove("site-1").unwrap();
550
551 assert_eq!(rx.try_recv().unwrap(), 2); assert_eq!(rx.try_recv().unwrap(), 3); }
554
555 #[test]
556 fn no_subscribers_does_not_panic() {
557 let sg = SharedGraph::new(EntityGraph::new());
558 sg.add(make_site("site-1")).unwrap();
560 assert_eq!(sg.len(), 1);
561 }
562}