1pub mod embedder;
24mod error;
25
26pub use embedder::Embedder;
27pub use error::RagError;
28
29use crate::error::Result;
30use crate::graph::{get_neighbors, Entity};
31use crate::vector::{cosine_similarity, TurboQuantConfig, TurboQuantIndex, VectorStore};
32use rusqlite::Connection;
33use std::collections::HashMap;
34
35#[derive(Debug, Clone)]
41pub struct RagResult {
42 pub entity: Entity,
44 pub vector_score: f64,
46 pub graph_score: f64,
48 pub combined_score: f64,
50 pub context_entities: Vec<Entity>,
52}
53
54#[derive(Debug, Clone)]
57pub struct RagConfig {
58 pub vector_weight: f64,
61 pub graph_weight: f64,
63
64 pub top_k_candidates: usize,
67 pub top_k_rerank: usize,
69
70 pub enable_graph_expansion: bool,
73 pub graph_depth: u32,
75
76 pub context_depth: u32,
79 pub max_context_entities: usize,
81
82 pub min_vector_score: f32,
85 pub min_combined_score: f64,
87
88 pub vector_dimension: usize,
91}
92
93impl Default for RagConfig {
94 fn default() -> Self {
95 Self {
96 vector_weight: 0.6,
97 graph_weight: 0.4,
98 top_k_candidates: 50,
99 top_k_rerank: 20,
100 enable_graph_expansion: true,
101 graph_depth: 1,
102 context_depth: 2,
103 max_context_entities: 5,
104 min_vector_score: 0.0,
105 min_combined_score: 0.0,
106 vector_dimension: 384,
107 }
108 }
109}
110
111pub struct RagEngine {
117 config: RagConfig,
118}
119
120impl RagEngine {
121 pub fn new(config: RagConfig) -> Self {
122 Self { config }
123 }
124
125 pub fn search(
133 &self,
134 conn: &Connection,
135 embedder: &dyn Embedder,
136 query: &str,
137 k: usize,
138 ) -> Result<Vec<RagResult>> {
139 let query_vec = embedder.embed(query)?;
141
142 let ann_candidates = self.stage1_ann(conn, &query_vec)?;
144
145 if ann_candidates.is_empty() {
146 return Ok(Vec::new());
147 }
148
149 let mut reranked = self.stage2_rerank(conn, &query_vec, ann_candidates)?;
151 reranked.truncate(self.config.top_k_rerank);
152
153 let mut pool: HashMap<i64, f32> = reranked.into_iter().collect();
155 if self.config.enable_graph_expansion {
156 self.rapo_expand(conn, &query_vec, &mut pool)?;
157 }
158
159 let pool_size = pool.len();
161 let mut scored = self.score_and_filter(conn, &pool, pool_size)?;
162
163 scored.sort_by(|a, b| b.combined_score.partial_cmp(&a.combined_score).unwrap());
165 scored.truncate(k);
166
167 for result in &mut scored {
169 let entity_id = result.entity.id.unwrap_or(0);
170 result.context_entities = self.collect_context(conn, entity_id, &pool)?;
171 }
172
173 Ok(scored)
174 }
175
176 fn stage1_ann(&self, conn: &Connection, query_vec: &[f32]) -> Result<Vec<(i64, f32)>> {
185 let vector_count: i64 =
186 conn.query_row("SELECT COUNT(*) FROM kg_vectors", [], |r| r.get(0))?;
187
188 if vector_count == 0 {
189 return Ok(Vec::new());
190 }
191
192 let cached = load_turboquant_cache(conn, vector_count)?;
194 let index = match cached {
195 Some(idx) => idx,
196 None => {
197 let all_vectors = load_all_vectors(conn)?;
199 let dim = all_vectors[0].1.len();
200 let config = TurboQuantConfig {
201 dimension: dim,
202 bit_width: 3,
203 seed: 42,
204 };
205 let mut idx = TurboQuantIndex::new(config)?;
206 for (entity_id, vec) in &all_vectors {
207 idx.add_vector(*entity_id, vec)?;
208 }
209 save_turboquant_cache(conn, &idx, vector_count)?;
210 idx
211 }
212 };
213
214 let k = self.config.top_k_candidates.min(vector_count as usize);
215 index.search(query_vec, k)
216 }
217
218 fn stage2_rerank(
220 &self,
221 conn: &Connection,
222 query_vec: &[f32],
223 candidates: Vec<(i64, f32)>,
224 ) -> Result<Vec<(i64, f32)>> {
225 let store = VectorStore::new();
226 let mut scored: Vec<(i64, f32)> = Vec::with_capacity(candidates.len());
227
228 for (entity_id, approx) in candidates {
229 if approx < self.config.min_vector_score {
231 continue;
232 }
233 match store.get_vector(conn, entity_id) {
234 Ok(vec) => {
235 let exact = cosine_similarity(query_vec, &vec);
236 if exact >= self.config.min_vector_score {
237 scored.push((entity_id, exact));
238 }
239 }
240 Err(_) => {
241 }
243 }
244 }
245
246 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
247 Ok(scored)
248 }
249
250 fn rapo_expand(
253 &self,
254 conn: &Connection,
255 query_vec: &[f32],
256 pool: &mut HashMap<i64, f32>,
257 ) -> Result<()> {
258 let store = VectorStore::new();
259 let seeds: Vec<i64> = pool.keys().copied().collect();
260
261 for seed_id in seeds {
262 let neighbours = match get_neighbors(conn, seed_id, self.config.graph_depth) {
263 Ok(n) => n,
264 Err(_) => continue,
265 };
266
267 for nbr in neighbours {
268 let nbr_id = match nbr.entity.id {
269 Some(id) => id,
270 None => continue,
271 };
272
273 if pool.contains_key(&nbr_id) {
274 continue;
275 }
276
277 if let Ok(vec) = store.get_vector(conn, nbr_id) {
279 let score = cosine_similarity(query_vec, &vec);
280 if score >= self.config.min_vector_score {
281 pool.insert(nbr_id, score);
282 }
283 }
284 }
285 }
286
287 Ok(())
288 }
289
290 fn score_and_filter(
293 &self,
294 conn: &Connection,
295 pool: &HashMap<i64, f32>,
296 pool_size: usize,
297 ) -> Result<Vec<RagResult>> {
298 let mut results = Vec::new();
299
300 for (&entity_id, &v_score) in pool {
301 let vector_score = v_score as f64;
302
303 let graph_score = if pool_size > 1 {
305 let neighbours = get_neighbors(conn, entity_id, 1).unwrap_or_default();
306 let overlap = neighbours
307 .iter()
308 .filter(|n| {
309 n.entity
310 .id
311 .map(|id| pool.contains_key(&id))
312 .unwrap_or(false)
313 })
314 .count();
315 overlap as f64 / (pool_size - 1) as f64
316 } else {
317 0.0
318 };
319
320 let combined_score =
321 self.config.vector_weight * vector_score + self.config.graph_weight * graph_score;
322
323 if combined_score < self.config.min_combined_score {
325 continue;
326 }
327
328 let entity = match crate::graph::get_entity(conn, entity_id) {
329 Ok(e) => e,
330 Err(_) => continue,
331 };
332
333 results.push(RagResult {
334 entity,
335 vector_score,
336 graph_score,
337 combined_score,
338 context_entities: Vec::new(), });
340 }
341
342 Ok(results)
343 }
344
345 fn collect_context(
348 &self,
349 conn: &Connection,
350 entity_id: i64,
351 pool: &HashMap<i64, f32>,
352 ) -> Result<Vec<Entity>> {
353 let neighbours = match get_neighbors(conn, entity_id, self.config.context_depth) {
354 Ok(n) => n,
355 Err(_) => return Ok(Vec::new()),
356 };
357
358 let mut in_pool: Vec<Entity> = Vec::new();
360 let mut not_in_pool: Vec<Entity> = Vec::new();
361
362 for nbr in neighbours {
363 if let Some(id) = nbr.entity.id {
364 if pool.contains_key(&id) {
365 in_pool.push(nbr.entity);
366 } else {
367 not_in_pool.push(nbr.entity);
368 }
369 }
370 }
371
372 in_pool.extend(not_in_pool);
373 in_pool.truncate(self.config.max_context_entities);
374 Ok(in_pool)
375 }
376}
377
378fn load_turboquant_cache(
387 conn: &Connection,
388 current_count: i64,
389) -> Result<Option<TurboQuantIndex>> {
390 let mut stmt = conn.prepare(
391 "SELECT index_blob, vector_count FROM kg_turboquant_cache WHERE id = 1",
392 )?;
393
394 let result = stmt.query_row([], |row| {
395 let blob: Vec<u8> = row.get(0)?;
396 let cached_count: i64 = row.get(1)?;
397 Ok((blob, cached_count))
398 });
399
400 match result {
401 Ok((blob, cached_count)) if cached_count == current_count => {
402 let index = TurboQuantIndex::from_bytes(&blob)
403 .map_err(|e| crate::error::Error::Other(e.to_string()))?;
404 Ok(Some(index))
405 }
406 Ok(_) | Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
407 Err(e) => Err(e.into()),
408 }
409}
410
411fn save_turboquant_cache(
413 conn: &Connection,
414 index: &TurboQuantIndex,
415 vector_count: i64,
416) -> Result<()> {
417 let blob = index
418 .to_bytes()
419 .map_err(|e| crate::error::Error::Other(e.to_string()))?;
420 conn.execute(
421 "INSERT INTO kg_turboquant_cache (id, index_blob, vector_count) \
422 VALUES (1, ?1, ?2) \
423 ON CONFLICT(id) DO UPDATE SET index_blob = excluded.index_blob, \
424 vector_count = excluded.vector_count",
425 rusqlite::params![blob, vector_count],
426 )?;
427 Ok(())
428}
429
430fn load_all_vectors(conn: &Connection) -> Result<Vec<(i64, Vec<f32>)>> {
431 let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
432
433 let rows = stmt.query_map([], |row| {
434 let entity_id: i64 = row.get(0)?;
435 let blob: Vec<u8> = row.get(1)?;
436 let dim: i64 = row.get(2)?;
437
438 let mut vec = Vec::with_capacity(dim as usize);
439 for chunk in blob.chunks_exact(4) {
440 vec.push(f32::from_le_bytes(chunk.try_into().unwrap()));
441 }
442
443 Ok((entity_id, vec))
444 })?;
445
446 let mut out = Vec::new();
447 for row in rows {
448 out.push(row?);
449 }
450 Ok(out)
451}
452
453#[cfg(test)]
458mod tests {
459 use super::*;
460 use crate::graph::entity::{insert_entity, Entity};
461 use crate::graph::relation::{insert_relation, Relation};
462 use crate::rag::embedder::FixedEmbedder;
463 use crate::vector::VectorStore;
464 use rusqlite::Connection;
465
466 fn setup(dim: usize) -> (Connection, Vec<i64>) {
467 let conn = Connection::open_in_memory().unwrap();
468 crate::schema::create_schema(&conn).unwrap();
469
470 let e1 = insert_entity(&conn, &Entity::new("doc", "Doc A")).unwrap();
471 let e2 = insert_entity(&conn, &Entity::new("doc", "Doc B")).unwrap();
472 let e3 = insert_entity(&conn, &Entity::new("doc", "Doc C")).unwrap();
473
474 let store = VectorStore::new();
475 let mut v1 = vec![0.0f32; dim];
477 v1[0] = 1.0;
478 store.insert_vector(&conn, e1, v1).unwrap();
479
480 let mut v2 = vec![0.0f32; dim];
482 v2[1] = 1.0;
483 store.insert_vector(&conn, e2, v2).unwrap();
484
485 let mut v3 = vec![0.0f32; dim];
487 v3[0] = 0.8;
488 v3[1] = 0.6;
489 store.insert_vector(&conn, e3, v3).unwrap();
490
491 insert_relation(&conn, &Relation::new(e1, e2, "related", 0.3).unwrap()).unwrap();
493 insert_relation(&conn, &Relation::new(e1, e3, "related", 0.9).unwrap()).unwrap();
494
495 (conn, vec![e1, e2, e3])
496 }
497
498 #[test]
499 fn test_basic_search() {
500 let dim = 4;
501 let (conn, ids) = setup(dim);
502
503 let mut query = vec![0.0f32; dim];
504 query[0] = 1.0;
505
506 let embedder = FixedEmbedder(query);
507 let engine = RagEngine::new(RagConfig {
508 vector_dimension: dim,
509 top_k_candidates: 10,
510 top_k_rerank: 5,
511 ..Default::default()
512 });
513
514 let results = engine.search(&conn, &embedder, "test query", 2).unwrap();
515 assert!(!results.is_empty(), "should return at least one result");
516
517 assert_eq!(results[0].entity.id, Some(ids[0]));
519 assert!((results[0].vector_score - 1.0).abs() < 1e-5);
520 }
521
522 #[test]
523 fn test_empty_db() {
524 let conn = Connection::open_in_memory().unwrap();
525 crate::schema::create_schema(&conn).unwrap();
526
527 let embedder = FixedEmbedder(vec![1.0, 0.0, 0.0]);
528 let engine = RagEngine::new(RagConfig::default());
529
530 let results = engine.search(&conn, &embedder, "anything", 5).unwrap();
531 assert!(results.is_empty());
532 }
533
534 #[test]
535 fn test_graph_expansion() {
536 let dim = 4;
538 let conn = Connection::open_in_memory().unwrap();
539 crate::schema::create_schema(&conn).unwrap();
540
541 let store = VectorStore::new();
542
543 let e1 = insert_entity(&conn, &Entity::new("doc", "A")).unwrap();
545 let e2 = insert_entity(&conn, &Entity::new("doc", "B")).unwrap();
546
547 let mut v1 = vec![0.0f32; dim];
548 v1[0] = 1.0;
549 store.insert_vector(&conn, e1, v1).unwrap();
550
551 let mut v2 = vec![0.0f32; dim];
552 v2[1] = 1.0;
553 store.insert_vector(&conn, e2, v2).unwrap();
554
555 insert_relation(&conn, &Relation::new(e1, e2, "link", 1.0).unwrap()).unwrap();
556
557 let mut query = vec![0.0f32; dim];
558 query[0] = 1.0;
559
560 let embedder = FixedEmbedder(query);
561 let engine = RagEngine::new(RagConfig {
562 vector_dimension: dim,
563 top_k_candidates: 1, top_k_rerank: 1,
565 enable_graph_expansion: true,
566 ..Default::default()
567 });
568
569 let results = engine.search(&conn, &embedder, "q", 5).unwrap();
570 let ids: Vec<i64> = results.iter().filter_map(|r| r.entity.id).collect();
571 assert!(ids.contains(&e1));
572 assert!(ids.contains(&e2), "RAPO should expand to e2");
573 }
574
575 #[test]
576 fn test_context_attached() {
577 let dim = 4;
578 let (conn, ids) = setup(dim);
579
580 let mut query = vec![0.0f32; dim];
581 query[0] = 1.0;
582
583 let embedder = FixedEmbedder(query);
584 let engine = RagEngine::new(RagConfig {
585 vector_dimension: dim,
586 context_depth: 1,
587 max_context_entities: 3,
588 ..Default::default()
589 });
590
591 let results = engine.search(&conn, &embedder, "q", 3).unwrap();
592
593 let e1_result = results.iter().find(|r| r.entity.id == Some(ids[0]));
595 assert!(e1_result.is_some());
596 let ctx = &e1_result.unwrap().context_entities;
597 assert!(!ctx.is_empty(), "e1 should have context neighbours");
598 }
599}