1use proc_macro::TokenStream;
2use quote::{quote, format_ident};
3use syn::{parse_macro_input, Ident, Type, parse::Parse, parse::ParseStream, Token, ItemStruct, FieldsNamed, Fields};
4
5struct ProcessorPipeline {
6 pipeline_name: Ident,
7 data_type: Type,
8 processors: Vec<Type>,
9}
10
11impl Parse for ProcessorPipeline {
12 fn parse(input: ParseStream) -> syn::Result<Self> {
13 let pipeline_name: Ident = input.parse()?;
14 input.parse::<Token![,]>()?;
15 let data_type: Type = input.parse()?;
16 input.parse::<Token![,]>()?;
17 let mut processors = Vec::new();
18 while !input.is_empty() {
20 let processor_type: Type = input.parse()?;
21 processors.push(processor_type);
22 if input.is_empty() {
23 break;
24 }
25 input.parse::<Token![,]>()?;
26 }
27 Ok(ProcessorPipeline {
28 pipeline_name,
29 data_type,
30 processors,
31 })
32 }
33}
34
35#[proc_macro]
36pub fn stateful_processor_pipeline_with_index(input: TokenStream) -> TokenStream {
37 let ProcessorPipeline { pipeline_name, data_type, processors } = parse_macro_input!(input as ProcessorPipeline);
38 let struct_fields = processors.iter().enumerate().map(|(idx, processor_type)| {
39 let type_str = get_type_name(processor_type);
41 let field_name = format_ident!("processor_{}_{}", type_str, idx);
42 quote! { #field_name: #processor_type }
43 });
44 let constructor_params = processors.iter().enumerate().map(|(idx, processor_type)| {
45 let type_str = get_type_name(processor_type);
46 let param_name = format_ident!("processor_{}_{}", type_str, idx);
47 quote! { #param_name: #processor_type }
48 });
49 let field_initializers = processors.iter().enumerate().map(|(idx, processor_type)| {
50 let type_str = get_type_name(processor_type);
51 let field_name = format_ident!("processor_{}_{}", type_str, idx);
52 quote! { #field_name }
53 });
54 let process_implementation = processors.iter().enumerate().map(|(idx, processor_type)| {
55 let type_str = get_type_name(processor_type);
56 let field_name = format_ident!("processor_{}_{}", type_str, idx);
57 quote! { data = self.#field_name.process(data); }
58 });
59 let expanded = quote! {
60 pub struct #pipeline_name {
61 #(#struct_fields,)*
62 }
63 impl #pipeline_name {
64 pub fn new(#(#constructor_params,)*) -> Self {
65 Self {
66 #(#field_initializers,)*
67 }
68 }
69 }
70 impl crate::StatefulProcessor<#data_type> for #pipeline_name {
71 fn process(&mut self, mut data: #data_type) -> #data_type {
72 #(#process_implementation)*
73 data
74 }
75 }
76 };
77 TokenStream::from(expanded)
78}
79
80struct ProcessorInplacePipeline {
81 pipeline_name: Ident,
82 data_type: Type,
83 error_type: Type,
84 processors: Vec<Type>,
85}
86
87impl Parse for ProcessorInplacePipeline {
88 fn parse(input: ParseStream) -> syn::Result<Self> {
89 let pipeline_name: Ident = input.parse()?;
90 input.parse::<Token![,]>()?;
91 let data_type: Type = input.parse()?;
92 input.parse::<Token![,]>()?;
93 let error_type: Type = input.parse()?;
94 input.parse::<Token![,]>()?;
95 let mut processors = Vec::new();
96 while !input.is_empty() {
98 let processor_type: Type = input.parse()?;
99 processors.push(processor_type);
100
101 if input.is_empty() {
102 break;
103 }
104 input.parse::<Token![,]>()?;
105 }
106 Ok(ProcessorInplacePipeline{
107 pipeline_name,
108 data_type,
109 error_type,
110 processors,
111 })
112 }
113}
114
115
116#[proc_macro_attribute]
117pub fn implement_processor_swapping(_attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
118 let input_struct_item = parse_macro_input!(item as ItemStruct);
119 let struct_name = &input_struct_item.ident;
120 let generics = &input_struct_item.generics;
121 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
122 let type_param_idents: Vec<&Ident> = generics.type_params().map(|tp| &tp.ident).collect();
123 let fields = match &input_struct_item.fields {
124 Fields::Named(FieldsNamed { named, .. }) => named,
125 _ => panic!("#[implement_processor_swapping] only works on structs with named fields."),
126 };
127 let processor_field_idents: Vec<&Ident> = fields.iter()
128 .filter_map(|f| f.ident.as_ref())
129 .filter(|&id| id.to_string() != "_marker")
130 .collect();
131 if type_param_idents.len() != processor_field_idents.len() {
132 panic!(
133 "#[implement_processor_swapping] detected a mismatch between the number of generic type parameters ({}) and processor fields ({}). Expected them to be equal.",
134 type_param_idents.len(),
135 processor_field_idents.len()
136 );
137 }
138 let mut trait_impls = Vec::new();
140 if !type_param_idents.is_empty() {
141 let mut rev_types = type_param_idents.clone();
143 rev_types.reverse();
144 let mut rev_source_fields = processor_field_idents.clone();
145 rev_source_fields.reverse();
146 let rev_assignments = processor_field_idents.iter().zip(rev_source_fields.iter()).map(|(dest, src)| {
147 quote!{ #dest: self.#src }
148 });
149 trait_impls.push(quote!{
150 impl #impl_generics type_flow_traits::Reverse for #struct_name #ty_generics #where_clause {
151 type Output = #struct_name<#(#rev_types),*>;
152 fn reverse(self) -> Self::Output {
153 Self::Output {
154 _marker: std::marker::PhantomData,
155 #(#rev_assignments),*
156 }
157 }
158 }
159 });
160 let mut sl_types = type_param_idents.clone();
162 sl_types.rotate_left(1);
163 let mut sl_source_fields = processor_field_idents.clone();
164 sl_source_fields.rotate_left(1);
165 let sl_assignments = processor_field_idents.iter().zip(sl_source_fields.iter()).map(|(dest, src)| {
166 quote!{ #dest: self.#src }
167 });
168 trait_impls.push(quote!{
169 impl #impl_generics type_flow_traits::ShiftLeft for #struct_name #ty_generics #where_clause {
170 type ShiftedLeft = #struct_name<#(#sl_types),*>;
171 fn shift_left(self) -> Self::ShiftedLeft {
172 Self::ShiftedLeft {
173 _marker: std::marker::PhantomData,
174 #(#sl_assignments),*
175 }
176 }
177 }
178 });
179 let mut sr_types = type_param_idents.clone();
181 sr_types.rotate_right(1);
182 let mut sr_source_fields = processor_field_idents.clone();
183 sr_source_fields.rotate_right(1);
184 let sr_assignments = processor_field_idents.iter().zip(sr_source_fields.iter()).map(|(dest, src)| {
185 quote!{ #dest: self.#src }
186 });
187 trait_impls.push(quote!{
188 impl #impl_generics type_flow_traits::ShiftRight for #struct_name #ty_generics #where_clause {
189 type ShiftedRight = #struct_name<#(#sr_types),*>;
190 fn shift_right(self) -> Self::ShiftedRight {
191 Self::ShiftedRight {
192 _marker: std::marker::PhantomData,
193 #(#sr_assignments),*
194 }
195 }
196 }
197 });
198 let mut sse_types = type_param_idents.clone();
200 let sse_length = sse_types.len();
201 if sse_length > 1 { sse_types.swap(0, sse_length - 1); }
202 let mut sse_source_fields = processor_field_idents.clone();
203 let sse_source_length = sse_source_fields.len();
204 if sse_source_length > 1 { sse_source_fields.swap(0, sse_source_length - 1); }
205 let sse_assignments = processor_field_idents.iter().zip(sse_source_fields.iter()).map(|(dest, src)| {
206 quote!{ #dest: self.#src }
207 });
208 trait_impls.push(quote!{
209 impl #impl_generics type_flow_traits::SwapStartEnd for #struct_name #ty_generics #where_clause {
210 type Output = #struct_name<#(#sse_types),*>;
211 fn swap(self) -> Self::Output {
212 Self::Output {
213 _marker: std::marker::PhantomData,
214 #(#sse_assignments),*
215 }
216 }
217 }
218 });
219 }
220 let num_processors = type_param_idents.len();
221 if num_processors < 2 {
222 let trait_def_placeholder = quote! {
223 pub trait SwapArbitraryProcessors<const I: usize, const J: usize> where Self: Sized {
224 type SwappedOutput;
225 fn swap_processors(self) -> Self::SwappedOutput;
226 }
227 };
228 return quote! {
229 #input_struct_item
230 #trait_def_placeholder
231 #(#trait_impls)*
232 }.into();
233 }
234 let mut generated_impls = Vec::new();
235 let swap_trait_name = format_ident!("SwapArbitraryProcessors");
236 for i in 0..num_processors {
237 for j in 0..num_processors {
238 if i == j { continue; }
239 let mut swapped_generic_params_for_type = type_param_idents.clone();
240 swapped_generic_params_for_type.swap(i, j);
241 let field_i_name = processor_field_idents[i];
242 let field_j_name = processor_field_idents[j];
243 let mut current_field_initializers = Vec::new();
244 for k in 0..num_processors {
245 let current_field_name = processor_field_idents[k];
246 if k == i {
247 current_field_initializers.push(quote! { #current_field_name: self.#field_j_name });
248 } else if k == j {
249 current_field_initializers.push(quote! { #current_field_name: self.#field_i_name });
250 } else {
251 current_field_initializers.push(quote! { #current_field_name: self.#current_field_name });
252 }
253 }
254 generated_impls.push(quote! {
255 impl #impl_generics #swap_trait_name<#i, #j> for #struct_name #ty_generics #where_clause {
256 type SwappedOutput = #struct_name<#(#swapped_generic_params_for_type),*>;
257 fn swap_processors(self) -> Self::SwappedOutput {
258 #struct_name {
259 _marker: std::marker::PhantomData,
260 #( #current_field_initializers ),*
261 }
262 }
263 }
264 });
265 }
266 }
267
268 let trait_def = quote! {
269 pub trait #swap_trait_name<const I: usize, const J: usize>
270 where
271 Self: Sized,
272 {
273 type SwappedOutput;
274 fn swap_processors(self) -> Self::SwappedOutput;
275 }
276 };
277
278 let expanded = quote! {
279 #input_struct_item
280 #trait_def
281 #(#generated_impls)*
282 #(#trait_impls)*
283 };
284 proc_macro::TokenStream::from(expanded)
285}
286
287#[proc_macro]
288pub fn inplace_stateful_processor_pipeline_with_index(input: TokenStream) -> TokenStream {
289 let ProcessorInplacePipeline { pipeline_name, data_type, error_type, processors } = parse_macro_input!(input as ProcessorInplacePipeline);
290 let struct_fields = processors.iter().enumerate().map(|(idx, processor_type)| {
291 let type_str = get_type_name(processor_type);
293 let field_name = format_ident!("processor_{}_{}", type_str, idx);
294 quote! { #field_name: #processor_type }
295 });
296 let constructor_params = processors.iter().enumerate().map(|(idx, processor_type)| {
297 let type_str = get_type_name(processor_type);
298 let param_name = format_ident!("processor_{}_{}", type_str, idx);
299 quote! { #param_name: #processor_type }
300 });
301 let field_initializers = processors.iter().enumerate().map(|(idx, processor_type)| {
302 let type_str = get_type_name(processor_type);
303 let field_name = format_ident!("processor_{}_{}", type_str, idx);
304 quote! { #field_name }
305 });
306 let process_implementation = processors.iter().enumerate().map(|(idx, processor_type)| {
307 let type_str = get_type_name(processor_type);
308 let field_name = format_ident!("processor_{}_{}", type_str, idx);
309 quote! { self.#field_name.process(data)?; }
310 });
311 let expanded = quote! {
312 pub struct #pipeline_name {
313 #(#struct_fields,)*
314 }
315 impl #pipeline_name {
316 pub fn new(#(#constructor_params,)*) -> Self {
317 Self {
318 #(#field_initializers,)*
319 }
320 }
321 }
322 impl crate::InPlaceStatefulProcessor<#data_type, #error_type> for #pipeline_name {
323 fn process(&mut self, data: &mut #data_type) -> Result<(), #error_type> {
324 #(#process_implementation)*
325 Ok(())
326 }
327 }
328 };
329 TokenStream::from(expanded)
330}
331
332fn get_type_name(ty: &syn::Type) -> String {
334 match ty {
335 syn::Type::Path(type_path) if !type_path.path.segments.is_empty() => {
336 let segment = type_path.path.segments.last().unwrap();
338 let name = segment.ident.to_string();
339
340 name.chars()
342 .map(|c| if c.is_alphanumeric() { c } else { '_' })
343 .collect()
344 },
345 _ => "unknown_type".to_string(),
347 }
348}