Skip to main content

sem_core/parser/
verify.rs

1//! Contract verification: check that callers pass the correct number of
2//! arguments to callees. Uses tree-sitter AST for accurate param/arg counting.
3
4use std::collections::HashMap;
5use std::path::Path;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityGraph, RefType};
9use crate::parser::plugins::code::languages::get_language_config;
10use crate::parser::registry::ParserRegistry;
11
12#[derive(Debug, Clone)]
13pub struct ContractViolation {
14    pub entity_name: String,
15    pub file_path: String,
16    pub expected_params: usize,
17    pub caller_name: String,
18    pub caller_file: String,
19    pub actual_args: usize,
20}
21
22/// Result of tree-sitter based parameter analysis.
23#[derive(Debug, Clone)]
24pub struct ParamInfo {
25    pub min_params: usize,
26    pub max_params: usize,
27    pub is_variadic: bool,
28}
29
30/// Arity mismatch found across the dependency graph.
31#[derive(Debug, Clone)]
32pub struct ArityMismatch {
33    pub caller_entity: String,
34    pub callee_entity: String,
35    pub expected_min: usize,
36    pub expected_max: usize,
37    pub actual_args: usize,
38    pub file_path: String,
39    pub line: usize,
40    pub is_variadic: bool,
41}
42
43/// Verify function call contracts across the codebase.
44pub fn verify_contracts(
45    root: &Path,
46    file_paths: &[String],
47    registry: &ParserRegistry,
48    target_file: Option<&str>,
49) -> Vec<ContractViolation> {
50    let (graph, _) = EntityGraph::build(root, file_paths, registry);
51
52    let mut content_map: HashMap<String, String> = HashMap::new();
53    for fp in file_paths {
54        let full = root.join(fp);
55        let content = match std::fs::read_to_string(&full) {
56            Ok(c) => c,
57            Err(_) => continue,
58        };
59        for entity in registry.extract_entities(fp, &content) {
60            content_map.insert(entity.id.clone(), entity.content.clone());
61        }
62    }
63
64    let mut violations = Vec::new();
65
66    for edge in &graph.edges {
67        if edge.ref_type != RefType::Calls {
68            continue;
69        }
70
71        let callee = match graph.entities.get(&edge.to_entity) {
72            Some(e) => e,
73            None => continue,
74        };
75
76        if let Some(tf) = target_file {
77            if callee.file_path != tf {
78                continue;
79            }
80        }
81
82        if !matches!(
83            callee.entity_type.as_str(),
84            "function" | "method" | "arrow_function"
85        ) {
86            continue;
87        }
88
89        let callee_content = match content_map.get(&edge.to_entity) {
90            Some(c) => c,
91            None => continue,
92        };
93
94        let caller = match graph.entities.get(&edge.from_entity) {
95            Some(e) => e,
96            None => continue,
97        };
98
99        let caller_content = match content_map.get(&edge.from_entity) {
100            Some(c) => c,
101            None => continue,
102        };
103
104        let expected = extract_param_count(callee_content);
105        if expected == 0 {
106            continue;
107        }
108
109        if let Some(actual) = count_call_args(caller_content, &callee.name) {
110            if actual != expected {
111                violations.push(ContractViolation {
112                    entity_name: callee.name.clone(),
113                    file_path: callee.file_path.clone(),
114                    expected_params: expected,
115                    caller_name: caller.name.clone(),
116                    caller_file: caller.file_path.clone(),
117                    actual_args: actual,
118                });
119            }
120        }
121    }
122
123    violations
124}
125
126/// Like `verify_contracts`, but accepts a pre-built graph + entities.
127pub fn verify_contracts_with_graph(
128    graph: &EntityGraph,
129    all_entities: &[SemanticEntity],
130    target_file: Option<&str>,
131) -> Vec<ContractViolation> {
132    let content_map: HashMap<String, String> = all_entities
133        .iter()
134        .map(|e| (e.id.clone(), e.content.clone()))
135        .collect();
136
137    let mut violations = Vec::new();
138
139    for edge in &graph.edges {
140        if edge.ref_type != RefType::Calls {
141            continue;
142        }
143
144        let callee = match graph.entities.get(&edge.to_entity) {
145            Some(e) => e,
146            None => continue,
147        };
148
149        if let Some(tf) = target_file {
150            if callee.file_path != tf {
151                continue;
152            }
153        }
154
155        if !matches!(
156            callee.entity_type.as_str(),
157            "function" | "method" | "arrow_function"
158        ) {
159            continue;
160        }
161
162        let callee_content = match content_map.get(&edge.to_entity) {
163            Some(c) => c,
164            None => continue,
165        };
166
167        let caller = match graph.entities.get(&edge.from_entity) {
168            Some(e) => e,
169            None => continue,
170        };
171
172        let caller_content = match content_map.get(&edge.from_entity) {
173            Some(c) => c,
174            None => continue,
175        };
176
177        let expected = extract_param_count(callee_content);
178        if expected == 0 {
179            continue;
180        }
181
182        if let Some(actual) = count_call_args(caller_content, &callee.name) {
183            if actual != expected {
184                violations.push(ContractViolation {
185                    entity_name: callee.name.clone(),
186                    file_path: callee.file_path.clone(),
187                    expected_params: expected,
188                    caller_name: caller.name.clone(),
189                    caller_file: caller.file_path.clone(),
190                    actual_args: actual,
191                });
192            }
193        }
194    }
195
196    violations
197}
198
199// ─── Tree-sitter based arity analysis ───────────────────────────────────────
200
201fn lang_from_ext(ext: &str) -> &'static str {
202    match ext {
203        ".py" | ".pyi" => "python",
204        ".ts" | ".tsx" | ".mts" | ".cts" => "typescript",
205        ".js" | ".jsx" | ".mjs" | ".cjs" => "typescript",
206        ".rs" => "rust",
207        ".go" => "go",
208        _ => "unknown",
209    }
210}
211
212/// Extract parameter info from entity content using tree-sitter.
213pub fn extract_param_info_ts(content: &str, file_path: &str) -> Option<ParamInfo> {
214    let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
215    let lang = lang_from_ext(ext);
216    if lang == "unknown" {
217        return None;
218    }
219    let config = get_language_config(ext)?;
220    let language = (config.get_language)()?;
221
222    let mut parser = tree_sitter::Parser::new();
223    let _ = parser.set_language(&language);
224    let tree = parser.parse(content.as_bytes(), None)?;
225
226    extract_param_info_from_node(tree.root_node(), content.as_bytes(), lang)
227}
228
229fn extract_param_info_from_node(
230    root: tree_sitter::Node,
231    source: &[u8],
232    lang: &str,
233) -> Option<ParamInfo> {
234    // Find the first function-like node
235    let func_node = find_first_function(root)?;
236    let params_node = func_node.child_by_field_name("parameters")?;
237
238    let mut min_params = 0usize;
239    let mut max_params = 0usize;
240    let mut is_variadic = false;
241
242    let mut cursor = params_node.walk();
243    for child in params_node.named_children(&mut cursor) {
244        let kind = child.kind();
245        match lang {
246            "python" => {
247                if kind == "identifier" {
248                    let name = child.utf8_text(source).unwrap_or("");
249                    if name == "self" || name == "cls" {
250                        continue;
251                    }
252                    min_params += 1;
253                    max_params += 1;
254                } else if kind == "typed_parameter" {
255                    let name = child
256                        .child_by_field_name("name")
257                        .or_else(|| child.named_child(0))
258                        .and_then(|n| n.utf8_text(source).ok())
259                        .unwrap_or("");
260                    if name == "self" || name == "cls" {
261                        continue;
262                    }
263                    min_params += 1;
264                    max_params += 1;
265                } else if kind == "default_parameter" || kind == "typed_default_parameter" {
266                    max_params += 1;
267                } else if kind == "list_splat_pattern" || kind == "dictionary_splat_pattern" {
268                    is_variadic = true;
269                }
270            }
271            "typescript" => {
272                if kind == "required_parameter" {
273                    min_params += 1;
274                    max_params += 1;
275                } else if kind == "optional_parameter" {
276                    max_params += 1;
277                } else if kind == "rest_pattern" {
278                    is_variadic = true;
279                }
280            }
281            "rust" => {
282                if kind == "parameter" {
283                    let pat = child
284                        .child_by_field_name("pattern")
285                        .and_then(|n| n.utf8_text(source).ok())
286                        .unwrap_or("");
287                    // Skip self/&self/&mut self
288                    let base = pat.trim_start_matches('&').trim();
289                    let base = base.strip_prefix("mut ").unwrap_or(base).trim();
290                    if base == "self" {
291                        continue;
292                    }
293                    min_params += 1;
294                    max_params += 1;
295                } else if kind == "self_parameter" {
296                    continue;
297                }
298            }
299            "go" => {
300                if kind == "parameter_declaration" {
301                    // Check for variadic: ...Type
302                    let type_text = child
303                        .child_by_field_name("type")
304                        .and_then(|n| n.utf8_text(source).ok())
305                        .unwrap_or("");
306                    if type_text.starts_with("...") {
307                        is_variadic = true;
308                    } else {
309                        min_params += 1;
310                        max_params += 1;
311                    }
312                }
313            }
314            _ => {}
315        }
316    }
317
318    Some(ParamInfo {
319        min_params,
320        max_params,
321        is_variadic,
322    })
323}
324
325fn find_first_function(root: tree_sitter::Node) -> Option<tree_sitter::Node> {
326    let mut worklist = vec![root];
327    while let Some(node) = worklist.pop() {
328        let kind = node.kind();
329        if matches!(
330            kind,
331            "function_definition"
332                | "function_item"
333                | "function_declaration"
334                | "method_definition"
335                | "method_declaration"
336                | "arrow_function"
337        ) {
338            return Some(node);
339        }
340        let mut cursor = node.walk();
341        let children: Vec<_> = node.named_children(&mut cursor).collect();
342        for child in children.into_iter().rev() {
343            worklist.push(child);
344        }
345    }
346    None
347}
348
349/// Count call arguments at a specific call site using tree-sitter.
350pub fn count_call_args_ts(
351    caller_content: &str,
352    callee_name: &str,
353    file_path: &str,
354) -> Option<usize> {
355    let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
356    let config = get_language_config(ext)?;
357    let language = (config.get_language)()?;
358
359    let mut parser = tree_sitter::Parser::new();
360    let _ = parser.set_language(&language);
361    let tree = parser.parse(caller_content.as_bytes(), None)?;
362
363    find_call_arg_count(tree.root_node(), caller_content.as_bytes(), callee_name)
364}
365
366fn find_call_arg_count(
367    root: tree_sitter::Node,
368    source: &[u8],
369    callee_name: &str,
370) -> Option<usize> {
371    let mut worklist = vec![root];
372    while let Some(node) = worklist.pop() {
373        let kind = node.kind();
374
375        if kind == "call" || kind == "call_expression" {
376            if let Some(func) = node.child_by_field_name("function") {
377                let func_name = match func.kind() {
378                    "identifier" => func.utf8_text(source).unwrap_or(""),
379                    "attribute" | "member_expression" | "field_expression" => func
380                        .child_by_field_name("attribute")
381                        .or_else(|| func.child_by_field_name("property"))
382                        .or_else(|| func.child_by_field_name("field"))
383                        .and_then(|n| n.utf8_text(source).ok())
384                        .unwrap_or(""),
385                    "selector_expression" => func
386                        .child_by_field_name("field")
387                        .and_then(|n| n.utf8_text(source).ok())
388                        .unwrap_or(""),
389                    "scoped_identifier" => {
390                        let text = func.utf8_text(source).unwrap_or("");
391                        text.rsplit("::").next().unwrap_or("")
392                    }
393                    _ => "",
394                };
395
396                if func_name == callee_name {
397                    if let Some(args) = node.child_by_field_name("arguments") {
398                        let mut count = 0;
399                        let mut cursor = args.walk();
400                        for child in args.named_children(&mut cursor) {
401                            // Skip comment nodes
402                            if !child.kind().contains("comment") {
403                                count += 1;
404                            }
405                        }
406                        return Some(count);
407                    }
408                }
409            }
410        }
411
412        let mut cursor = node.walk();
413        let children: Vec<_> = node.named_children(&mut cursor).collect();
414        for child in children.into_iter().rev() {
415            worklist.push(child);
416        }
417    }
418    None
419}
420
421/// Find arity mismatches across all Calls edges in the graph.
422pub fn find_arity_mismatches(
423    graph: &EntityGraph,
424    all_entities: &[SemanticEntity],
425) -> Vec<ArityMismatch> {
426    let entity_by_id: HashMap<&str, &SemanticEntity> = all_entities
427        .iter()
428        .map(|e| (e.id.as_str(), e))
429        .collect();
430
431    // Cache param info per callee entity
432    let mut param_cache: HashMap<String, Option<ParamInfo>> = HashMap::new();
433
434    let mut mismatches = Vec::new();
435
436    for edge in &graph.edges {
437        if edge.ref_type != RefType::Calls {
438            continue;
439        }
440
441        let callee_info = match graph.entities.get(&edge.to_entity) {
442            Some(e) => e,
443            None => continue,
444        };
445
446        if !matches!(
447            callee_info.entity_type.as_str(),
448            "function" | "method" | "arrow_function"
449        ) {
450            continue;
451        }
452
453        let callee = match entity_by_id.get(edge.to_entity.as_str()) {
454            Some(e) => *e,
455            None => continue,
456        };
457
458        let caller = match entity_by_id.get(edge.from_entity.as_str()) {
459            Some(e) => *e,
460            None => continue,
461        };
462
463        // Get callee param info (cached)
464        let param_info = param_cache
465            .entry(callee.id.clone())
466            .or_insert_with(|| extract_param_info_ts(&callee.content, &callee.file_path))
467            .clone();
468
469        let param_info = match param_info {
470            Some(pi) => pi,
471            None => continue,
472        };
473
474        // Skip variadic functions
475        if param_info.is_variadic {
476            continue;
477        }
478
479        // Count call args using tree-sitter
480        let actual = match count_call_args_ts(
481            &caller.content,
482            &callee.name,
483            &caller.file_path,
484        ) {
485            Some(a) => a,
486            None => continue,
487        };
488
489        if actual < param_info.min_params || actual > param_info.max_params {
490            mismatches.push(ArityMismatch {
491                caller_entity: caller.name.clone(),
492                callee_entity: callee.name.clone(),
493                expected_min: param_info.min_params,
494                expected_max: param_info.max_params,
495                actual_args: actual,
496                file_path: caller.file_path.clone(),
497                line: caller.start_line,
498                is_variadic: false,
499            });
500        }
501    }
502
503    mismatches
504}
505
506/// Find callers broken by signature changes between old and new entities.
507/// Compares param counts of functions that exist in both old and new,
508/// then checks if any callers in new_graph pass the wrong arg count.
509pub fn find_broken_callers(
510    old_entities: &[SemanticEntity],
511    new_graph: &EntityGraph,
512    new_entities: &[SemanticEntity],
513) -> Vec<ArityMismatch> {
514    // Build old param info map: entity_id -> ParamInfo
515    let old_params: HashMap<String, Option<ParamInfo>> = old_entities
516        .iter()
517        .filter(|e| matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function"))
518        .map(|e| (e.id.clone(), extract_param_info_ts(&e.content, &e.file_path)))
519        .collect();
520
521    // Build new entity lookup
522    let new_by_id: HashMap<&str, &SemanticEntity> = new_entities
523        .iter()
524        .map(|e| (e.id.as_str(), e))
525        .collect();
526
527    // Find entities whose param counts changed
528    let mut changed_entities: Vec<&str> = Vec::new();
529    for new_entity in new_entities {
530        if !matches!(new_entity.entity_type.as_str(), "function" | "method" | "arrow_function") {
531            continue;
532        }
533        let new_info = match extract_param_info_ts(&new_entity.content, &new_entity.file_path) {
534            Some(pi) => pi,
535            None => continue,
536        };
537        if let Some(Some(old_info)) = old_params.get(&new_entity.id) {
538            if old_info.min_params != new_info.min_params
539                || old_info.max_params != new_info.max_params
540            {
541                changed_entities.push(&new_entity.id);
542            }
543        }
544    }
545
546    if changed_entities.is_empty() {
547        return Vec::new();
548    }
549
550    // Check all callers of changed entities
551    let mut mismatches = Vec::new();
552
553    for edge in &new_graph.edges {
554        if edge.ref_type != RefType::Calls {
555            continue;
556        }
557        if !changed_entities.contains(&edge.to_entity.as_str()) {
558            continue;
559        }
560
561        let callee = match new_by_id.get(edge.to_entity.as_str()) {
562            Some(e) => *e,
563            None => continue,
564        };
565        let caller = match new_by_id.get(edge.from_entity.as_str()) {
566            Some(e) => *e,
567            None => continue,
568        };
569
570        let new_info = match extract_param_info_ts(&callee.content, &callee.file_path) {
571            Some(pi) => pi,
572            None => continue,
573        };
574
575        if new_info.is_variadic {
576            continue;
577        }
578
579        let actual = match count_call_args_ts(&caller.content, &callee.name, &caller.file_path) {
580            Some(a) => a,
581            None => continue,
582        };
583
584        if actual < new_info.min_params || actual > new_info.max_params {
585            mismatches.push(ArityMismatch {
586                caller_entity: caller.name.clone(),
587                callee_entity: callee.name.clone(),
588                expected_min: new_info.min_params,
589                expected_max: new_info.max_params,
590                actual_args: actual,
591                file_path: caller.file_path.clone(),
592                line: caller.start_line,
593                is_variadic: false,
594            });
595        }
596    }
597
598    mismatches
599}
600
601// ─── String-based helpers (kept for backward compatibility) ──────────────────
602
603/// Extract param count from the first line of a function/method.
604fn extract_param_count(content: &str) -> usize {
605    let first_line = content.lines().next().unwrap_or("");
606
607    let open = match first_line.find('(') {
608        Some(i) => i,
609        None => return 0,
610    };
611
612    let after_open = &first_line[open + 1..];
613    let close = match find_matching_paren(after_open) {
614        Some(i) => i,
615        None => return 0,
616    };
617
618    let params_str = after_open[..close].trim();
619    if params_str.is_empty() {
620        return 0;
621    }
622
623    count_top_level_commas(params_str) + 1
624}
625
626/// Count arguments at a call site: find `callee_name(...)` in content and count args.
627fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
628    let bytes = content.as_bytes();
629    let name_bytes = callee_name.as_bytes();
630    let mut search_start = 0;
631
632    while let Some(rel_pos) = content[search_start..].find(callee_name) {
633        let pos = search_start + rel_pos;
634        let after = pos + name_bytes.len();
635
636        let is_boundary = pos == 0 || {
637            let prev = bytes[pos - 1];
638            !prev.is_ascii_alphanumeric() && prev != b'_'
639        };
640
641        if is_boundary && after < bytes.len() && bytes[after] == b'(' {
642            let args_start = &content[after + 1..];
643            if let Some(close) = find_matching_paren(args_start) {
644                let args_str = args_start[..close].trim();
645                if args_str.is_empty() {
646                    return Some(0);
647                }
648                return Some(count_top_level_commas(args_str) + 1);
649            }
650        }
651
652        search_start = pos + 1;
653        while search_start < content.len() && !content.is_char_boundary(search_start) {
654            search_start += 1;
655        }
656    }
657
658    None
659}
660
661fn find_matching_paren(s: &str) -> Option<usize> {
662    let mut depth = 0i32;
663    for (i, ch) in s.char_indices() {
664        match ch {
665            '(' => depth += 1,
666            ')' => {
667                if depth == 0 {
668                    return Some(i);
669                }
670                depth -= 1;
671            }
672            _ => {}
673        }
674    }
675    None
676}
677
678fn count_top_level_commas(s: &str) -> usize {
679    let mut depth = 0i32;
680    let mut count = 0;
681    for ch in s.chars() {
682        match ch {
683            '(' | '[' | '{' | '<' => depth += 1,
684            ')' | ']' | '}' | '>' => depth -= 1,
685            ',' if depth == 0 => count += 1,
686            _ => {}
687        }
688    }
689    count
690}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695
696    #[test]
697    fn test_extract_param_count_basic() {
698        assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
699        assert_eq!(extract_param_count("function foo() {"), 0);
700        assert_eq!(extract_param_count("def bar(self, x):"), 2);
701        assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
702    }
703
704    #[test]
705    fn test_extract_param_count_nested() {
706        assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
707    }
708
709    #[test]
710    fn test_count_call_args() {
711        assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
712        assert_eq!(count_call_args("foo()", "foo"), Some(0));
713        assert_eq!(count_call_args("bar(1)", "foo"), None);
714        assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
715    }
716
717    #[test]
718    fn test_count_call_args_multibyte_utf8() {
719        assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
720        assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
721        assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
722    }
723
724    #[test]
725    fn test_extract_param_info_python() {
726        let info = extract_param_info_ts(
727            "def foo(a, b, c=3):\n    pass",
728            "test.py",
729        )
730        .unwrap();
731        assert_eq!(info.min_params, 2);
732        assert_eq!(info.max_params, 3);
733        assert!(!info.is_variadic);
734    }
735
736    #[test]
737    fn test_extract_param_info_python_self() {
738        let info = extract_param_info_ts(
739            "def foo(self, a, b):\n    pass",
740            "test.py",
741        )
742        .unwrap();
743        assert_eq!(info.min_params, 2);
744        assert_eq!(info.max_params, 2);
745    }
746
747    #[test]
748    fn test_extract_param_info_python_variadic() {
749        let info = extract_param_info_ts(
750            "def foo(a, *args, **kwargs):\n    pass",
751            "test.py",
752        )
753        .unwrap();
754        assert!(info.is_variadic);
755    }
756
757    #[test]
758    fn test_extract_param_info_typescript() {
759        let info = extract_param_info_ts(
760            "function foo(a: number, b: string, c?: boolean): void {}",
761            "test.ts",
762        )
763        .unwrap();
764        assert_eq!(info.min_params, 2);
765        assert_eq!(info.max_params, 3);
766        assert!(!info.is_variadic);
767    }
768
769    #[test]
770    fn test_extract_param_info_rust() {
771        let info = extract_param_info_ts(
772            "fn foo(&self, a: i32, b: String) -> bool { true }",
773            "test.rs",
774        )
775        .unwrap();
776        assert_eq!(info.min_params, 2);
777        assert_eq!(info.max_params, 2);
778    }
779
780    #[test]
781    fn test_extract_param_info_go() {
782        let info = extract_param_info_ts(
783            "func foo(a string, b int) error { return nil }",
784            "test.go",
785        )
786        .unwrap();
787        assert_eq!(info.min_params, 2);
788        assert_eq!(info.max_params, 2);
789    }
790
791    #[test]
792    fn test_count_call_args_ts() {
793        let count = count_call_args_ts(
794            "function bar() { foo(1, 2, 3); }",
795            "foo",
796            "test.ts",
797        );
798        assert_eq!(count, Some(3));
799    }
800
801    #[test]
802    fn test_count_call_args_ts_method() {
803        let count = count_call_args_ts(
804            "function bar() { obj.foo(1, 2); }",
805            "foo",
806            "test.ts",
807        );
808        assert_eq!(count, Some(2));
809    }
810}