1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_macro_input, parse_quote,
7 token::Comma,
8 Attribute, Data, DeriveInput, Field, Ident, Token, Type,
9};
10
11#[proc_macro_derive(Parameters, attributes(zenu, parameters))]
12pub fn zenu_derive_parameters(input: TokenStream) -> TokenStream {
13 let input = parse_macro_input!(input as DeriveInput);
14
15 let parameters_impl = impl_parameters(&input);
16
17 TokenStream::from(parameters_impl)
18}
19
20fn impl_parameters(input: &DeriveInput) -> TokenStream2 {
21 let name = &input.ident;
22 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
23
24 let fields = match &input.data {
25 Data::Struct(data) => &data.fields,
26 _ => panic!("ZenuModel only supports structs"),
27 };
28
29 let fields = fields.iter().filter(|field| !has_zenu_skip_attr(field));
30
31 let weights_code = fields.clone().map(|field| {
32 let field_name = &field.ident;
33 quote! {
34 for (name, variable) in &self.#field_name.weights() {
35 let name = format!("{}.{}", stringify!(#field_name), name);
36 params.insert(name.clone(), variable.clone());
37 }
38 }
39 });
40
41 let biases_code = fields.clone().map(|field| {
42 let field_name = &field.ident;
43 quote! {
44 for (name, variable) in &self.#field_name.biases() {
45 let name = format!("{}.{}", stringify!(#field_name), name);
46 params.insert(name.clone(), variable.clone());
47 }
48 }
49 });
50
51 let (num_type, device_type) = parse_parameters_attr(&input.attrs);
52
53 quote!(
54 impl #impl_generics ::zenu::layer::Parameters #ty_generics for #name #ty_generics #where_clause {
55 fn weights(&self) -> std::collections::HashMap<String, ::zenu::autograd::Variable<#num_type, #device_type>> {
56 let mut params = std::collections::HashMap::new();
57 #(
58 #weights_code
59 )*
60 params
61 }
62
63 fn biases(&self) -> std::collections::HashMap<String, ::zenu::autograd::Variable<#num_type, #device_type>> {
64 let mut params = std::collections::HashMap::new();
65 #(
66 #biases_code
67 )*
68 params
69 }
70 }
71 )
72}
73
74fn has_zenu_skip_attr(field: &Field) -> bool {
75 field
76 .attrs
77 .iter()
78 .any(|attr| attr.path.is_ident("zenu") && attr.tokens.to_string().contains("skip"))
79}
80
81fn parse_parameters_attr(attrs: &[Attribute]) -> (Type, Type) {
82 let mut num_type: Type = parse_quote!(f32);
83 let mut device_type: Type = parse_quote!(Cpu);
84
85 for attr in attrs {
86 if attr.path.is_ident("parameters") {
87 let args = syn::parse2::<ParametersArgs>(attr.tokens.clone())
88 .expect("Failed to parse parameters attribute");
89 if let Some(ty) = args.num {
90 num_type = ty;
91 }
92 if let Some(ty) = args.device {
93 device_type = ty;
94 }
95 }
96 }
97
98 (num_type, device_type)
99}
100
101struct ParametersArgs {
102 num: Option<Type>,
103 device: Option<Type>,
104}
105
106impl Parse for ParametersArgs {
107 fn parse(input: ParseStream) -> syn::Result<Self> {
108 let content;
109 syn::parenthesized!(content in input);
110
111 let mut num = None;
112 let mut device = None;
113
114 while !content.is_empty() {
115 let ident: Ident = content.parse()?;
116 let _: Token![=] = content.parse()?;
117 let ty: Type = content.parse()?;
118
119 if ident == "num" {
120 num = Some(ty);
121 } else if ident == "device" {
122 device = Some(ty);
123 }
129
130 if content.peek(Comma) {
131 let _: Comma = content.parse()?;
132 } else {
133 break;
134 }
135 }
136
137 Ok(ParametersArgs { num, device })
138 }
139}