1use quote::ToTokens;
2use std::fs;
3use std::io;
4use std::path::Path;
5
6use syn::{File, visit::Visit};
7
8pub struct TextVisitor {
13 indent: usize,
14}
15
16impl Default for TextVisitor {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl TextVisitor {
26 pub fn new() -> Self {
34 TextVisitor { indent: 0 }
35 }
36
37 fn print_indent(&self) -> String {
45 " ".repeat(self.indent)
46 }
47}
48
49impl<'ast> syn::visit::Visit<'ast> for TextVisitor {
52 fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
61 println!("{}Function: {}", self.print_indent(), node.sig.ident);
62 self.indent += 2;
63
64 if !node.sig.inputs.is_empty() {
65 println!("{}Parameters:", self.print_indent());
66 self.indent += 2;
67 for param in &node.sig.inputs {
68 match param {
69 syn::FnArg::Typed(pat_type) => {
70 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
71 println!(
72 "{}Parameter: {} - Type: {}",
73 self.print_indent(),
74 pat_ident.ident,
75 pat_type.ty.to_token_stream()
76 );
77 }
78 }
79 syn::FnArg::Receiver(receiver) => {
80 println!(
81 "{}Self receiver: {}",
82 self.print_indent(),
83 receiver.to_token_stream()
84 );
85 }
86 }
87 }
88 self.indent -= 2;
89 }
90
91 if let syn::ReturnType::Type(_, return_type) = &node.sig.output {
92 println!(
93 "{}Return type: {}",
94 self.print_indent(),
95 return_type.to_token_stream()
96 );
97 }
98
99 println!("{}Body:", self.print_indent());
100 self.indent += 2;
101 for stmt in &node.block.stmts {
102 self.visit_stmt(stmt);
103 }
104 self.indent -= 2;
105 }
106
107 fn visit_expr(&mut self, node: &'ast syn::Expr) {
116 match node {
117 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
118 syn::Lit::Int(lit_int) => {
119 println!(
120 "{}Integer literal: {}",
121 self.print_indent(),
122 lit_int.base10_digits()
123 );
124 }
125 syn::Lit::Float(lit_float) => {
126 println!(
127 "{}Float literal: {}",
128 self.print_indent(),
129 lit_float.base10_digits()
130 );
131 }
132 syn::Lit::Str(lit_str) => {
133 println!(
134 "{}String literal: \"{}\"",
135 self.print_indent(),
136 lit_str.value()
137 );
138 }
139 syn::Lit::Bool(lit_bool) => {
140 println!("{}Boolean literal: {}", self.print_indent(), lit_bool.value);
141 }
142 _ => {
143 println!(
144 "{}Other literal: {}",
145 self.print_indent(),
146 expr_lit.to_token_stream()
147 );
148 }
149 },
150 syn::Expr::Binary(expr_bin) => {
151 let op = match expr_bin.op {
152 syn::BinOp::Add(_) => "+",
153 syn::BinOp::Sub(_) => "-",
154 syn::BinOp::Mul(_) => "*",
155 syn::BinOp::Div(_) => "/",
156 syn::BinOp::Eq(_) => "==",
157 syn::BinOp::Lt(_) => "<",
158 syn::BinOp::Le(_) => "<=",
159 syn::BinOp::Ne(_) => "!=",
160 syn::BinOp::Ge(_) => ">=",
161 syn::BinOp::Gt(_) => ">",
162 _ => "other_operator",
163 };
164 println!("{}Binary expression: {}", self.print_indent(), op);
165
166 println!("{}Left:", self.print_indent());
167 self.indent += 2;
168 self.visit_expr(&expr_bin.left);
169 self.indent -= 2;
170
171 println!("{}Right:", self.print_indent());
172 self.indent += 2;
173 self.visit_expr(&expr_bin.right);
174 self.indent -= 2;
175 }
176 syn::Expr::Call(expr_call) => {
177 println!("{}Function call:", self.print_indent());
178
179 println!("{}Function:", self.print_indent());
180 self.indent += 2;
181 self.visit_expr(&expr_call.func);
182 self.indent -= 2;
183
184 if !expr_call.args.is_empty() {
185 println!("{}Arguments:", self.print_indent());
186 self.indent += 2;
187 for arg in &expr_call.args {
188 self.visit_expr(arg);
189 }
190 self.indent -= 2;
191 }
192 }
193 syn::Expr::Path(expr_path) => {
194 println!(
195 "{}Identifier: {}",
196 self.print_indent(),
197 expr_path.to_token_stream()
198 );
199 }
200 syn::Expr::If(expr_if) => {
201 println!("{}If statement:", self.print_indent());
202
203 println!("{}Condition:", self.print_indent());
204 self.indent += 2;
205 self.visit_expr(&expr_if.cond);
206 self.indent -= 2;
207
208 println!("{}Then branch:", self.print_indent());
209 self.indent += 2;
210 for stmt in &expr_if.then_branch.stmts {
211 self.visit_stmt(stmt);
212 }
213 self.indent -= 2;
214
215 if let Some((_, else_branch)) = &expr_if.else_branch {
216 println!("{}Else branch:", self.print_indent());
217 self.indent += 2;
218 self.visit_expr(else_branch);
219 self.indent -= 2;
220 }
221 }
222 syn::Expr::Loop(expr_loop) => {
223 println!("{}Loop:", self.print_indent());
224 self.indent += 2;
225 for stmt in &expr_loop.body.stmts {
226 self.visit_stmt(stmt);
227 }
228 self.indent -= 2;
229 }
230 syn::Expr::While(expr_while) => {
231 println!("{}While loop:", self.print_indent());
232
233 println!("{}Condition:", self.print_indent());
234 self.indent += 2;
235 self.visit_expr(&expr_while.cond);
236 self.indent -= 2;
237
238 println!("{}Body:", self.print_indent());
239 self.indent += 2;
240 for stmt in &expr_while.body.stmts {
241 self.visit_stmt(stmt);
242 }
243 self.indent -= 2;
244 }
245 syn::Expr::Return(expr_return) => {
246 println!("{}Return statement:", self.print_indent());
247 if let Some(expr) = &expr_return.expr {
248 self.indent += 2;
249 self.visit_expr(expr);
250 self.indent -= 2;
251 }
252 }
253 _ => {
254 println!(
255 "{}Other expression: {}",
256 self.print_indent(),
257 node.to_token_stream()
258 );
259 }
260 }
261 }
262
263 fn visit_stmt(&mut self, node: &'ast syn::Stmt) {
272 match node {
273 syn::Stmt::Local(local) => {
274 println!("{}Variable declaration:", self.print_indent());
275 if let syn::Pat::Ident(pat_ident) = &local.pat {
276 println!("{}Name: {}", self.print_indent(), pat_ident.ident);
277 }
278
279 if let Some(init) = &local.init {
280 println!("{}Initializer:", self.print_indent());
281 self.indent += 2;
282 self.visit_expr(&init.expr);
283 self.indent -= 2;
284 }
285 }
286
287 syn::Stmt::Expr(expr, _) => {
288 println!("{}Expression statement:", self.print_indent());
289 self.indent += 2;
290 self.visit_expr(expr);
291 self.indent -= 2;
292 }
293
294 syn::Stmt::Item(item) => match item {
295 syn::Item::Fn(item_fn) => {
296 self.visit_item_fn(item_fn);
297 }
298 syn::Item::Struct(item_struct) => {
299 println!("{}Struct: {}", self.print_indent(), item_struct.ident);
300 if !item_struct.fields.is_empty() {
301 println!("{}Fields:", self.print_indent());
302 self.indent += 2;
303 for field in &item_struct.fields {
304 if let Some(ident) = &field.ident {
305 println!(
306 "{}Field: {} - Type: {}",
307 self.print_indent(),
308 ident,
309 field.ty.to_token_stream()
310 );
311 } else {
312 println!(
313 "{}Tuple field: {}",
314 self.print_indent(),
315 field.ty.to_token_stream()
316 );
317 }
318 }
319 self.indent -= 2;
320 }
321 }
322 syn::Item::Enum(item_enum) => {
323 println!("{}Enum: {}", self.print_indent(), item_enum.ident);
324 if !item_enum.variants.is_empty() {
325 println!("{}Variants:", self.print_indent());
326 self.indent += 2;
327 for variant in &item_enum.variants {
328 println!("{}Variant: {}", self.print_indent(), variant.ident);
329 }
330 self.indent -= 2;
331 }
332 }
333 _ => {
334 println!(
335 "{}Other item: {}",
336 self.print_indent(),
337 item.to_token_stream()
338 );
339 }
340 },
341 _ => {
343 println!(
344 "{}Other statement: {}",
345 self.print_indent(),
346 node.to_token_stream()
347 );
348 }
349 }
350 }
351}
352
353pub fn parse_rust_source(source: &str) -> Result<syn::File, syn::Error> {
364 syn::parse_file(source)
365}
366
367pub fn parse_rust_file<P: AsRef<Path>>(path: P) -> io::Result<syn::File> {
379 let source = fs::read_to_string(path)?;
380 let syntax =
381 syn::parse_file(&source).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
382
383 Ok(syntax)
384}
385
386pub fn print_ast(file: &File) {
394 println!("AST for Rust code:");
395 let mut visitor = TextVisitor::new();
396 visitor.visit_file(file);
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use std::io::Write;
403 use tempfile::NamedTempFile;
404
405 #[test]
406 fn test_parse_rust_file() {
407 let mut file = NamedTempFile::new().unwrap();
408 let test_code = r#"
409 fn test_function() {
410 println!("Hello, world!");
411 }
412 "#;
413
414 file.write_all(test_code.as_bytes()).unwrap();
415 file.flush().unwrap();
416
417 let ast = parse_rust_file(file.path()).unwrap();
418
419 assert_eq!(ast.items.len(), 1);
420 if let syn::Item::Fn(func) = &ast.items[0] {
421 assert_eq!(func.sig.ident.to_string(), "test_function");
422 } else {
423 panic!("Parsed item is not a function");
424 }
425 }
426
427 #[test]
428 fn test_parse_function() {
429 let source = r#"
430 fn add(a: i32, b: i32) -> i32 {
431 a + b
432 }
433 "#;
434
435 let file = parse_rust_source(source).unwrap();
436
437 assert_eq!(file.items.len(), 1);
439
440 if let syn::Item::Fn(func) = &file.items[0] {
442 assert_eq!(func.sig.ident.to_string(), "add");
443 assert_eq!(func.sig.inputs.len(), 2); if let syn::ReturnType::Type(_, return_type) = &func.sig.output {
447 if let syn::Type::Path(type_path) = &**return_type {
448 let path_segment = &type_path.path.segments[0];
449 assert_eq!(path_segment.ident.to_string(), "i32");
450 } else {
451 panic!("Return type is not a path");
452 }
453 } else {
454 panic!("Function has no return type");
455 }
456
457 assert_eq!(func.block.stmts.len(), 1);
459 } else {
460 panic!("Item is not a function");
461 }
462 }
463
464 #[test]
465 fn test_parse_struct() {
466 let source = r#"
467 struct Point {
468 x: f64,
469 y: f64,
470 }
471 "#;
472
473 let file = parse_rust_source(source).unwrap();
474
475 assert_eq!(file.items.len(), 1);
477
478 if let syn::Item::Struct(struct_item) = &file.items[0] {
480 assert_eq!(struct_item.ident.to_string(), "Point");
481
482 assert_eq!(struct_item.fields.iter().count(), 2);
484
485 let fields: Vec<_> = struct_item.fields.iter().collect();
487
488 assert_eq!(fields[0].ident.as_ref().unwrap().to_string(), "x");
490 if let syn::Type::Path(type_path) = &fields[0].ty {
491 let path_segment = &type_path.path.segments[0];
492 assert_eq!(path_segment.ident.to_string(), "f64");
493 } else {
494 panic!("Field x is not a path type");
495 }
496
497 assert_eq!(fields[1].ident.as_ref().unwrap().to_string(), "y");
499 if let syn::Type::Path(type_path) = &fields[1].ty {
500 let path_segment = &type_path.path.segments[0];
501 assert_eq!(path_segment.ident.to_string(), "f64");
502 } else {
503 panic!("Field y is not a path type");
504 }
505 } else {
506 panic!("Item is not a struct");
507 }
508 }
509
510 #[test]
511 fn test_parse_enum() {
512 let source = r#"
513 enum Direction {
514 North,
515 East,
516 South,
517 West,
518 }
519 "#;
520
521 let file = parse_rust_source(source).unwrap();
522
523 assert_eq!(file.items.len(), 1);
525
526 if let syn::Item::Enum(enum_item) = &file.items[0] {
528 assert_eq!(enum_item.ident.to_string(), "Direction");
529
530 assert_eq!(enum_item.variants.len(), 4);
532
533 let variant_names: Vec<String> = enum_item
535 .variants
536 .iter()
537 .map(|v| v.ident.to_string())
538 .collect();
539
540 assert_eq!(variant_names, vec!["North", "East", "South", "West"]);
541 } else {
542 panic!("Item is not an enum");
543 }
544 }
545
546 #[test]
547 fn test_parse_complex_expression() {
548 let source = r#"
549 fn complex_expr() {
550 let result = (10 + 20) * 30 / (5 - 2);
551 if result > 100 {
552 println!("Large result: {}", result);
553 } else {
554 println!("Small result: {}", result);
555 }
556 }
557 "#;
558
559 let file = parse_rust_source(source).unwrap();
560
561 assert_eq!(file.items.len(), 1);
563
564 if let syn::Item::Fn(func) = &file.items[0] {
566 assert_eq!(func.sig.ident.to_string(), "complex_expr");
567
568 assert_eq!(func.block.stmts.len(), 2);
570
571 if let syn::Stmt::Local(local) = &func.block.stmts[0] {
573 assert!(local.init.is_some());
574
575 if let syn::Pat::Ident(pat_ident) = &local.pat {
577 assert_eq!(pat_ident.ident.to_string(), "result");
578 } else {
579 panic!("Variable declaration pattern is not an identifier");
580 }
581 } else {
582 panic!("First statement is not a variable declaration");
583 }
584
585 if let syn::Stmt::Expr(expr, _) = &func.block.stmts[1] {
587 if let syn::Expr::If(_) = expr {
588 } else {
590 panic!("Second statement is not an if expression");
591 }
592 } else {
593 panic!("Second statement is not an expression");
594 }
595 } else {
596 panic!("Item is not a function");
597 }
598 }
599
600 #[test]
601 fn test_parse_invalid_code() {
602 let source = r#"
603 fn invalid_function( {
604 let x = 10;
605 }
606 "#;
607
608 let result = parse_rust_source(source);
609 assert!(result.is_err(), "Expected parse error for invalid code");
610 }
611
612 #[test]
613 fn test_parse_multiple_items() {
614 let source = r#"
615 fn function1() -> i32 { 42 }
616
617 struct MyStruct {
618 field: i32,
619 }
620
621 fn function2(s: MyStruct) -> i32 {
622 s.field
623 }
624 "#;
625
626 let file = parse_rust_source(source).unwrap();
627
628 assert_eq!(file.items.len(), 3);
630
631 if let syn::Item::Fn(func) = &file.items[0] {
633 assert_eq!(func.sig.ident.to_string(), "function1");
634 } else {
635 panic!("First item is not a function");
636 }
637
638 if let syn::Item::Struct(struct_item) = &file.items[1] {
640 assert_eq!(struct_item.ident.to_string(), "MyStruct");
641 } else {
642 panic!("Second item is not a struct");
643 }
644
645 if let syn::Item::Fn(func) = &file.items[2] {
647 assert_eq!(func.sig.ident.to_string(), "function2");
648 } else {
649 panic!("Third item is not a function");
650 }
651 }
652}