Skip to main content

sqlite_knowledge_graph/rag/
mod.rs

1//! Paper-driven two-stage RAG engine.
2//!
3//! Pipeline (per-query):
4//! ```text
5//! query: &str
6//!   → Embedder::embed()                      [plug-in]
7//!   → [Stage 1 · MemRL]  TurboQuantIndex ANN (top_k_candidates)
8//!   → [Stage 2 · MemRL]  exact cosine rerank (top_k_rerank)
9//!   → [RAPO]             BFS graph expansion + vector score
10//!   → [combined score]   vector_weight·v + graph_weight·g
11//!   → [SuperLocalMemory] quality filter (min thresholds)
12//!   → sort desc, take k
13//!   → [Memex(RL)]        context BFS for each result (max_context_entities)
14//! ```
15//!
16//! References:
17//! - MemRL (2601.03192): two-stage ANN → exact rerank
18//! - RAPO  (2603.02958): graph-neighbour expansion
19//! - Memex (2603.03561): context-entity sizing
20//! - SuperLocalMemory (2602.13398): quality threshold filtering
21//! - NN-RAG (2511.20333): retrieval quality over quantity
22
23pub mod embedder;
24mod error;
25
26pub use embedder::Embedder;
27pub use error::RagError;
28
29use crate::error::Result;
30use crate::graph::{get_neighbors, Entity};
31use crate::vector::{cosine_similarity, TurboQuantConfig, TurboQuantIndex, VectorStore};
32use rusqlite::Connection;
33use std::collections::HashMap;
34
35// ─────────────────────────────────────────────────────────────────────────────
36// Public types
37// ─────────────────────────────────────────────────────────────────────────────
38
39/// One result row from the RAG engine.
40#[derive(Debug, Clone)]
41pub struct RagResult {
42    /// The matched entity.
43    pub entity: Entity,
44    /// Exact cosine similarity score in [0, 1].
45    pub vector_score: f64,
46    /// Graph connectivity score in [0, 1] (fraction of pool that is a neighbour).
47    pub graph_score: f64,
48    /// Final weighted score: `vector_weight·v + graph_weight·g`.
49    pub combined_score: f64,
50    /// Context neighbours collected by BFS (Memex(RL) sizing).
51    pub context_entities: Vec<Entity>,
52}
53
54/// Configuration knobs for the RAG pipeline.
55/// All fields have defaults tuned to the paper recommendations.
56#[derive(Debug, Clone)]
57pub struct RagConfig {
58    // ── Scoring ──────────────────────────────────────────────────────────────
59    /// Weight applied to the vector (semantic) score. Default: 0.6.
60    pub vector_weight: f64,
61    /// Weight applied to the graph connectivity score. Default: 0.4.
62    pub graph_weight: f64,
63
64    // ── MemRL two-stage retrieval ─────────────────────────────────────────────
65    /// Stage-1: how many ANN candidates to fetch from TurboQuant. Default: 50.
66    pub top_k_candidates: usize,
67    /// Stage-2: how many candidates survive after exact rerank. Default: 20.
68    pub top_k_rerank: usize,
69
70    // ── RAPO graph expansion ──────────────────────────────────────────────────
71    /// Whether to expand candidates via BFS neighbours. Default: true.
72    pub enable_graph_expansion: bool,
73    /// BFS depth for graph-score neighbour collection. Default: 1.
74    pub graph_depth: u32,
75
76    // ── Memex(RL) context sizing ──────────────────────────────────────────────
77    /// BFS depth for context collection. Default: 2.
78    pub context_depth: u32,
79    /// Maximum context entities attached to each result. Default: 5.
80    pub max_context_entities: usize,
81
82    // ── SuperLocalMemory quality thresholds ───────────────────────────────────
83    /// Minimum vector score for a candidate to survive. Default: 0.0 (off).
84    pub min_vector_score: f32,
85    /// Minimum combined score for a result to survive. Default: 0.0 (off).
86    pub min_combined_score: f64,
87
88    // ── TurboQuant index ─────────────────────────────────────────────────────
89    /// Vector dimension; set to match your embedding model. Default: 384.
90    pub vector_dimension: usize,
91}
92
93impl Default for RagConfig {
94    fn default() -> Self {
95        Self {
96            vector_weight: 0.6,
97            graph_weight: 0.4,
98            top_k_candidates: 50,
99            top_k_rerank: 20,
100            enable_graph_expansion: true,
101            graph_depth: 1,
102            context_depth: 2,
103            max_context_entities: 5,
104            min_vector_score: 0.0,
105            min_combined_score: 0.0,
106            vector_dimension: 384,
107        }
108    }
109}
110
111// ─────────────────────────────────────────────────────────────────────────────
112// RagEngine
113// ─────────────────────────────────────────────────────────────────────────────
114
115/// Hybrid RAG engine backed by SQLite.
116pub struct RagEngine {
117    config: RagConfig,
118}
119
120impl RagEngine {
121    pub fn new(config: RagConfig) -> Self {
122        Self { config }
123    }
124
125    /// Full hybrid search.
126    ///
127    /// # Arguments
128    /// * `conn`   – active SQLite connection (read/write access required)
129    /// * `embedder` – embedding backend (see `embedder::SubprocessEmbedder`)
130    /// * `query`  – raw query text
131    /// * `k`      – how many results to return
132    pub fn search(
133        &self,
134        conn: &Connection,
135        embedder: &dyn Embedder,
136        query: &str,
137        k: usize,
138    ) -> Result<Vec<RagResult>> {
139        // ── 0. Embed query ───────────────────────────────────────────────────
140        let query_vec = embedder.embed(query)?;
141
142        // ── Stage 1 · MemRL – fast ANN via TurboQuant ───────────────────────
143        let ann_candidates = self.stage1_ann(conn, &query_vec)?;
144
145        if ann_candidates.is_empty() {
146            return Ok(Vec::new());
147        }
148
149        // ── Stage 2 · MemRL – exact cosine rerank ───────────────────────────
150        let mut reranked = self.stage2_rerank(conn, &query_vec, ann_candidates)?;
151        reranked.truncate(self.config.top_k_rerank);
152
153        // ── RAPO – expand with graph neighbours ─────────────────────────────
154        let mut pool: HashMap<i64, f32> = reranked.into_iter().collect();
155        if self.config.enable_graph_expansion {
156            self.rapo_expand(conn, &query_vec, &mut pool)?;
157        }
158
159        // ── Score & filter (SuperLocalMemory) ───────────────────────────────
160        let pool_size = pool.len();
161        let mut scored = self.score_and_filter(conn, &pool, pool_size)?;
162
163        // Sort by combined_score descending, take top k
164        scored.sort_by(|a, b| b.combined_score.partial_cmp(&a.combined_score).unwrap());
165        scored.truncate(k);
166
167        // ── Memex(RL) – attach context neighbours ───────────────────────────
168        for result in &mut scored {
169            let entity_id = result.entity.id.unwrap_or(0);
170            result.context_entities = self.collect_context(conn, entity_id, &pool)?;
171        }
172
173        Ok(scored)
174    }
175
176    // ─────────────────────────────────────────────────────────────────────────
177    // Private helpers
178    // ─────────────────────────────────────────────────────────────────────────
179
180    /// Stage 1: ANN via TurboQuant.  Returns (entity_id, approx_score) pairs.
181    ///
182    /// The TurboQuant index is persisted in `kg_turboquant_cache` and rebuilt
183    /// whenever either the *count* or the *checksum* of vectors in `kg_vectors`
184    /// has changed.  Using both metrics catches the case where one vector is
185    /// deleted and a different one is inserted (count stays the same but the
186    /// set of entity_ids — and therefore their SUM — changes).
187    fn stage1_ann(&self, conn: &Connection, query_vec: &[f32]) -> Result<Vec<(i64, f32)>> {
188        let vector_count: i64 =
189            conn.query_row("SELECT COUNT(*) FROM kg_vectors", [], |r| r.get(0))?;
190
191        if vector_count == 0 {
192            return Ok(Vec::new());
193        }
194
195        // Lightweight fingerprint: SUM of all entity_ids in kg_vectors.
196        // Autoincrement IDs are never reused, so any insert/delete changes this.
197        let vectors_checksum: i64 = conn.query_row(
198            "SELECT COALESCE(SUM(entity_id), 0) FROM kg_vectors",
199            [],
200            |r| r.get(0),
201        )?;
202
203        // Try to load a valid cached index first.
204        let cached = load_turboquant_cache(conn, vector_count, vectors_checksum)?;
205        let index = match cached {
206            Some(idx) => idx,
207            None => {
208                // Cache miss or stale — rebuild from kg_vectors.
209                let all_vectors = load_all_vectors(conn)?;
210                let dim = all_vectors[0].1.len();
211                let config = TurboQuantConfig {
212                    dimension: dim,
213                    bit_width: 3,
214                    seed: 42,
215                };
216                let mut idx = TurboQuantIndex::new(config)?;
217                for (entity_id, vec) in &all_vectors {
218                    idx.add_vector(*entity_id, vec)?;
219                }
220                save_turboquant_cache(conn, &idx, vector_count, vectors_checksum)?;
221                idx
222            }
223        };
224
225        let k = self.config.top_k_candidates.min(vector_count as usize);
226        index.search(query_vec, k)
227    }
228
229    /// Stage 2: exact cosine rerank.
230    fn stage2_rerank(
231        &self,
232        conn: &Connection,
233        query_vec: &[f32],
234        candidates: Vec<(i64, f32)>,
235    ) -> Result<Vec<(i64, f32)>> {
236        let store = VectorStore::new();
237        let mut scored: Vec<(i64, f32)> = Vec::with_capacity(candidates.len());
238
239        for (entity_id, approx) in candidates {
240            // SuperLocalMemory: drop if even the ANN score is below threshold
241            if approx < self.config.min_vector_score {
242                continue;
243            }
244            match store.get_vector(conn, entity_id) {
245                Ok(vec) => {
246                    let exact = cosine_similarity(query_vec, &vec);
247                    if exact >= self.config.min_vector_score {
248                        scored.push((entity_id, exact));
249                    }
250                }
251                Err(_) => {
252                    // entity_id no longer has a vector – skip silently
253                }
254            }
255        }
256
257        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
258        Ok(scored)
259    }
260
261    /// RAPO: BFS expand candidate pool with graph neighbours,
262    /// computing their vector scores on the fly.
263    fn rapo_expand(
264        &self,
265        conn: &Connection,
266        query_vec: &[f32],
267        pool: &mut HashMap<i64, f32>,
268    ) -> Result<()> {
269        let store = VectorStore::new();
270        let seeds: Vec<i64> = pool.keys().copied().collect();
271
272        for seed_id in seeds {
273            let neighbours = match get_neighbors(conn, seed_id, self.config.graph_depth) {
274                Ok(n) => n,
275                Err(_) => continue,
276            };
277
278            for nbr in neighbours {
279                let nbr_id = match nbr.entity.id {
280                    Some(id) => id,
281                    None => continue,
282                };
283
284                if pool.contains_key(&nbr_id) {
285                    continue;
286                }
287
288                // Score the new candidate
289                if let Ok(vec) = store.get_vector(conn, nbr_id) {
290                    let score = cosine_similarity(query_vec, &vec);
291                    if score >= self.config.min_vector_score {
292                        pool.insert(nbr_id, score);
293                    }
294                }
295            }
296        }
297
298        Ok(())
299    }
300
301    /// Compute graph score, combined score, apply SuperLocalMemory filter,
302    /// and build partial RagResult (context filled later).
303    fn score_and_filter(
304        &self,
305        conn: &Connection,
306        pool: &HashMap<i64, f32>,
307        pool_size: usize,
308    ) -> Result<Vec<RagResult>> {
309        let mut results = Vec::new();
310
311        for (&entity_id, &v_score) in pool {
312            let vector_score = v_score as f64;
313
314            // Graph score: fraction of pool that is a direct neighbour
315            let graph_score = if pool_size > 1 {
316                let neighbours = get_neighbors(conn, entity_id, 1).unwrap_or_default();
317                let overlap = neighbours
318                    .iter()
319                    .filter(|n| {
320                        n.entity
321                            .id
322                            .map(|id| pool.contains_key(&id))
323                            .unwrap_or(false)
324                    })
325                    .count();
326                overlap as f64 / (pool_size - 1) as f64
327            } else {
328                0.0
329            };
330
331            let combined_score =
332                self.config.vector_weight * vector_score + self.config.graph_weight * graph_score;
333
334            // SuperLocalMemory quality filter
335            if combined_score < self.config.min_combined_score {
336                continue;
337            }
338
339            let entity = match crate::graph::get_entity(conn, entity_id) {
340                Ok(e) => e,
341                Err(_) => continue,
342            };
343
344            results.push(RagResult {
345                entity,
346                vector_score,
347                graph_score,
348                combined_score,
349                context_entities: Vec::new(), // filled in next pass
350            });
351        }
352
353        Ok(results)
354    }
355
356    /// Memex(RL): collect context neighbours for a result entity via BFS,
357    /// prioritising entities already in the retrieval pool.
358    fn collect_context(
359        &self,
360        conn: &Connection,
361        entity_id: i64,
362        pool: &HashMap<i64, f32>,
363    ) -> Result<Vec<Entity>> {
364        let neighbours = match get_neighbors(conn, entity_id, self.config.context_depth) {
365            Ok(n) => n,
366            Err(_) => return Ok(Vec::new()),
367        };
368
369        // Sort: pool members first (high relevance), then by graph-BFS order
370        let mut in_pool: Vec<Entity> = Vec::new();
371        let mut not_in_pool: Vec<Entity> = Vec::new();
372
373        for nbr in neighbours {
374            if let Some(id) = nbr.entity.id {
375                if pool.contains_key(&id) {
376                    in_pool.push(nbr.entity);
377                } else {
378                    not_in_pool.push(nbr.entity);
379                }
380            }
381        }
382
383        in_pool.extend(not_in_pool);
384        in_pool.truncate(self.config.max_context_entities);
385        Ok(in_pool)
386    }
387}
388
389// ─────────────────────────────────────────────────────────────────────────────
390// Utility
391// ─────────────────────────────────────────────────────────────────────────────
392
393/// Load a TurboQuant index from the SQLite cache if it is still valid.
394///
395/// Returns `None` if no cache row exists or either the vector count or the
396/// checksum does not match, indicating the index is stale.
397fn load_turboquant_cache(
398    conn: &Connection,
399    current_count: i64,
400    current_checksum: i64,
401) -> Result<Option<TurboQuantIndex>> {
402    let mut stmt = conn.prepare(
403        "SELECT index_blob, vector_count, vectors_checksum \
404         FROM kg_turboquant_cache WHERE id = 1",
405    )?;
406
407    let result = stmt.query_row([], |row| {
408        let blob: Vec<u8> = row.get(0)?;
409        let cached_count: i64 = row.get(1)?;
410        let cached_checksum: i64 = row.get(2)?;
411        Ok((blob, cached_count, cached_checksum))
412    });
413
414    match result {
415        Ok((blob, cached_count, cached_checksum))
416            if cached_count == current_count && cached_checksum == current_checksum =>
417        {
418            let index = TurboQuantIndex::from_bytes(&blob)
419                .map_err(|e| crate::error::Error::Other(e.to_string()))?;
420            Ok(Some(index))
421        }
422        Ok(_) | Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
423        Err(e) => Err(e.into()),
424    }
425}
426
427/// Persist a TurboQuant index into `kg_turboquant_cache` (upsert).
428fn save_turboquant_cache(
429    conn: &Connection,
430    index: &TurboQuantIndex,
431    vector_count: i64,
432    vectors_checksum: i64,
433) -> Result<()> {
434    let blob = index
435        .to_bytes()
436        .map_err(|e| crate::error::Error::Other(e.to_string()))?;
437    conn.execute(
438        "INSERT INTO kg_turboquant_cache \
439             (id, index_blob, vector_count, vectors_checksum) \
440         VALUES (1, ?1, ?2, ?3) \
441         ON CONFLICT(id) DO UPDATE SET \
442             index_blob        = excluded.index_blob, \
443             vector_count      = excluded.vector_count, \
444             vectors_checksum  = excluded.vectors_checksum",
445        rusqlite::params![blob, vector_count, vectors_checksum],
446    )?;
447    Ok(())
448}
449
450fn load_all_vectors(conn: &Connection) -> Result<Vec<(i64, Vec<f32>)>> {
451    let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
452
453    let rows = stmt.query_map([], |row| {
454        let entity_id: i64 = row.get(0)?;
455        let blob: Vec<u8> = row.get(1)?;
456        let dim: i64 = row.get(2)?;
457
458        let mut vec = Vec::with_capacity(dim as usize);
459        for chunk in blob.chunks_exact(4) {
460            vec.push(f32::from_le_bytes(chunk.try_into().unwrap()));
461        }
462
463        Ok((entity_id, vec))
464    })?;
465
466    let mut out = Vec::new();
467    for row in rows {
468        out.push(row?);
469    }
470    Ok(out)
471}
472
473// ─────────────────────────────────────────────────────────────────────────────
474// Tests
475// ─────────────────────────────────────────────────────────────────────────────
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use crate::graph::entity::{insert_entity, Entity};
481    use crate::graph::relation::{insert_relation, Relation};
482    use crate::rag::embedder::FixedEmbedder;
483    use crate::vector::VectorStore;
484    use rusqlite::Connection;
485
486    fn setup(dim: usize) -> (Connection, Vec<i64>) {
487        let conn = Connection::open_in_memory().unwrap();
488        crate::schema::create_schema(&conn).unwrap();
489
490        let e1 = insert_entity(&conn, &Entity::new("doc", "Doc A")).unwrap();
491        let e2 = insert_entity(&conn, &Entity::new("doc", "Doc B")).unwrap();
492        let e3 = insert_entity(&conn, &Entity::new("doc", "Doc C")).unwrap();
493
494        let store = VectorStore::new();
495        // e1 is very similar to query [1, 0, …]
496        let mut v1 = vec![0.0f32; dim];
497        v1[0] = 1.0;
498        store.insert_vector(&conn, e1, v1).unwrap();
499
500        // e2 is orthogonal to query
501        let mut v2 = vec![0.0f32; dim];
502        v2[1] = 1.0;
503        store.insert_vector(&conn, e2, v2).unwrap();
504
505        // e3 is somewhat similar
506        let mut v3 = vec![0.0f32; dim];
507        v3[0] = 0.8;
508        v3[1] = 0.6;
509        store.insert_vector(&conn, e3, v3).unwrap();
510
511        // e1 → e2 (weak link); e1 → e3 (strong link)
512        insert_relation(&conn, &Relation::new(e1, e2, "related", 0.3).unwrap()).unwrap();
513        insert_relation(&conn, &Relation::new(e1, e3, "related", 0.9).unwrap()).unwrap();
514
515        (conn, vec![e1, e2, e3])
516    }
517
518    #[test]
519    fn test_basic_search() {
520        let dim = 4;
521        let (conn, ids) = setup(dim);
522
523        let mut query = vec![0.0f32; dim];
524        query[0] = 1.0;
525
526        let embedder = FixedEmbedder(query);
527        let engine = RagEngine::new(RagConfig {
528            vector_dimension: dim,
529            top_k_candidates: 10,
530            top_k_rerank: 5,
531            ..Default::default()
532        });
533
534        let results = engine.search(&conn, &embedder, "test query", 2).unwrap();
535        assert!(!results.is_empty(), "should return at least one result");
536
537        // e1 must be the top result (similarity = 1.0)
538        assert_eq!(results[0].entity.id, Some(ids[0]));
539        assert!((results[0].vector_score - 1.0).abs() < 1e-5);
540    }
541
542    #[test]
543    fn test_empty_db() {
544        let conn = Connection::open_in_memory().unwrap();
545        crate::schema::create_schema(&conn).unwrap();
546
547        let embedder = FixedEmbedder(vec![1.0, 0.0, 0.0]);
548        let engine = RagEngine::new(RagConfig::default());
549
550        let results = engine.search(&conn, &embedder, "anything", 5).unwrap();
551        assert!(results.is_empty());
552    }
553
554    #[test]
555    fn test_graph_expansion() {
556        // Verify that RAPO brings in neighbours that were not in the ANN results
557        let dim = 4;
558        let conn = Connection::open_in_memory().unwrap();
559        crate::schema::create_schema(&conn).unwrap();
560
561        let store = VectorStore::new();
562
563        // e1 very similar to query; e2 orthogonal; e1→e2 link
564        let e1 = insert_entity(&conn, &Entity::new("doc", "A")).unwrap();
565        let e2 = insert_entity(&conn, &Entity::new("doc", "B")).unwrap();
566
567        let mut v1 = vec![0.0f32; dim];
568        v1[0] = 1.0;
569        store.insert_vector(&conn, e1, v1).unwrap();
570
571        let mut v2 = vec![0.0f32; dim];
572        v2[1] = 1.0;
573        store.insert_vector(&conn, e2, v2).unwrap();
574
575        insert_relation(&conn, &Relation::new(e1, e2, "link", 1.0).unwrap()).unwrap();
576
577        let mut query = vec![0.0f32; dim];
578        query[0] = 1.0;
579
580        let embedder = FixedEmbedder(query);
581        let engine = RagEngine::new(RagConfig {
582            vector_dimension: dim,
583            top_k_candidates: 1, // only ANN fetches e1; RAPO adds e2
584            top_k_rerank: 1,
585            enable_graph_expansion: true,
586            ..Default::default()
587        });
588
589        let results = engine.search(&conn, &embedder, "q", 5).unwrap();
590        let ids: Vec<i64> = results.iter().filter_map(|r| r.entity.id).collect();
591        assert!(ids.contains(&e1));
592        assert!(ids.contains(&e2), "RAPO should expand to e2");
593    }
594
595    #[test]
596    fn test_context_attached() {
597        let dim = 4;
598        let (conn, ids) = setup(dim);
599
600        let mut query = vec![0.0f32; dim];
601        query[0] = 1.0;
602
603        let embedder = FixedEmbedder(query);
604        let engine = RagEngine::new(RagConfig {
605            vector_dimension: dim,
606            context_depth: 1,
607            max_context_entities: 3,
608            ..Default::default()
609        });
610
611        let results = engine.search(&conn, &embedder, "q", 3).unwrap();
612
613        // e1's result should have context neighbours (e2 and e3)
614        let e1_result = results.iter().find(|r| r.entity.id == Some(ids[0]));
615        assert!(e1_result.is_some());
616        let ctx = &e1_result.unwrap().context_entities;
617        assert!(!ctx.is_empty(), "e1 should have context neighbours");
618    }
619
620    // ── TurboQuant cache tests ────────────────────────────────────────────────
621
622    #[test]
623    fn test_cache_written_on_first_query() {
624        let dim = 4;
625        let (conn, _ids) = setup(dim);
626
627        let mut query = vec![0.0f32; dim];
628        query[0] = 1.0;
629        let embedder = FixedEmbedder(query);
630        let engine = RagEngine::new(RagConfig {
631            vector_dimension: dim,
632            top_k_candidates: 10,
633            top_k_rerank: 5,
634            ..Default::default()
635        });
636
637        engine.search(&conn, &embedder, "q", 2).unwrap();
638
639        // Cache row must exist after the first search
640        let count: i64 = conn
641            .query_row(
642                "SELECT COUNT(*) FROM kg_turboquant_cache WHERE id = 1",
643                [],
644                |r| r.get(0),
645            )
646            .unwrap();
647        assert_eq!(count, 1, "cache row should be created after first query");
648    }
649
650    #[test]
651    fn test_cache_hit_on_second_query() {
652        let dim = 4;
653        let (conn, _ids) = setup(dim);
654
655        let mut query = vec![0.0f32; dim];
656        query[0] = 1.0;
657        let embedder = FixedEmbedder(query);
658        let engine = RagEngine::new(RagConfig {
659            vector_dimension: dim,
660            top_k_candidates: 10,
661            top_k_rerank: 5,
662            ..Default::default()
663        });
664
665        let r1 = engine.search(&conn, &embedder, "q", 2).unwrap();
666        let r2 = engine.search(&conn, &embedder, "q", 2).unwrap();
667
668        // Both searches return the same top entity
669        assert_eq!(
670            r1[0].entity.id, r2[0].entity.id,
671            "cache hit should return identical results"
672        );
673    }
674
675    #[test]
676    fn test_cache_stores_checksum() {
677        let dim = 4;
678        let (conn, _ids) = setup(dim);
679
680        let query = {
681            let mut q = vec![0.0f32; dim];
682            q[0] = 1.0;
683            q
684        };
685        let embedder = FixedEmbedder(query);
686        let engine = RagEngine::new(RagConfig {
687            vector_dimension: dim,
688            top_k_candidates: 10,
689            top_k_rerank: 5,
690            ..Default::default()
691        });
692
693        engine.search(&conn, &embedder, "q", 2).unwrap();
694
695        // Both count and checksum must be stored.
696        let (count, checksum): (i64, i64) = conn
697            .query_row(
698                "SELECT vector_count, vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
699                [],
700                |r| Ok((r.get(0)?, r.get(1)?)),
701            )
702            .unwrap();
703        assert_eq!(count, 3);
704        // SUM of entity_ids for 3 autoincrement entities must be > 0.
705        assert!(checksum > 0, "checksum should reflect entity_id sum");
706    }
707
708    #[test]
709    fn test_cache_invalidated_on_same_count_different_entity() {
710        // Verifies that swapping one vector (same count, different entity_id)
711        // causes a cache miss via the checksum.
712        let dim = 4;
713        let (conn, ids) = setup(dim); // 3 vectors: entity_ids ids[0], ids[1], ids[2]
714
715        let query = {
716            let mut q = vec![0.0f32; dim];
717            q[0] = 1.0;
718            q
719        };
720        let embedder = FixedEmbedder(query);
721        let engine = RagEngine::new(RagConfig {
722            vector_dimension: dim,
723            top_k_candidates: 10,
724            top_k_rerank: 5,
725            ..Default::default()
726        });
727
728        // First search — writes cache (count=3, checksum=ids[0]+ids[1]+ids[2]).
729        engine.search(&conn, &embedder, "q", 2).unwrap();
730
731        let checksum_before: i64 = conn
732            .query_row(
733                "SELECT vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
734                [],
735                |r| r.get(0),
736            )
737            .unwrap();
738
739        // Delete one vector and insert a brand-new one (autoincrement → higher id).
740        conn.execute("DELETE FROM kg_vectors WHERE entity_id = ?1", [ids[2]])
741            .unwrap();
742        let e_new = crate::graph::entity::insert_entity(
743            &conn,
744            &crate::graph::entity::Entity::new("doc", "Doc Swap"),
745        )
746        .unwrap();
747        let store = VectorStore::new();
748        let mut v_new = vec![0.0f32; dim];
749        v_new[3] = 1.0;
750        store.insert_vector(&conn, e_new, v_new).unwrap();
751        // count is still 3, but checksum changed because e_new.id != ids[2].
752
753        // Second search — must detect checksum mismatch → rebuild.
754        engine.search(&conn, &embedder, "q", 2).unwrap();
755
756        let (count_after, checksum_after): (i64, i64) = conn
757            .query_row(
758                "SELECT vector_count, vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
759                [],
760                |r| Ok((r.get(0)?, r.get(1)?)),
761            )
762            .unwrap();
763        assert_eq!(count_after, 3, "vector count should still be 3 after swap");
764        assert_ne!(
765            checksum_after, checksum_before,
766            "checksum must change after swapping one vector"
767        );
768    }
769
770    #[test]
771    fn test_cache_invalidated_after_new_vector() {
772        let dim = 4;
773        let (conn, _ids) = setup(dim);
774
775        let mut query = vec![0.0f32; dim];
776        query[0] = 1.0;
777        let embedder = FixedEmbedder(query);
778        let engine = RagEngine::new(RagConfig {
779            vector_dimension: dim,
780            top_k_candidates: 10,
781            top_k_rerank: 5,
782            ..Default::default()
783        });
784
785        // First search — writes cache with vector_count = 3
786        engine.search(&conn, &embedder, "q", 2).unwrap();
787
788        let cached_count_before: i64 = conn
789            .query_row(
790                "SELECT vector_count FROM kg_turboquant_cache WHERE id = 1",
791                [],
792                |r| r.get(0),
793            )
794            .unwrap();
795        assert_eq!(cached_count_before, 3);
796
797        // Add a 4th vector
798        let e4 = crate::graph::entity::insert_entity(
799            &conn,
800            &crate::graph::entity::Entity::new("doc", "Doc D"),
801        )
802        .unwrap();
803        let store = VectorStore::new();
804        let mut v4 = vec![0.0f32; dim];
805        v4[2] = 1.0;
806        store.insert_vector(&conn, e4, v4).unwrap();
807
808        // Second search — must rebuild and update cache to vector_count = 4
809        engine.search(&conn, &embedder, "q", 2).unwrap();
810
811        let cached_count_after: i64 = conn
812            .query_row(
813                "SELECT vector_count FROM kg_turboquant_cache WHERE id = 1",
814                [],
815                |r| r.get(0),
816            )
817            .unwrap();
818        assert_eq!(
819            cached_count_after, 4,
820            "cache should be rebuilt after new vector added"
821        );
822    }
823}