1use std::collections::HashMap;
5use std::path::Path;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityGraph, RefType};
9use crate::parser::plugins::code::languages::get_language_config;
10use crate::parser::registry::ParserRegistry;
11
12#[derive(Debug, Clone)]
13pub struct ContractViolation {
14 pub entity_name: String,
15 pub file_path: String,
16 pub expected_params: usize,
17 pub caller_name: String,
18 pub caller_file: String,
19 pub actual_args: usize,
20}
21
22#[derive(Debug, Clone)]
24pub struct ParamInfo {
25 pub min_params: usize,
26 pub max_params: usize,
27 pub is_variadic: bool,
28}
29
30#[derive(Debug, Clone)]
32pub struct ArityMismatch {
33 pub caller_entity: String,
34 pub callee_entity: String,
35 pub expected_min: usize,
36 pub expected_max: usize,
37 pub actual_args: usize,
38 pub file_path: String,
39 pub line: usize,
40 pub is_variadic: bool,
41}
42
43pub fn verify_contracts(
45 root: &Path,
46 file_paths: &[String],
47 registry: &ParserRegistry,
48 target_file: Option<&str>,
49) -> Vec<ContractViolation> {
50 let (graph, _) = EntityGraph::build(root, file_paths, registry);
51
52 let mut content_map: HashMap<String, String> = HashMap::new();
53 for fp in file_paths {
54 let full = root.join(fp);
55 let content = match std::fs::read_to_string(&full) {
56 Ok(c) => c,
57 Err(_) => continue,
58 };
59 for entity in registry.extract_entities(fp, &content) {
60 content_map.insert(entity.id.clone(), entity.content.clone());
61 }
62 }
63
64 let mut violations = Vec::new();
65
66 for edge in &graph.edges {
67 if edge.ref_type != RefType::Calls {
68 continue;
69 }
70
71 let callee = match graph.entities.get(&edge.to_entity) {
72 Some(e) => e,
73 None => continue,
74 };
75
76 if let Some(tf) = target_file {
77 if callee.file_path != tf {
78 continue;
79 }
80 }
81
82 if !matches!(
83 callee.entity_type.as_str(),
84 "function" | "method" | "arrow_function"
85 ) {
86 continue;
87 }
88
89 let callee_content = match content_map.get(&edge.to_entity) {
90 Some(c) => c,
91 None => continue,
92 };
93
94 let caller = match graph.entities.get(&edge.from_entity) {
95 Some(e) => e,
96 None => continue,
97 };
98
99 let caller_content = match content_map.get(&edge.from_entity) {
100 Some(c) => c,
101 None => continue,
102 };
103
104 let expected = extract_param_count(callee_content);
105 if expected == 0 {
106 continue;
107 }
108
109 if let Some(actual) = count_call_args(caller_content, &callee.name) {
110 if actual != expected {
111 violations.push(ContractViolation {
112 entity_name: callee.name.clone(),
113 file_path: callee.file_path.clone(),
114 expected_params: expected,
115 caller_name: caller.name.clone(),
116 caller_file: caller.file_path.clone(),
117 actual_args: actual,
118 });
119 }
120 }
121 }
122
123 violations
124}
125
126pub fn verify_contracts_with_graph(
128 graph: &EntityGraph,
129 all_entities: &[SemanticEntity],
130 target_file: Option<&str>,
131) -> Vec<ContractViolation> {
132 let content_map: HashMap<String, String> = all_entities
133 .iter()
134 .map(|e| (e.id.clone(), e.content.clone()))
135 .collect();
136
137 let mut violations = Vec::new();
138
139 for edge in &graph.edges {
140 if edge.ref_type != RefType::Calls {
141 continue;
142 }
143
144 let callee = match graph.entities.get(&edge.to_entity) {
145 Some(e) => e,
146 None => continue,
147 };
148
149 if let Some(tf) = target_file {
150 if callee.file_path != tf {
151 continue;
152 }
153 }
154
155 if !matches!(
156 callee.entity_type.as_str(),
157 "function" | "method" | "arrow_function"
158 ) {
159 continue;
160 }
161
162 let callee_content = match content_map.get(&edge.to_entity) {
163 Some(c) => c,
164 None => continue,
165 };
166
167 let caller = match graph.entities.get(&edge.from_entity) {
168 Some(e) => e,
169 None => continue,
170 };
171
172 let caller_content = match content_map.get(&edge.from_entity) {
173 Some(c) => c,
174 None => continue,
175 };
176
177 let expected = extract_param_count(callee_content);
178 if expected == 0 {
179 continue;
180 }
181
182 if let Some(actual) = count_call_args(caller_content, &callee.name) {
183 if actual != expected {
184 violations.push(ContractViolation {
185 entity_name: callee.name.clone(),
186 file_path: callee.file_path.clone(),
187 expected_params: expected,
188 caller_name: caller.name.clone(),
189 caller_file: caller.file_path.clone(),
190 actual_args: actual,
191 });
192 }
193 }
194 }
195
196 violations
197}
198
199fn lang_from_ext(ext: &str) -> &'static str {
202 match ext {
203 ".py" | ".pyi" => "python",
204 ".ts" | ".tsx" | ".mts" | ".cts" => "typescript",
205 ".js" | ".jsx" | ".mjs" | ".cjs" => "typescript",
206 ".rs" => "rust",
207 ".go" => "go",
208 _ => "unknown",
209 }
210}
211
212pub fn extract_param_info_ts(content: &str, file_path: &str) -> Option<ParamInfo> {
214 let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
215 let lang = lang_from_ext(ext);
216 if lang == "unknown" {
217 return None;
218 }
219 let config = get_language_config(ext)?;
220 let language = (config.get_language)()?;
221
222 let mut parser = tree_sitter::Parser::new();
223 let _ = parser.set_language(&language);
224 let tree = parser.parse(content.as_bytes(), None)?;
225
226 extract_param_info_from_node(tree.root_node(), content.as_bytes(), lang)
227}
228
229fn extract_param_info_from_node(
230 root: tree_sitter::Node,
231 source: &[u8],
232 lang: &str,
233) -> Option<ParamInfo> {
234 let func_node = find_first_function(root)?;
236 let params_node = func_node.child_by_field_name("parameters")?;
237
238 let mut min_params = 0usize;
239 let mut max_params = 0usize;
240 let mut is_variadic = false;
241
242 let mut cursor = params_node.walk();
243 for child in params_node.named_children(&mut cursor) {
244 let kind = child.kind();
245 match lang {
246 "python" => {
247 if kind == "identifier" {
248 let name = child.utf8_text(source).unwrap_or("");
249 if name == "self" || name == "cls" {
250 continue;
251 }
252 min_params += 1;
253 max_params += 1;
254 } else if kind == "typed_parameter" {
255 let name = child
256 .child_by_field_name("name")
257 .or_else(|| child.named_child(0))
258 .and_then(|n| n.utf8_text(source).ok())
259 .unwrap_or("");
260 if name == "self" || name == "cls" {
261 continue;
262 }
263 min_params += 1;
264 max_params += 1;
265 } else if kind == "default_parameter" || kind == "typed_default_parameter" {
266 max_params += 1;
267 } else if kind == "list_splat_pattern" || kind == "dictionary_splat_pattern" {
268 is_variadic = true;
269 }
270 }
271 "typescript" => {
272 if kind == "required_parameter" {
273 min_params += 1;
274 max_params += 1;
275 } else if kind == "optional_parameter" {
276 max_params += 1;
277 } else if kind == "rest_pattern" {
278 is_variadic = true;
279 }
280 }
281 "rust" => {
282 if kind == "parameter" {
283 let pat = child
284 .child_by_field_name("pattern")
285 .and_then(|n| n.utf8_text(source).ok())
286 .unwrap_or("");
287 let base = pat.trim_start_matches('&').trim();
289 let base = base.strip_prefix("mut ").unwrap_or(base).trim();
290 if base == "self" {
291 continue;
292 }
293 min_params += 1;
294 max_params += 1;
295 } else if kind == "self_parameter" {
296 continue;
297 }
298 }
299 "go" => {
300 if kind == "parameter_declaration" {
301 let type_text = child
303 .child_by_field_name("type")
304 .and_then(|n| n.utf8_text(source).ok())
305 .unwrap_or("");
306 if type_text.starts_with("...") {
307 is_variadic = true;
308 } else {
309 min_params += 1;
310 max_params += 1;
311 }
312 }
313 }
314 _ => {}
315 }
316 }
317
318 Some(ParamInfo {
319 min_params,
320 max_params,
321 is_variadic,
322 })
323}
324
325fn find_first_function(node: tree_sitter::Node) -> Option<tree_sitter::Node> {
326 let kind = node.kind();
327 if matches!(
328 kind,
329 "function_definition"
330 | "function_item"
331 | "function_declaration"
332 | "method_definition"
333 | "method_declaration"
334 | "arrow_function"
335 ) {
336 return Some(node);
337 }
338 let mut cursor = node.walk();
339 for child in node.named_children(&mut cursor) {
340 if let Some(f) = find_first_function(child) {
341 return Some(f);
342 }
343 }
344 None
345}
346
347pub fn count_call_args_ts(
349 caller_content: &str,
350 callee_name: &str,
351 file_path: &str,
352) -> Option<usize> {
353 let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
354 let config = get_language_config(ext)?;
355 let language = (config.get_language)()?;
356
357 let mut parser = tree_sitter::Parser::new();
358 let _ = parser.set_language(&language);
359 let tree = parser.parse(caller_content.as_bytes(), None)?;
360
361 find_call_arg_count(tree.root_node(), caller_content.as_bytes(), callee_name)
362}
363
364fn find_call_arg_count(
365 node: tree_sitter::Node,
366 source: &[u8],
367 callee_name: &str,
368) -> Option<usize> {
369 let kind = node.kind();
370
371 if kind == "call" || kind == "call_expression" {
372 let func = node.child_by_field_name("function")?;
373 let func_name = match func.kind() {
374 "identifier" => func.utf8_text(source).unwrap_or(""),
375 "attribute" | "member_expression" | "field_expression" => func
376 .child_by_field_name("attribute")
377 .or_else(|| func.child_by_field_name("property"))
378 .or_else(|| func.child_by_field_name("field"))
379 .and_then(|n| n.utf8_text(source).ok())
380 .unwrap_or(""),
381 "selector_expression" => func
382 .child_by_field_name("field")
383 .and_then(|n| n.utf8_text(source).ok())
384 .unwrap_or(""),
385 "scoped_identifier" => {
386 let text = func.utf8_text(source).unwrap_or("");
387 text.rsplit("::").next().unwrap_or("")
388 }
389 _ => "",
390 };
391
392 if func_name == callee_name {
393 let args = node.child_by_field_name("arguments")?;
394 let mut count = 0;
395 let mut cursor = args.walk();
396 for child in args.named_children(&mut cursor) {
397 if !child.kind().contains("comment") {
399 count += 1;
400 }
401 }
402 return Some(count);
403 }
404 }
405
406 let mut cursor = node.walk();
407 for child in node.named_children(&mut cursor) {
408 if let Some(count) = find_call_arg_count(child, source, callee_name) {
409 return Some(count);
410 }
411 }
412 None
413}
414
415pub fn find_arity_mismatches(
417 graph: &EntityGraph,
418 all_entities: &[SemanticEntity],
419) -> Vec<ArityMismatch> {
420 let entity_by_id: HashMap<&str, &SemanticEntity> = all_entities
421 .iter()
422 .map(|e| (e.id.as_str(), e))
423 .collect();
424
425 let mut param_cache: HashMap<String, Option<ParamInfo>> = HashMap::new();
427
428 let mut mismatches = Vec::new();
429
430 for edge in &graph.edges {
431 if edge.ref_type != RefType::Calls {
432 continue;
433 }
434
435 let callee_info = match graph.entities.get(&edge.to_entity) {
436 Some(e) => e,
437 None => continue,
438 };
439
440 if !matches!(
441 callee_info.entity_type.as_str(),
442 "function" | "method" | "arrow_function"
443 ) {
444 continue;
445 }
446
447 let callee = match entity_by_id.get(edge.to_entity.as_str()) {
448 Some(e) => *e,
449 None => continue,
450 };
451
452 let caller = match entity_by_id.get(edge.from_entity.as_str()) {
453 Some(e) => *e,
454 None => continue,
455 };
456
457 let param_info = param_cache
459 .entry(callee.id.clone())
460 .or_insert_with(|| extract_param_info_ts(&callee.content, &callee.file_path))
461 .clone();
462
463 let param_info = match param_info {
464 Some(pi) => pi,
465 None => continue,
466 };
467
468 if param_info.is_variadic {
470 continue;
471 }
472
473 let actual = match count_call_args_ts(
475 &caller.content,
476 &callee.name,
477 &caller.file_path,
478 ) {
479 Some(a) => a,
480 None => continue,
481 };
482
483 if actual < param_info.min_params || actual > param_info.max_params {
484 mismatches.push(ArityMismatch {
485 caller_entity: caller.name.clone(),
486 callee_entity: callee.name.clone(),
487 expected_min: param_info.min_params,
488 expected_max: param_info.max_params,
489 actual_args: actual,
490 file_path: caller.file_path.clone(),
491 line: caller.start_line,
492 is_variadic: false,
493 });
494 }
495 }
496
497 mismatches
498}
499
500pub fn find_broken_callers(
504 old_entities: &[SemanticEntity],
505 new_graph: &EntityGraph,
506 new_entities: &[SemanticEntity],
507) -> Vec<ArityMismatch> {
508 let old_params: HashMap<String, Option<ParamInfo>> = old_entities
510 .iter()
511 .filter(|e| matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function"))
512 .map(|e| (e.id.clone(), extract_param_info_ts(&e.content, &e.file_path)))
513 .collect();
514
515 let new_by_id: HashMap<&str, &SemanticEntity> = new_entities
517 .iter()
518 .map(|e| (e.id.as_str(), e))
519 .collect();
520
521 let mut changed_entities: Vec<&str> = Vec::new();
523 for new_entity in new_entities {
524 if !matches!(new_entity.entity_type.as_str(), "function" | "method" | "arrow_function") {
525 continue;
526 }
527 let new_info = match extract_param_info_ts(&new_entity.content, &new_entity.file_path) {
528 Some(pi) => pi,
529 None => continue,
530 };
531 if let Some(Some(old_info)) = old_params.get(&new_entity.id) {
532 if old_info.min_params != new_info.min_params
533 || old_info.max_params != new_info.max_params
534 {
535 changed_entities.push(&new_entity.id);
536 }
537 }
538 }
539
540 if changed_entities.is_empty() {
541 return Vec::new();
542 }
543
544 let mut mismatches = Vec::new();
546
547 for edge in &new_graph.edges {
548 if edge.ref_type != RefType::Calls {
549 continue;
550 }
551 if !changed_entities.contains(&edge.to_entity.as_str()) {
552 continue;
553 }
554
555 let callee = match new_by_id.get(edge.to_entity.as_str()) {
556 Some(e) => *e,
557 None => continue,
558 };
559 let caller = match new_by_id.get(edge.from_entity.as_str()) {
560 Some(e) => *e,
561 None => continue,
562 };
563
564 let new_info = match extract_param_info_ts(&callee.content, &callee.file_path) {
565 Some(pi) => pi,
566 None => continue,
567 };
568
569 if new_info.is_variadic {
570 continue;
571 }
572
573 let actual = match count_call_args_ts(&caller.content, &callee.name, &caller.file_path) {
574 Some(a) => a,
575 None => continue,
576 };
577
578 if actual < new_info.min_params || actual > new_info.max_params {
579 mismatches.push(ArityMismatch {
580 caller_entity: caller.name.clone(),
581 callee_entity: callee.name.clone(),
582 expected_min: new_info.min_params,
583 expected_max: new_info.max_params,
584 actual_args: actual,
585 file_path: caller.file_path.clone(),
586 line: caller.start_line,
587 is_variadic: false,
588 });
589 }
590 }
591
592 mismatches
593}
594
595fn extract_param_count(content: &str) -> usize {
599 let first_line = content.lines().next().unwrap_or("");
600
601 let open = match first_line.find('(') {
602 Some(i) => i,
603 None => return 0,
604 };
605
606 let after_open = &first_line[open + 1..];
607 let close = match find_matching_paren(after_open) {
608 Some(i) => i,
609 None => return 0,
610 };
611
612 let params_str = after_open[..close].trim();
613 if params_str.is_empty() {
614 return 0;
615 }
616
617 count_top_level_commas(params_str) + 1
618}
619
620fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
622 let bytes = content.as_bytes();
623 let name_bytes = callee_name.as_bytes();
624 let mut search_start = 0;
625
626 while let Some(rel_pos) = content[search_start..].find(callee_name) {
627 let pos = search_start + rel_pos;
628 let after = pos + name_bytes.len();
629
630 let is_boundary = pos == 0 || {
631 let prev = bytes[pos - 1];
632 !prev.is_ascii_alphanumeric() && prev != b'_'
633 };
634
635 if is_boundary && after < bytes.len() && bytes[after] == b'(' {
636 let args_start = &content[after + 1..];
637 if let Some(close) = find_matching_paren(args_start) {
638 let args_str = args_start[..close].trim();
639 if args_str.is_empty() {
640 return Some(0);
641 }
642 return Some(count_top_level_commas(args_str) + 1);
643 }
644 }
645
646 search_start = pos + 1;
647 while search_start < content.len() && !content.is_char_boundary(search_start) {
648 search_start += 1;
649 }
650 }
651
652 None
653}
654
655fn find_matching_paren(s: &str) -> Option<usize> {
656 let mut depth = 0i32;
657 for (i, ch) in s.char_indices() {
658 match ch {
659 '(' => depth += 1,
660 ')' => {
661 if depth == 0 {
662 return Some(i);
663 }
664 depth -= 1;
665 }
666 _ => {}
667 }
668 }
669 None
670}
671
672fn count_top_level_commas(s: &str) -> usize {
673 let mut depth = 0i32;
674 let mut count = 0;
675 for ch in s.chars() {
676 match ch {
677 '(' | '[' | '{' | '<' => depth += 1,
678 ')' | ']' | '}' | '>' => depth -= 1,
679 ',' if depth == 0 => count += 1,
680 _ => {}
681 }
682 }
683 count
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689
690 #[test]
691 fn test_extract_param_count_basic() {
692 assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
693 assert_eq!(extract_param_count("function foo() {"), 0);
694 assert_eq!(extract_param_count("def bar(self, x):"), 2);
695 assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
696 }
697
698 #[test]
699 fn test_extract_param_count_nested() {
700 assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
701 }
702
703 #[test]
704 fn test_count_call_args() {
705 assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
706 assert_eq!(count_call_args("foo()", "foo"), Some(0));
707 assert_eq!(count_call_args("bar(1)", "foo"), None);
708 assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
709 }
710
711 #[test]
712 fn test_count_call_args_multibyte_utf8() {
713 assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
714 assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
715 assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
716 }
717
718 #[test]
719 fn test_extract_param_info_python() {
720 let info = extract_param_info_ts(
721 "def foo(a, b, c=3):\n pass",
722 "test.py",
723 )
724 .unwrap();
725 assert_eq!(info.min_params, 2);
726 assert_eq!(info.max_params, 3);
727 assert!(!info.is_variadic);
728 }
729
730 #[test]
731 fn test_extract_param_info_python_self() {
732 let info = extract_param_info_ts(
733 "def foo(self, a, b):\n pass",
734 "test.py",
735 )
736 .unwrap();
737 assert_eq!(info.min_params, 2);
738 assert_eq!(info.max_params, 2);
739 }
740
741 #[test]
742 fn test_extract_param_info_python_variadic() {
743 let info = extract_param_info_ts(
744 "def foo(a, *args, **kwargs):\n pass",
745 "test.py",
746 )
747 .unwrap();
748 assert!(info.is_variadic);
749 }
750
751 #[test]
752 fn test_extract_param_info_typescript() {
753 let info = extract_param_info_ts(
754 "function foo(a: number, b: string, c?: boolean): void {}",
755 "test.ts",
756 )
757 .unwrap();
758 assert_eq!(info.min_params, 2);
759 assert_eq!(info.max_params, 3);
760 assert!(!info.is_variadic);
761 }
762
763 #[test]
764 fn test_extract_param_info_rust() {
765 let info = extract_param_info_ts(
766 "fn foo(&self, a: i32, b: String) -> bool { true }",
767 "test.rs",
768 )
769 .unwrap();
770 assert_eq!(info.min_params, 2);
771 assert_eq!(info.max_params, 2);
772 }
773
774 #[test]
775 fn test_extract_param_info_go() {
776 let info = extract_param_info_ts(
777 "func foo(a string, b int) error { return nil }",
778 "test.go",
779 )
780 .unwrap();
781 assert_eq!(info.min_params, 2);
782 assert_eq!(info.max_params, 2);
783 }
784
785 #[test]
786 fn test_count_call_args_ts() {
787 let count = count_call_args_ts(
788 "function bar() { foo(1, 2, 3); }",
789 "foo",
790 "test.ts",
791 );
792 assert_eq!(count, Some(3));
793 }
794
795 #[test]
796 fn test_count_call_args_ts_method() {
797 let count = count_call_args_ts(
798 "function bar() { obj.foo(1, 2); }",
799 "foo",
800 "test.ts",
801 );
802 assert_eq!(count, Some(2));
803 }
804}