Skip to main content

saorsa_agent/tools/
write.rs

1//! Write tool for writing file contents with diff display.
2
3use std::fs;
4use std::path::PathBuf;
5
6use serde::{Deserialize, Serialize};
7
8use super::{generate_diff, resolve_path};
9use crate::error::{Result, SaorsaAgentError};
10use crate::tool::Tool;
11
12/// Tool for writing file contents.
13pub struct WriteTool {
14    /// Base directory for resolving relative paths.
15    working_dir: PathBuf,
16}
17
18/// Input parameters for the Write tool.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20struct WriteInput {
21    /// Path to the file to write.
22    file_path: String,
23    /// Content to write to the file.
24    content: String,
25}
26
27impl WriteTool {
28    /// Create a new Write tool with the given working directory.
29    pub fn new(working_dir: impl Into<PathBuf>) -> Self {
30        Self {
31            working_dir: working_dir.into(),
32        }
33    }
34}
35
36#[async_trait::async_trait]
37impl Tool for WriteTool {
38    fn name(&self) -> &str {
39        "write"
40    }
41
42    fn description(&self) -> &str {
43        "Write content to a file, creating parent directories if needed, with diff for existing files"
44    }
45
46    fn input_schema(&self) -> serde_json::Value {
47        serde_json::json!({
48            "type": "object",
49            "properties": {
50                "file_path": {
51                    "type": "string",
52                    "description": "Path to the file to write (absolute or relative to working directory)"
53                },
54                "content": {
55                    "type": "string",
56                    "description": "Content to write to the file"
57                }
58            },
59            "required": ["file_path", "content"]
60        })
61    }
62
63    async fn execute(&self, input: serde_json::Value) -> Result<String> {
64        let input: WriteInput = serde_json::from_value(input)
65            .map_err(|e| SaorsaAgentError::Tool(format!("Invalid input: {e}")))?;
66
67        let path = resolve_path(&self.working_dir, &input.file_path);
68
69        // Check if file already exists and generate diff if so
70        let (old_content, file_exists) = if path.exists() {
71            if path.is_dir() {
72                return Err(SaorsaAgentError::Tool(format!(
73                    "Path is a directory, cannot write: {}",
74                    path.display()
75                )));
76            }
77
78            let content = fs::read_to_string(&path).map_err(|e| {
79                SaorsaAgentError::Tool(format!("Failed to read existing file: {e}"))
80            })?;
81            (Some(content), true)
82        } else {
83            (None, false)
84        };
85
86        // Create parent directories if they don't exist
87        if let Some(parent) = path.parent() {
88            fs::create_dir_all(parent).map_err(|e| {
89                SaorsaAgentError::Tool(format!("Failed to create parent directories: {e}"))
90            })?;
91        }
92
93        // Write the file
94        fs::write(&path, &input.content)
95            .map_err(|e| SaorsaAgentError::Tool(format!("Failed to write file: {e}")))?;
96
97        // Build response
98        let mut response = if file_exists {
99            format!("File updated: {}\n\n", path.display())
100        } else {
101            format!("File created: {}\n\n", path.display())
102        };
103
104        // Add diff if file was updated
105        if let Some(old) = old_content {
106            if old != input.content {
107                response.push_str("Diff:\n");
108                response.push_str(&generate_diff(&old, &input.content, &path, "new"));
109            } else {
110                response.push_str("(No changes - content identical)");
111            }
112        } else {
113            response.push_str(&format!("Wrote {} bytes", input.content.len()));
114        }
115
116        Ok(response)
117    }
118}
119
120#[cfg(test)]
121#[allow(clippy::unwrap_used)]
122mod tests {
123    use super::*;
124    use std::io::Write;
125    use tempfile::NamedTempFile;
126
127    #[tokio::test]
128    async fn write_new_file() {
129        let temp_dir = tempfile::tempdir().unwrap();
130        let tool = WriteTool::new(temp_dir.path());
131
132        let file_path = temp_dir.path().join("new_file.txt");
133        let input = serde_json::json!({
134            "file_path": file_path.to_str().unwrap(),
135            "content": "Hello, World!"
136        });
137
138        let result = tool.execute(input).await;
139        assert!(result.is_ok());
140
141        let response = result.unwrap();
142        assert!(response.contains("File created"));
143        assert!(response.contains("13 bytes")); // "Hello, World!" is 13 bytes
144
145        // Verify file was created
146        assert!(file_path.exists());
147        let content = fs::read_to_string(&file_path).unwrap();
148        assert_eq!(content, "Hello, World!");
149    }
150
151    #[tokio::test]
152    async fn write_update_existing_file() {
153        let mut temp = NamedTempFile::new().unwrap();
154        writeln!(temp, "Original content").unwrap();
155        temp.flush().unwrap();
156
157        let tool = WriteTool::new(std::env::current_dir().unwrap());
158        let input = serde_json::json!({
159            "file_path": temp.path().to_str().unwrap(),
160            "content": "New content"
161        });
162
163        let result = tool.execute(input).await;
164        assert!(result.is_ok());
165
166        let response = result.unwrap();
167        assert!(response.contains("File updated"));
168        assert!(response.contains("Diff:"));
169        assert!(response.contains("-Original content"));
170        assert!(response.contains("+New content"));
171
172        // Verify file was updated
173        let content = fs::read_to_string(temp.path()).unwrap();
174        assert_eq!(content, "New content");
175    }
176
177    #[tokio::test]
178    async fn write_identical_content() {
179        let mut temp = NamedTempFile::new().unwrap();
180        writeln!(temp, "Same content").unwrap();
181        temp.flush().unwrap();
182
183        let tool = WriteTool::new(std::env::current_dir().unwrap());
184        let input = serde_json::json!({
185            "file_path": temp.path().to_str().unwrap(),
186            "content": "Same content\n"
187        });
188
189        let result = tool.execute(input).await;
190        assert!(result.is_ok());
191
192        let response = result.unwrap();
193        assert!(response.contains("File updated"));
194        assert!(response.contains("No changes - content identical"));
195    }
196
197    #[tokio::test]
198    async fn write_create_parent_directories() {
199        let temp_dir = tempfile::tempdir().unwrap();
200        let tool = WriteTool::new(temp_dir.path());
201
202        let file_path = temp_dir.path().join("subdir/nested/file.txt");
203        let input = serde_json::json!({
204            "file_path": file_path.to_str().unwrap(),
205            "content": "Nested file content"
206        });
207
208        let result = tool.execute(input).await;
209        assert!(result.is_ok());
210
211        // Verify parent directories were created
212        assert!(file_path.parent().unwrap().exists());
213        assert!(file_path.exists());
214
215        let content = fs::read_to_string(&file_path).unwrap();
216        assert_eq!(content, "Nested file content");
217    }
218
219    #[tokio::test]
220    async fn write_to_directory_fails() {
221        let temp_dir = tempfile::tempdir().unwrap();
222        let tool = WriteTool::new(temp_dir.path());
223
224        let input = serde_json::json!({
225            "file_path": temp_dir.path().to_str().unwrap(),
226            "content": "This should fail"
227        });
228
229        let result = tool.execute(input).await;
230        assert!(result.is_err());
231
232        match result {
233            Err(SaorsaAgentError::Tool(msg)) => {
234                assert!(msg.contains("is a directory"));
235            }
236            _ => panic!("Expected Tool error"),
237        }
238    }
239
240    #[test]
241    fn diff_generation() {
242        let old = "Line 1\nLine 2\nLine 3\n";
243        let new = "Line 1\nModified Line 2\nLine 3\n";
244        let path = std::path::Path::new("test.txt");
245
246        let diff = super::super::generate_diff(old, new, path, "new");
247
248        assert!(diff.contains("--- test.txt"));
249        assert!(diff.contains("+++ test.txt (new)"));
250        assert!(diff.contains("-Line 2"));
251        assert!(diff.contains("+Modified Line 2"));
252    }
253
254    #[tokio::test]
255    async fn write_relative_path() {
256        let temp_dir = tempfile::tempdir().unwrap();
257        let tool = WriteTool::new(temp_dir.path());
258
259        let input = serde_json::json!({
260            "file_path": "relative/path/file.txt",
261            "content": "Content in relative path"
262        });
263
264        let result = tool.execute(input).await;
265        assert!(result.is_ok());
266
267        let file_path = temp_dir.path().join("relative/path/file.txt");
268        assert!(file_path.exists());
269
270        let content = fs::read_to_string(&file_path).unwrap();
271        assert_eq!(content, "Content in relative path");
272    }
273}