strander_derive/
lib.rs

1use proc_macro::TokenStream;
2use syn::{Data, Expr, Lit, Meta, MetaNameValue};
3
4use quote::{ToTokens, format_ident, quote};
5
6// TODO: struct attrs to control generated trait/struct names
7// TODO: input generics
8// TODO: enums?
9// TODO: tuple structs?
10
11#[proc_macro_derive(Strand, attributes(strand))]
12pub fn derive_strand(item: TokenStream) -> TokenStream {
13    derive_strand_impl(item, true)
14}
15
16fn derive_strand_impl(item: TokenStream, impl_strand: bool) -> TokenStream {
17    let input = syn::parse_macro_input!(item as syn::DeriveInput);
18
19    let struct_name = &input.ident;
20    let vis = &input.vis;
21    let distr_trait = format_ident!("{}Distribution", struct_name);
22    let distr_struct = format_ident!("{}Distr", struct_name);
23
24    match &input.data {
25        Data::Struct(syn::DataStruct { fields, .. }) => {
26            let mut distr_generics = quote! {};
27            let mut distr_generic_defaults = quote! {};
28            let mut distr_where_clause = quote! {
29                where
30            };
31            let mut distr_trait_method_defs = quote! {};
32            let mut distr_trait_method_impls = quote! {};
33            let mut distr_struct_fields = quote! {};
34            let mut distr_field_constructors = quote! {};
35            let mut distr_field_samplers = quote! {};
36
37            for (i, field) in fields.iter().enumerate() {
38                let field_name = field.ident.as_ref().unwrap();
39                let field_type = &field.ty;
40                let field_param = format_ident!("T_{}", i);
41                let method_name = format_ident!("with_{}", field_name);
42                let field_trait = quote! { ::strander::rand::distr::Distribution<#field_type> };
43                let method_signature = quote!{ fn #method_name(self, #field_name: impl #field_trait) -> impl #distr_trait };
44
45                let mut other_fields = quote!{};
46
47                for other in fields.iter().filter(|f| f.ident != field.ident) {
48                    let other_name = other.ident.as_ref().unwrap();
49                    other_fields.extend(quote! { #other_name : self.#other_name, });
50                }
51
52                distr_generics.extend(quote! { #field_param , });
53                distr_generic_defaults.extend(quote! { #field_param = (), });
54                distr_struct_fields.extend(quote! { #field_name: #field_param, });
55                distr_where_clause.extend(quote! { #field_param : #field_trait , });
56
57                distr_trait_method_defs.extend(quote! { #method_signature ; });
58                distr_trait_method_impls.extend(quote! { #method_signature {
59                    #distr_struct {
60                        #field_name,
61                        #other_fields
62                    }
63                }});
64                let mut constructor = quote! { <#field_type as ::strander::Strand>::strand() };
65                for attr in field.attrs.iter().map(|a| &a.meta) {
66                    match attr {
67                        Meta::NameValue(MetaNameValue{ path, value, .. }) => {
68                            if path
69                                .get_ident()
70                                .map(|i| i.to_string())
71                                .as_ref()
72                                .map(|s| s.as_str()) != Some("strand") {
73                                continue;
74                            }
75                            if let Expr::Lit(expr) = &value {
76                                if let Lit::Str(lit_str) = &expr.lit {
77                                    constructor = lit_str.parse::<Expr>().expect("a valid rust expression").into_token_stream();
78                                }
79                            }
80                        },
81                        _ => continue,
82                    }
83                }
84                distr_field_constructors.extend(quote! { #field_name: #constructor, });
85                distr_field_samplers.extend(quote! { #field_name: <#field_param as #field_trait>::sample(&self.#field_name, rng), })
86
87            }
88
89            let mut strand_impl = quote!{};
90            if impl_strand {
91                strand_impl = quote! {
92                    #[allow(refining_impl_trait)]
93                    impl ::strander::Strand for #struct_name {
94                        fn strand() -> impl #distr_trait {
95                            #distr_struct::new()
96                        }
97                    }
98                }
99            };
100
101            quote! {
102                #vis trait #distr_trait: ::strander::rand::distr::Distribution<#struct_name> {
103                    #distr_trait_method_defs
104                }
105
106                #vis struct #distr_struct <#distr_generic_defaults> {
107                    #distr_struct_fields
108                }
109
110                impl<#distr_generics> ::strander::rand::distr::Distribution<#struct_name> for #distr_struct <#distr_generics>
111                    #distr_where_clause
112                {
113                    fn sample<R: ::strander::rand::Rng + ?Sized>(&self, rng: &mut R) -> #struct_name {
114                        use ::strander::rand::distr::Distribution;
115                        #struct_name {
116                            #distr_field_samplers
117                        }
118                    }
119                }
120
121                impl<#distr_generics> #distr_trait for #distr_struct <#distr_generics>
122                    #distr_where_clause
123                {
124                    #distr_trait_method_impls
125                }
126
127                impl #distr_struct {
128                    pub fn new() -> impl #distr_trait {
129                        #distr_struct {
130                            #distr_field_constructors
131                        }
132                    }
133                }
134
135                #strand_impl
136            }
137        }
138        _ => unimplemented!(),
139    }
140    .into()
141}
142
143#[proc_macro_attribute]
144pub fn strand_remote(_args: TokenStream, item: TokenStream) -> TokenStream {
145    derive_strand_impl(item, false)
146}
147
148#[cfg(test)]
149mod tests {}