Skip to main content

saorsa_agent/tools/
edit.rs

1//! Edit tool for surgical file editing with ambiguity detection.
2
3use std::fs;
4use std::path::PathBuf;
5
6use serde::{Deserialize, Serialize};
7
8use super::{generate_diff, resolve_path};
9use crate::error::{Result, SaorsaAgentError};
10use crate::tool::Tool;
11
12/// Tool for surgical file editing.
13pub struct EditTool {
14    /// Base directory for resolving relative paths.
15    working_dir: PathBuf,
16}
17
18/// Input parameters for the Edit tool.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20struct EditInput {
21    /// Path to the file to edit.
22    file_path: String,
23    /// Text to search for and replace.
24    old_text: String,
25    /// Replacement text.
26    new_text: String,
27    /// Replace all occurrences (default: false).
28    #[serde(default)]
29    replace_all: bool,
30}
31
32impl EditTool {
33    /// Create a new Edit tool with the given working directory.
34    pub fn new(working_dir: impl Into<PathBuf>) -> Self {
35        Self {
36            working_dir: working_dir.into(),
37        }
38    }
39}
40
41#[async_trait::async_trait]
42impl Tool for EditTool {
43    fn name(&self) -> &str {
44        "edit"
45    }
46
47    fn description(&self) -> &str {
48        "Edit a file by replacing exact text matches, with ambiguity detection"
49    }
50
51    fn input_schema(&self) -> serde_json::Value {
52        serde_json::json!({
53            "type": "object",
54            "properties": {
55                "file_path": {
56                    "type": "string",
57                    "description": "Path to the file to edit (absolute or relative to working directory)"
58                },
59                "old_text": {
60                    "type": "string",
61                    "description": "Exact text to search for and replace"
62                },
63                "new_text": {
64                    "type": "string",
65                    "description": "Replacement text"
66                },
67                "replace_all": {
68                    "type": "boolean",
69                    "description": "Replace all occurrences (default: false, errors if multiple matches found)",
70                    "default": false
71                }
72            },
73            "required": ["file_path", "old_text", "new_text"]
74        })
75    }
76
77    async fn execute(&self, input: serde_json::Value) -> Result<String> {
78        let input: EditInput = serde_json::from_value(input)
79            .map_err(|e| SaorsaAgentError::Tool(format!("Invalid input: {e}")))?;
80
81        let path = resolve_path(&self.working_dir, &input.file_path);
82
83        // Check if file exists
84        if !path.exists() {
85            return Err(SaorsaAgentError::Tool(format!(
86                "File not found: {}",
87                path.display()
88            )));
89        }
90
91        // Check if path is a file
92        if !path.is_file() {
93            return Err(SaorsaAgentError::Tool(format!(
94                "Path is not a file: {}",
95                path.display()
96            )));
97        }
98
99        // Read file contents
100        let content = fs::read_to_string(&path)
101            .map_err(|e| SaorsaAgentError::Tool(format!("Failed to read file: {e}")))?;
102
103        // Count occurrences of old_text
104        let match_count = content.matches(&input.old_text).count();
105
106        if match_count == 0 {
107            return Err(SaorsaAgentError::Tool(format!(
108                "Text not found in file: '{}'",
109                input.old_text
110            )));
111        }
112
113        // Check for ambiguity
114        if match_count > 1 && !input.replace_all {
115            return Err(SaorsaAgentError::Tool(format!(
116                "Ambiguous: found {} matches for '{}'. Use replace_all: true to replace all occurrences, or provide more context to make the match unique.",
117                match_count, input.old_text
118            )));
119        }
120
121        // Perform replacement
122        let new_content = if input.replace_all {
123            content.replace(&input.old_text, &input.new_text)
124        } else {
125            content.replacen(&input.old_text, &input.new_text, 1)
126        };
127
128        // Write the updated content
129        fs::write(&path, &new_content)
130            .map_err(|e| SaorsaAgentError::Tool(format!("Failed to write file: {e}")))?;
131
132        // Build response
133        let mut response = if input.replace_all {
134            format!(
135                "Replaced {} occurrence(s) of text in: {}\n\n",
136                match_count,
137                path.display()
138            )
139        } else {
140            format!("Replaced text in: {}\n\n", path.display())
141        };
142
143        // Add diff
144        response.push_str("Diff:\n");
145        response.push_str(&generate_diff(&content, &new_content, &path, "edited"));
146
147        Ok(response)
148    }
149}
150
151#[cfg(test)]
152#[allow(clippy::unwrap_used)]
153mod tests {
154    use super::*;
155    use std::io::Write;
156    use tempfile::NamedTempFile;
157
158    #[tokio::test]
159    async fn edit_single_replacement() {
160        let mut temp = NamedTempFile::new().unwrap();
161        writeln!(temp, "Line 1").unwrap();
162        writeln!(temp, "Line 2").unwrap();
163        writeln!(temp, "Line 3").unwrap();
164        temp.flush().unwrap();
165
166        let tool = EditTool::new(std::env::current_dir().unwrap());
167        let input = serde_json::json!({
168            "file_path": temp.path().to_str().unwrap(),
169            "old_text": "Line 2",
170            "new_text": "Modified Line 2"
171        });
172
173        let result = tool.execute(input).await;
174        assert!(result.is_ok());
175
176        let response = result.unwrap();
177        assert!(response.contains("Replaced text"));
178        assert!(response.contains("Diff:"));
179        assert!(response.contains("-Line 2"));
180        assert!(response.contains("+Modified Line 2"));
181
182        // Verify file was edited
183        let content = fs::read_to_string(temp.path()).unwrap();
184        assert!(content.contains("Modified Line 2"));
185        // Verify expected final content
186        assert_eq!(content, "Line 1\nModified Line 2\nLine 3\n");
187    }
188
189    #[tokio::test]
190    async fn edit_ambiguous_without_replace_all() {
191        let mut temp = NamedTempFile::new().unwrap();
192        writeln!(temp, "foo bar").unwrap();
193        writeln!(temp, "foo baz").unwrap();
194        writeln!(temp, "foo qux").unwrap();
195        temp.flush().unwrap();
196
197        let tool = EditTool::new(std::env::current_dir().unwrap());
198        let input = serde_json::json!({
199            "file_path": temp.path().to_str().unwrap(),
200            "old_text": "foo",
201            "new_text": "FOO"
202        });
203
204        let result = tool.execute(input).await;
205        assert!(result.is_err());
206
207        match result {
208            Err(SaorsaAgentError::Tool(msg)) => {
209                assert!(msg.contains("Ambiguous"));
210                assert!(msg.contains("3 matches"));
211                assert!(msg.contains("replace_all"));
212            }
213            _ => panic!("Expected Tool error"),
214        }
215    }
216
217    #[tokio::test]
218    async fn edit_replace_all() {
219        let mut temp = NamedTempFile::new().unwrap();
220        writeln!(temp, "foo bar").unwrap();
221        writeln!(temp, "foo baz").unwrap();
222        writeln!(temp, "foo qux").unwrap();
223        temp.flush().unwrap();
224
225        let tool = EditTool::new(std::env::current_dir().unwrap());
226        let input = serde_json::json!({
227            "file_path": temp.path().to_str().unwrap(),
228            "old_text": "foo",
229            "new_text": "FOO",
230            "replace_all": true
231        });
232
233        let result = tool.execute(input).await;
234        assert!(result.is_ok());
235
236        let response = result.unwrap();
237        assert!(response.contains("Replaced 3 occurrence(s)"));
238        assert!(response.contains("Diff:"));
239
240        // Verify all occurrences were replaced
241        let content = fs::read_to_string(temp.path()).unwrap();
242        assert_eq!(content.matches("FOO").count(), 3);
243        assert_eq!(content.matches("foo").count(), 0);
244    }
245
246    #[tokio::test]
247    async fn edit_text_not_found() {
248        let mut temp = NamedTempFile::new().unwrap();
249        writeln!(temp, "Some content").unwrap();
250        temp.flush().unwrap();
251
252        let tool = EditTool::new(std::env::current_dir().unwrap());
253        let input = serde_json::json!({
254            "file_path": temp.path().to_str().unwrap(),
255            "old_text": "Nonexistent text",
256            "new_text": "Replacement"
257        });
258
259        let result = tool.execute(input).await;
260        assert!(result.is_err());
261
262        match result {
263            Err(SaorsaAgentError::Tool(msg)) => {
264                assert!(msg.contains("Text not found"));
265            }
266            _ => panic!("Expected Tool error"),
267        }
268    }
269
270    #[tokio::test]
271    async fn edit_file_not_found() {
272        let tool = EditTool::new(std::env::current_dir().unwrap());
273        let input = serde_json::json!({
274            "file_path": "/nonexistent/file.txt",
275            "old_text": "old",
276            "new_text": "new"
277        });
278
279        let result = tool.execute(input).await;
280        assert!(result.is_err());
281
282        match result {
283            Err(SaorsaAgentError::Tool(msg)) => {
284                assert!(msg.contains("File not found"));
285            }
286            _ => panic!("Expected Tool error"),
287        }
288    }
289
290    #[tokio::test]
291    async fn edit_multiline_text() {
292        let mut temp = NamedTempFile::new().unwrap();
293        writeln!(temp, "Line 1").unwrap();
294        writeln!(temp, "Line 2").unwrap();
295        writeln!(temp, "Line 3").unwrap();
296        writeln!(temp, "Line 4").unwrap();
297        temp.flush().unwrap();
298
299        let tool = EditTool::new(std::env::current_dir().unwrap());
300        let input = serde_json::json!({
301            "file_path": temp.path().to_str().unwrap(),
302            "old_text": "Line 2\nLine 3",
303            "new_text": "Modified Lines 2-3"
304        });
305
306        let result = tool.execute(input).await;
307        assert!(result.is_ok());
308
309        let content = fs::read_to_string(temp.path()).unwrap();
310        assert!(content.contains("Modified Lines 2-3"));
311        assert!(!content.contains("Line 2\nLine 3"));
312    }
313
314    #[tokio::test]
315    async fn edit_preserve_other_content() {
316        let mut temp = NamedTempFile::new().unwrap();
317        writeln!(temp, "Before").unwrap();
318        writeln!(temp, "Target").unwrap();
319        writeln!(temp, "After").unwrap();
320        temp.flush().unwrap();
321
322        let tool = EditTool::new(std::env::current_dir().unwrap());
323        let input = serde_json::json!({
324            "file_path": temp.path().to_str().unwrap(),
325            "old_text": "Target",
326            "new_text": "Modified"
327        });
328
329        let result = tool.execute(input).await;
330        assert!(result.is_ok());
331
332        let content = fs::read_to_string(temp.path()).unwrap();
333        assert!(content.contains("Before"));
334        assert!(content.contains("Modified"));
335        assert!(content.contains("After"));
336        assert!(!content.contains("Target"));
337    }
338}