up_rs/tasks/git/
cherry.rs

1//! `git cherry` command equivalent, finds equivalent commits.
2use crate::tasks::git::branch::get_branch_name;
3use crate::tasks::git::errors::GitError as E;
4use color_eyre::eyre::eyre;
5use color_eyre::eyre::Result;
6use git2::Branch;
7use git2::DiffFormat;
8use git2::DiffOptions;
9use git2::Oid;
10use git2::Repository;
11use git2::Revwalk;
12use ring::digest::Context;
13use ring::digest::Digest;
14use ring::digest::SHA256;
15use std::collections::HashSet;
16use std::io::Read;
17use tracing::trace;
18
19/// Return true if there are commits that aren't in upstream but are in head.
20///
21/// Find the merge-base of `upstream` and `head`, then look at the commits on
22/// `head`, and checks whether there is an equivalent patch for each one in
23/// `upstream`.
24///
25/// Equivalence is deterimined by taking the sha256sum of the patch with
26/// whitespace removed, and then comparing the sha256sums.
27///
28/// Equivalent of `git cherry -v "$up_branch" "$branch" | grep -q '^+'`
29/// Refs: <https://stackoverflow.com/questions/49480468/is-there-a-git-cherry-in-libgit2>
30pub(super) fn unmerged_commits(
31    repo: &Repository,
32    upstream: &Branch,
33    head: &Branch,
34) -> Result<bool> {
35    // TODO(gib): Add tests: https://github.com/git/git/blob/master/t/t3500-cherry.sh
36    let head_name = get_branch_name(head)?;
37    let upstream_name = get_branch_name(upstream)?;
38    let head_oid = head.get().target().ok_or(E::NoOidFound {
39        branch_name: head_name,
40    })?;
41    let upstream_oid = upstream.get().target().ok_or(E::NoOidFound {
42        branch_name: upstream_name,
43    })?;
44
45    let merge_base = repo.merge_base(head_oid, upstream_oid)?;
46    let upstream_ids = rev_list(repo, upstream_oid, merge_base)?;
47
48    let mut upstream_patch_ids = HashSet::new();
49
50    for id in upstream_ids {
51        let id = id?;
52        upstream_patch_ids.insert(patch_id(repo, id)?.as_ref().to_owned());
53    }
54    trace!("Upstream patch IDs: {upstream_patch_ids:?}");
55
56    let merge_base = repo.merge_base(head_oid, upstream_oid)?;
57    let head_ids: Vec<Oid> = rev_list(repo, head_oid, merge_base)?.collect::<Result<_, _>>()?;
58    trace!("Found head IDs: {head_ids:?}");
59
60    for id in head_ids {
61        let head_patch_id = patch_id(repo, id)?;
62        trace!("Head patch ID for '{id:?}': '{head_patch_id:?}'");
63        if !upstream_patch_ids.contains(head_patch_id.as_ref()) {
64            // Found an unmerged commit.
65            return Ok(true);
66        }
67    }
68
69    // We didn't find any unmerged commits.
70    Ok(false)
71}
72
73/// Generate a patch-id for the commit.
74///
75/// Take the sha256sum of the patch with whitespace removed, and
76/// then comparing the sha256sums.
77///
78/// <https://git.uis.cam.ac.uk/man/git-patch-id.html>
79// TODO(gib): consider running in parallel.
80// TODO(gib): Add tests: https://github.com/git/git/blob/306ee63a703ad67c54ba1209dc11dd9ea500dc1f/t/t4204-patch-id.sh
81fn patch_id(repo: &Repository, id: Oid) -> Result<Digest> {
82    // Get commit for Oid.
83    let commit = repo.find_commit(id).map_err(|e| E::NoCommitFound {
84        oid: id.to_string(),
85        source: e,
86    })?;
87    let parent = commit.parent(0)?;
88    // TODO(gib): What diff options are needed? What does git set?
89    let mut diff_opts = DiffOptions::new();
90    // TODO(gib): Extract into parent function.
91    let diff = repo.diff_tree_to_tree(
92        Some(&parent.tree()?),
93        Some(&commit.tree()?),
94        Some(&mut diff_opts),
95    )?;
96
97    let mut trimmed_diff: Vec<u8> = Vec::new();
98
99    // Convert diff to string so we can get the sha256sum.
100    diff.print(DiffFormat::PatchId, |delta, hunk_opt, line| -> bool {
101        trimmed_diff.extend(&u32_to_u8_array(delta.flags().bits()));
102        if let Some(hunk) = hunk_opt {
103            trimmed_diff.extend(hunk.header());
104        }
105        trimmed_diff.extend(line.content());
106        true
107    })?;
108
109    sha256_digest(&trimmed_diff[..])
110}
111
112/// Convert a u32 to array of 4 u8s.
113#[allow(clippy::cast_possible_truncation)]
114const fn u32_to_u8_array(x: u32) -> [u8; 4] {
115    let b1: u8 = ((x >> 24) & 0xff) as u8;
116    let b2: u8 = ((x >> 16) & 0xff) as u8;
117    let b3: u8 = ((x >> 8) & 0xff) as u8;
118    let b4: u8 = (x & 0xff) as u8;
119
120    [b1, b2, b3, b4]
121}
122
123/// Get the sha256 checksum of some input data.
124fn sha256_digest<R: Read>(mut reader: R) -> Result<Digest> {
125    let mut context = Context::new(&SHA256);
126    let mut buffer = [0; 1024];
127
128    loop {
129        let count = reader.read(&mut buffer)?;
130        if count == 0 {
131            break;
132        }
133        context.update(buffer.get(..count).ok_or_else(|| {
134            eyre!(
135                "Logic error in up, we should have just confirmed that the buffer wasn't empty."
136            )
137        })?);
138    }
139
140    Ok(context.finish())
141}
142
143/// Get a list of revisions between two references.
144fn rev_list(repo: &Repository, from: Oid, to: Oid) -> Result<Revwalk> {
145    let mut revwalk = repo.revwalk()?;
146    // TODO(gib): do I need to set a revwalk.set_sorting(Sort::REVERSE) here?
147    revwalk.push(from)?;
148    revwalk.hide(to)?;
149
150    Ok(revwalk)
151}