1#![no_std]
29#![forbid(unsafe_code)]
30#![forbid(rustdoc::broken_intra_doc_links)]
31#![forbid(rustdoc::private_intra_doc_links)]
32#![forbid(missing_docs)]
33#![forbid(rustdoc::missing_crate_level_docs)]
34#![forbid(rustdoc::private_doc_tests)]
36#![forbid(rustdoc::invalid_codeblock_attributes)]
37#![forbid(rustdoc::invalid_html_tags)]
38#![forbid(rustdoc::invalid_rust_codeblocks)]
39#![forbid(rustdoc::bare_urls)]
40#![forbid(rustdoc::unescaped_backticks)]
41#![forbid(rustdoc::redundant_explicit_links)]
42
43extern crate proc_macro;
44use proc_macro::TokenStream;
45use quote::quote;
46use syn::{parse_macro_input, Data, DataStruct, DeriveInput};
47
48#[proc_macro_derive(Module)]
54pub fn into_iterator_item_tensor(input: TokenStream) -> TokenStream {
55 let input = parse_macro_input!(input as DeriveInput);
56 let struct_name = &input.ident;
57 let mut field_iterators = quote! {
58 trait __MarkerTraitRef<'a, B: zyx_core::backend::Backend + 'a> {
59 fn __iterate_by_ref(&self, res: &mut Vec<&'a zyx_core::tensor::Tensor<B>>) {}
60 }
61
62 struct __MarkerStructRef<T: Copy>(T);
63
64 impl<'a, B: zyx_core::backend::Backend + 'a, T: IntoIterator<Item = &'a zyx_core::tensor::Tensor<B>> + Copy> __MarkerStructRef<T> {
65 fn __iterate_by_ref(&self, res: &mut Vec<&'a zyx_core::tensor::Tensor<B>>) {
66 res.extend(self.0.into_iter());
67 }
68 }
69
70 impl<'a, B: zyx_core::backend::Backend + 'a, T: Copy> __MarkerTraitRef<'a, B> for __MarkerStructRef<T>{}
71
72 let mut res = Vec::<&zyx_core::tensor::Tensor<_>>::new();
73 };
74
75 if let Data::Struct(DataStruct { fields, .. }) = &input.data {
76 for field in fields.iter() {
77 let field_name = match &field.ident {
78 Some(ident) => ident,
79 None => panic!("Unnamed fields are not supported"),
80 };
81 let field_ty: &syn::Type = &field.ty;
82 field_iterators = quote! {
84 #field_iterators
85 __MarkerStructRef::<&#field_ty>::__iterate_by_ref(&__MarkerStructRef(&self.#field_name), &mut res);
86 };
87 }
88 }
89
90 let expanded = quote! {
91 impl<'a, B: zyx_core::backend::Backend> IntoIterator for &'a #struct_name<B> {
92 type Item = &'a zyx_core::tensor::Tensor<B>;
93 type IntoIter = std::vec::IntoIter<&'a zyx_core::tensor::Tensor<B>>;
94
95 fn into_iter(self) -> Self::IntoIter {
96 #field_iterators
97 res.into_iter()
98 }
99 }
100 };
101
102 let mut field_iterators = quote! {
103 trait __MarkerTraitMut<'a, B: zyx_core::backend::Backend + 'a>: Sized {
104 fn __iterate_by_mut(mut self, res: &mut Vec<&'a mut zyx_core::tensor::Tensor<B>>) {}
105 }
106
107 struct __MarkerStructMut<T>(T);
108
109 impl<'a, B: zyx_core::backend::Backend + 'a, T: IntoIterator<Item = &'a mut zyx_core::tensor::Tensor<B>>> __MarkerStructMut<T> {
110 fn __iterate_by_mut(mut self, res: &mut Vec<&'a mut zyx_core::tensor::Tensor<B>>) {
111 res.extend(self.0.into_iter());
112 }
113 }
114
115 impl<'a, B: zyx_core::backend::Backend + 'a, T> __MarkerTraitMut<'a, B> for __MarkerStructMut<T>{}
116
117 let mut res = Vec::<&mut zyx_core::tensor::Tensor<_>>::new();
118 };
119
120 if let Data::Struct(DataStruct { fields, .. }) = &input.data {
121 for field in fields.iter() {
122 let field_name = match &field.ident {
123 Some(ident) => ident,
124 None => panic!("Unnamed fields are not supported"),
125 };
126 let field_ty: &syn::Type = &field.ty;
127 field_iterators = quote! {
128 #field_iterators
129 __MarkerStructMut::<&#field_ty>::__iterate_by_mut(__MarkerStructMut(&mut self.#field_name), &mut res);
130 };
131 }
132 }
133
134 let expanded = quote! {
135 #expanded
136
137 impl<'a, B: zyx_core::backend::Backend> IntoIterator for &'a mut #struct_name<B> {
138 type Item = &'a mut zyx_core::tensor::Tensor<B>;
139 type IntoIter = std::vec::IntoIter<&'a mut zyx_core::tensor::Tensor<B>>;
140
141 fn into_iter(self) -> Self::IntoIter {
142 #field_iterators
143 res.into_iter()
144 }
145 }
146 };
147
148 TokenStream::from(expanded)
149}