1use proc_macro::{self, TokenStream};
2use proc_macro2::{Ident, Literal, Span, TokenTree};
3use quote::quote;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{parse_macro_input, LitInt};
8
9struct Matrix {
10 height: usize,
11 width: usize,
12 data: Vec<Literal>,
13}
14
15impl Parse for Matrix {
16 fn parse(input: ParseStream) -> syn::Result<Self> {
17 let mut data = Vec::new();
18 let mut height = 0;
19 let mut width = None;
20 let mut last_span = input.span();
21
22 let mut current_row_width = 0;
23
24 while !input.is_empty() {
25 let token = input.parse::<TokenTree>()?;
26 last_span = token.span();
27 let err = Err(syn::Error::new(last_span, "Expected a literal or a comma"));
28
29 match token {
30 TokenTree::Punct(punct) => {
31 if punct.as_char() != ',' {
32 return err;
33 }
34
35 height += 1;
36
37 match width {
38 Some(width) => {
39 if width != current_row_width {
40 return Err(syn::Error::new(
41 last_span,
42 "All rows must have the same number of elements",
43 ));
44 }
45 },
46 None => width = Some(current_row_width),
47 }
48
49 current_row_width = 0;
50 },
51 TokenTree::Literal(literal) => {
52 current_row_width += 1;
53 data.push(literal);
54 },
55 _ => return err,
56 }
57 }
58
59 height += 1;
60
61 match width {
62 Some(width) => {
63 if width != current_row_width {
64 return Err(syn::Error::new(
65 last_span,
66 "All rows must have the same number of elements",
67 ));
68 }
69 },
70 None => width = Some(current_row_width),
71 }
72
73 Ok(Self {
74 height,
75 width: width.unwrap_or_default(),
76 data,
77 })
78 }
79}
80
81#[proc_macro]
82pub fn mat(input: TokenStream) -> TokenStream {
83 let Matrix {
84 width,
85 height,
86 data,
87 } = parse_macro_input!(input as Matrix);
88
89 let output = quote! {
90 ::smyl::maths::matrix::Matrix::new(#height, #width, vec![#(#data), *])
91 };
92
93 output.into()
94}
95
96struct Sizes(Vec<usize>);
97
98impl Parse for Sizes {
99 fn parse(input: ParseStream) -> syn::Result<Self> {
100 let punctuated = Punctuated::<LitInt, Comma>::parse_terminated(input)?;
101 let mut sizes = Vec::new();
102
103 for lit in punctuated {
104 sizes.push(lit.base10_parse()?);
105 }
106
107 Ok(Self(sizes))
108 }
109}
110
111#[derive(Copy, Clone)]
112enum Variable {
113 SynapseLayer,
114 OutputSignal,
115 InputSignal,
116 OutputGradient,
117 LayerGradient,
118}
119
120impl Variable {
121 pub fn new(self, i: usize) -> Ident {
122 let name = match self {
123 Variable::SynapseLayer => format!("synapse_layer_{i}"),
124 Variable::OutputSignal => format!("output_signal_{i}"),
125 Variable::InputSignal => "input_signal".to_string(),
126 Variable::OutputGradient => format!("output_gradient_{i}"),
127 Variable::LayerGradient => format!("layer_gradient_{i}"),
128 };
129
130 Ident::new(&name, Span::call_site())
131 }
132}
133
134#[proc_macro]
135pub fn ann(input: TokenStream) -> TokenStream {
136 let sizes = parse_macro_input!(input as Sizes).0;
137
138 let layer_count = sizes.len() - 1;
139 let input_size = sizes[0];
140
141 let input_signal = Variable::InputSignal.new(0);
142 let mut synapse_layers = Vec::with_capacity(layer_count);
143 let mut output_signals = Vec::with_capacity(layer_count);
144 let mut output_gradients = Vec::with_capacity(layer_count);
145 let mut layer_gradients = Vec::with_capacity(layer_count);
146
147 let mut vars_definitions = quote! {
148 #input_signal: Matrix<f64>,
149 };
150 let mut vars_initiations = quote! {
151 #input_signal: Matrix::zero(1, #input_size),
152 };
153
154 for i in 0..layer_count {
155 let previous_size = sizes[i];
156 let size = sizes[i + 1];
157
158 let synapse_layer = Variable::SynapseLayer.new(i);
159 let output_signal = Variable::OutputSignal.new(i);
160 let output_gradient = Variable::OutputGradient.new(i);
161 let layer_gradient = Variable::LayerGradient.new(i);
162
163 synapse_layers.insert(i, synapse_layer.clone());
164 output_signals.insert(i, output_signal.clone());
165 output_gradients.insert(i, output_gradient.clone());
166 layer_gradients.insert(i, layer_gradient.clone());
167
168 vars_initiations.extend(quote! {
169 #synapse_layer: SynapseLayer::random(#previous_size, #size, &mut rng),
170 #output_signal: Matrix::zero(1, #size),
171 #output_gradient: Matrix::zero(1, #size),
172 #layer_gradient: LayerGradient::zero(#previous_size, #size),
173 });
174
175 vars_definitions.extend(quote! {
176 #synapse_layer: SynapseLayer<f64, Sigmoid>,
177 #output_signal: Matrix<f64>,
178 #output_gradient: Matrix<f64>,
179 #layer_gradient: LayerGradient<f64>,
180 });
181 }
182
183 let mut forward_steps = proc_macro2::TokenStream::new();
184 for i in 0..layer_count {
185 let input = if i > 0 {
186 &output_signals[i - 1]
187 } else {
188 &input_signal
189 };
190 let synapses = &synapse_layers[i];
191 let output = &output_signals[i];
192
193 forward_steps.extend(quote! {
194 self.#output = self.#synapses.forward(self.#input.clone());
195 });
196 }
197
198 let mut backward_steps = proc_macro2::TokenStream::new();
199 for i in (0..layer_count).rev() {
200 let input = if i > 0 {
201 &output_signals[i - 1]
202 } else {
203 &input_signal
204 };
205 let synapses = &synapse_layers[i];
206 let output = &output_signals[i];
207 let output_gradient = &output_gradients[i];
208 let layer_gradient = &layer_gradients[i];
209 let gradient_destination = if i > 0 {
210 let last_output_gradient = &output_gradients[i - 1];
211 quote! {self.#last_output_gradient}
212 } else {
213 quote! {_}
214 };
215
216 backward_steps.extend(quote! {
217 #gradient_destination = self.#synapses.backward(
218 &self.#input,
219 self.#output.clone(),
220 self.#output_gradient.clone(),
221 &mut self.#layer_gradient
222 );
223 });
224 }
225
226 let mut apply_chunk_gradient_steps = proc_macro2::TokenStream::new();
227 for i in 0..layer_count {
228 let synapses = &synapse_layers[i];
229 let layer_gradient = &layer_gradients[i];
230
231 apply_chunk_gradient_steps.extend(quote! {
232 self.#synapses.apply_gradient(&self.#layer_gradient, learning_rate);
233 self.#layer_gradient.empty();
234 });
235 }
236
237 #[allow(unused_mut)]
238 let mut other_implementations = quote! {};
239
240 cfg_if::cfg_if! {
241 if #[cfg(feature="serde")] {
242 let mut serialize_fields = quote! {};
243
244 for synapse_layer in synapse_layers {
245 serialize_fields.extend(quote! {
246 ann.serialize_element(&self.#synapse_layer)?;
247 });
248 }
249
250 other_implementations.extend(quote! {
251 extern crate serde;
252
253 impl serde::Serialize for ANN {
254 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
255 where
256 S: serde::Serializer,
257 {
258 use serde::ser::SerializeSeq;
259
260 let mut ann = serializer.serialize_seq(Some(#layer_count))?;
261 #serialize_fields
262
263 ann.end()
264 }
265 }
266 });
267 }
268 }
269
270 let last_output_signal = &output_signals[layer_count - 1];
271 let last_output_gradient = &output_gradients[layer_count - 1];
272 let output = quote! {
273 {
274 extern crate smyl;
275 extern crate rand;
276
277 use smyl::prelude::*;
278 use rand::{thread_rng, Rng};
279
280 #[derive(Debug, Clone)]
281 struct ANN {
282 #vars_definitions
283 }
284
285 impl ANN {
286 pub fn forward(
287 &mut self,
288 input: Matrix<f64>
289 ) -> Matrix<f64> {
290 self.#input_signal = input;
291
292 #forward_steps
293
294 self.#last_output_signal.clone()
295 }
296
297 pub fn backward(&mut self, expected: Matrix<f64>) {
298 self.#last_output_gradient = self.#last_output_signal.clone() - expected;
299
300 #backward_steps
301 }
302
303 pub fn apply_chunk_gradient(&mut self, learning_rate: f64) {
304 #apply_chunk_gradient_steps
305 }
306
307 }
308
309 #other_implementations
310
311 let mut rng = thread_rng();
312
313 ANN {
314 #vars_initiations
315 }
316 }
317 };
318
319 output.into()
320}