rust_diff_analyzer/analysis/
ast_visitor.rs

1// SPDX-FileCopyrightText: 2025 RAprogramm <andrey.rozanov.vl@gmail.com>
2// SPDX-License-Identifier: MIT
3
4use proc_macro2::Span;
5use syn::{
6    Attribute, File, ImplItem, ItemConst, ItemEnum, ItemFn, ItemImpl, ItemMacro, ItemMod,
7    ItemStatic, ItemStruct, ItemTrait, ItemType, TraitItem, Visibility as SynVisibility,
8    spanned::Spanned, visit::Visit,
9};
10
11use crate::types::{LineSpan, SemanticUnit, SemanticUnitKind, Visibility};
12
13/// Visitor for extracting semantic units from Rust AST
14pub struct SemanticUnitVisitor {
15    units: Vec<SemanticUnit>,
16    in_test_module: bool,
17}
18
19impl SemanticUnitVisitor {
20    /// Creates a new semantic unit visitor
21    ///
22    /// # Returns
23    ///
24    /// A new SemanticUnitVisitor instance
25    ///
26    /// # Examples
27    ///
28    /// ```
29    /// use rust_diff_analyzer::analysis::ast_visitor::SemanticUnitVisitor;
30    ///
31    /// let visitor = SemanticUnitVisitor::new();
32    /// ```
33    pub fn new() -> Self {
34        Self {
35            units: Vec::new(),
36            in_test_module: false,
37        }
38    }
39
40    /// Extracts semantic units from a parsed AST
41    ///
42    /// # Arguments
43    ///
44    /// * `file` - Parsed syn File
45    ///
46    /// # Returns
47    ///
48    /// Vector of extracted semantic units
49    ///
50    /// # Examples
51    ///
52    /// ```
53    /// use rust_diff_analyzer::analysis::ast_visitor::SemanticUnitVisitor;
54    ///
55    /// let code = "fn main() {}";
56    /// let file = syn::parse_file(code).unwrap();
57    /// let units = SemanticUnitVisitor::extract(&file);
58    /// assert_eq!(units.len(), 1);
59    /// ```
60    pub fn extract(file: &File) -> Vec<SemanticUnit> {
61        let mut visitor = Self::new();
62        visitor.visit_file(file);
63        visitor.units
64    }
65
66    fn span_to_line_span(&self, span: Span) -> LineSpan {
67        let start = span.start();
68        let end = span.end();
69        LineSpan::new(start.line, end.line)
70    }
71
72    fn convert_visibility(&self, vis: &SynVisibility) -> Visibility {
73        match vis {
74            SynVisibility::Public(_) => Visibility::Public,
75            SynVisibility::Restricted(r) => {
76                if r.path.is_ident("crate") {
77                    Visibility::Crate
78                } else {
79                    Visibility::Restricted
80                }
81            }
82            SynVisibility::Inherited => Visibility::Private,
83        }
84    }
85
86    fn extract_attributes(&self, attrs: &[Attribute]) -> Vec<String> {
87        attrs
88            .iter()
89            .filter_map(|attr| attr.path().get_ident().map(|ident| ident.to_string()))
90            .collect()
91    }
92
93    fn has_test_attribute(&self, attrs: &[Attribute]) -> bool {
94        attrs.iter().any(|attr| {
95            let path = attr.path();
96            if path.is_ident("test") || path.is_ident("bench") {
97                return true;
98            }
99            if path.is_ident("cfg")
100                && let Ok(meta) = attr.meta.require_list()
101            {
102                let tokens = meta.tokens.to_string();
103                if tokens.contains("test") {
104                    return true;
105                }
106            }
107            false
108        })
109    }
110
111    fn is_test_module(&self, attrs: &[Attribute]) -> bool {
112        attrs.iter().any(|attr| {
113            if attr.path().is_ident("cfg")
114                && let Ok(meta) = attr.meta.require_list()
115            {
116                let tokens = meta.tokens.to_string();
117                return tokens.contains("test");
118            }
119            false
120        })
121    }
122
123    fn add_unit(
124        &mut self,
125        kind: SemanticUnitKind,
126        name: String,
127        visibility: Visibility,
128        span: Span,
129        attrs: &[Attribute],
130    ) {
131        let mut attributes = self.extract_attributes(attrs);
132
133        if self.in_test_module && !attributes.contains(&"cfg_test".to_string()) {
134            attributes.push("cfg_test".to_string());
135        }
136
137        if self.has_test_attribute(attrs) && !attributes.contains(&"test".to_string()) {
138            attributes.push("test".to_string());
139        }
140
141        let unit = SemanticUnit::new(
142            kind,
143            name,
144            visibility,
145            self.span_to_line_span(span),
146            attributes,
147        );
148        self.units.push(unit);
149    }
150}
151
152impl Default for SemanticUnitVisitor {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158impl<'ast> Visit<'ast> for SemanticUnitVisitor {
159    fn visit_item_fn(&mut self, node: &'ast ItemFn) {
160        self.add_unit(
161            SemanticUnitKind::Function,
162            node.sig.ident.to_string(),
163            self.convert_visibility(&node.vis),
164            node.span(),
165            &node.attrs,
166        );
167        syn::visit::visit_item_fn(self, node);
168    }
169
170    fn visit_item_struct(&mut self, node: &'ast ItemStruct) {
171        self.add_unit(
172            SemanticUnitKind::Struct,
173            node.ident.to_string(),
174            self.convert_visibility(&node.vis),
175            node.span(),
176            &node.attrs,
177        );
178        syn::visit::visit_item_struct(self, node);
179    }
180
181    fn visit_item_enum(&mut self, node: &'ast ItemEnum) {
182        self.add_unit(
183            SemanticUnitKind::Enum,
184            node.ident.to_string(),
185            self.convert_visibility(&node.vis),
186            node.span(),
187            &node.attrs,
188        );
189        syn::visit::visit_item_enum(self, node);
190    }
191
192    fn visit_item_trait(&mut self, node: &'ast ItemTrait) {
193        self.add_unit(
194            SemanticUnitKind::Trait,
195            node.ident.to_string(),
196            self.convert_visibility(&node.vis),
197            node.span(),
198            &node.attrs,
199        );
200        syn::visit::visit_item_trait(self, node);
201    }
202
203    fn visit_item_impl(&mut self, node: &'ast ItemImpl) {
204        let name = if let Some((_, path, _)) = &node.trait_ {
205            format!(
206                "{} for {}",
207                path.segments
208                    .last()
209                    .map(|s| s.ident.to_string())
210                    .unwrap_or_default(),
211                type_to_string(&node.self_ty)
212            )
213        } else {
214            type_to_string(&node.self_ty)
215        };
216
217        self.add_unit(
218            SemanticUnitKind::Impl,
219            name,
220            Visibility::Private,
221            node.span(),
222            &node.attrs,
223        );
224
225        for item in &node.items {
226            match item {
227                ImplItem::Fn(method) => {
228                    self.add_unit(
229                        SemanticUnitKind::Function,
230                        method.sig.ident.to_string(),
231                        self.convert_visibility(&method.vis),
232                        method.span(),
233                        &method.attrs,
234                    );
235                }
236                ImplItem::Const(c) => {
237                    self.add_unit(
238                        SemanticUnitKind::Const,
239                        c.ident.to_string(),
240                        self.convert_visibility(&c.vis),
241                        c.span(),
242                        &c.attrs,
243                    );
244                }
245                ImplItem::Type(t) => {
246                    self.add_unit(
247                        SemanticUnitKind::TypeAlias,
248                        t.ident.to_string(),
249                        self.convert_visibility(&t.vis),
250                        t.span(),
251                        &t.attrs,
252                    );
253                }
254                _ => {}
255            }
256        }
257    }
258
259    fn visit_item_const(&mut self, node: &'ast ItemConst) {
260        self.add_unit(
261            SemanticUnitKind::Const,
262            node.ident.to_string(),
263            self.convert_visibility(&node.vis),
264            node.span(),
265            &node.attrs,
266        );
267    }
268
269    fn visit_item_static(&mut self, node: &'ast ItemStatic) {
270        self.add_unit(
271            SemanticUnitKind::Static,
272            node.ident.to_string(),
273            self.convert_visibility(&node.vis),
274            node.span(),
275            &node.attrs,
276        );
277    }
278
279    fn visit_item_type(&mut self, node: &'ast ItemType) {
280        self.add_unit(
281            SemanticUnitKind::TypeAlias,
282            node.ident.to_string(),
283            self.convert_visibility(&node.vis),
284            node.span(),
285            &node.attrs,
286        );
287    }
288
289    fn visit_item_macro(&mut self, node: &'ast ItemMacro) {
290        if let Some(ident) = &node.ident {
291            self.add_unit(
292                SemanticUnitKind::Macro,
293                ident.to_string(),
294                Visibility::Private,
295                node.span(),
296                &node.attrs,
297            );
298        }
299    }
300
301    fn visit_item_mod(&mut self, node: &'ast ItemMod) {
302        let is_test = self.is_test_module(&node.attrs) || node.ident == "tests";
303
304        self.add_unit(
305            SemanticUnitKind::Module,
306            node.ident.to_string(),
307            self.convert_visibility(&node.vis),
308            node.span(),
309            &node.attrs,
310        );
311
312        if let Some((_, items)) = &node.content {
313            let was_in_test = self.in_test_module;
314            self.in_test_module = is_test || was_in_test;
315
316            for item in items {
317                self.visit_item(item);
318            }
319
320            self.in_test_module = was_in_test;
321        }
322    }
323
324    fn visit_trait_item(&mut self, node: &'ast TraitItem) {
325        match node {
326            TraitItem::Fn(method) => {
327                self.add_unit(
328                    SemanticUnitKind::Function,
329                    method.sig.ident.to_string(),
330                    Visibility::Public,
331                    method.span(),
332                    &method.attrs,
333                );
334            }
335            TraitItem::Const(c) => {
336                self.add_unit(
337                    SemanticUnitKind::Const,
338                    c.ident.to_string(),
339                    Visibility::Public,
340                    c.span(),
341                    &c.attrs,
342                );
343            }
344            TraitItem::Type(t) => {
345                self.add_unit(
346                    SemanticUnitKind::TypeAlias,
347                    t.ident.to_string(),
348                    Visibility::Public,
349                    t.span(),
350                    &t.attrs,
351                );
352            }
353            _ => {}
354        }
355        syn::visit::visit_trait_item(self, node);
356    }
357}
358
359fn type_to_string(ty: &syn::Type) -> String {
360    match ty {
361        syn::Type::Path(p) => p
362            .path
363            .segments
364            .last()
365            .map(|s| s.ident.to_string())
366            .unwrap_or_else(|| "Unknown".to_string()),
367        _ => "Unknown".to_string(),
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_extract_function() {
377        let code = "pub fn hello() {}";
378        let file = syn::parse_file(code).expect("parse failed");
379        let units = SemanticUnitVisitor::extract(&file);
380
381        assert_eq!(units.len(), 1);
382        assert_eq!(units[0].name, "hello");
383        assert!(matches!(units[0].kind, SemanticUnitKind::Function));
384        assert!(matches!(units[0].visibility, Visibility::Public));
385    }
386
387    #[test]
388    fn test_extract_struct() {
389        let code = "struct Point { x: i32, y: i32 }";
390        let file = syn::parse_file(code).expect("parse failed");
391        let units = SemanticUnitVisitor::extract(&file);
392
393        assert_eq!(units.len(), 1);
394        assert_eq!(units[0].name, "Point");
395        assert!(matches!(units[0].kind, SemanticUnitKind::Struct));
396    }
397
398    #[test]
399    fn test_extract_test_function() {
400        let code = r#"
401            #[test]
402            fn test_something() {}
403        "#;
404        let file = syn::parse_file(code).expect("parse failed");
405        let units = SemanticUnitVisitor::extract(&file);
406
407        assert_eq!(units.len(), 1);
408        assert!(units[0].has_attribute("test"));
409    }
410
411    #[test]
412    fn test_extract_impl_block() {
413        let code = r#"
414            struct Foo;
415            impl Foo {
416                pub fn new() -> Self { Foo }
417            }
418        "#;
419        let file = syn::parse_file(code).expect("parse failed");
420        let units = SemanticUnitVisitor::extract(&file);
421
422        assert_eq!(units.len(), 3);
423        assert!(
424            units
425                .iter()
426                .any(|u| u.name == "Foo" && matches!(u.kind, SemanticUnitKind::Struct))
427        );
428        assert!(
429            units
430                .iter()
431                .any(|u| u.name == "Foo" && matches!(u.kind, SemanticUnitKind::Impl))
432        );
433        assert!(
434            units
435                .iter()
436                .any(|u| u.name == "new" && matches!(u.kind, SemanticUnitKind::Function))
437        );
438    }
439
440    #[test]
441    fn test_extract_test_module() {
442        let code = r#"
443            fn production() {}
444
445            #[cfg(test)]
446            mod tests {
447                fn helper() {}
448
449                #[test]
450                fn test_it() {}
451            }
452        "#;
453        let file = syn::parse_file(code).expect("parse failed");
454        let units = SemanticUnitVisitor::extract(&file);
455
456        let prod_fn = units
457            .iter()
458            .find(|u| u.name == "production")
459            .expect("production not found");
460        assert!(!prod_fn.has_attribute("cfg_test"));
461
462        let helper_fn = units
463            .iter()
464            .find(|u| u.name == "helper")
465            .expect("helper not found");
466        assert!(helper_fn.has_attribute("cfg_test"));
467
468        let test_fn = units
469            .iter()
470            .find(|u| u.name == "test_it")
471            .expect("test_it not found");
472        assert!(test_fn.has_attribute("test"));
473        assert!(test_fn.has_attribute("cfg_test"));
474    }
475}