1use std::collections::HashMap;
5use std::path::Path;
6
7use crate::model::entity::SemanticEntity;
8use crate::parser::graph::{EntityGraph, RefType};
9use crate::parser::plugins::code::languages::get_language_config;
10use crate::parser::registry::ParserRegistry;
11
12#[derive(Debug, Clone)]
13pub struct ContractViolation {
14 pub entity_name: String,
15 pub file_path: String,
16 pub expected_params: usize,
17 pub caller_name: String,
18 pub caller_file: String,
19 pub actual_args: usize,
20}
21
22#[derive(Debug, Clone)]
24pub struct ParamInfo {
25 pub min_params: usize,
26 pub max_params: usize,
27 pub is_variadic: bool,
28}
29
30#[derive(Debug, Clone)]
32pub struct ArityMismatch {
33 pub caller_entity: String,
34 pub callee_entity: String,
35 pub expected_min: usize,
36 pub expected_max: usize,
37 pub actual_args: usize,
38 pub file_path: String,
39 pub line: usize,
40 pub is_variadic: bool,
41}
42
43pub fn verify_contracts(
45 root: &Path,
46 file_paths: &[String],
47 registry: &ParserRegistry,
48 target_file: Option<&str>,
49) -> Vec<ContractViolation> {
50 let (graph, _) = EntityGraph::build(root, file_paths, registry);
51
52 let mut content_map: HashMap<String, String> = HashMap::new();
53 for fp in file_paths {
54 let full = root.join(fp);
55 let content = match std::fs::read_to_string(&full) {
56 Ok(c) => c,
57 Err(_) => continue,
58 };
59 let plugin = match registry.get_plugin_with_content(fp, &content) {
60 Some(p) => p,
61 None => continue,
62 };
63 for entity in plugin.extract_entities(&content, fp) {
64 content_map.insert(entity.id.clone(), entity.content.clone());
65 }
66 }
67
68 let mut violations = Vec::new();
69
70 for edge in &graph.edges {
71 if edge.ref_type != RefType::Calls {
72 continue;
73 }
74
75 let callee = match graph.entities.get(&edge.to_entity) {
76 Some(e) => e,
77 None => continue,
78 };
79
80 if let Some(tf) = target_file {
81 if callee.file_path != tf {
82 continue;
83 }
84 }
85
86 if !matches!(
87 callee.entity_type.as_str(),
88 "function" | "method" | "arrow_function"
89 ) {
90 continue;
91 }
92
93 let callee_content = match content_map.get(&edge.to_entity) {
94 Some(c) => c,
95 None => continue,
96 };
97
98 let caller = match graph.entities.get(&edge.from_entity) {
99 Some(e) => e,
100 None => continue,
101 };
102
103 let caller_content = match content_map.get(&edge.from_entity) {
104 Some(c) => c,
105 None => continue,
106 };
107
108 let expected = extract_param_count(callee_content);
109 if expected == 0 {
110 continue;
111 }
112
113 if let Some(actual) = count_call_args(caller_content, &callee.name) {
114 if actual != expected {
115 violations.push(ContractViolation {
116 entity_name: callee.name.clone(),
117 file_path: callee.file_path.clone(),
118 expected_params: expected,
119 caller_name: caller.name.clone(),
120 caller_file: caller.file_path.clone(),
121 actual_args: actual,
122 });
123 }
124 }
125 }
126
127 violations
128}
129
130pub fn verify_contracts_with_graph(
132 graph: &EntityGraph,
133 all_entities: &[SemanticEntity],
134 target_file: Option<&str>,
135) -> Vec<ContractViolation> {
136 let content_map: HashMap<String, String> = all_entities
137 .iter()
138 .map(|e| (e.id.clone(), e.content.clone()))
139 .collect();
140
141 let mut violations = Vec::new();
142
143 for edge in &graph.edges {
144 if edge.ref_type != RefType::Calls {
145 continue;
146 }
147
148 let callee = match graph.entities.get(&edge.to_entity) {
149 Some(e) => e,
150 None => continue,
151 };
152
153 if let Some(tf) = target_file {
154 if callee.file_path != tf {
155 continue;
156 }
157 }
158
159 if !matches!(
160 callee.entity_type.as_str(),
161 "function" | "method" | "arrow_function"
162 ) {
163 continue;
164 }
165
166 let callee_content = match content_map.get(&edge.to_entity) {
167 Some(c) => c,
168 None => continue,
169 };
170
171 let caller = match graph.entities.get(&edge.from_entity) {
172 Some(e) => e,
173 None => continue,
174 };
175
176 let caller_content = match content_map.get(&edge.from_entity) {
177 Some(c) => c,
178 None => continue,
179 };
180
181 let expected = extract_param_count(callee_content);
182 if expected == 0 {
183 continue;
184 }
185
186 if let Some(actual) = count_call_args(caller_content, &callee.name) {
187 if actual != expected {
188 violations.push(ContractViolation {
189 entity_name: callee.name.clone(),
190 file_path: callee.file_path.clone(),
191 expected_params: expected,
192 caller_name: caller.name.clone(),
193 caller_file: caller.file_path.clone(),
194 actual_args: actual,
195 });
196 }
197 }
198 }
199
200 violations
201}
202
203fn lang_from_ext(ext: &str) -> &'static str {
206 match ext {
207 ".py" | ".pyi" => "python",
208 ".ts" | ".tsx" | ".mts" | ".cts" => "typescript",
209 ".js" | ".jsx" | ".mjs" | ".cjs" => "typescript",
210 ".rs" => "rust",
211 ".go" => "go",
212 _ => "unknown",
213 }
214}
215
216pub fn extract_param_info_ts(content: &str, file_path: &str) -> Option<ParamInfo> {
218 let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
219 let lang = lang_from_ext(ext);
220 if lang == "unknown" {
221 return None;
222 }
223 let config = get_language_config(ext)?;
224 let language = (config.get_language)()?;
225
226 let mut parser = tree_sitter::Parser::new();
227 let _ = parser.set_language(&language);
228 let tree = parser.parse(content.as_bytes(), None)?;
229
230 extract_param_info_from_node(tree.root_node(), content.as_bytes(), lang)
231}
232
233fn extract_param_info_from_node(
234 root: tree_sitter::Node,
235 source: &[u8],
236 lang: &str,
237) -> Option<ParamInfo> {
238 let func_node = find_first_function(root)?;
240 let params_node = func_node.child_by_field_name("parameters")?;
241
242 let mut min_params = 0usize;
243 let mut max_params = 0usize;
244 let mut is_variadic = false;
245
246 let mut cursor = params_node.walk();
247 for child in params_node.named_children(&mut cursor) {
248 let kind = child.kind();
249 match lang {
250 "python" => {
251 if kind == "identifier" {
252 let name = child.utf8_text(source).unwrap_or("");
253 if name == "self" || name == "cls" {
254 continue;
255 }
256 min_params += 1;
257 max_params += 1;
258 } else if kind == "typed_parameter" {
259 let name = child
260 .child_by_field_name("name")
261 .or_else(|| child.named_child(0))
262 .and_then(|n| n.utf8_text(source).ok())
263 .unwrap_or("");
264 if name == "self" || name == "cls" {
265 continue;
266 }
267 min_params += 1;
268 max_params += 1;
269 } else if kind == "default_parameter" || kind == "typed_default_parameter" {
270 max_params += 1;
271 } else if kind == "list_splat_pattern" || kind == "dictionary_splat_pattern" {
272 is_variadic = true;
273 }
274 }
275 "typescript" => {
276 if kind == "required_parameter" {
277 min_params += 1;
278 max_params += 1;
279 } else if kind == "optional_parameter" {
280 max_params += 1;
281 } else if kind == "rest_pattern" {
282 is_variadic = true;
283 }
284 }
285 "rust" => {
286 if kind == "parameter" {
287 let pat = child
288 .child_by_field_name("pattern")
289 .and_then(|n| n.utf8_text(source).ok())
290 .unwrap_or("");
291 let base = pat.trim_start_matches('&').trim();
293 let base = base.strip_prefix("mut ").unwrap_or(base).trim();
294 if base == "self" {
295 continue;
296 }
297 min_params += 1;
298 max_params += 1;
299 } else if kind == "self_parameter" {
300 continue;
301 }
302 }
303 "go" => {
304 if kind == "parameter_declaration" {
305 let type_text = child
307 .child_by_field_name("type")
308 .and_then(|n| n.utf8_text(source).ok())
309 .unwrap_or("");
310 if type_text.starts_with("...") {
311 is_variadic = true;
312 } else {
313 min_params += 1;
314 max_params += 1;
315 }
316 }
317 }
318 _ => {}
319 }
320 }
321
322 Some(ParamInfo {
323 min_params,
324 max_params,
325 is_variadic,
326 })
327}
328
329fn find_first_function(node: tree_sitter::Node) -> Option<tree_sitter::Node> {
330 let kind = node.kind();
331 if matches!(
332 kind,
333 "function_definition"
334 | "function_item"
335 | "function_declaration"
336 | "method_definition"
337 | "method_declaration"
338 | "arrow_function"
339 ) {
340 return Some(node);
341 }
342 let mut cursor = node.walk();
343 for child in node.named_children(&mut cursor) {
344 if let Some(f) = find_first_function(child) {
345 return Some(f);
346 }
347 }
348 None
349}
350
351pub fn count_call_args_ts(
353 caller_content: &str,
354 callee_name: &str,
355 file_path: &str,
356) -> Option<usize> {
357 let ext = file_path.rfind('.').map(|i| &file_path[i..])?;
358 let config = get_language_config(ext)?;
359 let language = (config.get_language)()?;
360
361 let mut parser = tree_sitter::Parser::new();
362 let _ = parser.set_language(&language);
363 let tree = parser.parse(caller_content.as_bytes(), None)?;
364
365 find_call_arg_count(tree.root_node(), caller_content.as_bytes(), callee_name)
366}
367
368fn find_call_arg_count(
369 node: tree_sitter::Node,
370 source: &[u8],
371 callee_name: &str,
372) -> Option<usize> {
373 let kind = node.kind();
374
375 if kind == "call" || kind == "call_expression" {
376 let func = node.child_by_field_name("function")?;
377 let func_name = match func.kind() {
378 "identifier" => func.utf8_text(source).unwrap_or(""),
379 "attribute" | "member_expression" | "field_expression" => func
380 .child_by_field_name("attribute")
381 .or_else(|| func.child_by_field_name("property"))
382 .or_else(|| func.child_by_field_name("field"))
383 .and_then(|n| n.utf8_text(source).ok())
384 .unwrap_or(""),
385 "selector_expression" => func
386 .child_by_field_name("field")
387 .and_then(|n| n.utf8_text(source).ok())
388 .unwrap_or(""),
389 "scoped_identifier" => {
390 let text = func.utf8_text(source).unwrap_or("");
391 text.rsplit("::").next().unwrap_or("")
392 }
393 _ => "",
394 };
395
396 if func_name == callee_name {
397 let args = node.child_by_field_name("arguments")?;
398 let mut count = 0;
399 let mut cursor = args.walk();
400 for child in args.named_children(&mut cursor) {
401 if !child.kind().contains("comment") {
403 count += 1;
404 }
405 }
406 return Some(count);
407 }
408 }
409
410 let mut cursor = node.walk();
411 for child in node.named_children(&mut cursor) {
412 if let Some(count) = find_call_arg_count(child, source, callee_name) {
413 return Some(count);
414 }
415 }
416 None
417}
418
419pub fn find_arity_mismatches(
421 graph: &EntityGraph,
422 all_entities: &[SemanticEntity],
423) -> Vec<ArityMismatch> {
424 let entity_by_id: HashMap<&str, &SemanticEntity> = all_entities
425 .iter()
426 .map(|e| (e.id.as_str(), e))
427 .collect();
428
429 let mut param_cache: HashMap<String, Option<ParamInfo>> = HashMap::new();
431
432 let mut mismatches = Vec::new();
433
434 for edge in &graph.edges {
435 if edge.ref_type != RefType::Calls {
436 continue;
437 }
438
439 let callee_info = match graph.entities.get(&edge.to_entity) {
440 Some(e) => e,
441 None => continue,
442 };
443
444 if !matches!(
445 callee_info.entity_type.as_str(),
446 "function" | "method" | "arrow_function"
447 ) {
448 continue;
449 }
450
451 let callee = match entity_by_id.get(edge.to_entity.as_str()) {
452 Some(e) => *e,
453 None => continue,
454 };
455
456 let caller = match entity_by_id.get(edge.from_entity.as_str()) {
457 Some(e) => *e,
458 None => continue,
459 };
460
461 let param_info = param_cache
463 .entry(callee.id.clone())
464 .or_insert_with(|| extract_param_info_ts(&callee.content, &callee.file_path))
465 .clone();
466
467 let param_info = match param_info {
468 Some(pi) => pi,
469 None => continue,
470 };
471
472 if param_info.is_variadic {
474 continue;
475 }
476
477 let actual = match count_call_args_ts(
479 &caller.content,
480 &callee.name,
481 &caller.file_path,
482 ) {
483 Some(a) => a,
484 None => continue,
485 };
486
487 if actual < param_info.min_params || actual > param_info.max_params {
488 mismatches.push(ArityMismatch {
489 caller_entity: caller.name.clone(),
490 callee_entity: callee.name.clone(),
491 expected_min: param_info.min_params,
492 expected_max: param_info.max_params,
493 actual_args: actual,
494 file_path: caller.file_path.clone(),
495 line: caller.start_line,
496 is_variadic: false,
497 });
498 }
499 }
500
501 mismatches
502}
503
504pub fn find_broken_callers(
508 old_entities: &[SemanticEntity],
509 new_graph: &EntityGraph,
510 new_entities: &[SemanticEntity],
511) -> Vec<ArityMismatch> {
512 let old_params: HashMap<String, Option<ParamInfo>> = old_entities
514 .iter()
515 .filter(|e| matches!(e.entity_type.as_str(), "function" | "method" | "arrow_function"))
516 .map(|e| (e.id.clone(), extract_param_info_ts(&e.content, &e.file_path)))
517 .collect();
518
519 let new_by_id: HashMap<&str, &SemanticEntity> = new_entities
521 .iter()
522 .map(|e| (e.id.as_str(), e))
523 .collect();
524
525 let mut changed_entities: Vec<&str> = Vec::new();
527 for new_entity in new_entities {
528 if !matches!(new_entity.entity_type.as_str(), "function" | "method" | "arrow_function") {
529 continue;
530 }
531 let new_info = match extract_param_info_ts(&new_entity.content, &new_entity.file_path) {
532 Some(pi) => pi,
533 None => continue,
534 };
535 if let Some(Some(old_info)) = old_params.get(&new_entity.id) {
536 if old_info.min_params != new_info.min_params
537 || old_info.max_params != new_info.max_params
538 {
539 changed_entities.push(&new_entity.id);
540 }
541 }
542 }
543
544 if changed_entities.is_empty() {
545 return Vec::new();
546 }
547
548 let mut mismatches = Vec::new();
550
551 for edge in &new_graph.edges {
552 if edge.ref_type != RefType::Calls {
553 continue;
554 }
555 if !changed_entities.contains(&edge.to_entity.as_str()) {
556 continue;
557 }
558
559 let callee = match new_by_id.get(edge.to_entity.as_str()) {
560 Some(e) => *e,
561 None => continue,
562 };
563 let caller = match new_by_id.get(edge.from_entity.as_str()) {
564 Some(e) => *e,
565 None => continue,
566 };
567
568 let new_info = match extract_param_info_ts(&callee.content, &callee.file_path) {
569 Some(pi) => pi,
570 None => continue,
571 };
572
573 if new_info.is_variadic {
574 continue;
575 }
576
577 let actual = match count_call_args_ts(&caller.content, &callee.name, &caller.file_path) {
578 Some(a) => a,
579 None => continue,
580 };
581
582 if actual < new_info.min_params || actual > new_info.max_params {
583 mismatches.push(ArityMismatch {
584 caller_entity: caller.name.clone(),
585 callee_entity: callee.name.clone(),
586 expected_min: new_info.min_params,
587 expected_max: new_info.max_params,
588 actual_args: actual,
589 file_path: caller.file_path.clone(),
590 line: caller.start_line,
591 is_variadic: false,
592 });
593 }
594 }
595
596 mismatches
597}
598
599fn extract_param_count(content: &str) -> usize {
603 let first_line = content.lines().next().unwrap_or("");
604
605 let open = match first_line.find('(') {
606 Some(i) => i,
607 None => return 0,
608 };
609
610 let after_open = &first_line[open + 1..];
611 let close = match find_matching_paren(after_open) {
612 Some(i) => i,
613 None => return 0,
614 };
615
616 let params_str = after_open[..close].trim();
617 if params_str.is_empty() {
618 return 0;
619 }
620
621 count_top_level_commas(params_str) + 1
622}
623
624fn count_call_args(content: &str, callee_name: &str) -> Option<usize> {
626 let bytes = content.as_bytes();
627 let name_bytes = callee_name.as_bytes();
628 let mut search_start = 0;
629
630 while let Some(rel_pos) = content[search_start..].find(callee_name) {
631 let pos = search_start + rel_pos;
632 let after = pos + name_bytes.len();
633
634 let is_boundary = pos == 0 || {
635 let prev = bytes[pos - 1];
636 !prev.is_ascii_alphanumeric() && prev != b'_'
637 };
638
639 if is_boundary && after < bytes.len() && bytes[after] == b'(' {
640 let args_start = &content[after + 1..];
641 if let Some(close) = find_matching_paren(args_start) {
642 let args_str = args_start[..close].trim();
643 if args_str.is_empty() {
644 return Some(0);
645 }
646 return Some(count_top_level_commas(args_str) + 1);
647 }
648 }
649
650 search_start = pos + 1;
651 while search_start < content.len() && !content.is_char_boundary(search_start) {
652 search_start += 1;
653 }
654 }
655
656 None
657}
658
659fn find_matching_paren(s: &str) -> Option<usize> {
660 let mut depth = 0i32;
661 for (i, ch) in s.char_indices() {
662 match ch {
663 '(' => depth += 1,
664 ')' => {
665 if depth == 0 {
666 return Some(i);
667 }
668 depth -= 1;
669 }
670 _ => {}
671 }
672 }
673 None
674}
675
676fn count_top_level_commas(s: &str) -> usize {
677 let mut depth = 0i32;
678 let mut count = 0;
679 for ch in s.chars() {
680 match ch {
681 '(' | '[' | '{' | '<' => depth += 1,
682 ')' | ']' | '}' | '>' => depth -= 1,
683 ',' if depth == 0 => count += 1,
684 _ => {}
685 }
686 }
687 count
688}
689
690#[cfg(test)]
691mod tests {
692 use super::*;
693
694 #[test]
695 fn test_extract_param_count_basic() {
696 assert_eq!(extract_param_count("function foo(a, b, c) {"), 3);
697 assert_eq!(extract_param_count("function foo() {"), 0);
698 assert_eq!(extract_param_count("def bar(self, x):"), 2);
699 assert_eq!(extract_param_count("fn baz(a: i32) -> bool {"), 1);
700 }
701
702 #[test]
703 fn test_extract_param_count_nested() {
704 assert_eq!(extract_param_count("function foo(a, fn(x, y), c) {"), 3);
705 }
706
707 #[test]
708 fn test_count_call_args() {
709 assert_eq!(count_call_args("let x = foo(1, 2, 3);", "foo"), Some(3));
710 assert_eq!(count_call_args("foo()", "foo"), Some(0));
711 assert_eq!(count_call_args("bar(1)", "foo"), None);
712 assert_eq!(count_call_args("foo(a, b)", "foo"), Some(2));
713 }
714
715 #[test]
716 fn test_count_call_args_multibyte_utf8() {
717 assert_eq!(count_call_args("let café = foo(1, 2);", "foo"), Some(2));
718 assert_eq!(count_call_args("let É = 1; bar(x)", "bar"), Some(1));
719 assert_eq!(count_call_args("// 日本語コメント\nfoo(a, b, c)", "foo"), Some(3));
720 }
721
722 #[test]
723 fn test_extract_param_info_python() {
724 let info = extract_param_info_ts(
725 "def foo(a, b, c=3):\n pass",
726 "test.py",
727 )
728 .unwrap();
729 assert_eq!(info.min_params, 2);
730 assert_eq!(info.max_params, 3);
731 assert!(!info.is_variadic);
732 }
733
734 #[test]
735 fn test_extract_param_info_python_self() {
736 let info = extract_param_info_ts(
737 "def foo(self, a, b):\n pass",
738 "test.py",
739 )
740 .unwrap();
741 assert_eq!(info.min_params, 2);
742 assert_eq!(info.max_params, 2);
743 }
744
745 #[test]
746 fn test_extract_param_info_python_variadic() {
747 let info = extract_param_info_ts(
748 "def foo(a, *args, **kwargs):\n pass",
749 "test.py",
750 )
751 .unwrap();
752 assert!(info.is_variadic);
753 }
754
755 #[test]
756 fn test_extract_param_info_typescript() {
757 let info = extract_param_info_ts(
758 "function foo(a: number, b: string, c?: boolean): void {}",
759 "test.ts",
760 )
761 .unwrap();
762 assert_eq!(info.min_params, 2);
763 assert_eq!(info.max_params, 3);
764 assert!(!info.is_variadic);
765 }
766
767 #[test]
768 fn test_extract_param_info_rust() {
769 let info = extract_param_info_ts(
770 "fn foo(&self, a: i32, b: String) -> bool { true }",
771 "test.rs",
772 )
773 .unwrap();
774 assert_eq!(info.min_params, 2);
775 assert_eq!(info.max_params, 2);
776 }
777
778 #[test]
779 fn test_extract_param_info_go() {
780 let info = extract_param_info_ts(
781 "func foo(a string, b int) error { return nil }",
782 "test.go",
783 )
784 .unwrap();
785 assert_eq!(info.min_params, 2);
786 assert_eq!(info.max_params, 2);
787 }
788
789 #[test]
790 fn test_count_call_args_ts() {
791 let count = count_call_args_ts(
792 "function bar() { foo(1, 2, 3); }",
793 "foo",
794 "test.ts",
795 );
796 assert_eq!(count, Some(3));
797 }
798
799 #[test]
800 fn test_count_call_args_ts_method() {
801 let count = count_call_args_ts(
802 "function bar() { obj.foo(1, 2); }",
803 "foo",
804 "test.ts",
805 );
806 assert_eq!(count, Some(2));
807 }
808}