Skip to main content

saorsa_agent/tools/
write.rs

1//! Write tool for writing file contents with diff display.
2
3use std::fs;
4use std::path::{Path, PathBuf};
5
6use serde::{Deserialize, Serialize};
7use similar::{ChangeTag, TextDiff};
8
9use crate::error::{Result, SaorsaAgentError};
10use crate::tool::Tool;
11
12/// Tool for writing file contents.
13pub struct WriteTool {
14    /// Base directory for resolving relative paths.
15    working_dir: PathBuf,
16}
17
18/// Input parameters for the Write tool.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20struct WriteInput {
21    /// Path to the file to write.
22    file_path: String,
23    /// Content to write to the file.
24    content: String,
25}
26
27impl WriteTool {
28    /// Create a new Write tool with the given working directory.
29    pub fn new(working_dir: impl Into<PathBuf>) -> Self {
30        Self {
31            working_dir: working_dir.into(),
32        }
33    }
34
35    /// Resolve a file path relative to the working directory.
36    fn resolve_path(&self, path: &str) -> PathBuf {
37        let path = Path::new(path);
38        if path.is_absolute() {
39            path.to_path_buf()
40        } else {
41            self.working_dir.join(path)
42        }
43    }
44
45    /// Generate a unified diff between old and new content.
46    fn generate_diff(old_content: &str, new_content: &str, file_path: &Path) -> String {
47        let diff = TextDiff::from_lines(old_content, new_content);
48
49        let mut output = String::new();
50        output.push_str(&format!("--- {}\n", file_path.display()));
51        output.push_str(&format!("+++ {} (new)\n", file_path.display()));
52
53        for change in diff.iter_all_changes() {
54            let sign = match change.tag() {
55                ChangeTag::Delete => "-",
56                ChangeTag::Insert => "+",
57                ChangeTag::Equal => " ",
58            };
59            output.push_str(&format!("{}{}", sign, change));
60        }
61
62        output
63    }
64}
65
66#[async_trait::async_trait]
67impl Tool for WriteTool {
68    fn name(&self) -> &str {
69        "write"
70    }
71
72    fn description(&self) -> &str {
73        "Write content to a file, creating parent directories if needed, with diff for existing files"
74    }
75
76    fn input_schema(&self) -> serde_json::Value {
77        serde_json::json!({
78            "type": "object",
79            "properties": {
80                "file_path": {
81                    "type": "string",
82                    "description": "Path to the file to write (absolute or relative to working directory)"
83                },
84                "content": {
85                    "type": "string",
86                    "description": "Content to write to the file"
87                }
88            },
89            "required": ["file_path", "content"]
90        })
91    }
92
93    async fn execute(&self, input: serde_json::Value) -> Result<String> {
94        let input: WriteInput = serde_json::from_value(input)
95            .map_err(|e| SaorsaAgentError::Tool(format!("Invalid input: {e}")))?;
96
97        let path = self.resolve_path(&input.file_path);
98
99        // Check if file already exists and generate diff if so
100        let (old_content, file_exists) = if path.exists() {
101            if path.is_dir() {
102                return Err(SaorsaAgentError::Tool(format!(
103                    "Path is a directory, cannot write: {}",
104                    path.display()
105                )));
106            }
107
108            let content = fs::read_to_string(&path).map_err(|e| {
109                SaorsaAgentError::Tool(format!("Failed to read existing file: {e}"))
110            })?;
111            (Some(content), true)
112        } else {
113            (None, false)
114        };
115
116        // Create parent directories if they don't exist
117        if let Some(parent) = path.parent() {
118            fs::create_dir_all(parent).map_err(|e| {
119                SaorsaAgentError::Tool(format!("Failed to create parent directories: {e}"))
120            })?;
121        }
122
123        // Write the file
124        fs::write(&path, &input.content)
125            .map_err(|e| SaorsaAgentError::Tool(format!("Failed to write file: {e}")))?;
126
127        // Build response
128        let mut response = if file_exists {
129            format!("File updated: {}\n\n", path.display())
130        } else {
131            format!("File created: {}\n\n", path.display())
132        };
133
134        // Add diff if file was updated
135        if let Some(old) = old_content {
136            if old != input.content {
137                response.push_str("Diff:\n");
138                response.push_str(&Self::generate_diff(&old, &input.content, &path));
139            } else {
140                response.push_str("(No changes - content identical)");
141            }
142        } else {
143            response.push_str(&format!("Wrote {} bytes", input.content.len()));
144        }
145
146        Ok(response)
147    }
148}
149
150#[cfg(test)]
151#[allow(clippy::unwrap_used)]
152mod tests {
153    use super::*;
154    use std::io::Write;
155    use tempfile::NamedTempFile;
156
157    #[tokio::test]
158    async fn write_new_file() {
159        let temp_dir = tempfile::tempdir().unwrap();
160        let tool = WriteTool::new(temp_dir.path());
161
162        let file_path = temp_dir.path().join("new_file.txt");
163        let input = serde_json::json!({
164            "file_path": file_path.to_str().unwrap(),
165            "content": "Hello, World!"
166        });
167
168        let result = tool.execute(input).await;
169        assert!(result.is_ok());
170
171        let response = result.unwrap();
172        assert!(response.contains("File created"));
173        assert!(response.contains("13 bytes")); // "Hello, World!" is 13 bytes
174
175        // Verify file was created
176        assert!(file_path.exists());
177        let content = fs::read_to_string(&file_path).unwrap();
178        assert_eq!(content, "Hello, World!");
179    }
180
181    #[tokio::test]
182    async fn write_update_existing_file() {
183        let mut temp = NamedTempFile::new().unwrap();
184        writeln!(temp, "Original content").unwrap();
185        temp.flush().unwrap();
186
187        let tool = WriteTool::new(std::env::current_dir().unwrap());
188        let input = serde_json::json!({
189            "file_path": temp.path().to_str().unwrap(),
190            "content": "New content"
191        });
192
193        let result = tool.execute(input).await;
194        assert!(result.is_ok());
195
196        let response = result.unwrap();
197        assert!(response.contains("File updated"));
198        assert!(response.contains("Diff:"));
199        assert!(response.contains("-Original content"));
200        assert!(response.contains("+New content"));
201
202        // Verify file was updated
203        let content = fs::read_to_string(temp.path()).unwrap();
204        assert_eq!(content, "New content");
205    }
206
207    #[tokio::test]
208    async fn write_identical_content() {
209        let mut temp = NamedTempFile::new().unwrap();
210        writeln!(temp, "Same content").unwrap();
211        temp.flush().unwrap();
212
213        let tool = WriteTool::new(std::env::current_dir().unwrap());
214        let input = serde_json::json!({
215            "file_path": temp.path().to_str().unwrap(),
216            "content": "Same content\n"
217        });
218
219        let result = tool.execute(input).await;
220        assert!(result.is_ok());
221
222        let response = result.unwrap();
223        assert!(response.contains("File updated"));
224        assert!(response.contains("No changes - content identical"));
225    }
226
227    #[tokio::test]
228    async fn write_create_parent_directories() {
229        let temp_dir = tempfile::tempdir().unwrap();
230        let tool = WriteTool::new(temp_dir.path());
231
232        let file_path = temp_dir.path().join("subdir/nested/file.txt");
233        let input = serde_json::json!({
234            "file_path": file_path.to_str().unwrap(),
235            "content": "Nested file content"
236        });
237
238        let result = tool.execute(input).await;
239        assert!(result.is_ok());
240
241        // Verify parent directories were created
242        assert!(file_path.parent().unwrap().exists());
243        assert!(file_path.exists());
244
245        let content = fs::read_to_string(&file_path).unwrap();
246        assert_eq!(content, "Nested file content");
247    }
248
249    #[tokio::test]
250    async fn write_to_directory_fails() {
251        let temp_dir = tempfile::tempdir().unwrap();
252        let tool = WriteTool::new(temp_dir.path());
253
254        let input = serde_json::json!({
255            "file_path": temp_dir.path().to_str().unwrap(),
256            "content": "This should fail"
257        });
258
259        let result = tool.execute(input).await;
260        assert!(result.is_err());
261
262        match result {
263            Err(SaorsaAgentError::Tool(msg)) => {
264                assert!(msg.contains("is a directory"));
265            }
266            _ => panic!("Expected Tool error"),
267        }
268    }
269
270    #[test]
271    fn diff_generation() {
272        let old = "Line 1\nLine 2\nLine 3\n";
273        let new = "Line 1\nModified Line 2\nLine 3\n";
274        let path = Path::new("test.txt");
275
276        let diff = WriteTool::generate_diff(old, new, path);
277
278        assert!(diff.contains("--- test.txt"));
279        assert!(diff.contains("+++ test.txt (new)"));
280        assert!(diff.contains("-Line 2"));
281        assert!(diff.contains("+Modified Line 2"));
282    }
283
284    #[tokio::test]
285    async fn write_relative_path() {
286        let temp_dir = tempfile::tempdir().unwrap();
287        let tool = WriteTool::new(temp_dir.path());
288
289        let input = serde_json::json!({
290            "file_path": "relative/path/file.txt",
291            "content": "Content in relative path"
292        });
293
294        let result = tool.execute(input).await;
295        assert!(result.is_ok());
296
297        let file_path = temp_dir.path().join("relative/path/file.txt");
298        assert!(file_path.exists());
299
300        let content = fs::read_to_string(&file_path).unwrap();
301        assert_eq!(content, "Content in relative path");
302    }
303}