1use git2::{Oid, Repository, Signature};
8use std::path::{Path, PathBuf};
9
10#[derive(Debug, thiserror::Error)]
12pub enum CheckpointError {
13 #[error("git error: {0}")]
14 Git(#[from] git2::Error),
15 #[error("no checkpoints available")]
16 NoCheckpoints,
17 #[error("repository not found at {0}")]
18 RepoNotFound(PathBuf),
19 #[error("io error: {0}")]
20 Io(#[from] std::io::Error),
21}
22
23#[derive(Debug, Clone)]
25pub struct Checkpoint {
26 pub oid: String,
28 pub label: String,
30 pub timestamp: chrono::DateTime<chrono::Utc>,
32 pub changed_files: Vec<String>,
34}
35
36pub struct CheckpointManager {
38 workspace: PathBuf,
39 checkpoints: Vec<Checkpoint>,
40 ref_prefix: String,
42}
43
44impl CheckpointManager {
45 pub fn new(workspace: PathBuf) -> Self {
47 Self {
48 workspace,
49 checkpoints: Vec::new(),
50 ref_prefix: "refs/rustant/checkpoints".to_string(),
51 }
52 }
53
54 pub fn workspace(&self) -> &Path {
56 &self.workspace
57 }
58
59 fn open_repo(&self) -> Result<Repository, CheckpointError> {
61 Repository::discover(&self.workspace)
62 .map_err(|_| CheckpointError::RepoNotFound(self.workspace.clone()))
63 }
64
65 pub fn create_checkpoint(&mut self, label: &str) -> Result<Checkpoint, CheckpointError> {
70 let repo = self.open_repo()?;
71
72 let head = repo.head()?;
74 let parent_commit = head.peel_to_commit()?;
75
76 let mut index = repo.index()?;
78 index.add_all(["*"].iter(), git2::IndexAddOption::DEFAULT, None)?;
79 index.write()?;
80 let tree_oid = index.write_tree()?;
81 let tree = repo.find_tree(tree_oid)?;
82
83 let parent_tree = parent_commit.tree()?;
85 let diff = repo.diff_tree_to_tree(Some(&parent_tree), Some(&tree), None)?;
86 let changed_files: Vec<String> = diff
87 .deltas()
88 .filter_map(|d| d.new_file().path().map(|p| p.to_string_lossy().to_string()))
89 .collect();
90
91 let sig = Signature::now("rustant", "rustant@local")?;
93 let message = format!("[checkpoint] {}", label);
94 let oid = repo.commit(
95 None, &sig,
97 &sig,
98 &message,
99 &tree,
100 &[&parent_commit],
101 )?;
102
103 let ref_name = format!("{}/{}", self.ref_prefix, self.checkpoints.len());
105 repo.reference(&ref_name, oid, true, &format!("checkpoint: {}", label))?;
106
107 let checkpoint = Checkpoint {
108 oid: oid.to_string(),
109 label: label.to_string(),
110 timestamp: chrono::Utc::now(),
111 changed_files,
112 };
113
114 self.checkpoints.push(checkpoint.clone());
115 Ok(checkpoint)
116 }
117
118 pub fn restore_checkpoint(
120 &mut self,
121 checkpoint_index: usize,
122 ) -> Result<&Checkpoint, CheckpointError> {
123 if checkpoint_index >= self.checkpoints.len() {
124 return Err(CheckpointError::NoCheckpoints);
125 }
126
127 let checkpoint = &self.checkpoints[checkpoint_index];
128 let repo = self.open_repo()?;
129 let oid = Oid::from_str(&checkpoint.oid)?;
130 let commit = repo.find_commit(oid)?;
131 let tree = commit.tree()?;
132
133 repo.checkout_tree(
135 tree.as_object(),
136 Some(git2::build::CheckoutBuilder::new().force()),
137 )?;
138
139 let mut index = repo.index()?;
141 index.read_tree(&tree)?;
142 index.write()?;
143
144 Ok(checkpoint)
145 }
146
147 pub fn undo(&mut self) -> Result<&Checkpoint, CheckpointError> {
149 if self.checkpoints.is_empty() {
150 return Err(CheckpointError::NoCheckpoints);
151 }
152 let last = self.checkpoints.len() - 1;
153 self.restore_checkpoint(last)
154 }
155
156 pub fn checkpoints(&self) -> &[Checkpoint] {
158 &self.checkpoints
159 }
160
161 pub fn count(&self) -> usize {
163 self.checkpoints.len()
164 }
165
166 pub fn diff_from_last(&self) -> Result<String, CheckpointError> {
168 let repo = self.open_repo()?;
169
170 if self.checkpoints.is_empty() {
171 let head = repo.head()?;
173 let tree = head.peel_to_tree()?;
174 let diff = repo.diff_tree_to_workdir(Some(&tree), None)?;
175 let mut output = Vec::new();
176 diff.print(git2::DiffFormat::Patch, |_, _, line| {
177 output.extend_from_slice(line.content());
178 true
179 })?;
180 return Ok(String::from_utf8_lossy(&output).to_string());
181 }
182
183 let last = &self.checkpoints[self.checkpoints.len() - 1];
184 let oid = Oid::from_str(&last.oid)?;
185 let commit = repo.find_commit(oid)?;
186 let tree = commit.tree()?;
187
188 let diff = repo.diff_tree_to_workdir(Some(&tree), None)?;
189 let mut output = Vec::new();
190 diff.print(git2::DiffFormat::Patch, |_, _, line| {
191 output.extend_from_slice(line.content());
192 true
193 })?;
194
195 Ok(String::from_utf8_lossy(&output).to_string())
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use std::fs;
203
204 fn setup_test_repo() -> (tempfile::TempDir, PathBuf) {
205 let dir = tempfile::tempdir().unwrap();
206 let path = dir.path().to_path_buf();
207
208 let repo = Repository::init(&path).unwrap();
210
211 fs::write(path.join("initial.txt"), "initial content").unwrap();
213 let mut index = repo.index().unwrap();
214 index
215 .add_all(["*"].iter(), git2::IndexAddOption::DEFAULT, None)
216 .unwrap();
217 index.write().unwrap();
218 let tree_oid = index.write_tree().unwrap();
219 let tree = repo.find_tree(tree_oid).unwrap();
220 let sig = Signature::now("test", "test@test.com").unwrap();
221 repo.commit(Some("HEAD"), &sig, &sig, "Initial commit", &tree, &[])
222 .unwrap();
223
224 (dir, path)
225 }
226
227 #[test]
228 fn test_checkpoint_manager_new() {
229 let (_dir, path) = setup_test_repo();
230 let mgr = CheckpointManager::new(path.clone());
231 assert_eq!(mgr.workspace(), path);
232 assert_eq!(mgr.count(), 0);
233 }
234
235 #[test]
236 fn test_create_checkpoint() {
237 let (_dir, path) = setup_test_repo();
238 let mut mgr = CheckpointManager::new(path.clone());
239
240 fs::write(path.join("initial.txt"), "modified content").unwrap();
242
243 let cp = mgr.create_checkpoint("before tool exec").unwrap();
244 assert_eq!(cp.label, "before tool exec");
245 assert!(!cp.oid.is_empty());
246 assert_eq!(mgr.count(), 1);
247 }
248
249 #[test]
250 fn test_create_multiple_checkpoints() {
251 let (_dir, path) = setup_test_repo();
252 let mut mgr = CheckpointManager::new(path.clone());
253
254 fs::write(path.join("initial.txt"), "v2").unwrap();
255 mgr.create_checkpoint("cp1").unwrap();
256
257 fs::write(path.join("initial.txt"), "v3").unwrap();
258 mgr.create_checkpoint("cp2").unwrap();
259
260 assert_eq!(mgr.count(), 2);
261 assert_eq!(mgr.checkpoints()[0].label, "cp1");
262 assert_eq!(mgr.checkpoints()[1].label, "cp2");
263 }
264
265 #[test]
266 fn test_restore_checkpoint() {
267 let (_dir, path) = setup_test_repo();
268 let mut mgr = CheckpointManager::new(path.clone());
269
270 mgr.create_checkpoint("original").unwrap();
272
273 fs::write(path.join("initial.txt"), "CHANGED").unwrap();
275 assert_eq!(
276 fs::read_to_string(path.join("initial.txt")).unwrap(),
277 "CHANGED"
278 );
279
280 mgr.restore_checkpoint(0).unwrap();
282 let content = fs::read_to_string(path.join("initial.txt")).unwrap();
284 assert_ne!(content, "CHANGED");
285 }
286
287 #[test]
288 fn test_undo_no_checkpoints() {
289 let (_dir, path) = setup_test_repo();
290 let mut mgr = CheckpointManager::new(path);
291 let result = mgr.undo();
292 assert!(result.is_err());
293 match result.unwrap_err() {
294 CheckpointError::NoCheckpoints => {}
295 other => panic!("Expected NoCheckpoints, got {:?}", other),
296 }
297 }
298
299 #[test]
300 fn test_undo_restores_last() {
301 let (_dir, path) = setup_test_repo();
302 let mut mgr = CheckpointManager::new(path.clone());
303
304 fs::write(path.join("initial.txt"), "checkpoint state").unwrap();
305 mgr.create_checkpoint("before change").unwrap();
306
307 fs::write(path.join("initial.txt"), "after change").unwrap();
308 mgr.undo().unwrap();
309
310 let content = fs::read_to_string(path.join("initial.txt")).unwrap();
311 assert_eq!(content, "checkpoint state");
312 }
313
314 #[test]
315 fn test_diff_from_last_no_checkpoints() {
316 let (_dir, path) = setup_test_repo();
317 let mgr = CheckpointManager::new(path.clone());
318
319 fs::write(path.join("initial.txt"), "modified").unwrap();
321
322 let diff = mgr.diff_from_last().unwrap();
323 assert!(!diff.is_empty());
324 assert!(diff.contains("modified"));
325 }
326
327 #[test]
328 fn test_diff_from_last_with_checkpoint() {
329 let (_dir, path) = setup_test_repo();
330 let mut mgr = CheckpointManager::new(path.clone());
331
332 fs::write(path.join("initial.txt"), "checkpoint state").unwrap();
333 mgr.create_checkpoint("cp1").unwrap();
334
335 fs::write(path.join("initial.txt"), "new state").unwrap();
336
337 let diff = mgr.diff_from_last().unwrap();
338 assert!(diff.contains("new state") || diff.contains("checkpoint state"));
339 }
340
341 #[test]
342 fn test_checkpoint_changed_files() {
343 let (_dir, path) = setup_test_repo();
344 let mut mgr = CheckpointManager::new(path.clone());
345
346 fs::write(path.join("new_file.txt"), "hello").unwrap();
347 let cp = mgr.create_checkpoint("added file").unwrap();
348 assert!(cp.changed_files.iter().any(|f| f.contains("new_file")));
349 }
350
351 #[test]
352 fn test_repo_not_found() {
353 let dir = tempfile::tempdir().unwrap();
354 let mgr = CheckpointManager::new(dir.path().to_path_buf());
355 let result = mgr.diff_from_last();
356 assert!(result.is_err());
357 }
358}