1use rusqlite::params;
4
5use super::store;
6use super::MergeStrategy;
7use crate::error::{Error, Result};
8
9pub fn version_merge(
14 conn: &rusqlite::Connection,
15 source_ids: &[i64],
16 target_name: &str,
17 strategy: MergeStrategy,
18) -> Result<i64> {
19 if source_ids.len() < 2 {
20 return Err(Error::InvalidMerge(
21 "merge requires at least 2 source versions".to_string(),
22 ));
23 }
24
25 let mut source_mask: i64 = 0;
27 for &sid in source_ids {
28 source_mask |= store::version_bit_for(conn, sid)?;
29 }
30
31 let tx = conn.unchecked_transaction()?;
32
33 let new_id = store::create_version(
34 &tx,
35 target_name,
36 "main",
37 Some(source_ids[0]),
38 Some(&format!("merge of {:?}", source_ids)),
39 )?;
40 tx.execute(
41 "UPDATE kg_versions SET is_merged = 1 WHERE id = ?1",
42 [new_id],
43 )?;
44
45 let new_bit = store::version_bit_for(&tx, new_id)?;
46 match strategy {
47 MergeStrategy::Union => apply_merge(&tx, new_bit, source_mask, MergeStrategy::Union)?,
48 MergeStrategy::Intersection => {
49 apply_merge(&tx, new_bit, source_mask, MergeStrategy::Intersection)?
50 }
51 }
52
53 tx.commit()?;
54 Ok(new_id)
55}
56
57fn apply_merge(
61 conn: &rusqlite::Connection,
62 new_bit: i64,
63 source_mask: i64,
64 strategy: MergeStrategy,
65) -> Result<()> {
66 let predicate = match strategy {
67 MergeStrategy::Union => "(validity & ?2) != 0",
68 MergeStrategy::Intersection => "(validity & ?2) = ?2",
69 };
70
71 for table in ["kg_entities", "kg_relations"] {
72 conn.execute(
74 &format!(
75 "UPDATE {table} SET validity = validity | ?1 \
76 WHERE validity IS NOT NULL AND {predicate}"
77 ),
78 params![new_bit, source_mask],
79 )?;
80 }
81 Ok(())
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use rusqlite::Connection;
88
89 fn setup() -> Connection {
90 let conn = Connection::open_in_memory().unwrap();
91 crate::schema::create_schema(&conn).unwrap();
92 conn
93 }
94
95 fn add_entity(conn: &Connection, name: &str) -> i64 {
96 conn.execute(
97 "INSERT INTO kg_entities (entity_type, name) VALUES ('test', ?1)",
98 [name],
99 )
100 .unwrap();
101 conn.last_insert_rowid()
102 }
103
104 fn add_relation(conn: &Connection, src: i64, tgt: i64) -> i64 {
105 conn.execute(
106 "INSERT INTO kg_relations (source_id, target_id, rel_type) VALUES (?1, ?2, 'rel')",
107 rusqlite::params![src, tgt],
108 )
109 .unwrap();
110 conn.last_insert_rowid()
111 }
112
113 fn make_version(conn: &Connection, name: &str) -> i64 {
114 super::super::store::create_version(conn, name, "main", None, None).unwrap()
115 }
116
117 fn set_validity(conn: &Connection, table: &str, id: i64, val: i64) {
118 conn.execute(
119 &format!("UPDATE {table} SET validity = ?1 WHERE id = ?2"),
120 rusqlite::params![val, id],
121 )
122 .unwrap();
123 }
124
125 fn get_validity(conn: &Connection, table: &str, id: i64) -> Option<i64> {
126 conn.query_row(
127 &format!("SELECT validity FROM {table} WHERE id = ?1"),
128 [id],
129 |r| r.get(0),
130 )
131 .unwrap()
132 }
133
134 #[test]
135 fn test_union_merge() {
136 let conn = setup();
137 let e1 = add_entity(&conn, "A");
138 let e2 = add_entity(&conn, "B");
139 let e3 = add_entity(&conn, "C");
140 let v1 = make_version(&conn, "v1");
141 let v2 = make_version(&conn, "v2");
142
143 set_validity(&conn, "kg_entities", e1, 0b01);
145 set_validity(&conn, "kg_entities", e2, 0b01);
146 set_validity(&conn, "kg_entities", e3, 0b10);
148 conn.execute(
149 "UPDATE kg_entities SET validity = validity | 2 WHERE id = ?1",
150 [e2],
151 )
152 .unwrap(); let merged = version_merge(&conn, &[v1, v2], "merged-union", MergeStrategy::Union).unwrap();
155 let mb = store::version_bit_for(&conn, merged).unwrap();
156
157 assert!(get_validity(&conn, "kg_entities", e1).unwrap() & mb != 0);
159 assert!(get_validity(&conn, "kg_entities", e2).unwrap() & mb != 0);
160 assert!(get_validity(&conn, "kg_entities", e3).unwrap() & mb != 0);
161 }
162
163 #[test]
164 fn test_intersection_merge() {
165 let conn = setup();
166 let e1 = add_entity(&conn, "A");
167 let e2 = add_entity(&conn, "B");
168 let e3 = add_entity(&conn, "C");
169 let v1 = make_version(&conn, "v1");
170 let v2 = make_version(&conn, "v2");
171
172 set_validity(&conn, "kg_entities", e1, 0b01);
174 set_validity(&conn, "kg_entities", e2, 0b11); set_validity(&conn, "kg_entities", e3, 0b10);
177
178 let merged = version_merge(
179 &conn,
180 &[v1, v2],
181 "merged-intersect",
182 MergeStrategy::Intersection,
183 )
184 .unwrap();
185 let mb = store::version_bit_for(&conn, merged).unwrap();
186
187 assert!(get_validity(&conn, "kg_entities", e1).unwrap() & mb == 0);
189 assert!(get_validity(&conn, "kg_entities", e2).unwrap() & mb != 0);
190 assert!(get_validity(&conn, "kg_entities", e3).unwrap() & mb == 0);
191 }
192
193 #[test]
194 fn test_merge_applies_to_relations() {
195 let conn = setup();
196 let e1 = add_entity(&conn, "A");
197 let e2 = add_entity(&conn, "B");
198 let e3 = add_entity(&conn, "C");
199 let r1 = add_relation(&conn, e1, e2);
200 let r2 = add_relation(&conn, e2, e3);
201 let v1 = make_version(&conn, "v1");
202 let v2 = make_version(&conn, "v2");
203
204 set_validity(&conn, "kg_relations", r1, 0b11); set_validity(&conn, "kg_relations", r2, 0b10); let mu = version_merge(&conn, &[v1, v2], "u", MergeStrategy::Union).unwrap();
209 let mub = store::version_bit_for(&conn, mu).unwrap();
210 assert!(get_validity(&conn, "kg_relations", r1).unwrap() & mub != 0);
211 assert!(get_validity(&conn, "kg_relations", r2).unwrap() & mub != 0);
212
213 let mi = version_merge(&conn, &[v1, v2], "i", MergeStrategy::Intersection).unwrap();
214 let mib = store::version_bit_for(&conn, mi).unwrap();
215 assert!(get_validity(&conn, "kg_relations", r1).unwrap() & mib != 0);
216 assert!(get_validity(&conn, "kg_relations", r2).unwrap() & mib == 0);
217 }
218
219 #[test]
220 fn test_merge_creates_version_row() {
221 let conn = setup();
222 let v1 = make_version(&conn, "v1");
223 let v2 = make_version(&conn, "v2");
224
225 let merged = version_merge(&conn, &[v1, v2], "merged", MergeStrategy::Union).unwrap();
226
227 let v = super::super::store::get_version(&conn, merged).unwrap();
228 assert_eq!(v.name, "merged");
229 assert_eq!(v.parent_id, Some(v1));
230 assert!(v.is_merged);
231 }
232
233 #[test]
234 fn test_delete_parent_after_merge_reclaims_slot() {
235 let conn = Connection::open_in_memory().unwrap();
238 conn.execute("PRAGMA foreign_keys = ON", []).unwrap();
239 crate::schema::create_schema(&conn).unwrap();
240
241 let v1 = make_version(&conn, "v1");
242 let v2 = make_version(&conn, "v2");
243 let merged = version_merge(&conn, &[v1, v2], "m", MergeStrategy::Union).unwrap();
244 let freed_bit = store::version_bit_for(&conn, v1).unwrap();
245
246 store::delete_version(&conn, v1).unwrap();
248
249 assert_eq!(store::get_version(&conn, merged).unwrap().parent_id, None);
251
252 let v3 = make_version(&conn, "v3");
254 assert_eq!(store::version_bit_for(&conn, v3).unwrap(), freed_bit);
255 }
256
257 #[test]
258 fn test_merge_single_source_rejected() {
259 let conn = setup();
260 let v1 = make_version(&conn, "v1");
261 let err = version_merge(&conn, &[v1], "bad", MergeStrategy::Union).unwrap_err();
262 assert!(matches!(err, Error::InvalidMerge(_)));
263 }
264
265 #[test]
266 fn test_merge_nonexistent_version_rejected() {
267 let conn = setup();
268 let v1 = make_version(&conn, "v1");
269 let err = version_merge(&conn, &[v1, 999], "bad", MergeStrategy::Union).unwrap_err();
270 assert!(matches!(err, Error::VersionNotFound(999)));
271 }
272}