Skip to main content

sr_ai/git/
mod.rs

1use anyhow::{Context, Result, bail};
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::process::Command;
5
6pub struct GitRepo {
7    root: PathBuf,
8}
9
10#[allow(dead_code)]
11impl GitRepo {
12    pub fn discover() -> Result<Self> {
13        let output = Command::new("git")
14            .args(["rev-parse", "--show-toplevel"])
15            .output()
16            .context("failed to run git")?;
17
18        if !output.status.success() {
19            bail!(crate::error::SrAiError::NotAGitRepo);
20        }
21
22        let root = String::from_utf8(output.stdout)
23            .context("invalid utf-8 from git")?
24            .trim()
25            .into();
26
27        Ok(Self { root })
28    }
29
30    pub fn root(&self) -> &PathBuf {
31        &self.root
32    }
33
34    fn git(&self, args: &[&str]) -> Result<String> {
35        let output = Command::new("git")
36            .args(["-C", self.root.to_str().unwrap()])
37            .args(args)
38            .output()
39            .with_context(|| format!("failed to run git {}", args.join(" ")))?;
40
41        if !output.status.success() {
42            let stderr = String::from_utf8_lossy(&output.stderr);
43            bail!(crate::error::SrAiError::GitCommand(format!(
44                "git {} failed: {}",
45                args.join(" "),
46                stderr.trim()
47            )));
48        }
49
50        Ok(String::from_utf8_lossy(&output.stdout).to_string())
51    }
52
53    fn git_allow_failure(&self, args: &[&str]) -> Result<(bool, String)> {
54        let output = Command::new("git")
55            .args(["-C", self.root.to_str().unwrap()])
56            .args(args)
57            .output()
58            .with_context(|| format!("failed to run git {}", args.join(" ")))?;
59
60        Ok((
61            output.status.success(),
62            String::from_utf8_lossy(&output.stdout).to_string(),
63        ))
64    }
65
66    pub fn has_staged_changes(&self) -> Result<bool> {
67        let out = self.git(&["diff", "--cached", "--name-only"])?;
68        Ok(!out.trim().is_empty())
69    }
70
71    pub fn has_any_changes(&self) -> Result<bool> {
72        let out = self.git(&["status", "--porcelain"])?;
73        Ok(!out.trim().is_empty())
74    }
75
76    pub fn has_head(&self) -> Result<bool> {
77        let (ok, _) = self.git_allow_failure(&["rev-parse", "HEAD"])?;
78        Ok(ok)
79    }
80
81    pub fn reset_head(&self) -> Result<()> {
82        if self.has_head()? {
83            self.git(&["reset", "HEAD", "--quiet"])?;
84        } else {
85            // Fresh repo with no commits — unstage via rm --cached
86            let _ = self.git_allow_failure(&["rm", "--cached", "-r", ".", "--quiet"]);
87        }
88        Ok(())
89    }
90
91    pub fn stage_file(&self, file: &str) -> Result<bool> {
92        let full_path = self.root.join(file);
93        let exists = full_path.exists();
94
95        if !exists {
96            // Check if it's a deleted file
97            let out = self.git(&["ls-files", "--deleted"])?;
98            let is_deleted = out.lines().any(|l| l.trim() == file);
99            if !is_deleted {
100                return Ok(false);
101            }
102        }
103
104        let (ok, _) = self.git_allow_failure(&["add", "--", file])?;
105        Ok(ok)
106    }
107
108    pub fn has_staged_after_add(&self) -> Result<bool> {
109        self.has_staged_changes()
110    }
111
112    pub fn commit(&self, message: &str) -> Result<()> {
113        let output = Command::new("git")
114            .args(["-C", self.root.to_str().unwrap()])
115            .args(["commit", "-F", "-"])
116            .stdin(std::process::Stdio::piped())
117            .stdout(std::process::Stdio::piped())
118            .stderr(std::process::Stdio::piped())
119            .spawn()
120            .context("failed to spawn git commit")?;
121
122        use std::io::Write;
123        let mut child = output;
124        if let Some(mut stdin) = child.stdin.take() {
125            stdin.write_all(message.as_bytes())?;
126        }
127
128        let out = child.wait_with_output()?;
129        if !out.status.success() {
130            let stderr = String::from_utf8_lossy(&out.stderr);
131            bail!(crate::error::SrAiError::GitCommand(format!(
132                "git commit failed: {}",
133                stderr.trim()
134            )));
135        }
136
137        Ok(())
138    }
139
140    pub fn recent_commits(&self, count: usize) -> Result<String> {
141        self.git(&["--no-pager", "log", "--oneline", &format!("-{count}")])
142    }
143
144    pub fn diff_cached(&self) -> Result<String> {
145        self.git(&["diff", "--cached"])
146    }
147
148    pub fn diff_cached_stat(&self) -> Result<String> {
149        self.git(&["diff", "--cached", "--stat"])
150    }
151
152    pub fn diff_head(&self) -> Result<String> {
153        let (ok, out) = self.git_allow_failure(&["diff", "HEAD"])?;
154        if ok { Ok(out) } else { self.git(&["diff"]) }
155    }
156
157    pub fn status_porcelain(&self) -> Result<String> {
158        self.git(&["status", "--porcelain"])
159    }
160
161    pub fn untracked_files(&self) -> Result<String> {
162        self.git(&["ls-files", "--others", "--exclude-standard"])
163    }
164
165    pub fn show(&self, rev: &str) -> Result<String> {
166        self.git(&["show", rev])
167    }
168
169    pub fn log_range(&self, base: &str, count: Option<usize>) -> Result<String> {
170        let mut args = vec!["--no-pager", "log", "--oneline"];
171        let count_str;
172        if let Some(n) = count {
173            count_str = format!("-{n}");
174            args.push(&count_str);
175        }
176        args.push(base);
177        self.git(&args)
178    }
179
180    pub fn diff_range(&self, base: &str) -> Result<String> {
181        self.git(&["diff", base])
182    }
183
184    pub fn current_branch(&self) -> Result<String> {
185        let out = self.git(&["rev-parse", "--abbrev-ref", "HEAD"])?;
186        Ok(out.trim().to_string())
187    }
188
189    pub fn head_short(&self) -> Result<String> {
190        let out = self.git(&["rev-parse", "--short", "HEAD"])?;
191        Ok(out.trim().to_string())
192    }
193
194    /// Count commits since the last tag. If no tags exist, counts all commits.
195    pub fn commits_since_last_tag(&self) -> Result<usize> {
196        // Try to find the most recent tag
197        let (ok, tag) = self.git_allow_failure(&["describe", "--tags", "--abbrev=0"])?;
198        let tag = tag.trim();
199
200        let out = if ok && !tag.is_empty() {
201            self.git(&["rev-list", &format!("{tag}..HEAD"), "--count"])?
202        } else {
203            self.git(&["rev-list", "HEAD", "--count"])?
204        };
205
206        out.trim()
207            .parse::<usize>()
208            .context("failed to parse commit count")
209    }
210
211    /// Get detailed log of recent commits (SHA, subject, body) oldest first.
212    pub fn log_detailed(&self, count: usize) -> Result<String> {
213        let out = self.git(&[
214            "--no-pager",
215            "log",
216            "--reverse",
217            &format!("-{count}"),
218            "--format=%h %s%n%b%n---",
219        ])?;
220        Ok(out)
221    }
222
223    pub fn file_statuses(&self) -> Result<HashMap<String, char>> {
224        let out = self.git(&["status", "--porcelain"])?;
225        let mut map = HashMap::new();
226        for line in out.lines() {
227            if line.len() < 3 {
228                continue;
229            }
230            let xy = &line.as_bytes()[..2];
231            let mut path = line[3..].to_string();
232            if let Some(pos) = path.find(" -> ") {
233                path = path[pos + 4..].to_string();
234            }
235            let (x, y) = (xy[0], xy[1]);
236            let status = match (x, y) {
237                (b'?', b'?') => 'A',
238                (b'A', _) | (_, b'A') => 'A',
239                (b'D', _) | (_, b'D') => 'D',
240                (b'R', _) | (_, b'R') => 'R',
241                (b'M', _) | (_, b'M') | (b'T', _) | (_, b'T') => 'M',
242                _ => '~',
243            };
244            map.insert(path, status);
245        }
246        Ok(map)
247    }
248
249    /// Create a snapshot of the working tree state into the platform data directory.
250    /// Location: `<data_local_dir>/sr/snapshots/<repo-hash>/`
251    ///   - macOS:   ~/Library/Application Support/sr/snapshots/<hash>/
252    ///   - Linux:   ~/.local/share/sr/snapshots/<hash>/
253    ///   - Windows: %LOCALAPPDATA%/sr/snapshots/<hash>/
254    ///
255    /// The snapshot includes:
256    /// - A `stash` ref created via `git stash create` (staged + unstaged changes)
257    /// - An `untracked.tar` of untracked files
258    /// - The list of staged files in `staged.txt`
259    /// - The repo root path in `repo_root` (so restore targets the right directory)
260    ///
261    /// Lives completely outside the repo so the agent cannot touch it.
262    pub fn snapshot_working_tree(&self) -> Result<PathBuf> {
263        let snapshot_dir = snapshot_dir_for(&self.root)
264            .context("failed to resolve snapshot directory (no data directory available)")?;
265        std::fs::create_dir_all(&snapshot_dir).context("failed to create snapshot directory")?;
266
267        // Record which repo this snapshot belongs to
268        std::fs::write(
269            snapshot_dir.join("repo_root"),
270            self.root.to_string_lossy().as_bytes(),
271        )
272        .context("failed to write repo_root")?;
273
274        // Capture staged file list
275        let staged = self.git(&["diff", "--cached", "--name-only"])?;
276        std::fs::write(snapshot_dir.join("staged.txt"), staged.trim())
277            .context("failed to write staged.txt")?;
278
279        // Create a stash object without modifying working tree or index.
280        // `git stash create` writes a stash commit but doesn't reset anything.
281        let (ok, stash_ref) = self.git_allow_failure(&["stash", "create"])?;
282        let stash_ref = stash_ref.trim().to_string();
283        if ok && !stash_ref.is_empty() {
284            std::fs::write(snapshot_dir.join("stash_ref"), &stash_ref)
285                .context("failed to write stash_ref")?;
286        } else {
287            let _ = std::fs::remove_file(snapshot_dir.join("stash_ref"));
288        }
289
290        // Archive untracked files
291        let untracked = self.git(&["ls-files", "--others", "--exclude-standard"])?;
292        let untracked_files: Vec<&str> = untracked
293            .lines()
294            .map(|l| l.trim())
295            .filter(|l| !l.is_empty())
296            .collect();
297        let tar_path = snapshot_dir.join("untracked.tar");
298        if !untracked_files.is_empty() {
299            let file_list = untracked_files.join("\n");
300            let file_list_path = snapshot_dir.join("untracked_list.txt");
301            std::fs::write(&file_list_path, &file_list)?;
302            let status = Command::new("tar")
303                .args([
304                    "cf",
305                    tar_path.to_str().unwrap(),
306                    "-T",
307                    file_list_path.to_str().unwrap(),
308                ])
309                .current_dir(&self.root)
310                .stdout(std::process::Stdio::null())
311                .stderr(std::process::Stdio::null())
312                .status()
313                .context("failed to run tar")?;
314            if !status.success() {
315                eprintln!("warning: failed to archive untracked files");
316            }
317        } else {
318            let _ = std::fs::remove_file(&tar_path);
319        }
320
321        // Mark snapshot as valid
322        let now = std::time::SystemTime::now()
323            .duration_since(std::time::UNIX_EPOCH)
324            .unwrap_or_default()
325            .as_secs();
326        std::fs::write(snapshot_dir.join("timestamp"), now.to_string())
327            .context("failed to write timestamp")?;
328
329        Ok(snapshot_dir)
330    }
331
332    /// Restore working tree from the latest snapshot (best-effort).
333    pub fn restore_snapshot(&self) -> Result<()> {
334        let snapshot_dir = self.snapshot_dir()?;
335        if !snapshot_dir.join("timestamp").exists() {
336            bail!("no valid snapshot found");
337        }
338
339        // Restore tracked changes from stash ref
340        let stash_ref_path = snapshot_dir.join("stash_ref");
341        if stash_ref_path.exists() {
342            let stash_ref = std::fs::read_to_string(&stash_ref_path)?;
343            let stash_ref = stash_ref.trim();
344            if !stash_ref.is_empty() {
345                let (ok, _) = self.git_allow_failure(&["stash", "apply", stash_ref])?;
346                if !ok {
347                    eprintln!("warning: failed to apply stash {stash_ref}, trying checkout");
348                    let _ = self.git_allow_failure(&["checkout", "."]);
349                    let _ = self.git_allow_failure(&["stash", "apply", stash_ref]);
350                }
351            }
352        }
353
354        // Restore staged files
355        let staged_path = snapshot_dir.join("staged.txt");
356        if staged_path.exists() {
357            let staged = std::fs::read_to_string(&staged_path)?;
358            for file in staged.lines().filter(|l| !l.trim().is_empty()) {
359                let full = self.root.join(file);
360                if full.exists() {
361                    let _ = self.git_allow_failure(&["add", "--", file]);
362                }
363            }
364        }
365
366        // Restore untracked files
367        let tar_path = snapshot_dir.join("untracked.tar");
368        if tar_path.exists() {
369            let _ = Command::new("tar")
370                .args(["xf", tar_path.to_str().unwrap()])
371                .current_dir(&self.root)
372                .stdout(std::process::Stdio::null())
373                .stderr(std::process::Stdio::null())
374                .status();
375        }
376
377        Ok(())
378    }
379
380    /// Remove the snapshot after a successful operation.
381    pub fn clear_snapshot(&self) {
382        if let Ok(dir) = self.snapshot_dir() {
383            let _ = std::fs::remove_dir_all(&dir);
384        }
385    }
386
387    /// Returns the snapshot directory path for this repo.
388    pub fn snapshot_dir(&self) -> Result<PathBuf> {
389        snapshot_dir_for(&self.root)
390            .context("failed to resolve snapshot directory (no data directory available)")
391    }
392
393    /// Check if a valid snapshot exists.
394    pub fn has_snapshot(&self) -> bool {
395        self.snapshot_dir()
396            .map(|d| d.join("timestamp").exists())
397            .unwrap_or(false)
398    }
399}
400
401/// Resolve the snapshot directory for a repo root.
402/// `<data_local_dir>/sr/snapshots/<repo-hash>/`
403fn snapshot_dir_for(repo_root: &std::path::Path) -> Option<PathBuf> {
404    let base = dirs::data_local_dir()?;
405    let repo_id =
406        &crate::cache::fingerprint::sha256_hex(repo_root.to_string_lossy().as_bytes())[..16];
407    Some(base.join("sr").join("snapshots").join(repo_id))
408}
409
410/// Guard that ensures the snapshot is cleaned up on success
411/// and restored on failure (drop without explicit success).
412pub struct SnapshotGuard<'a> {
413    repo: &'a GitRepo,
414    succeeded: bool,
415}
416
417impl<'a> SnapshotGuard<'a> {
418    /// Create a snapshot and return the guard.
419    pub fn new(repo: &'a GitRepo) -> Result<Self> {
420        repo.snapshot_working_tree()?;
421        Ok(Self {
422            repo,
423            succeeded: false,
424        })
425    }
426
427    /// Mark the operation as successful — snapshot will be cleared on drop.
428    pub fn success(mut self) {
429        self.succeeded = true;
430        self.repo.clear_snapshot();
431    }
432}
433
434impl Drop for SnapshotGuard<'_> {
435    fn drop(&mut self) {
436        if !self.succeeded && self.repo.has_snapshot() {
437            eprintln!("sr: operation failed, restoring working tree from snapshot...");
438            if let Err(e) = self.repo.restore_snapshot() {
439                eprintln!("sr: warning: snapshot restore failed: {e}");
440                if let Ok(dir) = self.repo.snapshot_dir() {
441                    eprintln!(
442                        "sr: snapshot preserved at {} for manual recovery",
443                        dir.display()
444                    );
445                }
446            } else {
447                self.repo.clear_snapshot();
448            }
449        }
450    }
451}