wok_dev/
repo.rs

1use std::{fmt, path};
2
3use anyhow::*;
4use git2::build::CheckoutBuilder;
5use std::result::Result::Ok;
6
7#[derive(Debug, Clone, PartialEq)]
8pub enum MergeResult {
9    UpToDate,
10    FastForward,
11    Merged,
12    Conflicts,
13}
14
15pub struct Repo {
16    pub git_repo: git2::Repository,
17    pub work_dir: path::PathBuf,
18    pub head: String,
19    pub subrepos: Vec<Repo>,
20}
21
22impl Repo {
23    pub fn new(work_dir: &path::Path, head_name: Option<&str>) -> Result<Self> {
24        let git_repo = git2::Repository::open(work_dir)
25            .with_context(|| format!("Cannot open repo at `{}`", work_dir.display()))?;
26
27        let head = match head_name {
28            Some(name) => String::from(name),
29            None => {
30                if git_repo.head_detached().with_context(|| {
31                    format!(
32                        "Cannot determine head state for repo at `{}`",
33                        work_dir.display()
34                    )
35                })? {
36                    bail!(
37                        "Cannot operate on a detached head for repo at `{}`",
38                        work_dir.display()
39                    )
40                }
41
42                String::from(git_repo.head().with_context(|| {
43                    format!(
44                        "Cannot find the head branch for repo at `{}`. Is it detached?",
45                        work_dir.display()
46                    )
47                })?.shorthand().with_context(|| {
48                    format!(
49                        "Cannot find a human readable representation of the head ref for repo at `{}`",
50                        work_dir.display(),
51                    )
52                })?)
53            },
54        };
55
56        let subrepos = git_repo
57            .submodules()
58            .with_context(|| {
59                format!(
60                    "Cannot load submodules for repo at `{}`",
61                    work_dir.display()
62                )
63            })?
64            .iter()
65            .map(|submodule| Repo::new(&work_dir.join(submodule.path()), Some(&head)))
66            .collect::<Result<Vec<Repo>>>()?;
67
68        Ok(Repo {
69            git_repo,
70            work_dir: path::PathBuf::from(work_dir),
71            head,
72            subrepos,
73        })
74    }
75
76    pub fn get_subrepo_by_path(&self, subrepo_path: &path::PathBuf) -> Option<&Repo> {
77        self.subrepos
78            .iter()
79            .find(|subrepo| subrepo.work_dir == self.work_dir.join(subrepo_path))
80    }
81
82    pub fn sync(&self) -> Result<()> {
83        self.switch(&self.head)?;
84        Ok(())
85    }
86
87    pub fn switch(&self, head: &str) -> Result<()> {
88        self.git_repo.set_head(&self.resolve_reference(head)?)?;
89        self.git_repo.checkout_head(None)?;
90        Ok(())
91    }
92
93    pub fn fetch(&self) -> Result<()> {
94        // Get the remote for the current branch
95        let head_ref = self.git_repo.head()?;
96        let branch_name = head_ref.shorthand().with_context(|| {
97            format!(
98                "Cannot get branch name for repo at `{}`",
99                self.work_dir.display()
100            )
101        })?;
102
103        let tracking = match self.tracking_branch(branch_name)? {
104            Some(tracking) => tracking,
105            None => {
106                // No upstream configured, skip fetch
107                return Ok(());
108            },
109        };
110
111        // Check if remote exists
112        match self.git_repo.find_remote(&tracking.remote) {
113            Ok(mut remote) => {
114                let mut fetch_options = git2::FetchOptions::new();
115                fetch_options.remote_callbacks(self.remote_callbacks()?);
116
117                remote
118                    .fetch::<&str>(&[], Some(&mut fetch_options), None)
119                    .with_context(|| {
120                        format!(
121                            "Failed to fetch from remote '{}' for repo at `{}`",
122                            tracking.remote,
123                            self.work_dir.display()
124                        )
125                    })?;
126            },
127            Err(_) => {
128                // No remote configured, skip fetch
129                return Ok(());
130            },
131        }
132
133        Ok(())
134    }
135
136    pub fn merge(&self, branch_name: &str) -> Result<MergeResult> {
137        // First, fetch the latest changes
138        self.fetch()?;
139
140        // Resolve the tracking branch reference
141        let tracking = match self.tracking_branch(branch_name)? {
142            Some(tracking) => tracking,
143            None => {
144                // No upstream configured, treat as up to date
145                return Ok(MergeResult::UpToDate);
146            },
147        };
148
149        // Check if remote branch exists
150        let remote_branch_oid = match self.git_repo.refname_to_id(&tracking.remote_ref)
151        {
152            Ok(oid) => oid,
153            Err(_) => {
154                // No remote branch, just return up to date
155                return Ok(MergeResult::UpToDate);
156            },
157        };
158
159        let remote_commit = self.git_repo.find_commit(remote_branch_oid)?;
160        let local_commit = self.git_repo.head()?.peel_to_commit()?;
161
162        // Check if we're already up to date
163        if local_commit.id() == remote_commit.id() {
164            return Ok(MergeResult::UpToDate);
165        }
166
167        // Check if we can fast-forward
168        if self
169            .git_repo
170            .graph_descendant_of(remote_commit.id(), local_commit.id())?
171        {
172            // Fast-forward merge
173            self.git_repo.reference(
174                &format!("refs/heads/{}", branch_name),
175                remote_commit.id(),
176                true,
177                &format!("Fast-forward '{}' to {}", branch_name, tracking.remote_ref),
178            )?;
179            self.git_repo
180                .set_head(&format!("refs/heads/{}", branch_name))?;
181            let mut checkout = CheckoutBuilder::new();
182            checkout.force();
183            self.git_repo.checkout_head(Some(&mut checkout))?;
184            return Ok(MergeResult::FastForward);
185        }
186
187        // Perform a merge
188        let mut merge_opts = git2::MergeOptions::new();
189        merge_opts.fail_on_conflict(false); // Don't fail on conflicts, we'll handle them
190
191        let _merge_result = self.git_repo.merge_commits(
192            &local_commit,
193            &remote_commit,
194            Some(&merge_opts),
195        )?;
196
197        // Check if there are conflicts by examining the index
198        let mut index = self.git_repo.index()?;
199        let has_conflicts = index.has_conflicts();
200
201        if !has_conflicts {
202            // No conflicts, merge was successful
203            let signature = self.git_repo.signature()?;
204            let tree_id = index.write_tree()?;
205            let tree = self.git_repo.find_tree(tree_id)?;
206
207            self.git_repo.commit(
208                Some(&format!("refs/heads/{}", branch_name)),
209                &signature,
210                &signature,
211                &format!("Merge remote-tracking branch '{}'", tracking.remote_ref),
212                &tree,
213                &[&local_commit, &remote_commit],
214            )?;
215
216            self.git_repo.cleanup_state()?;
217
218            Ok(MergeResult::Merged)
219        } else {
220            // There are conflicts
221            Ok(MergeResult::Conflicts)
222        }
223    }
224
225    pub fn get_remote_name_for_branch(&self, branch_name: &str) -> Result<String> {
226        if let Some(tracking) = self.tracking_branch(branch_name)? {
227            Ok(tracking.remote)
228        } else {
229            // Fall back to origin if no tracking branch is configured
230            Ok("origin".to_string())
231        }
232    }
233
234    pub fn remote_callbacks(&self) -> Result<git2::RemoteCallbacks<'static>> {
235        let config = self.git_repo.config()?;
236
237        let mut callbacks = git2::RemoteCallbacks::new();
238        callbacks.credentials(move |url, username_from_url, allowed| {
239            if allowed.contains(git2::CredentialType::SSH_KEY)
240                && let Some(username) = username_from_url
241                && let Ok(cred) = git2::Cred::ssh_key_from_agent(username)
242            {
243                return Ok(cred);
244            }
245
246            if (allowed.contains(git2::CredentialType::USER_PASS_PLAINTEXT)
247                || allowed.contains(git2::CredentialType::SSH_KEY)
248                || allowed.contains(git2::CredentialType::DEFAULT))
249                && let Ok(cred) =
250                    git2::Cred::credential_helper(&config, url, username_from_url)
251            {
252                return Ok(cred);
253            }
254
255            if allowed.contains(git2::CredentialType::USERNAME) {
256                if let Some(username) = username_from_url {
257                    return git2::Cred::username(username);
258                } else {
259                    return git2::Cred::username("git");
260                }
261            }
262
263            git2::Cred::default()
264        });
265
266        Ok(callbacks)
267    }
268
269    fn resolve_reference(&self, short_name: &str) -> Result<String> {
270        Ok(self
271            .git_repo
272            .resolve_reference_from_short_name(short_name)?
273            .name()
274            .with_context(|| {
275                format!(
276                    "Cannot resolve head reference for repo at `{}`",
277                    self.work_dir.display()
278                )
279            })?
280            .to_owned())
281    }
282
283    fn tracking_branch(&self, branch_name: &str) -> Result<Option<TrackingBranch>> {
284        let config = self.git_repo.config()?;
285
286        let remote_key = format!("branch.{}.remote", branch_name);
287        let merge_key = format!("branch.{}.merge", branch_name);
288
289        let remote = match config.get_string(&remote_key) {
290            Ok(name) => name,
291            Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
292            Err(err) => return Err(err.into()),
293        };
294
295        let merge_ref = match config.get_string(&merge_key) {
296            Ok(name) => name,
297            Err(err) if err.code() == git2::ErrorCode::NotFound => return Ok(None),
298            Err(err) => return Err(err.into()),
299        };
300
301        let branch_short = merge_ref
302            .strip_prefix("refs/heads/")
303            .unwrap_or(&merge_ref)
304            .to_owned();
305
306        let remote_ref = format!("refs/remotes/{}/{}", remote, branch_short);
307
308        Ok(Some(TrackingBranch { remote, remote_ref }))
309    }
310}
311
312impl fmt::Debug for Repo {
313    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
314        f.debug_struct("Repo")
315            .field("work_dir", &self.work_dir)
316            .field("head", &self.head)
317            .field("subrepos", &self.subrepos)
318            .finish()
319    }
320}
321
322struct TrackingBranch {
323    remote: String,
324    remote_ref: String,
325}