smyl_macros/
lib.rs

1use proc_macro::{self, TokenStream};
2use proc_macro2::{Ident, Literal, Span, TokenTree};
3use quote::quote;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{parse_macro_input, LitInt};
8
9struct Matrix {
10    height: usize,
11    width:  usize,
12    data:   Vec<Literal>,
13}
14
15impl Parse for Matrix {
16    fn parse(input: ParseStream) -> syn::Result<Self> {
17        let mut data = Vec::new();
18        let mut height = 0;
19        let mut width = None;
20        let mut last_span = input.span();
21
22        let mut current_row_width = 0;
23
24        while !input.is_empty() {
25            let token = input.parse::<TokenTree>()?;
26            last_span = token.span();
27            let err = Err(syn::Error::new(last_span, "Expected a literal or a comma"));
28
29            match token {
30                TokenTree::Punct(punct) => {
31                    if punct.as_char() != ',' {
32                        return err;
33                    }
34
35                    height += 1;
36
37                    match width {
38                        Some(width) => {
39                            if width != current_row_width {
40                                return Err(syn::Error::new(
41                                    last_span,
42                                    "All rows must have the same number of elements",
43                                ));
44                            }
45                        },
46                        None => width = Some(current_row_width),
47                    }
48
49                    current_row_width = 0;
50                },
51                TokenTree::Literal(literal) => {
52                    current_row_width += 1;
53                    data.push(literal);
54                },
55                _ => return err,
56            }
57        }
58
59        height += 1;
60
61        match width {
62            Some(width) => {
63                if width != current_row_width {
64                    return Err(syn::Error::new(
65                        last_span,
66                        "All rows must have the same number of elements",
67                    ));
68                }
69            },
70            None => width = Some(current_row_width),
71        }
72
73        Ok(Self {
74            height,
75            width: width.unwrap_or_default(),
76            data,
77        })
78    }
79}
80
81#[proc_macro]
82pub fn mat(input: TokenStream) -> TokenStream {
83    let Matrix {
84        width,
85        height,
86        data,
87    } = parse_macro_input!(input as Matrix);
88
89    let output = quote! {
90        ::smyl::maths::matrix::Matrix::new(#height, #width, vec![#(#data), *])
91    };
92
93    output.into()
94}
95
96struct Sizes(Vec<usize>);
97
98impl Parse for Sizes {
99    fn parse(input: ParseStream) -> syn::Result<Self> {
100        let punctuated = Punctuated::<LitInt, Comma>::parse_terminated(input)?;
101        let mut sizes = Vec::new();
102
103        for lit in punctuated {
104            sizes.push(lit.base10_parse()?);
105        }
106
107        Ok(Self(sizes))
108    }
109}
110
111#[derive(Copy, Clone)]
112enum Variable {
113    SynapseLayer,
114    OutputSignal,
115    InputSignal,
116    OutputGradient,
117    LayerGradient,
118}
119
120impl Variable {
121    pub fn new(self, i: usize) -> Ident {
122        let name = match self {
123            Variable::SynapseLayer => format!("synapse_layer_{i}"),
124            Variable::OutputSignal => format!("output_signal_{i}"),
125            Variable::InputSignal => "input_signal".to_string(),
126            Variable::OutputGradient => format!("output_gradient_{i}"),
127            Variable::LayerGradient => format!("layer_gradient_{i}"),
128        };
129
130        Ident::new(&name, Span::call_site())
131    }
132}
133
134#[proc_macro]
135pub fn ann(input: TokenStream) -> TokenStream {
136    let sizes = parse_macro_input!(input as Sizes).0;
137
138    let layer_count = sizes.len() - 1;
139    let input_size = sizes[0];
140
141    let input_signal = Variable::InputSignal.new(0);
142    let mut synapse_layers = Vec::with_capacity(layer_count);
143    let mut output_signals = Vec::with_capacity(layer_count);
144    let mut output_gradients = Vec::with_capacity(layer_count);
145    let mut layer_gradients = Vec::with_capacity(layer_count);
146
147    let mut vars_definitions = quote! {
148        #input_signal: Matrix<f64>,
149    };
150    let mut vars_initiations = quote! {
151         #input_signal: Matrix::zero(1, #input_size),
152    };
153
154    for i in 0..layer_count {
155        let previous_size = sizes[i];
156        let size = sizes[i + 1];
157
158        let synapse_layer = Variable::SynapseLayer.new(i);
159        let output_signal = Variable::OutputSignal.new(i);
160        let output_gradient = Variable::OutputGradient.new(i);
161        let layer_gradient = Variable::LayerGradient.new(i);
162
163        synapse_layers.insert(i, synapse_layer.clone());
164        output_signals.insert(i, output_signal.clone());
165        output_gradients.insert(i, output_gradient.clone());
166        layer_gradients.insert(i, layer_gradient.clone());
167
168        vars_initiations.extend(quote! {
169            #synapse_layer: SynapseLayer::random(#previous_size, #size, &mut rng),
170            #output_signal: Matrix::zero(1, #size),
171            #output_gradient: Matrix::zero(1, #size),
172            #layer_gradient: LayerGradient::zero(#previous_size, #size),
173        });
174
175        vars_definitions.extend(quote! {
176            #synapse_layer: SynapseLayer<f64, Sigmoid>,
177            #output_signal: Matrix<f64>,
178            #output_gradient: Matrix<f64>,
179            #layer_gradient: LayerGradient<f64>,
180        });
181    }
182
183    let mut forward_steps = proc_macro2::TokenStream::new();
184    for i in 0..layer_count {
185        let input = if i > 0 {
186            &output_signals[i - 1]
187        } else {
188            &input_signal
189        };
190        let synapses = &synapse_layers[i];
191        let output = &output_signals[i];
192
193        forward_steps.extend(quote! {
194           self.#output = self.#synapses.forward(self.#input.clone());
195        });
196    }
197
198    let mut backward_steps = proc_macro2::TokenStream::new();
199    for i in (0..layer_count).rev() {
200        let input = if i > 0 {
201            &output_signals[i - 1]
202        } else {
203            &input_signal
204        };
205        let synapses = &synapse_layers[i];
206        let output = &output_signals[i];
207        let output_gradient = &output_gradients[i];
208        let layer_gradient = &layer_gradients[i];
209        let gradient_destination = if i > 0 {
210            let last_output_gradient = &output_gradients[i - 1];
211            quote! {self.#last_output_gradient}
212        } else {
213            quote! {_}
214        };
215
216        backward_steps.extend(quote! {
217            #gradient_destination = self.#synapses.backward(
218                &self.#input,
219                self.#output.clone(),
220                self.#output_gradient.clone(),
221                &mut self.#layer_gradient
222            );
223        });
224    }
225
226    let mut apply_chunk_gradient_steps = proc_macro2::TokenStream::new();
227    for i in 0..layer_count {
228        let synapses = &synapse_layers[i];
229        let layer_gradient = &layer_gradients[i];
230
231        apply_chunk_gradient_steps.extend(quote! {
232            self.#synapses.apply_gradient(&self.#layer_gradient, learning_rate);
233            self.#layer_gradient.empty();
234        });
235    }
236
237    #[allow(unused_mut)]
238    let mut other_implementations = quote! {};
239
240    cfg_if::cfg_if! {
241        if #[cfg(feature="serde")] {
242            let mut serialize_fields = quote! {};
243
244            for synapse_layer in synapse_layers {
245                serialize_fields.extend(quote! {
246                   ann.serialize_element(&self.#synapse_layer)?;
247                });
248            }
249
250            other_implementations.extend(quote! {
251               extern crate serde;
252
253                impl serde::Serialize for ANN {
254                        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
255                        where
256                            S: serde::Serializer,
257                        {
258                            use serde::ser::SerializeSeq;
259
260                            let mut ann = serializer.serialize_seq(Some(#layer_count))?;
261                            #serialize_fields
262
263                            ann.end()
264                        }
265                }
266            });
267        }
268    }
269
270    let last_output_signal = &output_signals[layer_count - 1];
271    let last_output_gradient = &output_gradients[layer_count - 1];
272    let output = quote! {
273        {
274            extern crate smyl;
275            extern crate rand;
276
277            use smyl::prelude::*;
278            use rand::{thread_rng, Rng};
279
280            #[derive(Debug, Clone)]
281            struct ANN {
282                #vars_definitions
283            }
284
285            impl ANN {
286                pub fn forward(
287                    &mut self,
288                    input: Matrix<f64>
289                ) -> Matrix<f64> {
290                    self.#input_signal = input;
291
292                    #forward_steps
293
294                    self.#last_output_signal.clone()
295                }
296
297                pub fn backward(&mut self, expected: Matrix<f64>) {
298                    self.#last_output_gradient = self.#last_output_signal.clone() - expected;
299
300                    #backward_steps
301                }
302
303                pub fn apply_chunk_gradient(&mut self, learning_rate: f64) {
304                    #apply_chunk_gradient_steps
305                }
306
307            }
308
309            #other_implementations
310
311            let mut rng = thread_rng();
312
313            ANN {
314                #vars_initiations
315            }
316        }
317    };
318
319    output.into()
320}