Skip to main content

typhoon_syn/utils/
expr.rs

1use {
2    quote::{format_ident, quote, ToTokens},
3    syn::{
4        fold::{fold_expr, Fold},
5        parse::Parse,
6        parse_quote, Expr, Ident,
7    },
8};
9
10#[derive(Clone)]
11pub struct ContextExpr {
12    pub names: Vec<Ident>,
13    expr: Expr,
14}
15
16impl Parse for ContextExpr {
17    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
18        let expr = Expr::parse(input)?;
19
20        Ok(ContextExpr::from(expr.clone()))
21    }
22}
23
24impl ToTokens for ContextExpr {
25    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
26        let expr = &self.expr;
27        quote!(#expr).to_tokens(tokens)
28    }
29}
30
31impl From<Expr> for ContextExpr {
32    fn from(value: Expr) -> Self {
33        let mut names = Names::default();
34        let expr = names.fold_expr(value);
35        ContextExpr {
36            names: names.0,
37            expr,
38        }
39    }
40}
41
42#[derive(Default)]
43pub struct Names(Vec<Ident>);
44
45impl Fold for Names {
46    fn fold_expr(&mut self, i: syn::Expr) -> syn::Expr {
47        let Expr::Try(ref try_expr) = i else {
48            return fold_expr(self, i);
49        };
50
51        let Expr::MethodCall(ref method_call) = try_expr.expr.as_ref() else {
52            return fold_expr(self, i);
53        };
54
55        if method_call.method != "data" {
56            return fold_expr(self, i);
57        }
58
59        let Expr::Path(name) = method_call.receiver.as_ref() else {
60            return fold_expr(self, i);
61        };
62
63        let Some(name_ident) = name.path.get_ident().cloned() else {
64            return fold_expr(self, i);
65        };
66
67        let ident = format_ident!("{}_state", name_ident);
68
69        self.0.push(name_ident);
70
71        parse_quote!(#ident)
72    }
73}
74
75#[cfg(test)]
76mod from_expr_tests {
77    use {
78        super::*,
79        quote::{quote, ToTokens},
80        syn::parse_quote,
81    };
82
83    #[test]
84    fn test_method_call_with_try() {
85        // Test for pattern: counter.data()?.bump()
86        let expr: Expr = parse_quote!(counter.data()?.bump());
87        let context_expr = ContextExpr::from(expr);
88
89        let inner_expr = context_expr.to_token_stream();
90        let expected_expr = quote!(counter_state.bump());
91        assert_eq!(expected_expr.to_string(), inner_expr.to_string());
92        assert_eq!(context_expr.names.len(), 1);
93    }
94
95    #[test]
96    fn test_field_access_with_try() {
97        // Test for pattern: counter.data()?.bump
98        let expr: Expr = parse_quote!(counter.data()?.bump);
99        let context_expr = ContextExpr::from(expr);
100
101        let inner_expr = context_expr.to_token_stream().to_string();
102        let expected_expr = quote!(counter_state.bump);
103        assert_eq!(expected_expr.to_string(), inner_expr.to_string());
104        assert_eq!(context_expr.names.len(), 1);
105    }
106
107    #[test]
108    fn test_other_expr() {
109        let expr: Expr = parse_quote!(counter.random()?.bump);
110        let context_expr = ContextExpr::from(expr);
111
112        assert!(context_expr.names.is_empty());
113
114        let inner_expr = context_expr.to_token_stream().to_string();
115        let expected_expr = quote!(counter.random()?.bump);
116        assert_eq!(expected_expr.to_string(), inner_expr.to_string());
117    }
118}