sara_core/graph/
knowledge_graph.rs1use petgraph::Direction;
4use petgraph::graph::{DiGraph, NodeIndex};
5use petgraph::visit::EdgeRef;
6use std::collections::HashMap;
7
8use crate::model::{Item, ItemId, ItemType, RelationshipType};
9
10#[derive(Debug)]
12pub struct KnowledgeGraph {
13 graph: DiGraph<Item, RelationshipType>,
15
16 index: HashMap<ItemId, NodeIndex>,
18
19 strict_mode: bool,
21}
22
23impl KnowledgeGraph {
24 pub fn new(strict_mode: bool) -> Self {
26 Self {
27 graph: DiGraph::new(),
28 index: HashMap::new(),
29 strict_mode,
30 }
31 }
32
33 pub fn is_strict_mode(&self) -> bool {
35 self.strict_mode
36 }
37
38 pub fn item_count(&self) -> usize {
40 self.graph.node_count()
41 }
42
43 pub fn relationship_count(&self) -> usize {
45 self.graph.edge_count()
46 }
47
48 pub fn add_item(&mut self, item: Item) -> NodeIndex {
50 let id = item.id.clone();
51 let idx = self.graph.add_node(item);
52 self.index.insert(id, idx);
53 idx
54 }
55
56 pub fn add_relationship(
58 &mut self,
59 from: &ItemId,
60 to: &ItemId,
61 rel_type: RelationshipType,
62 ) -> Option<()> {
63 let from_idx = self.index.get(from)?;
64 let to_idx = self.index.get(to)?;
65 self.graph.add_edge(*from_idx, *to_idx, rel_type);
66 Some(())
67 }
68
69 pub fn get(&self, id: &ItemId) -> Option<&Item> {
71 let idx = self.index.get(id)?;
72 self.graph.node_weight(*idx)
73 }
74
75 pub fn get_mut(&mut self, id: &ItemId) -> Option<&mut Item> {
77 let idx = self.index.get(id)?;
78 self.graph.node_weight_mut(*idx)
79 }
80
81 pub fn contains(&self, id: &ItemId) -> bool {
83 self.index.contains_key(id)
84 }
85
86 pub fn items(&self) -> impl Iterator<Item = &Item> {
88 self.graph.node_weights()
89 }
90
91 pub fn item_ids(&self) -> impl Iterator<Item = &ItemId> {
93 self.index.keys()
94 }
95
96 pub fn items_by_type(&self, item_type: ItemType) -> Vec<&Item> {
98 self.graph
99 .node_weights()
100 .filter(|item| item.item_type == item_type)
101 .collect()
102 }
103
104 pub fn count_by_type(&self) -> HashMap<ItemType, usize> {
106 let mut counts = HashMap::new();
107 for item in self.graph.node_weights() {
108 *counts.entry(item.item_type).or_insert(0) += 1;
109 }
110 counts
111 }
112
113 pub fn parents(&self, id: &ItemId) -> Vec<&Item> {
115 let Some(idx) = self.index.get(id) else {
116 return Vec::new();
117 };
118
119 self.graph
120 .edges_directed(*idx, Direction::Outgoing)
121 .filter(|edge| edge.weight().is_upstream())
122 .filter_map(|edge| self.graph.node_weight(edge.target()))
123 .collect()
124 }
125
126 pub fn children(&self, id: &ItemId) -> Vec<&Item> {
128 let Some(idx) = self.index.get(id) else {
129 return Vec::new();
130 };
131
132 self.graph
133 .edges_directed(*idx, Direction::Incoming)
134 .filter(|edge| edge.weight().is_upstream())
135 .filter_map(|edge| self.graph.node_weight(edge.source()))
136 .collect()
137 }
138
139 pub fn orphans(&self) -> Vec<&Item> {
141 self.graph
142 .node_weights()
143 .filter(|item| {
144 if item.item_type.is_root() {
146 return false;
147 }
148 item.upstream.is_empty()
150 })
151 .collect()
152 }
153
154 pub fn inner(&self) -> &DiGraph<Item, RelationshipType> {
156 &self.graph
157 }
158
159 pub fn inner_mut(&mut self) -> &mut DiGraph<Item, RelationshipType> {
161 &mut self.graph
162 }
163
164 pub fn node_index(&self, id: &ItemId) -> Option<NodeIndex> {
166 self.index.get(id).copied()
167 }
168
169 pub fn has_cycles(&self) -> bool {
171 petgraph::algo::is_cyclic_directed(&self.graph)
172 }
173
174 pub fn relationships(&self) -> Vec<(ItemId, ItemId, RelationshipType)> {
176 self.graph
177 .edge_references()
178 .filter_map(|edge| {
179 let from = self.graph.node_weight(edge.source())?;
180 let to = self.graph.node_weight(edge.target())?;
181 Some((from.id.clone(), to.id.clone(), *edge.weight()))
182 })
183 .collect()
184 }
185}
186
187impl Default for KnowledgeGraph {
188 fn default() -> Self {
189 Self::new(false)
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::model::{ItemBuilder, SourceLocation};
197 use std::path::PathBuf;
198
199 fn create_test_item(id: &str, item_type: ItemType) -> Item {
200 let source = SourceLocation::new(PathBuf::from("/repo"), format!("{}.md", id));
201 let mut builder = ItemBuilder::new()
202 .id(ItemId::new_unchecked(id))
203 .item_type(item_type)
204 .name(format!("Test {}", id))
205 .source(source);
206
207 if item_type.requires_specification() {
208 builder = builder.specification("Test specification");
209 }
210
211 builder.build().unwrap()
212 }
213
214 #[test]
215 fn test_add_and_get_item() {
216 let mut graph = KnowledgeGraph::new(false);
217 let item = create_test_item("SOL-001", ItemType::Solution);
218 graph.add_item(item);
219
220 let id = ItemId::new_unchecked("SOL-001");
221 assert!(graph.contains(&id));
222 assert_eq!(graph.get(&id).unwrap().name, "Test SOL-001");
223 }
224
225 #[test]
226 fn test_items_by_type() {
227 let mut graph = KnowledgeGraph::new(false);
228 graph.add_item(create_test_item("SOL-001", ItemType::Solution));
229 graph.add_item(create_test_item("UC-001", ItemType::UseCase));
230 graph.add_item(create_test_item("UC-002", ItemType::UseCase));
231
232 let solutions = graph.items_by_type(ItemType::Solution);
233 assert_eq!(solutions.len(), 1);
234
235 let use_cases = graph.items_by_type(ItemType::UseCase);
236 assert_eq!(use_cases.len(), 2);
237 }
238
239 #[test]
240 fn test_item_count() {
241 let mut graph = KnowledgeGraph::new(false);
242 assert_eq!(graph.item_count(), 0);
243
244 graph.add_item(create_test_item("SOL-001", ItemType::Solution));
245 assert_eq!(graph.item_count(), 1);
246
247 graph.add_item(create_test_item("UC-001", ItemType::UseCase));
248 assert_eq!(graph.item_count(), 2);
249 }
250}