1use anyhow::{Context, Result, bail};
2use semver::Version;
3use sha2::{Digest, Sha256};
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::process::Command;
7
8use crate::commit::Commit;
9use crate::error::ReleaseError;
10
11fn sha256_hex(data: &[u8]) -> String {
12 let mut hasher = Sha256::new();
13 hasher.update(data);
14 format!("{:x}", hasher.finalize())
15}
16
17#[derive(Debug, Clone)]
19pub struct TagInfo {
20 pub name: String,
21 pub version: Version,
22 pub sha: String,
23}
24
25pub trait GitRepository: Send + Sync {
27 fn latest_tag(&self, prefix: &str) -> Result<Option<TagInfo>, ReleaseError>;
29
30 fn commits_since(&self, from: Option<&str>) -> Result<Vec<Commit>, ReleaseError>;
33
34 fn create_tag(&self, name: &str, message: &str, sign: bool) -> Result<(), ReleaseError>;
36
37 fn push_tag(&self, name: &str) -> Result<(), ReleaseError>;
39
40 fn stage_and_commit(&self, paths: &[&str], message: &str) -> Result<bool, ReleaseError>;
43
44 fn is_dirty(&self) -> Result<bool, ReleaseError>;
46
47 fn push(&self) -> Result<(), ReleaseError>;
49
50 fn tag_exists(&self, name: &str) -> Result<bool, ReleaseError>;
52
53 fn remote_tag_exists(&self, name: &str) -> Result<bool, ReleaseError>;
55
56 fn all_tags(&self, prefix: &str) -> Result<Vec<TagInfo>, ReleaseError>;
58
59 fn commits_between(&self, from: Option<&str>, to: &str) -> Result<Vec<Commit>, ReleaseError>;
62
63 fn tag_date(&self, tag_name: &str) -> Result<String, ReleaseError>;
65
66 fn force_create_tag(&self, name: &str) -> Result<(), ReleaseError>;
68
69 fn force_push_tag(&self, name: &str) -> Result<(), ReleaseError>;
71
72 fn head_sha(&self) -> Result<String, ReleaseError>;
74
75 fn commits_since_in_path(
77 &self,
78 from: Option<&str>,
79 path: &str,
80 ) -> Result<Vec<Commit>, ReleaseError> {
81 let _ = path;
83 self.commits_since(from)
84 }
85
86 fn commits_between_in_path(
88 &self,
89 from: Option<&str>,
90 to: &str,
91 path: &str,
92 ) -> Result<Vec<Commit>, ReleaseError> {
93 let _ = path;
94 self.commits_between(from, to)
95 }
96}
97
98fn git_unquote(s: &str) -> String {
103 let s = s.trim();
104 if !(s.starts_with('"') && s.ends_with('"')) {
105 return s.to_string();
106 }
107 let inner = &s[1..s.len() - 1];
109 let mut out = Vec::new();
110 let bytes = inner.as_bytes();
111 let mut i = 0;
112 while i < bytes.len() {
113 if bytes[i] == b'\\' && i + 1 < bytes.len() {
114 i += 1;
115 match bytes[i] {
116 b'\\' => out.push(b'\\'),
117 b'"' => out.push(b'"'),
118 b'n' => out.push(b'\n'),
119 b't' => out.push(b'\t'),
120 b'r' => out.push(b'\r'),
121 b'a' => out.push(0x07),
122 b'b' => out.push(0x08),
123 b'f' => out.push(0x0C),
124 b'v' => out.push(0x0B),
125 b'0'..=b'3' => {
127 let mut val = (bytes[i] - b'0') as u16;
128 for _ in 0..2 {
129 if i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
130 i += 1;
131 val = val * 8 + (bytes[i] - b'0') as u16;
132 } else {
133 break;
134 }
135 }
136 out.push(val as u8);
137 }
138 other => {
139 out.push(b'\\');
140 out.push(other);
141 }
142 }
143 } else {
144 out.push(bytes[i]);
145 }
146 i += 1;
147 }
148 String::from_utf8(out).unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).to_string())
149}
150
151pub struct GitRepo {
152 root: PathBuf,
153}
154
155#[allow(dead_code)]
156impl GitRepo {
157 pub fn discover() -> Result<Self> {
158 let output = Command::new("git")
159 .args(["rev-parse", "--show-toplevel"])
160 .output()
161 .context("failed to run git")?;
162
163 if !output.status.success() {
164 bail!("not in a git repository");
165 }
166
167 let root = String::from_utf8(output.stdout)
168 .context("invalid utf-8 from git")?
169 .trim()
170 .into();
171
172 Ok(Self { root })
173 }
174
175 pub fn root(&self) -> &PathBuf {
176 &self.root
177 }
178
179 fn git(&self, args: &[&str]) -> Result<String> {
180 let output = Command::new("git")
181 .args(["-C", self.root.to_str().unwrap()])
182 .args(args)
183 .output()
184 .with_context(|| format!("failed to run git {}", args.join(" ")))?;
185
186 if !output.status.success() {
187 let stderr = String::from_utf8_lossy(&output.stderr);
188 bail!("git {} failed: {}", args.join(" "), stderr.trim());
189 }
190
191 Ok(String::from_utf8_lossy(&output.stdout).to_string())
192 }
193
194 fn git_allow_failure(&self, args: &[&str]) -> Result<(bool, String)> {
195 let output = Command::new("git")
196 .args(["-C", self.root.to_str().unwrap()])
197 .args(args)
198 .output()
199 .with_context(|| format!("failed to run git {}", args.join(" ")))?;
200
201 Ok((
202 output.status.success(),
203 String::from_utf8_lossy(&output.stdout).to_string(),
204 ))
205 }
206
207 pub fn has_staged_changes(&self) -> Result<bool> {
208 let out = self.git(&["diff", "--cached", "--name-only"])?;
209 Ok(!out.trim().is_empty())
210 }
211
212 pub fn has_any_changes(&self) -> Result<bool> {
213 let out = self.git(&["status", "--porcelain"])?;
214 Ok(!out.trim().is_empty())
215 }
216
217 pub fn has_head(&self) -> Result<bool> {
218 let (ok, _) = self.git_allow_failure(&["rev-parse", "HEAD"])?;
219 Ok(ok)
220 }
221
222 pub fn reset_head(&self) -> Result<()> {
223 if self.has_head()? {
224 self.git(&["reset", "HEAD", "--quiet"])?;
225 } else {
226 let _ = self.git_allow_failure(&["rm", "--cached", "-r", ".", "--quiet"]);
228 }
229 Ok(())
230 }
231
232 pub fn stage_file(&self, file: &str) -> Result<bool> {
233 let (ok, _) = self.git_allow_failure(&["add", "--", file])?;
241 Ok(ok)
242 }
243
244 pub fn has_staged_after_add(&self) -> Result<bool> {
245 self.has_staged_changes()
246 }
247
248 pub fn commit(&self, message: &str) -> Result<()> {
249 let output = Command::new("git")
250 .args(["-C", self.root.to_str().unwrap()])
251 .args(["commit", "-F", "-"])
252 .stdin(std::process::Stdio::piped())
253 .stdout(std::process::Stdio::piped())
254 .stderr(std::process::Stdio::piped())
255 .spawn()
256 .context("failed to spawn git commit")?;
257
258 use std::io::Write;
259 let mut child = output;
260 if let Some(mut stdin) = child.stdin.take() {
261 stdin.write_all(message.as_bytes())?;
262 }
263
264 let out = child.wait_with_output()?;
265 if !out.status.success() {
266 let stderr = String::from_utf8_lossy(&out.stderr);
267 bail!("git commit failed: {}", stderr.trim());
268 }
269
270 Ok(())
271 }
272
273 pub fn recent_commits(&self, count: usize) -> Result<String> {
274 self.git(&["--no-pager", "log", "--oneline", &format!("-{count}")])
275 }
276
277 pub fn diff_cached(&self) -> Result<String> {
278 self.git(&["diff", "--cached"])
279 }
280
281 pub fn diff_cached_stat(&self) -> Result<String> {
282 self.git(&["diff", "--cached", "--stat"])
283 }
284
285 pub fn diff_head(&self) -> Result<String> {
286 let (ok, out) = self.git_allow_failure(&["diff", "HEAD"])?;
287 if ok { Ok(out) } else { self.git(&["diff"]) }
288 }
289
290 pub fn diff_unified(&self, staged: bool, context: usize, files: &[String]) -> Result<String> {
293 let ctx_flag = format!("-U{context}");
294 let mut args: Vec<&str> = vec!["diff", &ctx_flag];
295 if staged {
296 args.push("--cached");
297 } else {
298 args.push("HEAD");
299 }
300 if !files.is_empty() {
301 args.push("--");
302 for f in files {
303 args.push(f.as_str());
304 }
305 }
306 let (ok, out) = self.git_allow_failure(&args)?;
307 if ok {
308 Ok(out)
309 } else if !staged && files.is_empty() {
310 self.git(&["diff", &ctx_flag])
312 } else {
313 Ok(out)
314 }
315 }
316
317 pub fn diff_numstat(
319 &self,
320 staged: bool,
321 files: &[String],
322 ) -> Result<Vec<(usize, usize, String)>> {
323 let mut args: Vec<&str> = vec!["diff", "--numstat"];
324 if staged {
325 args.push("--cached");
326 } else {
327 args.push("HEAD");
328 }
329 if !files.is_empty() {
330 args.push("--");
331 for f in files {
332 args.push(f.as_str());
333 }
334 }
335 let (ok, out) = self.git_allow_failure(&args)?;
336 if !ok && !staged && files.is_empty() {
337 let out = self.git(&["diff", "--numstat"])?;
338 return Self::parse_numstat(&out);
339 }
340 Self::parse_numstat(&out)
341 }
342
343 fn parse_numstat(out: &str) -> Result<Vec<(usize, usize, String)>> {
344 let mut result = Vec::new();
345 for line in out.lines() {
346 let parts: Vec<&str> = line.splitn(3, '\t').collect();
347 if parts.len() == 3 {
348 let add = parts[0].parse().unwrap_or(0);
350 let del = parts[1].parse().unwrap_or(0);
351 let path = if let Some(pos) = parts[2].find(" => ") {
352 git_unquote(&parts[2][pos + 4..])
354 } else {
355 git_unquote(parts[2])
356 };
357 result.push((add, del, path));
358 }
359 }
360 Ok(result)
361 }
362
363 pub fn status_porcelain(&self) -> Result<String> {
364 self.git(&["status", "--porcelain"])
365 }
366
367 pub fn untracked_files(&self) -> Result<String> {
368 self.git(&["ls-files", "--others", "--exclude-standard"])
369 }
370
371 pub fn show(&self, rev: &str) -> Result<String> {
372 self.git(&["show", rev])
373 }
374
375 pub fn log_range(&self, base: &str, count: Option<usize>) -> Result<String> {
376 let mut args = vec!["--no-pager", "log", "--oneline"];
377 let count_str;
378 if let Some(n) = count {
379 count_str = format!("-{n}");
380 args.push(&count_str);
381 }
382 args.push(base);
383 self.git(&args)
384 }
385
386 pub fn diff_range(&self, base: &str) -> Result<String> {
387 self.git(&["diff", base])
388 }
389
390 pub fn current_branch(&self) -> Result<String> {
391 let out = self.git(&["rev-parse", "--abbrev-ref", "HEAD"])?;
392 Ok(out.trim().to_string())
393 }
394
395 pub fn head_short(&self) -> Result<String> {
396 let out = self.git(&["rev-parse", "--short", "HEAD"])?;
397 Ok(out.trim().to_string())
398 }
399
400 pub fn commits_since_last_tag(&self) -> Result<usize> {
402 let (ok, tag) = self.git_allow_failure(&["describe", "--tags", "--abbrev=0"])?;
404 let tag = tag.trim();
405
406 let out = if ok && !tag.is_empty() {
407 self.git(&["rev-list", &format!("{tag}..HEAD"), "--count"])?
408 } else {
409 self.git(&["rev-list", "HEAD", "--count"])?
410 };
411
412 out.trim()
413 .parse::<usize>()
414 .context("failed to parse commit count")
415 }
416
417 pub fn log_detailed(&self, count: usize) -> Result<String> {
419 let out = self.git(&[
420 "--no-pager",
421 "log",
422 "--reverse",
423 &format!("-{count}"),
424 "--format=%h %s%n%b%n---",
425 ])?;
426 Ok(out)
427 }
428
429 pub fn file_statuses(&self) -> Result<HashMap<String, char>> {
430 let out = self.git(&["status", "--porcelain"])?;
431 let mut map = HashMap::new();
432 for line in out.lines() {
433 if line.len() < 3 {
434 continue;
435 }
436 let xy = &line.as_bytes()[..2];
437 let path = line[3..].to_string();
438 let (x, y) = (xy[0], xy[1]);
439 let is_rename = matches!((x, y), (b'R', _) | (_, b'R'));
440 if is_rename {
441 if let Some(pos) = path.find(" -> ") {
442 let old_path = git_unquote(&path[..pos]);
443 let new_path = git_unquote(&path[pos + 4..]);
444 map.insert(old_path, 'D');
445 map.insert(new_path, 'R');
446 } else {
447 map.insert(git_unquote(&path), 'R');
448 }
449 } else {
450 let status = match (x, y) {
451 (b'?', b'?') => 'A',
452 (b'A', _) | (_, b'A') => 'A',
453 (b'D', _) | (_, b'D') => 'D',
454 (b'M', _) | (_, b'M') | (b'T', _) | (_, b'T') => 'M',
455 _ => '~',
456 };
457 map.insert(git_unquote(&path), status);
458 }
459 }
460 Ok(map)
461 }
462
463 pub fn snapshot_working_tree(&self) -> Result<PathBuf> {
476 let snapshot_dir = snapshot_dir_for(&self.root)
477 .context("failed to resolve snapshot directory (no data directory available)")?;
478 if snapshot_dir.exists() {
480 std::fs::remove_dir_all(&snapshot_dir).ok();
481 }
482 std::fs::create_dir_all(&snapshot_dir).context("failed to create snapshot directory")?;
483
484 let files_dir = snapshot_dir.join("files");
485 std::fs::create_dir_all(&files_dir)?;
486
487 std::fs::write(
489 snapshot_dir.join("repo_root"),
490 self.root.to_string_lossy().as_bytes(),
491 )
492 .context("failed to write repo_root")?;
493
494 let (has_head, head_ref) = self.git_allow_failure(&["rev-parse", "HEAD"])?;
496 if has_head {
497 std::fs::write(snapshot_dir.join("head_ref"), head_ref.trim())
498 .context("failed to write head_ref")?;
499 }
500
501 let porcelain = self.git(&["status", "--porcelain"])?;
504 let staged_names = self.git(&["diff", "--cached", "--name-only", "-z"])?;
505 let staged_set: std::collections::HashSet<String> = staged_names
506 .split('\0')
507 .map(|l| l.trim().to_string())
508 .filter(|l| !l.is_empty())
509 .collect();
510
511 #[derive(serde::Serialize, serde::Deserialize)]
512 struct ManifestEntry {
513 path: String,
514 index_status: char,
516 worktree_status: char,
518 staged: bool,
520 has_content: bool,
522 }
523
524 let mut manifest: Vec<ManifestEntry> = Vec::new();
525
526 for line in porcelain.lines() {
527 if line.len() < 3 {
528 continue;
529 }
530 let bytes = line.as_bytes();
531 let x = bytes[0] as char;
532 let y = bytes[1] as char;
533 let raw = line[3..].to_string();
534 let path = if let Some(pos) = raw.find(" -> ") {
536 git_unquote(&raw[pos + 4..])
537 } else {
538 git_unquote(&raw)
539 };
540
541 let src = self.root.join(&path);
542 let has_content = src.exists() && src.is_file();
543
544 if has_content {
545 let dest = files_dir.join(&path);
546 if let Some(parent) = dest.parent() {
547 std::fs::create_dir_all(parent).ok();
548 }
549 if let Err(e) = std::fs::copy(&src, &dest) {
550 eprintln!("warning: failed to snapshot {path}: {e}");
551 }
552 }
553
554 manifest.push(ManifestEntry {
555 staged: staged_set.contains(path.as_str()),
556 path,
557 index_status: x,
558 worktree_status: y,
559 has_content,
560 });
561 }
562
563 let manifest_json =
564 serde_json::to_string_pretty(&manifest).context("failed to serialize manifest")?;
565 std::fs::write(snapshot_dir.join("manifest.json"), manifest_json)
566 .context("failed to write manifest.json")?;
567
568 let now = std::time::SystemTime::now()
570 .duration_since(std::time::UNIX_EPOCH)
571 .unwrap_or_default()
572 .as_secs();
573 std::fs::write(snapshot_dir.join("timestamp"), now.to_string())
574 .context("failed to write timestamp")?;
575
576 Ok(snapshot_dir)
577 }
578
579 pub fn restore_snapshot(&self) -> Result<()> {
589 let snapshot_dir = self.snapshot_dir()?;
590 if !snapshot_dir.join("timestamp").exists() {
591 bail!("no valid snapshot found");
592 }
593
594 let files_dir = snapshot_dir.join("files");
595
596 let head_ref_path = snapshot_dir.join("head_ref");
598 if head_ref_path.exists() {
599 let original_head = std::fs::read_to_string(&head_ref_path)?;
600 let original_head = original_head.trim();
601 if !original_head.is_empty() {
602 let _ = self.git_allow_failure(&["reset", "--soft", original_head]);
603 }
604 }
605
606 self.reset_head()?;
608
609 let manifest_path = snapshot_dir.join("manifest.json");
611 if !manifest_path.exists() {
612 bail!("snapshot manifest.json missing — cannot restore");
613 }
614
615 #[derive(serde::Deserialize)]
616 struct ManifestEntry {
617 path: String,
618 index_status: char,
619 worktree_status: char,
620 staged: bool,
621 has_content: bool,
622 }
623
624 let manifest_data = std::fs::read_to_string(&manifest_path)?;
625 let manifest: Vec<ManifestEntry> =
626 serde_json::from_str(&manifest_data).context("failed to parse snapshot manifest")?;
627
628 let mut restored = 0usize;
629 let mut failed = 0usize;
630
631 for entry in &manifest {
632 let dest = self.root.join(&entry.path);
633
634 if entry.has_content {
635 let src = files_dir.join(&entry.path);
637 if src.exists() {
638 if let Some(parent) = dest.parent() {
639 std::fs::create_dir_all(parent).ok();
640 }
641 match std::fs::copy(&src, &dest) {
642 Ok(_) => restored += 1,
643 Err(e) => {
644 eprintln!("warning: failed to restore {}: {e}", entry.path);
645 failed += 1;
646 }
647 }
648 } else {
649 eprintln!("warning: snapshot missing content for {}", entry.path);
650 failed += 1;
651 }
652 } else if entry.index_status == 'D' || entry.worktree_status == 'D' {
653 if dest.exists() {
655 std::fs::remove_file(&dest).ok();
656 }
657 }
658
659 if entry.staged {
661 let _ = self.git_allow_failure(&["add", "--", &entry.path]);
662 }
663 }
664
665 if failed > 0 {
666 eprintln!("sr: restored {restored} files, {failed} failed");
667 }
668
669 Ok(())
670 }
671
672 pub fn clear_snapshot(&self) {
674 if let Ok(dir) = self.snapshot_dir() {
675 let _ = std::fs::remove_dir_all(&dir);
676 }
677 }
678
679 pub fn snapshot_dir(&self) -> Result<PathBuf> {
681 snapshot_dir_for(&self.root)
682 .context("failed to resolve snapshot directory (no data directory available)")
683 }
684
685 pub fn has_snapshot(&self) -> bool {
687 self.snapshot_dir()
688 .map(|d| d.join("timestamp").exists())
689 .unwrap_or(false)
690 }
691}
692
693fn snapshot_dir_for(repo_root: &std::path::Path) -> Option<PathBuf> {
696 let base = dirs::data_local_dir()?;
697 let repo_id = &sha256_hex(repo_root.to_string_lossy().as_bytes())[..16];
698 Some(base.join("sr").join("snapshots").join(repo_id))
699}
700
701pub struct SnapshotGuard<'a> {
704 repo: &'a GitRepo,
705 succeeded: bool,
706}
707
708impl<'a> SnapshotGuard<'a> {
709 pub fn new(repo: &'a GitRepo) -> Result<Self> {
711 repo.snapshot_working_tree()?;
712 Ok(Self {
713 repo,
714 succeeded: false,
715 })
716 }
717
718 pub fn success(mut self) {
720 self.succeeded = true;
721 self.repo.clear_snapshot();
722 }
723}
724
725impl Drop for SnapshotGuard<'_> {
726 fn drop(&mut self) {
727 if !self.succeeded && self.repo.has_snapshot() {
728 eprintln!("sr: operation failed, restoring working tree from snapshot...");
729 if let Err(e) = self.repo.restore_snapshot() {
730 eprintln!("sr: warning: snapshot restore failed: {e}");
731 if let Ok(dir) = self.repo.snapshot_dir() {
732 eprintln!(
733 "sr: snapshot preserved at {} for manual recovery",
734 dir.display()
735 );
736 }
737 } else {
738 self.repo.clear_snapshot();
739 }
740 }
741 }
742}
743
744#[cfg(test)]
745mod tests {
746 use super::*;
747 use std::fs;
748
749 fn temp_repo() -> (tempfile::TempDir, GitRepo) {
751 let dir = tempfile::tempdir().unwrap();
752 let root = dir.path().to_path_buf();
753
754 let git = |args: &[&str]| {
755 Command::new("git")
756 .args(["-C", root.to_str().unwrap()])
757 .args(args)
758 .output()
759 .unwrap()
760 };
761
762 git(&["init"]);
763 git(&["config", "user.email", "test@test.com"]);
764 git(&["config", "user.name", "Test"]);
765 fs::write(root.join("init.txt"), "init").unwrap();
767 git(&["add", "init.txt"]);
768 git(&["commit", "-m", "initial"]);
769
770 let repo = GitRepo { root };
771 (dir, repo)
772 }
773
774 #[test]
775 fn snapshot_creates_manifest_with_staged_files() {
776 let (_dir, repo) = temp_repo();
777
778 fs::write(repo.root.join("new.go"), "package main").unwrap();
780 repo.git(&["add", "new.go"]).unwrap();
781
782 let snap_dir = repo.snapshot_working_tree().unwrap();
783
784 let manifest_path = snap_dir.join("manifest.json");
786 assert!(manifest_path.exists(), "manifest.json should exist");
787
788 let data = fs::read_to_string(&manifest_path).unwrap();
789 assert!(data.contains("new.go"), "manifest should list new.go");
790 assert!(
791 data.contains("\"staged\": true"),
792 "new.go should be marked staged"
793 );
794
795 assert!(
797 snap_dir.join("files/new.go").exists(),
798 "file content should be copied"
799 );
800 assert_eq!(
801 fs::read_to_string(snap_dir.join("files/new.go")).unwrap(),
802 "package main"
803 );
804
805 assert!(snap_dir.join("head_ref").exists());
807
808 repo.clear_snapshot();
809 }
810
811 #[test]
812 fn snapshot_restore_recovers_staged_new_files() {
813 let (_dir, repo) = temp_repo();
814
815 fs::write(repo.root.join("a.go"), "package a").unwrap();
817 fs::write(repo.root.join("b.go"), "package b").unwrap();
818 repo.git(&["add", "a.go", "b.go"]).unwrap();
819
820 repo.snapshot_working_tree().unwrap();
821
822 repo.reset_head().unwrap();
824 repo.git(&["add", "a.go"]).unwrap();
825 repo.git(&["commit", "-m", "partial"]).unwrap();
826
827 repo.restore_snapshot().unwrap();
829
830 assert!(repo.root.join("a.go").exists());
832 assert!(repo.root.join("b.go").exists());
833 assert_eq!(
834 fs::read_to_string(repo.root.join("a.go")).unwrap(),
835 "package a"
836 );
837 assert_eq!(
838 fs::read_to_string(repo.root.join("b.go")).unwrap(),
839 "package b"
840 );
841
842 let staged = repo.git(&["diff", "--cached", "--name-only"]).unwrap();
844 assert!(staged.contains("a.go"), "a.go should be re-staged");
845 assert!(staged.contains("b.go"), "b.go should be re-staged");
846
847 let log = repo.git(&["log", "--oneline"]).unwrap();
849 assert!(
850 !log.contains("partial"),
851 "partial commit should be undone by HEAD reset"
852 );
853
854 repo.clear_snapshot();
855 }
856
857 #[test]
858 fn snapshot_restore_with_dirty_index_does_not_conflict() {
859 let (_dir, repo) = temp_repo();
860
861 fs::write(repo.root.join("file.rs"), "fn main() {}").unwrap();
863 repo.git(&["add", "file.rs"]).unwrap();
864
865 repo.snapshot_working_tree().unwrap();
866
867 repo.reset_head().unwrap();
869 repo.git(&["add", "file.rs"]).unwrap();
870 let result = repo.restore_snapshot();
874 assert!(
875 result.is_ok(),
876 "restore should succeed with dirty index: {result:?}"
877 );
878
879 assert_eq!(
880 fs::read_to_string(repo.root.join("file.rs")).unwrap(),
881 "fn main() {}"
882 );
883
884 repo.clear_snapshot();
885 }
886
887 #[test]
888 fn snapshot_handles_modified_files() {
889 let (_dir, repo) = temp_repo();
890
891 fs::write(repo.root.join("init.txt"), "modified content").unwrap();
893 repo.git(&["add", "init.txt"]).unwrap();
894
895 repo.snapshot_working_tree().unwrap();
896
897 repo.reset_head().unwrap();
899 fs::write(repo.root.join("init.txt"), "wrong content").unwrap();
900
901 repo.restore_snapshot().unwrap();
903
904 assert_eq!(
905 fs::read_to_string(repo.root.join("init.txt")).unwrap(),
906 "modified content"
907 );
908
909 repo.clear_snapshot();
910 }
911
912 #[test]
913 fn snapshot_guard_restores_on_drop() {
914 let (_dir, repo) = temp_repo();
915
916 fs::write(repo.root.join("guarded.txt"), "important").unwrap();
917 repo.git(&["add", "guarded.txt"]).unwrap();
918
919 {
920 let _guard = SnapshotGuard::new(&repo).unwrap();
921 repo.reset_head().unwrap();
923 fs::remove_file(repo.root.join("guarded.txt")).ok();
924 }
926
927 assert!(repo.root.join("guarded.txt").exists());
929 assert_eq!(
930 fs::read_to_string(repo.root.join("guarded.txt")).unwrap(),
931 "important"
932 );
933 }
934
935 #[test]
936 fn snapshot_guard_clears_on_success() {
937 let (_dir, repo) = temp_repo();
938
939 fs::write(repo.root.join("ok.txt"), "data").unwrap();
940 repo.git(&["add", "ok.txt"]).unwrap();
941
942 let guard = SnapshotGuard::new(&repo).unwrap();
943 assert!(repo.has_snapshot());
944 guard.success();
945
946 assert!(!repo.has_snapshot());
948 }
949
950 #[test]
951 fn file_statuses_includes_both_sides_of_rename() {
952 let (_dir, repo) = temp_repo();
953
954 fs::write(repo.root.join("old_name.txt"), "content").unwrap();
956 repo.git(&["add", "old_name.txt"]).unwrap();
957 repo.git(&["commit", "-m", "add old_name"]).unwrap();
958
959 repo.git(&["mv", "old_name.txt", "new_name.txt"]).unwrap();
961
962 let statuses = repo.file_statuses().unwrap();
963
964 assert_eq!(
965 statuses.get("old_name.txt").copied(),
966 Some('D'),
967 "old path should appear as deleted"
968 );
969 assert_eq!(
970 statuses.get("new_name.txt").copied(),
971 Some('R'),
972 "new path should appear as renamed"
973 );
974 }
975
976 #[test]
981 fn stage_file_handles_many_moves_and_deletes_after_reset() {
982 let (_dir, repo) = temp_repo();
983
984 for i in 0..30 {
986 fs::write(
987 repo.root.join(format!("file_{i}.txt")),
988 format!("content {i}"),
989 )
990 .unwrap();
991 }
992 repo.git(&["add", "."]).unwrap();
993 repo.git(&["commit", "-m", "add files"]).unwrap();
994
995 fs::create_dir_all(repo.root.join("moved")).unwrap();
997 for i in 0..10 {
998 repo.git(&[
999 "mv",
1000 &format!("file_{i}.txt"),
1001 &format!("moved/file_{i}.txt"),
1002 ])
1003 .unwrap();
1004 }
1005
1006 for i in 10..20 {
1008 repo.git(&["rm", &format!("file_{i}.txt")]).unwrap();
1009 }
1010
1011 for i in 20..30 {
1013 fs::write(
1014 repo.root.join(format!("file_{i}.txt")),
1015 format!("modified {i}"),
1016 )
1017 .unwrap();
1018 repo.git(&["add", &format!("file_{i}.txt")]).unwrap();
1019 }
1020
1021 for i in 30..35 {
1023 fs::write(repo.root.join(format!("new_{i}.txt")), format!("new {i}")).unwrap();
1024 repo.git(&["add", &format!("new_{i}.txt")]).unwrap();
1025 }
1026
1027 let statuses = repo.file_statuses().unwrap();
1029 assert!(
1030 statuses.len() >= 30,
1031 "should have many file statuses, got {}",
1032 statuses.len()
1033 );
1034
1035 repo.reset_head().unwrap();
1037
1038 let mut failed = Vec::new();
1040 for (file, status) in &statuses {
1041 if file == "init.txt" {
1042 continue;
1043 }
1044 let ok = repo.stage_file(file).unwrap();
1045 if !ok {
1046 failed.push((file.clone(), *status));
1047 }
1048 }
1049
1050 assert!(
1051 failed.is_empty(),
1052 "stage_file failed for {} files: {:?}",
1053 failed.len(),
1054 failed
1055 );
1056 }
1057
1058 #[test]
1062 fn stage_file_handles_manual_moves_after_reset() {
1063 let (_dir, repo) = temp_repo();
1064
1065 fs::create_dir_all(repo.root.join("old_dir")).unwrap();
1067 for i in 0..10 {
1068 fs::write(
1069 repo.root.join(format!("old_dir/file_{i}.txt")),
1070 format!("content {i}"),
1071 )
1072 .unwrap();
1073 }
1074 repo.git(&["add", "."]).unwrap();
1075 repo.git(&["commit", "-m", "add directory"]).unwrap();
1076
1077 fs::rename(repo.root.join("old_dir"), repo.root.join("new_dir")).unwrap();
1079
1080 repo.git(&["add", "-A"]).unwrap();
1082
1083 let statuses = repo.file_statuses().unwrap();
1085
1086 repo.reset_head().unwrap();
1088
1089 let mut failed = Vec::new();
1091 for (file, status) in &statuses {
1092 if file == "init.txt" {
1093 continue;
1094 }
1095 let ok = repo.stage_file(file).unwrap();
1096 if !ok {
1097 failed.push((file.clone(), *status));
1098 }
1099 }
1100
1101 assert!(
1102 failed.is_empty(),
1103 "stage_file failed for {} files after manual move: {:?}",
1104 failed.len(),
1105 failed
1106 );
1107 }
1108
1109 #[test]
1114 fn stage_file_handles_new_files_mixed_with_moves() {
1115 let (_dir, repo) = temp_repo();
1116
1117 for i in 0..5 {
1119 fs::write(
1120 repo.root.join(format!("existing_{i}.txt")),
1121 format!("existing {i}"),
1122 )
1123 .unwrap();
1124 }
1125 repo.git(&["add", "."]).unwrap();
1126 repo.git(&["commit", "-m", "add existing files"]).unwrap();
1127
1128 fs::create_dir_all(repo.root.join("moved")).unwrap();
1130 for i in 0..3 {
1131 repo.git(&[
1132 "mv",
1133 &format!("existing_{i}.txt"),
1134 &format!("moved/existing_{i}.txt"),
1135 ])
1136 .unwrap();
1137 }
1138
1139 repo.git(&["rm", "existing_3.txt"]).unwrap();
1141
1142 for i in 0..5 {
1144 fs::write(
1145 repo.root.join(format!("brand_new_{i}.txt")),
1146 format!("new {i}"),
1147 )
1148 .unwrap();
1149 }
1150 repo.git(&["add", "."]).unwrap();
1151
1152 let statuses = repo.file_statuses().unwrap();
1154
1155 repo.reset_head().unwrap();
1157
1158 let mut failed = Vec::new();
1160 for (file, status) in &statuses {
1161 if file == "init.txt" {
1162 continue;
1163 }
1164 let ok = repo.stage_file(file).unwrap();
1165 if !ok {
1166 failed.push((file.clone(), *status));
1167 }
1168 }
1169
1170 assert!(
1171 failed.is_empty(),
1172 "stage_file failed for {} files: {:?}",
1173 failed.len(),
1174 failed
1175 );
1176 }
1177
1178 #[test]
1183 fn stage_file_handles_quoted_paths_from_moves() {
1184 let (_dir, repo) = temp_repo();
1185
1186 fs::write(repo.root.join("old name.txt"), "content").unwrap();
1188 repo.git(&["add", "."]).unwrap();
1189 repo.git(&["commit", "-m", "add file with spaces"]).unwrap();
1190
1191 repo.git(&["mv", "old name.txt", "new name.txt"]).unwrap();
1193
1194 let statuses = repo.file_statuses().unwrap();
1196
1197 assert!(
1199 statuses.contains_key("old name.txt"),
1200 "old path should be unquoted; got keys: {:?}",
1201 statuses.keys().collect::<Vec<_>>()
1202 );
1203 assert!(
1204 statuses.contains_key("new name.txt"),
1205 "new path should be unquoted; got keys: {:?}",
1206 statuses.keys().collect::<Vec<_>>()
1207 );
1208
1209 repo.reset_head().unwrap();
1211
1212 let old_ok = repo.stage_file("old name.txt").unwrap();
1213 assert!(old_ok, "stage_file should succeed for old (deleted) path");
1214
1215 let new_ok = repo.stage_file("new name.txt").unwrap();
1216 assert!(new_ok, "stage_file should succeed for new (added) path");
1217 }
1218
1219 #[test]
1222 fn file_statuses_unquotes_paths_with_special_chars() {
1223 let (_dir, repo) = temp_repo();
1224
1225 fs::write(repo.root.join("my file.txt"), "content").unwrap();
1227 fs::write(repo.root.join("to delete.txt"), "delete me").unwrap();
1228 repo.git(&["add", "."]).unwrap();
1229 repo.git(&["commit", "-m", "add spaced files"]).unwrap();
1230
1231 fs::write(repo.root.join("my file.txt"), "modified").unwrap();
1233 repo.git(&["rm", "to delete.txt"]).unwrap();
1234 fs::write(repo.root.join("brand new file.txt"), "new").unwrap();
1235 repo.git(&["add", "."]).unwrap();
1236
1237 let statuses = repo.file_statuses().unwrap();
1238
1239 assert!(
1241 statuses.contains_key("my file.txt"),
1242 "modified file should be unquoted; keys: {:?}",
1243 statuses.keys().collect::<Vec<_>>()
1244 );
1245 assert!(
1246 statuses.contains_key("to delete.txt"),
1247 "deleted file should be unquoted; keys: {:?}",
1248 statuses.keys().collect::<Vec<_>>()
1249 );
1250 assert!(
1251 statuses.contains_key("brand new file.txt"),
1252 "new file should be unquoted; keys: {:?}",
1253 statuses.keys().collect::<Vec<_>>()
1254 );
1255 }
1256
1257 #[test]
1261 fn stage_file_works_across_sequential_commits_with_moves() {
1262 let (_dir, repo) = temp_repo();
1263
1264 for i in 0..10 {
1266 fs::write(
1267 repo.root.join(format!("src_{i}.txt")),
1268 format!("content {i}"),
1269 )
1270 .unwrap();
1271 }
1272 repo.git(&["add", "."]).unwrap();
1273 repo.git(&["commit", "-m", "add source files"]).unwrap();
1274
1275 fs::create_dir_all(repo.root.join("dst")).unwrap();
1277 for i in 0..10 {
1278 repo.git(&["mv", &format!("src_{i}.txt"), &format!("dst/src_{i}.txt")])
1279 .unwrap();
1280 }
1281
1282 let statuses = repo.file_statuses().unwrap();
1283 repo.reset_head().unwrap();
1284
1285 for i in 0..10 {
1287 let file = format!("dst/src_{i}.txt");
1288 let ok = repo.stage_file(&file).unwrap();
1289 assert!(ok, "should stage new path {file}");
1290 }
1291 repo.commit("feat: add new paths").unwrap();
1292
1293 let mut failed = Vec::new();
1296 for i in 0..10 {
1297 let file = format!("src_{i}.txt");
1298 if let Some(&status) = statuses.get(&file) {
1299 let ok = repo.stage_file(&file).unwrap();
1300 if !ok {
1301 failed.push((file, status));
1302 }
1303 }
1304 }
1305
1306 assert!(
1307 failed.is_empty(),
1308 "stage_file failed for old paths after prior commit: {:?}",
1309 failed
1310 );
1311 }
1312}