1use std::fs;
2use std::io;
3use std::path::Path;
4
5use syn::{File, visit::Visit};
6
7use crate::visitor::AstVisitor;
8
9pub fn parse_rust_source(source: &str) -> Result<syn::File, syn::Error> {
20 syn::parse_file(source)
21}
22
23pub fn parse_rust_file<P: AsRef<Path>>(path: P) -> io::Result<syn::File> {
35 let source = fs::read_to_string(path)?;
36 let syntax =
37 syn::parse_file(&source).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
38
39 Ok(syntax)
40}
41
42pub fn print_ast(file: &File) {
50 println!("AST for Rust code:");
51 let mut visitor = AstVisitor::new();
52 visitor.visit_file(file);
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58 use std::io::Write;
59 use tempfile::NamedTempFile;
60
61 #[test]
62 fn test_parse_rust_file() {
63 let mut file = NamedTempFile::new().unwrap();
64 let test_code = r#"
65 fn test_function() {
66 println!("Hello, world!");
67 }
68 "#;
69
70 file.write_all(test_code.as_bytes()).unwrap();
71 file.flush().unwrap();
72
73 let ast = parse_rust_file(file.path()).unwrap();
74
75 assert_eq!(ast.items.len(), 1);
76 if let syn::Item::Fn(func) = &ast.items[0] {
77 assert_eq!(func.sig.ident.to_string(), "test_function");
78 } else {
79 panic!("Parsed item is not a function");
80 }
81 }
82
83 #[test]
84 fn test_parse_function() {
85 let source = r#"
86 fn add(a: i32, b: i32) -> i32 {
87 a + b
88 }
89 "#;
90
91 let file = parse_rust_source(source).unwrap();
92
93 assert_eq!(file.items.len(), 1);
95
96 if let syn::Item::Fn(func) = &file.items[0] {
98 assert_eq!(func.sig.ident.to_string(), "add");
99 assert_eq!(func.sig.inputs.len(), 2); if let syn::ReturnType::Type(_, return_type) = &func.sig.output {
103 if let syn::Type::Path(type_path) = &**return_type {
104 let path_segment = &type_path.path.segments[0];
105 assert_eq!(path_segment.ident.to_string(), "i32");
106 } else {
107 panic!("Return type is not a path");
108 }
109 } else {
110 panic!("Function has no return type");
111 }
112
113 assert_eq!(func.block.stmts.len(), 1);
115 } else {
116 panic!("Item is not a function");
117 }
118 }
119
120 #[test]
121 fn test_parse_struct() {
122 let source = r#"
123 struct Point {
124 x: f64,
125 y: f64,
126 }
127 "#;
128
129 let file = parse_rust_source(source).unwrap();
130
131 assert_eq!(file.items.len(), 1);
133
134 if let syn::Item::Struct(struct_item) = &file.items[0] {
136 assert_eq!(struct_item.ident.to_string(), "Point");
137
138 assert_eq!(struct_item.fields.iter().count(), 2);
140
141 let fields: Vec<_> = struct_item.fields.iter().collect();
143
144 assert_eq!(fields[0].ident.as_ref().unwrap().to_string(), "x");
146 if let syn::Type::Path(type_path) = &fields[0].ty {
147 let path_segment = &type_path.path.segments[0];
148 assert_eq!(path_segment.ident.to_string(), "f64");
149 } else {
150 panic!("Field x is not a path type");
151 }
152
153 assert_eq!(fields[1].ident.as_ref().unwrap().to_string(), "y");
155 if let syn::Type::Path(type_path) = &fields[1].ty {
156 let path_segment = &type_path.path.segments[0];
157 assert_eq!(path_segment.ident.to_string(), "f64");
158 } else {
159 panic!("Field y is not a path type");
160 }
161 } else {
162 panic!("Item is not a struct");
163 }
164 }
165
166 #[test]
167 fn test_parse_enum() {
168 let source = r#"
169 enum Direction {
170 North,
171 East,
172 South,
173 West,
174 }
175 "#;
176
177 let file = parse_rust_source(source).unwrap();
178
179 assert_eq!(file.items.len(), 1);
181
182 if let syn::Item::Enum(enum_item) = &file.items[0] {
184 assert_eq!(enum_item.ident.to_string(), "Direction");
185
186 assert_eq!(enum_item.variants.len(), 4);
188
189 let variant_names: Vec<String> = enum_item
191 .variants
192 .iter()
193 .map(|v| v.ident.to_string())
194 .collect();
195
196 assert_eq!(variant_names, vec!["North", "East", "South", "West"]);
197 } else {
198 panic!("Item is not an enum");
199 }
200 }
201
202 #[test]
203 fn test_parse_complex_expression() {
204 let source = r#"
205 fn complex_expr() {
206 let result = (10 + 20) * 30 / (5 - 2);
207 if result > 100 {
208 println!("Large result: {}", result);
209 } else {
210 println!("Small result: {}", result);
211 }
212 }
213 "#;
214
215 let file = parse_rust_source(source).unwrap();
216
217 assert_eq!(file.items.len(), 1);
219
220 if let syn::Item::Fn(func) = &file.items[0] {
222 assert_eq!(func.sig.ident.to_string(), "complex_expr");
223
224 assert_eq!(func.block.stmts.len(), 2);
226
227 if let syn::Stmt::Local(local) = &func.block.stmts[0] {
229 assert!(local.init.is_some());
230
231 if let syn::Pat::Ident(pat_ident) = &local.pat {
233 assert_eq!(pat_ident.ident.to_string(), "result");
234 } else {
235 panic!("Variable declaration pattern is not an identifier");
236 }
237 } else {
238 panic!("First statement is not a variable declaration");
239 }
240
241 if let syn::Stmt::Expr(expr, _) = &func.block.stmts[1] {
243 if let syn::Expr::If(_) = expr {
244 } else {
246 panic!("Second statement is not an if expression");
247 }
248 } else {
249 panic!("Second statement is not an expression");
250 }
251 } else {
252 panic!("Item is not a function");
253 }
254 }
255
256 #[test]
257 fn test_parse_invalid_code() {
258 let source = r#"
259 fn invalid_function( {
260 let x = 10;
261 }
262 "#;
263
264 let result = parse_rust_source(source);
265 assert!(result.is_err(), "Expected parse error for invalid code");
266 }
267
268 #[test]
269 fn test_parse_multiple_items() {
270 let source = r#"
271 fn function1() -> i32 { 42 }
272
273 struct MyStruct {
274 field: i32,
275 }
276
277 fn function2(s: MyStruct) -> i32 {
278 s.field
279 }
280 "#;
281
282 let file = parse_rust_source(source).unwrap();
283
284 assert_eq!(file.items.len(), 3);
286
287 if let syn::Item::Fn(func) = &file.items[0] {
289 assert_eq!(func.sig.ident.to_string(), "function1");
290 } else {
291 panic!("First item is not a function");
292 }
293
294 if let syn::Item::Struct(struct_item) = &file.items[1] {
296 assert_eq!(struct_item.ident.to_string(), "MyStruct");
297 } else {
298 panic!("Second item is not a struct");
299 }
300
301 if let syn::Item::Fn(func) = &file.items[2] {
303 assert_eq!(func.sig.ident.to_string(), "function2");
304 } else {
305 panic!("Third item is not a function");
306 }
307 }
308}