Skip to main content

sem_core/parser/
verify.rs

1//! Contract verification: check that callers pass the correct number of
2//! arguments to callees. Heuristic, not perfect, but catches obvious mismatches.
3
4use std::path::Path;
5
6use crate::model::entity::SemanticEntity;
7use crate::parser::graph::EntityGraph;
8use crate::parser::registry::ParserRegistry;
9
10#[derive(Debug, Clone)]
11pub struct ContractViolation {
12    pub entity_name: String,
13    pub file_path: String,
14    pub expected_params: usize,
15    pub caller_name: String,
16    pub caller_file: String,
17    pub actual_args: usize,
18}
19
20/// Verify function call contracts across the codebase.
21///
22/// For each `Calls` edge in the graph, extracts expected param count from
23/// the callee's first line and actual arg count from the call site in the
24/// caller's content. Flags mismatches.
25///
26/// If `target_file` is Some, only report violations for callees in that file.
27pub fn verify_contracts(
28    root: &Path,
29    file_paths: &[String],
30    registry: &ParserRegistry,
31    target_file: Option<&str>,
32) -> Vec<ContractViolation> {
33    let graph = EntityGraph::build(root, file_paths, registry);
34
35    // Build content map: entity_id -> content
36    let mut content_map: std::collections::HashMap<String, String> = std::collections::HashMap::new();
37    for fp in file_paths {
38        let full = root.join(fp);
39        let content = match std::fs::read_to_string(&full) {
40            Ok(c) => c,
41            Err(_) => continue,
42        };
43        let plugin = match registry.get_plugin(fp) {
44            Some(p) => p,
45            None => continue,
46        };
47        for entity in plugin.extract_entities(&content, fp) {
48            content_map.insert(entity.id.clone(), entity.content.clone());
49        }
50    }
51
52    let mut violations = Vec::new();
53
54    for edge in &graph.edges {
55        if edge.ref_type != crate::parser::graph::RefType::Calls {
56            continue;
57        }
58
59        let callee = match graph.entities.get(&edge.to_entity) {
60            Some(e) => e,
61            None => continue,
62        };
63
64        // Filter to target file if specified
65        if let Some(tf) = target_file {
66            if callee.file_path != tf {
67                continue;
68            }
69        }
70
71        // Only check functions/methods
72        if !matches!(
73            callee.entity_type.as_str(),
74            "function" | "method" | "arrow_function"
75        ) {
76            continue;
77        }
78
79        let callee_content = match content_map.get(&edge.to_entity) {
80            Some(c) => c,
81            None => continue,
82        };
83
84        let caller = match graph.entities.get(&edge.from_entity) {
85            Some(e) => e,
86            None => continue,
87        };
88
89        let caller_content = match content_map.get(&edge.from_entity) {
90            Some(c) => c,
91            None => continue,
92        };
93
94        let expected = extract_param_count(callee_content);
95        if expected == 0 {
96            continue; // can't verify zero-param functions meaningfully
97        }
98
99        if let Some(actual) = count_call_args(caller_content, &callee.name) {
100            if actual != expected {
101                violations.push(ContractViolation {
102                    entity_name: callee.name.clone(),
103                    file_path: callee.file_path.clone(),
104                    expected_params: expected,
105                    caller_name: caller.name.clone(),
106                    caller_file: caller.file_path.clone(),
107                    actual_args: actual,
108                });
109            }
110        }
111    }
112
113    violations
114}
115
116/// Like `verify_contracts`, but accepts a pre-built graph + entities to avoid
117/// redundant work when the caller already has them cached.
118pub fn verify_contracts_with_graph(
119    graph: &EntityGraph,
120    all_entities: &[SemanticEntity],
121    target_file: Option<&str>,
122) -> Vec<ContractViolation> {
123    let content_map: std::collections::HashMap<String, String> = all_entities
124        .iter()
125        .map(|e| (e.id.clone(), e.content.clone()))
126        .collect();
127
128    let mut violations = Vec::new();
129
130    for edge in &graph.edges {
131        if edge.ref_type != crate::parser::graph::RefType::Calls {
132            continue;
133        }
134
135        let callee = match graph.entities.get(&edge.to_entity) {
136            Some(e) => e,
137            None => continue,
138        };
139
140        if let Some(tf) = target_file {
141            if callee.file_path != tf {
142                continue;
143            }
144        }
145
146        if !matches!(
147            callee.entity_type.as_str(),
148            "function" | "method" | "arrow_function"
149        ) {
150            continue;
151        }
152
153        let callee_content = match content_map.get(&edge.to_entity) {
154            Some(c) => c,
155            None => continue,
156        };
157
158        let caller = match graph.entities.get(&edge.from_entity) {
159            Some(e) => e,
160            None => continue,
161        };
162
163        let caller_content = match content_map.get(&edge.from_entity) {
164            Some(c) => c,
165            None => continue,
166        };
167
168        let expected = extract_param_count(callee_content);
169        if expected == 0 {
170            continue;
171        }
172
173        if let Some(actual) = count_call_args(caller_content, &callee.name) {
174            if actual != expected {
175                violations.push(ContractViolation {
176                    entity_name: callee.name.clone(),
177                    file_path: callee.file_path.clone(),
178                    expected_params: expected,
179                    caller_name: caller.name.clone(),
180                    caller_file: caller.file_path.clone(),
181                    actual_args: actual,
182                });
183            }
184        }
185    }
186
187    violations
188}
189
190/// Extract param count from the first line of a function/method.
191/// Looks for the pattern `name(param1, param2, ...)` and counts commas + 1.
192fn extract_param_count(content: &str) -> usize {
193    let first_line = content.lines().next().unwrap_or("");
194
195    // Find the opening paren
196    let open = match first_line.find('(') {
197        Some(i) => i,
198        None => return 0,
199    };
200
201    // Find matching close paren (handle nested parens)
202    let after_open = &first_line[open + 1..];
203    let close = match find_matching_paren(after_open) {
204        Some(i) => i,
205        None => return 0,
206    };
207
208    let params_str = after_open[..close].trim();
209    if params_str.is_empty() {
210        return 0;
211    }
212
213    // Count params by splitting on commas at depth 0
214    count_top_level_commas(params_str) + 1
215}
216
217/// Count arguments at a call site: find `callee_name(...)` in content and count args.
218fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
219    let bytes = content.as_bytes();
220    let name_bytes = callee_name.as_bytes();
221    let mut search_start = 0;
222
223    while let Some(rel_pos) = content[search_start..].find(callee_name) {
224        let pos = search_start + rel_pos;
225        let after = pos + name_bytes.len();
226
227        // Check word boundary before
228        let is_boundary = pos == 0 || {
229            let prev = bytes[pos - 1];
230            !prev.is_ascii_alphanumeric() && prev != b'_'
231        };
232
233        // Check '(' follows
234        if is_boundary && after < bytes.len() && bytes[after] == b'(' {
235            let args_start = &content[after + 1..];
236            if let Some(close) = find_matching_paren(args_start) {
237                let args_str = args_start[..close].trim();
238                if args_str.is_empty() {
239                    return Some(0);
240                }
241                return Some(count_top_level_commas(args_str) + 1);
242            }
243        }
244
245        search_start = pos + 1;
246        while search_start < content.len() && !content.is_char_boundary(search_start) {
247            search_start += 1;
248        }
249    }
250
251    None
252}
253
254/// Find the position of the matching close paren, handling nesting.
255fn find_matching_paren(s: &str) -> Option<usize> {
256    let mut depth = 0i32;
257    for (i, ch) in s.char_indices() {
258        match ch {
259            '(' => depth += 1,
260            ')' => {
261                if depth == 0 {
262                    return Some(i);
263                }
264                depth -= 1;
265            }
266            _ => {}
267        }
268    }
269    None
270}
271
272/// Count commas at depth 0 (not inside nested parens/brackets).
273fn count_top_level_commas(s: &str) -> usize {
274    let mut depth = 0i32;
275    let mut count = 0;
276    for ch in s.chars() {
277        match ch {
278            '(' | '[' | '{' | '<' => depth += 1,
279            ')' | ']' | '}' | '>' => depth -= 1,
280            ',' if depth == 0 => count += 1,
281            _ => {}
282        }
283    }
284    count
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_extract_param_count_basic() {
293        assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
294        assert_eq!(extract_param_count("function foo() {"), 0);
295        assert_eq!(extract_param_count("def bar(self, x):"), 2);
296        assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
297    }
298
299    #[test]
300    fn test_extract_param_count_nested() {
301        assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
302    }
303
304    #[test]
305    fn test_count_call_args() {
306        assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
307        assert_eq!(count_call_args("foo()", "foo"), Some(0));
308        assert_eq!(count_call_args("bar(1)", "foo"), None);
309        assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
310    }
311
312    #[test]
313    fn test_count_call_args_multibyte_utf8() {
314        // Ensure no panic when content contains multi-byte UTF-8 characters before the call site
315        assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
316        assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
317        assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
318    }
319}