Skip to main content

similarity_core/
class_comparator.rs

1use crate::class_extractor::{ClassDefinition, ClassMethod, ClassProperty};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone)]
5pub struct NormalizedClass {
6    pub name: String,
7    pub properties: HashMap<String, ClassProperty>,
8    pub methods: HashMap<String, ClassMethod>,
9    pub constructor_signature: String,
10    pub extends: Option<String>,
11    pub implements: Vec<String>,
12}
13
14#[derive(Debug, Clone)]
15pub struct ClassComparisonResult {
16    pub similarity: f64,
17    pub structural_similarity: f64,
18    pub naming_similarity: f64,
19    pub differences: ClassDifferences,
20}
21
22#[derive(Debug, Clone)]
23pub struct ClassDifferences {
24    pub missing_properties: Vec<String>,
25    pub extra_properties: Vec<String>,
26    pub missing_methods: Vec<String>,
27    pub extra_methods: Vec<String>,
28    pub property_type_mismatches: Vec<PropertyMismatch>,
29    pub method_signature_mismatches: Vec<MethodMismatch>,
30}
31
32#[derive(Debug, Clone)]
33pub struct PropertyMismatch {
34    pub name: String,
35    pub type1: String,
36    pub type2: String,
37}
38
39#[derive(Debug, Clone)]
40pub struct MethodMismatch {
41    pub name: String,
42    pub signature1: String,
43    pub signature2: String,
44}
45
46#[derive(Debug, Clone)]
47pub struct SimilarClassPair {
48    pub class1: ClassDefinition,
49    pub class2: ClassDefinition,
50    pub result: ClassComparisonResult,
51}
52
53pub fn normalize_class(class: &ClassDefinition) -> NormalizedClass {
54    let mut properties = HashMap::new();
55    for prop in &class.properties {
56        properties.insert(prop.name.clone(), prop.clone());
57    }
58
59    let mut methods = HashMap::new();
60    for method in &class.methods {
61        // Normalize method signature
62        let normalized_method = ClassMethod {
63            name: method.name.clone(),
64            parameters: normalize_parameters(&method.parameters),
65            return_type: normalize_type(&method.return_type),
66            is_static: method.is_static,
67            is_private: method.is_private,
68            is_async: method.is_async,
69            is_generator: method.is_generator,
70            kind: method.kind.clone(),
71        };
72        methods.insert(method.name.clone(), normalized_method);
73    }
74
75    let constructor_signature = if class.constructor_params.is_empty() {
76        "()".to_string()
77    } else {
78        format!("({})", class.constructor_params.join(", "))
79    };
80
81    NormalizedClass {
82        name: class.name.clone(),
83        properties,
84        methods,
85        constructor_signature,
86        extends: class.extends.clone(),
87        implements: class.implements.clone(),
88    }
89}
90
91fn normalize_parameters(params: &[String]) -> Vec<String> {
92    params.iter().map(|p| normalize_type(p)).collect()
93}
94
95fn normalize_type(type_str: &str) -> String {
96    // Basic normalization - can be expanded
97    type_str.replace("Array<", "[").replace(">", "]").replace(" ", "").trim().to_string()
98}
99
100pub fn compare_classes(
101    class1: &ClassDefinition,
102    class2: &ClassDefinition,
103) -> ClassComparisonResult {
104    let norm1 = normalize_class(class1);
105    let norm2 = normalize_class(class2);
106
107    // Calculate naming similarity
108    let naming_similarity = calculate_name_similarity(&class1.name, &class2.name);
109
110    // Calculate structural similarity
111    let (structural_similarity, differences) = calculate_structural_similarity(&norm1, &norm2);
112
113    // Combined similarity (weighted average)
114    let similarity = 0.3 * naming_similarity + 0.7 * structural_similarity;
115
116    ClassComparisonResult { similarity, structural_similarity, naming_similarity, differences }
117}
118
119fn calculate_name_similarity(name1: &str, name2: &str) -> f64 {
120    if name1 == name2 {
121        return 1.0;
122    }
123
124    // Calculate Levenshtein distance
125    let distance = levenshtein_distance(name1, name2);
126    let max_len = name1.len().max(name2.len()) as f64;
127
128    if max_len > 0.0 {
129        1.0 - (distance as f64 / max_len)
130    } else {
131        1.0
132    }
133}
134
135fn calculate_structural_similarity(
136    class1: &NormalizedClass,
137    class2: &NormalizedClass,
138) -> (f64, ClassDifferences) {
139    let mut missing_properties = Vec::new();
140    let mut extra_properties = Vec::new();
141    let mut property_type_mismatches = Vec::new();
142
143    // Check properties
144    let mut property_matches = 0;
145    let mut property_total = 0;
146
147    for (name, prop1) in &class1.properties {
148        property_total += 1;
149        if let Some(prop2) = class2.properties.get(name) {
150            if prop1.type_annotation == prop2.type_annotation {
151                property_matches += 1;
152            } else {
153                property_type_mismatches.push(PropertyMismatch {
154                    name: name.clone(),
155                    type1: prop1.type_annotation.clone(),
156                    type2: prop2.type_annotation.clone(),
157                });
158            }
159        } else {
160            missing_properties.push(name.clone());
161        }
162    }
163
164    for name in class2.properties.keys() {
165        if !class1.properties.contains_key(name) {
166            extra_properties.push(name.clone());
167            property_total += 1;
168        }
169    }
170
171    // Check methods
172    let mut missing_methods = Vec::new();
173    let mut extra_methods = Vec::new();
174    let mut method_signature_mismatches = Vec::new();
175
176    let mut method_matches = 0;
177    let mut method_total = 0;
178
179    for (name, method1) in &class1.methods {
180        method_total += 1;
181        if let Some(method2) = class2.methods.get(name) {
182            let sig1 = format!("({}) => {}", method1.parameters.join(", "), method1.return_type);
183            let sig2 = format!("({}) => {}", method2.parameters.join(", "), method2.return_type);
184
185            if sig1 == sig2 {
186                method_matches += 1;
187            } else {
188                method_signature_mismatches.push(MethodMismatch {
189                    name: name.clone(),
190                    signature1: sig1,
191                    signature2: sig2,
192                });
193            }
194        } else {
195            missing_methods.push(name.clone());
196        }
197    }
198
199    for name in class2.methods.keys() {
200        if !class1.methods.contains_key(name) {
201            extra_methods.push(name.clone());
202            method_total += 1;
203        }
204    }
205
206    // Calculate overall structural similarity
207    let total_elements = property_total + method_total;
208    let matched_elements = property_matches + method_matches;
209
210    let structural_similarity =
211        if total_elements > 0 { matched_elements as f64 / total_elements as f64 } else { 1.0 };
212
213    let differences = ClassDifferences {
214        missing_properties,
215        extra_properties,
216        missing_methods,
217        extra_methods,
218        property_type_mismatches,
219        method_signature_mismatches,
220    };
221
222    (structural_similarity, differences)
223}
224
225fn levenshtein_distance(s1: &str, s2: &str) -> usize {
226    let len1 = s1.len();
227    let len2 = s2.len();
228    let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
229
230    #[allow(clippy::needless_range_loop)]
231    for i in 0..=len1 {
232        matrix[i][0] = i;
233    }
234
235    #[allow(clippy::needless_range_loop)]
236    for j in 0..=len2 {
237        matrix[0][j] = j;
238    }
239
240    for (i, c1) in s1.chars().enumerate() {
241        for (j, c2) in s2.chars().enumerate() {
242            let cost = if c1 == c2 { 0 } else { 1 };
243            matrix[i + 1][j + 1] = std::cmp::min(
244                std::cmp::min(matrix[i][j + 1] + 1, matrix[i + 1][j] + 1),
245                matrix[i][j] + cost,
246            );
247        }
248    }
249
250    matrix[len1][len2]
251}
252
253pub fn find_similar_classes(classes: &[ClassDefinition], threshold: f64) -> Vec<SimilarClassPair> {
254    let mut similar_pairs = Vec::new();
255
256    for i in 0..classes.len() {
257        for j in i + 1..classes.len() {
258            let result = compare_classes(&classes[i], &classes[j]);
259
260            if result.similarity >= threshold {
261                similar_pairs.push(SimilarClassPair {
262                    class1: classes[i].clone(),
263                    class2: classes[j].clone(),
264                    result,
265                });
266            }
267        }
268    }
269
270    // Sort by similarity (highest first)
271    similar_pairs.sort_by(|a, b| {
272        b.result.similarity.partial_cmp(&a.result.similarity).unwrap_or(std::cmp::Ordering::Equal)
273    });
274
275    similar_pairs
276}
277
278pub fn find_similar_classes_across_files(
279    files: &[(String, String)],
280    threshold: f64,
281) -> Vec<SimilarClassPair> {
282    let mut all_classes = Vec::new();
283
284    for (file_path, content) in files {
285        if let Ok(classes) = crate::class_extractor::extract_classes_from_code(content, file_path) {
286            all_classes.extend(classes);
287        }
288    }
289
290    find_similar_classes(&all_classes, threshold)
291}