1use anyhow::{Context, Result};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use std::path::Path;
10use tree_sitter::{Node, Parser};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum SupportedLanguage {
15 Rust,
16 Python,
17 JavaScript,
18 TypeScript,
19 Go,
20 Java,
21 CSharp,
22 Cpp,
23 Ruby,
24}
25
26impl SupportedLanguage {
27 fn from_extension(ext: &str) -> Option<Self> {
28 match ext {
29 "rs" => Some(Self::Rust),
30 "py" => Some(Self::Python),
31 "js" | "mjs" => Some(Self::JavaScript),
32 "ts" | "tsx" => Some(Self::TypeScript),
33 "go" => Some(Self::Go),
34 "java" => Some(Self::Java),
35 "cs" => Some(Self::CSharp),
36 "cpp" | "cc" | "cxx" | "hpp" | "h" => Some(Self::Cpp),
37 "rb" => Some(Self::Ruby),
38 _ => None,
39 }
40 }
41
42 fn get_parser(&self) -> Result<Parser> {
43 use tree_sitter_language::LanguageFn;
44
45 let mut parser = Parser::new();
46 let language_fn: LanguageFn = match self {
47 Self::Rust => tree_sitter_rust::LANGUAGE,
48 Self::Python => tree_sitter_python::LANGUAGE,
49 Self::JavaScript => tree_sitter_javascript::LANGUAGE,
50 Self::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT,
51 Self::Go => tree_sitter_go::LANGUAGE,
52 Self::Java => tree_sitter_java::LANGUAGE,
53 Self::CSharp => tree_sitter_c_sharp::LANGUAGE,
54 Self::Cpp => tree_sitter_cpp::LANGUAGE,
55 Self::Ruby => tree_sitter_ruby::LANGUAGE,
56 };
57 let language = language_fn.into();
58 parser.set_language(&language)?;
59 Ok(parser)
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65#[serde(tag = "operation")]
66pub enum SmartEdit {
67 InsertFunction {
69 name: String,
70 #[serde(skip_serializing_if = "Option::is_none")]
71 class_name: Option<String>,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 namespace: Option<String>,
74 body: String,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 after: Option<String>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 before: Option<String>,
79 #[serde(default)]
80 visibility: String, },
82
83 ReplaceFunction {
85 name: String,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 class_name: Option<String>,
88 new_body: String,
89 },
90
91 AddImport {
93 import: String,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 alias: Option<String>,
96 },
97
98 InsertClass {
100 name: String,
101 #[serde(skip_serializing_if = "Option::is_none")]
102 namespace: Option<String>,
103 body: String,
104 #[serde(skip_serializing_if = "Option::is_none")]
105 extends: Option<String>,
106 #[serde(default)]
107 implements: Vec<String>,
108 },
109
110 AddMethod {
112 class_name: String,
113 method_name: String,
114 body: String,
115 #[serde(default)]
116 visibility: String,
117 },
118
119 WrapCode {
121 start_line: usize,
122 end_line: usize,
123 wrapper_type: String, #[serde(skip_serializing_if = "Option::is_none")]
125 condition: Option<String>,
126 },
127
128 DeleteElement {
130 element_type: String, name: String,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 parent: Option<String>,
134 },
135
136 Rename {
138 old_name: String,
139 new_name: String,
140 #[serde(default)]
141 scope: String, },
143
144 AddDocumentation {
146 target_type: String, target_name: String,
148 documentation: String,
149 },
150
151 SmartAppend {
153 section: String, content: String,
155 },
156
157 RemoveFunction {
159 name: String,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 class_name: Option<String>,
162 #[serde(default)]
163 force: bool, #[serde(default)]
165 cascade: bool, },
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct FunctionInfo {
172 pub name: String,
173 pub start_line: usize,
174 pub end_line: usize,
175 pub signature: String,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 pub class_name: Option<String>,
178 #[serde(skip_serializing_if = "Option::is_none")]
179 pub namespace: Option<String>,
180 pub visibility: String,
181 #[serde(default)]
182 pub calls: Vec<String>,
183 #[serde(default)]
184 pub called_by: Vec<String>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct CodeStructure {
190 pub language: String,
191 pub imports: Vec<String>,
192 pub functions: Vec<FunctionInfo>,
193 pub classes: Vec<ClassInfo>,
194 #[serde(skip_serializing_if = "Option::is_none")]
195 pub main_function: Option<String>,
196 pub line_count: usize,
197 #[serde(default)]
198 pub dependencies: DependencyGraph,
199}
200
201#[derive(Debug, Clone, Default, Serialize, Deserialize)]
203pub struct DependencyGraph {
204 pub calls: std::collections::HashMap<String, Vec<String>>,
206 pub called_by: std::collections::HashMap<String, Vec<String>>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ClassInfo {
212 pub name: String,
213 pub start_line: usize,
214 pub end_line: usize,
215 #[serde(skip_serializing_if = "Option::is_none")]
216 pub extends: Option<String>,
217 #[serde(default)]
218 pub implements: Vec<String>,
219 pub methods: Vec<FunctionInfo>,
220}
221
222pub struct SmartEditor {
224 content: String,
225 language: SupportedLanguage,
226 parser: Parser,
227 tree: Option<tree_sitter::Tree>,
228 structure: Option<CodeStructure>,
229}
230
231impl SmartEditor {
232 pub fn new(content: String, language: SupportedLanguage) -> Result<Self> {
233 let mut parser = language.get_parser()?;
234 let tree = parser.parse(&content, None);
235
236 let mut editor = Self {
237 content,
238 language,
239 parser,
240 tree,
241 structure: None,
242 };
243
244 editor.analyze_structure()?;
245 Ok(editor)
246 }
247
248 fn analyze_structure(&mut self) -> Result<()> {
250 let tree = self.tree.as_ref().context("No parse tree available")?;
251 let root = tree.root_node();
252
253 let mut structure = CodeStructure {
254 language: format!("{:?}", self.language),
255 imports: Vec::new(),
256 functions: Vec::new(),
257 classes: Vec::new(),
258 main_function: None,
259 line_count: self.content.lines().count(),
260 dependencies: DependencyGraph::default(),
261 };
262
263 self.walk_node(&root, &mut structure, None)?;
265
266 self.structure = Some(structure);
267 Ok(())
268 }
269
270 fn walk_node(
271 &self,
272 node: &Node,
273 structure: &mut CodeStructure,
274 current_class: Option<&str>,
275 ) -> Result<()> {
276 match node.kind() {
277 "use_declaration" => {
279 if let Some(text) = self.node_text(node) {
280 structure.imports.push(text);
281 }
282 }
283 "function_item"
284 | "method_definition"
285 | "function_definition"
286 | "function_declaration" => {
287 if let Some(func_info) = self.extract_function_info(node, current_class) {
288 if func_info.name == "main" {
289 structure.main_function = Some(func_info.name.clone());
290 }
291 structure.functions.push(func_info);
292 }
293 }
294 "struct_item" | "class_definition" | "class_declaration" => {
295 if let Some(class_info) = self.extract_class_info(node) {
296 structure.classes.push(class_info);
297 }
298 }
299 "import_statement" | "import_from_statement" => {
301 if let Some(text) = self.node_text(node) {
302 structure.imports.push(text);
303 }
304 }
305 _ => {}
306 }
307
308 let class_name = match node.kind() {
310 "class_definition" | "class_declaration" => {
311 self.find_child_by_kind(node, "identifier")
313 .and_then(|n| self.node_text(&n))
314 }
315 _ => None,
316 };
317
318 let class_context = class_name.as_deref().or(current_class);
319
320 for child in node.children(&mut node.walk()) {
322 self.walk_node(&child, structure, class_context)?;
323 }
324
325 Ok(())
326 }
327
328 fn node_text(&self, node: &Node) -> Option<String> {
329 node.utf8_text(self.content.as_bytes())
330 .ok()
331 .map(|s| s.to_string())
332 }
333
334 fn extract_function_info(&self, node: &Node, class_name: Option<&str>) -> Option<FunctionInfo> {
335 let name = self
336 .find_child_by_kind(node, "identifier")
337 .or_else(|| self.find_child_by_kind(node, "property_identifier"))
338 .and_then(|n| self.node_text(&n))?;
339
340 let start_line = node.start_position().row + 1;
341 let end_line = node.end_position().row + 1;
342
343 let signature = self.extract_signature(node)?;
344
345 Some(FunctionInfo {
346 name,
347 start_line,
348 end_line,
349 signature,
350 class_name: class_name.map(String::from),
351 namespace: None, visibility: self.extract_visibility(node),
353 calls: Vec::new(), called_by: Vec::new(),
355 })
356 }
357
358 fn extract_class_info(&self, node: &Node) -> Option<ClassInfo> {
359 let name = self
360 .find_child_by_kind(node, "identifier")
361 .or_else(|| self.find_child_by_kind(node, "type_identifier"))
362 .and_then(|n| self.node_text(&n))?;
363
364 let start_line = node.start_position().row + 1;
365 let end_line = node.end_position().row + 1;
366
367 let mut methods = Vec::new();
368 self.extract_methods(node, &name, &mut methods);
369
370 Some(ClassInfo {
371 name,
372 start_line,
373 end_line,
374 extends: None, implements: Vec::new(),
376 methods,
377 })
378 }
379
380 fn extract_methods(&self, node: &Node, class_name: &str, methods: &mut Vec<FunctionInfo>) {
381 for child in node.children(&mut node.walk()) {
382 if matches!(child.kind(), "method_definition" | "function_item") {
383 if let Some(method_info) = self.extract_function_info(&child, Some(class_name)) {
384 methods.push(method_info);
385 }
386 } else if child.kind().contains("body") {
387 self.extract_methods(&child, class_name, methods);
388 }
389 }
390 }
391
392 fn find_child_by_kind<'a>(&self, node: &'a Node, kind: &str) -> Option<Node<'a>> {
393 node.children(&mut node.walk()).find(|n| n.kind() == kind)
394 }
395
396 fn extract_signature(&self, node: &Node) -> Option<String> {
397 let start = node.start_byte();
399 let body_start = self
400 .find_child_by_kind(node, "block")
401 .or_else(|| self.find_child_by_kind(node, "body"))
402 .map(|n| n.start_byte())
403 .unwrap_or(node.end_byte());
404
405 self.content
406 .as_bytes()
407 .get(start..body_start)
408 .and_then(|bytes| std::str::from_utf8(bytes).ok())
409 .map(|s| s.trim().to_string())
410 }
411
412 fn extract_visibility(&self, node: &Node) -> String {
413 for child in node.children(&mut node.walk()) {
415 match child.kind() {
416 "visibility_modifier" => {
417 if let Some(text) = self.node_text(&child) {
418 return text;
419 }
420 }
421 "pub" => return "public".to_string(),
422 "private" => return "private".to_string(),
423 "protected" => return "protected".to_string(),
424 _ => {}
425 }
426 }
427 "private".to_string() }
429
430 pub fn apply_edit(&mut self, edit: &SmartEdit) -> Result<String> {
432 match edit {
433 SmartEdit::InsertFunction {
434 name,
435 class_name,
436 body,
437 after,
438 before,
439 visibility,
440 ..
441 } => {
442 self.insert_function(
443 name,
444 class_name.as_deref(),
445 body,
446 after.as_deref(),
447 before.as_deref(),
448 visibility,
449 )?;
450 }
451 SmartEdit::ReplaceFunction {
452 name,
453 class_name,
454 new_body,
455 } => {
456 self.replace_function(name, class_name.as_deref(), new_body)?;
457 }
458 SmartEdit::AddImport { import, alias } => {
459 self.add_import(import, alias.as_deref())?;
460 }
461 SmartEdit::SmartAppend { section, content } => {
462 self.smart_append(section, content)?;
463 }
464 SmartEdit::RemoveFunction {
465 name,
466 class_name,
467 force,
468 cascade,
469 } => {
470 self.remove_function(name, class_name.as_deref(), *force, *cascade)?;
471 }
472 _ => {
473 return Err(anyhow::anyhow!("Operation not yet implemented"));
474 }
475 }
476
477 self.tree = self.parser.parse(&self.content, None);
479 self.analyze_structure()?;
480
481 Ok(self.content.clone())
482 }
483
484 fn insert_function(
485 &mut self,
486 name: &str,
487 class_name: Option<&str>,
488 body: &str,
489 after: Option<&str>,
490 before: Option<&str>,
491 visibility: &str,
492 ) -> Result<()> {
493 let structure = self.structure.as_ref().context("No structure analyzed")?;
494
495 let insert_line = if let Some(after_name) = after {
497 structure
499 .functions
500 .iter()
501 .find(|f| f.name == after_name && f.class_name.as_deref() == class_name)
502 .map(|f| f.end_line + 1)
503 .with_context(|| format!("Function not found: {}", after_name))?
504 } else if let Some(before_name) = before {
505 structure
507 .functions
508 .iter()
509 .find(|f| f.name == before_name && f.class_name.as_deref() == class_name)
510 .map(|f| f.start_line.saturating_sub(1))
511 .with_context(|| format!("Function not found: {}", before_name))?
512 } else if let Some(class) = class_name {
513 structure
515 .classes
516 .iter()
517 .find(|c| c.name == class)
518 .map(|c| {
519 c.methods
521 .iter()
522 .map(|m| m.end_line)
523 .max()
524 .unwrap_or(c.start_line)
525 + 1
526 })
527 .context("Class not found: {class}")?
528 } else {
529 structure
531 .functions
532 .iter()
533 .filter(|f| f.class_name.is_none())
534 .map(|f| f.end_line)
535 .max()
536 .unwrap_or(structure.imports.len() + 1)
537 + 1
538 };
539
540 let formatted_function =
542 self.format_function(name, body, visibility, class_name.is_some())?;
543
544 let lines: Vec<&str> = self.content.lines().collect();
546 let mut new_lines: Vec<String> = Vec::new();
547
548 for (i, line) in lines.iter().enumerate() {
549 new_lines.push(line.to_string());
550 if i + 1 == insert_line {
551 new_lines.push(String::new());
552 new_lines.push(formatted_function.clone());
553 }
554 }
555
556 if insert_line > lines.len() {
558 new_lines.push(String::new());
559 new_lines.push(formatted_function);
560 }
561
562 self.content = new_lines.join("\n");
563 Ok(())
564 }
565
566 fn format_function(
567 &self,
568 name: &str,
569 body: &str,
570 visibility: &str,
571 is_method: bool,
572 ) -> Result<String> {
573 let formatted = match self.language {
575 SupportedLanguage::Rust => {
576 let vis = if visibility == "public" { "pub " } else { "" };
577 let indent = if is_method { " " } else { "" };
578 format!("{indent}{vis}fn {name}{body}")
579 }
580 SupportedLanguage::Python => {
581 let indent = if is_method { " " } else { "" };
582 format!("{indent}def {name}{body}")
583 }
584 SupportedLanguage::JavaScript | SupportedLanguage::TypeScript => {
585 let indent = if is_method { " " } else { "" };
586 format!("{indent}function {name}{body}")
587 }
588 _ => {
589 format!("{visibility} function {name}{body}")
590 }
591 };
592
593 Ok(formatted)
594 }
595
596 fn replace_function(
597 &mut self,
598 name: &str,
599 class_name: Option<&str>,
600 new_body: &str,
601 ) -> Result<()> {
602 let structure = self.structure.as_ref().context("No structure analyzed")?;
603
604 let function = structure
605 .functions
606 .iter()
607 .find(|f| f.name == name && f.class_name.as_deref() == class_name)
608 .context("Function not found")?;
609
610 let lines: Vec<&str> = self.content.lines().collect();
612 let signature_line = function.start_line - 1;
613
614 let body_start_line = signature_line + 1;
616 let body_end_line = function.end_line - 1;
617
618 let mut new_lines: Vec<String> = Vec::new();
620 for (i, line) in lines.iter().enumerate() {
621 if i < body_start_line || i > body_end_line {
622 new_lines.push(line.to_string());
623 } else if i == body_start_line {
624 new_lines.push(new_body.to_string());
625 }
626 }
627
628 self.content = new_lines.join("\n");
629 Ok(())
630 }
631
632 fn add_import(&mut self, import: &str, alias: Option<&str>) -> Result<()> {
633 let structure = self.structure.as_ref().context("No structure analyzed")?;
634
635 let formatted_import = match self.language {
637 SupportedLanguage::Rust => {
638 if let Some(alias) = alias {
639 format!("use {import} as {alias};")
640 } else {
641 format!("use {import};")
642 }
643 }
644 SupportedLanguage::Python => {
645 if let Some(alias) = alias {
646 format!("import {import} as {alias}")
647 } else {
648 format!("import {import}")
649 }
650 }
651 SupportedLanguage::JavaScript | SupportedLanguage::TypeScript => {
652 if let Some(a) = alias {
654 format!("const {} = require('{}');", a, import)
655 } else {
656 format!("const {} = require('{}');", import, import)
657 }
658 }
659 _ => format!("import {import};"),
660 };
661
662 let insert_line = if structure.imports.is_empty() {
664 1
665 } else {
666 structure.imports.len() + 1
667 };
668
669 let lines: Vec<&str> = self.content.lines().collect();
670 let mut new_lines: Vec<String> = Vec::new();
671
672 for (i, line) in lines.iter().enumerate() {
673 if i + 1 == insert_line {
674 new_lines.push(formatted_import.clone());
675 }
676 new_lines.push(line.to_string());
677 }
678
679 self.content = new_lines.join("\n");
680 Ok(())
681 }
682
683 fn smart_append(&mut self, section: &str, content: &str) -> Result<()> {
684 let structure = self.structure.as_ref().context("No structure analyzed")?;
685
686 let insert_line = match section {
687 "imports" => structure.imports.len() + 1,
688 "functions" => {
689 structure
690 .functions
691 .iter()
692 .filter(|f| f.class_name.is_none())
693 .map(|f| f.end_line)
694 .max()
695 .unwrap_or(structure.imports.len() + 1)
696 + 1
697 }
698 "classes" => {
699 structure
700 .classes
701 .iter()
702 .map(|c| c.end_line)
703 .max()
704 .unwrap_or_else(|| {
705 structure
706 .functions
707 .iter()
708 .map(|f| f.end_line)
709 .max()
710 .unwrap_or(structure.imports.len() + 1)
711 })
712 + 1
713 }
714 "main" => {
715 if let Some(main_fn) = &structure.main_function {
716 structure
717 .functions
718 .iter()
719 .find(|f| &f.name == main_fn)
720 .map(|f| f.end_line - 1)
721 .unwrap_or(structure.line_count)
722 } else {
723 structure.line_count
724 }
725 }
726 _ => structure.line_count,
727 };
728
729 let lines: Vec<&str> = self.content.lines().collect();
730 let mut new_lines: Vec<String> = Vec::new();
731
732 for (i, line) in lines.iter().enumerate() {
733 new_lines.push(line.to_string());
734 if i + 1 == insert_line {
735 new_lines.push(String::new());
736 new_lines.push(content.to_string());
737 }
738 }
739
740 self.content = new_lines.join("\n");
741 Ok(())
742 }
743
744 pub fn get_structure(&self) -> Option<&CodeStructure> {
746 self.structure.as_ref()
747 }
748
749 fn remove_function(
750 &mut self,
751 name: &str,
752 class_name: Option<&str>,
753 force: bool,
754 cascade: bool,
755 ) -> Result<()> {
756 let (function_start, function_end, functions_to_cascade) = {
758 let structure = self.structure.as_ref().context("No structure analyzed")?;
759
760 let function = structure
762 .functions
763 .iter()
764 .find(|f| f.name == name && f.class_name.as_deref() == class_name)
765 .context("Function not found")?;
766
767 if !force {
769 let dependents = structure
770 .dependencies
771 .called_by
772 .get(name)
773 .map(|v| v.as_slice())
774 .unwrap_or(&[]);
775
776 if !dependents.is_empty() {
777 return Err(anyhow::anyhow!(
778 "Function '{}' is called by: {}. Use force=true to remove anyway.",
779 name,
780 dependents.join(", ")
781 ));
782 }
783 }
784
785 let mut functions_to_cascade = Vec::new();
786
787 if cascade {
789 if let Some(calls) = structure.dependencies.calls.get(name) {
790 for called_func in calls {
791 if let Some(callers) = structure.dependencies.called_by.get(called_func) {
793 if callers.len() == 1 && callers[0] == name {
794 functions_to_cascade.push(called_func.clone());
795 }
796 }
797 }
798 }
799 }
800
801 (function.start_line, function.end_line, functions_to_cascade)
802 };
803
804 let lines: Vec<&str> = self.content.lines().collect();
806 let mut new_lines: Vec<String> = Vec::new();
807 let mut skip_lines = false;
808
809 for (i, line) in lines.iter().enumerate() {
810 let line_num = i + 1;
811
812 if line_num == function_start {
813 skip_lines = true;
814 }
815
816 if !skip_lines {
817 new_lines.push(line.to_string());
818 }
819
820 if line_num == function_end {
821 skip_lines = false;
822 }
823 }
824
825 self.content = new_lines.join("\n");
826
827 self.tree = self.parser.parse(&self.content, None);
829 self.analyze_structure()?;
830
831 for func_to_remove in functions_to_cascade {
833 self.remove_function(&func_to_remove, None, true, cascade)?;
834 }
835
836 Ok(())
837 }
838
839 pub fn get_function_tree(&self) -> Result<Value> {
841 let structure = self.structure.as_ref().context("No structure analyzed")?;
842
843 let tree = json!({
845 "language": format!("{:?}", self.language),
846 "file_structure": {
847 "imports": structure.imports,
848 "line_count": structure.line_count,
849 "main_function": structure.main_function,
850 },
851 "functions": structure.functions.iter().map(|f| {
852 json!({
853 "name": f.name,
854 "lines": format!("{}-{}", f.start_line, f.end_line),
855 "class": f.class_name,
856 "visibility": f.visibility,
857 "signature": f.signature,
858 "calls": f.calls,
859 "called_by": f.called_by,
860 })
861 }).collect::<Vec<_>>(),
862 "classes": structure.classes.iter().map(|c| {
863 json!({
864 "name": c.name,
865 "lines": format!("{}-{}", c.start_line, c.end_line),
866 "extends": c.extends,
867 "implements": c.implements,
868 "methods": c.methods.iter().map(|m| {
869 json!({
870 "name": m.name,
871 "lines": format!("{}-{}", m.start_line, m.end_line),
872 "visibility": m.visibility,
873 })
874 }).collect::<Vec<_>>(),
875 })
876 }).collect::<Vec<_>>(),
877 });
878
879 Ok(tree)
880 }
881}
882
883pub async fn handle_smart_edit(params: Option<Value>) -> Result<Value> {
885 let params = params.context("Parameters required")?;
886
887 let file_path = params["file_path"].as_str().context("file_path required")?;
888
889 let edits = params["edits"].as_array().context("edits array required")?;
890
891 let content = std::fs::read_to_string(file_path)?;
893 let original_content = content.clone(); let extension = Path::new(file_path)
895 .extension()
896 .and_then(|e| e.to_str())
897 .context("Could not determine file extension")?;
898
899 let language = SupportedLanguage::from_extension(extension).context("Unsupported language")?;
900
901 let mut editor = SmartEditor::new(content, language)?;
903
904 let initial_structure = editor.get_function_tree()?;
906
907 let mut results = Vec::new();
909 for edit in edits {
910 let smart_edit: SmartEdit = serde_json::from_value(edit.clone())?;
911 match editor.apply_edit(&smart_edit) {
912 Ok(_) => {
913 results.push(json!({
914 "status": "success",
915 "operation": edit["operation"],
916 }));
917 }
918 Err(e) => {
919 results.push(json!({
920 "status": "error",
921 "operation": edit["operation"],
922 "error": e.to_string(),
923 }));
924 }
925 }
926 }
927
928 let final_structure = editor.get_function_tree()?;
930
931 if let Ok(project_root) = std::env::current_dir() {
933 if let Ok(storage) = crate::smart_edit_diff::DiffStorage::new(&project_root) {
934 let _ = storage.store_diff(
936 Path::new(file_path),
937 &original_content, &editor.content, );
940
941 let _ = storage.store_original(Path::new(file_path), &original_content);
943 }
944 }
945
946 std::fs::write(file_path, &editor.content)?;
948
949 let result = json!({
950 "file_path": file_path,
951 "language": format!("{:?}", language),
952 "edits_applied": results,
953 "initial_structure": initial_structure,
954 "final_structure": final_structure,
955 "content_preview": editor.content.lines().take(20).collect::<Vec<_>>().join("\n"),
956 });
957
958 Ok(json!({
960 "content": [{
961 "type": "text",
962 "text": serde_json::to_string_pretty(&result)?
963 }]
964 }))
965}
966
967pub async fn handle_get_function_tree(params: Option<Value>) -> Result<Value> {
969 let params = params.context("Parameters required")?;
970 let file_path = params["file_path"].as_str().context("file_path required")?;
971
972 let content = std::fs::read_to_string(file_path)?;
973 let extension = Path::new(file_path)
974 .extension()
975 .and_then(|e| e.to_str())
976 .context("Could not determine file extension")?;
977
978 let language = SupportedLanguage::from_extension(extension).context("Unsupported language")?;
979
980 let editor = SmartEditor::new(content, language)?;
981 let function_tree = editor.get_function_tree()?;
982
983 Ok(json!({
985 "content": [{
986 "type": "text",
987 "text": serde_json::to_string_pretty(&function_tree)?
988 }]
989 }))
990}
991
992pub async fn handle_insert_function(params: Option<Value>) -> Result<Value> {
994 let params = params.context("Parameters required")?;
995
996 let edit = SmartEdit::InsertFunction {
997 name: params["name"]
998 .as_str()
999 .context("name required")?
1000 .to_string(),
1001 class_name: params["class_name"].as_str().map(String::from),
1002 namespace: params["namespace"].as_str().map(String::from),
1003 body: params["body"]
1004 .as_str()
1005 .context("body required")?
1006 .to_string(),
1007 after: params["after"].as_str().map(String::from),
1008 before: params["before"].as_str().map(String::from),
1009 visibility: params["visibility"]
1010 .as_str()
1011 .unwrap_or("private")
1012 .to_string(),
1013 };
1014
1015 handle_smart_edit(Some(json!({
1016 "file_path": params["file_path"],
1017 "edits": [edit],
1018 })))
1019 .await
1020}
1021
1022pub async fn handle_remove_function(params: Option<Value>) -> Result<Value> {
1024 let params = params.context("Parameters required")?;
1025
1026 let edit = SmartEdit::RemoveFunction {
1027 name: params["name"]
1028 .as_str()
1029 .context("name required")?
1030 .to_string(),
1031 class_name: params["class_name"].as_str().map(String::from),
1032 force: params["force"].as_bool().unwrap_or(false),
1033 cascade: params["cascade"].as_bool().unwrap_or(false),
1034 };
1035
1036 handle_smart_edit(Some(json!({
1037 "file_path": params["file_path"],
1038 "edits": [edit],
1039 })))
1040 .await
1041}
1042
1043pub async fn handle_create_file(params: Option<Value>) -> Result<Value> {
1045 let params = params.context("Parameters required")?;
1046
1047 let file_path = params["file_path"]
1048 .as_str()
1049 .context("file_path required")?;
1050
1051 let content = params["content"]
1052 .as_str()
1053 .unwrap_or(""); if Path::new(file_path).exists() {
1057 return Err(anyhow::anyhow!("File already exists: {}. Use edit operations to modify existing files.", file_path));
1058 }
1059
1060 if let Some(parent) = Path::new(file_path).parent() {
1062 if !parent.exists() {
1063 std::fs::create_dir_all(parent)
1064 .with_context(|| format!("Failed to create parent directories for: {}", file_path))?;
1065 }
1066 }
1067
1068 std::fs::write(file_path, content)
1070 .with_context(|| format!("Failed to create file: {}", file_path))?;
1071
1072 let result = json!({
1074 "status": "success",
1075 "file_path": file_path,
1076 "message": format!("File created: {}", file_path),
1077 "size": content.len(),
1078 });
1079
1080 let pretty = serde_json::to_string_pretty(&result).unwrap_or_else(|_| result.to_string());
1081
1082 Ok(json!({
1083 "content": [
1084 {
1085 "type": "text",
1086 "text": pretty
1087 }
1088 ]
1089 }))
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094 use super::*;
1095
1096 #[test]
1097 fn test_rust_function_insertion() {
1098 let content = r#"
1099use std::io;
1100
1101fn main() {
1102 println!("Hello, world!");
1103}
1104
1105fn helper() {
1106 println!("Helper");
1107}
1108"#
1109 .to_string();
1110
1111 let mut editor = SmartEditor::new(content, SupportedLanguage::Rust).unwrap();
1112 let edit = SmartEdit::InsertFunction {
1113 name: "new_function".to_string(),
1114 class_name: None,
1115 namespace: None,
1116 body: r#"() -> Result<()> {
1117 println!("New function!");
1118 Ok(())
1119}"#
1120 .to_string(),
1121 after: Some("main".to_string()),
1122 before: None,
1123 visibility: "public".to_string(),
1124 };
1125
1126 editor.apply_edit(&edit).unwrap();
1127 assert!(editor.content.contains("pub fn new_function"));
1128 assert!(
1129 editor.content.find("pub fn new_function").unwrap()
1130 > editor.content.find("fn main").unwrap()
1131 );
1132 }
1133
1134 #[test]
1135 fn test_python_function_insertion() {
1136 let content = r#"
1137import os
1138
1139def main():
1140 print("Hello, world!")
1141
1142def helper():
1143 print("Helper")
1144"#
1145 .to_string();
1146
1147 let mut editor = SmartEditor::new(content, SupportedLanguage::Python).unwrap();
1148 let edit = SmartEdit::InsertFunction {
1149 name: "process_data".to_string(),
1150 class_name: None,
1151 namespace: None,
1152 body: r#"(data):
1153 """Process the data."""
1154 return data * 2"#
1155 .to_string(),
1156 after: Some("main".to_string()),
1157 before: None,
1158 visibility: "public".to_string(),
1159 };
1160
1161 editor.apply_edit(&edit).unwrap();
1162 assert!(editor.content.contains("def process_data(data):"));
1163 assert!(editor.content.contains("return data * 2"));
1164 }
1165
1166 #[test]
1167 fn test_javascript_function_insertion() {
1168 let content = r#"
1169function main() {
1170 console.log("Hello, world!");
1171}
1172
1173function helper() {
1174 console.log("Helper");
1175}
1176"#
1177 .to_string();
1178
1179 let mut editor = SmartEditor::new(content, SupportedLanguage::JavaScript).unwrap();
1180 let edit = SmartEdit::InsertFunction {
1181 name: "processData".to_string(),
1182 class_name: None,
1183 namespace: None,
1184 body: r#"(data) {
1185 return data.map(x => x * 2);
1186}"#
1187 .to_string(),
1188 before: Some("helper".to_string()),
1189 after: None,
1190 visibility: "public".to_string(),
1191 };
1192
1193 editor.apply_edit(&edit).unwrap();
1194 assert!(editor.content.contains("function processData(data)"));
1195 assert!(editor.content.contains("return data.map(x => x * 2)"));
1196 }
1197
1198 #[test]
1199 fn test_add_import() {
1200 let content = r#"
1201use std::io;
1202
1203fn main() {
1204 println!("Hello");
1205}
1206"#
1207 .to_string();
1208
1209 let mut editor = SmartEditor::new(content, SupportedLanguage::Rust).unwrap();
1210 let edit = SmartEdit::AddImport {
1211 import: "std::collections::HashMap".to_string(),
1212 alias: None,
1213 };
1214
1215 editor.apply_edit(&edit).unwrap();
1216 assert!(editor.content.contains("use std::collections::HashMap;"));
1217
1218 let edit_with_alias = SmartEdit::AddImport {
1220 import: "std::sync::Arc".to_string(),
1221 alias: Some("MyArc".to_string()),
1222 };
1223
1224 editor.apply_edit(&edit_with_alias).unwrap();
1225 assert!(editor.content.contains("use std::sync::Arc as MyArc;"));
1226 }
1227
1228 #[test]
1229 fn test_replace_function() {
1230 let content = r#"
1231fn calculate(x: i32) -> i32 {
1232 x + 1
1233}
1234
1235fn main() {
1236 let result = calculate(5);
1237}
1238"#
1239 .to_string();
1240
1241 let mut editor = SmartEditor::new(content, SupportedLanguage::Rust).unwrap();
1242
1243 let _ = editor.analyze_structure();
1245
1246 let edit = SmartEdit::ReplaceFunction {
1247 name: "calculate".to_string(),
1248 class_name: None,
1249 new_body: r#"{
1250 // Improved calculation with logging
1251 println!("Calculating for: {}", x);
1252 x * 2
1253}"#
1254 .to_string(),
1255 };
1256
1257 editor.apply_edit(&edit).unwrap();
1258 assert!(editor.content.contains("x * 2"));
1259 assert!(editor.content.contains("Improved calculation"));
1260 assert!(!editor.content.contains("x + 1")); }
1262
1263 #[test]
1264 fn test_smart_append() {
1265 let content = r#"
1266import os
1267
1268def main():
1269 pass
1270
1271class MyClass:
1272 pass
1273"#
1274 .to_string();
1275
1276 let mut editor = SmartEditor::new(content, SupportedLanguage::Python).unwrap();
1277
1278 let import_edit = SmartEdit::SmartAppend {
1280 section: "imports".to_string(),
1281 content: "import sys".to_string(),
1282 };
1283
1284 editor.apply_edit(&import_edit).unwrap();
1285 assert!(editor.content.contains("import sys"));
1286
1287 let func_edit = SmartEdit::SmartAppend {
1289 section: "functions".to_string(),
1290 content: "def helper():\n return True".to_string(),
1291 };
1292
1293 editor.apply_edit(&func_edit).unwrap();
1294 assert!(editor.content.contains("def helper():"));
1295 }
1296
1297 #[test]
1298 fn test_remove_function_with_dependencies() {
1299 let content = r#"
1300fn caller() {
1301 helper();
1302}
1303
1304fn helper() {
1305 println!("I'm helping!");
1306}
1307
1308fn orphan() {
1309 // Only called by helper
1310}
1311"#
1312 .to_string();
1313
1314 let mut editor = SmartEditor::new(content, SupportedLanguage::Rust).unwrap();
1315
1316 editor.structure = Some(CodeStructure {
1318 language: "Rust".to_string(),
1319 imports: vec![],
1320 functions: vec![
1321 FunctionInfo {
1322 name: "caller".to_string(),
1323 class_name: None,
1324 namespace: None,
1325 start_line: 2,
1326 end_line: 4,
1327 signature: "fn caller()".to_string(),
1328 visibility: "private".to_string(),
1329 calls: vec!["helper".to_string()],
1330 called_by: vec![],
1331 },
1332 FunctionInfo {
1333 name: "helper".to_string(),
1334 class_name: None,
1335 namespace: None,
1336 start_line: 6,
1337 end_line: 8,
1338 signature: "fn helper()".to_string(),
1339 visibility: "private".to_string(),
1340 calls: vec!["orphan".to_string()],
1341 called_by: vec!["caller".to_string()],
1342 },
1343 FunctionInfo {
1344 name: "orphan".to_string(),
1345 class_name: None,
1346 namespace: None,
1347 start_line: 10,
1348 end_line: 12,
1349 signature: "fn orphan()".to_string(),
1350 visibility: "private".to_string(),
1351 calls: vec![],
1352 called_by: vec!["helper".to_string()],
1353 },
1354 ],
1355 classes: vec![],
1356 main_function: None,
1357 line_count: 12,
1358 dependencies: DependencyGraph {
1359 calls: [
1360 ("caller".to_string(), vec!["helper".to_string()]),
1361 ("helper".to_string(), vec!["orphan".to_string()]),
1362 ]
1363 .into_iter()
1364 .collect(),
1365 called_by: [
1366 ("helper".to_string(), vec!["caller".to_string()]),
1367 ("orphan".to_string(), vec!["helper".to_string()]),
1368 ]
1369 .into_iter()
1370 .collect(),
1371 },
1372 });
1373
1374 let remove_edit = SmartEdit::RemoveFunction {
1376 name: "helper".to_string(),
1377 class_name: None,
1378 force: false,
1379 cascade: false,
1380 };
1381
1382 let result = editor.apply_edit(&remove_edit);
1383 assert!(result.is_err());
1384 assert!(result
1385 .unwrap_err()
1386 .to_string()
1387 .contains("called by: caller"));
1388
1389 let force_remove = SmartEdit::RemoveFunction {
1391 name: "helper".to_string(),
1392 class_name: None,
1393 force: true,
1394 cascade: false,
1395 };
1396
1397 editor.apply_edit(&force_remove).unwrap();
1398 assert!(!editor.content.contains("fn helper()"));
1399 assert!(editor.content.contains("fn orphan()")); }
1401
1402 #[test]
1403 fn test_get_function_tree() {
1404 let content = r#"
1405class Calculator:
1406 def add(self, a, b):
1407 return a + b
1408
1409 def multiply(self, a, b):
1410 return self.add(a, b) * b
1411
1412def main():
1413 calc = Calculator()
1414 result = calc.add(5, 3)
1415"#
1416 .to_string();
1417
1418 let editor = SmartEditor::new(content, SupportedLanguage::Python).unwrap();
1419 let tree = editor.get_function_tree().unwrap();
1420
1421 assert!(tree["language"].as_str().unwrap().contains("Python"));
1423 assert!(tree["functions"].is_array());
1424 assert!(tree["classes"].is_array());
1425
1426 let functions = tree["functions"].as_array().unwrap();
1428 assert!(functions.iter().any(|f| f["name"] == "main"));
1429
1430 let classes = tree["classes"].as_array().unwrap();
1431 assert!(classes.iter().any(|c| c["name"] == "Calculator"));
1432 }
1433
1434 #[test]
1435 fn test_multiple_edits() {
1436 let content = r#"
1437fn main() {
1438 println!("Start");
1439}
1440"#
1441 .to_string();
1442
1443 let mut editor = SmartEditor::new(content, SupportedLanguage::Rust).unwrap();
1444
1445 let edits = vec![
1447 SmartEdit::AddImport {
1448 import: "std::thread".to_string(),
1449 alias: None,
1450 },
1451 SmartEdit::InsertFunction {
1452 name: "worker".to_string(),
1453 class_name: None,
1454 namespace: None,
1455 body: r#"() {
1456 thread::sleep(std::time::Duration::from_secs(1));
1457}"#
1458 .to_string(),
1459 after: Some("main".to_string()),
1460 before: None,
1461 visibility: "private".to_string(),
1462 },
1463 ];
1464
1465 for edit in edits {
1466 editor.apply_edit(&edit).unwrap();
1467 }
1468
1469 assert!(editor.content.contains("use std::thread;"));
1470 assert!(editor.content.contains("fn worker()"));
1471 }
1472
1473 #[test]
1474 fn test_class_method_insertion() {
1475 let content = r#"
1476class MyClass:
1477 def __init__(self):
1478 self.value = 0
1479
1480 def get_value(self):
1481 return self.value
1482"#
1483 .to_string();
1484
1485 let mut editor = SmartEditor::new(content, SupportedLanguage::Python).unwrap();
1486
1487 let edit = SmartEdit::InsertFunction {
1488 name: "set_value".to_string(),
1489 class_name: Some("MyClass".to_string()),
1490 namespace: None,
1491 body: r#"(self, value):
1492 self.value = value"#
1493 .to_string(),
1494 after: Some("get_value".to_string()),
1495 before: None,
1496 visibility: "public".to_string(),
1497 };
1498
1499 editor.apply_edit(&edit).unwrap();
1500 assert!(editor.content.contains("def set_value(self, value):"));
1501 assert!(editor.content.contains("self.value = value"));
1502 }
1503
1504 #[tokio::test]
1505 async fn test_create_file() {
1506 use tempfile::tempdir;
1507
1508 let dir = tempdir().unwrap();
1509 let test_file = dir.path().join("new_test.rs");
1510
1511 let params = json!({
1512 "file_path": test_file.to_str().unwrap(),
1513 "content": "// Test file\npub fn hello() {\n println!(\"Hello!\");\n}\n"
1514 });
1515
1516 let result = handle_create_file(Some(params.clone())).await;
1518 assert!(result.is_ok(), "Failed to create file: {:?}", result.err());
1519
1520 assert!(test_file.exists(), "File was not created");
1522
1523 let content = std::fs::read_to_string(&test_file).unwrap();
1525 assert!(content.contains("pub fn hello()"));
1526 assert!(content.contains("println!"));
1527
1528 let result2 = handle_create_file(Some(params)).await;
1530 assert!(result2.is_err(), "Should fail when file already exists");
1531 assert!(result2.unwrap_err().to_string().contains("already exists"));
1532 }
1533
1534 #[tokio::test]
1535 async fn test_create_file_with_parent_dirs() {
1536 use tempfile::tempdir;
1537
1538 let dir = tempdir().unwrap();
1539 let test_file = dir.path().join("subdir/nested/test.py");
1540
1541 let params = json!({
1542 "file_path": test_file.to_str().unwrap(),
1543 "content": "def main():\n print('Hello')\n"
1544 });
1545
1546 let result = handle_create_file(Some(params)).await;
1548 assert!(result.is_ok(), "Failed to create file with parent dirs: {:?}", result.err());
1549
1550 assert!(test_file.exists(), "File was not created");
1552 assert!(test_file.parent().unwrap().exists(), "Parent directory was not created");
1553
1554 let content = std::fs::read_to_string(&test_file).unwrap();
1556 assert!(content.contains("def main()"));
1557 }
1558
1559 #[tokio::test]
1560 async fn test_create_empty_file() {
1561 use tempfile::tempdir;
1562
1563 let dir = tempdir().unwrap();
1564 let test_file = dir.path().join("empty.txt");
1565
1566 let params = json!({
1567 "file_path": test_file.to_str().unwrap()
1568 });
1570
1571 let result = handle_create_file(Some(params)).await;
1572 assert!(result.is_ok(), "Failed to create empty file: {:?}", result.err());
1573
1574 assert!(test_file.exists());
1576 let content = std::fs::read_to_string(&test_file).unwrap();
1577 assert_eq!(content, "");
1578 }
1579}