Skip to main content

sqlite_knowledge_graph/rag/
mod.rs

1//! Paper-driven two-stage RAG engine.
2//!
3//! Pipeline (per-query):
4//! ```text
5//! query: &str
6//!   → Embedder::embed()                      [plug-in]
7//!   → [Stage 1 · MemRL]  TurboQuantIndex ANN (top_k_candidates)
8//!   → [Stage 2 · MemRL]  exact cosine rerank (top_k_rerank)
9//!   → [RAPO]             BFS graph expansion + vector score
10//!   → [combined score]   vector_weight·v + graph_weight·g
11//!   → [SuperLocalMemory] quality filter (min thresholds)
12//!   → sort desc, take k
13//!   → [Memex(RL)]        context BFS for each result (max_context_entities)
14//! ```
15//!
16//! References:
17//! - MemRL (2601.03192): two-stage ANN → exact rerank
18//! - RAPO  (2603.02958): graph-neighbour expansion
19//! - Memex (2603.03561): context-entity sizing
20//! - SuperLocalMemory (2602.13398): quality threshold filtering
21//! - NN-RAG (2511.20333): retrieval quality over quantity
22
23pub mod embedder;
24mod error;
25
26pub use embedder::Embedder;
27pub use error::RagError;
28
29use crate::error::Result;
30use crate::graph::{get_neighbors, Entity};
31use crate::vector::{cosine_similarity, TurboQuantConfig, TurboQuantIndex, VectorStore};
32use rusqlite::Connection;
33use std::collections::HashMap;
34
35// ─────────────────────────────────────────────────────────────────────────────
36// Public types
37// ─────────────────────────────────────────────────────────────────────────────
38
39/// One result row from the RAG engine.
40#[derive(Debug, Clone)]
41pub struct RagResult {
42    /// The matched entity.
43    pub entity: Entity,
44    /// Exact cosine similarity score in [0, 1].
45    pub vector_score: f64,
46    /// Graph connectivity score in [0, 1] (fraction of pool that is a neighbour).
47    pub graph_score: f64,
48    /// Final weighted score: `vector_weight·v + graph_weight·g`.
49    pub combined_score: f64,
50    /// Context neighbours collected by BFS (Memex(RL) sizing).
51    pub context_entities: Vec<Entity>,
52}
53
54/// Configuration knobs for the RAG pipeline.
55/// All fields have defaults tuned to the paper recommendations.
56#[derive(Debug, Clone)]
57pub struct RagConfig {
58    // ── Scoring ──────────────────────────────────────────────────────────────
59    /// Weight applied to the vector (semantic) score. Default: 0.6.
60    pub vector_weight: f64,
61    /// Weight applied to the graph connectivity score. Default: 0.4.
62    pub graph_weight: f64,
63
64    // ── MemRL two-stage retrieval ─────────────────────────────────────────────
65    /// Stage-1: how many ANN candidates to fetch from TurboQuant. Default: 50.
66    pub top_k_candidates: usize,
67    /// Stage-2: how many candidates survive after exact rerank. Default: 20.
68    pub top_k_rerank: usize,
69
70    // ── RAPO graph expansion ──────────────────────────────────────────────────
71    /// Whether to expand candidates via BFS neighbours. Default: true.
72    pub enable_graph_expansion: bool,
73    /// BFS depth for graph-score neighbour collection. Default: 1.
74    pub graph_depth: u32,
75
76    // ── Memex(RL) context sizing ──────────────────────────────────────────────
77    /// BFS depth for context collection. Default: 2.
78    pub context_depth: u32,
79    /// Maximum context entities attached to each result. Default: 5.
80    pub max_context_entities: usize,
81
82    // ── SuperLocalMemory quality thresholds ───────────────────────────────────
83    /// Minimum vector score for a candidate to survive. Default: 0.0 (off).
84    pub min_vector_score: f32,
85    /// Minimum combined score for a result to survive. Default: 0.0 (off).
86    pub min_combined_score: f64,
87
88    // ── TurboQuant index ─────────────────────────────────────────────────────
89    /// Vector dimension; set to match your embedding model. Default: 384.
90    pub vector_dimension: usize,
91}
92
93impl Default for RagConfig {
94    fn default() -> Self {
95        Self {
96            vector_weight: 0.6,
97            graph_weight: 0.4,
98            top_k_candidates: 50,
99            top_k_rerank: 20,
100            enable_graph_expansion: true,
101            graph_depth: 1,
102            context_depth: 2,
103            max_context_entities: 5,
104            min_vector_score: 0.0,
105            min_combined_score: 0.0,
106            vector_dimension: 384,
107        }
108    }
109}
110
111// ─────────────────────────────────────────────────────────────────────────────
112// RagEngine
113// ─────────────────────────────────────────────────────────────────────────────
114
115/// Hybrid RAG engine backed by SQLite.
116pub struct RagEngine {
117    config: RagConfig,
118}
119
120impl RagEngine {
121    pub fn new(config: RagConfig) -> Self {
122        Self { config }
123    }
124
125    /// Full hybrid search.
126    ///
127    /// # Arguments
128    /// * `conn`   – active SQLite connection (read/write access required)
129    /// * `embedder` – embedding backend (see `embedder::SubprocessEmbedder`)
130    /// * `query`  – raw query text
131    /// * `k`      – how many results to return
132    pub fn search(
133        &self,
134        conn: &Connection,
135        embedder: &dyn Embedder,
136        query: &str,
137        k: usize,
138    ) -> Result<Vec<RagResult>> {
139        // ── 0. Embed query ───────────────────────────────────────────────────
140        let query_vec = embedder.embed(query)?;
141
142        // ── Stage 1 · MemRL – fast ANN via TurboQuant ───────────────────────
143        let ann_candidates = self.stage1_ann(conn, &query_vec)?;
144
145        if ann_candidates.is_empty() {
146            return Ok(Vec::new());
147        }
148
149        // ── Stage 2 · MemRL – exact cosine rerank ───────────────────────────
150        let mut reranked = self.stage2_rerank(conn, &query_vec, ann_candidates)?;
151        reranked.truncate(self.config.top_k_rerank);
152
153        // ── RAPO – expand with graph neighbours ─────────────────────────────
154        let mut pool: HashMap<i64, f32> = reranked.into_iter().collect();
155        if self.config.enable_graph_expansion {
156            self.rapo_expand(conn, &query_vec, &mut pool)?;
157        }
158
159        // ── Score & filter (SuperLocalMemory) ───────────────────────────────
160        let pool_size = pool.len();
161        let mut scored = self.score_and_filter(conn, &pool, pool_size)?;
162
163        // Sort by combined_score descending, take top k
164        scored.sort_by(|a, b| b.combined_score.partial_cmp(&a.combined_score).unwrap());
165        scored.truncate(k);
166
167        // ── Memex(RL) – attach context neighbours ───────────────────────────
168        for result in &mut scored {
169            let entity_id = result.entity.id.unwrap_or(0);
170            result.context_entities = self.collect_context(conn, entity_id, &pool)?;
171        }
172
173        Ok(scored)
174    }
175
176    // ─────────────────────────────────────────────────────────────────────────
177    // Private helpers
178    // ─────────────────────────────────────────────────────────────────────────
179
180    /// Stage 1: ANN via TurboQuant.  Returns (entity_id, approx_score) pairs.
181    ///
182    /// The TurboQuant index is persisted in `kg_turboquant_cache` and only
183    /// rebuilt when the number of vectors in `kg_vectors` has changed.
184    fn stage1_ann(&self, conn: &Connection, query_vec: &[f32]) -> Result<Vec<(i64, f32)>> {
185        let vector_count: i64 =
186            conn.query_row("SELECT COUNT(*) FROM kg_vectors", [], |r| r.get(0))?;
187
188        if vector_count == 0 {
189            return Ok(Vec::new());
190        }
191
192        // Try to load a valid cached index first.
193        let cached = load_turboquant_cache(conn, vector_count)?;
194        let index = match cached {
195            Some(idx) => idx,
196            None => {
197                // Cache miss or stale — rebuild from kg_vectors.
198                let all_vectors = load_all_vectors(conn)?;
199                let dim = all_vectors[0].1.len();
200                let config = TurboQuantConfig {
201                    dimension: dim,
202                    bit_width: 3,
203                    seed: 42,
204                };
205                let mut idx = TurboQuantIndex::new(config)?;
206                for (entity_id, vec) in &all_vectors {
207                    idx.add_vector(*entity_id, vec)?;
208                }
209                save_turboquant_cache(conn, &idx, vector_count)?;
210                idx
211            }
212        };
213
214        let k = self.config.top_k_candidates.min(vector_count as usize);
215        index.search(query_vec, k)
216    }
217
218    /// Stage 2: exact cosine rerank.
219    fn stage2_rerank(
220        &self,
221        conn: &Connection,
222        query_vec: &[f32],
223        candidates: Vec<(i64, f32)>,
224    ) -> Result<Vec<(i64, f32)>> {
225        let store = VectorStore::new();
226        let mut scored: Vec<(i64, f32)> = Vec::with_capacity(candidates.len());
227
228        for (entity_id, approx) in candidates {
229            // SuperLocalMemory: drop if even the ANN score is below threshold
230            if approx < self.config.min_vector_score {
231                continue;
232            }
233            match store.get_vector(conn, entity_id) {
234                Ok(vec) => {
235                    let exact = cosine_similarity(query_vec, &vec);
236                    if exact >= self.config.min_vector_score {
237                        scored.push((entity_id, exact));
238                    }
239                }
240                Err(_) => {
241                    // entity_id no longer has a vector – skip silently
242                }
243            }
244        }
245
246        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
247        Ok(scored)
248    }
249
250    /// RAPO: BFS expand candidate pool with graph neighbours,
251    /// computing their vector scores on the fly.
252    fn rapo_expand(
253        &self,
254        conn: &Connection,
255        query_vec: &[f32],
256        pool: &mut HashMap<i64, f32>,
257    ) -> Result<()> {
258        let store = VectorStore::new();
259        let seeds: Vec<i64> = pool.keys().copied().collect();
260
261        for seed_id in seeds {
262            let neighbours = match get_neighbors(conn, seed_id, self.config.graph_depth) {
263                Ok(n) => n,
264                Err(_) => continue,
265            };
266
267            for nbr in neighbours {
268                let nbr_id = match nbr.entity.id {
269                    Some(id) => id,
270                    None => continue,
271                };
272
273                if pool.contains_key(&nbr_id) {
274                    continue;
275                }
276
277                // Score the new candidate
278                if let Ok(vec) = store.get_vector(conn, nbr_id) {
279                    let score = cosine_similarity(query_vec, &vec);
280                    if score >= self.config.min_vector_score {
281                        pool.insert(nbr_id, score);
282                    }
283                }
284            }
285        }
286
287        Ok(())
288    }
289
290    /// Compute graph score, combined score, apply SuperLocalMemory filter,
291    /// and build partial RagResult (context filled later).
292    fn score_and_filter(
293        &self,
294        conn: &Connection,
295        pool: &HashMap<i64, f32>,
296        pool_size: usize,
297    ) -> Result<Vec<RagResult>> {
298        let mut results = Vec::new();
299
300        for (&entity_id, &v_score) in pool {
301            let vector_score = v_score as f64;
302
303            // Graph score: fraction of pool that is a direct neighbour
304            let graph_score = if pool_size > 1 {
305                let neighbours = get_neighbors(conn, entity_id, 1).unwrap_or_default();
306                let overlap = neighbours
307                    .iter()
308                    .filter(|n| {
309                        n.entity
310                            .id
311                            .map(|id| pool.contains_key(&id))
312                            .unwrap_or(false)
313                    })
314                    .count();
315                overlap as f64 / (pool_size - 1) as f64
316            } else {
317                0.0
318            };
319
320            let combined_score =
321                self.config.vector_weight * vector_score + self.config.graph_weight * graph_score;
322
323            // SuperLocalMemory quality filter
324            if combined_score < self.config.min_combined_score {
325                continue;
326            }
327
328            let entity = match crate::graph::get_entity(conn, entity_id) {
329                Ok(e) => e,
330                Err(_) => continue,
331            };
332
333            results.push(RagResult {
334                entity,
335                vector_score,
336                graph_score,
337                combined_score,
338                context_entities: Vec::new(), // filled in next pass
339            });
340        }
341
342        Ok(results)
343    }
344
345    /// Memex(RL): collect context neighbours for a result entity via BFS,
346    /// prioritising entities already in the retrieval pool.
347    fn collect_context(
348        &self,
349        conn: &Connection,
350        entity_id: i64,
351        pool: &HashMap<i64, f32>,
352    ) -> Result<Vec<Entity>> {
353        let neighbours = match get_neighbors(conn, entity_id, self.config.context_depth) {
354            Ok(n) => n,
355            Err(_) => return Ok(Vec::new()),
356        };
357
358        // Sort: pool members first (high relevance), then by graph-BFS order
359        let mut in_pool: Vec<Entity> = Vec::new();
360        let mut not_in_pool: Vec<Entity> = Vec::new();
361
362        for nbr in neighbours {
363            if let Some(id) = nbr.entity.id {
364                if pool.contains_key(&id) {
365                    in_pool.push(nbr.entity);
366                } else {
367                    not_in_pool.push(nbr.entity);
368                }
369            }
370        }
371
372        in_pool.extend(not_in_pool);
373        in_pool.truncate(self.config.max_context_entities);
374        Ok(in_pool)
375    }
376}
377
378// ─────────────────────────────────────────────────────────────────────────────
379// Utility
380// ─────────────────────────────────────────────────────────────────────────────
381
382/// Load a TurboQuant index from the SQLite cache if it is still valid.
383///
384/// Returns `None` if no cache row exists or the cached vector count does not
385/// match `current_count` (meaning the index is stale).
386fn load_turboquant_cache(
387    conn: &Connection,
388    current_count: i64,
389) -> Result<Option<TurboQuantIndex>> {
390    let mut stmt = conn.prepare(
391        "SELECT index_blob, vector_count FROM kg_turboquant_cache WHERE id = 1",
392    )?;
393
394    let result = stmt.query_row([], |row| {
395        let blob: Vec<u8> = row.get(0)?;
396        let cached_count: i64 = row.get(1)?;
397        Ok((blob, cached_count))
398    });
399
400    match result {
401        Ok((blob, cached_count)) if cached_count == current_count => {
402            let index = TurboQuantIndex::from_bytes(&blob)
403                .map_err(|e| crate::error::Error::Other(e.to_string()))?;
404            Ok(Some(index))
405        }
406        Ok(_) | Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
407        Err(e) => Err(e.into()),
408    }
409}
410
411/// Persist a TurboQuant index into `kg_turboquant_cache` (upsert).
412fn save_turboquant_cache(
413    conn: &Connection,
414    index: &TurboQuantIndex,
415    vector_count: i64,
416) -> Result<()> {
417    let blob = index
418        .to_bytes()
419        .map_err(|e| crate::error::Error::Other(e.to_string()))?;
420    conn.execute(
421        "INSERT INTO kg_turboquant_cache (id, index_blob, vector_count) \
422         VALUES (1, ?1, ?2) \
423         ON CONFLICT(id) DO UPDATE SET index_blob = excluded.index_blob, \
424                                       vector_count = excluded.vector_count",
425        rusqlite::params![blob, vector_count],
426    )?;
427    Ok(())
428}
429
430fn load_all_vectors(conn: &Connection) -> Result<Vec<(i64, Vec<f32>)>> {
431    let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
432
433    let rows = stmt.query_map([], |row| {
434        let entity_id: i64 = row.get(0)?;
435        let blob: Vec<u8> = row.get(1)?;
436        let dim: i64 = row.get(2)?;
437
438        let mut vec = Vec::with_capacity(dim as usize);
439        for chunk in blob.chunks_exact(4) {
440            vec.push(f32::from_le_bytes(chunk.try_into().unwrap()));
441        }
442
443        Ok((entity_id, vec))
444    })?;
445
446    let mut out = Vec::new();
447    for row in rows {
448        out.push(row?);
449    }
450    Ok(out)
451}
452
453// ─────────────────────────────────────────────────────────────────────────────
454// Tests
455// ─────────────────────────────────────────────────────────────────────────────
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use crate::graph::entity::{insert_entity, Entity};
461    use crate::graph::relation::{insert_relation, Relation};
462    use crate::rag::embedder::FixedEmbedder;
463    use crate::vector::VectorStore;
464    use rusqlite::Connection;
465
466    fn setup(dim: usize) -> (Connection, Vec<i64>) {
467        let conn = Connection::open_in_memory().unwrap();
468        crate::schema::create_schema(&conn).unwrap();
469
470        let e1 = insert_entity(&conn, &Entity::new("doc", "Doc A")).unwrap();
471        let e2 = insert_entity(&conn, &Entity::new("doc", "Doc B")).unwrap();
472        let e3 = insert_entity(&conn, &Entity::new("doc", "Doc C")).unwrap();
473
474        let store = VectorStore::new();
475        // e1 is very similar to query [1, 0, …]
476        let mut v1 = vec![0.0f32; dim];
477        v1[0] = 1.0;
478        store.insert_vector(&conn, e1, v1).unwrap();
479
480        // e2 is orthogonal to query
481        let mut v2 = vec![0.0f32; dim];
482        v2[1] = 1.0;
483        store.insert_vector(&conn, e2, v2).unwrap();
484
485        // e3 is somewhat similar
486        let mut v3 = vec![0.0f32; dim];
487        v3[0] = 0.8;
488        v3[1] = 0.6;
489        store.insert_vector(&conn, e3, v3).unwrap();
490
491        // e1 → e2 (weak link); e1 → e3 (strong link)
492        insert_relation(&conn, &Relation::new(e1, e2, "related", 0.3).unwrap()).unwrap();
493        insert_relation(&conn, &Relation::new(e1, e3, "related", 0.9).unwrap()).unwrap();
494
495        (conn, vec![e1, e2, e3])
496    }
497
498    #[test]
499    fn test_basic_search() {
500        let dim = 4;
501        let (conn, ids) = setup(dim);
502
503        let mut query = vec![0.0f32; dim];
504        query[0] = 1.0;
505
506        let embedder = FixedEmbedder(query);
507        let engine = RagEngine::new(RagConfig {
508            vector_dimension: dim,
509            top_k_candidates: 10,
510            top_k_rerank: 5,
511            ..Default::default()
512        });
513
514        let results = engine.search(&conn, &embedder, "test query", 2).unwrap();
515        assert!(!results.is_empty(), "should return at least one result");
516
517        // e1 must be the top result (similarity = 1.0)
518        assert_eq!(results[0].entity.id, Some(ids[0]));
519        assert!((results[0].vector_score - 1.0).abs() < 1e-5);
520    }
521
522    #[test]
523    fn test_empty_db() {
524        let conn = Connection::open_in_memory().unwrap();
525        crate::schema::create_schema(&conn).unwrap();
526
527        let embedder = FixedEmbedder(vec![1.0, 0.0, 0.0]);
528        let engine = RagEngine::new(RagConfig::default());
529
530        let results = engine.search(&conn, &embedder, "anything", 5).unwrap();
531        assert!(results.is_empty());
532    }
533
534    #[test]
535    fn test_graph_expansion() {
536        // Verify that RAPO brings in neighbours that were not in the ANN results
537        let dim = 4;
538        let conn = Connection::open_in_memory().unwrap();
539        crate::schema::create_schema(&conn).unwrap();
540
541        let store = VectorStore::new();
542
543        // e1 very similar to query; e2 orthogonal; e1→e2 link
544        let e1 = insert_entity(&conn, &Entity::new("doc", "A")).unwrap();
545        let e2 = insert_entity(&conn, &Entity::new("doc", "B")).unwrap();
546
547        let mut v1 = vec![0.0f32; dim];
548        v1[0] = 1.0;
549        store.insert_vector(&conn, e1, v1).unwrap();
550
551        let mut v2 = vec![0.0f32; dim];
552        v2[1] = 1.0;
553        store.insert_vector(&conn, e2, v2).unwrap();
554
555        insert_relation(&conn, &Relation::new(e1, e2, "link", 1.0).unwrap()).unwrap();
556
557        let mut query = vec![0.0f32; dim];
558        query[0] = 1.0;
559
560        let embedder = FixedEmbedder(query);
561        let engine = RagEngine::new(RagConfig {
562            vector_dimension: dim,
563            top_k_candidates: 1, // only ANN fetches e1; RAPO adds e2
564            top_k_rerank: 1,
565            enable_graph_expansion: true,
566            ..Default::default()
567        });
568
569        let results = engine.search(&conn, &embedder, "q", 5).unwrap();
570        let ids: Vec<i64> = results.iter().filter_map(|r| r.entity.id).collect();
571        assert!(ids.contains(&e1));
572        assert!(ids.contains(&e2), "RAPO should expand to e2");
573    }
574
575    #[test]
576    fn test_context_attached() {
577        let dim = 4;
578        let (conn, ids) = setup(dim);
579
580        let mut query = vec![0.0f32; dim];
581        query[0] = 1.0;
582
583        let embedder = FixedEmbedder(query);
584        let engine = RagEngine::new(RagConfig {
585            vector_dimension: dim,
586            context_depth: 1,
587            max_context_entities: 3,
588            ..Default::default()
589        });
590
591        let results = engine.search(&conn, &embedder, "q", 3).unwrap();
592
593        // e1's result should have context neighbours (e2 and e3)
594        let e1_result = results.iter().find(|r| r.entity.id == Some(ids[0]));
595        assert!(e1_result.is_some());
596        let ctx = &e1_result.unwrap().context_entities;
597        assert!(!ctx.is_empty(), "e1 should have context neighbours");
598    }
599}