1use std::collections::HashMap;
5use std::path::Path;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityGraph, RefType};
9use crate::parser::plugins::code::languages::get_language_config;
10use crate::parser::registry::ParserRegistry;
11
12#[derive(Debug, Clone)]
13pub struct ContractViolation {
14 pub entity_name: String,
15 pub file_path: String,
16 pub expected_params: usize,
17 pub caller_name: String,
18 pub caller_file: String,
19 pub actual_args: usize,
20}
21
22#[derive(Debug, Clone)]
24pub struct ParamInfo {
25 pub min_params: usize,
26 pub max_params: usize,
27 pub is_variadic: bool,
28}
29
30#[derive(Debug, Clone)]
32pub struct ArityMismatch {
33 pub caller_entity: String,
34 pub callee_entity: String,
35 pub expected_min: usize,
36 pub expected_max: usize,
37 pub actual_args: usize,
38 pub file_path: String,
39 pub line: usize,
40 pub is_variadic: bool,
41}
42
43pub fn verify_contracts(
45 root: &Path,
46 file_paths: &[String],
47 registry: &ParserRegistry,
48 target_file: Option<&str>,
49) -> Vec<ContractViolation> {
50 let (graph, _) = EntityGraph::build(root, file_paths, registry);
51
52 let mut content_map: HashMap<String, String> = HashMap::new();
53 for fp in file_paths {
54 let full = root.join(fp);
55 let content = match std::fs::read_to_string(&full) {
56 Ok(c) => c,
57 Err(_) => continue,
58 };
59 for entity in registry.extract_entities(fp, &content) {
60 content_map.insert(entity.id.clone(), entity.content.clone());
61 }
62 }
63
64 let mut violations = Vec::new();
65
66 for edge in &graph.edges {
67 if edge.ref_type != RefType::Calls {
68 continue;
69 }
70
71 let callee = match graph.entities.get(&edge.to_entity) {
72 Some(e) => e,
73 None => continue,
74 };
75
76 if let Some(tf) = target_file {
77 if callee.file_path != tf {
78 continue;
79 }
80 }
81
82 if !matches!(
83 callee.entity_type.as_str(),
84 "function" | "method" | "arrow_function"
85 ) {
86 continue;
87 }
88
89 let callee_content = match content_map.get(&edge.to_entity) {
90 Some(c) => c,
91 None => continue,
92 };
93
94 let caller = match graph.entities.get(&edge.from_entity) {
95 Some(e) => e,
96 None => continue,
97 };
98
99 let caller_content = match content_map.get(&edge.from_entity) {
100 Some(c) => c,
101 None => continue,
102 };
103
104 let expected = extract_param_count(callee_content);
105 if expected == 0 {
106 continue;
107 }
108
109 if let Some(actual) = count_call_args(caller_content, &callee.name) {
110 if actual != expected {
111 violations.push(ContractViolation {
112 entity_name: callee.name.clone(),
113 file_path: callee.file_path.clone(),
114 expected_params: expected,
115 caller_name: caller.name.clone(),
116 caller_file: caller.file_path.clone(),
117 actual_args: actual,
118 });
119 }
120 }
121 }
122
123 violations
124}
125
126pub fn verify_contracts_with_graph(
128 graph: &EntityGraph,
129 all_entities: &[SemanticEntity],
130 target_file: Option<&str>,
131) -> Vec<ContractViolation> {
132 let content_map: HashMap<String, String> = all_entities
133 .iter()
134 .map(|e| (e.id.clone(), e.content.clone()))
135 .collect();
136
137 let mut violations = Vec::new();
138
139 for edge in &graph.edges {
140 if edge.ref_type != RefType::Calls {
141 continue;
142 }
143
144 let callee = match graph.entities.get(&edge.to_entity) {
145 Some(e) => e,
146 None => continue,
147 };
148
149 if let Some(tf) = target_file {
150 if callee.file_path != tf {
151 continue;
152 }
153 }
154
155 if !matches!(
156 callee.entity_type.as_str(),
157 "function" | "method" | "arrow_function"
158 ) {
159 continue;
160 }
161
162 let callee_content = match content_map.get(&edge.to_entity) {
163 Some(c) => c,
164 None => continue,
165 };
166
167 let caller = match graph.entities.get(&edge.from_entity) {
168 Some(e) => e,
169 None => continue,
170 };
171
172 let caller_content = match content_map.get(&edge.from_entity) {
173 Some(c) => c,
174 None => continue,
175 };
176
177 let expected = extract_param_count(callee_content);
178 if expected == 0 {
179 continue;
180 }
181
182 if let Some(actual) = count_call_args(caller_content, &callee.name) {
183 if actual != expected {
184 violations.push(ContractViolation {
185 entity_name: callee.name.clone(),
186 file_path: callee.file_path.clone(),
187 expected_params: expected,
188 caller_name: caller.name.clone(),
189 caller_file: caller.file_path.clone(),
190 actual_args: actual,
191 });
192 }
193 }
194 }
195
196 violations
197}
198
199fn lang_from_ext(ext: &str) -> &'static str {
202 match ext {
203 ".py" | ".pyi" => "python",
204 ".ts" | ".tsx" | ".mts" | ".cts" => "typescript",
205 ".js" | ".jsx" | ".mjs" | ".cjs" => "typescript",
206 ".rs" => "rust",
207 ".go" => "go",
208 _ => "unknown",
209 }
210}
211
212pub fn extract_param_info_ts(content: &str, file_path: &str) -> Option<ParamInfo> {
214 let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
215 let lang = lang_from_ext(ext);
216 if lang == "unknown" {
217 return None;
218 }
219 let config = get_language_config(ext)?;
220 let language = (config.get_language)()?;
221
222 let mut parser = tree_sitter::Parser::new();
223 let _ = parser.set_language(&language);
224 let tree = parser.parse(content.as_bytes(), None)?;
225
226 extract_param_info_from_node(tree.root_node(), content.as_bytes(), lang)
227}
228
229fn extract_param_info_from_node(
230 root: tree_sitter::Node,
231 source: &[u8],
232 lang: &str,
233) -> Option<ParamInfo> {
234 let func_node = find_first_function(root)?;
236 let params_node = func_node.child_by_field_name("parameters")?;
237
238 let mut min_params = 0usize;
239 let mut max_params = 0usize;
240 let mut is_variadic = false;
241
242 let mut cursor = params_node.walk();
243 for child in params_node.named_children(&mut cursor) {
244 let kind = child.kind();
245 match lang {
246 "python" => {
247 if kind == "identifier" {
248 let name = child.utf8_text(source).unwrap_or("");
249 if name == "self" || name == "cls" {
250 continue;
251 }
252 min_params += 1;
253 max_params += 1;
254 } else if kind == "typed_parameter" {
255 let name = child
256 .child_by_field_name("name")
257 .or_else(|| child.named_child(0))
258 .and_then(|n| n.utf8_text(source).ok())
259 .unwrap_or("");
260 if name == "self" || name == "cls" {
261 continue;
262 }
263 min_params += 1;
264 max_params += 1;
265 } else if kind == "default_parameter" || kind == "typed_default_parameter" {
266 max_params += 1;
267 } else if kind == "list_splat_pattern" || kind == "dictionary_splat_pattern" {
268 is_variadic = true;
269 }
270 }
271 "typescript" => {
272 if kind == "required_parameter" {
273 min_params += 1;
274 max_params += 1;
275 } else if kind == "optional_parameter" {
276 max_params += 1;
277 } else if kind == "rest_pattern" {
278 is_variadic = true;
279 }
280 }
281 "rust" => {
282 if kind == "parameter" {
283 let pat = child
284 .child_by_field_name("pattern")
285 .and_then(|n| n.utf8_text(source).ok())
286 .unwrap_or("");
287 let base = pat.trim_start_matches('&').trim();
289 let base = base.strip_prefix("mut ").unwrap_or(base).trim();
290 if base == "self" {
291 continue;
292 }
293 min_params += 1;
294 max_params += 1;
295 } else if kind == "self_parameter" {
296 continue;
297 }
298 }
299 "go" => {
300 if kind == "parameter_declaration" {
301 let type_text = child
303 .child_by_field_name("type")
304 .and_then(|n| n.utf8_text(source).ok())
305 .unwrap_or("");
306 if type_text.starts_with("...") {
307 is_variadic = true;
308 } else {
309 min_params += 1;
310 max_params += 1;
311 }
312 }
313 }
314 _ => {}
315 }
316 }
317
318 Some(ParamInfo {
319 min_params,
320 max_params,
321 is_variadic,
322 })
323}
324
325fn find_first_function(root: tree_sitter::Node) -> Option<tree_sitter::Node> {
326 let mut worklist = vec![root];
327 while let Some(node) = worklist.pop() {
328 let kind = node.kind();
329 if matches!(
330 kind,
331 "function_definition"
332 | "function_item"
333 | "function_declaration"
334 | "method_definition"
335 | "method_declaration"
336 | "arrow_function"
337 ) {
338 return Some(node);
339 }
340 let mut cursor = node.walk();
341 let children: Vec<_> = node.named_children(&mut cursor).collect();
342 for child in children.into_iter().rev() {
343 worklist.push(child);
344 }
345 }
346 None
347}
348
349pub fn count_call_args_ts(
351 caller_content: &str,
352 callee_name: &str,
353 file_path: &str,
354) -> Option<usize> {
355 let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
356 let config = get_language_config(ext)?;
357 let language = (config.get_language)()?;
358
359 let mut parser = tree_sitter::Parser::new();
360 let _ = parser.set_language(&language);
361 let tree = parser.parse(caller_content.as_bytes(), None)?;
362
363 find_call_arg_count(tree.root_node(), caller_content.as_bytes(), callee_name)
364}
365
366fn find_call_arg_count(
367 root: tree_sitter::Node,
368 source: &[u8],
369 callee_name: &str,
370) -> Option<usize> {
371 let mut worklist = vec![root];
372 while let Some(node) = worklist.pop() {
373 let kind = node.kind();
374
375 if kind == "call" || kind == "call_expression" {
376 if let Some(func) = node.child_by_field_name("function") {
377 let func_name = match func.kind() {
378 "identifier" => func.utf8_text(source).unwrap_or(""),
379 "attribute" | "member_expression" | "field_expression" => func
380 .child_by_field_name("attribute")
381 .or_else(|| func.child_by_field_name("property"))
382 .or_else(|| func.child_by_field_name("field"))
383 .and_then(|n| n.utf8_text(source).ok())
384 .unwrap_or(""),
385 "selector_expression" => func
386 .child_by_field_name("field")
387 .and_then(|n| n.utf8_text(source).ok())
388 .unwrap_or(""),
389 "scoped_identifier" => {
390 let text = func.utf8_text(source).unwrap_or("");
391 text.rsplit("::").next().unwrap_or("")
392 }
393 _ => "",
394 };
395
396 if func_name == callee_name {
397 if let Some(args) = node.child_by_field_name("arguments") {
398 let mut count = 0;
399 let mut cursor = args.walk();
400 for child in args.named_children(&mut cursor) {
401 if !child.kind().contains("comment") {
403 count += 1;
404 }
405 }
406 return Some(count);
407 }
408 }
409 }
410 }
411
412 let mut cursor = node.walk();
413 let children: Vec<_> = node.named_children(&mut cursor).collect();
414 for child in children.into_iter().rev() {
415 worklist.push(child);
416 }
417 }
418 None
419}
420
421pub fn find_arity_mismatches(
423 graph: &EntityGraph,
424 all_entities: &[SemanticEntity],
425) -> Vec<ArityMismatch> {
426 let entity_by_id: HashMap<&str, &SemanticEntity> = all_entities
427 .iter()
428 .map(|e| (e.id.as_str(), e))
429 .collect();
430
431 let mut param_cache: HashMap<String, Option<ParamInfo>> = HashMap::new();
433
434 let mut mismatches = Vec::new();
435
436 for edge in &graph.edges {
437 if edge.ref_type != RefType::Calls {
438 continue;
439 }
440
441 let callee_info = match graph.entities.get(&edge.to_entity) {
442 Some(e) => e,
443 None => continue,
444 };
445
446 if !matches!(
447 callee_info.entity_type.as_str(),
448 "function" | "method" | "arrow_function"
449 ) {
450 continue;
451 }
452
453 let callee = match entity_by_id.get(edge.to_entity.as_str()) {
454 Some(e) => *e,
455 None => continue,
456 };
457
458 let caller = match entity_by_id.get(edge.from_entity.as_str()) {
459 Some(e) => *e,
460 None => continue,
461 };
462
463 let param_info = param_cache
465 .entry(callee.id.clone())
466 .or_insert_with(|| extract_param_info_ts(&callee.content, &callee.file_path))
467 .clone();
468
469 let param_info = match param_info {
470 Some(pi) => pi,
471 None => continue,
472 };
473
474 if param_info.is_variadic {
476 continue;
477 }
478
479 let actual = match count_call_args_ts(
481 &caller.content,
482 &callee.name,
483 &caller.file_path,
484 ) {
485 Some(a) => a,
486 None => continue,
487 };
488
489 if actual < param_info.min_params || actual > param_info.max_params {
490 mismatches.push(ArityMismatch {
491 caller_entity: caller.name.clone(),
492 callee_entity: callee.name.clone(),
493 expected_min: param_info.min_params,
494 expected_max: param_info.max_params,
495 actual_args: actual,
496 file_path: caller.file_path.clone(),
497 line: caller.start_line,
498 is_variadic: false,
499 });
500 }
501 }
502
503 mismatches
504}
505
506pub fn find_broken_callers(
510 old_entities: &[SemanticEntity],
511 new_graph: &EntityGraph,
512 new_entities: &[SemanticEntity],
513) -> Vec<ArityMismatch> {
514 let old_params: HashMap<String, Option<ParamInfo>> = old_entities
516 .iter()
517 .filter(|e| matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function"))
518 .map(|e| (e.id.clone(), extract_param_info_ts(&e.content, &e.file_path)))
519 .collect();
520
521 let new_by_id: HashMap<&str, &SemanticEntity> = new_entities
523 .iter()
524 .map(|e| (e.id.as_str(), e))
525 .collect();
526
527 let mut changed_entities: Vec<&str> = Vec::new();
529 for new_entity in new_entities {
530 if !matches!(new_entity.entity_type.as_str(), "function" | "method" | "arrow_function") {
531 continue;
532 }
533 let new_info = match extract_param_info_ts(&new_entity.content, &new_entity.file_path) {
534 Some(pi) => pi,
535 None => continue,
536 };
537 if let Some(Some(old_info)) = old_params.get(&new_entity.id) {
538 if old_info.min_params != new_info.min_params
539 || old_info.max_params != new_info.max_params
540 {
541 changed_entities.push(&new_entity.id);
542 }
543 }
544 }
545
546 if changed_entities.is_empty() {
547 return Vec::new();
548 }
549
550 let mut mismatches = Vec::new();
552
553 for edge in &new_graph.edges {
554 if edge.ref_type != RefType::Calls {
555 continue;
556 }
557 if !changed_entities.contains(&edge.to_entity.as_str()) {
558 continue;
559 }
560
561 let callee = match new_by_id.get(edge.to_entity.as_str()) {
562 Some(e) => *e,
563 None => continue,
564 };
565 let caller = match new_by_id.get(edge.from_entity.as_str()) {
566 Some(e) => *e,
567 None => continue,
568 };
569
570 let new_info = match extract_param_info_ts(&callee.content, &callee.file_path) {
571 Some(pi) => pi,
572 None => continue,
573 };
574
575 if new_info.is_variadic {
576 continue;
577 }
578
579 let actual = match count_call_args_ts(&caller.content, &callee.name, &caller.file_path) {
580 Some(a) => a,
581 None => continue,
582 };
583
584 if actual < new_info.min_params || actual > new_info.max_params {
585 mismatches.push(ArityMismatch {
586 caller_entity: caller.name.clone(),
587 callee_entity: callee.name.clone(),
588 expected_min: new_info.min_params,
589 expected_max: new_info.max_params,
590 actual_args: actual,
591 file_path: caller.file_path.clone(),
592 line: caller.start_line,
593 is_variadic: false,
594 });
595 }
596 }
597
598 mismatches
599}
600
601fn extract_param_count(content: &str) -> usize {
605 let first_line = content.lines().next().unwrap_or("");
606
607 let open = match first_line.find('(') {
608 Some(i) => i,
609 None => return 0,
610 };
611
612 let after_open = &first_line[open + 1..];
613 let close = match find_matching_paren(after_open) {
614 Some(i) => i,
615 None => return 0,
616 };
617
618 let params_str = after_open[..close].trim();
619 if params_str.is_empty() {
620 return 0;
621 }
622
623 count_top_level_commas(params_str) + 1
624}
625
626fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
628 let bytes = content.as_bytes();
629 let name_bytes = callee_name.as_bytes();
630 let mut search_start = 0;
631
632 while let Some(rel_pos) = content[search_start..].find(callee_name) {
633 let pos = search_start + rel_pos;
634 let after = pos + name_bytes.len();
635
636 let is_boundary = pos == 0 || {
637 let prev = bytes[pos - 1];
638 !prev.is_ascii_alphanumeric() && prev != b'_'
639 };
640
641 if is_boundary && after < bytes.len() && bytes[after] == b'(' {
642 let args_start = &content[after + 1..];
643 if let Some(close) = find_matching_paren(args_start) {
644 let args_str = args_start[..close].trim();
645 if args_str.is_empty() {
646 return Some(0);
647 }
648 return Some(count_top_level_commas(args_str) + 1);
649 }
650 }
651
652 search_start = pos + 1;
653 while search_start < content.len() && !content.is_char_boundary(search_start) {
654 search_start += 1;
655 }
656 }
657
658 None
659}
660
661fn find_matching_paren(s: &str) -> Option<usize> {
662 let mut depth = 0i32;
663 for (i, ch) in s.char_indices() {
664 match ch {
665 '(' => depth += 1,
666 ')' => {
667 if depth == 0 {
668 return Some(i);
669 }
670 depth -= 1;
671 }
672 _ => {}
673 }
674 }
675 None
676}
677
678fn count_top_level_commas(s: &str) -> usize {
679 let mut depth = 0i32;
680 let mut count = 0;
681 for ch in s.chars() {
682 match ch {
683 '(' | '[' | '{' | '<' => depth += 1,
684 ')' | ']' | '}' | '>' => depth -= 1,
685 ',' if depth == 0 => count += 1,
686 _ => {}
687 }
688 }
689 count
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695
696 #[test]
697 fn test_extract_param_count_basic() {
698 assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
699 assert_eq!(extract_param_count("function foo() {"), 0);
700 assert_eq!(extract_param_count("def bar(self, x):"), 2);
701 assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
702 }
703
704 #[test]
705 fn test_extract_param_count_nested() {
706 assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
707 }
708
709 #[test]
710 fn test_count_call_args() {
711 assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
712 assert_eq!(count_call_args("foo()", "foo"), Some(0));
713 assert_eq!(count_call_args("bar(1)", "foo"), None);
714 assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
715 }
716
717 #[test]
718 fn test_count_call_args_multibyte_utf8() {
719 assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
720 assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
721 assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
722 }
723
724 #[test]
725 fn test_extract_param_info_python() {
726 let info = extract_param_info_ts(
727 "def foo(a, b, c=3):\n pass",
728 "test.py",
729 )
730 .unwrap();
731 assert_eq!(info.min_params, 2);
732 assert_eq!(info.max_params, 3);
733 assert!(!info.is_variadic);
734 }
735
736 #[test]
737 fn test_extract_param_info_python_self() {
738 let info = extract_param_info_ts(
739 "def foo(self, a, b):\n pass",
740 "test.py",
741 )
742 .unwrap();
743 assert_eq!(info.min_params, 2);
744 assert_eq!(info.max_params, 2);
745 }
746
747 #[test]
748 fn test_extract_param_info_python_variadic() {
749 let info = extract_param_info_ts(
750 "def foo(a, *args, **kwargs):\n pass",
751 "test.py",
752 )
753 .unwrap();
754 assert!(info.is_variadic);
755 }
756
757 #[test]
758 fn test_extract_param_info_typescript() {
759 let info = extract_param_info_ts(
760 "function foo(a: number, b: string, c?: boolean): void {}",
761 "test.ts",
762 )
763 .unwrap();
764 assert_eq!(info.min_params, 2);
765 assert_eq!(info.max_params, 3);
766 assert!(!info.is_variadic);
767 }
768
769 #[test]
770 fn test_extract_param_info_rust() {
771 let info = extract_param_info_ts(
772 "fn foo(&self, a: i32, b: String) -> bool { true }",
773 "test.rs",
774 )
775 .unwrap();
776 assert_eq!(info.min_params, 2);
777 assert_eq!(info.max_params, 2);
778 }
779
780 #[test]
781 fn test_extract_param_info_go() {
782 let info = extract_param_info_ts(
783 "func foo(a string, b int) error { return nil }",
784 "test.go",
785 )
786 .unwrap();
787 assert_eq!(info.min_params, 2);
788 assert_eq!(info.max_params, 2);
789 }
790
791 #[test]
792 fn test_count_call_args_ts() {
793 let count = count_call_args_ts(
794 "function bar() { foo(1, 2, 3); }",
795 "foo",
796 "test.ts",
797 );
798 assert_eq!(count, Some(3));
799 }
800
801 #[test]
802 fn test_count_call_args_ts_method() {
803 let count = count_call_args_ts(
804 "function bar() { obj.foo(1, 2); }",
805 "foo",
806 "test.ts",
807 );
808 assert_eq!(count, Some(2));
809 }
810}