rusty_ast/
text_visitor.rs

1use quote::ToTokens;
2use std::fs;
3use std::io;
4use std::path::Path;
5
6use syn::{File, visit::Visit};
7
8/// TextVisitor is a visitor that prints the AST in a text format
9///
10/// # Fields
11/// * `indent`: usize - the current indentation level
12pub struct TextVisitor {
13    indent: usize,
14}
15
16/// # Methods
17/// * `new()`: creates a new TextVisitor
18/// * `print_indent()`: prints the current indentation level
19impl TextVisitor {
20    /// new
21    ///
22    /// # Arguments
23    /// * `()`
24    ///
25    /// # Returns
26    /// * `TextVisitor` - a new TextVisitor
27    pub fn new() -> Self {
28        TextVisitor { indent: 0 }
29    }
30
31    /// print_indent
32    ///
33    /// # Arguments
34    /// * `self`: &Self - the TextVisitor
35    ///
36    /// # Returns
37    /// * `String` - the current indentation level
38    fn print_indent(&self) -> String {
39        " ".repeat(self.indent)
40    }
41}
42
43/// implement Visit trait for AstText
44/// Visit trait is defined in syn::visit
45impl<'ast> syn::visit::Visit<'ast> for TextVisitor {
46    /// visit_item_fn is defined in syn::visit::Visit
47    /// visit_item_fn is called when a Rust function definition is visited
48    ///
49    /// # Arguments
50    /// * `node`: &'ast syn::ItemFn
51    ///
52    /// # Returns
53    /// * `()`
54    fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
55        println!("{}Function: {}", self.print_indent(), node.sig.ident);
56        self.indent += 2;
57
58        if !node.sig.inputs.is_empty() {
59            println!("{}Parameters:", self.print_indent());
60            self.indent += 2;
61            for param in &node.sig.inputs {
62                match param {
63                    syn::FnArg::Typed(pat_type) => {
64                        if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
65                            println!(
66                                "{}Parameter: {} - Type: {}",
67                                self.print_indent(),
68                                pat_ident.ident,
69                                pat_type.ty.to_token_stream()
70                            );
71                        }
72                    }
73                    syn::FnArg::Receiver(receiver) => {
74                        println!(
75                            "{}Self receiver: {}",
76                            self.print_indent(),
77                            receiver.to_token_stream()
78                        );
79                    }
80                }
81            }
82            self.indent -= 2;
83        }
84
85        if let syn::ReturnType::Type(_, return_type) = &node.sig.output {
86            println!(
87                "{}Return type: {}",
88                self.print_indent(),
89                return_type.to_token_stream()
90            );
91        }
92
93        println!("{}Body:", self.print_indent());
94        self.indent += 2;
95        for stmt in &node.block.stmts {
96            self.visit_stmt(stmt);
97        }
98        self.indent -= 2;
99    }
100
101    /// visit_expr is defined in syn::visit::Visit
102    /// visit_expr is called when a Rust expression is visited
103    ///
104    /// # Arguments
105    /// * `node`: &'ast syn::Expr
106    ///
107    /// # Returns
108    /// * `()`
109    fn visit_expr(&mut self, node: &'ast syn::Expr) {
110        match node {
111            syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
112                syn::Lit::Int(lit_int) => {
113                    println!(
114                        "{}Integer literal: {}",
115                        self.print_indent(),
116                        lit_int.base10_digits()
117                    );
118                }
119                syn::Lit::Float(lit_float) => {
120                    println!(
121                        "{}Float literal: {}",
122                        self.print_indent(),
123                        lit_float.base10_digits()
124                    );
125                }
126                syn::Lit::Str(lit_str) => {
127                    println!(
128                        "{}String literal: \"{}\"",
129                        self.print_indent(),
130                        lit_str.value()
131                    );
132                }
133                syn::Lit::Bool(lit_bool) => {
134                    println!("{}Boolean literal: {}", self.print_indent(), lit_bool.value);
135                }
136                _ => {
137                    println!(
138                        "{}Other literal: {}",
139                        self.print_indent(),
140                        expr_lit.to_token_stream()
141                    );
142                }
143            },
144            syn::Expr::Binary(expr_bin) => {
145                let op = match expr_bin.op {
146                    syn::BinOp::Add(_) => "+",
147                    syn::BinOp::Sub(_) => "-",
148                    syn::BinOp::Mul(_) => "*",
149                    syn::BinOp::Div(_) => "/",
150                    syn::BinOp::Eq(_) => "==",
151                    syn::BinOp::Lt(_) => "<",
152                    syn::BinOp::Le(_) => "<=",
153                    syn::BinOp::Ne(_) => "!=",
154                    syn::BinOp::Ge(_) => ">=",
155                    syn::BinOp::Gt(_) => ">",
156                    _ => "other_operator",
157                };
158                println!("{}Binary expression: {}", self.print_indent(), op);
159
160                println!("{}Left:", self.print_indent());
161                self.indent += 2;
162                self.visit_expr(&expr_bin.left);
163                self.indent -= 2;
164
165                println!("{}Right:", self.print_indent());
166                self.indent += 2;
167                self.visit_expr(&expr_bin.right);
168                self.indent -= 2;
169            }
170            syn::Expr::Call(expr_call) => {
171                println!("{}Function call:", self.print_indent());
172
173                println!("{}Function:", self.print_indent());
174                self.indent += 2;
175                self.visit_expr(&expr_call.func);
176                self.indent -= 2;
177
178                if !expr_call.args.is_empty() {
179                    println!("{}Arguments:", self.print_indent());
180                    self.indent += 2;
181                    for arg in &expr_call.args {
182                        self.visit_expr(arg);
183                    }
184                    self.indent -= 2;
185                }
186            }
187            syn::Expr::Path(expr_path) => {
188                println!(
189                    "{}Identifier: {}",
190                    self.print_indent(),
191                    expr_path.to_token_stream()
192                );
193            }
194            syn::Expr::If(expr_if) => {
195                println!("{}If statement:", self.print_indent());
196
197                println!("{}Condition:", self.print_indent());
198                self.indent += 2;
199                self.visit_expr(&expr_if.cond);
200                self.indent -= 2;
201
202                println!("{}Then branch:", self.print_indent());
203                self.indent += 2;
204                for stmt in &expr_if.then_branch.stmts {
205                    self.visit_stmt(stmt);
206                }
207                self.indent -= 2;
208
209                if let Some((_, else_branch)) = &expr_if.else_branch {
210                    println!("{}Else branch:", self.print_indent());
211                    self.indent += 2;
212                    self.visit_expr(&else_branch);
213                    self.indent -= 2;
214                }
215            }
216            syn::Expr::Loop(expr_loop) => {
217                println!("{}Loop:", self.print_indent());
218                self.indent += 2;
219                for stmt in &expr_loop.body.stmts {
220                    self.visit_stmt(stmt);
221                }
222                self.indent -= 2;
223            }
224            syn::Expr::While(expr_while) => {
225                println!("{}While loop:", self.print_indent());
226
227                println!("{}Condition:", self.print_indent());
228                self.indent += 2;
229                self.visit_expr(&expr_while.cond);
230                self.indent -= 2;
231
232                println!("{}Body:", self.print_indent());
233                self.indent += 2;
234                for stmt in &expr_while.body.stmts {
235                    self.visit_stmt(stmt);
236                }
237                self.indent -= 2;
238            }
239            syn::Expr::Return(expr_return) => {
240                println!("{}Return statement:", self.print_indent());
241                if let Some(expr) = &expr_return.expr {
242                    self.indent += 2;
243                    self.visit_expr(expr);
244                    self.indent -= 2;
245                }
246            }
247            _ => {
248                println!(
249                    "{}Other expression: {}",
250                    self.print_indent(),
251                    node.to_token_stream()
252                );
253            }
254        }
255    }
256
257    /// visit_stmt is defined in syn::visit::Visit
258    /// visit_stmt is called when a Rust statement is visited
259    ///
260    /// # Arguments
261    /// * `node`: &'ast syn::Stmt
262    ///
263    /// # Returns
264    /// * `()`
265    fn visit_stmt(&mut self, node: &'ast syn::Stmt) {
266        match node {
267            syn::Stmt::Local(local) => {
268                println!("{}Variable declaration:", self.print_indent());
269                if let syn::Pat::Ident(pat_ident) = &local.pat {
270                    println!("{}Name: {}", self.print_indent(), pat_ident.ident);
271                }
272
273                if let Some(init) = &local.init {
274                    println!("{}Initializer:", self.print_indent());
275                    self.indent += 2;
276                    self.visit_expr(&init.expr);
277                    self.indent -= 2;
278                }
279            }
280
281            syn::Stmt::Expr(expr, _) => {
282                println!("{}Expression statement:", self.print_indent());
283                self.indent += 2;
284                self.visit_expr(expr);
285                self.indent -= 2;
286            }
287
288            syn::Stmt::Item(item) => match item {
289                syn::Item::Fn(item_fn) => {
290                    self.visit_item_fn(item_fn);
291                }
292                syn::Item::Struct(item_struct) => {
293                    println!("{}Struct: {}", self.print_indent(), item_struct.ident);
294                    if !item_struct.fields.is_empty() {
295                        println!("{}Fields:", self.print_indent());
296                        self.indent += 2;
297                        for field in &item_struct.fields {
298                            if let Some(ident) = &field.ident {
299                                println!(
300                                    "{}Field: {} - Type: {}",
301                                    self.print_indent(),
302                                    ident,
303                                    field.ty.to_token_stream()
304                                );
305                            } else {
306                                println!(
307                                    "{}Tuple field: {}",
308                                    self.print_indent(),
309                                    field.ty.to_token_stream()
310                                );
311                            }
312                        }
313                        self.indent -= 2;
314                    }
315                }
316                syn::Item::Enum(item_enum) => {
317                    println!("{}Enum: {}", self.print_indent(), item_enum.ident);
318                    if !item_enum.variants.is_empty() {
319                        println!("{}Variants:", self.print_indent());
320                        self.indent += 2;
321                        for variant in &item_enum.variants {
322                            println!("{}Variant: {}", self.print_indent(), variant.ident);
323                        }
324                        self.indent -= 2;
325                    }
326                }
327                _ => {
328                    println!(
329                        "{}Other item: {}",
330                        self.print_indent(),
331                        item.to_token_stream()
332                    );
333                }
334            },
335            // TODO: add other statement
336            _ => {
337                println!(
338                    "{}Other statement: {}",
339                    self.print_indent(),
340                    node.to_token_stream()
341                );
342            }
343        }
344    }
345}
346
347/// parse rust source code to ast
348///
349/// # Arguments
350/// * `source`: &str - rust source code
351///
352/// # Returns
353/// * `Result<syn::File, syn::Error>` - ast
354///
355/// # Errors
356/// * `syn::Error` - parse error
357pub fn parse_rust_source(source: &str) -> Result<syn::File, syn::Error> {
358    syn::parse_file(source)
359}
360
361/// Parse Rust source code from a file into an AST
362///
363/// # Arguments
364/// * `path`: impl AsRef<Path> - path to the rust source file
365///
366/// # Returns
367/// * `io::Result<syn::File>` - ast
368///
369/// # Errors
370/// * `io::Error` - file read error
371/// * `syn::Error` - parse error (wrapped in io::Error)
372pub fn parse_rust_file<P: AsRef<Path>>(path: P) -> io::Result<syn::File> {
373    let source = fs::read_to_string(path)?;
374    let syntax =
375        syn::parse_file(&source).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
376
377    Ok(syntax)
378}
379
380/// print ast
381///
382/// # Arguments
383/// * `file`: &File - ast
384///
385/// # Returns
386/// * `()`
387pub fn print_ast(file: &File) {
388    println!("AST for Rust code:");
389    let mut visitor = TextVisitor::new();
390    visitor.visit_file(file);
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use std::io::Write;
397    use tempfile::NamedTempFile;
398
399    #[test]
400    fn test_parse_rust_file() {
401        let mut file = NamedTempFile::new().unwrap();
402        let test_code = r#"
403            fn test_function() {
404                println!("Hello, world!");
405            }
406        "#;
407
408        file.write_all(test_code.as_bytes()).unwrap();
409        file.flush().unwrap();
410
411        let ast = parse_rust_file(file.path()).unwrap();
412
413        assert_eq!(ast.items.len(), 1);
414        if let syn::Item::Fn(func) = &ast.items[0] {
415            assert_eq!(func.sig.ident.to_string(), "test_function");
416        } else {
417            panic!("Parsed item is not a function");
418        }
419    }
420
421    #[test]
422    fn test_parse_function() {
423        let source = r#"
424            fn add(a: i32, b: i32) -> i32 {
425                a + b
426            }
427        "#;
428
429        let file = parse_rust_source(source).unwrap();
430
431        // should be 1 item
432        assert_eq!(file.items.len(), 1);
433
434        // item should be function
435        if let syn::Item::Fn(func) = &file.items[0] {
436            assert_eq!(func.sig.ident.to_string(), "add");
437            assert_eq!(func.sig.inputs.len(), 2); // should be 2 parameters
438
439            // return type should be i32
440            if let syn::ReturnType::Type(_, return_type) = &func.sig.output {
441                if let syn::Type::Path(type_path) = &**return_type {
442                    let path_segment = &type_path.path.segments[0];
443                    assert_eq!(path_segment.ident.to_string(), "i32");
444                } else {
445                    panic!("Return type is not a path");
446                }
447            } else {
448                panic!("Function has no return type");
449            }
450
451            // should be 1 statement
452            assert_eq!(func.block.stmts.len(), 1);
453        } else {
454            panic!("Item is not a function");
455        }
456    }
457
458    #[test]
459    fn test_parse_struct() {
460        let source = r#"
461            struct Point {
462                x: f64,
463                y: f64,
464            }
465        "#;
466
467        let file = parse_rust_source(source).unwrap();
468
469        // should be 1 item
470        assert_eq!(file.items.len(), 1);
471
472        // item should be struct
473        if let syn::Item::Struct(struct_item) = &file.items[0] {
474            assert_eq!(struct_item.ident.to_string(), "Point");
475
476            // should be 2 fields
477            assert_eq!(struct_item.fields.iter().count(), 2);
478
479            // should be 2 fields
480            let fields: Vec<_> = struct_item.fields.iter().collect();
481
482            // x field
483            assert_eq!(fields[0].ident.as_ref().unwrap().to_string(), "x");
484            if let syn::Type::Path(type_path) = &fields[0].ty {
485                let path_segment = &type_path.path.segments[0];
486                assert_eq!(path_segment.ident.to_string(), "f64");
487            } else {
488                panic!("Field x is not a path type");
489            }
490
491            // y field
492            assert_eq!(fields[1].ident.as_ref().unwrap().to_string(), "y");
493            if let syn::Type::Path(type_path) = &fields[1].ty {
494                let path_segment = &type_path.path.segments[0];
495                assert_eq!(path_segment.ident.to_string(), "f64");
496            } else {
497                panic!("Field y is not a path type");
498            }
499        } else {
500            panic!("Item is not a struct");
501        }
502    }
503
504    #[test]
505    fn test_parse_enum() {
506        let source = r#"
507            enum Direction {
508                North,
509                East,
510                South,
511                West,
512            }
513        "#;
514
515        let file = parse_rust_source(source).unwrap();
516
517        // should be 1 item
518        assert_eq!(file.items.len(), 1);
519
520        // item should be enum
521        if let syn::Item::Enum(enum_item) = &file.items[0] {
522            assert_eq!(enum_item.ident.to_string(), "Direction");
523
524            // should be 4 variants
525            assert_eq!(enum_item.variants.len(), 4);
526
527            // should be 4 variants
528            let variant_names: Vec<String> = enum_item
529                .variants
530                .iter()
531                .map(|v| v.ident.to_string())
532                .collect();
533
534            assert_eq!(variant_names, vec!["North", "East", "South", "West"]);
535        } else {
536            panic!("Item is not an enum");
537        }
538    }
539
540    #[test]
541    fn test_parse_complex_expression() {
542        let source = r#"
543            fn complex_expr() {
544                let result = (10 + 20) * 30 / (5 - 2);
545                if result > 100 {
546                    println!("Large result: {}", result);
547                } else {
548                    println!("Small result: {}", result);
549                }
550            }
551        "#;
552
553        let file = parse_rust_source(source).unwrap();
554
555        // should be 1 item
556        assert_eq!(file.items.len(), 1);
557
558        // item should be function
559        if let syn::Item::Fn(func) = &file.items[0] {
560            assert_eq!(func.sig.ident.to_string(), "complex_expr");
561
562            // should be 2 statements
563            assert_eq!(func.block.stmts.len(), 2);
564
565            // first statement should be variable declaration
566            if let syn::Stmt::Local(local) = &func.block.stmts[0] {
567                assert!(local.init.is_some());
568
569                // variable name should be result
570                if let syn::Pat::Ident(pat_ident) = &local.pat {
571                    assert_eq!(pat_ident.ident.to_string(), "result");
572                } else {
573                    panic!("Variable declaration pattern is not an identifier");
574                }
575            } else {
576                panic!("First statement is not a variable declaration");
577            }
578
579            // second statement should be if expression
580            if let syn::Stmt::Expr(expr, _) = &func.block.stmts[1] {
581                if let syn::Expr::If(_) = expr {
582                    // OK
583                } else {
584                    panic!("Second statement is not an if expression");
585                }
586            } else {
587                panic!("Second statement is not an expression");
588            }
589        } else {
590            panic!("Item is not a function");
591        }
592    }
593
594    #[test]
595    fn test_parse_invalid_code() {
596        let source = r#"
597            fn invalid_function( {
598                let x = 10;
599            }
600        "#;
601
602        let result = parse_rust_source(source);
603        assert!(result.is_err(), "Expected parse error for invalid code");
604    }
605
606    #[test]
607    fn test_parse_multiple_items() {
608        let source = r#"
609            fn function1() -> i32 { 42 }
610            
611            struct MyStruct {
612                field: i32,
613            }
614            
615            fn function2(s: MyStruct) -> i32 {
616                s.field
617            }
618        "#;
619
620        let file = parse_rust_source(source).unwrap();
621
622        // should be 3 items
623        assert_eq!(file.items.len(), 3);
624
625        // first item should be function
626        if let syn::Item::Fn(func) = &file.items[0] {
627            assert_eq!(func.sig.ident.to_string(), "function1");
628        } else {
629            panic!("First item is not a function");
630        }
631
632        // second item should be struct
633        if let syn::Item::Struct(struct_item) = &file.items[1] {
634            assert_eq!(struct_item.ident.to_string(), "MyStruct");
635        } else {
636            panic!("Second item is not a struct");
637        }
638
639        // third item should be function
640        if let syn::Item::Fn(func) = &file.items[2] {
641            assert_eq!(func.sig.ident.to_string(), "function2");
642        } else {
643            panic!("Third item is not a function");
644        }
645    }
646}