1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{parse_macro_input, Ident, Type, parse::Parse, parse::ParseStream, Token, ItemStruct, FieldsNamed, Fields};
4
5struct ProcessorPipeline {
6 pipeline_name: Ident,
7 data_type: Type,
8 processors: Vec<Type>,
9}
10
11impl Parse for ProcessorPipeline {
12 fn parse(input: ParseStream) -> syn::Result<Self> {
13 let pipeline_name: Ident = input.parse()?;
14 input.parse::<Token![,]>()?;
15 let data_type: Type = input.parse()?;
16 input.parse::<Token![,]>()?;
17 let mut processors = Vec::new();
18 while !input.is_empty() {
20 let processor_type: Type = input.parse()?;
21 processors.push(processor_type);
22 if input.is_empty() {
23 break;
24 }
25 input.parse::<Token![,]>()?;
26 }
27 Ok(ProcessorPipeline {
28 pipeline_name,
29 data_type,
30 processors,
31 })
32 }
33}
34
35#[proc_macro]
36pub fn stateful_processor_pipeline_with_index(input: TokenStream) -> TokenStream {
37 let ProcessorPipeline { pipeline_name, data_type, processors } = parse_macro_input!(input as ProcessorPipeline);
38 let struct_fields = processors.iter().enumerate().map(|(idx, processor_type)| {
39 let type_str = get_type_name(processor_type);
41 let field_name = format_ident!("processor_{}_{}", type_str, idx);
42 quote! { #field_name: #processor_type }
43 });
44 let constructor_params = processors.iter().enumerate().map(|(idx, processor_type)| {
45 let type_str = get_type_name(processor_type);
46 let param_name = format_ident!("processor_{}_{}", type_str, idx);
47 quote! { #param_name: #processor_type }
48 });
49 let field_initializers = processors.iter().enumerate().map(|(idx, processor_type)| {
50 let type_str = get_type_name(processor_type);
51 let field_name = format_ident!("processor_{}_{}", type_str, idx);
52 quote! { #field_name }
53 });
54 let process_implementation = processors.iter().enumerate().map(|(idx, processor_type)| {
55 let type_str = get_type_name(processor_type);
56 let field_name = format_ident!("processor_{}_{}", type_str, idx);
57 quote! { data = self.#field_name.process(data); }
58 });
59 let expanded = quote! {
60 pub struct #pipeline_name {
61 #(#struct_fields,)*
62 }
63 impl #pipeline_name {
64 pub fn new(#(#constructor_params,)*) -> Self {
65 Self {
66 #(#field_initializers,)*
67 }
68 }
69 }
70 impl StatefulProcessor<#data_type> for #pipeline_name {
71 fn process(&mut self, mut data: #data_type) -> #data_type {
72 #(#process_implementation)*
73 data
74 }
75 }
76 };
77 TokenStream::from(expanded)
78}
79
80struct ProcessorInplacePipeline {
81 pipeline_name: Ident,
82 data_type: Type,
83 error_type: Type,
84 processors: Vec<Type>,
85}
86
87impl Parse for ProcessorInplacePipeline {
88 fn parse(input: ParseStream) -> syn::Result<Self> {
89 let pipeline_name: Ident = input.parse()?;
90 input.parse::<Token![,]>()?;
91 let data_type: Type = input.parse()?;
92 input.parse::<Token![,]>()?;
93 let error_type: Type = input.parse()?;
94 input.parse::<Token![,]>()?;
95 let mut processors = Vec::new();
96 while !input.is_empty() {
98 let processor_type: Type = input.parse()?;
99 processors.push(processor_type);
100
101 if input.is_empty() {
102 break;
103 }
104 input.parse::<Token![,]>()?;
105 }
106 Ok(ProcessorInplacePipeline{
107 pipeline_name,
108 data_type,
109 error_type,
110 processors,
111 })
112 }
113}
114
115
116#[proc_macro_attribute]
117pub fn implement_processor_swapping(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
118 let input_struct_item = parse_macro_input!(item as ItemStruct);
119 let struct_name = &input_struct_item.ident;
120 let generics = &input_struct_item.generics;
121 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
122 let type_param_idents: Vec<&Ident> = generics.type_params().map(|tp| &tp.ident).collect();
123 let fields = match &input_struct_item.fields {
124 Fields::Named(FieldsNamed { named, .. }) => named,
125 _ => panic!("#[implement_processor_swapping] only works on structs with named fields."),
126 };
127 let processor_field_idents: Vec<&Ident> = fields.iter()
128 .filter_map(|f| f.ident.as_ref())
129 .filter(|&id| id.to_string() != "_marker")
130 .collect();
131 if type_param_idents.len() != processor_field_idents.len() {
132 panic!(
133 "#[implement_processor_swapping] detected a mismatch between the number of generic type parameters ({}) and processor fields ({}). Expected them to be equal.",
134 type_param_idents.len(),
135 processor_field_idents.len()
136 );
137 }
138 let mut trait_impls = Vec::new();
140 if !type_param_idents.is_empty() {
141 let mut rev_types = type_param_idents.clone();
143 rev_types.reverse();
144 let mut rev_source_fields = processor_field_idents.clone();
145 rev_source_fields.reverse();
146 let rev_assignments = processor_field_idents.iter().zip(rev_source_fields.iter()).map(|(dest, src)| {
147 quote!{ #dest: self.#src }
148 });
149 trait_impls.push(quote!{
150 impl #impl_generics type_flow_traits::Reverse for #struct_name #ty_generics #where_clause {
151 type Output = #struct_name<#(#rev_types),*>;
152 fn reverse(self) -> Self::Output {
153 Self::Output {
154 _marker: std::marker::PhantomData,
155 #(#rev_assignments),*
156 }
157 }
158 }
159 });
160 let mut sl_types = type_param_idents.clone();
162 sl_types.rotate_left(1);
163 let mut sl_source_fields = processor_field_idents.clone();
164 sl_source_fields.rotate_left(1);
165 let sl_assignments = processor_field_idents.iter().zip(sl_source_fields.iter()).map(|(dest, src)| {
166 quote!{ #dest: self.#src }
167 });
168 trait_impls.push(quote!{
169 impl #impl_generics type_flow_traits::ShiftLeft for #struct_name #ty_generics #where_clause {
170 type ShiftedLeft = #struct_name<#(#sl_types),*>;
171 fn shift_left(self) -> Self::ShiftedLeft {
172 Self::ShiftedLeft {
173 _marker: std::marker::PhantomData,
174 #(#sl_assignments),*
175 }
176 }
177 }
178 });
179 let mut sr_types = type_param_idents.clone();
181 sr_types.rotate_right(1);
182 let mut sr_source_fields = processor_field_idents.clone();
183 sr_source_fields.rotate_right(1);
184 let sr_assignments = processor_field_idents.iter().zip(sr_source_fields.iter()).map(|(dest, src)| {
185 quote!{ #dest: self.#src }
186 });
187 trait_impls.push(quote!{
188 impl #impl_generics type_flow_traits::ShiftRight for #struct_name #ty_generics #where_clause {
189 type ShiftedRight = #struct_name<#(#sr_types),*>;
190 fn shift_right(self) -> Self::ShiftedRight {
191 Self::ShiftedRight {
192 _marker: std::marker::PhantomData,
193 #(#sr_assignments),*
194 }
195 }
196 }
197 });
198 let mut sse_types = type_param_idents.clone();
200 let sse_length = sse_types.len();
201 if sse_length > 1 { sse_types.swap(0, sse_length - 1); }
202 let mut sse_source_fields = processor_field_idents.clone();
203 let sse_source_length = sse_source_fields.len();
204 if sse_source_length > 1 { sse_source_fields.swap(0, sse_source_length - 1); }
205 let sse_assignments = processor_field_idents.iter().zip(sse_source_fields.iter()).map(|(dest, src)| {
206 quote!{ #dest: self.#src }
207 });
208 trait_impls.push(quote!{
209 impl #impl_generics type_flow_traits::SwapStartEnd for #struct_name #ty_generics #where_clause {
210 type Output = #struct_name<#(#sse_types),*>;
211 fn swap(self) -> Self::Output {
212 Self::Output {
213 _marker: std::marker::PhantomData,
214 #(#sse_assignments),*
215 }
216 }
217 }
218 });
219 }
220 let num_processors = type_param_idents.len();
221 let mut generated_impls = Vec::new();
222 for i in 0..num_processors {
223 for j in 0..num_processors {
225 if i == j { continue; }
226 let mut swapped_generic_params_for_type = type_param_idents.clone();
227 swapped_generic_params_for_type.swap(i, j);
228 let field_i_name = processor_field_idents[i];
229 let field_j_name = processor_field_idents[j];
230 let mut current_field_initializers = Vec::new();
231 for k in 0..num_processors {
232 let current_field_name = processor_field_idents[k];
233 if k == i {
234 current_field_initializers.push(quote! { #current_field_name: self.#field_j_name });
235 } else if k == j {
236 current_field_initializers.push(quote! { #current_field_name: self.#field_i_name });
237 } else {
238 current_field_initializers.push(quote! { #current_field_name: self.#current_field_name });
239 }
240 }
241 generated_impls.push(quote! {
242 impl #impl_generics type_flow_traits::SwapArbitraryProcessors<#i, #j> for #struct_name #ty_generics #where_clause {
243 type SwappedOutput = #struct_name<#(#swapped_generic_params_for_type),*>;
244 fn swap_processors(self) -> Self::SwappedOutput {
245 #struct_name {
246 _marker: std::marker::PhantomData,
247 #( #current_field_initializers ),*
248 }
249 }
250 }
251 });
252 }
253 }
258 let expanded = quote! {
259 #input_struct_item
260 #(#generated_impls)*
261 #(#trait_impls)*
262 };
263 proc_macro::TokenStream::from(expanded)
264}
265
266struct TypeFLowInplaceStatefulProcessorPipeline {
267 pipeline_name: Ident,
268 data_type: Type,
269 error_type: Type,
270 number_of_processors: usize,
271}
272impl Parse for TypeFLowInplaceStatefulProcessorPipeline {
273 fn parse(input: ParseStream) -> syn::Result<Self> {
274 let pipeline_name: Ident = input.parse()?;
275 input.parse::<Token![,]>()?;
276 let data_type: Type = input.parse()?;
277 input.parse::<Token![,]>()?;
278 let error_type: Type = input.parse()?;
279 input.parse::<Token![,]>()?;
280 let lit_int: syn::LitInt = input.parse()?;
281 let number_of_processors = lit_int.base10_parse::<usize>()?;
282 Ok(TypeFLowInplaceStatefulProcessorPipeline{
283 pipeline_name,
284 data_type,
285 error_type,
286 number_of_processors,
287 })
288 }
289}
290
291#[proc_macro]
292pub fn type_flow_inplace_stateful_processor_pipeline_by_count(input: TokenStream) -> TokenStream {
293 let TypeFLowInplaceStatefulProcessorPipeline { pipeline_name, data_type, error_type, number_of_processors } = parse_macro_input!(input as TypeFLowInplaceStatefulProcessorPipeline);
294 let generic_params: Vec<_> = (0..number_of_processors)
295 .map(|i| format_ident!("P{}", i))
296 .collect();
297 let field_names: Vec<_> = (0..number_of_processors)
298 .map(|i| format_ident!("p{}", i))
299 .collect();
300 let fields_with_types: Vec<_> = field_names.iter().zip(generic_params.iter())
301 .map(|(field, param)| quote! { #field: #param })
302 .collect();
303 let processor_args = fields_with_types.into_iter().reduce(|acc, item| {
304 quote! { #acc, #item }
305 });
306 let expanded = quote! {
307 type_flow_macros::type_flow_inplace_stateful_processor_pipeline!(
308 #pipeline_name,
309 #data_type,
310 #error_type,
311 #processor_args
312 );
313 };
314 TokenStream::from(expanded)
315}
316
317
318#[proc_macro]
319pub fn inplace_stateful_processor_pipeline_with_index(input: TokenStream) -> TokenStream {
320 let ProcessorInplacePipeline { pipeline_name, data_type, error_type, processors } = parse_macro_input!(input as ProcessorInplacePipeline);
321 let struct_fields = processors.iter().enumerate().map(|(idx, processor_type)| {
322 let type_str = get_type_name(processor_type);
323 let field_name = format_ident!("processor_{}_{}", type_str, idx);
324 quote! { #field_name: #processor_type }
325 });
326 let constructor_params = processors.iter().enumerate().map(|(idx, processor_type)| {
327 let type_str = get_type_name(processor_type);
328 let param_name = format_ident!("processor_{}_{}", type_str, idx);
329 quote! { #param_name: #processor_type }
330 });
331 let field_initializers = processors.iter().enumerate().map(|(idx, processor_type)| {
332 let type_str = get_type_name(processor_type);
333 let field_name = format_ident!("processor_{}_{}", type_str, idx);
334 quote! { #field_name }
335 });
336 let process_implementation = processors.iter().enumerate().map(|(idx, processor_type)| {
337 let type_str = get_type_name(processor_type);
338 let field_name = format_ident!("processor_{}_{}", type_str, idx);
339 quote! { self.#field_name.process(data)?; }
340 });
341 let expanded = quote! {
342 pub struct #pipeline_name {
343 #(#struct_fields,)*
344 }
345 impl #pipeline_name {
346 pub fn new(#(#constructor_params,)*) -> Self {
347 Self {
348 #(#field_initializers,)*
349 }
350 }
351 }
352 impl crate::InPlaceStatefulProcessor<#data_type, #error_type> for #pipeline_name {
353 fn process(&mut self, data: &mut #data_type) -> Result<(), #error_type> {
354 #(#process_implementation)*
355 Ok(())
356 }
357 }
358 };
359 TokenStream::from(expanded)
360}
361
362fn get_type_name(ty: &syn::Type) -> String {
364 match ty {
365 syn::Type::Path(type_path) if !type_path.path.segments.is_empty() => {
366 let segment = type_path.path.segments.last().unwrap();
368 let name = segment.ident.to_string();
369
370 name.chars()
372 .map(|c| if c.is_alphanumeric() { c } else { '_' })
373 .collect()
374 },
375 _ => "unknown_type".to_string(),
377 }
378}