Skip to main content

rustant_tools/
git.rs

1//! Git integration tools: status, diff, and commit.
2
3use crate::registry::Tool;
4use async_trait::async_trait;
5use rustant_core::error::ToolError;
6use rustant_core::types::{RiskLevel, ToolOutput};
7use std::path::PathBuf;
8use tracing::debug;
9
10/// Show git repository status.
11pub struct GitStatusTool {
12    workspace: PathBuf,
13}
14
15impl GitStatusTool {
16    pub fn new(workspace: PathBuf) -> Self {
17        Self { workspace }
18    }
19
20    async fn run_git(&self, args: &[&str]) -> Result<String, ToolError> {
21        let output = tokio::process::Command::new("git")
22            .args(args)
23            .current_dir(&self.workspace)
24            .output()
25            .await
26            .map_err(|e| ToolError::ExecutionFailed {
27                name: "git".into(),
28                message: format!("Failed to run git: {}", e),
29            })?;
30
31        let stdout = String::from_utf8_lossy(&output.stdout).to_string();
32        let stderr = String::from_utf8_lossy(&output.stderr).to_string();
33
34        if !output.status.success() {
35            return Err(ToolError::ExecutionFailed {
36                name: "git".into(),
37                message: format!("git {} failed: {}", args.join(" "), stderr),
38            });
39        }
40
41        Ok(if stdout.is_empty() { stderr } else { stdout })
42    }
43}
44
45#[async_trait]
46impl Tool for GitStatusTool {
47    fn name(&self) -> &str {
48        "git_status"
49    }
50
51    fn description(&self) -> &str {
52        "Show the current git repository status, including staged, modified, and untracked files."
53    }
54
55    fn parameters_schema(&self) -> serde_json::Value {
56        serde_json::json!({
57            "type": "object",
58            "properties": {}
59        })
60    }
61
62    async fn execute(&self, _args: serde_json::Value) -> Result<ToolOutput, ToolError> {
63        debug!(workspace = %self.workspace.display(), "Getting git status");
64        let status = self.run_git(&["status", "--short"]).await?;
65        let branch = self.run_git(&["branch", "--show-current"]).await?;
66
67        let output = format!(
68            "Branch: {}\n{}",
69            branch.trim(),
70            if status.trim().is_empty() {
71                "Working tree clean".to_string()
72            } else {
73                status
74            }
75        );
76
77        Ok(ToolOutput::text(output))
78    }
79
80    fn risk_level(&self) -> RiskLevel {
81        RiskLevel::ReadOnly
82    }
83}
84
85/// Show git diff of working tree changes.
86pub struct GitDiffTool {
87    workspace: PathBuf,
88}
89
90impl GitDiffTool {
91    pub fn new(workspace: PathBuf) -> Self {
92        Self { workspace }
93    }
94
95    async fn run_git(&self, args: &[&str]) -> Result<String, ToolError> {
96        let output = tokio::process::Command::new("git")
97            .args(args)
98            .current_dir(&self.workspace)
99            .output()
100            .await
101            .map_err(|e| ToolError::ExecutionFailed {
102                name: "git_diff".into(),
103                message: format!("Failed to run git: {}", e),
104            })?;
105
106        Ok(String::from_utf8_lossy(&output.stdout).to_string())
107    }
108}
109
110#[async_trait]
111impl Tool for GitDiffTool {
112    fn name(&self) -> &str {
113        "git_diff"
114    }
115
116    fn description(&self) -> &str {
117        "Show the diff of changes in the working tree. Optionally specify a file path to see changes for a specific file."
118    }
119
120    fn parameters_schema(&self) -> serde_json::Value {
121        serde_json::json!({
122            "type": "object",
123            "properties": {
124                "path": {
125                    "type": "string",
126                    "description": "Optional file path to diff"
127                },
128                "staged": {
129                    "type": "boolean",
130                    "description": "Show staged changes instead of unstaged. Default: false."
131                }
132            }
133        })
134    }
135
136    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
137        let staged = args["staged"].as_bool().unwrap_or(false);
138        let path = args["path"].as_str();
139
140        let mut git_args = vec!["diff"];
141        if staged {
142            git_args.push("--cached");
143        }
144        if let Some(p) = path {
145            git_args.push("--");
146            git_args.push(p);
147        }
148
149        debug!(staged, path = ?path, "Getting git diff");
150
151        let diff = self.run_git(&git_args).await?;
152
153        let output = if diff.trim().is_empty() {
154            let scope = if staged { "staged" } else { "unstaged" };
155            format!("No {} changes", scope)
156        } else {
157            diff
158        };
159
160        Ok(ToolOutput::text(output))
161    }
162
163    fn risk_level(&self) -> RiskLevel {
164        RiskLevel::ReadOnly
165    }
166}
167
168/// Stage files and create a git commit.
169pub struct GitCommitTool {
170    workspace: PathBuf,
171}
172
173impl GitCommitTool {
174    pub fn new(workspace: PathBuf) -> Self {
175        Self { workspace }
176    }
177
178    async fn run_git(&self, args: &[&str]) -> Result<String, ToolError> {
179        let output = tokio::process::Command::new("git")
180            .args(args)
181            .current_dir(&self.workspace)
182            .output()
183            .await
184            .map_err(|e| ToolError::ExecutionFailed {
185                name: "git_commit".into(),
186                message: format!("Failed to run git: {}", e),
187            })?;
188
189        let stdout = String::from_utf8_lossy(&output.stdout).to_string();
190        let stderr = String::from_utf8_lossy(&output.stderr).to_string();
191
192        if !output.status.success() {
193            return Err(ToolError::ExecutionFailed {
194                name: "git_commit".into(),
195                message: format!("git {} failed: {}", args.join(" "), stderr),
196            });
197        }
198
199        Ok(if stdout.is_empty() { stderr } else { stdout })
200    }
201}
202
203#[async_trait]
204impl Tool for GitCommitTool {
205    fn name(&self) -> &str {
206        "git_commit"
207    }
208
209    fn description(&self) -> &str {
210        "Stage files and create a git commit. Specify files to stage and a commit message."
211    }
212
213    fn parameters_schema(&self) -> serde_json::Value {
214        serde_json::json!({
215            "type": "object",
216            "properties": {
217                "message": {
218                    "type": "string",
219                    "description": "The commit message"
220                },
221                "files": {
222                    "type": "array",
223                    "items": { "type": "string" },
224                    "description": "Files to stage before committing. Use [\".\"] for all changes."
225                }
226            },
227            "required": ["message"]
228        })
229    }
230
231    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
232        let message = args["message"]
233            .as_str()
234            .ok_or_else(|| ToolError::InvalidArguments {
235                name: "git_commit".into(),
236                reason: "'message' parameter is required".into(),
237            })?;
238
239        // Stage files if specified
240        if let Some(files) = args["files"].as_array() {
241            for file in files {
242                if let Some(f) = file.as_str() {
243                    debug!(file = f, "Staging file");
244                    self.run_git(&["add", f]).await?;
245                }
246            }
247        }
248
249        // Create commit
250        debug!(message = message, "Creating commit");
251        let result = self.run_git(&["commit", "-m", message]).await?;
252
253        Ok(ToolOutput::text(format!("Committed: {}", result.trim())))
254    }
255
256    fn risk_level(&self) -> RiskLevel {
257        RiskLevel::Write
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use tempfile::TempDir;
265
266    fn setup_git_repo() -> TempDir {
267        let dir = TempDir::new().unwrap();
268        // Initialize a git repo
269        std::process::Command::new("git")
270            .args(["init"])
271            .current_dir(dir.path())
272            .output()
273            .unwrap();
274        std::process::Command::new("git")
275            .args(["config", "user.email", "test@example.com"])
276            .current_dir(dir.path())
277            .output()
278            .unwrap();
279        std::process::Command::new("git")
280            .args(["config", "user.name", "Test User"])
281            .current_dir(dir.path())
282            .output()
283            .unwrap();
284        // Disable commit signing for tests (avoids GPG/SSH signing issues)
285        std::process::Command::new("git")
286            .args(["config", "commit.gpgsign", "false"])
287            .current_dir(dir.path())
288            .output()
289            .unwrap();
290
291        // Create an initial commit
292        std::fs::write(dir.path().join("README.md"), "# Test\n").unwrap();
293        std::process::Command::new("git")
294            .args(["add", "."])
295            .current_dir(dir.path())
296            .output()
297            .unwrap();
298        std::process::Command::new("git")
299            .args(["commit", "-m", "Initial commit"])
300            .current_dir(dir.path())
301            .output()
302            .unwrap();
303
304        dir
305    }
306
307    #[tokio::test]
308    async fn test_git_status_clean() {
309        let dir = setup_git_repo();
310        let tool = GitStatusTool::new(dir.path().to_path_buf());
311
312        let result = tool.execute(serde_json::json!({})).await.unwrap();
313        assert!(
314            result.content.contains("Working tree clean") || result.content.contains("Branch:")
315        );
316    }
317
318    #[tokio::test]
319    async fn test_git_status_with_changes() {
320        let dir = setup_git_repo();
321        std::fs::write(dir.path().join("new_file.txt"), "new content").unwrap();
322
323        let tool = GitStatusTool::new(dir.path().to_path_buf());
324        let result = tool.execute(serde_json::json!({})).await.unwrap();
325        assert!(result.content.contains("new_file.txt"));
326    }
327
328    #[tokio::test]
329    async fn test_git_diff_no_changes() {
330        let dir = setup_git_repo();
331        let tool = GitDiffTool::new(dir.path().to_path_buf());
332
333        let result = tool.execute(serde_json::json!({})).await.unwrap();
334        assert!(result.content.contains("No unstaged changes"));
335    }
336
337    #[tokio::test]
338    async fn test_git_diff_with_changes() {
339        let dir = setup_git_repo();
340        std::fs::write(dir.path().join("README.md"), "# Updated\n").unwrap();
341
342        let tool = GitDiffTool::new(dir.path().to_path_buf());
343        let result = tool.execute(serde_json::json!({})).await.unwrap();
344        assert!(result.content.contains("Updated") || result.content.contains("diff"));
345    }
346
347    #[tokio::test]
348    async fn test_git_commit() {
349        let dir = setup_git_repo();
350        std::fs::write(dir.path().join("new_file.txt"), "content").unwrap();
351
352        let tool = GitCommitTool::new(dir.path().to_path_buf());
353        let result = tool
354            .execute(serde_json::json!({
355                "message": "Add new file",
356                "files": ["new_file.txt"]
357            }))
358            .await
359            .unwrap();
360
361        assert!(result.content.contains("Committed"));
362    }
363
364    #[test]
365    fn test_git_tool_properties() {
366        let ws = PathBuf::from("/tmp");
367        let status = GitStatusTool::new(ws.clone());
368        assert_eq!(status.name(), "git_status");
369        assert_eq!(status.risk_level(), RiskLevel::ReadOnly);
370
371        let diff = GitDiffTool::new(ws.clone());
372        assert_eq!(diff.name(), "git_diff");
373        assert_eq!(diff.risk_level(), RiskLevel::ReadOnly);
374
375        let commit = GitCommitTool::new(ws);
376        assert_eq!(commit.name(), "git_commit");
377        assert_eq!(commit.risk_level(), RiskLevel::Write);
378    }
379
380    #[tokio::test]
381    async fn test_git_commit_missing_message() {
382        let dir = setup_git_repo();
383        let tool = GitCommitTool::new(dir.path().to_path_buf());
384        let result = tool.execute(serde_json::json!({})).await;
385        assert!(result.is_err());
386        match result.unwrap_err() {
387            ToolError::InvalidArguments { name, reason } => {
388                assert_eq!(name, "git_commit");
389                assert!(reason.contains("message"));
390            }
391            e => panic!("Expected InvalidArguments, got: {:?}", e),
392        }
393    }
394
395    #[tokio::test]
396    async fn test_git_commit_null_message() {
397        let dir = setup_git_repo();
398        let tool = GitCommitTool::new(dir.path().to_path_buf());
399        let result = tool.execute(serde_json::json!({"message": null})).await;
400        assert!(result.is_err());
401    }
402
403    #[tokio::test]
404    async fn test_git_diff_staged_no_changes() {
405        let dir = setup_git_repo();
406        let tool = GitDiffTool::new(dir.path().to_path_buf());
407        let result = tool
408            .execute(serde_json::json!({"staged": true}))
409            .await
410            .unwrap();
411        assert!(result.content.contains("No staged changes"));
412    }
413
414    #[tokio::test]
415    async fn test_git_status_in_non_repo() {
416        let dir = TempDir::new().unwrap(); // Not a git repo
417        let tool = GitStatusTool::new(dir.path().to_path_buf());
418        let result = tool.execute(serde_json::json!({})).await;
419        // Should fail since it's not a git repo
420        assert!(result.is_err());
421    }
422
423    #[test]
424    fn test_git_commit_schema_required() {
425        let tool = GitCommitTool::new(PathBuf::from("/tmp"));
426        let schema = tool.parameters_schema();
427        let required = schema["required"].as_array().unwrap();
428        assert!(required.contains(&serde_json::json!("message")));
429    }
430
431    #[test]
432    fn test_git_diff_schema_no_required() {
433        let tool = GitDiffTool::new(PathBuf::from("/tmp"));
434        let schema = tool.parameters_schema();
435        // diff has no required params (path and staged are optional)
436        assert!(schema.get("required").is_none());
437    }
438
439    #[test]
440    fn test_git_status_schema_no_required() {
441        let tool = GitStatusTool::new(PathBuf::from("/tmp"));
442        let schema = tool.parameters_schema();
443        assert!(schema.get("required").is_none());
444    }
445}