1use anyhow::{Context, Result, bail};
2use semver::Version;
3use sha2::{Digest, Sha256};
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::process::Command;
7
8use crate::commit::Commit;
9use crate::error::ReleaseError;
10
11fn sha256_hex(data: &[u8]) -> String {
12 let mut hasher = Sha256::new();
13 hasher.update(data);
14 format!("{:x}", hasher.finalize())
15}
16
17#[derive(Debug, Clone)]
19pub struct TagInfo {
20 pub name: String,
21 pub version: Version,
22 pub sha: String,
23}
24
25pub trait GitRepository: Send + Sync {
27 fn latest_tag(&self, prefix: &str) -> Result<Option<TagInfo>, ReleaseError>;
29
30 fn commits_since(&self, from: Option<&str>) -> Result<Vec<Commit>, ReleaseError>;
33
34 fn create_tag(&self, name: &str, message: &str, sign: bool) -> Result<(), ReleaseError>;
36
37 fn push_tag(&self, name: &str) -> Result<(), ReleaseError>;
39
40 fn stage_and_commit(&self, paths: &[&str], message: &str) -> Result<bool, ReleaseError>;
43
44 fn push(&self) -> Result<(), ReleaseError>;
46
47 fn tag_exists(&self, name: &str) -> Result<bool, ReleaseError>;
49
50 fn remote_tag_exists(&self, name: &str) -> Result<bool, ReleaseError>;
52
53 fn all_tags(&self, prefix: &str) -> Result<Vec<TagInfo>, ReleaseError>;
55
56 fn commits_between(&self, from: Option<&str>, to: &str) -> Result<Vec<Commit>, ReleaseError>;
59
60 fn tag_date(&self, tag_name: &str) -> Result<String, ReleaseError>;
62
63 fn force_create_tag(&self, name: &str) -> Result<(), ReleaseError>;
65
66 fn force_push_tag(&self, name: &str) -> Result<(), ReleaseError>;
68
69 fn head_sha(&self) -> Result<String, ReleaseError>;
71
72 fn commits_since_in_path(
74 &self,
75 from: Option<&str>,
76 path: &str,
77 ) -> Result<Vec<Commit>, ReleaseError> {
78 let _ = path;
80 self.commits_since(from)
81 }
82
83 fn commits_between_in_path(
85 &self,
86 from: Option<&str>,
87 to: &str,
88 path: &str,
89 ) -> Result<Vec<Commit>, ReleaseError> {
90 let _ = path;
91 self.commits_between(from, to)
92 }
93}
94
95fn git_unquote(s: &str) -> String {
100 let s = s.trim();
101 if !(s.starts_with('"') && s.ends_with('"')) {
102 return s.to_string();
103 }
104 let inner = &s[1..s.len() - 1];
106 let mut out = Vec::new();
107 let bytes = inner.as_bytes();
108 let mut i = 0;
109 while i < bytes.len() {
110 if bytes[i] == b'\\' && i + 1 < bytes.len() {
111 i += 1;
112 match bytes[i] {
113 b'\\' => out.push(b'\\'),
114 b'"' => out.push(b'"'),
115 b'n' => out.push(b'\n'),
116 b't' => out.push(b'\t'),
117 b'r' => out.push(b'\r'),
118 b'a' => out.push(0x07),
119 b'b' => out.push(0x08),
120 b'f' => out.push(0x0C),
121 b'v' => out.push(0x0B),
122 b'0'..=b'3' => {
124 let mut val = (bytes[i] - b'0') as u16;
125 for _ in 0..2 {
126 if i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
127 i += 1;
128 val = val * 8 + (bytes[i] - b'0') as u16;
129 } else {
130 break;
131 }
132 }
133 out.push(val as u8);
134 }
135 other => {
136 out.push(b'\\');
137 out.push(other);
138 }
139 }
140 } else {
141 out.push(bytes[i]);
142 }
143 i += 1;
144 }
145 String::from_utf8(out).unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).to_string())
146}
147
148pub struct GitRepo {
149 root: PathBuf,
150}
151
152#[allow(dead_code)]
153impl GitRepo {
154 pub fn discover() -> Result<Self> {
155 let output = Command::new("git")
156 .args(["rev-parse", "--show-toplevel"])
157 .output()
158 .context("failed to run git")?;
159
160 if !output.status.success() {
161 bail!("not in a git repository");
162 }
163
164 let root = String::from_utf8(output.stdout)
165 .context("invalid utf-8 from git")?
166 .trim()
167 .into();
168
169 Ok(Self { root })
170 }
171
172 pub fn root(&self) -> &PathBuf {
173 &self.root
174 }
175
176 fn git(&self, args: &[&str]) -> Result<String> {
177 let output = Command::new("git")
178 .args(["-C", self.root.to_str().unwrap()])
179 .args(args)
180 .output()
181 .with_context(|| format!("failed to run git {}", args.join(" ")))?;
182
183 if !output.status.success() {
184 let stderr = String::from_utf8_lossy(&output.stderr);
185 bail!("git {} failed: {}", args.join(" "), stderr.trim());
186 }
187
188 Ok(String::from_utf8_lossy(&output.stdout).to_string())
189 }
190
191 fn git_allow_failure(&self, args: &[&str]) -> Result<(bool, String)> {
192 let output = Command::new("git")
193 .args(["-C", self.root.to_str().unwrap()])
194 .args(args)
195 .output()
196 .with_context(|| format!("failed to run git {}", args.join(" ")))?;
197
198 Ok((
199 output.status.success(),
200 String::from_utf8_lossy(&output.stdout).to_string(),
201 ))
202 }
203
204 pub fn has_staged_changes(&self) -> Result<bool> {
205 let out = self.git(&["diff", "--cached", "--name-only"])?;
206 Ok(!out.trim().is_empty())
207 }
208
209 pub fn has_any_changes(&self) -> Result<bool> {
210 let out = self.git(&["status", "--porcelain"])?;
211 Ok(!out.trim().is_empty())
212 }
213
214 pub fn has_head(&self) -> Result<bool> {
215 let (ok, _) = self.git_allow_failure(&["rev-parse", "HEAD"])?;
216 Ok(ok)
217 }
218
219 pub fn reset_head(&self) -> Result<()> {
220 if self.has_head()? {
221 self.git(&["reset", "HEAD", "--quiet"])?;
222 } else {
223 let _ = self.git_allow_failure(&["rm", "--cached", "-r", ".", "--quiet"]);
225 }
226 Ok(())
227 }
228
229 pub fn stage_file(&self, file: &str) -> Result<bool> {
230 let (ok, _) = self.git_allow_failure(&["add", "--", file])?;
238 Ok(ok)
239 }
240
241 pub fn has_staged_after_add(&self) -> Result<bool> {
242 self.has_staged_changes()
243 }
244
245 pub fn commit(&self, message: &str) -> Result<()> {
246 let output = Command::new("git")
247 .args(["-C", self.root.to_str().unwrap()])
248 .args(["commit", "-F", "-"])
249 .stdin(std::process::Stdio::piped())
250 .stdout(std::process::Stdio::piped())
251 .stderr(std::process::Stdio::piped())
252 .spawn()
253 .context("failed to spawn git commit")?;
254
255 use std::io::Write;
256 let mut child = output;
257 if let Some(mut stdin) = child.stdin.take() {
258 stdin.write_all(message.as_bytes())?;
259 }
260
261 let out = child.wait_with_output()?;
262 if !out.status.success() {
263 let stderr = String::from_utf8_lossy(&out.stderr);
264 bail!("git commit failed: {}", stderr.trim());
265 }
266
267 Ok(())
268 }
269
270 pub fn recent_commits(&self, count: usize) -> Result<String> {
271 self.git(&["--no-pager", "log", "--oneline", &format!("-{count}")])
272 }
273
274 pub fn diff_cached(&self) -> Result<String> {
275 self.git(&["diff", "--cached"])
276 }
277
278 pub fn diff_cached_stat(&self) -> Result<String> {
279 self.git(&["diff", "--cached", "--stat"])
280 }
281
282 pub fn diff_head(&self) -> Result<String> {
283 let (ok, out) = self.git_allow_failure(&["diff", "HEAD"])?;
284 if ok { Ok(out) } else { self.git(&["diff"]) }
285 }
286
287 pub fn diff_unified(&self, staged: bool, context: usize, files: &[String]) -> Result<String> {
290 let ctx_flag = format!("-U{context}");
291 let mut args: Vec<&str> = vec!["diff", &ctx_flag];
292 if staged {
293 args.push("--cached");
294 } else {
295 args.push("HEAD");
296 }
297 if !files.is_empty() {
298 args.push("--");
299 for f in files {
300 args.push(f.as_str());
301 }
302 }
303 let (ok, out) = self.git_allow_failure(&args)?;
304 if ok {
305 Ok(out)
306 } else if !staged && files.is_empty() {
307 self.git(&["diff", &ctx_flag])
309 } else {
310 Ok(out)
311 }
312 }
313
314 pub fn diff_numstat(
316 &self,
317 staged: bool,
318 files: &[String],
319 ) -> Result<Vec<(usize, usize, String)>> {
320 let mut args: Vec<&str> = vec!["diff", "--numstat"];
321 if staged {
322 args.push("--cached");
323 } else {
324 args.push("HEAD");
325 }
326 if !files.is_empty() {
327 args.push("--");
328 for f in files {
329 args.push(f.as_str());
330 }
331 }
332 let (ok, out) = self.git_allow_failure(&args)?;
333 if !ok && !staged && files.is_empty() {
334 let out = self.git(&["diff", "--numstat"])?;
335 return Self::parse_numstat(&out);
336 }
337 Self::parse_numstat(&out)
338 }
339
340 fn parse_numstat(out: &str) -> Result<Vec<(usize, usize, String)>> {
341 let mut result = Vec::new();
342 for line in out.lines() {
343 let parts: Vec<&str> = line.splitn(3, '\t').collect();
344 if parts.len() == 3 {
345 let add = parts[0].parse().unwrap_or(0);
347 let del = parts[1].parse().unwrap_or(0);
348 let path = if let Some(pos) = parts[2].find(" => ") {
349 git_unquote(&parts[2][pos + 4..])
351 } else {
352 git_unquote(parts[2])
353 };
354 result.push((add, del, path));
355 }
356 }
357 Ok(result)
358 }
359
360 pub fn status_porcelain(&self) -> Result<String> {
361 self.git(&["status", "--porcelain"])
362 }
363
364 pub fn untracked_files(&self) -> Result<String> {
365 self.git(&["ls-files", "--others", "--exclude-standard"])
366 }
367
368 pub fn show(&self, rev: &str) -> Result<String> {
369 self.git(&["show", rev])
370 }
371
372 pub fn log_range(&self, base: &str, count: Option<usize>) -> Result<String> {
373 let mut args = vec!["--no-pager", "log", "--oneline"];
374 let count_str;
375 if let Some(n) = count {
376 count_str = format!("-{n}");
377 args.push(&count_str);
378 }
379 args.push(base);
380 self.git(&args)
381 }
382
383 pub fn diff_range(&self, base: &str) -> Result<String> {
384 self.git(&["diff", base])
385 }
386
387 pub fn current_branch(&self) -> Result<String> {
388 let out = self.git(&["rev-parse", "--abbrev-ref", "HEAD"])?;
389 Ok(out.trim().to_string())
390 }
391
392 pub fn head_short(&self) -> Result<String> {
393 let out = self.git(&["rev-parse", "--short", "HEAD"])?;
394 Ok(out.trim().to_string())
395 }
396
397 pub fn commits_since_last_tag(&self) -> Result<usize> {
399 let (ok, tag) = self.git_allow_failure(&["describe", "--tags", "--abbrev=0"])?;
401 let tag = tag.trim();
402
403 let out = if ok && !tag.is_empty() {
404 self.git(&["rev-list", &format!("{tag}..HEAD"), "--count"])?
405 } else {
406 self.git(&["rev-list", "HEAD", "--count"])?
407 };
408
409 out.trim()
410 .parse::<usize>()
411 .context("failed to parse commit count")
412 }
413
414 pub fn log_detailed(&self, count: usize) -> Result<String> {
416 let out = self.git(&[
417 "--no-pager",
418 "log",
419 "--reverse",
420 &format!("-{count}"),
421 "--format=%h %s%n%b%n---",
422 ])?;
423 Ok(out)
424 }
425
426 pub fn file_statuses(&self) -> Result<HashMap<String, char>> {
427 let out = self.git(&["status", "--porcelain"])?;
428 let mut map = HashMap::new();
429 for line in out.lines() {
430 if line.len() < 3 {
431 continue;
432 }
433 let xy = &line.as_bytes()[..2];
434 let path = line[3..].to_string();
435 let (x, y) = (xy[0], xy[1]);
436 let is_rename = matches!((x, y), (b'R', _) | (_, b'R'));
437 if is_rename {
438 if let Some(pos) = path.find(" -> ") {
439 let old_path = git_unquote(&path[..pos]);
440 let new_path = git_unquote(&path[pos + 4..]);
441 map.insert(old_path, 'D');
442 map.insert(new_path, 'R');
443 } else {
444 map.insert(git_unquote(&path), 'R');
445 }
446 } else {
447 let status = match (x, y) {
448 (b'?', b'?') => 'A',
449 (b'A', _) | (_, b'A') => 'A',
450 (b'D', _) | (_, b'D') => 'D',
451 (b'M', _) | (_, b'M') | (b'T', _) | (_, b'T') => 'M',
452 _ => '~',
453 };
454 map.insert(git_unquote(&path), status);
455 }
456 }
457 Ok(map)
458 }
459
460 pub fn snapshot_working_tree(&self) -> Result<PathBuf> {
473 let snapshot_dir = snapshot_dir_for(&self.root)
474 .context("failed to resolve snapshot directory (no data directory available)")?;
475 if snapshot_dir.exists() {
477 std::fs::remove_dir_all(&snapshot_dir).ok();
478 }
479 std::fs::create_dir_all(&snapshot_dir).context("failed to create snapshot directory")?;
480
481 let files_dir = snapshot_dir.join("files");
482 std::fs::create_dir_all(&files_dir)?;
483
484 std::fs::write(
486 snapshot_dir.join("repo_root"),
487 self.root.to_string_lossy().as_bytes(),
488 )
489 .context("failed to write repo_root")?;
490
491 let (has_head, head_ref) = self.git_allow_failure(&["rev-parse", "HEAD"])?;
493 if has_head {
494 std::fs::write(snapshot_dir.join("head_ref"), head_ref.trim())
495 .context("failed to write head_ref")?;
496 }
497
498 let porcelain = self.git(&["status", "--porcelain"])?;
501 let staged_names = self.git(&["diff", "--cached", "--name-only", "-z"])?;
502 let staged_set: std::collections::HashSet<String> = staged_names
503 .split('\0')
504 .map(|l| l.trim().to_string())
505 .filter(|l| !l.is_empty())
506 .collect();
507
508 #[derive(serde::Serialize, serde::Deserialize)]
509 struct ManifestEntry {
510 path: String,
511 index_status: char,
513 worktree_status: char,
515 staged: bool,
517 has_content: bool,
519 }
520
521 let mut manifest: Vec<ManifestEntry> = Vec::new();
522
523 for line in porcelain.lines() {
524 if line.len() < 3 {
525 continue;
526 }
527 let bytes = line.as_bytes();
528 let x = bytes[0] as char;
529 let y = bytes[1] as char;
530 let raw = line[3..].to_string();
531 let path = if let Some(pos) = raw.find(" -> ") {
533 git_unquote(&raw[pos + 4..])
534 } else {
535 git_unquote(&raw)
536 };
537
538 let src = self.root.join(&path);
539 let has_content = src.exists() && src.is_file();
540
541 if has_content {
542 let dest = files_dir.join(&path);
543 if let Some(parent) = dest.parent() {
544 std::fs::create_dir_all(parent).ok();
545 }
546 if let Err(e) = std::fs::copy(&src, &dest) {
547 eprintln!("warning: failed to snapshot {path}: {e}");
548 }
549 }
550
551 manifest.push(ManifestEntry {
552 staged: staged_set.contains(path.as_str()),
553 path,
554 index_status: x,
555 worktree_status: y,
556 has_content,
557 });
558 }
559
560 let manifest_json =
561 serde_json::to_string_pretty(&manifest).context("failed to serialize manifest")?;
562 std::fs::write(snapshot_dir.join("manifest.json"), manifest_json)
563 .context("failed to write manifest.json")?;
564
565 let now = std::time::SystemTime::now()
567 .duration_since(std::time::UNIX_EPOCH)
568 .unwrap_or_default()
569 .as_secs();
570 std::fs::write(snapshot_dir.join("timestamp"), now.to_string())
571 .context("failed to write timestamp")?;
572
573 Ok(snapshot_dir)
574 }
575
576 pub fn restore_snapshot(&self) -> Result<()> {
586 let snapshot_dir = self.snapshot_dir()?;
587 if !snapshot_dir.join("timestamp").exists() {
588 bail!("no valid snapshot found");
589 }
590
591 let files_dir = snapshot_dir.join("files");
592
593 let head_ref_path = snapshot_dir.join("head_ref");
595 if head_ref_path.exists() {
596 let original_head = std::fs::read_to_string(&head_ref_path)?;
597 let original_head = original_head.trim();
598 if !original_head.is_empty() {
599 let _ = self.git_allow_failure(&["reset", "--soft", original_head]);
600 }
601 }
602
603 self.reset_head()?;
605
606 let manifest_path = snapshot_dir.join("manifest.json");
608 if !manifest_path.exists() {
609 bail!("snapshot manifest.json missing — cannot restore");
610 }
611
612 #[derive(serde::Deserialize)]
613 struct ManifestEntry {
614 path: String,
615 index_status: char,
616 worktree_status: char,
617 staged: bool,
618 has_content: bool,
619 }
620
621 let manifest_data = std::fs::read_to_string(&manifest_path)?;
622 let manifest: Vec<ManifestEntry> =
623 serde_json::from_str(&manifest_data).context("failed to parse snapshot manifest")?;
624
625 let mut restored = 0usize;
626 let mut failed = 0usize;
627
628 for entry in &manifest {
629 let dest = self.root.join(&entry.path);
630
631 if entry.has_content {
632 let src = files_dir.join(&entry.path);
634 if src.exists() {
635 if let Some(parent) = dest.parent() {
636 std::fs::create_dir_all(parent).ok();
637 }
638 match std::fs::copy(&src, &dest) {
639 Ok(_) => restored += 1,
640 Err(e) => {
641 eprintln!("warning: failed to restore {}: {e}", entry.path);
642 failed += 1;
643 }
644 }
645 } else {
646 eprintln!("warning: snapshot missing content for {}", entry.path);
647 failed += 1;
648 }
649 } else if entry.index_status == 'D' || entry.worktree_status == 'D' {
650 if dest.exists() {
652 std::fs::remove_file(&dest).ok();
653 }
654 }
655
656 if entry.staged {
658 let _ = self.git_allow_failure(&["add", "--", &entry.path]);
659 }
660 }
661
662 if failed > 0 {
663 eprintln!("sr: restored {restored} files, {failed} failed");
664 }
665
666 Ok(())
667 }
668
669 pub fn clear_snapshot(&self) {
671 if let Ok(dir) = self.snapshot_dir() {
672 let _ = std::fs::remove_dir_all(&dir);
673 }
674 }
675
676 pub fn snapshot_dir(&self) -> Result<PathBuf> {
678 snapshot_dir_for(&self.root)
679 .context("failed to resolve snapshot directory (no data directory available)")
680 }
681
682 pub fn has_snapshot(&self) -> bool {
684 self.snapshot_dir()
685 .map(|d| d.join("timestamp").exists())
686 .unwrap_or(false)
687 }
688}
689
690fn snapshot_dir_for(repo_root: &std::path::Path) -> Option<PathBuf> {
693 let base = dirs::data_local_dir()?;
694 let repo_id = &sha256_hex(repo_root.to_string_lossy().as_bytes())[..16];
695 Some(base.join("sr").join("snapshots").join(repo_id))
696}
697
698pub struct SnapshotGuard<'a> {
701 repo: &'a GitRepo,
702 succeeded: bool,
703}
704
705impl<'a> SnapshotGuard<'a> {
706 pub fn new(repo: &'a GitRepo) -> Result<Self> {
708 repo.snapshot_working_tree()?;
709 Ok(Self {
710 repo,
711 succeeded: false,
712 })
713 }
714
715 pub fn success(mut self) {
717 self.succeeded = true;
718 self.repo.clear_snapshot();
719 }
720}
721
722impl Drop for SnapshotGuard<'_> {
723 fn drop(&mut self) {
724 if !self.succeeded && self.repo.has_snapshot() {
725 eprintln!("sr: operation failed, restoring working tree from snapshot...");
726 if let Err(e) = self.repo.restore_snapshot() {
727 eprintln!("sr: warning: snapshot restore failed: {e}");
728 if let Ok(dir) = self.repo.snapshot_dir() {
729 eprintln!(
730 "sr: snapshot preserved at {} for manual recovery",
731 dir.display()
732 );
733 }
734 } else {
735 self.repo.clear_snapshot();
736 }
737 }
738 }
739}
740
741#[cfg(test)]
742mod tests {
743 use super::*;
744 use std::fs;
745
746 fn temp_repo() -> (tempfile::TempDir, GitRepo) {
748 let dir = tempfile::tempdir().unwrap();
749 let root = dir.path().to_path_buf();
750
751 let git = |args: &[&str]| {
752 Command::new("git")
753 .args(["-C", root.to_str().unwrap()])
754 .args(args)
755 .output()
756 .unwrap()
757 };
758
759 git(&["init"]);
760 git(&["config", "user.email", "test@test.com"]);
761 git(&["config", "user.name", "Test"]);
762 fs::write(root.join("init.txt"), "init").unwrap();
764 git(&["add", "init.txt"]);
765 git(&["commit", "-m", "initial"]);
766
767 let repo = GitRepo { root };
768 (dir, repo)
769 }
770
771 #[test]
772 fn snapshot_creates_manifest_with_staged_files() {
773 let (_dir, repo) = temp_repo();
774
775 fs::write(repo.root.join("new.go"), "package main").unwrap();
777 repo.git(&["add", "new.go"]).unwrap();
778
779 let snap_dir = repo.snapshot_working_tree().unwrap();
780
781 let manifest_path = snap_dir.join("manifest.json");
783 assert!(manifest_path.exists(), "manifest.json should exist");
784
785 let data = fs::read_to_string(&manifest_path).unwrap();
786 assert!(data.contains("new.go"), "manifest should list new.go");
787 assert!(
788 data.contains("\"staged\": true"),
789 "new.go should be marked staged"
790 );
791
792 assert!(
794 snap_dir.join("files/new.go").exists(),
795 "file content should be copied"
796 );
797 assert_eq!(
798 fs::read_to_string(snap_dir.join("files/new.go")).unwrap(),
799 "package main"
800 );
801
802 assert!(snap_dir.join("head_ref").exists());
804
805 repo.clear_snapshot();
806 }
807
808 #[test]
809 fn snapshot_restore_recovers_staged_new_files() {
810 let (_dir, repo) = temp_repo();
811
812 fs::write(repo.root.join("a.go"), "package a").unwrap();
814 fs::write(repo.root.join("b.go"), "package b").unwrap();
815 repo.git(&["add", "a.go", "b.go"]).unwrap();
816
817 repo.snapshot_working_tree().unwrap();
818
819 repo.reset_head().unwrap();
821 repo.git(&["add", "a.go"]).unwrap();
822 repo.git(&["commit", "-m", "partial"]).unwrap();
823
824 repo.restore_snapshot().unwrap();
826
827 assert!(repo.root.join("a.go").exists());
829 assert!(repo.root.join("b.go").exists());
830 assert_eq!(
831 fs::read_to_string(repo.root.join("a.go")).unwrap(),
832 "package a"
833 );
834 assert_eq!(
835 fs::read_to_string(repo.root.join("b.go")).unwrap(),
836 "package b"
837 );
838
839 let staged = repo.git(&["diff", "--cached", "--name-only"]).unwrap();
841 assert!(staged.contains("a.go"), "a.go should be re-staged");
842 assert!(staged.contains("b.go"), "b.go should be re-staged");
843
844 let log = repo.git(&["log", "--oneline"]).unwrap();
846 assert!(
847 !log.contains("partial"),
848 "partial commit should be undone by HEAD reset"
849 );
850
851 repo.clear_snapshot();
852 }
853
854 #[test]
855 fn snapshot_restore_with_dirty_index_does_not_conflict() {
856 let (_dir, repo) = temp_repo();
857
858 fs::write(repo.root.join("file.rs"), "fn main() {}").unwrap();
860 repo.git(&["add", "file.rs"]).unwrap();
861
862 repo.snapshot_working_tree().unwrap();
863
864 repo.reset_head().unwrap();
866 repo.git(&["add", "file.rs"]).unwrap();
867 let result = repo.restore_snapshot();
871 assert!(
872 result.is_ok(),
873 "restore should succeed with dirty index: {result:?}"
874 );
875
876 assert_eq!(
877 fs::read_to_string(repo.root.join("file.rs")).unwrap(),
878 "fn main() {}"
879 );
880
881 repo.clear_snapshot();
882 }
883
884 #[test]
885 fn snapshot_handles_modified_files() {
886 let (_dir, repo) = temp_repo();
887
888 fs::write(repo.root.join("init.txt"), "modified content").unwrap();
890 repo.git(&["add", "init.txt"]).unwrap();
891
892 repo.snapshot_working_tree().unwrap();
893
894 repo.reset_head().unwrap();
896 fs::write(repo.root.join("init.txt"), "wrong content").unwrap();
897
898 repo.restore_snapshot().unwrap();
900
901 assert_eq!(
902 fs::read_to_string(repo.root.join("init.txt")).unwrap(),
903 "modified content"
904 );
905
906 repo.clear_snapshot();
907 }
908
909 #[test]
910 fn snapshot_guard_restores_on_drop() {
911 let (_dir, repo) = temp_repo();
912
913 fs::write(repo.root.join("guarded.txt"), "important").unwrap();
914 repo.git(&["add", "guarded.txt"]).unwrap();
915
916 {
917 let _guard = SnapshotGuard::new(&repo).unwrap();
918 repo.reset_head().unwrap();
920 fs::remove_file(repo.root.join("guarded.txt")).ok();
921 }
923
924 assert!(repo.root.join("guarded.txt").exists());
926 assert_eq!(
927 fs::read_to_string(repo.root.join("guarded.txt")).unwrap(),
928 "important"
929 );
930 }
931
932 #[test]
933 fn snapshot_guard_clears_on_success() {
934 let (_dir, repo) = temp_repo();
935
936 fs::write(repo.root.join("ok.txt"), "data").unwrap();
937 repo.git(&["add", "ok.txt"]).unwrap();
938
939 let guard = SnapshotGuard::new(&repo).unwrap();
940 assert!(repo.has_snapshot());
941 guard.success();
942
943 assert!(!repo.has_snapshot());
945 }
946
947 #[test]
948 fn file_statuses_includes_both_sides_of_rename() {
949 let (_dir, repo) = temp_repo();
950
951 fs::write(repo.root.join("old_name.txt"), "content").unwrap();
953 repo.git(&["add", "old_name.txt"]).unwrap();
954 repo.git(&["commit", "-m", "add old_name"]).unwrap();
955
956 repo.git(&["mv", "old_name.txt", "new_name.txt"]).unwrap();
958
959 let statuses = repo.file_statuses().unwrap();
960
961 assert_eq!(
962 statuses.get("old_name.txt").copied(),
963 Some('D'),
964 "old path should appear as deleted"
965 );
966 assert_eq!(
967 statuses.get("new_name.txt").copied(),
968 Some('R'),
969 "new path should appear as renamed"
970 );
971 }
972
973 #[test]
978 fn stage_file_handles_many_moves_and_deletes_after_reset() {
979 let (_dir, repo) = temp_repo();
980
981 for i in 0..30 {
983 fs::write(
984 repo.root.join(format!("file_{i}.txt")),
985 format!("content {i}"),
986 )
987 .unwrap();
988 }
989 repo.git(&["add", "."]).unwrap();
990 repo.git(&["commit", "-m", "add files"]).unwrap();
991
992 fs::create_dir_all(repo.root.join("moved")).unwrap();
994 for i in 0..10 {
995 repo.git(&[
996 "mv",
997 &format!("file_{i}.txt"),
998 &format!("moved/file_{i}.txt"),
999 ])
1000 .unwrap();
1001 }
1002
1003 for i in 10..20 {
1005 repo.git(&["rm", &format!("file_{i}.txt")]).unwrap();
1006 }
1007
1008 for i in 20..30 {
1010 fs::write(
1011 repo.root.join(format!("file_{i}.txt")),
1012 format!("modified {i}"),
1013 )
1014 .unwrap();
1015 repo.git(&["add", &format!("file_{i}.txt")]).unwrap();
1016 }
1017
1018 for i in 30..35 {
1020 fs::write(repo.root.join(format!("new_{i}.txt")), format!("new {i}")).unwrap();
1021 repo.git(&["add", &format!("new_{i}.txt")]).unwrap();
1022 }
1023
1024 let statuses = repo.file_statuses().unwrap();
1026 assert!(
1027 statuses.len() >= 30,
1028 "should have many file statuses, got {}",
1029 statuses.len()
1030 );
1031
1032 repo.reset_head().unwrap();
1034
1035 let mut failed = Vec::new();
1037 for (file, status) in &statuses {
1038 if file == "init.txt" {
1039 continue;
1040 }
1041 let ok = repo.stage_file(file).unwrap();
1042 if !ok {
1043 failed.push((file.clone(), *status));
1044 }
1045 }
1046
1047 assert!(
1048 failed.is_empty(),
1049 "stage_file failed for {} files: {:?}",
1050 failed.len(),
1051 failed
1052 );
1053 }
1054
1055 #[test]
1059 fn stage_file_handles_manual_moves_after_reset() {
1060 let (_dir, repo) = temp_repo();
1061
1062 fs::create_dir_all(repo.root.join("old_dir")).unwrap();
1064 for i in 0..10 {
1065 fs::write(
1066 repo.root.join(format!("old_dir/file_{i}.txt")),
1067 format!("content {i}"),
1068 )
1069 .unwrap();
1070 }
1071 repo.git(&["add", "."]).unwrap();
1072 repo.git(&["commit", "-m", "add directory"]).unwrap();
1073
1074 fs::rename(repo.root.join("old_dir"), repo.root.join("new_dir")).unwrap();
1076
1077 repo.git(&["add", "-A"]).unwrap();
1079
1080 let statuses = repo.file_statuses().unwrap();
1082
1083 repo.reset_head().unwrap();
1085
1086 let mut failed = Vec::new();
1088 for (file, status) in &statuses {
1089 if file == "init.txt" {
1090 continue;
1091 }
1092 let ok = repo.stage_file(file).unwrap();
1093 if !ok {
1094 failed.push((file.clone(), *status));
1095 }
1096 }
1097
1098 assert!(
1099 failed.is_empty(),
1100 "stage_file failed for {} files after manual move: {:?}",
1101 failed.len(),
1102 failed
1103 );
1104 }
1105
1106 #[test]
1111 fn stage_file_handles_new_files_mixed_with_moves() {
1112 let (_dir, repo) = temp_repo();
1113
1114 for i in 0..5 {
1116 fs::write(
1117 repo.root.join(format!("existing_{i}.txt")),
1118 format!("existing {i}"),
1119 )
1120 .unwrap();
1121 }
1122 repo.git(&["add", "."]).unwrap();
1123 repo.git(&["commit", "-m", "add existing files"]).unwrap();
1124
1125 fs::create_dir_all(repo.root.join("moved")).unwrap();
1127 for i in 0..3 {
1128 repo.git(&[
1129 "mv",
1130 &format!("existing_{i}.txt"),
1131 &format!("moved/existing_{i}.txt"),
1132 ])
1133 .unwrap();
1134 }
1135
1136 repo.git(&["rm", "existing_3.txt"]).unwrap();
1138
1139 for i in 0..5 {
1141 fs::write(
1142 repo.root.join(format!("brand_new_{i}.txt")),
1143 format!("new {i}"),
1144 )
1145 .unwrap();
1146 }
1147 repo.git(&["add", "."]).unwrap();
1148
1149 let statuses = repo.file_statuses().unwrap();
1151
1152 repo.reset_head().unwrap();
1154
1155 let mut failed = Vec::new();
1157 for (file, status) in &statuses {
1158 if file == "init.txt" {
1159 continue;
1160 }
1161 let ok = repo.stage_file(file).unwrap();
1162 if !ok {
1163 failed.push((file.clone(), *status));
1164 }
1165 }
1166
1167 assert!(
1168 failed.is_empty(),
1169 "stage_file failed for {} files: {:?}",
1170 failed.len(),
1171 failed
1172 );
1173 }
1174
1175 #[test]
1180 fn stage_file_handles_quoted_paths_from_moves() {
1181 let (_dir, repo) = temp_repo();
1182
1183 fs::write(repo.root.join("old name.txt"), "content").unwrap();
1185 repo.git(&["add", "."]).unwrap();
1186 repo.git(&["commit", "-m", "add file with spaces"]).unwrap();
1187
1188 repo.git(&["mv", "old name.txt", "new name.txt"]).unwrap();
1190
1191 let statuses = repo.file_statuses().unwrap();
1193
1194 assert!(
1196 statuses.contains_key("old name.txt"),
1197 "old path should be unquoted; got keys: {:?}",
1198 statuses.keys().collect::<Vec<_>>()
1199 );
1200 assert!(
1201 statuses.contains_key("new name.txt"),
1202 "new path should be unquoted; got keys: {:?}",
1203 statuses.keys().collect::<Vec<_>>()
1204 );
1205
1206 repo.reset_head().unwrap();
1208
1209 let old_ok = repo.stage_file("old name.txt").unwrap();
1210 assert!(old_ok, "stage_file should succeed for old (deleted) path");
1211
1212 let new_ok = repo.stage_file("new name.txt").unwrap();
1213 assert!(new_ok, "stage_file should succeed for new (added) path");
1214 }
1215
1216 #[test]
1219 fn file_statuses_unquotes_paths_with_special_chars() {
1220 let (_dir, repo) = temp_repo();
1221
1222 fs::write(repo.root.join("my file.txt"), "content").unwrap();
1224 fs::write(repo.root.join("to delete.txt"), "delete me").unwrap();
1225 repo.git(&["add", "."]).unwrap();
1226 repo.git(&["commit", "-m", "add spaced files"]).unwrap();
1227
1228 fs::write(repo.root.join("my file.txt"), "modified").unwrap();
1230 repo.git(&["rm", "to delete.txt"]).unwrap();
1231 fs::write(repo.root.join("brand new file.txt"), "new").unwrap();
1232 repo.git(&["add", "."]).unwrap();
1233
1234 let statuses = repo.file_statuses().unwrap();
1235
1236 assert!(
1238 statuses.contains_key("my file.txt"),
1239 "modified file should be unquoted; keys: {:?}",
1240 statuses.keys().collect::<Vec<_>>()
1241 );
1242 assert!(
1243 statuses.contains_key("to delete.txt"),
1244 "deleted file should be unquoted; keys: {:?}",
1245 statuses.keys().collect::<Vec<_>>()
1246 );
1247 assert!(
1248 statuses.contains_key("brand new file.txt"),
1249 "new file should be unquoted; keys: {:?}",
1250 statuses.keys().collect::<Vec<_>>()
1251 );
1252 }
1253
1254 #[test]
1258 fn stage_file_works_across_sequential_commits_with_moves() {
1259 let (_dir, repo) = temp_repo();
1260
1261 for i in 0..10 {
1263 fs::write(
1264 repo.root.join(format!("src_{i}.txt")),
1265 format!("content {i}"),
1266 )
1267 .unwrap();
1268 }
1269 repo.git(&["add", "."]).unwrap();
1270 repo.git(&["commit", "-m", "add source files"]).unwrap();
1271
1272 fs::create_dir_all(repo.root.join("dst")).unwrap();
1274 for i in 0..10 {
1275 repo.git(&["mv", &format!("src_{i}.txt"), &format!("dst/src_{i}.txt")])
1276 .unwrap();
1277 }
1278
1279 let statuses = repo.file_statuses().unwrap();
1280 repo.reset_head().unwrap();
1281
1282 for i in 0..10 {
1284 let file = format!("dst/src_{i}.txt");
1285 let ok = repo.stage_file(&file).unwrap();
1286 assert!(ok, "should stage new path {file}");
1287 }
1288 repo.commit("feat: add new paths").unwrap();
1289
1290 let mut failed = Vec::new();
1293 for i in 0..10 {
1294 let file = format!("src_{i}.txt");
1295 if let Some(&status) = statuses.get(&file) {
1296 let ok = repo.stage_file(&file).unwrap();
1297 if !ok {
1298 failed.push((file, status));
1299 }
1300 }
1301 }
1302
1303 assert!(
1304 failed.is_empty(),
1305 "stage_file failed for old paths after prior commit: {:?}",
1306 failed
1307 );
1308 }
1309}