Skip to main content

sqlite_knowledge_graph/
lib.rs

1//! SQLite-based Knowledge Graph Library
2//!
3//! This library provides a knowledge graph implementation built on SQLite with support for:
4//! - Entities with typed properties
5//! - Relations between entities with weights
6//! - Vector embeddings for semantic search
7//! - Custom SQLite functions for direct SQL operations
8//! - RAG (Retrieval-Augmented Generation) query functions
9//! - Graph algorithms (PageRank, Louvain, Connected Components)
10//!
11//! ## SQLite Extension
12//!
13//! This crate can be compiled as a SQLite loadable extension:
14//! ```bash
15//! cargo build --release
16//! sqlite3 db.db ".load ./target/release/libsqlite_knowledge_graph.dylib"
17//! sqlite3 db.db "SELECT kg_version();"
18//! ```
19
20pub mod algorithms;
21pub mod embed;
22pub mod error;
23pub mod export;
24pub mod extension;
25pub mod functions;
26pub mod graph;
27pub mod migrate;
28pub mod rag;
29pub mod schema;
30pub mod vector;
31
32/// Read a weight column that may be stored as REAL, INTEGER, NULL, or 8-byte BLOB.
33///
34/// Python's sqlite3 module stores numpy.float64 as a little-endian IEEE 754
35/// BLOB instead of REAL. This helper handles all storage variants so callers
36/// do not crash with `InvalidColumnType` on externally-written databases.
37pub(crate) fn row_get_weight(row: &rusqlite::Row, col: usize) -> rusqlite::Result<f64> {
38    use rusqlite::types::ValueRef;
39    match row.get_ref(col)? {
40        ValueRef::Real(f) => Ok(f),
41        ValueRef::Integer(i) => Ok(i as f64),
42        ValueRef::Null => Ok(1.0), // schema DEFAULT 1.0
43        ValueRef::Blob(b) if b.len() == 8 => {
44            let mut bytes = [0u8; 8];
45            bytes.copy_from_slice(b);
46            Ok(f64::from_le_bytes(bytes)) // numpy on LE systems (x86/arm64)
47        }
48        _ => Err(rusqlite::Error::InvalidColumnType(
49            col,
50            "weight".into(),
51            rusqlite::types::Type::Real,
52        )),
53    }
54}
55
56pub use algorithms::{
57    analyze_graph, connected_components, louvain_communities, pagerank, CommunityResult,
58    PageRankConfig,
59};
60pub use embed::{
61    check_dependencies, get_entities_needing_embedding, EmbeddingConfig, EmbeddingGenerator,
62    EmbeddingStats,
63};
64pub use error::{Error, Result};
65pub use export::{D3ExportGraph, D3ExportMetadata, D3Link, D3Node, DotConfig};
66pub use extension::sqlite3_sqlite_knowledge_graph_init;
67pub use functions::register_functions;
68pub use graph::{Direction, GraphStats, PathStep, TraversalNode, TraversalPath, TraversalQuery};
69pub use graph::{Entity, Neighbor, Relation};
70pub use graph::{HigherOrderNeighbor, HigherOrderPath, HigherOrderPathStep, Hyperedge};
71pub use migrate::{
72    build_relationships, migrate_all, migrate_papers, migrate_skills, MigrationStats,
73};
74pub use rag::{embedder::Embedder, embedder::FixedEmbedder, RagConfig, RagEngine, RagResult};
75pub use schema::{create_schema, schema_exists};
76pub use vector::{cosine_similarity, SearchResult, VectorStore};
77pub use vector::{TurboQuantConfig, TurboQuantIndex, TurboQuantStats};
78
79use rusqlite::Connection;
80use serde::{Deserialize, Serialize};
81
82/// Semantic search result with entity information.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct SearchResultWithEntity {
85    pub entity: Entity,
86    pub similarity: f32,
87}
88
89/// Graph context for an entity (root + neighbors).
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct GraphContext {
92    pub root_entity: Entity,
93    pub neighbors: Vec<Neighbor>,
94}
95
96/// Hybrid search result combining semantic similarity and graph context.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct HybridSearchResult {
99    pub entity: Entity,
100    pub similarity: f32,
101    pub context: Option<GraphContext>,
102}
103
104/// Knowledge Graph Manager - main entry point for the library.
105#[derive(Debug)]
106pub struct KnowledgeGraph {
107    conn: Connection,
108}
109
110impl KnowledgeGraph {
111    /// Open a new knowledge graph database connection.
112    pub fn open<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
113        let conn = Connection::open(path)?;
114
115        // Enable foreign keys
116        conn.execute("PRAGMA foreign_keys = ON", [])?;
117
118        // Create schema if not exists
119        if !schema_exists(&conn)? {
120            create_schema(&conn)?;
121        }
122
123        // Register custom functions
124        register_functions(&conn)?;
125
126        Ok(Self { conn })
127    }
128
129    /// Open an in-memory knowledge graph (useful for testing).
130    pub fn open_in_memory() -> Result<Self> {
131        let conn = Connection::open_in_memory()?;
132
133        // Enable foreign keys
134        conn.execute("PRAGMA foreign_keys = ON", [])?;
135
136        // Create schema
137        create_schema(&conn)?;
138
139        // Register custom functions
140        register_functions(&conn)?;
141
142        Ok(Self { conn })
143    }
144
145    /// Get a reference to the underlying SQLite connection.
146    pub fn connection(&self) -> &Connection {
147        &self.conn
148    }
149
150    /// Begin a transaction for batch operations.
151    pub fn transaction(&self) -> Result<rusqlite::Transaction<'_>> {
152        Ok(self.conn.unchecked_transaction()?)
153    }
154
155    /// Insert an entity into the knowledge graph.
156    pub fn insert_entity(&self, entity: &Entity) -> Result<i64> {
157        graph::insert_entity(&self.conn, entity)
158    }
159
160    /// Get an entity by ID.
161    pub fn get_entity(&self, id: i64) -> Result<Entity> {
162        graph::get_entity(&self.conn, id)
163    }
164
165    /// List entities with optional filtering.
166    pub fn list_entities(
167        &self,
168        entity_type: Option<&str>,
169        limit: Option<i64>,
170    ) -> Result<Vec<Entity>> {
171        graph::list_entities(&self.conn, entity_type, limit)
172    }
173
174    /// Update an entity.
175    pub fn update_entity(&self, entity: &Entity) -> Result<()> {
176        graph::update_entity(&self.conn, entity)
177    }
178
179    /// Delete an entity.
180    pub fn delete_entity(&self, id: i64) -> Result<()> {
181        graph::delete_entity(&self.conn, id)
182    }
183
184    /// Insert a relation between entities.
185    pub fn insert_relation(&self, relation: &Relation) -> Result<i64> {
186        graph::insert_relation(&self.conn, relation)
187    }
188
189    /// Get neighbors of an entity using BFS traversal.
190    pub fn get_neighbors(&self, entity_id: i64, depth: u32) -> Result<Vec<Neighbor>> {
191        graph::get_neighbors(&self.conn, entity_id, depth)
192    }
193
194    /// Insert a vector embedding for an entity.
195    pub fn insert_vector(&self, entity_id: i64, vector: Vec<f32>) -> Result<()> {
196        let store = VectorStore::new();
197        store.insert_vector(&self.conn, entity_id, vector)
198    }
199
200    /// Search for similar entities using vector embeddings.
201    pub fn search_vectors(&self, query: Vec<f32>, k: usize) -> Result<Vec<SearchResult>> {
202        let store = VectorStore::new();
203        store.search_vectors(&self.conn, query, k)
204    }
205
206    // ========== TurboQuant Vector Index ==========
207
208    /// Create a TurboQuant index for fast approximate nearest neighbor search.
209    ///
210    /// TurboQuant provides:
211    /// - Instant indexing (no training required)
212    /// - 6x memory compression
213    /// - Near-zero accuracy loss
214    ///
215    /// # Arguments
216    /// * `config` - Optional configuration (uses defaults if None)
217    ///
218    /// # Example
219    /// ```ignore
220    /// let config = TurboQuantConfig {
221    ///     dimension: 384,
222    ///     bit_width: 3,
223    ///     seed: 42,
224    /// };
225    /// let mut index = kg.create_turboquant_index(Some(config))?;
226    ///
227    /// // Add vectors to index
228    /// for (entity_id, vector) in all_vectors {
229    ///     index.add_vector(entity_id, &vector)?;
230    /// }
231    ///
232    /// // Fast search
233    /// let results = index.search(&query_vector, 10)?;
234    /// ```
235    pub fn create_turboquant_index(
236        &self,
237        config: Option<TurboQuantConfig>,
238    ) -> Result<TurboQuantIndex> {
239        let config = config.unwrap_or_default();
240
241        TurboQuantIndex::new(config)
242    }
243
244    /// Build a TurboQuant index from all existing vectors in the database.
245    /// This is a convenience method that loads all vectors and indexes them.
246    pub fn build_turboquant_index(
247        &self,
248        config: Option<TurboQuantConfig>,
249    ) -> Result<TurboQuantIndex> {
250        // Get dimension from first vector
251        let dimension = self.get_vector_dimension()?.unwrap_or(384);
252
253        let config = config.unwrap_or(TurboQuantConfig {
254            dimension,
255            bit_width: 3,
256            seed: 42,
257        });
258
259        let mut index = TurboQuantIndex::new(config)?;
260
261        // Load all vectors
262        let vectors = self.load_all_vectors()?;
263
264        for (entity_id, vector) in vectors {
265            index.add_vector(entity_id, &vector)?;
266        }
267
268        Ok(index)
269    }
270
271    /// Get the dimension of stored vectors (if any exist).
272    fn get_vector_dimension(&self) -> Result<Option<usize>> {
273        let result = self
274            .conn
275            .query_row("SELECT dimension FROM kg_vectors LIMIT 1", [], |row| {
276                row.get::<_, i64>(0)
277            });
278
279        match result {
280            Ok(dim) => Ok(Some(dim as usize)),
281            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
282            Err(e) => Err(e.into()),
283        }
284    }
285
286    /// Load all vectors from the database.
287    fn load_all_vectors(&self) -> Result<Vec<(i64, Vec<f32>)>> {
288        let mut stmt = self
289            .conn
290            .prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
291
292        let rows = stmt.query_map([], |row| {
293            let entity_id: i64 = row.get(0)?;
294            let vector_blob: Vec<u8> = row.get(1)?;
295            let dimension: i64 = row.get(2)?;
296
297            let mut vector = Vec::with_capacity(dimension as usize);
298            for chunk in vector_blob.chunks_exact(4) {
299                let bytes: [u8; 4] = chunk.try_into().unwrap();
300                vector.push(f32::from_le_bytes(bytes));
301            }
302
303            Ok((entity_id, vector))
304        })?;
305
306        let mut vectors = Vec::new();
307        for row in rows {
308            vectors.push(row?);
309        }
310
311        Ok(vectors)
312    }
313
314    // ========== Higher-Order Relations (Hyperedges) ==========
315
316    /// Insert a hyperedge (higher-order relation) into the knowledge graph.
317    pub fn insert_hyperedge(&self, hyperedge: &Hyperedge) -> Result<i64> {
318        graph::insert_hyperedge(&self.conn, hyperedge)
319    }
320
321    /// Get a hyperedge by ID.
322    pub fn get_hyperedge(&self, id: i64) -> Result<Hyperedge> {
323        graph::get_hyperedge(&self.conn, id)
324    }
325
326    /// List hyperedges with optional filtering.
327    pub fn list_hyperedges(
328        &self,
329        hyperedge_type: Option<&str>,
330        min_arity: Option<usize>,
331        max_arity: Option<usize>,
332        limit: Option<i64>,
333    ) -> Result<Vec<Hyperedge>> {
334        graph::list_hyperedges(&self.conn, hyperedge_type, min_arity, max_arity, limit)
335    }
336
337    /// Update a hyperedge.
338    pub fn update_hyperedge(&self, hyperedge: &Hyperedge) -> Result<()> {
339        graph::update_hyperedge(&self.conn, hyperedge)
340    }
341
342    /// Delete a hyperedge by ID.
343    pub fn delete_hyperedge(&self, id: i64) -> Result<()> {
344        graph::delete_hyperedge(&self.conn, id)
345    }
346
347    /// Get higher-order neighbors of an entity (connected through hyperedges).
348    pub fn get_higher_order_neighbors(
349        &self,
350        entity_id: i64,
351        min_arity: Option<usize>,
352        max_arity: Option<usize>,
353    ) -> Result<Vec<HigherOrderNeighbor>> {
354        graph::get_higher_order_neighbors(&self.conn, entity_id, min_arity, max_arity)
355    }
356
357    /// Get all hyperedges that an entity participates in.
358    pub fn get_entity_hyperedges(&self, entity_id: i64) -> Result<Vec<Hyperedge>> {
359        graph::get_entity_hyperedges(&self.conn, entity_id)
360    }
361
362    /// Higher-order BFS traversal through hyperedges.
363    pub fn kg_higher_order_bfs(
364        &self,
365        start_id: i64,
366        max_depth: u32,
367        min_arity: Option<usize>,
368    ) -> Result<Vec<TraversalNode>> {
369        graph::higher_order_bfs(&self.conn, start_id, max_depth, min_arity)
370    }
371
372    /// Find shortest path between two entities through hyperedges.
373    pub fn kg_higher_order_shortest_path(
374        &self,
375        from_id: i64,
376        to_id: i64,
377        max_depth: u32,
378    ) -> Result<Option<HigherOrderPath>> {
379        graph::higher_order_shortest_path(&self.conn, from_id, to_id, max_depth)
380    }
381
382    /// Compute hyperedge degree centrality for an entity.
383    pub fn kg_hyperedge_degree(&self, entity_id: i64) -> Result<f64> {
384        graph::hyperedge_degree(&self.conn, entity_id)
385    }
386
387    /// Compute entity-level hypergraph PageRank using Zhou formula.
388    pub fn kg_hypergraph_entity_pagerank(
389        &self,
390        damping: Option<f64>,
391        max_iter: Option<usize>,
392        tolerance: Option<f64>,
393    ) -> Result<std::collections::HashMap<i64, f64>> {
394        graph::hypergraph_entity_pagerank(
395            &self.conn,
396            damping.unwrap_or(0.85),
397            max_iter.unwrap_or(100),
398            tolerance.unwrap_or(1e-6),
399        )
400    }
401
402    // ========== RAG Query Functions ==========
403
404    /// Semantic search using vector embeddings.
405    /// Returns entities sorted by similarity score.
406    pub fn kg_semantic_search(
407        &self,
408        query_embedding: Vec<f32>,
409        k: usize,
410    ) -> Result<Vec<SearchResultWithEntity>> {
411        let results = self.search_vectors(query_embedding, k)?;
412
413        let mut entities_with_results = Vec::new();
414        for result in results {
415            let entity = self.get_entity(result.entity_id)?;
416            entities_with_results.push(SearchResultWithEntity {
417                entity,
418                similarity: result.similarity,
419            });
420        }
421
422        Ok(entities_with_results)
423    }
424
425    /// Get context around an entity using graph traversal.
426    /// Returns neighbors up to the specified depth.
427    pub fn kg_get_context(&self, entity_id: i64, depth: u32) -> Result<GraphContext> {
428        let root_entity = self.get_entity(entity_id)?;
429        let neighbors = self.get_neighbors(entity_id, depth)?;
430
431        Ok(GraphContext {
432            root_entity,
433            neighbors,
434        })
435    }
436
437    /// Hybrid search combining semantic search and graph context.
438    /// Performs semantic search first, then retrieves context for top-k results.
439    pub fn kg_hybrid_search(
440        &self,
441        _query_text: &str,
442        query_embedding: Vec<f32>,
443        k: usize,
444    ) -> Result<Vec<HybridSearchResult>> {
445        let semantic_results = self.kg_semantic_search(query_embedding, k)?;
446
447        let mut hybrid_results = Vec::new();
448        for result in semantic_results.iter() {
449            let entity_id = result.entity.id.ok_or(Error::EntityNotFound(0))?;
450            let context = self.kg_get_context(entity_id, 1)?; // Depth 1 context
451
452            hybrid_results.push(HybridSearchResult {
453                entity: result.entity.clone(),
454                similarity: result.similarity,
455                context: Some(context),
456            });
457        }
458
459        Ok(hybrid_results)
460    }
461
462    /// Find entities similar to `entity_id` by vector cosine similarity.
463    ///
464    /// Looks up the stored embedding for the given entity and performs a
465    /// nearest-neighbour search across all vectors.  The source entity is
466    /// excluded from results.
467    ///
468    /// Returns `(entity, similarity)` pairs sorted by similarity descending.
469    pub fn kg_similar_entities(
470        &self,
471        entity_id: i64,
472        k: usize,
473    ) -> Result<Vec<SearchResultWithEntity>> {
474        let store = VectorStore::new();
475        let query_vec = store.get_vector(&self.conn, entity_id)?;
476        // request k+1 since the entity itself will appear
477        let results = store.search_vectors(&self.conn, query_vec, k + 1)?;
478
479        let mut out = Vec::new();
480        for r in results {
481            if r.entity_id == entity_id {
482                continue;
483            }
484            let entity = self.get_entity(r.entity_id)?;
485            out.push(SearchResultWithEntity {
486                entity,
487                similarity: r.similarity,
488            });
489        }
490        out.truncate(k);
491        Ok(out)
492    }
493
494    /// Find entities related to `entity_id` whose connecting relation weight
495    /// is at or above `threshold`.  Depth-1 neighbours only.
496    ///
497    /// Returns `(entity, relation_weight)` pairs sorted by weight descending.
498    pub fn kg_find_related(
499        &self,
500        entity_id: i64,
501        threshold: f64,
502    ) -> Result<Vec<(graph::Entity, f64)>> {
503        let neighbours = self.get_neighbors(entity_id, 1)?;
504        let mut results: Vec<(graph::Entity, f64)> = neighbours
505            .into_iter()
506            .filter(|n| n.relation.weight >= threshold)
507            .map(|n| (n.entity, n.relation.weight))
508            .collect();
509        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
510        Ok(results)
511    }
512
513    // ========== Graph Traversal Functions ==========
514
515    /// BFS traversal from a starting entity.
516    /// Returns all reachable entities within max_depth with depth information.
517    pub fn kg_bfs_traversal(
518        &self,
519        start_id: i64,
520        direction: Direction,
521        max_depth: u32,
522    ) -> Result<Vec<TraversalNode>> {
523        let query = TraversalQuery {
524            direction,
525            max_depth,
526            ..Default::default()
527        };
528        graph::bfs_traversal(&self.conn, start_id, query)
529    }
530
531    /// DFS traversal from a starting entity.
532    /// Returns all reachable entities within max_depth.
533    pub fn kg_dfs_traversal(
534        &self,
535        start_id: i64,
536        direction: Direction,
537        max_depth: u32,
538    ) -> Result<Vec<TraversalNode>> {
539        let query = TraversalQuery {
540            direction,
541            max_depth,
542            ..Default::default()
543        };
544        graph::dfs_traversal(&self.conn, start_id, query)
545    }
546
547    /// Find shortest path between two entities using BFS.
548    /// Returns the path with all intermediate steps (if exists).
549    pub fn kg_shortest_path(
550        &self,
551        from_id: i64,
552        to_id: i64,
553        max_depth: u32,
554    ) -> Result<Option<TraversalPath>> {
555        graph::find_shortest_path(&self.conn, from_id, to_id, max_depth)
556    }
557
558    /// Compute graph statistics.
559    pub fn kg_graph_stats(&self) -> Result<GraphStats> {
560        graph::compute_graph_stats(&self.conn)
561    }
562
563    // ========== Graph Algorithms ==========
564
565    /// Compute PageRank scores for all entities.
566    /// Returns a vector of (entity_id, score) sorted by score descending.
567    pub fn kg_pagerank(&self, config: Option<PageRankConfig>) -> Result<Vec<(i64, f64)>> {
568        algorithms::pagerank(&self.conn, config.unwrap_or_default())
569    }
570
571    /// Detect communities using Louvain algorithm.
572    /// Returns community memberships and modularity score.
573    pub fn kg_louvain(&self) -> Result<CommunityResult> {
574        algorithms::louvain_communities(&self.conn)
575    }
576
577    /// Find connected components in the graph.
578    /// Returns a list of components, each being a list of entity IDs.
579    pub fn kg_connected_components(&self) -> Result<Vec<Vec<i64>>> {
580        algorithms::connected_components(&self.conn)
581    }
582
583    /// Run full graph analysis (PageRank + Louvain + Connected Components).
584    pub fn kg_analyze(&self) -> Result<algorithms::GraphAnalysis> {
585        algorithms::analyze_graph(&self.conn)
586    }
587
588    // ========== Visualization Export ==========
589
590    /// Export the knowledge graph in D3.js JSON format.
591    ///
592    /// Returns a `D3ExportGraph` containing nodes, links, and metadata,
593    /// ready for use with D3.js force-directed graph visualizations.
594    ///
595    /// # Example
596    /// ```ignore
597    /// let graph = kg.export_json()?;
598    /// let json = serde_json::to_string_pretty(&graph)?;
599    /// std::fs::write("graph.json", json)?;
600    /// ```
601    pub fn export_json(&self) -> Result<D3ExportGraph> {
602        export::export_d3_json(&self.conn)
603    }
604
605    /// Export the knowledge graph in DOT (Graphviz) format.
606    ///
607    /// Returns a DOT format string that can be rendered with Graphviz tools
608    /// (`dot`, `neato`, `fdp`, etc.).
609    ///
610    /// # Example
611    /// ```ignore
612    /// let dot = kg.export_dot(&DotConfig::default())?;
613    /// std::fs::write("graph.dot", dot)?;
614    /// // Then: dot -Tpng graph.dot -o graph.png
615    /// ```
616    pub fn export_dot(&self, config: &DotConfig) -> Result<String> {
617        export::export_dot(&self.conn, config)
618    }
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624
625    #[test]
626    fn test_open_in_memory() {
627        let kg = KnowledgeGraph::open_in_memory().unwrap();
628        assert!(schema_exists(kg.connection()).unwrap());
629    }
630
631    #[test]
632    fn test_crud_operations() {
633        let kg = KnowledgeGraph::open_in_memory().unwrap();
634
635        // Create entity
636        let mut entity = Entity::new("paper", "Test Paper");
637        entity.set_property("author", serde_json::json!("John Doe"));
638        let id = kg.insert_entity(&entity).unwrap();
639
640        // Read entity
641        let retrieved = kg.get_entity(id).unwrap();
642        assert_eq!(retrieved.name, "Test Paper");
643
644        // List entities
645        let entities = kg.list_entities(Some("paper"), None).unwrap();
646        assert_eq!(entities.len(), 1);
647
648        // Update entity
649        let mut updated = retrieved.clone();
650        updated.set_property("year", serde_json::json!(2024));
651        kg.update_entity(&updated).unwrap();
652
653        // Delete entity
654        kg.delete_entity(id).unwrap();
655        let entities = kg.list_entities(None, None).unwrap();
656        assert_eq!(entities.len(), 0);
657    }
658
659    #[test]
660    fn test_graph_traversal() {
661        let kg = KnowledgeGraph::open_in_memory().unwrap();
662
663        // Create entities
664        let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
665        let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
666        let id3 = kg.insert_entity(&Entity::new("paper", "Paper 3")).unwrap();
667
668        // Create relations
669        kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
670            .unwrap();
671        kg.insert_relation(&Relation::new(id2, id3, "cites", 0.9).unwrap())
672            .unwrap();
673
674        // Get neighbors depth 1
675        let neighbors = kg.get_neighbors(id1, 1).unwrap();
676        assert_eq!(neighbors.len(), 1);
677
678        // Get neighbors depth 2
679        let neighbors = kg.get_neighbors(id1, 2).unwrap();
680        assert_eq!(neighbors.len(), 2);
681    }
682
683    #[test]
684    fn test_vector_search() {
685        let kg = KnowledgeGraph::open_in_memory().unwrap();
686
687        // Create entities
688        let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
689        let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
690
691        // Insert vectors
692        kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
693        kg.insert_vector(id2, vec![0.0, 1.0, 0.0]).unwrap();
694
695        // Search
696        let results = kg.search_vectors(vec![1.0, 0.0, 0.0], 2).unwrap();
697        assert_eq!(results.len(), 2);
698        assert_eq!(results[0].entity_id, id1);
699    }
700
701    // ── kg_find_related tests ─────────────────────────────────────────────────
702
703    #[test]
704    fn test_find_related_above_threshold() {
705        let kg = KnowledgeGraph::open_in_memory().unwrap();
706        let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
707        let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
708        let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
709
710        kg.insert_relation(&Relation::new(id1, id2, "related", 0.9).unwrap())
711            .unwrap();
712        kg.insert_relation(&Relation::new(id1, id3, "related", 0.3).unwrap())
713            .unwrap();
714
715        let results = kg.kg_find_related(id1, 0.5).unwrap();
716        assert_eq!(
717            results.len(),
718            1,
719            "only B (weight 0.9) should pass threshold 0.5"
720        );
721        assert_eq!(results[0].0.id, Some(id2));
722    }
723
724    #[test]
725    fn test_find_related_sorted_descending() {
726        let kg = KnowledgeGraph::open_in_memory().unwrap();
727        let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
728        let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
729        let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
730
731        kg.insert_relation(&Relation::new(id1, id2, "related", 0.4).unwrap())
732            .unwrap();
733        kg.insert_relation(&Relation::new(id1, id3, "related", 0.9).unwrap())
734            .unwrap();
735
736        let results = kg.kg_find_related(id1, 0.0).unwrap();
737        assert_eq!(results.len(), 2);
738        assert!(
739            results[0].1 >= results[1].1,
740            "results should be sorted by weight desc"
741        );
742        assert_eq!(results[0].0.id, Some(id3)); // weight 0.9 first
743    }
744
745    #[test]
746    fn test_find_related_threshold_one() {
747        let kg = KnowledgeGraph::open_in_memory().unwrap();
748        let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
749        let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
750        let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
751
752        kg.insert_relation(&Relation::new(id1, id2, "related", 1.0).unwrap())
753            .unwrap();
754        kg.insert_relation(&Relation::new(id1, id3, "related", 0.9).unwrap())
755            .unwrap();
756
757        let results = kg.kg_find_related(id1, 1.0).unwrap();
758        assert_eq!(results.len(), 1);
759        assert_eq!(results[0].0.id, Some(id2));
760    }
761
762    #[test]
763    fn test_find_related_no_neighbours() {
764        let kg = KnowledgeGraph::open_in_memory().unwrap();
765        let id1 = kg.insert_entity(&Entity::new("paper", "Isolated")).unwrap();
766
767        let results = kg.kg_find_related(id1, 0.0).unwrap();
768        assert!(results.is_empty(), "isolated entity should return empty");
769    }
770
771    #[test]
772    fn test_find_related_entity_not_found() {
773        let kg = KnowledgeGraph::open_in_memory().unwrap();
774        let result = kg.kg_find_related(9999, 0.5);
775        assert!(result.is_err(), "non-existent entity should return error");
776    }
777
778    #[test]
779    fn test_similar_entities() {
780        let kg = KnowledgeGraph::open_in_memory().unwrap();
781        let id1 = kg.insert_entity(&graph::Entity::new("paper", "A")).unwrap();
782        let id2 = kg.insert_entity(&graph::Entity::new("paper", "B")).unwrap();
783        let id3 = kg.insert_entity(&graph::Entity::new("paper", "C")).unwrap();
784
785        // A and B are very similar, C is different
786        kg.insert_vector(id1, vec![1.0, 0.0, 0.0, 0.0]).unwrap();
787        kg.insert_vector(id2, vec![0.9, 0.1, 0.0, 0.0]).unwrap();
788        kg.insert_vector(id3, vec![0.0, 0.0, 1.0, 0.0]).unwrap();
789
790        let results = kg.kg_similar_entities(id1, 2).unwrap();
791        assert_eq!(results.len(), 2);
792        // B should be most similar to A
793        assert_eq!(results[0].entity.name, "B");
794        assert!(results[0].similarity > results[1].similarity);
795    }
796
797    #[test]
798    fn test_similar_entities_excludes_self() {
799        let kg = KnowledgeGraph::open_in_memory().unwrap();
800        let id1 = kg.insert_entity(&graph::Entity::new("paper", "X")).unwrap();
801        kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
802
803        let results = kg.kg_similar_entities(id1, 5).unwrap();
804        assert!(results.is_empty(), "self should not appear in results");
805    }
806
807    #[test]
808    fn test_similar_entities_no_vector() {
809        let kg = KnowledgeGraph::open_in_memory().unwrap();
810        let id1 = kg
811            .insert_entity(&graph::Entity::new("paper", "NoVec"))
812            .unwrap();
813        let result = kg.kg_similar_entities(id1, 5);
814        assert!(result.is_err(), "entity without vector should error");
815    }
816}