rust_automata_macros/
lib.rs

1//! Attribute‑style DSL for defining finite‑state machines.
2//!
3//! See the [`rust-automata` crate](https://docs.rs/rust-automata/) for more details.
4//!
5//! Documentation features:
6//! - `"mermaid"`: embed a clickable Mermaid state diagram.
7//! - `"dsl"`: (re)generate a DSL for the machine.
8
9#![recursion_limit = "256"]
10
11extern crate proc_macro;
12
13use proc_macro::TokenStream;
14use proc_macro2::TokenStream as TokenStream2;
15use quote::{format_ident, quote};
16use std::collections::BTreeSet;
17use syn::{parse_macro_input, DeriveInput, Ident, ItemStruct, Path};
18
19mod parser;
20use parser::{MachineAttr, Transition};
21
22mod annotations;
23
24// Handlers that have this prefix receive states and inputs and should return a state and an output.
25const HANDLE_PREFIX: &str = "handle_";
26// Guards that have this prefix receive state reference and should return a boolean.
27const GUARD_PREFIX: &str = "guard_";
28
29mod util {
30    use super::*;
31    use heck::ToSnakeCase;
32
33    /// `CamelCase` → `snake_case` ident.
34    pub fn snake(id: &Ident) -> Ident {
35        Ident::new(&id.to_string().to_snake_case(), id.span())
36    }
37
38    /// Last segment of `syn::Path` → `snake_case` ident.
39    pub fn snake_path(p: &Path) -> Ident {
40        snake(last(p))
41    }
42
43    /// Last identifier of a path (`states::Open` → `Open`)
44    pub fn last(p: &Path) -> &Ident {
45        &p.segments.last().unwrap().ident
46    }
47
48    /// A unique string key for set membership (`states::Open`)
49    pub fn key(p: &Path) -> String {
50        p.segments
51            .iter()
52            .map(|s| s.ident.to_string())
53            .collect::<Vec<_>>()
54            .join("::")
55    }
56
57    /// Strip a trailing `Machine` from a type name if present.
58    pub fn strip_machine(id: &Ident) -> String {
59        let s = id.to_string();
60        s.strip_suffix("Machine").unwrap_or(&s).to_owned()
61    }
62
63    pub fn compile_error_if(condition: bool, message: &str) -> Option<TokenStream2> {
64        condition.then(|| quote! { compile_error!(#message); })
65    }
66}
67
68use util::*;
69
70mod building_blocks {
71    use super::*;
72    /// Generate a signature check for a transition.
73    pub fn make_handler_sig_check(tr: &Transition, machine_ident: &Ident) -> TokenStream2 {
74        match tr.handler {
75            Some(ref handler) if handler.to_string().starts_with(HANDLE_PREFIX) => {
76                let state_ty = &tr.from_state;
77                let to_ty = &tr.to_state;
78
79                match (tr.input.as_ref(), tr.output.as_ref()) {
80                    (Some(inp_ty), Some(out_ty)) => quote! {
81                        super::#machine_ident::#handler as fn(&mut super::#machine_ident, super::#state_ty, super::#inp_ty) -> (super::#to_ty, super::#out_ty);
82                    },
83                    (Some(inp_ty), None) => quote! {
84                        super::#machine_ident::#handler as fn(&mut super::#machine_ident, super::#state_ty, super::#inp_ty) -> super::#to_ty;
85                    },
86                    (None, Some(out_ty)) => quote! {
87                        super::#machine_ident::#handler as fn(&mut super::#machine_ident, super::#state_ty) -> (super::#to_ty, super::#out_ty);
88                    },
89                    (None, None) => quote! {
90                        super::#machine_ident::#handler as fn(&mut super::#machine_ident, super::#state_ty) -> super::#to_ty;
91                    },
92                }
93            }
94            _ => quote! {},
95        }
96    }
97
98    pub fn instantiate_vals(tr: &parser::Transition, state_var: &Ident) -> TokenStream2 {
99        let next_val = if key(&tr.from_state) == key(&tr.to_state) {
100            quote! { #state_var }
101        } else {
102            let to_path = &tr.to_state;
103            quote! { super::#to_path::default() }
104        };
105        let out_val = if tr.output.is_some() {
106            let out_path = tr.output.as_ref().unwrap();
107            quote! { super::#out_path::default() }
108        } else {
109            quote! { () }
110        };
111
112        quote! {
113            next_val = #next_val;
114            out_val = #out_val;
115        }
116    }
117
118    pub fn build_handler_code(
119        tr: &Transition,
120        state_var: &Ident,
121        input_var: &Ident,
122    ) -> (TokenStream2, TokenStream2) {
123        match &tr.handler {
124            Some(handler) if handler.to_string().starts_with(HANDLE_PREFIX) => {
125                let has_input = tr.input.is_some();
126                let has_output = tr.output.is_some();
127                let call = match (has_input, has_output) {
128                    (true, true) => {
129                        quote! { (next_val, out_val) = self.#handler(#state_var, #input_var); }
130                    }
131                    (true, false) => {
132                        quote! { next_val = self.#handler(#state_var, #input_var); out_val = (); }
133                    }
134                    (false, true) => quote! { (next_val, out_val) = self.#handler(#state_var); },
135                    (false, false) => {
136                        quote! { next_val = self.#handler(#state_var); out_val = (); }
137                    }
138                };
139                (call, quote! {})
140            }
141            Some(callback) => (
142                quote! { self.#callback(); },
143                instantiate_vals(tr, state_var),
144            ),
145            None => (quote! {}, instantiate_vals(tr, state_var)),
146        }
147    }
148
149    pub fn build_guard_code(tr: &Transition, state_var: &Ident) -> TokenStream2 {
150        match &tr.guard {
151            Some(expr) => {
152                fn transform_expr(expr: &syn::Expr, state_var: &Ident) -> TokenStream2 {
153                    match expr {
154                        syn::Expr::Path(expr_path) => {
155                            let ident = &expr_path.path;
156                            if key(ident).starts_with(GUARD_PREFIX) {
157                                quote! { (&self).#ident(&#state_var) }
158                            } else {
159                                quote! { (&self).#ident() }
160                            }
161                        }
162                        syn::Expr::Binary(binary) => {
163                            let left = transform_expr(&binary.left, state_var);
164                            let op = &binary.op;
165                            let right = transform_expr(&binary.right, state_var);
166                            quote! { #left #op #right }
167                        }
168                        syn::Expr::Unary(unary) => {
169                            let op = &unary.op;
170                            let expr = transform_expr(&unary.expr, state_var);
171                            quote! { #op #expr }
172                        }
173                        _ => panic!("Unsupported expression: {}", parser::token_to_string(expr)),
174                    }
175                }
176
177                let transformed = transform_expr(expr, state_var);
178                quote! { if #transformed }
179            }
180            // no guard
181            None => quote! {},
182        }
183    }
184
185    /// Validates a MachineAttr for correctness.
186    /// Returns a TokenStream2 containing any compile errors found.
187    pub fn validate_machine_attr(m: &MachineAttr) -> TokenStream2 {
188        let states_set: BTreeSet<String> = m.states.iter().map(key).collect();
189        let inputs_set: BTreeSet<String> = m.inputs.iter().map(key).collect();
190        let outputs_set: BTreeSet<String> = m.outputs.iter().map(key).collect();
191
192        if states_set.is_empty() {
193            return quote! { compile_error!("No states are defined"); };
194        }
195        let errors = m.transitions.iter().flat_map(|tr| {
196            let tr_descr = tr.to_string();
197            vec![
198                compile_error_if(
199                    !states_set.contains(&key(&tr.from_state)),
200                    &format!("Unknown state: {} in {}", key(&tr.from_state), tr_descr),
201                ),
202                compile_error_if(
203                    !states_set.contains(&key(&tr.to_state)),
204                    &format!("Unknown state: {} in {}", key(&tr.to_state), tr_descr),
205                ),
206                tr.input.as_ref().and_then(|i| {
207                    compile_error_if(
208                        !inputs_set.contains(&key(i)),
209                        &format!("Unknown input: {} in {}", key(i), tr_descr),
210                    )
211                }),
212                tr.output.as_ref().and_then(|o| {
213                    compile_error_if(
214                        !outputs_set.contains(&key(o)),
215                        &format!("Unknown output: {} in {}", key(o), tr_descr),
216                    )
217                }),
218                tr.handler.as_ref().and_then(|h| {
219                    compile_error_if(
220                        h.to_string().starts_with(GUARD_PREFIX),
221                        &format!("Handler cannot start with guard_ prefix: {}", h),
222                    )
223                }),
224            ]
225            .into_iter()
226            .flatten()
227        });
228        quote! { #(#errors)* }
229    }
230
231    // A helper function that maps an iterable collection of identifiers to our enum match arms.
232    pub fn generate_enum_matches(ids: &Vec<&Ident>) -> Vec<proc_macro2::TokenStream> {
233        ids.iter()
234            .enumerate()
235            .map(|(idx, id)| {
236                let idx = idx + 1;
237                quote! {
238                    Self::#id(_) => rust_automata::EnumId::new(#idx)
239                }
240            })
241            .collect()
242    }
243
244    pub fn build_getters(alphabet_paths: &[Path]) -> TokenStream2 {
245        let getters = alphabet_paths.iter().map(|p| {
246            let id = last(p);
247            let direct_fn = snake_path(p);
248            let is_fn = format_ident!("is_{}", direct_fn);
249            let maybe_fn = format_ident!("maybe_{}", direct_fn);
250            quote! {
251                pub fn #is_fn(&self) -> bool {
252                    matches!(self, Self::#id(_))
253                }
254                pub fn #maybe_fn(&self) -> Option<&super::#p> {
255                    if let Self::#id(o) = self { Some(o) } else { None }
256                }
257                pub fn #direct_fn(&self) -> &super::#p {
258                    self.#maybe_fn().expect(&format!("No such symbol like {}", stringify!(#direct_fn)))
259                }
260            }
261        });
262        quote! { #( #getters )* }
263    }
264
265    pub fn build_conversions(enum_ident: &Ident, alphabet_paths: &[Path]) -> TokenStream2 {
266        let conversions = alphabet_paths.iter().enumerate().map(|(idx, p)| {
267            let id = last(p);
268            quote! {
269                impl From<super::#p> for #enum_ident {
270                    fn from(i: super::#p) -> Self { Self::#id(i) }
271                }
272                impl rust_automata::Enumerated<#enum_ident> for super::#p {
273                    fn enum_id() -> rust_automata::EnumId<#enum_ident> {
274                        rust_automata::EnumId::new(#idx + 1)
275                    }
276                }
277                impl From<#enum_ident> for super::#p {
278                    fn from(o: #enum_ident) -> Self {
279                        match o {
280                            #enum_ident::#id(v) => v,
281                            _ => panic!("Invalid symbol requested from {}", stringify!(#p)),
282                        }
283                    }
284                }
285            }
286        });
287
288        quote! { #( #conversions )* }
289    }
290
291    pub fn build_alphabet(
292        derive_attr: &TokenStream2,
293        enum_ident: &Ident,
294        alphabet_paths: &Vec<Path>,
295    ) -> TokenStream2 {
296        let alphabet_ids: Vec<_> = alphabet_paths.iter().map(last).collect();
297        let enumerable_ids_alphabet = generate_enum_matches(&alphabet_ids);
298        let alphabet_getters = build_getters(alphabet_paths);
299        let alphabet_conversions = build_conversions(enum_ident, alphabet_paths);
300        quote! {
301            #derive_attr
302            pub enum #enum_ident {
303                Nothing(()),
304                #( #alphabet_ids ( super::#alphabet_paths ) ),*
305            }
306            impl rust_automata::Alphabet for #enum_ident {
307                fn nothing() -> Self { Self::Nothing(()) }
308                fn any(&self) -> bool { !matches!(self, Self::Nothing(_)) }
309            }
310            impl rust_automata::Enumerable<#enum_ident> for #enum_ident {
311                fn enum_id(&self) -> rust_automata::EnumId<#enum_ident> {
312                    match self {
313                        Self::Nothing(_) => rust_automata::EnumId::new(0),
314                        #( #enumerable_ids_alphabet ),*
315                    }
316                }
317            }
318            impl #enum_ident {
319                #alphabet_getters
320            }
321            #alphabet_conversions
322        }
323    }
324
325    pub fn build_set(
326        derive_attr: &TokenStream2,
327        enum_ident: &Ident,
328        state_paths: &Vec<Path>,
329    ) -> TokenStream2 {
330        let state_ids: Vec<_> = state_paths.iter().map(last).collect();
331        let enumerable_ids_states = generate_enum_matches(&state_ids);
332        let state_getters = build_getters(state_paths);
333        let state_conversions = build_conversions(enum_ident, state_paths);
334
335        quote! {
336            #derive_attr
337            pub enum #enum_ident {
338                Failure(()),
339                 #( #state_ids ( super::#state_paths ) ),*
340            }
341            impl rust_automata::StateTrait for #enum_ident {
342                fn failure() -> Self { Self::Failure(()) }
343                fn is_failure(&self) -> bool { matches!(self, Self::Failure(_)) }
344            }
345            impl rust_automata::Enumerable<#enum_ident> for #enum_ident {
346                fn enum_id(&self) -> rust_automata::EnumId<#enum_ident> {
347                    match self {
348                        Self::Failure(_) => rust_automata::EnumId::new(0),
349                        #( #enumerable_ids_states ),*
350                    }
351                }
352            }
353            impl #enum_ident {
354                #state_getters
355            }
356            #state_conversions
357        }
358    }
359
360    pub fn compute_symbol_index(
361        needle: Option<&syn::Path>,
362        symbols: &[syn::Path],
363        tr: &parser::Transition,
364    ) -> usize {
365        match needle {
366            Some(symbol) => {
367                1 + symbols
368                    .iter()
369                    .position(|p| key(p) == key(symbol))
370                    .unwrap_or_else(|| {
371                        panic!("Symbol {} not found in transition: {}", key(symbol), tr);
372                    })
373            }
374            None => 0,
375        }
376    }
377}
378
379/// The main macro for defining automata.
380///
381/// See [rust-automata](https://github.com/michalsustr/rust-automata) crate for more details.
382#[proc_macro_attribute]
383pub fn state_machine(attr: TokenStream, item: TokenStream) -> TokenStream {
384    use building_blocks::*;
385
386    // Parse attribute + struct
387    let m: MachineAttr = parse_macro_input!(attr as MachineAttr);
388    let errors = validate_machine_attr(&m);
389    if !errors.is_empty() {
390        return errors.into();
391    }
392
393    // Prepare all the identifiers and lists
394    let machine_ts: TokenStream2 = item.clone().into();
395    let machine: ItemStruct = parse_macro_input!(item as ItemStruct);
396    let machine_ident = machine.ident.clone();
397    let vis = machine.vis.clone();
398    let base = strip_machine(&machine_ident);
399    let internal_mod = format_ident!("internal_{}", base);
400    let state_enum_ident = format_ident!("{}State", base);
401    let input_enum_ident = format_ident!("{}Input", base);
402    let output_enum_ident = format_ident!("{}Output", base);
403    let initial_state_ident = &m.states.first().unwrap();
404    let nothing_ident = format_ident!("Nothing");
405    // pre‑compute frequently‑used lists
406    let state_paths = &m.states;
407    let input_paths = &m.inputs;
408    let output_paths = &m.outputs;
409    let derives = &m.derives;
410
411    // Simplify derive attribute generation
412    let (derive_attr, derive_struct) = if derives.is_empty() {
413        (quote!( #[derive(Display)] ), quote! {})
414    } else {
415        (
416            quote!( #[derive(Display, #( #derives ),* )] ),
417            quote! {#[derive(Default, #( #derives ),* )]},
418        )
419    };
420
421    let maybe_generate_structs = state_paths
422        .iter()
423        .chain(input_paths.iter())
424        .chain(output_paths.iter())
425        .filter_map(|p| {
426            if m.generate_structs {
427                Some(quote! {
428                    #derive_struct
429                    pub struct #p;
430                })
431            } else {
432                None
433            }
434        });
435
436    let transition_match_arms = m.transitions.iter().enumerate().map(|(idx, tr)| {
437        let from_id = last(&tr.from_state);
438        let to_id = last(&tr.to_state);
439        let inp_id = tr.input.as_ref().map(last).unwrap_or(&nothing_ident);
440        let out_id = tr.output.as_ref().map(last).unwrap_or(&nothing_ident);
441        let state_var = format_ident!("state{idx}");
442        let input_var = format_ident!("input{idx}");
443        let to_path = &tr.to_state;
444        let type_declaration = match tr.output {
445            Some(ref out_path) => quote! {
446                let next_val: super::#to_path;
447                let out_val: super::#out_path;
448            },
449            None => quote! {
450                let next_val: super::#to_path;
451                let out_val:  ();
452            },
453        };
454        let (transition_call, value_instantiation) = build_handler_code(tr, &state_var, &input_var);
455        let guard_call = build_guard_code(tr, &state_var);
456
457        quote! {
458            (Self::State::#from_id(#state_var), Self::Input::#inp_id(#input_var)) #guard_call => {
459                #type_declaration
460                #transition_call
461                #value_instantiation
462                (
463                    Self::State::#to_id(next_val),
464                    Self::Output::#out_id(out_val)
465                )
466            }
467        }
468    });
469    let can_transition_match_arms = m.transitions.iter().enumerate().map(|(idx, tr) | {
470        let from_id = last(&tr.from_state);
471        let state_var = format_ident!("state{idx}");
472        let input_idx: usize = compute_symbol_index(tr.input.as_ref(), input_paths, tr);
473        let output_idx: usize = compute_symbol_index(tr.output.as_ref(), output_paths, tr);
474        let guard_call = build_guard_code(tr, &state_var);
475        quote! {
476            (Self::State::#from_id(#state_var), #input_idx) #guard_call => Some(rust_automata::EnumId::new(#output_idx))
477        }
478    });
479
480    let input_alphabet = build_alphabet(&derive_attr, &input_enum_ident, input_paths);
481    let output_alphabet = build_alphabet(&derive_attr, &output_enum_ident, output_paths);
482    let state_set = build_set(&derive_attr, &state_enum_ident, state_paths);
483
484    let sig_checks = m
485        .transitions
486        .iter()
487        .map(|tr| make_handler_sig_check(tr, &machine_ident));
488
489    // ────────────────── annotations ──────────────────
490    let mermaid_attr = annotations::mermaid_attr(&m);
491    let dsl_attr = annotations::dsl_attr(&m);
492
493    // ────────────────── put everything together ──────────────────
494    let output = quote! {
495        #mermaid_attr
496        #dsl_attr
497        #machine_ts
498
499        #( #maybe_generate_structs )*
500
501        #[allow(non_snake_case)]
502        #[doc(hidden)]
503        #vis mod #internal_mod {
504            use rust_automata::*;
505
506            #state_set
507            #input_alphabet
508            #output_alphabet
509
510            impl rust_automata::StateMachineImpl for super::#machine_ident {
511                type Input  = #input_enum_ident;
512                type State  = #state_enum_ident;
513                type Output = #output_enum_ident;
514                type InitialState = super::#initial_state_ident;
515                fn transition(
516                    &mut self,
517                    mut state: rust_automata::Takeable<Self::State>,
518                    input: Self::Input,
519                ) -> (rust_automata::Takeable<Self::State>, Self::Output) {
520
521                    // Make nice error messages
522                    #( #sig_checks )*
523
524                    let out = state.borrow_result(|old_state| {
525                        match (old_state, input) {
526                            #( #transition_match_arms , )*
527                            (_, _) => { (Self::State::failure(), Self::Output::nothing()) }
528                        }
529                    });
530                    (state, out)
531                }
532
533                fn can_transition(&self, state: &Self::State, input: EnumId<Self::Input>) -> Option<EnumId<Self::Output>> {
534                    match (state, input.id) {
535                        #( #can_transition_match_arms , )*
536                        (_, _) => None,
537                    }
538                }
539            }
540        }
541    };
542
543    output.into()
544}
545
546/// A custom proc macro that implements `Display` for enums by extracting the enum variant name.
547///
548/// This macro will generate an implementation such that:
549/// - For a variant named `Foo`, `Display::fmt` will output `"Foo"`.
550///
551/// Intended only for the internal use with `rust-automata` crate.
552#[doc(hidden)]
553#[proc_macro_derive(Display)]
554pub fn display_derive(input: TokenStream) -> TokenStream {
555    let ast = parse_macro_input!(input as DeriveInput);
556    let name = ast.ident.clone();
557
558    // Ensure that the macro is only applied to enums.
559    let data_enum = match ast.data {
560        syn::Data::Enum(data_enum) => data_enum,
561        _ => {
562            return syn::Error::new_spanned(ast, "Display can only be derived for enums")
563                .to_compile_error()
564                .into();
565        }
566    };
567
568    // For each variant, create a match arm that writes the variant's name.
569    let arms = data_enum.variants.into_iter().map(|variant| {
570        let variant_ident = variant.ident;
571        let variant_str = variant_ident.to_string();
572        quote! {
573            Self::#variant_ident(_) => write!(f, "{}", #variant_str)
574        }
575    });
576
577    // Generate the complete implementation of Display.
578    let expanded = quote! {
579        impl std::fmt::Display for #name {
580            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581                match self {
582                    #(#arms),*
583                }
584            }
585        }
586    };
587
588    expanded.into()
589}