1use std::collections::{HashMap, HashSet, VecDeque};
2use std::sync::RwLock;
3
4use crate::config::KgConfig;
5use crate::error::KgError;
6use crate::traits::KnowledgeGraph;
7use crate::types::consistency::{ConsistencyIssue, ConsistencyIssueType, IssueSeverity};
8use crate::types::entity::Entity;
9use crate::types::graph::{GraphStats, PathStep, SubGraph};
10use crate::types::import::{ImportResult, MergeStrategy, UpsertResult};
11use crate::types::query::{EntityPage, EntityQuery, RelationQuery};
12use crate::types::relation::{Relation, WeightStrategy};
13
14#[derive(Debug)]
16pub struct InMemoryKnowledgeGraph {
17 entities: RwLock<HashMap<String, Entity>>,
18 relations: RwLock<HashMap<String, Relation>>,
19 #[allow(dead_code)]
20 merge_strategy: MergeStrategy,
21 weight_strategy: WeightStrategy,
22 max_bfs_depth: u32,
23 max_path_search: u32,
24}
25
26impl InMemoryKnowledgeGraph {
27 pub fn new(config: KgConfig) -> Self {
28 Self {
29 entities: RwLock::new(HashMap::new()),
30 relations: RwLock::new(HashMap::new()),
31 merge_strategy: config.merge_strategy,
32 weight_strategy: config.weight_strategy,
33 max_bfs_depth: config.max_bfs_depth,
34 max_path_search: config.max_path_search,
35 }
36 }
37}
38
39#[async_trait::async_trait]
40impl KnowledgeGraph for InMemoryKnowledgeGraph {
41 async fn upsert_entity(&self, entity: Entity) -> Result<UpsertResult, KgError> {
42 let mut entities = self.entities.write().unwrap();
43 if let Some(_existing) = entities.get(&entity.id) {
44 entities.insert(entity.id.clone(), entity);
45 Ok(UpsertResult::Updated {
46 changed_fields: vec!["*".into()],
47 conflicts: vec![],
48 })
49 } else {
50 entities.insert(entity.id.clone(), entity);
51 Ok(UpsertResult::Created)
52 }
53 }
54
55 async fn get_entity(&self, id: &str) -> Result<Option<Entity>, KgError> {
56 Ok(self.entities.read().unwrap().get(id).cloned())
57 }
58
59 async fn search_entities(&self, query: &EntityQuery) -> Result<EntityPage, KgError> {
60 let entities = self.entities.read().unwrap();
61 let mut items: Vec<Entity> = entities.values().cloned().collect();
62
63 if let Some(ref name) = query.name_contains {
65 let name = name.to_lowercase();
66 items.retain(|e| e.name.to_lowercase().contains(&name));
67 }
68 if let Some(ref types) = query.entity_types {
70 if !types.is_empty() {
71 items.retain(|e| types.contains(&e.entity_type));
72 }
73 }
74 if let Some(ref source) = query.source {
76 items.retain(|e| e.source.as_deref() == Some(source.as_str()));
77 }
78
79 let total = items.len();
80 items.sort_by_key(|e| std::cmp::Reverse(e.updated_at));
81
82 let page = &query.page;
83 let has_more = page.offset + page.limit < total;
84 let items = items
85 .into_iter()
86 .skip(page.offset)
87 .take(page.limit)
88 .collect();
89
90 Ok(EntityPage { items, total, has_more })
91 }
92
93 async fn delete_entity(&self, id: &str) -> Result<usize, KgError> {
94 let mut entities = self.entities.write().unwrap();
95 let mut relations = self.relations.write().unwrap();
96
97 let relation_count = relations
98 .values()
99 .filter(|r| r.source_id == id || r.target_id == id)
100 .count();
101
102 relations.retain(|_, r| r.source_id != id && r.target_id != id);
103 entities.remove(id);
104
105 Ok(relation_count)
106 }
107
108 async fn get_entities_batch(&self, ids: &[&str]) -> Result<Vec<Entity>, KgError> {
109 let entities = self.entities.read().unwrap();
110 Ok(ids.iter().filter_map(|id| entities.get(*id).cloned()).collect())
111 }
112
113 async fn upsert_relation(&self, relation: Relation) -> Result<UpsertResult, KgError> {
116 let mut relations = self.relations.write().unwrap();
117 let existed = relations.contains_key(&relation.id);
118 relations.insert(relation.id.clone(), relation);
119 if existed {
120 Ok(UpsertResult::Updated {
121 changed_fields: vec!["*".into()],
122 conflicts: vec![],
123 })
124 } else {
125 Ok(UpsertResult::Created)
126 }
127 }
128
129 async fn get_relations(&self, entity_id: &str) -> Result<Vec<Relation>, KgError> {
130 let relations = self.relations.read().unwrap();
131 Ok(relations
132 .values()
133 .filter(|r| r.source_id == entity_id || r.target_id == entity_id)
134 .cloned()
135 .collect())
136 }
137
138 async fn query_relations(&self, query: &RelationQuery) -> Result<Vec<Relation>, KgError> {
139 let relations = self.relations.read().unwrap();
140 let mut items: Vec<Relation> = relations.values().cloned().collect();
141
142 if let Some(ref sid) = query.source_id {
143 items.retain(|r| &r.source_id == sid);
144 }
145 if let Some(ref tid) = query.target_id {
146 items.retain(|r| &r.target_id == tid);
147 }
148 if let Some(ref eid) = query.entity_id {
149 items.retain(|r| &r.source_id == eid || &r.target_id == eid);
150 }
151 if let Some(ref rt) = query.relation_type {
152 items.retain(|r| &r.relation_type == rt);
153 }
154
155 Ok(items)
156 }
157
158 async fn delete_relation(&self, id: &str) -> Result<(), KgError> {
159 let mut relations = self.relations.write().unwrap();
160 if relations.remove(id).is_none() {
161 return Err(KgError::RelationNotFound(id.to_string()));
162 }
163 Ok(())
164 }
165
166 async fn get_neighbors(&self, entity_id: &str, depth: u32) -> Result<SubGraph, KgError> {
169 if depth > self.max_bfs_depth {
170 return Err(KgError::MaxDepthExceeded {
171 depth,
172 max: self.max_bfs_depth,
173 });
174 }
175
176 let entities = self.entities.read().unwrap();
177 let relations = self.relations.read().unwrap();
178
179 let center = entities
180 .get(entity_id)
181 .cloned()
182 .ok_or_else(|| KgError::EntityNotFound(entity_id.to_string()))?;
183
184 let mut visited_entities: HashMap<String, Entity> = HashMap::new();
185 let mut visited_relations: Vec<Relation> = Vec::new();
186 let mut queue: VecDeque<(String, u32)> = VecDeque::new();
187
188 visited_entities.insert(entity_id.to_string(), center.clone());
189 queue.push_back((entity_id.to_string(), 0));
190
191 while let Some((current_id, current_depth)) = queue.pop_front() {
192 if current_depth >= depth {
193 continue;
194 }
195
196 let neighbors: Vec<Relation> = relations
197 .values()
198 .filter(|r| r.source_id == current_id || r.target_id == current_id)
199 .cloned()
200 .collect();
201
202 for rel in neighbors {
203 let neighbor_id = if rel.source_id == current_id {
204 rel.target_id.clone()
205 } else {
206 rel.source_id.clone()
207 };
208
209 visited_relations.push(rel);
210
211 if !visited_entities.contains_key(&neighbor_id) {
212 if let Some(entity) = entities.get(&neighbor_id).cloned() {
213 visited_entities.insert(neighbor_id.clone(), entity);
214 queue.push_back((neighbor_id, current_depth + 1));
215 }
216 }
217 }
218 }
219
220 let result_entities: Vec<Entity> = visited_entities
221 .into_iter()
222 .filter(|(id, _)| id != entity_id)
223 .map(|(_, e)| e)
224 .collect();
225
226 Ok(SubGraph {
227 center,
228 entities: result_entities,
229 relations: visited_relations,
230 })
231 }
232
233 async fn shortest_path(
234 &self,
235 from: &str,
236 to: &str,
237 ) -> Result<Option<Vec<PathStep>>, KgError> {
238 if from == to {
239 return Ok(Some(vec![]));
240 }
241
242 let entities = self.entities.read().unwrap();
243 let relations = self.relations.read().unwrap();
244
245 let mut adj: HashMap<String, Vec<(String, Relation)>> = HashMap::new();
247 for rel in relations.values() {
248 adj.entry(rel.source_id.clone())
249 .or_default()
250 .push((rel.target_id.clone(), rel.clone()));
251 adj.entry(rel.target_id.clone())
252 .or_default()
253 .push((rel.source_id.clone(), rel.clone()));
254 }
255
256 let mut dist: HashMap<String, f32> = HashMap::new();
257 let mut prev: HashMap<String, (String, Relation)> = HashMap::new();
258
259 for id in entities.keys() {
260 dist.insert(id.clone(), f32::MAX);
261 }
262 dist.insert(from.to_string(), 0.0);
263
264 let mut queue: Vec<(f32, String)> = vec![(0.0, from.to_string())];
265
266 while let Some((_d, u)) = queue.pop() {
267 if let Some(neighbors) = adj.get(&u) {
268 for (v, rel) in neighbors {
269 let weight = self.weight_strategy.relation_cost(rel);
270 let alt = dist.get(&u).copied().unwrap_or(f32::MAX) + weight;
271 if alt < dist.get(v).copied().unwrap_or(f32::MAX) {
272 dist.insert(v.clone(), alt);
273 prev.insert(v.clone(), (u.clone(), rel.clone()));
274 queue.push((-alt, v.clone()));
275 }
276 }
277 }
278 queue.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
279 }
280
281 if !prev.contains_key(to) && from != to {
282 return Ok(None);
283 }
284
285 let mut path = Vec::new();
286 let mut current = to.to_string();
287 while current != from {
288 if let Some((prev_node, rel)) = prev.get(¤t) {
289 let entity = entities.get(¤t).cloned().unwrap();
290 path.push(PathStep { entity, relation: rel.clone() });
291 current = prev_node.clone();
292 } else {
293 break;
294 }
295 }
296 path.reverse();
297 Ok(Some(path))
298 }
299
300 async fn all_paths(
301 &self,
302 from: &str,
303 to: &str,
304 max_depth: u32,
305 ) -> Result<Vec<Vec<PathStep>>, KgError> {
306 if max_depth > self.max_path_search {
307 return Err(KgError::MaxDepthExceeded {
308 depth: max_depth,
309 max: self.max_path_search,
310 });
311 }
312
313 let entities = self.entities.read().unwrap();
314 let relations = self.relations.read().unwrap();
315
316 let mut adj: HashMap<String, Vec<(String, Relation)>> = HashMap::new();
317 for rel in relations.values() {
318 adj.entry(rel.source_id.clone())
319 .or_default()
320 .push((rel.target_id.clone(), rel.clone()));
321 adj.entry(rel.target_id.clone())
322 .or_default()
323 .push((rel.source_id.clone(), rel.clone()));
324 }
325
326 let mut all_paths = Vec::new();
327 let mut visited = HashSet::new();
328 let mut current_path = Vec::new();
329
330 dfs_memory(from, to, max_depth, &entities, &adj, &mut visited, &mut current_path, &mut all_paths);
331
332 all_paths.sort_by(|a, b| {
333 let a_cost: f32 = a.iter().map(|step| self.weight_strategy.relation_cost(&step.relation)).sum();
334 let b_cost: f32 = b.iter().map(|step| self.weight_strategy.relation_cost(&step.relation)).sum();
335 a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
336 });
337
338 Ok(all_paths)
339 }
340
341 async fn batch_import(
344 &self,
345 entities: Vec<Entity>,
346 relations: Vec<Relation>,
347 ) -> Result<ImportResult, KgError> {
348 let mut result = ImportResult::default();
349
350 for entity in entities {
351 match self.upsert_entity(entity).await? {
352 UpsertResult::Created => result.entities_created += 1,
353 UpsertResult::Updated { .. } => result.entities_updated += 1,
354 UpsertResult::Unchanged => result.entities_skipped += 1,
355 }
356 }
357 for rel in relations {
358 match self.upsert_relation(rel).await? {
359 UpsertResult::Created => result.relations_created += 1,
360 UpsertResult::Updated { .. } => result.relations_updated += 1,
361 _ => {}
362 }
363 }
364
365 Ok(result)
366 }
367
368 async fn check_consistency(&self) -> Result<Vec<ConsistencyIssue>, KgError> {
371 let entities = self.entities.read().unwrap();
372 let relations = self.relations.read().unwrap();
373 let mut issues = Vec::new();
374
375 for rel in relations.values() {
377 if !entities.contains_key(&rel.source_id) || !entities.contains_key(&rel.target_id) {
378 issues.push(ConsistencyIssue {
379 severity: IssueSeverity::Error,
380 issue_type: ConsistencyIssueType::OrphanRelation,
381 description: format!("Relation {} references a non-existent entity", rel.id),
382 related_entities: vec![rel.source_id.clone(), rel.target_id.clone()],
383 related_relations: vec![rel.id.clone()],
384 });
385 }
386 }
387
388 for rel in relations.values() {
390 if rel.source_id == rel.target_id {
391 issues.push(ConsistencyIssue {
392 severity: IssueSeverity::Warning,
393 issue_type: ConsistencyIssueType::SelfReferencing,
394 description: format!("Relation {} self-references entity {}", rel.id, rel.source_id),
395 related_entities: vec![rel.source_id.clone()],
396 related_relations: vec![rel.id.clone()],
397 });
398 }
399 }
400
401 for (id, entity) in entities.iter() {
403 let has_relation = relations.values().any(|r| r.source_id == *id || r.target_id == *id);
404 if !has_relation {
405 issues.push(ConsistencyIssue {
406 severity: IssueSeverity::Info,
407 issue_type: ConsistencyIssueType::OrphanEntity,
408 description: format!("Entity {} ({}) has no relations", entity.name, id),
409 related_entities: vec![id.clone()],
410 related_relations: vec![],
411 });
412 }
413 }
414
415 let now = current_epoch_ms();
417 for rel in relations.values() {
418 if let Some(valid_to) = rel.valid_to {
419 if valid_to < now {
420 issues.push(ConsistencyIssue {
421 severity: IssueSeverity::Warning,
422 issue_type: ConsistencyIssueType::ExpiredRelation,
423 description: format!("Relation {} has expired (valid_to < now)", rel.id),
424 related_entities: vec![rel.source_id.clone(), rel.target_id.clone()],
425 related_relations: vec![rel.id.clone()],
426 });
427 }
428 }
429 }
430
431 Ok(issues)
432 }
433
434 async fn stats(&self) -> Result<GraphStats, KgError> {
437 let entities = self.entities.read().unwrap();
438 let relations = self.relations.read().unwrap();
439
440 let mut entity_types: HashMap<String, usize> = HashMap::new();
441 for e in entities.values() {
442 *entity_types.entry(e.entity_type.as_str()).or_default() += 1;
443 }
444
445 let mut relation_types: HashMap<String, usize> = HashMap::new();
446 for r in relations.values() {
447 *relation_types.entry(r.relation_type.clone()).or_default() += 1;
448 }
449
450 let mut degrees: HashMap<String, usize> = HashMap::new();
451 for r in relations.values() {
452 *degrees.entry(r.source_id.clone()).or_default() += 1;
453 *degrees.entry(r.target_id.clone()).or_default() += 1;
454 }
455
456 let degree_values: Vec<usize> = degrees.values().copied().collect();
457 let avg_degree = if degree_values.is_empty() {
458 0.0
459 } else {
460 degree_values.iter().sum::<usize>() as f64 / degree_values.len() as f64
461 };
462 let max_degree = degree_values.iter().max().copied().unwrap_or(0);
463
464 let orphan_entities = entities
465 .keys()
466 .filter(|id| !relations.values().any(|r| &&r.source_id == id || &&r.target_id == id))
467 .count();
468
469 Ok(GraphStats {
470 total_entities: entities.len(),
471 total_relations: relations.len(),
472 entity_types,
473 relation_types,
474 avg_degree,
475 max_degree,
476 orphan_entities,
477 db_size_bytes: 0,
478 })
479 }
480}
481
482fn dfs_memory(
483 current: &str,
484 target: &str,
485 max_depth: u32,
486 entities: &HashMap<String, Entity>,
487 adj: &HashMap<String, Vec<(String, Relation)>>,
488 visited: &mut HashSet<String>,
489 current_path: &mut Vec<PathStep>,
490 all_paths: &mut Vec<Vec<PathStep>>,
491) {
492 if current == target {
493 all_paths.push(current_path.clone());
494 return;
495 }
496 if current_path.len() >= max_depth as usize {
497 return;
498 }
499 visited.insert(current.to_string());
500
501 if let Some(neighbors) = adj.get(current) {
502 for (neighbor, rel) in neighbors {
503 if visited.contains(neighbor.as_str()) {
504 continue;
505 }
506 if let Some(entity) = entities.get(neighbor).cloned() {
507 current_path.push(PathStep {
508 entity,
509 relation: rel.clone(),
510 });
511 dfs_memory(neighbor, target, max_depth, entities, adj, visited, current_path, all_paths);
512 current_path.pop();
513 }
514 }
515 }
516
517 visited.remove(current);
518}
519
520fn current_epoch_ms() -> u64 {
521 std::time::SystemTime::now()
522 .duration_since(std::time::UNIX_EPOCH)
523 .unwrap_or_default()
524 .as_millis() as u64
525}