semantic_code_edit_mcp/validation/
context_validator.rs

1use tree_sitter::{Node, Query, QueryCursor, StreamingIterator, Tree};
2
3/// Tree-sitter based context validator for semantic code editing
4pub struct ContextValidator;
5
6#[derive(Debug)]
7pub struct ValidationResult<'tree, 'source> {
8    pub is_valid: bool,
9    pub violations: Vec<ContextViolation<'tree>>,
10    pub source_code: &'source str,
11}
12
13#[derive(Debug)]
14pub struct ContextViolation<'tree> {
15    pub node: Node<'tree>,
16    pub message: String, // Human-readable error
17    pub suggestion: &'static str,
18}
19
20impl ContextValidator {
21    /// Validate if content can be safely inserted at the target location
22    pub fn validate_tree<'tree, 'source>(
23        tree: &'tree Tree,
24        query: &Query,
25        source_code: &'source str,
26    ) -> ValidationResult<'tree, 'source> {
27        // Run validation queries against the temporary tree
28        let mut cursor = QueryCursor::new();
29        let mut matches = cursor.matches(query, tree.root_node(), source_code.as_bytes());
30
31        let mut violations = Vec::new();
32
33        while let Some(m) = matches.next() {
34            for capture in m.captures {
35                let node = capture.node;
36
37                // Extract violation type from capture name
38                if let Some(violation_type) = Self::extract_violation_type(capture.index, query) {
39                    // Only process "invalid" captures
40                    if violation_type.starts_with("invalid.") {
41                        violations.push(ContextViolation {
42                            node,
43                            message: Self::get_violation_message(&violation_type),
44                            suggestion: Self::get_violation_suggestion(&violation_type),
45                        });
46                    }
47                }
48            }
49        }
50
51        ValidationResult {
52            is_valid: violations.is_empty(),
53            source_code,
54            violations,
55        }
56    }
57
58    fn extract_violation_type(capture_index: u32, query: &Query) -> Option<String> {
59        query
60            .capture_names()
61            .get(capture_index as usize)
62            .map(|s| s.to_string())
63    }
64
65    fn get_violation_message(violation_type: &str) -> String {
66        match violation_type {
67            "invalid.function.in.struct.fields" => {
68                "Functions cannot be defined inside struct field lists".to_string()
69            }
70            "invalid.function.in.enum.variants" => {
71                "Functions cannot be defined inside enum variant lists".to_string()
72            }
73            "invalid.type.in.function.body" => {
74                "Type definitions cannot be placed inside function bodies".to_string()
75            }
76            "invalid.impl.in.function.body" => {
77                "Impl blocks cannot be placed inside function bodies".to_string()
78            }
79            "invalid.trait.in.function.body" => {
80                "Trait definitions cannot be placed inside function bodies".to_string()
81            }
82            "invalid.impl.nested" => "Impl blocks can only be defined at module level".to_string(),
83            "invalid.trait.nested" => {
84                "Trait definitions can only be defined at module level".to_string()
85            }
86            "invalid.use.in.item.body" => "Use declarations should be at module level".to_string(),
87            "invalid.const.in.function.body" => {
88                "Const/static items should be at module level".to_string()
89            }
90            "invalid.mod.in.function.body" => {
91                "Module declarations cannot be inside function bodies".to_string()
92            }
93            "invalid.item.nested.in.item" => {
94                "Items cannot be nested inside other items".to_string()
95            }
96            "invalid.expression.as.type" => "Expressions cannot be used as types".to_string(),
97            _ => format!(
98                "Invalid placement: {}",
99                violation_type
100                    .strip_prefix("invalid.")
101                    .unwrap_or(violation_type)
102            ),
103        }
104    }
105
106    fn get_violation_suggestion(violation_type: &str) -> &'static str {
107        match violation_type {
108            "invalid.function.in.struct.fields" | "invalid.function.in.enum.variants" => {
109                "Place the function after the type definition"
110            }
111            "invalid.type.in.function.body"
112            | "invalid.impl.in.function.body"
113            | "invalid.trait.in.function.body" => "Move this to module level",
114
115            "invalid.use.in.item.body" => "Move use declarations to the top of the file",
116            _ => "Consider placing this construct in an appropriate context",
117        }
118    }
119}
120
121impl ValidationResult<'_, '_> {
122    pub fn format_errors(&self) -> String {
123        if self.is_valid {
124            return "✅ All validations passed".to_string();
125        }
126
127        let mut response = String::new();
128        response.push_str("❌ Invalid placement detected:\n\n");
129
130        for violation in &self.violations {
131            response.push_str(&format!("• {}:\n", violation.message));
132            let parent = violation.node.parent().unwrap_or(violation.node);
133            response.push_str(&self.source_code[parent.byte_range()]);
134            response.push_str("\n\n");
135            response.push_str(&format!("  💡 Suggestion: {}\n", violation.suggestion));
136        }
137
138        response
139    }
140}