1use petgraph::algo::{dijkstra, has_path_connecting};
6use petgraph::graph::{DiGraph, NodeIndex};
7use petgraph::visit::EdgeRef;
8use petgraph::Direction;
9use std::collections::HashMap;
10
11use crate::core::{Event, EventId, GeoBounds, Location, TimeRange};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub struct NodeId(pub(crate) NodeIndex);
16
17impl NodeId {
18 pub fn index(&self) -> usize {
20 self.0.index()
21 }
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
26pub enum EdgeType {
27 #[default]
29 Temporal,
30 Spatial,
32 Causal,
34 Thematic,
36 Reference,
38 Custom,
40}
41
42#[derive(Debug, Clone)]
44pub struct EdgeWeight {
45 pub edge_type: EdgeType,
47 pub weight: f64,
49 pub label: Option<String>,
51}
52
53impl EdgeWeight {
54 pub fn new(edge_type: EdgeType) -> Self {
56 Self {
57 edge_type,
58 weight: 1.0,
59 label: None,
60 }
61 }
62
63 pub fn with_weight(edge_type: EdgeType, weight: f64) -> Self {
65 Self {
66 edge_type,
67 weight: weight.clamp(0.0, 1.0),
68 label: None,
69 }
70 }
71
72 pub fn with_label(mut self, label: impl Into<String>) -> Self {
74 self.label = Some(label.into());
75 self
76 }
77}
78
79impl Default for EdgeWeight {
80 fn default() -> Self {
81 Self::new(EdgeType::Temporal)
82 }
83}
84
85#[derive(Debug)]
93pub struct NarrativeGraph {
94 graph: DiGraph<Event, EdgeWeight>,
96 id_map: HashMap<EventId, NodeIndex>,
98}
99
100impl NarrativeGraph {
101 pub fn new() -> Self {
103 Self {
104 graph: DiGraph::new(),
105 id_map: HashMap::new(),
106 }
107 }
108
109 pub fn from_events(events: impl IntoIterator<Item = Event>) -> Self {
113 let mut graph = Self::new();
114 for event in events {
115 graph.add_event(event);
116 }
117 graph
118 }
119
120 pub fn add_event(&mut self, event: Event) -> NodeId {
124 let event_id = event.id.clone();
125 let idx = self.graph.add_node(event);
126 self.id_map.insert(event_id, idx);
127 NodeId(idx)
128 }
129
130 pub fn get_node(&self, event_id: &EventId) -> Option<NodeId> {
132 self.id_map.get(event_id).map(|&idx| NodeId(idx))
133 }
134
135 pub fn event(&self, node: NodeId) -> Option<&Event> {
137 self.graph.node_weight(node.0)
138 }
139
140 pub fn event_mut(&mut self, node: NodeId) -> Option<&mut Event> {
142 self.graph.node_weight_mut(node.0)
143 }
144
145 pub fn connect(&mut self, from: NodeId, to: NodeId, edge_type: EdgeType) {
147 self.graph
148 .add_edge(from.0, to.0, EdgeWeight::new(edge_type));
149 }
150
151 pub fn connect_weighted(&mut self, from: NodeId, to: NodeId, weight: EdgeWeight) {
153 self.graph.add_edge(from.0, to.0, weight);
154 }
155
156 pub fn are_connected(&self, from: NodeId, to: NodeId) -> bool {
158 self.graph.contains_edge(from.0, to.0)
159 }
160
161 pub fn has_path(&self, from: NodeId, to: NodeId) -> bool {
163 has_path_connecting(&self.graph, from.0, to.0, None)
164 }
165
166 pub fn successors(&self, node: NodeId) -> Vec<NodeId> {
168 self.graph
169 .neighbors_directed(node.0, Direction::Outgoing)
170 .map(NodeId)
171 .collect()
172 }
173
174 pub fn predecessors(&self, node: NodeId) -> Vec<NodeId> {
176 self.graph
177 .neighbors_directed(node.0, Direction::Incoming)
178 .map(NodeId)
179 .collect()
180 }
181
182 pub fn node_count(&self) -> usize {
184 self.graph.node_count()
185 }
186
187 pub fn edge_count(&self) -> usize {
189 self.graph.edge_count()
190 }
191
192 pub fn is_empty(&self) -> bool {
194 self.graph.node_count() == 0
195 }
196
197 pub fn nodes(&self) -> impl Iterator<Item = (NodeId, &Event)> {
199 self.graph.node_indices().filter_map(|idx| {
200 self.graph
201 .node_weight(idx)
202 .map(|event| (NodeId(idx), event))
203 })
204 }
205
206 pub fn edges(&self) -> impl Iterator<Item = (NodeId, NodeId, &EdgeWeight)> {
208 self.graph
209 .edge_references()
210 .map(|edge| (NodeId(edge.source()), NodeId(edge.target()), edge.weight()))
211 }
212
213 pub fn shortest_path(&self, from: NodeId, to: NodeId) -> Option<PathInfo> {
217 let costs = dijkstra(&self.graph, from.0, Some(to.0), |e| 1.0 - e.weight().weight);
219
220 if !costs.contains_key(&to.0) {
221 return None;
222 }
223
224 let mut path = vec![to];
226 let mut current = to.0;
227
228 while current != from.0 {
229 let predecessors: Vec<_> = self
230 .graph
231 .neighbors_directed(current, Direction::Incoming)
232 .collect();
233
234 let best = predecessors
235 .iter()
236 .filter(|&&n| costs.contains_key(&n))
237 .min_by(|&&a, &&b| costs[&a].partial_cmp(&costs[&b]).unwrap());
238
239 if let Some(&next) = best {
240 path.push(NodeId(next));
241 current = next;
242 } else {
243 break;
244 }
245 }
246
247 path.reverse();
248
249 Some(PathInfo {
250 nodes: path,
251 total_weight: costs[&to.0],
252 })
253 }
254
255 pub fn edges_of_type(&self, edge_type: EdgeType) -> Vec<(NodeId, NodeId)> {
257 self.graph
258 .edge_references()
259 .filter(|e| e.weight().edge_type == edge_type)
260 .map(|e| (NodeId(e.source()), NodeId(e.target())))
261 .collect()
262 }
263
264 pub fn connect_temporal(&mut self) {
268 let mut nodes: Vec<_> = self
269 .graph
270 .node_indices()
271 .filter_map(|idx| {
272 self.graph
273 .node_weight(idx)
274 .map(|e| (idx, e.timestamp.clone()))
275 })
276 .collect();
277
278 nodes.sort_by(|a, b| a.1.cmp(&b.1));
279
280 for window in nodes.windows(2) {
281 if let [a, b] = window {
282 if !self.graph.contains_edge(a.0, b.0) {
283 self.graph
284 .add_edge(a.0, b.0, EdgeWeight::new(EdgeType::Temporal));
285 }
286 }
287 }
288 }
289
290 pub fn connect_spatial(&mut self, max_distance_km: f64) {
294 let nodes: Vec<_> = self
295 .graph
296 .node_indices()
297 .filter_map(|idx| {
298 self.graph
299 .node_weight(idx)
300 .map(|e| (idx, e.location.clone()))
301 })
302 .collect();
303
304 for i in 0..nodes.len() {
305 for j in (i + 1)..nodes.len() {
306 let dist = haversine_distance(&nodes[i].1, &nodes[j].1);
307 if dist <= max_distance_km {
308 let weight = 1.0 - (dist / max_distance_km);
309 let edge = EdgeWeight::with_weight(EdgeType::Spatial, weight);
310
311 if !self.graph.contains_edge(nodes[i].0, nodes[j].0) {
313 self.graph.add_edge(nodes[i].0, nodes[j].0, edge.clone());
314 }
315 if !self.graph.contains_edge(nodes[j].0, nodes[i].0) {
316 self.graph.add_edge(nodes[j].0, nodes[i].0, edge);
317 }
318 }
319 }
320 }
321 }
322
323 pub fn connect_thematic(&mut self) {
325 let nodes: Vec<_> = self
326 .graph
327 .node_indices()
328 .filter_map(|idx| self.graph.node_weight(idx).map(|e| (idx, e.tags.clone())))
329 .collect();
330
331 for i in 0..nodes.len() {
332 for j in (i + 1)..nodes.len() {
333 let shared: usize = nodes[i].1.iter().filter(|t| nodes[j].1.contains(t)).count();
334
335 if shared > 0 {
336 let total = nodes[i].1.len().max(nodes[j].1.len());
337 let weight = shared as f64 / total as f64;
338 let edge = EdgeWeight::with_weight(EdgeType::Thematic, weight);
339
340 if !self.graph.contains_edge(nodes[i].0, nodes[j].0) {
342 self.graph.add_edge(nodes[i].0, nodes[j].0, edge.clone());
343 }
344 if !self.graph.contains_edge(nodes[j].0, nodes[i].0) {
345 self.graph.add_edge(nodes[j].0, nodes[i].0, edge);
346 }
347 }
348 }
349 }
350 }
351
352 pub fn subgraph_temporal(&self, range: &TimeRange) -> SubgraphResult {
354 let nodes: Vec<NodeId> = self
355 .nodes()
356 .filter(|(_, event)| range.contains(&event.timestamp))
357 .map(|(id, _)| id)
358 .collect();
359
360 self.subgraph_from_nodes(&nodes)
361 }
362
363 pub fn subgraph_spatial(&self, bounds: &GeoBounds) -> SubgraphResult {
365 let nodes: Vec<NodeId> = self
366 .nodes()
367 .filter(|(_, event)| bounds.contains(&event.location))
368 .map(|(id, _)| id)
369 .collect();
370
371 self.subgraph_from_nodes(&nodes)
372 }
373
374 fn subgraph_from_nodes(&self, nodes: &[NodeId]) -> SubgraphResult {
376 let mut new_graph = NarrativeGraph::new();
377 let mut id_map = HashMap::new();
378
379 for &node_id in nodes {
381 if let Some(event) = self.event(node_id) {
382 let new_id = new_graph.add_event(event.clone());
383 id_map.insert(node_id, new_id);
384 }
385 }
386
387 for (from, to, weight) in self.edges() {
389 if let (Some(&new_from), Some(&new_to)) = (id_map.get(&from), id_map.get(&to)) {
390 new_graph.connect_weighted(new_from, new_to, weight.clone());
391 }
392 }
393
394 SubgraphResult {
395 graph: new_graph,
396 node_mapping: id_map,
397 }
398 }
399
400 pub fn in_degree(&self, node: NodeId) -> usize {
402 self.graph
403 .edges_directed(node.0, Direction::Incoming)
404 .count()
405 }
406
407 pub fn out_degree(&self, node: NodeId) -> usize {
409 self.graph
410 .edges_directed(node.0, Direction::Outgoing)
411 .count()
412 }
413
414 pub fn roots(&self) -> Vec<NodeId> {
416 self.graph
417 .node_indices()
418 .filter(|&idx| self.graph.edges_directed(idx, Direction::Incoming).count() == 0)
419 .map(NodeId)
420 .collect()
421 }
422
423 pub fn leaves(&self) -> Vec<NodeId> {
425 self.graph
426 .node_indices()
427 .filter(|&idx| self.graph.edges_directed(idx, Direction::Outgoing).count() == 0)
428 .map(NodeId)
429 .collect()
430 }
431
432 pub fn to_dot(&self) -> String {
455 self.to_dot_with_options(DotOptions::default())
456 }
457
458 pub fn to_dot_with_options(&self, options: DotOptions) -> String {
460 let mut output = String::new();
461 output.push_str("digraph NarrativeGraph {\n");
462
463 output.push_str(&format!(" rankdir={};\n", options.rank_direction));
465 output.push_str(&format!(
466 " node [shape={}, fontname=\"{}\"];\n",
467 options.node_shape, options.font_name
468 ));
469 output.push_str(&format!(" edge [fontname=\"{}\"];\n", options.font_name));
470 output.push('\n');
471
472 for idx in self.graph.node_indices() {
474 let event = &self.graph[idx];
475 let label = Self::escape_dot_string(&Self::truncate_text(&event.text, 30));
476 let tooltip = Self::escape_dot_string(&format!(
477 "{}\\n({:.4}, {:.4})\\n{}",
478 event.text,
479 event.location.lat,
480 event.location.lon,
481 event.timestamp.to_rfc3339()
482 ));
483
484 let color = self.get_node_color(NodeId(idx));
486
487 output.push_str(&format!(
488 " n{} [label=\"{}\", tooltip=\"{}\", fillcolor=\"{}\", style=filled];\n",
489 idx.index(),
490 label,
491 tooltip,
492 color
493 ));
494 }
495
496 output.push('\n');
497
498 for edge in self.graph.edge_references() {
500 let weight = edge.weight();
501 let color = Self::edge_type_color(&weight.edge_type);
502 let style = Self::edge_type_style(&weight.edge_type);
503 let label = weight.label.as_deref().unwrap_or("");
504
505 output.push_str(&format!(
506 " n{} -> n{} [color=\"{}\", style={}, label=\"{}\", penwidth={}];\n",
507 edge.source().index(),
508 edge.target().index(),
509 color,
510 style,
511 Self::escape_dot_string(label),
512 1.0 + weight.weight * 2.0
513 ));
514 }
515
516 output.push_str("}\n");
517 output
518 }
519
520 pub fn to_json(&self) -> String {
524 let nodes: Vec<serde_json::Value> = self
525 .graph
526 .node_indices()
527 .map(|idx| {
528 let event = &self.graph[idx];
529 serde_json::json!({
530 "id": idx.index(),
531 "event_id": event.id.to_string(),
532 "text": event.text,
533 "location": {
534 "lat": event.location.lat,
535 "lon": event.location.lon,
536 "elevation": event.location.elevation,
537 "name": event.location.name
538 },
539 "timestamp": event.timestamp.to_rfc3339(),
540 "tags": event.tags
541 })
542 })
543 .collect();
544
545 let edges: Vec<serde_json::Value> = self
546 .graph
547 .edge_references()
548 .map(|edge| {
549 let weight = edge.weight();
550 serde_json::json!({
551 "source": edge.source().index(),
552 "target": edge.target().index(),
553 "type": format!("{:?}", weight.edge_type),
554 "weight": weight.weight,
555 "label": weight.label
556 })
557 })
558 .collect();
559
560 serde_json::json!({
561 "nodes": nodes,
562 "edges": edges,
563 "metadata": {
564 "node_count": self.node_count(),
565 "edge_count": self.edge_count()
566 }
567 })
568 .to_string()
569 }
570
571 pub fn to_json_pretty(&self) -> String {
573 let nodes: Vec<serde_json::Value> = self
574 .graph
575 .node_indices()
576 .map(|idx| {
577 let event = &self.graph[idx];
578 serde_json::json!({
579 "id": idx.index(),
580 "event_id": event.id.to_string(),
581 "text": event.text,
582 "location": {
583 "lat": event.location.lat,
584 "lon": event.location.lon,
585 "elevation": event.location.elevation,
586 "name": event.location.name
587 },
588 "timestamp": event.timestamp.to_rfc3339(),
589 "tags": event.tags
590 })
591 })
592 .collect();
593
594 let edges: Vec<serde_json::Value> = self
595 .graph
596 .edge_references()
597 .map(|edge| {
598 let weight = edge.weight();
599 serde_json::json!({
600 "source": edge.source().index(),
601 "target": edge.target().index(),
602 "type": format!("{:?}", weight.edge_type),
603 "weight": weight.weight,
604 "label": weight.label
605 })
606 })
607 .collect();
608
609 serde_json::to_string_pretty(&serde_json::json!({
610 "nodes": nodes,
611 "edges": edges,
612 "metadata": {
613 "node_count": self.node_count(),
614 "edge_count": self.edge_count()
615 }
616 }))
617 .unwrap_or_default()
618 }
619
620 fn escape_dot_string(s: &str) -> String {
622 s.replace('\\', "\\\\")
623 .replace('"', "\\\"")
624 .replace('\n', "\\n")
625 }
626
627 fn truncate_text(text: &str, max_len: usize) -> String {
628 if text.len() <= max_len {
629 text.to_string()
630 } else {
631 format!("{}...", &text[..max_len.saturating_sub(3)])
632 }
633 }
634
635 fn get_node_color(&self, node: NodeId) -> &'static str {
636 let in_deg = self.in_degree(node);
638 let out_deg = self.out_degree(node);
639
640 if in_deg == 0 && out_deg > 0 {
641 "#90EE90" } else if out_deg == 0 && in_deg > 0 {
643 "#FFB6C1" } else if in_deg > 2 || out_deg > 2 {
645 "#87CEEB" } else {
647 "#FFFACD" }
649 }
650
651 fn edge_type_color(edge_type: &EdgeType) -> &'static str {
652 match edge_type {
653 EdgeType::Temporal => "#2E86AB", EdgeType::Spatial => "#A23B72", EdgeType::Causal => "#F18F01", EdgeType::Thematic => "#C73E1D", EdgeType::Reference => "#6B8E23", EdgeType::Custom => "#808080", }
660 }
661
662 fn edge_type_style(edge_type: &EdgeType) -> &'static str {
663 match edge_type {
664 EdgeType::Temporal => "solid",
665 EdgeType::Spatial => "dashed",
666 EdgeType::Causal => "bold",
667 EdgeType::Thematic => "dotted",
668 EdgeType::Reference => "solid",
669 EdgeType::Custom => "solid",
670 }
671 }
672}
673
674#[derive(Debug, Clone)]
676pub struct DotOptions {
677 pub rank_direction: String,
679 pub node_shape: String,
681 pub font_name: String,
683}
684
685impl Default for DotOptions {
686 fn default() -> Self {
687 Self {
688 rank_direction: "TB".to_string(),
689 node_shape: "box".to_string(),
690 font_name: "Arial".to_string(),
691 }
692 }
693}
694
695impl DotOptions {
696 pub fn timeline() -> Self {
698 Self {
699 rank_direction: "LR".to_string(),
700 node_shape: "box".to_string(),
701 font_name: "Arial".to_string(),
702 }
703 }
704
705 pub fn hierarchical() -> Self {
707 Self {
708 rank_direction: "TB".to_string(),
709 node_shape: "ellipse".to_string(),
710 font_name: "Arial".to_string(),
711 }
712 }
713}
714
715impl Default for NarrativeGraph {
716 fn default() -> Self {
717 Self::new()
718 }
719}
720
721#[derive(Debug, Clone)]
723pub struct PathInfo {
724 pub nodes: Vec<NodeId>,
726 pub total_weight: f64,
728}
729
730impl PathInfo {
731 pub fn len(&self) -> usize {
733 self.nodes.len()
734 }
735
736 pub fn is_empty(&self) -> bool {
738 self.nodes.is_empty()
739 }
740}
741
742#[derive(Debug)]
744pub struct SubgraphResult {
745 pub graph: NarrativeGraph,
747 pub node_mapping: HashMap<NodeId, NodeId>,
749}
750
751fn haversine_distance(loc1: &Location, loc2: &Location) -> f64 {
753 let r = 6371.0; let lat1 = loc1.lat.to_radians();
756 let lat2 = loc2.lat.to_radians();
757 let dlat = (loc2.lat - loc1.lat).to_radians();
758 let dlon = (loc2.lon - loc1.lon).to_radians();
759
760 let a = (dlat / 2.0).sin().powi(2) + lat1.cos() * lat2.cos() * (dlon / 2.0).sin().powi(2);
761 let c = 2.0 * a.sqrt().asin();
762
763 r * c
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769 use crate::core::Timestamp;
770
771 fn make_event(lat: f64, lon: f64, time: &str, text: &str) -> Event {
772 Event::new(
773 Location::new(lat, lon),
774 Timestamp::parse(time).unwrap(),
775 text,
776 )
777 }
778
779 #[test]
780 fn test_graph_new() {
781 let graph = NarrativeGraph::new();
782 assert!(graph.is_empty());
783 assert_eq!(graph.node_count(), 0);
784 assert_eq!(graph.edge_count(), 0);
785 }
786
787 #[test]
788 fn test_graph_add_event() {
789 let mut graph = NarrativeGraph::new();
790 let event = make_event(40.7128, -74.0060, "2024-01-01T12:00:00Z", "NYC Event");
791
792 let node = graph.add_event(event.clone());
793
794 assert_eq!(graph.node_count(), 1);
795 assert_eq!(graph.event(node).unwrap().text, "NYC Event");
796 }
797
798 #[test]
799 fn test_graph_connect() {
800 let mut graph = NarrativeGraph::new();
801 let n1 = graph.add_event(make_event(40.7, -74.0, "2024-01-01T10:00:00Z", "Event 1"));
802 let n2 = graph.add_event(make_event(40.7, -74.0, "2024-01-01T12:00:00Z", "Event 2"));
803
804 graph.connect(n1, n2, EdgeType::Temporal);
805
806 assert!(graph.are_connected(n1, n2));
807 assert!(!graph.are_connected(n2, n1)); assert_eq!(graph.edge_count(), 1);
809 }
810
811 #[test]
812 fn test_graph_successors_predecessors() {
813 let mut graph = NarrativeGraph::new();
814 let n1 = graph.add_event(make_event(40.7, -74.0, "2024-01-01T10:00:00Z", "Event 1"));
815 let n2 = graph.add_event(make_event(40.7, -74.0, "2024-01-01T12:00:00Z", "Event 2"));
816 let n3 = graph.add_event(make_event(40.7, -74.0, "2024-01-01T14:00:00Z", "Event 3"));
817
818 graph.connect(n1, n2, EdgeType::Temporal);
819 graph.connect(n1, n3, EdgeType::Temporal);
820
821 assert_eq!(graph.successors(n1).len(), 2);
822 assert_eq!(graph.predecessors(n2).len(), 1);
823 assert_eq!(graph.predecessors(n1).len(), 0);
824 }
825
826 #[test]
827 fn test_graph_connect_temporal() {
828 let mut graph = NarrativeGraph::new();
829 graph.add_event(make_event(40.7, -74.0, "2024-01-01T14:00:00Z", "Third"));
830 graph.add_event(make_event(40.7, -74.0, "2024-01-01T10:00:00Z", "First"));
831 graph.add_event(make_event(40.7, -74.0, "2024-01-01T12:00:00Z", "Second"));
832
833 graph.connect_temporal();
834
835 assert_eq!(graph.edge_count(), 2);
837 }
838
839 #[test]
840 fn test_graph_connect_thematic() {
841 let mut graph = NarrativeGraph::new();
842
843 let mut e1 = make_event(40.7, -74.0, "2024-01-01T10:00:00Z", "Event 1");
844 e1.add_tag("politics");
845 e1.add_tag("economy");
846
847 let mut e2 = make_event(40.7, -74.0, "2024-01-01T12:00:00Z", "Event 2");
848 e2.add_tag("politics");
849
850 let mut e3 = make_event(40.7, -74.0, "2024-01-01T14:00:00Z", "Event 3");
851 e3.add_tag("sports");
852
853 graph.add_event(e1);
854 graph.add_event(e2);
855 graph.add_event(e3);
856
857 graph.connect_thematic();
858
859 let thematic_edges = graph.edges_of_type(EdgeType::Thematic);
861 assert_eq!(thematic_edges.len(), 2); }
863
864 #[test]
865 fn test_graph_roots_leaves() {
866 let mut graph = NarrativeGraph::new();
867 let n1 = graph.add_event(make_event(40.7, -74.0, "2024-01-01T10:00:00Z", "Root"));
868 let n2 = graph.add_event(make_event(40.7, -74.0, "2024-01-01T12:00:00Z", "Middle"));
869 let n3 = graph.add_event(make_event(40.7, -74.0, "2024-01-01T14:00:00Z", "Leaf"));
870
871 graph.connect(n1, n2, EdgeType::Temporal);
872 graph.connect(n2, n3, EdgeType::Temporal);
873
874 let roots = graph.roots();
875 let leaves = graph.leaves();
876
877 assert_eq!(roots.len(), 1);
878 assert_eq!(roots[0], n1);
879 assert_eq!(leaves.len(), 1);
880 assert_eq!(leaves[0], n3);
881 }
882
883 #[test]
884 fn test_haversine_distance() {
885 let nyc = Location::new(40.7128, -74.0060);
886 let la = Location::new(34.0522, -118.2437);
887
888 let distance = haversine_distance(&nyc, &la);
889
890 assert!(distance > 3900.0 && distance < 4000.0);
892 }
893}