rust_diff_analyzer/analysis/
ast_visitor.rs

1// SPDX-FileCopyrightText: 2025 RAprogramm <andrey.rozanov.vl@gmail.com>
2// SPDX-License-Identifier: MIT
3
4use proc_macro2::Span;
5use syn::{
6    Attribute, File, ImplItem, ItemConst, ItemEnum, ItemFn, ItemImpl, ItemMacro, ItemMod,
7    ItemStatic, ItemStruct, ItemTrait, ItemType, TraitItem, Visibility as SynVisibility,
8    spanned::Spanned, visit::Visit,
9};
10
11use crate::types::{LineSpan, SemanticUnit, SemanticUnitKind, Visibility};
12
13/// Visitor for extracting semantic units from Rust AST
14pub struct SemanticUnitVisitor {
15    units: Vec<SemanticUnit>,
16    in_test_module: bool,
17    current_impl_name: Option<String>,
18}
19
20impl SemanticUnitVisitor {
21    /// Creates a new semantic unit visitor
22    ///
23    /// # Returns
24    ///
25    /// A new SemanticUnitVisitor instance
26    ///
27    /// # Examples
28    ///
29    /// ```
30    /// use rust_diff_analyzer::analysis::ast_visitor::SemanticUnitVisitor;
31    ///
32    /// let visitor = SemanticUnitVisitor::new();
33    /// ```
34    pub fn new() -> Self {
35        Self {
36            units: Vec::new(),
37            in_test_module: false,
38            current_impl_name: None,
39        }
40    }
41
42    /// Extracts semantic units from a parsed AST
43    ///
44    /// # Arguments
45    ///
46    /// * `file` - Parsed syn File
47    ///
48    /// # Returns
49    ///
50    /// Vector of extracted semantic units
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use rust_diff_analyzer::analysis::ast_visitor::SemanticUnitVisitor;
56    ///
57    /// let code = "fn main() {}";
58    /// let file = syn::parse_file(code).unwrap();
59    /// let units = SemanticUnitVisitor::extract(&file);
60    /// assert_eq!(units.len(), 1);
61    /// ```
62    pub fn extract(file: &File) -> Vec<SemanticUnit> {
63        let mut visitor = Self::new();
64        visitor.visit_file(file);
65        visitor.units
66    }
67
68    fn span_to_line_span(&self, span: Span) -> LineSpan {
69        let start = span.start();
70        let end = span.end();
71        LineSpan::new(start.line, end.line)
72    }
73
74    fn convert_visibility(&self, vis: &SynVisibility) -> Visibility {
75        match vis {
76            SynVisibility::Public(_) => Visibility::Public,
77            SynVisibility::Restricted(r) => {
78                if r.path.is_ident("crate") {
79                    Visibility::Crate
80                } else {
81                    Visibility::Restricted
82                }
83            }
84            SynVisibility::Inherited => Visibility::Private,
85        }
86    }
87
88    fn extract_attributes(&self, attrs: &[Attribute]) -> Vec<String> {
89        attrs
90            .iter()
91            .filter_map(|attr| attr.path().get_ident().map(|ident| ident.to_string()))
92            .collect()
93    }
94
95    fn has_test_attribute(&self, attrs: &[Attribute]) -> bool {
96        attrs.iter().any(|attr| {
97            let path = attr.path();
98            if path.is_ident("test") || path.is_ident("bench") {
99                return true;
100            }
101            if path.is_ident("cfg")
102                && let Ok(meta) = attr.meta.require_list()
103            {
104                let tokens = meta.tokens.to_string();
105                if tokens.contains("test") {
106                    return true;
107                }
108            }
109            false
110        })
111    }
112
113    fn is_test_module(&self, attrs: &[Attribute]) -> bool {
114        attrs.iter().any(|attr| {
115            if attr.path().is_ident("cfg")
116                && let Ok(meta) = attr.meta.require_list()
117            {
118                let tokens = meta.tokens.to_string();
119                return tokens.contains("test");
120            }
121            false
122        })
123    }
124
125    fn add_unit(
126        &mut self,
127        kind: SemanticUnitKind,
128        name: String,
129        visibility: Visibility,
130        span: Span,
131        attrs: &[Attribute],
132    ) {
133        let mut attributes = self.extract_attributes(attrs);
134
135        if self.in_test_module && !attributes.contains(&"cfg_test".to_string()) {
136            attributes.push("cfg_test".to_string());
137        }
138
139        if self.has_test_attribute(attrs) && !attributes.contains(&"test".to_string()) {
140            attributes.push("test".to_string());
141        }
142
143        let unit = match &self.current_impl_name {
144            Some(impl_name) => SemanticUnit::with_impl(
145                kind,
146                name,
147                impl_name.clone(),
148                visibility,
149                self.span_to_line_span(span),
150                attributes,
151            ),
152            None => SemanticUnit::new(
153                kind,
154                name,
155                visibility,
156                self.span_to_line_span(span),
157                attributes,
158            ),
159        };
160        self.units.push(unit);
161    }
162}
163
164impl Default for SemanticUnitVisitor {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170impl<'ast> Visit<'ast> for SemanticUnitVisitor {
171    fn visit_item_fn(&mut self, node: &'ast ItemFn) {
172        self.add_unit(
173            SemanticUnitKind::Function,
174            node.sig.ident.to_string(),
175            self.convert_visibility(&node.vis),
176            node.span(),
177            &node.attrs,
178        );
179        syn::visit::visit_item_fn(self, node);
180    }
181
182    fn visit_item_struct(&mut self, node: &'ast ItemStruct) {
183        self.add_unit(
184            SemanticUnitKind::Struct,
185            node.ident.to_string(),
186            self.convert_visibility(&node.vis),
187            node.span(),
188            &node.attrs,
189        );
190        syn::visit::visit_item_struct(self, node);
191    }
192
193    fn visit_item_enum(&mut self, node: &'ast ItemEnum) {
194        self.add_unit(
195            SemanticUnitKind::Enum,
196            node.ident.to_string(),
197            self.convert_visibility(&node.vis),
198            node.span(),
199            &node.attrs,
200        );
201        syn::visit::visit_item_enum(self, node);
202    }
203
204    fn visit_item_trait(&mut self, node: &'ast ItemTrait) {
205        self.add_unit(
206            SemanticUnitKind::Trait,
207            node.ident.to_string(),
208            self.convert_visibility(&node.vis),
209            node.span(),
210            &node.attrs,
211        );
212        syn::visit::visit_item_trait(self, node);
213    }
214
215    fn visit_item_impl(&mut self, node: &'ast ItemImpl) {
216        let impl_name = if let Some((_, path, _)) = &node.trait_ {
217            format!(
218                "{} for {}",
219                path.segments
220                    .last()
221                    .map(|s| s.ident.to_string())
222                    .unwrap_or_default(),
223                type_to_string(&node.self_ty)
224            )
225        } else {
226            type_to_string(&node.self_ty)
227        };
228
229        self.add_unit(
230            SemanticUnitKind::Impl,
231            impl_name.clone(),
232            Visibility::Private,
233            node.span(),
234            &node.attrs,
235        );
236
237        let previous_impl_name = self.current_impl_name.take();
238        self.current_impl_name = Some(impl_name);
239
240        for item in &node.items {
241            match item {
242                ImplItem::Fn(method) => {
243                    self.add_unit(
244                        SemanticUnitKind::Function,
245                        method.sig.ident.to_string(),
246                        self.convert_visibility(&method.vis),
247                        method.span(),
248                        &method.attrs,
249                    );
250                }
251                ImplItem::Const(c) => {
252                    self.add_unit(
253                        SemanticUnitKind::Const,
254                        c.ident.to_string(),
255                        self.convert_visibility(&c.vis),
256                        c.span(),
257                        &c.attrs,
258                    );
259                }
260                ImplItem::Type(t) => {
261                    self.add_unit(
262                        SemanticUnitKind::TypeAlias,
263                        t.ident.to_string(),
264                        self.convert_visibility(&t.vis),
265                        t.span(),
266                        &t.attrs,
267                    );
268                }
269                _ => {}
270            }
271        }
272
273        self.current_impl_name = previous_impl_name;
274    }
275
276    fn visit_item_const(&mut self, node: &'ast ItemConst) {
277        self.add_unit(
278            SemanticUnitKind::Const,
279            node.ident.to_string(),
280            self.convert_visibility(&node.vis),
281            node.span(),
282            &node.attrs,
283        );
284    }
285
286    fn visit_item_static(&mut self, node: &'ast ItemStatic) {
287        self.add_unit(
288            SemanticUnitKind::Static,
289            node.ident.to_string(),
290            self.convert_visibility(&node.vis),
291            node.span(),
292            &node.attrs,
293        );
294    }
295
296    fn visit_item_type(&mut self, node: &'ast ItemType) {
297        self.add_unit(
298            SemanticUnitKind::TypeAlias,
299            node.ident.to_string(),
300            self.convert_visibility(&node.vis),
301            node.span(),
302            &node.attrs,
303        );
304    }
305
306    fn visit_item_macro(&mut self, node: &'ast ItemMacro) {
307        if let Some(ident) = &node.ident {
308            self.add_unit(
309                SemanticUnitKind::Macro,
310                ident.to_string(),
311                Visibility::Private,
312                node.span(),
313                &node.attrs,
314            );
315        }
316    }
317
318    fn visit_item_mod(&mut self, node: &'ast ItemMod) {
319        let is_test = self.is_test_module(&node.attrs) || node.ident == "tests";
320
321        self.add_unit(
322            SemanticUnitKind::Module,
323            node.ident.to_string(),
324            self.convert_visibility(&node.vis),
325            node.span(),
326            &node.attrs,
327        );
328
329        if let Some((_, items)) = &node.content {
330            let was_in_test = self.in_test_module;
331            self.in_test_module = is_test || was_in_test;
332
333            for item in items {
334                self.visit_item(item);
335            }
336
337            self.in_test_module = was_in_test;
338        }
339    }
340
341    fn visit_trait_item(&mut self, node: &'ast TraitItem) {
342        match node {
343            TraitItem::Fn(method) => {
344                self.add_unit(
345                    SemanticUnitKind::Function,
346                    method.sig.ident.to_string(),
347                    Visibility::Public,
348                    method.span(),
349                    &method.attrs,
350                );
351            }
352            TraitItem::Const(c) => {
353                self.add_unit(
354                    SemanticUnitKind::Const,
355                    c.ident.to_string(),
356                    Visibility::Public,
357                    c.span(),
358                    &c.attrs,
359                );
360            }
361            TraitItem::Type(t) => {
362                self.add_unit(
363                    SemanticUnitKind::TypeAlias,
364                    t.ident.to_string(),
365                    Visibility::Public,
366                    t.span(),
367                    &t.attrs,
368                );
369            }
370            _ => {}
371        }
372        syn::visit::visit_trait_item(self, node);
373    }
374}
375
376fn type_to_string(ty: &syn::Type) -> String {
377    match ty {
378        syn::Type::Path(p) => p
379            .path
380            .segments
381            .last()
382            .map(|s| s.ident.to_string())
383            .unwrap_or_else(|| "Unknown".to_string()),
384        _ => "Unknown".to_string(),
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_extract_function() {
394        let code = "pub fn hello() {}";
395        let file = syn::parse_file(code).expect("parse failed");
396        let units = SemanticUnitVisitor::extract(&file);
397
398        assert_eq!(units.len(), 1);
399        assert_eq!(units[0].name, "hello");
400        assert!(matches!(units[0].kind, SemanticUnitKind::Function));
401        assert!(matches!(units[0].visibility, Visibility::Public));
402    }
403
404    #[test]
405    fn test_extract_struct() {
406        let code = "struct Point { x: i32, y: i32 }";
407        let file = syn::parse_file(code).expect("parse failed");
408        let units = SemanticUnitVisitor::extract(&file);
409
410        assert_eq!(units.len(), 1);
411        assert_eq!(units[0].name, "Point");
412        assert!(matches!(units[0].kind, SemanticUnitKind::Struct));
413    }
414
415    #[test]
416    fn test_extract_test_function() {
417        let code = r#"
418            #[test]
419            fn test_something() {}
420        "#;
421        let file = syn::parse_file(code).expect("parse failed");
422        let units = SemanticUnitVisitor::extract(&file);
423
424        assert_eq!(units.len(), 1);
425        assert!(units[0].has_attribute("test"));
426    }
427
428    #[test]
429    fn test_extract_impl_block() {
430        let code = r#"
431            struct Foo;
432            impl Foo {
433                pub fn new() -> Self { Foo }
434            }
435        "#;
436        let file = syn::parse_file(code).expect("parse failed");
437        let units = SemanticUnitVisitor::extract(&file);
438
439        assert_eq!(units.len(), 3);
440        assert!(
441            units
442                .iter()
443                .any(|u| u.name == "Foo" && matches!(u.kind, SemanticUnitKind::Struct))
444        );
445        assert!(
446            units
447                .iter()
448                .any(|u| u.name == "Foo" && matches!(u.kind, SemanticUnitKind::Impl))
449        );
450        assert!(
451            units
452                .iter()
453                .any(|u| u.name == "new" && matches!(u.kind, SemanticUnitKind::Function))
454        );
455    }
456
457    #[test]
458    fn test_extract_test_module() {
459        let code = r#"
460            fn production() {}
461
462            #[cfg(test)]
463            mod tests {
464                fn helper() {}
465
466                #[test]
467                fn test_it() {}
468            }
469        "#;
470        let file = syn::parse_file(code).expect("parse failed");
471        let units = SemanticUnitVisitor::extract(&file);
472
473        let prod_fn = units
474            .iter()
475            .find(|u| u.name == "production")
476            .expect("production not found");
477        assert!(!prod_fn.has_attribute("cfg_test"));
478
479        let helper_fn = units
480            .iter()
481            .find(|u| u.name == "helper")
482            .expect("helper not found");
483        assert!(helper_fn.has_attribute("cfg_test"));
484
485        let test_fn = units
486            .iter()
487            .find(|u| u.name == "test_it")
488            .expect("test_it not found");
489        assert!(test_fn.has_attribute("test"));
490        assert!(test_fn.has_attribute("cfg_test"));
491    }
492}