1use proc_macro::TokenStream;
2use syn::{Data, Expr, Lit, Meta, MetaNameValue};
3
4use quote::{ToTokens, format_ident, quote};
5
6#[proc_macro_derive(Strand, attributes(strand))]
12pub fn derive_strand(item: TokenStream) -> TokenStream {
13 derive_strand_impl(item, true)
14}
15
16fn derive_strand_impl(item: TokenStream, impl_strand: bool) -> TokenStream {
17 let input = syn::parse_macro_input!(item as syn::DeriveInput);
18
19 let struct_name = &input.ident;
20 let vis = &input.vis;
21 let distr_trait = format_ident!("{}Distribution", struct_name);
22 let distr_struct = format_ident!("{}Distr", struct_name);
23
24 match &input.data {
25 Data::Struct(syn::DataStruct { fields, .. }) => {
26 let mut distr_generics = quote! {};
27 let mut distr_generic_defaults = quote! {};
28 let mut distr_where_clause = quote! {
29 where
30 };
31 let mut distr_trait_method_defs = quote! {};
32 let mut distr_trait_method_impls = quote! {};
33 let mut distr_struct_fields = quote! {};
34 let mut distr_field_constructors = quote! {};
35 let mut distr_field_samplers = quote! {};
36
37 for (i, field) in fields.iter().enumerate() {
38 let field_name = field.ident.as_ref().unwrap();
39 let field_type = &field.ty;
40 let field_param = format_ident!("T_{}", i);
41 let method_name = format_ident!("with_{}", field_name);
42 let field_trait = quote! { ::strander::rand::distr::Distribution<#field_type> };
43 let method_signature = quote!{ fn #method_name(self, #field_name: impl #field_trait) -> impl #distr_trait };
44
45 let mut other_fields = quote!{};
46
47 for other in fields.iter().filter(|f| f.ident != field.ident) {
48 let other_name = other.ident.as_ref().unwrap();
49 other_fields.extend(quote! { #other_name : self.#other_name, });
50 }
51
52 distr_generics.extend(quote! { #field_param , });
53 distr_generic_defaults.extend(quote! { #field_param = (), });
54 distr_struct_fields.extend(quote! { #field_name: #field_param, });
55 distr_where_clause.extend(quote! { #field_param : #field_trait , });
56
57 distr_trait_method_defs.extend(quote! { #method_signature ; });
58 distr_trait_method_impls.extend(quote! { #method_signature {
59 #distr_struct {
60 #field_name,
61 #other_fields
62 }
63 }});
64 let mut constructor = quote! { <#field_type as ::strander::Strand>::strand() };
65 for attr in field.attrs.iter().map(|a| &a.meta) {
66 match attr {
67 Meta::NameValue(MetaNameValue{ path, value, .. }) => {
68 if path
69 .get_ident()
70 .map(|i| i.to_string())
71 .as_ref()
72 .map(|s| s.as_str()) != Some("strand") {
73 continue;
74 }
75 if let Expr::Lit(expr) = &value {
76 if let Lit::Str(lit_str) = &expr.lit {
77 constructor = lit_str.parse::<Expr>().expect("a valid rust expression").into_token_stream();
78 }
79 }
80 },
81 _ => continue,
82 }
83 }
84 distr_field_constructors.extend(quote! { #field_name: #constructor, });
85 distr_field_samplers.extend(quote! { #field_name: <#field_param as #field_trait>::sample(&self.#field_name, rng), })
86
87 }
88
89 let mut strand_impl = quote!{};
90 if impl_strand {
91 strand_impl = quote! {
92 #[allow(refining_impl_trait)]
93 impl ::strander::Strand for #struct_name {
94 fn strand() -> impl #distr_trait {
95 #distr_struct::new()
96 }
97 }
98 }
99 };
100
101 quote! {
102 #vis trait #distr_trait: ::strander::rand::distr::Distribution<#struct_name> {
103 #distr_trait_method_defs
104 }
105
106 #vis struct #distr_struct <#distr_generic_defaults> {
107 #distr_struct_fields
108 }
109
110 impl<#distr_generics> ::strander::rand::distr::Distribution<#struct_name> for #distr_struct <#distr_generics>
111 #distr_where_clause
112 {
113 fn sample<R: ::strander::rand::Rng + ?Sized>(&self, rng: &mut R) -> #struct_name {
114 use ::strander::rand::distr::Distribution;
115 #struct_name {
116 #distr_field_samplers
117 }
118 }
119 }
120
121 impl<#distr_generics> #distr_trait for #distr_struct <#distr_generics>
122 #distr_where_clause
123 {
124 #distr_trait_method_impls
125 }
126
127 impl #distr_struct {
128 pub fn new() -> impl #distr_trait {
129 #distr_struct {
130 #distr_field_constructors
131 }
132 }
133 }
134
135 #strand_impl
136 }
137 }
138 _ => unimplemented!(),
139 }
140 .into()
141}
142
143#[proc_macro_attribute]
144pub fn strand_remote(_args: TokenStream, item: TokenStream) -> TokenStream {
145 derive_strand_impl(item, false)
146}
147
148#[cfg(test)]
149mod tests {}