sqlite_knowledge_graph/vector/
confidence.rs1use crate::error::{Error, Result};
12use rusqlite::Connection;
13use tracing::debug;
14
15const SECS_PER_DAY: f64 = 86_400.0;
16
17#[derive(Debug, Clone)]
23pub struct ConfidenceParams {
24 pub lambda: f64,
26 pub access_bonus: f64,
28}
29
30impl Default for ConfidenceParams {
31 fn default() -> Self {
32 Self {
33 lambda: 0.05,
34 access_bonus: 0.1,
35 }
36 }
37}
38
39#[derive(Default)]
41pub struct ConfidenceEngine {
42 pub params: ConfidenceParams,
43}
44
45impl ConfidenceEngine {
46 pub fn new(params: ConfidenceParams) -> Self {
47 Self { params }
48 }
49
50 pub fn compute(
52 &self,
53 base: f64,
54 lambda: f64,
55 elapsed_days: f64,
56 access_count: i64,
57 feedback_sum: f64,
58 ) -> f64 {
59 let fb = feedback_sum.clamp(-1.0, 1.0);
60 base * (-lambda * elapsed_days).exp()
61 + self.params.access_bonus * (1.0 + access_count as f64).ln()
62 + fb
63 }
64
65 pub fn get_confidence(&self, conn: &Connection, entity_id: i64) -> Result<f64> {
67 let (base, stored_lambda, created_at, access_count) = conn
68 .query_row(
69 "SELECT \
70 COALESCE(base_confidence, 1.0), \
71 decay_rate, \
72 COALESCE(created_at, 0), \
73 COALESCE(access_count, 0) \
74 FROM kg_entities WHERE id = ?1",
75 [entity_id],
76 |r| {
77 Ok((
78 r.get::<_, f64>(0)?,
79 r.get::<_, Option<f64>>(1)?,
80 r.get::<_, i64>(2)?,
81 r.get::<_, i64>(3)?,
82 ))
83 },
84 )
85 .map_err(|e| match e {
86 rusqlite::Error::QueryReturnedNoRows => Error::EntityNotFound(entity_id),
87 other => Error::SQLite(other),
88 })?;
89 let lambda = stored_lambda.unwrap_or(self.params.lambda);
90
91 let elapsed_days = (now_unix() - created_at).max(0) as f64 / SECS_PER_DAY;
92 let feedback_sum = feedback_sum_for(conn, entity_id)?;
93 let conf = self.compute(base, lambda, elapsed_days, access_count, feedback_sum);
94
95 debug!(entity_id, elapsed_days, conf, "live confidence computed");
96 Ok(conf)
97 }
98
99 pub fn update_confidence(
101 &self,
102 conn: &Connection,
103 entity_id: i64,
104 feedback: f64,
105 ) -> Result<f64> {
106 let old_conf = self.get_confidence(conn, entity_id)?;
107 let ts = now_unix();
108
109 let tx = conn.unchecked_transaction()?;
110
111 tx.execute(
113 "INSERT INTO kg_confidence_log \
114 (entity_id, old_value, new_value, reason, created_at) \
115 VALUES (?1, ?2, ?3, 'feedback', ?4)",
116 rusqlite::params![entity_id, old_conf, old_conf + feedback, ts],
117 )?;
118 let log_rowid = tx.last_insert_rowid();
119
120 let new_conf = self.get_confidence(&tx, entity_id)?;
122
123 tx.execute(
125 "UPDATE kg_confidence_log SET new_value = ?1 WHERE rowid = ?2",
126 rusqlite::params![new_conf, log_rowid],
127 )?;
128
129 tx.execute(
130 "UPDATE kg_entities SET confidence = ?1 WHERE id = ?2",
131 rusqlite::params![new_conf, entity_id],
132 )?;
133
134 tx.commit()?;
135
136 debug!(
137 entity_id,
138 old_conf, new_conf, feedback, "confidence updated"
139 );
140 Ok(new_conf)
141 }
142}
143
144fn feedback_sum_for(conn: &Connection, entity_id: i64) -> Result<f64> {
150 let sum: f64 = conn.query_row(
151 "SELECT COALESCE(SUM(new_value - old_value), 0.0) \
152 FROM kg_confidence_log \
153 WHERE entity_id = ?1 AND reason = 'feedback'",
154 [entity_id],
155 |r| r.get(0),
156 )?;
157 Ok(sum.clamp(-1.0, 1.0))
158}
159
160pub(crate) fn now_unix() -> i64 {
161 std::time::SystemTime::now()
162 .duration_since(std::time::UNIX_EPOCH)
163 .expect("system time before Unix epoch")
164 .as_secs() as i64
165}
166
167#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::schema::ensure_schema;
175
176 fn setup() -> Connection {
177 let conn = Connection::open_in_memory().unwrap();
178 ensure_schema(&conn).unwrap();
179 conn
180 }
181
182 fn insert_entity(conn: &Connection, base: f64, lambda: f64) -> i64 {
183 conn.execute(
184 "INSERT INTO kg_entities (entity_type, name, base_confidence, decay_rate) \
185 VALUES ('test', 'E', ?1, ?2)",
186 rusqlite::params![base, lambda],
187 )
188 .unwrap();
189 conn.last_insert_rowid()
190 }
191
192 #[test]
193 fn fresh_entity_confidence_is_base() {
194 let conn = setup();
195 let id = insert_entity(&conn, 1.0, 0.05);
196 let engine = ConfidenceEngine::default();
197 let conf = engine.get_confidence(&conn, id).unwrap();
198 assert!((conf - 1.0).abs() < 0.01, "expected ~1.0, got {conf}");
200 }
201
202 #[test]
203 fn confidence_decays_over_time() {
204 let engine = ConfidenceEngine::default();
205 let conf_now = engine.compute(1.0, 0.05, 0.0, 0, 0.0);
206 let conf_30d = engine.compute(1.0, 0.05, 30.0, 0, 0.0);
207 assert!(conf_30d < conf_now, "confidence should decay over time");
208 assert!(conf_30d > 0.0, "confidence should stay positive");
209 }
210
211 #[test]
212 fn access_reinforces_confidence() {
213 let engine = ConfidenceEngine::default();
214 let low = engine.compute(1.0, 0.05, 30.0, 0, 0.0);
215 let high = engine.compute(1.0, 0.05, 30.0, 10, 0.0);
216 assert!(high > low, "access should reinforce confidence");
217 }
218
219 #[test]
220 fn feedback_adjusts_confidence() {
221 let conn = setup();
222 let id = insert_entity(&conn, 0.8, 0.0); let engine = ConfidenceEngine::default();
224
225 let before = engine.get_confidence(&conn, id).unwrap();
227 assert!((before - 0.8).abs() < 0.01, "expected ~0.8, got {before}");
228
229 let after = engine.update_confidence(&conn, id, -0.2).unwrap();
230 assert!((after - 0.6).abs() < 0.01, "expected ~0.6, got {after}");
231 }
232
233 #[test]
234 fn feedback_sum_bounded() {
235 let engine = ConfidenceEngine::default();
236 let c = engine.compute(1.0, 0.0, 0.0, 0, -5.0);
238 assert!((c - 0.0).abs() < 1e-9, "expected 0.0, got {c}");
240 }
241
242 #[test]
243 fn change_logged_to_confidence_log() {
244 let conn = setup();
245 let id = insert_entity(&conn, 1.0, 0.0);
246 let engine = ConfidenceEngine::default();
247 engine.update_confidence(&conn, id, -0.1).unwrap();
248
249 let count: i64 = conn
250 .query_row(
251 "SELECT COUNT(*) FROM kg_confidence_log WHERE entity_id = ?1 AND reason = 'feedback'",
252 [id],
253 |r| r.get(0),
254 )
255 .unwrap();
256 assert_eq!(count, 1, "feedback entry should be logged");
257 }
258}