Skip to main content

suture_core/patch/
merge.rs

1//! Three-way merge algorithm for patch sets.
2//!
3//! Given a common base and two branches (A and B), the merge algorithm:
4//! 1. Identifies patches unique to each branch
5//! 2. Checks for conflicts (overlapping touch sets)
6//! 3. Produces a merged patch set with conflict nodes where needed
7//!
8//! # Correctness
9//!
10//! Per THM-MERGE-001 (YP-ALGEBRA-PATCH-001):
11//! The merge result is deterministic and independent of branch processing order.
12
13use crate::patch::commute::commute;
14use crate::patch::conflict::Conflict;
15use crate::patch::types::{Patch, PatchId};
16use std::collections::{HashMap, HashSet};
17use thiserror::Error;
18
19/// Errors that can occur during merge operations.
20#[derive(Error, Debug)]
21pub enum MergeError {
22    #[error("patch not found: {0}")]
23    PatchNotFound(String),
24
25    #[error("no common ancestor found between branches")]
26    NoCommonAncestor,
27
28    #[error("merge already in progress")]
29    MergeInProgress,
30
31    #[error("empty branch: {0}")]
32    EmptyBranch(String),
33
34    #[error("{0}")]
35    Custom(String),
36}
37
38/// Result of a merge operation.
39#[derive(Clone, Debug)]
40pub struct MergeResult {
41    /// Patches that are in both branches (already applied, include once).
42    pub common_patches: Vec<PatchId>,
43    /// Patches unique to branch A (applied in order).
44    pub patches_a_only: Vec<PatchId>,
45    /// Patches unique to branch B (applied in order).
46    pub patches_b_only: Vec<PatchId>,
47    /// Conflicts detected between patches from different branches.
48    pub conflicts: Vec<Conflict>,
49    /// Whether the merge is clean (no conflicts).
50    pub is_clean: bool,
51}
52
53impl MergeResult {
54    /// Get all patch IDs that should be in the merged result.
55    pub fn all_patch_ids(&self) -> Vec<PatchId> {
56        let mut ids = Vec::new();
57        ids.extend(self.common_patches.iter());
58        ids.extend(self.patches_a_only.iter());
59        ids.extend(self.patches_b_only.iter());
60        ids
61    }
62}
63
64/// Perform a three-way merge of two patch sets.
65///
66/// # Algorithm (ALG-MERGE-001)
67///
68/// 1. Compute unique patches on each branch: patches not in the base
69/// 2. For each pair (P_a, P_b) where P_a is unique to A and P_b is unique to B:
70///    a. Check if they commute (disjoint touch sets)
71///    b. If not, create a conflict node
72/// 3. Return the merged patch set + conflicts
73///
74/// # Arguments
75///
76/// * `base_patches` - Patches in the common ancestor
77/// * `branch_a_patches` - Patches on branch A (in application order)
78/// * `branch_b_patches` - Patches on branch B (in application order)
79/// * `all_patches` - HashMap of PatchId -> Patch for looking up patch details
80pub fn merge(
81    base_patches: &[PatchId],
82    branch_a_patches: &[PatchId],
83    branch_b_patches: &[PatchId],
84    all_patches: &HashMap<PatchId, Patch>,
85) -> Result<MergeResult, MergeError> {
86    let base_set: HashSet<&PatchId> = base_patches.iter().collect();
87
88    // Partition patches into common and unique
89    let patches_a_only: Vec<PatchId> = branch_a_patches
90        .iter()
91        .filter(|p| !base_set.contains(p))
92        .copied()
93        .collect();
94
95    let patches_b_only: Vec<PatchId> = branch_b_patches
96        .iter()
97        .filter(|p| !base_set.contains(p))
98        .copied()
99        .collect();
100
101    // Common patches (in base, also in both branches)
102    let branch_a_set: HashSet<&PatchId> = branch_a_patches.iter().collect();
103    let branch_b_set: HashSet<&PatchId> = branch_b_patches.iter().collect();
104    let common_patches: Vec<PatchId> = base_patches
105        .iter()
106        .filter(|p| branch_a_set.contains(p) && branch_b_set.contains(p))
107        .copied()
108        .collect();
109
110    // Detect conflicts between unique patches
111    let mut conflicts = Vec::new();
112
113    for patch_a_id in &patches_a_only {
114        let patch_a = all_patches
115            .get(patch_a_id)
116            .ok_or_else(|| MergeError::PatchNotFound(patch_a_id.to_hex()))?;
117
118        // Skip identity patches for conflict detection
119        if patch_a.is_identity() {
120            continue;
121        }
122
123        for patch_b_id in &patches_b_only {
124            let patch_b = all_patches
125                .get(patch_b_id)
126                .ok_or_else(|| MergeError::PatchNotFound(patch_b_id.to_hex()))?;
127
128            if patch_b.is_identity() {
129                continue;
130            }
131
132            match commute(patch_a, patch_b) {
133                crate::patch::CommuteResult::DoesNotCommute { conflict_addresses } => {
134                    conflicts.push(Conflict::new(*patch_a_id, *patch_b_id, conflict_addresses));
135                }
136                crate::patch::CommuteResult::Commutes => {
137                    // No conflict, both can be included
138                }
139            }
140        }
141    }
142
143    let is_clean = conflicts.is_empty();
144
145    Ok(MergeResult {
146        common_patches,
147        patches_a_only,
148        patches_b_only,
149        conflicts,
150        is_clean,
151    })
152}
153
154/// Detect all conflicts between two patch sets without performing a full merge.
155///
156/// This is useful for showing a preview of what would conflict before
157/// actually committing to a merge.
158pub fn detect_conflicts(patches_a: &[Patch], patches_b: &[Patch]) -> Vec<Conflict> {
159    let mut conflicts = Vec::new();
160
161    for patch_a in patches_a {
162        if patch_a.is_identity() {
163            continue;
164        }
165        for patch_b in patches_b {
166            if patch_b.is_identity() {
167                continue;
168            }
169            match commute(patch_a, patch_b) {
170                crate::patch::CommuteResult::DoesNotCommute { conflict_addresses } => {
171                    conflicts.push(Conflict::new(patch_a.id, patch_b.id, conflict_addresses));
172                }
173                crate::patch::CommuteResult::Commutes => {}
174            }
175        }
176    }
177
178    conflicts
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::patch::types::{OperationType, Patch, TouchSet};
185
186    fn patch(addr: &str, name: &str) -> Patch {
187        Patch::new(
188            OperationType::Modify,
189            TouchSet::single(addr),
190            Some(format!("file_{}", addr)),
191            vec![],
192            vec![],
193            name.to_string(),
194            format!("edit {}", addr),
195        )
196    }
197
198    fn make_patches(patches: &[Patch]) -> (Vec<PatchId>, HashMap<PatchId, Patch>) {
199        let ids: Vec<PatchId> = patches.iter().map(|p| p.id).collect();
200        let map: HashMap<PatchId, Patch> = patches.iter().map(|p| (p.id, p.clone())).collect();
201        (ids, map)
202    }
203
204    #[test]
205    fn test_clean_merge_disjoint() {
206        let base = patch("Z0", "base");
207        let pa = patch("A1", "branch_a");
208        let pb = patch("B1", "branch_b");
209
210        let (base_ids, mut all) = make_patches(std::slice::from_ref(&base));
211        let (a_ids, a_map) = make_patches(&[base.clone(), pa.clone()]);
212        let (b_ids, b_map) = make_patches(&[base.clone(), pb.clone()]);
213
214        all.extend(a_map);
215        all.extend(b_map);
216
217        let result = merge(&base_ids, &a_ids, &b_ids, &all).unwrap();
218        assert!(result.is_clean);
219        assert!(result.conflicts.is_empty());
220        assert!(result.patches_a_only.contains(&pa.id));
221        assert!(result.patches_b_only.contains(&pb.id));
222    }
223
224    #[test]
225    fn test_conflicting_merge() {
226        let base = patch("Z0", "base");
227        let pa = patch("A1", "branch_a");
228        let pb = patch("A1", "branch_b"); // Same address!
229
230        let (base_ids, mut all) = make_patches(std::slice::from_ref(&base));
231        let (a_ids, a_map) = make_patches(&[base.clone(), pa.clone()]);
232        let (b_ids, b_map) = make_patches(&[base.clone(), pb.clone()]);
233
234        all.extend(a_map);
235        all.extend(b_map);
236
237        let result = merge(&base_ids, &a_ids, &b_ids, &all).unwrap();
238        assert!(!result.is_clean);
239        assert_eq!(result.conflicts.len(), 1);
240        assert_eq!(result.conflicts[0].conflict_addresses, vec!["A1"]);
241    }
242
243    #[test]
244    fn test_empty_branches() {
245        let base = patch("Z0", "base");
246        let (base_ids, all) = make_patches(&[base]);
247
248        let result = merge(&base_ids, &[], &[], &all).unwrap();
249        assert!(result.is_clean);
250        assert!(result.patches_a_only.is_empty());
251        assert!(result.patches_b_only.is_empty());
252    }
253
254    #[test]
255    fn test_single_branch_changed() {
256        let base = patch("Z0", "base");
257        let pa = patch("A1", "branch_a");
258
259        let (base_ids, mut all) = make_patches(std::slice::from_ref(&base));
260        let (a_ids, a_map) = make_patches(&[base.clone(), pa.clone()]);
261        all.extend(a_map);
262
263        let result = merge(&base_ids, &a_ids, &base_ids, &all).unwrap();
264        assert!(result.is_clean);
265        assert!(result.patches_a_only.contains(&pa.id));
266        assert!(result.patches_b_only.is_empty());
267    }
268
269    #[test]
270    fn test_merge_deterministic() {
271        let base = patch("Z0", "base");
272        let pa1 = patch("A1", "a1");
273        let pa2 = patch("A2", "a2");
274        let pb1 = patch("B1", "b1");
275        let pb2 = patch("B2", "b2");
276
277        let (base_ids, mut all) = make_patches(std::slice::from_ref(&base));
278        let (a_ids, a_map) = make_patches(&[base.clone(), pa1.clone(), pa2.clone()]);
279        let (b_ids, b_map) = make_patches(&[base.clone(), pb1.clone(), pb2.clone()]);
280        all.extend(a_map);
281        all.extend(b_map);
282
283        let r1 = merge(&base_ids, &a_ids, &b_ids, &all).unwrap();
284        let r2 = merge(&base_ids, &b_ids, &a_ids, &all).unwrap();
285
286        // Both results should have the same set of unique patches
287        let mut ids1 = r1.all_patch_ids();
288        let mut ids2 = r2.all_patch_ids();
289        ids1.sort();
290        ids2.sort();
291        assert_eq!(
292            ids1, ids2,
293            "Merge must be deterministic regardless of order"
294        );
295        assert_eq!(r1.conflicts.len(), r2.conflicts.len());
296    }
297
298    #[test]
299    fn test_detect_conflicts() {
300        let pa = patch("A1", "a");
301        let pb = patch("A1", "b"); // Same address
302
303        let conflicts = detect_conflicts(std::slice::from_ref(&pa), std::slice::from_ref(&pb));
304        assert_eq!(conflicts.len(), 1);
305
306        let pc = patch("C1", "c"); // Different address
307        let no_conflicts = detect_conflicts(&[pa], &[pc]);
308        assert!(no_conflicts.is_empty());
309    }
310
311    #[test]
312    fn test_partial_overlap_merge() {
313        let base = patch("Z0", "base");
314        let pa1 = patch("A1", "a1");
315        let pa2 = patch("B1", "a2"); // Overlaps with pb2
316        let pb1 = patch("C1", "b1"); // No overlap
317        let pb2 = patch("B1", "b2"); // Overlaps with pa2
318
319        let (base_ids, mut all) = make_patches(std::slice::from_ref(&base));
320        let (a_ids, a_map) = make_patches(&[base.clone(), pa1.clone(), pa2.clone()]);
321        let (b_ids, b_map) = make_patches(&[base.clone(), pb1.clone(), pb2.clone()]);
322        all.extend(a_map);
323        all.extend(b_map);
324
325        let result = merge(&base_ids, &a_ids, &b_ids, &all).unwrap();
326        assert!(!result.is_clean);
327        assert_eq!(result.conflicts.len(), 1); // Only B1 conflicts
328    }
329}