Skip to main content

rustant_tools/
checkpoint.rs

1//! Git-based checkpoint manager for undo/redo functionality.
2//!
3//! Uses git2 (libgit2 bindings) to create lightweight stash-like snapshots
4//! of the workspace. Each checkpoint captures the full working tree state
5//! so file modifications can be rolled back.
6
7use git2::{Oid, Repository, Signature};
8use std::path::{Path, PathBuf};
9
10/// Errors specific to checkpoint operations.
11#[derive(Debug, thiserror::Error)]
12pub enum CheckpointError {
13    #[error("git error: {0}")]
14    Git(#[from] git2::Error),
15    #[error("no checkpoints available")]
16    NoCheckpoints,
17    #[error("repository not found at {0}")]
18    RepoNotFound(PathBuf),
19    #[error("io error: {0}")]
20    Io(#[from] std::io::Error),
21}
22
23/// A single checkpoint (snapshot of the working tree).
24#[derive(Debug, Clone)]
25pub struct Checkpoint {
26    /// The commit OID for this checkpoint.
27    pub oid: String,
28    /// Human-readable label.
29    pub label: String,
30    /// Timestamp when the checkpoint was created.
31    pub timestamp: chrono::DateTime<chrono::Utc>,
32    /// Files changed in this checkpoint.
33    pub changed_files: Vec<String>,
34}
35
36/// Manages git-based checkpoints for the workspace.
37pub struct CheckpointManager {
38    workspace: PathBuf,
39    checkpoints: Vec<Checkpoint>,
40    /// Name of the checkpoint ref namespace.
41    ref_prefix: String,
42}
43
44impl CheckpointManager {
45    /// Create a new CheckpointManager for the given workspace.
46    pub fn new(workspace: PathBuf) -> Self {
47        Self {
48            workspace,
49            checkpoints: Vec::new(),
50            ref_prefix: "refs/rustant/checkpoints".to_string(),
51        }
52    }
53
54    /// Get the workspace path.
55    pub fn workspace(&self) -> &Path {
56        &self.workspace
57    }
58
59    /// Open the repository at the workspace path.
60    fn open_repo(&self) -> Result<Repository, CheckpointError> {
61        Repository::discover(&self.workspace)
62            .map_err(|_| CheckpointError::RepoNotFound(self.workspace.clone()))
63    }
64
65    /// Create a checkpoint of the current workspace state.
66    ///
67    /// This stages all changes and creates a commit on a detached ref
68    /// so it doesn't affect the user's branch or history.
69    pub fn create_checkpoint(&mut self, label: &str) -> Result<Checkpoint, CheckpointError> {
70        let repo = self.open_repo()?;
71
72        // Get the current HEAD as the parent
73        let head = repo.head()?;
74        let parent_commit = head.peel_to_commit()?;
75
76        // Build a tree from the current working directory state
77        let mut index = repo.index()?;
78        index.add_all(["*"].iter(), git2::IndexAddOption::DEFAULT, None)?;
79        index.write()?;
80        let tree_oid = index.write_tree()?;
81        let tree = repo.find_tree(tree_oid)?;
82
83        // Detect changed files by diffing against parent
84        let parent_tree = parent_commit.tree()?;
85        let diff = repo.diff_tree_to_tree(Some(&parent_tree), Some(&tree), None)?;
86        let changed_files: Vec<String> = diff
87            .deltas()
88            .filter_map(|d| d.new_file().path().map(|p| p.to_string_lossy().to_string()))
89            .collect();
90
91        // Create the checkpoint commit
92        let sig = Signature::now("rustant", "rustant@local")?;
93        let message = format!("[checkpoint] {}", label);
94        let oid = repo.commit(
95            None, // don't update any ref yet
96            &sig,
97            &sig,
98            &message,
99            &tree,
100            &[&parent_commit],
101        )?;
102
103        // Store as a named reference
104        let ref_name = format!("{}/{}", self.ref_prefix, self.checkpoints.len());
105        repo.reference(&ref_name, oid, true, &format!("checkpoint: {}", label))?;
106
107        let checkpoint = Checkpoint {
108            oid: oid.to_string(),
109            label: label.to_string(),
110            timestamp: chrono::Utc::now(),
111            changed_files,
112        };
113
114        self.checkpoints.push(checkpoint.clone());
115        Ok(checkpoint)
116    }
117
118    /// Restore the workspace to the state at the given checkpoint.
119    pub fn restore_checkpoint(
120        &mut self,
121        checkpoint_index: usize,
122    ) -> Result<&Checkpoint, CheckpointError> {
123        if checkpoint_index >= self.checkpoints.len() {
124            return Err(CheckpointError::NoCheckpoints);
125        }
126
127        let checkpoint = &self.checkpoints[checkpoint_index];
128        let repo = self.open_repo()?;
129        let oid = Oid::from_str(&checkpoint.oid)?;
130        let commit = repo.find_commit(oid)?;
131        let tree = commit.tree()?;
132
133        // Reset the working directory to the checkpoint tree
134        repo.checkout_tree(
135            tree.as_object(),
136            Some(git2::build::CheckoutBuilder::new().force()),
137        )?;
138
139        // Reset index to match
140        let mut index = repo.index()?;
141        index.read_tree(&tree)?;
142        index.write()?;
143
144        Ok(checkpoint)
145    }
146
147    /// Undo the last change by restoring the most recent checkpoint.
148    pub fn undo(&mut self) -> Result<&Checkpoint, CheckpointError> {
149        if self.checkpoints.is_empty() {
150            return Err(CheckpointError::NoCheckpoints);
151        }
152        let last = self.checkpoints.len() - 1;
153        self.restore_checkpoint(last)
154    }
155
156    /// Get all checkpoints.
157    pub fn checkpoints(&self) -> &[Checkpoint] {
158        &self.checkpoints
159    }
160
161    /// Get the number of checkpoints.
162    pub fn count(&self) -> usize {
163        self.checkpoints.len()
164    }
165
166    /// Get the diff between the current working tree and the last checkpoint.
167    pub fn diff_from_last(&self) -> Result<String, CheckpointError> {
168        let repo = self.open_repo()?;
169
170        if self.checkpoints.is_empty() {
171            // Diff against HEAD
172            let head = repo.head()?;
173            let tree = head.peel_to_tree()?;
174            let diff = repo.diff_tree_to_workdir(Some(&tree), None)?;
175            let mut output = Vec::new();
176            diff.print(git2::DiffFormat::Patch, |_, _, line| {
177                output.extend_from_slice(line.content());
178                true
179            })?;
180            return Ok(String::from_utf8_lossy(&output).to_string());
181        }
182
183        let last = &self.checkpoints[self.checkpoints.len() - 1];
184        let oid = Oid::from_str(&last.oid)?;
185        let commit = repo.find_commit(oid)?;
186        let tree = commit.tree()?;
187
188        let diff = repo.diff_tree_to_workdir(Some(&tree), None)?;
189        let mut output = Vec::new();
190        diff.print(git2::DiffFormat::Patch, |_, _, line| {
191            output.extend_from_slice(line.content());
192            true
193        })?;
194
195        Ok(String::from_utf8_lossy(&output).to_string())
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use std::fs;
203
204    fn setup_test_repo() -> (tempfile::TempDir, PathBuf) {
205        let dir = tempfile::tempdir().unwrap();
206        let path = dir.path().to_path_buf();
207
208        // Initialize repo
209        let repo = Repository::init(&path).unwrap();
210
211        // Create initial file and commit
212        fs::write(path.join("initial.txt"), "initial content").unwrap();
213        let mut index = repo.index().unwrap();
214        index
215            .add_all(["*"].iter(), git2::IndexAddOption::DEFAULT, None)
216            .unwrap();
217        index.write().unwrap();
218        let tree_oid = index.write_tree().unwrap();
219        let tree = repo.find_tree(tree_oid).unwrap();
220        let sig = Signature::now("test", "test@test.com").unwrap();
221        repo.commit(Some("HEAD"), &sig, &sig, "Initial commit", &tree, &[])
222            .unwrap();
223
224        (dir, path)
225    }
226
227    #[test]
228    fn test_checkpoint_manager_new() {
229        let (_dir, path) = setup_test_repo();
230        let mgr = CheckpointManager::new(path.clone());
231        assert_eq!(mgr.workspace(), path);
232        assert_eq!(mgr.count(), 0);
233    }
234
235    #[test]
236    fn test_create_checkpoint() {
237        let (_dir, path) = setup_test_repo();
238        let mut mgr = CheckpointManager::new(path.clone());
239
240        // Modify a file
241        fs::write(path.join("initial.txt"), "modified content").unwrap();
242
243        let cp = mgr.create_checkpoint("before tool exec").unwrap();
244        assert_eq!(cp.label, "before tool exec");
245        assert!(!cp.oid.is_empty());
246        assert_eq!(mgr.count(), 1);
247    }
248
249    #[test]
250    fn test_create_multiple_checkpoints() {
251        let (_dir, path) = setup_test_repo();
252        let mut mgr = CheckpointManager::new(path.clone());
253
254        fs::write(path.join("initial.txt"), "v2").unwrap();
255        mgr.create_checkpoint("cp1").unwrap();
256
257        fs::write(path.join("initial.txt"), "v3").unwrap();
258        mgr.create_checkpoint("cp2").unwrap();
259
260        assert_eq!(mgr.count(), 2);
261        assert_eq!(mgr.checkpoints()[0].label, "cp1");
262        assert_eq!(mgr.checkpoints()[1].label, "cp2");
263    }
264
265    #[test]
266    fn test_restore_checkpoint() {
267        let (_dir, path) = setup_test_repo();
268        let mut mgr = CheckpointManager::new(path.clone());
269
270        // Create checkpoint with original state
271        mgr.create_checkpoint("original").unwrap();
272
273        // Modify file
274        fs::write(path.join("initial.txt"), "CHANGED").unwrap();
275        assert_eq!(
276            fs::read_to_string(path.join("initial.txt")).unwrap(),
277            "CHANGED"
278        );
279
280        // Restore
281        mgr.restore_checkpoint(0).unwrap();
282        // File should be restored to the checkpoint state
283        let content = fs::read_to_string(path.join("initial.txt")).unwrap();
284        assert_ne!(content, "CHANGED");
285    }
286
287    #[test]
288    fn test_undo_no_checkpoints() {
289        let (_dir, path) = setup_test_repo();
290        let mut mgr = CheckpointManager::new(path);
291        let result = mgr.undo();
292        assert!(result.is_err());
293        match result.unwrap_err() {
294            CheckpointError::NoCheckpoints => {}
295            other => panic!("Expected NoCheckpoints, got {:?}", other),
296        }
297    }
298
299    #[test]
300    fn test_undo_restores_last() {
301        let (_dir, path) = setup_test_repo();
302        let mut mgr = CheckpointManager::new(path.clone());
303
304        fs::write(path.join("initial.txt"), "checkpoint state").unwrap();
305        mgr.create_checkpoint("before change").unwrap();
306
307        fs::write(path.join("initial.txt"), "after change").unwrap();
308        mgr.undo().unwrap();
309
310        let content = fs::read_to_string(path.join("initial.txt")).unwrap();
311        assert_eq!(content, "checkpoint state");
312    }
313
314    #[test]
315    fn test_diff_from_last_no_checkpoints() {
316        let (_dir, path) = setup_test_repo();
317        let mgr = CheckpointManager::new(path.clone());
318
319        // Modify file
320        fs::write(path.join("initial.txt"), "modified").unwrap();
321
322        let diff = mgr.diff_from_last().unwrap();
323        assert!(!diff.is_empty());
324        assert!(diff.contains("modified"));
325    }
326
327    #[test]
328    fn test_diff_from_last_with_checkpoint() {
329        let (_dir, path) = setup_test_repo();
330        let mut mgr = CheckpointManager::new(path.clone());
331
332        fs::write(path.join("initial.txt"), "checkpoint state").unwrap();
333        mgr.create_checkpoint("cp1").unwrap();
334
335        fs::write(path.join("initial.txt"), "new state").unwrap();
336
337        let diff = mgr.diff_from_last().unwrap();
338        assert!(diff.contains("new state") || diff.contains("checkpoint state"));
339    }
340
341    #[test]
342    fn test_checkpoint_changed_files() {
343        let (_dir, path) = setup_test_repo();
344        let mut mgr = CheckpointManager::new(path.clone());
345
346        fs::write(path.join("new_file.txt"), "hello").unwrap();
347        let cp = mgr.create_checkpoint("added file").unwrap();
348        assert!(cp.changed_files.iter().any(|f| f.contains("new_file")));
349    }
350
351    #[test]
352    fn test_repo_not_found() {
353        let dir = tempfile::tempdir().unwrap();
354        let mgr = CheckpointManager::new(dir.path().to_path_buf());
355        let result = mgr.diff_from_last();
356        assert!(result.is_err());
357    }
358}