1use crate::registry::Tool;
4use async_trait::async_trait;
5use rustant_core::error::ToolError;
6use rustant_core::types::{RiskLevel, ToolOutput};
7use std::path::PathBuf;
8use tracing::debug;
9
10pub struct GitStatusTool {
12 workspace: PathBuf,
13}
14
15impl GitStatusTool {
16 pub fn new(workspace: PathBuf) -> Self {
17 Self { workspace }
18 }
19
20 async fn run_git(&self, args: &[&str]) -> Result<String, ToolError> {
21 let output = tokio::process::Command::new("git")
22 .args(args)
23 .current_dir(&self.workspace)
24 .output()
25 .await
26 .map_err(|e| ToolError::ExecutionFailed {
27 name: "git".into(),
28 message: format!("Failed to run git: {}", e),
29 })?;
30
31 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
32 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
33
34 if !output.status.success() {
35 return Err(ToolError::ExecutionFailed {
36 name: "git".into(),
37 message: format!("git {} failed: {}", args.join(" "), stderr),
38 });
39 }
40
41 Ok(if stdout.is_empty() { stderr } else { stdout })
42 }
43}
44
45#[async_trait]
46impl Tool for GitStatusTool {
47 fn name(&self) -> &str {
48 "git_status"
49 }
50
51 fn description(&self) -> &str {
52 "Show the current git repository status, including staged, modified, and untracked files."
53 }
54
55 fn parameters_schema(&self) -> serde_json::Value {
56 serde_json::json!({
57 "type": "object",
58 "properties": {}
59 })
60 }
61
62 async fn execute(&self, _args: serde_json::Value) -> Result<ToolOutput, ToolError> {
63 debug!(workspace = %self.workspace.display(), "Getting git status");
64 let status = self.run_git(&["status", "--short"]).await?;
65 let branch = self.run_git(&["branch", "--show-current"]).await?;
66
67 let output = format!(
68 "Branch: {}\n{}",
69 branch.trim(),
70 if status.trim().is_empty() {
71 "Working tree clean".to_string()
72 } else {
73 status
74 }
75 );
76
77 Ok(ToolOutput::text(output))
78 }
79
80 fn risk_level(&self) -> RiskLevel {
81 RiskLevel::ReadOnly
82 }
83}
84
85pub struct GitDiffTool {
87 workspace: PathBuf,
88}
89
90impl GitDiffTool {
91 pub fn new(workspace: PathBuf) -> Self {
92 Self { workspace }
93 }
94
95 async fn run_git(&self, args: &[&str]) -> Result<String, ToolError> {
96 let output = tokio::process::Command::new("git")
97 .args(args)
98 .current_dir(&self.workspace)
99 .output()
100 .await
101 .map_err(|e| ToolError::ExecutionFailed {
102 name: "git_diff".into(),
103 message: format!("Failed to run git: {}", e),
104 })?;
105
106 Ok(String::from_utf8_lossy(&output.stdout).to_string())
107 }
108}
109
110#[async_trait]
111impl Tool for GitDiffTool {
112 fn name(&self) -> &str {
113 "git_diff"
114 }
115
116 fn description(&self) -> &str {
117 "Show the diff of changes in the working tree. Optionally specify a file path to see changes for a specific file."
118 }
119
120 fn parameters_schema(&self) -> serde_json::Value {
121 serde_json::json!({
122 "type": "object",
123 "properties": {
124 "path": {
125 "type": "string",
126 "description": "Optional file path to diff"
127 },
128 "staged": {
129 "type": "boolean",
130 "description": "Show staged changes instead of unstaged. Default: false."
131 }
132 }
133 })
134 }
135
136 async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
137 let staged = args["staged"].as_bool().unwrap_or(false);
138 let path = args["path"].as_str();
139
140 let mut git_args = vec!["diff"];
141 if staged {
142 git_args.push("--cached");
143 }
144 if let Some(p) = path {
145 git_args.push("--");
146 git_args.push(p);
147 }
148
149 debug!(staged, path = ?path, "Getting git diff");
150
151 let diff = self.run_git(&git_args).await?;
152
153 let output = if diff.trim().is_empty() {
154 let scope = if staged { "staged" } else { "unstaged" };
155 format!("No {} changes", scope)
156 } else {
157 diff
158 };
159
160 Ok(ToolOutput::text(output))
161 }
162
163 fn risk_level(&self) -> RiskLevel {
164 RiskLevel::ReadOnly
165 }
166}
167
168pub struct GitCommitTool {
170 workspace: PathBuf,
171}
172
173impl GitCommitTool {
174 pub fn new(workspace: PathBuf) -> Self {
175 Self { workspace }
176 }
177
178 async fn run_git(&self, args: &[&str]) -> Result<String, ToolError> {
179 let output = tokio::process::Command::new("git")
180 .args(args)
181 .current_dir(&self.workspace)
182 .output()
183 .await
184 .map_err(|e| ToolError::ExecutionFailed {
185 name: "git_commit".into(),
186 message: format!("Failed to run git: {}", e),
187 })?;
188
189 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
190 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
191
192 if !output.status.success() {
193 return Err(ToolError::ExecutionFailed {
194 name: "git_commit".into(),
195 message: format!("git {} failed: {}", args.join(" "), stderr),
196 });
197 }
198
199 Ok(if stdout.is_empty() { stderr } else { stdout })
200 }
201}
202
203#[async_trait]
204impl Tool for GitCommitTool {
205 fn name(&self) -> &str {
206 "git_commit"
207 }
208
209 fn description(&self) -> &str {
210 "Stage files and create a git commit. Specify files to stage and a commit message."
211 }
212
213 fn parameters_schema(&self) -> serde_json::Value {
214 serde_json::json!({
215 "type": "object",
216 "properties": {
217 "message": {
218 "type": "string",
219 "description": "The commit message"
220 },
221 "files": {
222 "type": "array",
223 "items": { "type": "string" },
224 "description": "Files to stage before committing. Use [\".\"] for all changes."
225 }
226 },
227 "required": ["message"]
228 })
229 }
230
231 async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
232 let message = args["message"]
233 .as_str()
234 .ok_or_else(|| ToolError::InvalidArguments {
235 name: "git_commit".into(),
236 reason: "'message' parameter is required".into(),
237 })?;
238
239 if let Some(files) = args["files"].as_array() {
241 for file in files {
242 if let Some(f) = file.as_str() {
243 debug!(file = f, "Staging file");
244 self.run_git(&["add", f]).await?;
245 }
246 }
247 }
248
249 debug!(message = message, "Creating commit");
251 let result = self.run_git(&["commit", "-m", message]).await?;
252
253 Ok(ToolOutput::text(format!("Committed: {}", result.trim())))
254 }
255
256 fn risk_level(&self) -> RiskLevel {
257 RiskLevel::Write
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use tempfile::TempDir;
265
266 fn setup_git_repo() -> TempDir {
267 let dir = TempDir::new().unwrap();
268 std::process::Command::new("git")
270 .args(["init"])
271 .current_dir(dir.path())
272 .output()
273 .unwrap();
274 std::process::Command::new("git")
275 .args(["config", "user.email", "test@example.com"])
276 .current_dir(dir.path())
277 .output()
278 .unwrap();
279 std::process::Command::new("git")
280 .args(["config", "user.name", "Test User"])
281 .current_dir(dir.path())
282 .output()
283 .unwrap();
284 std::process::Command::new("git")
286 .args(["config", "commit.gpgsign", "false"])
287 .current_dir(dir.path())
288 .output()
289 .unwrap();
290
291 std::fs::write(dir.path().join("README.md"), "# Test\n").unwrap();
293 std::process::Command::new("git")
294 .args(["add", "."])
295 .current_dir(dir.path())
296 .output()
297 .unwrap();
298 std::process::Command::new("git")
299 .args(["commit", "-m", "Initial commit"])
300 .current_dir(dir.path())
301 .output()
302 .unwrap();
303
304 dir
305 }
306
307 #[tokio::test]
308 async fn test_git_status_clean() {
309 let dir = setup_git_repo();
310 let tool = GitStatusTool::new(dir.path().to_path_buf());
311
312 let result = tool.execute(serde_json::json!({})).await.unwrap();
313 assert!(
314 result.content.contains("Working tree clean") || result.content.contains("Branch:")
315 );
316 }
317
318 #[tokio::test]
319 async fn test_git_status_with_changes() {
320 let dir = setup_git_repo();
321 std::fs::write(dir.path().join("new_file.txt"), "new content").unwrap();
322
323 let tool = GitStatusTool::new(dir.path().to_path_buf());
324 let result = tool.execute(serde_json::json!({})).await.unwrap();
325 assert!(result.content.contains("new_file.txt"));
326 }
327
328 #[tokio::test]
329 async fn test_git_diff_no_changes() {
330 let dir = setup_git_repo();
331 let tool = GitDiffTool::new(dir.path().to_path_buf());
332
333 let result = tool.execute(serde_json::json!({})).await.unwrap();
334 assert!(result.content.contains("No unstaged changes"));
335 }
336
337 #[tokio::test]
338 async fn test_git_diff_with_changes() {
339 let dir = setup_git_repo();
340 std::fs::write(dir.path().join("README.md"), "# Updated\n").unwrap();
341
342 let tool = GitDiffTool::new(dir.path().to_path_buf());
343 let result = tool.execute(serde_json::json!({})).await.unwrap();
344 assert!(result.content.contains("Updated") || result.content.contains("diff"));
345 }
346
347 #[tokio::test]
348 async fn test_git_commit() {
349 let dir = setup_git_repo();
350 std::fs::write(dir.path().join("new_file.txt"), "content").unwrap();
351
352 let tool = GitCommitTool::new(dir.path().to_path_buf());
353 let result = tool
354 .execute(serde_json::json!({
355 "message": "Add new file",
356 "files": ["new_file.txt"]
357 }))
358 .await
359 .unwrap();
360
361 assert!(result.content.contains("Committed"));
362 }
363
364 #[test]
365 fn test_git_tool_properties() {
366 let ws = PathBuf::from("/tmp");
367 let status = GitStatusTool::new(ws.clone());
368 assert_eq!(status.name(), "git_status");
369 assert_eq!(status.risk_level(), RiskLevel::ReadOnly);
370
371 let diff = GitDiffTool::new(ws.clone());
372 assert_eq!(diff.name(), "git_diff");
373 assert_eq!(diff.risk_level(), RiskLevel::ReadOnly);
374
375 let commit = GitCommitTool::new(ws);
376 assert_eq!(commit.name(), "git_commit");
377 assert_eq!(commit.risk_level(), RiskLevel::Write);
378 }
379
380 #[tokio::test]
381 async fn test_git_commit_missing_message() {
382 let dir = setup_git_repo();
383 let tool = GitCommitTool::new(dir.path().to_path_buf());
384 let result = tool.execute(serde_json::json!({})).await;
385 assert!(result.is_err());
386 match result.unwrap_err() {
387 ToolError::InvalidArguments { name, reason } => {
388 assert_eq!(name, "git_commit");
389 assert!(reason.contains("message"));
390 }
391 e => panic!("Expected InvalidArguments, got: {:?}", e),
392 }
393 }
394
395 #[tokio::test]
396 async fn test_git_commit_null_message() {
397 let dir = setup_git_repo();
398 let tool = GitCommitTool::new(dir.path().to_path_buf());
399 let result = tool.execute(serde_json::json!({"message": null})).await;
400 assert!(result.is_err());
401 }
402
403 #[tokio::test]
404 async fn test_git_diff_staged_no_changes() {
405 let dir = setup_git_repo();
406 let tool = GitDiffTool::new(dir.path().to_path_buf());
407 let result = tool
408 .execute(serde_json::json!({"staged": true}))
409 .await
410 .unwrap();
411 assert!(result.content.contains("No staged changes"));
412 }
413
414 #[tokio::test]
415 async fn test_git_status_in_non_repo() {
416 let dir = TempDir::new().unwrap(); let tool = GitStatusTool::new(dir.path().to_path_buf());
418 let result = tool.execute(serde_json::json!({})).await;
419 assert!(result.is_err());
421 }
422
423 #[test]
424 fn test_git_commit_schema_required() {
425 let tool = GitCommitTool::new(PathBuf::from("/tmp"));
426 let schema = tool.parameters_schema();
427 let required = schema["required"].as_array().unwrap();
428 assert!(required.contains(&serde_json::json!("message")));
429 }
430
431 #[test]
432 fn test_git_diff_schema_no_required() {
433 let tool = GitDiffTool::new(PathBuf::from("/tmp"));
434 let schema = tool.parameters_schema();
435 assert!(schema.get("required").is_none());
437 }
438
439 #[test]
440 fn test_git_status_schema_no_required() {
441 let tool = GitStatusTool::new(PathBuf::from("/tmp"));
442 let schema = tool.parameters_schema();
443 assert!(schema.get("required").is_none());
444 }
445}