1use proc_macro::TokenStream;
2use syn::{Data, Expr, Lit, Meta, MetaNameValue};
3
4use quote::{ToTokens, format_ident, quote};
5
6#[proc_macro_derive(Strand, attributes(strand))]
14pub fn derive_strand(item: TokenStream) -> TokenStream {
15 derive_strand_impl(item, true)
16}
17
18fn derive_strand_impl(item: TokenStream, impl_strand: bool) -> TokenStream {
19 let input = syn::parse_macro_input!(item as syn::DeriveInput);
20
21 let struct_name = &input.ident;
22 let vis = &input.vis;
23 let distr_trait = format_ident!("{}Distribution", struct_name);
24 let distr_struct = format_ident!("{}Distr", struct_name);
25
26 match &input.data {
27 Data::Struct(syn::DataStruct { fields, .. }) => {
28 let mut distr_generics = quote! {};
29 let mut distr_generic_defaults = quote! {};
30 let mut distr_where_clause = quote! {
31 where
32 };
33 let mut distr_trait_method_defs = quote! {};
34 let mut distr_trait_method_impls = quote! {};
35 let mut distr_struct_fields = quote! {};
36 let mut distr_field_constructors = quote! {};
37 let mut distr_field_samplers = quote! {};
38
39 for (i, field) in fields.iter().enumerate() {
40 let field_name = field.ident.as_ref().unwrap();
41 let field_type = &field.ty;
42 let field_param = format_ident!("T_{}", i);
43 let method_name = format_ident!("with_{}", field_name);
44 let field_trait = quote! { ::strander::rand::distr::Distribution<#field_type> };
45 let method_signature = quote!{ fn #method_name(self, #field_name: impl #field_trait) -> impl #distr_trait };
46
47 let mut other_fields = quote!{};
48
49 for other in fields.iter().filter(|f| f.ident != field.ident) {
50 let other_name = other.ident.as_ref().unwrap();
51 other_fields.extend(quote! { #other_name : self.#other_name, });
52 }
53
54 distr_generics.extend(quote! { #field_param , });
55 distr_generic_defaults.extend(quote! { #field_param = (), });
56 distr_struct_fields.extend(quote! { #field_name: #field_param, });
57 distr_where_clause.extend(quote! { #field_param : #field_trait , });
58
59 distr_trait_method_defs.extend(quote! { #method_signature ; });
60 distr_trait_method_impls.extend(quote! { #method_signature {
61 #distr_struct {
62 #field_name,
63 #other_fields
64 }
65 }});
66 let mut constructor = quote! { <#field_type as ::strander::Strand>::strand() };
67 for attr in field.attrs.iter().map(|a| &a.meta) {
68 match attr {
69 Meta::NameValue(MetaNameValue{ value, .. }) => {
70 if let Expr::Lit(expr) = &value {
71 if let Lit::Str(lit_str) = &expr.lit {
72 constructor = lit_str.parse::<Expr>().expect("a valid rust expression").into_token_stream();
73 }
74 }
75 },
76 other => panic!("unsupported attribute: {:?}", other),
77 }
78 }
79 distr_field_constructors.extend(quote! { #field_name: #constructor, });
80 distr_field_samplers.extend(quote! { #field_name: <#field_param as #field_trait>::sample(&self.#field_name, rng), })
81
82 }
83
84 let mut strand_impl = quote!{};
85 if impl_strand {
86 strand_impl = quote! {
87 #[allow(refining_impl_trait)]
88 impl ::strander::Strand for #struct_name {
89 fn strand() -> impl #distr_trait {
90 #distr_struct::new()
91 }
92 }
93 }
94 };
95
96 quote! {
97 #vis trait #distr_trait: ::strander::rand::distr::Distribution<#struct_name> {
98 #distr_trait_method_defs
99 }
100
101 #vis struct #distr_struct <#distr_generic_defaults> {
102 #distr_struct_fields
103 }
104
105 impl<#distr_generics> ::strander::rand::distr::Distribution<#struct_name> for #distr_struct <#distr_generics>
106 #distr_where_clause
107 {
108 fn sample<R: ::strander::rand::Rng + ?Sized>(&self, rng: &mut R) -> #struct_name {
109 use ::strander::rand::distr::Distribution;
110 #struct_name {
111 #distr_field_samplers
112 }
113 }
114 }
115
116 impl<#distr_generics> #distr_trait for #distr_struct <#distr_generics>
117 #distr_where_clause
118 {
119 #distr_trait_method_impls
120 }
121
122 impl #distr_struct {
123 pub fn new() -> impl #distr_trait {
124 #distr_struct {
125 #distr_field_constructors
126 }
127 }
128 }
129
130 #strand_impl
131 }
132 }
133 _ => unimplemented!(),
134 }
135 .into()
136}
137
138#[proc_macro_attribute]
139pub fn strand_remote(_args: TokenStream, item: TokenStream) -> TokenStream {
140 derive_strand_impl(item, false)
141}
142
143#[cfg(test)]
144mod tests {}