rusty_ast/
text_visitor.rs

1use quote::ToTokens;
2use std::fs;
3use std::io;
4use std::path::Path;
5
6use syn::{File, visit::Visit};
7
8/// TextVisitor is a visitor that prints the AST in a text format
9///
10/// # Fields
11/// * `indent`: usize - the current indentation level
12pub struct TextVisitor {
13    indent: usize,
14}
15
16/// # Methods
17/// * `new()`: creates a new TextVisitor
18/// * `print_indent()`: prints the current indentation level
19impl Default for TextVisitor {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl TextVisitor {
26    /// new
27    ///
28    /// # Arguments
29    /// * `()`
30    ///
31    /// # Returns
32    /// * `TextVisitor` - a new TextVisitor
33    pub fn new() -> Self {
34        TextVisitor { indent: 0 }
35    }
36
37    /// print_indent
38    ///
39    /// # Arguments
40    /// * `self`: &Self - the TextVisitor
41    ///
42    /// # Returns
43    /// * `String` - the current indentation level
44    fn print_indent(&self) -> String {
45        " ".repeat(self.indent)
46    }
47}
48
49/// implement Visit trait for AstText
50/// Visit trait is defined in syn::visit
51impl<'ast> syn::visit::Visit<'ast> for TextVisitor {
52    /// visit_item_fn is defined in syn::visit::Visit
53    /// visit_item_fn is called when a Rust function definition is visited
54    ///
55    /// # Arguments
56    /// * `node`: &'ast syn::ItemFn
57    ///
58    /// # Returns
59    /// * `()`
60    fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
61        println!("{}Function: {}", self.print_indent(), node.sig.ident);
62        self.indent += 2;
63
64        if !node.sig.inputs.is_empty() {
65            println!("{}Parameters:", self.print_indent());
66            self.indent += 2;
67            for param in &node.sig.inputs {
68                match param {
69                    syn::FnArg::Typed(pat_type) => {
70                        if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
71                            println!(
72                                "{}Parameter: {} - Type: {}",
73                                self.print_indent(),
74                                pat_ident.ident,
75                                pat_type.ty.to_token_stream()
76                            );
77                        }
78                    }
79                    syn::FnArg::Receiver(receiver) => {
80                        println!(
81                            "{}Self receiver: {}",
82                            self.print_indent(),
83                            receiver.to_token_stream()
84                        );
85                    }
86                }
87            }
88            self.indent -= 2;
89        }
90
91        if let syn::ReturnType::Type(_, return_type) = &node.sig.output {
92            println!(
93                "{}Return type: {}",
94                self.print_indent(),
95                return_type.to_token_stream()
96            );
97        }
98
99        println!("{}Body:", self.print_indent());
100        self.indent += 2;
101        for stmt in &node.block.stmts {
102            self.visit_stmt(stmt);
103        }
104        self.indent -= 2;
105    }
106
107    /// visit_expr is defined in syn::visit::Visit
108    /// visit_expr is called when a Rust expression is visited
109    ///
110    /// # Arguments
111    /// * `node`: &'ast syn::Expr
112    ///
113    /// # Returns
114    /// * `()`
115    fn visit_expr(&mut self, node: &'ast syn::Expr) {
116        match node {
117            syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
118                syn::Lit::Int(lit_int) => {
119                    println!(
120                        "{}Integer literal: {}",
121                        self.print_indent(),
122                        lit_int.base10_digits()
123                    );
124                }
125                syn::Lit::Float(lit_float) => {
126                    println!(
127                        "{}Float literal: {}",
128                        self.print_indent(),
129                        lit_float.base10_digits()
130                    );
131                }
132                syn::Lit::Str(lit_str) => {
133                    println!(
134                        "{}String literal: \"{}\"",
135                        self.print_indent(),
136                        lit_str.value()
137                    );
138                }
139                syn::Lit::Bool(lit_bool) => {
140                    println!("{}Boolean literal: {}", self.print_indent(), lit_bool.value);
141                }
142                _ => {
143                    println!(
144                        "{}Other literal: {}",
145                        self.print_indent(),
146                        expr_lit.to_token_stream()
147                    );
148                }
149            },
150            syn::Expr::Binary(expr_bin) => {
151                let op = match expr_bin.op {
152                    syn::BinOp::Add(_) => "+",
153                    syn::BinOp::Sub(_) => "-",
154                    syn::BinOp::Mul(_) => "*",
155                    syn::BinOp::Div(_) => "/",
156                    syn::BinOp::Eq(_) => "==",
157                    syn::BinOp::Lt(_) => "<",
158                    syn::BinOp::Le(_) => "<=",
159                    syn::BinOp::Ne(_) => "!=",
160                    syn::BinOp::Ge(_) => ">=",
161                    syn::BinOp::Gt(_) => ">",
162                    _ => "other_operator",
163                };
164                println!("{}Binary expression: {}", self.print_indent(), op);
165
166                println!("{}Left:", self.print_indent());
167                self.indent += 2;
168                self.visit_expr(&expr_bin.left);
169                self.indent -= 2;
170
171                println!("{}Right:", self.print_indent());
172                self.indent += 2;
173                self.visit_expr(&expr_bin.right);
174                self.indent -= 2;
175            }
176            syn::Expr::Call(expr_call) => {
177                println!("{}Function call:", self.print_indent());
178
179                println!("{}Function:", self.print_indent());
180                self.indent += 2;
181                self.visit_expr(&expr_call.func);
182                self.indent -= 2;
183
184                if !expr_call.args.is_empty() {
185                    println!("{}Arguments:", self.print_indent());
186                    self.indent += 2;
187                    for arg in &expr_call.args {
188                        self.visit_expr(arg);
189                    }
190                    self.indent -= 2;
191                }
192            }
193            syn::Expr::Path(expr_path) => {
194                println!(
195                    "{}Identifier: {}",
196                    self.print_indent(),
197                    expr_path.to_token_stream()
198                );
199            }
200            syn::Expr::If(expr_if) => {
201                println!("{}If statement:", self.print_indent());
202
203                println!("{}Condition:", self.print_indent());
204                self.indent += 2;
205                self.visit_expr(&expr_if.cond);
206                self.indent -= 2;
207
208                println!("{}Then branch:", self.print_indent());
209                self.indent += 2;
210                for stmt in &expr_if.then_branch.stmts {
211                    self.visit_stmt(stmt);
212                }
213                self.indent -= 2;
214
215                if let Some((_, else_branch)) = &expr_if.else_branch {
216                    println!("{}Else branch:", self.print_indent());
217                    self.indent += 2;
218                    self.visit_expr(else_branch);
219                    self.indent -= 2;
220                }
221            }
222            syn::Expr::Loop(expr_loop) => {
223                println!("{}Loop:", self.print_indent());
224                self.indent += 2;
225                for stmt in &expr_loop.body.stmts {
226                    self.visit_stmt(stmt);
227                }
228                self.indent -= 2;
229            }
230            syn::Expr::While(expr_while) => {
231                println!("{}While loop:", self.print_indent());
232
233                println!("{}Condition:", self.print_indent());
234                self.indent += 2;
235                self.visit_expr(&expr_while.cond);
236                self.indent -= 2;
237
238                println!("{}Body:", self.print_indent());
239                self.indent += 2;
240                for stmt in &expr_while.body.stmts {
241                    self.visit_stmt(stmt);
242                }
243                self.indent -= 2;
244            }
245            syn::Expr::Return(expr_return) => {
246                println!("{}Return statement:", self.print_indent());
247                if let Some(expr) = &expr_return.expr {
248                    self.indent += 2;
249                    self.visit_expr(expr);
250                    self.indent -= 2;
251                }
252            }
253            _ => {
254                println!(
255                    "{}Other expression: {}",
256                    self.print_indent(),
257                    node.to_token_stream()
258                );
259            }
260        }
261    }
262
263    /// visit_stmt is defined in syn::visit::Visit
264    /// visit_stmt is called when a Rust statement is visited
265    ///
266    /// # Arguments
267    /// * `node`: &'ast syn::Stmt
268    ///
269    /// # Returns
270    /// * `()`
271    fn visit_stmt(&mut self, node: &'ast syn::Stmt) {
272        match node {
273            syn::Stmt::Local(local) => {
274                println!("{}Variable declaration:", self.print_indent());
275                if let syn::Pat::Ident(pat_ident) = &local.pat {
276                    println!("{}Name: {}", self.print_indent(), pat_ident.ident);
277                }
278
279                if let Some(init) = &local.init {
280                    println!("{}Initializer:", self.print_indent());
281                    self.indent += 2;
282                    self.visit_expr(&init.expr);
283                    self.indent -= 2;
284                }
285            }
286
287            syn::Stmt::Expr(expr, _) => {
288                println!("{}Expression statement:", self.print_indent());
289                self.indent += 2;
290                self.visit_expr(expr);
291                self.indent -= 2;
292            }
293
294            syn::Stmt::Item(item) => match item {
295                syn::Item::Fn(item_fn) => {
296                    self.visit_item_fn(item_fn);
297                }
298                syn::Item::Struct(item_struct) => {
299                    println!("{}Struct: {}", self.print_indent(), item_struct.ident);
300                    if !item_struct.fields.is_empty() {
301                        println!("{}Fields:", self.print_indent());
302                        self.indent += 2;
303                        for field in &item_struct.fields {
304                            if let Some(ident) = &field.ident {
305                                println!(
306                                    "{}Field: {} - Type: {}",
307                                    self.print_indent(),
308                                    ident,
309                                    field.ty.to_token_stream()
310                                );
311                            } else {
312                                println!(
313                                    "{}Tuple field: {}",
314                                    self.print_indent(),
315                                    field.ty.to_token_stream()
316                                );
317                            }
318                        }
319                        self.indent -= 2;
320                    }
321                }
322                syn::Item::Enum(item_enum) => {
323                    println!("{}Enum: {}", self.print_indent(), item_enum.ident);
324                    if !item_enum.variants.is_empty() {
325                        println!("{}Variants:", self.print_indent());
326                        self.indent += 2;
327                        for variant in &item_enum.variants {
328                            println!("{}Variant: {}", self.print_indent(), variant.ident);
329                        }
330                        self.indent -= 2;
331                    }
332                }
333                _ => {
334                    println!(
335                        "{}Other item: {}",
336                        self.print_indent(),
337                        item.to_token_stream()
338                    );
339                }
340            },
341            // TODO: add other statement
342            _ => {
343                println!(
344                    "{}Other statement: {}",
345                    self.print_indent(),
346                    node.to_token_stream()
347                );
348            }
349        }
350    }
351}
352
353/// parse rust source code to ast
354///
355/// # Arguments
356/// * `source`: &str - rust source code
357///
358/// # Returns
359/// * `Result<syn::File, syn::Error>` - ast
360///
361/// # Errors
362/// * `syn::Error` - parse error
363pub fn parse_rust_source(source: &str) -> Result<syn::File, syn::Error> {
364    syn::parse_file(source)
365}
366
367/// Parse Rust source code from a file into an AST
368///
369/// # Arguments
370/// * `path`: impl AsRef<Path> - path to the rust source file
371///
372/// # Returns
373/// * `io::Result<syn::File>` - ast
374///
375/// # Errors
376/// * `io::Error` - file read error
377/// * `syn::Error` - parse error (wrapped in io::Error)
378pub fn parse_rust_file<P: AsRef<Path>>(path: P) -> io::Result<syn::File> {
379    let source = fs::read_to_string(path)?;
380    let syntax =
381        syn::parse_file(&source).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
382
383    Ok(syntax)
384}
385
386/// print ast
387///
388/// # Arguments
389/// * `file`: &File - ast
390///
391/// # Returns
392/// * `()`
393pub fn print_ast(file: &File) {
394    println!("AST for Rust code:");
395    let mut visitor = TextVisitor::new();
396    visitor.visit_file(file);
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use std::io::Write;
403    use tempfile::NamedTempFile;
404
405    #[test]
406    fn test_parse_rust_file() {
407        let mut file = NamedTempFile::new().unwrap();
408        let test_code = r#"
409            fn test_function() {
410                println!("Hello, world!");
411            }
412        "#;
413
414        file.write_all(test_code.as_bytes()).unwrap();
415        file.flush().unwrap();
416
417        let ast = parse_rust_file(file.path()).unwrap();
418
419        assert_eq!(ast.items.len(), 1);
420        if let syn::Item::Fn(func) = &ast.items[0] {
421            assert_eq!(func.sig.ident.to_string(), "test_function");
422        } else {
423            panic!("Parsed item is not a function");
424        }
425    }
426
427    #[test]
428    fn test_parse_function() {
429        let source = r#"
430            fn add(a: i32, b: i32) -> i32 {
431                a + b
432            }
433        "#;
434
435        let file = parse_rust_source(source).unwrap();
436
437        // should be 1 item
438        assert_eq!(file.items.len(), 1);
439
440        // item should be function
441        if let syn::Item::Fn(func) = &file.items[0] {
442            assert_eq!(func.sig.ident.to_string(), "add");
443            assert_eq!(func.sig.inputs.len(), 2); // should be 2 parameters
444
445            // return type should be i32
446            if let syn::ReturnType::Type(_, return_type) = &func.sig.output {
447                if let syn::Type::Path(type_path) = &**return_type {
448                    let path_segment = &type_path.path.segments[0];
449                    assert_eq!(path_segment.ident.to_string(), "i32");
450                } else {
451                    panic!("Return type is not a path");
452                }
453            } else {
454                panic!("Function has no return type");
455            }
456
457            // should be 1 statement
458            assert_eq!(func.block.stmts.len(), 1);
459        } else {
460            panic!("Item is not a function");
461        }
462    }
463
464    #[test]
465    fn test_parse_struct() {
466        let source = r#"
467            struct Point {
468                x: f64,
469                y: f64,
470            }
471        "#;
472
473        let file = parse_rust_source(source).unwrap();
474
475        // should be 1 item
476        assert_eq!(file.items.len(), 1);
477
478        // item should be struct
479        if let syn::Item::Struct(struct_item) = &file.items[0] {
480            assert_eq!(struct_item.ident.to_string(), "Point");
481
482            // should be 2 fields
483            assert_eq!(struct_item.fields.iter().count(), 2);
484
485            // should be 2 fields
486            let fields: Vec<_> = struct_item.fields.iter().collect();
487
488            // x field
489            assert_eq!(fields[0].ident.as_ref().unwrap().to_string(), "x");
490            if let syn::Type::Path(type_path) = &fields[0].ty {
491                let path_segment = &type_path.path.segments[0];
492                assert_eq!(path_segment.ident.to_string(), "f64");
493            } else {
494                panic!("Field x is not a path type");
495            }
496
497            // y field
498            assert_eq!(fields[1].ident.as_ref().unwrap().to_string(), "y");
499            if let syn::Type::Path(type_path) = &fields[1].ty {
500                let path_segment = &type_path.path.segments[0];
501                assert_eq!(path_segment.ident.to_string(), "f64");
502            } else {
503                panic!("Field y is not a path type");
504            }
505        } else {
506            panic!("Item is not a struct");
507        }
508    }
509
510    #[test]
511    fn test_parse_enum() {
512        let source = r#"
513            enum Direction {
514                North,
515                East,
516                South,
517                West,
518            }
519        "#;
520
521        let file = parse_rust_source(source).unwrap();
522
523        // should be 1 item
524        assert_eq!(file.items.len(), 1);
525
526        // item should be enum
527        if let syn::Item::Enum(enum_item) = &file.items[0] {
528            assert_eq!(enum_item.ident.to_string(), "Direction");
529
530            // should be 4 variants
531            assert_eq!(enum_item.variants.len(), 4);
532
533            // should be 4 variants
534            let variant_names: Vec<String> = enum_item
535                .variants
536                .iter()
537                .map(|v| v.ident.to_string())
538                .collect();
539
540            assert_eq!(variant_names, vec!["North", "East", "South", "West"]);
541        } else {
542            panic!("Item is not an enum");
543        }
544    }
545
546    #[test]
547    fn test_parse_complex_expression() {
548        let source = r#"
549            fn complex_expr() {
550                let result = (10 + 20) * 30 / (5 - 2);
551                if result > 100 {
552                    println!("Large result: {}", result);
553                } else {
554                    println!("Small result: {}", result);
555                }
556            }
557        "#;
558
559        let file = parse_rust_source(source).unwrap();
560
561        // should be 1 item
562        assert_eq!(file.items.len(), 1);
563
564        // item should be function
565        if let syn::Item::Fn(func) = &file.items[0] {
566            assert_eq!(func.sig.ident.to_string(), "complex_expr");
567
568            // should be 2 statements
569            assert_eq!(func.block.stmts.len(), 2);
570
571            // first statement should be variable declaration
572            if let syn::Stmt::Local(local) = &func.block.stmts[0] {
573                assert!(local.init.is_some());
574
575                // variable name should be result
576                if let syn::Pat::Ident(pat_ident) = &local.pat {
577                    assert_eq!(pat_ident.ident.to_string(), "result");
578                } else {
579                    panic!("Variable declaration pattern is not an identifier");
580                }
581            } else {
582                panic!("First statement is not a variable declaration");
583            }
584
585            // second statement should be if expression
586            if let syn::Stmt::Expr(expr, _) = &func.block.stmts[1] {
587                if let syn::Expr::If(_) = expr {
588                    // OK
589                } else {
590                    panic!("Second statement is not an if expression");
591                }
592            } else {
593                panic!("Second statement is not an expression");
594            }
595        } else {
596            panic!("Item is not a function");
597        }
598    }
599
600    #[test]
601    fn test_parse_invalid_code() {
602        let source = r#"
603            fn invalid_function( {
604                let x = 10;
605            }
606        "#;
607
608        let result = parse_rust_source(source);
609        assert!(result.is_err(), "Expected parse error for invalid code");
610    }
611
612    #[test]
613    fn test_parse_multiple_items() {
614        let source = r#"
615            fn function1() -> i32 { 42 }
616            
617            struct MyStruct {
618                field: i32,
619            }
620            
621            fn function2(s: MyStruct) -> i32 {
622                s.field
623            }
624        "#;
625
626        let file = parse_rust_source(source).unwrap();
627
628        // should be 3 items
629        assert_eq!(file.items.len(), 3);
630
631        // first item should be function
632        if let syn::Item::Fn(func) = &file.items[0] {
633            assert_eq!(func.sig.ident.to_string(), "function1");
634        } else {
635            panic!("First item is not a function");
636        }
637
638        // second item should be struct
639        if let syn::Item::Struct(struct_item) = &file.items[1] {
640            assert_eq!(struct_item.ident.to_string(), "MyStruct");
641        } else {
642            panic!("Second item is not a struct");
643        }
644
645        // third item should be function
646        if let syn::Item::Fn(func) = &file.items[2] {
647            assert_eq!(func.sig.ident.to_string(), "function2");
648        } else {
649            panic!("Third item is not a function");
650        }
651    }
652}