type_flow_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{parse_macro_input, Ident, Type, parse::Parse, parse::ParseStream, Token, /*GenericParam,*/ ItemStruct, FieldsNamed, Fields};
4
5struct ProcessorPipeline {
6    pipeline_name: Ident,
7    data_type: Type,
8    processors: Vec<Type>,
9}
10
11impl Parse for ProcessorPipeline {
12    fn parse(input: ParseStream) -> syn::Result<Self> {
13        let pipeline_name: Ident = input.parse()?;
14        input.parse::<Token![,]>()?;
15
16        let data_type: Type = input.parse()?;
17        input.parse::<Token![,]>()?;
18
19        let mut processors = Vec::new();
20
21        // Parse the remaining types as processor types
22        while !input.is_empty() {
23            let processor_type: Type = input.parse()?;
24            processors.push(processor_type);
25
26            if input.is_empty() {
27                break;
28            }
29            input.parse::<Token![,]>()?;
30        }
31
32        Ok(ProcessorPipeline {
33            pipeline_name,
34            data_type,
35            processors,
36        })
37    }
38}
39
40#[proc_macro]
41pub fn stateful_processor_pipeline_with_index(input: TokenStream) -> TokenStream {
42    let ProcessorPipeline { pipeline_name, data_type, processors } = parse_macro_input!(input as ProcessorPipeline);
43
44    let struct_fields = processors.iter().enumerate().map(|(idx, processor_type)| {
45        // Get a string representation of the type for the field name
46        let type_str = get_type_name(processor_type);
47        let field_name = format_ident!("processor_{}_{}", type_str, idx);
48        quote! { #field_name: #processor_type }
49    });
50
51    let constructor_params = processors.iter().enumerate().map(|(idx, processor_type)| {
52        let type_str = get_type_name(processor_type);
53        let param_name = format_ident!("processor_{}_{}", type_str, idx);
54        quote! { #param_name: #processor_type }
55    });
56
57    let field_initializers = processors.iter().enumerate().map(|(idx, processor_type)| {
58        let type_str = get_type_name(processor_type);
59        let field_name = format_ident!("processor_{}_{}", type_str, idx);
60        quote! { #field_name }
61    });
62
63    let process_implementation = processors.iter().enumerate().map(|(idx, processor_type)| {
64        let type_str = get_type_name(processor_type);
65        let field_name = format_ident!("processor_{}_{}", type_str, idx);
66        quote! { data = self.#field_name.process(data); }
67    });
68
69    let expanded = quote! {
70        pub struct #pipeline_name {
71            #(#struct_fields,)*
72        }
73        
74        impl #pipeline_name {
75            pub fn new(#(#constructor_params,)*) -> Self {
76                Self {
77                    #(#field_initializers,)*
78                }
79            }
80        }
81        
82        impl crate::StatefulProcessor<#data_type> for #pipeline_name {
83            fn process(&mut self, mut data: #data_type) -> #data_type {
84                #(#process_implementation)*
85                data
86            }
87        }
88    };
89
90    TokenStream::from(expanded)
91}
92
93struct ProcessorInplacePipeline {
94    pipeline_name: Ident,
95    data_type: Type,
96    error_type: Type,
97    processors: Vec<Type>,
98}
99
100impl Parse for ProcessorInplacePipeline {
101    fn parse(input: ParseStream) -> syn::Result<Self> {
102        let pipeline_name: Ident = input.parse()?;
103        input.parse::<Token![,]>()?;
104        let data_type: Type = input.parse()?;
105        input.parse::<Token![,]>()?;
106        let error_type: Type = input.parse()?;
107        input.parse::<Token![,]>()?;
108        let mut processors = Vec::new();
109        // Parse the remaining types as processor types
110        while !input.is_empty() {
111            let processor_type: Type = input.parse()?;
112            processors.push(processor_type);
113
114            if input.is_empty() {
115                break;
116            }
117            input.parse::<Token![,]>()?;
118        }
119        Ok(ProcessorInplacePipeline{
120            pipeline_name,
121            data_type,
122            error_type,
123            processors,
124        })
125    }
126}
127
128
129#[proc_macro_attribute]
130pub fn implement_processor_swapping(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
131    let input_struct_item = parse_macro_input!(item as ItemStruct); // Use ItemStruct for more direct access
132
133    let struct_name = &input_struct_item.ident;
134    let generics = &input_struct_item.generics;
135
136    //let type_params_with_bounds: Vec<&GenericParam> = generics.params.iter().collect();
137
138    // Extract only the identifiers of the type parameters (e.g., P0, P1, P2)
139    let type_param_idents: Vec<&Ident> = generics.type_params().map(|tp| &tp.ident).collect();
140
141    let fields = match &input_struct_item.fields {
142        Fields::Named(FieldsNamed { named, .. }) => named,
143        _ => panic!("#[implement_processor_swapping] only works on structs with named fields."),
144    };
145
146    // Assuming the fields (excluding _marker) correspond to type_params in order.
147    // The type_flow_inplace_stateful_processor_pipeline macro creates a `_marker: PhantomData`
148    // field first, then the processor fields.
149    let processor_field_idents: Vec<&Ident> = fields.iter()
150        .filter_map(|f| f.ident.as_ref())
151        .filter(|&id| id.to_string() != "_marker") // Skip _marker field
152        .collect();
153
154    if type_param_idents.len() != processor_field_idents.len() {
155        panic!(
156            "#[implement_processor_swapping] detected a mismatch between the number of generic type parameters ({}) and processor fields ({}). Expected them to be equal.",
157            type_param_idents.len(),
158            processor_field_idents.len()
159        );
160    }
161
162    let num_processors = type_param_idents.len();
163    if num_processors < 2 { // Swapping requires at least two processors
164        // Return the original struct definition along with a placeholder for the trait
165        let trait_def_placeholder = quote! {
166            pub trait SwapArbitraryProcessors<const I: usize, const J: usize> where Self: Sized {
167                type SwappedOutput;
168                fn swap_processors(self) -> Self::SwappedOutput;
169            }
170        };
171        let original_struct_tokens = quote! { #input_struct_item };
172        return quote! {
173            #original_struct_tokens
174            #trait_def_placeholder
175        }.into();
176    }
177
178    let mut generated_impls = Vec::new();
179    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
180
181    let swap_trait_name = format_ident!("SwapArbitraryProcessors");
182
183    for i in 0..num_processors {
184        for j in 0..num_processors {
185            if i == j {
186                continue; // No need to generate an impl for swapping a processor with itself
187            }
188
189            let mut swapped_generic_params_for_type = type_param_idents.clone();
190            swapped_generic_params_for_type.swap(i, j);
191
192            let field_i_name = processor_field_idents[i];
193            let field_j_name = processor_field_idents[j];
194
195            let mut current_field_initializers = Vec::new();
196            for k in 0..num_processors {
197                let current_field_name = processor_field_idents[k];
198                if k == i {
199                    // The k-th field in the new struct (which is field_i_name) gets the j-th processor's state from self
200                    current_field_initializers.push(quote! { #current_field_name: self.#field_j_name });
201                } else if k == j {
202                    // The k-th field in the new struct (which is field_j_name) gets the i-th processor's state from self
203                    current_field_initializers.push(quote! { #current_field_name: self.#field_i_name });
204                } else {
205                    current_field_initializers.push(quote! { #current_field_name: self.#current_field_name });
206                }
207            }
208
209            generated_impls.push(quote! {
210                impl #impl_generics #swap_trait_name<#i, #j> for #struct_name #ty_generics #where_clause {
211                    type SwappedOutput = #struct_name<#(#swapped_generic_params_for_type),*>;
212
213                    fn swap_processors(self) -> Self::SwappedOutput {
214                        #struct_name {
215                            _marker: std::marker::PhantomData,
216                            #( #current_field_initializers ),*
217                        }
218                    }
219                }
220            });
221        }
222    }
223
224    // The trait definition itself
225    // It's better to define it outside the loop to avoid redefinition if num_processors < 2,
226    // and ensure it's always present.
227    let trait_def = quote! {
228        pub trait #swap_trait_name<const I: usize, const J: usize>
229        where
230            Self: Sized,
231        {
232            type SwappedOutput;
233            fn swap_processors(self) -> Self::SwappedOutput;
234        }
235    };
236
237    let expanded = quote! {
238        #input_struct_item // The original struct definition
239
240        #trait_def
241
242        #(#generated_impls)*
243    };
244
245    proc_macro::TokenStream::from(expanded)
246}
247
248
249#[proc_macro]
250pub fn inplace_stateful_processor_pipeline_with_index(input: TokenStream) -> TokenStream { 
251    let ProcessorInplacePipeline { pipeline_name, data_type, error_type, processors } = parse_macro_input!(input as ProcessorInplacePipeline);
252    let struct_fields = processors.iter().enumerate().map(|(idx, processor_type)| {
253        // Get a string representation of the type for the field name
254        let type_str = get_type_name(processor_type);
255        let field_name = format_ident!("processor_{}_{}", type_str, idx);
256        quote! { #field_name: #processor_type }
257    });
258
259    let constructor_params = processors.iter().enumerate().map(|(idx, processor_type)| {
260        let type_str = get_type_name(processor_type);
261        let param_name = format_ident!("processor_{}_{}", type_str, idx);
262        quote! { #param_name: #processor_type }
263    });
264
265    let field_initializers = processors.iter().enumerate().map(|(idx, processor_type)| {
266        let type_str = get_type_name(processor_type);
267        let field_name = format_ident!("processor_{}_{}", type_str, idx);
268        quote! { #field_name }
269    });
270
271    let process_implementation = processors.iter().enumerate().map(|(idx, processor_type)| {
272        let type_str = get_type_name(processor_type);
273        let field_name = format_ident!("processor_{}_{}", type_str, idx);
274        quote! { self.#field_name.process(data)?; }
275    });
276
277    let expanded = quote! {
278        pub struct #pipeline_name {
279            #(#struct_fields,)*
280        }
281        
282        impl #pipeline_name {
283            pub fn new(#(#constructor_params,)*) -> Self {
284                Self {
285                    #(#field_initializers,)*
286                }
287            }
288        }
289        
290        impl crate::InPlaceStatefulProcessor<#data_type, #error_type> for #pipeline_name {
291            fn process(&mut self, data: &mut #data_type) -> Result<(), #error_type> {
292                #(#process_implementation)*
293                Ok(())
294            }
295        }
296    };
297    TokenStream::from(expanded)
298}
299
300// Helper function to extract a usable identifier from a Type
301fn get_type_name(ty: &syn::Type) -> String {
302    match ty {
303        syn::Type::Path(type_path) if !type_path.path.segments.is_empty() => {
304            // Get the last segment of the path (e.g., for std::string::String, get "String")
305            let segment = type_path.path.segments.last().unwrap();
306            let name = segment.ident.to_string();
307
308            // Filter out any characters that aren't valid in an identifier
309            name.chars()
310                .map(|c| if c.is_alphanumeric() { c } else { '_' })
311                .collect()
312        },
313        // Handle other types (arrays, references, etc.) by using a generic name
314        _ => "unknown_type".to_string(),
315    }
316}