strander_derive/
lib.rs

1use proc_macro::TokenStream;
2use syn::{Data, Expr, Lit, Meta, MetaNameValue};
3
4use quote::{ToTokens, format_ident, quote};
5
6// TODO: field attrs to control default distributions
7// TODO: struct attrs to control generated trait/struct names
8// TODO: visibility
9// TODO: input generics
10// TODO: enums?
11// TODO: tuple structs?
12
13#[proc_macro_derive(Strand, attributes(strand))]
14pub fn derive_strand(item: TokenStream) -> TokenStream {
15    derive_strand_impl(item, true)
16}
17
18fn derive_strand_impl(item: TokenStream, impl_strand: bool) -> TokenStream {
19    let input = syn::parse_macro_input!(item as syn::DeriveInput);
20
21    let struct_name = &input.ident;
22    let vis = &input.vis;
23    let distr_trait = format_ident!("{}Distribution", struct_name);
24    let distr_struct = format_ident!("{}Distr", struct_name);
25
26    match &input.data {
27        Data::Struct(syn::DataStruct { fields, .. }) => {
28            let mut distr_generics = quote! {};
29            let mut distr_generic_defaults = quote! {};
30            let mut distr_where_clause = quote! {
31                where
32            };
33            let mut distr_trait_method_defs = quote! {};
34            let mut distr_trait_method_impls = quote! {};
35            let mut distr_struct_fields = quote! {};
36            let mut distr_field_constructors = quote! {};
37            let mut distr_field_samplers = quote! {};
38
39            for (i, field) in fields.iter().enumerate() {
40                let field_name = field.ident.as_ref().unwrap();
41                let field_type = &field.ty;
42                let field_param = format_ident!("T_{}", i);
43                let method_name = format_ident!("with_{}", field_name);
44                let field_trait = quote! { ::strander::rand::distr::Distribution<#field_type> };
45                let method_signature = quote!{ fn #method_name(self, #field_name: impl #field_trait) -> impl #distr_trait };
46
47                let mut other_fields = quote!{};
48
49                for other in fields.iter().filter(|f| f.ident != field.ident) {
50                    let other_name = other.ident.as_ref().unwrap();
51                    other_fields.extend(quote! { #other_name : self.#other_name, });
52                }
53
54                distr_generics.extend(quote! { #field_param , });
55                distr_generic_defaults.extend(quote! { #field_param = (), });
56                distr_struct_fields.extend(quote! { #field_name: #field_param, });
57                distr_where_clause.extend(quote! { #field_param : #field_trait , });
58
59                distr_trait_method_defs.extend(quote! { #method_signature ; });
60                distr_trait_method_impls.extend(quote! { #method_signature {
61                    #distr_struct {
62                        #field_name,
63                        #other_fields
64                    }
65                }});
66                let mut constructor = quote! { <#field_type as ::strander::Strand>::strand() };
67                for attr in field.attrs.iter().map(|a| &a.meta) {
68                    match attr {
69                        Meta::NameValue(MetaNameValue{ value, .. }) => {
70                            if let Expr::Lit(expr) = &value {
71                                if let Lit::Str(lit_str) = &expr.lit {
72                                    constructor = lit_str.parse::<Expr>().expect("a valid rust expression").into_token_stream();
73                                }
74                            }
75                        },
76                        other => panic!("unsupported attribute: {:?}", other),
77                    }
78                }
79                distr_field_constructors.extend(quote! { #field_name: #constructor, });
80                distr_field_samplers.extend(quote! { #field_name: <#field_param as #field_trait>::sample(&self.#field_name, rng), })
81
82            }
83
84            let mut strand_impl = quote!{};
85            if impl_strand {
86                strand_impl = quote! {
87                    #[allow(refining_impl_trait)]
88                    impl ::strander::Strand for #struct_name {
89                        fn strand() -> impl #distr_trait {
90                            #distr_struct::new()
91                        }
92                    }
93                }
94            };
95
96            quote! {
97                #vis trait #distr_trait: ::strander::rand::distr::Distribution<#struct_name> {
98                    #distr_trait_method_defs
99                }
100
101                #vis struct #distr_struct <#distr_generic_defaults> {
102                    #distr_struct_fields
103                }
104
105                impl<#distr_generics> ::strander::rand::distr::Distribution<#struct_name> for #distr_struct <#distr_generics>
106                    #distr_where_clause
107                {
108                    fn sample<R: ::strander::rand::Rng + ?Sized>(&self, rng: &mut R) -> #struct_name {
109                        use ::strander::rand::distr::Distribution;
110                        #struct_name {
111                            #distr_field_samplers
112                        }
113                    }
114                }
115
116                impl<#distr_generics> #distr_trait for #distr_struct <#distr_generics>
117                    #distr_where_clause
118                {
119                    #distr_trait_method_impls
120                }
121
122                impl #distr_struct {
123                    pub fn new() -> impl #distr_trait {
124                        #distr_struct {
125                            #distr_field_constructors
126                        }
127                    }
128                }
129
130                #strand_impl
131            }
132        }
133        _ => unimplemented!(),
134    }
135    .into()
136}
137
138#[proc_macro_attribute]
139pub fn strand_remote(_args: TokenStream, item: TokenStream) -> TokenStream {
140    derive_strand_impl(item, false)
141}
142
143#[cfg(test)]
144mod tests {}