Skip to main content

sqlite_knowledge_graph/graph/
ripple.rs

1//! Ripple propagation of confidence penalties along dependency edges.
2//!
3//! When an entity becomes stale or its confidence drops, dependent entities
4//! receive an attenuated penalty via BFS up to `MAX_DEPTH = 2` hops.
5//! Every change is appended to `kg_confidence_log`.
6
7use crate::error::Result;
8use crate::vector::confidence::now_unix;
9use rusqlite::Connection;
10use std::collections::{HashSet, VecDeque};
11use tracing::debug;
12
13const MAX_DEPTH: usize = 2;
14const ATTENUATION: f64 = 0.5;
15
16// ─────────────────────────────────────────────────────────────────────────────
17// Public API
18// ─────────────────────────────────────────────────────────────────────────────
19
20/// Propagate a confidence penalty from `origin_id` to all entities that
21/// transitively depend on it, up to `MAX_DEPTH` hops.
22///
23/// Penalty at hop *h* = `base_penalty * ATTENUATION^h`.
24pub fn propagate(conn: &Connection, origin_id: i64, base_penalty: f64) -> Result<()> {
25    // BFS queue: (entity_id, hop_depth)
26    let mut queue: VecDeque<(i64, usize)> = VecDeque::new();
27    let mut visited: HashSet<i64> = HashSet::new();
28    visited.insert(origin_id);
29    queue.push_back((origin_id, 0));
30
31    while let Some((current_id, depth)) = queue.pop_front() {
32        if depth >= MAX_DEPTH {
33            continue;
34        }
35
36        let next_depth = depth + 1;
37        let actual_penalty = base_penalty * ATTENUATION.powi(next_depth as i32);
38
39        // Find entities whose confidence depends_on current_id
40        let dependents = dependents_of(conn, current_id)?;
41
42        for dep_id in dependents {
43            if visited.contains(&dep_id) {
44                continue;
45            }
46            visited.insert(dep_id);
47
48            apply_penalty(conn, dep_id, actual_penalty)?;
49            debug!(
50                dep_id,
51                depth = next_depth,
52                actual_penalty,
53                "ripple penalty applied"
54            );
55            queue.push_back((dep_id, next_depth));
56        }
57    }
58
59    Ok(())
60}
61
62/// Insert a dependency edge: `source_id` depends on `target_id`.
63///
64/// `dep_type` should be one of: `depends_on`, `depended_by`, `supersedes`,
65/// `contradicts`.
66pub fn add_dependency(
67    conn: &Connection,
68    source_id: i64,
69    target_id: i64,
70    dep_type: &str,
71) -> Result<()> {
72    conn.execute(
73        "INSERT INTO kg_dependencies (source_id, target_id, dep_type, created_at) \
74         VALUES (?1, ?2, ?3, ?4)",
75        rusqlite::params![source_id, target_id, dep_type, now_unix()],
76    )?;
77    Ok(())
78}
79
80// ─────────────────────────────────────────────────────────────────────────────
81// Helpers
82// ─────────────────────────────────────────────────────────────────────────────
83
84fn dependents_of(conn: &Connection, target_id: i64) -> Result<Vec<i64>> {
85    let mut stmt = conn.prepare(
86        "SELECT source_id FROM kg_dependencies \
87         WHERE target_id = ?1 AND dep_type = 'depends_on'",
88    )?;
89    let ids = stmt
90        .query_map([target_id], |r| r.get::<_, i64>(0))?
91        .collect::<std::result::Result<Vec<_>, _>>()?;
92    Ok(ids)
93}
94
95fn apply_penalty(conn: &Connection, entity_id: i64, penalty: f64) -> Result<()> {
96    let old_conf: f64 = conn.query_row(
97        "SELECT COALESCE(confidence, 1.0) FROM kg_entities WHERE id = ?1",
98        [entity_id],
99        |r| r.get(0),
100    )?;
101    let raw_conf = old_conf - penalty;
102    let new_conf = raw_conf.clamp(0.0, 1.0);
103
104    conn.execute(
105        "UPDATE kg_entities SET confidence = ?1 WHERE id = ?2",
106        rusqlite::params![new_conf, entity_id],
107    )?;
108    conn.execute(
109        "INSERT INTO kg_confidence_log \
110         (entity_id, old_value, new_value, reason, created_at) \
111         VALUES (?1, ?2, ?3, 'ripple', ?4)",
112        rusqlite::params![entity_id, old_conf, new_conf, now_unix()],
113    )?;
114    Ok(())
115}
116
117// ─────────────────────────────────────────────────────────────────────────────
118// Tests
119// ─────────────────────────────────────────────────────────────────────────────
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::schema::ensure_schema;
125
126    fn setup() -> Connection {
127        let conn = Connection::open_in_memory().unwrap();
128        ensure_schema(&conn).unwrap();
129        conn
130    }
131
132    fn insert_entity(conn: &Connection, name: &str) -> i64 {
133        conn.execute(
134            "INSERT INTO kg_entities (entity_type, name, confidence) VALUES ('t', ?1, 1.0)",
135            [name],
136        )
137        .unwrap();
138        conn.last_insert_rowid()
139    }
140
141    fn get_confidence(conn: &Connection, id: i64) -> f64 {
142        conn.query_row(
143            "SELECT confidence FROM kg_entities WHERE id = ?1",
144            [id],
145            |r| r.get(0),
146        )
147        .unwrap()
148    }
149
150    #[test]
151    fn add_dependency_inserts_row() {
152        let conn = setup();
153        let a = insert_entity(&conn, "A");
154        let b = insert_entity(&conn, "B");
155        add_dependency(&conn, a, b, "depends_on").unwrap();
156
157        let count: i64 = conn
158            .query_row("SELECT COUNT(*) FROM kg_dependencies", [], |r| r.get(0))
159            .unwrap();
160        assert_eq!(count, 1);
161    }
162
163    #[test]
164    fn ripple_penalises_direct_dependent() {
165        let conn = setup();
166        let b = insert_entity(&conn, "B");
167        let a = insert_entity(&conn, "A");
168        add_dependency(&conn, a, b, "depends_on").unwrap(); // A depends on B
169
170        propagate(&conn, b, 0.4).unwrap();
171
172        // A should have been penalised at depth 1: penalty = 0.4 * 0.5 = 0.2
173        let conf_a = get_confidence(&conn, a);
174        assert!((conf_a - 0.8).abs() < 1e-9, "expected 0.8, got {conf_a}");
175    }
176
177    #[test]
178    fn ripple_attenuates_at_depth_two() {
179        let conn = setup();
180        let c_node = insert_entity(&conn, "C"); // C depends on B
181        let b = insert_entity(&conn, "B"); // B depends on A
182        let a = insert_entity(&conn, "A"); // origin: A becomes stale
183
184        add_dependency(&conn, b, a, "depends_on").unwrap();
185        add_dependency(&conn, c_node, b, "depends_on").unwrap();
186
187        propagate(&conn, a, 0.4).unwrap();
188
189        let conf_b = get_confidence(&conn, b);
190        let conf_c = get_confidence(&conn, c_node);
191
192        // depth 1: 0.4 * 0.5 = 0.2 → conf_b = 0.8
193        assert!((conf_b - 0.8).abs() < 1e-9, "expected 0.8, got {conf_b}");
194        // depth 2: 0.4 * 0.25 = 0.1 → conf_c = 0.9
195        assert!((conf_c - 0.9).abs() < 1e-9, "expected 0.9, got {conf_c}");
196    }
197
198    #[test]
199    fn ripple_logs_changes() {
200        let conn = setup();
201        let b = insert_entity(&conn, "B");
202        let a = insert_entity(&conn, "A");
203        add_dependency(&conn, a, b, "depends_on").unwrap();
204
205        propagate(&conn, b, 0.2).unwrap();
206
207        let count: i64 = conn
208            .query_row(
209                "SELECT COUNT(*) FROM kg_confidence_log WHERE reason = 'ripple'",
210                [],
211                |r| r.get(0),
212            )
213            .unwrap();
214        assert_eq!(count, 1);
215    }
216
217    #[test]
218    fn ripple_stops_at_max_depth() {
219        let conn = setup();
220        // Chain: D → C → B → A  (A is origin)
221        let d = insert_entity(&conn, "D");
222        let c_node = insert_entity(&conn, "C");
223        let b = insert_entity(&conn, "B");
224        let a = insert_entity(&conn, "A");
225
226        add_dependency(&conn, b, a, "depends_on").unwrap();
227        add_dependency(&conn, c_node, b, "depends_on").unwrap();
228        add_dependency(&conn, d, c_node, "depends_on").unwrap();
229
230        propagate(&conn, a, 0.4).unwrap();
231
232        // D is at depth 3 — should NOT be penalised
233        let conf_d = get_confidence(&conn, d);
234        assert!((conf_d - 1.0).abs() < 1e-9, "D should be unaffected");
235    }
236
237    #[test]
238    fn apply_penalty_clamps_to_zero_when_penalty_exceeds_confidence() {
239        let conn = setup();
240        let a = insert_entity(&conn, "A");
241
242        // penalty larger than the starting confidence of 1.0
243        apply_penalty(&conn, a, 1.5).unwrap();
244
245        let conf = get_confidence(&conn, a);
246        assert!(conf >= 0.0, "confidence must not be negative, got {conf}");
247        assert!((conf - 0.0).abs() < 1e-9, "expected 0.0, got {conf}");
248    }
249}