Skip to main content

sem_core/parser/
verify.rs

1//! Contract verification: check that callers pass the correct number of
2//! arguments to callees. Uses tree-sitter AST for accurate param/arg counting.
3
4use std::collections::HashMap;
5use std::path::Path;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityGraph, RefType};
9use crate::parser::plugins::code::languages::get_language_config;
10use crate::parser::registry::ParserRegistry;
11
12#[derive(Debug, Clone)]
13pub struct ContractViolation {
14    pub entity_name: String,
15    pub file_path: String,
16    pub expected_params: usize,
17    pub caller_name: String,
18    pub caller_file: String,
19    pub actual_args: usize,
20}
21
22/// Result of tree-sitter based parameter analysis.
23#[derive(Debug, Clone)]
24pub struct ParamInfo {
25    pub min_params: usize,
26    pub max_params: usize,
27    pub is_variadic: bool,
28}
29
30/// Arity mismatch found across the dependency graph.
31#[derive(Debug, Clone)]
32pub struct ArityMismatch {
33    pub caller_entity: String,
34    pub callee_entity: String,
35    pub expected_min: usize,
36    pub expected_max: usize,
37    pub actual_args: usize,
38    pub file_path: String,
39    pub line: usize,
40    pub is_variadic: bool,
41}
42
43/// Verify function call contracts across the codebase.
44pub fn verify_contracts(
45    root: &Path,
46    file_paths: &[String],
47    registry: &ParserRegistry,
48    target_file: Option<&str>,
49) -> Vec<ContractViolation> {
50    let graph = EntityGraph::build(root, file_paths, registry);
51
52    let mut content_map: HashMap<String, String> = HashMap::new();
53    for fp in file_paths {
54        let full = root.join(fp);
55        let content = match std::fs::read_to_string(&full) {
56            Ok(c) => c,
57            Err(_) => continue,
58        };
59        let plugin = match registry.get_plugin_with_content(fp, &content) {
60            Some(p) => p,
61            None => continue,
62        };
63        for entity in plugin.extract_entities(&content, fp) {
64            content_map.insert(entity.id.clone(), entity.content.clone());
65        }
66    }
67
68    let mut violations = Vec::new();
69
70    for edge in &graph.edges {
71        if edge.ref_type != RefType::Calls {
72            continue;
73        }
74
75        let callee = match graph.entities.get(&edge.to_entity) {
76            Some(e) => e,
77            None => continue,
78        };
79
80        if let Some(tf) = target_file {
81            if callee.file_path != tf {
82                continue;
83            }
84        }
85
86        if !matches!(
87            callee.entity_type.as_str(),
88            "function" | "method" | "arrow_function"
89        ) {
90            continue;
91        }
92
93        let callee_content = match content_map.get(&edge.to_entity) {
94            Some(c) => c,
95            None => continue,
96        };
97
98        let caller = match graph.entities.get(&edge.from_entity) {
99            Some(e) => e,
100            None => continue,
101        };
102
103        let caller_content = match content_map.get(&edge.from_entity) {
104            Some(c) => c,
105            None => continue,
106        };
107
108        let expected = extract_param_count(callee_content);
109        if expected == 0 {
110            continue;
111        }
112
113        if let Some(actual) = count_call_args(caller_content, &callee.name) {
114            if actual != expected {
115                violations.push(ContractViolation {
116                    entity_name: callee.name.clone(),
117                    file_path: callee.file_path.clone(),
118                    expected_params: expected,
119                    caller_name: caller.name.clone(),
120                    caller_file: caller.file_path.clone(),
121                    actual_args: actual,
122                });
123            }
124        }
125    }
126
127    violations
128}
129
130/// Like `verify_contracts`, but accepts a pre-built graph + entities.
131pub fn verify_contracts_with_graph(
132    graph: &EntityGraph,
133    all_entities: &[SemanticEntity],
134    target_file: Option<&str>,
135) -> Vec<ContractViolation> {
136    let content_map: HashMap<String, String> = all_entities
137        .iter()
138        .map(|e| (e.id.clone(), e.content.clone()))
139        .collect();
140
141    let mut violations = Vec::new();
142
143    for edge in &graph.edges {
144        if edge.ref_type != RefType::Calls {
145            continue;
146        }
147
148        let callee = match graph.entities.get(&edge.to_entity) {
149            Some(e) => e,
150            None => continue,
151        };
152
153        if let Some(tf) = target_file {
154            if callee.file_path != tf {
155                continue;
156            }
157        }
158
159        if !matches!(
160            callee.entity_type.as_str(),
161            "function" | "method" | "arrow_function"
162        ) {
163            continue;
164        }
165
166        let callee_content = match content_map.get(&edge.to_entity) {
167            Some(c) => c,
168            None => continue,
169        };
170
171        let caller = match graph.entities.get(&edge.from_entity) {
172            Some(e) => e,
173            None => continue,
174        };
175
176        let caller_content = match content_map.get(&edge.from_entity) {
177            Some(c) => c,
178            None => continue,
179        };
180
181        let expected = extract_param_count(callee_content);
182        if expected == 0 {
183            continue;
184        }
185
186        if let Some(actual) = count_call_args(caller_content, &callee.name) {
187            if actual != expected {
188                violations.push(ContractViolation {
189                    entity_name: callee.name.clone(),
190                    file_path: callee.file_path.clone(),
191                    expected_params: expected,
192                    caller_name: caller.name.clone(),
193                    caller_file: caller.file_path.clone(),
194                    actual_args: actual,
195                });
196            }
197        }
198    }
199
200    violations
201}
202
203// ─── Tree-sitter based arity analysis ───────────────────────────────────────
204
205fn lang_from_ext(ext: &str) -> &'static str {
206    match ext {
207        ".py" | ".pyi" => "python",
208        ".ts" | ".tsx" | ".mts" | ".cts" => "typescript",
209        ".js" | ".jsx" | ".mjs" | ".cjs" => "typescript",
210        ".rs" => "rust",
211        ".go" => "go",
212        _ => "unknown",
213    }
214}
215
216/// Extract parameter info from entity content using tree-sitter.
217pub fn extract_param_info_ts(content: &str, file_path: &str) -> Option<ParamInfo> {
218    let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
219    let lang = lang_from_ext(ext);
220    if lang == "unknown" {
221        return None;
222    }
223    let config = get_language_config(ext)?;
224    let language = (config.get_language)()?;
225
226    let mut parser = tree_sitter::Parser::new();
227    let _ = parser.set_language(&language);
228    let tree = parser.parse(content.as_bytes(), None)?;
229
230    extract_param_info_from_node(tree.root_node(), content.as_bytes(), lang)
231}
232
233fn extract_param_info_from_node(
234    root: tree_sitter::Node,
235    source: &[u8],
236    lang: &str,
237) -> Option<ParamInfo> {
238    // Find the first function-like node
239    let func_node = find_first_function(root)?;
240    let params_node = func_node.child_by_field_name("parameters")?;
241
242    let mut min_params = 0usize;
243    let mut max_params = 0usize;
244    let mut is_variadic = false;
245
246    let mut cursor = params_node.walk();
247    for child in params_node.named_children(&mut cursor) {
248        let kind = child.kind();
249        match lang {
250            "python" => {
251                if kind == "identifier" {
252                    let name = child.utf8_text(source).unwrap_or("");
253                    if name == "self" || name == "cls" {
254                        continue;
255                    }
256                    min_params += 1;
257                    max_params += 1;
258                } else if kind == "typed_parameter" {
259                    let name = child
260                        .child_by_field_name("name")
261                        .or_else(|| child.named_child(0))
262                        .and_then(|n| n.utf8_text(source).ok())
263                        .unwrap_or("");
264                    if name == "self" || name == "cls" {
265                        continue;
266                    }
267                    min_params += 1;
268                    max_params += 1;
269                } else if kind == "default_parameter" || kind == "typed_default_parameter" {
270                    max_params += 1;
271                } else if kind == "list_splat_pattern" || kind == "dictionary_splat_pattern" {
272                    is_variadic = true;
273                }
274            }
275            "typescript" => {
276                if kind == "required_parameter" {
277                    min_params += 1;
278                    max_params += 1;
279                } else if kind == "optional_parameter" {
280                    max_params += 1;
281                } else if kind == "rest_pattern" {
282                    is_variadic = true;
283                }
284            }
285            "rust" => {
286                if kind == "parameter" {
287                    let pat = child
288                        .child_by_field_name("pattern")
289                        .and_then(|n| n.utf8_text(source).ok())
290                        .unwrap_or("");
291                    // Skip self/&self/&mut self
292                    let base = pat.trim_start_matches('&').trim();
293                    let base = base.strip_prefix("mut ").unwrap_or(base).trim();
294                    if base == "self" {
295                        continue;
296                    }
297                    min_params += 1;
298                    max_params += 1;
299                } else if kind == "self_parameter" {
300                    continue;
301                }
302            }
303            "go" => {
304                if kind == "parameter_declaration" {
305                    // Check for variadic: ...Type
306                    let type_text = child
307                        .child_by_field_name("type")
308                        .and_then(|n| n.utf8_text(source).ok())
309                        .unwrap_or("");
310                    if type_text.starts_with("...") {
311                        is_variadic = true;
312                    } else {
313                        min_params += 1;
314                        max_params += 1;
315                    }
316                }
317            }
318            _ => {}
319        }
320    }
321
322    Some(ParamInfo {
323        min_params,
324        max_params,
325        is_variadic,
326    })
327}
328
329fn find_first_function(node: tree_sitter::Node) -> Option<tree_sitter::Node> {
330    let kind = node.kind();
331    if matches!(
332        kind,
333        "function_definition"
334            | "function_item"
335            | "function_declaration"
336            | "method_definition"
337            | "method_declaration"
338            | "arrow_function"
339    ) {
340        return Some(node);
341    }
342    let mut cursor = node.walk();
343    for child in node.named_children(&mut cursor) {
344        if let Some(f) = find_first_function(child) {
345            return Some(f);
346        }
347    }
348    None
349}
350
351/// Count call arguments at a specific call site using tree-sitter.
352pub fn count_call_args_ts(
353    caller_content: &str,
354    callee_name: &str,
355    file_path: &str,
356) -> Option<usize> {
357    let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
358    let config = get_language_config(ext)?;
359    let language = (config.get_language)()?;
360
361    let mut parser = tree_sitter::Parser::new();
362    let _ = parser.set_language(&language);
363    let tree = parser.parse(caller_content.as_bytes(), None)?;
364
365    find_call_arg_count(tree.root_node(), caller_content.as_bytes(), callee_name)
366}
367
368fn find_call_arg_count(
369    node: tree_sitter::Node,
370    source: &[u8],
371    callee_name: &str,
372) -> Option<usize> {
373    let kind = node.kind();
374
375    if kind == "call" || kind == "call_expression" {
376        let func = node.child_by_field_name("function")?;
377        let func_name = match func.kind() {
378            "identifier" => func.utf8_text(source).unwrap_or(""),
379            "attribute" | "member_expression" | "field_expression" => func
380                .child_by_field_name("attribute")
381                .or_else(|| func.child_by_field_name("property"))
382                .or_else(|| func.child_by_field_name("field"))
383                .and_then(|n| n.utf8_text(source).ok())
384                .unwrap_or(""),
385            "selector_expression" => func
386                .child_by_field_name("field")
387                .and_then(|n| n.utf8_text(source).ok())
388                .unwrap_or(""),
389            "scoped_identifier" => {
390                let text = func.utf8_text(source).unwrap_or("");
391                text.rsplit("::").next().unwrap_or("")
392            }
393            _ => "",
394        };
395
396        if func_name == callee_name {
397            let args = node.child_by_field_name("arguments")?;
398            let mut count = 0;
399            let mut cursor = args.walk();
400            for child in args.named_children(&mut cursor) {
401                // Skip comment nodes
402                if !child.kind().contains("comment") {
403                    count += 1;
404                }
405            }
406            return Some(count);
407        }
408    }
409
410    let mut cursor = node.walk();
411    for child in node.named_children(&mut cursor) {
412        if let Some(count) = find_call_arg_count(child, source, callee_name) {
413            return Some(count);
414        }
415    }
416    None
417}
418
419/// Find arity mismatches across all Calls edges in the graph.
420pub fn find_arity_mismatches(
421    graph: &EntityGraph,
422    all_entities: &[SemanticEntity],
423) -> Vec<ArityMismatch> {
424    let entity_by_id: HashMap<&str, &SemanticEntity> = all_entities
425        .iter()
426        .map(|e| (e.id.as_str(), e))
427        .collect();
428
429    // Cache param info per callee entity
430    let mut param_cache: HashMap<String, Option<ParamInfo>> = HashMap::new();
431
432    let mut mismatches = Vec::new();
433
434    for edge in &graph.edges {
435        if edge.ref_type != RefType::Calls {
436            continue;
437        }
438
439        let callee_info = match graph.entities.get(&edge.to_entity) {
440            Some(e) => e,
441            None => continue,
442        };
443
444        if !matches!(
445            callee_info.entity_type.as_str(),
446            "function" | "method" | "arrow_function"
447        ) {
448            continue;
449        }
450
451        let callee = match entity_by_id.get(edge.to_entity.as_str()) {
452            Some(e) => *e,
453            None => continue,
454        };
455
456        let caller = match entity_by_id.get(edge.from_entity.as_str()) {
457            Some(e) => *e,
458            None => continue,
459        };
460
461        // Get callee param info (cached)
462        let param_info = param_cache
463            .entry(callee.id.clone())
464            .or_insert_with(|| extract_param_info_ts(&callee.content, &callee.file_path))
465            .clone();
466
467        let param_info = match param_info {
468            Some(pi) => pi,
469            None => continue,
470        };
471
472        // Skip variadic functions
473        if param_info.is_variadic {
474            continue;
475        }
476
477        // Count call args using tree-sitter
478        let actual = match count_call_args_ts(
479            &caller.content,
480            &callee.name,
481            &caller.file_path,
482        ) {
483            Some(a) => a,
484            None => continue,
485        };
486
487        if actual < param_info.min_params || actual > param_info.max_params {
488            mismatches.push(ArityMismatch {
489                caller_entity: caller.name.clone(),
490                callee_entity: callee.name.clone(),
491                expected_min: param_info.min_params,
492                expected_max: param_info.max_params,
493                actual_args: actual,
494                file_path: caller.file_path.clone(),
495                line: caller.start_line,
496                is_variadic: false,
497            });
498        }
499    }
500
501    mismatches
502}
503
504/// Find callers broken by signature changes between old and new entities.
505/// Compares param counts of functions that exist in both old and new,
506/// then checks if any callers in new_graph pass the wrong arg count.
507pub fn find_broken_callers(
508    old_entities: &[SemanticEntity],
509    new_graph: &EntityGraph,
510    new_entities: &[SemanticEntity],
511) -> Vec<ArityMismatch> {
512    // Build old param info map: entity_id -> ParamInfo
513    let old_params: HashMap<String, Option<ParamInfo>> = old_entities
514        .iter()
515        .filter(|e| matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function"))
516        .map(|e| (e.id.clone(), extract_param_info_ts(&e.content, &e.file_path)))
517        .collect();
518
519    // Build new entity lookup
520    let new_by_id: HashMap<&str, &SemanticEntity> = new_entities
521        .iter()
522        .map(|e| (e.id.as_str(), e))
523        .collect();
524
525    // Find entities whose param counts changed
526    let mut changed_entities: Vec<&str> = Vec::new();
527    for new_entity in new_entities {
528        if !matches!(new_entity.entity_type.as_str(), "function" | "method" | "arrow_function") {
529            continue;
530        }
531        let new_info = match extract_param_info_ts(&new_entity.content, &new_entity.file_path) {
532            Some(pi) => pi,
533            None => continue,
534        };
535        if let Some(Some(old_info)) = old_params.get(&new_entity.id) {
536            if old_info.min_params != new_info.min_params
537                || old_info.max_params != new_info.max_params
538            {
539                changed_entities.push(&new_entity.id);
540            }
541        }
542    }
543
544    if changed_entities.is_empty() {
545        return Vec::new();
546    }
547
548    // Check all callers of changed entities
549    let mut mismatches = Vec::new();
550
551    for edge in &new_graph.edges {
552        if edge.ref_type != RefType::Calls {
553            continue;
554        }
555        if !changed_entities.contains(&edge.to_entity.as_str()) {
556            continue;
557        }
558
559        let callee = match new_by_id.get(edge.to_entity.as_str()) {
560            Some(e) => *e,
561            None => continue,
562        };
563        let caller = match new_by_id.get(edge.from_entity.as_str()) {
564            Some(e) => *e,
565            None => continue,
566        };
567
568        let new_info = match extract_param_info_ts(&callee.content, &callee.file_path) {
569            Some(pi) => pi,
570            None => continue,
571        };
572
573        if new_info.is_variadic {
574            continue;
575        }
576
577        let actual = match count_call_args_ts(&caller.content, &callee.name, &caller.file_path) {
578            Some(a) => a,
579            None => continue,
580        };
581
582        if actual < new_info.min_params || actual > new_info.max_params {
583            mismatches.push(ArityMismatch {
584                caller_entity: caller.name.clone(),
585                callee_entity: callee.name.clone(),
586                expected_min: new_info.min_params,
587                expected_max: new_info.max_params,
588                actual_args: actual,
589                file_path: caller.file_path.clone(),
590                line: caller.start_line,
591                is_variadic: false,
592            });
593        }
594    }
595
596    mismatches
597}
598
599// ─── String-based helpers (kept for backward compatibility) ──────────────────
600
601/// Extract param count from the first line of a function/method.
602fn extract_param_count(content: &str) -> usize {
603    let first_line = content.lines().next().unwrap_or("");
604
605    let open = match first_line.find('(') {
606        Some(i) => i,
607        None => return 0,
608    };
609
610    let after_open = &first_line[open + 1..];
611    let close = match find_matching_paren(after_open) {
612        Some(i) => i,
613        None => return 0,
614    };
615
616    let params_str = after_open[..close].trim();
617    if params_str.is_empty() {
618        return 0;
619    }
620
621    count_top_level_commas(params_str) + 1
622}
623
624/// Count arguments at a call site: find `callee_name(...)` in content and count args.
625fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
626    let bytes = content.as_bytes();
627    let name_bytes = callee_name.as_bytes();
628    let mut search_start = 0;
629
630    while let Some(rel_pos) = content[search_start..].find(callee_name) {
631        let pos = search_start + rel_pos;
632        let after = pos + name_bytes.len();
633
634        let is_boundary = pos == 0 || {
635            let prev = bytes[pos - 1];
636            !prev.is_ascii_alphanumeric() && prev != b'_'
637        };
638
639        if is_boundary && after < bytes.len() && bytes[after] == b'(' {
640            let args_start = &content[after + 1..];
641            if let Some(close) = find_matching_paren(args_start) {
642                let args_str = args_start[..close].trim();
643                if args_str.is_empty() {
644                    return Some(0);
645                }
646                return Some(count_top_level_commas(args_str) + 1);
647            }
648        }
649
650        search_start = pos + 1;
651        while search_start < content.len() && !content.is_char_boundary(search_start) {
652            search_start += 1;
653        }
654    }
655
656    None
657}
658
659fn find_matching_paren(s: &str) -> Option<usize> {
660    let mut depth = 0i32;
661    for (i, ch) in s.char_indices() {
662        match ch {
663            '(' => depth += 1,
664            ')' => {
665                if depth == 0 {
666                    return Some(i);
667                }
668                depth -= 1;
669            }
670            _ => {}
671        }
672    }
673    None
674}
675
676fn count_top_level_commas(s: &str) -> usize {
677    let mut depth = 0i32;
678    let mut count = 0;
679    for ch in s.chars() {
680        match ch {
681            '(' | '[' | '{' | '<' => depth += 1,
682            ')' | ']' | '}' | '>' => depth -= 1,
683            ',' if depth == 0 => count += 1,
684            _ => {}
685        }
686    }
687    count
688}
689
690#[cfg(test)]
691mod tests {
692    use super::*;
693
694    #[test]
695    fn test_extract_param_count_basic() {
696        assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
697        assert_eq!(extract_param_count("function foo() {"), 0);
698        assert_eq!(extract_param_count("def bar(self, x):"), 2);
699        assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
700    }
701
702    #[test]
703    fn test_extract_param_count_nested() {
704        assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
705    }
706
707    #[test]
708    fn test_count_call_args() {
709        assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
710        assert_eq!(count_call_args("foo()", "foo"), Some(0));
711        assert_eq!(count_call_args("bar(1)", "foo"), None);
712        assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
713    }
714
715    #[test]
716    fn test_count_call_args_multibyte_utf8() {
717        assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
718        assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
719        assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
720    }
721
722    #[test]
723    fn test_extract_param_info_python() {
724        let info = extract_param_info_ts(
725            "def foo(a, b, c=3):\n    pass",
726            "test.py",
727        )
728        .unwrap();
729        assert_eq!(info.min_params, 2);
730        assert_eq!(info.max_params, 3);
731        assert!(!info.is_variadic);
732    }
733
734    #[test]
735    fn test_extract_param_info_python_self() {
736        let info = extract_param_info_ts(
737            "def foo(self, a, b):\n    pass",
738            "test.py",
739        )
740        .unwrap();
741        assert_eq!(info.min_params, 2);
742        assert_eq!(info.max_params, 2);
743    }
744
745    #[test]
746    fn test_extract_param_info_python_variadic() {
747        let info = extract_param_info_ts(
748            "def foo(a, *args, **kwargs):\n    pass",
749            "test.py",
750        )
751        .unwrap();
752        assert!(info.is_variadic);
753    }
754
755    #[test]
756    fn test_extract_param_info_typescript() {
757        let info = extract_param_info_ts(
758            "function foo(a: number, b: string, c?: boolean): void {}",
759            "test.ts",
760        )
761        .unwrap();
762        assert_eq!(info.min_params, 2);
763        assert_eq!(info.max_params, 3);
764        assert!(!info.is_variadic);
765    }
766
767    #[test]
768    fn test_extract_param_info_rust() {
769        let info = extract_param_info_ts(
770            "fn foo(&self, a: i32, b: String) -> bool { true }",
771            "test.rs",
772        )
773        .unwrap();
774        assert_eq!(info.min_params, 2);
775        assert_eq!(info.max_params, 2);
776    }
777
778    #[test]
779    fn test_extract_param_info_go() {
780        let info = extract_param_info_ts(
781            "func foo(a string, b int) error { return nil }",
782            "test.go",
783        )
784        .unwrap();
785        assert_eq!(info.min_params, 2);
786        assert_eq!(info.max_params, 2);
787    }
788
789    #[test]
790    fn test_count_call_args_ts() {
791        let count = count_call_args_ts(
792            "function bar() { foo(1, 2, 3); }",
793            "foo",
794            "test.ts",
795        );
796        assert_eq!(count, Some(3));
797    }
798
799    #[test]
800    fn test_count_call_args_ts_method() {
801        let count = count_call_args_ts(
802            "function bar() { obj.foo(1, 2); }",
803            "foo",
804            "test.ts",
805        );
806        assert_eq!(count, Some(2));
807    }
808}