sara_core/graph/
knowledge_graph.rs

1//! Knowledge graph implementation using petgraph.
2
3use petgraph::Direction;
4use petgraph::graph::{DiGraph, NodeIndex};
5use petgraph::visit::EdgeRef;
6use std::collections::HashMap;
7
8use crate::model::{Item, ItemId, ItemType, RelationshipType};
9
10/// The main knowledge graph container.
11#[derive(Debug)]
12pub struct KnowledgeGraph {
13    /// The underlying directed graph.
14    graph: DiGraph<Item, RelationshipType>,
15
16    /// Index for O(1) lookup by ItemId.
17    index: HashMap<ItemId, NodeIndex>,
18
19    /// Validation mode (strict orphan checking).
20    strict_mode: bool,
21}
22
23impl KnowledgeGraph {
24    /// Creates a new empty knowledge graph.
25    pub fn new(strict_mode: bool) -> Self {
26        Self {
27            graph: DiGraph::new(),
28            index: HashMap::new(),
29            strict_mode,
30        }
31    }
32
33    /// Returns whether strict mode is enabled.
34    pub fn is_strict_mode(&self) -> bool {
35        self.strict_mode
36    }
37
38    /// Returns the number of items in the graph.
39    pub fn item_count(&self) -> usize {
40        self.graph.node_count()
41    }
42
43    /// Returns the number of relationships in the graph.
44    pub fn relationship_count(&self) -> usize {
45        self.graph.edge_count()
46    }
47
48    /// Adds an item to the graph.
49    pub fn add_item(&mut self, item: Item) -> NodeIndex {
50        let id = item.id.clone();
51        let idx = self.graph.add_node(item);
52        self.index.insert(id, idx);
53        idx
54    }
55
56    /// Adds a relationship between two items.
57    pub fn add_relationship(
58        &mut self,
59        from: &ItemId,
60        to: &ItemId,
61        rel_type: RelationshipType,
62    ) -> Option<()> {
63        let from_idx = self.index.get(from)?;
64        let to_idx = self.index.get(to)?;
65        self.graph.add_edge(*from_idx, *to_idx, rel_type);
66        Some(())
67    }
68
69    /// Gets an item by ID.
70    pub fn get(&self, id: &ItemId) -> Option<&Item> {
71        let idx = self.index.get(id)?;
72        self.graph.node_weight(*idx)
73    }
74
75    /// Gets a mutable reference to an item by ID.
76    pub fn get_mut(&mut self, id: &ItemId) -> Option<&mut Item> {
77        let idx = self.index.get(id)?;
78        self.graph.node_weight_mut(*idx)
79    }
80
81    /// Checks if an item exists in the graph.
82    pub fn contains(&self, id: &ItemId) -> bool {
83        self.index.contains_key(id)
84    }
85
86    /// Returns all items in the graph.
87    pub fn items(&self) -> impl Iterator<Item = &Item> {
88        self.graph.node_weights()
89    }
90
91    /// Returns all item IDs in the graph.
92    pub fn item_ids(&self) -> impl Iterator<Item = &ItemId> {
93        self.index.keys()
94    }
95
96    /// Returns all items of a specific type.
97    pub fn items_by_type(&self, item_type: ItemType) -> Vec<&Item> {
98        self.graph
99            .node_weights()
100            .filter(|item| item.item_type == item_type)
101            .collect()
102    }
103
104    /// Returns the count of items by type.
105    pub fn count_by_type(&self) -> HashMap<ItemType, usize> {
106        let mut counts = HashMap::new();
107        for item in self.graph.node_weights() {
108            *counts.entry(item.item_type).or_insert(0) += 1;
109        }
110        counts
111    }
112
113    /// Returns direct parents of an item (items that this item relates to upstream).
114    pub fn parents(&self, id: &ItemId) -> Vec<&Item> {
115        let Some(idx) = self.index.get(id) else {
116            return Vec::new();
117        };
118
119        self.graph
120            .edges_directed(*idx, Direction::Outgoing)
121            .filter(|edge| edge.weight().is_upstream())
122            .filter_map(|edge| self.graph.node_weight(edge.target()))
123            .collect()
124    }
125
126    /// Returns direct children of an item (items that relate to this item downstream).
127    pub fn children(&self, id: &ItemId) -> Vec<&Item> {
128        let Some(idx) = self.index.get(id) else {
129            return Vec::new();
130        };
131
132        self.graph
133            .edges_directed(*idx, Direction::Incoming)
134            .filter(|edge| edge.weight().is_upstream())
135            .filter_map(|edge| self.graph.node_weight(edge.source()))
136            .collect()
137    }
138
139    /// Returns all items with no upstream parents (potential orphans).
140    pub fn orphans(&self) -> Vec<&Item> {
141        self.graph
142            .node_weights()
143            .filter(|item| {
144                // Solutions are allowed to have no parents
145                if item.item_type.is_root() {
146                    return false;
147                }
148                // Check if item has any upstream references
149                item.upstream.is_empty()
150            })
151            .collect()
152    }
153
154    /// Returns the underlying petgraph for advanced operations.
155    pub fn inner(&self) -> &DiGraph<Item, RelationshipType> {
156        &self.graph
157    }
158
159    /// Returns a mutable reference to the underlying petgraph.
160    pub fn inner_mut(&mut self) -> &mut DiGraph<Item, RelationshipType> {
161        &mut self.graph
162    }
163
164    /// Returns the node index for an item ID.
165    pub fn node_index(&self, id: &ItemId) -> Option<NodeIndex> {
166        self.index.get(id).copied()
167    }
168
169    /// Checks if the graph has cycles.
170    pub fn has_cycles(&self) -> bool {
171        petgraph::algo::is_cyclic_directed(&self.graph)
172    }
173
174    /// Returns all relationships in the graph.
175    pub fn relationships(&self) -> Vec<(ItemId, ItemId, RelationshipType)> {
176        self.graph
177            .edge_references()
178            .filter_map(|edge| {
179                let from = self.graph.node_weight(edge.source())?;
180                let to = self.graph.node_weight(edge.target())?;
181                Some((from.id.clone(), to.id.clone(), *edge.weight()))
182            })
183            .collect()
184    }
185}
186
187impl Default for KnowledgeGraph {
188    fn default() -> Self {
189        Self::new(false)
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::model::{ItemBuilder, SourceLocation};
197    use std::path::PathBuf;
198
199    fn create_test_item(id: &str, item_type: ItemType) -> Item {
200        let source = SourceLocation::new(PathBuf::from("/repo"), format!("{}.md", id));
201        let mut builder = ItemBuilder::new()
202            .id(ItemId::new_unchecked(id))
203            .item_type(item_type)
204            .name(format!("Test {}", id))
205            .source(source);
206
207        if item_type.requires_specification() {
208            builder = builder.specification("Test specification");
209        }
210
211        builder.build().unwrap()
212    }
213
214    #[test]
215    fn test_add_and_get_item() {
216        let mut graph = KnowledgeGraph::new(false);
217        let item = create_test_item("SOL-001", ItemType::Solution);
218        graph.add_item(item);
219
220        let id = ItemId::new_unchecked("SOL-001");
221        assert!(graph.contains(&id));
222        assert_eq!(graph.get(&id).unwrap().name, "Test SOL-001");
223    }
224
225    #[test]
226    fn test_items_by_type() {
227        let mut graph = KnowledgeGraph::new(false);
228        graph.add_item(create_test_item("SOL-001", ItemType::Solution));
229        graph.add_item(create_test_item("UC-001", ItemType::UseCase));
230        graph.add_item(create_test_item("UC-002", ItemType::UseCase));
231
232        let solutions = graph.items_by_type(ItemType::Solution);
233        assert_eq!(solutions.len(), 1);
234
235        let use_cases = graph.items_by_type(ItemType::UseCase);
236        assert_eq!(use_cases.len(), 2);
237    }
238
239    #[test]
240    fn test_item_count() {
241        let mut graph = KnowledgeGraph::new(false);
242        assert_eq!(graph.item_count(), 0);
243
244        graph.add_item(create_test_item("SOL-001", ItemType::Solution));
245        assert_eq!(graph.item_count(), 1);
246
247        graph.add_item(create_test_item("UC-001", ItemType::UseCase));
248        assert_eq!(graph.item_count(), 2);
249    }
250}