Skip to main content

suture_core/dag/
branch.rs

1//! Branch management utilities for the Patch DAG.
2//!
3//! Provides convenience methods for common branch operations:
4//! - Creating branches from existing points
5//! - Listing branch histories
6//! - Computing diffs between branches
7
8use crate::dag::graph::{DagError, PatchDag};
9use crate::patch::types::PatchId;
10use suture_common::BranchName;
11
12impl PatchDag {
13    /// Check if a branch exists.
14    pub fn branch_exists(&self, name: &BranchName) -> bool {
15        self.branches.contains_key(name.as_str())
16    }
17
18    /// Get the current HEAD branch (the branch named "main", or the first branch).
19    pub fn head(&self) -> Option<(String, PatchId)> {
20        self.branches
21            .get("main")
22            .map(|id| ("main".to_string(), *id))
23            .or_else(|| {
24                self.branches
25                    .iter()
26                    .next()
27                    .map(|(name, id)| (name.clone(), *id))
28            })
29    }
30
31    /// Get the number of commits ahead/behind between two branches.
32    ///
33    /// Returns `(ahead, behind)` where:
34    /// - `ahead` is the number of patches on `branch_a` since the LCA
35    /// - `behind` is the number of patches on `branch_b` since the LCA
36    ///
37    /// Computed by finding the Lowest Common Ancestor (LCA) and counting
38    /// patches on each branch's chain between the LCA and the branch tip.
39    pub fn branch_divergence(
40        &self,
41        branch_a: &BranchName,
42        branch_b: &BranchName,
43    ) -> Result<(usize, usize), DagError> {
44        let target_a = self
45            .get_branch(branch_a)
46            .ok_or_else(|| DagError::BranchNotFound(branch_a.as_str().to_string()))?;
47        let target_b = self
48            .get_branch(branch_b)
49            .ok_or_else(|| DagError::BranchNotFound(branch_b.as_str().to_string()))?;
50
51        // Find the LCA (the most recent common ancestor)
52        let lca_id = self
53            .lca(&target_a, &target_b)
54            .ok_or_else(|| DagError::Custom("no common ancestor found".to_string()))?;
55
56        // Get patch chains from each tip back to root (tip-first order)
57        let chain_a = self.patch_chain(&target_a);
58        let chain_b = self.patch_chain(&target_b);
59
60        // Find the LCA position in each chain
61        // ahead = patches on branch_a between tip and LCA (exclusive of LCA)
62        let ahead = match chain_a.iter().position(|id| *id == lca_id) {
63            Some(pos) => pos,
64            None => chain_a.len(),
65        };
66
67        // behind = patches on branch_b between tip and LCA (exclusive of LCA)
68        let behind = match chain_b.iter().position(|id| *id == lca_id) {
69            Some(pos) => pos,
70            None => chain_b.len(),
71        };
72
73        Ok((ahead, behind))
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use crate::dag::graph::PatchDag;
81    use crate::patch::types::{OperationType, Patch, TouchSet};
82
83    fn make_patch(addr: &str) -> Patch {
84        Patch::new(
85            OperationType::Modify,
86            TouchSet::single(addr),
87            None,
88            vec![],
89            vec![],
90            "test".to_string(),
91            format!("edit {}", addr),
92        )
93    }
94
95    #[test]
96    fn test_branch_exists() {
97        let mut dag = PatchDag::new();
98        let root = make_patch("root");
99        let root_id = dag.add_patch(root, vec![]).unwrap();
100
101        let main = BranchName::new("main").unwrap();
102        dag.create_branch(main.clone(), root_id).unwrap();
103
104        assert!(dag.branch_exists(&main));
105        assert!(!dag.branch_exists(&BranchName::new("nonexistent").unwrap()));
106    }
107
108    #[test]
109    fn test_head() {
110        let mut dag = PatchDag::new();
111        let root = make_patch("root");
112        let root_id = dag.add_patch(root, vec![]).unwrap();
113
114        assert!(dag.head().is_none());
115
116        let main = BranchName::new("main").unwrap();
117        dag.create_branch(main.clone(), root_id).unwrap();
118
119        let head = dag.head().unwrap();
120        assert_eq!(head.0, "main");
121        assert_eq!(head.1, root_id);
122    }
123
124    #[test]
125    fn test_branch_divergence() {
126        let mut dag = PatchDag::new();
127        let root = make_patch("root");
128        let root_id = dag.add_patch(root, vec![]).unwrap();
129
130        // Create main branch
131        let main = BranchName::new("main").unwrap();
132        dag.create_branch(main.clone(), root_id).unwrap();
133
134        // Add a commit to main
135        let mc = make_patch("main_commit");
136        let mc_id = dag.add_patch(mc, vec![root_id]).unwrap();
137        dag.update_branch(&main, mc_id).unwrap();
138
139        // Create feature branch from root
140        let feat = BranchName::new("feature").unwrap();
141        dag.create_branch(feat.clone(), root_id).unwrap();
142
143        // Add a commit to feature
144        let fc = make_patch("feat_commit");
145        let fc_id = dag.add_patch(fc, vec![root_id]).unwrap();
146        dag.update_branch(&feat, fc_id).unwrap();
147
148        let (ahead, behind) = dag.branch_divergence(&main, &feat).unwrap();
149        // Both branches diverge from root_id as LCA.
150        // Each has 1 patch between tip and LCA.
151        assert_eq!(ahead, 1, "main should be 1 ahead of feature");
152        assert_eq!(behind, 1, "main should be 1 behind feature");
153    }
154}