sigma_enum_macros/
lib.rs

1use crate::attrs::extract_expansion;
2use crate::nice_type::Infallible;
3use crate::nice_type::NiceType;
4use attrs::ItemAttr;
5use heck::ToSnakeCase;
6use nice_type::NiceTypeLit;
7use proc_macro::TokenStream;
8use quote::ToTokens;
9use quote::TokenStreamExt;
10use quote::format_ident;
11use quote::quote;
12use std::collections::BTreeMap;
13use std::collections::BTreeSet;
14use syn::Attribute;
15use syn::Expr;
16use syn::Ident;
17use syn::LitStr;
18use syn::Token;
19use syn::Visibility;
20use syn::braced;
21use syn::parenthesized;
22use syn::parse::Parse;
23use syn::parse::ParseStream;
24use syn::parse_macro_input;
25use syn::spanned::Spanned;
26
27mod attrs;
28mod nice_type;
29
30const INTERNAL_IDENT: &str = "__INTERNAL_IDENT";
31const INTERNAL_FULL_WILDCARD: &str = "__INTERNAL_FULL_WILDCARD";
32
33#[derive(Clone)]
34struct Variant {
35    ty: NiceType<Infallible>,
36    name: Ident,
37    attrs: Vec<Attribute>,
38}
39
40#[derive(Clone)]
41struct SigmaEnum {
42    visibility: Visibility,
43    name: Ident,
44    variants: Vec<Variant>,
45    subattrs: Vec<Attribute>,
46    attr: ItemAttr,
47}
48
49impl SigmaEnum {
50    fn macro_match_name(&self) -> Ident {
51        self.attr.macro_match.name.as_ref().map_or_else(
52            || format_ident!("{}_match", self.name.to_string().to_snake_case()),
53            |name| format_ident!("{}", name),
54        )
55    }
56
57    fn macro_construct_name(&self) -> Ident {
58        self.attr.macro_construct.name.as_ref().map_or_else(
59            || format_ident!("{}_construct", self.name.to_string().to_snake_case()),
60            |name| format_ident!("{}", name),
61        )
62    }
63
64    fn into_trait_name(&self) -> Ident {
65        self.attr.into_trait.name.as_ref().map_or_else(
66            || format_ident!("Into{}", self.name),
67            |name| format_ident!("{}", name),
68        )
69    }
70
71    fn into_method_name(&self) -> Ident {
72        self.attr.into_method.name.as_ref().map_or_else(
73            || format_ident!("into_{}", self.name.to_string().to_snake_case()),
74            |name| format_ident!("{}", name),
75        )
76    }
77
78    fn try_from_method_name(&self) -> Ident {
79        self.attr.try_from_method.name.as_ref().map_or_else(
80            || format_ident!("try_from_{}", self.name.to_string().to_snake_case()),
81            |name| format_ident!("{}", name),
82        )
83    }
84
85    fn try_from_owned_method_name(&self) -> Ident {
86        self.attr.try_from_owned_method.name.as_ref().map_or_else(
87            || format_ident!("try_from_owned_{}", self.name.to_string().to_snake_case()),
88            |name| format_ident!("{}", name),
89        )
90    }
91
92    fn try_from_mut_method_name(&self) -> Ident {
93        self.attr.try_from_mut_method.name.as_ref().map_or_else(
94            || format_ident!("try_from_mut_{}", self.name.to_string().to_snake_case()),
95            |name| format_ident!("{}", name),
96        )
97    }
98
99    fn extract_method_name(&self) -> Ident {
100        self.attr.extract_method.name.as_ref().map_or_else(
101            || format_ident!("extract"),
102            |name| format_ident!("{}", name),
103        )
104    }
105
106    fn extract_owned_method_name(&self) -> Ident {
107        self.attr.extract_owned_method.name.as_ref().map_or_else(
108            || format_ident!("extract_owned"),
109            |name| format_ident!("{}", name),
110        )
111    }
112
113    fn extract_mut_method_name(&self) -> Ident {
114        self.attr.extract_mut_method.name.as_ref().map_or_else(
115            || format_ident!("extract_mut"),
116            |name| format_ident!("{}", name),
117        )
118    }
119
120    fn try_from_error_name(&self) -> Ident {
121        self.attr.try_from_error.name.as_ref().map_or_else(
122            || format_ident!("TryFrom{}Error", self.name.to_string()),
123            |name| format_ident!("{}", name),
124        )
125    }
126
127    fn internal_name(&self, which: &str, suffix: &str) -> Ident {
128        format_ident!(
129            "{INTERNAL_IDENT}_{}_{}{}",
130            self.name.to_string().to_snake_case(),
131            which,
132            suffix
133        )
134    }
135
136    fn to_tokens_macros(&self, tokens: &mut proc_macro2::TokenStream, export: bool, suffix: &str) {
137        let SigmaEnum {
138            visibility: _,
139            name,
140            variants,
141            subattrs: _,
142            attr,
143        } = &self;
144
145        let path = match &attr.path {
146            Some(path) => quote! { $ #path :: },
147            None => quote! {},
148        };
149
150        let variants_btree: BTreeMap<_, _> = variants
151            .iter()
152            .map(|var| (var.ty.clone(), var.name.clone()))
153            .collect();
154        let variant_pats: Vec<_> = variants.iter().map(|var| var.ty.clone()).collect();
155
156        let macro_match = format_ident!("{}{}", self.macro_match_name(), suffix);
157        let macro_construct = format_ident!("{}{}", self.macro_construct_name(), suffix);
158        let macro_match_body = self.internal_name("body", suffix);
159        let macro_match_process_body = self.internal_name("process_body", suffix);
160        let macro_process_type = self.internal_name("process_type", suffix);
161        let macro_match_variant = self.internal_name("variant", suffix);
162        let macro_match_pattern = self.internal_name("pattern", suffix);
163        let macro_construct_inner = self.internal_name("construct_inner", suffix);
164
165        let macro_match_docstring = self.attr.macro_match.docstring();
166        let macro_construct_docstring = self.attr.macro_construct.docstring();
167
168        // https://github.com/rust-lang/rust/pull/52234#issuecomment-1417098097
169        let macro_match_export;
170        let macro_construct_export;
171        let macro_match_body_export;
172        let macro_match_process_body_export;
173        let macro_process_type_export;
174        let macro_match_variant_export;
175        let macro_match_pattern_export;
176        let macro_construct_inner_export;
177        let macro_match_pub_use;
178        let macro_construct_pub_use;
179        let macro_match_body_pub_use;
180        let macro_match_process_body_pub_use;
181        let macro_process_type_pub_use;
182        let macro_match_variant_pub_use;
183        let macro_match_pattern_pub_use;
184        let macro_construct_inner_pub_use;
185        if export {
186            macro_match_export = quote! { #macro_match_docstring #[macro_export] };
187            macro_construct_export = quote! { #macro_construct_docstring #[macro_export] };
188            macro_match_body_export = quote! { #[macro_export] };
189            macro_match_process_body_export = quote! { #[macro_export] };
190            macro_process_type_export = quote! { #[macro_export] };
191            macro_match_variant_export = quote! { #[macro_export] };
192            macro_match_pattern_export = quote! { #[macro_export] };
193            macro_construct_inner_export = quote! { #[macro_export] };
194            macro_match_pub_use = quote! {};
195            macro_construct_pub_use = quote! {};
196            macro_match_body_pub_use = quote! {};
197            macro_match_process_body_pub_use = quote! {};
198            macro_process_type_pub_use = quote! {};
199            macro_match_variant_pub_use = quote! {};
200            macro_match_pattern_pub_use = quote! {};
201            macro_construct_inner_pub_use = quote! {};
202        } else {
203            macro_match_export = quote! { #macro_match_docstring };
204            macro_construct_export = quote! { #macro_construct_docstring };
205            macro_match_body_export = quote! {};
206            macro_match_process_body_export = quote! {};
207            macro_process_type_export = quote! {};
208            macro_match_variant_export = quote! {};
209            macro_match_pattern_export = quote! {};
210            macro_construct_inner_export = quote! {};
211            macro_match_pub_use = quote! { #macro_match_docstring pub(crate) use #macro_match; };
212            macro_construct_pub_use =
213                quote! { #macro_construct_docstring pub(crate) use #macro_construct; };
214            macro_match_body_pub_use = quote! { #[doc(hidden)] pub(crate) use #macro_match_body; };
215            macro_match_process_body_pub_use =
216                quote! { #[doc(hidden)] pub(crate) use #macro_match_process_body; };
217            macro_process_type_pub_use =
218                quote! { #[doc(hidden)] pub(crate) use #macro_process_type; };
219            macro_match_variant_pub_use =
220                quote! { #[doc(hidden)] pub(crate) use #macro_match_variant; };
221            macro_match_pattern_pub_use =
222                quote! { #[doc(hidden)] pub(crate) use #macro_match_pattern; };
223            macro_construct_inner_pub_use =
224                quote! { #[doc(hidden)] pub(crate) use #macro_construct_inner; };
225        }
226
227        let internal_full_wildcard = format_ident!("{INTERNAL_FULL_WILDCARD}");
228
229        let mut patterns_map = BTreeMap::new();
230        patterns_map.insert(NiceType::PatternIdent(()), Vec::new());
231        for ty in &variant_pats {
232            for pat in ty.patterns_matching() {
233                let matches = patterns_map.entry(pat).or_insert(Vec::new());
234                matches.push(ty);
235            }
236        }
237
238        let patterns: Vec<_> = patterns_map.keys().collect();
239        let pat_variants: Vec<_> = patterns_map.values().collect();
240        let pat_variant_names: Vec<Vec<_>> = pat_variants
241            .iter()
242            .map(|v| v.iter().map(|ty| variants_btree[ty].clone()).collect())
243            .collect();
244
245        let patterns_vars: Vec<_> = patterns.iter().map(|pat| pat.index_patterns()).collect();
246        let patterns_vars_assoc: Vec<Vec<Vec<_>>> = pat_variants
247            .iter()
248            .zip(&patterns_vars)
249            .map(|(v, pat)| {
250                v.iter()
251                    .map(|ty| {
252                        ty.matches_map(&pat)
253                            .into_iter()
254                            .filter_map(|(ident, (ty, location))| {
255                                let NiceType::Literal(lit) = ty else {
256                                    return None;
257                                };
258                                // try block. sad
259                                let generic_ty = (|| {
260                                    let (parent, i) = location?;
261                                    self.attr.generics.get(&parent)?[i].as_ref()
262                                })();
263                                Some((ident, lit, generic_ty))
264                            })
265                            .collect()
266                    })
267                    .collect()
268            })
269            .collect();
270        // for each pattern, for each variant it matches, get the type pattern variables
271        // and their literals and locations, and generate let statements for them
272        let const_let_statements: Vec<Vec<proc_macro2::TokenStream>> = patterns_vars_assoc
273            .iter()
274            .map(|v| {
275                v.iter()
276                    .map(|v| {
277                        v.iter()
278                            .map(|(ident, lit, generic_ty)| match generic_ty {
279                                Some(generic_ty) => quote! { const $ #ident : #generic_ty = #lit; },
280                                None => quote! { let $ #ident = #lit; },
281                            })
282                            .map(|q| quote! { #[allow(nonstandard_style)] #[allow(unused_variables)] #q })
283                            .collect()
284                    })
285                    .collect()
286            })
287            .collect();
288
289        let pat_vars_params_eqs: Vec<Vec<Vec<_>>> = patterns_vars_assoc
290            .iter()
291            .map(|v| {
292                v.iter()
293                    .map(|v| {
294                        v.iter()
295                            .map(|(ident, lit, _generic_ty)| quote! { $ #ident == #lit })
296                            .collect()
297                    })
298                    .collect()
299            })
300            .collect();
301
302        let (pat_vars_names, pat_vars_params): (Vec<_>, Vec<_>) = patterns_vars
303            .iter()
304            .map(|pat| match pat {
305                NiceType::Ident(name, params) => (format_ident!("{}", name), {
306                    let params: Vec<_> = params
307                        .iter()
308                        .map(|param| param.map_pattern(|p| quote! { ? $ #p :ident }))
309                        .collect();
310                    (!params.is_empty())
311                        .then_some(params)
312                        .into_iter()
313                        .collect::<Vec<_>>()
314                }),
315                NiceType::PatternIdent(_p) => (
316                    internal_full_wildcard.clone(),
317                    None.into_iter().collect::<Vec<_>>(),
318                ),
319                _ => panic!("not ident {:?}", pat),
320            })
321            .unzip();
322
323        tokens.append_all(quote! {
324            #macro_match_export
325            #[allow(unused_macros)]
326            macro_rules! #macro_match {
327                ( match $( $rest:tt )* ) => {
328                    #path #macro_match_body ! { (), ( $($rest)* ) }
329                };
330            }
331            #macro_match_pub_use
332        });
333
334        tokens.append_all(quote! {
335            #macro_match_body_export
336            #[doc(hidden)]
337            macro_rules! #macro_match_body {
338                ( $what:tt, ({
339                    $( $rest:tt )*
340                }) ) => {
341                    #path #macro_match_process_body !( $what, ( $($rest)* ), () )
342                };
343                ( ( $( $what:tt )* ), ( $next:tt $( $rest:tt )* ) ) => {
344                    #path #macro_match_body ! { ( $($what)* $next ), ( $($rest)* ) }
345                };
346            }
347            #macro_match_body_pub_use
348        });
349
350        tokens.append_all(quote! {
351            #macro_match_process_body_export
352            #[doc(hidden)]
353            macro_rules! #macro_match_process_body {
354                ( $what:tt, (), ( $( ( $ty:tt; $binding:pat => $body:expr ) )* ) ) => {
355                    {
356                        let what = $what;
357
358                        #[allow(unreachable_patterns)]
359                        match what {
360                            $( #path #macro_match_pattern !($ty) => (), )*
361                        }
362
363                        #[allow(unused_labels)]
364                        'ma: {
365                            $( #path #macro_match_variant !{$ty; what; 'ma; $binding => $body} )*
366                            ::std::unreachable!();
367                        }
368                    }
369                };
370                (
371                    $what:tt,
372                    ( $binding:ident => { $( $body:tt )* } $(,)? $( $rest:tt )* ),
373                    ( $( $matched:tt )* )
374                ) => {
375                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( (#internal_full_wildcard) ; $binding => { $( $body )* } ) ) )
376                };
377                (
378                    $what:tt,
379                    ( $binding:ident => $body:expr, $( $rest:tt )* ),
380                    ( $( $matched:tt )* )
381                ) => {
382                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( (#internal_full_wildcard) ; $binding => { $body } ) ) )
383                };
384                (
385                    $what:tt,
386                    ( $tyn:ident ( $binding:pat ) => { $( $body:tt )* } $(,)? $( $rest:tt )* ),
387                    ( $( $matched:tt )* )
388                ) => {
389                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn); $binding => { $($body)* } ) ) )
390                };
391                (
392                    $what:tt,
393                    ( $tyn:ident ( $binding:pat ) => $body:expr, $( $rest:tt )* ),
394                    ( $( $matched:tt )* )
395                ) => {
396                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn); $binding => { $body } ) ) )
397                };
398                (
399                    $what:tt,
400                    ( $tyn:ident ::< $( $rest:tt )* ),
401                    ( $( $matched:tt )* )
402                ) => {
403                    #path #macro_process_type !( (@match, $what, $tyn, ($( $matched )*)), ($( $rest )*), (<), (<) )
404                };
405            }
406            #macro_match_process_body_pub_use
407        });
408
409        tokens.append_all(quote! {
410            #macro_process_type_export
411            #[doc(hidden)]
412            macro_rules! #macro_process_type {
413                ( $bundle:tt, ($(,)? > $($rest:tt)*), ( $($params:tt)* ), (< $($counter:tt)*) ) => {
414                    #path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* >), ($($counter)*) )
415                };
416                ( $bundle:tt, ($(,)? >> $($rest:tt)*), ( $($params:tt)* ), (< < $($counter:tt)*) ) => {
417                    #path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* > >), ($($counter)*) )
418                };
419                ( $bundle:tt, ($(,)? > $($rest:tt)*), ( $($params:tt)* ), () ) => {
420                    ::std::compile_error!("imbalanced")
421                };
422                ( $bundle:tt, ($(,)? >> $($rest:tt)*), ( $($params:tt)* ), () ) => {
423                    ::std::compile_error!("imbalanced")
424                };
425                ( $bundle:tt, (< $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
426                    #path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* <), (< $($counter)*) )
427                };
428                ( $bundle:tt, (<< $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
429                    #path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* < <), (< < $($counter)*) )
430                };
431                ( (@match, $what:tt, $tyn:ident, ( $($matched:tt)* )), (( $binding:pat ) => { $( $body:tt )* } $(,)? $($rest:tt)*), ( $($params:tt)* ), () ) => {
432                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn :: $($params)+); $binding => { $($body)* } ) ) )
433                };
434                ( (@match, $what:tt, $tyn:ident, ( $($matched:tt)* )), (( $binding:pat ) => $body:expr, $($rest:tt)*), ( $($params:tt)* ), () ) => {
435                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn :: $($params)+); $binding => { $body } ) ) )
436                };
437                ( (@construct, $tyn:ident), (( $expr:expr )), ( $($params:tt)+ ), () ) => {
438                    #path #macro_construct_inner !( ($tyn :: $($params)+); ( $expr ) )
439                };
440                ( $bundle:tt, (( $($any:tt)* ) $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
441                    ::std::compile_error!("imbalanced or something")
442                };
443                ( $bundle:tt, ($thing:tt $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
444                    #path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* $thing), ( $($counter)*) )
445                };
446            }
447            #macro_process_type_pub_use
448        });
449
450        tokens.append_all(quote! {
451            #macro_match_variant_export
452            #[doc(hidden)]
453            macro_rules! #macro_match_variant {
454                #( ( (#pat_vars_names #(::< #( #pat_vars_params ),* >)* ); $what:ident; $ma:lifetime; $binding:pat => $body:expr ) => {
455                    #( if let #path #name :: #pat_variant_names ($binding) = $what {
456                        #const_let_statements
457                        break $ma($body);
458                    } )*
459                }; )*
460            }
461            #macro_match_variant_pub_use
462        });
463
464        tokens.append_all(quote! {
465            #macro_match_pattern_export
466            #[doc(hidden)]
467            macro_rules! #macro_match_pattern {
468                #( ( ( #pat_vars_names #(::< #( #pat_vars_params ),* >)* ) ) => {
469                    #( #path #name :: #pat_variant_names (_) )|*
470                }; )*
471            }
472            #macro_match_pattern_pub_use
473        });
474
475        tokens.append_all(quote! {
476            #macro_construct_export
477            #[allow(unused_macros)]
478            #macro_construct_docstring
479            macro_rules! #macro_construct {
480                ( $tyn:ident ::< $($tt:tt)* ) => {
481                    #path #macro_process_type !( (@construct, $tyn), ($($tt)*), (<), (<) )
482                };
483                ( $tyn:ident ( $body:expr ) ) => {
484                    #path #macro_construct_inner !( ($tyn); ($body) )
485                };
486            }
487            #macro_construct_pub_use
488        });
489
490        tokens.append_all(quote! {
491            #macro_construct_inner_export
492            #[doc(hidden)]
493            macro_rules! #macro_construct_inner {
494                #( ( (#pat_vars_names #(::< #( #pat_vars_params ),* >)* ); $body:expr ) => {
495                    'ma: {
496                        #( if true #(&& #pat_vars_params_eqs)* {
497                            #const_let_statements
498                            break 'ma ::std::option::Option::Some(#path #name :: #pat_variant_names($body));
499                        } )*
500                        ::std::option::Option::None
501                    }
502                }; )*
503            }
504            #macro_construct_inner_pub_use
505        });
506    }
507
508    fn to_tokens_traits(&self, tokens: &mut proc_macro2::TokenStream) {
509        let SigmaEnum {
510            visibility,
511            name,
512            variants,
513            subattrs: _,
514            attr,
515        } = &self;
516
517        let variant_types: Vec<_> = variants
518            .iter()
519            .map(|var| var.ty.to_tokens_aliased(&attr.alias))
520            .collect();
521        let variant_names: Vec<_> = variants.iter().map(|var| var.name.clone()).collect();
522
523        let into_trait = self.into_trait_name();
524        let into_trait_sealed_mod = self.internal_name("into_trait_sealed_mod", "");
525        let into_method = self.into_method_name();
526        let try_from_method = self.try_from_method_name();
527        let try_from_owned_method = self.try_from_owned_method_name();
528        let try_from_mut_method = self.try_from_mut_method_name();
529        let extract_method = self.extract_method_name();
530        let extract_owned_method = self.extract_owned_method_name();
531        let extract_mut_method = self.extract_mut_method_name();
532        let try_from_error = self.try_from_error_name();
533
534        let into_trait_docstring = self.attr.into_trait.docstring();
535        let into_method_docstring = self.attr.into_method.docstring();
536        let try_from_method_docstring = self.attr.try_from_method.docstring();
537        let try_from_owned_method_docstring = self.attr.try_from_owned_method.docstring();
538        let try_from_mut_method_docstring = self.attr.try_from_mut_method.docstring();
539        let extract_method_docstring = self.attr.extract_method.docstring();
540        let extract_owned_method_docstring = self.attr.extract_owned_method.docstring();
541        let extract_mut_method_docstring = self.attr.extract_mut_method.docstring();
542        let try_from_error_docstring = self.attr.try_from_error.docstring();
543
544        let methods = quote! {
545            #into_method_docstring
546            fn #into_method (self) -> #name;
547            #try_from_method_docstring
548            fn #try_from_method (value: & #name) -> Option<&Self>;
549            #try_from_owned_method_docstring
550            fn #try_from_owned_method (value: #name) -> Option<Self>
551                where Self: ::core::marker::Sized;
552            #try_from_mut_method_docstring
553            fn #try_from_mut_method (value: &mut #name) -> Option<&mut Self>;
554        };
555
556        tokens.append_all(quote! {
557            #into_trait_docstring
558            pub trait #into_trait : #into_trait_sealed_mod ::Sealed {
559                #methods
560            }
561
562            mod #into_trait_sealed_mod {
563                pub trait Sealed {}
564            }
565
566            #(
567                #[automatically_derived]
568                impl #into_trait_sealed_mod ::Sealed for #variant_types {}
569            )*
570        });
571
572        tokens.append_all(quote! {
573            #(
574                #into_trait_docstring
575                #[automatically_derived]
576                impl #into_trait for #variant_types {
577                    fn #into_method (self) -> #name {
578                        #name :: #variant_names (self)
579                    }
580
581                    fn #try_from_method <'a>(value: &'a #name) -> ::std::option::Option<&'a Self> {
582                        if let #name :: #variant_names (out) = value {
583                            ::std::option::Option::Some(out)
584                        } else {
585                            ::std::option::Option::None
586                        }
587                    }
588
589                    fn #try_from_owned_method (value: #name) -> ::std::option::Option<Self>
590                        where Self: ::core::marker::Sized
591                    {
592                        if let #name :: #variant_names (out) = value {
593                            ::std::option::Option::Some(out)
594                        } else {
595                            ::std::option::Option::None
596                        }
597                    }
598
599                    fn #try_from_mut_method <'a>(value: &'a mut #name) -> ::std::option::Option<&'a mut Self> {
600                        if let #name :: #variant_names (out) = value {
601                            ::std::option::Option::Some(out)
602                        } else {
603                            ::std::option::Option::None
604                        }
605                    }
606                }
607
608                #[automatically_derived]
609                impl ::std::convert::From<#variant_types> for #name {
610                    fn from(value: #variant_types) -> Self {
611                        #into_trait :: #into_method (value)
612                    }
613                }
614
615                #[automatically_derived]
616                impl<'a> ::std::convert::TryFrom<&'a #name> for &'a #variant_types {
617                    type Error = #try_from_error;
618                    fn try_from(value: &'a #name) -> Result<&'a #variant_types, #try_from_error > {
619                       < #variant_types as #into_trait >:: #try_from_method (value).ok_or( #try_from_error )
620                    }
621                }
622
623                #[automatically_derived]
624                impl ::std::convert::TryFrom<#name> for #variant_types
625                        where Self: ::core::marker::Sized
626                {
627                    type Error = #try_from_error;
628                    fn try_from(value: #name) -> Result<#variant_types, #try_from_error > {
629                       < #variant_types as #into_trait >:: #try_from_owned_method (value).ok_or( #try_from_error )
630                    }
631                }
632            )*
633
634            impl #name {
635                #extract_method_docstring
636                #visibility fn #extract_method <T: #into_trait >(&self) -> Option<&T> {
637                    T:: #try_from_method (self)
638                }
639
640                #extract_owned_method_docstring
641                #visibility fn #extract_owned_method <T: #into_trait >(self) -> Option<T> {
642                    T:: #try_from_owned_method (self)
643                }
644
645                #extract_mut_method_docstring
646                #visibility fn #extract_mut_method <T: #into_trait >(&mut self) -> Option<&mut T> {
647                    T:: #try_from_mut_method (self)
648                }
649            }
650        });
651
652        tokens.append_all(quote! {
653            pub struct #try_from_error;
654
655            #[automatically_derived]
656            impl ::core::fmt::Debug for #try_from_error {
657                #[inline]
658                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
659                    ::core::fmt::Formatter::write_str(f, ::std::stringify!(#try_from_error))
660                }
661            }
662            #[automatically_derived]
663            impl ::core::clone::Clone for #try_from_error {
664                #[inline]
665                fn clone(&self) -> #try_from_error {
666                    *self
667                }
668            }
669            #[automatically_derived]
670            impl ::core::marker::Copy for #try_from_error {}
671            #[automatically_derived]
672            impl ::core::cmp::PartialEq for #try_from_error {
673                #[inline]
674                fn eq(&self, other: & #try_from_error) -> bool {
675                    true
676                }
677            }
678            #[automatically_derived]
679            impl ::core::cmp::Eq for #try_from_error {}
680            #[automatically_derived]
681            impl ::core::hash::Hash for #try_from_error {
682                #[inline]
683                fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () {}
684            }
685            #[automatically_derived]
686            impl ::core::cmp::PartialOrd for #try_from_error {
687                #[inline]
688                fn partial_cmp(&self, other: & #try_from_error) -> ::core::option::Option<::core::cmp::Ordering> {
689                    ::core::option::Option::Some(::core::cmp::Ordering::Equal)
690                }
691            }
692            #[automatically_derived]
693            impl ::core::cmp::Ord for #try_from_error {
694                #[inline]
695                fn cmp(&self, other: & #try_from_error) -> ::core::cmp::Ordering {
696                    ::core::cmp::Ordering::Equal
697                }
698            }
699
700            #[automatically_derived]
701            impl ::std::fmt::Display for #try_from_error {
702                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
703                    f.write_str("attempted to extract value from a ")?;
704                    f.write_str(::std::stringify!( #name ))?;
705                    f.write_str(" holding a different type")?;
706                    ::std::fmt::Result::Ok(())
707                }
708            }
709
710           #try_from_error_docstring
711            #[automatically_derived]
712            impl ::std::error::Error for #try_from_error {}
713        });
714    }
715}
716
717impl ToTokens for SigmaEnum {
718    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
719        let SigmaEnum {
720            visibility,
721            name,
722            variants,
723            subattrs,
724            attr,
725        } = &self;
726
727        let variant_types: Vec<_> = variants
728            .iter()
729            .map(|var| var.ty.to_tokens_aliased(&attr.alias))
730            .collect();
731        let variant_names: Vec<_> = variants.iter().map(|var| var.name.clone()).collect();
732        let variant_attrs: Vec<_> = variants.iter().map(|var| var.attrs.clone()).collect();
733
734        tokens.append_all(quote! {
735            #(#subattrs)*
736            #visibility enum #name {
737                #(
738                    #(#variant_attrs)*
739                    #variant_names(#variant_types),
740                )*
741            }
742        });
743
744        match visibility {
745            Visibility::Public(_) => {
746                self.to_tokens_macros(tokens, true, "");
747                self.to_tokens_macros(tokens, false, "_crate");
748            }
749            _ => {
750                self.to_tokens_macros(tokens, false, "");
751            }
752        }
753        self.to_tokens_traits(tokens);
754    }
755}
756
757impl Parse for SigmaEnum {
758    fn parse(input: ParseStream) -> syn::Result<Self> {
759        let subattrs = Attribute::parse_outer(input)?;
760        let visibility: Visibility = input.parse()?;
761        let _: Token![enum] = input.parse()?;
762        let name: Ident = input.parse()?;
763        let content;
764        braced!(content in input);
765        let mut variants = Vec::new();
766        let mut variant_tys = BTreeSet::new();
767        let mut attrs = Vec::new();
768        while !content.is_empty() {
769            let mut expand = BTreeMap::new();
770            let mut rename = None;
771            if let Ok(attributes) = content.call(Attribute::parse_outer) {
772                for attr in &attributes {
773                    if attr.path().is_ident("sigma_enum") {
774                        attr.parse_nested_meta(|meta| {
775                            match meta.path.require_ident()?.to_string().as_str() {
776                                "expand" => {
777                                    meta.parse_nested_meta(|meta| {
778                                        let ident = meta.path.require_ident()?;
779                                        let value: Expr = meta.value()?.parse()?;
780                                        let value = extract_expansion(&value)?;
781                                        if expand.contains_key(ident) {
782                                            return Err(syn::Error::new(
783                                                meta.path.span(),
784                                                "duplicate expand attribute",
785                                            ));
786                                        }
787                                        expand.insert(ident.clone(), value);
788                                        Ok(())
789                                    })?;
790                                }
791                                "rename" => {
792                                    if rename.is_some() {
793                                        return Err(syn::Error::new(
794                                            meta.path.span(),
795                                            "duplicate rename attribute",
796                                        ));
797                                    }
798                                    let _: Token![=] = meta.input.parse()?;
799                                    if let Ok(ident) = meta.input.parse::<Ident>() {
800                                        rename = Some(ident.to_string());
801                                    } else if let Ok(template) = meta.input.parse::<LitStr>() {
802                                        rename = Some(template.value());
803                                    } else {
804                                        return Err(syn::Error::new(
805                                            meta.input.span(),
806                                            "invalid renaming template",
807                                        ));
808                                    }
809                                }
810                                _ => {
811                                    return Err(syn::Error::new(meta.path.span(), "invalid attr"));
812                                }
813                            }
814                            Ok(())
815                        })?;
816                    } else {
817                        attrs.push(attr.clone());
818                    }
819                }
820            }
821
822            // variant name
823            // we cannot have rename and variant name
824
825            let enum_var_name: Ident = content.parse()?;
826            let enum_var_name =
827                (!enum_var_name.to_string().starts_with("_")).then_some(enum_var_name);
828            if rename.is_some() && enum_var_name.is_some() {
829                return Err(syn::Error::new(
830                    enum_var_name.span(),
831                    "cannot use variant name and rename attribute",
832                ));
833            }
834            if !expand.is_empty() && enum_var_name.is_some() {
835                return Err(syn::Error::new(
836                    enum_var_name.span(),
837                    "cannot use variant name and expand attribute",
838                ));
839            }
840
841            let ty_paren;
842            parenthesized!(ty_paren in content);
843            let nice_type: NiceType<Infallible> = ty_paren.parse()?;
844            assert!(ty_paren.is_empty());
845            let _ = content.parse::<Token![,]>();
846
847            if rename.as_deref().is_some_and(|rename| {
848                !expand
849                    .keys()
850                    .all(|ident| rename.contains(&format!("{{{}}}", ident)))
851            }) {
852                return Err(syn::Error::new(
853                    enum_var_name.span(),
854                    "rename template does not have all metavariables",
855                ));
856            }
857
858            let cartesian: Vec<Vec<(Ident, NiceTypeLit)>> =
859                expand
860                    .into_iter()
861                    .fold(vec![Vec::new()], |accum, (ident, range)| {
862                        accum
863                            .into_iter()
864                            .flat_map(|a| {
865                                range.iter().map({
866                                    let ident = &ident;
867                                    move |r| {
868                                        let mut a = a.clone();
869                                        a.push((ident.clone(), r.clone()));
870                                        a
871                                    }
872                                })
873                            })
874                            .collect()
875                    });
876
877            for assignments in cartesian {
878                let mut var_type = nice_type.clone();
879                for (ident, r) in &assignments {
880                    var_type = var_type.replace_ident(&ident.to_string(), &r)
881                }
882                let name = match &rename {
883                    Some(template) => {
884                        let mut name = template.clone();
885                        for (var, val) in assignments {
886                            name =
887                                name.replace(&format!("{{{}}}", var), &val.variant_name_string());
888                        }
889                        if name.contains('{') {
890                            return Err(syn::Error::new(
891                                template.span(),
892                                "invalid metavariable in rename template",
893                            ));
894                        }
895                        format_ident!("{}", name)
896                    }
897                    None => match &enum_var_name {
898                        Some(enum_var_name) => enum_var_name.clone(),
899                        None => var_type.variant_name(),
900                    },
901                };
902                if !variant_tys.insert(var_type.clone()) {
903                    return Err(syn::Error::new(var_type.span(), "duplicate variant types"));
904                }
905                variants.push(Variant {
906                    ty: var_type,
907                    name,
908                    attrs: attrs.clone(),
909                });
910            }
911        }
912
913        Ok(SigmaEnum {
914            visibility,
915            name,
916            variants,
917            subattrs,
918            attr: ItemAttr::default(),
919        })
920    }
921}
922
923#[proc_macro_attribute]
924pub fn sigma_enum(attr: TokenStream, item: TokenStream) -> TokenStream {
925    // Parse the input tokens into a syntax tree
926    let mut sigma_enum = parse_macro_input!(item as SigmaEnum);
927    let attr = parse_macro_input!(attr as ItemAttr);
928    sigma_enum.attr = attr;
929
930    // panic!("{}", quote! { #sigma_enum });
931    quote! { #sigma_enum }.into()
932}