1pub mod algorithms;
21pub mod error;
22pub mod extension;
23pub mod functions;
24pub mod graph;
25pub mod migrate;
26pub mod schema;
27pub mod vector;
28
29pub use algorithms::{
30 analyze_graph, connected_components, louvain_communities, pagerank, CommunityResult,
31 PageRankConfig,
32};
33pub use error::{Error, Result};
34pub use extension::sqlite3_sqlite_knowledge_graph_init;
35pub use functions::register_functions;
36pub use graph::{Direction, GraphStats, PathStep, TraversalNode, TraversalPath, TraversalQuery};
37pub use graph::{Entity, Neighbor, Relation};
38pub use migrate::{
39 build_relationships, migrate_all, migrate_papers, migrate_skills, MigrationStats,
40};
41pub use schema::{create_schema, schema_exists};
42pub use vector::{cosine_similarity, SearchResult, VectorStore};
43pub use vector::{TurboQuantConfig, TurboQuantIndex, TurboQuantStats};
44
45use rusqlite::Connection;
46use serde::{Deserialize, Serialize};
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SearchResultWithEntity {
51 pub entity: Entity,
52 pub similarity: f32,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct GraphContext {
58 pub root_entity: Entity,
59 pub neighbors: Vec<Neighbor>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct HybridSearchResult {
65 pub entity: Entity,
66 pub similarity: f32,
67 pub context: Option<GraphContext>,
68}
69
70#[derive(Debug)]
72pub struct KnowledgeGraph {
73 conn: Connection,
74}
75
76impl KnowledgeGraph {
77 pub fn open<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
79 let conn = Connection::open(path)?;
80
81 conn.execute("PRAGMA foreign_keys = ON", [])?;
83
84 if !schema_exists(&conn)? {
86 create_schema(&conn)?;
87 }
88
89 register_functions(&conn)?;
91
92 Ok(Self { conn })
93 }
94
95 pub fn open_in_memory() -> Result<Self> {
97 let conn = Connection::open_in_memory()?;
98
99 conn.execute("PRAGMA foreign_keys = ON", [])?;
101
102 create_schema(&conn)?;
104
105 register_functions(&conn)?;
107
108 Ok(Self { conn })
109 }
110
111 pub fn connection(&self) -> &Connection {
113 &self.conn
114 }
115
116 pub fn transaction(&self) -> Result<rusqlite::Transaction<'_>> {
118 Ok(self.conn.unchecked_transaction()?)
119 }
120
121 pub fn insert_entity(&self, entity: &Entity) -> Result<i64> {
123 graph::insert_entity(&self.conn, entity)
124 }
125
126 pub fn get_entity(&self, id: i64) -> Result<Entity> {
128 graph::get_entity(&self.conn, id)
129 }
130
131 pub fn list_entities(
133 &self,
134 entity_type: Option<&str>,
135 limit: Option<i64>,
136 ) -> Result<Vec<Entity>> {
137 graph::list_entities(&self.conn, entity_type, limit)
138 }
139
140 pub fn update_entity(&self, entity: &Entity) -> Result<()> {
142 graph::update_entity(&self.conn, entity)
143 }
144
145 pub fn delete_entity(&self, id: i64) -> Result<()> {
147 graph::delete_entity(&self.conn, id)
148 }
149
150 pub fn insert_relation(&self, relation: &Relation) -> Result<i64> {
152 graph::insert_relation(&self.conn, relation)
153 }
154
155 pub fn get_neighbors(&self, entity_id: i64, depth: u32) -> Result<Vec<Neighbor>> {
157 graph::get_neighbors(&self.conn, entity_id, depth)
158 }
159
160 pub fn insert_vector(&self, entity_id: i64, vector: Vec<f32>) -> Result<()> {
162 let store = VectorStore::new();
163 store.insert_vector(&self.conn, entity_id, vector)
164 }
165
166 pub fn search_vectors(&self, query: Vec<f32>, k: usize) -> Result<Vec<SearchResult>> {
168 let store = VectorStore::new();
169 store.search_vectors(&self.conn, query, k)
170 }
171
172 pub fn create_turboquant_index(
202 &self,
203 config: Option<TurboQuantConfig>,
204 ) -> Result<TurboQuantIndex> {
205 let config = config.unwrap_or_default();
206
207 TurboQuantIndex::new(config)
208 }
209
210 pub fn build_turboquant_index(
213 &self,
214 config: Option<TurboQuantConfig>,
215 ) -> Result<TurboQuantIndex> {
216 let dimension = self.get_vector_dimension()?.unwrap_or(384);
218
219 let config = config.unwrap_or(TurboQuantConfig {
220 dimension,
221 bit_width: 3,
222 seed: 42,
223 });
224
225 let mut index = TurboQuantIndex::new(config)?;
226
227 let vectors = self.load_all_vectors()?;
229
230 for (entity_id, vector) in vectors {
231 index.add_vector(entity_id, &vector)?;
232 }
233
234 Ok(index)
235 }
236
237 fn get_vector_dimension(&self) -> Result<Option<usize>> {
239 let result = self
240 .conn
241 .query_row("SELECT dimension FROM kg_vectors LIMIT 1", [], |row| {
242 row.get::<_, i64>(0)
243 });
244
245 match result {
246 Ok(dim) => Ok(Some(dim as usize)),
247 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
248 Err(e) => Err(e.into()),
249 }
250 }
251
252 fn load_all_vectors(&self) -> Result<Vec<(i64, Vec<f32>)>> {
254 let mut stmt = self
255 .conn
256 .prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
257
258 let rows = stmt.query_map([], |row| {
259 let entity_id: i64 = row.get(0)?;
260 let vector_blob: Vec<u8> = row.get(1)?;
261 let dimension: i64 = row.get(2)?;
262
263 let mut vector = Vec::with_capacity(dimension as usize);
264 for chunk in vector_blob.chunks_exact(4) {
265 let bytes: [u8; 4] = chunk.try_into().unwrap();
266 vector.push(f32::from_le_bytes(bytes));
267 }
268
269 Ok((entity_id, vector))
270 })?;
271
272 let mut vectors = Vec::new();
273 for row in rows {
274 vectors.push(row?);
275 }
276
277 Ok(vectors)
278 }
279
280 pub fn kg_semantic_search(
285 &self,
286 query_embedding: Vec<f32>,
287 k: usize,
288 ) -> Result<Vec<SearchResultWithEntity>> {
289 let results = self.search_vectors(query_embedding, k)?;
290
291 let mut entities_with_results = Vec::new();
292 for result in results {
293 let entity = self.get_entity(result.entity_id)?;
294 entities_with_results.push(SearchResultWithEntity {
295 entity,
296 similarity: result.similarity,
297 });
298 }
299
300 Ok(entities_with_results)
301 }
302
303 pub fn kg_get_context(&self, entity_id: i64, depth: u32) -> Result<GraphContext> {
306 let root_entity = self.get_entity(entity_id)?;
307 let neighbors = self.get_neighbors(entity_id, depth)?;
308
309 Ok(GraphContext {
310 root_entity,
311 neighbors,
312 })
313 }
314
315 pub fn kg_hybrid_search(
318 &self,
319 _query_text: &str,
320 query_embedding: Vec<f32>,
321 k: usize,
322 ) -> Result<Vec<HybridSearchResult>> {
323 let semantic_results = self.kg_semantic_search(query_embedding, k)?;
324
325 let mut hybrid_results = Vec::new();
326 for result in semantic_results.iter() {
327 let entity_id = result.entity.id.ok_or(Error::EntityNotFound(0))?;
328 let context = self.kg_get_context(entity_id, 1)?; hybrid_results.push(HybridSearchResult {
331 entity: result.entity.clone(),
332 similarity: result.similarity,
333 context: Some(context),
334 });
335 }
336
337 Ok(hybrid_results)
338 }
339
340 pub fn kg_bfs_traversal(
345 &self,
346 start_id: i64,
347 direction: Direction,
348 max_depth: u32,
349 ) -> Result<Vec<TraversalNode>> {
350 let query = TraversalQuery {
351 direction,
352 max_depth,
353 ..Default::default()
354 };
355 graph::bfs_traversal(&self.conn, start_id, query)
356 }
357
358 pub fn kg_dfs_traversal(
361 &self,
362 start_id: i64,
363 direction: Direction,
364 max_depth: u32,
365 ) -> Result<Vec<TraversalNode>> {
366 let query = TraversalQuery {
367 direction,
368 max_depth,
369 ..Default::default()
370 };
371 graph::dfs_traversal(&self.conn, start_id, query)
372 }
373
374 pub fn kg_shortest_path(
377 &self,
378 from_id: i64,
379 to_id: i64,
380 max_depth: u32,
381 ) -> Result<Option<TraversalPath>> {
382 graph::find_shortest_path(&self.conn, from_id, to_id, max_depth)
383 }
384
385 pub fn kg_graph_stats(&self) -> Result<GraphStats> {
387 graph::compute_graph_stats(&self.conn)
388 }
389
390 pub fn kg_pagerank(&self, config: Option<PageRankConfig>) -> Result<Vec<(i64, f64)>> {
395 algorithms::pagerank(&self.conn, config.unwrap_or_default())
396 }
397
398 pub fn kg_louvain(&self) -> Result<CommunityResult> {
401 algorithms::louvain_communities(&self.conn)
402 }
403
404 pub fn kg_connected_components(&self) -> Result<Vec<Vec<i64>>> {
407 algorithms::connected_components(&self.conn)
408 }
409
410 pub fn kg_analyze(&self) -> Result<algorithms::GraphAnalysis> {
412 algorithms::analyze_graph(&self.conn)
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn test_open_in_memory() {
422 let kg = KnowledgeGraph::open_in_memory().unwrap();
423 assert!(schema_exists(kg.connection()).unwrap());
424 }
425
426 #[test]
427 fn test_crud_operations() {
428 let kg = KnowledgeGraph::open_in_memory().unwrap();
429
430 let mut entity = Entity::new("paper", "Test Paper");
432 entity.set_property("author", serde_json::json!("John Doe"));
433 let id = kg.insert_entity(&entity).unwrap();
434
435 let retrieved = kg.get_entity(id).unwrap();
437 assert_eq!(retrieved.name, "Test Paper");
438
439 let entities = kg.list_entities(Some("paper"), None).unwrap();
441 assert_eq!(entities.len(), 1);
442
443 let mut updated = retrieved.clone();
445 updated.set_property("year", serde_json::json!(2024));
446 kg.update_entity(&updated).unwrap();
447
448 kg.delete_entity(id).unwrap();
450 let entities = kg.list_entities(None, None).unwrap();
451 assert_eq!(entities.len(), 0);
452 }
453
454 #[test]
455 fn test_graph_traversal() {
456 let kg = KnowledgeGraph::open_in_memory().unwrap();
457
458 let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
460 let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
461 let id3 = kg.insert_entity(&Entity::new("paper", "Paper 3")).unwrap();
462
463 kg.insert_relation(&Relation::new(id1, id2, "cites", 0.8).unwrap())
465 .unwrap();
466 kg.insert_relation(&Relation::new(id2, id3, "cites", 0.9).unwrap())
467 .unwrap();
468
469 let neighbors = kg.get_neighbors(id1, 1).unwrap();
471 assert_eq!(neighbors.len(), 1);
472
473 let neighbors = kg.get_neighbors(id1, 2).unwrap();
475 assert_eq!(neighbors.len(), 2);
476 }
477
478 #[test]
479 fn test_vector_search() {
480 let kg = KnowledgeGraph::open_in_memory().unwrap();
481
482 let id1 = kg.insert_entity(&Entity::new("paper", "Paper 1")).unwrap();
484 let id2 = kg.insert_entity(&Entity::new("paper", "Paper 2")).unwrap();
485
486 kg.insert_vector(id1, vec![1.0, 0.0, 0.0]).unwrap();
488 kg.insert_vector(id2, vec![0.0, 1.0, 0.0]).unwrap();
489
490 let results = kg.search_vectors(vec![1.0, 0.0, 0.0], 2).unwrap();
492 assert_eq!(results.len(), 2);
493 assert_eq!(results[0].entity_id, id1);
494 }
495}