Skip to main content

ryo_source/
ast.rs

1//! Core RustAST implementation.
2
3use std::path::Path;
4
5use proc_macro2::TokenStream;
6use quote::ToTokens;
7use syn::visit::Visit;
8use syn::visit_mut::VisitMut;
9
10use crate::error::SourceResult;
11use crate::ops;
12use crate::visitor::{IdentifierCollector, ImportCollector};
13
14/// Rust source code with parsed AST.
15///
16/// This is the main entry point for AST manipulation.
17/// It wraps a `syn::File` and provides high-level operations.
18#[derive(Debug, Clone)]
19pub struct RustAST {
20    /// The parsed AST.
21    file: syn::File,
22    /// Original source (optional, for span information).
23    source: Option<String>,
24}
25
26impl RustAST {
27    /// Parse Rust source code into an AST.
28    pub fn parse(source: &str) -> SourceResult<Self> {
29        let file = syn::parse_file(source)?;
30        Ok(Self {
31            file,
32            source: Some(source.to_string()),
33        })
34    }
35
36    /// Load and parse a Rust file.
37    pub fn from_file(path: &Path) -> SourceResult<Self> {
38        let source = std::fs::read_to_string(path)?;
39        Self::parse(&source)
40    }
41
42    /// Get a reference to the underlying syn::File.
43    pub fn file(&self) -> &syn::File {
44        &self.file
45    }
46
47    /// Get a mutable reference to the underlying syn::File.
48    pub fn file_mut(&mut self) -> &mut syn::File {
49        &mut self.file
50    }
51
52    /// Get the original source if available.
53    pub fn source(&self) -> Option<&str> {
54        self.source.as_deref()
55    }
56
57    /// Convert the AST to a pretty-printed source code.
58    pub fn to_string_pretty(&self) -> String {
59        prettyplease::unparse(&self.file)
60    }
61
62    /// Get the token stream.
63    pub fn to_token_stream(&self) -> TokenStream {
64        self.file.to_token_stream()
65    }
66
67    // ==================== Analysis ====================
68
69    /// Collect all use statements.
70    pub fn collect_imports(&self) -> Vec<&syn::ItemUse> {
71        let mut collector = ImportCollector::new();
72        collector.visit_file(&self.file);
73        collector.imports
74    }
75
76    /// Collect all identifiers used in the code (excluding imports).
77    pub fn collect_used_identifiers(&self) -> std::collections::HashSet<String> {
78        let mut collector = IdentifierCollector::new();
79        collector.visit_file(&self.file);
80        collector.identifiers
81    }
82
83    /// Find unused imports.
84    pub fn find_unused_imports(&self) -> Vec<UnusedImport> {
85        ops::RemoveUnusedImports::detect(self)
86    }
87
88    // ==================== Transformations ====================
89
90    /// Remove all unused imports. Returns the removed imports.
91    pub fn remove_unused_imports(&mut self) -> Vec<UnusedImport> {
92        ops::RemoveUnusedImports::apply(self)
93    }
94
95    /// Apply a custom visitor mutation.
96    pub fn visit_mut<V: VisitMut>(&mut self, visitor: &mut V) {
97        visitor.visit_file_mut(&mut self.file);
98    }
99
100    /// Apply a custom visitor (read-only).
101    pub fn visit<'a, V: Visit<'a>>(&'a self, visitor: &mut V) {
102        visitor.visit_file(&self.file);
103    }
104
105    // ==================== Item Access ====================
106
107    /// Get all items in the file.
108    pub fn items(&self) -> &[syn::Item] {
109        &self.file.items
110    }
111
112    /// Get mutable access to all items.
113    pub fn items_mut(&mut self) -> &mut Vec<syn::Item> {
114        &mut self.file.items
115    }
116
117    /// Filter items by type.
118    pub fn filter_items<F>(&self, predicate: F) -> Vec<&syn::Item>
119    where
120        F: Fn(&syn::Item) -> bool,
121    {
122        self.file.items.iter().filter(|i| predicate(i)).collect()
123    }
124
125    /// Get all functions.
126    pub fn functions(&self) -> Vec<&syn::ItemFn> {
127        self.file
128            .items
129            .iter()
130            .filter_map(|item| {
131                if let syn::Item::Fn(f) = item {
132                    Some(f)
133                } else {
134                    None
135                }
136            })
137            .collect()
138    }
139
140    /// Get all structs.
141    pub fn structs(&self) -> Vec<&syn::ItemStruct> {
142        self.file
143            .items
144            .iter()
145            .filter_map(|item| {
146                if let syn::Item::Struct(s) = item {
147                    Some(s)
148                } else {
149                    None
150                }
151            })
152            .collect()
153    }
154
155    /// Get all impl blocks.
156    pub fn impls(&self) -> Vec<&syn::ItemImpl> {
157        self.file
158            .items
159            .iter()
160            .filter_map(|item| {
161                if let syn::Item::Impl(i) = item {
162                    Some(i)
163                } else {
164                    None
165                }
166            })
167            .collect()
168    }
169}
170
171impl std::fmt::Display for RustAST {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        write!(f, "{}", self.to_string_pretty())
174    }
175}
176
177/// Information about an unused import.
178#[derive(Debug, Clone)]
179pub struct UnusedImport {
180    /// The imported path (e.g., "std::io").
181    pub path: String,
182    /// The name that would be used in code (e.g., "io").
183    pub name: String,
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_parse_simple() {
192        let ast = RustAST::parse("fn main() {}").unwrap();
193        assert_eq!(ast.functions().len(), 1);
194    }
195
196    #[test]
197    fn test_parse_with_imports() {
198        let ast = RustAST::parse("use std::io;\nuse std::fs;\nfn main() {}").unwrap();
199        assert_eq!(ast.collect_imports().len(), 2);
200    }
201
202    #[test]
203    fn test_to_string() {
204        let ast = RustAST::parse("fn main() {}").unwrap();
205        let output = ast.to_string();
206        assert!(output.contains("fn main"));
207    }
208
209    #[test]
210    fn test_collect_identifiers() {
211        let ast = RustAST::parse(
212            r#"
213            use std::io;
214            fn main() {
215                let x = io::stdin();
216                println!("{}", x);
217            }
218            "#,
219        )
220        .unwrap();
221
222        let idents = ast.collect_used_identifiers();
223        assert!(idents.contains("io"));
224        assert!(idents.contains("x"));
225        assert!(idents.contains("println"));
226    }
227}