1use quote::ToTokens;
2use std::fs;
3use std::io;
4use std::path::Path;
5
6use syn::{File, visit::Visit};
7
8pub struct TextVisitor {
13 indent: usize,
14}
15
16impl TextVisitor {
20 pub fn new() -> Self {
28 TextVisitor { indent: 0 }
29 }
30
31 fn print_indent(&self) -> String {
39 " ".repeat(self.indent)
40 }
41}
42
43impl<'ast> syn::visit::Visit<'ast> for TextVisitor {
46 fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
55 println!("{}Function: {}", self.print_indent(), node.sig.ident);
56 self.indent += 2;
57
58 if !node.sig.inputs.is_empty() {
59 println!("{}Parameters:", self.print_indent());
60 self.indent += 2;
61 for param in &node.sig.inputs {
62 match param {
63 syn::FnArg::Typed(pat_type) => {
64 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
65 println!(
66 "{}Parameter: {} - Type: {}",
67 self.print_indent(),
68 pat_ident.ident,
69 pat_type.ty.to_token_stream()
70 );
71 }
72 }
73 syn::FnArg::Receiver(receiver) => {
74 println!(
75 "{}Self receiver: {}",
76 self.print_indent(),
77 receiver.to_token_stream()
78 );
79 }
80 }
81 }
82 self.indent -= 2;
83 }
84
85 if let syn::ReturnType::Type(_, return_type) = &node.sig.output {
86 println!(
87 "{}Return type: {}",
88 self.print_indent(),
89 return_type.to_token_stream()
90 );
91 }
92
93 println!("{}Body:", self.print_indent());
94 self.indent += 2;
95 for stmt in &node.block.stmts {
96 self.visit_stmt(stmt);
97 }
98 self.indent -= 4;
99 }
100
101 fn visit_expr(&mut self, node: &'ast syn::Expr) {
110 match node {
111 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
112 syn::Lit::Int(lit_int) => {
113 println!(
114 "{}Integer literal: {}",
115 self.print_indent(),
116 lit_int.base10_digits()
117 );
118 }
119 syn::Lit::Float(lit_float) => {
120 println!(
121 "{}Float literal: {}",
122 self.print_indent(),
123 lit_float.base10_digits()
124 );
125 }
126 syn::Lit::Str(lit_str) => {
127 println!(
128 "{}String literal: \"{}\"",
129 self.print_indent(),
130 lit_str.value()
131 );
132 }
133 syn::Lit::Bool(lit_bool) => {
134 println!("{}Boolean literal: {}", self.print_indent(), lit_bool.value);
135 }
136 _ => {
137 println!(
138 "{}Other literal: {}",
139 self.print_indent(),
140 expr_lit.to_token_stream()
141 );
142 }
143 },
144 syn::Expr::Binary(expr_bin) => {
145 let op = match expr_bin.op {
146 syn::BinOp::Add(_) => "+",
147 syn::BinOp::Sub(_) => "-",
148 syn::BinOp::Mul(_) => "*",
149 syn::BinOp::Div(_) => "/",
150 syn::BinOp::Eq(_) => "==",
151 syn::BinOp::Lt(_) => "<",
152 syn::BinOp::Le(_) => "<=",
153 syn::BinOp::Ne(_) => "!=",
154 syn::BinOp::Ge(_) => ">=",
155 syn::BinOp::Gt(_) => ">",
156 _ => "other_operator",
157 };
158 println!("{}Binary expression: {}", self.print_indent(), op);
159
160 println!("{}Left:", self.print_indent());
161 self.indent += 2;
162 self.visit_expr(&expr_bin.left);
163 self.indent -= 2;
164
165 println!("{}Right:", self.print_indent());
166 self.indent += 2;
167 self.visit_expr(&expr_bin.right);
168 self.indent -= 2;
169 }
170 syn::Expr::Call(expr_call) => {
171 println!("{}Function call:", self.print_indent());
172
173 println!("{}Function:", self.print_indent());
174 self.indent += 2;
175 self.visit_expr(&expr_call.func);
176 self.indent -= 2;
177
178 if !expr_call.args.is_empty() {
179 println!("{}Arguments:", self.print_indent());
180 self.indent += 2;
181 for arg in &expr_call.args {
182 self.visit_expr(arg);
183 }
184 self.indent -= 2;
185 }
186 }
187 syn::Expr::Path(expr_path) => {
188 println!(
189 "{}Identifier: {}",
190 self.print_indent(),
191 expr_path.to_token_stream()
192 );
193 }
194 syn::Expr::If(expr_if) => {
195 println!("{}If statement:", self.print_indent());
196
197 println!("{}Condition:", self.print_indent());
198 self.indent += 2;
199 self.visit_expr(&expr_if.cond);
200 self.indent -= 2;
201
202 println!("{}Then branch:", self.print_indent());
203 self.indent += 2;
204 for stmt in &expr_if.then_branch.stmts {
205 self.visit_stmt(stmt);
206 }
207 self.indent -= 2;
208
209 if let Some((_, else_branch)) = &expr_if.else_branch {
210 println!("{}Else branch:", self.print_indent());
211 self.indent += 2;
212 self.visit_expr(&else_branch);
213 self.indent -= 2;
214 }
215 }
216 syn::Expr::Loop(expr_loop) => {
217 println!("{}Loop:", self.print_indent());
218 self.indent += 2;
219 for stmt in &expr_loop.body.stmts {
220 self.visit_stmt(stmt);
221 }
222 self.indent -= 2;
223 }
224 syn::Expr::While(expr_while) => {
225 println!("{}While loop:", self.print_indent());
226
227 println!("{}Condition:", self.print_indent());
228 self.indent += 2;
229 self.visit_expr(&expr_while.cond);
230 self.indent -= 2;
231
232 println!("{}Body:", self.print_indent());
233 self.indent += 2;
234 for stmt in &expr_while.body.stmts {
235 self.visit_stmt(stmt);
236 }
237 self.indent -= 2;
238 }
239 syn::Expr::Return(expr_return) => {
240 println!("{}Return statement:", self.print_indent());
241 if let Some(expr) = &expr_return.expr {
242 self.indent += 2;
243 self.visit_expr(expr);
244 self.indent -= 2;
245 }
246 }
247 _ => {
248 println!(
249 "{}Other expression: {}",
250 self.print_indent(),
251 node.to_token_stream()
252 );
253 }
254 }
255 }
256
257 fn visit_stmt(&mut self, node: &'ast syn::Stmt) {
266 match node {
267 syn::Stmt::Local(local) => {
268 println!("{}Variable declaration:", self.print_indent());
269 if let syn::Pat::Ident(pat_ident) = &local.pat {
270 println!("{}Name: {}", self.print_indent(), pat_ident.ident);
271 }
272
273 if let Some(init) = &local.init {
274 println!("{}Initializer:", self.print_indent());
275 self.indent += 2;
276 self.visit_expr(&init.expr);
277 self.indent -= 2;
278 }
279 }
280
281 syn::Stmt::Expr(expr, _) => {
282 println!("{}Expression statement:", self.print_indent());
283 self.indent += 2;
284 self.visit_expr(expr);
285 self.indent -= 2;
286 }
287
288 syn::Stmt::Item(item) => match item {
289 syn::Item::Fn(item_fn) => {
290 self.visit_item_fn(item_fn);
291 }
292 syn::Item::Struct(item_struct) => {
293 println!("{}Struct: {}", self.print_indent(), item_struct.ident);
294 if !item_struct.fields.is_empty() {
295 println!("{}Fields:", self.print_indent());
296 self.indent += 2;
297 for field in &item_struct.fields {
298 if let Some(ident) = &field.ident {
299 println!(
300 "{}Field: {} - Type: {}",
301 self.print_indent(),
302 ident,
303 field.ty.to_token_stream()
304 );
305 } else {
306 println!(
307 "{}Tuple field: {}",
308 self.print_indent(),
309 field.ty.to_token_stream()
310 );
311 }
312 }
313 self.indent -= 2;
314 }
315 }
316 syn::Item::Enum(item_enum) => {
317 println!("{}Enum: {}", self.print_indent(), item_enum.ident);
318 if !item_enum.variants.is_empty() {
319 println!("{}Variants:", self.print_indent());
320 self.indent += 2;
321 for variant in &item_enum.variants {
322 println!("{}Variant: {}", self.print_indent(), variant.ident);
323 }
324 self.indent -= 2;
325 }
326 }
327 _ => {
328 println!(
329 "{}Other item: {}",
330 self.print_indent(),
331 item.to_token_stream()
332 );
333 }
334 },
335 _ => {
337 println!(
338 "{}Other statement: {}",
339 self.print_indent(),
340 node.to_token_stream()
341 );
342 }
343 }
344 }
345}
346
347pub fn parse_rust_source(source: &str) -> Result<syn::File, syn::Error> {
358 syn::parse_file(source)
359}
360
361pub fn parse_rust_file<P: AsRef<Path>>(path: P) -> io::Result<syn::File> {
373 let source = fs::read_to_string(path)?;
374 let syntax =
375 syn::parse_file(&source).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
376
377 Ok(syntax)
378}
379
380pub fn print_ast(file: &File) {
388 println!("AST for Rust code:");
389 let mut visitor = TextVisitor::new();
390 visitor.visit_file(file);
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use std::io::Write;
397 use tempfile::NamedTempFile;
398
399 #[test]
400 fn test_parse_rust_file() {
401 let mut file = NamedTempFile::new().unwrap();
402 let test_code = r#"
403 fn test_function() {
404 println!("Hello, world!");
405 }
406 "#;
407
408 file.write_all(test_code.as_bytes()).unwrap();
409 file.flush().unwrap();
410
411 let ast = parse_rust_file(file.path()).unwrap();
412
413 assert_eq!(ast.items.len(), 1);
414 if let syn::Item::Fn(func) = &ast.items[0] {
415 assert_eq!(func.sig.ident.to_string(), "test_function");
416 } else {
417 panic!("Parsed item is not a function");
418 }
419 }
420
421 #[test]
422 fn test_parse_function() {
423 let source = r#"
424 fn add(a: i32, b: i32) -> i32 {
425 a + b
426 }
427 "#;
428
429 let file = parse_rust_source(source).unwrap();
430
431 assert_eq!(file.items.len(), 1);
433
434 if let syn::Item::Fn(func) = &file.items[0] {
436 assert_eq!(func.sig.ident.to_string(), "add");
437 assert_eq!(func.sig.inputs.len(), 2); if let syn::ReturnType::Type(_, return_type) = &func.sig.output {
441 if let syn::Type::Path(type_path) = &**return_type {
442 let path_segment = &type_path.path.segments[0];
443 assert_eq!(path_segment.ident.to_string(), "i32");
444 } else {
445 panic!("Return type is not a path");
446 }
447 } else {
448 panic!("Function has no return type");
449 }
450
451 assert_eq!(func.block.stmts.len(), 1);
453 } else {
454 panic!("Item is not a function");
455 }
456 }
457
458 #[test]
459 fn test_parse_struct() {
460 let source = r#"
461 struct Point {
462 x: f64,
463 y: f64,
464 }
465 "#;
466
467 let file = parse_rust_source(source).unwrap();
468
469 assert_eq!(file.items.len(), 1);
471
472 if let syn::Item::Struct(struct_item) = &file.items[0] {
474 assert_eq!(struct_item.ident.to_string(), "Point");
475
476 assert_eq!(struct_item.fields.iter().count(), 2);
478
479 let fields: Vec<_> = struct_item.fields.iter().collect();
481
482 assert_eq!(fields[0].ident.as_ref().unwrap().to_string(), "x");
484 if let syn::Type::Path(type_path) = &fields[0].ty {
485 let path_segment = &type_path.path.segments[0];
486 assert_eq!(path_segment.ident.to_string(), "f64");
487 } else {
488 panic!("Field x is not a path type");
489 }
490
491 assert_eq!(fields[1].ident.as_ref().unwrap().to_string(), "y");
493 if let syn::Type::Path(type_path) = &fields[1].ty {
494 let path_segment = &type_path.path.segments[0];
495 assert_eq!(path_segment.ident.to_string(), "f64");
496 } else {
497 panic!("Field y is not a path type");
498 }
499 } else {
500 panic!("Item is not a struct");
501 }
502 }
503
504 #[test]
505 fn test_parse_enum() {
506 let source = r#"
507 enum Direction {
508 North,
509 East,
510 South,
511 West,
512 }
513 "#;
514
515 let file = parse_rust_source(source).unwrap();
516
517 assert_eq!(file.items.len(), 1);
519
520 if let syn::Item::Enum(enum_item) = &file.items[0] {
522 assert_eq!(enum_item.ident.to_string(), "Direction");
523
524 assert_eq!(enum_item.variants.len(), 4);
526
527 let variant_names: Vec<String> = enum_item
529 .variants
530 .iter()
531 .map(|v| v.ident.to_string())
532 .collect();
533
534 assert_eq!(variant_names, vec!["North", "East", "South", "West"]);
535 } else {
536 panic!("Item is not an enum");
537 }
538 }
539
540 #[test]
541 fn test_parse_complex_expression() {
542 let source = r#"
543 fn complex_expr() {
544 let result = (10 + 20) * 30 / (5 - 2);
545 if result > 100 {
546 println!("Large result: {}", result);
547 } else {
548 println!("Small result: {}", result);
549 }
550 }
551 "#;
552
553 let file = parse_rust_source(source).unwrap();
554
555 assert_eq!(file.items.len(), 1);
557
558 if let syn::Item::Fn(func) = &file.items[0] {
560 assert_eq!(func.sig.ident.to_string(), "complex_expr");
561
562 assert_eq!(func.block.stmts.len(), 2);
564
565 if let syn::Stmt::Local(local) = &func.block.stmts[0] {
567 assert!(local.init.is_some());
568
569 if let syn::Pat::Ident(pat_ident) = &local.pat {
571 assert_eq!(pat_ident.ident.to_string(), "result");
572 } else {
573 panic!("Variable declaration pattern is not an identifier");
574 }
575 } else {
576 panic!("First statement is not a variable declaration");
577 }
578
579 if let syn::Stmt::Expr(expr, _) = &func.block.stmts[1] {
581 if let syn::Expr::If(_) = expr {
582 } else {
584 panic!("Second statement is not an if expression");
585 }
586 } else {
587 panic!("Second statement is not an expression");
588 }
589 } else {
590 panic!("Item is not a function");
591 }
592 }
593
594 #[test]
595 fn test_parse_invalid_code() {
596 let source = r#"
597 fn invalid_function( {
598 let x = 10;
599 }
600 "#;
601
602 let result = parse_rust_source(source);
603 assert!(result.is_err(), "Expected parse error for invalid code");
604 }
605
606 #[test]
607 fn test_parse_multiple_items() {
608 let source = r#"
609 fn function1() -> i32 { 42 }
610
611 struct MyStruct {
612 field: i32,
613 }
614
615 fn function2(s: MyStruct) -> i32 {
616 s.field
617 }
618 "#;
619
620 let file = parse_rust_source(source).unwrap();
621
622 assert_eq!(file.items.len(), 3);
624
625 if let syn::Item::Fn(func) = &file.items[0] {
627 assert_eq!(func.sig.ident.to_string(), "function1");
628 } else {
629 panic!("First item is not a function");
630 }
631
632 if let syn::Item::Struct(struct_item) = &file.items[1] {
634 assert_eq!(struct_item.ident.to_string(), "MyStruct");
635 } else {
636 panic!("Second item is not a struct");
637 }
638
639 if let syn::Item::Fn(func) = &file.items[2] {
641 assert_eq!(func.sig.ident.to_string(), "function2");
642 } else {
643 panic!("Third item is not a function");
644 }
645 }
646}