sigma_enum_macros/
lib.rs

1use crate::attrs::extract_expansion;
2use crate::nice_type::Infallible;
3use crate::nice_type::NiceType;
4use attrs::ItemAttr;
5use heck::ToSnakeCase;
6use nice_type::NiceTypeLit;
7use proc_macro::TokenStream;
8use quote::ToTokens;
9use quote::TokenStreamExt;
10use quote::format_ident;
11use quote::quote;
12use std::collections::BTreeMap;
13use std::collections::BTreeSet;
14use syn::Attribute;
15use syn::Expr;
16use syn::Ident;
17use syn::LitStr;
18use syn::Token;
19use syn::Visibility;
20use syn::braced;
21use syn::parenthesized;
22use syn::parse::Parse;
23use syn::parse::ParseStream;
24use syn::parse_macro_input;
25use syn::spanned::Spanned;
26
27mod attrs;
28mod nice_type;
29
30const INTERNAL_IDENT: &str = "__INTERNAL_IDENT";
31const INTERNAL_FULL_WILDCARD: &str = "__INTERNAL_FULL_WILDCARD";
32
33#[derive(Clone)]
34struct Variant {
35    ty: NiceType<Infallible>,
36    name: Ident,
37    attrs: Vec<Attribute>,
38}
39
40#[derive(Clone)]
41struct SigmaEnum {
42    visibility: Visibility,
43    name: Ident,
44    variants: Vec<Variant>,
45    subattrs: Vec<Attribute>,
46    attr: ItemAttr,
47}
48
49impl SigmaEnum {
50    fn macro_match_name(&self) -> Ident {
51        self.attr.macro_match.name.as_ref().map_or_else(
52            || format_ident!("{}_match", self.name.to_string().to_snake_case()),
53            |name| format_ident!("{}", name),
54        )
55    }
56
57    fn macro_construct_name(&self) -> Ident {
58        self.attr.macro_construct.name.as_ref().map_or_else(
59            || format_ident!("{}_construct", self.name.to_string().to_snake_case()),
60            |name| format_ident!("{}", name),
61        )
62    }
63
64    fn into_trait_name(&self) -> Ident {
65        self.attr.into_trait.name.as_ref().map_or_else(
66            || format_ident!("Into{}", self.name),
67            |name| format_ident!("{}", name),
68        )
69    }
70
71    fn into_method_name(&self) -> Ident {
72        self.attr.into_method.name.as_ref().map_or_else(
73            || format_ident!("into_{}", self.name.to_string().to_snake_case()),
74            |name| format_ident!("{}", name),
75        )
76    }
77
78    fn try_from_method_name(&self) -> Ident {
79        self.attr.try_from_method.name.as_ref().map_or_else(
80            || format_ident!("try_from_{}", self.name.to_string().to_snake_case()),
81            |name| format_ident!("{}", name),
82        )
83    }
84
85    fn try_from_owned_method_name(&self) -> Ident {
86        self.attr.try_from_owned_method.name.as_ref().map_or_else(
87            || format_ident!("try_from_owned_{}", self.name.to_string().to_snake_case()),
88            |name| format_ident!("{}", name),
89        )
90    }
91
92    fn try_from_mut_method_name(&self) -> Ident {
93        self.attr.try_from_mut_method.name.as_ref().map_or_else(
94            || format_ident!("try_from_mut_{}", self.name.to_string().to_snake_case()),
95            |name| format_ident!("{}", name),
96        )
97    }
98
99    fn extract_method_name(&self) -> Ident {
100        self.attr.extract_method.name.as_ref().map_or_else(
101            || format_ident!("extract"),
102            |name| format_ident!("{}", name),
103        )
104    }
105
106    fn extract_owned_method_name(&self) -> Ident {
107        self.attr.extract_owned_method.name.as_ref().map_or_else(
108            || format_ident!("extract_owned"),
109            |name| format_ident!("{}", name),
110        )
111    }
112
113    fn extract_mut_method_name(&self) -> Ident {
114        self.attr.extract_mut_method.name.as_ref().map_or_else(
115            || format_ident!("extract_mut"),
116            |name| format_ident!("{}", name),
117        )
118    }
119
120    fn try_from_error_name(&self) -> Ident {
121        self.attr.try_from_error.name.as_ref().map_or_else(
122            || format_ident!("TryFrom{}Error", self.name.to_string()),
123            |name| format_ident!("{}", name),
124        )
125    }
126
127    fn internal_name(&self, which: &str, suffix: &str) -> Ident {
128        format_ident!(
129            "{INTERNAL_IDENT}_{}_{}{}",
130            self.name.to_string().to_snake_case(),
131            which,
132            suffix
133        )
134    }
135
136    fn to_tokens_macros(&self, tokens: &mut proc_macro2::TokenStream, export: bool, suffix: &str) {
137        let SigmaEnum {
138            visibility: _,
139            name,
140            variants,
141            subattrs: _,
142            attr,
143        } = &self;
144
145        let path = if export {
146            quote! { $crate :: }
147        } else {
148            match &attr.path {
149                Some(path) => quote! { $ #path :: },
150                None => quote! {},
151            }
152        };
153
154        let variants_btree: BTreeMap<_, _> = variants
155            .iter()
156            .map(|var| (var.ty.clone(), var.name.clone()))
157            .collect();
158        let variant_pats: Vec<_> = variants.iter().map(|var| var.ty.clone()).collect();
159
160        let macro_match = format_ident!("{}{}", self.macro_match_name(), suffix);
161        let macro_construct = format_ident!("{}{}", self.macro_construct_name(), suffix);
162        let macro_match_body = self.internal_name("body", suffix);
163        let macro_match_process_body = self.internal_name("process_body", suffix);
164        let macro_process_type = self.internal_name("process_type", suffix);
165        let macro_match_variant = self.internal_name("variant", suffix);
166        let macro_match_pattern = self.internal_name("pattern", suffix);
167        let macro_construct_inner = self.internal_name("construct_inner", suffix);
168
169        let macro_match_docstring = self.attr.macro_match.docstring();
170        let macro_construct_docstring = self.attr.macro_construct.docstring();
171
172        // https://github.com/rust-lang/rust/pull/52234#issuecomment-1417098097
173        let macro_match_export;
174        let macro_construct_export;
175        let macro_match_body_export;
176        let macro_match_process_body_export;
177        let macro_process_type_export;
178        let macro_match_variant_export;
179        let macro_match_pattern_export;
180        let macro_construct_inner_export;
181        let macro_match_pub_use;
182        let macro_construct_pub_use;
183        let macro_match_body_pub_use;
184        let macro_match_process_body_pub_use;
185        let macro_process_type_pub_use;
186        let macro_match_variant_pub_use;
187        let macro_match_pattern_pub_use;
188        let macro_construct_inner_pub_use;
189        if export {
190            macro_match_export = quote! { #macro_match_docstring #[macro_export] };
191            macro_construct_export = quote! { #macro_construct_docstring #[macro_export] };
192            macro_match_body_export = quote! { #[macro_export] };
193            macro_match_process_body_export = quote! { #[macro_export] };
194            macro_process_type_export = quote! { #[macro_export] };
195            macro_match_variant_export = quote! { #[macro_export] };
196            macro_match_pattern_export = quote! { #[macro_export] };
197            macro_construct_inner_export = quote! { #[macro_export] };
198            macro_match_pub_use = quote! {};
199            macro_construct_pub_use = quote! {};
200            macro_match_body_pub_use = quote! {};
201            macro_match_process_body_pub_use = quote! {};
202            macro_process_type_pub_use = quote! {};
203            macro_match_variant_pub_use = quote! {};
204            macro_match_pattern_pub_use = quote! {};
205            macro_construct_inner_pub_use = quote! {};
206        } else {
207            macro_match_export = quote! { #macro_match_docstring };
208            macro_construct_export = quote! { #macro_construct_docstring };
209            macro_match_body_export = quote! {};
210            macro_match_process_body_export = quote! {};
211            macro_process_type_export = quote! {};
212            macro_match_variant_export = quote! {};
213            macro_match_pattern_export = quote! {};
214            macro_construct_inner_export = quote! {};
215            macro_match_pub_use = quote! { #macro_match_docstring pub(crate) use #macro_match; };
216            macro_construct_pub_use =
217                quote! { #macro_construct_docstring pub(crate) use #macro_construct; };
218            macro_match_body_pub_use = quote! { #[doc(hidden)] pub(crate) use #macro_match_body; };
219            macro_match_process_body_pub_use =
220                quote! { #[doc(hidden)] pub(crate) use #macro_match_process_body; };
221            macro_process_type_pub_use =
222                quote! { #[doc(hidden)] pub(crate) use #macro_process_type; };
223            macro_match_variant_pub_use =
224                quote! { #[doc(hidden)] pub(crate) use #macro_match_variant; };
225            macro_match_pattern_pub_use =
226                quote! { #[doc(hidden)] pub(crate) use #macro_match_pattern; };
227            macro_construct_inner_pub_use =
228                quote! { #[doc(hidden)] pub(crate) use #macro_construct_inner; };
229        }
230
231        let internal_full_wildcard = format_ident!("{INTERNAL_FULL_WILDCARD}");
232
233        let mut patterns_map = BTreeMap::new();
234        patterns_map.insert(NiceType::PatternIdent(()), Vec::new());
235        for ty in &variant_pats {
236            for pat in ty.patterns_matching() {
237                let matches = patterns_map.entry(pat).or_insert(Vec::new());
238                matches.push(ty);
239            }
240        }
241
242        let patterns: Vec<_> = patterns_map.keys().collect();
243        let pat_variants: Vec<_> = patterns_map.values().collect();
244        let pat_variant_names: Vec<Vec<_>> = pat_variants
245            .iter()
246            .map(|v| v.iter().map(|ty| variants_btree[ty].clone()).collect())
247            .collect();
248
249        let patterns_vars: Vec<_> = patterns.iter().map(|pat| pat.index_patterns()).collect();
250        let patterns_vars_assoc: Vec<Vec<Vec<_>>> = pat_variants
251            .iter()
252            .zip(&patterns_vars)
253            .map(|(v, pat)| {
254                v.iter()
255                    .map(|ty| {
256                        ty.matches_map(&pat)
257                            .into_iter()
258                            .filter_map(|(ident, (ty, location))| {
259                                let NiceType::Literal(lit) = ty else {
260                                    return None;
261                                };
262                                // try block. sad
263                                let generic_ty = (|| {
264                                    let (parent, i) = location?;
265                                    self.attr.generics.get(&parent)?[i].as_ref()
266                                })();
267                                Some((ident, lit, generic_ty))
268                            })
269                            .collect()
270                    })
271                    .collect()
272            })
273            .collect();
274        // for each pattern, for each variant it matches, get the type pattern variables
275        // and their literals and locations, and generate let statements for them
276        let const_let_statements: Vec<Vec<proc_macro2::TokenStream>> = patterns_vars_assoc
277            .iter()
278            .map(|v| {
279                v.iter()
280                    .map(|v| {
281                        v.iter()
282                            .map(|(ident, lit, generic_ty)| match generic_ty {
283                                Some(generic_ty) => quote! { const $ #ident : #generic_ty = #lit; },
284                                None => quote! { let $ #ident = #lit; },
285                            })
286                            .map(|q| quote! { #[allow(nonstandard_style)] #[allow(unused_variables)] #q })
287                            .collect()
288                    })
289                    .collect()
290            })
291            .collect();
292
293        let pat_vars_params_eqs: Vec<Vec<Vec<_>>> = patterns_vars_assoc
294            .iter()
295            .map(|v| {
296                v.iter()
297                    .map(|v| {
298                        v.iter()
299                            .map(|(ident, lit, _generic_ty)| quote! { $ #ident == #lit })
300                            .collect()
301                    })
302                    .collect()
303            })
304            .collect();
305
306        let (pat_vars_names, pat_vars_params): (Vec<_>, Vec<_>) = patterns_vars
307            .iter()
308            .map(|pat| match pat {
309                NiceType::Ident(name, params) => (format_ident!("{}", name), {
310                    let params: Vec<_> = params
311                        .iter()
312                        .map(|param| param.map_pattern(|p| quote! { ? $ #p :ident }))
313                        .collect();
314                    (!params.is_empty())
315                        .then_some(params)
316                        .into_iter()
317                        .collect::<Vec<_>>()
318                }),
319                NiceType::PatternIdent(_p) => (
320                    internal_full_wildcard.clone(),
321                    None.into_iter().collect::<Vec<_>>(),
322                ),
323                _ => panic!("not ident {:?}", pat),
324            })
325            .unzip();
326
327        tokens.append_all(quote! {
328            #macro_match_export
329            #[allow(unused_macros)]
330            macro_rules! #macro_match {
331                ( match $( $rest:tt )* ) => {
332                    #path #macro_match_body ! { (), ( $($rest)* ) }
333                };
334            }
335            #macro_match_pub_use
336        });
337
338        tokens.append_all(quote! {
339            #macro_match_body_export
340            #[doc(hidden)]
341            macro_rules! #macro_match_body {
342                ( $what:tt, ({
343                    $( $rest:tt )*
344                }) ) => {
345                    #path #macro_match_process_body !( $what, ( $($rest)* ), () )
346                };
347                ( ( $( $what:tt )* ), ( $next:tt $( $rest:tt )* ) ) => {
348                    #path #macro_match_body ! { ( $($what)* $next ), ( $($rest)* ) }
349                };
350            }
351            #macro_match_body_pub_use
352        });
353
354        tokens.append_all(quote! {
355            #macro_match_process_body_export
356            #[doc(hidden)]
357            macro_rules! #macro_match_process_body {
358                ( $what:tt, (), ( $( ( $ty:tt; $binding:pat => $body:expr ) )* ) ) => {
359                    {
360                        let what = $what;
361
362                        #[allow(unreachable_patterns)]
363                        match what {
364                            $( #path #macro_match_pattern !($ty) => (), )*
365                        }
366
367                        #[allow(unused_labels)]
368                        'ma: {
369                            $( #path #macro_match_variant !{$ty; what; 'ma; $binding => $body} )*
370                            ::std::unreachable!();
371                        }
372                    }
373                };
374                (
375                    $what:tt,
376                    ( $binding:ident => { $( $body:tt )* } $(,)? $( $rest:tt )* ),
377                    ( $( $matched:tt )* )
378                ) => {
379                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( (#internal_full_wildcard) ; $binding => { $( $body )* } ) ) )
380                };
381                (
382                    $what:tt,
383                    ( $binding:ident => $body:expr, $( $rest:tt )* ),
384                    ( $( $matched:tt )* )
385                ) => {
386                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( (#internal_full_wildcard) ; $binding => { $body } ) ) )
387                };
388                (
389                    $what:tt,
390                    ( $tyn:ident ( $binding:pat ) => { $( $body:tt )* } $(,)? $( $rest:tt )* ),
391                    ( $( $matched:tt )* )
392                ) => {
393                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn); $binding => { $($body)* } ) ) )
394                };
395                (
396                    $what:tt,
397                    ( $tyn:ident ( $binding:pat ) => $body:expr, $( $rest:tt )* ),
398                    ( $( $matched:tt )* )
399                ) => {
400                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn); $binding => { $body } ) ) )
401                };
402                (
403                    $what:tt,
404                    ( $tyn:ident ::< $( $rest:tt )* ),
405                    ( $( $matched:tt )* )
406                ) => {
407                    #path #macro_process_type !( (@match, $what, $tyn, ($( $matched )*)), ($( $rest )*), (<), (<) )
408                };
409            }
410            #macro_match_process_body_pub_use
411        });
412
413        tokens.append_all(quote! {
414            #macro_process_type_export
415            #[doc(hidden)]
416            macro_rules! #macro_process_type {
417                ( $bundle:tt, ($(,)? > $($rest:tt)*), ( $($params:tt)* ), (< $($counter:tt)*) ) => {
418                    #path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* >), ($($counter)*) )
419                };
420                ( $bundle:tt, ($(,)? >> $($rest:tt)*), ( $($params:tt)* ), (< < $($counter:tt)*) ) => {
421                    #path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* > >), ($($counter)*) )
422                };
423                ( $bundle:tt, ($(,)? > $($rest:tt)*), ( $($params:tt)* ), () ) => {
424                    ::std::compile_error!("imbalanced")
425                };
426                ( $bundle:tt, ($(,)? >> $($rest:tt)*), ( $($params:tt)* ), () ) => {
427                    ::std::compile_error!("imbalanced")
428                };
429                ( $bundle:tt, (< $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
430                    #path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* <), (< $($counter)*) )
431                };
432                ( $bundle:tt, (<< $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
433                    #path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* < <), (< < $($counter)*) )
434                };
435                ( (@match, $what:tt, $tyn:ident, ( $($matched:tt)* )), (( $binding:pat ) => { $( $body:tt )* } $(,)? $($rest:tt)*), ( $($params:tt)* ), () ) => {
436                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn :: $($params)+); $binding => { $($body)* } ) ) )
437                };
438                ( (@match, $what:tt, $tyn:ident, ( $($matched:tt)* )), (( $binding:pat ) => $body:expr, $($rest:tt)*), ( $($params:tt)* ), () ) => {
439                    #path #macro_match_process_body !( $what, ( $($rest)* ), ( $($matched)* ( ($tyn :: $($params)+); $binding => { $body } ) ) )
440                };
441                ( (@construct, $tyn:ident), (( $expr:expr )), ( $($params:tt)+ ), () ) => {
442                    #path #macro_construct_inner !( ($tyn :: $($params)+); ( $expr ) )
443                };
444                ( $bundle:tt, (( $($any:tt)* ) $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
445                    ::std::compile_error!("imbalanced or something")
446                };
447                ( $bundle:tt, ($thing:tt $($rest:tt)*), ( $($params:tt)* ), ( $($counter:tt)* ) ) => {
448                    #path #macro_process_type ! ( $bundle, ($($rest)*), ($($params)* $thing), ( $($counter)*) )
449                };
450            }
451            #macro_process_type_pub_use
452        });
453
454        tokens.append_all(quote! {
455            #macro_match_variant_export
456            #[doc(hidden)]
457            macro_rules! #macro_match_variant {
458                #( ( (#pat_vars_names #(::< #( #pat_vars_params ),* >)* ); $what:ident; $ma:lifetime; $binding:pat => $body:expr ) => {
459                    #( if let #path #name :: #pat_variant_names ($binding) = $what {
460                        #const_let_statements
461                        break $ma($body);
462                    } )*
463                }; )*
464            }
465            #macro_match_variant_pub_use
466        });
467
468        tokens.append_all(quote! {
469            #macro_match_pattern_export
470            #[doc(hidden)]
471            macro_rules! #macro_match_pattern {
472                ( ( #internal_full_wildcard ) ) => { _ };
473                #( ( ( #pat_vars_names #(::< #( #pat_vars_params ),* >)* ) ) => {
474                    #( #path #name :: #pat_variant_names (_) )|*
475                }; )*
476            }
477            #macro_match_pattern_pub_use
478        });
479
480        tokens.append_all(quote! {
481            #macro_construct_export
482            #[allow(unused_macros)]
483            #macro_construct_docstring
484            macro_rules! #macro_construct {
485                ( $tyn:ident ::< $($tt:tt)* ) => {
486                    #path #macro_process_type !( (@construct, $tyn), ($($tt)*), (<), (<) )
487                };
488                ( $tyn:ident ( $body:expr ) ) => {
489                    #path #macro_construct_inner !( ($tyn); ($body) )
490                };
491            }
492            #macro_construct_pub_use
493        });
494
495        tokens.append_all(quote! {
496            #macro_construct_inner_export
497            #[doc(hidden)]
498            macro_rules! #macro_construct_inner {
499                #( ( (#pat_vars_names #(::< #( #pat_vars_params ),* >)* ); $body:expr ) => {
500                    'ma: {
501                        #( if true #(&& #pat_vars_params_eqs)* {
502                            #const_let_statements
503                            break 'ma ::std::option::Option::Some(#path #name :: #pat_variant_names($body));
504                        } )*
505                        ::std::option::Option::None
506                    }
507                }; )*
508            }
509            #macro_construct_inner_pub_use
510        });
511    }
512
513    fn to_tokens_traits(&self, tokens: &mut proc_macro2::TokenStream) {
514        let SigmaEnum {
515            visibility,
516            name,
517            variants,
518            subattrs: _,
519            attr,
520        } = &self;
521
522        let variant_types: Vec<_> = variants
523            .iter()
524            .map(|var| var.ty.to_tokens_aliased(&attr.alias))
525            .collect();
526        let variant_names: Vec<_> = variants.iter().map(|var| var.name.clone()).collect();
527
528        let into_trait = self.into_trait_name();
529        let into_trait_sealed_mod = self.internal_name("into_trait_sealed_mod", "");
530        let into_method = self.into_method_name();
531        let try_from_method = self.try_from_method_name();
532        let try_from_owned_method = self.try_from_owned_method_name();
533        let try_from_mut_method = self.try_from_mut_method_name();
534        let extract_method = self.extract_method_name();
535        let extract_owned_method = self.extract_owned_method_name();
536        let extract_mut_method = self.extract_mut_method_name();
537        let try_from_error = self.try_from_error_name();
538
539        let into_trait_docstring = self.attr.into_trait.docstring();
540        let into_method_docstring = self.attr.into_method.docstring();
541        let try_from_method_docstring = self.attr.try_from_method.docstring();
542        let try_from_owned_method_docstring = self.attr.try_from_owned_method.docstring();
543        let try_from_mut_method_docstring = self.attr.try_from_mut_method.docstring();
544        let extract_method_docstring = self.attr.extract_method.docstring();
545        let extract_owned_method_docstring = self.attr.extract_owned_method.docstring();
546        let extract_mut_method_docstring = self.attr.extract_mut_method.docstring();
547        let try_from_error_docstring = self.attr.try_from_error.docstring();
548
549        let methods = quote! {
550            #into_method_docstring
551            fn #into_method (self) -> #name;
552            #try_from_method_docstring
553            fn #try_from_method (value: & #name) -> Option<&Self>;
554            #try_from_owned_method_docstring
555            fn #try_from_owned_method (value: #name) -> Option<Self>
556                where Self: ::core::marker::Sized;
557            #try_from_mut_method_docstring
558            fn #try_from_mut_method (value: &mut #name) -> Option<&mut Self>;
559        };
560
561        tokens.append_all(quote! {
562            #into_trait_docstring
563            pub trait #into_trait : #into_trait_sealed_mod ::Sealed {
564                #methods
565            }
566
567            mod #into_trait_sealed_mod {
568                pub trait Sealed {}
569            }
570
571            #(
572                #[automatically_derived]
573                impl #into_trait_sealed_mod ::Sealed for #variant_types {}
574            )*
575        });
576
577        tokens.append_all(quote! {
578            #(
579                #into_trait_docstring
580                #[automatically_derived]
581                impl #into_trait for #variant_types {
582                    fn #into_method (self) -> #name {
583                        #name :: #variant_names (self)
584                    }
585
586                    fn #try_from_method <'a>(value: &'a #name) -> ::std::option::Option<&'a Self> {
587                        if let #name :: #variant_names (out) = value {
588                            ::std::option::Option::Some(out)
589                        } else {
590                            ::std::option::Option::None
591                        }
592                    }
593
594                    fn #try_from_owned_method (value: #name) -> ::std::option::Option<Self>
595                        where Self: ::core::marker::Sized
596                    {
597                        if let #name :: #variant_names (out) = value {
598                            ::std::option::Option::Some(out)
599                        } else {
600                            ::std::option::Option::None
601                        }
602                    }
603
604                    fn #try_from_mut_method <'a>(value: &'a mut #name) -> ::std::option::Option<&'a mut Self> {
605                        if let #name :: #variant_names (out) = value {
606                            ::std::option::Option::Some(out)
607                        } else {
608                            ::std::option::Option::None
609                        }
610                    }
611                }
612
613                #[automatically_derived]
614                impl ::std::convert::From<#variant_types> for #name {
615                    fn from(value: #variant_types) -> Self {
616                        #into_trait :: #into_method (value)
617                    }
618                }
619
620                #[automatically_derived]
621                impl<'a> ::std::convert::TryFrom<&'a #name> for &'a #variant_types {
622                    type Error = #try_from_error;
623                    fn try_from(value: &'a #name) -> Result<&'a #variant_types, #try_from_error > {
624                       < #variant_types as #into_trait >:: #try_from_method (value).ok_or( #try_from_error )
625                    }
626                }
627
628                #[automatically_derived]
629                impl ::std::convert::TryFrom<#name> for #variant_types
630                        where Self: ::core::marker::Sized
631                {
632                    type Error = #try_from_error;
633                    fn try_from(value: #name) -> Result<#variant_types, #try_from_error > {
634                       < #variant_types as #into_trait >:: #try_from_owned_method (value).ok_or( #try_from_error )
635                    }
636                }
637            )*
638
639            impl #name {
640                #extract_method_docstring
641                #visibility fn #extract_method <T: #into_trait >(&self) -> Option<&T> {
642                    T:: #try_from_method (self)
643                }
644
645                #extract_owned_method_docstring
646                #visibility fn #extract_owned_method <T: #into_trait >(self) -> Option<T> {
647                    T:: #try_from_owned_method (self)
648                }
649
650                #extract_mut_method_docstring
651                #visibility fn #extract_mut_method <T: #into_trait >(&mut self) -> Option<&mut T> {
652                    T:: #try_from_mut_method (self)
653                }
654            }
655        });
656
657        tokens.append_all(quote! {
658            pub struct #try_from_error;
659
660            #[automatically_derived]
661            impl ::core::fmt::Debug for #try_from_error {
662                #[inline]
663                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
664                    ::core::fmt::Formatter::write_str(f, ::std::stringify!(#try_from_error))
665                }
666            }
667            #[automatically_derived]
668            impl ::core::clone::Clone for #try_from_error {
669                #[inline]
670                fn clone(&self) -> #try_from_error {
671                    *self
672                }
673            }
674            #[automatically_derived]
675            impl ::core::marker::Copy for #try_from_error {}
676            #[automatically_derived]
677            impl ::core::cmp::PartialEq for #try_from_error {
678                #[inline]
679                fn eq(&self, other: & #try_from_error) -> bool {
680                    true
681                }
682            }
683            #[automatically_derived]
684            impl ::core::cmp::Eq for #try_from_error {}
685            #[automatically_derived]
686            impl ::core::hash::Hash for #try_from_error {
687                #[inline]
688                fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () {}
689            }
690            #[automatically_derived]
691            impl ::core::cmp::PartialOrd for #try_from_error {
692                #[inline]
693                fn partial_cmp(&self, other: & #try_from_error) -> ::core::option::Option<::core::cmp::Ordering> {
694                    ::core::option::Option::Some(::core::cmp::Ordering::Equal)
695                }
696            }
697            #[automatically_derived]
698            impl ::core::cmp::Ord for #try_from_error {
699                #[inline]
700                fn cmp(&self, other: & #try_from_error) -> ::core::cmp::Ordering {
701                    ::core::cmp::Ordering::Equal
702                }
703            }
704
705            #[automatically_derived]
706            impl ::std::fmt::Display for #try_from_error {
707                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
708                    f.write_str("attempted to extract value from a ")?;
709                    f.write_str(::std::stringify!( #name ))?;
710                    f.write_str(" holding a different type")?;
711                    ::std::fmt::Result::Ok(())
712                }
713            }
714
715           #try_from_error_docstring
716            #[automatically_derived]
717            impl ::std::error::Error for #try_from_error {}
718        });
719    }
720}
721
722impl ToTokens for SigmaEnum {
723    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
724        let SigmaEnum {
725            visibility,
726            name,
727            variants,
728            subattrs,
729            attr,
730        } = &self;
731
732        let variant_types: Vec<_> = variants
733            .iter()
734            .map(|var| var.ty.to_tokens_aliased(&attr.alias))
735            .collect();
736        let variant_names: Vec<_> = variants.iter().map(|var| var.name.clone()).collect();
737        let variant_attrs: Vec<_> = variants.iter().map(|var| var.attrs.clone()).collect();
738
739        tokens.append_all(quote! {
740            #(#subattrs)*
741            #visibility enum #name {
742                #(
743                    #(#variant_attrs)*
744                    #variant_names(#variant_types),
745                )*
746            }
747        });
748
749        match visibility {
750            Visibility::Public(_) => {
751                self.to_tokens_macros(tokens, true, "");
752                self.to_tokens_macros(tokens, false, "_crate");
753            }
754            _ => {
755                self.to_tokens_macros(tokens, false, "");
756            }
757        }
758        self.to_tokens_traits(tokens);
759    }
760}
761
762impl Parse for SigmaEnum {
763    fn parse(input: ParseStream) -> syn::Result<Self> {
764        let subattrs = Attribute::parse_outer(input)?;
765        let visibility: Visibility = input.parse()?;
766        let _: Token![enum] = input.parse()?;
767        let name: Ident = input.parse()?;
768        let content;
769        braced!(content in input);
770        let mut variants = Vec::new();
771        let mut variant_tys = BTreeSet::new();
772        let mut attrs = Vec::new();
773        while !content.is_empty() {
774            let mut expand = BTreeMap::new();
775            let mut rename = None;
776            if let Ok(attributes) = content.call(Attribute::parse_outer) {
777                for attr in &attributes {
778                    if attr.path().is_ident("sigma_enum") {
779                        attr.parse_nested_meta(|meta| {
780                            match meta.path.require_ident()?.to_string().as_str() {
781                                "expand" => {
782                                    meta.parse_nested_meta(|meta| {
783                                        let ident = meta.path.require_ident()?;
784                                        let value: Expr = meta.value()?.parse()?;
785                                        let value = extract_expansion(&value)?;
786                                        if expand.contains_key(ident) {
787                                            return Err(syn::Error::new(
788                                                meta.path.span(),
789                                                "duplicate expand attribute",
790                                            ));
791                                        }
792                                        expand.insert(ident.clone(), value);
793                                        Ok(())
794                                    })?;
795                                }
796                                "rename" => {
797                                    if rename.is_some() {
798                                        return Err(syn::Error::new(
799                                            meta.path.span(),
800                                            "duplicate rename attribute",
801                                        ));
802                                    }
803                                    let _: Token![=] = meta.input.parse()?;
804                                    if let Ok(ident) = meta.input.parse::<Ident>() {
805                                        rename = Some(ident.to_string());
806                                    } else if let Ok(template) = meta.input.parse::<LitStr>() {
807                                        rename = Some(template.value());
808                                    } else {
809                                        return Err(syn::Error::new(
810                                            meta.input.span(),
811                                            "invalid renaming template",
812                                        ));
813                                    }
814                                }
815                                _ => {
816                                    return Err(syn::Error::new(meta.path.span(), "invalid attr"));
817                                }
818                            }
819                            Ok(())
820                        })?;
821                    } else {
822                        attrs.push(attr.clone());
823                    }
824                }
825            }
826
827            // variant name
828            // we cannot have rename and variant name
829
830            let enum_var_name: Ident = content.parse()?;
831            let enum_var_name =
832                (!enum_var_name.to_string().starts_with("_")).then_some(enum_var_name);
833            if rename.is_some() && enum_var_name.is_some() {
834                return Err(syn::Error::new(
835                    enum_var_name.span(),
836                    "cannot use variant name and rename attribute",
837                ));
838            }
839            if !expand.is_empty() && enum_var_name.is_some() {
840                return Err(syn::Error::new(
841                    enum_var_name.span(),
842                    "cannot use variant name and expand attribute",
843                ));
844            }
845
846            let ty_paren;
847            parenthesized!(ty_paren in content);
848            let nice_type: NiceType<Infallible> = ty_paren.parse()?;
849            assert!(ty_paren.is_empty());
850            let _ = content.parse::<Token![,]>();
851
852            if rename.as_deref().is_some_and(|rename| {
853                !expand
854                    .keys()
855                    .all(|ident| rename.contains(&format!("{{{}}}", ident)))
856            }) {
857                return Err(syn::Error::new(
858                    enum_var_name.span(),
859                    "rename template does not have all metavariables",
860                ));
861            }
862
863            let cartesian: Vec<Vec<(Ident, NiceTypeLit)>> =
864                expand
865                    .into_iter()
866                    .fold(vec![Vec::new()], |accum, (ident, range)| {
867                        accum
868                            .into_iter()
869                            .flat_map(|a| {
870                                range.iter().map({
871                                    let ident = &ident;
872                                    move |r| {
873                                        let mut a = a.clone();
874                                        a.push((ident.clone(), r.clone()));
875                                        a
876                                    }
877                                })
878                            })
879                            .collect()
880                    });
881
882            for assignments in cartesian {
883                let mut var_type = nice_type.clone();
884                for (ident, r) in &assignments {
885                    var_type = var_type.replace_ident(&ident.to_string(), &r)
886                }
887                let name = match &rename {
888                    Some(template) => {
889                        let mut name = template.clone();
890                        for (var, val) in assignments {
891                            name =
892                                name.replace(&format!("{{{}}}", var), &val.variant_name_string());
893                        }
894                        if name.contains('{') {
895                            return Err(syn::Error::new(
896                                template.span(),
897                                "invalid metavariable in rename template",
898                            ));
899                        }
900                        format_ident!("{}", name)
901                    }
902                    None => match &enum_var_name {
903                        Some(enum_var_name) => enum_var_name.clone(),
904                        None => var_type.variant_name(),
905                    },
906                };
907                if !variant_tys.insert(var_type.clone()) {
908                    return Err(syn::Error::new(var_type.span(), "duplicate variant types"));
909                }
910                variants.push(Variant {
911                    ty: var_type,
912                    name,
913                    attrs: attrs.clone(),
914                });
915            }
916        }
917
918        Ok(SigmaEnum {
919            visibility,
920            name,
921            variants,
922            subattrs,
923            attr: ItemAttr::default(),
924        })
925    }
926}
927
928#[proc_macro_attribute]
929pub fn sigma_enum(attr: TokenStream, item: TokenStream) -> TokenStream {
930    // Parse the input tokens into a syntax tree
931    let mut sigma_enum = parse_macro_input!(item as SigmaEnum);
932    let attr = parse_macro_input!(attr as ItemAttr);
933    sigma_enum.attr = attr;
934
935    // panic!("{}", quote! { #sigma_enum });
936    quote! { #sigma_enum }.into()
937}