Skip to main content

sem_core/parser/
verify.rs

1//! Contract verification: check that callers pass the correct number of
2//! arguments to callees. Uses tree-sitter AST for accurate param/arg counting.
3
4use std::collections::HashMap;
5use std::path::Path;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityGraph, RefType};
9use crate::parser::plugins::code::languages::get_language_config;
10use crate::parser::registry::ParserRegistry;
11
12#[derive(Debug, Clone)]
13pub struct ContractViolation {
14    pub entity_name: String,
15    pub file_path: String,
16    pub expected_params: usize,
17    pub caller_name: String,
18    pub caller_file: String,
19    pub actual_args: usize,
20}
21
22/// Result of tree-sitter based parameter analysis.
23#[derive(Debug, Clone)]
24pub struct ParamInfo {
25    pub min_params: usize,
26    pub max_params: usize,
27    pub is_variadic: bool,
28}
29
30/// Arity mismatch found across the dependency graph.
31#[derive(Debug, Clone)]
32pub struct ArityMismatch {
33    pub caller_entity: String,
34    pub callee_entity: String,
35    pub expected_min: usize,
36    pub expected_max: usize,
37    pub actual_args: usize,
38    pub file_path: String,
39    pub line: usize,
40    pub is_variadic: bool,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44struct CallArgCount {
45    actual_args: usize,
46    line_offset: usize,
47}
48
49/// Verify function call contracts across the codebase.
50pub fn verify_contracts(
51    root: &Path,
52    file_paths: &[String],
53    registry: &ParserRegistry,
54    target_file: Option<&str>,
55) -> Vec<ContractViolation> {
56    let (graph, _) = EntityGraph::build(root, file_paths, registry);
57
58    let mut content_map: HashMap<String, String> = HashMap::new();
59    for fp in file_paths {
60        let full = root.join(fp);
61        let content = match std::fs::read_to_string(&full) {
62            Ok(c) => c,
63            Err(_) => continue,
64        };
65        for entity in registry.extract_entities(fp, &content) {
66            content_map.insert(entity.id.clone(), entity.content.clone());
67        }
68    }
69
70    let mut violations = Vec::new();
71
72    for edge in &graph.edges {
73        if edge.ref_type != RefType::Calls {
74            continue;
75        }
76
77        let callee = match graph.entities.get(&edge.to_entity) {
78            Some(e) => e,
79            None => continue,
80        };
81
82        if let Some(tf) = target_file {
83            if callee.file_path != tf {
84                continue;
85            }
86        }
87
88        if !matches!(
89            callee.entity_type.as_str(),
90            "function" | "method" | "arrow_function"
91        ) {
92            continue;
93        }
94
95        let callee_content = match content_map.get(&edge.to_entity) {
96            Some(c) => c,
97            None => continue,
98        };
99
100        let caller = match graph.entities.get(&edge.from_entity) {
101            Some(e) => e,
102            None => continue,
103        };
104
105        let caller_content = match content_map.get(&edge.from_entity) {
106            Some(c) => c,
107            None => continue,
108        };
109
110        let expected = extract_param_count(callee_content);
111        if expected == 0 {
112            continue;
113        }
114
115        for actual in count_all_call_args(caller_content, &callee.name) {
116            if actual != expected {
117                violations.push(ContractViolation {
118                    entity_name: callee.name.clone(),
119                    file_path: callee.file_path.clone(),
120                    expected_params: expected,
121                    caller_name: caller.name.clone(),
122                    caller_file: caller.file_path.clone(),
123                    actual_args: actual,
124                });
125            }
126        }
127    }
128
129    violations
130}
131
132/// Like `verify_contracts`, but accepts a pre-built graph + entities.
133pub fn verify_contracts_with_graph(
134    graph: &EntityGraph,
135    all_entities: &[SemanticEntity],
136    target_file: Option<&str>,
137) -> Vec<ContractViolation> {
138    let content_map: HashMap<String, String> = all_entities
139        .iter()
140        .map(|e| (e.id.clone(), e.content.clone()))
141        .collect();
142
143    let mut violations = Vec::new();
144
145    for edge in &graph.edges {
146        if edge.ref_type != RefType::Calls {
147            continue;
148        }
149
150        let callee = match graph.entities.get(&edge.to_entity) {
151            Some(e) => e,
152            None => continue,
153        };
154
155        if let Some(tf) = target_file {
156            if callee.file_path != tf {
157                continue;
158            }
159        }
160
161        if !matches!(
162            callee.entity_type.as_str(),
163            "function" | "method" | "arrow_function"
164        ) {
165            continue;
166        }
167
168        let callee_content = match content_map.get(&edge.to_entity) {
169            Some(c) => c,
170            None => continue,
171        };
172
173        let caller = match graph.entities.get(&edge.from_entity) {
174            Some(e) => e,
175            None => continue,
176        };
177
178        let caller_content = match content_map.get(&edge.from_entity) {
179            Some(c) => c,
180            None => continue,
181        };
182
183        let expected = extract_param_count(callee_content);
184        if expected == 0 {
185            continue;
186        }
187
188        for actual in count_all_call_args(caller_content, &callee.name) {
189            if actual != expected {
190                violations.push(ContractViolation {
191                    entity_name: callee.name.clone(),
192                    file_path: callee.file_path.clone(),
193                    expected_params: expected,
194                    caller_name: caller.name.clone(),
195                    caller_file: caller.file_path.clone(),
196                    actual_args: actual,
197                });
198            }
199        }
200    }
201
202    violations
203}
204
205// ─── Tree-sitter based arity analysis ───────────────────────────────────────
206
207fn lang_from_ext(ext: &str) -> &'static str {
208    match ext {
209        ".py" | ".pyi" => "python",
210        ".ts" | ".tsx" | ".mts" | ".cts" => "typescript",
211        ".js" | ".jsx" | ".mjs" | ".cjs" => "javascript",
212        ".rs" => "rust",
213        ".go" => "go",
214        _ => "unknown",
215    }
216}
217
218/// Extract parameter info from entity content using tree-sitter.
219pub fn extract_param_info_ts(content: &str, file_path: &str) -> Option<ParamInfo> {
220    let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
221    let lang = lang_from_ext(ext);
222    if lang == "unknown" {
223        return None;
224    }
225    let config = get_language_config(ext)?;
226    let language = (config.get_language)()?;
227
228    let mut parser = tree_sitter::Parser::new();
229    let _ = parser.set_language(&language);
230    let tree = parser.parse(content.as_bytes(), None)?;
231
232    extract_param_info_from_node(tree.root_node(), content.as_bytes(), lang)
233}
234
235fn extract_param_info_from_node(
236    root: tree_sitter::Node,
237    source: &[u8],
238    lang: &str,
239) -> Option<ParamInfo> {
240    // Find the first function-like node
241    let func_node = find_first_function(root)?;
242    let params_node = func_node.child_by_field_name("parameters")?;
243
244    let mut min_params = 0usize;
245    let mut max_params = 0usize;
246    let mut is_variadic = false;
247
248    let mut cursor = params_node.walk();
249    for child in params_node.named_children(&mut cursor) {
250        let kind = child.kind();
251        match lang {
252            "python" => {
253                if kind == "identifier" {
254                    let name = child.utf8_text(source).unwrap_or("");
255                    if name == "self" || name == "cls" {
256                        continue;
257                    }
258                    min_params += 1;
259                    max_params += 1;
260                } else if kind == "typed_parameter" {
261                    let name = child
262                        .child_by_field_name("name")
263                        .or_else(|| child.named_child(0))
264                        .and_then(|n| n.utf8_text(source).ok())
265                        .unwrap_or("");
266                    if name == "self" || name == "cls" {
267                        continue;
268                    }
269                    min_params += 1;
270                    max_params += 1;
271                } else if kind == "default_parameter" || kind == "typed_default_parameter" {
272                    max_params += 1;
273                } else if kind == "list_splat_pattern" || kind == "dictionary_splat_pattern" {
274                    is_variadic = true;
275                }
276            }
277            "typescript" => {
278                if kind == "required_parameter" {
279                    max_params += 1;
280                    if !has_js_ts_default_value(child) {
281                        min_params += 1;
282                    }
283                } else if kind == "optional_parameter" {
284                    max_params += 1;
285                } else if kind == "rest_pattern" {
286                    is_variadic = true;
287                }
288            }
289            "javascript" => {
290                if kind == "rest_pattern" {
291                    is_variadic = true;
292                } else if matches!(kind, "identifier" | "formal_parameter" | "assignment_pattern") {
293                    max_params += 1;
294                    if !has_js_ts_default_value(child) {
295                        min_params += 1;
296                    }
297                }
298            }
299            "rust" => {
300                if kind == "parameter" {
301                    let pat = child
302                        .child_by_field_name("pattern")
303                        .and_then(|n| n.utf8_text(source).ok())
304                        .unwrap_or("");
305                    // Skip self/&self/&mut self
306                    let base = pat.trim_start_matches('&').trim();
307                    let base = base.strip_prefix("mut ").unwrap_or(base).trim();
308                    if base == "self" {
309                        continue;
310                    }
311                    min_params += 1;
312                    max_params += 1;
313                } else if kind == "self_parameter" {
314                    continue;
315                }
316            }
317            "go" => {
318                if kind == "parameter_declaration" {
319                    let type_node = child.child_by_field_name("type");
320                    let type_text = type_node.and_then(|n| n.utf8_text(source).ok()).unwrap_or("");
321                    let param_text = child.utf8_text(source).unwrap_or("");
322                    if type_text.starts_with("...") || param_text.contains("...") {
323                        is_variadic = true;
324                    } else {
325                        let count = count_go_parameter_declaration_arity(child);
326                        min_params += count;
327                        max_params += count;
328                    }
329                }
330            }
331            _ => {}
332        }
333    }
334
335    Some(ParamInfo {
336        min_params,
337        max_params,
338        is_variadic,
339    })
340}
341
342fn has_js_ts_default_value(node: tree_sitter::Node) -> bool {
343    let mut cursor = node.walk();
344    let has_assignment_child = node
345        .named_children(&mut cursor)
346        .any(|child| child.kind() == "assignment_pattern");
347    node.kind() == "assignment_pattern"
348        || node.child_by_field_name("value").is_some()
349        || has_assignment_child
350}
351
352fn count_go_parameter_declaration_arity(node: tree_sitter::Node) -> usize {
353    let mut name_cursor = node.walk();
354    let field_names = node
355        .children_by_field_name("name", &mut name_cursor)
356        .count();
357    if field_names > 0 {
358        return field_names;
359    }
360
361    let type_range = match node.child_by_field_name("type") {
362        Some(type_node) => (type_node.start_byte(), type_node.end_byte()),
363        None => return 1,
364    };
365    let mut cursor = node.walk();
366    let identifier_names = node
367        .named_children(&mut cursor)
368        .filter(|child| {
369            child.kind() == "identifier" && type_range != (child.start_byte(), child.end_byte())
370        })
371        .count();
372    if identifier_names > 0 {
373        identifier_names
374    } else {
375        1
376    }
377}
378
379fn find_first_function(root: tree_sitter::Node) -> Option<tree_sitter::Node> {
380    let mut worklist = vec![root];
381    while let Some(node) = worklist.pop() {
382        let kind = node.kind();
383        if matches!(
384            kind,
385            "function_definition"
386                | "function_item"
387                | "function_declaration"
388                | "method_definition"
389                | "method_declaration"
390                | "arrow_function"
391        ) {
392            return Some(node);
393        }
394        let mut cursor = node.walk();
395        let children: Vec<_> = node.named_children(&mut cursor).collect();
396        for child in children.into_iter().rev() {
397            worklist.push(child);
398        }
399    }
400    None
401}
402
403/// Count call arguments at a specific call site using tree-sitter.
404pub fn count_call_args_ts(
405    caller_content: &str,
406    callee_name: &str,
407    file_path: &str,
408) -> Option<usize> {
409    count_call_arg_sites_ts(caller_content, callee_name, file_path)
410        .into_iter()
411        .next()
412        .map(|site| site.actual_args)
413}
414
415fn count_call_arg_sites_ts(
416    caller_content: &str,
417    callee_name: &str,
418    file_path: &str,
419) -> Vec<CallArgCount> {
420    let ext = match file_path.rfind('.').map(|i| &file_path[i..]) {
421        Some(ext) => ext,
422        None => return Vec::new(),
423    };
424    let config = match get_language_config(ext) {
425        Some(config) => config,
426        None => return Vec::new(),
427    };
428    let language = match (config.get_language)() {
429        Some(language) => language,
430        None => return Vec::new(),
431    };
432
433    let mut parser = tree_sitter::Parser::new();
434    let _ = parser.set_language(&language);
435    let tree = match parser.parse(caller_content.as_bytes(), None) {
436        Some(tree) => tree,
437        None => return Vec::new(),
438    };
439
440    find_call_arg_counts(tree.root_node(), caller_content.as_bytes(), callee_name)
441}
442
443fn find_call_arg_counts(
444    root: tree_sitter::Node,
445    source: &[u8],
446    callee_name: &str,
447) -> Vec<CallArgCount> {
448    let mut sites = Vec::new();
449    let mut worklist = vec![root];
450    while let Some(node) = worklist.pop() {
451        let kind = node.kind();
452
453        if kind == "call" || kind == "call_expression" {
454            if let Some(func) = node.child_by_field_name("function") {
455                let func_name = match func.kind() {
456                    "identifier" => func.utf8_text(source).unwrap_or(""),
457                    "attribute" | "member_expression" | "field_expression" => func
458                        .child_by_field_name("attribute")
459                        .or_else(|| func.child_by_field_name("property"))
460                        .or_else(|| func.child_by_field_name("field"))
461                        .and_then(|n| n.utf8_text(source).ok())
462                        .unwrap_or(""),
463                    "selector_expression" => func
464                        .child_by_field_name("field")
465                        .and_then(|n| n.utf8_text(source).ok())
466                        .unwrap_or(""),
467                    "scoped_identifier" => {
468                        let text = func.utf8_text(source).unwrap_or("");
469                        text.rsplit("::").next().unwrap_or("")
470                    }
471                    _ => "",
472                };
473
474                if func_name == callee_name {
475                    if let Some(args) = node.child_by_field_name("arguments") {
476                        let mut actual_args = 0;
477                        let mut cursor = args.walk();
478                        for child in args.named_children(&mut cursor) {
479                            // Skip comment nodes
480                            if !child.kind().contains("comment") {
481                                actual_args += 1;
482                            }
483                        }
484                        sites.push(CallArgCount {
485                            actual_args,
486                            line_offset: node.start_position().row,
487                        });
488                    }
489                }
490            }
491        }
492
493        let mut cursor = node.walk();
494        let children: Vec<_> = node.named_children(&mut cursor).collect();
495        for child in children.into_iter().rev() {
496            worklist.push(child);
497        }
498    }
499    sites
500}
501
502/// Names too common/ambiguous for reliable arity checking (constructors, builtins).
503const AMBIGUOUS_NAMES: &[&str] = &[
504    "new", "constructor", "toString", "valueOf", "init", "__init__",
505    "apply", "call", "bind", "get", "set", "run", "execute", "create",
506];
507
508fn is_test_or_fixture_path(path: &str) -> bool {
509    crate::parser::test_detect::is_test_path(path)
510}
511
512/// Find arity mismatches across all Calls edges in the graph.
513pub fn find_arity_mismatches(
514    graph: &EntityGraph,
515    all_entities: &[SemanticEntity],
516) -> Vec<ArityMismatch> {
517    let entity_by_id: HashMap<&str, &SemanticEntity> = all_entities
518        .iter()
519        .map(|e| (e.id.as_str(), e))
520        .collect();
521
522    // Build name → count map to detect ambiguous names
523    let mut name_counts: HashMap<&str, usize> = HashMap::new();
524    for e in all_entities {
525        if matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function") {
526            *name_counts.entry(&e.name).or_insert(0) += 1;
527        }
528    }
529
530    // Cache param info per callee entity
531    let mut param_cache: HashMap<String, Option<ParamInfo>> = HashMap::new();
532
533    let mut mismatches = Vec::new();
534
535    for edge in &graph.edges {
536        if edge.ref_type != RefType::Calls {
537            continue;
538        }
539
540        let callee_info = match graph.entities.get(&edge.to_entity) {
541            Some(e) => e,
542            None => continue,
543        };
544
545        if !matches!(
546            callee_info.entity_type.as_str(),
547            "function" | "method" | "arrow_function"
548        ) {
549            continue;
550        }
551
552        // Skip ambiguous/common names where name-only matching is unreliable
553        if AMBIGUOUS_NAMES.contains(&callee_info.name.as_str()) {
554            continue;
555        }
556
557        // Skip callee names shared by multiple entities (overloads, trait impls)
558        if name_counts.get(callee_info.name.as_str()).copied().unwrap_or(0) > 1 {
559            continue;
560        }
561
562        // Skip test/fixture files
563        if is_test_or_fixture_path(&callee_info.file_path) {
564            continue;
565        }
566
567        let callee = match entity_by_id.get(edge.to_entity.as_str()) {
568            Some(e) => *e,
569            None => continue,
570        };
571
572        let caller = match entity_by_id.get(edge.from_entity.as_str()) {
573            Some(e) => *e,
574            None => continue,
575        };
576
577        // Skip callers in test/fixture files
578        if is_test_or_fixture_path(&caller.file_path) {
579            continue;
580        }
581
582        // Get callee param info (cached)
583        let param_info = param_cache
584            .entry(callee.id.clone())
585            .or_insert_with(|| extract_param_info_ts(&callee.content, &callee.file_path))
586            .clone();
587
588        let param_info = match param_info {
589            Some(pi) => pi,
590            None => continue,
591        };
592
593        // Skip variadic functions
594        if param_info.is_variadic {
595            continue;
596        }
597
598        for call_site in count_call_arg_sites_ts(&caller.content, &callee.name, &caller.file_path) {
599            if call_site.actual_args < param_info.min_params
600                || call_site.actual_args > param_info.max_params
601            {
602                mismatches.push(ArityMismatch {
603                    caller_entity: caller.name.clone(),
604                    callee_entity: callee.name.clone(),
605                    expected_min: param_info.min_params,
606                    expected_max: param_info.max_params,
607                    actual_args: call_site.actual_args,
608                    file_path: caller.file_path.clone(),
609                    line: caller.start_line + call_site.line_offset,
610                    is_variadic: false,
611                });
612            }
613        }
614    }
615
616    mismatches
617}
618
619/// Find callers broken by signature changes between old and new entities.
620/// Compares param counts of functions that exist in both old and new,
621/// then checks if any callers in new_graph pass the wrong arg count.
622pub fn find_broken_callers(
623    old_entities: &[SemanticEntity],
624    new_graph: &EntityGraph,
625    new_entities: &[SemanticEntity],
626) -> Vec<ArityMismatch> {
627    // Build old param info map: entity_id -> ParamInfo
628    let old_params: HashMap<String, Option<ParamInfo>> = old_entities
629        .iter()
630        .filter(|e| matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function"))
631        .map(|e| (e.id.clone(), extract_param_info_ts(&e.content, &e.file_path)))
632        .collect();
633
634    // Build new entity lookup
635    let new_by_id: HashMap<&str, &SemanticEntity> = new_entities
636        .iter()
637        .map(|e| (e.id.as_str(), e))
638        .collect();
639
640    // Find entities whose param counts changed
641    let mut changed_entities: Vec<&str> = Vec::new();
642    for new_entity in new_entities {
643        if !matches!(new_entity.entity_type.as_str(), "function" | "method" | "arrow_function") {
644            continue;
645        }
646        let new_info = match extract_param_info_ts(&new_entity.content, &new_entity.file_path) {
647            Some(pi) => pi,
648            None => continue,
649        };
650        if let Some(Some(old_info)) = old_params.get(&new_entity.id) {
651            if old_info.min_params != new_info.min_params
652                || old_info.max_params != new_info.max_params
653            {
654                changed_entities.push(&new_entity.id);
655            }
656        }
657    }
658
659    if changed_entities.is_empty() {
660        return Vec::new();
661    }
662
663    // Check all callers of changed entities
664    let mut mismatches = Vec::new();
665
666    for edge in &new_graph.edges {
667        if edge.ref_type != RefType::Calls {
668            continue;
669        }
670        if !changed_entities.contains(&edge.to_entity.as_str()) {
671            continue;
672        }
673
674        let callee = match new_by_id.get(edge.to_entity.as_str()) {
675            Some(e) => *e,
676            None => continue,
677        };
678        let caller = match new_by_id.get(edge.from_entity.as_str()) {
679            Some(e) => *e,
680            None => continue,
681        };
682
683        let new_info = match extract_param_info_ts(&callee.content, &callee.file_path) {
684            Some(pi) => pi,
685            None => continue,
686        };
687
688        if new_info.is_variadic {
689            continue;
690        }
691
692        for call_site in count_call_arg_sites_ts(&caller.content, &callee.name, &caller.file_path) {
693            if call_site.actual_args < new_info.min_params
694                || call_site.actual_args > new_info.max_params
695            {
696                mismatches.push(ArityMismatch {
697                    caller_entity: caller.name.clone(),
698                    callee_entity: callee.name.clone(),
699                    expected_min: new_info.min_params,
700                    expected_max: new_info.max_params,
701                    actual_args: call_site.actual_args,
702                    file_path: caller.file_path.clone(),
703                    line: caller.start_line + call_site.line_offset,
704                    is_variadic: false,
705                });
706            }
707        }
708    }
709
710    mismatches
711}
712
713// ─── String-based helpers (kept for backward compatibility) ──────────────────
714
715/// Extract param count from the first line of a function/method.
716fn extract_param_count(content: &str) -> usize {
717    let first_line = content.lines().next().unwrap_or("");
718
719    let open = match first_line.find('(') {
720        Some(i) => i,
721        None => return 0,
722    };
723
724    let after_open = &first_line[open + 1..];
725    let close = match find_matching_paren(after_open) {
726        Some(i) => i,
727        None => return 0,
728    };
729
730    let params_str = after_open[..close].trim();
731    if params_str.is_empty() {
732        return 0;
733    }
734
735    count_top_level_commas(params_str) + 1
736}
737
738/// Count arguments at a call site: find `callee_name(...)` in content and count args.
739#[cfg(test)]
740fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
741    count_all_call_args(content, callee_name).into_iter().next()
742}
743
744fn count_all_call_args(content: &str, callee_name: &str) -> Vec<usize> {
745    let bytes = content.as_bytes();
746    let name_bytes = callee_name.as_bytes();
747    let mut search_start = 0;
748    let mut counts = Vec::new();
749
750    while let Some(rel_pos) = content[search_start..].find(callee_name) {
751        let pos = search_start + rel_pos;
752        let after = pos + name_bytes.len();
753
754        let is_boundary = pos == 0 || {
755            let prev = bytes[pos - 1];
756            !prev.is_ascii_alphanumeric() && prev != b'_'
757        };
758
759        let mut next_search_start = pos + 1;
760        if is_boundary && after < bytes.len() && bytes[after] == b'(' {
761            let args_start_index = after + 1;
762            let args_start = &content[args_start_index..];
763            if let Some(close) = find_matching_paren(args_start) {
764                let args_str = args_start[..close].trim();
765                if args_str.is_empty() {
766                    counts.push(0);
767                } else {
768                    counts.push(count_top_level_commas(args_str) + 1);
769                }
770                next_search_start = args_start_index + close + 1;
771            } else {
772                next_search_start = after;
773            }
774        }
775
776        search_start = next_search_start;
777        while search_start < content.len() && !content.is_char_boundary(search_start) {
778            search_start += 1;
779        }
780    }
781
782    counts
783}
784
785fn find_matching_paren(s: &str) -> Option<usize> {
786    let mut depth = 0i32;
787    for (i, ch) in s.char_indices() {
788        match ch {
789            '(' => depth += 1,
790            ')' => {
791                if depth == 0 {
792                    return Some(i);
793                }
794                depth -= 1;
795            }
796            _ => {}
797        }
798    }
799    None
800}
801
802fn count_top_level_commas(s: &str) -> usize {
803    let mut depth = 0i32;
804    let mut count = 0;
805    for ch in s.chars() {
806        match ch {
807            '(' | '[' | '{' | '<' => depth += 1,
808            ')' | ']' | '}' | '>' => depth -= 1,
809            ',' if depth == 0 => count += 1,
810            _ => {}
811        }
812    }
813    count
814}
815
816#[cfg(test)]
817mod tests {
818    use super::*;
819
820    #[test]
821    fn test_extract_param_count_basic() {
822        assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
823        assert_eq!(extract_param_count("function foo() {"), 0);
824        assert_eq!(extract_param_count("def bar(self, x):"), 2);
825        assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
826    }
827
828    #[test]
829    fn test_extract_param_count_nested() {
830        assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
831    }
832
833    #[test]
834    fn test_count_call_args() {
835        assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
836        assert_eq!(count_call_args("foo()", "foo"), Some(0));
837        assert_eq!(count_call_args("bar(1)", "foo"), None);
838        assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
839    }
840
841    #[test]
842    fn test_count_all_call_args() {
843        assert_eq!(count_all_call_args("foo(1, 2); foo(1);", "foo"), vec![2, 1]);
844    }
845
846    #[test]
847    fn test_count_all_call_args_resumes_after_unclosed_candidate() {
848        assert_eq!(count_all_call_args("foo(\nfoo(1, 2)", "foo"), vec![2]);
849    }
850
851    #[test]
852    fn test_count_call_args_multibyte_utf8() {
853        assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
854        assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
855        assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
856    }
857
858    #[test]
859    fn test_extract_param_info_python() {
860        let info = extract_param_info_ts(
861            "def foo(a, b, c=3):\n    pass",
862            "test.py",
863        )
864        .unwrap();
865        assert_eq!(info.min_params, 2);
866        assert_eq!(info.max_params, 3);
867        assert!(!info.is_variadic);
868    }
869
870    #[test]
871    fn test_extract_param_info_python_self() {
872        let info = extract_param_info_ts(
873            "def foo(self, a, b):\n    pass",
874            "test.py",
875        )
876        .unwrap();
877        assert_eq!(info.min_params, 2);
878        assert_eq!(info.max_params, 2);
879    }
880
881    #[test]
882    fn test_extract_param_info_python_variadic() {
883        let info = extract_param_info_ts(
884            "def foo(a, *args, **kwargs):\n    pass",
885            "test.py",
886        )
887        .unwrap();
888        assert!(info.is_variadic);
889    }
890
891    #[test]
892    fn test_extract_param_info_typescript() {
893        let info = extract_param_info_ts(
894            "function foo(a: number, b: string, c?: boolean): void {}",
895            "test.ts",
896        )
897        .unwrap();
898        assert_eq!(info.min_params, 2);
899        assert_eq!(info.max_params, 3);
900        assert!(!info.is_variadic);
901    }
902
903    #[test]
904    fn test_extract_param_info_typescript_default_parameter() {
905        let info = extract_param_info_ts(
906            "function foo(a: number, b = 1): number { return a + b; }",
907            "test.ts",
908        )
909        .unwrap();
910        assert_eq!(info.min_params, 1);
911        assert_eq!(info.max_params, 2);
912        assert!(!info.is_variadic);
913    }
914
915    #[test]
916    fn test_extract_param_info_javascript_default_parameter() {
917        let info =
918            extract_param_info_ts("function foo(a, b = 1) { return a + b; }", "test.js").unwrap();
919        assert_eq!(info.min_params, 1);
920        assert_eq!(info.max_params, 2);
921        assert!(!info.is_variadic);
922    }
923
924    #[test]
925    fn test_extract_param_info_javascript_required_parameters() {
926        let info = extract_param_info_ts("function foo(a, b) { return a + b; }", "test.js")
927            .unwrap();
928        assert_eq!(info.min_params, 2);
929        assert_eq!(info.max_params, 2);
930        assert!(!info.is_variadic);
931    }
932
933    #[test]
934    fn test_extract_param_info_typescript_arrow_default_parameter() {
935        let info = extract_param_info_ts(
936            "const foo = (a: number, b = 1): number => a + b;",
937            "test.ts",
938        )
939        .unwrap();
940        assert_eq!(info.min_params, 1);
941        assert_eq!(info.max_params, 2);
942        assert!(!info.is_variadic);
943    }
944
945    #[test]
946    fn test_extract_param_info_rust() {
947        let info = extract_param_info_ts(
948            "fn foo(&self, a: i32, b: String) -> bool { true }",
949            "test.rs",
950        )
951        .unwrap();
952        assert_eq!(info.min_params, 2);
953        assert_eq!(info.max_params, 2);
954    }
955
956    #[test]
957    fn test_extract_param_info_go() {
958        let info = extract_param_info_ts(
959            "func foo(a string, b int) error { return nil }",
960            "test.go",
961        )
962        .unwrap();
963        assert_eq!(info.min_params, 2);
964        assert_eq!(info.max_params, 2);
965    }
966
967    #[test]
968    fn test_extract_param_info_go_grouped_params() {
969        let info = extract_param_info_ts(
970            "func foo(a, b int, c string) int { return a + b }",
971            "test.go",
972        )
973        .unwrap();
974        assert_eq!(info.min_params, 3);
975        assert_eq!(info.max_params, 3);
976    }
977
978    #[test]
979    fn test_extract_param_info_go_unnamed_params() {
980        let info = extract_param_info_ts(
981            "func foo(int, string) bool { return true }",
982            "test.go",
983        )
984        .unwrap();
985        assert_eq!(info.min_params, 2);
986        assert_eq!(info.max_params, 2);
987    }
988
989    #[test]
990    fn test_count_call_args_ts() {
991        let count = count_call_args_ts(
992            "function bar() { foo(1, 2, 3); }",
993            "foo",
994            "test.ts",
995        );
996        assert_eq!(count, Some(3));
997    }
998
999    #[test]
1000    fn test_count_call_args_ts_method() {
1001        let count = count_call_args_ts(
1002            "function bar() { obj.foo(1, 2); }",
1003            "foo",
1004            "test.ts",
1005        );
1006        assert_eq!(count, Some(2));
1007    }
1008
1009    #[test]
1010    fn test_count_call_arg_sites_ts_repeated_calls() {
1011        let sites =
1012            count_call_arg_sites_ts("def bar():\n    foo(1, 2)\n    foo(1)\n", "foo", "test.py");
1013        assert_eq!(
1014            sites,
1015            vec![
1016                CallArgCount {
1017                    actual_args: 2,
1018                    line_offset: 1,
1019                },
1020                CallArgCount {
1021                    actual_args: 1,
1022                    line_offset: 2,
1023                },
1024            ]
1025        );
1026    }
1027}