Skip to main content

sqlite_knowledge_graph/
lib.rs

1//! SQLite-based Knowledge Graph Library
2//!
3//! This library provides a knowledge graph implementation built on SQLite with support for:
4//! - Entities with typed properties
5//! - Relations between entities with weights
6//! - Vector embeddings for semantic search
7//! - Custom SQLite functions for direct SQL operations
8//! - RAG (Retrieval-Augmented Generation) query functions
9//! - Graph algorithms (PageRank, Louvain, Connected Components)
10//!
11//! ## SQLite Extension
12//!
13//! This crate can be compiled as a SQLite loadable extension:
14//! ```bash
15//! cargo build --release
16//! sqlite3 db.db ".load ./target/release/libsqlite_knowledge_graph.dylib"
17//! sqlite3 db.db "SELECT kg_version();"
18//! ```
19
20pub mod algorithms;
21pub mod embed;
22pub mod error;
23pub mod export;
24pub mod extension;
25pub mod functions;
26pub mod graph;
27pub mod migrate;
28pub mod rag;
29pub mod schema;
30pub mod vector;
31
32#[cfg(feature = "async")]
33pub mod async_kg;
34#[cfg(feature = "async")]
35pub use async_kg::embed::AsyncEmbeddingGenerator;
36#[cfg(feature = "async")]
37pub use async_kg::AsyncKnowledgeGraph;
38
39/// Read a weight column that may be stored as REAL, INTEGER, NULL, or 8-byte BLOB.
40///
41/// Python's sqlite3 module stores numpy.float64 as a little-endian IEEE 754
42/// BLOB instead of REAL. This helper handles all storage variants so callers
43/// do not crash with `InvalidColumnType` on externally-written databases.
44pub(crate) fn row_get_weight(row: &rusqlite::Row, col: usize) -> rusqlite::Result<f64> {
45    use rusqlite::types::ValueRef;
46    match row.get_ref(col)? {
47        ValueRef::Real(f) => Ok(f),
48        ValueRef::Integer(i) => Ok(i as f64),
49        ValueRef::Null => Ok(1.0), // schema DEFAULT 1.0
50        ValueRef::Blob(b) if b.len() == 8 => {
51            let mut bytes = [0u8; 8];
52            bytes.copy_from_slice(b);
53            Ok(f64::from_le_bytes(bytes)) // numpy.float64 on LE systems
54        }
55        ValueRef::Blob(b) if b.len() == 4 => {
56            let mut bytes = [0u8; 4];
57            bytes.copy_from_slice(b);
58            Ok(f32::from_le_bytes(bytes) as f64) // numpy.float32 or truncated
59        }
60        ValueRef::Blob(_) => Ok(1.0), // Unknown blob size, use default
61        _ => Err(rusqlite::Error::InvalidColumnType(
62            col,
63            "weight".into(),
64            rusqlite::types::Type::Real,
65        )),
66    }
67}
68
69pub use algorithms::{
70    analyze_graph, connected_components, louvain_communities, pagerank, CommunityResult,
71    PageRankConfig,
72};
73pub use embed::{
74    check_dependencies, get_entities_needing_embedding, EmbeddingConfig, EmbeddingGenerator,
75    EmbeddingStats,
76};
77pub use error::{Error, Result};
78pub use export::{D3ExportGraph, D3ExportMetadata, D3Link, D3Node, DotConfig};
79pub use extension::sqlite3_sqlite_knowledge_graph_init;
80pub use functions::register_functions;
81pub use graph::{Direction, GraphStats, PathStep, TraversalNode, TraversalPath, TraversalQuery};
82pub use graph::{Entity, Neighbor, Relation};
83pub use graph::{HigherOrderNeighbor, HigherOrderPath, HigherOrderPathStep, Hyperedge};
84pub use migrate::{
85    build_relationships, migrate_all, migrate_papers, migrate_skills, MigrationStats,
86};
87pub use rag::{embedder::Embedder, embedder::FixedEmbedder, RagConfig, RagEngine, RagResult};
88pub use schema::{create_schema, schema_exists};
89pub use vector::{cosine_similarity, SearchResult, VectorStore};
90pub use vector::{TurboQuantConfig, TurboQuantIndex, TurboQuantStats};
91
92use rusqlite::Connection;
93use serde::{Deserialize, Serialize};
94
95/// Semantic search result with entity information.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct SearchResultWithEntity {
98    pub entity: Entity,
99    pub similarity: f32,
100}
101
102/// Graph context for an entity (root + neighbors).
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct GraphContext {
105    pub root_entity: Entity,
106    pub neighbors: Vec<Neighbor>,
107}
108
109/// Hybrid search result combining semantic similarity and graph context.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct HybridSearchResult {
112    pub entity: Entity,
113    pub similarity: f32,
114    pub context: Option<GraphContext>,
115}
116
117/// Knowledge Graph Manager - main entry point for the library.
118#[derive(Debug)]
119pub struct KnowledgeGraph {
120    conn: Connection,
121}
122
123impl KnowledgeGraph {
124    /// Open a new knowledge graph database connection.
125    pub fn open<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
126        let conn = Connection::open(path)?;
127
128        // Enable foreign keys
129        conn.execute("PRAGMA foreign_keys = ON", [])?;
130
131        // Create schema if not exists
132        if !schema_exists(&conn)? {
133            create_schema(&conn)?;
134        }
135
136        // Register custom functions
137        register_functions(&conn)?;
138
139        Ok(Self { conn })
140    }
141
142    /// Open an in-memory knowledge graph (useful for testing).
143    pub fn open_in_memory() -> Result<Self> {
144        let conn = Connection::open_in_memory()?;
145
146        // Enable foreign keys
147        conn.execute("PRAGMA foreign_keys = ON", [])?;
148
149        // Create schema
150        create_schema(&conn)?;
151
152        // Register custom functions
153        register_functions(&conn)?;
154
155        Ok(Self { conn })
156    }
157
158    /// Get a reference to the underlying SQLite connection.
159    pub fn connection(&self) -> &Connection {
160        &self.conn
161    }
162
163    /// Begin a transaction for batch operations.
164    pub fn transaction(&self) -> Result<rusqlite::Transaction<'_>> {
165        Ok(self.conn.unchecked_transaction()?)
166    }
167
168    /// Insert an entity into the knowledge graph.
169    pub fn insert_entity(&self, entity: &Entity) -> Result<i64> {
170        graph::insert_entity(&self.conn, entity)
171    }
172
173    /// Get an entity by ID.
174    pub fn get_entity(&self, id: i64) -> Result<Entity> {
175        graph::get_entity(&self.conn, id)
176    }
177
178    /// List entities with optional filtering.
179    pub fn list_entities(
180        &self,
181        entity_type: Option<&str>,
182        limit: Option<i64>,
183    ) -> Result<Vec<Entity>> {
184        graph::list_entities(&self.conn, entity_type, limit)
185    }
186
187    /// Update an entity.
188    pub fn update_entity(&self, entity: &Entity) -> Result<()> {
189        graph::update_entity(&self.conn, entity)
190    }
191
192    /// Delete an entity.
193    pub fn delete_entity(&self, id: i64) -> Result<()> {
194        graph::delete_entity(&self.conn, id)
195    }
196
197    /// Insert a relation between entities.
198    pub fn insert_relation(&self, relation: &Relation) -> Result<i64> {
199        graph::insert_relation(&self.conn, relation)
200    }
201
202    /// Get neighbors of an entity using BFS traversal.
203    pub fn get_neighbors(&self, entity_id: i64, depth: u32) -> Result<Vec<Neighbor>> {
204        graph::get_neighbors(&self.conn, entity_id, depth)
205    }
206
207    /// Insert a vector embedding for an entity.
208    pub fn insert_vector(&self, entity_id: i64, vector: Vec<f32>) -> Result<()> {
209        let store = VectorStore::new();
210        store.insert_vector(&self.conn, entity_id, vector)
211    }
212
213    /// Search for similar entities using vector embeddings.
214    pub fn search_vectors(&self, query: Vec<f32>, k: usize) -> Result<Vec<SearchResult>> {
215        let store = VectorStore::new();
216        store.search_vectors(&self.conn, query, k)
217    }
218
219    // ========== TurboQuant Vector Index ==========
220
221    /// Create a TurboQuant index for fast approximate nearest neighbor search.
222    ///
223    /// TurboQuant provides:
224    /// - Instant indexing (no training required)
225    /// - 6x memory compression
226    /// - Near-zero accuracy loss
227    ///
228    /// # Arguments
229    /// * `config` - Optional configuration (uses defaults if None)
230    ///
231    /// # Example
232    /// ```ignore
233    /// let config = TurboQuantConfig {
234    ///     dimension: 384,
235    ///     bit_width: 3,
236    ///     seed: 42,
237    /// };
238    /// let mut index = kg.create_turboquant_index(Some(config))?;
239    ///
240    /// // Add vectors to index
241    /// for (entity_id, vector) in all_vectors {
242    ///     index.add_vector(entity_id, &vector)?;
243    /// }
244    ///
245    /// // Fast search
246    /// let results = index.search(&query_vector, 10)?;
247    /// ```
248    pub fn create_turboquant_index(
249        &self,
250        config: Option<TurboQuantConfig>,
251    ) -> Result<TurboQuantIndex> {
252        let config = config.unwrap_or_default();
253
254        TurboQuantIndex::new(config)
255    }
256
257    /// Build a TurboQuant index from all existing vectors in the database.
258    /// This is a convenience method that loads all vectors and indexes them.
259    pub fn build_turboquant_index(
260        &self,
261        config: Option<TurboQuantConfig>,
262    ) -> Result<TurboQuantIndex> {
263        // Get dimension from first vector
264        let dimension = self.get_vector_dimension()?.unwrap_or(384);
265
266        let config = config.unwrap_or(TurboQuantConfig {
267            dimension,
268            bit_width: 3,
269            seed: 42,
270        });
271
272        let mut index = TurboQuantIndex::new(config)?;
273
274        // Load all vectors
275        let vectors = self.load_all_vectors()?;
276
277        for (entity_id, vector) in vectors {
278            index.add_vector(entity_id, &vector)?;
279        }
280
281        Ok(index)
282    }
283
284    /// Get the dimension of stored vectors (if any exist).
285    fn get_vector_dimension(&self) -> Result<Option<usize>> {
286        let result = self
287            .conn
288            .query_row("SELECT dimension FROM kg_vectors LIMIT 1", [], |row| {
289                row.get::<_, i64>(0)
290            });
291
292        match result {
293            Ok(dim) => Ok(Some(dim as usize)),
294            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
295            Err(e) => Err(e.into()),
296        }
297    }
298
299    /// Load all vectors from the database.
300    fn load_all_vectors(&self) -> Result<Vec<(i64, Vec<f32>)>> {
301        let mut stmt = self
302            .conn
303            .prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
304
305        let rows = stmt.query_map([], |row| {
306            let entity_id: i64 = row.get(0)?;
307            let vector_blob: Vec<u8> = row.get(1)?;
308            let dimension: i64 = row.get(2)?;
309
310            let mut vector = Vec::with_capacity(dimension as usize);
311            for chunk in vector_blob.chunks_exact(4) {
312                let bytes: [u8; 4] = chunk.try_into().unwrap();
313                vector.push(f32::from_le_bytes(bytes));
314            }
315
316            Ok((entity_id, vector))
317        })?;
318
319        let mut vectors = Vec::new();
320        for row in rows {
321            vectors.push(row?);
322        }
323
324        Ok(vectors)
325    }
326
327    // ========== Higher-Order Relations (Hyperedges) ==========
328
329    /// Insert a hyperedge (higher-order relation) into the knowledge graph.
330    pub fn insert_hyperedge(&self, hyperedge: &Hyperedge) -> Result<i64> {
331        graph::insert_hyperedge(&self.conn, hyperedge)
332    }
333
334    /// Get a hyperedge by ID.
335    pub fn get_hyperedge(&self, id: i64) -> Result<Hyperedge> {
336        graph::get_hyperedge(&self.conn, id)
337    }
338
339    /// List hyperedges with optional filtering.
340    pub fn list_hyperedges(
341        &self,
342        hyperedge_type: Option<&str>,
343        min_arity: Option<usize>,
344        max_arity: Option<usize>,
345        limit: Option<i64>,
346    ) -> Result<Vec<Hyperedge>> {
347        graph::list_hyperedges(&self.conn, hyperedge_type, min_arity, max_arity, limit)
348    }
349
350    /// Update a hyperedge.
351    pub fn update_hyperedge(&self, hyperedge: &Hyperedge) -> Result<()> {
352        graph::update_hyperedge(&self.conn, hyperedge)
353    }
354
355    /// Delete a hyperedge by ID.
356    pub fn delete_hyperedge(&self, id: i64) -> Result<()> {
357        graph::delete_hyperedge(&self.conn, id)
358    }
359
360    /// Get higher-order neighbors of an entity (connected through hyperedges).
361    pub fn get_higher_order_neighbors(
362        &self,
363        entity_id: i64,
364        min_arity: Option<usize>,
365        max_arity: Option<usize>,
366    ) -> Result<Vec<HigherOrderNeighbor>> {
367        graph::get_higher_order_neighbors(&self.conn, entity_id, min_arity, max_arity)
368    }
369
370    /// Get all hyperedges that an entity participates in.
371    pub fn get_entity_hyperedges(&self, entity_id: i64) -> Result<Vec<Hyperedge>> {
372        graph::get_entity_hyperedges(&self.conn, entity_id)
373    }
374
375    /// Higher-order BFS traversal through hyperedges.
376    pub fn kg_higher_order_bfs(
377        &self,
378        start_id: i64,
379        max_depth: u32,
380        min_arity: Option<usize>,
381    ) -> Result<Vec<TraversalNode>> {
382        graph::higher_order_bfs(&self.conn, start_id, max_depth, min_arity)
383    }
384
385    /// Find shortest path between two entities through hyperedges.
386    pub fn kg_higher_order_shortest_path(
387        &self,
388        from_id: i64,
389        to_id: i64,
390        max_depth: u32,
391    ) -> Result<Option<HigherOrderPath>> {
392        graph::higher_order_shortest_path(&self.conn, from_id, to_id, max_depth)
393    }
394
395    /// Compute hyperedge degree centrality for an entity.
396    pub fn kg_hyperedge_degree(&self, entity_id: i64) -> Result<f64> {
397        graph::hyperedge_degree(&self.conn, entity_id)
398    }
399
400    /// Compute entity-level hypergraph PageRank using Zhou formula.
401    pub fn kg_hypergraph_entity_pagerank(
402        &self,
403        damping: Option<f64>,
404        max_iter: Option<usize>,
405        tolerance: Option<f64>,
406    ) -> Result<std::collections::HashMap<i64, f64>> {
407        graph::hypergraph_entity_pagerank(
408            &self.conn,
409            damping.unwrap_or(0.85),
410            max_iter.unwrap_or(100),
411            tolerance.unwrap_or(1e-6),
412        )
413    }
414
415    // ========== RAG Query Functions ==========
416
417    /// Semantic search using vector embeddings.
418    /// Returns entities sorted by similarity score.
419    pub fn kg_semantic_search(
420        &self,
421        query_embedding: Vec<f32>,
422        k: usize,
423    ) -> Result<Vec<SearchResultWithEntity>> {
424        let results = self.search_vectors(query_embedding, k)?;
425
426        let mut entities_with_results = Vec::new();
427        for result in results {
428            let entity = self.get_entity(result.entity_id)?;
429            entities_with_results.push(SearchResultWithEntity {
430                entity,
431                similarity: result.similarity,
432            });
433        }
434
435        Ok(entities_with_results)
436    }
437
438    /// Get context around an entity using graph traversal.
439    /// Returns neighbors up to the specified depth.
440    pub fn kg_get_context(&self, entity_id: i64, depth: u32) -> Result<GraphContext> {
441        let root_entity = self.get_entity(entity_id)?;
442        let neighbors = self.get_neighbors(entity_id, depth)?;
443
444        Ok(GraphContext {
445            root_entity,
446            neighbors,
447        })
448    }
449
450    /// Hybrid search combining semantic search and graph context.
451    /// Performs semantic search first, then retrieves context for top-k results.
452    pub fn kg_hybrid_search(
453        &self,
454        _query_text: &str,
455        query_embedding: Vec<f32>,
456        k: usize,
457    ) -> Result<Vec<HybridSearchResult>> {
458        let semantic_results = self.kg_semantic_search(query_embedding, k)?;
459
460        let mut hybrid_results = Vec::new();
461        for result in semantic_results.iter() {
462            let entity_id = result.entity.id.ok_or(Error::EntityNotFound(0))?;
463            let context = self.kg_get_context(entity_id, 1)?; // Depth 1 context
464
465            hybrid_results.push(HybridSearchResult {
466                entity: result.entity.clone(),
467                similarity: result.similarity,
468                context: Some(context),
469            });
470        }
471
472        Ok(hybrid_results)
473    }
474
475    /// Find entities similar to `entity_id` by vector cosine similarity.
476    ///
477    /// Looks up the stored embedding for the given entity and performs a
478    /// nearest-neighbour search across all vectors.  The source entity is
479    /// excluded from results.
480    ///
481    /// Returns `(entity, similarity)` pairs sorted by similarity descending.
482    pub fn kg_similar_entities(
483        &self,
484        entity_id: i64,
485        k: usize,
486    ) -> Result<Vec<SearchResultWithEntity>> {
487        let store = VectorStore::new();
488        let query_vec = store.get_vector(&self.conn, entity_id)?;
489        // request k+1 since the entity itself will appear
490        let results = store.search_vectors(&self.conn, query_vec, k + 1)?;
491
492        let mut out = Vec::new();
493        for r in results {
494            if r.entity_id == entity_id {
495                continue;
496            }
497            let entity = self.get_entity(r.entity_id)?;
498            out.push(SearchResultWithEntity {
499                entity,
500                similarity: r.similarity,
501            });
502        }
503        out.truncate(k);
504        Ok(out)
505    }
506
507    /// Find entities related to `entity_id` whose connecting relation weight
508    /// is at or above `threshold`.  Depth-1 neighbours only.
509    ///
510    /// Returns `(entity, relation_weight)` pairs sorted by weight descending.
511    pub fn kg_find_related(
512        &self,
513        entity_id: i64,
514        threshold: f64,
515    ) -> Result<Vec<(graph::Entity, f64)>> {
516        let neighbours = self.get_neighbors(entity_id, 1)?;
517        let mut results: Vec<(graph::Entity, f64)> = neighbours
518            .into_iter()
519            .filter(|n| n.relation.weight >= threshold)
520            .map(|n| (n.entity, n.relation.weight))
521            .collect();
522        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
523        Ok(results)
524    }
525
526    // ========== Graph Traversal Functions ==========
527
528    /// BFS traversal from a starting entity.
529    /// Returns all reachable entities within max_depth with depth information.
530    pub fn kg_bfs_traversal(
531        &self,
532        start_id: i64,
533        direction: Direction,
534        max_depth: u32,
535    ) -> Result<Vec<TraversalNode>> {
536        let query = TraversalQuery {
537            direction,
538            max_depth,
539            ..Default::default()
540        };
541        graph::bfs_traversal(&self.conn, start_id, query)
542    }
543
544    /// DFS traversal from a starting entity.
545    /// Returns all reachable entities within max_depth.
546    pub fn kg_dfs_traversal(
547        &self,
548        start_id: i64,
549        direction: Direction,
550        max_depth: u32,
551    ) -> Result<Vec<TraversalNode>> {
552        let query = TraversalQuery {
553            direction,
554            max_depth,
555            ..Default::default()
556        };
557        graph::dfs_traversal(&self.conn, start_id, query)
558    }
559
560    /// Find shortest path between two entities using BFS.
561    /// Returns the path with all intermediate steps (if exists).
562    pub fn kg_shortest_path(
563        &self,
564        from_id: i64,
565        to_id: i64,
566        max_depth: u32,
567    ) -> Result<Option<TraversalPath>> {
568        graph::find_shortest_path(&self.conn, from_id, to_id, max_depth)
569    }
570
571    /// Compute graph statistics.
572    pub fn kg_graph_stats(&self) -> Result<GraphStats> {
573        graph::compute_graph_stats(&self.conn)
574    }
575
576    // ========== Graph Algorithms ==========
577
578    /// Compute PageRank scores for all entities.
579    /// Returns a vector of (entity_id, score) sorted by score descending.
580    pub fn kg_pagerank(&self, config: Option<PageRankConfig>) -> Result<Vec<(i64, f64)>> {
581        algorithms::pagerank(&self.conn, config.unwrap_or_default())
582    }
583
584    /// Detect communities using Louvain algorithm.
585    /// Returns community memberships and modularity score.
586    pub fn kg_louvain(&self) -> Result<CommunityResult> {
587        algorithms::louvain_communities(&self.conn)
588    }
589
590    /// Find connected components in the graph.
591    /// Returns a list of components, each being a list of entity IDs.
592    pub fn kg_connected_components(&self) -> Result<Vec<Vec<i64>>> {
593        algorithms::connected_components(&self.conn)
594    }
595
596    /// Run full graph analysis (PageRank + Louvain + Connected Components).
597    pub fn kg_analyze(&self) -> Result<algorithms::GraphAnalysis> {
598        algorithms::analyze_graph(&self.conn)
599    }
600
601    // ========== Visualization Export ==========
602
603    /// Export the knowledge graph in D3.js JSON format.
604    ///
605    /// Returns a `D3ExportGraph` containing nodes, links, and metadata,
606    /// ready for use with D3.js force-directed graph visualizations.
607    ///
608    /// # Example
609    /// ```ignore
610    /// let graph = kg.export_json()?;
611    /// let json = serde_json::to_string_pretty(&graph)?;
612    /// std::fs::write("graph.json", json)?;
613    /// ```
614    pub fn export_json(&self) -> Result<D3ExportGraph> {
615        export::export_d3_json(&self.conn)
616    }
617
618    /// Export the knowledge graph in DOT (Graphviz) format.
619    ///
620    /// Returns a DOT format string that can be rendered with Graphviz tools
621    /// (`dot`, `neato`, `fdp`, etc.).
622    ///
623    /// # Example
624    /// ```ignore
625    /// let dot = kg.export_dot(&DotConfig::default())?;
626    /// std::fs::write("graph.dot", dot)?;
627    /// // Then: dot -Tpng graph.dot -o graph.png
628    /// ```
629    pub fn export_dot(&self, config: &DotConfig) -> Result<String> {
630        export::export_dot(&self.conn, config)
631    }
632
633    /// Convert this synchronous `KnowledgeGraph` into an `AsyncKnowledgeGraph`.
634    ///
635    /// Requires the `async` feature to be enabled.
636    #[cfg(feature = "async")]
637    pub fn into_async(self) -> AsyncKnowledgeGraph {
638        AsyncKnowledgeGraph::from_sync(self)
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645
646    #[test]
647    fn test_open_in_memory() {
648        let kg = KnowledgeGraph::open_in_memory().unwrap();
649        assert!(schema_exists(kg.connection()).unwrap());
650    }
651
652    #[test]
653    fn test_crud_operations() {
654        let kg = KnowledgeGraph::open_in_memory().unwrap();
655
656        // Create entity
657        let mut entity = Entity::new("paper", "Test Paper");
658        entity.set_property("author", serde_json::json!("John Doe"));
659        let id = kg.insert_entity(&entity).unwrap();
660
661        // Read entity
662        let retrieved = kg.get_entity(id).unwrap();
663        assert_eq!(retrieved.name, "Test Paper");
664
665        // List entities
666        let entities = kg.list_entities(Some("paper"), None).unwrap();
667        assert_eq!(entities.len(), 1);
668
669        // Update entity
670        let mut updated = retrieved.clone();
671        updated.set_property("year", serde_json::json!(2024));
672        kg.update_entity(&updated).unwrap();
673
674        // Delete entity
675        kg.delete_entity(id).unwrap();
676        let entities = kg.list_entities(None, None).unwrap();
677        assert_eq!(entities.len(), 0);
678    }
679
680    #[test]
681    fn test_graph_traversal() {
682        let kg = KnowledgeGraph::open_in_memory().unwrap();
683
684        // Create entities
685        let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
686        let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
687        let id3 = kg.insert_entity(&Entity::new("paper", "Paper 3")).unwrap();
688
689        // Create relations
690        kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
691            .unwrap();
692        kg.insert_relation(&Relation::new(id2, id3, "cites", 0.9).unwrap())
693            .unwrap();
694
695        // Get neighbors depth 1
696        let neighbors = kg.get_neighbors(id1, 1).unwrap();
697        assert_eq!(neighbors.len(), 1);
698
699        // Get neighbors depth 2
700        let neighbors = kg.get_neighbors(id1, 2).unwrap();
701        assert_eq!(neighbors.len(), 2);
702    }
703
704    #[test]
705    fn test_vector_search() {
706        let kg = KnowledgeGraph::open_in_memory().unwrap();
707
708        // Create entities
709        let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
710        let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
711
712        // Insert vectors
713        kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
714        kg.insert_vector(id2, vec![0.0, 1.0, 0.0]).unwrap();
715
716        // Search
717        let results = kg.search_vectors(vec![1.0, 0.0, 0.0], 2).unwrap();
718        assert_eq!(results.len(), 2);
719        assert_eq!(results[0].entity_id, id1);
720    }
721
722    // ── kg_find_related tests ─────────────────────────────────────────────────
723
724    #[test]
725    fn test_find_related_above_threshold() {
726        let kg = KnowledgeGraph::open_in_memory().unwrap();
727        let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
728        let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
729        let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
730
731        kg.insert_relation(&Relation::new(id1, id2, "related", 0.9).unwrap())
732            .unwrap();
733        kg.insert_relation(&Relation::new(id1, id3, "related", 0.3).unwrap())
734            .unwrap();
735
736        let results = kg.kg_find_related(id1, 0.5).unwrap();
737        assert_eq!(
738            results.len(),
739            1,
740            "only B (weight 0.9) should pass threshold 0.5"
741        );
742        assert_eq!(results[0].0.id, Some(id2));
743    }
744
745    #[test]
746    fn test_find_related_sorted_descending() {
747        let kg = KnowledgeGraph::open_in_memory().unwrap();
748        let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
749        let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
750        let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
751
752        kg.insert_relation(&Relation::new(id1, id2, "related", 0.4).unwrap())
753            .unwrap();
754        kg.insert_relation(&Relation::new(id1, id3, "related", 0.9).unwrap())
755            .unwrap();
756
757        let results = kg.kg_find_related(id1, 0.0).unwrap();
758        assert_eq!(results.len(), 2);
759        assert!(
760            results[0].1 >= results[1].1,
761            "results should be sorted by weight desc"
762        );
763        assert_eq!(results[0].0.id, Some(id3)); // weight 0.9 first
764    }
765
766    #[test]
767    fn test_find_related_threshold_one() {
768        let kg = KnowledgeGraph::open_in_memory().unwrap();
769        let id1 = kg.insert_entity(&Entity::new("paper", "A")).unwrap();
770        let id2 = kg.insert_entity(&Entity::new("paper", "B")).unwrap();
771        let id3 = kg.insert_entity(&Entity::new("paper", "C")).unwrap();
772
773        kg.insert_relation(&Relation::new(id1, id2, "related", 1.0).unwrap())
774            .unwrap();
775        kg.insert_relation(&Relation::new(id1, id3, "related", 0.9).unwrap())
776            .unwrap();
777
778        let results = kg.kg_find_related(id1, 1.0).unwrap();
779        assert_eq!(results.len(), 1);
780        assert_eq!(results[0].0.id, Some(id2));
781    }
782
783    #[test]
784    fn test_find_related_no_neighbours() {
785        let kg = KnowledgeGraph::open_in_memory().unwrap();
786        let id1 = kg.insert_entity(&Entity::new("paper", "Isolated")).unwrap();
787
788        let results = kg.kg_find_related(id1, 0.0).unwrap();
789        assert!(results.is_empty(), "isolated entity should return empty");
790    }
791
792    #[test]
793    fn test_find_related_entity_not_found() {
794        let kg = KnowledgeGraph::open_in_memory().unwrap();
795        let result = kg.kg_find_related(9999, 0.5);
796        assert!(result.is_err(), "non-existent entity should return error");
797    }
798
799    #[test]
800    fn test_similar_entities() {
801        let kg = KnowledgeGraph::open_in_memory().unwrap();
802        let id1 = kg.insert_entity(&graph::Entity::new("paper", "A")).unwrap();
803        let id2 = kg.insert_entity(&graph::Entity::new("paper", "B")).unwrap();
804        let id3 = kg.insert_entity(&graph::Entity::new("paper", "C")).unwrap();
805
806        // A and B are very similar, C is different
807        kg.insert_vector(id1, vec![1.0, 0.0, 0.0, 0.0]).unwrap();
808        kg.insert_vector(id2, vec![0.9, 0.1, 0.0, 0.0]).unwrap();
809        kg.insert_vector(id3, vec![0.0, 0.0, 1.0, 0.0]).unwrap();
810
811        let results = kg.kg_similar_entities(id1, 2).unwrap();
812        assert_eq!(results.len(), 2);
813        // B should be most similar to A
814        assert_eq!(results[0].entity.name, "B");
815        assert!(results[0].similarity > results[1].similarity);
816    }
817
818    #[test]
819    fn test_similar_entities_excludes_self() {
820        let kg = KnowledgeGraph::open_in_memory().unwrap();
821        let id1 = kg.insert_entity(&graph::Entity::new("paper", "X")).unwrap();
822        kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
823
824        let results = kg.kg_similar_entities(id1, 5).unwrap();
825        assert!(results.is_empty(), "self should not appear in results");
826    }
827
828    #[test]
829    fn test_similar_entities_no_vector() {
830        let kg = KnowledgeGraph::open_in_memory().unwrap();
831        let id1 = kg
832            .insert_entity(&graph::Entity::new("paper", "NoVec"))
833            .unwrap();
834        let result = kg.kg_similar_entities(id1, 5);
835        assert!(result.is_err(), "entity without vector should error");
836    }
837}