Skip to main content

sem_core/parser/
verify.rs

1//! Contract verification: check that callers pass the correct number of
2//! arguments to callees. Uses tree-sitter AST for accurate param/arg counting.
3
4use std::collections::HashMap;
5use std::path::Path;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityGraph, RefType};
9use crate::parser::plugins::code::languages::get_language_config;
10use crate::parser::registry::ParserRegistry;
11
12#[derive(Debug, Clone)]
13pub struct ContractViolation {
14    pub entity_name: String,
15    pub file_path: String,
16    pub expected_params: usize,
17    pub caller_name: String,
18    pub caller_file: String,
19    pub actual_args: usize,
20}
21
22/// Result of tree-sitter based parameter analysis.
23#[derive(Debug, Clone)]
24pub struct ParamInfo {
25    pub min_params: usize,
26    pub max_params: usize,
27    pub is_variadic: bool,
28}
29
30/// Arity mismatch found across the dependency graph.
31#[derive(Debug, Clone)]
32pub struct ArityMismatch {
33    pub caller_entity: String,
34    pub callee_entity: String,
35    pub expected_min: usize,
36    pub expected_max: usize,
37    pub actual_args: usize,
38    pub file_path: String,
39    pub line: usize,
40    pub is_variadic: bool,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44struct CallArgCount {
45    actual_args: usize,
46    line_offset: usize,
47}
48
49/// Verify function call contracts across the codebase.
50pub fn verify_contracts(
51    root: &Path,
52    file_paths: &[String],
53    registry: &ParserRegistry,
54    target_file: Option<&str>,
55) -> Vec<ContractViolation> {
56    let (graph, _) = EntityGraph::build(root, file_paths, registry);
57
58    let mut content_map: HashMap<String, String> = HashMap::new();
59    for fp in file_paths {
60        let full = root.join(fp);
61        let content = match std::fs::read_to_string(&full) {
62            Ok(c) => c,
63            Err(_) => continue,
64        };
65        for entity in registry.extract_entities(fp, &content) {
66            content_map.insert(entity.id.clone(), entity.content.clone());
67        }
68    }
69
70    let mut violations = Vec::new();
71
72    for edge in &graph.edges {
73        if edge.ref_type != RefType::Calls {
74            continue;
75        }
76
77        let callee = match graph.entities.get(&edge.to_entity) {
78            Some(e) => e,
79            None => continue,
80        };
81
82        if let Some(tf) = target_file {
83            if callee.file_path != tf {
84                continue;
85            }
86        }
87
88        if !matches!(
89            callee.entity_type.as_str(),
90            "function" | "method" | "arrow_function"
91        ) {
92            continue;
93        }
94
95        let callee_content = match content_map.get(&edge.to_entity) {
96            Some(c) => c,
97            None => continue,
98        };
99
100        let caller = match graph.entities.get(&edge.from_entity) {
101            Some(e) => e,
102            None => continue,
103        };
104
105        let caller_content = match content_map.get(&edge.from_entity) {
106            Some(c) => c,
107            None => continue,
108        };
109
110        let expected = extract_param_count(callee_content);
111        if expected == 0 {
112            continue;
113        }
114
115        for actual in count_all_call_args(caller_content, &callee.name) {
116            if actual != expected {
117                violations.push(ContractViolation {
118                    entity_name: callee.name.clone(),
119                    file_path: callee.file_path.clone(),
120                    expected_params: expected,
121                    caller_name: caller.name.clone(),
122                    caller_file: caller.file_path.clone(),
123                    actual_args: actual,
124                });
125            }
126        }
127    }
128
129    violations
130}
131
132/// Like `verify_contracts`, but accepts a pre-built graph + entities.
133pub fn verify_contracts_with_graph(
134    graph: &EntityGraph,
135    all_entities: &[SemanticEntity],
136    target_file: Option<&str>,
137) -> Vec<ContractViolation> {
138    let content_map: HashMap<String, String> = all_entities
139        .iter()
140        .map(|e| (e.id.clone(), e.content.clone()))
141        .collect();
142
143    let mut violations = Vec::new();
144
145    for edge in &graph.edges {
146        if edge.ref_type != RefType::Calls {
147            continue;
148        }
149
150        let callee = match graph.entities.get(&edge.to_entity) {
151            Some(e) => e,
152            None => continue,
153        };
154
155        if let Some(tf) = target_file {
156            if callee.file_path != tf {
157                continue;
158            }
159        }
160
161        if !matches!(
162            callee.entity_type.as_str(),
163            "function" | "method" | "arrow_function"
164        ) {
165            continue;
166        }
167
168        let callee_content = match content_map.get(&edge.to_entity) {
169            Some(c) => c,
170            None => continue,
171        };
172
173        let caller = match graph.entities.get(&edge.from_entity) {
174            Some(e) => e,
175            None => continue,
176        };
177
178        let caller_content = match content_map.get(&edge.from_entity) {
179            Some(c) => c,
180            None => continue,
181        };
182
183        let expected = extract_param_count(callee_content);
184        if expected == 0 {
185            continue;
186        }
187
188        for actual in count_all_call_args(caller_content, &callee.name) {
189            if actual != expected {
190                violations.push(ContractViolation {
191                    entity_name: callee.name.clone(),
192                    file_path: callee.file_path.clone(),
193                    expected_params: expected,
194                    caller_name: caller.name.clone(),
195                    caller_file: caller.file_path.clone(),
196                    actual_args: actual,
197                });
198            }
199        }
200    }
201
202    violations
203}
204
205// ─── Tree-sitter based arity analysis ───────────────────────────────────────
206
207fn lang_from_ext(ext: &str) -> &'static str {
208    match ext {
209        ".py" | ".pyi" => "python",
210        ".ts" | ".tsx" | ".mts" | ".cts" => "typescript",
211        ".js" | ".jsx" | ".mjs" | ".cjs" => "javascript",
212        ".rs" => "rust",
213        ".go" => "go",
214        _ => "unknown",
215    }
216}
217
218/// Extract parameter info from entity content using tree-sitter.
219pub fn extract_param_info_ts(content: &str, file_path: &str) -> Option<ParamInfo> {
220    let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
221    let lang = lang_from_ext(ext);
222    if lang == "unknown" {
223        return None;
224    }
225    let config = get_language_config(ext)?;
226    let language = (config.get_language)()?;
227
228    let mut parser = tree_sitter::Parser::new();
229    let _ = parser.set_language(&language);
230    let tree = parser.parse(content.as_bytes(), None)?;
231
232    extract_param_info_from_node(tree.root_node(), content.as_bytes(), lang)
233}
234
235fn extract_param_info_from_node(
236    root: tree_sitter::Node,
237    source: &[u8],
238    lang: &str,
239) -> Option<ParamInfo> {
240    // Find the first function-like node
241    let func_node = find_first_function(root)?;
242    let params_node = func_node.child_by_field_name("parameters")?;
243
244    let mut min_params = 0usize;
245    let mut max_params = 0usize;
246    let mut is_variadic = false;
247
248    let mut cursor = params_node.walk();
249    for child in params_node.named_children(&mut cursor) {
250        let kind = child.kind();
251        match lang {
252            "python" => {
253                if kind == "identifier" {
254                    let name = child.utf8_text(source).unwrap_or("");
255                    if name == "self" || name == "cls" {
256                        continue;
257                    }
258                    min_params += 1;
259                    max_params += 1;
260                } else if kind == "typed_parameter" {
261                    let name = child
262                        .child_by_field_name("name")
263                        .or_else(|| child.named_child(0))
264                        .and_then(|n| n.utf8_text(source).ok())
265                        .unwrap_or("");
266                    if name == "self" || name == "cls" {
267                        continue;
268                    }
269                    min_params += 1;
270                    max_params += 1;
271                } else if kind == "default_parameter" || kind == "typed_default_parameter" {
272                    max_params += 1;
273                } else if kind == "list_splat_pattern" || kind == "dictionary_splat_pattern" {
274                    is_variadic = true;
275                }
276            }
277            "typescript" => {
278                if kind == "required_parameter" {
279                    max_params += 1;
280                    if !has_js_ts_default_value(child) {
281                        min_params += 1;
282                    }
283                } else if kind == "optional_parameter" {
284                    max_params += 1;
285                } else if kind == "rest_pattern" {
286                    is_variadic = true;
287                }
288            }
289            "javascript" => {
290                if kind == "rest_pattern" {
291                    is_variadic = true;
292                } else if matches!(kind, "identifier" | "formal_parameter" | "assignment_pattern") {
293                    max_params += 1;
294                    if !has_js_ts_default_value(child) {
295                        min_params += 1;
296                    }
297                }
298            }
299            "rust" => {
300                if kind == "parameter" {
301                    let pat = child
302                        .child_by_field_name("pattern")
303                        .and_then(|n| n.utf8_text(source).ok())
304                        .unwrap_or("");
305                    // Skip self/&self/&mut self
306                    let base = pat.trim_start_matches('&').trim();
307                    let base = base.strip_prefix("mut ").unwrap_or(base).trim();
308                    if base == "self" {
309                        continue;
310                    }
311                    min_params += 1;
312                    max_params += 1;
313                } else if kind == "self_parameter" {
314                    continue;
315                }
316            }
317            "go" => {
318                if kind == "parameter_declaration" {
319                    let type_node = child.child_by_field_name("type");
320                    let type_text = type_node.and_then(|n| n.utf8_text(source).ok()).unwrap_or("");
321                    let param_text = child.utf8_text(source).unwrap_or("");
322                    if type_text.starts_with("...") || param_text.contains("...") {
323                        is_variadic = true;
324                    } else {
325                        let count = count_go_parameter_declaration_arity(child);
326                        min_params += count;
327                        max_params += count;
328                    }
329                }
330            }
331            _ => {}
332        }
333    }
334
335    Some(ParamInfo {
336        min_params,
337        max_params,
338        is_variadic,
339    })
340}
341
342fn has_js_ts_default_value(node: tree_sitter::Node) -> bool {
343    let mut cursor = node.walk();
344    let has_assignment_child = node
345        .named_children(&mut cursor)
346        .any(|child| child.kind() == "assignment_pattern");
347    node.kind() == "assignment_pattern"
348        || node.child_by_field_name("value").is_some()
349        || has_assignment_child
350}
351
352fn count_go_parameter_declaration_arity(node: tree_sitter::Node) -> usize {
353    let mut name_cursor = node.walk();
354    let field_names = node
355        .children_by_field_name("name", &mut name_cursor)
356        .count();
357    if field_names > 0 {
358        return field_names;
359    }
360
361    let type_range = match node.child_by_field_name("type") {
362        Some(type_node) => (type_node.start_byte(), type_node.end_byte()),
363        None => return 1,
364    };
365    let mut cursor = node.walk();
366    let identifier_names = node
367        .named_children(&mut cursor)
368        .filter(|child| {
369            child.kind() == "identifier" && type_range != (child.start_byte(), child.end_byte())
370        })
371        .count();
372    if identifier_names > 0 {
373        identifier_names
374    } else {
375        1
376    }
377}
378
379fn find_first_function(root: tree_sitter::Node) -> Option<tree_sitter::Node> {
380    let mut worklist = vec![root];
381    while let Some(node) = worklist.pop() {
382        let kind = node.kind();
383        if matches!(
384            kind,
385            "function_definition"
386                | "function_item"
387                | "function_declaration"
388                | "method_definition"
389                | "method_declaration"
390                | "arrow_function"
391        ) {
392            return Some(node);
393        }
394        let mut cursor = node.walk();
395        let children: Vec<_> = node.named_children(&mut cursor).collect();
396        for child in children.into_iter().rev() {
397            worklist.push(child);
398        }
399    }
400    None
401}
402
403/// Count call arguments at a specific call site using tree-sitter.
404pub fn count_call_args_ts(
405    caller_content: &str,
406    callee_name: &str,
407    file_path: &str,
408) -> Option<usize> {
409    count_call_arg_sites_ts(caller_content, callee_name, file_path)
410        .into_iter()
411        .next()
412        .map(|site| site.actual_args)
413}
414
415fn count_call_arg_sites_ts(
416    caller_content: &str,
417    callee_name: &str,
418    file_path: &str,
419) -> Vec<CallArgCount> {
420    let ext = match file_path.rfind('.').map(|i| &file_path[i..]) {
421        Some(ext) => ext,
422        None => return Vec::new(),
423    };
424    let config = match get_language_config(ext) {
425        Some(config) => config,
426        None => return Vec::new(),
427    };
428    let language = match (config.get_language)() {
429        Some(language) => language,
430        None => return Vec::new(),
431    };
432
433    let mut parser = tree_sitter::Parser::new();
434    let _ = parser.set_language(&language);
435    let tree = match parser.parse(caller_content.as_bytes(), None) {
436        Some(tree) => tree,
437        None => return Vec::new(),
438    };
439
440    find_call_arg_counts(tree.root_node(), caller_content.as_bytes(), callee_name)
441}
442
443fn find_call_arg_counts(
444    root: tree_sitter::Node,
445    source: &[u8],
446    callee_name: &str,
447) -> Vec<CallArgCount> {
448    let mut sites = Vec::new();
449    let mut worklist = vec![root];
450    while let Some(node) = worklist.pop() {
451        let kind = node.kind();
452
453        if kind == "call" || kind == "call_expression" {
454            if let Some(func) = node.child_by_field_name("function") {
455                let func_name = match func.kind() {
456                    "identifier" => func.utf8_text(source).unwrap_or(""),
457                    "attribute" | "member_expression" | "field_expression" => func
458                        .child_by_field_name("attribute")
459                        .or_else(|| func.child_by_field_name("property"))
460                        .or_else(|| func.child_by_field_name("field"))
461                        .and_then(|n| n.utf8_text(source).ok())
462                        .unwrap_or(""),
463                    "selector_expression" => func
464                        .child_by_field_name("field")
465                        .and_then(|n| n.utf8_text(source).ok())
466                        .unwrap_or(""),
467                    "scoped_identifier" => {
468                        let text = func.utf8_text(source).unwrap_or("");
469                        text.rsplit("::").next().unwrap_or("")
470                    }
471                    _ => "",
472                };
473
474                if func_name == callee_name {
475                    if let Some(args) = node.child_by_field_name("arguments") {
476                        let mut actual_args = 0;
477                        let mut cursor = args.walk();
478                        for child in args.named_children(&mut cursor) {
479                            // Skip comment nodes
480                            if !child.kind().contains("comment") {
481                                actual_args += 1;
482                            }
483                        }
484                        sites.push(CallArgCount {
485                            actual_args,
486                            line_offset: node.start_position().row,
487                        });
488                    }
489                }
490            }
491        }
492
493        let mut cursor = node.walk();
494        let children: Vec<_> = node.named_children(&mut cursor).collect();
495        for child in children.into_iter().rev() {
496            worklist.push(child);
497        }
498    }
499    sites
500}
501
502/// Names too common/ambiguous for reliable arity checking (constructors, builtins).
503const AMBIGUOUS_NAMES: &[&str] = &[
504    "new", "constructor", "toString", "valueOf", "init", "__init__",
505    "apply", "call", "bind", "get", "set", "run", "execute", "create",
506];
507
508/// Path components that indicate test/fixture files (not production source).
509const TEST_PATH_MARKERS: &[&str] = &[
510    "test", "tests", "spec", "specs", "fixtures", "fixture",
511    "benchmarks", "benchmark", "__tests__", "__mocks__",
512];
513
514fn is_test_or_fixture_path(path: &str) -> bool {
515    path.split('/').any(|component| TEST_PATH_MARKERS.contains(&component))
516}
517
518/// Find arity mismatches across all Calls edges in the graph.
519pub fn find_arity_mismatches(
520    graph: &EntityGraph,
521    all_entities: &[SemanticEntity],
522) -> Vec<ArityMismatch> {
523    let entity_by_id: HashMap<&str, &SemanticEntity> = all_entities
524        .iter()
525        .map(|e| (e.id.as_str(), e))
526        .collect();
527
528    // Build name → count map to detect ambiguous names
529    let mut name_counts: HashMap<&str, usize> = HashMap::new();
530    for e in all_entities {
531        if matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function") {
532            *name_counts.entry(&e.name).or_insert(0) += 1;
533        }
534    }
535
536    // Cache param info per callee entity
537    let mut param_cache: HashMap<String, Option<ParamInfo>> = HashMap::new();
538
539    let mut mismatches = Vec::new();
540
541    for edge in &graph.edges {
542        if edge.ref_type != RefType::Calls {
543            continue;
544        }
545
546        let callee_info = match graph.entities.get(&edge.to_entity) {
547            Some(e) => e,
548            None => continue,
549        };
550
551        if !matches!(
552            callee_info.entity_type.as_str(),
553            "function" | "method" | "arrow_function"
554        ) {
555            continue;
556        }
557
558        // Skip ambiguous/common names where name-only matching is unreliable
559        if AMBIGUOUS_NAMES.contains(&callee_info.name.as_str()) {
560            continue;
561        }
562
563        // Skip callee names shared by multiple entities (overloads, trait impls)
564        if name_counts.get(callee_info.name.as_str()).copied().unwrap_or(0) > 1 {
565            continue;
566        }
567
568        // Skip test/fixture files
569        if is_test_or_fixture_path(&callee_info.file_path) {
570            continue;
571        }
572
573        let callee = match entity_by_id.get(edge.to_entity.as_str()) {
574            Some(e) => *e,
575            None => continue,
576        };
577
578        let caller = match entity_by_id.get(edge.from_entity.as_str()) {
579            Some(e) => *e,
580            None => continue,
581        };
582
583        // Skip callers in test/fixture files
584        if is_test_or_fixture_path(&caller.file_path) {
585            continue;
586        }
587
588        // Get callee param info (cached)
589        let param_info = param_cache
590            .entry(callee.id.clone())
591            .or_insert_with(|| extract_param_info_ts(&callee.content, &callee.file_path))
592            .clone();
593
594        let param_info = match param_info {
595            Some(pi) => pi,
596            None => continue,
597        };
598
599        // Skip variadic functions
600        if param_info.is_variadic {
601            continue;
602        }
603
604        for call_site in count_call_arg_sites_ts(&caller.content, &callee.name, &caller.file_path) {
605            if call_site.actual_args < param_info.min_params
606                || call_site.actual_args > param_info.max_params
607            {
608                mismatches.push(ArityMismatch {
609                    caller_entity: caller.name.clone(),
610                    callee_entity: callee.name.clone(),
611                    expected_min: param_info.min_params,
612                    expected_max: param_info.max_params,
613                    actual_args: call_site.actual_args,
614                    file_path: caller.file_path.clone(),
615                    line: caller.start_line + call_site.line_offset,
616                    is_variadic: false,
617                });
618            }
619        }
620    }
621
622    mismatches
623}
624
625/// Find callers broken by signature changes between old and new entities.
626/// Compares param counts of functions that exist in both old and new,
627/// then checks if any callers in new_graph pass the wrong arg count.
628pub fn find_broken_callers(
629    old_entities: &[SemanticEntity],
630    new_graph: &EntityGraph,
631    new_entities: &[SemanticEntity],
632) -> Vec<ArityMismatch> {
633    // Build old param info map: entity_id -> ParamInfo
634    let old_params: HashMap<String, Option<ParamInfo>> = old_entities
635        .iter()
636        .filter(|e| matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function"))
637        .map(|e| (e.id.clone(), extract_param_info_ts(&e.content, &e.file_path)))
638        .collect();
639
640    // Build new entity lookup
641    let new_by_id: HashMap<&str, &SemanticEntity> = new_entities
642        .iter()
643        .map(|e| (e.id.as_str(), e))
644        .collect();
645
646    // Find entities whose param counts changed
647    let mut changed_entities: Vec<&str> = Vec::new();
648    for new_entity in new_entities {
649        if !matches!(new_entity.entity_type.as_str(), "function" | "method" | "arrow_function") {
650            continue;
651        }
652        let new_info = match extract_param_info_ts(&new_entity.content, &new_entity.file_path) {
653            Some(pi) => pi,
654            None => continue,
655        };
656        if let Some(Some(old_info)) = old_params.get(&new_entity.id) {
657            if old_info.min_params != new_info.min_params
658                || old_info.max_params != new_info.max_params
659            {
660                changed_entities.push(&new_entity.id);
661            }
662        }
663    }
664
665    if changed_entities.is_empty() {
666        return Vec::new();
667    }
668
669    // Check all callers of changed entities
670    let mut mismatches = Vec::new();
671
672    for edge in &new_graph.edges {
673        if edge.ref_type != RefType::Calls {
674            continue;
675        }
676        if !changed_entities.contains(&edge.to_entity.as_str()) {
677            continue;
678        }
679
680        let callee = match new_by_id.get(edge.to_entity.as_str()) {
681            Some(e) => *e,
682            None => continue,
683        };
684        let caller = match new_by_id.get(edge.from_entity.as_str()) {
685            Some(e) => *e,
686            None => continue,
687        };
688
689        let new_info = match extract_param_info_ts(&callee.content, &callee.file_path) {
690            Some(pi) => pi,
691            None => continue,
692        };
693
694        if new_info.is_variadic {
695            continue;
696        }
697
698        for call_site in count_call_arg_sites_ts(&caller.content, &callee.name, &caller.file_path) {
699            if call_site.actual_args < new_info.min_params
700                || call_site.actual_args > new_info.max_params
701            {
702                mismatches.push(ArityMismatch {
703                    caller_entity: caller.name.clone(),
704                    callee_entity: callee.name.clone(),
705                    expected_min: new_info.min_params,
706                    expected_max: new_info.max_params,
707                    actual_args: call_site.actual_args,
708                    file_path: caller.file_path.clone(),
709                    line: caller.start_line + call_site.line_offset,
710                    is_variadic: false,
711                });
712            }
713        }
714    }
715
716    mismatches
717}
718
719// ─── String-based helpers (kept for backward compatibility) ──────────────────
720
721/// Extract param count from the first line of a function/method.
722fn extract_param_count(content: &str) -> usize {
723    let first_line = content.lines().next().unwrap_or("");
724
725    let open = match first_line.find('(') {
726        Some(i) => i,
727        None => return 0,
728    };
729
730    let after_open = &first_line[open + 1..];
731    let close = match find_matching_paren(after_open) {
732        Some(i) => i,
733        None => return 0,
734    };
735
736    let params_str = after_open[..close].trim();
737    if params_str.is_empty() {
738        return 0;
739    }
740
741    count_top_level_commas(params_str) + 1
742}
743
744/// Count arguments at a call site: find `callee_name(...)` in content and count args.
745#[cfg(test)]
746fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
747    count_all_call_args(content, callee_name).into_iter().next()
748}
749
750fn count_all_call_args(content: &str, callee_name: &str) -> Vec<usize> {
751    let bytes = content.as_bytes();
752    let name_bytes = callee_name.as_bytes();
753    let mut search_start = 0;
754    let mut counts = Vec::new();
755
756    while let Some(rel_pos) = content[search_start..].find(callee_name) {
757        let pos = search_start + rel_pos;
758        let after = pos + name_bytes.len();
759
760        let is_boundary = pos == 0 || {
761            let prev = bytes[pos - 1];
762            !prev.is_ascii_alphanumeric() && prev != b'_'
763        };
764
765        let mut next_search_start = pos + 1;
766        if is_boundary && after < bytes.len() && bytes[after] == b'(' {
767            let args_start_index = after + 1;
768            let args_start = &content[args_start_index..];
769            if let Some(close) = find_matching_paren(args_start) {
770                let args_str = args_start[..close].trim();
771                if args_str.is_empty() {
772                    counts.push(0);
773                } else {
774                    counts.push(count_top_level_commas(args_str) + 1);
775                }
776                next_search_start = args_start_index + close + 1;
777            } else {
778                next_search_start = after;
779            }
780        }
781
782        search_start = next_search_start;
783        while search_start < content.len() && !content.is_char_boundary(search_start) {
784            search_start += 1;
785        }
786    }
787
788    counts
789}
790
791fn find_matching_paren(s: &str) -> Option<usize> {
792    let mut depth = 0i32;
793    for (i, ch) in s.char_indices() {
794        match ch {
795            '(' => depth += 1,
796            ')' => {
797                if depth == 0 {
798                    return Some(i);
799                }
800                depth -= 1;
801            }
802            _ => {}
803        }
804    }
805    None
806}
807
808fn count_top_level_commas(s: &str) -> usize {
809    let mut depth = 0i32;
810    let mut count = 0;
811    for ch in s.chars() {
812        match ch {
813            '(' | '[' | '{' | '<' => depth += 1,
814            ')' | ']' | '}' | '>' => depth -= 1,
815            ',' if depth == 0 => count += 1,
816            _ => {}
817        }
818    }
819    count
820}
821
822#[cfg(test)]
823mod tests {
824    use super::*;
825
826    #[test]
827    fn test_extract_param_count_basic() {
828        assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
829        assert_eq!(extract_param_count("function foo() {"), 0);
830        assert_eq!(extract_param_count("def bar(self, x):"), 2);
831        assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
832    }
833
834    #[test]
835    fn test_extract_param_count_nested() {
836        assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
837    }
838
839    #[test]
840    fn test_count_call_args() {
841        assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
842        assert_eq!(count_call_args("foo()", "foo"), Some(0));
843        assert_eq!(count_call_args("bar(1)", "foo"), None);
844        assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
845    }
846
847    #[test]
848    fn test_count_all_call_args() {
849        assert_eq!(count_all_call_args("foo(1, 2); foo(1);", "foo"), vec![2, 1]);
850    }
851
852    #[test]
853    fn test_count_all_call_args_resumes_after_unclosed_candidate() {
854        assert_eq!(count_all_call_args("foo(\nfoo(1, 2)", "foo"), vec![2]);
855    }
856
857    #[test]
858    fn test_count_call_args_multibyte_utf8() {
859        assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
860        assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
861        assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
862    }
863
864    #[test]
865    fn test_extract_param_info_python() {
866        let info = extract_param_info_ts(
867            "def foo(a, b, c=3):\n    pass",
868            "test.py",
869        )
870        .unwrap();
871        assert_eq!(info.min_params, 2);
872        assert_eq!(info.max_params, 3);
873        assert!(!info.is_variadic);
874    }
875
876    #[test]
877    fn test_extract_param_info_python_self() {
878        let info = extract_param_info_ts(
879            "def foo(self, a, b):\n    pass",
880            "test.py",
881        )
882        .unwrap();
883        assert_eq!(info.min_params, 2);
884        assert_eq!(info.max_params, 2);
885    }
886
887    #[test]
888    fn test_extract_param_info_python_variadic() {
889        let info = extract_param_info_ts(
890            "def foo(a, *args, **kwargs):\n    pass",
891            "test.py",
892        )
893        .unwrap();
894        assert!(info.is_variadic);
895    }
896
897    #[test]
898    fn test_extract_param_info_typescript() {
899        let info = extract_param_info_ts(
900            "function foo(a: number, b: string, c?: boolean): void {}",
901            "test.ts",
902        )
903        .unwrap();
904        assert_eq!(info.min_params, 2);
905        assert_eq!(info.max_params, 3);
906        assert!(!info.is_variadic);
907    }
908
909    #[test]
910    fn test_extract_param_info_typescript_default_parameter() {
911        let info = extract_param_info_ts(
912            "function foo(a: number, b = 1): number { return a + b; }",
913            "test.ts",
914        )
915        .unwrap();
916        assert_eq!(info.min_params, 1);
917        assert_eq!(info.max_params, 2);
918        assert!(!info.is_variadic);
919    }
920
921    #[test]
922    fn test_extract_param_info_javascript_default_parameter() {
923        let info =
924            extract_param_info_ts("function foo(a, b = 1) { return a + b; }", "test.js").unwrap();
925        assert_eq!(info.min_params, 1);
926        assert_eq!(info.max_params, 2);
927        assert!(!info.is_variadic);
928    }
929
930    #[test]
931    fn test_extract_param_info_javascript_required_parameters() {
932        let info = extract_param_info_ts("function foo(a, b) { return a + b; }", "test.js")
933            .unwrap();
934        assert_eq!(info.min_params, 2);
935        assert_eq!(info.max_params, 2);
936        assert!(!info.is_variadic);
937    }
938
939    #[test]
940    fn test_extract_param_info_typescript_arrow_default_parameter() {
941        let info = extract_param_info_ts(
942            "const foo = (a: number, b = 1): number => a + b;",
943            "test.ts",
944        )
945        .unwrap();
946        assert_eq!(info.min_params, 1);
947        assert_eq!(info.max_params, 2);
948        assert!(!info.is_variadic);
949    }
950
951    #[test]
952    fn test_extract_param_info_rust() {
953        let info = extract_param_info_ts(
954            "fn foo(&self, a: i32, b: String) -> bool { true }",
955            "test.rs",
956        )
957        .unwrap();
958        assert_eq!(info.min_params, 2);
959        assert_eq!(info.max_params, 2);
960    }
961
962    #[test]
963    fn test_extract_param_info_go() {
964        let info = extract_param_info_ts(
965            "func foo(a string, b int) error { return nil }",
966            "test.go",
967        )
968        .unwrap();
969        assert_eq!(info.min_params, 2);
970        assert_eq!(info.max_params, 2);
971    }
972
973    #[test]
974    fn test_extract_param_info_go_grouped_params() {
975        let info = extract_param_info_ts(
976            "func foo(a, b int, c string) int { return a + b }",
977            "test.go",
978        )
979        .unwrap();
980        assert_eq!(info.min_params, 3);
981        assert_eq!(info.max_params, 3);
982    }
983
984    #[test]
985    fn test_extract_param_info_go_unnamed_params() {
986        let info = extract_param_info_ts(
987            "func foo(int, string) bool { return true }",
988            "test.go",
989        )
990        .unwrap();
991        assert_eq!(info.min_params, 2);
992        assert_eq!(info.max_params, 2);
993    }
994
995    #[test]
996    fn test_count_call_args_ts() {
997        let count = count_call_args_ts(
998            "function bar() { foo(1, 2, 3); }",
999            "foo",
1000            "test.ts",
1001        );
1002        assert_eq!(count, Some(3));
1003    }
1004
1005    #[test]
1006    fn test_count_call_args_ts_method() {
1007        let count = count_call_args_ts(
1008            "function bar() { obj.foo(1, 2); }",
1009            "foo",
1010            "test.ts",
1011        );
1012        assert_eq!(count, Some(2));
1013    }
1014
1015    #[test]
1016    fn test_count_call_arg_sites_ts_repeated_calls() {
1017        let sites =
1018            count_call_arg_sites_ts("def bar():\n    foo(1, 2)\n    foo(1)\n", "foo", "test.py");
1019        assert_eq!(
1020            sites,
1021            vec![
1022                CallArgCount {
1023                    actual_args: 2,
1024                    line_offset: 1,
1025                },
1026                CallArgCount {
1027                    actual_args: 1,
1028                    line_offset: 2,
1029                },
1030            ]
1031        );
1032    }
1033}