type_flow_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{parse_macro_input, Ident, Type, parse::Parse, parse::ParseStream, Token, /*GenericParam,*/ ItemStruct, FieldsNamed, Fields};
4
5struct ProcessorPipeline {
6    pipeline_name: Ident,
7    data_type: Type,
8    processors: Vec<Type>,
9}
10
11impl Parse for ProcessorPipeline {
12    fn parse(input: ParseStream) -> syn::Result<Self> {
13        let pipeline_name: Ident = input.parse()?;
14        input.parse::<Token![,]>()?;
15        let data_type: Type = input.parse()?;
16        input.parse::<Token![,]>()?;
17        let mut processors = Vec::new();
18        // Parse the remaining types as processor types
19        while !input.is_empty() {
20            let processor_type: Type = input.parse()?;
21            processors.push(processor_type);
22            if input.is_empty() {
23                break;
24            }
25            input.parse::<Token![,]>()?;
26        }
27        Ok(ProcessorPipeline {
28            pipeline_name,
29            data_type,
30            processors,
31        })
32    }
33}
34
35#[proc_macro]
36pub fn stateful_processor_pipeline_with_index(input: TokenStream) -> TokenStream {
37    let ProcessorPipeline { pipeline_name, data_type, processors } = parse_macro_input!(input as ProcessorPipeline);
38    let struct_fields = processors.iter().enumerate().map(|(idx, processor_type)| {
39        // Get a string representation of the type for the field name
40        let type_str = get_type_name(processor_type);
41        let field_name = format_ident!("processor_{}_{}", type_str, idx);
42        quote! { #field_name: #processor_type }
43    });
44    let constructor_params = processors.iter().enumerate().map(|(idx, processor_type)| {
45        let type_str = get_type_name(processor_type);
46        let param_name = format_ident!("processor_{}_{}", type_str, idx);
47        quote! { #param_name: #processor_type }
48    });
49    let field_initializers = processors.iter().enumerate().map(|(idx, processor_type)| {
50        let type_str = get_type_name(processor_type);
51        let field_name = format_ident!("processor_{}_{}", type_str, idx);
52        quote! { #field_name }
53    });
54    let process_implementation = processors.iter().enumerate().map(|(idx, processor_type)| {
55        let type_str = get_type_name(processor_type);
56        let field_name = format_ident!("processor_{}_{}", type_str, idx);
57        quote! { data = self.#field_name.process(data); }
58    });
59    let expanded = quote! {
60        pub struct #pipeline_name {
61            #(#struct_fields,)*
62        }
63        impl #pipeline_name {
64            pub fn new(#(#constructor_params,)*) -> Self {
65                Self {
66                    #(#field_initializers,)*
67                }
68            }
69        }
70        impl crate::StatefulProcessor<#data_type> for #pipeline_name {
71            fn process(&mut self, mut data: #data_type) -> #data_type {
72                #(#process_implementation)*
73                data
74            }
75        }
76    };
77    TokenStream::from(expanded)
78}
79
80struct ProcessorInplacePipeline {
81    pipeline_name: Ident,
82    data_type: Type,
83    error_type: Type,
84    processors: Vec<Type>,
85}
86
87impl Parse for ProcessorInplacePipeline {
88    fn parse(input: ParseStream) -> syn::Result<Self> {
89        let pipeline_name: Ident = input.parse()?;
90        input.parse::<Token![,]>()?;
91        let data_type: Type = input.parse()?;
92        input.parse::<Token![,]>()?;
93        let error_type: Type = input.parse()?;
94        input.parse::<Token![,]>()?;
95        let mut processors = Vec::new();
96        // Parse the remaining types as processor types
97        while !input.is_empty() {
98            let processor_type: Type = input.parse()?;
99            processors.push(processor_type);
100
101            if input.is_empty() {
102                break;
103            }
104            input.parse::<Token![,]>()?;
105        }
106        Ok(ProcessorInplacePipeline{
107            pipeline_name,
108            data_type,
109            error_type,
110            processors,
111        })
112    }
113}
114
115
116#[proc_macro_attribute]
117pub fn implement_processor_swapping(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
118    let input_struct_item = parse_macro_input!(item as ItemStruct);
119    let struct_name = &input_struct_item.ident;
120    let generics = &input_struct_item.generics;
121    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
122    let type_param_idents: Vec<&Ident> = generics.type_params().map(|tp| &tp.ident).collect();
123    let fields = match &input_struct_item.fields {
124        Fields::Named(FieldsNamed { named, .. }) => named,
125        _ => panic!("#[implement_processor_swapping] only works on structs with named fields."),
126    };
127    let processor_field_idents: Vec<&Ident> = fields.iter()
128        .filter_map(|f| f.ident.as_ref())
129        .filter(|&id| id.to_string() != "_marker")
130        .collect();
131    if type_param_idents.len() != processor_field_idents.len() {
132        panic!(
133            "#[implement_processor_swapping] detected a mismatch between the number of generic type parameters ({}) and processor fields ({}). Expected them to be equal.",
134            type_param_idents.len(),
135            processor_field_idents.len()
136        );
137    }
138    // --- Trait Impls ---
139    let mut trait_impls = Vec::new();
140    if !type_param_idents.is_empty() {
141        // Reverse
142        let mut rev_types = type_param_idents.clone();
143        rev_types.reverse();
144        let mut rev_source_fields = processor_field_idents.clone();
145        rev_source_fields.reverse();
146        let rev_assignments = processor_field_idents.iter().zip(rev_source_fields.iter()).map(|(dest, src)| {
147            quote!{ #dest: self.#src }
148        });
149        trait_impls.push(quote!{
150            impl #impl_generics type_flow_traits::Reverse for #struct_name #ty_generics #where_clause {
151                type Output = #struct_name<#(#rev_types),*>;
152                fn reverse(self) -> Self::Output {
153                    Self::Output {
154                        _marker: std::marker::PhantomData,
155                        #(#rev_assignments),*
156                    }
157                }
158            }
159        });
160        // ShiftLeft
161        let mut sl_types = type_param_idents.clone();
162        sl_types.rotate_left(1);
163        let mut sl_source_fields = processor_field_idents.clone();
164        sl_source_fields.rotate_left(1);
165        let sl_assignments = processor_field_idents.iter().zip(sl_source_fields.iter()).map(|(dest, src)| {
166            quote!{ #dest: self.#src }
167        });
168        trait_impls.push(quote!{
169            impl #impl_generics type_flow_traits::ShiftLeft for #struct_name #ty_generics #where_clause {
170                type ShiftedLeft = #struct_name<#(#sl_types),*>;
171                fn shift_left(self) -> Self::ShiftedLeft {
172                    Self::ShiftedLeft {
173                        _marker: std::marker::PhantomData,
174                        #(#sl_assignments),*
175                    }
176                }
177            }
178        });
179        // ShiftRight
180        let mut sr_types = type_param_idents.clone();
181        sr_types.rotate_right(1);
182        let mut sr_source_fields = processor_field_idents.clone();
183        sr_source_fields.rotate_right(1);
184        let sr_assignments = processor_field_idents.iter().zip(sr_source_fields.iter()).map(|(dest, src)| {
185            quote!{ #dest: self.#src }
186        });
187        trait_impls.push(quote!{
188            impl #impl_generics type_flow_traits::ShiftRight for #struct_name #ty_generics #where_clause {
189                type ShiftedRight = #struct_name<#(#sr_types),*>;
190                fn shift_right(self) -> Self::ShiftedRight {
191                    Self::ShiftedRight {
192                        _marker: std::marker::PhantomData,
193                        #(#sr_assignments),*
194                    }
195                }
196            }
197        });
198        // SwapStartEnd
199        let mut sse_types = type_param_idents.clone();
200        let sse_length = sse_types.len();
201        if sse_length > 1 { sse_types.swap(0, sse_length - 1); }
202        let mut sse_source_fields = processor_field_idents.clone();
203        let sse_source_length = sse_source_fields.len();
204        if sse_source_length > 1 { sse_source_fields.swap(0, sse_source_length - 1); }
205        let sse_assignments = processor_field_idents.iter().zip(sse_source_fields.iter()).map(|(dest, src)| {
206            quote!{ #dest: self.#src }
207        });
208        trait_impls.push(quote!{
209            impl #impl_generics type_flow_traits::SwapStartEnd for #struct_name #ty_generics #where_clause {
210                type Output = #struct_name<#(#sse_types),*>;
211                fn swap(self) -> Self::Output {
212                    Self::Output {
213                        _marker: std::marker::PhantomData,
214                        #(#sse_assignments),*
215                    }
216                }
217            }
218        });
219    }
220    let num_processors = type_param_idents.len();
221    if num_processors < 2 {
222        let trait_def_placeholder = quote! {
223            pub trait SwapArbitraryProcessors<const I: usize, const J: usize> where Self: Sized {
224                type SwappedOutput;
225                fn swap_processors(self) -> Self::SwappedOutput;
226            }
227        };
228        return quote! {
229            #input_struct_item
230            #trait_def_placeholder
231            #(#trait_impls)*
232        }.into();
233    }
234    let mut generated_impls = Vec::new();
235    let swap_trait_name = format_ident!("SwapArbitraryProcessors");
236    for i in 0..num_processors {
237        for j in 0..num_processors {
238            if i == j { continue; }
239            let mut swapped_generic_params_for_type = type_param_idents.clone();
240            swapped_generic_params_for_type.swap(i, j);
241            let field_i_name = processor_field_idents[i];
242            let field_j_name = processor_field_idents[j];
243            let mut current_field_initializers = Vec::new();
244            for k in 0..num_processors {
245                let current_field_name = processor_field_idents[k];
246                if k == i {
247                    current_field_initializers.push(quote! { #current_field_name: self.#field_j_name });
248                } else if k == j {
249                    current_field_initializers.push(quote! { #current_field_name: self.#field_i_name });
250                } else {
251                    current_field_initializers.push(quote! { #current_field_name: self.#current_field_name });
252                }
253            }
254            generated_impls.push(quote! {
255                impl #impl_generics #swap_trait_name<#i, #j> for #struct_name #ty_generics #where_clause {
256                    type SwappedOutput = #struct_name<#(#swapped_generic_params_for_type),*>;
257                    fn swap_processors(self) -> Self::SwappedOutput {
258                        #struct_name {
259                            _marker: std::marker::PhantomData,
260                            #( #current_field_initializers ),*
261                        }
262                    }
263                }
264            });
265        }
266    }
267
268    let trait_def = quote! {
269        pub trait #swap_trait_name<const I: usize, const J: usize>
270        where
271            Self: Sized,
272        {
273            type SwappedOutput;
274            fn swap_processors(self) -> Self::SwappedOutput;
275        }
276    };
277
278    let expanded = quote! {
279        #input_struct_item
280        #trait_def
281        #(#generated_impls)*
282        #(#trait_impls)*
283    };
284    proc_macro::TokenStream::from(expanded)
285}
286
287#[proc_macro]
288pub fn inplace_stateful_processor_pipeline_with_index(input: TokenStream) -> TokenStream { 
289    let ProcessorInplacePipeline { pipeline_name, data_type, error_type, processors } = parse_macro_input!(input as ProcessorInplacePipeline);
290    let struct_fields = processors.iter().enumerate().map(|(idx, processor_type)| {
291        // Get a string representation of the type for the field name
292        let type_str = get_type_name(processor_type);
293        let field_name = format_ident!("processor_{}_{}", type_str, idx);
294        quote! { #field_name: #processor_type }
295    });
296    let constructor_params = processors.iter().enumerate().map(|(idx, processor_type)| {
297        let type_str = get_type_name(processor_type);
298        let param_name = format_ident!("processor_{}_{}", type_str, idx);
299        quote! { #param_name: #processor_type }
300    });
301    let field_initializers = processors.iter().enumerate().map(|(idx, processor_type)| {
302        let type_str = get_type_name(processor_type);
303        let field_name = format_ident!("processor_{}_{}", type_str, idx);
304        quote! { #field_name }
305    });
306    let process_implementation = processors.iter().enumerate().map(|(idx, processor_type)| {
307        let type_str = get_type_name(processor_type);
308        let field_name = format_ident!("processor_{}_{}", type_str, idx);
309        quote! { self.#field_name.process(data)?; }
310    });
311    let expanded = quote! {
312        pub struct #pipeline_name {
313            #(#struct_fields,)*
314        }
315        impl #pipeline_name {
316            pub fn new(#(#constructor_params,)*) -> Self {
317                Self {
318                    #(#field_initializers,)*
319                }
320            }
321        }
322        impl crate::InPlaceStatefulProcessor<#data_type, #error_type> for #pipeline_name {
323            fn process(&mut self, data: &mut #data_type) -> Result<(), #error_type> {
324                #(#process_implementation)*
325                Ok(())
326            }
327        }
328    };
329    TokenStream::from(expanded)
330}
331
332// Helper function to extract a usable identifier from a Type
333fn get_type_name(ty: &syn::Type) -> String {
334    match ty {
335        syn::Type::Path(type_path) if !type_path.path.segments.is_empty() => {
336            // Get the last segment of the path (e.g., for std::string::String, get "String")
337            let segment = type_path.path.segments.last().unwrap();
338            let name = segment.ident.to_string();
339
340            // Filter out any characters that aren't valid in an identifier
341            name.chars()
342                .map(|c| if c.is_alphanumeric() { c } else { '_' })
343                .collect()
344        },
345        // Handle other types (arrays, references, etc.) by using a generic name
346        _ => "unknown_type".to_string(),
347    }
348}