1use std::collections::HashMap;
5use std::path::Path;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityGraph, RefType};
9use crate::parser::plugins::code::languages::get_language_config;
10use crate::parser::registry::ParserRegistry;
11
12#[derive(Debug, Clone)]
13pub struct ContractViolation {
14 pub entity_name: String,
15 pub file_path: String,
16 pub expected_params: usize,
17 pub caller_name: String,
18 pub caller_file: String,
19 pub actual_args: usize,
20}
21
22#[derive(Debug, Clone)]
24pub struct ParamInfo {
25 pub min_params: usize,
26 pub max_params: usize,
27 pub is_variadic: bool,
28}
29
30#[derive(Debug, Clone)]
32pub struct ArityMismatch {
33 pub caller_entity: String,
34 pub callee_entity: String,
35 pub expected_min: usize,
36 pub expected_max: usize,
37 pub actual_args: usize,
38 pub file_path: String,
39 pub line: usize,
40 pub is_variadic: bool,
41}
42
43pub fn verify_contracts(
45 root: &Path,
46 file_paths: &[String],
47 registry: &ParserRegistry,
48 target_file: Option<&str>,
49) -> Vec<ContractViolation> {
50 let (graph, _) = EntityGraph::build(root, file_paths, registry);
51
52 let mut content_map: HashMap<String, String> = HashMap::new();
53 for fp in file_paths {
54 let full = root.join(fp);
55 let content = match std::fs::read_to_string(&full) {
56 Ok(c) => c,
57 Err(_) => continue,
58 };
59 for entity in registry.extract_entities(fp, &content) {
60 content_map.insert(entity.id.clone(), entity.content.clone());
61 }
62 }
63
64 let mut violations = Vec::new();
65
66 for edge in &graph.edges {
67 if edge.ref_type != RefType::Calls {
68 continue;
69 }
70
71 let callee = match graph.entities.get(&edge.to_entity) {
72 Some(e) => e,
73 None => continue,
74 };
75
76 if let Some(tf) = target_file {
77 if callee.file_path != tf {
78 continue;
79 }
80 }
81
82 if !matches!(
83 callee.entity_type.as_str(),
84 "function" | "method" | "arrow_function"
85 ) {
86 continue;
87 }
88
89 let callee_content = match content_map.get(&edge.to_entity) {
90 Some(c) => c,
91 None => continue,
92 };
93
94 let caller = match graph.entities.get(&edge.from_entity) {
95 Some(e) => e,
96 None => continue,
97 };
98
99 let caller_content = match content_map.get(&edge.from_entity) {
100 Some(c) => c,
101 None => continue,
102 };
103
104 let expected = extract_param_count(callee_content);
105 if expected == 0 {
106 continue;
107 }
108
109 if let Some(actual) = count_call_args(caller_content, &callee.name) {
110 if actual != expected {
111 violations.push(ContractViolation {
112 entity_name: callee.name.clone(),
113 file_path: callee.file_path.clone(),
114 expected_params: expected,
115 caller_name: caller.name.clone(),
116 caller_file: caller.file_path.clone(),
117 actual_args: actual,
118 });
119 }
120 }
121 }
122
123 violations
124}
125
126pub fn verify_contracts_with_graph(
128 graph: &EntityGraph,
129 all_entities: &[SemanticEntity],
130 target_file: Option<&str>,
131) -> Vec<ContractViolation> {
132 let content_map: HashMap<String, String> = all_entities
133 .iter()
134 .map(|e| (e.id.clone(), e.content.clone()))
135 .collect();
136
137 let mut violations = Vec::new();
138
139 for edge in &graph.edges {
140 if edge.ref_type != RefType::Calls {
141 continue;
142 }
143
144 let callee = match graph.entities.get(&edge.to_entity) {
145 Some(e) => e,
146 None => continue,
147 };
148
149 if let Some(tf) = target_file {
150 if callee.file_path != tf {
151 continue;
152 }
153 }
154
155 if !matches!(
156 callee.entity_type.as_str(),
157 "function" | "method" | "arrow_function"
158 ) {
159 continue;
160 }
161
162 let callee_content = match content_map.get(&edge.to_entity) {
163 Some(c) => c,
164 None => continue,
165 };
166
167 let caller = match graph.entities.get(&edge.from_entity) {
168 Some(e) => e,
169 None => continue,
170 };
171
172 let caller_content = match content_map.get(&edge.from_entity) {
173 Some(c) => c,
174 None => continue,
175 };
176
177 let expected = extract_param_count(callee_content);
178 if expected == 0 {
179 continue;
180 }
181
182 if let Some(actual) = count_call_args(caller_content, &callee.name) {
183 if actual != expected {
184 violations.push(ContractViolation {
185 entity_name: callee.name.clone(),
186 file_path: callee.file_path.clone(),
187 expected_params: expected,
188 caller_name: caller.name.clone(),
189 caller_file: caller.file_path.clone(),
190 actual_args: actual,
191 });
192 }
193 }
194 }
195
196 violations
197}
198
199fn lang_from_ext(ext: &str) -> &'static str {
202 match ext {
203 ".py" | ".pyi" => "python",
204 ".ts" | ".tsx" | ".mts" | ".cts" => "typescript",
205 ".js" | ".jsx" | ".mjs" | ".cjs" => "typescript",
206 ".rs" => "rust",
207 ".go" => "go",
208 _ => "unknown",
209 }
210}
211
212pub fn extract_param_info_ts(content: &str, file_path: &str) -> Option<ParamInfo> {
214 let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
215 let lang = lang_from_ext(ext);
216 if lang == "unknown" {
217 return None;
218 }
219 let config = get_language_config(ext)?;
220 let language = (config.get_language)()?;
221
222 let mut parser = tree_sitter::Parser::new();
223 let _ = parser.set_language(&language);
224 let tree = parser.parse(content.as_bytes(), None)?;
225
226 extract_param_info_from_node(tree.root_node(), content.as_bytes(), lang)
227}
228
229fn extract_param_info_from_node(
230 root: tree_sitter::Node,
231 source: &[u8],
232 lang: &str,
233) -> Option<ParamInfo> {
234 let func_node = find_first_function(root)?;
236 let params_node = func_node.child_by_field_name("parameters")?;
237
238 let mut min_params = 0usize;
239 let mut max_params = 0usize;
240 let mut is_variadic = false;
241
242 let mut cursor = params_node.walk();
243 for child in params_node.named_children(&mut cursor) {
244 let kind = child.kind();
245 match lang {
246 "python" => {
247 if kind == "identifier" {
248 let name = child.utf8_text(source).unwrap_or("");
249 if name == "self" || name == "cls" {
250 continue;
251 }
252 min_params += 1;
253 max_params += 1;
254 } else if kind == "typed_parameter" {
255 let name = child
256 .child_by_field_name("name")
257 .or_else(|| child.named_child(0))
258 .and_then(|n| n.utf8_text(source).ok())
259 .unwrap_or("");
260 if name == "self" || name == "cls" {
261 continue;
262 }
263 min_params += 1;
264 max_params += 1;
265 } else if kind == "default_parameter" || kind == "typed_default_parameter" {
266 max_params += 1;
267 } else if kind == "list_splat_pattern" || kind == "dictionary_splat_pattern" {
268 is_variadic = true;
269 }
270 }
271 "typescript" => {
272 if kind == "required_parameter" {
273 min_params += 1;
274 max_params += 1;
275 } else if kind == "optional_parameter" {
276 max_params += 1;
277 } else if kind == "rest_pattern" {
278 is_variadic = true;
279 }
280 }
281 "rust" => {
282 if kind == "parameter" {
283 let pat = child
284 .child_by_field_name("pattern")
285 .and_then(|n| n.utf8_text(source).ok())
286 .unwrap_or("");
287 let base = pat.trim_start_matches('&').trim();
289 let base = base.strip_prefix("mut ").unwrap_or(base).trim();
290 if base == "self" {
291 continue;
292 }
293 min_params += 1;
294 max_params += 1;
295 } else if kind == "self_parameter" {
296 continue;
297 }
298 }
299 "go" => {
300 if kind == "parameter_declaration" {
301 let type_text = child
303 .child_by_field_name("type")
304 .and_then(|n| n.utf8_text(source).ok())
305 .unwrap_or("");
306 if type_text.starts_with("...") {
307 is_variadic = true;
308 } else {
309 min_params += 1;
310 max_params += 1;
311 }
312 }
313 }
314 _ => {}
315 }
316 }
317
318 Some(ParamInfo {
319 min_params,
320 max_params,
321 is_variadic,
322 })
323}
324
325fn find_first_function(root: tree_sitter::Node) -> Option<tree_sitter::Node> {
326 let mut worklist = vec![root];
327 while let Some(node) = worklist.pop() {
328 let kind = node.kind();
329 if matches!(
330 kind,
331 "function_definition"
332 | "function_item"
333 | "function_declaration"
334 | "method_definition"
335 | "method_declaration"
336 | "arrow_function"
337 ) {
338 return Some(node);
339 }
340 let mut cursor = node.walk();
341 let children: Vec<_> = node.named_children(&mut cursor).collect();
342 for child in children.into_iter().rev() {
343 worklist.push(child);
344 }
345 }
346 None
347}
348
349pub fn count_call_args_ts(
351 caller_content: &str,
352 callee_name: &str,
353 file_path: &str,
354) -> Option<usize> {
355 let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
356 let config = get_language_config(ext)?;
357 let language = (config.get_language)()?;
358
359 let mut parser = tree_sitter::Parser::new();
360 let _ = parser.set_language(&language);
361 let tree = parser.parse(caller_content.as_bytes(), None)?;
362
363 find_call_arg_count(tree.root_node(), caller_content.as_bytes(), callee_name)
364}
365
366fn find_call_arg_count(
367 root: tree_sitter::Node,
368 source: &[u8],
369 callee_name: &str,
370) -> Option<usize> {
371 let mut worklist = vec![root];
372 while let Some(node) = worklist.pop() {
373 let kind = node.kind();
374
375 if kind == "call" || kind == "call_expression" {
376 if let Some(func) = node.child_by_field_name("function") {
377 let func_name = match func.kind() {
378 "identifier" => func.utf8_text(source).unwrap_or(""),
379 "attribute" | "member_expression" | "field_expression" => func
380 .child_by_field_name("attribute")
381 .or_else(|| func.child_by_field_name("property"))
382 .or_else(|| func.child_by_field_name("field"))
383 .and_then(|n| n.utf8_text(source).ok())
384 .unwrap_or(""),
385 "selector_expression" => func
386 .child_by_field_name("field")
387 .and_then(|n| n.utf8_text(source).ok())
388 .unwrap_or(""),
389 "scoped_identifier" => {
390 let text = func.utf8_text(source).unwrap_or("");
391 text.rsplit("::").next().unwrap_or("")
392 }
393 _ => "",
394 };
395
396 if func_name == callee_name {
397 if let Some(args) = node.child_by_field_name("arguments") {
398 let mut count = 0;
399 let mut cursor = args.walk();
400 for child in args.named_children(&mut cursor) {
401 if !child.kind().contains("comment") {
403 count += 1;
404 }
405 }
406 return Some(count);
407 }
408 }
409 }
410 }
411
412 let mut cursor = node.walk();
413 let children: Vec<_> = node.named_children(&mut cursor).collect();
414 for child in children.into_iter().rev() {
415 worklist.push(child);
416 }
417 }
418 None
419}
420
421const AMBIGUOUS_NAMES: &[&str] = &[
423 "new", "constructor", "toString", "valueOf", "init", "__init__",
424 "apply", "call", "bind", "get", "set", "run", "execute", "create",
425];
426
427const TEST_PATH_MARKERS: &[&str] = &[
429 "test", "tests", "spec", "specs", "fixtures", "fixture",
430 "benchmarks", "benchmark", "__tests__", "__mocks__",
431];
432
433fn is_test_or_fixture_path(path: &str) -> bool {
434 path.split('/').any(|component| TEST_PATH_MARKERS.contains(&component))
435}
436
437pub fn find_arity_mismatches(
439 graph: &EntityGraph,
440 all_entities: &[SemanticEntity],
441) -> Vec<ArityMismatch> {
442 let entity_by_id: HashMap<&str, &SemanticEntity> = all_entities
443 .iter()
444 .map(|e| (e.id.as_str(), e))
445 .collect();
446
447 let mut name_counts: HashMap<&str, usize> = HashMap::new();
449 for e in all_entities {
450 if matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function") {
451 *name_counts.entry(&e.name).or_insert(0) += 1;
452 }
453 }
454
455 let mut param_cache: HashMap<String, Option<ParamInfo>> = HashMap::new();
457
458 let mut mismatches = Vec::new();
459
460 for edge in &graph.edges {
461 if edge.ref_type != RefType::Calls {
462 continue;
463 }
464
465 let callee_info = match graph.entities.get(&edge.to_entity) {
466 Some(e) => e,
467 None => continue,
468 };
469
470 if !matches!(
471 callee_info.entity_type.as_str(),
472 "function" | "method" | "arrow_function"
473 ) {
474 continue;
475 }
476
477 if AMBIGUOUS_NAMES.contains(&callee_info.name.as_str()) {
479 continue;
480 }
481
482 if name_counts.get(callee_info.name.as_str()).copied().unwrap_or(0) > 1 {
484 continue;
485 }
486
487 if is_test_or_fixture_path(&callee_info.file_path) {
489 continue;
490 }
491
492 let callee = match entity_by_id.get(edge.to_entity.as_str()) {
493 Some(e) => *e,
494 None => continue,
495 };
496
497 let caller = match entity_by_id.get(edge.from_entity.as_str()) {
498 Some(e) => *e,
499 None => continue,
500 };
501
502 if is_test_or_fixture_path(&caller.file_path) {
504 continue;
505 }
506
507 let param_info = param_cache
509 .entry(callee.id.clone())
510 .or_insert_with(|| extract_param_info_ts(&callee.content, &callee.file_path))
511 .clone();
512
513 let param_info = match param_info {
514 Some(pi) => pi,
515 None => continue,
516 };
517
518 if param_info.is_variadic {
520 continue;
521 }
522
523 let actual = match count_call_args_ts(
525 &caller.content,
526 &callee.name,
527 &caller.file_path,
528 ) {
529 Some(a) => a,
530 None => continue,
531 };
532
533 if actual < param_info.min_params || actual > param_info.max_params {
534 mismatches.push(ArityMismatch {
535 caller_entity: caller.name.clone(),
536 callee_entity: callee.name.clone(),
537 expected_min: param_info.min_params,
538 expected_max: param_info.max_params,
539 actual_args: actual,
540 file_path: caller.file_path.clone(),
541 line: caller.start_line,
542 is_variadic: false,
543 });
544 }
545 }
546
547 mismatches
548}
549
550pub fn find_broken_callers(
554 old_entities: &[SemanticEntity],
555 new_graph: &EntityGraph,
556 new_entities: &[SemanticEntity],
557) -> Vec<ArityMismatch> {
558 let old_params: HashMap<String, Option<ParamInfo>> = old_entities
560 .iter()
561 .filter(|e| matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function"))
562 .map(|e| (e.id.clone(), extract_param_info_ts(&e.content, &e.file_path)))
563 .collect();
564
565 let new_by_id: HashMap<&str, &SemanticEntity> = new_entities
567 .iter()
568 .map(|e| (e.id.as_str(), e))
569 .collect();
570
571 let mut changed_entities: Vec<&str> = Vec::new();
573 for new_entity in new_entities {
574 if !matches!(new_entity.entity_type.as_str(), "function" | "method" | "arrow_function") {
575 continue;
576 }
577 let new_info = match extract_param_info_ts(&new_entity.content, &new_entity.file_path) {
578 Some(pi) => pi,
579 None => continue,
580 };
581 if let Some(Some(old_info)) = old_params.get(&new_entity.id) {
582 if old_info.min_params != new_info.min_params
583 || old_info.max_params != new_info.max_params
584 {
585 changed_entities.push(&new_entity.id);
586 }
587 }
588 }
589
590 if changed_entities.is_empty() {
591 return Vec::new();
592 }
593
594 let mut mismatches = Vec::new();
596
597 for edge in &new_graph.edges {
598 if edge.ref_type != RefType::Calls {
599 continue;
600 }
601 if !changed_entities.contains(&edge.to_entity.as_str()) {
602 continue;
603 }
604
605 let callee = match new_by_id.get(edge.to_entity.as_str()) {
606 Some(e) => *e,
607 None => continue,
608 };
609 let caller = match new_by_id.get(edge.from_entity.as_str()) {
610 Some(e) => *e,
611 None => continue,
612 };
613
614 let new_info = match extract_param_info_ts(&callee.content, &callee.file_path) {
615 Some(pi) => pi,
616 None => continue,
617 };
618
619 if new_info.is_variadic {
620 continue;
621 }
622
623 let actual = match count_call_args_ts(&caller.content, &callee.name, &caller.file_path) {
624 Some(a) => a,
625 None => continue,
626 };
627
628 if actual < new_info.min_params || actual > new_info.max_params {
629 mismatches.push(ArityMismatch {
630 caller_entity: caller.name.clone(),
631 callee_entity: callee.name.clone(),
632 expected_min: new_info.min_params,
633 expected_max: new_info.max_params,
634 actual_args: actual,
635 file_path: caller.file_path.clone(),
636 line: caller.start_line,
637 is_variadic: false,
638 });
639 }
640 }
641
642 mismatches
643}
644
645fn extract_param_count(content: &str) -> usize {
649 let first_line = content.lines().next().unwrap_or("");
650
651 let open = match first_line.find('(') {
652 Some(i) => i,
653 None => return 0,
654 };
655
656 let after_open = &first_line[open + 1..];
657 let close = match find_matching_paren(after_open) {
658 Some(i) => i,
659 None => return 0,
660 };
661
662 let params_str = after_open[..close].trim();
663 if params_str.is_empty() {
664 return 0;
665 }
666
667 count_top_level_commas(params_str) + 1
668}
669
670fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
672 let bytes = content.as_bytes();
673 let name_bytes = callee_name.as_bytes();
674 let mut search_start = 0;
675
676 while let Some(rel_pos) = content[search_start..].find(callee_name) {
677 let pos = search_start + rel_pos;
678 let after = pos + name_bytes.len();
679
680 let is_boundary = pos == 0 || {
681 let prev = bytes[pos - 1];
682 !prev.is_ascii_alphanumeric() && prev != b'_'
683 };
684
685 if is_boundary && after < bytes.len() && bytes[after] == b'(' {
686 let args_start = &content[after + 1..];
687 if let Some(close) = find_matching_paren(args_start) {
688 let args_str = args_start[..close].trim();
689 if args_str.is_empty() {
690 return Some(0);
691 }
692 return Some(count_top_level_commas(args_str) + 1);
693 }
694 }
695
696 search_start = pos + 1;
697 while search_start < content.len() && !content.is_char_boundary(search_start) {
698 search_start += 1;
699 }
700 }
701
702 None
703}
704
705fn find_matching_paren(s: &str) -> Option<usize> {
706 let mut depth = 0i32;
707 for (i, ch) in s.char_indices() {
708 match ch {
709 '(' => depth += 1,
710 ')' => {
711 if depth == 0 {
712 return Some(i);
713 }
714 depth -= 1;
715 }
716 _ => {}
717 }
718 }
719 None
720}
721
722fn count_top_level_commas(s: &str) -> usize {
723 let mut depth = 0i32;
724 let mut count = 0;
725 for ch in s.chars() {
726 match ch {
727 '(' | '[' | '{' | '<' => depth += 1,
728 ')' | ']' | '}' | '>' => depth -= 1,
729 ',' if depth == 0 => count += 1,
730 _ => {}
731 }
732 }
733 count
734}
735
736#[cfg(test)]
737mod tests {
738 use super::*;
739
740 #[test]
741 fn test_extract_param_count_basic() {
742 assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
743 assert_eq!(extract_param_count("function foo() {"), 0);
744 assert_eq!(extract_param_count("def bar(self, x):"), 2);
745 assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
746 }
747
748 #[test]
749 fn test_extract_param_count_nested() {
750 assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
751 }
752
753 #[test]
754 fn test_count_call_args() {
755 assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
756 assert_eq!(count_call_args("foo()", "foo"), Some(0));
757 assert_eq!(count_call_args("bar(1)", "foo"), None);
758 assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
759 }
760
761 #[test]
762 fn test_count_call_args_multibyte_utf8() {
763 assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
764 assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
765 assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
766 }
767
768 #[test]
769 fn test_extract_param_info_python() {
770 let info = extract_param_info_ts(
771 "def foo(a, b, c=3):\n pass",
772 "test.py",
773 )
774 .unwrap();
775 assert_eq!(info.min_params, 2);
776 assert_eq!(info.max_params, 3);
777 assert!(!info.is_variadic);
778 }
779
780 #[test]
781 fn test_extract_param_info_python_self() {
782 let info = extract_param_info_ts(
783 "def foo(self, a, b):\n pass",
784 "test.py",
785 )
786 .unwrap();
787 assert_eq!(info.min_params, 2);
788 assert_eq!(info.max_params, 2);
789 }
790
791 #[test]
792 fn test_extract_param_info_python_variadic() {
793 let info = extract_param_info_ts(
794 "def foo(a, *args, **kwargs):\n pass",
795 "test.py",
796 )
797 .unwrap();
798 assert!(info.is_variadic);
799 }
800
801 #[test]
802 fn test_extract_param_info_typescript() {
803 let info = extract_param_info_ts(
804 "function foo(a: number, b: string, c?: boolean): void {}",
805 "test.ts",
806 )
807 .unwrap();
808 assert_eq!(info.min_params, 2);
809 assert_eq!(info.max_params, 3);
810 assert!(!info.is_variadic);
811 }
812
813 #[test]
814 fn test_extract_param_info_rust() {
815 let info = extract_param_info_ts(
816 "fn foo(&self, a: i32, b: String) -> bool { true }",
817 "test.rs",
818 )
819 .unwrap();
820 assert_eq!(info.min_params, 2);
821 assert_eq!(info.max_params, 2);
822 }
823
824 #[test]
825 fn test_extract_param_info_go() {
826 let info = extract_param_info_ts(
827 "func foo(a string, b int) error { return nil }",
828 "test.go",
829 )
830 .unwrap();
831 assert_eq!(info.min_params, 2);
832 assert_eq!(info.max_params, 2);
833 }
834
835 #[test]
836 fn test_count_call_args_ts() {
837 let count = count_call_args_ts(
838 "function bar() { foo(1, 2, 3); }",
839 "foo",
840 "test.ts",
841 );
842 assert_eq!(count, Some(3));
843 }
844
845 #[test]
846 fn test_count_call_args_ts_method() {
847 let count = count_call_args_ts(
848 "function bar() { obj.foo(1, 2); }",
849 "foo",
850 "test.ts",
851 );
852 assert_eq!(count, Some(2));
853 }
854}