story_tracker_cli/subcommands/
switch.rs

1use crate::subcommands::generate::branch_name;
2use auth_git2::GitAuthenticator;
3use git2::{BranchType, Repository};
4use pivotal_tracker::{
5	client::Client,
6	story::{GetStoryOptions, StoryID},
7};
8use std::{error::Error, process::Command};
9
10#[derive(Debug)]
11pub struct RunOptions<'a> {
12	pub branch_or_story_id: &'a String,
13	pub client: &'a Client,
14}
15
16pub async fn run(options: RunOptions<'_>) -> Result<(), Box<dyn Error>> {
17	let branch_name = {
18		let story_id = options.branch_or_story_id.parse::<StoryID>();
19
20		if story_id.is_ok() {
21			let story_id = story_id?;
22
23			println!("Fetching story {}...", story_id.0);
24
25			let story = options
26				.client
27				.get_story(GetStoryOptions { id: story_id })
28				.await
29				.expect("Failed to get story");
30
31			branch_name(&story)
32		} else {
33			options.branch_or_story_id.to_string()
34		}
35	};
36
37	println!("Fetching remote branch origin/{}...", branch_name);
38
39	let auth = GitAuthenticator::default();
40	let repo = Repository::open_from_env()?;
41	let mut remote = repo.find_remote("origin")?;
42
43	remote
44		.fetch_refspecs()?
45		.iter()
46		.for_each(|x| auth.fetch(&repo, &mut remote, &[x.unwrap()], None).unwrap());
47
48	let remote_has_branch = repo
49		.find_branch(&format!("origin/{}", branch_name), BranchType::Remote)
50		.is_ok();
51
52	if remote_has_branch {
53		checkout_branch(&branch_name);
54	} else {
55		branch_off_of(&get_default_branch(), &branch_name);
56	}
57
58	Ok(())
59}
60
61fn checkout_branch(branch_name: &str) {
62	println!("Checking out branch {}...", branch_name);
63
64	Command::new("git")
65		.args(["checkout", branch_name])
66		.output()
67		.unwrap();
68
69	println!("Pulling latest from {}...", branch_name);
70
71	Command::new("git").args(["pull"]).output().unwrap();
72}
73
74fn branch_off_of(from_branch_name: &str, to_branch_name: &str) {
75	checkout_branch(from_branch_name);
76
77	println!("Creating branch {}...", to_branch_name);
78
79	Command::new("git")
80		.args(["checkout", "-b", to_branch_name])
81		.output()
82		.unwrap();
83}
84
85fn get_default_branch() -> String {
86	let repo = Repository::open_from_env().unwrap();
87	let origin_revspec =
88		repo.revparse_single("refs/remotes/origin/HEAD").unwrap();
89	let origin_commit = origin_revspec.as_commit().unwrap();
90	let origin_commit_id = origin_commit.id();
91	let mut branches = repo.branches(Some(BranchType::Remote)).unwrap();
92	let origin_branch_name = branches
93		.find_map(|branch_tuple| {
94			let (branch, _) = branch_tuple.unwrap();
95			let branch_name = branch.name().unwrap().unwrap().to_string();
96
97			if branch_name == "origin/HEAD" {
98				return None;
99			}
100
101			let branch_commit = branch.get().peel_to_commit().unwrap();
102
103			if branch_commit.id() != origin_commit_id {
104				return None;
105			}
106
107			Some(branch_name)
108		})
109		.unwrap();
110
111	origin_branch_name.replace("origin/", "")
112}