sqlite_knowledge_graph/graph/
ripple.rs1use crate::error::Result;
8use crate::vector::confidence::now_unix;
9use rusqlite::Connection;
10use std::collections::{HashSet, VecDeque};
11use tracing::debug;
12
13const MAX_DEPTH: usize = 2;
14const ATTENUATION: f64 = 0.5;
15
16pub fn propagate(conn: &Connection, origin_id: i64, base_penalty: f64) -> Result<()> {
25 let mut queue: VecDeque<(i64, usize)> = VecDeque::new();
27 let mut visited: HashSet<i64> = HashSet::new();
28 visited.insert(origin_id);
29 queue.push_back((origin_id, 0));
30
31 while let Some((current_id, depth)) = queue.pop_front() {
32 if depth >= MAX_DEPTH {
33 continue;
34 }
35
36 let next_depth = depth + 1;
37 let actual_penalty = base_penalty * ATTENUATION.powi(next_depth as i32);
38
39 let dependents = dependents_of(conn, current_id)?;
41
42 for dep_id in dependents {
43 if visited.contains(&dep_id) {
44 continue;
45 }
46 visited.insert(dep_id);
47
48 apply_penalty(conn, dep_id, actual_penalty)?;
49 debug!(
50 dep_id,
51 depth = next_depth,
52 actual_penalty,
53 "ripple penalty applied"
54 );
55 queue.push_back((dep_id, next_depth));
56 }
57 }
58
59 Ok(())
60}
61
62pub fn add_dependency(
67 conn: &Connection,
68 source_id: i64,
69 target_id: i64,
70 dep_type: &str,
71) -> Result<()> {
72 conn.execute(
73 "INSERT INTO kg_dependencies (source_id, target_id, dep_type, created_at) \
74 VALUES (?1, ?2, ?3, ?4)",
75 rusqlite::params![source_id, target_id, dep_type, now_unix()],
76 )?;
77 Ok(())
78}
79
80fn dependents_of(conn: &Connection, target_id: i64) -> Result<Vec<i64>> {
85 let mut stmt = conn.prepare(
86 "SELECT source_id FROM kg_dependencies \
87 WHERE target_id = ?1 AND dep_type = 'depends_on'",
88 )?;
89 let ids = stmt
90 .query_map([target_id], |r| r.get::<_, i64>(0))?
91 .collect::<std::result::Result<Vec<_>, _>>()?;
92 Ok(ids)
93}
94
95fn apply_penalty(conn: &Connection, entity_id: i64, penalty: f64) -> Result<()> {
96 let old_conf: f64 = conn.query_row(
97 "SELECT COALESCE(confidence, 1.0) FROM kg_entities WHERE id = ?1",
98 [entity_id],
99 |r| r.get(0),
100 )?;
101 let raw_conf = old_conf - penalty;
102 let new_conf = raw_conf.clamp(0.0, 1.0);
103
104 conn.execute(
105 "UPDATE kg_entities SET confidence = ?1 WHERE id = ?2",
106 rusqlite::params![new_conf, entity_id],
107 )?;
108 conn.execute(
109 "INSERT INTO kg_confidence_log \
110 (entity_id, old_value, new_value, reason, created_at) \
111 VALUES (?1, ?2, ?3, 'ripple', ?4)",
112 rusqlite::params![entity_id, old_conf, new_conf, now_unix()],
113 )?;
114 Ok(())
115}
116
117#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::schema::ensure_schema;
125
126 fn setup() -> Connection {
127 let conn = Connection::open_in_memory().unwrap();
128 ensure_schema(&conn).unwrap();
129 conn
130 }
131
132 fn insert_entity(conn: &Connection, name: &str) -> i64 {
133 conn.execute(
134 "INSERT INTO kg_entities (entity_type, name, confidence) VALUES ('t', ?1, 1.0)",
135 [name],
136 )
137 .unwrap();
138 conn.last_insert_rowid()
139 }
140
141 fn get_confidence(conn: &Connection, id: i64) -> f64 {
142 conn.query_row(
143 "SELECT confidence FROM kg_entities WHERE id = ?1",
144 [id],
145 |r| r.get(0),
146 )
147 .unwrap()
148 }
149
150 #[test]
151 fn add_dependency_inserts_row() {
152 let conn = setup();
153 let a = insert_entity(&conn, "A");
154 let b = insert_entity(&conn, "B");
155 add_dependency(&conn, a, b, "depends_on").unwrap();
156
157 let count: i64 = conn
158 .query_row("SELECT COUNT(*) FROM kg_dependencies", [], |r| r.get(0))
159 .unwrap();
160 assert_eq!(count, 1);
161 }
162
163 #[test]
164 fn ripple_penalises_direct_dependent() {
165 let conn = setup();
166 let b = insert_entity(&conn, "B");
167 let a = insert_entity(&conn, "A");
168 add_dependency(&conn, a, b, "depends_on").unwrap(); propagate(&conn, b, 0.4).unwrap();
171
172 let conf_a = get_confidence(&conn, a);
174 assert!((conf_a - 0.8).abs() < 1e-9, "expected 0.8, got {conf_a}");
175 }
176
177 #[test]
178 fn ripple_attenuates_at_depth_two() {
179 let conn = setup();
180 let c_node = insert_entity(&conn, "C"); let b = insert_entity(&conn, "B"); let a = insert_entity(&conn, "A"); add_dependency(&conn, b, a, "depends_on").unwrap();
185 add_dependency(&conn, c_node, b, "depends_on").unwrap();
186
187 propagate(&conn, a, 0.4).unwrap();
188
189 let conf_b = get_confidence(&conn, b);
190 let conf_c = get_confidence(&conn, c_node);
191
192 assert!((conf_b - 0.8).abs() < 1e-9, "expected 0.8, got {conf_b}");
194 assert!((conf_c - 0.9).abs() < 1e-9, "expected 0.9, got {conf_c}");
196 }
197
198 #[test]
199 fn ripple_logs_changes() {
200 let conn = setup();
201 let b = insert_entity(&conn, "B");
202 let a = insert_entity(&conn, "A");
203 add_dependency(&conn, a, b, "depends_on").unwrap();
204
205 propagate(&conn, b, 0.2).unwrap();
206
207 let count: i64 = conn
208 .query_row(
209 "SELECT COUNT(*) FROM kg_confidence_log WHERE reason = 'ripple'",
210 [],
211 |r| r.get(0),
212 )
213 .unwrap();
214 assert_eq!(count, 1);
215 }
216
217 #[test]
218 fn ripple_stops_at_max_depth() {
219 let conn = setup();
220 let d = insert_entity(&conn, "D");
222 let c_node = insert_entity(&conn, "C");
223 let b = insert_entity(&conn, "B");
224 let a = insert_entity(&conn, "A");
225
226 add_dependency(&conn, b, a, "depends_on").unwrap();
227 add_dependency(&conn, c_node, b, "depends_on").unwrap();
228 add_dependency(&conn, d, c_node, "depends_on").unwrap();
229
230 propagate(&conn, a, 0.4).unwrap();
231
232 let conf_d = get_confidence(&conn, d);
234 assert!((conf_d - 1.0).abs() < 1e-9, "D should be unaffected");
235 }
236
237 #[test]
238 fn apply_penalty_clamps_to_zero_when_penalty_exceeds_confidence() {
239 let conn = setup();
240 let a = insert_entity(&conn, "A");
241
242 apply_penalty(&conn, a, 1.5).unwrap();
244
245 let conf = get_confidence(&conn, a);
246 assert!(conf >= 0.0, "confidence must not be negative, got {conf}");
247 assert!((conf - 0.0).abs() < 1e-9, "expected 0.0, got {conf}");
248 }
249}