1use anyhow::{Context, Result, bail};
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::process::Command;
5
6pub struct GitRepo {
7 root: PathBuf,
8}
9
10#[allow(dead_code)]
11impl GitRepo {
12 pub fn discover() -> Result<Self> {
13 let output = Command::new("git")
14 .args(["rev-parse", "--show-toplevel"])
15 .output()
16 .context("failed to run git")?;
17
18 if !output.status.success() {
19 bail!(crate::error::SrAiError::NotAGitRepo);
20 }
21
22 let root = String::from_utf8(output.stdout)
23 .context("invalid utf-8 from git")?
24 .trim()
25 .into();
26
27 Ok(Self { root })
28 }
29
30 pub fn root(&self) -> &PathBuf {
31 &self.root
32 }
33
34 fn git(&self, args: &[&str]) -> Result<String> {
35 let output = Command::new("git")
36 .args(["-C", self.root.to_str().unwrap()])
37 .args(args)
38 .output()
39 .with_context(|| format!("failed to run git {}", args.join(" ")))?;
40
41 if !output.status.success() {
42 let stderr = String::from_utf8_lossy(&output.stderr);
43 bail!(crate::error::SrAiError::GitCommand(format!(
44 "git {} failed: {}",
45 args.join(" "),
46 stderr.trim()
47 )));
48 }
49
50 Ok(String::from_utf8_lossy(&output.stdout).to_string())
51 }
52
53 fn git_allow_failure(&self, args: &[&str]) -> Result<(bool, String)> {
54 let output = Command::new("git")
55 .args(["-C", self.root.to_str().unwrap()])
56 .args(args)
57 .output()
58 .with_context(|| format!("failed to run git {}", args.join(" ")))?;
59
60 Ok((
61 output.status.success(),
62 String::from_utf8_lossy(&output.stdout).to_string(),
63 ))
64 }
65
66 pub fn has_staged_changes(&self) -> Result<bool> {
67 let out = self.git(&["diff", "--cached", "--name-only"])?;
68 Ok(!out.trim().is_empty())
69 }
70
71 pub fn has_any_changes(&self) -> Result<bool> {
72 let out = self.git(&["status", "--porcelain"])?;
73 Ok(!out.trim().is_empty())
74 }
75
76 pub fn has_head(&self) -> Result<bool> {
77 let (ok, _) = self.git_allow_failure(&["rev-parse", "HEAD"])?;
78 Ok(ok)
79 }
80
81 pub fn reset_head(&self) -> Result<()> {
82 if self.has_head()? {
83 self.git(&["reset", "HEAD", "--quiet"])?;
84 } else {
85 let _ = self.git_allow_failure(&["rm", "--cached", "-r", ".", "--quiet"]);
87 }
88 Ok(())
89 }
90
91 pub fn stage_file(&self, file: &str) -> Result<bool> {
92 let full_path = self.root.join(file);
93 let exists = full_path.exists();
94
95 if !exists {
96 let out = self.git(&["ls-files", "--deleted"])?;
98 let is_deleted = out.lines().any(|l| l.trim() == file);
99 if !is_deleted {
100 return Ok(false);
101 }
102 }
103
104 let (ok, _) = self.git_allow_failure(&["add", "--", file])?;
105 Ok(ok)
106 }
107
108 pub fn has_staged_after_add(&self) -> Result<bool> {
109 self.has_staged_changes()
110 }
111
112 pub fn commit(&self, message: &str) -> Result<()> {
113 let output = Command::new("git")
114 .args(["-C", self.root.to_str().unwrap()])
115 .args(["commit", "-F", "-"])
116 .stdin(std::process::Stdio::piped())
117 .stdout(std::process::Stdio::piped())
118 .stderr(std::process::Stdio::piped())
119 .spawn()
120 .context("failed to spawn git commit")?;
121
122 use std::io::Write;
123 let mut child = output;
124 if let Some(mut stdin) = child.stdin.take() {
125 stdin.write_all(message.as_bytes())?;
126 }
127
128 let out = child.wait_with_output()?;
129 if !out.status.success() {
130 let stderr = String::from_utf8_lossy(&out.stderr);
131 bail!(crate::error::SrAiError::GitCommand(format!(
132 "git commit failed: {}",
133 stderr.trim()
134 )));
135 }
136
137 Ok(())
138 }
139
140 pub fn recent_commits(&self, count: usize) -> Result<String> {
141 self.git(&["--no-pager", "log", "--oneline", &format!("-{count}")])
142 }
143
144 pub fn diff_cached(&self) -> Result<String> {
145 self.git(&["diff", "--cached"])
146 }
147
148 pub fn diff_cached_stat(&self) -> Result<String> {
149 self.git(&["diff", "--cached", "--stat"])
150 }
151
152 pub fn diff_head(&self) -> Result<String> {
153 let (ok, out) = self.git_allow_failure(&["diff", "HEAD"])?;
154 if ok { Ok(out) } else { self.git(&["diff"]) }
155 }
156
157 pub fn status_porcelain(&self) -> Result<String> {
158 self.git(&["status", "--porcelain"])
159 }
160
161 pub fn untracked_files(&self) -> Result<String> {
162 self.git(&["ls-files", "--others", "--exclude-standard"])
163 }
164
165 pub fn show(&self, rev: &str) -> Result<String> {
166 self.git(&["show", rev])
167 }
168
169 pub fn log_range(&self, base: &str, count: Option<usize>) -> Result<String> {
170 let mut args = vec!["--no-pager", "log", "--oneline"];
171 let count_str;
172 if let Some(n) = count {
173 count_str = format!("-{n}");
174 args.push(&count_str);
175 }
176 args.push(base);
177 self.git(&args)
178 }
179
180 pub fn diff_range(&self, base: &str) -> Result<String> {
181 self.git(&["diff", base])
182 }
183
184 pub fn current_branch(&self) -> Result<String> {
185 let out = self.git(&["rev-parse", "--abbrev-ref", "HEAD"])?;
186 Ok(out.trim().to_string())
187 }
188
189 pub fn head_short(&self) -> Result<String> {
190 let out = self.git(&["rev-parse", "--short", "HEAD"])?;
191 Ok(out.trim().to_string())
192 }
193
194 pub fn commits_since_last_tag(&self) -> Result<usize> {
196 let (ok, tag) = self.git_allow_failure(&["describe", "--tags", "--abbrev=0"])?;
198 let tag = tag.trim();
199
200 let out = if ok && !tag.is_empty() {
201 self.git(&["rev-list", &format!("{tag}..HEAD"), "--count"])?
202 } else {
203 self.git(&["rev-list", "HEAD", "--count"])?
204 };
205
206 out.trim()
207 .parse::<usize>()
208 .context("failed to parse commit count")
209 }
210
211 pub fn log_detailed(&self, count: usize) -> Result<String> {
213 let out = self.git(&[
214 "--no-pager",
215 "log",
216 "--reverse",
217 &format!("-{count}"),
218 "--format=%h %s%n%b%n---",
219 ])?;
220 Ok(out)
221 }
222
223 pub fn file_statuses(&self) -> Result<HashMap<String, char>> {
224 let out = self.git(&["status", "--porcelain"])?;
225 let mut map = HashMap::new();
226 for line in out.lines() {
227 if line.len() < 3 {
228 continue;
229 }
230 let xy = &line.as_bytes()[..2];
231 let mut path = line[3..].to_string();
232 if let Some(pos) = path.find(" -> ") {
233 path = path[pos + 4..].to_string();
234 }
235 let (x, y) = (xy[0], xy[1]);
236 let status = match (x, y) {
237 (b'?', b'?') => 'A',
238 (b'A', _) | (_, b'A') => 'A',
239 (b'D', _) | (_, b'D') => 'D',
240 (b'R', _) | (_, b'R') => 'R',
241 (b'M', _) | (_, b'M') | (b'T', _) | (_, b'T') => 'M',
242 _ => '~',
243 };
244 map.insert(path, status);
245 }
246 Ok(map)
247 }
248
249 pub fn snapshot_working_tree(&self) -> Result<PathBuf> {
262 let snapshot_dir = snapshot_dir_for(&self.root)
263 .context("failed to resolve snapshot directory (no data directory available)")?;
264 if snapshot_dir.exists() {
266 std::fs::remove_dir_all(&snapshot_dir).ok();
267 }
268 std::fs::create_dir_all(&snapshot_dir).context("failed to create snapshot directory")?;
269
270 let files_dir = snapshot_dir.join("files");
271 std::fs::create_dir_all(&files_dir)?;
272
273 std::fs::write(
275 snapshot_dir.join("repo_root"),
276 self.root.to_string_lossy().as_bytes(),
277 )
278 .context("failed to write repo_root")?;
279
280 let (has_head, head_ref) = self.git_allow_failure(&["rev-parse", "HEAD"])?;
282 if has_head {
283 std::fs::write(snapshot_dir.join("head_ref"), head_ref.trim())
284 .context("failed to write head_ref")?;
285 }
286
287 let porcelain = self.git(&["status", "--porcelain"])?;
290 let staged_names = self.git(&["diff", "--cached", "--name-only"])?;
291 let staged_set: std::collections::HashSet<&str> = staged_names
292 .lines()
293 .map(|l| l.trim())
294 .filter(|l| !l.is_empty())
295 .collect();
296
297 #[derive(serde::Serialize, serde::Deserialize)]
298 struct ManifestEntry {
299 path: String,
300 index_status: char,
302 worktree_status: char,
304 staged: bool,
306 has_content: bool,
308 }
309
310 let mut manifest: Vec<ManifestEntry> = Vec::new();
311
312 for line in porcelain.lines() {
313 if line.len() < 3 {
314 continue;
315 }
316 let bytes = line.as_bytes();
317 let x = bytes[0] as char;
318 let y = bytes[1] as char;
319 let mut path = line[3..].to_string();
320 if let Some(pos) = path.find(" -> ") {
322 path = path[pos + 4..].to_string();
323 }
324
325 let src = self.root.join(&path);
326 let has_content = src.exists() && src.is_file();
327
328 if has_content {
329 let dest = files_dir.join(&path);
330 if let Some(parent) = dest.parent() {
331 std::fs::create_dir_all(parent).ok();
332 }
333 if let Err(e) = std::fs::copy(&src, &dest) {
334 eprintln!("warning: failed to snapshot {path}: {e}");
335 }
336 }
337
338 manifest.push(ManifestEntry {
339 staged: staged_set.contains(path.as_str()),
340 path,
341 index_status: x,
342 worktree_status: y,
343 has_content,
344 });
345 }
346
347 let manifest_json =
348 serde_json::to_string_pretty(&manifest).context("failed to serialize manifest")?;
349 std::fs::write(snapshot_dir.join("manifest.json"), manifest_json)
350 .context("failed to write manifest.json")?;
351
352 let now = std::time::SystemTime::now()
354 .duration_since(std::time::UNIX_EPOCH)
355 .unwrap_or_default()
356 .as_secs();
357 std::fs::write(snapshot_dir.join("timestamp"), now.to_string())
358 .context("failed to write timestamp")?;
359
360 Ok(snapshot_dir)
361 }
362
363 pub fn restore_snapshot(&self) -> Result<()> {
373 let snapshot_dir = self.snapshot_dir()?;
374 if !snapshot_dir.join("timestamp").exists() {
375 bail!("no valid snapshot found");
376 }
377
378 let files_dir = snapshot_dir.join("files");
379
380 let head_ref_path = snapshot_dir.join("head_ref");
382 if head_ref_path.exists() {
383 let original_head = std::fs::read_to_string(&head_ref_path)?;
384 let original_head = original_head.trim();
385 if !original_head.is_empty() {
386 let _ = self.git_allow_failure(&["reset", "--soft", original_head]);
387 }
388 }
389
390 self.reset_head()?;
392
393 let manifest_path = snapshot_dir.join("manifest.json");
395 if !manifest_path.exists() {
396 bail!("snapshot manifest.json missing — cannot restore");
397 }
398
399 #[derive(serde::Deserialize)]
400 struct ManifestEntry {
401 path: String,
402 index_status: char,
403 worktree_status: char,
404 staged: bool,
405 has_content: bool,
406 }
407
408 let manifest_data = std::fs::read_to_string(&manifest_path)?;
409 let manifest: Vec<ManifestEntry> =
410 serde_json::from_str(&manifest_data).context("failed to parse snapshot manifest")?;
411
412 let mut restored = 0usize;
413 let mut failed = 0usize;
414
415 for entry in &manifest {
416 let dest = self.root.join(&entry.path);
417
418 if entry.has_content {
419 let src = files_dir.join(&entry.path);
421 if src.exists() {
422 if let Some(parent) = dest.parent() {
423 std::fs::create_dir_all(parent).ok();
424 }
425 match std::fs::copy(&src, &dest) {
426 Ok(_) => restored += 1,
427 Err(e) => {
428 eprintln!("warning: failed to restore {}: {e}", entry.path);
429 failed += 1;
430 }
431 }
432 } else {
433 eprintln!("warning: snapshot missing content for {}", entry.path);
434 failed += 1;
435 }
436 } else if entry.index_status == 'D' || entry.worktree_status == 'D' {
437 if dest.exists() {
439 std::fs::remove_file(&dest).ok();
440 }
441 }
442
443 if entry.staged {
445 let _ = self.git_allow_failure(&["add", "--", &entry.path]);
446 }
447 }
448
449 if failed > 0 {
450 eprintln!("sr: restored {restored} files, {failed} failed");
451 }
452
453 Ok(())
454 }
455
456 pub fn clear_snapshot(&self) {
458 if let Ok(dir) = self.snapshot_dir() {
459 let _ = std::fs::remove_dir_all(&dir);
460 }
461 }
462
463 pub fn snapshot_dir(&self) -> Result<PathBuf> {
465 snapshot_dir_for(&self.root)
466 .context("failed to resolve snapshot directory (no data directory available)")
467 }
468
469 pub fn has_snapshot(&self) -> bool {
471 self.snapshot_dir()
472 .map(|d| d.join("timestamp").exists())
473 .unwrap_or(false)
474 }
475}
476
477fn snapshot_dir_for(repo_root: &std::path::Path) -> Option<PathBuf> {
480 let base = dirs::data_local_dir()?;
481 let repo_id =
482 &crate::cache::fingerprint::sha256_hex(repo_root.to_string_lossy().as_bytes())[..16];
483 Some(base.join("sr").join("snapshots").join(repo_id))
484}
485
486pub struct SnapshotGuard<'a> {
489 repo: &'a GitRepo,
490 succeeded: bool,
491}
492
493impl<'a> SnapshotGuard<'a> {
494 pub fn new(repo: &'a GitRepo) -> Result<Self> {
496 repo.snapshot_working_tree()?;
497 Ok(Self {
498 repo,
499 succeeded: false,
500 })
501 }
502
503 pub fn success(mut self) {
505 self.succeeded = true;
506 self.repo.clear_snapshot();
507 }
508}
509
510impl Drop for SnapshotGuard<'_> {
511 fn drop(&mut self) {
512 if !self.succeeded && self.repo.has_snapshot() {
513 eprintln!("sr: operation failed, restoring working tree from snapshot...");
514 if let Err(e) = self.repo.restore_snapshot() {
515 eprintln!("sr: warning: snapshot restore failed: {e}");
516 if let Ok(dir) = self.repo.snapshot_dir() {
517 eprintln!(
518 "sr: snapshot preserved at {} for manual recovery",
519 dir.display()
520 );
521 }
522 } else {
523 self.repo.clear_snapshot();
524 }
525 }
526 }
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532 use std::fs;
533
534 fn temp_repo() -> (tempfile::TempDir, GitRepo) {
536 let dir = tempfile::tempdir().unwrap();
537 let root = dir.path().to_path_buf();
538
539 let git = |args: &[&str]| {
540 Command::new("git")
541 .args(["-C", root.to_str().unwrap()])
542 .args(args)
543 .output()
544 .unwrap()
545 };
546
547 git(&["init"]);
548 git(&["config", "user.email", "test@test.com"]);
549 git(&["config", "user.name", "Test"]);
550 fs::write(root.join("init.txt"), "init").unwrap();
552 git(&["add", "init.txt"]);
553 git(&["commit", "-m", "initial"]);
554
555 let repo = GitRepo { root };
556 (dir, repo)
557 }
558
559 #[test]
560 fn snapshot_creates_manifest_with_staged_files() {
561 let (_dir, repo) = temp_repo();
562
563 fs::write(repo.root.join("new.go"), "package main").unwrap();
565 repo.git(&["add", "new.go"]).unwrap();
566
567 let snap_dir = repo.snapshot_working_tree().unwrap();
568
569 let manifest_path = snap_dir.join("manifest.json");
571 assert!(manifest_path.exists(), "manifest.json should exist");
572
573 let data = fs::read_to_string(&manifest_path).unwrap();
574 assert!(data.contains("new.go"), "manifest should list new.go");
575 assert!(
576 data.contains("\"staged\": true"),
577 "new.go should be marked staged"
578 );
579
580 assert!(
582 snap_dir.join("files/new.go").exists(),
583 "file content should be copied"
584 );
585 assert_eq!(
586 fs::read_to_string(snap_dir.join("files/new.go")).unwrap(),
587 "package main"
588 );
589
590 assert!(snap_dir.join("head_ref").exists());
592
593 repo.clear_snapshot();
594 }
595
596 #[test]
597 fn snapshot_restore_recovers_staged_new_files() {
598 let (_dir, repo) = temp_repo();
599
600 fs::write(repo.root.join("a.go"), "package a").unwrap();
602 fs::write(repo.root.join("b.go"), "package b").unwrap();
603 repo.git(&["add", "a.go", "b.go"]).unwrap();
604
605 repo.snapshot_working_tree().unwrap();
606
607 repo.reset_head().unwrap();
609 repo.git(&["add", "a.go"]).unwrap();
610 repo.git(&["commit", "-m", "partial"]).unwrap();
611
612 repo.restore_snapshot().unwrap();
614
615 assert!(repo.root.join("a.go").exists());
617 assert!(repo.root.join("b.go").exists());
618 assert_eq!(
619 fs::read_to_string(repo.root.join("a.go")).unwrap(),
620 "package a"
621 );
622 assert_eq!(
623 fs::read_to_string(repo.root.join("b.go")).unwrap(),
624 "package b"
625 );
626
627 let staged = repo.git(&["diff", "--cached", "--name-only"]).unwrap();
629 assert!(staged.contains("a.go"), "a.go should be re-staged");
630 assert!(staged.contains("b.go"), "b.go should be re-staged");
631
632 let log = repo.git(&["log", "--oneline"]).unwrap();
634 assert!(
635 !log.contains("partial"),
636 "partial commit should be undone by HEAD reset"
637 );
638
639 repo.clear_snapshot();
640 }
641
642 #[test]
643 fn snapshot_restore_with_dirty_index_does_not_conflict() {
644 let (_dir, repo) = temp_repo();
645
646 fs::write(repo.root.join("file.rs"), "fn main() {}").unwrap();
648 repo.git(&["add", "file.rs"]).unwrap();
649
650 repo.snapshot_working_tree().unwrap();
651
652 repo.reset_head().unwrap();
654 repo.git(&["add", "file.rs"]).unwrap();
655 let result = repo.restore_snapshot();
659 assert!(
660 result.is_ok(),
661 "restore should succeed with dirty index: {result:?}"
662 );
663
664 assert_eq!(
665 fs::read_to_string(repo.root.join("file.rs")).unwrap(),
666 "fn main() {}"
667 );
668
669 repo.clear_snapshot();
670 }
671
672 #[test]
673 fn snapshot_handles_modified_files() {
674 let (_dir, repo) = temp_repo();
675
676 fs::write(repo.root.join("init.txt"), "modified content").unwrap();
678 repo.git(&["add", "init.txt"]).unwrap();
679
680 repo.snapshot_working_tree().unwrap();
681
682 repo.reset_head().unwrap();
684 fs::write(repo.root.join("init.txt"), "wrong content").unwrap();
685
686 repo.restore_snapshot().unwrap();
688
689 assert_eq!(
690 fs::read_to_string(repo.root.join("init.txt")).unwrap(),
691 "modified content"
692 );
693
694 repo.clear_snapshot();
695 }
696
697 #[test]
698 fn snapshot_guard_restores_on_drop() {
699 let (_dir, repo) = temp_repo();
700
701 fs::write(repo.root.join("guarded.txt"), "important").unwrap();
702 repo.git(&["add", "guarded.txt"]).unwrap();
703
704 {
705 let _guard = SnapshotGuard::new(&repo).unwrap();
706 repo.reset_head().unwrap();
708 fs::remove_file(repo.root.join("guarded.txt")).ok();
709 }
711
712 assert!(repo.root.join("guarded.txt").exists());
714 assert_eq!(
715 fs::read_to_string(repo.root.join("guarded.txt")).unwrap(),
716 "important"
717 );
718 }
719
720 #[test]
721 fn snapshot_guard_clears_on_success() {
722 let (_dir, repo) = temp_repo();
723
724 fs::write(repo.root.join("ok.txt"), "data").unwrap();
725 repo.git(&["add", "ok.txt"]).unwrap();
726
727 let guard = SnapshotGuard::new(&repo).unwrap();
728 assert!(repo.has_snapshot());
729 guard.success();
730
731 assert!(!repo.has_snapshot());
733 }
734}