Skip to main content

similarity_core/
class_extractor.rs

1use oxc_allocator::Allocator;
2use oxc_ast::ast::{ClassElement, MethodDefinitionKind, Statement};
3use oxc_parser::Parser;
4use oxc_span::SourceType;
5
6use crate::ignore_directive::has_similarity_ignore_directive;
7
8#[derive(Debug, Clone)]
9pub struct ClassDefinition {
10    pub name: String,
11    pub properties: Vec<ClassProperty>,
12    pub methods: Vec<ClassMethod>,
13    pub constructor_params: Vec<String>,
14    pub extends: Option<String>,
15    pub implements: Vec<String>,
16    pub start_line: usize,
17    pub end_line: usize,
18    pub file_path: String,
19    pub is_abstract: bool,
20    pub has_ignore_directive: bool,
21}
22
23#[derive(Debug, Clone)]
24pub struct ClassProperty {
25    pub name: String,
26    pub type_annotation: String,
27    pub is_static: bool,
28    pub is_private: bool,
29    pub is_readonly: bool,
30    pub is_optional: bool,
31}
32
33#[derive(Debug, Clone)]
34pub struct ClassMethod {
35    pub name: String,
36    pub parameters: Vec<String>,
37    pub return_type: String,
38    pub is_static: bool,
39    pub is_private: bool,
40    pub is_async: bool,
41    pub is_generator: bool,
42    pub kind: MethodKind,
43}
44
45#[derive(Debug, Clone, PartialEq)]
46pub enum MethodKind {
47    Method,
48    Getter,
49    Setter,
50    Constructor,
51}
52
53struct ClassExtractor {
54    source_text: String,
55    file_path: String,
56    line_offsets: Vec<usize>,
57}
58
59impl ClassExtractor {
60    fn new(source_text: String, file_path: String) -> Self {
61        let line_offsets = Self::calculate_line_offsets(&source_text);
62        Self { source_text, file_path, line_offsets }
63    }
64
65    fn calculate_line_offsets(source: &str) -> Vec<usize> {
66        let mut offsets = vec![0];
67        for (i, ch) in source.char_indices() {
68            if ch == '\n' {
69                offsets.push(i + 1);
70            }
71        }
72        offsets
73    }
74
75    fn get_line_number(&self, offset: usize) -> usize {
76        match self.line_offsets.binary_search(&offset) {
77            Ok(line) => line + 1,
78            Err(line) => line,
79        }
80    }
81
82    fn extract_type_string(&self, type_annotation: &oxc_ast::ast::TSTypeAnnotation) -> String {
83        use oxc_ast::ast::TSType;
84
85        match &type_annotation.type_annotation {
86            TSType::TSStringKeyword(_) => "string".to_string(),
87            TSType::TSNumberKeyword(_) => "number".to_string(),
88            TSType::TSBooleanKeyword(_) => "boolean".to_string(),
89            TSType::TSAnyKeyword(_) => "any".to_string(),
90            TSType::TSUnknownKeyword(_) => "unknown".to_string(),
91            TSType::TSNeverKeyword(_) => "never".to_string(),
92            TSType::TSVoidKeyword(_) => "void".to_string(),
93            TSType::TSUndefinedKeyword(_) => "undefined".to_string(),
94            TSType::TSNullKeyword(_) => "null".to_string(),
95            TSType::TSArrayType(array) => {
96                format!("{}[]", self.extract_type_string_from_ts_type(&array.element_type))
97            }
98            TSType::TSTypeReference(type_ref) => match &type_ref.type_name {
99                oxc_ast::ast::TSTypeName::IdentifierReference(ident) => {
100                    let base = ident.name.as_str();
101                    if let Some(params) = &type_ref.type_arguments {
102                        let param_strings: Vec<String> = params
103                            .params
104                            .iter()
105                            .map(|p| self.extract_type_string_from_ts_type(p))
106                            .collect();
107                        format!("{}<{}>", base, param_strings.join(", "))
108                    } else {
109                        base.to_string()
110                    }
111                }
112                _ => "unknown".to_string(),
113            },
114            TSType::TSUnionType(union) => {
115                let types: Vec<String> =
116                    union.types.iter().map(|t| self.extract_type_string_from_ts_type(t)).collect();
117                types.join(" | ")
118            }
119            TSType::TSIntersectionType(intersection) => {
120                let types: Vec<String> = intersection
121                    .types
122                    .iter()
123                    .map(|t| self.extract_type_string_from_ts_type(t))
124                    .collect();
125                types.join(" & ")
126            }
127            TSType::TSFunctionType(func) => {
128                let params = self.extract_function_params(&func.params);
129                let return_type =
130                    self.extract_type_string_from_ts_type(&func.return_type.type_annotation);
131                format!("({}) => {}", params, return_type)
132            }
133            TSType::TSTypeLiteral(literal) => {
134                let props: Vec<String> = literal
135                    .members
136                    .iter()
137                    .filter_map(|member| {
138                        if let oxc_ast::ast::TSSignature::TSPropertySignature(prop) = member {
139                            let name = match &prop.key {
140                                oxc_ast::ast::PropertyKey::StaticIdentifier(ident) => {
141                                    ident.name.as_str().to_string()
142                                }
143                                oxc_ast::ast::PropertyKey::StringLiteral(str_lit) => {
144                                    str_lit.value.as_str().to_string()
145                                }
146                                _ => return None,
147                            };
148                            let type_str = prop
149                                .type_annotation
150                                .as_ref()
151                                .map(|ta| self.extract_type_string(ta))
152                                .unwrap_or_else(|| "any".to_string());
153                            let optional = if prop.optional { "?" } else { "" };
154                            Some(format!("{}{}: {}", name, optional, type_str))
155                        } else {
156                            None
157                        }
158                    })
159                    .collect();
160                format!("{{ {} }}", props.join(", "))
161            }
162            _ => "any".to_string(),
163        }
164    }
165
166    fn extract_type_string_from_ts_type(&self, ts_type: &oxc_ast::ast::TSType) -> String {
167        use oxc_ast::ast::TSType;
168
169        match ts_type {
170            TSType::TSStringKeyword(_) => "string".to_string(),
171            TSType::TSNumberKeyword(_) => "number".to_string(),
172            TSType::TSBooleanKeyword(_) => "boolean".to_string(),
173            TSType::TSAnyKeyword(_) => "any".to_string(),
174            TSType::TSUnknownKeyword(_) => "unknown".to_string(),
175            TSType::TSNeverKeyword(_) => "never".to_string(),
176            TSType::TSVoidKeyword(_) => "void".to_string(),
177            TSType::TSUndefinedKeyword(_) => "undefined".to_string(),
178            TSType::TSNullKeyword(_) => "null".to_string(),
179            TSType::TSArrayType(array) => {
180                format!("{}[]", self.extract_type_string_from_ts_type(&array.element_type))
181            }
182            TSType::TSTypeReference(type_ref) => match &type_ref.type_name {
183                oxc_ast::ast::TSTypeName::IdentifierReference(ident) => {
184                    ident.name.as_str().to_string()
185                }
186                _ => "unknown".to_string(),
187            },
188            TSType::TSUnionType(union) => {
189                let types: Vec<String> =
190                    union.types.iter().map(|t| self.extract_type_string_from_ts_type(t)).collect();
191                types.join(" | ")
192            }
193            TSType::TSIntersectionType(intersection) => {
194                let types: Vec<String> = intersection
195                    .types
196                    .iter()
197                    .map(|t| self.extract_type_string_from_ts_type(t))
198                    .collect();
199                types.join(" & ")
200            }
201            TSType::TSFunctionType(func) => {
202                let params = self.extract_function_params(&func.params);
203                let return_type =
204                    self.extract_type_string_from_ts_type(&func.return_type.type_annotation);
205                format!("({}) => {}", params, return_type)
206            }
207            TSType::TSTypeLiteral(literal) => {
208                let props: Vec<String> = literal
209                    .members
210                    .iter()
211                    .filter_map(|member| {
212                        if let oxc_ast::ast::TSSignature::TSPropertySignature(prop) = member {
213                            let name = match &prop.key {
214                                oxc_ast::ast::PropertyKey::StaticIdentifier(ident) => {
215                                    ident.name.as_str().to_string()
216                                }
217                                oxc_ast::ast::PropertyKey::StringLiteral(str_lit) => {
218                                    str_lit.value.as_str().to_string()
219                                }
220                                _ => return None,
221                            };
222                            let type_str = prop
223                                .type_annotation
224                                .as_ref()
225                                .map(|ta| self.extract_type_string(ta))
226                                .unwrap_or_else(|| "any".to_string());
227                            let optional = if prop.optional { "?" } else { "" };
228                            Some(format!("{}{}: {}", name, optional, type_str))
229                        } else {
230                            None
231                        }
232                    })
233                    .collect();
234                format!("{{ {} }}", props.join(", "))
235            }
236            _ => "any".to_string(),
237        }
238    }
239
240    fn extract_function_params(&self, params: &oxc_ast::ast::FormalParameters) -> String {
241        let param_strings: Vec<String> = params
242            .items
243            .iter()
244            .map(|param| {
245                let name = match &param.pattern {
246                    oxc_ast::ast::BindingPattern::BindingIdentifier(ident) => ident.name.as_str(),
247                    _ => "param",
248                };
249                let type_str = param
250                    .type_annotation
251                    .as_ref()
252                    .map(|ta| self.extract_type_string(ta))
253                    .unwrap_or_else(|| "any".to_string());
254                format!("{}: {}", name, type_str)
255            })
256            .collect();
257        param_strings.join(", ")
258    }
259
260    fn extract_class(&self, class: &oxc_ast::ast::Class) -> ClassDefinition {
261        let name = class
262            .id
263            .as_ref()
264            .map(|id| id.name.as_str().to_string())
265            .unwrap_or_else(|| "AnonymousClass".to_string());
266
267        let start_line = self.get_line_number(class.span.start as usize);
268        let end_line = self.get_line_number(class.span.end as usize);
269
270        let extends = class.super_class.as_ref().and_then(|super_class| {
271            if let oxc_ast::ast::Expression::Identifier(ident) = super_class {
272                Some(ident.name.as_str().to_string())
273            } else {
274                None
275            }
276        });
277
278        let implements = class
279            .implements
280            .iter()
281            .filter_map(|impl_clause| match &impl_clause.expression {
282                oxc_ast::ast::TSTypeName::IdentifierReference(ident) => {
283                    Some(ident.name.as_str().to_string())
284                }
285                _ => None,
286            })
287            .collect();
288
289        let mut properties = Vec::new();
290        let mut methods = Vec::new();
291        let mut constructor_params = Vec::new();
292
293        for element in &class.body.body {
294            match element {
295                ClassElement::PropertyDefinition(prop) => {
296                    let name = match &prop.key {
297                        oxc_ast::ast::PropertyKey::StaticIdentifier(ident) => {
298                            ident.name.as_str().to_string()
299                        }
300                        oxc_ast::ast::PropertyKey::StringLiteral(str_lit) => {
301                            str_lit.value.as_str().to_string()
302                        }
303                        _ => continue,
304                    };
305
306                    let type_annotation = prop
307                        .type_annotation
308                        .as_ref()
309                        .map(|ta| self.extract_type_string(ta))
310                        .unwrap_or_else(|| "any".to_string());
311
312                    properties.push(ClassProperty {
313                        name,
314                        type_annotation,
315                        is_static: prop.r#static,
316                        is_private: false, // PropertyDefinitionType doesn't have TSPrivateProperty
317                        is_readonly: prop.readonly,
318                        is_optional: prop.optional,
319                    });
320                }
321                ClassElement::MethodDefinition(method) => {
322                    let name = match &method.key {
323                        oxc_ast::ast::PropertyKey::StaticIdentifier(ident) => {
324                            ident.name.as_str().to_string()
325                        }
326                        oxc_ast::ast::PropertyKey::StringLiteral(str_lit) => {
327                            str_lit.value.as_str().to_string()
328                        }
329                        _ => continue,
330                    };
331
332                    let kind = match method.kind {
333                        MethodDefinitionKind::Constructor => {
334                            // Extract constructor parameters
335                            constructor_params = method
336                                .value
337                                .params
338                                .items
339                                .iter()
340                                .map(|param| {
341                                    let param_name = match &param.pattern {
342                                        oxc_ast::ast::BindingPattern::BindingIdentifier(ident) => {
343                                            ident.name.as_str()
344                                        }
345                                        _ => "param",
346                                    };
347                                    let type_str = param
348                                        .type_annotation
349                                        .as_ref()
350                                        .map(|ta| self.extract_type_string(ta))
351                                        .unwrap_or_else(|| "any".to_string());
352                                    format!("{}: {}", param_name, type_str)
353                                })
354                                .collect();
355                            MethodKind::Constructor
356                        }
357                        MethodDefinitionKind::Method => MethodKind::Method,
358                        MethodDefinitionKind::Get => MethodKind::Getter,
359                        MethodDefinitionKind::Set => MethodKind::Setter,
360                    };
361
362                    if kind != MethodKind::Constructor {
363                        let parameters = self.extract_function_params(&method.value.params);
364                        let return_type = method
365                            .value
366                            .return_type
367                            .as_ref()
368                            .map(|rt| self.extract_type_string_from_ts_type(&rt.type_annotation))
369                            .unwrap_or_else(|| "void".to_string());
370
371                        methods.push(ClassMethod {
372                            name,
373                            parameters: vec![parameters],
374                            return_type,
375                            is_static: method.r#static,
376                            is_private: false, // Would need to check for private keyword
377                            is_async: method.value.r#async,
378                            is_generator: method.value.generator,
379                            kind,
380                        });
381                    }
382                }
383                _ => {}
384            }
385        }
386
387        ClassDefinition {
388            name,
389            properties,
390            methods,
391            constructor_params,
392            extends,
393            implements,
394            start_line,
395            end_line,
396            file_path: self.file_path.clone(),
397            is_abstract: class.r#abstract,
398            has_ignore_directive: has_similarity_ignore_directive(&self.source_text, start_line),
399        }
400    }
401
402    pub fn extract_classes(&self) -> Result<Vec<ClassDefinition>, String> {
403        let allocator = Allocator::default();
404        let source_type = SourceType::from_path(&self.file_path).unwrap_or(SourceType::tsx());
405        let ret = Parser::new(&allocator, &self.source_text, source_type).parse();
406
407        if !ret.errors.is_empty() {
408            let error_messages: Vec<String> =
409                ret.errors.iter().map(|e| format!("{:?}", e)).collect();
410            return Err(format!("Parse errors: {}", error_messages.join(", ")));
411        }
412
413        let mut classes = Vec::new();
414
415        // Walk through all statements and find classes
416        for statement in &ret.program.body {
417            match statement {
418                Statement::ExportDefaultDeclaration(export) => {
419                    if let oxc_ast::ast::ExportDefaultDeclarationKind::ClassDeclaration(class) =
420                        &export.declaration
421                    {
422                        classes.push(self.extract_class(class));
423                    }
424                }
425                Statement::ExportNamedDeclaration(export) => {
426                    if let Some(oxc_ast::ast::Declaration::ClassDeclaration(class)) =
427                        &export.declaration
428                    {
429                        classes.push(self.extract_class(class));
430                    }
431                }
432                Statement::ClassDeclaration(class) => {
433                    classes.push(self.extract_class(class));
434                }
435                _ => {}
436            }
437        }
438
439        Ok(classes)
440    }
441}
442
443pub fn extract_classes_from_code(
444    code: &str,
445    file_path: &str,
446) -> Result<Vec<ClassDefinition>, String> {
447    let extractor = ClassExtractor::new(code.to_string(), file_path.to_string());
448    extractor.extract_classes()
449}
450
451pub fn extract_classes_from_files(files: &[(String, String)]) -> Vec<ClassDefinition> {
452    let mut all_classes = Vec::new();
453
454    for (file_path, content) in files {
455        match extract_classes_from_code(content, file_path) {
456            Ok(classes) => all_classes.extend(classes),
457            Err(e) => eprintln!("Error extracting classes from {}: {}", file_path, e),
458        }
459    }
460
461    all_classes
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    #[test]
469    fn test_extract_classes_marks_similarity_ignore_directives() {
470        let source = r#"
471class ActiveService {
472    run(): void {}
473}
474
475// similarity-ignore
476class IgnoredService {
477    run(): void {}
478}
479"#;
480
481        let classes = extract_classes_from_code(source, "test.ts").unwrap();
482
483        let active = classes.iter().find(|class| class.name == "ActiveService").unwrap();
484        assert!(!active.has_ignore_directive);
485
486        let ignored = classes.iter().find(|class| class.name == "IgnoredService").unwrap();
487        assert!(ignored.has_ignore_directive);
488    }
489}