1#![recursion_limit = "256"]
10
11extern crate proc_macro;
12
13use proc_macro::TokenStream;
14use proc_macro2::TokenStream as TokenStream2;
15use quote::{format_ident, quote};
16use std::collections::BTreeSet;
17use syn::{parse_macro_input, DeriveInput, Ident, ItemStruct, Path};
18
19mod parser;
20use parser::{MachineAttr, Transition};
21
22mod annotations;
23
24const HANDLE_PREFIX: &str = "handle_";
26const GUARD_PREFIX: &str = "guard_";
28
29mod util {
30 use super::*;
31 use heck::ToSnakeCase;
32
33 pub fn snake(id: &Ident) -> Ident {
35 Ident::new(&id.to_string().to_snake_case(), id.span())
36 }
37
38 pub fn snake_path(p: &Path) -> Ident {
40 snake(last(p))
41 }
42
43 pub fn last(p: &Path) -> &Ident {
45 &p.segments.last().unwrap().ident
46 }
47
48 pub fn key(p: &Path) -> String {
50 p.segments
51 .iter()
52 .map(|s| s.ident.to_string())
53 .collect::<Vec<_>>()
54 .join("::")
55 }
56
57 pub fn strip_machine(id: &Ident) -> String {
59 let s = id.to_string();
60 s.strip_suffix("Machine").unwrap_or(&s).to_owned()
61 }
62
63 pub fn compile_error_if(condition: bool, message: &str) -> Option<TokenStream2> {
64 condition.then(|| quote! { compile_error!(#message); })
65 }
66}
67
68use util::*;
69
70mod building_blocks {
71 use super::*;
72 pub fn make_handler_sig_check(tr: &Transition, machine_ident: &Ident) -> TokenStream2 {
74 match tr.handler {
75 Some(ref handler) if handler.to_string().starts_with(HANDLE_PREFIX) => {
76 let state_ty = &tr.from_state;
77 let to_ty = &tr.to_state;
78
79 match (tr.input.as_ref(), tr.output.as_ref()) {
80 (Some(inp_ty), Some(out_ty)) => quote! {
81 super::#machine_ident::#handler as fn(&mut super::#machine_ident, super::#state_ty, super::#inp_ty) -> (super::#to_ty, super::#out_ty);
82 },
83 (Some(inp_ty), None) => quote! {
84 super::#machine_ident::#handler as fn(&mut super::#machine_ident, super::#state_ty, super::#inp_ty) -> super::#to_ty;
85 },
86 (None, Some(out_ty)) => quote! {
87 super::#machine_ident::#handler as fn(&mut super::#machine_ident, super::#state_ty) -> (super::#to_ty, super::#out_ty);
88 },
89 (None, None) => quote! {
90 super::#machine_ident::#handler as fn(&mut super::#machine_ident, super::#state_ty) -> super::#to_ty;
91 },
92 }
93 }
94 _ => quote! {},
95 }
96 }
97
98 pub fn instantiate_vals(tr: &parser::Transition, state_var: &Ident) -> TokenStream2 {
99 let next_val = if key(&tr.from_state) == key(&tr.to_state) {
100 quote! { #state_var }
101 } else {
102 let to_path = &tr.to_state;
103 quote! { super::#to_path::default() }
104 };
105 let out_val = if tr.output.is_some() {
106 let out_path = tr.output.as_ref().unwrap();
107 quote! { super::#out_path::default() }
108 } else {
109 quote! { () }
110 };
111
112 quote! {
113 next_val = #next_val;
114 out_val = #out_val;
115 }
116 }
117
118 pub fn build_handler_code(
119 tr: &Transition,
120 state_var: &Ident,
121 input_var: &Ident,
122 ) -> (TokenStream2, TokenStream2) {
123 match &tr.handler {
124 Some(handler) if handler.to_string().starts_with(HANDLE_PREFIX) => {
125 let has_input = tr.input.is_some();
126 let has_output = tr.output.is_some();
127 let call = match (has_input, has_output) {
128 (true, true) => {
129 quote! { (next_val, out_val) = self.#handler(#state_var, #input_var); }
130 }
131 (true, false) => {
132 quote! { next_val = self.#handler(#state_var, #input_var); out_val = (); }
133 }
134 (false, true) => quote! { (next_val, out_val) = self.#handler(#state_var); },
135 (false, false) => {
136 quote! { next_val = self.#handler(#state_var); out_val = (); }
137 }
138 };
139 (call, quote! {})
140 }
141 Some(callback) => (
142 quote! { self.#callback(); },
143 instantiate_vals(tr, state_var),
144 ),
145 None => (quote! {}, instantiate_vals(tr, state_var)),
146 }
147 }
148
149 pub fn build_guard_code(tr: &Transition, state_var: &Ident) -> TokenStream2 {
150 match &tr.guard {
151 Some(expr) => {
152 fn transform_expr(expr: &syn::Expr, state_var: &Ident) -> TokenStream2 {
153 match expr {
154 syn::Expr::Path(expr_path) => {
155 let ident = &expr_path.path;
156 if key(ident).starts_with(GUARD_PREFIX) {
157 quote! { (&self).#ident(&#state_var) }
158 } else {
159 quote! { (&self).#ident() }
160 }
161 }
162 syn::Expr::Binary(binary) => {
163 let left = transform_expr(&binary.left, state_var);
164 let op = &binary.op;
165 let right = transform_expr(&binary.right, state_var);
166 quote! { #left #op #right }
167 }
168 syn::Expr::Unary(unary) => {
169 let op = &unary.op;
170 let expr = transform_expr(&unary.expr, state_var);
171 quote! { #op #expr }
172 }
173 _ => panic!("Unsupported expression: {}", parser::token_to_string(expr)),
174 }
175 }
176
177 let transformed = transform_expr(expr, state_var);
178 quote! { if #transformed }
179 }
180 None => quote! {},
182 }
183 }
184
185 pub fn validate_machine_attr(m: &MachineAttr) -> TokenStream2 {
188 let states_set: BTreeSet<String> = m.states.iter().map(key).collect();
189 let inputs_set: BTreeSet<String> = m.inputs.iter().map(key).collect();
190 let outputs_set: BTreeSet<String> = m.outputs.iter().map(key).collect();
191
192 if states_set.is_empty() {
193 return quote! { compile_error!("No states are defined"); };
194 }
195 let errors = m.transitions.iter().flat_map(|tr| {
196 let tr_descr = tr.to_string();
197 vec![
198 compile_error_if(
199 !states_set.contains(&key(&tr.from_state)),
200 &format!("Unknown state: {} in {}", key(&tr.from_state), tr_descr),
201 ),
202 compile_error_if(
203 !states_set.contains(&key(&tr.to_state)),
204 &format!("Unknown state: {} in {}", key(&tr.to_state), tr_descr),
205 ),
206 tr.input.as_ref().and_then(|i| {
207 compile_error_if(
208 !inputs_set.contains(&key(i)),
209 &format!("Unknown input: {} in {}", key(i), tr_descr),
210 )
211 }),
212 tr.output.as_ref().and_then(|o| {
213 compile_error_if(
214 !outputs_set.contains(&key(o)),
215 &format!("Unknown output: {} in {}", key(o), tr_descr),
216 )
217 }),
218 tr.handler.as_ref().and_then(|h| {
219 compile_error_if(
220 h.to_string().starts_with(GUARD_PREFIX),
221 &format!("Handler cannot start with guard_ prefix: {}", h),
222 )
223 }),
224 ]
225 .into_iter()
226 .flatten()
227 });
228 quote! { #(#errors)* }
229 }
230
231 pub fn generate_enum_matches(ids: &Vec<&Ident>) -> Vec<proc_macro2::TokenStream> {
233 ids.iter()
234 .enumerate()
235 .map(|(idx, id)| {
236 let idx = idx + 1;
237 quote! {
238 Self::#id(_) => rust_automata::EnumId::new(#idx)
239 }
240 })
241 .collect()
242 }
243
244 pub fn build_getters(alphabet_paths: &[Path]) -> TokenStream2 {
245 let getters = alphabet_paths.iter().map(|p| {
246 let id = last(p);
247 let direct_fn = snake_path(p);
248 let is_fn = format_ident!("is_{}", direct_fn);
249 let maybe_fn = format_ident!("maybe_{}", direct_fn);
250 quote! {
251 pub fn #is_fn(&self) -> bool {
252 matches!(self, Self::#id(_))
253 }
254 pub fn #maybe_fn(&self) -> Option<&super::#p> {
255 if let Self::#id(o) = self { Some(o) } else { None }
256 }
257 pub fn #direct_fn(&self) -> &super::#p {
258 self.#maybe_fn().expect(&format!("No such symbol like {}", stringify!(#direct_fn)))
259 }
260 }
261 });
262 quote! { #( #getters )* }
263 }
264
265 pub fn build_conversions(enum_ident: &Ident, alphabet_paths: &[Path]) -> TokenStream2 {
266 let conversions = alphabet_paths.iter().enumerate().map(|(idx, p)| {
267 let id = last(p);
268 quote! {
269 impl From<super::#p> for #enum_ident {
270 fn from(i: super::#p) -> Self { Self::#id(i) }
271 }
272 impl rust_automata::Enumerated<#enum_ident> for super::#p {
273 fn enum_id() -> rust_automata::EnumId<#enum_ident> {
274 rust_automata::EnumId::new(#idx + 1)
275 }
276 }
277 impl From<#enum_ident> for super::#p {
278 fn from(o: #enum_ident) -> Self {
279 match o {
280 #enum_ident::#id(v) => v,
281 _ => panic!("Invalid symbol requested from {}", stringify!(#p)),
282 }
283 }
284 }
285 }
286 });
287
288 quote! { #( #conversions )* }
289 }
290
291 pub fn build_alphabet(
292 derive_attr: &TokenStream2,
293 enum_ident: &Ident,
294 alphabet_paths: &Vec<Path>,
295 ) -> TokenStream2 {
296 let alphabet_ids: Vec<_> = alphabet_paths.iter().map(last).collect();
297 let enumerable_ids_alphabet = generate_enum_matches(&alphabet_ids);
298 let alphabet_getters = build_getters(alphabet_paths);
299 let alphabet_conversions = build_conversions(enum_ident, alphabet_paths);
300 quote! {
301 #derive_attr
302 pub enum #enum_ident {
303 Nothing(()),
304 #( #alphabet_ids ( super::#alphabet_paths ) ),*
305 }
306 impl rust_automata::Alphabet for #enum_ident {
307 fn nothing() -> Self { Self::Nothing(()) }
308 fn any(&self) -> bool { !matches!(self, Self::Nothing(_)) }
309 }
310 impl rust_automata::Enumerable<#enum_ident> for #enum_ident {
311 fn enum_id(&self) -> rust_automata::EnumId<#enum_ident> {
312 match self {
313 Self::Nothing(_) => rust_automata::EnumId::new(0),
314 #( #enumerable_ids_alphabet ),*
315 }
316 }
317 }
318 impl #enum_ident {
319 #alphabet_getters
320 }
321 #alphabet_conversions
322 }
323 }
324
325 pub fn build_set(
326 derive_attr: &TokenStream2,
327 enum_ident: &Ident,
328 state_paths: &Vec<Path>,
329 ) -> TokenStream2 {
330 let state_ids: Vec<_> = state_paths.iter().map(last).collect();
331 let enumerable_ids_states = generate_enum_matches(&state_ids);
332 let state_getters = build_getters(state_paths);
333 let state_conversions = build_conversions(enum_ident, state_paths);
334
335 quote! {
336 #derive_attr
337 pub enum #enum_ident {
338 Failure(()),
339 #( #state_ids ( super::#state_paths ) ),*
340 }
341 impl rust_automata::StateTrait for #enum_ident {
342 fn failure() -> Self { Self::Failure(()) }
343 fn is_failure(&self) -> bool { matches!(self, Self::Failure(_)) }
344 }
345 impl rust_automata::Enumerable<#enum_ident> for #enum_ident {
346 fn enum_id(&self) -> rust_automata::EnumId<#enum_ident> {
347 match self {
348 Self::Failure(_) => rust_automata::EnumId::new(0),
349 #( #enumerable_ids_states ),*
350 }
351 }
352 }
353 impl #enum_ident {
354 #state_getters
355 }
356 #state_conversions
357 }
358 }
359
360 pub fn compute_symbol_index(
361 needle: Option<&syn::Path>,
362 symbols: &[syn::Path],
363 tr: &parser::Transition,
364 ) -> usize {
365 match needle {
366 Some(symbol) => {
367 1 + symbols
368 .iter()
369 .position(|p| key(p) == key(symbol))
370 .unwrap_or_else(|| {
371 panic!("Symbol {} not found in transition: {}", key(symbol), tr);
372 })
373 }
374 None => 0,
375 }
376 }
377}
378
379#[proc_macro_attribute]
383pub fn state_machine(attr: TokenStream, item: TokenStream) -> TokenStream {
384 use building_blocks::*;
385
386 let m: MachineAttr = parse_macro_input!(attr as MachineAttr);
388 let errors = validate_machine_attr(&m);
389 if !errors.is_empty() {
390 return errors.into();
391 }
392
393 let machine_ts: TokenStream2 = item.clone().into();
395 let machine: ItemStruct = parse_macro_input!(item as ItemStruct);
396 let machine_ident = machine.ident.clone();
397 let vis = machine.vis.clone();
398 let base = strip_machine(&machine_ident);
399 let internal_mod = format_ident!("internal_{}", base);
400 let state_enum_ident = format_ident!("{}State", base);
401 let input_enum_ident = format_ident!("{}Input", base);
402 let output_enum_ident = format_ident!("{}Output", base);
403 let initial_state_ident = &m.states.first().unwrap();
404 let nothing_ident = format_ident!("Nothing");
405 let state_paths = &m.states;
407 let input_paths = &m.inputs;
408 let output_paths = &m.outputs;
409 let derives = &m.derives;
410
411 let (derive_attr, derive_struct) = if derives.is_empty() {
413 (quote!( #[derive(Display)] ), quote! {})
414 } else {
415 (
416 quote!( #[derive(Display, #( #derives ),* )] ),
417 quote! {#[derive(Default, #( #derives ),* )]},
418 )
419 };
420
421 let maybe_generate_structs = state_paths
422 .iter()
423 .chain(input_paths.iter())
424 .chain(output_paths.iter())
425 .filter_map(|p| {
426 if m.generate_structs {
427 Some(quote! {
428 #derive_struct
429 pub struct #p;
430 })
431 } else {
432 None
433 }
434 });
435
436 let transition_match_arms = m.transitions.iter().enumerate().map(|(idx, tr)| {
437 let from_id = last(&tr.from_state);
438 let to_id = last(&tr.to_state);
439 let inp_id = tr.input.as_ref().map(last).unwrap_or(¬hing_ident);
440 let out_id = tr.output.as_ref().map(last).unwrap_or(¬hing_ident);
441 let state_var = format_ident!("state{idx}");
442 let input_var = format_ident!("input{idx}");
443 let to_path = &tr.to_state;
444 let type_declaration = match tr.output {
445 Some(ref out_path) => quote! {
446 let next_val: super::#to_path;
447 let out_val: super::#out_path;
448 },
449 None => quote! {
450 let next_val: super::#to_path;
451 let out_val: ();
452 },
453 };
454 let (transition_call, value_instantiation) = build_handler_code(tr, &state_var, &input_var);
455 let guard_call = build_guard_code(tr, &state_var);
456
457 quote! {
458 (Self::State::#from_id(#state_var), Self::Input::#inp_id(#input_var)) #guard_call => {
459 #type_declaration
460 #transition_call
461 #value_instantiation
462 (
463 Self::State::#to_id(next_val),
464 Self::Output::#out_id(out_val)
465 )
466 }
467 }
468 });
469 let can_transition_match_arms = m.transitions.iter().enumerate().map(|(idx, tr) | {
470 let from_id = last(&tr.from_state);
471 let state_var = format_ident!("state{idx}");
472 let input_idx: usize = compute_symbol_index(tr.input.as_ref(), input_paths, tr);
473 let output_idx: usize = compute_symbol_index(tr.output.as_ref(), output_paths, tr);
474 let guard_call = build_guard_code(tr, &state_var);
475 quote! {
476 (Self::State::#from_id(#state_var), #input_idx) #guard_call => Some(rust_automata::EnumId::new(#output_idx))
477 }
478 });
479
480 let input_alphabet = build_alphabet(&derive_attr, &input_enum_ident, input_paths);
481 let output_alphabet = build_alphabet(&derive_attr, &output_enum_ident, output_paths);
482 let state_set = build_set(&derive_attr, &state_enum_ident, state_paths);
483
484 let sig_checks = m
485 .transitions
486 .iter()
487 .map(|tr| make_handler_sig_check(tr, &machine_ident));
488
489 let mermaid_attr = annotations::mermaid_attr(&m);
491 let dsl_attr = annotations::dsl_attr(&m);
492
493 let output = quote! {
495 #mermaid_attr
496 #dsl_attr
497 #machine_ts
498
499 #( #maybe_generate_structs )*
500
501 #[allow(non_snake_case)]
502 #[doc(hidden)]
503 #vis mod #internal_mod {
504 use rust_automata::*;
505
506 #state_set
507 #input_alphabet
508 #output_alphabet
509
510 impl rust_automata::StateMachineImpl for super::#machine_ident {
511 type Input = #input_enum_ident;
512 type State = #state_enum_ident;
513 type Output = #output_enum_ident;
514 type InitialState = super::#initial_state_ident;
515 fn transition(
516 &mut self,
517 mut state: rust_automata::Takeable<Self::State>,
518 input: Self::Input,
519 ) -> (rust_automata::Takeable<Self::State>, Self::Output) {
520
521 #( #sig_checks )*
523
524 let out = state.borrow_result(|old_state| {
525 match (old_state, input) {
526 #( #transition_match_arms , )*
527 (_, _) => { (Self::State::failure(), Self::Output::nothing()) }
528 }
529 });
530 (state, out)
531 }
532
533 fn can_transition(&self, state: &Self::State, input: EnumId<Self::Input>) -> Option<EnumId<Self::Output>> {
534 match (state, input.id) {
535 #( #can_transition_match_arms , )*
536 (_, _) => None,
537 }
538 }
539 }
540 }
541 };
542
543 output.into()
544}
545
546#[doc(hidden)]
553#[proc_macro_derive(Display)]
554pub fn display_derive(input: TokenStream) -> TokenStream {
555 let ast = parse_macro_input!(input as DeriveInput);
556 let name = ast.ident.clone();
557
558 let data_enum = match ast.data {
560 syn::Data::Enum(data_enum) => data_enum,
561 _ => {
562 return syn::Error::new_spanned(ast, "Display can only be derived for enums")
563 .to_compile_error()
564 .into();
565 }
566 };
567
568 let arms = data_enum.variants.into_iter().map(|variant| {
570 let variant_ident = variant.ident;
571 let variant_str = variant_ident.to_string();
572 quote! {
573 Self::#variant_ident(_) => write!(f, "{}", #variant_str)
574 }
575 });
576
577 let expanded = quote! {
579 impl std::fmt::Display for #name {
580 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581 match self {
582 #(#arms),*
583 }
584 }
585 }
586 };
587
588 expanded.into()
589}