Skip to main content

xz_knowledge_graph/store/
memory.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2use std::sync::RwLock;
3
4use crate::config::KgConfig;
5use crate::error::KgError;
6use crate::traits::KnowledgeGraph;
7use crate::types::consistency::{ConsistencyIssue, ConsistencyIssueType, IssueSeverity};
8use crate::types::entity::Entity;
9use crate::types::graph::{GraphStats, PathStep, SubGraph};
10use crate::types::import::{ImportResult, MergeStrategy, UpsertResult};
11use crate::types::query::{EntityPage, EntityQuery, RelationQuery};
12use crate::types::relation::{Relation, WeightStrategy};
13
14/// In-memory knowledge graph implementation (for testing).
15#[derive(Debug)]
16pub struct InMemoryKnowledgeGraph {
17    entities: RwLock<HashMap<String, Entity>>,
18    relations: RwLock<HashMap<String, Relation>>,
19    #[allow(dead_code)]
20    merge_strategy: MergeStrategy,
21    weight_strategy: WeightStrategy,
22    max_bfs_depth: u32,
23    max_path_search: u32,
24}
25
26impl InMemoryKnowledgeGraph {
27    pub fn new(config: KgConfig) -> Self {
28        Self {
29            entities: RwLock::new(HashMap::new()),
30            relations: RwLock::new(HashMap::new()),
31            merge_strategy: config.merge_strategy,
32            weight_strategy: config.weight_strategy,
33            max_bfs_depth: config.max_bfs_depth,
34            max_path_search: config.max_path_search,
35        }
36    }
37}
38
39#[async_trait::async_trait]
40impl KnowledgeGraph for InMemoryKnowledgeGraph {
41    async fn upsert_entity(&self, entity: Entity) -> Result<UpsertResult, KgError> {
42        let mut entities = self.entities.write().unwrap();
43        if let Some(_existing) = entities.get(&entity.id) {
44            entities.insert(entity.id.clone(), entity);
45            Ok(UpsertResult::Updated {
46                changed_fields: vec!["*".into()],
47                conflicts: vec![],
48            })
49        } else {
50            entities.insert(entity.id.clone(), entity);
51            Ok(UpsertResult::Created)
52        }
53    }
54
55    async fn get_entity(&self, id: &str) -> Result<Option<Entity>, KgError> {
56        Ok(self.entities.read().unwrap().get(id).cloned())
57    }
58
59    async fn search_entities(&self, query: &EntityQuery) -> Result<EntityPage, KgError> {
60        let entities = self.entities.read().unwrap();
61        let mut items: Vec<Entity> = entities.values().cloned().collect();
62
63        // Filter by name_contains
64        if let Some(ref name) = query.name_contains {
65            let name = name.to_lowercase();
66            items.retain(|e| e.name.to_lowercase().contains(&name));
67        }
68        // Filter by entity_types
69        if let Some(ref types) = query.entity_types {
70            if !types.is_empty() {
71                items.retain(|e| types.contains(&e.entity_type));
72            }
73        }
74        // Filter by source
75        if let Some(ref source) = query.source {
76            items.retain(|e| e.source.as_deref() == Some(source.as_str()));
77        }
78
79        let total = items.len();
80        items.sort_by_key(|e| std::cmp::Reverse(e.updated_at));
81
82        let page = &query.page;
83        let has_more = page.offset + page.limit < total;
84        let items = items
85            .into_iter()
86            .skip(page.offset)
87            .take(page.limit)
88            .collect();
89
90        Ok(EntityPage { items, total, has_more })
91    }
92
93    async fn delete_entity(&self, id: &str) -> Result<usize, KgError> {
94        let mut entities = self.entities.write().unwrap();
95        let mut relations = self.relations.write().unwrap();
96
97        let relation_count = relations
98            .values()
99            .filter(|r| r.source_id == id || r.target_id == id)
100            .count();
101
102        relations.retain(|_, r| r.source_id != id && r.target_id != id);
103        entities.remove(id);
104
105        Ok(relation_count)
106    }
107
108    async fn get_entities_batch(&self, ids: &[&str]) -> Result<Vec<Entity>, KgError> {
109        let entities = self.entities.read().unwrap();
110        Ok(ids.iter().filter_map(|id| entities.get(*id).cloned()).collect())
111    }
112
113    // === Relation Operations ===
114
115    async fn upsert_relation(&self, relation: Relation) -> Result<UpsertResult, KgError> {
116        let mut relations = self.relations.write().unwrap();
117        let existed = relations.contains_key(&relation.id);
118        relations.insert(relation.id.clone(), relation);
119        if existed {
120            Ok(UpsertResult::Updated {
121                changed_fields: vec!["*".into()],
122                conflicts: vec![],
123            })
124        } else {
125            Ok(UpsertResult::Created)
126        }
127    }
128
129    async fn get_relations(&self, entity_id: &str) -> Result<Vec<Relation>, KgError> {
130        let relations = self.relations.read().unwrap();
131        Ok(relations
132            .values()
133            .filter(|r| r.source_id == entity_id || r.target_id == entity_id)
134            .cloned()
135            .collect())
136    }
137
138    async fn query_relations(&self, query: &RelationQuery) -> Result<Vec<Relation>, KgError> {
139        let relations = self.relations.read().unwrap();
140        let mut items: Vec<Relation> = relations.values().cloned().collect();
141
142        if let Some(ref sid) = query.source_id {
143            items.retain(|r| &r.source_id == sid);
144        }
145        if let Some(ref tid) = query.target_id {
146            items.retain(|r| &r.target_id == tid);
147        }
148        if let Some(ref eid) = query.entity_id {
149            items.retain(|r| &r.source_id == eid || &r.target_id == eid);
150        }
151        if let Some(ref rt) = query.relation_type {
152            items.retain(|r| &r.relation_type == rt);
153        }
154
155        Ok(items)
156    }
157
158    async fn delete_relation(&self, id: &str) -> Result<(), KgError> {
159        let mut relations = self.relations.write().unwrap();
160        if relations.remove(id).is_none() {
161            return Err(KgError::RelationNotFound(id.to_string()));
162        }
163        Ok(())
164    }
165
166    // === Graph Traversal ===
167
168    async fn get_neighbors(&self, entity_id: &str, depth: u32) -> Result<SubGraph, KgError> {
169        if depth > self.max_bfs_depth {
170            return Err(KgError::MaxDepthExceeded {
171                depth,
172                max: self.max_bfs_depth,
173            });
174        }
175
176        let entities = self.entities.read().unwrap();
177        let relations = self.relations.read().unwrap();
178
179        let center = entities
180            .get(entity_id)
181            .cloned()
182            .ok_or_else(|| KgError::EntityNotFound(entity_id.to_string()))?;
183
184        let mut visited_entities: HashMap<String, Entity> = HashMap::new();
185        let mut visited_relations: Vec<Relation> = Vec::new();
186        let mut queue: VecDeque<(String, u32)> = VecDeque::new();
187
188        visited_entities.insert(entity_id.to_string(), center.clone());
189        queue.push_back((entity_id.to_string(), 0));
190
191        while let Some((current_id, current_depth)) = queue.pop_front() {
192            if current_depth >= depth {
193                continue;
194            }
195
196            let neighbors: Vec<Relation> = relations
197                .values()
198                .filter(|r| r.source_id == current_id || r.target_id == current_id)
199                .cloned()
200                .collect();
201
202            for rel in neighbors {
203                let neighbor_id = if rel.source_id == current_id {
204                    rel.target_id.clone()
205                } else {
206                    rel.source_id.clone()
207                };
208
209                visited_relations.push(rel);
210
211                if !visited_entities.contains_key(&neighbor_id) {
212                    if let Some(entity) = entities.get(&neighbor_id).cloned() {
213                        visited_entities.insert(neighbor_id.clone(), entity);
214                        queue.push_back((neighbor_id, current_depth + 1));
215                    }
216                }
217            }
218        }
219
220        let result_entities: Vec<Entity> = visited_entities
221            .into_iter()
222            .filter(|(id, _)| id != entity_id)
223            .map(|(_, e)| e)
224            .collect();
225
226        Ok(SubGraph {
227            center,
228            entities: result_entities,
229            relations: visited_relations,
230        })
231    }
232
233    async fn shortest_path(
234        &self,
235        from: &str,
236        to: &str,
237    ) -> Result<Option<Vec<PathStep>>, KgError> {
238        if from == to {
239            return Ok(Some(vec![]));
240        }
241
242        let entities = self.entities.read().unwrap();
243        let relations = self.relations.read().unwrap();
244
245        // Build adjacency
246        let mut adj: HashMap<String, Vec<(String, Relation)>> = HashMap::new();
247        for rel in relations.values() {
248            adj.entry(rel.source_id.clone())
249                .or_default()
250                .push((rel.target_id.clone(), rel.clone()));
251            adj.entry(rel.target_id.clone())
252                .or_default()
253                .push((rel.source_id.clone(), rel.clone()));
254        }
255
256        let mut dist: HashMap<String, f32> = HashMap::new();
257        let mut prev: HashMap<String, (String, Relation)> = HashMap::new();
258
259        for id in entities.keys() {
260            dist.insert(id.clone(), f32::MAX);
261        }
262        dist.insert(from.to_string(), 0.0);
263
264        let mut queue: Vec<(f32, String)> = vec![(0.0, from.to_string())];
265
266        while let Some((_d, u)) = queue.pop() {
267            if let Some(neighbors) = adj.get(&u) {
268                for (v, rel) in neighbors {
269                    let weight = self.weight_strategy.relation_cost(rel);
270                    let alt = dist.get(&u).copied().unwrap_or(f32::MAX) + weight;
271                    if alt < dist.get(v).copied().unwrap_or(f32::MAX) {
272                        dist.insert(v.clone(), alt);
273                        prev.insert(v.clone(), (u.clone(), rel.clone()));
274                        queue.push((-alt, v.clone()));
275                    }
276                }
277            }
278            queue.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
279        }
280
281        if !prev.contains_key(to) && from != to {
282            return Ok(None);
283        }
284
285        let mut path = Vec::new();
286        let mut current = to.to_string();
287        while current != from {
288            if let Some((prev_node, rel)) = prev.get(&current) {
289                let entity = entities.get(&current).cloned().unwrap();
290                path.push(PathStep { entity, relation: rel.clone() });
291                current = prev_node.clone();
292            } else {
293                break;
294            }
295        }
296        path.reverse();
297        Ok(Some(path))
298    }
299
300    async fn all_paths(
301        &self,
302        from: &str,
303        to: &str,
304        max_depth: u32,
305    ) -> Result<Vec<Vec<PathStep>>, KgError> {
306        if max_depth > self.max_path_search {
307            return Err(KgError::MaxDepthExceeded {
308                depth: max_depth,
309                max: self.max_path_search,
310            });
311        }
312
313        let entities = self.entities.read().unwrap();
314        let relations = self.relations.read().unwrap();
315
316        let mut adj: HashMap<String, Vec<(String, Relation)>> = HashMap::new();
317        for rel in relations.values() {
318            adj.entry(rel.source_id.clone())
319                .or_default()
320                .push((rel.target_id.clone(), rel.clone()));
321            adj.entry(rel.target_id.clone())
322                .or_default()
323                .push((rel.source_id.clone(), rel.clone()));
324        }
325
326        let mut all_paths = Vec::new();
327        let mut visited = HashSet::new();
328        let mut current_path = Vec::new();
329
330        dfs_memory(from, to, max_depth, &entities, &adj, &mut visited, &mut current_path, &mut all_paths);
331
332        all_paths.sort_by(|a, b| {
333            let a_cost: f32 = a.iter().map(|step| self.weight_strategy.relation_cost(&step.relation)).sum();
334            let b_cost: f32 = b.iter().map(|step| self.weight_strategy.relation_cost(&step.relation)).sum();
335            a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
336        });
337
338        Ok(all_paths)
339    }
340
341    // === Batch Operations ===
342
343    async fn batch_import(
344        &self,
345        entities: Vec<Entity>,
346        relations: Vec<Relation>,
347    ) -> Result<ImportResult, KgError> {
348        let mut result = ImportResult::default();
349
350        for entity in entities {
351            match self.upsert_entity(entity).await? {
352                UpsertResult::Created => result.entities_created += 1,
353                UpsertResult::Updated { .. } => result.entities_updated += 1,
354                UpsertResult::Unchanged => result.entities_skipped += 1,
355            }
356        }
357        for rel in relations {
358            match self.upsert_relation(rel).await? {
359                UpsertResult::Created => result.relations_created += 1,
360                UpsertResult::Updated { .. } => result.relations_updated += 1,
361                _ => {}
362            }
363        }
364
365        Ok(result)
366    }
367
368    // === Consistency ===
369
370    async fn check_consistency(&self) -> Result<Vec<ConsistencyIssue>, KgError> {
371        let entities = self.entities.read().unwrap();
372        let relations = self.relations.read().unwrap();
373        let mut issues = Vec::new();
374
375        // Orphan relations
376        for rel in relations.values() {
377            if !entities.contains_key(&rel.source_id) || !entities.contains_key(&rel.target_id) {
378                issues.push(ConsistencyIssue {
379                    severity: IssueSeverity::Error,
380                    issue_type: ConsistencyIssueType::OrphanRelation,
381                    description: format!("Relation {} references a non-existent entity", rel.id),
382                    related_entities: vec![rel.source_id.clone(), rel.target_id.clone()],
383                    related_relations: vec![rel.id.clone()],
384                });
385            }
386        }
387
388        // Self-referencing
389        for rel in relations.values() {
390            if rel.source_id == rel.target_id {
391                issues.push(ConsistencyIssue {
392                    severity: IssueSeverity::Warning,
393                    issue_type: ConsistencyIssueType::SelfReferencing,
394                    description: format!("Relation {} self-references entity {}", rel.id, rel.source_id),
395                    related_entities: vec![rel.source_id.clone()],
396                    related_relations: vec![rel.id.clone()],
397                });
398            }
399        }
400
401        // Orphan entities
402        for (id, entity) in entities.iter() {
403            let has_relation = relations.values().any(|r| r.source_id == *id || r.target_id == *id);
404            if !has_relation {
405                issues.push(ConsistencyIssue {
406                    severity: IssueSeverity::Info,
407                    issue_type: ConsistencyIssueType::OrphanEntity,
408                    description: format!("Entity {} ({}) has no relations", entity.name, id),
409                    related_entities: vec![id.clone()],
410                    related_relations: vec![],
411                });
412            }
413        }
414
415        // Expired relations
416        let now = current_epoch_ms();
417        for rel in relations.values() {
418            if let Some(valid_to) = rel.valid_to {
419                if valid_to < now {
420                    issues.push(ConsistencyIssue {
421                        severity: IssueSeverity::Warning,
422                        issue_type: ConsistencyIssueType::ExpiredRelation,
423                        description: format!("Relation {} has expired (valid_to < now)", rel.id),
424                        related_entities: vec![rel.source_id.clone(), rel.target_id.clone()],
425                        related_relations: vec![rel.id.clone()],
426                    });
427                }
428            }
429        }
430
431        Ok(issues)
432    }
433
434    // === Statistics ===
435
436    async fn stats(&self) -> Result<GraphStats, KgError> {
437        let entities = self.entities.read().unwrap();
438        let relations = self.relations.read().unwrap();
439
440        let mut entity_types: HashMap<String, usize> = HashMap::new();
441        for e in entities.values() {
442            *entity_types.entry(e.entity_type.as_str()).or_default() += 1;
443        }
444
445        let mut relation_types: HashMap<String, usize> = HashMap::new();
446        for r in relations.values() {
447            *relation_types.entry(r.relation_type.clone()).or_default() += 1;
448        }
449
450        let mut degrees: HashMap<String, usize> = HashMap::new();
451        for r in relations.values() {
452            *degrees.entry(r.source_id.clone()).or_default() += 1;
453            *degrees.entry(r.target_id.clone()).or_default() += 1;
454        }
455
456        let degree_values: Vec<usize> = degrees.values().copied().collect();
457        let avg_degree = if degree_values.is_empty() {
458            0.0
459        } else {
460            degree_values.iter().sum::<usize>() as f64 / degree_values.len() as f64
461        };
462        let max_degree = degree_values.iter().max().copied().unwrap_or(0);
463
464        let orphan_entities = entities
465            .keys()
466            .filter(|id| !relations.values().any(|r| &&r.source_id == id || &&r.target_id == id))
467            .count();
468
469        Ok(GraphStats {
470            total_entities: entities.len(),
471            total_relations: relations.len(),
472            entity_types,
473            relation_types,
474            avg_degree,
475            max_degree,
476            orphan_entities,
477            db_size_bytes: 0,
478        })
479    }
480}
481
482fn dfs_memory(
483    current: &str,
484    target: &str,
485    max_depth: u32,
486    entities: &HashMap<String, Entity>,
487    adj: &HashMap<String, Vec<(String, Relation)>>,
488    visited: &mut HashSet<String>,
489    current_path: &mut Vec<PathStep>,
490    all_paths: &mut Vec<Vec<PathStep>>,
491) {
492    if current == target {
493        all_paths.push(current_path.clone());
494        return;
495    }
496    if current_path.len() >= max_depth as usize {
497        return;
498    }
499    visited.insert(current.to_string());
500
501    if let Some(neighbors) = adj.get(current) {
502        for (neighbor, rel) in neighbors {
503            if visited.contains(neighbor.as_str()) {
504                continue;
505            }
506            if let Some(entity) = entities.get(neighbor).cloned() {
507                current_path.push(PathStep {
508                    entity,
509                    relation: rel.clone(),
510                });
511                dfs_memory(neighbor, target, max_depth, entities, adj, visited, current_path, all_paths);
512                current_path.pop();
513            }
514        }
515    }
516
517    visited.remove(current);
518}
519
520fn current_epoch_ms() -> u64 {
521    std::time::SystemTime::now()
522        .duration_since(std::time::UNIX_EPOCH)
523        .unwrap_or_default()
524        .as_millis() as u64
525}