1use anyhow::{anyhow, Result};
67use serde::{Deserialize, Serialize};
68use std::collections::{HashMap, HashSet, VecDeque};
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct Entity {
73 pub id: String,
75
76 pub embedding: Vec<f32>,
78
79 pub entity_type: String,
81
82 pub properties: HashMap<String, serde_json::Value>,
84}
85
86impl Entity {
87 pub fn new(id: impl Into<String>, embedding: Vec<f32>, entity_type: impl Into<String>) -> Self {
89 Self {
90 id: id.into(),
91 embedding,
92 entity_type: entity_type.into(),
93 properties: HashMap::new(),
94 }
95 }
96
97 pub fn with_property(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
99 self.properties.insert(key.into(), value);
100 self
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct Relation {
107 pub from: String,
109
110 pub to: String,
112
113 pub relation_type: String,
115
116 pub weight: f32,
118
119 pub properties: HashMap<String, serde_json::Value>,
121}
122
123impl Relation {
124 pub fn new(
126 from: impl Into<String>,
127 to: impl Into<String>,
128 relation_type: impl Into<String>,
129 weight: f32,
130 ) -> Self {
131 Self {
132 from: from.into(),
133 to: to.into(),
134 relation_type: relation_type.into(),
135 weight,
136 properties: HashMap::new(),
137 }
138 }
139
140 pub fn with_property(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
142 self.properties.insert(key.into(), value);
143 self
144 }
145}
146
147#[derive(Clone)]
149pub struct GraphQuery {
150 pub embedding: Vec<f32>,
152
153 pub limit: usize,
155
156 pub max_hops: usize,
158
159 pub entity_type_filter: Option<String>,
161
162 pub relation_type_filter: Option<Vec<String>>,
164
165 pub min_relation_weight: f32,
167}
168
169impl GraphQuery {
170 pub fn new(embedding: Vec<f32>) -> Self {
172 Self {
173 embedding,
174 limit: 10,
175 max_hops: 1,
176 entity_type_filter: None,
177 relation_type_filter: None,
178 min_relation_weight: 0.0,
179 }
180 }
181
182 pub fn with_limit(mut self, limit: usize) -> Self {
184 self.limit = limit;
185 self
186 }
187
188 pub fn with_max_hops(mut self, max_hops: usize) -> Self {
190 self.max_hops = max_hops;
191 self
192 }
193
194 pub fn with_entity_type(mut self, entity_type: impl Into<String>) -> Self {
196 self.entity_type_filter = Some(entity_type.into());
197 self
198 }
199
200 pub fn with_relation_types(mut self, types: Vec<String>) -> Self {
202 self.relation_type_filter = Some(types);
203 self
204 }
205
206 pub fn with_min_relation_weight(mut self, weight: f32) -> Self {
208 self.min_relation_weight = weight;
209 self
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct GraphResult {
216 pub entity: Entity,
218
219 pub score: f32,
221
222 pub hops: usize,
224
225 pub path: Vec<String>,
227
228 pub neighbors: Vec<Entity>,
230}
231
232pub struct GraphRAG {
234 dimension: usize,
236
237 entities: HashMap<String, Entity>,
239
240 outgoing: HashMap<String, Vec<Relation>>,
242
243 incoming: HashMap<String, Vec<Relation>>,
245}
246
247impl GraphRAG {
248 pub fn new(dimension: usize) -> Result<Self> {
250 Ok(Self {
251 dimension,
252 entities: HashMap::new(),
253 outgoing: HashMap::new(),
254 incoming: HashMap::new(),
255 })
256 }
257
258 pub fn add_entity(
260 &mut self,
261 id: impl Into<String>,
262 embedding: Vec<f32>,
263 entity_type: impl Into<String>,
264 ) -> Result<()> {
265 let id = id.into();
266
267 if embedding.len() != self.dimension {
268 return Err(anyhow!(
269 "Embedding dimension {} doesn't match graph dimension {}",
270 embedding.len(),
271 self.dimension
272 ));
273 }
274
275 let entity = Entity::new(id.clone(), embedding, entity_type);
276 self.entities.insert(id.clone(), entity);
277
278 self.outgoing.entry(id.clone()).or_insert_with(Vec::new);
280 self.incoming.entry(id).or_insert_with(Vec::new);
281
282 Ok(())
283 }
284
285 pub fn add_entity_with_properties(&mut self, entity: Entity) -> Result<()> {
287 if entity.embedding.len() != self.dimension {
288 return Err(anyhow!(
289 "Embedding dimension {} doesn't match graph dimension {}",
290 entity.embedding.len(),
291 self.dimension
292 ));
293 }
294
295 let id = entity.id.clone();
296 self.entities.insert(id.clone(), entity);
297
298 self.outgoing.entry(id.clone()).or_insert_with(Vec::new);
300 self.incoming.entry(id).or_insert_with(Vec::new);
301
302 Ok(())
303 }
304
305 pub fn add_relation(
307 &mut self,
308 from: impl Into<String>,
309 to: impl Into<String>,
310 relation_type: impl Into<String>,
311 weight: f32,
312 ) -> Result<()> {
313 let from = from.into();
314 let to = to.into();
315
316 if !self.entities.contains_key(&from) {
318 return Err(anyhow!("Source entity '{}' not found", from));
319 }
320 if !self.entities.contains_key(&to) {
321 return Err(anyhow!("Target entity '{}' not found", to));
322 }
323
324 let relation = Relation::new(from.clone(), to.clone(), relation_type, weight);
325
326 self.outgoing
328 .entry(from.clone())
329 .or_insert_with(Vec::new)
330 .push(relation.clone());
331
332 self.incoming
334 .entry(to)
335 .or_insert_with(Vec::new)
336 .push(relation);
337
338 Ok(())
339 }
340
341 pub fn add_relation_with_properties(&mut self, relation: Relation) -> Result<()> {
343 if !self.entities.contains_key(&relation.from) {
345 return Err(anyhow!("Source entity '{}' not found", relation.from));
346 }
347 if !self.entities.contains_key(&relation.to) {
348 return Err(anyhow!("Target entity '{}' not found", relation.to));
349 }
350
351 let from = relation.from.clone();
352 let to = relation.to.clone();
353
354 self.outgoing
356 .entry(from)
357 .or_insert_with(Vec::new)
358 .push(relation.clone());
359
360 self.incoming
362 .entry(to)
363 .or_insert_with(Vec::new)
364 .push(relation);
365
366 Ok(())
367 }
368
369 pub fn search(&self, query: &GraphQuery) -> Result<Vec<GraphResult>> {
371 if query.embedding.len() != self.dimension {
372 return Err(anyhow!(
373 "Query embedding dimension {} doesn't match graph dimension {}",
374 query.embedding.len(),
375 self.dimension
376 ));
377 }
378
379 let mut candidates: Vec<(String, f32)> = self
381 .entities
382 .iter()
383 .filter(|(_, entity)| {
384 if let Some(ref filter) = query.entity_type_filter {
385 &entity.entity_type == filter
386 } else {
387 true
388 }
389 })
390 .map(|(id, entity)| {
391 let distance = euclidean_distance(&query.embedding, &entity.embedding);
392 let score = 1.0 / (1.0 + distance);
393 (id.clone(), score)
394 })
395 .collect();
396
397 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
399 candidates.truncate(query.limit);
400
401 let mut results = Vec::new();
403 let mut visited = HashSet::new();
404
405 for (entity_id, score) in candidates {
406 if visited.contains(&entity_id) {
407 continue;
408 }
409
410 let subgraph = self.expand_subgraph(&entity_id, query, &mut visited)?;
412
413 let entity = self.entities.get(&entity_id).unwrap().clone();
415 let neighbors: Vec<Entity> = subgraph
416 .iter()
417 .filter_map(|id| self.entities.get(id).cloned())
418 .collect();
419
420 results.push(GraphResult {
421 entity,
422 score,
423 hops: 0,
424 path: vec![entity_id.clone()],
425 neighbors,
426 });
427
428 visited.insert(entity_id);
429 }
430
431 Ok(results)
432 }
433
434 fn expand_subgraph(
436 &self,
437 start_id: &str,
438 query: &GraphQuery,
439 visited: &mut HashSet<String>,
440 ) -> Result<Vec<String>> {
441 let mut queue = VecDeque::new();
442 let mut subgraph = Vec::new();
443
444 queue.push_back((start_id.to_string(), 0));
445
446 while let Some((entity_id, hops)) = queue.pop_front() {
447 if hops >= query.max_hops {
448 continue;
449 }
450
451 if let Some(relations) = self.outgoing.get(&entity_id) {
452 for relation in relations {
453 if relation.weight < query.min_relation_weight {
455 continue;
456 }
457
458 if let Some(ref filter) = query.relation_type_filter {
459 if !filter.contains(&relation.relation_type) {
460 continue;
461 }
462 }
463
464 if !visited.contains(&relation.to) {
465 subgraph.push(relation.to.clone());
466 queue.push_back((relation.to.clone(), hops + 1));
467 visited.insert(relation.to.clone());
468 }
469 }
470 }
471 }
472
473 Ok(subgraph)
474 }
475
476 pub fn get_entity(&self, id: &str) -> Option<&Entity> {
478 self.entities.get(id)
479 }
480
481 pub fn get_outgoing(&self, id: &str) -> Vec<&Relation> {
483 self.outgoing
484 .get(id)
485 .map(|rels| rels.iter().collect())
486 .unwrap_or_default()
487 }
488
489 pub fn get_incoming(&self, id: &str) -> Vec<&Relation> {
491 self.incoming
492 .get(id)
493 .map(|rels| rels.iter().collect())
494 .unwrap_or_default()
495 }
496
497 pub fn get_neighbors(&self, id: &str) -> Vec<String> {
499 let mut neighbors = HashSet::new();
500
501 if let Some(relations) = self.outgoing.get(id) {
503 for rel in relations {
504 neighbors.insert(rel.to.clone());
505 }
506 }
507
508 if let Some(relations) = self.incoming.get(id) {
510 for rel in relations {
511 neighbors.insert(rel.from.clone());
512 }
513 }
514
515 neighbors.into_iter().collect()
516 }
517
518 pub fn stats(&self) -> GraphStats {
520 let total_relations: usize = self.outgoing.values().map(|v| v.len()).sum();
521
522 let mut entity_types: HashMap<String, usize> = HashMap::new();
523 for entity in self.entities.values() {
524 *entity_types.entry(entity.entity_type.clone()).or_insert(0) += 1;
525 }
526
527 let mut relation_types: HashMap<String, usize> = HashMap::new();
528 for relations in self.outgoing.values() {
529 for rel in relations {
530 *relation_types.entry(rel.relation_type.clone()).or_insert(0) += 1;
531 }
532 }
533
534 GraphStats {
535 num_entities: self.entities.len(),
536 num_relations: total_relations,
537 entity_types,
538 relation_types,
539 }
540 }
541
542 pub fn remove_entity(&mut self, id: &str) -> Result<bool> {
544 if !self.entities.contains_key(id) {
545 return Ok(false);
546 }
547
548 self.entities.remove(id);
550
551 self.outgoing.remove(id);
553
554 for relations in self.incoming.values_mut() {
556 relations.retain(|r| r.from != id);
557 }
558
559 self.incoming.remove(id);
561
562 for relations in self.outgoing.values_mut() {
564 relations.retain(|r| r.to != id);
565 }
566
567 Ok(true)
568 }
569
570 pub fn len(&self) -> usize {
572 self.entities.len()
573 }
574
575 pub fn is_empty(&self) -> bool {
577 self.entities.is_empty()
578 }
579
580 pub fn dimension(&self) -> usize {
582 self.dimension
583 }
584}
585
586#[derive(Debug, Clone)]
588pub struct GraphStats {
589 pub num_entities: usize,
590 pub num_relations: usize,
591 pub entity_types: HashMap<String, usize>,
592 pub relation_types: HashMap<String, usize>,
593}
594
595fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
597 a.iter()
598 .zip(b.iter())
599 .map(|(x, y)| (x - y).powi(2))
600 .sum::<f32>()
601 .sqrt()
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607
608 #[test]
609 fn test_graph_basic() {
610 let mut graph = GraphRAG::new(64).unwrap();
611
612 graph.add_entity("rust", vec![0.1; 64], "language").unwrap();
614 graph
615 .add_entity("python", vec![0.2; 64], "language")
616 .unwrap();
617 graph.add_entity("wasm", vec![0.3; 64], "platform").unwrap();
618
619 assert_eq!(graph.len(), 3);
620
621 graph
623 .add_relation("rust", "wasm", "compiles_to", 1.0)
624 .unwrap();
625 graph
626 .add_relation("python", "wasm", "compiles_to", 0.8)
627 .unwrap();
628
629 let rust_out = graph.get_outgoing("rust");
631 assert_eq!(rust_out.len(), 1);
632 assert_eq!(rust_out[0].to, "wasm");
633
634 let wasm_in = graph.get_incoming("wasm");
635 assert_eq!(wasm_in.len(), 2);
636 }
637
638 #[test]
639 fn test_graph_search() {
640 let mut graph = GraphRAG::new(32).unwrap();
641
642 graph.add_entity("doc1", vec![0.1; 32], "document").unwrap();
644 graph.add_entity("doc2", vec![0.5; 32], "document").unwrap();
645 graph.add_entity("topic1", vec![0.3; 32], "topic").unwrap();
646
647 graph.add_relation("doc1", "topic1", "about", 1.0).unwrap();
649 graph.add_relation("doc2", "topic1", "about", 0.9).unwrap();
650
651 let query = GraphQuery::new(vec![0.1; 32])
653 .with_limit(5)
654 .with_max_hops(1);
655
656 let results = graph.search(&query).unwrap();
657
658 assert!(!results.is_empty());
659 assert_eq!(results[0].entity.id, "doc1");
661 }
662
663 #[test]
664 fn test_graph_traversal() {
665 let mut graph = GraphRAG::new(32).unwrap();
666
667 graph.add_entity("A", vec![0.1; 32], "node").unwrap();
669 graph.add_entity("B", vec![0.2; 32], "node").unwrap();
670 graph.add_entity("C", vec![0.3; 32], "node").unwrap();
671
672 graph.add_relation("A", "B", "connects", 1.0).unwrap();
673 graph.add_relation("B", "C", "connects", 1.0).unwrap();
674
675 let query = GraphQuery::new(vec![0.1; 32])
677 .with_limit(1)
678 .with_max_hops(2);
679
680 let results = graph.search(&query).unwrap();
681
682 assert_eq!(results.len(), 1);
683 assert!(!results[0].neighbors.is_empty());
685 }
686
687 #[test]
688 fn test_entity_type_filter() {
689 let mut graph = GraphRAG::new(32).unwrap();
690
691 graph.add_entity("rust", vec![0.1; 32], "language").unwrap();
692 graph.add_entity("wasm", vec![0.2; 32], "platform").unwrap();
693 graph
694 .add_entity("python", vec![0.3; 32], "language")
695 .unwrap();
696
697 let query = GraphQuery::new(vec![0.15; 32]).with_entity_type("language");
699
700 let results = graph.search(&query).unwrap();
701
702 for result in &results {
704 assert_eq!(result.entity.entity_type, "language");
705 }
706 }
707
708 #[test]
709 fn test_neighbors() {
710 let mut graph = GraphRAG::new(32).unwrap();
711
712 graph.add_entity("A", vec![0.1; 32], "node").unwrap();
713 graph.add_entity("B", vec![0.2; 32], "node").unwrap();
714 graph.add_entity("C", vec![0.3; 32], "node").unwrap();
715
716 graph.add_relation("A", "B", "connects", 1.0).unwrap();
717 graph.add_relation("C", "A", "connects", 1.0).unwrap();
718
719 let neighbors = graph.get_neighbors("A");
720
721 assert_eq!(neighbors.len(), 2);
722 assert!(neighbors.contains(&"B".to_string()));
723 assert!(neighbors.contains(&"C".to_string()));
724 }
725
726 #[test]
727 fn test_remove_entity() {
728 let mut graph = GraphRAG::new(32).unwrap();
729
730 graph.add_entity("A", vec![0.1; 32], "node").unwrap();
731 graph.add_entity("B", vec![0.2; 32], "node").unwrap();
732
733 graph.add_relation("A", "B", "connects", 1.0).unwrap();
734
735 assert_eq!(graph.len(), 2);
736
737 let removed = graph.remove_entity("A").unwrap();
738 assert!(removed);
739 assert_eq!(graph.len(), 1);
740
741 let b_in = graph.get_incoming("B");
743 assert_eq!(b_in.len(), 0);
744 }
745
746 #[test]
747 fn test_stats() {
748 let mut graph = GraphRAG::new(32).unwrap();
749
750 graph.add_entity("rust", vec![0.1; 32], "language").unwrap();
751 graph
752 .add_entity("python", vec![0.2; 32], "language")
753 .unwrap();
754 graph.add_entity("wasm", vec![0.3; 32], "platform").unwrap();
755
756 graph
757 .add_relation("rust", "wasm", "compiles_to", 1.0)
758 .unwrap();
759 graph
760 .add_relation("python", "wasm", "compiles_to", 0.8)
761 .unwrap();
762
763 let stats = graph.stats();
764
765 assert_eq!(stats.num_entities, 3);
766 assert_eq!(stats.num_relations, 2);
767 assert_eq!(stats.entity_types.get("language"), Some(&2));
768 assert_eq!(stats.entity_types.get("platform"), Some(&1));
769 assert_eq!(stats.relation_types.get("compiles_to"), Some(&2));
770 }
771
772 #[test]
773 fn test_entity_with_properties() {
774 let mut graph = GraphRAG::new(32).unwrap();
775
776 let entity = Entity::new("rust", vec![0.1; 32], "language")
777 .with_property("year", serde_json::json!(2010))
778 .with_property("paradigm", serde_json::json!("systems"));
779
780 graph.add_entity_with_properties(entity).unwrap();
781
782 let retrieved = graph.get_entity("rust").unwrap();
783 assert_eq!(
784 retrieved.properties.get("year"),
785 Some(&serde_json::json!(2010))
786 );
787 }
788
789 #[test]
790 fn test_relation_weight_filter() {
791 let mut graph = GraphRAG::new(32).unwrap();
792
793 graph.add_entity("A", vec![0.1; 32], "node").unwrap();
794 graph.add_entity("B", vec![0.2; 32], "node").unwrap();
795 graph.add_entity("C", vec![0.3; 32], "node").unwrap();
796
797 graph.add_relation("A", "B", "strong", 0.9).unwrap();
798 graph.add_relation("A", "C", "weak", 0.1).unwrap();
799
800 let query = GraphQuery::new(vec![0.1; 32])
802 .with_max_hops(1)
803 .with_min_relation_weight(0.5);
804
805 let results = graph.search(&query).unwrap();
806
807 assert!(results[0].neighbors.iter().any(|e| e.id == "B"));
809 }
810}