1pub mod embedder;
24mod error;
25pub mod smart_retrieval;
26
27pub use embedder::Embedder;
28pub use error::RagError;
29pub use smart_retrieval::{RetrievalWeights, SmartRetrieval, SmartSearchResult};
30
31use crate::error::Result;
32use crate::graph::{get_neighbors, Entity};
33use crate::vector::{cosine_similarity, TurboQuantConfig, TurboQuantIndex, VectorStore};
34use rusqlite::Connection;
35use std::collections::HashMap;
36
37#[derive(Debug, Clone)]
43pub struct RagResult {
44 pub entity: Entity,
46 pub vector_score: f64,
48 pub graph_score: f64,
50 pub combined_score: f64,
52 pub context_entities: Vec<Entity>,
54}
55
56#[derive(Debug, Clone)]
59pub struct RagConfig {
60 pub vector_weight: f64,
63 pub graph_weight: f64,
65
66 pub top_k_candidates: usize,
69 pub top_k_rerank: usize,
71
72 pub enable_graph_expansion: bool,
75 pub graph_depth: u32,
77
78 pub context_depth: u32,
81 pub max_context_entities: usize,
83
84 pub min_vector_score: f32,
87 pub min_combined_score: f64,
89
90 pub vector_dimension: usize,
93}
94
95impl Default for RagConfig {
96 fn default() -> Self {
97 Self {
98 vector_weight: 0.6,
99 graph_weight: 0.4,
100 top_k_candidates: 50,
101 top_k_rerank: 20,
102 enable_graph_expansion: true,
103 graph_depth: 1,
104 context_depth: 2,
105 max_context_entities: 5,
106 min_vector_score: 0.0,
107 min_combined_score: 0.0,
108 vector_dimension: 384,
109 }
110 }
111}
112
113pub struct RagEngine {
119 config: RagConfig,
120}
121
122impl RagEngine {
123 pub fn new(config: RagConfig) -> Self {
124 Self { config }
125 }
126
127 pub fn search(
135 &self,
136 conn: &Connection,
137 embedder: &dyn Embedder,
138 query: &str,
139 k: usize,
140 ) -> Result<Vec<RagResult>> {
141 let query_vec = embedder.embed(query)?;
143
144 let ann_candidates = self.stage1_ann(conn, &query_vec)?;
146
147 if ann_candidates.is_empty() {
148 return Ok(Vec::new());
149 }
150
151 let mut reranked = self.stage2_rerank(conn, &query_vec, ann_candidates)?;
153 reranked.truncate(self.config.top_k_rerank);
154
155 let mut pool: HashMap<i64, f32> = reranked.into_iter().collect();
157 if self.config.enable_graph_expansion {
158 self.rapo_expand(conn, &query_vec, &mut pool)?;
159 }
160
161 let pool_size = pool.len();
163 let mut scored = self.score_and_filter(conn, &pool, pool_size)?;
164
165 scored.sort_by(|a, b| b.combined_score.partial_cmp(&a.combined_score).unwrap());
167 scored.truncate(k);
168
169 for result in &mut scored {
171 let entity_id = result.entity.id.unwrap_or(0);
172 result.context_entities = self.collect_context(conn, entity_id, &pool)?;
173 }
174
175 Ok(scored)
176 }
177
178 fn stage1_ann(&self, conn: &Connection, query_vec: &[f32]) -> Result<Vec<(i64, f32)>> {
190 let vector_count: i64 =
191 conn.query_row("SELECT COUNT(*) FROM kg_vectors", [], |r| r.get(0))?;
192
193 if vector_count == 0 {
194 return Ok(Vec::new());
195 }
196
197 let vectors_checksum: i64 = conn.query_row(
200 "SELECT COALESCE(SUM(entity_id), 0) FROM kg_vectors",
201 [],
202 |r| r.get(0),
203 )?;
204
205 let cached = load_turboquant_cache(conn, vector_count, vectors_checksum)?;
207 let index = match cached {
208 Some(idx) => idx,
209 None => {
210 let all_vectors = load_all_vectors(conn)?;
212 let dim = all_vectors[0].1.len();
213 let config = TurboQuantConfig {
214 dimension: dim,
215 bit_width: 3,
216 seed: 42,
217 };
218 let mut idx = TurboQuantIndex::new(config)?;
219 for (entity_id, vec) in &all_vectors {
220 idx.add_vector(*entity_id, vec)?;
221 }
222 save_turboquant_cache(conn, &idx, vector_count, vectors_checksum)?;
223 idx
224 }
225 };
226
227 let k = self.config.top_k_candidates.min(vector_count as usize);
228 index.search(query_vec, k)
229 }
230
231 fn stage2_rerank(
233 &self,
234 conn: &Connection,
235 query_vec: &[f32],
236 candidates: Vec<(i64, f32)>,
237 ) -> Result<Vec<(i64, f32)>> {
238 let store = VectorStore::new();
239 let mut scored: Vec<(i64, f32)> = Vec::with_capacity(candidates.len());
240
241 for (entity_id, approx) in candidates {
242 if approx < self.config.min_vector_score {
244 continue;
245 }
246 match store.get_vector(conn, entity_id) {
247 Ok(vec) => {
248 let exact = cosine_similarity(query_vec, &vec);
249 if exact >= self.config.min_vector_score {
250 scored.push((entity_id, exact));
251 }
252 }
253 Err(_) => {
254 }
256 }
257 }
258
259 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
260 Ok(scored)
261 }
262
263 fn rapo_expand(
266 &self,
267 conn: &Connection,
268 query_vec: &[f32],
269 pool: &mut HashMap<i64, f32>,
270 ) -> Result<()> {
271 let store = VectorStore::new();
272 let seeds: Vec<i64> = pool.keys().copied().collect();
273
274 for seed_id in seeds {
275 let neighbours = match get_neighbors(conn, seed_id, self.config.graph_depth) {
276 Ok(n) => n,
277 Err(_) => continue,
278 };
279
280 for nbr in neighbours {
281 let nbr_id = match nbr.entity.id {
282 Some(id) => id,
283 None => continue,
284 };
285
286 if pool.contains_key(&nbr_id) {
287 continue;
288 }
289
290 if let Ok(vec) = store.get_vector(conn, nbr_id) {
292 let score = cosine_similarity(query_vec, &vec);
293 if score >= self.config.min_vector_score {
294 pool.insert(nbr_id, score);
295 }
296 }
297 }
298 }
299
300 Ok(())
301 }
302
303 fn score_and_filter(
306 &self,
307 conn: &Connection,
308 pool: &HashMap<i64, f32>,
309 pool_size: usize,
310 ) -> Result<Vec<RagResult>> {
311 let mut results = Vec::new();
312
313 for (&entity_id, &v_score) in pool {
314 let vector_score = v_score as f64;
315
316 let graph_score = if pool_size > 1 {
318 let neighbours = get_neighbors(conn, entity_id, 1).unwrap_or_default();
319 let overlap = neighbours
320 .iter()
321 .filter(|n| {
322 n.entity
323 .id
324 .map(|id| pool.contains_key(&id))
325 .unwrap_or(false)
326 })
327 .count();
328 overlap as f64 / (pool_size - 1) as f64
329 } else {
330 0.0
331 };
332
333 let combined_score =
334 self.config.vector_weight * vector_score + self.config.graph_weight * graph_score;
335
336 if combined_score < self.config.min_combined_score {
338 continue;
339 }
340
341 let entity = match crate::graph::get_entity(conn, entity_id) {
342 Ok(e) => e,
343 Err(_) => continue,
344 };
345
346 results.push(RagResult {
347 entity,
348 vector_score,
349 graph_score,
350 combined_score,
351 context_entities: Vec::new(), });
353 }
354
355 Ok(results)
356 }
357
358 fn collect_context(
361 &self,
362 conn: &Connection,
363 entity_id: i64,
364 pool: &HashMap<i64, f32>,
365 ) -> Result<Vec<Entity>> {
366 let neighbours = match get_neighbors(conn, entity_id, self.config.context_depth) {
367 Ok(n) => n,
368 Err(_) => return Ok(Vec::new()),
369 };
370
371 let mut in_pool: Vec<Entity> = Vec::new();
373 let mut not_in_pool: Vec<Entity> = Vec::new();
374
375 for nbr in neighbours {
376 if let Some(id) = nbr.entity.id {
377 if pool.contains_key(&id) {
378 in_pool.push(nbr.entity);
379 } else {
380 not_in_pool.push(nbr.entity);
381 }
382 }
383 }
384
385 in_pool.extend(not_in_pool);
386 in_pool.truncate(self.config.max_context_entities);
387 Ok(in_pool)
388 }
389}
390
391fn load_turboquant_cache(
400 conn: &Connection,
401 current_count: i64,
402 current_checksum: i64,
403) -> Result<Option<TurboQuantIndex>> {
404 let mut stmt = conn.prepare(
405 "SELECT index_blob, vector_count, vectors_checksum \
406 FROM kg_turboquant_cache WHERE id = 1",
407 )?;
408
409 let result = stmt.query_row([], |row| {
410 let blob: Vec<u8> = row.get(0)?;
411 let cached_count: i64 = row.get(1)?;
412 let cached_checksum: i64 = row.get(2)?;
413 Ok((blob, cached_count, cached_checksum))
414 });
415
416 match result {
417 Ok((blob, cached_count, cached_checksum))
418 if cached_count == current_count && cached_checksum == current_checksum =>
419 {
420 let index = TurboQuantIndex::from_bytes(&blob)
421 .map_err(|e| crate::error::Error::Other(e.to_string()))?;
422 Ok(Some(index))
423 }
424 Ok(_) | Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
425 Err(e) => Err(e.into()),
426 }
427}
428
429fn save_turboquant_cache(
431 conn: &Connection,
432 index: &TurboQuantIndex,
433 vector_count: i64,
434 vectors_checksum: i64,
435) -> Result<()> {
436 let blob = index
437 .to_bytes()
438 .map_err(|e| crate::error::Error::Other(e.to_string()))?;
439 conn.execute(
440 "INSERT INTO kg_turboquant_cache \
441 (id, index_blob, vector_count, vectors_checksum) \
442 VALUES (1, ?1, ?2, ?3) \
443 ON CONFLICT(id) DO UPDATE SET \
444 index_blob = excluded.index_blob, \
445 vector_count = excluded.vector_count, \
446 vectors_checksum = excluded.vectors_checksum",
447 rusqlite::params![blob, vector_count, vectors_checksum],
448 )?;
449 Ok(())
450}
451
452fn load_all_vectors(conn: &Connection) -> Result<Vec<(i64, Vec<f32>)>> {
453 let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
454
455 let rows = stmt.query_map([], |row| {
456 let entity_id: i64 = row.get(0)?;
457 let blob: Vec<u8> = row.get(1)?;
458 let dim: i64 = row.get(2)?;
459
460 let mut vec = Vec::with_capacity(dim as usize);
461 for chunk in blob.chunks_exact(4) {
462 vec.push(f32::from_le_bytes(chunk.try_into().unwrap()));
463 }
464
465 Ok((entity_id, vec))
466 })?;
467
468 let mut out = Vec::new();
469 for row in rows {
470 out.push(row?);
471 }
472 Ok(out)
473}
474
475#[cfg(test)]
480mod tests {
481 use super::*;
482 use crate::graph::entity::{insert_entity, Entity};
483 use crate::graph::relation::{insert_relation, Relation};
484 use crate::rag::embedder::FixedEmbedder;
485 use crate::vector::VectorStore;
486 use rusqlite::Connection;
487
488 fn setup(dim: usize) -> (Connection, Vec<i64>) {
489 let conn = Connection::open_in_memory().unwrap();
490 crate::schema::create_schema(&conn).unwrap();
491
492 let e1 = insert_entity(&conn, &Entity::new("doc", "Doc A")).unwrap();
493 let e2 = insert_entity(&conn, &Entity::new("doc", "Doc B")).unwrap();
494 let e3 = insert_entity(&conn, &Entity::new("doc", "Doc C")).unwrap();
495
496 let store = VectorStore::new();
497 let mut v1 = vec![0.0f32; dim];
499 v1[0] = 1.0;
500 store.insert_vector(&conn, e1, v1).unwrap();
501
502 let mut v2 = vec![0.0f32; dim];
504 v2[1] = 1.0;
505 store.insert_vector(&conn, e2, v2).unwrap();
506
507 let mut v3 = vec![0.0f32; dim];
509 v3[0] = 0.8;
510 v3[1] = 0.6;
511 store.insert_vector(&conn, e3, v3).unwrap();
512
513 insert_relation(&conn, &Relation::new(e1, e2, "related", 0.3).unwrap()).unwrap();
515 insert_relation(&conn, &Relation::new(e1, e3, "related", 0.9).unwrap()).unwrap();
516
517 (conn, vec![e1, e2, e3])
518 }
519
520 #[test]
521 fn test_basic_search() {
522 let dim = 4;
523 let (conn, ids) = setup(dim);
524
525 let mut query = vec![0.0f32; dim];
526 query[0] = 1.0;
527
528 let embedder = FixedEmbedder(query);
529 let engine = RagEngine::new(RagConfig {
530 vector_dimension: dim,
531 top_k_candidates: 10,
532 top_k_rerank: 5,
533 ..Default::default()
534 });
535
536 let results = engine.search(&conn, &embedder, "test query", 2).unwrap();
537 assert!(!results.is_empty(), "should return at least one result");
538
539 assert_eq!(results[0].entity.id, Some(ids[0]));
541 assert!((results[0].vector_score - 1.0).abs() < 1e-5);
542 }
543
544 #[test]
545 fn test_empty_db() {
546 let conn = Connection::open_in_memory().unwrap();
547 crate::schema::create_schema(&conn).unwrap();
548
549 let embedder = FixedEmbedder(vec![1.0, 0.0, 0.0]);
550 let engine = RagEngine::new(RagConfig::default());
551
552 let results = engine.search(&conn, &embedder, "anything", 5).unwrap();
553 assert!(results.is_empty());
554 }
555
556 #[test]
557 fn test_graph_expansion() {
558 let dim = 4;
560 let conn = Connection::open_in_memory().unwrap();
561 crate::schema::create_schema(&conn).unwrap();
562
563 let store = VectorStore::new();
564
565 let e1 = insert_entity(&conn, &Entity::new("doc", "A")).unwrap();
567 let e2 = insert_entity(&conn, &Entity::new("doc", "B")).unwrap();
568
569 let mut v1 = vec![0.0f32; dim];
570 v1[0] = 1.0;
571 store.insert_vector(&conn, e1, v1).unwrap();
572
573 let mut v2 = vec![0.0f32; dim];
574 v2[1] = 1.0;
575 store.insert_vector(&conn, e2, v2).unwrap();
576
577 insert_relation(&conn, &Relation::new(e1, e2, "link", 1.0).unwrap()).unwrap();
578
579 let mut query = vec![0.0f32; dim];
580 query[0] = 1.0;
581
582 let embedder = FixedEmbedder(query);
583 let engine = RagEngine::new(RagConfig {
584 vector_dimension: dim,
585 top_k_candidates: 1, top_k_rerank: 1,
587 enable_graph_expansion: true,
588 ..Default::default()
589 });
590
591 let results = engine.search(&conn, &embedder, "q", 5).unwrap();
592 let ids: Vec<i64> = results.iter().filter_map(|r| r.entity.id).collect();
593 assert!(ids.contains(&e1));
594 assert!(ids.contains(&e2), "RAPO should expand to e2");
595 }
596
597 #[test]
598 fn test_context_attached() {
599 let dim = 4;
600 let (conn, ids) = setup(dim);
601
602 let mut query = vec![0.0f32; dim];
603 query[0] = 1.0;
604
605 let embedder = FixedEmbedder(query);
606 let engine = RagEngine::new(RagConfig {
607 vector_dimension: dim,
608 context_depth: 1,
609 max_context_entities: 3,
610 ..Default::default()
611 });
612
613 let results = engine.search(&conn, &embedder, "q", 3).unwrap();
614
615 let e1_result = results.iter().find(|r| r.entity.id == Some(ids[0]));
617 assert!(e1_result.is_some());
618 let ctx = &e1_result.unwrap().context_entities;
619 assert!(!ctx.is_empty(), "e1 should have context neighbours");
620 }
621
622 #[test]
625 fn test_cache_written_on_first_query() {
626 let dim = 4;
627 let (conn, _ids) = setup(dim);
628
629 let mut query = vec![0.0f32; dim];
630 query[0] = 1.0;
631 let embedder = FixedEmbedder(query);
632 let engine = RagEngine::new(RagConfig {
633 vector_dimension: dim,
634 top_k_candidates: 10,
635 top_k_rerank: 5,
636 ..Default::default()
637 });
638
639 engine.search(&conn, &embedder, "q", 2).unwrap();
640
641 let count: i64 = conn
643 .query_row(
644 "SELECT COUNT(*) FROM kg_turboquant_cache WHERE id = 1",
645 [],
646 |r| r.get(0),
647 )
648 .unwrap();
649 assert_eq!(count, 1, "cache row should be created after first query");
650 }
651
652 #[test]
653 fn test_cache_hit_on_second_query() {
654 let dim = 4;
655 let (conn, _ids) = setup(dim);
656
657 let mut query = vec![0.0f32; dim];
658 query[0] = 1.0;
659 let embedder = FixedEmbedder(query);
660 let engine = RagEngine::new(RagConfig {
661 vector_dimension: dim,
662 top_k_candidates: 10,
663 top_k_rerank: 5,
664 ..Default::default()
665 });
666
667 let r1 = engine.search(&conn, &embedder, "q", 2).unwrap();
668 let r2 = engine.search(&conn, &embedder, "q", 2).unwrap();
669
670 assert_eq!(
672 r1[0].entity.id, r2[0].entity.id,
673 "cache hit should return identical results"
674 );
675 }
676
677 #[test]
678 fn test_cache_stores_checksum() {
679 let dim = 4;
680 let (conn, _ids) = setup(dim);
681
682 let query = {
683 let mut q = vec![0.0f32; dim];
684 q[0] = 1.0;
685 q
686 };
687 let embedder = FixedEmbedder(query);
688 let engine = RagEngine::new(RagConfig {
689 vector_dimension: dim,
690 top_k_candidates: 10,
691 top_k_rerank: 5,
692 ..Default::default()
693 });
694
695 engine.search(&conn, &embedder, "q", 2).unwrap();
696
697 let (count, checksum): (i64, i64) = conn
699 .query_row(
700 "SELECT vector_count, vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
701 [],
702 |r| Ok((r.get(0)?, r.get(1)?)),
703 )
704 .unwrap();
705 assert_eq!(count, 3);
706 assert!(checksum > 0, "checksum should reflect entity_id sum");
708 }
709
710 #[test]
711 fn test_cache_invalidated_on_same_count_different_entity() {
712 let dim = 4;
715 let (conn, ids) = setup(dim); let query = {
718 let mut q = vec![0.0f32; dim];
719 q[0] = 1.0;
720 q
721 };
722 let embedder = FixedEmbedder(query);
723 let engine = RagEngine::new(RagConfig {
724 vector_dimension: dim,
725 top_k_candidates: 10,
726 top_k_rerank: 5,
727 ..Default::default()
728 });
729
730 engine.search(&conn, &embedder, "q", 2).unwrap();
732
733 let checksum_before: i64 = conn
734 .query_row(
735 "SELECT vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
736 [],
737 |r| r.get(0),
738 )
739 .unwrap();
740
741 conn.execute("DELETE FROM kg_vectors WHERE entity_id = ?1", [ids[2]])
743 .unwrap();
744 let e_new = crate::graph::entity::insert_entity(
745 &conn,
746 &crate::graph::entity::Entity::new("doc", "Doc Swap"),
747 )
748 .unwrap();
749 let store = VectorStore::new();
750 let mut v_new = vec![0.0f32; dim];
751 v_new[3] = 1.0;
752 store.insert_vector(&conn, e_new, v_new).unwrap();
753 engine.search(&conn, &embedder, "q", 2).unwrap();
757
758 let (count_after, checksum_after): (i64, i64) = conn
759 .query_row(
760 "SELECT vector_count, vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
761 [],
762 |r| Ok((r.get(0)?, r.get(1)?)),
763 )
764 .unwrap();
765 assert_eq!(count_after, 3, "vector count should still be 3 after swap");
766 assert_ne!(
767 checksum_after, checksum_before,
768 "checksum must change after swapping one vector"
769 );
770 }
771
772 #[test]
773 fn test_cache_invalidated_after_new_vector() {
774 let dim = 4;
775 let (conn, _ids) = setup(dim);
776
777 let mut query = vec![0.0f32; dim];
778 query[0] = 1.0;
779 let embedder = FixedEmbedder(query);
780 let engine = RagEngine::new(RagConfig {
781 vector_dimension: dim,
782 top_k_candidates: 10,
783 top_k_rerank: 5,
784 ..Default::default()
785 });
786
787 engine.search(&conn, &embedder, "q", 2).unwrap();
789
790 let cached_count_before: i64 = conn
791 .query_row(
792 "SELECT vector_count FROM kg_turboquant_cache WHERE id = 1",
793 [],
794 |r| r.get(0),
795 )
796 .unwrap();
797 assert_eq!(cached_count_before, 3);
798
799 let e4 = crate::graph::entity::insert_entity(
801 &conn,
802 &crate::graph::entity::Entity::new("doc", "Doc D"),
803 )
804 .unwrap();
805 let store = VectorStore::new();
806 let mut v4 = vec![0.0f32; dim];
807 v4[2] = 1.0;
808 store.insert_vector(&conn, e4, v4).unwrap();
809
810 engine.search(&conn, &embedder, "q", 2).unwrap();
812
813 let cached_count_after: i64 = conn
814 .query_row(
815 "SELECT vector_count FROM kg_turboquant_cache WHERE id = 1",
816 [],
817 |r| r.get(0),
818 )
819 .unwrap();
820 assert_eq!(
821 cached_count_after, 4,
822 "cache should be rebuilt after new vector added"
823 );
824 }
825}