1use crate::patch::commute::commute;
14use crate::patch::conflict::Conflict;
15use crate::patch::types::{Patch, PatchId};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19#[derive(Error, Debug)]
21pub enum MergeError {
22 #[error("patch not found: {0}")]
23 PatchNotFound(String),
24
25 #[error("no common ancestor found between branches")]
26 NoCommonAncestor,
27
28 #[error("merge already in progress")]
29 MergeInProgress,
30
31 #[error("empty branch: {0}")]
32 EmptyBranch(String),
33
34 #[error("{0}")]
35 Custom(String),
36}
37
38#[derive(Clone, Debug)]
40pub struct MergeResult {
41 pub common_patches: Vec<PatchId>,
43 pub patches_a_only: Vec<PatchId>,
45 pub patches_b_only: Vec<PatchId>,
47 pub conflicts: Vec<Conflict>,
49 pub is_clean: bool,
51}
52
53impl MergeResult {
54 pub fn all_patch_ids(&self) -> Vec<PatchId> {
56 let mut ids = Vec::new();
57 ids.extend(self.common_patches.iter());
58 ids.extend(self.patches_a_only.iter());
59 ids.extend(self.patches_b_only.iter());
60 ids
61 }
62}
63
64pub fn merge(
81 base_patches: &[PatchId],
82 branch_a_patches: &[PatchId],
83 branch_b_patches: &[PatchId],
84 all_patches: &HashMap<PatchId, Patch>,
85) -> Result<MergeResult, MergeError> {
86 let base_set: HashSet<&PatchId> = base_patches.iter().collect();
87
88 let patches_a_only: Vec<PatchId> = branch_a_patches
90 .iter()
91 .filter(|p| !base_set.contains(p))
92 .copied()
93 .collect();
94
95 let patches_b_only: Vec<PatchId> = branch_b_patches
96 .iter()
97 .filter(|p| !base_set.contains(p))
98 .copied()
99 .collect();
100
101 let branch_a_set: HashSet<&PatchId> = branch_a_patches.iter().collect();
103 let branch_b_set: HashSet<&PatchId> = branch_b_patches.iter().collect();
104 let common_patches: Vec<PatchId> = base_patches
105 .iter()
106 .filter(|p| branch_a_set.contains(p) && branch_b_set.contains(p))
107 .copied()
108 .collect();
109
110 let mut conflicts = Vec::new();
112
113 for patch_a_id in &patches_a_only {
114 let patch_a = all_patches
115 .get(patch_a_id)
116 .ok_or_else(|| MergeError::PatchNotFound(patch_a_id.to_hex()))?;
117
118 if patch_a.is_identity() {
120 continue;
121 }
122
123 for patch_b_id in &patches_b_only {
124 let patch_b = all_patches
125 .get(patch_b_id)
126 .ok_or_else(|| MergeError::PatchNotFound(patch_b_id.to_hex()))?;
127
128 if patch_b.is_identity() {
129 continue;
130 }
131
132 match commute(patch_a, patch_b) {
133 crate::patch::CommuteResult::DoesNotCommute { conflict_addresses } => {
134 conflicts.push(Conflict::new(*patch_a_id, *patch_b_id, conflict_addresses));
135 }
136 crate::patch::CommuteResult::Commutes => {
137 }
139 }
140 }
141 }
142
143 let is_clean = conflicts.is_empty();
144
145 Ok(MergeResult {
146 common_patches,
147 patches_a_only,
148 patches_b_only,
149 conflicts,
150 is_clean,
151 })
152}
153
154pub fn detect_conflicts(patches_a: &[Patch], patches_b: &[Patch]) -> Vec<Conflict> {
159 let mut conflicts = Vec::new();
160
161 for patch_a in patches_a {
162 if patch_a.is_identity() {
163 continue;
164 }
165 for patch_b in patches_b {
166 if patch_b.is_identity() {
167 continue;
168 }
169 match commute(patch_a, patch_b) {
170 crate::patch::CommuteResult::DoesNotCommute { conflict_addresses } => {
171 conflicts.push(Conflict::new(patch_a.id, patch_b.id, conflict_addresses));
172 }
173 crate::patch::CommuteResult::Commutes => {}
174 }
175 }
176 }
177
178 conflicts
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::patch::types::{OperationType, Patch, TouchSet};
185
186 fn patch(addr: &str, name: &str) -> Patch {
187 Patch::new(
188 OperationType::Modify,
189 TouchSet::single(addr),
190 Some(format!("file_{}", addr)),
191 vec![],
192 vec![],
193 name.to_string(),
194 format!("edit {}", addr),
195 )
196 }
197
198 fn make_patches(patches: &[Patch]) -> (Vec<PatchId>, HashMap<PatchId, Patch>) {
199 let ids: Vec<PatchId> = patches.iter().map(|p| p.id).collect();
200 let map: HashMap<PatchId, Patch> = patches.iter().map(|p| (p.id, p.clone())).collect();
201 (ids, map)
202 }
203
204 #[test]
205 fn test_clean_merge_disjoint() {
206 let base = patch("Z0", "base");
207 let pa = patch("A1", "branch_a");
208 let pb = patch("B1", "branch_b");
209
210 let (base_ids, mut all) = make_patches(std::slice::from_ref(&base));
211 let (a_ids, a_map) = make_patches(&[base.clone(), pa.clone()]);
212 let (b_ids, b_map) = make_patches(&[base.clone(), pb.clone()]);
213
214 all.extend(a_map);
215 all.extend(b_map);
216
217 let result = merge(&base_ids, &a_ids, &b_ids, &all).unwrap();
218 assert!(result.is_clean);
219 assert!(result.conflicts.is_empty());
220 assert!(result.patches_a_only.contains(&pa.id));
221 assert!(result.patches_b_only.contains(&pb.id));
222 }
223
224 #[test]
225 fn test_conflicting_merge() {
226 let base = patch("Z0", "base");
227 let pa = patch("A1", "branch_a");
228 let pb = patch("A1", "branch_b"); let (base_ids, mut all) = make_patches(std::slice::from_ref(&base));
231 let (a_ids, a_map) = make_patches(&[base.clone(), pa.clone()]);
232 let (b_ids, b_map) = make_patches(&[base.clone(), pb.clone()]);
233
234 all.extend(a_map);
235 all.extend(b_map);
236
237 let result = merge(&base_ids, &a_ids, &b_ids, &all).unwrap();
238 assert!(!result.is_clean);
239 assert_eq!(result.conflicts.len(), 1);
240 assert_eq!(result.conflicts[0].conflict_addresses, vec!["A1"]);
241 }
242
243 #[test]
244 fn test_empty_branches() {
245 let base = patch("Z0", "base");
246 let (base_ids, all) = make_patches(&[base]);
247
248 let result = merge(&base_ids, &[], &[], &all).unwrap();
249 assert!(result.is_clean);
250 assert!(result.patches_a_only.is_empty());
251 assert!(result.patches_b_only.is_empty());
252 }
253
254 #[test]
255 fn test_single_branch_changed() {
256 let base = patch("Z0", "base");
257 let pa = patch("A1", "branch_a");
258
259 let (base_ids, mut all) = make_patches(std::slice::from_ref(&base));
260 let (a_ids, a_map) = make_patches(&[base.clone(), pa.clone()]);
261 all.extend(a_map);
262
263 let result = merge(&base_ids, &a_ids, &base_ids, &all).unwrap();
264 assert!(result.is_clean);
265 assert!(result.patches_a_only.contains(&pa.id));
266 assert!(result.patches_b_only.is_empty());
267 }
268
269 #[test]
270 fn test_merge_deterministic() {
271 let base = patch("Z0", "base");
272 let pa1 = patch("A1", "a1");
273 let pa2 = patch("A2", "a2");
274 let pb1 = patch("B1", "b1");
275 let pb2 = patch("B2", "b2");
276
277 let (base_ids, mut all) = make_patches(std::slice::from_ref(&base));
278 let (a_ids, a_map) = make_patches(&[base.clone(), pa1.clone(), pa2.clone()]);
279 let (b_ids, b_map) = make_patches(&[base.clone(), pb1.clone(), pb2.clone()]);
280 all.extend(a_map);
281 all.extend(b_map);
282
283 let r1 = merge(&base_ids, &a_ids, &b_ids, &all).unwrap();
284 let r2 = merge(&base_ids, &b_ids, &a_ids, &all).unwrap();
285
286 let mut ids1 = r1.all_patch_ids();
288 let mut ids2 = r2.all_patch_ids();
289 ids1.sort();
290 ids2.sort();
291 assert_eq!(
292 ids1, ids2,
293 "Merge must be deterministic regardless of order"
294 );
295 assert_eq!(r1.conflicts.len(), r2.conflicts.len());
296 }
297
298 #[test]
299 fn test_detect_conflicts() {
300 let pa = patch("A1", "a");
301 let pb = patch("A1", "b"); let conflicts = detect_conflicts(std::slice::from_ref(&pa), std::slice::from_ref(&pb));
304 assert_eq!(conflicts.len(), 1);
305
306 let pc = patch("C1", "c"); let no_conflicts = detect_conflicts(&[pa], &[pc]);
308 assert!(no_conflicts.is_empty());
309 }
310
311 #[test]
312 fn test_partial_overlap_merge() {
313 let base = patch("Z0", "base");
314 let pa1 = patch("A1", "a1");
315 let pa2 = patch("B1", "a2"); let pb1 = patch("C1", "b1"); let pb2 = patch("B1", "b2"); let (base_ids, mut all) = make_patches(std::slice::from_ref(&base));
320 let (a_ids, a_map) = make_patches(&[base.clone(), pa1.clone(), pa2.clone()]);
321 let (b_ids, b_map) = make_patches(&[base.clone(), pb1.clone(), pb2.clone()]);
322 all.extend(a_map);
323 all.extend(b_map);
324
325 let result = merge(&base_ids, &a_ids, &b_ids, &all).unwrap();
326 assert!(!result.is_clean);
327 assert_eq!(result.conflicts.len(), 1); }
329}