Skip to main content

sqlite_knowledge_graph/version/
merge.rs

1//! Version merge operations — union and intersection strategies.
2
3use rusqlite::params;
4
5use super::store;
6use super::MergeStrategy;
7use crate::error::{Error, Result};
8
9/// Merge two or more versions into a new version.
10///
11/// The entire merge — new-version creation, `is_merged` flag, and all validity
12/// updates — runs in one transaction, so a failure leaves no half-built version.
13pub fn version_merge(
14    conn: &rusqlite::Connection,
15    source_ids: &[i64],
16    target_name: &str,
17    strategy: MergeStrategy,
18) -> Result<i64> {
19    if source_ids.len() < 2 {
20        return Err(Error::InvalidMerge(
21            "merge requires at least 2 source versions".to_string(),
22        ));
23    }
24
25    // Combine source slots into one mask, validating each version exists.
26    let mut source_mask: i64 = 0;
27    for &sid in source_ids {
28        source_mask |= store::version_bit_for(conn, sid)?;
29    }
30
31    let tx = conn.unchecked_transaction()?;
32
33    let new_id = store::create_version(
34        &tx,
35        target_name,
36        "main",
37        Some(source_ids[0]),
38        Some(&format!("merge of {:?}", source_ids)),
39    )?;
40    tx.execute(
41        "UPDATE kg_versions SET is_merged = 1 WHERE id = ?1",
42        [new_id],
43    )?;
44
45    let new_bit = store::version_bit_for(&tx, new_id)?;
46    match strategy {
47        MergeStrategy::Union => apply_merge(&tx, new_bit, source_mask, MergeStrategy::Union)?,
48        MergeStrategy::Intersection => {
49            apply_merge(&tx, new_bit, source_mask, MergeStrategy::Intersection)?
50        }
51    }
52
53    tx.commit()?;
54    Ok(new_id)
55}
56
57/// Set `new_bit` on every entity and relation that matches the strategy:
58/// - Union: validity overlaps ANY source bit → `(validity & mask) != 0`
59/// - Intersection: validity covers ALL source bits → `(validity & mask) = mask`
60fn apply_merge(
61    conn: &rusqlite::Connection,
62    new_bit: i64,
63    source_mask: i64,
64    strategy: MergeStrategy,
65) -> Result<()> {
66    let predicate = match strategy {
67        MergeStrategy::Union => "(validity & ?2) != 0",
68        MergeStrategy::Intersection => "(validity & ?2) = ?2",
69    };
70
71    for table in ["kg_entities", "kg_relations"] {
72        // `table` is a hard-coded literal, never user input.
73        conn.execute(
74            &format!(
75                "UPDATE {table} SET validity = validity | ?1 \
76                 WHERE validity IS NOT NULL AND {predicate}"
77            ),
78            params![new_bit, source_mask],
79        )?;
80    }
81    Ok(())
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use rusqlite::Connection;
88
89    fn setup() -> Connection {
90        let conn = Connection::open_in_memory().unwrap();
91        crate::schema::create_schema(&conn).unwrap();
92        conn
93    }
94
95    fn add_entity(conn: &Connection, name: &str) -> i64 {
96        conn.execute(
97            "INSERT INTO kg_entities (entity_type, name) VALUES ('test', ?1)",
98            [name],
99        )
100        .unwrap();
101        conn.last_insert_rowid()
102    }
103
104    fn add_relation(conn: &Connection, src: i64, tgt: i64) -> i64 {
105        conn.execute(
106            "INSERT INTO kg_relations (source_id, target_id, rel_type) VALUES (?1, ?2, 'rel')",
107            rusqlite::params![src, tgt],
108        )
109        .unwrap();
110        conn.last_insert_rowid()
111    }
112
113    fn make_version(conn: &Connection, name: &str) -> i64 {
114        super::super::store::create_version(conn, name, "main", None, None).unwrap()
115    }
116
117    fn set_validity(conn: &Connection, table: &str, id: i64, val: i64) {
118        conn.execute(
119            &format!("UPDATE {table} SET validity = ?1 WHERE id = ?2"),
120            rusqlite::params![val, id],
121        )
122        .unwrap();
123    }
124
125    fn get_validity(conn: &Connection, table: &str, id: i64) -> Option<i64> {
126        conn.query_row(
127            &format!("SELECT validity FROM {table} WHERE id = ?1"),
128            [id],
129            |r| r.get(0),
130        )
131        .unwrap()
132    }
133
134    #[test]
135    fn test_union_merge() {
136        let conn = setup();
137        let e1 = add_entity(&conn, "A");
138        let e2 = add_entity(&conn, "B");
139        let e3 = add_entity(&conn, "C");
140        let v1 = make_version(&conn, "v1");
141        let v2 = make_version(&conn, "v2");
142
143        // v1: A, B
144        set_validity(&conn, "kg_entities", e1, 0b01);
145        set_validity(&conn, "kg_entities", e2, 0b01);
146        // v2: B, C
147        set_validity(&conn, "kg_entities", e3, 0b10);
148        conn.execute(
149            "UPDATE kg_entities SET validity = validity | 2 WHERE id = ?1",
150            [e2],
151        )
152        .unwrap(); // B in both: 0b11
153
154        let merged = version_merge(&conn, &[v1, v2], "merged-union", MergeStrategy::Union).unwrap();
155        let mb = store::version_bit_for(&conn, merged).unwrap();
156
157        // All three should have the merged version's bit
158        assert!(get_validity(&conn, "kg_entities", e1).unwrap() & mb != 0);
159        assert!(get_validity(&conn, "kg_entities", e2).unwrap() & mb != 0);
160        assert!(get_validity(&conn, "kg_entities", e3).unwrap() & mb != 0);
161    }
162
163    #[test]
164    fn test_intersection_merge() {
165        let conn = setup();
166        let e1 = add_entity(&conn, "A");
167        let e2 = add_entity(&conn, "B");
168        let e3 = add_entity(&conn, "C");
169        let v1 = make_version(&conn, "v1");
170        let v2 = make_version(&conn, "v2");
171
172        // v1: A, B
173        set_validity(&conn, "kg_entities", e1, 0b01);
174        set_validity(&conn, "kg_entities", e2, 0b11); // B in both
175                                                      // v2: B, C
176        set_validity(&conn, "kg_entities", e3, 0b10);
177
178        let merged = version_merge(
179            &conn,
180            &[v1, v2],
181            "merged-intersect",
182            MergeStrategy::Intersection,
183        )
184        .unwrap();
185        let mb = store::version_bit_for(&conn, merged).unwrap();
186
187        // Only B should be in the intersection
188        assert!(get_validity(&conn, "kg_entities", e1).unwrap() & mb == 0);
189        assert!(get_validity(&conn, "kg_entities", e2).unwrap() & mb != 0);
190        assert!(get_validity(&conn, "kg_entities", e3).unwrap() & mb == 0);
191    }
192
193    #[test]
194    fn test_merge_applies_to_relations() {
195        let conn = setup();
196        let e1 = add_entity(&conn, "A");
197        let e2 = add_entity(&conn, "B");
198        let e3 = add_entity(&conn, "C");
199        let r1 = add_relation(&conn, e1, e2);
200        let r2 = add_relation(&conn, e2, e3);
201        let v1 = make_version(&conn, "v1");
202        let v2 = make_version(&conn, "v2");
203
204        set_validity(&conn, "kg_relations", r1, 0b11); // both versions
205        set_validity(&conn, "kg_relations", r2, 0b10); // v2 only
206
207        // Union: both relations carry the merged bit; intersection: only r1.
208        let mu = version_merge(&conn, &[v1, v2], "u", MergeStrategy::Union).unwrap();
209        let mub = store::version_bit_for(&conn, mu).unwrap();
210        assert!(get_validity(&conn, "kg_relations", r1).unwrap() & mub != 0);
211        assert!(get_validity(&conn, "kg_relations", r2).unwrap() & mub != 0);
212
213        let mi = version_merge(&conn, &[v1, v2], "i", MergeStrategy::Intersection).unwrap();
214        let mib = store::version_bit_for(&conn, mi).unwrap();
215        assert!(get_validity(&conn, "kg_relations", r1).unwrap() & mib != 0);
216        assert!(get_validity(&conn, "kg_relations", r2).unwrap() & mib == 0);
217    }
218
219    #[test]
220    fn test_merge_creates_version_row() {
221        let conn = setup();
222        let v1 = make_version(&conn, "v1");
223        let v2 = make_version(&conn, "v2");
224
225        let merged = version_merge(&conn, &[v1, v2], "merged", MergeStrategy::Union).unwrap();
226
227        let v = super::super::store::get_version(&conn, merged).unwrap();
228        assert_eq!(v.name, "merged");
229        assert_eq!(v.parent_id, Some(v1));
230        assert!(v.is_merged);
231    }
232
233    #[test]
234    fn test_delete_parent_after_merge_reclaims_slot() {
235        // Production connections enable foreign keys; ON DELETE SET NULL only
236        // fires under enforcement, so opt in explicitly here.
237        let conn = Connection::open_in_memory().unwrap();
238        conn.execute("PRAGMA foreign_keys = ON", []).unwrap();
239        crate::schema::create_schema(&conn).unwrap();
240
241        let v1 = make_version(&conn, "v1");
242        let v2 = make_version(&conn, "v2");
243        let merged = version_merge(&conn, &[v1, v2], "m", MergeStrategy::Union).unwrap();
244        let freed_bit = store::version_bit_for(&conn, v1).unwrap();
245
246        // v1 is the merged version's parent; deleting it must succeed (not RESTRICT).
247        store::delete_version(&conn, v1).unwrap();
248
249        // The merged child's parent_id was nulled, not left dangling.
250        assert_eq!(store::get_version(&conn, merged).unwrap().parent_id, None);
251
252        // v1's slot is reclaimed: the next version reuses that exact bit.
253        let v3 = make_version(&conn, "v3");
254        assert_eq!(store::version_bit_for(&conn, v3).unwrap(), freed_bit);
255    }
256
257    #[test]
258    fn test_merge_single_source_rejected() {
259        let conn = setup();
260        let v1 = make_version(&conn, "v1");
261        let err = version_merge(&conn, &[v1], "bad", MergeStrategy::Union).unwrap_err();
262        assert!(matches!(err, Error::InvalidMerge(_)));
263    }
264
265    #[test]
266    fn test_merge_nonexistent_version_rejected() {
267        let conn = setup();
268        let v1 = make_version(&conn, "v1");
269        let err = version_merge(&conn, &[v1, 999], "bad", MergeStrategy::Union).unwrap_err();
270        assert!(matches!(err, Error::VersionNotFound(999)));
271    }
272}