zyx_derive/
lib.rs

1//! # zyx-derive
2//!
3//! This crate contains procedural macros for zyx.
4//!
5//! Macro Module automatically implements IntoIterator<Item = &Tensor>
6//! for your module, so that you can use it in backpropagation and save it to disk.
7//! ```rust
8//! use zyx_core::backend::Backend;
9//! use zyx_core::tensor::Tensor;
10//! use zyx_derive::Module;
11//!
12//! #[derive(Module)]
13//! struct MyNet<B: Backend> {
14//!     b: Tensor<B>,
15//!     w: Tensor<B>,
16//! }
17//!
18//! impl<B: Backend> MyNet<B> {
19//!     fn forward(&self, x: &Tensor<B>) -> Tensor<B> {
20//!         x.dot(&self.w) + &self.b
21//!     }
22//! }
23//! ```
24//!
25//! For README, quick tutorial and source code, please visit [https://www.github.com/zk4x/zyx].
26//!
27//! For more details, there is a [book](https://www.github.com/zk4x/zyx/tree/main/zyx-book).
28#![no_std]
29#![forbid(unsafe_code)]
30#![forbid(rustdoc::broken_intra_doc_links)]
31#![forbid(rustdoc::private_intra_doc_links)]
32#![forbid(missing_docs)]
33#![forbid(rustdoc::missing_crate_level_docs)]
34//#![forbid(rustdoc::missing_doc_code_examples)]
35#![forbid(rustdoc::private_doc_tests)]
36#![forbid(rustdoc::invalid_codeblock_attributes)]
37#![forbid(rustdoc::invalid_html_tags)]
38#![forbid(rustdoc::invalid_rust_codeblocks)]
39#![forbid(rustdoc::bare_urls)]
40#![forbid(rustdoc::unescaped_backticks)]
41#![forbid(rustdoc::redundant_explicit_links)]
42
43extern crate proc_macro;
44use proc_macro::TokenStream;
45use quote::quote;
46use syn::{parse_macro_input, Data, DataStruct, DeriveInput};
47
48/// # Procedural macro Module
49///
50/// Implements IntoIterator<Item = &Tensor> and IntoIterator<Item = &mut Tensor> for your struct.
51///
52/// This allows saving, loading, backpropagation and updating your modules.
53#[proc_macro_derive(Module)]
54pub fn into_iterator_item_tensor(input: TokenStream) -> TokenStream {
55    let input = parse_macro_input!(input as DeriveInput);
56    let struct_name = &input.ident;
57    let mut field_iterators = quote! {
58        trait __MarkerTraitRef<'a, B: zyx_core::backend::Backend + 'a> {
59            fn __iterate_by_ref(&self, res: &mut Vec<&'a zyx_core::tensor::Tensor<B>>) {}
60        }
61
62        struct __MarkerStructRef<T: Copy>(T);
63
64        impl<'a, B: zyx_core::backend::Backend + 'a, T: IntoIterator<Item = &'a zyx_core::tensor::Tensor<B>> + Copy> __MarkerStructRef<T> {
65            fn __iterate_by_ref(&self, res: &mut Vec<&'a zyx_core::tensor::Tensor<B>>) {
66                res.extend(self.0.into_iter());
67            }
68        }
69
70        impl<'a, B: zyx_core::backend::Backend + 'a, T: Copy> __MarkerTraitRef<'a, B> for __MarkerStructRef<T>{}
71
72        let mut res = Vec::<&zyx_core::tensor::Tensor<_>>::new();
73    };
74
75    if let Data::Struct(DataStruct { fields, .. }) = &input.data {
76        for field in fields.iter() {
77            let field_name = match &field.ident {
78                Some(ident) => ident,
79                None => panic!("Unnamed fields are not supported"),
80            };
81            let field_ty: &syn::Type = &field.ty;
82            // TODO check if field is tensor, or implement IntoIterator<Item = &Tensor> for &Tensor
83            field_iterators = quote! {
84                #field_iterators
85                __MarkerStructRef::<&#field_ty>::__iterate_by_ref(&__MarkerStructRef(&self.#field_name), &mut res);
86            };
87        }
88    }
89
90    let expanded = quote! {
91        impl<'a, B: zyx_core::backend::Backend> IntoIterator for &'a #struct_name<B> {
92            type Item = &'a zyx_core::tensor::Tensor<B>;
93            type IntoIter = std::vec::IntoIter<&'a zyx_core::tensor::Tensor<B>>;
94
95            fn into_iter(self) -> Self::IntoIter {
96                #field_iterators
97                res.into_iter()
98            }
99        }
100    };
101
102    let mut field_iterators = quote! {
103        trait __MarkerTraitMut<'a, B: zyx_core::backend::Backend + 'a>: Sized {
104            fn __iterate_by_mut(mut self, res: &mut Vec<&'a mut zyx_core::tensor::Tensor<B>>) {}
105        }
106
107        struct __MarkerStructMut<T>(T);
108
109        impl<'a, B: zyx_core::backend::Backend + 'a, T: IntoIterator<Item = &'a mut zyx_core::tensor::Tensor<B>>> __MarkerStructMut<T> {
110            fn __iterate_by_mut(mut self, res: &mut Vec<&'a mut zyx_core::tensor::Tensor<B>>) {
111                res.extend(self.0.into_iter());
112            }
113        }
114
115        impl<'a, B: zyx_core::backend::Backend + 'a, T> __MarkerTraitMut<'a, B> for __MarkerStructMut<T>{}
116
117        let mut res = Vec::<&mut zyx_core::tensor::Tensor<_>>::new();
118    };
119
120    if let Data::Struct(DataStruct { fields, .. }) = &input.data {
121        for field in fields.iter() {
122            let field_name = match &field.ident {
123                Some(ident) => ident,
124                None => panic!("Unnamed fields are not supported"),
125            };
126            let field_ty: &syn::Type = &field.ty;
127            field_iterators = quote! {
128                #field_iterators
129                __MarkerStructMut::<&#field_ty>::__iterate_by_mut(__MarkerStructMut(&mut self.#field_name), &mut res);
130            };
131        }
132    }
133
134    let expanded = quote! {
135        #expanded
136
137        impl<'a, B: zyx_core::backend::Backend> IntoIterator for &'a mut #struct_name<B> {
138            type Item = &'a mut zyx_core::tensor::Tensor<B>;
139            type IntoIter = std::vec::IntoIter<&'a mut zyx_core::tensor::Tensor<B>>;
140
141            fn into_iter(self) -> Self::IntoIter {
142                #field_iterators
143                res.into_iter()
144            }
145        }
146    };
147
148    TokenStream::from(expanded)
149}