type_flow_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{parse_macro_input, Ident, Type, parse::Parse, parse::ParseStream, Token, ItemStruct, FieldsNamed, Fields};
4
5struct ProcessorPipeline {
6    pipeline_name: Ident,
7    data_type: Type,
8    processors: Vec<Type>,
9}
10
11impl Parse for ProcessorPipeline {
12    fn parse(input: ParseStream) -> syn::Result<Self> {
13        let pipeline_name: Ident = input.parse()?;
14        input.parse::<Token![,]>()?;
15        let data_type: Type = input.parse()?;
16        input.parse::<Token![,]>()?;
17        let mut processors = Vec::new();
18        // Parse the remaining types as processor types
19        while !input.is_empty() {
20            let processor_type: Type = input.parse()?;
21            processors.push(processor_type);
22            if input.is_empty() {
23                break;
24            }
25            input.parse::<Token![,]>()?;
26        }
27        Ok(ProcessorPipeline {
28            pipeline_name,
29            data_type,
30            processors,
31        })
32    }
33}
34
35#[proc_macro]
36pub fn stateful_processor_pipeline_with_index(input: TokenStream) -> TokenStream {
37    let ProcessorPipeline { pipeline_name, data_type, processors } = parse_macro_input!(input as ProcessorPipeline);
38    let struct_fields = processors.iter().enumerate().map(|(idx, processor_type)| {
39        // Get a string representation of the type for the field name
40        let type_str = get_type_name(processor_type);
41        let field_name = format_ident!("processor_{}_{}", type_str, idx);
42        quote! { #field_name: #processor_type }
43    });
44    let constructor_params = processors.iter().enumerate().map(|(idx, processor_type)| {
45        let type_str = get_type_name(processor_type);
46        let param_name = format_ident!("processor_{}_{}", type_str, idx);
47        quote! { #param_name: #processor_type }
48    });
49    let field_initializers = processors.iter().enumerate().map(|(idx, processor_type)| {
50        let type_str = get_type_name(processor_type);
51        let field_name = format_ident!("processor_{}_{}", type_str, idx);
52        quote! { #field_name }
53    });
54    let process_implementation = processors.iter().enumerate().map(|(idx, processor_type)| {
55        let type_str = get_type_name(processor_type);
56        let field_name = format_ident!("processor_{}_{}", type_str, idx);
57        quote! { data = self.#field_name.process(data); }
58    });
59    let expanded = quote! {
60        pub struct #pipeline_name {
61            #(#struct_fields,)*
62        }
63        impl #pipeline_name {
64            pub fn new(#(#constructor_params,)*) -> Self {
65                Self {
66                    #(#field_initializers,)*
67                }
68            }
69        }
70        impl StatefulProcessor<#data_type> for #pipeline_name {
71            fn process(&mut self, mut data: #data_type) -> #data_type {
72                #(#process_implementation)*
73                data
74            }
75        }
76    };
77    TokenStream::from(expanded)
78}
79
80struct ProcessorInplacePipeline {
81    pipeline_name: Ident,
82    data_type: Type,
83    error_type: Type,
84    processors: Vec<Type>,
85}
86
87impl Parse for ProcessorInplacePipeline {
88    fn parse(input: ParseStream) -> syn::Result<Self> {
89        let pipeline_name: Ident = input.parse()?;
90        input.parse::<Token![,]>()?;
91        let data_type: Type = input.parse()?;
92        input.parse::<Token![,]>()?;
93        let error_type: Type = input.parse()?;
94        input.parse::<Token![,]>()?;
95        let mut processors = Vec::new();
96        // Parse the remaining types as processor types
97        while !input.is_empty() {
98            let processor_type: Type = input.parse()?;
99            processors.push(processor_type);
100
101            if input.is_empty() {
102                break;
103            }
104            input.parse::<Token![,]>()?;
105        }
106        Ok(ProcessorInplacePipeline{
107            pipeline_name,
108            data_type,
109            error_type,
110            processors,
111        })
112    }
113}
114
115
116#[proc_macro_attribute]
117pub fn implement_processor_swapping(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
118    let input_struct_item = parse_macro_input!(item as ItemStruct);
119    let struct_name = &input_struct_item.ident;
120    let generics = &input_struct_item.generics;
121    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
122    let type_param_idents: Vec<&Ident> = generics.type_params().map(|tp| &tp.ident).collect();
123    let fields = match &input_struct_item.fields {
124        Fields::Named(FieldsNamed { named, .. }) => named,
125        _ => panic!("#[implement_processor_swapping] only works on structs with named fields."),
126    };
127    let processor_field_idents: Vec<&Ident> = fields.iter()
128        .filter_map(|f| f.ident.as_ref())
129        .filter(|&id| id.to_string() != "_marker")
130        .collect();
131    if type_param_idents.len() != processor_field_idents.len() {
132        panic!(
133            "#[implement_processor_swapping] detected a mismatch between the number of generic type parameters ({}) and processor fields ({}). Expected them to be equal.",
134            type_param_idents.len(),
135            processor_field_idents.len()
136        );
137    }
138    // --- Trait Impls ---
139    let mut trait_impls = Vec::new();
140    if !type_param_idents.is_empty() {
141        // Reverse
142        let mut rev_types = type_param_idents.clone();
143        rev_types.reverse();
144        let mut rev_source_fields = processor_field_idents.clone();
145        rev_source_fields.reverse();
146        let rev_assignments = processor_field_idents.iter().zip(rev_source_fields.iter()).map(|(dest, src)| {
147            quote!{ #dest: self.#src }
148        });
149        trait_impls.push(quote!{
150            impl #impl_generics type_flow_traits::Reverse for #struct_name #ty_generics #where_clause {
151                type Output = #struct_name<#(#rev_types),*>;
152                fn reverse(self) -> Self::Output {
153                    Self::Output {
154                        _marker: std::marker::PhantomData,
155                        #(#rev_assignments),*
156                    }
157                }
158            }
159        });
160        // ShiftLeft
161        let mut sl_types = type_param_idents.clone();
162        sl_types.rotate_left(1);
163        let mut sl_source_fields = processor_field_idents.clone();
164        sl_source_fields.rotate_left(1);
165        let sl_assignments = processor_field_idents.iter().zip(sl_source_fields.iter()).map(|(dest, src)| {
166            quote!{ #dest: self.#src }
167        });
168        trait_impls.push(quote!{
169            impl #impl_generics type_flow_traits::ShiftLeft for #struct_name #ty_generics #where_clause {
170                type ShiftedLeft = #struct_name<#(#sl_types),*>;
171                fn shift_left(self) -> Self::ShiftedLeft {
172                    Self::ShiftedLeft {
173                        _marker: std::marker::PhantomData,
174                        #(#sl_assignments),*
175                    }
176                }
177            }
178        });
179        // ShiftRight
180        let mut sr_types = type_param_idents.clone();
181        sr_types.rotate_right(1);
182        let mut sr_source_fields = processor_field_idents.clone();
183        sr_source_fields.rotate_right(1);
184        let sr_assignments = processor_field_idents.iter().zip(sr_source_fields.iter()).map(|(dest, src)| {
185            quote!{ #dest: self.#src }
186        });
187        trait_impls.push(quote!{
188            impl #impl_generics type_flow_traits::ShiftRight for #struct_name #ty_generics #where_clause {
189                type ShiftedRight = #struct_name<#(#sr_types),*>;
190                fn shift_right(self) -> Self::ShiftedRight {
191                    Self::ShiftedRight {
192                        _marker: std::marker::PhantomData,
193                        #(#sr_assignments),*
194                    }
195                }
196            }
197        });
198        // SwapStartEnd
199        let mut sse_types = type_param_idents.clone();
200        let sse_length = sse_types.len();
201        if sse_length > 1 { sse_types.swap(0, sse_length - 1); }
202        let mut sse_source_fields = processor_field_idents.clone();
203        let sse_source_length = sse_source_fields.len();
204        if sse_source_length > 1 { sse_source_fields.swap(0, sse_source_length - 1); }
205        let sse_assignments = processor_field_idents.iter().zip(sse_source_fields.iter()).map(|(dest, src)| {
206            quote!{ #dest: self.#src }
207        });
208        trait_impls.push(quote!{
209            impl #impl_generics type_flow_traits::SwapStartEnd for #struct_name #ty_generics #where_clause {
210                type Output = #struct_name<#(#sse_types),*>;
211                fn swap(self) -> Self::Output {
212                    Self::Output {
213                        _marker: std::marker::PhantomData,
214                        #(#sse_assignments),*
215                    }
216                }
217            }
218        });
219    }
220    let num_processors = type_param_idents.len();
221    let mut generated_impls = Vec::new();
222    for i in 0..num_processors {
223        // SwapArbitraryProcessors
224        for j in 0..num_processors {
225            if i == j { continue; }
226            let mut swapped_generic_params_for_type = type_param_idents.clone();
227            swapped_generic_params_for_type.swap(i, j);
228            let field_i_name = processor_field_idents[i];
229            let field_j_name = processor_field_idents[j];
230            let mut current_field_initializers = Vec::new();
231            for k in 0..num_processors {
232                let current_field_name = processor_field_idents[k];
233                if k == i {
234                    current_field_initializers.push(quote! { #current_field_name: self.#field_j_name });
235                } else if k == j {
236                    current_field_initializers.push(quote! { #current_field_name: self.#field_i_name });
237                } else {
238                    current_field_initializers.push(quote! { #current_field_name: self.#current_field_name });
239                }
240            }
241            generated_impls.push(quote! {
242                impl #impl_generics type_flow_traits::SwapArbitraryProcessors<#i, #j> for #struct_name #ty_generics #where_clause {
243                    type SwappedOutput = #struct_name<#(#swapped_generic_params_for_type),*>;
244                    fn swap_processors(self) -> Self::SwappedOutput {
245                        #struct_name {
246                            _marker: std::marker::PhantomData,
247                            #( #current_field_initializers ),*
248                        }
249                    }
250                }
251            });
252        }
253        // PivotSpin
254        // PivotSwap
255        // InterleavePivotSpin
256        // InterleavePivotSwap
257    }
258    let expanded = quote! {
259        #input_struct_item
260        #(#generated_impls)*
261        #(#trait_impls)*
262    };
263    proc_macro::TokenStream::from(expanded)
264}
265
266struct TypeFLowInplaceStatefulProcessorPipeline {
267    pipeline_name: Ident,
268    data_type: Type,
269    error_type: Type,
270    number_of_processors: usize,
271}
272impl Parse for TypeFLowInplaceStatefulProcessorPipeline {
273    fn parse(input: ParseStream) -> syn::Result<Self> {
274        let pipeline_name: Ident = input.parse()?;
275        input.parse::<Token![,]>()?;
276        let data_type: Type = input.parse()?;
277        input.parse::<Token![,]>()?;
278        let error_type: Type = input.parse()?;
279        input.parse::<Token![,]>()?;
280        let lit_int: syn::LitInt = input.parse()?;
281        let number_of_processors = lit_int.base10_parse::<usize>()?;
282        Ok(TypeFLowInplaceStatefulProcessorPipeline{
283            pipeline_name,
284            data_type,
285            error_type,
286            number_of_processors,
287        })
288    }
289}
290
291#[proc_macro]
292pub fn type_flow_inplace_stateful_processor_pipeline_by_count(input: TokenStream) -> TokenStream {
293    let TypeFLowInplaceStatefulProcessorPipeline { pipeline_name, data_type, error_type, number_of_processors } = parse_macro_input!(input as TypeFLowInplaceStatefulProcessorPipeline);
294    let generic_params: Vec<_> = (0..number_of_processors)
295        .map(|i| format_ident!("P{}", i))
296        .collect();
297    let field_names: Vec<_> = (0..number_of_processors)
298        .map(|i| format_ident!("p{}", i))
299        .collect();
300    let fields_with_types: Vec<_> = field_names.iter().zip(generic_params.iter())
301        .map(|(field, param)| quote! { #field: #param })
302        .collect();
303    let processor_args = fields_with_types.into_iter().reduce(|acc, item| {
304        quote! { #acc, #item }
305    });
306    let expanded = quote! {
307        type_flow_macros::type_flow_inplace_stateful_processor_pipeline!(
308            #pipeline_name,
309            #data_type,
310            #error_type,
311            #processor_args
312        );
313    };
314    TokenStream::from(expanded)
315}
316
317
318#[proc_macro]
319pub fn inplace_stateful_processor_pipeline_with_index(input: TokenStream) -> TokenStream { 
320    let ProcessorInplacePipeline { pipeline_name, data_type, error_type, processors } = parse_macro_input!(input as ProcessorInplacePipeline);
321    let struct_fields = processors.iter().enumerate().map(|(idx, processor_type)| {
322        let type_str = get_type_name(processor_type);
323        let field_name = format_ident!("processor_{}_{}", type_str, idx);
324        quote! { #field_name: #processor_type }
325    });
326    let constructor_params = processors.iter().enumerate().map(|(idx, processor_type)| {
327        let type_str = get_type_name(processor_type);
328        let param_name = format_ident!("processor_{}_{}", type_str, idx);
329        quote! { #param_name: #processor_type }
330    });
331    let field_initializers = processors.iter().enumerate().map(|(idx, processor_type)| {
332        let type_str = get_type_name(processor_type);
333        let field_name = format_ident!("processor_{}_{}", type_str, idx);
334        quote! { #field_name }
335    });
336    let process_implementation = processors.iter().enumerate().map(|(idx, processor_type)| {
337        let type_str = get_type_name(processor_type);
338        let field_name = format_ident!("processor_{}_{}", type_str, idx);
339        quote! { self.#field_name.process(data)?; }
340    });
341    let expanded = quote! {
342        pub struct #pipeline_name {
343            #(#struct_fields,)*
344        }
345        impl #pipeline_name {
346            pub fn new(#(#constructor_params,)*) -> Self {
347                Self {
348                    #(#field_initializers,)*
349                }
350            }
351        }
352        impl crate::InPlaceStatefulProcessor<#data_type, #error_type> for #pipeline_name {
353            fn process(&mut self, data: &mut #data_type) -> Result<(), #error_type> {
354                #(#process_implementation)*
355                Ok(())
356            }
357        }
358    };
359    TokenStream::from(expanded)
360}
361
362// Helper function to extract a usable identifier from a Type
363fn get_type_name(ty: &syn::Type) -> String {
364    match ty {
365        syn::Type::Path(type_path) if !type_path.path.segments.is_empty() => {
366            // Get the last segment of the path (e.g., for std::string::String, get "String")
367            let segment = type_path.path.segments.last().unwrap();
368            let name = segment.ident.to_string();
369
370            // Filter out any characters that aren't valid in an identifier
371            name.chars()
372                .map(|c| if c.is_alphanumeric() { c } else { '_' })
373                .collect()
374        },
375        // Handle other types (arrays, references, etc.) by using a generic name
376        _ => "unknown_type".to_string(),
377    }
378}