1use std::ops::Range;
4
5use schemars::JsonSchema;
6use serde::Deserialize;
7use serde_json::Value;
8use sha2::{Digest as _, Sha256};
9use tokio_util::sync::CancellationToken;
10
11use crate::tool::{AgentTool, AgentToolResult, ToolFuture, validated_schema_for};
12use crate::types::ContentBlock;
13
14pub struct EditFileTool {
19 schema: Value,
20}
21
22impl EditFileTool {
23 #[must_use]
25 pub fn new() -> Self {
26 Self {
27 schema: validated_schema_for::<Params>(),
28 }
29 }
30}
31
32impl Default for EditFileTool {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38#[derive(Deserialize, JsonSchema)]
40#[schemars(deny_unknown_fields)]
41struct EditOp {
42 old_string: String,
45 new_string: String,
47 #[serde(default)]
50 replace_all: bool,
51 line_hint: Option<u32>,
54}
55
56#[derive(Deserialize, JsonSchema)]
57#[schemars(deny_unknown_fields)]
58struct Params {
59 path: String,
61 edits: Vec<EditOp>,
63 expected_hash: Option<String>,
66}
67
68fn sha256_hex(data: &[u8]) -> String {
74 Sha256::digest(data)
75 .iter()
76 .fold(String::with_capacity(64), |mut s, b| {
77 use std::fmt::Write as _;
78 let _ = write!(s, "{b:02x}");
79 s
80 })
81}
82
83fn line_spans(s: &str) -> Vec<(usize, &str)> {
89 let mut spans = Vec::new();
90 let mut pos = 0;
91 for line in s.split('\n') {
92 spans.push((pos, line));
93 pos += line.len() + 1; }
95 spans
96}
97
98fn find_exact(content: &str, pattern: &str) -> Vec<Range<usize>> {
100 if pattern.is_empty() {
101 return Vec::new();
102 }
103 let mut ranges = Vec::new();
104 let mut start = 0;
105 while let Some(pos) = content[start..].find(pattern) {
106 let abs = start + pos;
107 ranges.push(abs..abs + pattern.len());
108 start = abs + pattern.len();
109 }
110 ranges
111}
112
113fn find_normalized(content: &str, pattern: &str) -> Vec<Range<usize>> {
120 let pattern = pattern.trim_matches('\n');
121 if pattern.is_empty() {
122 return Vec::new();
123 }
124 let pattern_lines: Vec<&str> = pattern.split('\n').collect();
125 let spans = line_spans(content);
126 let n = pattern_lines.len();
127
128 if n > spans.len() {
129 return Vec::new();
130 }
131
132 let mut ranges = Vec::new();
133 let mut i = 0;
134 while i + n <= spans.len() {
135 let all_match = pattern_lines
136 .iter()
137 .enumerate()
138 .all(|(j, &pl)| spans[i + j].1.trim_end() == pl.trim_end());
139
140 if all_match {
141 let byte_start = spans[i].0;
142 let last = &spans[i + n - 1];
143 let byte_end = last.0 + last.1.len();
144 ranges.push(byte_start..byte_end);
145 i += n; } else {
147 i += 1;
148 }
149 }
150 ranges
151}
152
153fn line_number_at(content: &str, byte_pos: usize) -> usize {
155 content[..byte_pos].chars().filter(|&c| c == '\n').count() + 1
156}
157
158fn replace_ranges(content: &str, ranges: &[Range<usize>], replacement: &str) -> String {
162 let mut out = String::with_capacity(content.len());
163 let mut cursor = 0;
164 for r in ranges {
165 out.push_str(&content[cursor..r.start]);
166 out.push_str(replacement);
167 cursor = r.end;
168 }
169 out.push_str(&content[cursor..]);
170 out
171}
172
173fn apply_op(content: &str, op: &EditOp) -> Result<String, String> {
176 if op.old_string.is_empty() {
177 return Err("old_string must not be empty".to_owned());
178 }
179
180 let candidates: Vec<Range<usize>> = {
182 let exact = find_exact(content, &op.old_string);
183 if exact.is_empty() {
184 let norm = find_normalized(content, &op.old_string);
185 if norm.is_empty() {
186 return Err(format!(
187 "old_string not found (tried exact and whitespace-normalised match):\n{}",
188 op.old_string
189 ));
190 }
191 norm
192 } else {
193 exact
194 }
195 };
196
197 if op.replace_all {
198 return Ok(replace_ranges(content, &candidates, &op.new_string));
199 }
200
201 match candidates.len() {
202 0 => unreachable!("candidates is non-empty at this point"),
203 1 => Ok(replace_ranges(content, &candidates, &op.new_string)),
204 n => op.line_hint.map_or_else(
205 || {
206 Err(format!(
207 "old_string matched {n} times; set replace_all to replace every \
208 occurrence, or provide line_hint to select one"
209 ))
210 },
211 |hint| {
212 let best = candidates
213 .iter()
214 .min_by_key(|r| {
215 let line =
216 i64::try_from(line_number_at(content, r.start)).unwrap_or(i64::MAX);
217 (line - i64::from(hint)).abs()
218 })
219 .expect("candidates is non-empty");
220 Ok(replace_ranges(
221 content,
222 std::slice::from_ref(best),
223 &op.new_string,
224 ))
225 },
226 ),
227 }
228}
229
230async fn atomic_write(path: &std::path::Path, content: &str) -> std::io::Result<()> {
238 let tmp = {
239 let name = path
240 .file_name()
241 .unwrap_or_default()
242 .to_string_lossy()
243 .into_owned();
244 path.with_file_name(format!("{name}.swink-edit.tmp"))
245 };
246 tokio::fs::write(&tmp, content).await?;
247 tokio::fs::rename(&tmp, path).await
248}
249
250#[allow(clippy::unnecessary_literal_bound)]
255impl AgentTool for EditFileTool {
256 fn name(&self) -> &str {
257 "edit_file"
258 }
259
260 fn label(&self) -> &str {
261 "Edit File"
262 }
263
264 fn description(&self) -> &str {
265 "Apply one or more surgical find-and-replace edits to a file. \
266 Edits are applied top-to-bottom. Trailing whitespace is ignored \
267 during matching when an exact match is not found. The write is \
268 atomic: the file is never left in a partially-written state."
269 }
270
271 fn parameters_schema(&self) -> &Value {
272 &self.schema
273 }
274
275 fn requires_approval(&self) -> bool {
276 true
277 }
278
279 fn execute(
280 &self,
281 _tool_call_id: &str,
282 params: Value,
283 cancellation_token: CancellationToken,
284 _on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
285 _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
286 _credential: Option<crate::credential::ResolvedCredential>,
287 ) -> ToolFuture<'_> {
288 Box::pin(async move {
289 let parsed: Params = match serde_json::from_value(params) {
290 Ok(p) => p,
291 Err(e) => return AgentToolResult::error(format!("invalid parameters: {e}")),
292 };
293
294 if cancellation_token.is_cancelled() {
295 return AgentToolResult::error("cancelled");
296 }
297
298 let path = std::path::Path::new(&parsed.path);
299
300 let raw_bytes = match tokio::fs::read(path).await {
301 Ok(b) => b,
302 Err(e) => {
303 return AgentToolResult::error(format!("failed to read {}: {e}", parsed.path));
304 }
305 };
306
307 let original = match std::str::from_utf8(&raw_bytes) {
308 Ok(s) => s.to_owned(),
309 Err(_) => {
310 return AgentToolResult::error(format!("{} is not valid UTF-8", parsed.path));
311 }
312 };
313
314 if let Some(expected) = &parsed.expected_hash {
316 let actual = sha256_hex(&raw_bytes);
317 if actual != expected.to_ascii_lowercase() {
318 return AgentToolResult::error(format!(
319 "{} has changed since it was last read (hash mismatch); \
320 re-read the file before editing",
321 parsed.path
322 ));
323 }
324 }
325
326 if parsed.edits.is_empty() {
327 return AgentToolResult::text("no edits specified; file unchanged");
328 }
329
330 let mut content = original.clone();
332 for (i, op) in parsed.edits.iter().enumerate() {
333 content = match apply_op(&content, op) {
334 Ok(updated) => updated,
335 Err(msg) => {
336 return AgentToolResult::error(format!("edit {}: {msg}", i + 1));
337 }
338 };
339 }
340
341 if cancellation_token.is_cancelled() {
342 return AgentToolResult::error("cancelled");
343 }
344
345 if let Err(e) = atomic_write(path, &content).await {
346 return AgentToolResult::error(format!("failed to write {}: {e}", parsed.path));
347 }
348
349 let n = parsed.edits.len();
350 AgentToolResult {
351 content: vec![ContentBlock::Text {
352 text: format!(
353 "Applied {} edit{} to {}",
354 n,
355 if n == 1 { "" } else { "s" },
356 parsed.path
357 ),
358 }],
359 details: serde_json::json!({
360 "path": parsed.path,
361 "edits_applied": n,
362 "old_content": original,
363 "new_content": content,
364 }),
365 is_error: false,
366 transfer_signal: None,
367 }
368 })
369 }
370}
371
372#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
383 fn exact_single_replacement() {
384 let content = "hello world\n";
385 let op = EditOp {
386 old_string: "world".into(),
387 new_string: "Rust".into(),
388 replace_all: false,
389 line_hint: None,
390 };
391 assert_eq!(apply_op(content, &op).unwrap(), "hello Rust\n");
392 }
393
394 #[test]
395 fn normalised_trailing_whitespace_match() {
396 let content = "fn foo() { \n let x = 1;\n}\n";
398 let op = EditOp {
399 old_string: "fn foo() {\n let x = 1;\n}".into(),
400 new_string: "fn foo() {\n let x = 2;\n}".into(),
401 replace_all: false,
402 line_hint: None,
403 };
404 assert_eq!(
405 apply_op(content, &op).unwrap(),
406 "fn foo() {\n let x = 2;\n}\n"
407 );
408 }
409
410 #[test]
411 fn replace_all_occurrences() {
412 let content = "foo bar foo baz foo\n";
413 let op = EditOp {
414 old_string: "foo".into(),
415 new_string: "qux".into(),
416 replace_all: true,
417 line_hint: None,
418 };
419 assert_eq!(apply_op(content, &op).unwrap(), "qux bar qux baz qux\n");
420 }
421
422 #[test]
423 fn multiple_matches_without_hint_is_error() {
424 let content = "fn foo() {}\nfn foo() {}\n";
425 let op = EditOp {
426 old_string: "fn foo() {}".into(),
427 new_string: "fn bar() {}".into(),
428 replace_all: false,
429 line_hint: None,
430 };
431 let err = apply_op(content, &op).unwrap_err();
432 assert!(err.contains("matched 2 times"), "unexpected error: {err}");
433 }
434
435 #[test]
436 fn line_hint_picks_closest_match() {
437 let content = "fn foo() {}\nfn bar() {}\nfn foo() {}\n";
439 let op = EditOp {
440 old_string: "fn foo() {}".into(),
441 new_string: "fn baz() {}".into(),
442 replace_all: false,
443 line_hint: Some(3),
444 };
445 assert_eq!(
446 apply_op(content, &op).unwrap(),
447 "fn foo() {}\nfn bar() {}\nfn baz() {}\n"
448 );
449 }
450
451 #[test]
452 fn not_found_returns_error() {
453 let content = "hello world\n";
454 let op = EditOp {
455 old_string: "missing".into(),
456 new_string: "x".into(),
457 replace_all: false,
458 line_hint: None,
459 };
460 assert!(apply_op(content, &op).is_err());
461 }
462
463 #[test]
464 fn empty_old_string_is_error() {
465 let op = EditOp {
466 old_string: String::new(),
467 new_string: "x".into(),
468 replace_all: false,
469 line_hint: None,
470 };
471 assert!(apply_op("anything", &op).is_err());
472 }
473
474 #[test]
475 fn multiple_edits_applied_in_order() {
476 let mut content = "a b c\n".to_owned();
477 let ops = [
478 EditOp {
479 old_string: "a".into(),
480 new_string: "1".into(),
481 replace_all: false,
482 line_hint: None,
483 },
484 EditOp {
485 old_string: "b".into(),
486 new_string: "2".into(),
487 replace_all: false,
488 line_hint: None,
489 },
490 EditOp {
491 old_string: "c".into(),
492 new_string: "3".into(),
493 replace_all: false,
494 line_hint: None,
495 },
496 ];
497 for op in &ops {
498 content = apply_op(&content, op).unwrap();
499 }
500 assert_eq!(content, "1 2 3\n");
501 }
502
503 #[test]
506 fn sha256_hex_known_value() {
507 let digest = sha256_hex(b"abc");
509 assert!(digest.starts_with("ba7816bf"), "got: {digest}");
510 assert_eq!(digest.len(), 64);
511 }
512
513 #[tokio::test]
516 async fn execute_edits_file_and_returns_diff() {
517 use std::sync::{Arc, RwLock};
518
519 use serde_json::json;
520
521 use crate::SessionState;
522 use crate::tool::AgentTool;
523
524 let dir = tempfile::tempdir().unwrap();
525 let file = dir.path().join("test.txt");
526 tokio::fs::write(&file, "hello world\n").await.unwrap();
527
528 let tool = EditFileTool::new();
529 let params = json!({
530 "path": file.to_str().unwrap(),
531 "edits": [{ "old_string": "world", "new_string": "Rust" }]
532 });
533
534 let result = tool
535 .execute(
536 "id",
537 params,
538 CancellationToken::new(),
539 None,
540 Arc::new(RwLock::new(SessionState::default())),
541 None,
542 )
543 .await;
544
545 assert!(!result.is_error);
546 let on_disk = tokio::fs::read_to_string(&file).await.unwrap();
547 assert_eq!(on_disk, "hello Rust\n");
548 assert_eq!(result.details["old_content"], "hello world\n");
549 assert_eq!(result.details["new_content"], "hello Rust\n");
550 }
551
552 #[tokio::test]
553 async fn execute_rejects_stale_hash() {
554 use std::sync::{Arc, RwLock};
555
556 use serde_json::json;
557
558 use crate::SessionState;
559 use crate::tool::AgentTool;
560
561 let dir = tempfile::tempdir().unwrap();
562 let file = dir.path().join("test.txt");
563 tokio::fs::write(&file, "hello world\n").await.unwrap();
564
565 let tool = EditFileTool::new();
566 let params = json!({
567 "path": file.to_str().unwrap(),
568 "edits": [{ "old_string": "world", "new_string": "Rust" }],
569 "expected_hash": "0000000000000000000000000000000000000000000000000000000000000000"
570 });
571
572 let result = tool
573 .execute(
574 "id",
575 params,
576 CancellationToken::new(),
577 None,
578 Arc::new(RwLock::new(SessionState::default())),
579 None,
580 )
581 .await;
582
583 assert!(result.is_error);
584 let text = match &result.content[0] {
585 ContentBlock::Text { text } => text.clone(),
586 _ => panic!("expected text block"),
587 };
588 assert!(text.contains("hash mismatch"), "got: {text}");
589 }
590}