semantic_code_edit_mcp/validation/
context_validator.rs1use tree_sitter::{Node, Query, QueryCursor, StreamingIterator, Tree};
2
3pub struct ContextValidator;
5
6#[derive(Debug)]
7pub struct ValidationResult<'tree, 'source> {
8 pub is_valid: bool,
9 pub violations: Vec<ContextViolation<'tree>>,
10 pub source_code: &'source str,
11}
12
13#[derive(Debug)]
14pub struct ContextViolation<'tree> {
15 pub node: Node<'tree>,
16 pub message: String, pub suggestion: &'static str,
18}
19
20impl ContextValidator {
21 pub fn validate_tree<'tree, 'source>(
23 tree: &'tree Tree,
24 query: &Query,
25 source_code: &'source str,
26 ) -> ValidationResult<'tree, 'source> {
27 let mut cursor = QueryCursor::new();
29 let mut matches = cursor.matches(query, tree.root_node(), source_code.as_bytes());
30
31 let mut violations = Vec::new();
32
33 while let Some(m) = matches.next() {
34 for capture in m.captures {
35 let node = capture.node;
36
37 if let Some(violation_type) = Self::extract_violation_type(capture.index, query) {
39 if violation_type.starts_with("invalid.") {
41 violations.push(ContextViolation {
42 node,
43 message: Self::get_violation_message(&violation_type),
44 suggestion: Self::get_violation_suggestion(&violation_type),
45 });
46 }
47 }
48 }
49 }
50
51 ValidationResult {
52 is_valid: violations.is_empty(),
53 source_code,
54 violations,
55 }
56 }
57
58 fn extract_violation_type(capture_index: u32, query: &Query) -> Option<String> {
59 query
60 .capture_names()
61 .get(capture_index as usize)
62 .map(|s| s.to_string())
63 }
64
65 fn get_violation_message(violation_type: &str) -> String {
66 match violation_type {
67 "invalid.function.in.struct.fields" => {
68 "Functions cannot be defined inside struct field lists".to_string()
69 }
70 "invalid.function.in.enum.variants" => {
71 "Functions cannot be defined inside enum variant lists".to_string()
72 }
73 "invalid.type.in.function.body" => {
74 "Type definitions cannot be placed inside function bodies".to_string()
75 }
76 "invalid.impl.in.function.body" => {
77 "Impl blocks cannot be placed inside function bodies".to_string()
78 }
79 "invalid.trait.in.function.body" => {
80 "Trait definitions cannot be placed inside function bodies".to_string()
81 }
82 "invalid.impl.nested" => "Impl blocks can only be defined at module level".to_string(),
83 "invalid.trait.nested" => {
84 "Trait definitions can only be defined at module level".to_string()
85 }
86 "invalid.use.in.item.body" => "Use declarations should be at module level".to_string(),
87 "invalid.const.in.function.body" => {
88 "Const/static items should be at module level".to_string()
89 }
90 "invalid.mod.in.function.body" => {
91 "Module declarations cannot be inside function bodies".to_string()
92 }
93 "invalid.item.nested.in.item" => {
94 "Items cannot be nested inside other items".to_string()
95 }
96 "invalid.expression.as.type" => "Expressions cannot be used as types".to_string(),
97 _ => format!(
98 "Invalid placement: {}",
99 violation_type
100 .strip_prefix("invalid.")
101 .unwrap_or(violation_type)
102 ),
103 }
104 }
105
106 fn get_violation_suggestion(violation_type: &str) -> &'static str {
107 match violation_type {
108 "invalid.function.in.struct.fields" | "invalid.function.in.enum.variants" => {
109 "Place the function after the type definition"
110 }
111 "invalid.type.in.function.body"
112 | "invalid.impl.in.function.body"
113 | "invalid.trait.in.function.body" => "Move this to module level",
114
115 "invalid.use.in.item.body" => "Move use declarations to the top of the file",
116 _ => "Consider placing this construct in an appropriate context",
117 }
118 }
119}
120
121impl ValidationResult<'_, '_> {
122 pub fn format_errors(&self) -> String {
123 if self.is_valid {
124 return "✅ All validations passed".to_string();
125 }
126
127 let mut response = String::new();
128 response.push_str("❌ Invalid placement detected:\n\n");
129
130 for violation in &self.violations {
131 response.push_str(&format!("• {}:\n", violation.message));
132 let parent = violation.node.parent().unwrap_or(violation.node);
133 response.push_str(&self.source_code[parent.byte_range()]);
134 response.push_str("\n\n");
135 response.push_str(&format!(" 💡 Suggestion: {}\n", violation.suggestion));
136 }
137
138 response
139 }
140}