rustic_git/commands/
branch.rs

1use crate::types::Hash;
2use crate::utils::git;
3use crate::{Repository, Result};
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum BranchType {
8    Local,
9    RemoteTracking,
10}
11
12impl fmt::Display for BranchType {
13    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
14        match self {
15            BranchType::Local => write!(f, "local"),
16            BranchType::RemoteTracking => write!(f, "remote-tracking"),
17        }
18    }
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct Branch {
23    pub name: String,
24    pub branch_type: BranchType,
25    pub is_current: bool,
26    pub commit_hash: Hash,
27    pub upstream: Option<String>,
28}
29
30impl Branch {
31    /// Check if this is a local branch
32    pub fn is_local(&self) -> bool {
33        matches!(self.branch_type, BranchType::Local)
34    }
35
36    /// Check if this is a remote-tracking branch
37    pub fn is_remote(&self) -> bool {
38        matches!(self.branch_type, BranchType::RemoteTracking)
39    }
40
41    /// Get the short name of the branch (without remote prefix for remote branches)
42    pub fn short_name(&self) -> &str {
43        if self.is_remote() && self.name.contains('/') {
44            self.name.split('/').nth(1).unwrap_or(&self.name)
45        } else {
46            &self.name
47        }
48    }
49}
50
51impl fmt::Display for Branch {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        let marker = if self.is_current { "*" } else { " " };
54        write!(f, "{} {}", marker, self.name)
55    }
56}
57
58#[derive(Debug, Clone, PartialEq)]
59pub struct BranchList {
60    branches: Box<[Branch]>,
61}
62
63impl BranchList {
64    /// Create a new BranchList from a vector of branches
65    pub fn new(branches: Vec<Branch>) -> Self {
66        Self {
67            branches: branches.into_boxed_slice(),
68        }
69    }
70
71    /// Get all branches
72    pub fn all(&self) -> &[Branch] {
73        &self.branches
74    }
75
76    /// Get an iterator over all branches
77    pub fn iter(&self) -> impl Iterator<Item = &Branch> {
78        self.branches.iter()
79    }
80
81    /// Get an iterator over local branches
82    pub fn local(&self) -> impl Iterator<Item = &Branch> {
83        self.branches.iter().filter(|b| b.is_local())
84    }
85
86    /// Get an iterator over remote-tracking branches
87    pub fn remote(&self) -> impl Iterator<Item = &Branch> {
88        self.branches.iter().filter(|b| b.is_remote())
89    }
90
91    /// Get the current branch
92    pub fn current(&self) -> Option<&Branch> {
93        self.branches.iter().find(|b| b.is_current)
94    }
95
96    /// Find a branch by name
97    pub fn find(&self, name: &str) -> Option<&Branch> {
98        self.branches.iter().find(|b| b.name == name)
99    }
100
101    /// Find a branch by short name (useful for remote branches)
102    pub fn find_by_short_name(&self, short_name: &str) -> Option<&Branch> {
103        self.branches.iter().find(|b| b.short_name() == short_name)
104    }
105
106    /// Check if the list is empty
107    pub fn is_empty(&self) -> bool {
108        self.branches.is_empty()
109    }
110
111    /// Get the count of branches
112    pub fn len(&self) -> usize {
113        self.branches.len()
114    }
115
116    /// Get count of local branches
117    pub fn local_count(&self) -> usize {
118        self.local().count()
119    }
120
121    /// Get count of remote-tracking branches
122    pub fn remote_count(&self) -> usize {
123        self.remote().count()
124    }
125}
126
127impl fmt::Display for BranchList {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        for branch in &self.branches {
130            writeln!(f, "{}", branch)?;
131        }
132        Ok(())
133    }
134}
135
136impl Repository {
137    /// List all branches in the repository
138    pub fn branches(&self) -> Result<BranchList> {
139        Self::ensure_git()?;
140
141        // Use git branch -vv --all for comprehensive branch information
142        let stdout = git(&["branch", "-vv", "--all"], Some(self.repo_path()))?;
143
144        let branches = parse_branch_output(&stdout)?;
145        Ok(BranchList::new(branches))
146    }
147
148    /// Get the current branch
149    pub fn current_branch(&self) -> Result<Option<Branch>> {
150        Self::ensure_git()?;
151
152        let stdout = git(&["branch", "--show-current"], Some(self.repo_path()))?;
153        let current_name = stdout.trim();
154
155        if current_name.is_empty() {
156            // Might be in detached HEAD state
157            return Ok(None);
158        }
159
160        // Get detailed info about the current branch
161        let branches = self.branches()?;
162        Ok(branches.current().cloned())
163    }
164
165    /// Create a new branch
166    pub fn create_branch(&self, name: &str, start_point: Option<&str>) -> Result<Branch> {
167        Self::ensure_git()?;
168
169        let mut args = vec!["branch", name];
170        if let Some(start) = start_point {
171            args.push(start);
172        }
173
174        let _stdout = git(&args, Some(self.repo_path()))?;
175
176        // Get information about the newly created branch
177        let branches = self.branches()?;
178        branches.find(name).cloned().ok_or_else(|| {
179            crate::error::GitError::CommandFailed(format!("Failed to create branch: {}", name))
180        })
181    }
182
183    /// Delete a branch
184    pub fn delete_branch(&self, branch: &Branch, force: bool) -> Result<()> {
185        Self::ensure_git()?;
186
187        if branch.is_current {
188            return Err(crate::error::GitError::CommandFailed(
189                "Cannot delete the current branch".to_string(),
190            ));
191        }
192
193        let flag = if force { "-D" } else { "-d" };
194        let args = vec!["branch", flag, &branch.name];
195
196        let _stdout = git(&args, Some(self.repo_path()))?;
197        Ok(())
198    }
199
200    /// Switch to an existing branch
201    pub fn checkout(&self, branch: &Branch) -> Result<()> {
202        Self::ensure_git()?;
203
204        let branch_name = if branch.is_remote() {
205            branch.short_name()
206        } else {
207            &branch.name
208        };
209
210        let _stdout = git(&["checkout", branch_name], Some(self.repo_path()))?;
211        Ok(())
212    }
213
214    /// Create a new branch and switch to it
215    pub fn checkout_new(&self, name: &str, start_point: Option<&str>) -> Result<Branch> {
216        Self::ensure_git()?;
217
218        let mut args = vec!["checkout", "-b", name];
219        if let Some(start) = start_point {
220            args.push(start);
221        }
222
223        let _stdout = git(&args, Some(self.repo_path()))?;
224
225        // Get information about the newly created and checked out branch
226        self.current_branch()?.ok_or_else(|| {
227            crate::error::GitError::CommandFailed(format!(
228                "Failed to create and checkout branch: {}",
229                name
230            ))
231        })
232    }
233}
234
235/// Parse the output of `git branch -vv --all`
236fn parse_branch_output(output: &str) -> Result<Vec<Branch>> {
237    let mut branches = Vec::new();
238
239    for line in output.lines() {
240        let line = line.trim();
241        if line.is_empty() {
242            continue;
243        }
244
245        // Skip the line that shows HEAD -> branch mapping for remotes
246        if line.contains("->") {
247            continue;
248        }
249
250        let is_current = line.starts_with('*');
251        let line = if is_current {
252            line[1..].trim() // Skip the '*' and trim
253        } else {
254            line.trim() // Just trim whitespace for non-current branches
255        };
256
257        // Parse branch name (first word)
258        let parts: Vec<&str> = line.split_whitespace().collect();
259        if parts.is_empty() {
260            continue;
261        }
262
263        let name = parts[0].to_string();
264
265        // Determine branch type
266        let branch_type = if name.starts_with("remotes/") {
267            BranchType::RemoteTracking
268        } else {
269            BranchType::Local
270        };
271
272        // Extract commit hash (second part if available)
273        let commit_hash = if parts.len() > 1 {
274            Hash::from(parts[1].to_string())
275        } else {
276            Hash::from("0000000000000000000000000000000000000000".to_string())
277        };
278
279        // Extract upstream information (look for [upstream] pattern)
280        let upstream = if let Some(bracket_start) = line.find('[') {
281            if let Some(bracket_end) = line.find(']') {
282                let upstream_info = &line[bracket_start + 1..bracket_end];
283                // Extract just the upstream branch name, ignore ahead/behind info
284                let upstream_branch = upstream_info
285                    .split(':')
286                    .next()
287                    .unwrap_or(upstream_info)
288                    .trim();
289                if upstream_branch.is_empty() {
290                    None
291                } else {
292                    Some(upstream_branch.to_string())
293                }
294            } else {
295                None
296            }
297        } else {
298            None
299        };
300
301        // Clean up remote branch names
302        let clean_name = if branch_type == BranchType::RemoteTracking {
303            name.strip_prefix("remotes/").unwrap_or(&name).to_string()
304        } else {
305            name
306        };
307
308        let branch = Branch {
309            name: clean_name,
310            branch_type,
311            is_current,
312            commit_hash,
313            upstream,
314        };
315
316        branches.push(branch);
317    }
318
319    Ok(branches)
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use std::fs;
326    use std::path::Path;
327
328    #[test]
329    fn test_branch_type_display() {
330        assert_eq!(format!("{}", BranchType::Local), "local");
331        assert_eq!(format!("{}", BranchType::RemoteTracking), "remote-tracking");
332    }
333
334    #[test]
335    fn test_branch_is_local() {
336        let branch = Branch {
337            name: "main".to_string(),
338            branch_type: BranchType::Local,
339            is_current: true,
340            commit_hash: Hash::from("abc123".to_string()),
341            upstream: None,
342        };
343
344        assert!(branch.is_local());
345        assert!(!branch.is_remote());
346    }
347
348    #[test]
349    fn test_branch_is_remote() {
350        let branch = Branch {
351            name: "origin/main".to_string(),
352            branch_type: BranchType::RemoteTracking,
353            is_current: false,
354            commit_hash: Hash::from("abc123".to_string()),
355            upstream: None,
356        };
357
358        assert!(branch.is_remote());
359        assert!(!branch.is_local());
360    }
361
362    #[test]
363    fn test_branch_short_name() {
364        let local_branch = Branch {
365            name: "feature".to_string(),
366            branch_type: BranchType::Local,
367            is_current: false,
368            commit_hash: Hash::from("abc123".to_string()),
369            upstream: None,
370        };
371
372        let remote_branch = Branch {
373            name: "origin/feature".to_string(),
374            branch_type: BranchType::RemoteTracking,
375            is_current: false,
376            commit_hash: Hash::from("abc123".to_string()),
377            upstream: None,
378        };
379
380        assert_eq!(local_branch.short_name(), "feature");
381        assert_eq!(remote_branch.short_name(), "feature");
382    }
383
384    #[test]
385    fn test_branch_display() {
386        let current_branch = Branch {
387            name: "main".to_string(),
388            branch_type: BranchType::Local,
389            is_current: true,
390            commit_hash: Hash::from("abc123".to_string()),
391            upstream: None,
392        };
393
394        let other_branch = Branch {
395            name: "feature".to_string(),
396            branch_type: BranchType::Local,
397            is_current: false,
398            commit_hash: Hash::from("def456".to_string()),
399            upstream: None,
400        };
401
402        assert_eq!(format!("{}", current_branch), "* main");
403        assert_eq!(format!("{}", other_branch), "  feature");
404    }
405
406    #[test]
407    fn test_branch_list_creation() {
408        let branches = vec![
409            Branch {
410                name: "main".to_string(),
411                branch_type: BranchType::Local,
412                is_current: true,
413                commit_hash: Hash::from("abc123".to_string()),
414                upstream: Some("origin/main".to_string()),
415            },
416            Branch {
417                name: "origin/main".to_string(),
418                branch_type: BranchType::RemoteTracking,
419                is_current: false,
420                commit_hash: Hash::from("abc123".to_string()),
421                upstream: None,
422            },
423        ];
424
425        let branch_list = BranchList::new(branches);
426
427        assert_eq!(branch_list.len(), 2);
428        assert_eq!(branch_list.local_count(), 1);
429        assert_eq!(branch_list.remote_count(), 1);
430        assert!(!branch_list.is_empty());
431    }
432
433    #[test]
434    fn test_branch_list_find() {
435        let branches = vec![
436            Branch {
437                name: "main".to_string(),
438                branch_type: BranchType::Local,
439                is_current: true,
440                commit_hash: Hash::from("abc123".to_string()),
441                upstream: None,
442            },
443            Branch {
444                name: "origin/feature".to_string(),
445                branch_type: BranchType::RemoteTracking,
446                is_current: false,
447                commit_hash: Hash::from("def456".to_string()),
448                upstream: None,
449            },
450        ];
451
452        let branch_list = BranchList::new(branches);
453
454        assert!(branch_list.find("main").is_some());
455        assert!(branch_list.find("origin/feature").is_some());
456        assert!(branch_list.find("nonexistent").is_none());
457
458        assert!(branch_list.find_by_short_name("main").is_some());
459        assert!(branch_list.find_by_short_name("feature").is_some());
460    }
461
462    #[test]
463    fn test_branch_list_current() {
464        let branches = vec![
465            Branch {
466                name: "main".to_string(),
467                branch_type: BranchType::Local,
468                is_current: true,
469                commit_hash: Hash::from("abc123".to_string()),
470                upstream: None,
471            },
472            Branch {
473                name: "feature".to_string(),
474                branch_type: BranchType::Local,
475                is_current: false,
476                commit_hash: Hash::from("def456".to_string()),
477                upstream: None,
478            },
479        ];
480
481        let branch_list = BranchList::new(branches);
482        let current = branch_list.current().unwrap();
483
484        assert_eq!(current.name, "main");
485        assert!(current.is_current);
486    }
487
488    #[test]
489    fn test_parse_branch_output() {
490        let output = r#"
491* main                abc1234 [origin/main] Initial commit
492  feature             def5678 Feature branch
493  remotes/origin/main abc1234 Initial commit
494"#;
495
496        let branches = parse_branch_output(output).unwrap();
497
498        assert_eq!(branches.len(), 3);
499
500        // Check main branch
501        let main_branch = branches.iter().find(|b| b.name == "main");
502        assert!(main_branch.is_some());
503        let main_branch = main_branch.unwrap();
504        assert!(main_branch.is_current);
505        assert_eq!(main_branch.branch_type, BranchType::Local);
506        assert_eq!(main_branch.upstream, Some("origin/main".to_string()));
507
508        // Check feature branch
509        let feature_branch = branches.iter().find(|b| b.name == "feature").unwrap();
510        assert!(!feature_branch.is_current);
511        assert_eq!(feature_branch.branch_type, BranchType::Local);
512        assert_eq!(feature_branch.upstream, None);
513
514        // Check remote branch
515        let remote_branch = branches.iter().find(|b| b.name == "origin/main").unwrap();
516        assert!(!remote_branch.is_current);
517        assert_eq!(remote_branch.branch_type, BranchType::RemoteTracking);
518    }
519
520    #[test]
521    fn test_repository_current_branch() {
522        let test_path = "/tmp/test_current_branch_repo";
523
524        // Clean up if exists
525        if Path::new(test_path).exists() {
526            fs::remove_dir_all(test_path).unwrap();
527        }
528
529        // Create a repository and test current branch
530        let repo = Repository::init(test_path, false).unwrap();
531
532        // In a new repo, there might not be a current branch until first commit
533        let _current = repo.current_branch().unwrap();
534        // This might be None in a fresh repository with no commits
535
536        // Clean up
537        fs::remove_dir_all(test_path).unwrap();
538    }
539
540    #[test]
541    fn test_repository_create_branch() {
542        let test_path = "/tmp/test_create_branch_repo";
543
544        // Clean up if exists
545        if Path::new(test_path).exists() {
546            fs::remove_dir_all(test_path).unwrap();
547        }
548
549        // Create a repository with an initial commit
550        let repo = Repository::init(test_path, false).unwrap();
551
552        // Configure git user for this repository to enable commits
553        repo.config()
554            .set_user("Test User", "test@example.com")
555            .unwrap();
556
557        // Create a test file and commit to have a valid HEAD
558        std::fs::write(format!("{}/test.txt", test_path), "test content").unwrap();
559        repo.add(&["test.txt"]).unwrap();
560        repo.commit("Initial commit").unwrap();
561
562        // Create a new branch
563        let branch = repo.create_branch("feature", None).unwrap();
564        assert_eq!(branch.name, "feature");
565        assert_eq!(branch.branch_type, BranchType::Local);
566        assert!(!branch.is_current);
567
568        // Verify the branch exists in the branch list
569        let branches = repo.branches().unwrap();
570        assert!(branches.find("feature").is_some());
571
572        // Clean up
573        fs::remove_dir_all(test_path).unwrap();
574    }
575}