1pub mod embedder;
24mod error;
25
26pub use embedder::Embedder;
27pub use error::RagError;
28
29use crate::error::Result;
30use crate::graph::{get_neighbors, Entity};
31use crate::vector::{cosine_similarity, TurboQuantConfig, TurboQuantIndex, VectorStore};
32use rusqlite::Connection;
33use std::collections::HashMap;
34
35#[derive(Debug, Clone)]
41pub struct RagResult {
42 pub entity: Entity,
44 pub vector_score: f64,
46 pub graph_score: f64,
48 pub combined_score: f64,
50 pub context_entities: Vec<Entity>,
52}
53
54#[derive(Debug, Clone)]
57pub struct RagConfig {
58 pub vector_weight: f64,
61 pub graph_weight: f64,
63
64 pub top_k_candidates: usize,
67 pub top_k_rerank: usize,
69
70 pub enable_graph_expansion: bool,
73 pub graph_depth: u32,
75
76 pub context_depth: u32,
79 pub max_context_entities: usize,
81
82 pub min_vector_score: f32,
85 pub min_combined_score: f64,
87
88 pub vector_dimension: usize,
91}
92
93impl Default for RagConfig {
94 fn default() -> Self {
95 Self {
96 vector_weight: 0.6,
97 graph_weight: 0.4,
98 top_k_candidates: 50,
99 top_k_rerank: 20,
100 enable_graph_expansion: true,
101 graph_depth: 1,
102 context_depth: 2,
103 max_context_entities: 5,
104 min_vector_score: 0.0,
105 min_combined_score: 0.0,
106 vector_dimension: 384,
107 }
108 }
109}
110
111pub struct RagEngine {
117 config: RagConfig,
118}
119
120impl RagEngine {
121 pub fn new(config: RagConfig) -> Self {
122 Self { config }
123 }
124
125 pub fn search(
133 &self,
134 conn: &Connection,
135 embedder: &dyn Embedder,
136 query: &str,
137 k: usize,
138 ) -> Result<Vec<RagResult>> {
139 let query_vec = embedder.embed(query)?;
141
142 let ann_candidates = self.stage1_ann(conn, &query_vec)?;
144
145 if ann_candidates.is_empty() {
146 return Ok(Vec::new());
147 }
148
149 let mut reranked = self.stage2_rerank(conn, &query_vec, ann_candidates)?;
151 reranked.truncate(self.config.top_k_rerank);
152
153 let mut pool: HashMap<i64, f32> = reranked.into_iter().collect();
155 if self.config.enable_graph_expansion {
156 self.rapo_expand(conn, &query_vec, &mut pool)?;
157 }
158
159 let pool_size = pool.len();
161 let mut scored = self.score_and_filter(conn, &pool, pool_size)?;
162
163 scored.sort_by(|a, b| b.combined_score.partial_cmp(&a.combined_score).unwrap());
165 scored.truncate(k);
166
167 for result in &mut scored {
169 let entity_id = result.entity.id.unwrap_or(0);
170 result.context_entities = self.collect_context(conn, entity_id, &pool)?;
171 }
172
173 Ok(scored)
174 }
175
176 fn stage1_ann(&self, conn: &Connection, query_vec: &[f32]) -> Result<Vec<(i64, f32)>> {
188 let vector_count: i64 =
189 conn.query_row("SELECT COUNT(*) FROM kg_vectors", [], |r| r.get(0))?;
190
191 if vector_count == 0 {
192 return Ok(Vec::new());
193 }
194
195 let vectors_checksum: i64 = conn.query_row(
198 "SELECT COALESCE(SUM(entity_id), 0) FROM kg_vectors",
199 [],
200 |r| r.get(0),
201 )?;
202
203 let cached = load_turboquant_cache(conn, vector_count, vectors_checksum)?;
205 let index = match cached {
206 Some(idx) => idx,
207 None => {
208 let all_vectors = load_all_vectors(conn)?;
210 let dim = all_vectors[0].1.len();
211 let config = TurboQuantConfig {
212 dimension: dim,
213 bit_width: 3,
214 seed: 42,
215 };
216 let mut idx = TurboQuantIndex::new(config)?;
217 for (entity_id, vec) in &all_vectors {
218 idx.add_vector(*entity_id, vec)?;
219 }
220 save_turboquant_cache(conn, &idx, vector_count, vectors_checksum)?;
221 idx
222 }
223 };
224
225 let k = self.config.top_k_candidates.min(vector_count as usize);
226 index.search(query_vec, k)
227 }
228
229 fn stage2_rerank(
231 &self,
232 conn: &Connection,
233 query_vec: &[f32],
234 candidates: Vec<(i64, f32)>,
235 ) -> Result<Vec<(i64, f32)>> {
236 let store = VectorStore::new();
237 let mut scored: Vec<(i64, f32)> = Vec::with_capacity(candidates.len());
238
239 for (entity_id, approx) in candidates {
240 if approx < self.config.min_vector_score {
242 continue;
243 }
244 match store.get_vector(conn, entity_id) {
245 Ok(vec) => {
246 let exact = cosine_similarity(query_vec, &vec);
247 if exact >= self.config.min_vector_score {
248 scored.push((entity_id, exact));
249 }
250 }
251 Err(_) => {
252 }
254 }
255 }
256
257 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
258 Ok(scored)
259 }
260
261 fn rapo_expand(
264 &self,
265 conn: &Connection,
266 query_vec: &[f32],
267 pool: &mut HashMap<i64, f32>,
268 ) -> Result<()> {
269 let store = VectorStore::new();
270 let seeds: Vec<i64> = pool.keys().copied().collect();
271
272 for seed_id in seeds {
273 let neighbours = match get_neighbors(conn, seed_id, self.config.graph_depth) {
274 Ok(n) => n,
275 Err(_) => continue,
276 };
277
278 for nbr in neighbours {
279 let nbr_id = match nbr.entity.id {
280 Some(id) => id,
281 None => continue,
282 };
283
284 if pool.contains_key(&nbr_id) {
285 continue;
286 }
287
288 if let Ok(vec) = store.get_vector(conn, nbr_id) {
290 let score = cosine_similarity(query_vec, &vec);
291 if score >= self.config.min_vector_score {
292 pool.insert(nbr_id, score);
293 }
294 }
295 }
296 }
297
298 Ok(())
299 }
300
301 fn score_and_filter(
304 &self,
305 conn: &Connection,
306 pool: &HashMap<i64, f32>,
307 pool_size: usize,
308 ) -> Result<Vec<RagResult>> {
309 let mut results = Vec::new();
310
311 for (&entity_id, &v_score) in pool {
312 let vector_score = v_score as f64;
313
314 let graph_score = if pool_size > 1 {
316 let neighbours = get_neighbors(conn, entity_id, 1).unwrap_or_default();
317 let overlap = neighbours
318 .iter()
319 .filter(|n| {
320 n.entity
321 .id
322 .map(|id| pool.contains_key(&id))
323 .unwrap_or(false)
324 })
325 .count();
326 overlap as f64 / (pool_size - 1) as f64
327 } else {
328 0.0
329 };
330
331 let combined_score =
332 self.config.vector_weight * vector_score + self.config.graph_weight * graph_score;
333
334 if combined_score < self.config.min_combined_score {
336 continue;
337 }
338
339 let entity = match crate::graph::get_entity(conn, entity_id) {
340 Ok(e) => e,
341 Err(_) => continue,
342 };
343
344 results.push(RagResult {
345 entity,
346 vector_score,
347 graph_score,
348 combined_score,
349 context_entities: Vec::new(), });
351 }
352
353 Ok(results)
354 }
355
356 fn collect_context(
359 &self,
360 conn: &Connection,
361 entity_id: i64,
362 pool: &HashMap<i64, f32>,
363 ) -> Result<Vec<Entity>> {
364 let neighbours = match get_neighbors(conn, entity_id, self.config.context_depth) {
365 Ok(n) => n,
366 Err(_) => return Ok(Vec::new()),
367 };
368
369 let mut in_pool: Vec<Entity> = Vec::new();
371 let mut not_in_pool: Vec<Entity> = Vec::new();
372
373 for nbr in neighbours {
374 if let Some(id) = nbr.entity.id {
375 if pool.contains_key(&id) {
376 in_pool.push(nbr.entity);
377 } else {
378 not_in_pool.push(nbr.entity);
379 }
380 }
381 }
382
383 in_pool.extend(not_in_pool);
384 in_pool.truncate(self.config.max_context_entities);
385 Ok(in_pool)
386 }
387}
388
389fn load_turboquant_cache(
398 conn: &Connection,
399 current_count: i64,
400 current_checksum: i64,
401) -> Result<Option<TurboQuantIndex>> {
402 let mut stmt = conn.prepare(
403 "SELECT index_blob, vector_count, vectors_checksum \
404 FROM kg_turboquant_cache WHERE id = 1",
405 )?;
406
407 let result = stmt.query_row([], |row| {
408 let blob: Vec<u8> = row.get(0)?;
409 let cached_count: i64 = row.get(1)?;
410 let cached_checksum: i64 = row.get(2)?;
411 Ok((blob, cached_count, cached_checksum))
412 });
413
414 match result {
415 Ok((blob, cached_count, cached_checksum))
416 if cached_count == current_count && cached_checksum == current_checksum =>
417 {
418 let index = TurboQuantIndex::from_bytes(&blob)
419 .map_err(|e| crate::error::Error::Other(e.to_string()))?;
420 Ok(Some(index))
421 }
422 Ok(_) | Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
423 Err(e) => Err(e.into()),
424 }
425}
426
427fn save_turboquant_cache(
429 conn: &Connection,
430 index: &TurboQuantIndex,
431 vector_count: i64,
432 vectors_checksum: i64,
433) -> Result<()> {
434 let blob = index
435 .to_bytes()
436 .map_err(|e| crate::error::Error::Other(e.to_string()))?;
437 conn.execute(
438 "INSERT INTO kg_turboquant_cache \
439 (id, index_blob, vector_count, vectors_checksum) \
440 VALUES (1, ?1, ?2, ?3) \
441 ON CONFLICT(id) DO UPDATE SET \
442 index_blob = excluded.index_blob, \
443 vector_count = excluded.vector_count, \
444 vectors_checksum = excluded.vectors_checksum",
445 rusqlite::params![blob, vector_count, vectors_checksum],
446 )?;
447 Ok(())
448}
449
450fn load_all_vectors(conn: &Connection) -> Result<Vec<(i64, Vec<f32>)>> {
451 let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
452
453 let rows = stmt.query_map([], |row| {
454 let entity_id: i64 = row.get(0)?;
455 let blob: Vec<u8> = row.get(1)?;
456 let dim: i64 = row.get(2)?;
457
458 let mut vec = Vec::with_capacity(dim as usize);
459 for chunk in blob.chunks_exact(4) {
460 vec.push(f32::from_le_bytes(chunk.try_into().unwrap()));
461 }
462
463 Ok((entity_id, vec))
464 })?;
465
466 let mut out = Vec::new();
467 for row in rows {
468 out.push(row?);
469 }
470 Ok(out)
471}
472
473#[cfg(test)]
478mod tests {
479 use super::*;
480 use crate::graph::entity::{insert_entity, Entity};
481 use crate::graph::relation::{insert_relation, Relation};
482 use crate::rag::embedder::FixedEmbedder;
483 use crate::vector::VectorStore;
484 use rusqlite::Connection;
485
486 fn setup(dim: usize) -> (Connection, Vec<i64>) {
487 let conn = Connection::open_in_memory().unwrap();
488 crate::schema::create_schema(&conn).unwrap();
489
490 let e1 = insert_entity(&conn, &Entity::new("doc", "Doc A")).unwrap();
491 let e2 = insert_entity(&conn, &Entity::new("doc", "Doc B")).unwrap();
492 let e3 = insert_entity(&conn, &Entity::new("doc", "Doc C")).unwrap();
493
494 let store = VectorStore::new();
495 let mut v1 = vec![0.0f32; dim];
497 v1[0] = 1.0;
498 store.insert_vector(&conn, e1, v1).unwrap();
499
500 let mut v2 = vec![0.0f32; dim];
502 v2[1] = 1.0;
503 store.insert_vector(&conn, e2, v2).unwrap();
504
505 let mut v3 = vec![0.0f32; dim];
507 v3[0] = 0.8;
508 v3[1] = 0.6;
509 store.insert_vector(&conn, e3, v3).unwrap();
510
511 insert_relation(&conn, &Relation::new(e1, e2, "related", 0.3).unwrap()).unwrap();
513 insert_relation(&conn, &Relation::new(e1, e3, "related", 0.9).unwrap()).unwrap();
514
515 (conn, vec![e1, e2, e3])
516 }
517
518 #[test]
519 fn test_basic_search() {
520 let dim = 4;
521 let (conn, ids) = setup(dim);
522
523 let mut query = vec![0.0f32; dim];
524 query[0] = 1.0;
525
526 let embedder = FixedEmbedder(query);
527 let engine = RagEngine::new(RagConfig {
528 vector_dimension: dim,
529 top_k_candidates: 10,
530 top_k_rerank: 5,
531 ..Default::default()
532 });
533
534 let results = engine.search(&conn, &embedder, "test query", 2).unwrap();
535 assert!(!results.is_empty(), "should return at least one result");
536
537 assert_eq!(results[0].entity.id, Some(ids[0]));
539 assert!((results[0].vector_score - 1.0).abs() < 1e-5);
540 }
541
542 #[test]
543 fn test_empty_db() {
544 let conn = Connection::open_in_memory().unwrap();
545 crate::schema::create_schema(&conn).unwrap();
546
547 let embedder = FixedEmbedder(vec![1.0, 0.0, 0.0]);
548 let engine = RagEngine::new(RagConfig::default());
549
550 let results = engine.search(&conn, &embedder, "anything", 5).unwrap();
551 assert!(results.is_empty());
552 }
553
554 #[test]
555 fn test_graph_expansion() {
556 let dim = 4;
558 let conn = Connection::open_in_memory().unwrap();
559 crate::schema::create_schema(&conn).unwrap();
560
561 let store = VectorStore::new();
562
563 let e1 = insert_entity(&conn, &Entity::new("doc", "A")).unwrap();
565 let e2 = insert_entity(&conn, &Entity::new("doc", "B")).unwrap();
566
567 let mut v1 = vec![0.0f32; dim];
568 v1[0] = 1.0;
569 store.insert_vector(&conn, e1, v1).unwrap();
570
571 let mut v2 = vec![0.0f32; dim];
572 v2[1] = 1.0;
573 store.insert_vector(&conn, e2, v2).unwrap();
574
575 insert_relation(&conn, &Relation::new(e1, e2, "link", 1.0).unwrap()).unwrap();
576
577 let mut query = vec![0.0f32; dim];
578 query[0] = 1.0;
579
580 let embedder = FixedEmbedder(query);
581 let engine = RagEngine::new(RagConfig {
582 vector_dimension: dim,
583 top_k_candidates: 1, top_k_rerank: 1,
585 enable_graph_expansion: true,
586 ..Default::default()
587 });
588
589 let results = engine.search(&conn, &embedder, "q", 5).unwrap();
590 let ids: Vec<i64> = results.iter().filter_map(|r| r.entity.id).collect();
591 assert!(ids.contains(&e1));
592 assert!(ids.contains(&e2), "RAPO should expand to e2");
593 }
594
595 #[test]
596 fn test_context_attached() {
597 let dim = 4;
598 let (conn, ids) = setup(dim);
599
600 let mut query = vec![0.0f32; dim];
601 query[0] = 1.0;
602
603 let embedder = FixedEmbedder(query);
604 let engine = RagEngine::new(RagConfig {
605 vector_dimension: dim,
606 context_depth: 1,
607 max_context_entities: 3,
608 ..Default::default()
609 });
610
611 let results = engine.search(&conn, &embedder, "q", 3).unwrap();
612
613 let e1_result = results.iter().find(|r| r.entity.id == Some(ids[0]));
615 assert!(e1_result.is_some());
616 let ctx = &e1_result.unwrap().context_entities;
617 assert!(!ctx.is_empty(), "e1 should have context neighbours");
618 }
619
620 #[test]
623 fn test_cache_written_on_first_query() {
624 let dim = 4;
625 let (conn, _ids) = setup(dim);
626
627 let mut query = vec![0.0f32; dim];
628 query[0] = 1.0;
629 let embedder = FixedEmbedder(query);
630 let engine = RagEngine::new(RagConfig {
631 vector_dimension: dim,
632 top_k_candidates: 10,
633 top_k_rerank: 5,
634 ..Default::default()
635 });
636
637 engine.search(&conn, &embedder, "q", 2).unwrap();
638
639 let count: i64 = conn
641 .query_row(
642 "SELECT COUNT(*) FROM kg_turboquant_cache WHERE id = 1",
643 [],
644 |r| r.get(0),
645 )
646 .unwrap();
647 assert_eq!(count, 1, "cache row should be created after first query");
648 }
649
650 #[test]
651 fn test_cache_hit_on_second_query() {
652 let dim = 4;
653 let (conn, _ids) = setup(dim);
654
655 let mut query = vec![0.0f32; dim];
656 query[0] = 1.0;
657 let embedder = FixedEmbedder(query);
658 let engine = RagEngine::new(RagConfig {
659 vector_dimension: dim,
660 top_k_candidates: 10,
661 top_k_rerank: 5,
662 ..Default::default()
663 });
664
665 let r1 = engine.search(&conn, &embedder, "q", 2).unwrap();
666 let r2 = engine.search(&conn, &embedder, "q", 2).unwrap();
667
668 assert_eq!(
670 r1[0].entity.id, r2[0].entity.id,
671 "cache hit should return identical results"
672 );
673 }
674
675 #[test]
676 fn test_cache_stores_checksum() {
677 let dim = 4;
678 let (conn, _ids) = setup(dim);
679
680 let query = {
681 let mut q = vec![0.0f32; dim];
682 q[0] = 1.0;
683 q
684 };
685 let embedder = FixedEmbedder(query);
686 let engine = RagEngine::new(RagConfig {
687 vector_dimension: dim,
688 top_k_candidates: 10,
689 top_k_rerank: 5,
690 ..Default::default()
691 });
692
693 engine.search(&conn, &embedder, "q", 2).unwrap();
694
695 let (count, checksum): (i64, i64) = conn
697 .query_row(
698 "SELECT vector_count, vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
699 [],
700 |r| Ok((r.get(0)?, r.get(1)?)),
701 )
702 .unwrap();
703 assert_eq!(count, 3);
704 assert!(checksum > 0, "checksum should reflect entity_id sum");
706 }
707
708 #[test]
709 fn test_cache_invalidated_on_same_count_different_entity() {
710 let dim = 4;
713 let (conn, ids) = setup(dim); let query = {
716 let mut q = vec![0.0f32; dim];
717 q[0] = 1.0;
718 q
719 };
720 let embedder = FixedEmbedder(query);
721 let engine = RagEngine::new(RagConfig {
722 vector_dimension: dim,
723 top_k_candidates: 10,
724 top_k_rerank: 5,
725 ..Default::default()
726 });
727
728 engine.search(&conn, &embedder, "q", 2).unwrap();
730
731 let checksum_before: i64 = conn
732 .query_row(
733 "SELECT vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
734 [],
735 |r| r.get(0),
736 )
737 .unwrap();
738
739 conn.execute("DELETE FROM kg_vectors WHERE entity_id = ?1", [ids[2]])
741 .unwrap();
742 let e_new = crate::graph::entity::insert_entity(
743 &conn,
744 &crate::graph::entity::Entity::new("doc", "Doc Swap"),
745 )
746 .unwrap();
747 let store = VectorStore::new();
748 let mut v_new = vec![0.0f32; dim];
749 v_new[3] = 1.0;
750 store.insert_vector(&conn, e_new, v_new).unwrap();
751 engine.search(&conn, &embedder, "q", 2).unwrap();
755
756 let (count_after, checksum_after): (i64, i64) = conn
757 .query_row(
758 "SELECT vector_count, vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
759 [],
760 |r| Ok((r.get(0)?, r.get(1)?)),
761 )
762 .unwrap();
763 assert_eq!(count_after, 3, "vector count should still be 3 after swap");
764 assert_ne!(
765 checksum_after, checksum_before,
766 "checksum must change after swapping one vector"
767 );
768 }
769
770 #[test]
771 fn test_cache_invalidated_after_new_vector() {
772 let dim = 4;
773 let (conn, _ids) = setup(dim);
774
775 let mut query = vec![0.0f32; dim];
776 query[0] = 1.0;
777 let embedder = FixedEmbedder(query);
778 let engine = RagEngine::new(RagConfig {
779 vector_dimension: dim,
780 top_k_candidates: 10,
781 top_k_rerank: 5,
782 ..Default::default()
783 });
784
785 engine.search(&conn, &embedder, "q", 2).unwrap();
787
788 let cached_count_before: i64 = conn
789 .query_row(
790 "SELECT vector_count FROM kg_turboquant_cache WHERE id = 1",
791 [],
792 |r| r.get(0),
793 )
794 .unwrap();
795 assert_eq!(cached_count_before, 3);
796
797 let e4 = crate::graph::entity::insert_entity(
799 &conn,
800 &crate::graph::entity::Entity::new("doc", "Doc D"),
801 )
802 .unwrap();
803 let store = VectorStore::new();
804 let mut v4 = vec![0.0f32; dim];
805 v4[2] = 1.0;
806 store.insert_vector(&conn, e4, v4).unwrap();
807
808 engine.search(&conn, &embedder, "q", 2).unwrap();
810
811 let cached_count_after: i64 = conn
812 .query_row(
813 "SELECT vector_count FROM kg_turboquant_cache WHERE id = 1",
814 [],
815 |r| r.get(0),
816 )
817 .unwrap();
818 assert_eq!(
819 cached_count_after, 4,
820 "cache should be rebuilt after new vector added"
821 );
822 }
823}