sqlite_knowledge_graph/rag/
smart_retrieval.rs1use crate::error::{Error, Result};
8use crate::graph::get_entity;
9use crate::vector::confidence::{now_unix, ConfidenceEngine};
10use crate::vector::VectorStore;
11use rusqlite::Connection;
12use std::collections::HashMap;
13use tracing::debug;
14
15const TEMPORAL_DECAY_FACTOR: f64 = 0.1; const SECS_PER_DAY: f64 = 86_400.0;
17
18#[derive(Debug, Clone, Copy)]
24pub struct RetrievalWeights {
25 pub w1: f64, pub w2: f64, pub w3: f64, pub w4: f64, }
30
31impl Default for RetrievalWeights {
32 fn default() -> Self {
33 Self {
34 w1: 0.5,
35 w2: 0.2,
36 w3: 0.2,
37 w4: 0.1,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct SmartSearchResult {
45 pub entity: crate::graph::Entity,
46 pub final_score: f64,
47 pub cosine_score: f64,
48 pub temporal_score: f64,
49 pub confidence_score: f64,
50 pub graph_importance: f64,
51}
52
53#[derive(Default)]
59pub struct SmartRetrieval {
60 pub weights: RetrievalWeights,
61}
62
63impl SmartRetrieval {
64 pub fn new(weights: RetrievalWeights) -> Self {
65 Self { weights }
66 }
67
68 pub fn set_weights(&mut self, weights: RetrievalWeights) {
69 self.weights = weights;
70 }
71
72 pub fn retrieve(
74 &self,
75 conn: &Connection,
76 query: &[f32],
77 top_k: usize,
78 ) -> Result<Vec<SmartSearchResult>> {
79 let store = VectorStore::new();
80 let pool_size = (top_k * 3).max(20);
82 let candidates = store.search_vectors(conn, query.to_vec(), pool_size)?;
83
84 if candidates.is_empty() {
85 return Ok(vec![]);
86 }
87
88 let ids: Vec<i64> = candidates.iter().map(|c| c.entity_id).collect();
89 let indegrees = load_indegrees(conn, &ids)?;
90 let max_indegree = indegrees.values().copied().fold(0u32, u32::max);
91 let now = now_unix();
92
93 let mut results = Vec::with_capacity(candidates.len());
94 for candidate in &candidates {
95 let eid = candidate.entity_id;
96 let cosine = candidate.similarity as f64;
97 let temporal = temporal_validity(conn, eid, now)?;
98 let conf = ConfidenceEngine::default().get_confidence(conn, eid)?;
99 let importance = if max_indegree > 0 {
100 *indegrees.get(&eid).unwrap_or(&0) as f64 / max_indegree as f64
101 } else {
102 0.0
103 };
104
105 let final_score = self.weights.w1 * cosine
106 + self.weights.w2 * temporal
107 + self.weights.w3 * conf
108 + self.weights.w4 * importance;
109
110 let entity = get_entity(conn, eid)?;
111 results.push(SmartSearchResult {
112 entity,
113 final_score,
114 cosine_score: cosine,
115 temporal_score: temporal,
116 confidence_score: conf,
117 graph_importance: importance,
118 });
119 }
120
121 results.sort_by(|a, b| {
122 b.final_score
123 .partial_cmp(&a.final_score)
124 .unwrap_or(std::cmp::Ordering::Equal)
125 });
126 results.truncate(top_k);
127
128 debug!(
129 top_k,
130 found = results.len(),
131 "four-signal retrieval complete"
132 );
133 Ok(results)
134 }
135}
136
137fn load_indegrees(conn: &Connection, ids: &[i64]) -> Result<HashMap<i64, u32>> {
142 if ids.is_empty() {
143 return Ok(HashMap::new());
144 }
145 let placeholders = ids
147 .iter()
148 .enumerate()
149 .map(|(i, _)| format!("?{}", i + 1))
150 .collect::<Vec<_>>()
151 .join(", ");
152 let sql = format!(
153 "SELECT target_id, COUNT(*) FROM kg_dependencies WHERE target_id IN ({placeholders}) GROUP BY target_id"
154 );
155 let mut stmt = conn.prepare(&sql)?;
156 let params = rusqlite::params_from_iter(ids.iter());
157 let mut map: HashMap<i64, u32> = ids.iter().map(|&id| (id, 0)).collect();
158 let rows = stmt.query_map(params, |r| Ok((r.get::<_, i64>(0)?, r.get::<_, u32>(1)?)))?;
159 for row in rows {
160 let (id, count) = row?;
161 map.insert(id, count);
162 }
163 Ok(map)
164}
165
166fn temporal_validity(conn: &Connection, entity_id: i64, now: i64) -> Result<f64> {
168 let (valid_from, valid_until): (Option<i64>, Option<i64>) = conn
169 .query_row(
170 "SELECT valid_from, valid_until FROM kg_entities WHERE id = ?1",
171 [entity_id],
172 |r| Ok((r.get(0)?, r.get(1)?)),
173 )
174 .map_err(|e| match e {
175 rusqlite::Error::QueryReturnedNoRows => Error::EntityNotFound(entity_id),
176 other => Error::SQLite(other),
177 })?;
178
179 if let Some(from) = valid_from {
180 if now < from {
181 return Ok(0.0); }
183 }
184 if let Some(until) = valid_until {
185 if now > until {
186 let days_over = (now - until) as f64 / SECS_PER_DAY;
187 return Ok((-TEMPORAL_DECAY_FACTOR * days_over).exp());
188 }
189 }
190 Ok(1.0)
191}
192
193#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::schema::ensure_schema;
201
202 fn setup() -> Connection {
203 let conn = Connection::open_in_memory().unwrap();
204 ensure_schema(&conn).unwrap();
205 conn
206 }
207
208 fn add_entity_with_vector(conn: &Connection, name: &str, vec: &[f32]) -> i64 {
209 conn.execute(
210 "INSERT INTO kg_entities (entity_type, name) VALUES ('t', ?1)",
211 [name],
212 )
213 .unwrap();
214 let id = conn.last_insert_rowid();
215 let store = VectorStore::new();
216 store.insert_vector(conn, id, vec.to_vec()).unwrap();
217 id
218 }
219
220 #[test]
221 fn retrieves_top_k_results() {
222 let conn = setup();
223 add_entity_with_vector(&conn, "A", &[1.0, 0.0, 0.0]);
224 add_entity_with_vector(&conn, "B", &[0.9, 0.1, 0.0]);
225 add_entity_with_vector(&conn, "C", &[0.0, 0.0, 1.0]);
226
227 let sr = SmartRetrieval::default();
228 let results = sr.retrieve(&conn, &[1.0, 0.0, 0.0], 2).unwrap();
229 assert_eq!(results.len(), 2);
230 assert!(results[0].cosine_score >= results[1].cosine_score - 0.1);
232 }
233
234 #[test]
235 fn temporal_past_window_decays_score() {
236 let conn = setup();
237 let id = add_entity_with_vector(&conn, "old", &[1.0, 0.0]);
238 let past = now_unix() - 365 * 86400;
240 conn.execute(
241 "UPDATE kg_entities SET valid_until = ?1 WHERE id = ?2",
242 rusqlite::params![past, id],
243 )
244 .unwrap();
245
246 let score = temporal_validity(&conn, id, now_unix()).unwrap();
247 assert!(
248 score < 0.01,
249 "expired entity should have near-zero temporal score"
250 );
251 }
252
253 #[test]
254 fn temporal_future_window_returns_zero() {
255 let conn = setup();
256 let id = add_entity_with_vector(&conn, "future", &[1.0, 0.0]);
257 let future = now_unix() + 86400;
258 conn.execute(
259 "UPDATE kg_entities SET valid_from = ?1 WHERE id = ?2",
260 rusqlite::params![future, id],
261 )
262 .unwrap();
263
264 let score = temporal_validity(&conn, id, now_unix()).unwrap();
265 assert_eq!(
266 score, 0.0,
267 "not-yet-valid entity should have zero temporal score"
268 );
269 }
270
271 #[test]
272 fn configurable_weights_affect_ranking() {
273 let conn = setup();
274 let _id_a = add_entity_with_vector(&conn, "A", &[1.0, 0.0]);
275 let id_b = add_entity_with_vector(&conn, "B", &[0.5, 0.5]);
276
277 conn.execute(
279 "UPDATE kg_entities SET base_confidence = 2.0 WHERE id = ?1",
280 [id_b],
281 )
282 .unwrap();
283
284 let mut sr = SmartRetrieval::default();
286 sr.set_weights(RetrievalWeights {
287 w1: 0.1,
288 w2: 0.1,
289 w3: 0.7,
290 w4: 0.1,
291 });
292 let results = sr.retrieve(&conn, &[1.0, 0.0], 2).unwrap();
293 assert_eq!(results.len(), 2);
294 assert_eq!(results[0].entity.id, Some(id_b));
296 }
297
298 #[test]
299 fn graph_importance_boosts_score() {
300 let conn = setup();
301 let _id_a = add_entity_with_vector(&conn, "A", &[1.0, 0.0]);
302 let id_b = add_entity_with_vector(&conn, "B", &[1.0, 0.0]);
303
304 for _ in 0..5 {
306 conn.execute(
307 "INSERT INTO kg_entities (entity_type, name) VALUES ('dep', 'dep')",
308 [],
309 )
310 .unwrap();
311 let dep_id = conn.last_insert_rowid();
312 conn.execute(
313 "INSERT INTO kg_dependencies (source_id, target_id, dep_type) VALUES (?1, ?2, 'depends_on')",
314 rusqlite::params![dep_id, id_b],
315 )
316 .unwrap();
317 }
318
319 let sr = SmartRetrieval::new(RetrievalWeights {
320 w1: 0.0,
321 w2: 0.0,
322 w3: 0.0,
323 w4: 1.0, });
325 let results = sr.retrieve(&conn, &[1.0, 0.0], 2).unwrap();
326 assert_eq!(results.len(), 2);
327 assert_eq!(
329 results[0].entity.id,
330 Some(id_b),
331 "high in-degree entity should rank first"
332 );
333 assert!(
334 results[0].graph_importance > results[1].graph_importance,
335 "importance should be normalised"
336 );
337 }
338}