Skip to main content

zyx_derive/
lib.rs

1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4//! # zyx-derive
5//!
6//! This crate contains procedural macros for zyx.
7//!
8//! Macro Module automatically implements IntoIterator<Item = &Tensor>
9//! for your module, so that you can use it in backpropagation and save it to disk.
10//! ```rust
11//! use zyx::Tensor;
12//! use zyx_derive::Module;
13//!
14//! #[derive(Module)]
15//! struct MyNet {
16//!     b: Tensor,
17//!     w: Tensor,
18//! }
19//!
20//! impl MyNet {
21//!     fn forward(&self, x: &Tensor) -> Tensor {
22//!         x.dot(&self.w).unwrap() + &self.b
23//!     }
24//! }
25//! ```
26//!
27//! For README, quick tutorial and source code, please visit `<https://www.github.com/zk4x/zyx>`.
28//!
29//! For more details, there is a [book](https://www.github.com/zk4x/zyx/tree/main/zyx-book).
30#![forbid(unsafe_code)]
31#![forbid(rustdoc::broken_intra_doc_links)]
32#![forbid(rustdoc::private_intra_doc_links)]
33#![forbid(missing_docs)]
34#![forbid(rustdoc::missing_crate_level_docs)]
35//#![forbid(rustdoc::missing_doc_code_examples)]
36#![forbid(rustdoc::private_doc_tests)]
37#![forbid(rustdoc::invalid_codeblock_attributes)]
38#![forbid(rustdoc::invalid_html_tags)]
39#![forbid(rustdoc::invalid_rust_codeblocks)]
40#![forbid(rustdoc::bare_urls)]
41#![forbid(rustdoc::unescaped_backticks)]
42#![forbid(rustdoc::redundant_explicit_links)]
43
44use proc_macro::TokenStream;
45use quote::quote;
46use syn::{parse_macro_input, Data, DataStruct, DeriveInput};
47
48/// # Procedural macro Module
49///
50/// Implements FromIterator<Item = (String, Tensor)> and Module for your struct.
51///
52/// This allows saving, loading, backpropagation and updating your modules.
53#[proc_macro_derive(Module)]
54pub fn derive_module(input: TokenStream) -> TokenStream {
55    let input = parse_macro_input!(input as DeriveInput);
56    let struct_name = &input.ident;
57
58    let mut field_iterators = quote! {
59        trait __MarkerTraitRef: Sized {
60            fn __iterate_by_ref(self, res: &mut Vec<(String, &zyx::Tensor)>, label: &str) {}
61        }
62
63        struct __MarkerStructRef<T>(T);
64
65        impl<'a, T: zyx::Module> __MarkerStructRef<&'a T> {
66            fn __iterate_by_ref(self, res: &mut Vec<(String, &'a zyx::Tensor)>, label: &str) {
67                res.extend(self.0.iter_tensors().map(|(k, t)|  (format!("{label}.{k}"), t)));
68            }
69        }
70
71        impl<'a, T> __MarkerTraitRef for __MarkerStructRef<&'a T>{}
72
73        let mut res = Vec::<(String, &zyx::Tensor)>::new();
74    };
75
76    if let Data::Struct(DataStruct { fields, .. }) = &input.data {
77        for field in fields.iter() {
78            let field_name = match &field.ident {
79                Some(ident) => ident,
80                None => panic!("Unnamed fields are not supported"),
81            };
82            let field_name_str = field_name.to_string();
83
84            let field_ty: &syn::Type = &field.ty;
85
86            use std::string::ToString;
87            if quote! { #field_ty }.to_string() == "Tensor" {
88                field_iterators = quote! {
89                    #field_iterators
90                    res.push((#field_name_str.to_string(), &self.#field_name));
91                }
92            } else if quote! { #field_ty }.to_string() == "Option < Tensor >" {
93                field_iterators = quote! {
94                    #field_iterators
95                    if let Some(tensor) = &self.#field_name {
96                        res.push((#field_name_str.to_string(), tensor));
97                    }
98                }
99            } else {
100                field_iterators = quote! {
101                    #field_iterators
102                    __MarkerStructRef::<&#field_ty>::__iterate_by_ref(__MarkerStructRef(&self.#field_name), &mut res, #field_name_str);
103                };
104            }
105        }
106    }
107
108    let mut mut_field_iterators = quote! {
109        trait __MarkerTraitRef: Sized {
110            fn __iterate_by_ref(mut self, res: &mut Vec<(String, &mut zyx::Tensor)>, label: &str) {}
111        }
112
113        struct __MarkerStructRef<T>(T);
114
115        impl<'a, T: zyx::Module> __MarkerStructRef<&'a mut T> {
116            fn __iterate_by_ref(mut self, res: &mut Vec<(String, &'a mut zyx::Tensor)>, label: &str) {
117                res.extend(self.0.iter_tensors_mut().map(|(k, t)|  (format!("{label}.{k}"), t)));
118            }
119        }
120
121        impl<'a, T> __MarkerTraitRef for __MarkerStructRef<&'a mut T>{}
122
123        let mut res = Vec::<(String, &mut zyx::Tensor)>::new();
124    };
125
126    if let Data::Struct(DataStruct { fields, .. }) = &input.data {
127        for field in fields.iter() {
128            let field_name = match &field.ident {
129                Some(ident) => ident,
130                None => panic!("Unnamed fields are not supported"),
131            };
132            let field_name_str = field_name.to_string();
133
134            let field_ty: &syn::Type = &field.ty;
135
136            use std::string::ToString;
137            if quote! { #field_ty }.to_string() == "Tensor" {
138                mut_field_iterators = quote! {
139                    #mut_field_iterators
140                    res.push((#field_name_str.to_string(), &mut self.#field_name));
141                }
142            } else if quote! { #field_ty }.to_string() == "Option < Tensor >" {
143                mut_field_iterators = quote! {
144                    #mut_field_iterators
145                    if let Some(tensor) = &mut self.#field_name {
146                        res.push((#field_name_str.to_string(), tensor));
147                    }
148                }
149            } else {
150                mut_field_iterators = quote! {
151                    #mut_field_iterators
152                    __MarkerStructRef::<&mut #field_ty>::__iterate_by_ref(__MarkerStructRef(&mut self.#field_name), &mut res, #field_name_str);
153                };
154            }
155        }
156    }
157
158    let expanded = quote! {
159        impl zyx::Module for #struct_name {
160            fn iter<'a>(&'a self) -> impl Iterator<Item = &'a zyx::Tensor> {
161                self.into_iter()
162            }
163
164            fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut zyx::Tensor> {
165                self.into_iter()
166            }
167
168            fn iter_tensors<'a>(&'a self) -> impl Iterator<Item = (String, &'a zyx::Tensor)> {
169                #field_iterators
170                res.into_iter()
171            }
172
173            fn iter_tensors_mut<'a>(&'a mut self) -> impl Iterator<Item = (String, &'a mut zyx::Tensor)> {
174                #mut_field_iterators
175                res.into_iter()
176            }
177        }
178    };
179
180    // Implementation of IntoIterator<Item = &Tensor>
181    let mut field_iterators = quote! {
182        trait __MarkerTraitRef<'a> {
183            fn __iterate_by_ref(&self, res: &mut Vec<&'a zyx::Tensor>) {}
184        }
185
186        struct __MarkerStructRef<T: Copy>(T);
187
188        impl<'a, T: IntoIterator<Item = &'a zyx::Tensor> + Copy> __MarkerStructRef<T> {
189            fn __iterate_by_ref(&self, res: &mut Vec<&'a zyx::Tensor>) {
190                res.extend(self.0.into_iter());
191            }
192        }
193
194        impl<'a, T: Copy> __MarkerTraitRef<'a> for __MarkerStructRef<T>{}
195
196        let mut res = Vec::<&zyx::Tensor>::new();
197    };
198
199    if let Data::Struct(DataStruct { fields, .. }) = &input.data {
200        for field in fields.iter() {
201            let field_name = match &field.ident {
202                Some(ident) => ident,
203                None => panic!("Unnamed fields are not supported"),
204            };
205            let field_ty: &syn::Type = &field.ty;
206            use std::string::ToString;
207            if quote! { #field_ty }.to_string() == "Tensor" {
208                field_iterators = quote! {
209                    #field_iterators
210                    res.push(&self.#field_name);
211                }
212            } else {
213                field_iterators = quote! {
214                    #field_iterators
215                    __MarkerStructRef::<&#field_ty>::__iterate_by_ref(&__MarkerStructRef(&self.#field_name), &mut res);
216                };
217            }
218        }
219    }
220
221    let expanded = quote! {
222        #expanded
223
224        impl<'a> IntoIterator for &'a #struct_name {
225            type Item = &'a zyx::Tensor;
226            type IntoIter = std::vec::IntoIter<&'a zyx::Tensor>;
227
228            fn into_iter(self) -> Self::IntoIter {
229                #field_iterators
230                res.into_iter()
231            }
232        }
233    };
234
235    // Implementation of IntoIterator<Item = &mut Tensor>
236    let mut field_iterators = quote! {
237        trait MarkerTraitMut<'a>: Sized {
238            fn iterate_by_mut(mut self, res: &mut Vec<&'a mut zyx::Tensor>) {}
239        }
240
241        struct MarkerStructMut<T>(T);
242
243        impl<'a, T: IntoIterator<Item = &'a mut zyx::Tensor>> MarkerStructMut<T> {
244            fn iterate_by_mut(mut self, res: &mut Vec<&'a mut zyx::Tensor>) {
245                res.extend(self.0.into_iter());
246            }
247        }
248
249        impl<'a, T> MarkerTraitMut<'a> for MarkerStructMut<T>{}
250
251        let mut res = Vec::<&mut zyx::Tensor>::new();
252    };
253
254    if let Data::Struct(DataStruct { fields, .. }) = &input.data {
255        for field in fields.iter() {
256            let field_name = match &field.ident {
257                Some(ident) => ident,
258                None => panic!("Unnamed fields are not supported"),
259            };
260            let field_ty: &syn::Type = &field.ty;
261            use std::string::ToString;
262            if quote! { #field_ty }.to_string() == "Tensor" {
263                field_iterators = quote! {
264                    #field_iterators
265                    res.push(&mut self.#field_name);
266                }
267            } else {
268                field_iterators = quote! {
269                    #field_iterators
270                    MarkerStructMut::<&mut #field_ty>::iterate_by_mut(MarkerStructMut(&mut self.#field_name), &mut res);
271                };
272            }
273        }
274    }
275
276    let expanded = quote! {
277        #expanded
278
279        impl<'a> IntoIterator for &'a mut #struct_name {
280            type Item = &'a mut zyx::Tensor;
281            type IntoIter = std::vec::IntoIter<&'a mut zyx::Tensor>;
282
283            fn into_iter(self) -> Self::IntoIter {
284                #field_iterators
285                res.into_iter()
286            }
287        }
288    };
289
290    TokenStream::from(expanded)
291}