Skip to main content

sqlite_knowledge_graph/rag/
mod.rs

1//! Paper-driven two-stage RAG engine.
2//!
3//! Pipeline (per-query):
4//! ```text
5//! query: &str
6//!   → Embedder::embed()                      [plug-in]
7//!   → [Stage 1 · MemRL]  TurboQuantIndex ANN (top_k_candidates)
8//!   → [Stage 2 · MemRL]  exact cosine rerank (top_k_rerank)
9//!   → [RAPO]             BFS graph expansion + vector score
10//!   → [combined score]   vector_weight·v + graph_weight·g
11//!   → [SuperLocalMemory] quality filter (min thresholds)
12//!   → sort desc, take k
13//!   → [Memex(RL)]        context BFS for each result (max_context_entities)
14//! ```
15//!
16//! References:
17//! - MemRL (2601.03192): two-stage ANN → exact rerank
18//! - RAPO  (2603.02958): graph-neighbour expansion
19//! - Memex (2603.03561): context-entity sizing
20//! - SuperLocalMemory (2602.13398): quality threshold filtering
21//! - NN-RAG (2511.20333): retrieval quality over quantity
22
23pub mod embedder;
24mod error;
25pub mod smart_retrieval;
26
27pub use embedder::Embedder;
28pub use error::RagError;
29pub use smart_retrieval::{RetrievalWeights, SmartRetrieval, SmartSearchResult};
30
31use crate::error::Result;
32use crate::graph::{get_neighbors, Entity};
33use crate::vector::{cosine_similarity, TurboQuantConfig, TurboQuantIndex, VectorStore};
34use rusqlite::Connection;
35use std::collections::HashMap;
36
37// ─────────────────────────────────────────────────────────────────────────────
38// Public types
39// ─────────────────────────────────────────────────────────────────────────────
40
41/// One result row from the RAG engine.
42#[derive(Debug, Clone)]
43pub struct RagResult {
44    /// The matched entity.
45    pub entity: Entity,
46    /// Exact cosine similarity score in [0, 1].
47    pub vector_score: f64,
48    /// Graph connectivity score in [0, 1] (fraction of pool that is a neighbour).
49    pub graph_score: f64,
50    /// Final weighted score: `vector_weight·v + graph_weight·g`.
51    pub combined_score: f64,
52    /// Context neighbours collected by BFS (Memex(RL) sizing).
53    pub context_entities: Vec<Entity>,
54}
55
56/// Configuration knobs for the RAG pipeline.
57/// All fields have defaults tuned to the paper recommendations.
58#[derive(Debug, Clone)]
59pub struct RagConfig {
60    // ── Scoring ──────────────────────────────────────────────────────────────
61    /// Weight applied to the vector (semantic) score. Default: 0.6.
62    pub vector_weight: f64,
63    /// Weight applied to the graph connectivity score. Default: 0.4.
64    pub graph_weight: f64,
65
66    // ── MemRL two-stage retrieval ─────────────────────────────────────────────
67    /// Stage-1: how many ANN candidates to fetch from TurboQuant. Default: 50.
68    pub top_k_candidates: usize,
69    /// Stage-2: how many candidates survive after exact rerank. Default: 20.
70    pub top_k_rerank: usize,
71
72    // ── RAPO graph expansion ──────────────────────────────────────────────────
73    /// Whether to expand candidates via BFS neighbours. Default: true.
74    pub enable_graph_expansion: bool,
75    /// BFS depth for graph-score neighbour collection. Default: 1.
76    pub graph_depth: u32,
77
78    // ── Memex(RL) context sizing ──────────────────────────────────────────────
79    /// BFS depth for context collection. Default: 2.
80    pub context_depth: u32,
81    /// Maximum context entities attached to each result. Default: 5.
82    pub max_context_entities: usize,
83
84    // ── SuperLocalMemory quality thresholds ───────────────────────────────────
85    /// Minimum vector score for a candidate to survive. Default: 0.0 (off).
86    pub min_vector_score: f32,
87    /// Minimum combined score for a result to survive. Default: 0.0 (off).
88    pub min_combined_score: f64,
89
90    // ── TurboQuant index ─────────────────────────────────────────────────────
91    /// Vector dimension; set to match your embedding model. Default: 384.
92    pub vector_dimension: usize,
93}
94
95impl Default for RagConfig {
96    fn default() -> Self {
97        Self {
98            vector_weight: 0.6,
99            graph_weight: 0.4,
100            top_k_candidates: 50,
101            top_k_rerank: 20,
102            enable_graph_expansion: true,
103            graph_depth: 1,
104            context_depth: 2,
105            max_context_entities: 5,
106            min_vector_score: 0.0,
107            min_combined_score: 0.0,
108            vector_dimension: 384,
109        }
110    }
111}
112
113// ─────────────────────────────────────────────────────────────────────────────
114// RagEngine
115// ─────────────────────────────────────────────────────────────────────────────
116
117/// Hybrid RAG engine backed by SQLite.
118pub struct RagEngine {
119    config: RagConfig,
120}
121
122impl RagEngine {
123    pub fn new(config: RagConfig) -> Self {
124        Self { config }
125    }
126
127    /// Full hybrid search.
128    ///
129    /// # Arguments
130    /// * `conn`   – active SQLite connection (read/write access required)
131    /// * `embedder` – embedding backend (see `embedder::SubprocessEmbedder`)
132    /// * `query`  – raw query text
133    /// * `k`      – how many results to return
134    pub fn search(
135        &self,
136        conn: &Connection,
137        embedder: &dyn Embedder,
138        query: &str,
139        k: usize,
140    ) -> Result<Vec<RagResult>> {
141        // ── 0. Embed query ───────────────────────────────────────────────────
142        let query_vec = embedder.embed(query)?;
143
144        // ── Stage 1 · MemRL – fast ANN via TurboQuant ───────────────────────
145        let ann_candidates = self.stage1_ann(conn, &query_vec)?;
146
147        if ann_candidates.is_empty() {
148            return Ok(Vec::new());
149        }
150
151        // ── Stage 2 · MemRL – exact cosine rerank ───────────────────────────
152        let mut reranked = self.stage2_rerank(conn, &query_vec, ann_candidates)?;
153        reranked.truncate(self.config.top_k_rerank);
154
155        // ── RAPO – expand with graph neighbours ─────────────────────────────
156        let mut pool: HashMap<i64, f32> = reranked.into_iter().collect();
157        if self.config.enable_graph_expansion {
158            self.rapo_expand(conn, &query_vec, &mut pool)?;
159        }
160
161        // ── Score & filter (SuperLocalMemory) ───────────────────────────────
162        let pool_size = pool.len();
163        let mut scored = self.score_and_filter(conn, &pool, pool_size)?;
164
165        // Sort by combined_score descending, take top k
166        scored.sort_by(|a, b| b.combined_score.partial_cmp(&a.combined_score).unwrap());
167        scored.truncate(k);
168
169        // ── Memex(RL) – attach context neighbours ───────────────────────────
170        for result in &mut scored {
171            let entity_id = result.entity.id.unwrap_or(0);
172            result.context_entities = self.collect_context(conn, entity_id, &pool)?;
173        }
174
175        Ok(scored)
176    }
177
178    // ─────────────────────────────────────────────────────────────────────────
179    // Private helpers
180    // ─────────────────────────────────────────────────────────────────────────
181
182    /// Stage 1: ANN via TurboQuant.  Returns (entity_id, approx_score) pairs.
183    ///
184    /// The TurboQuant index is persisted in `kg_turboquant_cache` and rebuilt
185    /// whenever either the *count* or the *checksum* of vectors in `kg_vectors`
186    /// has changed.  Using both metrics catches the case where one vector is
187    /// deleted and a different one is inserted (count stays the same but the
188    /// set of entity_ids — and therefore their SUM — changes).
189    fn stage1_ann(&self, conn: &Connection, query_vec: &[f32]) -> Result<Vec<(i64, f32)>> {
190        let vector_count: i64 =
191            conn.query_row("SELECT COUNT(*) FROM kg_vectors", [], |r| r.get(0))?;
192
193        if vector_count == 0 {
194            return Ok(Vec::new());
195        }
196
197        // Lightweight fingerprint: SUM of all entity_ids in kg_vectors.
198        // Autoincrement IDs are never reused, so any insert/delete changes this.
199        let vectors_checksum: i64 = conn.query_row(
200            "SELECT COALESCE(SUM(entity_id), 0) FROM kg_vectors",
201            [],
202            |r| r.get(0),
203        )?;
204
205        // Try to load a valid cached index first.
206        let cached = load_turboquant_cache(conn, vector_count, vectors_checksum)?;
207        let index = match cached {
208            Some(idx) => idx,
209            None => {
210                // Cache miss or stale — rebuild from kg_vectors.
211                let all_vectors = load_all_vectors(conn)?;
212                let dim = all_vectors[0].1.len();
213                let config = TurboQuantConfig {
214                    dimension: dim,
215                    bit_width: 3,
216                    seed: 42,
217                };
218                let mut idx = TurboQuantIndex::new(config)?;
219                for (entity_id, vec) in &all_vectors {
220                    idx.add_vector(*entity_id, vec)?;
221                }
222                save_turboquant_cache(conn, &idx, vector_count, vectors_checksum)?;
223                idx
224            }
225        };
226
227        let k = self.config.top_k_candidates.min(vector_count as usize);
228        index.search(query_vec, k)
229    }
230
231    /// Stage 2: exact cosine rerank.
232    fn stage2_rerank(
233        &self,
234        conn: &Connection,
235        query_vec: &[f32],
236        candidates: Vec<(i64, f32)>,
237    ) -> Result<Vec<(i64, f32)>> {
238        let store = VectorStore::new();
239        let mut scored: Vec<(i64, f32)> = Vec::with_capacity(candidates.len());
240
241        for (entity_id, approx) in candidates {
242            // SuperLocalMemory: drop if even the ANN score is below threshold
243            if approx < self.config.min_vector_score {
244                continue;
245            }
246            match store.get_vector(conn, entity_id) {
247                Ok(vec) => {
248                    let exact = cosine_similarity(query_vec, &vec);
249                    if exact >= self.config.min_vector_score {
250                        scored.push((entity_id, exact));
251                    }
252                }
253                Err(_) => {
254                    // entity_id no longer has a vector – skip silently
255                }
256            }
257        }
258
259        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
260        Ok(scored)
261    }
262
263    /// RAPO: BFS expand candidate pool with graph neighbours,
264    /// computing their vector scores on the fly.
265    fn rapo_expand(
266        &self,
267        conn: &Connection,
268        query_vec: &[f32],
269        pool: &mut HashMap<i64, f32>,
270    ) -> Result<()> {
271        let store = VectorStore::new();
272        let seeds: Vec<i64> = pool.keys().copied().collect();
273
274        for seed_id in seeds {
275            let neighbours = match get_neighbors(conn, seed_id, self.config.graph_depth) {
276                Ok(n) => n,
277                Err(_) => continue,
278            };
279
280            for nbr in neighbours {
281                let nbr_id = match nbr.entity.id {
282                    Some(id) => id,
283                    None => continue,
284                };
285
286                if pool.contains_key(&nbr_id) {
287                    continue;
288                }
289
290                // Score the new candidate
291                if let Ok(vec) = store.get_vector(conn, nbr_id) {
292                    let score = cosine_similarity(query_vec, &vec);
293                    if score >= self.config.min_vector_score {
294                        pool.insert(nbr_id, score);
295                    }
296                }
297            }
298        }
299
300        Ok(())
301    }
302
303    /// Compute graph score, combined score, apply SuperLocalMemory filter,
304    /// and build partial RagResult (context filled later).
305    fn score_and_filter(
306        &self,
307        conn: &Connection,
308        pool: &HashMap<i64, f32>,
309        pool_size: usize,
310    ) -> Result<Vec<RagResult>> {
311        let mut results = Vec::new();
312
313        for (&entity_id, &v_score) in pool {
314            let vector_score = v_score as f64;
315
316            // Graph score: fraction of pool that is a direct neighbour
317            let graph_score = if pool_size > 1 {
318                let neighbours = get_neighbors(conn, entity_id, 1).unwrap_or_default();
319                let overlap = neighbours
320                    .iter()
321                    .filter(|n| {
322                        n.entity
323                            .id
324                            .map(|id| pool.contains_key(&id))
325                            .unwrap_or(false)
326                    })
327                    .count();
328                overlap as f64 / (pool_size - 1) as f64
329            } else {
330                0.0
331            };
332
333            let combined_score =
334                self.config.vector_weight * vector_score + self.config.graph_weight * graph_score;
335
336            // SuperLocalMemory quality filter
337            if combined_score < self.config.min_combined_score {
338                continue;
339            }
340
341            let entity = match crate::graph::get_entity(conn, entity_id) {
342                Ok(e) => e,
343                Err(_) => continue,
344            };
345
346            results.push(RagResult {
347                entity,
348                vector_score,
349                graph_score,
350                combined_score,
351                context_entities: Vec::new(), // filled in next pass
352            });
353        }
354
355        Ok(results)
356    }
357
358    /// Memex(RL): collect context neighbours for a result entity via BFS,
359    /// prioritising entities already in the retrieval pool.
360    fn collect_context(
361        &self,
362        conn: &Connection,
363        entity_id: i64,
364        pool: &HashMap<i64, f32>,
365    ) -> Result<Vec<Entity>> {
366        let neighbours = match get_neighbors(conn, entity_id, self.config.context_depth) {
367            Ok(n) => n,
368            Err(_) => return Ok(Vec::new()),
369        };
370
371        // Sort: pool members first (high relevance), then by graph-BFS order
372        let mut in_pool: Vec<Entity> = Vec::new();
373        let mut not_in_pool: Vec<Entity> = Vec::new();
374
375        for nbr in neighbours {
376            if let Some(id) = nbr.entity.id {
377                if pool.contains_key(&id) {
378                    in_pool.push(nbr.entity);
379                } else {
380                    not_in_pool.push(nbr.entity);
381                }
382            }
383        }
384
385        in_pool.extend(not_in_pool);
386        in_pool.truncate(self.config.max_context_entities);
387        Ok(in_pool)
388    }
389}
390
391// ─────────────────────────────────────────────────────────────────────────────
392// Utility
393// ─────────────────────────────────────────────────────────────────────────────
394
395/// Load a TurboQuant index from the SQLite cache if it is still valid.
396///
397/// Returns `None` if no cache row exists or either the vector count or the
398/// checksum does not match, indicating the index is stale.
399fn load_turboquant_cache(
400    conn: &Connection,
401    current_count: i64,
402    current_checksum: i64,
403) -> Result<Option<TurboQuantIndex>> {
404    let mut stmt = conn.prepare(
405        "SELECT index_blob, vector_count, vectors_checksum \
406         FROM kg_turboquant_cache WHERE id = 1",
407    )?;
408
409    let result = stmt.query_row([], |row| {
410        let blob: Vec<u8> = row.get(0)?;
411        let cached_count: i64 = row.get(1)?;
412        let cached_checksum: i64 = row.get(2)?;
413        Ok((blob, cached_count, cached_checksum))
414    });
415
416    match result {
417        Ok((blob, cached_count, cached_checksum))
418            if cached_count == current_count && cached_checksum == current_checksum =>
419        {
420            let index = TurboQuantIndex::from_bytes(&blob)
421                .map_err(|e| crate::error::Error::Other(e.to_string()))?;
422            Ok(Some(index))
423        }
424        Ok(_) | Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
425        Err(e) => Err(e.into()),
426    }
427}
428
429/// Persist a TurboQuant index into `kg_turboquant_cache` (upsert).
430fn save_turboquant_cache(
431    conn: &Connection,
432    index: &TurboQuantIndex,
433    vector_count: i64,
434    vectors_checksum: i64,
435) -> Result<()> {
436    let blob = index
437        .to_bytes()
438        .map_err(|e| crate::error::Error::Other(e.to_string()))?;
439    conn.execute(
440        "INSERT INTO kg_turboquant_cache \
441             (id, index_blob, vector_count, vectors_checksum) \
442         VALUES (1, ?1, ?2, ?3) \
443         ON CONFLICT(id) DO UPDATE SET \
444             index_blob        = excluded.index_blob, \
445             vector_count      = excluded.vector_count, \
446             vectors_checksum  = excluded.vectors_checksum",
447        rusqlite::params![blob, vector_count, vectors_checksum],
448    )?;
449    Ok(())
450}
451
452fn load_all_vectors(conn: &Connection) -> Result<Vec<(i64, Vec<f32>)>> {
453    let mut stmt = conn.prepare("SELECT entity_id, vector, dimension FROM kg_vectors")?;
454
455    let rows = stmt.query_map([], |row| {
456        let entity_id: i64 = row.get(0)?;
457        let blob: Vec<u8> = row.get(1)?;
458        let dim: i64 = row.get(2)?;
459
460        let mut vec = Vec::with_capacity(dim as usize);
461        for chunk in blob.chunks_exact(4) {
462            vec.push(f32::from_le_bytes(chunk.try_into().unwrap()));
463        }
464
465        Ok((entity_id, vec))
466    })?;
467
468    let mut out = Vec::new();
469    for row in rows {
470        out.push(row?);
471    }
472    Ok(out)
473}
474
475// ─────────────────────────────────────────────────────────────────────────────
476// Tests
477// ─────────────────────────────────────────────────────────────────────────────
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use crate::graph::entity::{insert_entity, Entity};
483    use crate::graph::relation::{insert_relation, Relation};
484    use crate::rag::embedder::FixedEmbedder;
485    use crate::vector::VectorStore;
486    use rusqlite::Connection;
487
488    fn setup(dim: usize) -> (Connection, Vec<i64>) {
489        let conn = Connection::open_in_memory().unwrap();
490        crate::schema::create_schema(&conn).unwrap();
491
492        let e1 = insert_entity(&conn, &Entity::new("doc", "Doc A")).unwrap();
493        let e2 = insert_entity(&conn, &Entity::new("doc", "Doc B")).unwrap();
494        let e3 = insert_entity(&conn, &Entity::new("doc", "Doc C")).unwrap();
495
496        let store = VectorStore::new();
497        // e1 is very similar to query [1, 0, …]
498        let mut v1 = vec![0.0f32; dim];
499        v1[0] = 1.0;
500        store.insert_vector(&conn, e1, v1).unwrap();
501
502        // e2 is orthogonal to query
503        let mut v2 = vec![0.0f32; dim];
504        v2[1] = 1.0;
505        store.insert_vector(&conn, e2, v2).unwrap();
506
507        // e3 is somewhat similar
508        let mut v3 = vec![0.0f32; dim];
509        v3[0] = 0.8;
510        v3[1] = 0.6;
511        store.insert_vector(&conn, e3, v3).unwrap();
512
513        // e1 → e2 (weak link); e1 → e3 (strong link)
514        insert_relation(&conn, &Relation::new(e1, e2, "related", 0.3).unwrap()).unwrap();
515        insert_relation(&conn, &Relation::new(e1, e3, "related", 0.9).unwrap()).unwrap();
516
517        (conn, vec![e1, e2, e3])
518    }
519
520    #[test]
521    fn test_basic_search() {
522        let dim = 4;
523        let (conn, ids) = setup(dim);
524
525        let mut query = vec![0.0f32; dim];
526        query[0] = 1.0;
527
528        let embedder = FixedEmbedder(query);
529        let engine = RagEngine::new(RagConfig {
530            vector_dimension: dim,
531            top_k_candidates: 10,
532            top_k_rerank: 5,
533            ..Default::default()
534        });
535
536        let results = engine.search(&conn, &embedder, "test query", 2).unwrap();
537        assert!(!results.is_empty(), "should return at least one result");
538
539        // e1 must be the top result (similarity = 1.0)
540        assert_eq!(results[0].entity.id, Some(ids[0]));
541        assert!((results[0].vector_score - 1.0).abs() < 1e-5);
542    }
543
544    #[test]
545    fn test_empty_db() {
546        let conn = Connection::open_in_memory().unwrap();
547        crate::schema::create_schema(&conn).unwrap();
548
549        let embedder = FixedEmbedder(vec![1.0, 0.0, 0.0]);
550        let engine = RagEngine::new(RagConfig::default());
551
552        let results = engine.search(&conn, &embedder, "anything", 5).unwrap();
553        assert!(results.is_empty());
554    }
555
556    #[test]
557    fn test_graph_expansion() {
558        // Verify that RAPO brings in neighbours that were not in the ANN results
559        let dim = 4;
560        let conn = Connection::open_in_memory().unwrap();
561        crate::schema::create_schema(&conn).unwrap();
562
563        let store = VectorStore::new();
564
565        // e1 very similar to query; e2 orthogonal; e1→e2 link
566        let e1 = insert_entity(&conn, &Entity::new("doc", "A")).unwrap();
567        let e2 = insert_entity(&conn, &Entity::new("doc", "B")).unwrap();
568
569        let mut v1 = vec![0.0f32; dim];
570        v1[0] = 1.0;
571        store.insert_vector(&conn, e1, v1).unwrap();
572
573        let mut v2 = vec![0.0f32; dim];
574        v2[1] = 1.0;
575        store.insert_vector(&conn, e2, v2).unwrap();
576
577        insert_relation(&conn, &Relation::new(e1, e2, "link", 1.0).unwrap()).unwrap();
578
579        let mut query = vec![0.0f32; dim];
580        query[0] = 1.0;
581
582        let embedder = FixedEmbedder(query);
583        let engine = RagEngine::new(RagConfig {
584            vector_dimension: dim,
585            top_k_candidates: 1, // only ANN fetches e1; RAPO adds e2
586            top_k_rerank: 1,
587            enable_graph_expansion: true,
588            ..Default::default()
589        });
590
591        let results = engine.search(&conn, &embedder, "q", 5).unwrap();
592        let ids: Vec<i64> = results.iter().filter_map(|r| r.entity.id).collect();
593        assert!(ids.contains(&e1));
594        assert!(ids.contains(&e2), "RAPO should expand to e2");
595    }
596
597    #[test]
598    fn test_context_attached() {
599        let dim = 4;
600        let (conn, ids) = setup(dim);
601
602        let mut query = vec![0.0f32; dim];
603        query[0] = 1.0;
604
605        let embedder = FixedEmbedder(query);
606        let engine = RagEngine::new(RagConfig {
607            vector_dimension: dim,
608            context_depth: 1,
609            max_context_entities: 3,
610            ..Default::default()
611        });
612
613        let results = engine.search(&conn, &embedder, "q", 3).unwrap();
614
615        // e1's result should have context neighbours (e2 and e3)
616        let e1_result = results.iter().find(|r| r.entity.id == Some(ids[0]));
617        assert!(e1_result.is_some());
618        let ctx = &e1_result.unwrap().context_entities;
619        assert!(!ctx.is_empty(), "e1 should have context neighbours");
620    }
621
622    // ── TurboQuant cache tests ────────────────────────────────────────────────
623
624    #[test]
625    fn test_cache_written_on_first_query() {
626        let dim = 4;
627        let (conn, _ids) = setup(dim);
628
629        let mut query = vec![0.0f32; dim];
630        query[0] = 1.0;
631        let embedder = FixedEmbedder(query);
632        let engine = RagEngine::new(RagConfig {
633            vector_dimension: dim,
634            top_k_candidates: 10,
635            top_k_rerank: 5,
636            ..Default::default()
637        });
638
639        engine.search(&conn, &embedder, "q", 2).unwrap();
640
641        // Cache row must exist after the first search
642        let count: i64 = conn
643            .query_row(
644                "SELECT COUNT(*) FROM kg_turboquant_cache WHERE id = 1",
645                [],
646                |r| r.get(0),
647            )
648            .unwrap();
649        assert_eq!(count, 1, "cache row should be created after first query");
650    }
651
652    #[test]
653    fn test_cache_hit_on_second_query() {
654        let dim = 4;
655        let (conn, _ids) = setup(dim);
656
657        let mut query = vec![0.0f32; dim];
658        query[0] = 1.0;
659        let embedder = FixedEmbedder(query);
660        let engine = RagEngine::new(RagConfig {
661            vector_dimension: dim,
662            top_k_candidates: 10,
663            top_k_rerank: 5,
664            ..Default::default()
665        });
666
667        let r1 = engine.search(&conn, &embedder, "q", 2).unwrap();
668        let r2 = engine.search(&conn, &embedder, "q", 2).unwrap();
669
670        // Both searches return the same top entity
671        assert_eq!(
672            r1[0].entity.id, r2[0].entity.id,
673            "cache hit should return identical results"
674        );
675    }
676
677    #[test]
678    fn test_cache_stores_checksum() {
679        let dim = 4;
680        let (conn, _ids) = setup(dim);
681
682        let query = {
683            let mut q = vec![0.0f32; dim];
684            q[0] = 1.0;
685            q
686        };
687        let embedder = FixedEmbedder(query);
688        let engine = RagEngine::new(RagConfig {
689            vector_dimension: dim,
690            top_k_candidates: 10,
691            top_k_rerank: 5,
692            ..Default::default()
693        });
694
695        engine.search(&conn, &embedder, "q", 2).unwrap();
696
697        // Both count and checksum must be stored.
698        let (count, checksum): (i64, i64) = conn
699            .query_row(
700                "SELECT vector_count, vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
701                [],
702                |r| Ok((r.get(0)?, r.get(1)?)),
703            )
704            .unwrap();
705        assert_eq!(count, 3);
706        // SUM of entity_ids for 3 autoincrement entities must be > 0.
707        assert!(checksum > 0, "checksum should reflect entity_id sum");
708    }
709
710    #[test]
711    fn test_cache_invalidated_on_same_count_different_entity() {
712        // Verifies that swapping one vector (same count, different entity_id)
713        // causes a cache miss via the checksum.
714        let dim = 4;
715        let (conn, ids) = setup(dim); // 3 vectors: entity_ids ids[0], ids[1], ids[2]
716
717        let query = {
718            let mut q = vec![0.0f32; dim];
719            q[0] = 1.0;
720            q
721        };
722        let embedder = FixedEmbedder(query);
723        let engine = RagEngine::new(RagConfig {
724            vector_dimension: dim,
725            top_k_candidates: 10,
726            top_k_rerank: 5,
727            ..Default::default()
728        });
729
730        // First search — writes cache (count=3, checksum=ids[0]+ids[1]+ids[2]).
731        engine.search(&conn, &embedder, "q", 2).unwrap();
732
733        let checksum_before: i64 = conn
734            .query_row(
735                "SELECT vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
736                [],
737                |r| r.get(0),
738            )
739            .unwrap();
740
741        // Delete one vector and insert a brand-new one (autoincrement → higher id).
742        conn.execute("DELETE FROM kg_vectors WHERE entity_id = ?1", [ids[2]])
743            .unwrap();
744        let e_new = crate::graph::entity::insert_entity(
745            &conn,
746            &crate::graph::entity::Entity::new("doc", "Doc Swap"),
747        )
748        .unwrap();
749        let store = VectorStore::new();
750        let mut v_new = vec![0.0f32; dim];
751        v_new[3] = 1.0;
752        store.insert_vector(&conn, e_new, v_new).unwrap();
753        // count is still 3, but checksum changed because e_new.id != ids[2].
754
755        // Second search — must detect checksum mismatch → rebuild.
756        engine.search(&conn, &embedder, "q", 2).unwrap();
757
758        let (count_after, checksum_after): (i64, i64) = conn
759            .query_row(
760                "SELECT vector_count, vectors_checksum FROM kg_turboquant_cache WHERE id = 1",
761                [],
762                |r| Ok((r.get(0)?, r.get(1)?)),
763            )
764            .unwrap();
765        assert_eq!(count_after, 3, "vector count should still be 3 after swap");
766        assert_ne!(
767            checksum_after, checksum_before,
768            "checksum must change after swapping one vector"
769        );
770    }
771
772    #[test]
773    fn test_cache_invalidated_after_new_vector() {
774        let dim = 4;
775        let (conn, _ids) = setup(dim);
776
777        let mut query = vec![0.0f32; dim];
778        query[0] = 1.0;
779        let embedder = FixedEmbedder(query);
780        let engine = RagEngine::new(RagConfig {
781            vector_dimension: dim,
782            top_k_candidates: 10,
783            top_k_rerank: 5,
784            ..Default::default()
785        });
786
787        // First search — writes cache with vector_count = 3
788        engine.search(&conn, &embedder, "q", 2).unwrap();
789
790        let cached_count_before: i64 = conn
791            .query_row(
792                "SELECT vector_count FROM kg_turboquant_cache WHERE id = 1",
793                [],
794                |r| r.get(0),
795            )
796            .unwrap();
797        assert_eq!(cached_count_before, 3);
798
799        // Add a 4th vector
800        let e4 = crate::graph::entity::insert_entity(
801            &conn,
802            &crate::graph::entity::Entity::new("doc", "Doc D"),
803        )
804        .unwrap();
805        let store = VectorStore::new();
806        let mut v4 = vec![0.0f32; dim];
807        v4[2] = 1.0;
808        store.insert_vector(&conn, e4, v4).unwrap();
809
810        // Second search — must rebuild and update cache to vector_count = 4
811        engine.search(&conn, &embedder, "q", 2).unwrap();
812
813        let cached_count_after: i64 = conn
814            .query_row(
815                "SELECT vector_count FROM kg_turboquant_cache WHERE id = 1",
816                [],
817                |r| r.get(0),
818            )
819            .unwrap();
820        assert_eq!(
821            cached_count_after, 4,
822            "cache should be rebuilt after new vector added"
823        );
824    }
825}