1use oxc_ast::ast::*;
2use oxc_span::Span;
3
4use crate::parser::parse_and_convert_to_tree;
5use crate::tsed::{calculate_tsed, TSEDOptions};
6
7type CrossFileSimilarityResult = Vec<(String, SimilarityResult, String)>;
8
9#[derive(Debug, Clone)]
10pub struct SimilarityResult {
11 pub func1: FunctionDefinition,
12 pub func2: FunctionDefinition,
13 pub similarity: f64,
14 pub impact: u32, }
16
17impl SimilarityResult {
18 pub fn new(func1: FunctionDefinition, func2: FunctionDefinition, similarity: f64) -> Self {
19 let impact = func1.line_count().min(func2.line_count());
21 SimilarityResult { func1, func2, similarity, impact }
22 }
23}
24
25#[derive(Debug, Clone)]
26pub struct FunctionDefinition {
27 pub name: String,
28 pub function_type: FunctionType,
29 pub parameters: Vec<String>,
30 pub body_span: Span,
31 pub start_line: u32,
32 pub end_line: u32,
33 pub class_name: Option<String>,
34 pub parent_function: Option<String>,
35 pub node_count: Option<u32>,
36}
37
38impl FunctionDefinition {
39 pub fn line_count(&self) -> u32 {
40 self.end_line - self.start_line + 1
41 }
42
43 pub fn is_parent_child_relationship(&self, other: &FunctionDefinition) -> bool {
45 let other_inside_self = self.start_line <= other.start_line
47 && self.end_line >= other.end_line
48 && self.body_span.start < other.body_span.start
49 && self.body_span.end > other.body_span.end;
50
51 let self_inside_other = other.start_line <= self.start_line
53 && other.end_line >= self.end_line
54 && other.body_span.start < self.body_span.start
55 && other.body_span.end > self.body_span.end;
56
57 other_inside_self || self_inside_other
58 }
59}
60
61#[derive(Debug, Clone, PartialEq)]
62pub enum FunctionType {
63 Function,
64 Method,
65 Arrow,
66 Constructor,
67}
68
69pub fn extract_functions(
71 filename: &str,
72 source_text: &str,
73) -> Result<Vec<FunctionDefinition>, String> {
74 use oxc_allocator::Allocator;
75 use oxc_parser::Parser;
76 use oxc_span::SourceType;
77
78 let allocator = Allocator::default();
79 let source_type = SourceType::from_path(filename).unwrap_or(SourceType::tsx());
80 let ret = Parser::new(&allocator, source_text, source_type).parse();
81
82 if !ret.errors.is_empty() {
83 return Err(format!("Parse errors: {:?}", ret.errors));
84 }
85
86 let mut functions = Vec::new();
87 let mut context = ExtractionContext {
88 functions: &mut functions,
89 source_text,
90 class_name: None,
91 parent_function: None,
92 };
93
94 extract_from_program(&ret.program, &mut context);
95 Ok(functions)
96}
97
98struct ExtractionContext<'a> {
99 functions: &'a mut Vec<FunctionDefinition>,
100 source_text: &'a str,
101 class_name: Option<String>,
102 parent_function: Option<String>,
103}
104
105fn extract_from_program(program: &Program, ctx: &mut ExtractionContext) {
106 for stmt in &program.body {
107 extract_from_statement(stmt, ctx);
108 }
109}
110
111fn extract_from_statement(stmt: &Statement, ctx: &mut ExtractionContext) {
112 match stmt {
113 Statement::FunctionDeclaration(func) => {
114 if let Some(name) = &func.id {
115 let func_name = name.name.to_string();
116 let params = extract_parameters(&func.params);
117 ctx.functions.push(FunctionDefinition {
118 name: func_name.clone(),
119 function_type: FunctionType::Function,
120 parameters: params,
121 body_span: func.span,
122 start_line: get_line_number(func.span.start, ctx.source_text),
123 end_line: get_line_number(func.span.end, ctx.source_text),
124 class_name: None,
125 parent_function: ctx.parent_function.clone(),
126 node_count: count_function_nodes(func.span, ctx.source_text),
127 });
128
129 if let Some(body) = &func.body {
131 let saved_parent = ctx.parent_function.clone();
132 ctx.parent_function = Some(func_name);
133 extract_from_function_body(body, ctx);
134 ctx.parent_function = saved_parent;
135 }
136 }
137 }
138 Statement::ClassDeclaration(class) => {
139 let class_name = class.id.as_ref().map(|id| id.name.to_string());
140 let saved_class_name = ctx.class_name.clone();
141 ctx.class_name = class_name.clone();
142
143 for element in &class.body.body {
144 if let ClassElement::MethodDefinition(method) = element {
145 let method_name = match &method.key {
146 PropertyKey::StaticIdentifier(ident) => ident.name.to_string(),
147 PropertyKey::PrivateIdentifier(ident) => format!("#{}", ident.name),
148 _ => "anonymous".to_string(),
149 };
150
151 let params = extract_parameters(&method.value.params);
152 let function_type = if method.kind == MethodDefinitionKind::Constructor {
153 FunctionType::Constructor
154 } else {
155 FunctionType::Method
156 };
157
158 let method_full_name = if let Some(ref class) = class_name {
159 format!("{}.{}", class, method_name)
160 } else {
161 method_name.clone()
162 };
163
164 ctx.functions.push(FunctionDefinition {
165 name: method_name.clone(),
166 function_type,
167 parameters: params,
168 body_span: method.span,
169 start_line: get_line_number(method.span.start, ctx.source_text),
170 end_line: get_line_number(method.span.end, ctx.source_text),
171 class_name: class_name.clone(),
172 parent_function: ctx.parent_function.clone(),
173 node_count: count_function_nodes(method.span, ctx.source_text),
174 });
175
176 if let Some(body) = &method.value.body {
178 let saved_parent = ctx.parent_function.clone();
179 ctx.parent_function = Some(method_full_name);
180 extract_from_function_body(body, ctx);
181 ctx.parent_function = saved_parent;
182 }
183 }
184 }
185
186 ctx.class_name = saved_class_name;
187 }
188 Statement::VariableDeclaration(var_decl) => {
189 for decl in &var_decl.declarations {
190 if let Some(Expression::ArrowFunctionExpression(arrow)) = &decl.init {
191 if let BindingPatternKind::BindingIdentifier(ident) = &decl.id.kind {
192 let params = extract_parameters(&arrow.params);
193 let arrow_name = ident.name.to_string();
194 ctx.functions.push(FunctionDefinition {
195 name: arrow_name.clone(),
196 function_type: FunctionType::Arrow,
197 parameters: params,
198 body_span: arrow.span,
199 start_line: get_line_number(arrow.span.start, ctx.source_text),
200 end_line: get_line_number(arrow.span.end, ctx.source_text),
201 class_name: None,
202 parent_function: ctx.parent_function.clone(),
203 node_count: count_function_nodes(arrow.span, ctx.source_text),
204 });
205
206 if !arrow.expression {
208 let saved_parent = ctx.parent_function.clone();
209 ctx.parent_function = Some(arrow_name);
210 extract_from_function_body(&arrow.body, ctx);
211 ctx.parent_function = saved_parent;
212 }
213 }
214 }
215 }
216 }
217 Statement::ExportNamedDeclaration(export) => {
218 if let Some(decl) = &export.declaration {
219 extract_from_declaration(decl, ctx);
220 }
221 }
222 Statement::ExportDefaultDeclaration(export) => {
223 if let ExportDefaultDeclarationKind::FunctionDeclaration(func) = &export.declaration {
224 let name = func
225 .id
226 .as_ref()
227 .map(|id| id.name.to_string())
228 .unwrap_or_else(|| "default".to_string());
229 let params = extract_parameters(&func.params);
230 let func_name = name.clone();
231 ctx.functions.push(FunctionDefinition {
232 name: func_name.clone(),
233 function_type: FunctionType::Function,
234 parameters: params,
235 body_span: func.span,
236 start_line: get_line_number(func.span.start, ctx.source_text),
237 end_line: get_line_number(func.span.end, ctx.source_text),
238 class_name: None,
239 parent_function: ctx.parent_function.clone(),
240 node_count: count_function_nodes(func.span, ctx.source_text),
241 });
242
243 if let Some(body) = &func.body {
245 let saved_parent = ctx.parent_function.clone();
246 ctx.parent_function = Some(func_name);
247 extract_from_function_body(body, ctx);
248 ctx.parent_function = saved_parent;
249 }
250 }
251 }
252 _ => {}
253 }
254}
255
256fn extract_from_declaration(decl: &Declaration, ctx: &mut ExtractionContext) {
257 match decl {
258 Declaration::FunctionDeclaration(func) => {
259 if let Some(name) = &func.id {
260 let func_name = name.name.to_string();
261 let params = extract_parameters(&func.params);
262 ctx.functions.push(FunctionDefinition {
263 name: func_name.clone(),
264 function_type: FunctionType::Function,
265 parameters: params,
266 body_span: func.span,
267 start_line: get_line_number(func.span.start, ctx.source_text),
268 end_line: get_line_number(func.span.end, ctx.source_text),
269 class_name: None,
270 parent_function: ctx.parent_function.clone(),
271 node_count: count_function_nodes(func.span, ctx.source_text),
272 });
273
274 if let Some(body) = &func.body {
276 let saved_parent = ctx.parent_function.clone();
277 ctx.parent_function = Some(func_name);
278 extract_from_function_body(body, ctx);
279 ctx.parent_function = saved_parent;
280 }
281 }
282 }
283 Declaration::ClassDeclaration(class) => {
284 let class_name = class.id.as_ref().map(|id| id.name.to_string());
285 let saved_class_name = ctx.class_name.clone();
286 ctx.class_name = class_name.clone();
287
288 for element in &class.body.body {
289 if let ClassElement::MethodDefinition(method) = element {
290 let method_name = match &method.key {
291 PropertyKey::StaticIdentifier(ident) => ident.name.to_string(),
292 PropertyKey::PrivateIdentifier(ident) => format!("#{}", ident.name),
293 _ => "anonymous".to_string(),
294 };
295
296 let params = extract_parameters(&method.value.params);
297 let function_type = if method.kind == MethodDefinitionKind::Constructor {
298 FunctionType::Constructor
299 } else {
300 FunctionType::Method
301 };
302
303 let method_full_name = if let Some(ref class) = class_name {
304 format!("{}.{}", class, method_name)
305 } else {
306 method_name.clone()
307 };
308
309 ctx.functions.push(FunctionDefinition {
310 name: method_name.clone(),
311 function_type,
312 parameters: params,
313 body_span: method.span,
314 start_line: get_line_number(method.span.start, ctx.source_text),
315 end_line: get_line_number(method.span.end, ctx.source_text),
316 class_name: class_name.clone(),
317 parent_function: ctx.parent_function.clone(),
318 node_count: count_function_nodes(method.span, ctx.source_text),
319 });
320
321 if let Some(body) = &method.value.body {
323 let saved_parent = ctx.parent_function.clone();
324 ctx.parent_function = Some(method_full_name);
325 extract_from_function_body(body, ctx);
326 ctx.parent_function = saved_parent;
327 }
328 }
329 }
330
331 ctx.class_name = saved_class_name;
332 }
333 Declaration::VariableDeclaration(var) => {
334 for decl in &var.declarations {
335 if let Some(Expression::ArrowFunctionExpression(arrow)) = &decl.init {
336 if let BindingPatternKind::BindingIdentifier(ident) = &decl.id.kind {
337 let params = extract_parameters(&arrow.params);
338 let arrow_name = ident.name.to_string();
339 ctx.functions.push(FunctionDefinition {
340 name: arrow_name.clone(),
341 function_type: FunctionType::Arrow,
342 parameters: params,
343 body_span: arrow.span,
344 start_line: get_line_number(arrow.span.start, ctx.source_text),
345 end_line: get_line_number(arrow.span.end, ctx.source_text),
346 class_name: None,
347 parent_function: ctx.parent_function.clone(),
348 node_count: count_function_nodes(arrow.span, ctx.source_text),
349 });
350
351 if !arrow.expression {
353 let saved_parent = ctx.parent_function.clone();
354 ctx.parent_function = Some(arrow_name);
355 extract_from_function_body(&arrow.body, ctx);
356 ctx.parent_function = saved_parent;
357 }
358 }
359 }
360 }
361 }
362 _ => {}
363 }
364}
365
366fn extract_parameters(params: &oxc_ast::ast::FormalParameters) -> Vec<String> {
367 params
368 .items
369 .iter()
370 .filter_map(|param| match ¶m.pattern.kind {
371 BindingPatternKind::BindingIdentifier(ident) => Some(ident.name.to_string()),
372 _ => None,
373 })
374 .collect()
375}
376
377fn extract_from_function_body(body: &FunctionBody, ctx: &mut ExtractionContext) {
378 for stmt in &body.statements {
379 extract_from_statement(stmt, ctx);
380 }
381}
382
383fn get_line_number(offset: u32, source_text: &str) -> u32 {
384 let mut line = 1;
385 let mut current_offset = 0;
386
387 for ch in source_text.chars() {
388 if current_offset >= offset as usize {
389 break;
390 }
391 if ch == '\n' {
392 line += 1;
393 }
394 current_offset += ch.len_utf8();
395 }
396
397 line
398}
399
400pub fn compare_functions(
402 func1: &FunctionDefinition,
403 func2: &FunctionDefinition,
404 source1: &str,
405 source2: &str,
406 options: &TSEDOptions,
407) -> Result<f64, String> {
408 let body1 = extract_body_text(func1, source1);
410 let body2 = extract_body_text(func2, source2);
411
412 let tree1 = parse_and_convert_to_tree("func1.ts", &body1)?;
414 let tree2 = parse_and_convert_to_tree("func2.ts", &body2)?;
415
416 let mut similarity = calculate_tsed(&tree1, &tree2, options);
417
418 if options.size_penalty {
420 let avg_lines = (func1.line_count() + func2.line_count()) as f64 / 2.0;
421 if avg_lines < 10.0 {
422 let penalty = avg_lines / 10.0;
424 similarity *= penalty;
425 }
426 }
427
428 Ok(similarity)
429}
430
431fn extract_body_text(func: &FunctionDefinition, source: &str) -> String {
432 let start = func.body_span.start as usize;
433 let end = func.body_span.end as usize;
434 source[start..end].to_string()
435}
436
437fn count_function_nodes(body_span: Span, source_text: &str) -> Option<u32> {
439 let start = body_span.start as usize;
440 let end = body_span.end as usize;
441 if start >= end || end > source_text.len() {
442 return None;
443 }
444
445 let body_text = &source_text[start..end];
446
447 match parse_and_convert_to_tree("temp.ts", body_text) {
450 Ok(tree) => Some(tree.get_subtree_size() as u32),
451 Err(_) => {
452 let wrapped = if body_text.starts_with("constructor") {
455 format!("class C {{ {} }}", body_text)
456 } else if body_text.contains("(") && body_text.contains(")") && body_text.contains("{")
457 {
458 if body_text.trim().starts_with(|c: char| c.is_alphabetic() || c == '_' || c == '#')
460 {
461 format!("class C {{ {} }}", body_text)
463 } else {
464 format!("const x = {}", body_text)
466 }
467 } else {
468 body_text.to_string()
470 };
471
472 match parse_and_convert_to_tree("temp.ts", &wrapped) {
473 Ok(tree) => {
474 let base_nodes = if wrapped.starts_with("class C") {
476 3 } else if wrapped.starts_with("const x") {
478 2 } else {
480 0
481 };
482 Some((tree.get_subtree_size().saturating_sub(base_nodes)) as u32)
483 }
484 Err(_) => {
485 let node_count =
488 body_text.matches(['{', '}', '(', ')', ';']).count() as u32 + 1;
489 Some(node_count.max(1))
490 }
491 }
492 }
493 }
494}
495
496pub fn find_similar_functions_in_file(
498 filename: &str,
499 source_text: &str,
500 threshold: f64,
501 options: &TSEDOptions,
502) -> Result<Vec<SimilarityResult>, String> {
503 let functions = extract_functions(filename, source_text)?;
504 let mut similar_pairs = Vec::new();
505
506 for i in 0..functions.len() {
508 for j in (i + 1)..functions.len() {
509 if let Some(min_tokens) = options.min_tokens {
511 let tokens_i = functions[i].node_count.unwrap_or(0);
513 let tokens_j = functions[j].node_count.unwrap_or(0);
514 if tokens_i < min_tokens || tokens_j < min_tokens {
515 continue;
516 }
517 } else {
518 if functions[i].line_count() < options.min_lines
520 || functions[j].line_count() < options.min_lines
521 {
522 continue;
523 }
524 }
525
526 if functions[i].is_parent_child_relationship(&functions[j]) {
528 continue;
529 }
530
531 let similarity =
532 compare_functions(&functions[i], &functions[j], source_text, source_text, options)?;
533
534 if similarity >= threshold {
535 similar_pairs.push(SimilarityResult::new(
536 functions[i].clone(),
537 functions[j].clone(),
538 similarity,
539 ));
540 }
541 }
542 }
543
544 similar_pairs.sort_by(|a, b| {
546 b.impact
547 .cmp(&a.impact)
548 .then(b.similarity.partial_cmp(&a.similarity).unwrap_or(std::cmp::Ordering::Equal))
549 });
550
551 Ok(similar_pairs)
552}
553
554pub fn find_similar_functions_across_files(
556 files: &[(String, String)], threshold: f64,
558 options: &TSEDOptions,
559) -> Result<CrossFileSimilarityResult, String> {
560 let mut all_functions = Vec::new();
561
562 for (filename, source) in files {
564 let functions = extract_functions(filename, source)?;
565 for func in functions {
566 all_functions.push((filename.clone(), source.clone(), func));
567 }
568 }
569
570 let mut similar_pairs = Vec::new();
571
572 for i in 0..all_functions.len() {
574 for j in (i + 1)..all_functions.len() {
575 let (first_file, source1, func1) = &all_functions[i];
576 let (second_file, source2, func2) = &all_functions[j];
577
578 if first_file == second_file {
580 continue;
581 }
582
583 if let Some(min_tokens) = options.min_tokens {
585 let tokens1 = func1.node_count.unwrap_or(0);
587 let tokens2 = func2.node_count.unwrap_or(0);
588 if tokens1 < min_tokens || tokens2 < min_tokens {
589 continue;
590 }
591 } else {
592 if func1.line_count() < options.min_lines || func2.line_count() < options.min_lines
594 {
595 continue;
596 }
597 }
598
599 if func1.is_parent_child_relationship(func2) {
601 continue;
602 }
603
604 let similarity = compare_functions(func1, func2, source1, source2, options)?;
605
606 if similarity >= threshold {
607 similar_pairs.push((
608 first_file.clone(),
609 SimilarityResult::new(func1.clone(), func2.clone(), similarity),
610 second_file.clone(),
611 ));
612 }
613 }
614 }
615
616 similar_pairs.sort_by(|a, b| {
618 b.1.impact
619 .cmp(&a.1.impact)
620 .then(b.1.similarity.partial_cmp(&a.1.similarity).unwrap_or(std::cmp::Ordering::Equal))
621 });
622
623 Ok(similar_pairs)
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 #[test]
631 fn test_extract_functions() {
632 let code = r"
633 function add(a: number, b: number): number {
634 return a + b;
635 }
636
637 const multiply = (x: number, y: number) => x * y;
638
639 class Calculator {
640 constructor(private initial: number) {}
641
642 add(value: number): number {
643 return this.initial + value;
644 }
645
646 subtract(value: number): number {
647 return this.initial - value;
648 }
649 }
650
651 export function divide(a: number, b: number): number {
652 return a / b;
653 }
654 ";
655
656 let functions = extract_functions("test.ts", code).unwrap();
657
658 assert_eq!(functions.len(), 6);
659
660 let names: Vec<&str> = functions.iter().map(|f| f.name.as_str()).collect();
662 assert!(names.contains(&"add"));
663 assert!(names.contains(&"multiply"));
664 assert!(names.contains(&"constructor"));
665 assert!(names.contains(&"subtract"));
666 assert!(names.contains(&"divide"));
667
668 let add_func =
670 functions.iter().find(|f| f.name == "add" && f.class_name.is_none()).unwrap();
671 assert_eq!(add_func.function_type, FunctionType::Function);
672 assert_eq!(add_func.parameters, vec!["a", "b"]);
673
674 let multiply_func = functions.iter().find(|f| f.name == "multiply").unwrap();
675 assert_eq!(multiply_func.function_type, FunctionType::Arrow);
676
677 let constructor = functions.iter().find(|f| f.name == "constructor").unwrap();
678 assert_eq!(constructor.function_type, FunctionType::Constructor);
679 assert_eq!(constructor.class_name, Some("Calculator".to_string()));
680
681 for func in &functions {
683 assert!(
684 func.node_count.is_some(),
685 "Function {} should have node_count populated",
686 func.name
687 );
688 assert!(
690 func.node_count.unwrap() > 0,
691 "Function {} should have positive node_count",
692 func.name
693 );
694 }
695 }
696
697 #[test]
698 fn test_node_count_calculation() {
699 let code = r#"
700 function simple() {
701 return 42;
702 }
703
704 function complex(a: number, b: number): number {
705 if (a > b) {
706 return a - b;
707 } else {
708 return a + b;
709 }
710 }
711 "#;
712
713 let functions = extract_functions("test.ts", code).unwrap();
714
715 let simple = functions.iter().find(|f| f.name == "simple").unwrap();
716 let complex = functions.iter().find(|f| f.name == "complex").unwrap();
717
718 println!("Simple function node count: {:?}", simple.node_count);
719 println!("Complex function node count: {:?}", complex.node_count);
720
721 assert!(simple.node_count.is_some());
723 assert!(complex.node_count.is_some());
724 assert!(simple.node_count.unwrap() < complex.node_count.unwrap());
725 }
726
727 #[test]
728 fn test_find_similar_functions_in_file() {
729 let code = r"
730 function calculateSum(a: number, b: number): number {
731 return a + b;
732 }
733
734 function addNumbers(x: number, y: number): number {
735 return x + y;
736 }
737
738 function multiply(a: number, b: number): number {
739 return a * b;
740 }
741
742 function computeSum(first: number, second: number): number {
743 return first + second;
744 }
745 ";
746
747 let mut options = TSEDOptions::default();
748 options.apted_options.rename_cost = 0.3; options.size_penalty = false; options.min_lines = 1; let similar_pairs = find_similar_functions_in_file("test.ts", code, 0.7, &options).unwrap();
753
754 assert!(
756 similar_pairs.len() >= 2,
757 "Expected at least 2 similar pairs, found {}",
758 similar_pairs.len()
759 );
760
761 let sum_pairs = similar_pairs
765 .iter()
766 .filter(|result| {
767 (result.func1.name.contains("Sum") || result.func2.name.contains("Sum"))
768 || (result.func1.name == "addNumbers" || result.func2.name == "addNumbers")
769 })
770 .count();
771 assert!(sum_pairs >= 3, "Expected at least 3 pairs involving sum functions");
772 }
773
774 #[test]
775 fn test_find_similar_functions_across_files() {
776 let file1 = (
777 "file1.ts".to_string(),
778 r#"
779 export function processUser(user: User): void {
780 validateUser(user);
781 saveUser(user);
782 notifyUser(user);
783 }
784
785 function validateUser(user: User): boolean {
786 return user.name.length > 0 && user.email.includes('@');
787 }
788 "#
789 .to_string(),
790 );
791
792 let file2 = (
793 "file2.ts".to_string(),
794 r#"
795 export function handleUser(u: User): void {
796 checkUser(u);
797 storeUser(u);
798 alertUser(u);
799 }
800
801 function checkUser(u: User): boolean {
802 return u.name.length > 0 && u.email.includes('@');
803 }
804 "#
805 .to_string(),
806 );
807
808 let mut options = TSEDOptions::default();
809 options.apted_options.rename_cost = 0.3;
810 options.size_penalty = false; options.min_lines = 1; let similar_pairs =
814 find_similar_functions_across_files(&[file1, file2], 0.7, &options).unwrap();
815
816 assert!(similar_pairs.len() >= 2);
818
819 let process_handle = similar_pairs.iter().find(|(_, result, _)| {
821 (result.func1.name == "processUser" && result.func2.name == "handleUser")
822 || (result.func1.name == "handleUser" && result.func2.name == "processUser")
823 });
824 assert!(process_handle.is_some());
825
826 let validate_check = similar_pairs.iter().find(|(_, result, _)| {
827 (result.func1.name == "validateUser" && result.func2.name == "checkUser")
828 || (result.func1.name == "checkUser" && result.func2.name == "validateUser")
829 });
830 assert!(validate_check.is_some());
831 }
832}