1use std::collections::HashMap;
5use std::path::Path;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityGraph, RefType};
9use crate::parser::plugins::code::languages::get_language_config;
10use crate::parser::registry::ParserRegistry;
11
12#[derive(Debug, Clone)]
13pub struct ContractViolation {
14 pub entity_name: String,
15 pub file_path: String,
16 pub expected_params: usize,
17 pub caller_name: String,
18 pub caller_file: String,
19 pub actual_args: usize,
20}
21
22#[derive(Debug, Clone)]
24pub struct ParamInfo {
25 pub min_params: usize,
26 pub max_params: usize,
27 pub is_variadic: bool,
28}
29
30#[derive(Debug, Clone)]
32pub struct ArityMismatch {
33 pub caller_entity: String,
34 pub callee_entity: String,
35 pub expected_min: usize,
36 pub expected_max: usize,
37 pub actual_args: usize,
38 pub file_path: String,
39 pub line: usize,
40 pub is_variadic: bool,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
44struct CallArgCount {
45 actual_args: usize,
46 line_offset: usize,
47}
48
49pub fn verify_contracts(
51 root: &Path,
52 file_paths: &[String],
53 registry: &ParserRegistry,
54 target_file: Option<&str>,
55) -> Vec<ContractViolation> {
56 let (graph, _) = EntityGraph::build(root, file_paths, registry);
57
58 let mut content_map: HashMap<String, String> = HashMap::new();
59 for fp in file_paths {
60 let full = root.join(fp);
61 let content = match std::fs::read_to_string(&full) {
62 Ok(c) => c,
63 Err(_) => continue,
64 };
65 for entity in registry.extract_entities(fp, &content) {
66 content_map.insert(entity.id.clone(), entity.content.clone());
67 }
68 }
69
70 let mut violations = Vec::new();
71
72 for edge in &graph.edges {
73 if edge.ref_type != RefType::Calls {
74 continue;
75 }
76
77 let callee = match graph.entities.get(&edge.to_entity) {
78 Some(e) => e,
79 None => continue,
80 };
81
82 if let Some(tf) = target_file {
83 if callee.file_path != tf {
84 continue;
85 }
86 }
87
88 if !matches!(
89 callee.entity_type.as_str(),
90 "function" | "method" | "arrow_function"
91 ) {
92 continue;
93 }
94
95 let callee_content = match content_map.get(&edge.to_entity) {
96 Some(c) => c,
97 None => continue,
98 };
99
100 let caller = match graph.entities.get(&edge.from_entity) {
101 Some(e) => e,
102 None => continue,
103 };
104
105 let caller_content = match content_map.get(&edge.from_entity) {
106 Some(c) => c,
107 None => continue,
108 };
109
110 let expected = extract_param_count(callee_content);
111 if expected == 0 {
112 continue;
113 }
114
115 for actual in count_all_call_args(caller_content, &callee.name) {
116 if actual != expected {
117 violations.push(ContractViolation {
118 entity_name: callee.name.clone(),
119 file_path: callee.file_path.clone(),
120 expected_params: expected,
121 caller_name: caller.name.clone(),
122 caller_file: caller.file_path.clone(),
123 actual_args: actual,
124 });
125 }
126 }
127 }
128
129 violations
130}
131
132pub fn verify_contracts_with_graph(
134 graph: &EntityGraph,
135 all_entities: &[SemanticEntity],
136 target_file: Option<&str>,
137) -> Vec<ContractViolation> {
138 let content_map: HashMap<String, String> = all_entities
139 .iter()
140 .map(|e| (e.id.clone(), e.content.clone()))
141 .collect();
142
143 let mut violations = Vec::new();
144
145 for edge in &graph.edges {
146 if edge.ref_type != RefType::Calls {
147 continue;
148 }
149
150 let callee = match graph.entities.get(&edge.to_entity) {
151 Some(e) => e,
152 None => continue,
153 };
154
155 if let Some(tf) = target_file {
156 if callee.file_path != tf {
157 continue;
158 }
159 }
160
161 if !matches!(
162 callee.entity_type.as_str(),
163 "function" | "method" | "arrow_function"
164 ) {
165 continue;
166 }
167
168 let callee_content = match content_map.get(&edge.to_entity) {
169 Some(c) => c,
170 None => continue,
171 };
172
173 let caller = match graph.entities.get(&edge.from_entity) {
174 Some(e) => e,
175 None => continue,
176 };
177
178 let caller_content = match content_map.get(&edge.from_entity) {
179 Some(c) => c,
180 None => continue,
181 };
182
183 let expected = extract_param_count(callee_content);
184 if expected == 0 {
185 continue;
186 }
187
188 for actual in count_all_call_args(caller_content, &callee.name) {
189 if actual != expected {
190 violations.push(ContractViolation {
191 entity_name: callee.name.clone(),
192 file_path: callee.file_path.clone(),
193 expected_params: expected,
194 caller_name: caller.name.clone(),
195 caller_file: caller.file_path.clone(),
196 actual_args: actual,
197 });
198 }
199 }
200 }
201
202 violations
203}
204
205fn lang_from_ext(ext: &str) -> &'static str {
208 match ext {
209 ".py" | ".pyi" => "python",
210 ".ts" | ".tsx" | ".mts" | ".cts" => "typescript",
211 ".js" | ".jsx" | ".mjs" | ".cjs" => "javascript",
212 ".rs" => "rust",
213 ".go" => "go",
214 _ => "unknown",
215 }
216}
217
218pub fn extract_param_info_ts(content: &str, file_path: &str) -> Option<ParamInfo> {
220 let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
221 let lang = lang_from_ext(ext);
222 if lang == "unknown" {
223 return None;
224 }
225 let config = get_language_config(ext)?;
226 let language = (config.get_language)()?;
227
228 let mut parser = tree_sitter::Parser::new();
229 let _ = parser.set_language(&language);
230 let tree = parser.parse(content.as_bytes(), None)?;
231
232 extract_param_info_from_node(tree.root_node(), content.as_bytes(), lang)
233}
234
235fn extract_param_info_from_node(
236 root: tree_sitter::Node,
237 source: &[u8],
238 lang: &str,
239) -> Option<ParamInfo> {
240 let func_node = find_first_function(root)?;
242 let params_node = func_node.child_by_field_name("parameters")?;
243
244 let mut min_params = 0usize;
245 let mut max_params = 0usize;
246 let mut is_variadic = false;
247
248 let mut cursor = params_node.walk();
249 for child in params_node.named_children(&mut cursor) {
250 let kind = child.kind();
251 match lang {
252 "python" => {
253 if kind == "identifier" {
254 let name = child.utf8_text(source).unwrap_or("");
255 if name == "self" || name == "cls" {
256 continue;
257 }
258 min_params += 1;
259 max_params += 1;
260 } else if kind == "typed_parameter" {
261 let name = child
262 .child_by_field_name("name")
263 .or_else(|| child.named_child(0))
264 .and_then(|n| n.utf8_text(source).ok())
265 .unwrap_or("");
266 if name == "self" || name == "cls" {
267 continue;
268 }
269 min_params += 1;
270 max_params += 1;
271 } else if kind == "default_parameter" || kind == "typed_default_parameter" {
272 max_params += 1;
273 } else if kind == "list_splat_pattern" || kind == "dictionary_splat_pattern" {
274 is_variadic = true;
275 }
276 }
277 "typescript" => {
278 if kind == "required_parameter" {
279 max_params += 1;
280 if !has_js_ts_default_value(child) {
281 min_params += 1;
282 }
283 } else if kind == "optional_parameter" {
284 max_params += 1;
285 } else if kind == "rest_pattern" {
286 is_variadic = true;
287 }
288 }
289 "javascript" => {
290 if kind == "rest_pattern" {
291 is_variadic = true;
292 } else if matches!(kind, "identifier" | "formal_parameter" | "assignment_pattern") {
293 max_params += 1;
294 if !has_js_ts_default_value(child) {
295 min_params += 1;
296 }
297 }
298 }
299 "rust" => {
300 if kind == "parameter" {
301 let pat = child
302 .child_by_field_name("pattern")
303 .and_then(|n| n.utf8_text(source).ok())
304 .unwrap_or("");
305 let base = pat.trim_start_matches('&').trim();
307 let base = base.strip_prefix("mut ").unwrap_or(base).trim();
308 if base == "self" {
309 continue;
310 }
311 min_params += 1;
312 max_params += 1;
313 } else if kind == "self_parameter" {
314 continue;
315 }
316 }
317 "go" => {
318 if kind == "parameter_declaration" {
319 let type_node = child.child_by_field_name("type");
320 let type_text = type_node.and_then(|n| n.utf8_text(source).ok()).unwrap_or("");
321 let param_text = child.utf8_text(source).unwrap_or("");
322 if type_text.starts_with("...") || param_text.contains("...") {
323 is_variadic = true;
324 } else {
325 let count = count_go_parameter_declaration_arity(child);
326 min_params += count;
327 max_params += count;
328 }
329 }
330 }
331 _ => {}
332 }
333 }
334
335 Some(ParamInfo {
336 min_params,
337 max_params,
338 is_variadic,
339 })
340}
341
342fn has_js_ts_default_value(node: tree_sitter::Node) -> bool {
343 let mut cursor = node.walk();
344 let has_assignment_child = node
345 .named_children(&mut cursor)
346 .any(|child| child.kind() == "assignment_pattern");
347 node.kind() == "assignment_pattern"
348 || node.child_by_field_name("value").is_some()
349 || has_assignment_child
350}
351
352fn count_go_parameter_declaration_arity(node: tree_sitter::Node) -> usize {
353 let mut name_cursor = node.walk();
354 let field_names = node
355 .children_by_field_name("name", &mut name_cursor)
356 .count();
357 if field_names > 0 {
358 return field_names;
359 }
360
361 let type_range = match node.child_by_field_name("type") {
362 Some(type_node) => (type_node.start_byte(), type_node.end_byte()),
363 None => return 1,
364 };
365 let mut cursor = node.walk();
366 let identifier_names = node
367 .named_children(&mut cursor)
368 .filter(|child| {
369 child.kind() == "identifier" && type_range != (child.start_byte(), child.end_byte())
370 })
371 .count();
372 if identifier_names > 0 {
373 identifier_names
374 } else {
375 1
376 }
377}
378
379fn find_first_function(root: tree_sitter::Node) -> Option<tree_sitter::Node> {
380 let mut worklist = vec![root];
381 while let Some(node) = worklist.pop() {
382 let kind = node.kind();
383 if matches!(
384 kind,
385 "function_definition"
386 | "function_item"
387 | "function_declaration"
388 | "method_definition"
389 | "method_declaration"
390 | "arrow_function"
391 ) {
392 return Some(node);
393 }
394 let mut cursor = node.walk();
395 let children: Vec<_> = node.named_children(&mut cursor).collect();
396 for child in children.into_iter().rev() {
397 worklist.push(child);
398 }
399 }
400 None
401}
402
403pub fn count_call_args_ts(
405 caller_content: &str,
406 callee_name: &str,
407 file_path: &str,
408) -> Option<usize> {
409 count_call_arg_sites_ts(caller_content, callee_name, file_path)
410 .into_iter()
411 .next()
412 .map(|site| site.actual_args)
413}
414
415fn count_call_arg_sites_ts(
416 caller_content: &str,
417 callee_name: &str,
418 file_path: &str,
419) -> Vec<CallArgCount> {
420 let ext = match file_path.rfind('.').map(|i| &file_path[i..]) {
421 Some(ext) => ext,
422 None => return Vec::new(),
423 };
424 let config = match get_language_config(ext) {
425 Some(config) => config,
426 None => return Vec::new(),
427 };
428 let language = match (config.get_language)() {
429 Some(language) => language,
430 None => return Vec::new(),
431 };
432
433 let mut parser = tree_sitter::Parser::new();
434 let _ = parser.set_language(&language);
435 let tree = match parser.parse(caller_content.as_bytes(), None) {
436 Some(tree) => tree,
437 None => return Vec::new(),
438 };
439
440 find_call_arg_counts(tree.root_node(), caller_content.as_bytes(), callee_name)
441}
442
443fn find_call_arg_counts(
444 root: tree_sitter::Node,
445 source: &[u8],
446 callee_name: &str,
447) -> Vec<CallArgCount> {
448 let mut sites = Vec::new();
449 let mut worklist = vec![root];
450 while let Some(node) = worklist.pop() {
451 let kind = node.kind();
452
453 if kind == "call" || kind == "call_expression" {
454 if let Some(func) = node.child_by_field_name("function") {
455 let func_name = match func.kind() {
456 "identifier" => func.utf8_text(source).unwrap_or(""),
457 "attribute" | "member_expression" | "field_expression" => func
458 .child_by_field_name("attribute")
459 .or_else(|| func.child_by_field_name("property"))
460 .or_else(|| func.child_by_field_name("field"))
461 .and_then(|n| n.utf8_text(source).ok())
462 .unwrap_or(""),
463 "selector_expression" => func
464 .child_by_field_name("field")
465 .and_then(|n| n.utf8_text(source).ok())
466 .unwrap_or(""),
467 "scoped_identifier" => {
468 let text = func.utf8_text(source).unwrap_or("");
469 text.rsplit("::").next().unwrap_or("")
470 }
471 _ => "",
472 };
473
474 if func_name == callee_name {
475 if let Some(args) = node.child_by_field_name("arguments") {
476 let mut actual_args = 0;
477 let mut cursor = args.walk();
478 for child in args.named_children(&mut cursor) {
479 if !child.kind().contains("comment") {
481 actual_args += 1;
482 }
483 }
484 sites.push(CallArgCount {
485 actual_args,
486 line_offset: node.start_position().row,
487 });
488 }
489 }
490 }
491 }
492
493 let mut cursor = node.walk();
494 let children: Vec<_> = node.named_children(&mut cursor).collect();
495 for child in children.into_iter().rev() {
496 worklist.push(child);
497 }
498 }
499 sites
500}
501
502const AMBIGUOUS_NAMES: &[&str] = &[
504 "new", "constructor", "toString", "valueOf", "init", "__init__",
505 "apply", "call", "bind", "get", "set", "run", "execute", "create",
506];
507
508const TEST_PATH_MARKERS: &[&str] = &[
510 "test", "tests", "spec", "specs", "fixtures", "fixture",
511 "benchmarks", "benchmark", "__tests__", "__mocks__",
512];
513
514fn is_test_or_fixture_path(path: &str) -> bool {
515 path.split('/').any(|component| TEST_PATH_MARKERS.contains(&component))
516}
517
518pub fn find_arity_mismatches(
520 graph: &EntityGraph,
521 all_entities: &[SemanticEntity],
522) -> Vec<ArityMismatch> {
523 let entity_by_id: HashMap<&str, &SemanticEntity> = all_entities
524 .iter()
525 .map(|e| (e.id.as_str(), e))
526 .collect();
527
528 let mut name_counts: HashMap<&str, usize> = HashMap::new();
530 for e in all_entities {
531 if matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function") {
532 *name_counts.entry(&e.name).or_insert(0) += 1;
533 }
534 }
535
536 let mut param_cache: HashMap<String, Option<ParamInfo>> = HashMap::new();
538
539 let mut mismatches = Vec::new();
540
541 for edge in &graph.edges {
542 if edge.ref_type != RefType::Calls {
543 continue;
544 }
545
546 let callee_info = match graph.entities.get(&edge.to_entity) {
547 Some(e) => e,
548 None => continue,
549 };
550
551 if !matches!(
552 callee_info.entity_type.as_str(),
553 "function" | "method" | "arrow_function"
554 ) {
555 continue;
556 }
557
558 if AMBIGUOUS_NAMES.contains(&callee_info.name.as_str()) {
560 continue;
561 }
562
563 if name_counts.get(callee_info.name.as_str()).copied().unwrap_or(0) > 1 {
565 continue;
566 }
567
568 if is_test_or_fixture_path(&callee_info.file_path) {
570 continue;
571 }
572
573 let callee = match entity_by_id.get(edge.to_entity.as_str()) {
574 Some(e) => *e,
575 None => continue,
576 };
577
578 let caller = match entity_by_id.get(edge.from_entity.as_str()) {
579 Some(e) => *e,
580 None => continue,
581 };
582
583 if is_test_or_fixture_path(&caller.file_path) {
585 continue;
586 }
587
588 let param_info = param_cache
590 .entry(callee.id.clone())
591 .or_insert_with(|| extract_param_info_ts(&callee.content, &callee.file_path))
592 .clone();
593
594 let param_info = match param_info {
595 Some(pi) => pi,
596 None => continue,
597 };
598
599 if param_info.is_variadic {
601 continue;
602 }
603
604 for call_site in count_call_arg_sites_ts(&caller.content, &callee.name, &caller.file_path) {
605 if call_site.actual_args < param_info.min_params
606 || call_site.actual_args > param_info.max_params
607 {
608 mismatches.push(ArityMismatch {
609 caller_entity: caller.name.clone(),
610 callee_entity: callee.name.clone(),
611 expected_min: param_info.min_params,
612 expected_max: param_info.max_params,
613 actual_args: call_site.actual_args,
614 file_path: caller.file_path.clone(),
615 line: caller.start_line + call_site.line_offset,
616 is_variadic: false,
617 });
618 }
619 }
620 }
621
622 mismatches
623}
624
625pub fn find_broken_callers(
629 old_entities: &[SemanticEntity],
630 new_graph: &EntityGraph,
631 new_entities: &[SemanticEntity],
632) -> Vec<ArityMismatch> {
633 let old_params: HashMap<String, Option<ParamInfo>> = old_entities
635 .iter()
636 .filter(|e| matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function"))
637 .map(|e| (e.id.clone(), extract_param_info_ts(&e.content, &e.file_path)))
638 .collect();
639
640 let new_by_id: HashMap<&str, &SemanticEntity> = new_entities
642 .iter()
643 .map(|e| (e.id.as_str(), e))
644 .collect();
645
646 let mut changed_entities: Vec<&str> = Vec::new();
648 for new_entity in new_entities {
649 if !matches!(new_entity.entity_type.as_str(), "function" | "method" | "arrow_function") {
650 continue;
651 }
652 let new_info = match extract_param_info_ts(&new_entity.content, &new_entity.file_path) {
653 Some(pi) => pi,
654 None => continue,
655 };
656 if let Some(Some(old_info)) = old_params.get(&new_entity.id) {
657 if old_info.min_params != new_info.min_params
658 || old_info.max_params != new_info.max_params
659 {
660 changed_entities.push(&new_entity.id);
661 }
662 }
663 }
664
665 if changed_entities.is_empty() {
666 return Vec::new();
667 }
668
669 let mut mismatches = Vec::new();
671
672 for edge in &new_graph.edges {
673 if edge.ref_type != RefType::Calls {
674 continue;
675 }
676 if !changed_entities.contains(&edge.to_entity.as_str()) {
677 continue;
678 }
679
680 let callee = match new_by_id.get(edge.to_entity.as_str()) {
681 Some(e) => *e,
682 None => continue,
683 };
684 let caller = match new_by_id.get(edge.from_entity.as_str()) {
685 Some(e) => *e,
686 None => continue,
687 };
688
689 let new_info = match extract_param_info_ts(&callee.content, &callee.file_path) {
690 Some(pi) => pi,
691 None => continue,
692 };
693
694 if new_info.is_variadic {
695 continue;
696 }
697
698 for call_site in count_call_arg_sites_ts(&caller.content, &callee.name, &caller.file_path) {
699 if call_site.actual_args < new_info.min_params
700 || call_site.actual_args > new_info.max_params
701 {
702 mismatches.push(ArityMismatch {
703 caller_entity: caller.name.clone(),
704 callee_entity: callee.name.clone(),
705 expected_min: new_info.min_params,
706 expected_max: new_info.max_params,
707 actual_args: call_site.actual_args,
708 file_path: caller.file_path.clone(),
709 line: caller.start_line + call_site.line_offset,
710 is_variadic: false,
711 });
712 }
713 }
714 }
715
716 mismatches
717}
718
719fn extract_param_count(content: &str) -> usize {
723 let first_line = content.lines().next().unwrap_or("");
724
725 let open = match first_line.find('(') {
726 Some(i) => i,
727 None => return 0,
728 };
729
730 let after_open = &first_line[open + 1..];
731 let close = match find_matching_paren(after_open) {
732 Some(i) => i,
733 None => return 0,
734 };
735
736 let params_str = after_open[..close].trim();
737 if params_str.is_empty() {
738 return 0;
739 }
740
741 count_top_level_commas(params_str) + 1
742}
743
744#[cfg(test)]
746fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
747 count_all_call_args(content, callee_name).into_iter().next()
748}
749
750fn count_all_call_args(content: &str, callee_name: &str) -> Vec<usize> {
751 let bytes = content.as_bytes();
752 let name_bytes = callee_name.as_bytes();
753 let mut search_start = 0;
754 let mut counts = Vec::new();
755
756 while let Some(rel_pos) = content[search_start..].find(callee_name) {
757 let pos = search_start + rel_pos;
758 let after = pos + name_bytes.len();
759
760 let is_boundary = pos == 0 || {
761 let prev = bytes[pos - 1];
762 !prev.is_ascii_alphanumeric() && prev != b'_'
763 };
764
765 let mut next_search_start = pos + 1;
766 if is_boundary && after < bytes.len() && bytes[after] == b'(' {
767 let args_start_index = after + 1;
768 let args_start = &content[args_start_index..];
769 if let Some(close) = find_matching_paren(args_start) {
770 let args_str = args_start[..close].trim();
771 if args_str.is_empty() {
772 counts.push(0);
773 } else {
774 counts.push(count_top_level_commas(args_str) + 1);
775 }
776 next_search_start = args_start_index + close + 1;
777 } else {
778 next_search_start = after;
779 }
780 }
781
782 search_start = next_search_start;
783 while search_start < content.len() && !content.is_char_boundary(search_start) {
784 search_start += 1;
785 }
786 }
787
788 counts
789}
790
791fn find_matching_paren(s: &str) -> Option<usize> {
792 let mut depth = 0i32;
793 for (i, ch) in s.char_indices() {
794 match ch {
795 '(' => depth += 1,
796 ')' => {
797 if depth == 0 {
798 return Some(i);
799 }
800 depth -= 1;
801 }
802 _ => {}
803 }
804 }
805 None
806}
807
808fn count_top_level_commas(s: &str) -> usize {
809 let mut depth = 0i32;
810 let mut count = 0;
811 for ch in s.chars() {
812 match ch {
813 '(' | '[' | '{' | '<' => depth += 1,
814 ')' | ']' | '}' | '>' => depth -= 1,
815 ',' if depth == 0 => count += 1,
816 _ => {}
817 }
818 }
819 count
820}
821
822#[cfg(test)]
823mod tests {
824 use super::*;
825
826 #[test]
827 fn test_extract_param_count_basic() {
828 assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
829 assert_eq!(extract_param_count("function foo() {"), 0);
830 assert_eq!(extract_param_count("def bar(self, x):"), 2);
831 assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
832 }
833
834 #[test]
835 fn test_extract_param_count_nested() {
836 assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
837 }
838
839 #[test]
840 fn test_count_call_args() {
841 assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
842 assert_eq!(count_call_args("foo()", "foo"), Some(0));
843 assert_eq!(count_call_args("bar(1)", "foo"), None);
844 assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
845 }
846
847 #[test]
848 fn test_count_all_call_args() {
849 assert_eq!(count_all_call_args("foo(1, 2); foo(1);", "foo"), vec![2, 1]);
850 }
851
852 #[test]
853 fn test_count_all_call_args_resumes_after_unclosed_candidate() {
854 assert_eq!(count_all_call_args("foo(\nfoo(1, 2)", "foo"), vec![2]);
855 }
856
857 #[test]
858 fn test_count_call_args_multibyte_utf8() {
859 assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
860 assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
861 assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
862 }
863
864 #[test]
865 fn test_extract_param_info_python() {
866 let info = extract_param_info_ts(
867 "def foo(a, b, c=3):\n pass",
868 "test.py",
869 )
870 .unwrap();
871 assert_eq!(info.min_params, 2);
872 assert_eq!(info.max_params, 3);
873 assert!(!info.is_variadic);
874 }
875
876 #[test]
877 fn test_extract_param_info_python_self() {
878 let info = extract_param_info_ts(
879 "def foo(self, a, b):\n pass",
880 "test.py",
881 )
882 .unwrap();
883 assert_eq!(info.min_params, 2);
884 assert_eq!(info.max_params, 2);
885 }
886
887 #[test]
888 fn test_extract_param_info_python_variadic() {
889 let info = extract_param_info_ts(
890 "def foo(a, *args, **kwargs):\n pass",
891 "test.py",
892 )
893 .unwrap();
894 assert!(info.is_variadic);
895 }
896
897 #[test]
898 fn test_extract_param_info_typescript() {
899 let info = extract_param_info_ts(
900 "function foo(a: number, b: string, c?: boolean): void {}",
901 "test.ts",
902 )
903 .unwrap();
904 assert_eq!(info.min_params, 2);
905 assert_eq!(info.max_params, 3);
906 assert!(!info.is_variadic);
907 }
908
909 #[test]
910 fn test_extract_param_info_typescript_default_parameter() {
911 let info = extract_param_info_ts(
912 "function foo(a: number, b = 1): number { return a + b; }",
913 "test.ts",
914 )
915 .unwrap();
916 assert_eq!(info.min_params, 1);
917 assert_eq!(info.max_params, 2);
918 assert!(!info.is_variadic);
919 }
920
921 #[test]
922 fn test_extract_param_info_javascript_default_parameter() {
923 let info =
924 extract_param_info_ts("function foo(a, b = 1) { return a + b; }", "test.js").unwrap();
925 assert_eq!(info.min_params, 1);
926 assert_eq!(info.max_params, 2);
927 assert!(!info.is_variadic);
928 }
929
930 #[test]
931 fn test_extract_param_info_javascript_required_parameters() {
932 let info = extract_param_info_ts("function foo(a, b) { return a + b; }", "test.js")
933 .unwrap();
934 assert_eq!(info.min_params, 2);
935 assert_eq!(info.max_params, 2);
936 assert!(!info.is_variadic);
937 }
938
939 #[test]
940 fn test_extract_param_info_typescript_arrow_default_parameter() {
941 let info = extract_param_info_ts(
942 "const foo = (a: number, b = 1): number => a + b;",
943 "test.ts",
944 )
945 .unwrap();
946 assert_eq!(info.min_params, 1);
947 assert_eq!(info.max_params, 2);
948 assert!(!info.is_variadic);
949 }
950
951 #[test]
952 fn test_extract_param_info_rust() {
953 let info = extract_param_info_ts(
954 "fn foo(&self, a: i32, b: String) -> bool { true }",
955 "test.rs",
956 )
957 .unwrap();
958 assert_eq!(info.min_params, 2);
959 assert_eq!(info.max_params, 2);
960 }
961
962 #[test]
963 fn test_extract_param_info_go() {
964 let info = extract_param_info_ts(
965 "func foo(a string, b int) error { return nil }",
966 "test.go",
967 )
968 .unwrap();
969 assert_eq!(info.min_params, 2);
970 assert_eq!(info.max_params, 2);
971 }
972
973 #[test]
974 fn test_extract_param_info_go_grouped_params() {
975 let info = extract_param_info_ts(
976 "func foo(a, b int, c string) int { return a + b }",
977 "test.go",
978 )
979 .unwrap();
980 assert_eq!(info.min_params, 3);
981 assert_eq!(info.max_params, 3);
982 }
983
984 #[test]
985 fn test_extract_param_info_go_unnamed_params() {
986 let info = extract_param_info_ts(
987 "func foo(int, string) bool { return true }",
988 "test.go",
989 )
990 .unwrap();
991 assert_eq!(info.min_params, 2);
992 assert_eq!(info.max_params, 2);
993 }
994
995 #[test]
996 fn test_count_call_args_ts() {
997 let count = count_call_args_ts(
998 "function bar() { foo(1, 2, 3); }",
999 "foo",
1000 "test.ts",
1001 );
1002 assert_eq!(count, Some(3));
1003 }
1004
1005 #[test]
1006 fn test_count_call_args_ts_method() {
1007 let count = count_call_args_ts(
1008 "function bar() { obj.foo(1, 2); }",
1009 "foo",
1010 "test.ts",
1011 );
1012 assert_eq!(count, Some(2));
1013 }
1014
1015 #[test]
1016 fn test_count_call_arg_sites_ts_repeated_calls() {
1017 let sites =
1018 count_call_arg_sites_ts("def bar():\n foo(1, 2)\n foo(1)\n", "foo", "test.py");
1019 assert_eq!(
1020 sites,
1021 vec![
1022 CallArgCount {
1023 actual_args: 2,
1024 line_offset: 1,
1025 },
1026 CallArgCount {
1027 actual_args: 1,
1028 line_offset: 2,
1029 },
1030 ]
1031 );
1032 }
1033}