Skip to main content

swink_agent/tools/
edit_file.rs

1//! Built-in tool for making surgical find-and-replace edits to a file.
2
3use std::ops::Range;
4
5use schemars::JsonSchema;
6use serde::Deserialize;
7use serde_json::Value;
8use sha2::{Digest as _, Sha256};
9use tokio_util::sync::CancellationToken;
10
11use crate::tool::{AgentTool, AgentToolResult, ToolFuture, validated_schema_for};
12use crate::types::ContentBlock;
13
14/// Built-in tool for making precise, surgical edits to a file.
15///
16/// Supports multiple edits per call, atomic writes, stale-read detection,
17/// whitespace-normalised matching, and line-number-based disambiguation.
18pub struct EditFileTool {
19    schema: Value,
20}
21
22impl EditFileTool {
23    /// Create a new `EditFileTool`.
24    #[must_use]
25    pub fn new() -> Self {
26        Self {
27            schema: validated_schema_for::<Params>(),
28        }
29    }
30}
31
32impl Default for EditFileTool {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38/// One find-and-replace operation.
39#[derive(Deserialize, JsonSchema)]
40#[schemars(deny_unknown_fields)]
41struct EditOp {
42    /// Text to find in the file.  Exact match is tried first; if that fails a
43    /// line-by-line match that ignores trailing whitespace is attempted.
44    old_string: String,
45    /// Replacement text.
46    new_string: String,
47    /// When `true`, every occurrence is replaced.  When `false` (the default)
48    /// exactly one occurrence must exist, or `line_hint` must be provided.
49    #[serde(default)]
50    replace_all: bool,
51    /// 1-based line number of the desired occurrence.  Used to pick among
52    /// multiple matches when `replace_all` is `false`.
53    line_hint: Option<u32>,
54}
55
56#[derive(Deserialize, JsonSchema)]
57#[schemars(deny_unknown_fields)]
58struct Params {
59    /// Absolute path to the file to edit.
60    path: String,
61    /// Edits to apply in order (top-to-bottom).
62    edits: Vec<EditOp>,
63    /// SHA-256 hex digest of the file content as previously read.  When
64    /// provided the edit is rejected if the file has changed since.
65    expected_hash: Option<String>,
66}
67
68// ---------------------------------------------------------------------------
69// Matching helpers
70// ---------------------------------------------------------------------------
71
72/// Compute the SHA-256 hex digest of `data`.
73fn sha256_hex(data: &[u8]) -> String {
74    Sha256::digest(data)
75        .iter()
76        .fold(String::with_capacity(64), |mut s, b| {
77            use std::fmt::Write as _;
78            let _ = write!(s, "{b:02x}");
79            s
80        })
81}
82
83/// Return `(byte_start, line_content_without_newline)` for every line.
84///
85/// Splits on `'\n'`; the `'\n'` itself is not included in the line slice.
86/// Windows `\r\n` files: the `\r` will appear as trailing content in each
87/// slice, which is stripped by [`str::trim_end`] during normalised matching.
88fn line_spans(s: &str) -> Vec<(usize, &str)> {
89    let mut spans = Vec::new();
90    let mut pos = 0;
91    for line in s.split('\n') {
92        spans.push((pos, line));
93        pos += line.len() + 1; // +1 for the '\n'
94    }
95    spans
96}
97
98/// Find all non-overlapping exact byte ranges of `pattern` in `content`.
99fn find_exact(content: &str, pattern: &str) -> Vec<Range<usize>> {
100    if pattern.is_empty() {
101        return Vec::new();
102    }
103    let mut ranges = Vec::new();
104    let mut start = 0;
105    while let Some(pos) = content[start..].find(pattern) {
106        let abs = start + pos;
107        ranges.push(abs..abs + pattern.len());
108        start = abs + pattern.len();
109    }
110    ranges
111}
112
113/// Find all non-overlapping byte ranges in `content` that match `pattern`
114/// line-by-line, ignoring trailing whitespace on each line.
115///
116/// Leading and trailing blank lines in `pattern` are stripped before
117/// matching.  The returned ranges refer to byte positions in the original
118/// (un-normalised) `content`.
119fn find_normalized(content: &str, pattern: &str) -> Vec<Range<usize>> {
120    let pattern = pattern.trim_matches('\n');
121    if pattern.is_empty() {
122        return Vec::new();
123    }
124    let pattern_lines: Vec<&str> = pattern.split('\n').collect();
125    let spans = line_spans(content);
126    let n = pattern_lines.len();
127
128    if n > spans.len() {
129        return Vec::new();
130    }
131
132    let mut ranges = Vec::new();
133    let mut i = 0;
134    while i + n <= spans.len() {
135        let all_match = pattern_lines
136            .iter()
137            .enumerate()
138            .all(|(j, &pl)| spans[i + j].1.trim_end() == pl.trim_end());
139
140        if all_match {
141            let byte_start = spans[i].0;
142            let last = &spans[i + n - 1];
143            let byte_end = last.0 + last.1.len();
144            ranges.push(byte_start..byte_end);
145            i += n; // skip past the match so occurrences don't overlap
146        } else {
147            i += 1;
148        }
149    }
150    ranges
151}
152
153/// Return the 1-based line number of the character at `byte_pos`.
154fn line_number_at(content: &str, byte_pos: usize) -> usize {
155    content[..byte_pos].chars().filter(|&c| c == '\n').count() + 1
156}
157
158/// Replace all `ranges` in `content` with `replacement`.
159///
160/// `ranges` must be sorted ascending and non-overlapping.
161fn replace_ranges(content: &str, ranges: &[Range<usize>], replacement: &str) -> String {
162    let mut out = String::with_capacity(content.len());
163    let mut cursor = 0;
164    for r in ranges {
165        out.push_str(&content[cursor..r.start]);
166        out.push_str(replacement);
167        cursor = r.end;
168    }
169    out.push_str(&content[cursor..]);
170    out
171}
172
173/// Apply a single [`EditOp`] to `content`, returning the modified string or
174/// an error message.
175fn apply_op(content: &str, op: &EditOp) -> Result<String, String> {
176    if op.old_string.is_empty() {
177        return Err("old_string must not be empty".to_owned());
178    }
179
180    // Prefer exact match; fall back to whitespace-normalised line matching.
181    let candidates: Vec<Range<usize>> = {
182        let exact = find_exact(content, &op.old_string);
183        if exact.is_empty() {
184            let norm = find_normalized(content, &op.old_string);
185            if norm.is_empty() {
186                return Err(format!(
187                    "old_string not found (tried exact and whitespace-normalised match):\n{}",
188                    op.old_string
189                ));
190            }
191            norm
192        } else {
193            exact
194        }
195    };
196
197    if op.replace_all {
198        return Ok(replace_ranges(content, &candidates, &op.new_string));
199    }
200
201    match candidates.len() {
202        0 => unreachable!("candidates is non-empty at this point"),
203        1 => Ok(replace_ranges(content, &candidates, &op.new_string)),
204        n => op.line_hint.map_or_else(
205            || {
206                Err(format!(
207                    "old_string matched {n} times; set replace_all to replace every \
208                     occurrence, or provide line_hint to select one"
209                ))
210            },
211            |hint| {
212                let best = candidates
213                    .iter()
214                    .min_by_key(|r| {
215                        let line =
216                            i64::try_from(line_number_at(content, r.start)).unwrap_or(i64::MAX);
217                        (line - i64::from(hint)).abs()
218                    })
219                    .expect("candidates is non-empty");
220                Ok(replace_ranges(
221                    content,
222                    std::slice::from_ref(best),
223                    &op.new_string,
224                ))
225            },
226        ),
227    }
228}
229
230// ---------------------------------------------------------------------------
231// Atomic write
232// ---------------------------------------------------------------------------
233
234/// Write `content` to `path` atomically: write to a sibling `.swink-edit.tmp`
235/// file then rename it over the target.  On most Unix filesystems `rename` is
236/// atomic when src and dst share a directory.
237async fn atomic_write(path: &std::path::Path, content: &str) -> std::io::Result<()> {
238    let tmp = {
239        let name = path
240            .file_name()
241            .unwrap_or_default()
242            .to_string_lossy()
243            .into_owned();
244        path.with_file_name(format!("{name}.swink-edit.tmp"))
245    };
246    tokio::fs::write(&tmp, content).await?;
247    tokio::fs::rename(&tmp, path).await
248}
249
250// ---------------------------------------------------------------------------
251// AgentTool impl
252// ---------------------------------------------------------------------------
253
254#[allow(clippy::unnecessary_literal_bound)]
255impl AgentTool for EditFileTool {
256    fn name(&self) -> &str {
257        "edit_file"
258    }
259
260    fn label(&self) -> &str {
261        "Edit File"
262    }
263
264    fn description(&self) -> &str {
265        "Apply one or more surgical find-and-replace edits to a file. \
266         Edits are applied top-to-bottom. Trailing whitespace is ignored \
267         during matching when an exact match is not found. The write is \
268         atomic: the file is never left in a partially-written state."
269    }
270
271    fn parameters_schema(&self) -> &Value {
272        &self.schema
273    }
274
275    fn requires_approval(&self) -> bool {
276        true
277    }
278
279    fn execute(
280        &self,
281        _tool_call_id: &str,
282        params: Value,
283        cancellation_token: CancellationToken,
284        _on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
285        _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
286        _credential: Option<crate::credential::ResolvedCredential>,
287    ) -> ToolFuture<'_> {
288        Box::pin(async move {
289            let parsed: Params = match serde_json::from_value(params) {
290                Ok(p) => p,
291                Err(e) => return AgentToolResult::error(format!("invalid parameters: {e}")),
292            };
293
294            if cancellation_token.is_cancelled() {
295                return AgentToolResult::error("cancelled");
296            }
297
298            let path = std::path::Path::new(&parsed.path);
299
300            let raw_bytes = match tokio::fs::read(path).await {
301                Ok(b) => b,
302                Err(e) => {
303                    return AgentToolResult::error(format!("failed to read {}: {e}", parsed.path));
304                }
305            };
306
307            let original = match std::str::from_utf8(&raw_bytes) {
308                Ok(s) => s.to_owned(),
309                Err(_) => {
310                    return AgentToolResult::error(format!("{} is not valid UTF-8", parsed.path));
311                }
312            };
313
314            // Stale-read check.
315            if let Some(expected) = &parsed.expected_hash {
316                let actual = sha256_hex(&raw_bytes);
317                if actual != expected.to_ascii_lowercase() {
318                    return AgentToolResult::error(format!(
319                        "{} has changed since it was last read (hash mismatch); \
320                         re-read the file before editing",
321                        parsed.path
322                    ));
323                }
324            }
325
326            if parsed.edits.is_empty() {
327                return AgentToolResult::text("no edits specified; file unchanged");
328            }
329
330            // Apply all edits in-memory (fail-fast — no partial writes).
331            let mut content = original.clone();
332            for (i, op) in parsed.edits.iter().enumerate() {
333                content = match apply_op(&content, op) {
334                    Ok(updated) => updated,
335                    Err(msg) => {
336                        return AgentToolResult::error(format!("edit {}: {msg}", i + 1));
337                    }
338                };
339            }
340
341            if cancellation_token.is_cancelled() {
342                return AgentToolResult::error("cancelled");
343            }
344
345            if let Err(e) = atomic_write(path, &content).await {
346                return AgentToolResult::error(format!("failed to write {}: {e}", parsed.path));
347            }
348
349            let n = parsed.edits.len();
350            AgentToolResult {
351                content: vec![ContentBlock::Text {
352                    text: format!(
353                        "Applied {} edit{} to {}",
354                        n,
355                        if n == 1 { "" } else { "s" },
356                        parsed.path
357                    ),
358                }],
359                details: serde_json::json!({
360                    "path": parsed.path,
361                    "edits_applied": n,
362                    "old_content": original,
363                    "new_content": content,
364                }),
365                is_error: false,
366                transfer_signal: None,
367            }
368        })
369    }
370}
371
372// ---------------------------------------------------------------------------
373// Tests
374// ---------------------------------------------------------------------------
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    // ── apply_op unit tests ──────────────────────────────────────────────────
381
382    #[test]
383    fn exact_single_replacement() {
384        let content = "hello world\n";
385        let op = EditOp {
386            old_string: "world".into(),
387            new_string: "Rust".into(),
388            replace_all: false,
389            line_hint: None,
390        };
391        assert_eq!(apply_op(content, &op).unwrap(), "hello Rust\n");
392    }
393
394    #[test]
395    fn normalised_trailing_whitespace_match() {
396        // File has trailing spaces; old_string does not — should still match.
397        let content = "fn foo() {   \n    let x = 1;\n}\n";
398        let op = EditOp {
399            old_string: "fn foo() {\n    let x = 1;\n}".into(),
400            new_string: "fn foo() {\n    let x = 2;\n}".into(),
401            replace_all: false,
402            line_hint: None,
403        };
404        assert_eq!(
405            apply_op(content, &op).unwrap(),
406            "fn foo() {\n    let x = 2;\n}\n"
407        );
408    }
409
410    #[test]
411    fn replace_all_occurrences() {
412        let content = "foo bar foo baz foo\n";
413        let op = EditOp {
414            old_string: "foo".into(),
415            new_string: "qux".into(),
416            replace_all: true,
417            line_hint: None,
418        };
419        assert_eq!(apply_op(content, &op).unwrap(), "qux bar qux baz qux\n");
420    }
421
422    #[test]
423    fn multiple_matches_without_hint_is_error() {
424        let content = "fn foo() {}\nfn foo() {}\n";
425        let op = EditOp {
426            old_string: "fn foo() {}".into(),
427            new_string: "fn bar() {}".into(),
428            replace_all: false,
429            line_hint: None,
430        };
431        let err = apply_op(content, &op).unwrap_err();
432        assert!(err.contains("matched 2 times"), "unexpected error: {err}");
433    }
434
435    #[test]
436    fn line_hint_picks_closest_match() {
437        // "fn foo() {}" appears on lines 1 and 3; hint=3 should pick line 3.
438        let content = "fn foo() {}\nfn bar() {}\nfn foo() {}\n";
439        let op = EditOp {
440            old_string: "fn foo() {}".into(),
441            new_string: "fn baz() {}".into(),
442            replace_all: false,
443            line_hint: Some(3),
444        };
445        assert_eq!(
446            apply_op(content, &op).unwrap(),
447            "fn foo() {}\nfn bar() {}\nfn baz() {}\n"
448        );
449    }
450
451    #[test]
452    fn not_found_returns_error() {
453        let content = "hello world\n";
454        let op = EditOp {
455            old_string: "missing".into(),
456            new_string: "x".into(),
457            replace_all: false,
458            line_hint: None,
459        };
460        assert!(apply_op(content, &op).is_err());
461    }
462
463    #[test]
464    fn empty_old_string_is_error() {
465        let op = EditOp {
466            old_string: String::new(),
467            new_string: "x".into(),
468            replace_all: false,
469            line_hint: None,
470        };
471        assert!(apply_op("anything", &op).is_err());
472    }
473
474    #[test]
475    fn multiple_edits_applied_in_order() {
476        let mut content = "a b c\n".to_owned();
477        let ops = [
478            EditOp {
479                old_string: "a".into(),
480                new_string: "1".into(),
481                replace_all: false,
482                line_hint: None,
483            },
484            EditOp {
485                old_string: "b".into(),
486                new_string: "2".into(),
487                replace_all: false,
488                line_hint: None,
489            },
490            EditOp {
491                old_string: "c".into(),
492                new_string: "3".into(),
493                replace_all: false,
494                line_hint: None,
495            },
496        ];
497        for op in &ops {
498            content = apply_op(&content, op).unwrap();
499        }
500        assert_eq!(content, "1 2 3\n");
501    }
502
503    // ── sha256_hex ───────────────────────────────────────────────────────────
504
505    #[test]
506    fn sha256_hex_known_value() {
507        // echo -n "abc" | sha256sum → ba7816bf…
508        let digest = sha256_hex(b"abc");
509        assert!(digest.starts_with("ba7816bf"), "got: {digest}");
510        assert_eq!(digest.len(), 64);
511    }
512
513    // ── Integration: execute via tempfile ────────────────────────────────────
514
515    #[tokio::test]
516    async fn execute_edits_file_and_returns_diff() {
517        use std::sync::{Arc, RwLock};
518
519        use serde_json::json;
520
521        use crate::SessionState;
522        use crate::tool::AgentTool;
523
524        let dir = tempfile::tempdir().unwrap();
525        let file = dir.path().join("test.txt");
526        tokio::fs::write(&file, "hello world\n").await.unwrap();
527
528        let tool = EditFileTool::new();
529        let params = json!({
530            "path": file.to_str().unwrap(),
531            "edits": [{ "old_string": "world", "new_string": "Rust" }]
532        });
533
534        let result = tool
535            .execute(
536                "id",
537                params,
538                CancellationToken::new(),
539                None,
540                Arc::new(RwLock::new(SessionState::default())),
541                None,
542            )
543            .await;
544
545        assert!(!result.is_error);
546        let on_disk = tokio::fs::read_to_string(&file).await.unwrap();
547        assert_eq!(on_disk, "hello Rust\n");
548        assert_eq!(result.details["old_content"], "hello world\n");
549        assert_eq!(result.details["new_content"], "hello Rust\n");
550    }
551
552    #[tokio::test]
553    async fn execute_rejects_stale_hash() {
554        use std::sync::{Arc, RwLock};
555
556        use serde_json::json;
557
558        use crate::SessionState;
559        use crate::tool::AgentTool;
560
561        let dir = tempfile::tempdir().unwrap();
562        let file = dir.path().join("test.txt");
563        tokio::fs::write(&file, "hello world\n").await.unwrap();
564
565        let tool = EditFileTool::new();
566        let params = json!({
567            "path": file.to_str().unwrap(),
568            "edits": [{ "old_string": "world", "new_string": "Rust" }],
569            "expected_hash": "0000000000000000000000000000000000000000000000000000000000000000"
570        });
571
572        let result = tool
573            .execute(
574                "id",
575                params,
576                CancellationToken::new(),
577                None,
578                Arc::new(RwLock::new(SessionState::default())),
579                None,
580            )
581            .await;
582
583        assert!(result.is_error);
584        let text = match &result.content[0] {
585            ContentBlock::Text { text } => text.clone(),
586            _ => panic!("expected text block"),
587        };
588        assert!(text.contains("hash mismatch"), "got: {text}");
589    }
590}