Skip to main content

rustant_tools/
smart_edit.rs

1//! Semantic code edit tool with fuzzy location matching and diff preview.
2//!
3//! Accepts natural language descriptions of edit locations (e.g., "the function
4//! that handles authentication") and edit operations, then applies precise edits
5//! with unified diff output and auto-checkpoint support.
6
7use crate::checkpoint::CheckpointManager;
8use crate::registry::Tool;
9use async_trait::async_trait;
10use rustant_core::error::ToolError;
11use rustant_core::types::{Artifact, RiskLevel, ToolOutput};
12use similar::TextDiff;
13use std::path::{Path, PathBuf};
14use tokio::sync::Mutex;
15use tracing::debug;
16
17/// Smart editing tool that accepts fuzzy location descriptions and generates
18/// precise edits with diff preview and optional auto-checkpoint.
19pub struct SmartEditTool {
20    workspace: PathBuf,
21    checkpoint_mgr: Mutex<CheckpointManager>,
22}
23
24impl SmartEditTool {
25    pub fn new(workspace: PathBuf) -> Self {
26        let checkpoint_mgr = CheckpointManager::new(workspace.clone());
27        Self {
28            workspace,
29            checkpoint_mgr: Mutex::new(checkpoint_mgr),
30        }
31    }
32}
33
34/// Supported edit operation types.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36enum EditType {
37    /// Replace matched text with new text.
38    Replace,
39    /// Insert new text after the matched location.
40    InsertAfter,
41    /// Insert new text before the matched location.
42    InsertBefore,
43    /// Delete the matched text.
44    Delete,
45}
46
47impl EditType {
48    fn from_str(s: &str) -> Option<Self> {
49        match s.to_lowercase().as_str() {
50            "replace" => Some(Self::Replace),
51            "insert_after" | "insert-after" => Some(Self::InsertAfter),
52            "insert_before" | "insert-before" => Some(Self::InsertBefore),
53            "delete" | "remove" => Some(Self::Delete),
54            _ => None,
55        }
56    }
57}
58
59/// A located match within a file.
60#[derive(Debug)]
61#[allow(dead_code)]
62struct LocationMatch {
63    /// Start byte offset in file.
64    start: usize,
65    /// End byte offset in file.
66    end: usize,
67    /// The matched text.
68    matched_text: String,
69    /// Line number (1-based) where match starts.
70    line_number: usize,
71    /// Context lines around the match for preview.
72    context_preview: String,
73}
74
75/// Find a location in file content using a search pattern.
76/// Supports exact text, line-number patterns ("line 42"), and fuzzy substring matching.
77fn find_location(content: &str, pattern: &str) -> Result<LocationMatch, String> {
78    // Strategy 1: Try exact match first
79    if let Some(start) = content.find(pattern) {
80        let end = start + pattern.len();
81        let line_number = content[..start].matches('\n').count() + 1;
82        let preview = extract_context(content, start, end, 2);
83        return Ok(LocationMatch {
84            start,
85            end,
86            matched_text: pattern.to_string(),
87            line_number,
88            context_preview: preview,
89        });
90    }
91
92    // Strategy 2: Line number pattern (e.g., "line 42", "lines 10-20")
93    if let Some(m) = parse_line_pattern(pattern) {
94        return find_by_line_range(content, m.0, m.1);
95    }
96
97    // Strategy 3: Function/method name pattern (e.g., "fn handle_request", "def process")
98    if let Some(m) = find_by_function_pattern(content, pattern) {
99        return Ok(m);
100    }
101
102    // Strategy 4: Fuzzy line-by-line matching using similarity scoring
103    if let Some(m) = find_by_fuzzy_match(content, pattern) {
104        return Ok(m);
105    }
106
107    Err(format!(
108        "Could not locate '{}' in the file. Try using exact text, a line number (e.g., 'line 42'), or a function name.",
109        truncate(pattern, 80)
110    ))
111}
112
113/// Parse "line N" or "lines N-M" patterns.
114fn parse_line_pattern(pattern: &str) -> Option<(usize, usize)> {
115    let lower = pattern.trim().to_lowercase();
116
117    // "line 42"
118    if let Some(rest) = lower.strip_prefix("line ")
119        && let Ok(n) = rest.trim().parse::<usize>()
120    {
121        return Some((n, n));
122    }
123
124    // "lines 10-20"
125    if let Some(rest) = lower.strip_prefix("lines ") {
126        let parts: Vec<&str> = rest.split('-').collect();
127        if parts.len() == 2
128            && let (Ok(a), Ok(b)) = (
129                parts[0].trim().parse::<usize>(),
130                parts[1].trim().parse::<usize>(),
131            )
132        {
133            return Some((a, b));
134        }
135    }
136
137    None
138}
139
140/// Find a location by line range.
141fn find_by_line_range(
142    content: &str,
143    start_line: usize,
144    end_line: usize,
145) -> Result<LocationMatch, String> {
146    let lines: Vec<&str> = content.lines().collect();
147    let total = lines.len();
148
149    if start_line == 0 || start_line > total {
150        return Err(format!(
151            "Line {} is out of range (file has {} lines)",
152            start_line, total
153        ));
154    }
155
156    let end_line = end_line.min(total);
157    let start_idx = start_line - 1;
158    let end_idx = end_line;
159
160    // Calculate byte offsets
161    let mut byte_offset = 0;
162    let mut start_byte = 0;
163    let mut end_byte = content.len();
164
165    for (i, line) in content.lines().enumerate() {
166        if i == start_idx {
167            start_byte = byte_offset;
168        }
169        byte_offset += line.len() + 1; // +1 for newline
170        if i + 1 == end_idx {
171            end_byte = byte_offset.min(content.len());
172        }
173    }
174
175    let matched = &content[start_byte..end_byte];
176    let preview = extract_context(content, start_byte, end_byte, 1);
177
178    Ok(LocationMatch {
179        start: start_byte,
180        end: end_byte,
181        matched_text: matched.to_string(),
182        line_number: start_line,
183        context_preview: preview,
184    })
185}
186
187/// Find a function or block by pattern matching common language constructs.
188fn find_by_function_pattern(content: &str, pattern: &str) -> Option<LocationMatch> {
189    let pattern_lower = pattern.to_lowercase();
190
191    // Common function signature prefixes
192    let fn_prefixes = [
193        "fn ",
194        "def ",
195        "func ",
196        "function ",
197        "pub fn ",
198        "async fn ",
199        "pub async fn ",
200        "impl ",
201        "class ",
202        "struct ",
203        "enum ",
204    ];
205
206    // Check if pattern looks like a function reference
207    let is_fn_pattern = fn_prefixes.iter().any(|p| pattern_lower.starts_with(p))
208        || pattern_lower.starts_with("the ")
209        || pattern_lower.contains(" function")
210        || pattern_lower.contains(" method");
211
212    if !is_fn_pattern {
213        return None;
214    }
215
216    // Extract function name from pattern
217    let name = extract_identifier_from_pattern(&pattern_lower);
218    if name.is_empty() {
219        return None;
220    }
221
222    // Search for function definition containing this name
223    for (i, line) in content.lines().enumerate() {
224        let line_lower = line.to_lowercase();
225        let has_fn_keyword = fn_prefixes.iter().any(|p| line_lower.contains(p));
226
227        if has_fn_keyword && line_lower.contains(&name) {
228            let byte_start: usize = content.lines().take(i).map(|l| l.len() + 1).sum();
229
230            // Find the end of the function block (matching braces or indentation)
231            let block_end = find_block_end(content, byte_start);
232
233            let matched = &content[byte_start..block_end];
234            let preview = extract_context(content, byte_start, block_end, 0);
235
236            return Some(LocationMatch {
237                start: byte_start,
238                end: block_end,
239                matched_text: matched.to_string(),
240                line_number: i + 1,
241                context_preview: preview,
242            });
243        }
244    }
245
246    None
247}
248
249/// Extract a likely identifier name from a natural language pattern.
250fn extract_identifier_from_pattern(pattern: &str) -> String {
251    // Common language keywords to skip
252    const KEYWORDS: &[&str] = &[
253        "fn",
254        "def",
255        "func",
256        "function",
257        "pub",
258        "async",
259        "impl",
260        "class",
261        "struct",
262        "enum",
263        "let",
264        "const",
265        "var",
266        "type",
267        "trait",
268        "interface",
269        "the",
270        "a",
271        "an",
272        "in",
273        "of",
274        "for",
275        "with",
276        "from",
277        "to",
278    ];
279
280    // Remove common wrappers
281    let cleaned = pattern
282        .replace("the ", "")
283        .replace(" function", "")
284        .replace(" method", "")
285        .replace(" that ", " ")
286        .replace("called ", "");
287
288    // Try to find a snake_case or camelCase identifier (skip keywords)
289    for word in cleaned.split_whitespace() {
290        let w = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '_');
291        if w.len() >= 2
292            && !KEYWORDS.contains(&w)
293            && (w.contains('_')
294                || w.chars().any(|c| c.is_uppercase())
295                || w.chars().all(|c| c.is_alphanumeric() || c == '_'))
296        {
297            return w.to_string();
298        }
299    }
300
301    // Fall back to last significant word (that isn't a keyword)
302    cleaned
303        .split_whitespace()
304        .rev()
305        .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric() && c != '_'))
306        .find(|w| w.len() >= 2 && !KEYWORDS.contains(w))
307        .unwrap_or("")
308        .to_string()
309}
310
311/// Find the end of a code block starting at the given byte offset.
312/// Uses brace matching for C-like languages, or indentation for Python-like.
313fn find_block_end(content: &str, start: usize) -> usize {
314    let rest = &content[start..];
315    let lines: Vec<&str> = rest.lines().collect();
316
317    if lines.is_empty() {
318        return content.len();
319    }
320
321    // Check if the block uses braces
322    let first_line = lines[0];
323    let has_opening_brace = first_line.contains('{');
324
325    if has_opening_brace {
326        // Brace matching
327        let mut depth = 0;
328        let mut byte_pos = start;
329        for ch in content[start..].chars() {
330            match ch {
331                '{' => depth += 1,
332                '}' => {
333                    depth -= 1;
334                    if depth == 0 {
335                        return byte_pos + ch.len_utf8();
336                    }
337                }
338                _ => {}
339            }
340            byte_pos += ch.len_utf8();
341        }
342        content.len()
343    } else {
344        // Indentation-based: find where indentation returns to same level
345        let base_indent = first_line.len() - first_line.trim_start().len();
346        let mut end = start + first_line.len() + 1;
347
348        for line in lines.iter().skip(1) {
349            if line.trim().is_empty() {
350                end += line.len() + 1;
351                continue;
352            }
353            let indent = line.len() - line.trim_start().len();
354            if indent <= base_indent {
355                break;
356            }
357            end += line.len() + 1;
358        }
359
360        end.min(content.len())
361    }
362}
363
364/// Fuzzy match by finding the line with highest similarity to the pattern.
365/// Uses both bigram similarity and word containment for better matching.
366fn find_by_fuzzy_match(content: &str, pattern: &str) -> Option<LocationMatch> {
367    let pattern_lower = pattern.to_lowercase();
368    let pattern_words: Vec<&str> = pattern_lower.split_whitespace().collect();
369    let mut best_score = 0.0_f64;
370    let mut best_line_idx = None;
371
372    for (i, line) in content.lines().enumerate() {
373        let line_lower = line.to_lowercase();
374        let line_trimmed = line_lower.trim();
375        if line_trimmed.is_empty() {
376            continue;
377        }
378
379        // Combined score: bigram similarity + word containment bonus
380        let bigram_score = similarity_score(line_trimmed, &pattern_lower);
381        let word_score = pattern_words
382            .iter()
383            .filter(|w| line_trimmed.contains(**w))
384            .count() as f64
385            / pattern_words.len().max(1) as f64;
386
387        let score = (bigram_score + word_score) / 2.0;
388        if score > best_score && score > 0.25 {
389            best_score = score;
390            best_line_idx = Some(i);
391        }
392    }
393
394    let line_idx = best_line_idx?;
395    let byte_start: usize = content.lines().take(line_idx).map(|l| l.len() + 1).sum();
396    let line_text = content.lines().nth(line_idx)?;
397    let byte_end = byte_start + line_text.len();
398    let preview = extract_context(content, byte_start, byte_end, 2);
399
400    Some(LocationMatch {
401        start: byte_start,
402        end: byte_end,
403        matched_text: line_text.to_string(),
404        line_number: line_idx + 1,
405        context_preview: preview,
406    })
407}
408
409/// Simple similarity score between two strings (Jaccard on character bigrams).
410fn similarity_score(a: &str, b: &str) -> f64 {
411    if a.is_empty() || b.is_empty() {
412        return 0.0;
413    }
414
415    let bigrams_a: std::collections::HashSet<(char, char)> =
416        a.chars().zip(a.chars().skip(1)).collect();
417    let bigrams_b: std::collections::HashSet<(char, char)> =
418        b.chars().zip(b.chars().skip(1)).collect();
419
420    if bigrams_a.is_empty() || bigrams_b.is_empty() {
421        return 0.0;
422    }
423
424    let intersection = bigrams_a.intersection(&bigrams_b).count() as f64;
425    let union = bigrams_a.union(&bigrams_b).count() as f64;
426
427    intersection / union
428}
429
430/// Extract context lines around a byte range.
431fn extract_context(content: &str, start: usize, end: usize, context_lines: usize) -> String {
432    let lines: Vec<&str> = content.lines().collect();
433    let start_line = content[..start].lines().count().saturating_sub(1);
434    let end_line = content[..end].lines().count();
435
436    let from = start_line.saturating_sub(context_lines);
437    let to = (end_line + context_lines).min(lines.len());
438
439    lines[from..to]
440        .iter()
441        .enumerate()
442        .map(|(i, line)| format!("{:4} | {}", from + i + 1, line))
443        .collect::<Vec<_>>()
444        .join("\n")
445}
446
447/// Generate a unified diff between old and new content.
448fn generate_diff(path: &str, old: &str, new: &str) -> String {
449    let diff = TextDiff::from_lines(old, new);
450    let mut output = String::new();
451
452    output.push_str(&format!("--- a/{}\n", path));
453    output.push_str(&format!("+++ b/{}\n", path));
454
455    for hunk in diff.unified_diff().context_radius(3).iter_hunks() {
456        output.push_str(&format!("{}", hunk));
457    }
458
459    output
460}
461
462/// Apply an edit operation to content.
463fn apply_edit(
464    content: &str,
465    location: &LocationMatch,
466    edit_type: EditType,
467    new_text: &str,
468) -> String {
469    match edit_type {
470        EditType::Replace => {
471            let mut result = String::with_capacity(content.len());
472            result.push_str(&content[..location.start]);
473            result.push_str(new_text);
474            result.push_str(&content[location.end..]);
475            result
476        }
477        EditType::InsertAfter => {
478            let mut result = String::with_capacity(content.len() + new_text.len());
479            result.push_str(&content[..location.end]);
480            if !new_text.starts_with('\n') && !content[..location.end].ends_with('\n') {
481                result.push('\n');
482            }
483            result.push_str(new_text);
484            result.push_str(&content[location.end..]);
485            result
486        }
487        EditType::InsertBefore => {
488            let mut result = String::with_capacity(content.len() + new_text.len());
489            result.push_str(&content[..location.start]);
490            result.push_str(new_text);
491            if !new_text.ends_with('\n') && !content[location.start..].starts_with('\n') {
492                result.push('\n');
493            }
494            result.push_str(&content[location.start..]);
495            result
496        }
497        EditType::Delete => {
498            let mut result = String::with_capacity(content.len());
499            result.push_str(&content[..location.start]);
500            result.push_str(&content[location.end..]);
501            result
502        }
503    }
504}
505
506/// Truncate a string for display.
507fn truncate(s: &str, max_len: usize) -> String {
508    if s.len() <= max_len {
509        s.to_string()
510    } else {
511        format!("{}...", &s[..max_len.saturating_sub(3)])
512    }
513}
514
515/// Validate that a path stays inside the workspace.
516fn validate_workspace_path(workspace: &Path, path_str: &str) -> Result<PathBuf, ToolError> {
517    let workspace_canonical = workspace
518        .canonicalize()
519        .unwrap_or_else(|_| workspace.to_path_buf());
520
521    let resolved = if Path::new(path_str).is_absolute() {
522        PathBuf::from(path_str)
523    } else {
524        workspace_canonical.join(path_str)
525    };
526
527    if resolved.exists() {
528        let canonical = resolved
529            .canonicalize()
530            .map_err(|e| ToolError::ExecutionFailed {
531                name: "smart_edit".into(),
532                message: format!("Path resolution failed: {}", e),
533            })?;
534
535        if !canonical.starts_with(&workspace_canonical) {
536            return Err(ToolError::PermissionDenied {
537                name: "smart_edit".into(),
538                reason: format!("Path '{}' is outside the workspace", path_str),
539            });
540        }
541        return Ok(canonical);
542    }
543
544    // Non-existent path: normalize components
545    let mut normalized = Vec::new();
546    for component in resolved.components() {
547        match component {
548            std::path::Component::ParentDir => {
549                if normalized.pop().is_none() {
550                    return Err(ToolError::PermissionDenied {
551                        name: "smart_edit".into(),
552                        reason: format!("Path '{}' escapes the workspace", path_str),
553                    });
554                }
555            }
556            std::path::Component::CurDir => {}
557            other => normalized.push(other),
558        }
559    }
560    let normalized_path: PathBuf = normalized.iter().collect();
561
562    if !normalized_path.starts_with(&workspace_canonical) {
563        return Err(ToolError::PermissionDenied {
564            name: "smart_edit".into(),
565            reason: format!("Path '{}' is outside the workspace", path_str),
566        });
567    }
568
569    Ok(resolved)
570}
571
572#[async_trait]
573impl Tool for SmartEditTool {
574    fn name(&self) -> &str {
575        "smart_edit"
576    }
577
578    fn description(&self) -> &str {
579        "Smart code editor that accepts fuzzy location descriptions (function names, \
580         line numbers, search patterns) and edit types (replace, insert_after, \
581         insert_before, delete). Creates an auto-checkpoint before writing and \
582         returns a unified diff preview."
583    }
584
585    fn parameters_schema(&self) -> serde_json::Value {
586        serde_json::json!({
587            "type": "object",
588            "properties": {
589                "path": {
590                    "type": "string",
591                    "description": "Path to the file to edit (relative to workspace)"
592                },
593                "location": {
594                    "type": "string",
595                    "description": "Where to apply the edit. Supports: exact text to match, \
596                        'line N' or 'lines N-M', function/method names (e.g. 'fn handle_request'), \
597                        or fuzzy descriptions."
598                },
599                "edit_type": {
600                    "type": "string",
601                    "enum": ["replace", "insert_after", "insert_before", "delete"],
602                    "description": "Type of edit to perform"
603                },
604                "new_text": {
605                    "type": "string",
606                    "description": "The new text (required for replace, insert_after, insert_before; \
607                        omit for delete)"
608                }
609            },
610            "required": ["path", "location", "edit_type"]
611        })
612    }
613
614    async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
615        let path_str = args["path"]
616            .as_str()
617            .ok_or_else(|| ToolError::InvalidArguments {
618                name: "smart_edit".into(),
619                reason: "'path' parameter is required".into(),
620            })?;
621
622        let location_str =
623            args["location"]
624                .as_str()
625                .ok_or_else(|| ToolError::InvalidArguments {
626                    name: "smart_edit".into(),
627                    reason: "'location' parameter is required".into(),
628                })?;
629
630        let edit_type_str =
631            args["edit_type"]
632                .as_str()
633                .ok_or_else(|| ToolError::InvalidArguments {
634                    name: "smart_edit".into(),
635                    reason: "'edit_type' parameter is required".into(),
636                })?;
637
638        let edit_type = EditType::from_str(edit_type_str).ok_or_else(|| {
639            ToolError::InvalidArguments {
640                name: "smart_edit".into(),
641                reason: format!(
642                    "Invalid edit_type '{}'. Must be one of: replace, insert_after, insert_before, delete",
643                    edit_type_str
644                ),
645            }
646        })?;
647
648        let new_text = args["new_text"].as_str().unwrap_or("");
649
650        if edit_type != EditType::Delete && new_text.is_empty() {
651            return Err(ToolError::InvalidArguments {
652                name: "smart_edit".into(),
653                reason: "'new_text' is required for replace and insert operations".into(),
654            });
655        }
656
657        // Validate path
658        let _ = validate_workspace_path(&self.workspace, path_str)?;
659        let path = self.workspace.join(path_str);
660
661        // Read file
662        let content =
663            tokio::fs::read_to_string(&path)
664                .await
665                .map_err(|e| ToolError::ExecutionFailed {
666                    name: "smart_edit".into(),
667                    message: format!("Failed to read '{}': {}", path_str, e),
668                })?;
669
670        // Find the location
671        let location =
672            find_location(&content, location_str).map_err(|e| ToolError::ExecutionFailed {
673                name: "smart_edit".into(),
674                message: e,
675            })?;
676
677        debug!(
678            "smart_edit: matched at line {} ({} bytes)",
679            location.line_number,
680            location.matched_text.len()
681        );
682
683        // Apply the edit
684        let new_content = apply_edit(&content, &location, edit_type, new_text);
685
686        // Generate diff
687        let diff = generate_diff(path_str, &content, &new_content);
688
689        // Create checkpoint before writing
690        let checkpoint_result = {
691            let mut mgr = self.checkpoint_mgr.lock().await;
692            mgr.create_checkpoint(&format!("before smart_edit on {}", path_str))
693        };
694
695        if let Err(e) = &checkpoint_result {
696            debug!("Checkpoint creation failed (non-fatal): {}", e);
697        }
698
699        // Write the file
700        tokio::fs::write(&path, &new_content)
701            .await
702            .map_err(|e| ToolError::ExecutionFailed {
703                name: "smart_edit".into(),
704                message: format!("Failed to write '{}': {}", path_str, e),
705            })?;
706
707        // Build output
708        let edit_desc = match edit_type {
709            EditType::Replace => "replaced",
710            EditType::InsertAfter => "inserted after",
711            EditType::InsertBefore => "inserted before",
712            EditType::Delete => "deleted",
713        };
714
715        let checkpoint_note = if checkpoint_result.is_ok() {
716            " (checkpoint created, use /undo to revert)"
717        } else {
718            ""
719        };
720
721        let summary = format!(
722            "Edited '{}': {} at line {}{}\n\nDiff:\n{}",
723            path_str, edit_desc, location.line_number, checkpoint_note, diff
724        );
725
726        let mut output = ToolOutput::text(summary);
727        output.artifacts.push(Artifact::FileModified {
728            path: PathBuf::from(path_str),
729            diff,
730        });
731
732        Ok(output)
733    }
734
735    fn risk_level(&self) -> RiskLevel {
736        RiskLevel::Write
737    }
738}
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743    use std::fs;
744    use tempfile::TempDir;
745
746    #[test]
747    fn test_find_location_exact() {
748        let content = "fn main() {\n    println!(\"hello\");\n}\n";
749        let loc = find_location(content, "println!(\"hello\")").unwrap();
750        assert_eq!(loc.line_number, 2);
751        assert_eq!(loc.matched_text, "println!(\"hello\")");
752    }
753
754    #[test]
755    fn test_find_location_line_number() {
756        let content = "line one\nline two\nline three\n";
757        let loc = find_location(content, "line 2").unwrap();
758        assert_eq!(loc.line_number, 2);
759        assert!(loc.matched_text.contains("line two"));
760    }
761
762    #[test]
763    fn test_find_location_line_range() {
764        let content = "a\nb\nc\nd\ne\n";
765        let loc = find_location(content, "lines 2-4").unwrap();
766        assert_eq!(loc.line_number, 2);
767        assert!(loc.matched_text.contains('b'));
768        assert!(loc.matched_text.contains('c'));
769        assert!(loc.matched_text.contains('d'));
770    }
771
772    #[test]
773    fn test_find_location_function_pattern() {
774        let content = "use std::io;\n\nfn handle_request(req: Request) {\n    process(req);\n}\n\nfn main() {}\n";
775        let loc = find_location(content, "fn handle_request").unwrap();
776        assert_eq!(loc.line_number, 3);
777        assert!(loc.matched_text.contains("handle_request"));
778    }
779
780    #[test]
781    fn test_find_location_fuzzy() {
782        let content = "struct Config {\n    timeout: u64,\n    retries: usize,\n}\n";
783        let loc = find_location(content, "timeout field").unwrap();
784        assert!(loc.matched_text.contains("timeout"));
785    }
786
787    #[test]
788    fn test_find_location_not_found() {
789        let content = "hello world\n";
790        let result = find_location(content, "nonexistent_xyz_123");
791        assert!(result.is_err());
792    }
793
794    #[test]
795    fn test_apply_edit_replace() {
796        let content = "fn old_name() {}\n";
797        let loc = find_location(content, "old_name").unwrap();
798        let result = apply_edit(content, &loc, EditType::Replace, "new_name");
799        assert!(result.contains("new_name"));
800        assert!(!result.contains("old_name"));
801    }
802
803    #[test]
804    fn test_apply_edit_insert_after() {
805        let content = "use std::io;\n\nfn main() {}\n";
806        let loc = find_location(content, "use std::io;").unwrap();
807        let result = apply_edit(content, &loc, EditType::InsertAfter, "use std::fs;");
808        assert!(result.contains("use std::io;\nuse std::fs;"));
809    }
810
811    #[test]
812    fn test_apply_edit_insert_before() {
813        let content = "fn main() {}\n";
814        let loc = find_location(content, "fn main").unwrap();
815        let result = apply_edit(content, &loc, EditType::InsertBefore, "// Entry point\n");
816        assert!(result.starts_with("// Entry point\n"));
817    }
818
819    #[test]
820    fn test_apply_edit_delete() {
821        let content = "line1\nline2\nline3\n";
822        let loc = find_location(content, "line2").unwrap();
823        let result = apply_edit(content, &loc, EditType::Delete, "");
824        assert!(!result.contains("line2"));
825        assert!(result.contains("line1"));
826        assert!(result.contains("line3"));
827    }
828
829    #[test]
830    fn test_generate_diff() {
831        let old = "line1\nline2\nline3\n";
832        let new = "line1\nmodified\nline3\n";
833        let diff = generate_diff("test.rs", old, new);
834        assert!(diff.contains("--- a/test.rs"));
835        assert!(diff.contains("+++ b/test.rs"));
836        assert!(diff.contains("-line2"));
837        assert!(diff.contains("+modified"));
838    }
839
840    #[test]
841    fn test_similarity_score() {
842        let a = "handle_request";
843        let b = "handle_request";
844        assert!((similarity_score(a, b) - 1.0).abs() < 0.01);
845
846        let c = "handle_response";
847        let score = similarity_score(a, c);
848        assert!(score > 0.3); // Similar but not identical
849
850        let d = "totally_different_thing";
851        let score2 = similarity_score(a, d);
852        assert!(score2 < score); // Less similar
853    }
854
855    #[test]
856    fn test_edit_type_from_str() {
857        assert_eq!(EditType::from_str("replace"), Some(EditType::Replace));
858        assert_eq!(
859            EditType::from_str("insert_after"),
860            Some(EditType::InsertAfter)
861        );
862        assert_eq!(
863            EditType::from_str("insert-before"),
864            Some(EditType::InsertBefore)
865        );
866        assert_eq!(EditType::from_str("delete"), Some(EditType::Delete));
867        assert_eq!(EditType::from_str("remove"), Some(EditType::Delete));
868        assert_eq!(EditType::from_str("unknown"), None);
869    }
870
871    #[test]
872    fn test_parse_line_pattern() {
873        assert_eq!(parse_line_pattern("line 42"), Some((42, 42)));
874        assert_eq!(parse_line_pattern("lines 10-20"), Some((10, 20)));
875        assert_eq!(parse_line_pattern("not a line pattern"), None);
876    }
877
878    #[test]
879    fn test_extract_identifier() {
880        assert_eq!(
881            extract_identifier_from_pattern("fn handle_request"),
882            "handle_request"
883        );
884        assert_eq!(
885            extract_identifier_from_pattern("the process_data function"),
886            "process_data"
887        );
888    }
889
890    #[test]
891    fn test_find_block_end_braces() {
892        let content = "fn foo() {\n    bar();\n    baz();\n}\nfn next() {}";
893        let end = find_block_end(content, 0);
894        let block = &content[0..end];
895        assert!(block.contains("baz();"));
896        assert!(block.ends_with('}'));
897    }
898
899    #[test]
900    fn test_truncate() {
901        assert_eq!(truncate("short", 10), "short");
902        assert_eq!(truncate("a long string here", 10), "a long ...");
903    }
904
905    #[tokio::test]
906    async fn test_smart_edit_tool_execute_replace() {
907        let dir = TempDir::new().unwrap();
908        let workspace = dir.path().to_path_buf();
909
910        // Initialize git repo for checkpoint
911        git2::Repository::init(&workspace).unwrap();
912
913        // Create a test file
914        fs::write(
915            workspace.join("test.rs"),
916            "fn old_name() {\n    // body\n}\n",
917        )
918        .unwrap();
919
920        // Initial commit so checkpoint works
921        let repo = git2::Repository::open(&workspace).unwrap();
922        let mut index = repo.index().unwrap();
923        index
924            .add_all(["*"].iter(), git2::IndexAddOption::DEFAULT, None)
925            .unwrap();
926        index.write().unwrap();
927        let tree_oid = index.write_tree().unwrap();
928        let tree = repo.find_tree(tree_oid).unwrap();
929        let sig = git2::Signature::now("test", "test@test.com").unwrap();
930        repo.commit(Some("HEAD"), &sig, &sig, "init", &tree, &[])
931            .unwrap();
932
933        let tool = SmartEditTool::new(workspace.clone());
934
935        let args = serde_json::json!({
936            "path": "test.rs",
937            "location": "old_name",
938            "edit_type": "replace",
939            "new_text": "new_name"
940        });
941
942        let result = tool.execute(args).await.unwrap();
943        assert!(result.content.contains("Edited"));
944        assert!(result.content.contains("replaced"));
945
946        // Verify file was modified
947        let content = fs::read_to_string(workspace.join("test.rs")).unwrap();
948        assert!(content.contains("new_name"));
949        assert!(!content.contains("old_name"));
950    }
951
952    #[tokio::test]
953    async fn test_smart_edit_tool_execute_delete() {
954        let dir = TempDir::new().unwrap();
955        let workspace = dir.path().to_path_buf();
956
957        git2::Repository::init(&workspace).unwrap();
958        fs::write(
959            workspace.join("test.txt"),
960            "keep this\ndelete this line\nkeep this too\n",
961        )
962        .unwrap();
963
964        let repo = git2::Repository::open(&workspace).unwrap();
965        let mut index = repo.index().unwrap();
966        index
967            .add_all(["*"].iter(), git2::IndexAddOption::DEFAULT, None)
968            .unwrap();
969        index.write().unwrap();
970        let tree_oid = index.write_tree().unwrap();
971        let tree = repo.find_tree(tree_oid).unwrap();
972        let sig = git2::Signature::now("test", "test@test.com").unwrap();
973        repo.commit(Some("HEAD"), &sig, &sig, "init", &tree, &[])
974            .unwrap();
975
976        let tool = SmartEditTool::new(workspace.clone());
977
978        let args = serde_json::json!({
979            "path": "test.txt",
980            "location": "delete this line",
981            "edit_type": "delete"
982        });
983
984        let result = tool.execute(args).await.unwrap();
985        assert!(result.content.contains("deleted"));
986
987        let content = fs::read_to_string(workspace.join("test.txt")).unwrap();
988        assert!(!content.contains("delete this line"));
989        assert!(content.contains("keep this"));
990    }
991
992    #[tokio::test]
993    async fn test_smart_edit_tool_line_number() {
994        let dir = TempDir::new().unwrap();
995        let workspace = dir.path().to_path_buf();
996
997        git2::Repository::init(&workspace).unwrap();
998        fs::write(workspace.join("test.txt"), "line 1\nline 2\nline 3\n").unwrap();
999
1000        let repo = git2::Repository::open(&workspace).unwrap();
1001        let mut index = repo.index().unwrap();
1002        index
1003            .add_all(["*"].iter(), git2::IndexAddOption::DEFAULT, None)
1004            .unwrap();
1005        index.write().unwrap();
1006        let tree_oid = index.write_tree().unwrap();
1007        let tree = repo.find_tree(tree_oid).unwrap();
1008        let sig = git2::Signature::now("test", "test@test.com").unwrap();
1009        repo.commit(Some("HEAD"), &sig, &sig, "init", &tree, &[])
1010            .unwrap();
1011
1012        let tool = SmartEditTool::new(workspace.clone());
1013
1014        let args = serde_json::json!({
1015            "path": "test.txt",
1016            "location": "line 2",
1017            "edit_type": "replace",
1018            "new_text": "replaced line\n"
1019        });
1020
1021        let result = tool.execute(args).await.unwrap();
1022        assert!(result.content.contains("replaced"));
1023
1024        let content = fs::read_to_string(workspace.join("test.txt")).unwrap();
1025        assert!(content.contains("replaced line"));
1026        assert!(!content.contains("line 2"));
1027    }
1028
1029    #[tokio::test]
1030    async fn test_smart_edit_tool_missing_new_text() {
1031        let dir = TempDir::new().unwrap();
1032        let workspace = dir.path().to_path_buf();
1033        let tool = SmartEditTool::new(workspace);
1034
1035        let args = serde_json::json!({
1036            "path": "test.txt",
1037            "location": "something",
1038            "edit_type": "replace"
1039        });
1040
1041        let result = tool.execute(args).await;
1042        assert!(result.is_err());
1043    }
1044
1045    #[tokio::test]
1046    async fn test_smart_edit_tool_invalid_edit_type() {
1047        let dir = TempDir::new().unwrap();
1048        let workspace = dir.path().to_path_buf();
1049        let tool = SmartEditTool::new(workspace);
1050
1051        let args = serde_json::json!({
1052            "path": "test.txt",
1053            "location": "something",
1054            "edit_type": "invalid_op",
1055            "new_text": "hello"
1056        });
1057
1058        let result = tool.execute(args).await;
1059        assert!(result.is_err());
1060    }
1061}