Skip to main content

sqlite_knowledge_graph/rag/
smart_retrieval.rs

1//! Four-signal retrieval scoring for SmartVector.
2//!
3//! Final score = w1·cosine + w2·temporal + w3·confidence + w4·graph_importance
4//!
5//! Default weights: w1=0.5, w2=0.2, w3=0.2, w4=0.1.
6
7use crate::error::{Error, Result};
8use crate::graph::get_entity;
9use crate::vector::confidence::{now_unix, ConfidenceEngine};
10use crate::vector::VectorStore;
11use rusqlite::Connection;
12use std::collections::HashMap;
13use tracing::debug;
14
15const TEMPORAL_DECAY_FACTOR: f64 = 0.1; // per-day decay outside validity window
16const SECS_PER_DAY: f64 = 86_400.0;
17
18// ─────────────────────────────────────────────────────────────────────────────
19// Public types
20// ─────────────────────────────────────────────────────────────────────────────
21
22/// Blending weights for the four retrieval signals.
23#[derive(Debug, Clone, Copy)]
24pub struct RetrievalWeights {
25    pub w1: f64, // cosine similarity
26    pub w2: f64, // temporal validity
27    pub w3: f64, // live confidence
28    pub w4: f64, // graph importance
29}
30
31impl Default for RetrievalWeights {
32    fn default() -> Self {
33        Self {
34            w1: 0.5,
35            w2: 0.2,
36            w3: 0.2,
37            w4: 0.1,
38        }
39    }
40}
41
42/// One result from four-signal retrieval.
43#[derive(Debug, Clone)]
44pub struct SmartSearchResult {
45    pub entity: crate::graph::Entity,
46    pub final_score: f64,
47    pub cosine_score: f64,
48    pub temporal_score: f64,
49    pub confidence_score: f64,
50    pub graph_importance: f64,
51}
52
53// ─────────────────────────────────────────────────────────────────────────────
54// SmartRetrieval
55// ─────────────────────────────────────────────────────────────────────────────
56
57/// Retrieval engine that combines four signals.
58#[derive(Default)]
59pub struct SmartRetrieval {
60    pub weights: RetrievalWeights,
61}
62
63impl SmartRetrieval {
64    pub fn new(weights: RetrievalWeights) -> Self {
65        Self { weights }
66    }
67
68    pub fn set_weights(&mut self, weights: RetrievalWeights) {
69        self.weights = weights;
70    }
71
72    /// Retrieve top-`k` entities scored by the four-signal formula.
73    pub fn retrieve(
74        &self,
75        conn: &Connection,
76        query: &[f32],
77        top_k: usize,
78    ) -> Result<Vec<SmartSearchResult>> {
79        let store = VectorStore::new();
80        // Fetch a larger candidate pool so re-ranking has enough options.
81        let pool_size = (top_k * 3).max(20);
82        let candidates = store.search_vectors(conn, query.to_vec(), pool_size)?;
83
84        if candidates.is_empty() {
85            return Ok(vec![]);
86        }
87
88        let ids: Vec<i64> = candidates.iter().map(|c| c.entity_id).collect();
89        let indegrees = load_indegrees(conn, &ids)?;
90        let max_indegree = indegrees.values().copied().fold(0u32, u32::max);
91        let now = now_unix();
92
93        let mut results = Vec::with_capacity(candidates.len());
94        for candidate in &candidates {
95            let eid = candidate.entity_id;
96            let cosine = candidate.similarity as f64;
97            let temporal = temporal_validity(conn, eid, now)?;
98            let conf = ConfidenceEngine::default().get_confidence(conn, eid)?;
99            let importance = if max_indegree > 0 {
100                *indegrees.get(&eid).unwrap_or(&0) as f64 / max_indegree as f64
101            } else {
102                0.0
103            };
104
105            let final_score = self.weights.w1 * cosine
106                + self.weights.w2 * temporal
107                + self.weights.w3 * conf
108                + self.weights.w4 * importance;
109
110            let entity = get_entity(conn, eid)?;
111            results.push(SmartSearchResult {
112                entity,
113                final_score,
114                cosine_score: cosine,
115                temporal_score: temporal,
116                confidence_score: conf,
117                graph_importance: importance,
118            });
119        }
120
121        results.sort_by(|a, b| {
122            b.final_score
123                .partial_cmp(&a.final_score)
124                .unwrap_or(std::cmp::Ordering::Equal)
125        });
126        results.truncate(top_k);
127
128        debug!(
129            top_k,
130            found = results.len(),
131            "four-signal retrieval complete"
132        );
133        Ok(results)
134    }
135}
136
137// ─────────────────────────────────────────────────────────────────────────────
138// Signal helpers
139// ─────────────────────────────────────────────────────────────────────────────
140
141fn load_indegrees(conn: &Connection, ids: &[i64]) -> Result<HashMap<i64, u32>> {
142    if ids.is_empty() {
143        return Ok(HashMap::new());
144    }
145    // Build a single grouped query instead of one SELECT per id.
146    let placeholders = ids
147        .iter()
148        .enumerate()
149        .map(|(i, _)| format!("?{}", i + 1))
150        .collect::<Vec<_>>()
151        .join(", ");
152    let sql = format!(
153        "SELECT target_id, COUNT(*) FROM kg_dependencies WHERE target_id IN ({placeholders}) GROUP BY target_id"
154    );
155    let mut stmt = conn.prepare(&sql)?;
156    let params = rusqlite::params_from_iter(ids.iter());
157    let mut map: HashMap<i64, u32> = ids.iter().map(|&id| (id, 0)).collect();
158    let rows = stmt.query_map(params, |r| Ok((r.get::<_, i64>(0)?, r.get::<_, u32>(1)?)))?;
159    for row in rows {
160        let (id, count) = row?;
161        map.insert(id, count);
162    }
163    Ok(map)
164}
165
166/// Returns a score in [0, 1] reflecting how valid the entity is right now.
167fn temporal_validity(conn: &Connection, entity_id: i64, now: i64) -> Result<f64> {
168    let (valid_from, valid_until): (Option<i64>, Option<i64>) = conn
169        .query_row(
170            "SELECT valid_from, valid_until FROM kg_entities WHERE id = ?1",
171            [entity_id],
172            |r| Ok((r.get(0)?, r.get(1)?)),
173        )
174        .map_err(|e| match e {
175            rusqlite::Error::QueryReturnedNoRows => Error::EntityNotFound(entity_id),
176            other => Error::SQLite(other),
177        })?;
178
179    if let Some(from) = valid_from {
180        if now < from {
181            return Ok(0.0); // not yet valid
182        }
183    }
184    if let Some(until) = valid_until {
185        if now > until {
186            let days_over = (now - until) as f64 / SECS_PER_DAY;
187            return Ok((-TEMPORAL_DECAY_FACTOR * days_over).exp());
188        }
189    }
190    Ok(1.0)
191}
192
193// ─────────────────────────────────────────────────────────────────────────────
194// Tests
195// ─────────────────────────────────────────────────────────────────────────────
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::schema::ensure_schema;
201
202    fn setup() -> Connection {
203        let conn = Connection::open_in_memory().unwrap();
204        ensure_schema(&conn).unwrap();
205        conn
206    }
207
208    fn add_entity_with_vector(conn: &Connection, name: &str, vec: &[f32]) -> i64 {
209        conn.execute(
210            "INSERT INTO kg_entities (entity_type, name) VALUES ('t', ?1)",
211            [name],
212        )
213        .unwrap();
214        let id = conn.last_insert_rowid();
215        let store = VectorStore::new();
216        store.insert_vector(conn, id, vec.to_vec()).unwrap();
217        id
218    }
219
220    #[test]
221    fn retrieves_top_k_results() {
222        let conn = setup();
223        add_entity_with_vector(&conn, "A", &[1.0, 0.0, 0.0]);
224        add_entity_with_vector(&conn, "B", &[0.9, 0.1, 0.0]);
225        add_entity_with_vector(&conn, "C", &[0.0, 0.0, 1.0]);
226
227        let sr = SmartRetrieval::default();
228        let results = sr.retrieve(&conn, &[1.0, 0.0, 0.0], 2).unwrap();
229        assert_eq!(results.len(), 2);
230        // A or B should be at the top (most similar to query)
231        assert!(results[0].cosine_score >= results[1].cosine_score - 0.1);
232    }
233
234    #[test]
235    fn temporal_past_window_decays_score() {
236        let conn = setup();
237        let id = add_entity_with_vector(&conn, "old", &[1.0, 0.0]);
238        // Set valid_until 365 days in the past
239        let past = now_unix() - 365 * 86400;
240        conn.execute(
241            "UPDATE kg_entities SET valid_until = ?1 WHERE id = ?2",
242            rusqlite::params![past, id],
243        )
244        .unwrap();
245
246        let score = temporal_validity(&conn, id, now_unix()).unwrap();
247        assert!(
248            score < 0.01,
249            "expired entity should have near-zero temporal score"
250        );
251    }
252
253    #[test]
254    fn temporal_future_window_returns_zero() {
255        let conn = setup();
256        let id = add_entity_with_vector(&conn, "future", &[1.0, 0.0]);
257        let future = now_unix() + 86400;
258        conn.execute(
259            "UPDATE kg_entities SET valid_from = ?1 WHERE id = ?2",
260            rusqlite::params![future, id],
261        )
262        .unwrap();
263
264        let score = temporal_validity(&conn, id, now_unix()).unwrap();
265        assert_eq!(
266            score, 0.0,
267            "not-yet-valid entity should have zero temporal score"
268        );
269    }
270
271    #[test]
272    fn configurable_weights_affect_ranking() {
273        let conn = setup();
274        let _id_a = add_entity_with_vector(&conn, "A", &[1.0, 0.0]);
275        let id_b = add_entity_with_vector(&conn, "B", &[0.5, 0.5]);
276
277        // Give B higher base_confidence so the live formula returns a higher score.
278        conn.execute(
279            "UPDATE kg_entities SET base_confidence = 2.0 WHERE id = ?1",
280            [id_b],
281        )
282        .unwrap();
283
284        // High confidence weight: B might rank above A despite lower cosine
285        let mut sr = SmartRetrieval::default();
286        sr.set_weights(RetrievalWeights {
287            w1: 0.1,
288            w2: 0.1,
289            w3: 0.7,
290            w4: 0.1,
291        });
292        let results = sr.retrieve(&conn, &[1.0, 0.0], 2).unwrap();
293        assert_eq!(results.len(), 2);
294        // B should now rank first due to high confidence weight
295        assert_eq!(results[0].entity.id, Some(id_b));
296    }
297
298    #[test]
299    fn graph_importance_boosts_score() {
300        let conn = setup();
301        let _id_a = add_entity_with_vector(&conn, "A", &[1.0, 0.0]);
302        let id_b = add_entity_with_vector(&conn, "B", &[1.0, 0.0]);
303
304        // Make several entities depend on B (high in-degree)
305        for _ in 0..5 {
306            conn.execute(
307                "INSERT INTO kg_entities (entity_type, name) VALUES ('dep', 'dep')",
308                [],
309            )
310            .unwrap();
311            let dep_id = conn.last_insert_rowid();
312            conn.execute(
313                "INSERT INTO kg_dependencies (source_id, target_id, dep_type) VALUES (?1, ?2, 'depends_on')",
314                rusqlite::params![dep_id, id_b],
315            )
316            .unwrap();
317        }
318
319        let sr = SmartRetrieval::new(RetrievalWeights {
320            w1: 0.0,
321            w2: 0.0,
322            w3: 0.0,
323            w4: 1.0, // only graph importance
324        });
325        let results = sr.retrieve(&conn, &[1.0, 0.0], 2).unwrap();
326        assert_eq!(results.len(), 2);
327        // B has higher in-degree so should rank first
328        assert_eq!(
329            results[0].entity.id,
330            Some(id_b),
331            "high in-degree entity should rank first"
332        );
333        assert!(
334            results[0].graph_importance > results[1].graph_importance,
335            "importance should be normalised"
336        );
337    }
338}