state_validation_derive/
lib.rs

1use std::collections::{BTreeSet, HashMap};
2
3use itertools::Itertools;
4use proc_macro::TokenStream;
5use quote::TokenStreamExt;
6use syn::{
7    Expr, GenericArgument, GenericParam, Generics, Ident, Lifetime, Type, TypePath,
8    parse_macro_input, parse_quote,
9};
10
11#[derive(Clone, PartialEq, Eq, Hash)]
12struct ConversionSort {
13    sort_number: usize,
14    ty: ConversionType,
15}
16impl Ord for ConversionSort {
17    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
18        self.sort_number.cmp(&other.sort_number)
19    }
20}
21impl PartialOrd for ConversionSort {
22    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
23        Some(self.cmp(other))
24    }
25}
26impl quote::ToTokens for ConversionSort {
27    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
28        self.ty.to_tokens(tokens)
29    }
30}
31#[derive(Clone, PartialEq, Eq, Hash)]
32enum ConversionType {
33    Type(syn::Type),
34    Generic {
35        generic_ident: Vec<syn::Ident>,
36        path: syn::Path,
37    },
38}
39impl syn::parse::Parse for ConversionType {
40    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
41        if input.peek(syn::Ident) && input.peek2(syn::Token![=]) {
42            let generic_ident = input.parse()?;
43            let _: syn::Token![=] = input.parse()?;
44            let path = input.parse()?;
45            Ok(ConversionType::Generic {
46                generic_ident: vec![generic_ident],
47                path,
48            })
49        } else {
50            input.parse().map(ConversionType::Type)
51        }
52    }
53}
54impl quote::ToTokens for ConversionType {
55    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
56        match self {
57            ConversionType::Type(ty) => {
58                tokens.append_all(ty.to_token_stream());
59            }
60            ConversionType::Generic {
61                generic_ident,
62                path,
63            } => {
64                tokens.append_all(quote::quote!(#path));
65            }
66        }
67    }
68}
69
70#[proc_macro_derive(StateFilterConversion, attributes(conversion))]
71pub fn state_filter_conversion(input: TokenStream) -> TokenStream {
72    let ast = parse_macro_input!(input as syn::DeriveInput);
73    let name = &ast.ident;
74    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
75    let generics: Vec<_> = ast
76        .generics
77        .type_params()
78        .into_iter()
79        .map(|ty| ty.ident.clone())
80        .collect();
81    let state_conversions = match &ast.data {
82        syn::Data::Struct(s) => {
83            let fields_count = s.fields.len();
84            let mut state_conversions = Vec::with_capacity(fields_count);
85            let iter: Vec<_> = s
86                .fields
87                .iter()
88                .enumerate()
89                .map(|(i, field)| {
90                    let field_name = field.ident.as_ref().expect("expected a named field");
91                    let mut all_conversion_fields = Vec::new();
92                    all_conversion_fields.push((
93                        field_name,
94                        ConversionSort {
95                            sort_number: i,
96                            ty: ConversionType::Type(field.ty.clone()),
97                        },
98                        extract_generics_from_type(&field.ty),
99                    ));
100                    for attr in field
101                        .attrs
102                        .iter()
103                        .filter(|attr| attr.path().is_ident("conversion"))
104                    {
105                        let f = attr
106                            .parse_args::<ConversionType>()
107                            .expect("expected a conversion type");
108                        let generics = match &f {
109                            ConversionType::Type(ty) => extract_generics_from_type(ty),
110                            ConversionType::Generic { generic_ident, .. } => {
111                                parse_quote!(<#(#generic_ident),*>)
112                            }
113                        };
114                        all_conversion_fields.push((
115                            field_name,
116                            ConversionSort {
117                                sort_number: i,
118                                ty: f,
119                            },
120                            generics,
121                        ));
122                    }
123                    all_conversion_fields
124                })
125                .collect();
126            let mut combination_names = HashMap::new();
127            let mut remainder_names = HashMap::new();
128            for (i, (field_names, mut field_types, field_generics)) in iter
129                .iter()
130                .multi_cartesian_product()
131                .map(|f| {
132                    let mut field_names = Vec::with_capacity(f.len());
133                    let mut field_types = Vec::with_capacity(f.len());
134                    let mut generics = Vec::with_capacity(f.len());
135                    for (field_name, field_type, field_generics) in f {
136                        field_names.push(*field_name);
137                        field_types.push(field_type.clone());
138                        generics.push(field_generics);
139                    }
140                    (field_names, field_types, generics)
141                })
142                .enumerate()
143            {
144                let combination_struct_name =
145                    quote::format_ident!("__StateValidationGeneration_{name}Combined_{i}");
146                let mut generics = Generics::default();
147                for g in field_generics {
148                    generics = merge_generics(generics, g);
149                }
150                let q = quote::quote! {
151                    pub struct #combination_struct_name #generics {
152                        #(pub #field_names: #field_types),*
153                    }
154                };
155                state_conversions.push(q);
156                field_types.sort();
157                combination_names.insert(field_types, combination_struct_name);
158            }
159            let mut i = 0;
160            for powerset in iter.iter().powerset() {
161                for (field_names, mut field_types, field_generics) in
162                    powerset.into_iter().multi_cartesian_product().map(|f| {
163                        let mut field_names = Vec::with_capacity(f.len());
164                        let mut field_types = Vec::with_capacity(f.len());
165                        let mut generics = Vec::with_capacity(f.len());
166                        for (field_name, field_type, field_generics) in f {
167                            field_names.push(*field_name);
168                            field_types.push(field_type.clone());
169                            generics.push(field_generics);
170                        }
171                        (field_names, field_types, generics)
172                    })
173                {
174                    let remainder_struct_name =
175                        quote::format_ident!("__StateValidationGeneration_{name}Remainder_{i}");
176                    let mut generics = Generics::default();
177                    for g in field_generics {
178                        generics = merge_generics(generics, g);
179                    }
180                    let q = quote::quote! {
181                        pub struct #remainder_struct_name #generics {
182                            #(#field_names: #field_types),*
183                        }
184                    };
185                    state_conversions.push(q);
186                    field_types.sort();
187                    remainder_names.insert(field_types, remainder_struct_name);
188                    i += 1;
189                }
190            }
191            create_original_conversion_combinations(
192                &mut state_conversions,
193                &combination_names,
194                &remainder_names,
195                name,
196                &s.fields,
197                generics,
198            );
199            let cartesian_product = iter.iter().multi_cartesian_product().map(|f| {
200                let mut field_names = Vec::with_capacity(f.len());
201                let mut field_types = Vec::with_capacity(f.len());
202                let mut generics = Vec::with_capacity(f.len());
203                for (field_name, field_type, field_generics) in f {
204                    field_names.push(field_name);
205                    field_types.push(field_type);
206                    generics.push(field_generics);
207                }
208                (field_names, field_types, generics)
209            });
210            for (k, (field_names, field_types, field_generics)) in cartesian_product.enumerate() {
211                let mut all_field_generics = Generics::default();
212                for field_generics in field_generics.iter() {
213                    all_field_generics = merge_generics(all_field_generics, field_generics);
214                }
215                let fields_name_type_generics: Vec<_> = field_names
216                    .clone()
217                    .into_iter()
218                    .zip(field_types.clone().into_iter())
219                    .zip(field_generics.clone().into_iter())
220                    .collect();
221                for count in 0..=fields_count {
222                    for f in fields_name_type_generics.iter().combinations(count) {
223                        for (
224                            current_field_names,
225                            current_field_types,
226                            current_field_generics,
227                            other_field_names,
228                            other_field_types,
229                            other_field_generics,
230                        ) in f.into_iter().permutations(count).map(|subset| {
231                            let remainder: Vec<_> = fields_name_type_generics
232                                .iter()
233                                .filter(|((field_name_a, ..), ..)| {
234                                    !subset.iter().any(|((field_name_b, ..), ..)| {
235                                        field_name_a == field_name_b
236                                    })
237                                })
238                                .collect();
239                            let mut current_field_names = Vec::with_capacity(subset.len());
240                            let mut current_field_types = Vec::with_capacity(subset.len());
241                            let mut current_field_generics = Vec::with_capacity(subset.len());
242                            for ((field_name, field_type), generics) in subset {
243                                current_field_names.push(**field_name);
244                                current_field_types.push((*field_type).clone());
245                                current_field_generics.push(generics);
246                            }
247                            let mut other_field_names = Vec::with_capacity(remainder.len());
248                            let mut other_field_types = Vec::with_capacity(remainder.len());
249                            let mut other_field_generics = Vec::with_capacity(remainder.len());
250                            for ((field_name, field_type), generics) in remainder {
251                                other_field_names.push(**field_name);
252                                other_field_types.push((*field_type).clone());
253                                other_field_generics.push(generics);
254                            }
255                            (
256                                current_field_names,
257                                current_field_types,
258                                current_field_generics,
259                                other_field_names,
260                                other_field_types,
261                                other_field_generics,
262                            )
263                        }) {
264                            let combined_struct_name = combination_names
265                                .get(
266                                    &current_field_types
267                                        .iter()
268                                        .chain(other_field_types.iter())
269                                        .cloned()
270                                        .sorted()
271                                        .collect::<Vec<_>>(),
272                                )
273                                .expect("0: expected a combined struct");
274                            let remainder_struct_name = {
275                                let mut other_field_types = other_field_types.clone();
276                                other_field_types.sort();
277                                remainder_names.get(&other_field_types).unwrap()
278                            };
279                            let mut o = Generics::default();
280                            for other_field_generics in other_field_generics {
281                                o = merge_generics(o, other_field_generics);
282                            }
283                            let other_field_generics = o;
284                            let q = quote::quote! {
285                                impl #all_field_generics state_validation::StateFilterInputCombination<(#(#current_field_types),*)> for #remainder_struct_name #other_field_generics {
286                                    type Combined = #combined_struct_name #all_field_generics;
287                                    fn combine(self, (#(#current_field_names),*): (#(#current_field_types),*)) -> Self::Combined {
288                                        #combined_struct_name {
289                                            #(#current_field_names,)*
290                                            #(#other_field_names: self.#other_field_names),*
291                                        }
292                                    }
293                                }
294                                impl #all_field_generics state_validation::StateFilterInputConversion<(#(#current_field_types),*)> for #combined_struct_name #all_field_generics {
295                                    type Remainder = #remainder_struct_name #other_field_generics;
296                                    fn split_take(self) -> ((#(#current_field_types),*), Self::Remainder) {
297                                        (
298                                            (#(self.#current_field_names),*),
299                                            #remainder_struct_name {
300                                                #(#other_field_names: self.#other_field_names),*
301                                            },
302                                        )
303                                    }
304                                }
305                            };
306                            state_conversions.push(q);
307                        }
308                    }
309                }
310            }
311            state_conversions
312        }
313        _ => todo!(),
314    };
315    quote::quote! {
316        //impl #impl_generics state_validation::StateFilterInput for #name #ty_generics #where_clause {}
317        #(#state_conversions)*
318    }
319    .into()
320}
321
322fn create_original_conversion_combinations(
323    state_conversions: &mut Vec<proc_macro2::TokenStream>,
324    combination_names: &HashMap<Vec<ConversionSort>, Ident>,
325    remainder_names: &HashMap<Vec<ConversionSort>, Ident>,
326    name: &Ident,
327    fields: &syn::Fields,
328    all_field_generics: Vec<Ident>,
329) {
330    let fields: Vec<_> = fields
331        .iter()
332        .enumerate()
333        .map(|(i, field)| {
334            let field_name = field.ident.as_ref().expect("expected a named field");
335            (
336                field_name,
337                ConversionSort {
338                    sort_number: i,
339                    ty: ConversionType::Type(field.ty.clone()),
340                },
341                extract_generics_from_type(&field.ty),
342            )
343        })
344        .collect();
345    let mut all_field_generics: Generics = parse_quote!();
346    for (_, _, generics_b) in fields.iter() {
347        all_field_generics = merge_generics(all_field_generics, generics_b);
348    }
349    for k in 0..=fields.len() {
350        for combination in fields.iter().combinations(k) {
351            for (
352                current_field_names,
353                current_field_types,
354                current_field_generics,
355                other_field_names,
356                other_field_types,
357                other_field_generics,
358            ) in combination.into_iter().permutations(k).map(|subset| {
359                let remainder: Vec<_> = fields
360                    .iter()
361                    .filter(|(field_name_a, ..)| {
362                        !subset
363                            .iter()
364                            .any(|(field_name_b, ..)| field_name_a == field_name_b)
365                    })
366                    .collect();
367                let mut current_field_names = Vec::with_capacity(subset.len());
368                let mut current_field_types = Vec::with_capacity(subset.len());
369                let mut current_field_generics = Vec::new();
370                for (field_name, field_type, generics) in subset {
371                    current_field_names.push(*field_name);
372                    current_field_types.push((*field_type).clone());
373                    current_field_generics.push(generics);
374                }
375                let mut other_field_names = Vec::with_capacity(remainder.len());
376                let mut other_field_types = Vec::with_capacity(remainder.len());
377                let mut other_field_generics = Vec::new();
378                for (field_name, field_type, generics) in remainder {
379                    other_field_names.push(*field_name);
380                    other_field_types.push((*field_type).clone());
381                    other_field_generics.push(generics);
382                }
383                (
384                    current_field_names,
385                    current_field_types,
386                    current_field_generics,
387                    other_field_names,
388                    other_field_types,
389                    other_field_generics,
390                )
391            }) {
392                let combined_struct_name = combination_names
393                    .get(
394                        &current_field_types
395                            .iter()
396                            .chain(other_field_types.iter())
397                            .cloned()
398                            .sorted()
399                            .collect::<Vec<_>>(),
400                    )
401                    .expect("1: expected a combined struct");
402                let remainder_struct_name = {
403                    let mut other_field_types = other_field_types.clone();
404                    other_field_types.sort();
405                    remainder_names
406                        .get(&other_field_types)
407                        .expect("expected a remainder struct")
408                };
409                let mut current_field_generic = Generics::default();
410                for current_generics in current_field_generics {
411                    current_field_generic = merge_generics(current_field_generic, current_generics);
412                }
413                let current_field_generics = current_field_generic;
414                let mut other_field_generic = Generics::default();
415                for other_generics in other_field_generics {
416                    other_field_generic = merge_generics(other_field_generic, other_generics);
417                }
418                let other_field_generics = other_field_generic;
419                let q = quote::quote! {
420                    impl #all_field_generics state_validation::StateFilterInputConversion<(#(#current_field_types),*)> for #name #all_field_generics {
421                        type Remainder = #remainder_struct_name #other_field_generics;
422                        fn split_take(self) -> ((#(#current_field_types),*), Self::Remainder) {
423                            (
424                                (#(self.#current_field_names),*),
425                                #remainder_struct_name {
426                                    #(#other_field_names: self.#other_field_names),*
427                                },
428                            )
429                        }
430                    }
431                };
432                state_conversions.push(q);
433            }
434        }
435    }
436}
437
438// UTILITY //
439
440fn extract_generics_from_type(ty: &Type) -> Generics {
441    let mut type_params = BTreeSet::new();
442    let mut lifetime_params = BTreeSet::new();
443    let mut const_params = BTreeSet::new();
444
445    collect_generics(
446        ty,
447        &mut type_params,
448        &mut lifetime_params,
449        &mut const_params,
450    );
451
452    let mut generics = Generics::default();
453
454    for lt in lifetime_params {
455        generics
456            .params
457            .push(GenericParam::Lifetime(parse_quote!(#lt)));
458    }
459    for tp in type_params {
460        generics.params.push(GenericParam::Type(parse_quote!(#tp)));
461    }
462    for cp in const_params {
463        generics
464            .params
465            .push(GenericParam::Const(parse_quote!(const #cp: usize)));
466    }
467
468    generics
469}
470
471fn collect_generics(
472    ty: &Type,
473    type_params: &mut BTreeSet<syn::Ident>,
474    lifetime_params: &mut BTreeSet<Lifetime>,
475    const_params: &mut BTreeSet<syn::Ident>,
476) {
477    match ty {
478        Type::Path(TypePath { path, .. }) => {
479            for segment in &path.segments {
480                // Extract angle bracketed generics
481                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
482                    for arg in &args.args {
483                        match arg {
484                            GenericArgument::Type(inner_ty) => {
485                                collect_generics(
486                                    inner_ty,
487                                    type_params,
488                                    lifetime_params,
489                                    const_params,
490                                );
491                            }
492                            GenericArgument::Lifetime(lt) => {
493                                lifetime_params.insert(lt.clone());
494                            }
495                            GenericArgument::Const(expr) => {
496                                if let syn::Expr::Path(expr_path) = expr
497                                    && let Some(ident) = expr_path.path.get_ident()
498                                {
499                                    const_params.insert(ident.clone());
500                                }
501                            }
502                            _ => {}
503                        }
504                    }
505                }
506            }
507        }
508        Type::Reference(r) => {
509            if let Some(lt) = &r.lifetime {
510                lifetime_params.insert(lt.clone());
511            }
512            collect_generics(&r.elem, type_params, lifetime_params, const_params);
513        }
514        _ => {}
515    }
516}
517
518fn merge_generics(mut generics_a: Generics, generics_b: &Generics) -> Generics {
519    let mut existing = BTreeSet::new();
520    for param in &generics_a.params {
521        match param {
522            GenericParam::Type(tp) => {
523                existing.insert(tp.ident.to_string());
524            }
525            GenericParam::Lifetime(lt) => {
526                existing.insert(lt.lifetime.ident.to_string());
527            }
528            GenericParam::Const(cp) => {
529                existing.insert(cp.ident.to_string());
530            }
531        }
532    }
533
534    for param in &generics_b.params {
535        let name = match param {
536            GenericParam::Type(tp) => tp.ident.to_string(),
537            GenericParam::Lifetime(lt) => lt.lifetime.ident.to_string(),
538            GenericParam::Const(cp) => cp.ident.to_string(),
539        };
540        if !existing.contains(&name) {
541            generics_a.params.push(param.clone());
542            existing.insert(name);
543        }
544    }
545
546    match (&mut generics_a.where_clause, &generics_b.where_clause) {
547        (Some(a_wc), Some(b_wc)) => {
548            a_wc.predicates.extend(b_wc.predicates.clone());
549        }
550        (None, Some(b_wc)) => {
551            generics_a.where_clause = Some(b_wc.clone());
552        }
553        _ => {}
554    }
555
556    generics_a
557}