1pub mod algorithms;
21pub mod embed;
22pub mod error;
23pub mod export;
24pub mod extension;
25pub mod functions;
26pub mod graph;
27pub mod migrate;
28pub mod rag;
29pub mod schema;
30pub mod vector;
31pub mod version;
32
33#[cfg(feature = "async")]
34pub mod async_kg;
35#[cfg(feature = "async")]
36pub use async_kg::embed::AsyncEmbeddingGenerator;
37#[cfg(feature = "async")]
38pub use async_kg::AsyncKnowledgeGraph;
39
40pub(crate) fn row_get_weight(row: &rusqlite::Row, col: usize) -> rusqlite::Result<f64> {
46 use rusqlite::types::ValueRef;
47 match row.get_ref(col)? {
48 ValueRef::Real(f) => Ok(f),
49 ValueRef::Integer(i) => Ok(i as f64),
50 ValueRef::Null => Ok(1.0), ValueRef::Blob(b) if b.len() == 8 => {
52 let mut bytes = [0u8; 8];
53 bytes.copy_from_slice(b);
54 Ok(f64::from_le_bytes(bytes)) }
56 ValueRef::Blob(b) if b.len() == 4 => {
57 let mut bytes = [0u8; 4];
58 bytes.copy_from_slice(b);
59 Ok(f32::from_le_bytes(bytes) as f64) }
61 ValueRef::Blob(_) => Ok(1.0), _ => Err(rusqlite::Error::InvalidColumnType(
63 col,
64 "weight".into(),
65 rusqlite::types::Type::Real,
66 )),
67 }
68}
69
70pub use algorithms::{
71 analyze_graph, connected_components, louvain_communities, pagerank, CommunityResult,
72 PageRankConfig,
73};
74pub use embed::{
75 check_dependencies, get_entities_needing_embedding, EmbeddingConfig, EmbeddingGenerator,
76 EmbeddingStats,
77};
78pub use error::{Error, Result};
79pub use export::{D3ExportGraph, D3ExportMetadata, D3Link, D3Node, DotConfig};
80pub use extension::sqlite3_sqlite_knowledge_graph_init;
81pub use functions::register_functions;
82pub use graph::{Direction, GraphStats, PathStep, TraversalNode, TraversalPath, TraversalQuery};
83pub use graph::{Entity, Neighbor, Relation};
84pub use graph::{HigherOrderNeighbor, HigherOrderPath, HigherOrderPathStep, Hyperedge};
85pub use migrate::{
86 build_relationships, migrate_all, migrate_papers, migrate_skills, MigrationStats,
87};
88pub use rag::{embedder::Embedder, embedder::FixedEmbedder, RagConfig, RagEngine, RagResult};
89pub use rag::{RetrievalWeights, SmartRetrieval, SmartSearchResult};
90pub use schema::{create_schema, schema_exists};
91pub use vector::{cosine_similarity, SearchResult, VectorStore};
92pub use vector::{ConfidenceEngine, ConfidenceParams};
93pub use vector::{TurboQuantConfig, TurboQuantIndex, TurboQuantStats};
94pub use version::{MergeStrategy, Version, VersionDiff};
95
96use rusqlite::Connection;
97use serde::{Deserialize, Serialize};
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct SearchResultWithEntity {
102 pub entity: Entity,
103 pub similarity: f32,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct GraphContext {
109 pub root_entity: Entity,
110 pub neighbors: Vec<Neighbor>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct HybridSearchResult {
116 pub entity: Entity,
117 pub similarity: f32,
118 pub context: Option<GraphContext>,
119}
120
121pub struct KnowledgeGraph {
123 conn: Connection,
124 retrieval_weights: std::cell::Cell<RetrievalWeights>,
125}
126
127impl std::fmt::Debug for KnowledgeGraph {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 f.debug_struct("KnowledgeGraph").finish_non_exhaustive()
130 }
131}
132
133impl KnowledgeGraph {
134 pub fn open<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
136 let conn = Connection::open(path)?;
137
138 conn.execute("PRAGMA foreign_keys = ON", [])?;
140
141 if !schema_exists(&conn)? {
143 create_schema(&conn)?;
144 }
145
146 register_functions(&conn)?;
148
149 Ok(Self {
150 conn,
151 retrieval_weights: std::cell::Cell::new(RetrievalWeights::default()),
152 })
153 }
154
155 pub fn open_in_memory() -> Result<Self> {
157 let conn = Connection::open_in_memory()?;
158
159 conn.execute("PRAGMA foreign_keys = ON", [])?;
161
162 create_schema(&conn)?;
164
165 register_functions(&conn)?;
167
168 Ok(Self {
169 conn,
170 retrieval_weights: std::cell::Cell::new(RetrievalWeights::default()),
171 })
172 }
173
174 pub fn connection(&self) -> &Connection {
176 &self.conn
177 }
178
179 pub fn retrieval_weights(&self) -> RetrievalWeights {
181 self.retrieval_weights.get()
182 }
183
184 pub fn set_retrieval_weights(&self, weights: RetrievalWeights) {
186 self.retrieval_weights.set(weights);
187 }
188
189 pub fn transaction(&self) -> Result<rusqlite::Transaction<'_>> {
191 Ok(self.conn.unchecked_transaction()?)
192 }
193
194 pub fn insert_entity(&self, entity: &Entity) -> Result<i64> {
196 graph::insert_entity(&self.conn, entity)
197 }
198
199 pub fn get_entity(&self, id: i64) -> Result<Entity> {
201 graph::get_entity(&self.conn, id)
202 }
203
204 pub fn list_entities(
206 &self,
207 entity_type: Option<&str>,
208 limit: Option<i64>,
209 ) -> Result<Vec<Entity>> {
210 graph::list_entities(&self.conn, entity_type, limit)
211 }
212
213 pub fn update_entity(&self, entity: &Entity) -> Result<()> {
215 graph::update_entity(&self.conn, entity)
216 }
217
218 pub fn delete_entity(&self, id: i64) -> Result<()> {
220 graph::delete_entity(&self.conn, id)
221 }
222
223 pub fn insert_relation(&self, relation: &Relation) -> Result<i64> {
225 graph::insert_relation(&self.conn, relation)
226 }
227
228 pub fn get_neighbors(&self, entity_id: i64, depth: u32) -> Result<Vec<Neighbor>> {
230 graph::get_neighbors(&self.conn, entity_id, depth)
231 }
232
233 pub fn insert_vector(&self, entity_id: i64, vector: Vec<f32>) -> Result<()> {
235 let store = VectorStore::new();
236 store.insert_vector(&self.conn, entity_id, vector)
237 }
238
239 pub fn search_vectors(&self, query: Vec<f32>, k: usize) -> Result<Vec<SearchResult>> {
241 let store = VectorStore::new();
242 store.search_vectors(&self.conn, query, k)
243 }
244
245 pub fn smart_search(&self, query: Vec<f32>, k: usize) -> Result<Vec<SmartSearchResult>> {
249 let retrieval = SmartRetrieval::new(self.retrieval_weights.get());
250 retrieval.retrieve(&self.conn, &query, k)
251 }
252
253 pub fn create_turboquant_index(
283 &self,
284 config: Option<TurboQuantConfig>,
285 ) -> Result<TurboQuantIndex> {
286 let config = config.unwrap_or_default();
287
288 TurboQuantIndex::new(config)
289 }
290
291 pub fn build_turboquant_index(
294 &self,
295 config: Option<TurboQuantConfig>,
296 ) -> Result<TurboQuantIndex> {
297 let dimension = self.get_vector_dimension()?.unwrap_or(384);
299
300 let config = config.unwrap_or(TurboQuantConfig {
301 dimension,
302 bit_width: 3,
303 seed: 42,
304 });
305
306 let mut index = TurboQuantIndex::new(config)?;
307
308 let vectors = self.load_all_vectors()?;
310
311 for (entity_id, vector) in vectors {
312 index.add_vector(entity_id, &vector)?;
313 }
314
315 Ok(index)
316 }
317
318 fn get_vector_dimension(&self) -> Result<Option<usize>> {
320 let result = self
321 .conn
322 .query_row("SELECT dimension FROM kg_vectors LIMIT 1", [], |row| {
323 row.get::<_, i64>(0)
324 });
325
326 match result {
327 Ok(dim) => Ok(Some(dim as usize)),
328 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
329 Err(e) => Err(e.into()),
330 }
331 }
332
333 fn load_all_vectors(&self) -> Result<Vec<(i64, Vec<f32>)>> {
335 let mut stmt = self
336 .conn
337 .prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
338
339 let rows = stmt.query_map([], |row| {
340 let entity_id: i64 = row.get(0)?;
341 let vector_blob: Vec<u8> = row.get(1)?;
342 let dimension: i64 = row.get(2)?;
343
344 let mut vector = Vec::with_capacity(dimension as usize);
345 for chunk in vector_blob.chunks_exact(4) {
346 let bytes: [u8; 4] = chunk.try_into().unwrap();
347 vector.push(f32::from_le_bytes(bytes));
348 }
349
350 Ok((entity_id, vector))
351 })?;
352
353 let mut vectors = Vec::new();
354 for row in rows {
355 vectors.push(row?);
356 }
357
358 Ok(vectors)
359 }
360
361 pub fn insert_hyperedge(&self, hyperedge: &Hyperedge) -> Result<i64> {
365 graph::insert_hyperedge(&self.conn, hyperedge)
366 }
367
368 pub fn get_hyperedge(&self, id: i64) -> Result<Hyperedge> {
370 graph::get_hyperedge(&self.conn, id)
371 }
372
373 pub fn list_hyperedges(
375 &self,
376 hyperedge_type: Option<&str>,
377 min_arity: Option<usize>,
378 max_arity: Option<usize>,
379 limit: Option<i64>,
380 ) -> Result<Vec<Hyperedge>> {
381 graph::list_hyperedges(&self.conn, hyperedge_type, min_arity, max_arity, limit)
382 }
383
384 pub fn update_hyperedge(&self, hyperedge: &Hyperedge) -> Result<()> {
386 graph::update_hyperedge(&self.conn, hyperedge)
387 }
388
389 pub fn delete_hyperedge(&self, id: i64) -> Result<()> {
391 graph::delete_hyperedge(&self.conn, id)
392 }
393
394 pub fn get_higher_order_neighbors(
396 &self,
397 entity_id: i64,
398 min_arity: Option<usize>,
399 max_arity: Option<usize>,
400 ) -> Result<Vec<HigherOrderNeighbor>> {
401 graph::get_higher_order_neighbors(&self.conn, entity_id, min_arity, max_arity)
402 }
403
404 pub fn get_entity_hyperedges(&self, entity_id: i64) -> Result<Vec<Hyperedge>> {
406 graph::get_entity_hyperedges(&self.conn, entity_id)
407 }
408
409 pub fn kg_higher_order_bfs(
411 &self,
412 start_id: i64,
413 max_depth: u32,
414 min_arity: Option<usize>,
415 ) -> Result<Vec<TraversalNode>> {
416 graph::higher_order_bfs(&self.conn, start_id, max_depth, min_arity)
417 }
418
419 pub fn kg_higher_order_shortest_path(
421 &self,
422 from_id: i64,
423 to_id: i64,
424 max_depth: u32,
425 ) -> Result<Option<HigherOrderPath>> {
426 graph::higher_order_shortest_path(&self.conn, from_id, to_id, max_depth)
427 }
428
429 pub fn kg_hyperedge_degree(&self, entity_id: i64) -> Result<f64> {
431 graph::hyperedge_degree(&self.conn, entity_id)
432 }
433
434 pub fn kg_hypergraph_entity_pagerank(
436 &self,
437 damping: Option<f64>,
438 max_iter: Option<usize>,
439 tolerance: Option<f64>,
440 ) -> Result<std::collections::HashMap<i64, f64>> {
441 graph::hypergraph_entity_pagerank(
442 &self.conn,
443 damping.unwrap_or(0.85),
444 max_iter.unwrap_or(100),
445 tolerance.unwrap_or(1e-6),
446 )
447 }
448
449 pub fn kg_semantic_search(
454 &self,
455 query_embedding: Vec<f32>,
456 k: usize,
457 ) -> Result<Vec<SearchResultWithEntity>> {
458 let results = self.search_vectors(query_embedding, k)?;
459
460 let mut entities_with_results = Vec::new();
461 for result in results {
462 let entity = self.get_entity(result.entity_id)?;
463 entities_with_results.push(SearchResultWithEntity {
464 entity,
465 similarity: result.similarity,
466 });
467 }
468
469 Ok(entities_with_results)
470 }
471
472 pub fn version_create(
476 &self,
477 name: &str,
478 branch: &str,
479 parent_id: Option<i64>,
480 description: Option<&str>,
481 ) -> Result<i64> {
482 version::store::create_version(&self.conn, name, branch, parent_id, description)
483 }
484
485 pub fn version_delete(&self, version_id: i64) -> Result<()> {
487 version::store::delete_version(&self.conn, version_id)
488 }
489
490 pub fn version_list(&self, branch: Option<&str>) -> Result<Vec<Version>> {
492 version::store::list_versions(&self.conn, branch)
493 }
494
495 pub fn version_add_entity(&self, version_id: i64, entity_id: i64) -> Result<()> {
497 version::snapshot::version_add_entity(&self.conn, version_id, entity_id)
498 }
499
500 pub fn version_remove_entity(&self, version_id: i64, entity_id: i64) -> Result<()> {
502 version::snapshot::version_remove_entity(&self.conn, version_id, entity_id)
503 }
504
505 pub fn version_add_relation(&self, version_id: i64, relation_id: i64) -> Result<()> {
507 version::snapshot::version_add_relation(&self.conn, version_id, relation_id)
508 }
509
510 pub fn version_remove_relation(&self, version_id: i64, relation_id: i64) -> Result<()> {
512 version::snapshot::version_remove_relation(&self.conn, version_id, relation_id)
513 }
514
515 pub fn version_snapshot_entities(&self, version_id: i64) -> Result<()> {
517 version::snapshot::version_snapshot_entities(&self.conn, version_id)
518 }
519
520 pub fn version_snapshot_relations(&self, version_id: i64) -> Result<()> {
522 version::snapshot::version_snapshot_relations(&self.conn, version_id)
523 }
524
525 pub fn version_entities(
527 &self,
528 version_id: i64,
529 entity_type: Option<&str>,
530 limit: Option<i64>,
531 ) -> Result<Vec<Entity>> {
532 version::query::version_entities(&self.conn, version_id, entity_type, limit)
533 }
534
535 pub fn version_relations(
537 &self,
538 version_id: i64,
539 rel_type: Option<&str>,
540 ) -> Result<Vec<Relation>> {
541 version::query::version_relations(&self.conn, version_id, rel_type, None, None, None)
542 }
543
544 pub fn version_neighbors(
546 &self,
547 entity_id: i64,
548 version_id: i64,
549 depth: u32,
550 ) -> Result<Vec<Neighbor>> {
551 version::query::version_neighbors(&self.conn, entity_id, version_id, depth)
552 }
553
554 pub fn version_compare(&self, v1_id: i64, v2_id: i64) -> Result<VersionDiff> {
556 version::diff::version_compare(&self.conn, v1_id, v2_id)
557 }
558
559 pub fn version_entity_history(&self, entity_id: i64) -> Result<Vec<Version>> {
561 version::diff::version_entity_history(&self.conn, entity_id)
562 }
563
564 pub fn version_merge(
566 &self,
567 source_ids: &[i64],
568 target_name: &str,
569 strategy: MergeStrategy,
570 ) -> Result<i64> {
571 version::merge::version_merge(&self.conn, source_ids, target_name, strategy)
572 }
573
574 pub fn kg_get_context(&self, entity_id: i64, depth: u32) -> Result<GraphContext> {
577 let root_entity = self.get_entity(entity_id)?;
578 let neighbors = self.get_neighbors(entity_id, depth)?;
579
580 Ok(GraphContext {
581 root_entity,
582 neighbors,
583 })
584 }
585
586 pub fn kg_hybrid_search(
589 &self,
590 _query_text: &str,
591 query_embedding: Vec<f32>,
592 k: usize,
593 ) -> Result<Vec<HybridSearchResult>> {
594 let semantic_results = self.kg_semantic_search(query_embedding, k)?;
595
596 let mut hybrid_results = Vec::new();
597 for result in semantic_results.iter() {
598 let entity_id = result.entity.id.ok_or(Error::EntityNotFound(0))?;
599 let context = self.kg_get_context(entity_id, 1)?; hybrid_results.push(HybridSearchResult {
602 entity: result.entity.clone(),
603 similarity: result.similarity,
604 context: Some(context),
605 });
606 }
607
608 Ok(hybrid_results)
609 }
610
611 pub fn kg_similar_entities(
619 &self,
620 entity_id: i64,
621 k: usize,
622 ) -> Result<Vec<SearchResultWithEntity>> {
623 let store = VectorStore::new();
624 let query_vec = store.get_vector(&self.conn, entity_id)?;
625 let results = store.search_vectors(&self.conn, query_vec, k + 1)?;
627
628 let mut out = Vec::new();
629 for r in results {
630 if r.entity_id == entity_id {
631 continue;
632 }
633 let entity = self.get_entity(r.entity_id)?;
634 out.push(SearchResultWithEntity {
635 entity,
636 similarity: r.similarity,
637 });
638 }
639 out.truncate(k);
640 Ok(out)
641 }
642
643 pub fn kg_find_related(
648 &self,
649 entity_id: i64,
650 threshold: f64,
651 ) -> Result<Vec<(graph::Entity, f64)>> {
652 let neighbours = self.get_neighbors(entity_id, 1)?;
653 let mut results: Vec<(graph::Entity, f64)> = neighbours
654 .into_iter()
655 .filter(|n| n.relation.weight >= threshold)
656 .map(|n| (n.entity, n.relation.weight))
657 .collect();
658 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
659 Ok(results)
660 }
661
662 pub fn kg_bfs_traversal(
667 &self,
668 start_id: i64,
669 direction: Direction,
670 max_depth: u32,
671 ) -> Result<Vec<TraversalNode>> {
672 let query = TraversalQuery {
673 direction,
674 max_depth,
675 ..Default::default()
676 };
677 graph::bfs_traversal(&self.conn, start_id, query)
678 }
679
680 pub fn kg_dfs_traversal(
683 &self,
684 start_id: i64,
685 direction: Direction,
686 max_depth: u32,
687 ) -> Result<Vec<TraversalNode>> {
688 let query = TraversalQuery {
689 direction,
690 max_depth,
691 ..Default::default()
692 };
693 graph::dfs_traversal(&self.conn, start_id, query)
694 }
695
696 pub fn kg_shortest_path(
699 &self,
700 from_id: i64,
701 to_id: i64,
702 max_depth: u32,
703 ) -> Result<Option<TraversalPath>> {
704 graph::find_shortest_path(&self.conn, from_id, to_id, max_depth)
705 }
706
707 pub fn kg_graph_stats(&self) -> Result<GraphStats> {
709 graph::compute_graph_stats(&self.conn)
710 }
711
712 pub fn kg_pagerank(&self, config: Option<PageRankConfig>) -> Result<Vec<(i64, f64)>> {
717 algorithms::pagerank(&self.conn, config.unwrap_or_default())
718 }
719
720 pub fn kg_louvain(&self) -> Result<CommunityResult> {
723 algorithms::louvain_communities(&self.conn)
724 }
725
726 pub fn kg_connected_components(&self) -> Result<Vec<Vec<i64>>> {
729 algorithms::connected_components(&self.conn)
730 }
731
732 pub fn kg_analyze(&self) -> Result<algorithms::GraphAnalysis> {
734 algorithms::analyze_graph(&self.conn)
735 }
736
737 pub fn export_json(&self) -> Result<D3ExportGraph> {
751 export::export_d3_json(&self.conn)
752 }
753
754 pub fn export_dot(&self, config: &DotConfig) -> Result<String> {
766 export::export_dot(&self.conn, config)
767 }
768
769 #[cfg(feature = "async")]
773 pub fn into_async(self) -> AsyncKnowledgeGraph {
774 AsyncKnowledgeGraph::from_sync(self)
775 }
776}
777
778#[cfg(test)]
779mod tests {
780 use super::*;
781
782 #[test]
783 fn test_open_in_memory() {
784 let kg = KnowledgeGraph::open_in_memory().unwrap();
785 assert!(schema_exists(kg.connection()).unwrap());
786 }
787
788 #[test]
789 fn test_crud_operations() {
790 let kg = KnowledgeGraph::open_in_memory().unwrap();
791
792 let mut entity = Entity::new("paper", "Test Paper");
794 entity.set_property("author", serde_json::json!("John Doe"));
795 let id = kg.insert_entity(&entity).unwrap();
796
797 let retrieved = kg.get_entity(id).unwrap();
799 assert_eq!(retrieved.name, "Test Paper");
800
801 let entities = kg.list_entities(Some("paper"), None).unwrap();
803 assert_eq!(entities.len(), 1);
804
805 let mut updated = retrieved.clone();
807 updated.set_property("year", serde_json::json!(2024));
808 kg.update_entity(&updated).unwrap();
809
810 kg.delete_entity(id).unwrap();
812 let entities = kg.list_entities(None, None).unwrap();
813 assert_eq!(entities.len(), 0);
814 }
815
816 #[test]
817 fn test_graph_traversal() {
818 let kg = KnowledgeGraph::open_in_memory().unwrap();
819
820 let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
822 let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
823 let id3 = kg.insert_entity(&Entity::new("paper", "Paper 3")).unwrap();
824
825 kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
827 .unwrap();
828 kg.insert_relation(&Relation::new(id2, id3, "cites", 0.9).unwrap())
829 .unwrap();
830
831 let neighbors = kg.get_neighbors(id1, 1).unwrap();
833 assert_eq!(neighbors.len(), 1);
834
835 let neighbors = kg.get_neighbors(id1, 2).unwrap();
837 assert_eq!(neighbors.len(), 2);
838 }
839
840 #[test]
841 fn test_vector_search() {
842 let kg = KnowledgeGraph::open_in_memory().unwrap();
843
844 let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
846 let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
847
848 kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
850 kg.insert_vector(id2, vec![0.0, 1.0, 0.0]).unwrap();
851
852 let results = kg.search_vectors(vec![1.0, 0.0, 0.0], 2).unwrap();
854 assert_eq!(results.len(), 2);
855 assert_eq!(results[0].entity_id, id1);
856 }
857
858 #[test]
861 fn test_find_related_above_threshold() {
862 let kg = KnowledgeGraph::open_in_memory().unwrap();
863 let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
864 let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
865 let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
866
867 kg.insert_relation(&Relation::new(id1, id2, "related", 0.9).unwrap())
868 .unwrap();
869 kg.insert_relation(&Relation::new(id1, id3, "related", 0.3).unwrap())
870 .unwrap();
871
872 let results = kg.kg_find_related(id1, 0.5).unwrap();
873 assert_eq!(
874 results.len(),
875 1,
876 "only B (weight 0.9) should pass threshold 0.5"
877 );
878 assert_eq!(results[0].0.id, Some(id2));
879 }
880
881 #[test]
882 fn test_find_related_sorted_descending() {
883 let kg = KnowledgeGraph::open_in_memory().unwrap();
884 let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
885 let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
886 let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
887
888 kg.insert_relation(&Relation::new(id1, id2, "related", 0.4).unwrap())
889 .unwrap();
890 kg.insert_relation(&Relation::new(id1, id3, "related", 0.9).unwrap())
891 .unwrap();
892
893 let results = kg.kg_find_related(id1, 0.0).unwrap();
894 assert_eq!(results.len(), 2);
895 assert!(
896 results[0].1 >= results[1].1,
897 "results should be sorted by weight desc"
898 );
899 assert_eq!(results[0].0.id, Some(id3)); }
901
902 #[test]
903 fn test_find_related_threshold_one() {
904 let kg = KnowledgeGraph::open_in_memory().unwrap();
905 let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
906 let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
907 let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
908
909 kg.insert_relation(&Relation::new(id1, id2, "related", 1.0).unwrap())
910 .unwrap();
911 kg.insert_relation(&Relation::new(id1, id3, "related", 0.9).unwrap())
912 .unwrap();
913
914 let results = kg.kg_find_related(id1, 1.0).unwrap();
915 assert_eq!(results.len(), 1);
916 assert_eq!(results[0].0.id, Some(id2));
917 }
918
919 #[test]
920 fn test_find_related_no_neighbours() {
921 let kg = KnowledgeGraph::open_in_memory().unwrap();
922 let id1 = kg.insert_entity(&Entity::new("paper", "Isolated")).unwrap();
923
924 let results = kg.kg_find_related(id1, 0.0).unwrap();
925 assert!(results.is_empty(), "isolated entity should return empty");
926 }
927
928 #[test]
929 fn test_find_related_entity_not_found() {
930 let kg = KnowledgeGraph::open_in_memory().unwrap();
931 let result = kg.kg_find_related(9999, 0.5);
932 assert!(result.is_err(), "non-existent entity should return error");
933 }
934
935 #[test]
936 fn test_similar_entities() {
937 let kg = KnowledgeGraph::open_in_memory().unwrap();
938 let id1 = kg.insert_entity(&graph::Entity::new("paper", "A")).unwrap();
939 let id2 = kg.insert_entity(&graph::Entity::new("paper", "B")).unwrap();
940 let id3 = kg.insert_entity(&graph::Entity::new("paper", "C")).unwrap();
941
942 kg.insert_vector(id1, vec![1.0, 0.0, 0.0, 0.0]).unwrap();
944 kg.insert_vector(id2, vec![0.9, 0.1, 0.0, 0.0]).unwrap();
945 kg.insert_vector(id3, vec![0.0, 0.0, 1.0, 0.0]).unwrap();
946
947 let results = kg.kg_similar_entities(id1, 2).unwrap();
948 assert_eq!(results.len(), 2);
949 assert_eq!(results[0].entity.name, "B");
951 assert!(results[0].similarity > results[1].similarity);
952 }
953
954 #[test]
955 fn test_similar_entities_excludes_self() {
956 let kg = KnowledgeGraph::open_in_memory().unwrap();
957 let id1 = kg.insert_entity(&graph::Entity::new("paper", "X")).unwrap();
958 kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
959
960 let results = kg.kg_similar_entities(id1, 5).unwrap();
961 assert!(results.is_empty(), "self should not appear in results");
962 }
963
964 #[test]
965 fn test_similar_entities_no_vector() {
966 let kg = KnowledgeGraph::open_in_memory().unwrap();
967 let id1 = kg
968 .insert_entity(&graph::Entity::new("paper", "NoVec"))
969 .unwrap();
970 let result = kg.kg_similar_entities(id1, 5);
971 assert!(result.is_err(), "entity without vector should error");
972 }
973}