Skip to main content

umili_derive/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote, ToTokens};
5
6/// Derive `Observe` trait for a struct.
7/// 
8/// ## Example
9/// 
10/// ```
11/// use serde::Serialize;
12/// use umili::Observe;
13/// 
14/// // It is commonly used with `Serialize`, `Clone` and `PartialEq` traits.
15/// #[derive(Serialize, Clone, PartialEq, Observe)]
16/// struct Point {
17///    x: f64,
18///    y: f64,
19/// }
20/// ```
21#[proc_macro_derive(Observe)]
22pub fn derive_observe(input: TokenStream) -> TokenStream {
23    let derive: syn::DeriveInput = syn::parse_macro_input!(input);
24    let ident = &derive.ident;
25    let (impl_generics, type_generics, where_clause) = derive.generics.split_for_impl();
26    let ident_ob = format_ident!("{}Ob", ident);
27    let mut type_fields = vec![];
28    let mut inst_fields = vec![];
29    match &derive.data {
30        syn::Data::Struct(syn::DataStruct {
31            fields: syn::Fields::Named(syn::FieldsNamed { named, .. }),
32            ..
33        }) => {
34            for name in named {
35                let ident = name.ident.as_ref().unwrap();
36                let ty = &name.ty;
37                type_fields.push(quote! {
38                    pub #ident: ::umili::Ob<'i, #ty>,
39                });
40                inst_fields.push(quote! {
41                    #ident: ::umili::Ob {
42                        value: &mut self.#ident,
43                        ctx: ctx.extend(stringify!(#ident)),
44                    },
45                });
46            }
47        },
48        _ => unimplemented!("not implemented"),
49    };
50    quote! {
51        #[automatically_derived]
52        impl #impl_generics Observe for #ident #type_generics #where_clause {
53            type Target<'i> = #ident_ob<'i>;
54
55            fn observe(&mut self, ctx: &::umili::Context) -> Self::Target<'_> {
56                #ident_ob {
57                    #(#inst_fields)*
58                }
59            }
60        }
61
62        pub struct #ident_ob<'i> {
63            #(#type_fields)*
64        }
65    }.into()
66}
67
68/// Observe the side effects of a closure.
69/// 
70/// ## Example
71/// 
72/// ```
73/// use serde::Serialize;
74/// use umili::{observe, Observe};
75/// 
76/// #[derive(Serialize, Clone, PartialEq, Observe)]
77/// struct Point {
78///   x: f64,
79///   y: f64,
80/// }
81/// 
82/// let mut point = Point { x: 1.0, y: 2.0 };
83/// observe!(|mut point| {
84///    point.x += 1.0;
85///    point.y += 1.0;
86/// }).unwrap();
87/// ```
88#[proc_macro]
89pub fn observe(input: TokenStream) -> TokenStream {
90    let input: syn::Expr = syn::parse_macro_input!(input);
91    let syn::Expr::Closure(mut closure) = input else {
92        panic!("expect a closure expression")
93    };
94    if closure.inputs.len() != 1 {
95        panic!("expect a closure with one argument")
96    }
97    let syn::Pat::Ident(syn::PatIdent { ident, .. }) = &closure.inputs[0] else {
98        panic!("expect a closure with one argument")
99    };
100    let body = &mut closure.body;
101    let body_shadow = body.to_token_stream();
102    subst_expr(body, ident);
103    quote! {
104        {
105            use ::std::ops::*;
106            let _ = || #body_shadow;
107            let ctx = ::umili::Context::new();
108            let mut #ident = #ident.observe(&ctx);
109            #body;
110            ctx.collect()
111        }
112    }.into()
113}
114
115fn subst_expr_field(expr_field: &mut syn::ExprField, ident: &syn::Ident, inner: bool) -> Option<syn::Expr> {
116    // erase span info from expr_field
117    let member = format_ident!("{}", expr_field.member.to_token_stream().to_string());
118    let method = match inner {
119        true => format_ident!("borrow"),
120        false => format_ident!("borrow_mut"),
121    };
122    match &mut *expr_field.base {
123        syn::Expr::Path(expr_path) => {
124            if expr_path.to_token_stream().to_string() == ident.to_string() {
125                return Some(syn::parse_quote! {
126                    #ident.#member.#method()
127                });
128            }
129        },
130        syn::Expr::Field(expr_field) => {
131            if let Some(new_expr) = subst_expr_field(expr_field, ident, true) {
132                return Some(syn::parse_quote! {
133                    #new_expr.#member.#method()
134                });
135            }
136        },
137        _ => subst_expr(&mut expr_field.base, ident),
138    }
139    None
140}
141
142fn subst_expr(expr: &mut syn::Expr, ident: &syn::Ident) {
143    match expr {
144        syn::Expr::Array(expr_array) => {
145            for expr in expr_array.elems.iter_mut() {
146                subst_expr(expr, ident);
147            }
148        },
149        syn::Expr::Assign(expr_assign) => {
150            subst_expr(&mut expr_assign.left, ident);
151            subst_expr(&mut expr_assign.right, ident);
152        },
153        syn::Expr::Async(expr_async) => {
154            subst_block(&mut expr_async.block, ident);
155        },
156        syn::Expr::Await(expr_await) => {
157            subst_expr(&mut expr_await.base, ident);
158        },
159        syn::Expr::Binary(expr_binary) => {
160            subst_expr(&mut expr_binary.left, ident);
161            subst_expr(&mut expr_binary.right, ident);
162            match &expr_binary.op {
163                syn::BinOp::AddAssign(..) => {
164                    let left = &expr_binary.left;
165                    let right = &expr_binary.right;
166                    *expr = syn::parse_quote! {
167                        #left.add_assign(#right)
168                    }
169                },
170                _ => {},
171            }
172        },
173        syn::Expr::Block(expr_block) => {
174            subst_block(&mut expr_block.block, ident);
175        },
176        syn::Expr::Break(..) => {},
177        syn::Expr::Call(expr_call) => {
178            subst_expr(&mut expr_call.func, ident);
179            for expr in expr_call.args.iter_mut() {
180                subst_expr(expr, ident);
181            }
182        },
183        syn::Expr::Cast(expr_cast) => {
184            subst_expr(&mut expr_cast.expr, ident);
185        },
186        syn::Expr::Closure(expr_closure) => {
187            subst_expr(&mut expr_closure.body, ident);
188        },
189        syn::Expr::Const(expr_const) => {
190            subst_block(&mut expr_const.block, ident);
191        },
192        syn::Expr::Continue(..) => {},
193        syn::Expr::Field(expr_field) => {
194            if let Some(new_expr) = subst_expr_field(expr_field, ident, false) {
195                *expr = new_expr;
196            }
197        },
198        syn::Expr::ForLoop(expr_for_loop) => {
199            subst_expr(&mut expr_for_loop.expr, ident);
200            subst_block(&mut expr_for_loop.body, ident);
201        },
202        syn::Expr::Group(expr_group) => {
203            subst_expr(&mut expr_group.expr, ident);
204        },
205        syn::Expr::If(expr_if) => {
206            subst_expr(&mut expr_if.cond, ident);
207            subst_block(&mut expr_if.then_branch, ident);
208            if let Some((_, expr)) = &mut expr_if.else_branch {
209                subst_expr(expr, ident);
210            }
211        },
212        syn::Expr::Index(expr_index) => {
213            subst_expr(&mut expr_index.expr, ident);
214            subst_expr(&mut expr_index.index, ident);
215        },
216        syn::Expr::Let(expr_let) => {
217            subst_expr(&mut expr_let.expr, ident);
218        },
219        syn::Expr::Lit(..) => {},
220        syn::Expr::Loop(expr_loop) => {
221            subst_block(&mut expr_loop.body, ident);
222        },
223        syn::Expr::Macro(..) => {},
224        syn::Expr::Match(expr_match) => {
225            subst_expr(&mut expr_match.expr, ident);
226            for arm in expr_match.arms.iter_mut() {
227                subst_expr(&mut arm.body, ident);
228            }
229        },
230        syn::Expr::MethodCall(expr_method_call) => {
231            subst_expr(&mut expr_method_call.receiver, ident);
232            for expr in expr_method_call.args.iter_mut() {
233                subst_expr(expr, ident);
234            }
235        },
236        syn::Expr::Paren(expr_paren) => {
237            subst_expr(&mut expr_paren.expr, ident);
238        },
239        syn::Expr::Path(..) => {},
240        syn::Expr::Range(expr_range) => {
241            if let Some(expr) = &mut expr_range.start {
242                subst_expr(expr, ident);
243            }
244            if let Some(expr) = &mut expr_range.end {
245                subst_expr(expr, ident);
246            }
247        },
248        syn::Expr::RawAddr(expr_raw_addr) => {
249            subst_expr(&mut expr_raw_addr.expr, ident);
250        },
251        syn::Expr::Reference(expr_reference) => {
252            subst_expr(&mut expr_reference.expr, ident);
253        },
254        syn::Expr::Repeat(expr_repeat) => {
255            subst_expr(&mut expr_repeat.expr, ident);
256            subst_expr(&mut expr_repeat.len, ident);
257        },
258        syn::Expr::Return(expr_return) => {
259            if let Some(expr) = &mut expr_return.expr {
260                subst_expr(expr, ident);
261            }
262        },
263        syn::Expr::Struct(expr_struct) => {
264            for field in expr_struct.fields.iter_mut() {
265                subst_expr(&mut field.expr, ident);
266            }
267        },
268        syn::Expr::Try(expr_try) => {
269            subst_expr(&mut expr_try.expr, ident);
270        },
271        syn::Expr::TryBlock(expr_try_block) => {
272            subst_block(&mut expr_try_block.block, ident);
273        },
274        syn::Expr::Tuple(expr_tuple) => {
275            for expr in expr_tuple.elems.iter_mut() {
276                subst_expr(expr, ident);
277            }
278        },
279        syn::Expr::Unary(expr_unary) => {
280            subst_expr(&mut expr_unary.expr, ident);
281        },
282        syn::Expr::Unsafe(expr_unsafe) => {
283            subst_block(&mut expr_unsafe.block, ident);
284        },
285        syn::Expr::Verbatim(..) => {},
286        syn::Expr::While(expr_while) => {
287            subst_expr(&mut expr_while.cond, ident);
288            subst_block(&mut expr_while.body, ident);
289        },
290        syn::Expr::Yield(expr_yield) => {
291            if let Some(expr) = &mut expr_yield.expr {
292                subst_expr(expr, ident);
293            }
294        },
295        _ => unimplemented!("unimplemented expr: {}", expr.to_token_stream()),
296    }
297}
298
299fn subst_block(block: &mut syn::Block, ident: &syn::Ident) {
300    for stmt in block.stmts.iter_mut() {
301        subst_stmt(stmt, ident);
302    }
303}
304
305fn subst_stmt(stmt: &mut syn::Stmt, ident: &syn::Ident) {
306    match stmt {
307        syn::Stmt::Local(local) => {
308            if let Some(local_init) = &mut local.init {
309                subst_expr(&mut local_init.expr, ident);
310                if let Some((_, expr)) = &mut local_init.diverge {
311                    subst_expr(expr, ident);
312                }
313            }
314        },
315        syn::Stmt::Expr(expr, ..) => {
316            subst_expr(expr, ident);
317        },
318        syn::Stmt::Macro(..) => {},
319        _ => unimplemented!("unimplemented stmt: {}", stmt.to_token_stream()),
320    }
321}