Skip to main content

sr_ai/git/
mod.rs

1use anyhow::{Context, Result, bail};
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::process::Command;
5
6pub struct GitRepo {
7    root: PathBuf,
8}
9
10#[allow(dead_code)]
11impl GitRepo {
12    pub fn discover() -> Result<Self> {
13        let output = Command::new("git")
14            .args(["rev-parse", "--show-toplevel"])
15            .output()
16            .context("failed to run git")?;
17
18        if !output.status.success() {
19            bail!(crate::error::SrAiError::NotAGitRepo);
20        }
21
22        let root = String::from_utf8(output.stdout)
23            .context("invalid utf-8 from git")?
24            .trim()
25            .into();
26
27        Ok(Self { root })
28    }
29
30    pub fn root(&self) -> &PathBuf {
31        &self.root
32    }
33
34    fn git(&self, args: &[&str]) -> Result<String> {
35        let output = Command::new("git")
36            .args(["-C", self.root.to_str().unwrap()])
37            .args(args)
38            .output()
39            .with_context(|| format!("failed to run git {}", args.join(" ")))?;
40
41        if !output.status.success() {
42            let stderr = String::from_utf8_lossy(&output.stderr);
43            bail!(crate::error::SrAiError::GitCommand(format!(
44                "git {} failed: {}",
45                args.join(" "),
46                stderr.trim()
47            )));
48        }
49
50        Ok(String::from_utf8_lossy(&output.stdout).to_string())
51    }
52
53    fn git_allow_failure(&self, args: &[&str]) -> Result<(bool, String)> {
54        let output = Command::new("git")
55            .args(["-C", self.root.to_str().unwrap()])
56            .args(args)
57            .output()
58            .with_context(|| format!("failed to run git {}", args.join(" ")))?;
59
60        Ok((
61            output.status.success(),
62            String::from_utf8_lossy(&output.stdout).to_string(),
63        ))
64    }
65
66    pub fn has_staged_changes(&self) -> Result<bool> {
67        let out = self.git(&["diff", "--cached", "--name-only"])?;
68        Ok(!out.trim().is_empty())
69    }
70
71    pub fn has_any_changes(&self) -> Result<bool> {
72        let out = self.git(&["status", "--porcelain"])?;
73        Ok(!out.trim().is_empty())
74    }
75
76    pub fn has_head(&self) -> Result<bool> {
77        let (ok, _) = self.git_allow_failure(&["rev-parse", "HEAD"])?;
78        Ok(ok)
79    }
80
81    pub fn reset_head(&self) -> Result<()> {
82        if self.has_head()? {
83            self.git(&["reset", "HEAD", "--quiet"])?;
84        } else {
85            // Fresh repo with no commits — unstage via rm --cached
86            let _ = self.git_allow_failure(&["rm", "--cached", "-r", ".", "--quiet"]);
87        }
88        Ok(())
89    }
90
91    pub fn stage_file(&self, file: &str) -> Result<bool> {
92        let full_path = self.root.join(file);
93        let exists = full_path.exists();
94
95        if !exists {
96            // Check if it's a deleted file
97            let out = self.git(&["ls-files", "--deleted"])?;
98            let is_deleted = out.lines().any(|l| l.trim() == file);
99            if !is_deleted {
100                return Ok(false);
101            }
102        }
103
104        let (ok, _) = self.git_allow_failure(&["add", "--", file])?;
105        Ok(ok)
106    }
107
108    pub fn has_staged_after_add(&self) -> Result<bool> {
109        self.has_staged_changes()
110    }
111
112    pub fn commit(&self, message: &str) -> Result<()> {
113        let output = Command::new("git")
114            .args(["-C", self.root.to_str().unwrap()])
115            .args(["commit", "-F", "-"])
116            .stdin(std::process::Stdio::piped())
117            .stdout(std::process::Stdio::piped())
118            .stderr(std::process::Stdio::piped())
119            .spawn()
120            .context("failed to spawn git commit")?;
121
122        use std::io::Write;
123        let mut child = output;
124        if let Some(mut stdin) = child.stdin.take() {
125            stdin.write_all(message.as_bytes())?;
126        }
127
128        let out = child.wait_with_output()?;
129        if !out.status.success() {
130            let stderr = String::from_utf8_lossy(&out.stderr);
131            bail!(crate::error::SrAiError::GitCommand(format!(
132                "git commit failed: {}",
133                stderr.trim()
134            )));
135        }
136
137        Ok(())
138    }
139
140    pub fn recent_commits(&self, count: usize) -> Result<String> {
141        self.git(&["--no-pager", "log", "--oneline", &format!("-{count}")])
142    }
143
144    pub fn diff_cached(&self) -> Result<String> {
145        self.git(&["diff", "--cached"])
146    }
147
148    pub fn diff_cached_stat(&self) -> Result<String> {
149        self.git(&["diff", "--cached", "--stat"])
150    }
151
152    pub fn diff_head(&self) -> Result<String> {
153        let (ok, out) = self.git_allow_failure(&["diff", "HEAD"])?;
154        if ok { Ok(out) } else { self.git(&["diff"]) }
155    }
156
157    pub fn status_porcelain(&self) -> Result<String> {
158        self.git(&["status", "--porcelain"])
159    }
160
161    pub fn untracked_files(&self) -> Result<String> {
162        self.git(&["ls-files", "--others", "--exclude-standard"])
163    }
164
165    pub fn show(&self, rev: &str) -> Result<String> {
166        self.git(&["show", rev])
167    }
168
169    pub fn log_range(&self, base: &str, count: Option<usize>) -> Result<String> {
170        let mut args = vec!["--no-pager", "log", "--oneline"];
171        let count_str;
172        if let Some(n) = count {
173            count_str = format!("-{n}");
174            args.push(&count_str);
175        }
176        args.push(base);
177        self.git(&args)
178    }
179
180    pub fn diff_range(&self, base: &str) -> Result<String> {
181        self.git(&["diff", base])
182    }
183
184    pub fn current_branch(&self) -> Result<String> {
185        let out = self.git(&["rev-parse", "--abbrev-ref", "HEAD"])?;
186        Ok(out.trim().to_string())
187    }
188
189    pub fn head_short(&self) -> Result<String> {
190        let out = self.git(&["rev-parse", "--short", "HEAD"])?;
191        Ok(out.trim().to_string())
192    }
193
194    pub fn file_statuses(&self) -> Result<HashMap<String, char>> {
195        let out = self.git(&["status", "--porcelain"])?;
196        let mut map = HashMap::new();
197        for line in out.lines() {
198            if line.len() < 3 {
199                continue;
200            }
201            let xy = &line.as_bytes()[..2];
202            let mut path = line[3..].to_string();
203            if let Some(pos) = path.find(" -> ") {
204                path = path[pos + 4..].to_string();
205            }
206            let (x, y) = (xy[0], xy[1]);
207            let status = match (x, y) {
208                (b'?', b'?') => 'A',
209                (b'A', _) | (_, b'A') => 'A',
210                (b'D', _) | (_, b'D') => 'D',
211                (b'R', _) | (_, b'R') => 'R',
212                (b'M', _) | (_, b'M') | (b'T', _) | (_, b'T') => 'M',
213                _ => '~',
214            };
215            map.insert(path, status);
216        }
217        Ok(map)
218    }
219
220    /// Create a snapshot of the working tree state into the platform data directory.
221    /// Location: `<data_local_dir>/sr/snapshots/<repo-hash>/`
222    ///   - macOS:   ~/Library/Application Support/sr/snapshots/<hash>/
223    ///   - Linux:   ~/.local/share/sr/snapshots/<hash>/
224    ///   - Windows: %LOCALAPPDATA%/sr/snapshots/<hash>/
225    ///
226    /// The snapshot includes:
227    /// - A `stash` ref created via `git stash create` (staged + unstaged changes)
228    /// - An `untracked.tar` of untracked files
229    /// - The list of staged files in `staged.txt`
230    /// - The repo root path in `repo_root` (so restore targets the right directory)
231    ///
232    /// Lives completely outside the repo so the agent cannot touch it.
233    pub fn snapshot_working_tree(&self) -> Result<PathBuf> {
234        let snapshot_dir = snapshot_dir_for(&self.root)
235            .context("failed to resolve snapshot directory (no data directory available)")?;
236        std::fs::create_dir_all(&snapshot_dir).context("failed to create snapshot directory")?;
237
238        // Record which repo this snapshot belongs to
239        std::fs::write(
240            snapshot_dir.join("repo_root"),
241            self.root.to_string_lossy().as_bytes(),
242        )
243        .context("failed to write repo_root")?;
244
245        // Capture staged file list
246        let staged = self.git(&["diff", "--cached", "--name-only"])?;
247        std::fs::write(snapshot_dir.join("staged.txt"), staged.trim())
248            .context("failed to write staged.txt")?;
249
250        // Create a stash object without modifying working tree or index.
251        // `git stash create` writes a stash commit but doesn't reset anything.
252        let (ok, stash_ref) = self.git_allow_failure(&["stash", "create"])?;
253        let stash_ref = stash_ref.trim().to_string();
254        if ok && !stash_ref.is_empty() {
255            std::fs::write(snapshot_dir.join("stash_ref"), &stash_ref)
256                .context("failed to write stash_ref")?;
257        } else {
258            let _ = std::fs::remove_file(snapshot_dir.join("stash_ref"));
259        }
260
261        // Archive untracked files
262        let untracked = self.git(&["ls-files", "--others", "--exclude-standard"])?;
263        let untracked_files: Vec<&str> = untracked
264            .lines()
265            .map(|l| l.trim())
266            .filter(|l| !l.is_empty())
267            .collect();
268        let tar_path = snapshot_dir.join("untracked.tar");
269        if !untracked_files.is_empty() {
270            let file_list = untracked_files.join("\n");
271            let file_list_path = snapshot_dir.join("untracked_list.txt");
272            std::fs::write(&file_list_path, &file_list)?;
273            let status = Command::new("tar")
274                .args([
275                    "cf",
276                    tar_path.to_str().unwrap(),
277                    "-T",
278                    file_list_path.to_str().unwrap(),
279                ])
280                .current_dir(&self.root)
281                .stdout(std::process::Stdio::null())
282                .stderr(std::process::Stdio::null())
283                .status()
284                .context("failed to run tar")?;
285            if !status.success() {
286                eprintln!("warning: failed to archive untracked files");
287            }
288        } else {
289            let _ = std::fs::remove_file(&tar_path);
290        }
291
292        // Mark snapshot as valid
293        let now = std::time::SystemTime::now()
294            .duration_since(std::time::UNIX_EPOCH)
295            .unwrap_or_default()
296            .as_secs();
297        std::fs::write(snapshot_dir.join("timestamp"), now.to_string())
298            .context("failed to write timestamp")?;
299
300        Ok(snapshot_dir)
301    }
302
303    /// Restore working tree from the latest snapshot (best-effort).
304    pub fn restore_snapshot(&self) -> Result<()> {
305        let snapshot_dir = self.snapshot_dir()?;
306        if !snapshot_dir.join("timestamp").exists() {
307            bail!("no valid snapshot found");
308        }
309
310        // Restore tracked changes from stash ref
311        let stash_ref_path = snapshot_dir.join("stash_ref");
312        if stash_ref_path.exists() {
313            let stash_ref = std::fs::read_to_string(&stash_ref_path)?;
314            let stash_ref = stash_ref.trim();
315            if !stash_ref.is_empty() {
316                let (ok, _) = self.git_allow_failure(&["stash", "apply", stash_ref])?;
317                if !ok {
318                    eprintln!("warning: failed to apply stash {stash_ref}, trying checkout");
319                    let _ = self.git_allow_failure(&["checkout", "."]);
320                    let _ = self.git_allow_failure(&["stash", "apply", stash_ref]);
321                }
322            }
323        }
324
325        // Restore staged files
326        let staged_path = snapshot_dir.join("staged.txt");
327        if staged_path.exists() {
328            let staged = std::fs::read_to_string(&staged_path)?;
329            for file in staged.lines().filter(|l| !l.trim().is_empty()) {
330                let full = self.root.join(file);
331                if full.exists() {
332                    let _ = self.git_allow_failure(&["add", "--", file]);
333                }
334            }
335        }
336
337        // Restore untracked files
338        let tar_path = snapshot_dir.join("untracked.tar");
339        if tar_path.exists() {
340            let _ = Command::new("tar")
341                .args(["xf", tar_path.to_str().unwrap()])
342                .current_dir(&self.root)
343                .stdout(std::process::Stdio::null())
344                .stderr(std::process::Stdio::null())
345                .status();
346        }
347
348        Ok(())
349    }
350
351    /// Remove the snapshot after a successful operation.
352    pub fn clear_snapshot(&self) {
353        if let Ok(dir) = self.snapshot_dir() {
354            let _ = std::fs::remove_dir_all(&dir);
355        }
356    }
357
358    /// Returns the snapshot directory path for this repo.
359    pub fn snapshot_dir(&self) -> Result<PathBuf> {
360        snapshot_dir_for(&self.root)
361            .context("failed to resolve snapshot directory (no data directory available)")
362    }
363
364    /// Check if a valid snapshot exists.
365    pub fn has_snapshot(&self) -> bool {
366        self.snapshot_dir()
367            .map(|d| d.join("timestamp").exists())
368            .unwrap_or(false)
369    }
370}
371
372/// Resolve the snapshot directory for a repo root.
373/// `<data_local_dir>/sr/snapshots/<repo-hash>/`
374fn snapshot_dir_for(repo_root: &std::path::Path) -> Option<PathBuf> {
375    let base = dirs::data_local_dir()?;
376    let repo_id =
377        &crate::cache::fingerprint::sha256_hex(repo_root.to_string_lossy().as_bytes())[..16];
378    Some(base.join("sr").join("snapshots").join(repo_id))
379}
380
381/// Guard that ensures the snapshot is cleaned up on success
382/// and restored on failure (drop without explicit success).
383pub struct SnapshotGuard<'a> {
384    repo: &'a GitRepo,
385    succeeded: bool,
386}
387
388impl<'a> SnapshotGuard<'a> {
389    /// Create a snapshot and return the guard.
390    pub fn new(repo: &'a GitRepo) -> Result<Self> {
391        repo.snapshot_working_tree()?;
392        Ok(Self {
393            repo,
394            succeeded: false,
395        })
396    }
397
398    /// Mark the operation as successful — snapshot will be cleared on drop.
399    pub fn success(mut self) {
400        self.succeeded = true;
401        self.repo.clear_snapshot();
402    }
403}
404
405impl Drop for SnapshotGuard<'_> {
406    fn drop(&mut self) {
407        if !self.succeeded && self.repo.has_snapshot() {
408            eprintln!("sr: operation failed, restoring working tree from snapshot...");
409            if let Err(e) = self.repo.restore_snapshot() {
410                eprintln!("sr: warning: snapshot restore failed: {e}");
411                if let Ok(dir) = self.repo.snapshot_dir() {
412                    eprintln!(
413                        "sr: snapshot preserved at {} for manual recovery",
414                        dir.display()
415                    );
416                }
417            } else {
418                self.repo.clear_snapshot();
419            }
420        }
421    }
422}