rustalize/
lib.rs

1use std::str::FromStr;
2
3#[derive(Debug, PartialEq)]
4pub enum AstNode {
5    Trait(TraitNode),
6    Struct(StructNode),
7    Enum(EnumNode),
8}
9
10#[derive(Debug, PartialEq)]
11pub struct TraitNode {
12    pub name: String,
13    pub methods: Vec<MethodNode>,
14}
15
16#[derive(Debug, PartialEq)]
17pub struct StructNode {
18    pub name: String,
19    pub fields: Vec<FieldNode>,
20}
21
22#[derive(Debug, PartialEq)]
23pub struct EnumNode {
24    pub name: String,
25    pub variants: Vec<VariantNode>,
26}
27
28#[derive(Debug, PartialEq)]
29pub struct MethodNode {
30    pub name: String,
31    pub params: Vec<ParamNode>,
32    pub return_type: Option<Box<TypeNode>>,
33}
34
35#[derive(Debug, PartialEq)]
36pub struct ParamNode {
37    pub name: String,
38    pub param_type: Box<TypeNode>,
39}
40
41#[derive(Debug, PartialEq)]
42pub struct FieldNode {
43    pub name: String,
44    pub field_type: Box<TypeNode>,
45}
46
47#[derive(Debug, PartialEq)]
48pub struct VariantNode {
49    pub name: String,
50    pub associated_data: Option<Box<AstNode>>,
51}
52
53#[derive(Debug, PartialEq)]
54pub enum TypeNode {
55    Simple(String),
56    Reference(Box<TypeNode>),
57    Generic { name: String, args: Vec<TypeNode> },
58}
59
60pub struct Parser;
61
62impl Parser {
63    pub fn parse(input: &str) -> Result<AstNode, String> {
64        let input = input.trim();
65        if input.starts_with("pub trait") {
66            Parser::parse_trait(input)
67        } else if input.starts_with("pub struct") {
68            Parser::parse_struct(input)
69        } else if input.starts_with("pub enum") {
70            Parser::parse_enum(input)
71        } else {
72            Err("Unsupported or invalid Rust construct".to_string())
73        }
74    }
75
76    fn parse_trait(input: &str) -> Result<AstNode, String> {
77        let trait_name = input
78            .split_whitespace()
79            .nth(2)
80            .ok_or("Invalid trait definition")?
81            .to_string();
82
83        let body_start = input.find('{').ok_or("Missing trait body")?;
84        let body_end = input.rfind('}').ok_or("Missing closing brace")?;
85        if body_end <= body_start {
86            return Err("Invalid trait body".to_string());
87        }
88        let body_content = &input[body_start + 1..body_end].trim();
89
90        let method_strings: Vec<&str> = body_content
91            .split(';')
92            .map(|s| s.trim())
93            .filter(|s| !s.is_empty())
94            .collect();
95
96        let mut methods = Vec::new();
97        for method_str in method_strings {
98            methods.push(Self::parse_method(method_str)?);
99        }
100
101        Ok(AstNode::Trait(TraitNode {
102            name: trait_name,
103            methods,
104        }))
105    }
106
107    fn parse_struct(input: &str) -> Result<AstNode, String> {
108        let struct_name = input
109            .split_whitespace()
110            .nth(2)
111            .ok_or("Invalid struct definition")?
112            .to_string();
113
114        let body_start = input.find('{').ok_or("Missing struct body")?;
115        let body_end = input.rfind('}').ok_or("Missing closing brace")?;
116        if body_end <= body_start {
117            return Err("Invalid struct body".to_string());
118        }
119        let body_content = &input[body_start + 1..body_end].trim();
120
121        let fields = body_content
122            .split(',')
123            .map(|field_str| {
124                let parts: Vec<&str> = field_str.split(':').collect();
125                if parts.len() != 2 {
126                    return Err("Invalid field format".to_string());
127                }
128                Ok(FieldNode {
129                    name: parts[0].trim().to_string(),
130                    field_type: Box::new(Self::parse_type(parts[1].trim())?),
131                })
132            })
133            .collect::<Result<Vec<FieldNode>, String>>()?;
134
135        Ok(AstNode::Struct(StructNode {
136            name: struct_name,
137            fields,
138        }))
139    }
140
141    fn parse_enum(input: &str) -> Result<AstNode, String> {
142        let enum_name = input
143            .split_whitespace()
144            .nth(2)
145            .ok_or("Invalid enum definition")?
146            .to_string();
147
148        let body_start = input.find('{').ok_or("Missing enum body")?;
149        let body_end = input.rfind('}').ok_or("Missing closing brace")?;
150        if body_end <= body_start {
151            return Err("Invalid enum body".to_string());
152        }
153        let body_content = &input[body_start + 1..body_end].trim();
154
155        let variant_strings: Vec<&str> = body_content
156            .split(',')
157            .map(|s| s.trim())
158            .filter(|s| !s.is_empty())
159            .collect();
160
161        let mut variants = Vec::new();
162        for variant_str in variant_strings {
163            if variant_str.contains('(') && variant_str.contains(')') {
164                // Variant with associated data
165                let name = variant_str.split('(').next().unwrap().trim().to_string();
166                let data_str = variant_str.split('(').nth(1).unwrap().trim_end_matches(')');
167                // For simplicity, assume associated data is a struct
168                let associated_ast = Parser::parse(data_str)?;
169                variants.push(VariantNode {
170                    name,
171                    associated_data: Some(Box::new(associated_ast)),
172                });
173            } else {
174                // Simple variant
175                variants.push(VariantNode {
176                    name: variant_str.to_string(),
177                    associated_data: None,
178                });
179            }
180        }
181
182        Ok(AstNode::Enum(EnumNode {
183            name: enum_name,
184            variants,
185        }))
186    }
187
188    fn parse_method(input: &str) -> Result<MethodNode, String> {
189        let input = input.trim();
190        let parts: Vec<&str> = input.split(&['(', ')']).collect();
191        if parts.len() < 2 {
192            return Err("Invalid method format".to_string());
193        }
194
195        let name = parts[0]
196            .split_whitespace()
197            .nth(1)
198            .ok_or("Invalid method name")?
199            .to_string();
200
201        let params = Self::parse_params(parts[1])?;
202
203        let return_type = if input.contains("->") {
204            let return_str = input
205                .split("->")
206                .nth(1)
207                .unwrap()
208                .trim()
209                .trim_end_matches(';')
210                .to_string();
211            Some(Box::new(Self::parse_type(&return_str)?))
212        } else {
213            None
214        };
215
216        Ok(MethodNode {
217            name,
218            params,
219            return_type,
220        })
221    }
222
223    fn parse_params(input: &str) -> Result<Vec<ParamNode>, String> {
224        if input.trim().is_empty() {
225            return Ok(Vec::new());
226        }
227
228        input
229            .split(',')
230            .map(|param| {
231                let param = param.trim();
232                if param == "&self" {
233                    Ok(ParamNode {
234                        name: "&self".to_string(),
235                        param_type: Box::new(TypeNode::Reference(Box::new(TypeNode::Simple("self".to_string())))),
236                    })
237                } else if param == "self" {
238                    Ok(ParamNode {
239                        name: "self".to_string(),
240                        param_type: Box::new(TypeNode::Simple("self".to_string())),
241                    })
242                } else {
243                    let parts: Vec<&str> = param.split(':').collect();
244                    if parts.len() != 2 {
245                        return Err("Invalid parameter format".to_string());
246                    }
247                    Ok(ParamNode {
248                        name: parts[0].trim().to_string(),
249                        param_type: Box::new(Self::parse_type(parts[1].trim())?),
250                    })
251                }
252            })
253            .collect()
254    }
255
256    fn parse_type(input: &str) -> Result<TypeNode, String> {
257        if input.starts_with('&') {
258            let inner = input.trim_start_matches('&').trim();
259            let inner_type = Self::parse_type(inner)?;
260            Ok(TypeNode::Reference(Box::new(inner_type)))
261        } else if input.starts_with('[') && input.ends_with(']') {
262            let inner_str = &input[1..input.len()-1].trim();
263            let inner_type = Self::parse_type(inner_str)?;
264            Ok(TypeNode::Generic {
265                name: "[]".to_string(),
266                args: vec![inner_type],
267            })
268        } else if input.contains('<') && input.contains('>') {
269            let name = input.split('<').next().unwrap().trim().to_string();
270            let args_str = input
271                .split('<')
272                .nth(1)
273                .unwrap()
274                .trim_end_matches('>')
275                .trim();
276            let args: Result<Vec<TypeNode>, String> = args_str
277                .split(',')
278                .map(|arg| Self::parse_type(arg.trim()))
279                .collect();
280            Ok(TypeNode::Generic { name, args: args? })
281        } else {
282            Ok(TypeNode::Simple(input.to_string()))
283        }
284    }
285
286    fn parse_tuple_variant(input: &str) -> Result<AstNode, String> {
287        let fields: Vec<FieldNode> = input
288            .split(',')
289            .map(|s| s.trim())
290            .filter(|s| !s.is_empty())
291            .enumerate()
292            .map(|(i, s)| -> Result<FieldNode, String> {
293                Ok(FieldNode {
294                    name: format!("{}", i),
295                    field_type: Box::new(Self::parse_type(s)?),
296                })
297            })
298            .collect::<Result<Vec<FieldNode>, String>>()?;
299
300        Ok(AstNode::Struct(StructNode {
301            name: "".to_string(),
302            fields,
303        }))
304    }
305}
306
307impl FromStr for AstNode {
308    type Err = String;
309
310    fn from_str(s: &str) -> Result<Self, Self::Err> {
311        Parser::parse(s)
312    }
313}
314
315// Tree Display Implementation with Recursive Traversal
316impl AstNode {
317    pub fn display_tree(&self) {
318        self.display_tree_internal("");
319    }
320
321    fn display_tree_internal(&self, prefix: &str) {
322        match self {
323            AstNode::Trait(trait_node) => {
324                println!("{}- Trait: {}", prefix, trait_node.name);
325                let len = trait_node.methods.len();
326                for (i, method) in trait_node.methods.iter().enumerate() {
327                    let is_last = i == len - 1;
328                    let branch = if is_last { "└──" } else { "├──" };
329                    let new_prefix = format!("{}{} ", prefix, branch);
330                    method.display_tree_internal(&new_prefix, is_last);
331                }
332            }
333            AstNode::Struct(struct_node) => {
334                println!("{}- Struct: {}", prefix, struct_node.name);
335                let len = struct_node.fields.len();
336                for (i, field) in struct_node.fields.iter().enumerate() {
337                    let is_last = i == len - 1;
338                    let branch = if is_last { "└──" } else { "├──" };
339                    let new_prefix = format!("{}{} ", prefix, branch);
340                    field.display_tree_internal(&new_prefix, is_last);
341                }
342            }
343            AstNode::Enum(enum_node) => {
344                println!("{}- Enum: {}", prefix, enum_node.name);
345                let len = enum_node.variants.len();
346                for (i, variant) in enum_node.variants.iter().enumerate() {
347                    let is_last = i == len - 1;
348                    let branch = if is_last { "└──" } else { "├──" };
349                    let new_prefix = format!("{}{} ", prefix, branch);
350                    variant.display_tree_internal(&new_prefix, is_last);
351                }
352            }
353        }
354    }
355}
356
357impl MethodNode {
358    fn display_tree_internal(&self, prefix: &str, is_last: bool) {
359        let _ = is_last;
360        println!("{}Method: {}", prefix, self.name);
361        let len = self.params.len();
362        for (i, param) in self.params.iter().enumerate() {
363            let is_last_param = i == len - 1;
364            let branch = if is_last_param {
365                "└──"
366            } else {
367                "├──"
368            };
369            let param_prefix = format!("{}{} ", prefix, branch);
370            param.display_tree_internal(&param_prefix, is_last_param);
371        }
372        if let Some(return_type) = &self.return_type {
373            let branch = if len == 0 { "└──" } else { "├──" };
374            let return_prefix = format!("{}{} ", prefix, branch);
375            println!("{}Return Type: {}", return_prefix, return_type.display());
376        }
377    }
378}
379
380impl FieldNode {
381    fn display_tree_internal(&self, prefix: &str, _is_last: bool) {
382        println!(
383            "{}Field: {}: {}",
384            prefix,
385            self.name,
386            self.field_type.display()
387        );
388    }
389}
390
391impl VariantNode {
392    fn display_tree_internal(&self, prefix: &str, _is_last: bool) {
393        println!("{}Variant: {}", prefix, self.name);
394        if let Some(associated_data) = &self.associated_data {
395            // Recursively display the associated AstNode
396            associated_data.display_tree_internal(&format!("{}    ", prefix));
397        }
398    }
399}
400
401impl TypeNode {
402    fn display(&self) -> String {
403        match self {
404            TypeNode::Simple(name) => name.clone(),
405            TypeNode::Reference(inner) => format!("&{}", inner.display()),
406            TypeNode::Generic { name, args } => {
407                let args_display: Vec<String> = args.iter().map(|arg| arg.display()).collect();
408                format!("{}<{}>", name, args_display.join(", "))
409            }
410        }
411    }
412}
413
414impl ParamNode {
415    fn display_tree_internal(&self, prefix: &str, _is_last: bool) {
416        println!(
417            "{}Param: {}: {}",
418            prefix,
419            self.name,
420            self.param_type.display()
421        );
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_parse_visualizer_trait() {
431        let inputs = vec![
432            r#"pub trait Visualizer {
433                fn visualize(&self, data: &[u8]);
434                fn process(&self, input: &str) -> String;
435            }"#,
436            r#"
437            pub trait Visualizer {
438                fn visualize(&self, data: &[u8]);
439                fn process(&self, input: &str) -> String;
440            }
441            "#,
442            "pub trait Visualizer { fn visualize(&self, data: &[u8]); fn process(&self, input: &str) -> String; }",
443        ];
444
445        for input in inputs {
446            let expected = AstNode::Trait(TraitNode {
447                name: "Visualizer".to_string(),
448                methods: vec![
449                    MethodNode {
450                        name: "visualize".to_string(),
451                        params: vec![
452                            ParamNode {
453                                name: "&self".to_string(),
454                                param_type: Box::new(TypeNode::Reference(Box::new(TypeNode::Simple("self".to_string())))),
455                            },
456                            ParamNode {
457                                name: "data".to_string(),
458                                param_type: Box::new(TypeNode::Generic {
459                                    name: "[]".to_string(),
460                                    args: vec![TypeNode::Simple("u8".to_string())],
461                                }),
462                            },
463                        ],
464                        return_type: None,
465                    },
466                    MethodNode {
467                        name: "process".to_string(),
468                        params: vec![
469                            ParamNode {
470                                name: "&self".to_string(),
471                                param_type: Box::new(TypeNode::Reference(Box::new(TypeNode::Simple("self".to_string())))),
472                            },
473                            ParamNode {
474                                name: "input".to_string(),
475                                param_type: Box::new(TypeNode::Reference(Box::new(TypeNode::Simple("str".to_string())))),
476                            },
477                        ],
478                        return_type: Some(Box::new(TypeNode::Simple("String".to_string()))),
479                    },
480                ],
481            });
482
483            assert_eq!(input.parse::<AstNode>().unwrap(), expected);
484        }
485    }
486
487    #[test]
488    fn test_parse_struct() {
489        let input = r#"
490            pub struct Point {
491                x: f64,
492                y: f64,
493                label: String,
494            }
495        "#;
496
497        let expected = AstNode::Struct(StructNode {
498            name: "Point".to_string(),
499            fields: vec![
500                FieldNode {
501                    name: "x".to_string(),
502                    field_type: Box::new(TypeNode::Simple("f64".to_string())),
503                },
504                FieldNode {
505                    name: "y".to_string(),
506                    field_type: Box::new(TypeNode::Simple("f64".to_string())),
507                },
508                FieldNode {
509                    name: "label".to_string(),
510                    field_type: Box::new(TypeNode::Simple("String".to_string())),
511                },
512            ],
513        });
514
515        assert_eq!(input.parse::<AstNode>().unwrap(), expected);
516    }
517
518    #[test]
519    fn test_parse_enum() {
520        let input = r#"
521            pub enum Color {
522                Red,
523                Green,
524                Blue,
525            }
526        "#;
527
528        let expected = AstNode::Enum(EnumNode {
529            name: "Color".to_string(),
530            variants: vec![
531                VariantNode {
532                    name: "Red".to_string(),
533                    associated_data: None,
534                },
535                VariantNode {
536                    name: "Green".to_string(),
537                    associated_data: None,
538                },
539                VariantNode {
540                    name: "Blue".to_string(),
541                    associated_data: None,
542                },
543            ],
544        });
545
546        assert_eq!(input.parse::<AstNode>().unwrap(), expected);
547    }
548
549    #[test]
550    fn test_parse_enum_with_associated_data() {
551        let input = r#"
552            pub enum Message {
553                Quit,
554                Move { x: i32, y: i32 },
555                Write(String),
556                ChangeColor(i32, i32, i32),
557            }
558        "#;
559
560        let expected = AstNode::Enum(EnumNode {
561            name: "Message".to_string(),
562            variants: vec![
563                VariantNode {
564                    name: "Quit".to_string(),
565                    associated_data: None,
566                },
567                VariantNode {
568                    name: "Move".to_string(),
569                    associated_data: Some(Box::new(AstNode::Struct(StructNode {
570                        name: "".to_string(), // Anonymous struct
571                        fields: vec![
572                            FieldNode {
573                                name: "x".to_string(),
574                                field_type: Box::new(TypeNode::Simple("i32".to_string())),
575                            },
576                            FieldNode {
577                                name: "y".to_string(),
578                                field_type: Box::new(TypeNode::Simple("i32".to_string())),
579                            },
580                        ],
581                    }))),
582                },
583                VariantNode {
584                    name: "Write".to_string(),
585                    associated_data: Some(Box::new(AstNode::Struct(StructNode {
586                        name: "".to_string(), // Tuple struct equivalent
587                        fields: vec![FieldNode {
588                            name: "0".to_string(),
589                            field_type: Box::new(TypeNode::Simple("String".to_string())),
590                        }],
591                    }))),
592                },
593                VariantNode {
594                    name: "ChangeColor".to_string(),
595                    associated_data: Some(Box::new(AstNode::Struct(StructNode {
596                        name: "".to_string(), // Tuple struct equivalent
597                        fields: vec![
598                            FieldNode {
599                                name: "0".to_string(),
600                                field_type: Box::new(TypeNode::Simple("i32".to_string())),
601                            },
602                            FieldNode {
603                                name: "1".to_string(),
604                                field_type: Box::new(TypeNode::Simple("i32".to_string())),
605                            },
606                            FieldNode {
607                                name: "2".to_string(),
608                                field_type: Box::new(TypeNode::Simple("i32".to_string())),
609                            },
610                        ],
611                    }))),
612                },
613            ],
614        });
615
616        assert_eq!(input.parse::<AstNode>().unwrap(), expected);
617    }
618
619    #[test]
620    fn test_invalid_input() {
621        let input = "fn standalone_function() {}";
622        assert!(input.parse::<AstNode>().is_err());
623    }
624
625    #[test]
626    fn test_parse_struct_with_invalid_field() {
627        let input = r#"
628            pub struct InvalidStruct {
629                x f64, // Missing colon
630                y: f64,
631            }
632        "#;
633
634        assert!(input.parse::<AstNode>().is_err());
635    }
636
637    #[test]
638    fn test_parse_trait_with_invalid_method() {
639        let input = r#"
640            pub trait InvalidTrait {
641                fn invalid_method(&self data: &[u8]); // Missing comma
642            }
643        "#;
644
645        assert!(input.parse::<AstNode>().is_err());
646    }
647}