Skip to main content

sem_core/parser/
verify.rs

1//! Contract verification: check that callers pass the correct number of
2//! arguments to callees. Uses tree-sitter AST for accurate param/arg counting.
3
4use std::collections::HashMap;
5use std::path::Path;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityGraph, RefType};
9use crate::parser::plugins::code::languages::get_language_config;
10use crate::parser::registry::ParserRegistry;
11
12#[derive(Debug, Clone)]
13pub struct ContractViolation {
14    pub entity_name: String,
15    pub file_path: String,
16    pub expected_params: usize,
17    pub caller_name: String,
18    pub caller_file: String,
19    pub actual_args: usize,
20}
21
22/// Result of tree-sitter based parameter analysis.
23#[derive(Debug, Clone)]
24pub struct ParamInfo {
25    pub min_params: usize,
26    pub max_params: usize,
27    pub is_variadic: bool,
28}
29
30/// Arity mismatch found across the dependency graph.
31#[derive(Debug, Clone)]
32pub struct ArityMismatch {
33    pub caller_entity: String,
34    pub callee_entity: String,
35    pub expected_min: usize,
36    pub expected_max: usize,
37    pub actual_args: usize,
38    pub file_path: String,
39    pub line: usize,
40    pub is_variadic: bool,
41}
42
43/// Verify function call contracts across the codebase.
44pub fn verify_contracts(
45    root: &Path,
46    file_paths: &[String],
47    registry: &ParserRegistry,
48    target_file: Option<&str>,
49) -> Vec<ContractViolation> {
50    let (graph, _) = EntityGraph::build(root, file_paths, registry);
51
52    let mut content_map: HashMap<String, String> = HashMap::new();
53    for fp in file_paths {
54        let full = root.join(fp);
55        let content = match std::fs::read_to_string(&full) {
56            Ok(c) => c,
57            Err(_) => continue,
58        };
59        for entity in registry.extract_entities(fp, &content) {
60            content_map.insert(entity.id.clone(), entity.content.clone());
61        }
62    }
63
64    let mut violations = Vec::new();
65
66    for edge in &graph.edges {
67        if edge.ref_type != RefType::Calls {
68            continue;
69        }
70
71        let callee = match graph.entities.get(&edge.to_entity) {
72            Some(e) => e,
73            None => continue,
74        };
75
76        if let Some(tf) = target_file {
77            if callee.file_path != tf {
78                continue;
79            }
80        }
81
82        if !matches!(
83            callee.entity_type.as_str(),
84            "function" | "method" | "arrow_function"
85        ) {
86            continue;
87        }
88
89        let callee_content = match content_map.get(&edge.to_entity) {
90            Some(c) => c,
91            None => continue,
92        };
93
94        let caller = match graph.entities.get(&edge.from_entity) {
95            Some(e) => e,
96            None => continue,
97        };
98
99        let caller_content = match content_map.get(&edge.from_entity) {
100            Some(c) => c,
101            None => continue,
102        };
103
104        let expected = extract_param_count(callee_content);
105        if expected == 0 {
106            continue;
107        }
108
109        if let Some(actual) = count_call_args(caller_content, &callee.name) {
110            if actual != expected {
111                violations.push(ContractViolation {
112                    entity_name: callee.name.clone(),
113                    file_path: callee.file_path.clone(),
114                    expected_params: expected,
115                    caller_name: caller.name.clone(),
116                    caller_file: caller.file_path.clone(),
117                    actual_args: actual,
118                });
119            }
120        }
121    }
122
123    violations
124}
125
126/// Like `verify_contracts`, but accepts a pre-built graph + entities.
127pub fn verify_contracts_with_graph(
128    graph: &EntityGraph,
129    all_entities: &[SemanticEntity],
130    target_file: Option<&str>,
131) -> Vec<ContractViolation> {
132    let content_map: HashMap<String, String> = all_entities
133        .iter()
134        .map(|e| (e.id.clone(), e.content.clone()))
135        .collect();
136
137    let mut violations = Vec::new();
138
139    for edge in &graph.edges {
140        if edge.ref_type != RefType::Calls {
141            continue;
142        }
143
144        let callee = match graph.entities.get(&edge.to_entity) {
145            Some(e) => e,
146            None => continue,
147        };
148
149        if let Some(tf) = target_file {
150            if callee.file_path != tf {
151                continue;
152            }
153        }
154
155        if !matches!(
156            callee.entity_type.as_str(),
157            "function" | "method" | "arrow_function"
158        ) {
159            continue;
160        }
161
162        let callee_content = match content_map.get(&edge.to_entity) {
163            Some(c) => c,
164            None => continue,
165        };
166
167        let caller = match graph.entities.get(&edge.from_entity) {
168            Some(e) => e,
169            None => continue,
170        };
171
172        let caller_content = match content_map.get(&edge.from_entity) {
173            Some(c) => c,
174            None => continue,
175        };
176
177        let expected = extract_param_count(callee_content);
178        if expected == 0 {
179            continue;
180        }
181
182        if let Some(actual) = count_call_args(caller_content, &callee.name) {
183            if actual != expected {
184                violations.push(ContractViolation {
185                    entity_name: callee.name.clone(),
186                    file_path: callee.file_path.clone(),
187                    expected_params: expected,
188                    caller_name: caller.name.clone(),
189                    caller_file: caller.file_path.clone(),
190                    actual_args: actual,
191                });
192            }
193        }
194    }
195
196    violations
197}
198
199// ─── Tree-sitter based arity analysis ───────────────────────────────────────
200
201fn lang_from_ext(ext: &str) -> &'static str {
202    match ext {
203        ".py" | ".pyi" => "python",
204        ".ts" | ".tsx" | ".mts" | ".cts" => "typescript",
205        ".js" | ".jsx" | ".mjs" | ".cjs" => "typescript",
206        ".rs" => "rust",
207        ".go" => "go",
208        _ => "unknown",
209    }
210}
211
212/// Extract parameter info from entity content using tree-sitter.
213pub fn extract_param_info_ts(content: &str, file_path: &str) -> Option<ParamInfo> {
214    let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
215    let lang = lang_from_ext(ext);
216    if lang == "unknown" {
217        return None;
218    }
219    let config = get_language_config(ext)?;
220    let language = (config.get_language)()?;
221
222    let mut parser = tree_sitter::Parser::new();
223    let _ = parser.set_language(&language);
224    let tree = parser.parse(content.as_bytes(), None)?;
225
226    extract_param_info_from_node(tree.root_node(), content.as_bytes(), lang)
227}
228
229fn extract_param_info_from_node(
230    root: tree_sitter::Node,
231    source: &[u8],
232    lang: &str,
233) -> Option<ParamInfo> {
234    // Find the first function-like node
235    let func_node = find_first_function(root)?;
236    let params_node = func_node.child_by_field_name("parameters")?;
237
238    let mut min_params = 0usize;
239    let mut max_params = 0usize;
240    let mut is_variadic = false;
241
242    let mut cursor = params_node.walk();
243    for child in params_node.named_children(&mut cursor) {
244        let kind = child.kind();
245        match lang {
246            "python" => {
247                if kind == "identifier" {
248                    let name = child.utf8_text(source).unwrap_or("");
249                    if name == "self" || name == "cls" {
250                        continue;
251                    }
252                    min_params += 1;
253                    max_params += 1;
254                } else if kind == "typed_parameter" {
255                    let name = child
256                        .child_by_field_name("name")
257                        .or_else(|| child.named_child(0))
258                        .and_then(|n| n.utf8_text(source).ok())
259                        .unwrap_or("");
260                    if name == "self" || name == "cls" {
261                        continue;
262                    }
263                    min_params += 1;
264                    max_params += 1;
265                } else if kind == "default_parameter" || kind == "typed_default_parameter" {
266                    max_params += 1;
267                } else if kind == "list_splat_pattern" || kind == "dictionary_splat_pattern" {
268                    is_variadic = true;
269                }
270            }
271            "typescript" => {
272                if kind == "required_parameter" {
273                    min_params += 1;
274                    max_params += 1;
275                } else if kind == "optional_parameter" {
276                    max_params += 1;
277                } else if kind == "rest_pattern" {
278                    is_variadic = true;
279                }
280            }
281            "rust" => {
282                if kind == "parameter" {
283                    let pat = child
284                        .child_by_field_name("pattern")
285                        .and_then(|n| n.utf8_text(source).ok())
286                        .unwrap_or("");
287                    // Skip self/&self/&mut self
288                    let base = pat.trim_start_matches('&').trim();
289                    let base = base.strip_prefix("mut ").unwrap_or(base).trim();
290                    if base == "self" {
291                        continue;
292                    }
293                    min_params += 1;
294                    max_params += 1;
295                } else if kind == "self_parameter" {
296                    continue;
297                }
298            }
299            "go" => {
300                if kind == "parameter_declaration" {
301                    // Check for variadic: ...Type
302                    let type_text = child
303                        .child_by_field_name("type")
304                        .and_then(|n| n.utf8_text(source).ok())
305                        .unwrap_or("");
306                    if type_text.starts_with("...") {
307                        is_variadic = true;
308                    } else {
309                        min_params += 1;
310                        max_params += 1;
311                    }
312                }
313            }
314            _ => {}
315        }
316    }
317
318    Some(ParamInfo {
319        min_params,
320        max_params,
321        is_variadic,
322    })
323}
324
325fn find_first_function(node: tree_sitter::Node) -> Option<tree_sitter::Node> {
326    let kind = node.kind();
327    if matches!(
328        kind,
329        "function_definition"
330            | "function_item"
331            | "function_declaration"
332            | "method_definition"
333            | "method_declaration"
334            | "arrow_function"
335    ) {
336        return Some(node);
337    }
338    let mut cursor = node.walk();
339    for child in node.named_children(&mut cursor) {
340        if let Some(f) = find_first_function(child) {
341            return Some(f);
342        }
343    }
344    None
345}
346
347/// Count call arguments at a specific call site using tree-sitter.
348pub fn count_call_args_ts(
349    caller_content: &str,
350    callee_name: &str,
351    file_path: &str,
352) -> Option<usize> {
353    let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
354    let config = get_language_config(ext)?;
355    let language = (config.get_language)()?;
356
357    let mut parser = tree_sitter::Parser::new();
358    let _ = parser.set_language(&language);
359    let tree = parser.parse(caller_content.as_bytes(), None)?;
360
361    find_call_arg_count(tree.root_node(), caller_content.as_bytes(), callee_name)
362}
363
364fn find_call_arg_count(
365    node: tree_sitter::Node,
366    source: &[u8],
367    callee_name: &str,
368) -> Option<usize> {
369    let kind = node.kind();
370
371    if kind == "call" || kind == "call_expression" {
372        let func = node.child_by_field_name("function")?;
373        let func_name = match func.kind() {
374            "identifier" => func.utf8_text(source).unwrap_or(""),
375            "attribute" | "member_expression" | "field_expression" => func
376                .child_by_field_name("attribute")
377                .or_else(|| func.child_by_field_name("property"))
378                .or_else(|| func.child_by_field_name("field"))
379                .and_then(|n| n.utf8_text(source).ok())
380                .unwrap_or(""),
381            "selector_expression" => func
382                .child_by_field_name("field")
383                .and_then(|n| n.utf8_text(source).ok())
384                .unwrap_or(""),
385            "scoped_identifier" => {
386                let text = func.utf8_text(source).unwrap_or("");
387                text.rsplit("::").next().unwrap_or("")
388            }
389            _ => "",
390        };
391
392        if func_name == callee_name {
393            let args = node.child_by_field_name("arguments")?;
394            let mut count = 0;
395            let mut cursor = args.walk();
396            for child in args.named_children(&mut cursor) {
397                // Skip comment nodes
398                if !child.kind().contains("comment") {
399                    count += 1;
400                }
401            }
402            return Some(count);
403        }
404    }
405
406    let mut cursor = node.walk();
407    for child in node.named_children(&mut cursor) {
408        if let Some(count) = find_call_arg_count(child, source, callee_name) {
409            return Some(count);
410        }
411    }
412    None
413}
414
415/// Find arity mismatches across all Calls edges in the graph.
416pub fn find_arity_mismatches(
417    graph: &EntityGraph,
418    all_entities: &[SemanticEntity],
419) -> Vec<ArityMismatch> {
420    let entity_by_id: HashMap<&str, &SemanticEntity> = all_entities
421        .iter()
422        .map(|e| (e.id.as_str(), e))
423        .collect();
424
425    // Cache param info per callee entity
426    let mut param_cache: HashMap<String, Option<ParamInfo>> = HashMap::new();
427
428    let mut mismatches = Vec::new();
429
430    for edge in &graph.edges {
431        if edge.ref_type != RefType::Calls {
432            continue;
433        }
434
435        let callee_info = match graph.entities.get(&edge.to_entity) {
436            Some(e) => e,
437            None => continue,
438        };
439
440        if !matches!(
441            callee_info.entity_type.as_str(),
442            "function" | "method" | "arrow_function"
443        ) {
444            continue;
445        }
446
447        let callee = match entity_by_id.get(edge.to_entity.as_str()) {
448            Some(e) => *e,
449            None => continue,
450        };
451
452        let caller = match entity_by_id.get(edge.from_entity.as_str()) {
453            Some(e) => *e,
454            None => continue,
455        };
456
457        // Get callee param info (cached)
458        let param_info = param_cache
459            .entry(callee.id.clone())
460            .or_insert_with(|| extract_param_info_ts(&callee.content, &callee.file_path))
461            .clone();
462
463        let param_info = match param_info {
464            Some(pi) => pi,
465            None => continue,
466        };
467
468        // Skip variadic functions
469        if param_info.is_variadic {
470            continue;
471        }
472
473        // Count call args using tree-sitter
474        let actual = match count_call_args_ts(
475            &caller.content,
476            &callee.name,
477            &caller.file_path,
478        ) {
479            Some(a) => a,
480            None => continue,
481        };
482
483        if actual < param_info.min_params || actual > param_info.max_params {
484            mismatches.push(ArityMismatch {
485                caller_entity: caller.name.clone(),
486                callee_entity: callee.name.clone(),
487                expected_min: param_info.min_params,
488                expected_max: param_info.max_params,
489                actual_args: actual,
490                file_path: caller.file_path.clone(),
491                line: caller.start_line,
492                is_variadic: false,
493            });
494        }
495    }
496
497    mismatches
498}
499
500/// Find callers broken by signature changes between old and new entities.
501/// Compares param counts of functions that exist in both old and new,
502/// then checks if any callers in new_graph pass the wrong arg count.
503pub fn find_broken_callers(
504    old_entities: &[SemanticEntity],
505    new_graph: &EntityGraph,
506    new_entities: &[SemanticEntity],
507) -> Vec<ArityMismatch> {
508    // Build old param info map: entity_id -> ParamInfo
509    let old_params: HashMap<String, Option<ParamInfo>> = old_entities
510        .iter()
511        .filter(|e| matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function"))
512        .map(|e| (e.id.clone(), extract_param_info_ts(&e.content, &e.file_path)))
513        .collect();
514
515    // Build new entity lookup
516    let new_by_id: HashMap<&str, &SemanticEntity> = new_entities
517        .iter()
518        .map(|e| (e.id.as_str(), e))
519        .collect();
520
521    // Find entities whose param counts changed
522    let mut changed_entities: Vec<&str> = Vec::new();
523    for new_entity in new_entities {
524        if !matches!(new_entity.entity_type.as_str(), "function" | "method" | "arrow_function") {
525            continue;
526        }
527        let new_info = match extract_param_info_ts(&new_entity.content, &new_entity.file_path) {
528            Some(pi) => pi,
529            None => continue,
530        };
531        if let Some(Some(old_info)) = old_params.get(&new_entity.id) {
532            if old_info.min_params != new_info.min_params
533                || old_info.max_params != new_info.max_params
534            {
535                changed_entities.push(&new_entity.id);
536            }
537        }
538    }
539
540    if changed_entities.is_empty() {
541        return Vec::new();
542    }
543
544    // Check all callers of changed entities
545    let mut mismatches = Vec::new();
546
547    for edge in &new_graph.edges {
548        if edge.ref_type != RefType::Calls {
549            continue;
550        }
551        if !changed_entities.contains(&edge.to_entity.as_str()) {
552            continue;
553        }
554
555        let callee = match new_by_id.get(edge.to_entity.as_str()) {
556            Some(e) => *e,
557            None => continue,
558        };
559        let caller = match new_by_id.get(edge.from_entity.as_str()) {
560            Some(e) => *e,
561            None => continue,
562        };
563
564        let new_info = match extract_param_info_ts(&callee.content, &callee.file_path) {
565            Some(pi) => pi,
566            None => continue,
567        };
568
569        if new_info.is_variadic {
570            continue;
571        }
572
573        let actual = match count_call_args_ts(&caller.content, &callee.name, &caller.file_path) {
574            Some(a) => a,
575            None => continue,
576        };
577
578        if actual < new_info.min_params || actual > new_info.max_params {
579            mismatches.push(ArityMismatch {
580                caller_entity: caller.name.clone(),
581                callee_entity: callee.name.clone(),
582                expected_min: new_info.min_params,
583                expected_max: new_info.max_params,
584                actual_args: actual,
585                file_path: caller.file_path.clone(),
586                line: caller.start_line,
587                is_variadic: false,
588            });
589        }
590    }
591
592    mismatches
593}
594
595// ─── String-based helpers (kept for backward compatibility) ──────────────────
596
597/// Extract param count from the first line of a function/method.
598fn extract_param_count(content: &str) -> usize {
599    let first_line = content.lines().next().unwrap_or("");
600
601    let open = match first_line.find('(') {
602        Some(i) => i,
603        None => return 0,
604    };
605
606    let after_open = &first_line[open + 1..];
607    let close = match find_matching_paren(after_open) {
608        Some(i) => i,
609        None => return 0,
610    };
611
612    let params_str = after_open[..close].trim();
613    if params_str.is_empty() {
614        return 0;
615    }
616
617    count_top_level_commas(params_str) + 1
618}
619
620/// Count arguments at a call site: find `callee_name(...)` in content and count args.
621fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
622    let bytes = content.as_bytes();
623    let name_bytes = callee_name.as_bytes();
624    let mut search_start = 0;
625
626    while let Some(rel_pos) = content[search_start..].find(callee_name) {
627        let pos = search_start + rel_pos;
628        let after = pos + name_bytes.len();
629
630        let is_boundary = pos == 0 || {
631            let prev = bytes[pos - 1];
632            !prev.is_ascii_alphanumeric() && prev != b'_'
633        };
634
635        if is_boundary && after < bytes.len() && bytes[after] == b'(' {
636            let args_start = &content[after + 1..];
637            if let Some(close) = find_matching_paren(args_start) {
638                let args_str = args_start[..close].trim();
639                if args_str.is_empty() {
640                    return Some(0);
641                }
642                return Some(count_top_level_commas(args_str) + 1);
643            }
644        }
645
646        search_start = pos + 1;
647        while search_start < content.len() && !content.is_char_boundary(search_start) {
648            search_start += 1;
649        }
650    }
651
652    None
653}
654
655fn find_matching_paren(s: &str) -> Option<usize> {
656    let mut depth = 0i32;
657    for (i, ch) in s.char_indices() {
658        match ch {
659            '(' => depth += 1,
660            ')' => {
661                if depth == 0 {
662                    return Some(i);
663                }
664                depth -= 1;
665            }
666            _ => {}
667        }
668    }
669    None
670}
671
672fn count_top_level_commas(s: &str) -> usize {
673    let mut depth = 0i32;
674    let mut count = 0;
675    for ch in s.chars() {
676        match ch {
677            '(' | '[' | '{' | '<' => depth += 1,
678            ')' | ']' | '}' | '>' => depth -= 1,
679            ',' if depth == 0 => count += 1,
680            _ => {}
681        }
682    }
683    count
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689
690    #[test]
691    fn test_extract_param_count_basic() {
692        assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
693        assert_eq!(extract_param_count("function foo() {"), 0);
694        assert_eq!(extract_param_count("def bar(self, x):"), 2);
695        assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
696    }
697
698    #[test]
699    fn test_extract_param_count_nested() {
700        assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
701    }
702
703    #[test]
704    fn test_count_call_args() {
705        assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
706        assert_eq!(count_call_args("foo()", "foo"), Some(0));
707        assert_eq!(count_call_args("bar(1)", "foo"), None);
708        assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
709    }
710
711    #[test]
712    fn test_count_call_args_multibyte_utf8() {
713        assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
714        assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
715        assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
716    }
717
718    #[test]
719    fn test_extract_param_info_python() {
720        let info = extract_param_info_ts(
721            "def foo(a, b, c=3):\n    pass",
722            "test.py",
723        )
724        .unwrap();
725        assert_eq!(info.min_params, 2);
726        assert_eq!(info.max_params, 3);
727        assert!(!info.is_variadic);
728    }
729
730    #[test]
731    fn test_extract_param_info_python_self() {
732        let info = extract_param_info_ts(
733            "def foo(self, a, b):\n    pass",
734            "test.py",
735        )
736        .unwrap();
737        assert_eq!(info.min_params, 2);
738        assert_eq!(info.max_params, 2);
739    }
740
741    #[test]
742    fn test_extract_param_info_python_variadic() {
743        let info = extract_param_info_ts(
744            "def foo(a, *args, **kwargs):\n    pass",
745            "test.py",
746        )
747        .unwrap();
748        assert!(info.is_variadic);
749    }
750
751    #[test]
752    fn test_extract_param_info_typescript() {
753        let info = extract_param_info_ts(
754            "function foo(a: number, b: string, c?: boolean): void {}",
755            "test.ts",
756        )
757        .unwrap();
758        assert_eq!(info.min_params, 2);
759        assert_eq!(info.max_params, 3);
760        assert!(!info.is_variadic);
761    }
762
763    #[test]
764    fn test_extract_param_info_rust() {
765        let info = extract_param_info_ts(
766            "fn foo(&self, a: i32, b: String) -> bool { true }",
767            "test.rs",
768        )
769        .unwrap();
770        assert_eq!(info.min_params, 2);
771        assert_eq!(info.max_params, 2);
772    }
773
774    #[test]
775    fn test_extract_param_info_go() {
776        let info = extract_param_info_ts(
777            "func foo(a string, b int) error { return nil }",
778            "test.go",
779        )
780        .unwrap();
781        assert_eq!(info.min_params, 2);
782        assert_eq!(info.max_params, 2);
783    }
784
785    #[test]
786    fn test_count_call_args_ts() {
787        let count = count_call_args_ts(
788            "function bar() { foo(1, 2, 3); }",
789            "foo",
790            "test.ts",
791        );
792        assert_eq!(count, Some(3));
793    }
794
795    #[test]
796    fn test_count_call_args_ts_method() {
797        let count = count_call_args_ts(
798            "function bar() { obj.foo(1, 2); }",
799            "foo",
800            "test.ts",
801        );
802        assert_eq!(count, Some(2));
803    }
804}