1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToSnakeCase;
4use itertools::Itertools;
5use proc_macro::TokenStream;
6use quote::TokenStreamExt;
7use syn::{
8 Expr, GenericArgument, GenericParam, Generics, Ident, Lifetime, Type, TypePath, ext,
9 parse_macro_input, parse_quote,
10};
11
12#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13struct ConversionSort {
14 sort_number: usize,
15 ty: ConversionType,
16}
17impl Ord for ConversionSort {
18 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
19 self.sort_number.cmp(&other.sort_number)
20 }
21}
22impl PartialOrd for ConversionSort {
23 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
24 Some(self.cmp(other))
25 }
26}
27impl quote::ToTokens for ConversionSort {
28 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
29 self.ty.to_tokens(tokens)
30 }
31}
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33enum ConversionType {
34 Type(syn::Type),
35 Generic {
36 generic_ident: Vec<syn::Ident>,
37 ty: syn::Type,
38 },
39}
40impl syn::parse::Parse for ConversionType {
41 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
42 if input.peek(syn::Ident) && input.peek2(syn::Token![=]) {
43 let generic_ident = input.parse()?;
44 let generic_ident = vec![generic_ident];
45 let _: syn::Token![=] = input.parse()?;
46 let ty = input.parse()?;
47 Ok(ConversionType::Generic { generic_ident, ty })
48 } else if input.peek(syn::Ident) && input.peek2(syn::Token![,]) {
49 let mut generic_ident = Vec::with_capacity(2);
50 loop {
51 let generic = input.parse::<syn::Ident>()?;
52 generic_ident.push(generic);
53 if input.parse::<syn::Token![,]>().is_err() {
54 break;
55 }
56 }
57 let _: syn::Token![=] = input.parse()?;
58 let ty = input.parse()?;
59 Ok(ConversionType::Generic { generic_ident, ty })
60 } else {
61 input.parse().map(ConversionType::Type)
62 }
63 }
64}
65impl quote::ToTokens for ConversionType {
66 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
67 match self {
68 ConversionType::Type(ty) => {
69 tokens.append_all(ty.to_token_stream());
70 }
71 ConversionType::Generic { ty, .. } => {
72 tokens.append_all(quote::quote!(#ty));
73 }
74 }
75 }
76}
77
78#[proc_macro_derive(StateFilterConversion, attributes(conversion))]
146pub fn state_filter_conversion(input: TokenStream) -> TokenStream {
147 let ast = parse_macro_input!(input as syn::DeriveInput);
148 let name = &ast.ident;
149 let state_conversions = match &ast.data {
150 syn::Data::Struct(s) => {
151 let fields_count = s.fields.len();
152 let mut state_conversions = Vec::with_capacity(fields_count);
153 let (iter, extra_fields_count) = {
154 let mut iter: Vec<_> = s
155 .fields
156 .iter()
157 .enumerate()
158 .map(|(i, field)| {
159 let field_name = field.ident.as_ref().expect("expected a named field");
160 let mut all_conversion_fields = Vec::new();
161 all_conversion_fields.push((
162 field_name.clone(),
163 ConversionSort {
164 sort_number: i,
165 ty: ConversionType::Type(field.ty.clone()),
166 },
167 extract_generics_from_type(&field.ty, &ast.generics),
168 ));
169 for attr in field
170 .attrs
171 .iter()
172 .filter(|attr| attr.path().is_ident("conversion"))
173 {
174 let f = attr
175 .parse_args::<ConversionType>()
176 .expect("expected a conversion type");
177 let generics = match &f {
178 ConversionType::Type(ty) => {
179 extract_generics_from_type(ty, &ast.generics)
180 }
181 ConversionType::Generic { generic_ident, .. } => {
182 parse_quote!(<#(#generic_ident),*>)
183 }
184 };
185 all_conversion_fields.push((
186 field_name.clone(),
187 ConversionSort {
188 sort_number: i,
189 ty: f,
190 },
191 generics,
192 ));
193 }
194 all_conversion_fields
195 })
196 .collect();
197 let extra_struct_fields: Vec<_> = ast
198 .attrs
199 .into_iter()
200 .filter(|attr| attr.path().is_ident("conversion"))
201 .enumerate()
202 .map(|(i, attr)| {
203 let f = attr
204 .parse_args::<ConversionType>()
205 .expect("expected a conversion type");
206 let (field_name, generics) = match &f {
207 ConversionType::Type(ty) => {
208 let ident = type_to_ident(ty);
209 (
210 quote::format_ident!("{}", ident.to_string().to_snake_case()),
211 extract_generics_from_type(ty, &ast.generics),
212 )
213 }
214 ConversionType::Generic { generic_ident, ty } => {
215 let ident = type_to_ident(ty);
216 (
217 quote::format_ident!("{}", ident.to_string().to_snake_case()),
218 parse_quote!(<#(#generic_ident),*>),
219 )
220 }
221 };
222 vec![(
224 field_name,
225 ConversionSort {
226 sort_number: i + iter.len(),
227 ty: f,
228 },
229 generics,
230 )]
231 })
232 .collect();
233 let extra_fields_count = extra_struct_fields.len();
234 iter.extend(extra_struct_fields);
235 (iter, extra_fields_count)
236 };
237 let mut combination_names = HashMap::new();
238 let mut remainder_names = HashMap::new();
239 let mut i = 0;
240 for powerset in iter.iter().powerset() {
241 for (field_names, mut field_types, field_generics) in
242 powerset.into_iter().multi_cartesian_product().map(|f| {
243 let mut field_names = Vec::with_capacity(f.len());
244 let mut field_types = Vec::with_capacity(f.len());
245 let mut generics = Vec::with_capacity(f.len());
246 for (field_name, field_type, field_generics) in f {
247 field_names.push(field_name);
248 field_types.push(field_type.clone());
249 generics.push(field_generics);
250 }
251 (field_names, field_types, generics)
252 })
253 {
254 let combination_struct_name =
255 quote::format_ident!("__StateValidationGeneration_{name}Combined_{i}");
256 let mut generics = Generics::default();
257 for g in field_generics {
258 generics = merge_generics(generics, g);
259 }
260 let q = quote::quote! {
261 pub struct #combination_struct_name #generics {
262 #(pub #field_names: #field_types),*
263 }
264 };
265 state_conversions.push(q);
266 field_types.sort();
267 combination_names.insert(field_types, combination_struct_name);
268 i += 1;
269 }
270 }
271 let mut i = 0;
272 for powerset in iter.iter().powerset() {
273 for (field_names, mut field_types, field_generics) in
274 powerset.into_iter().multi_cartesian_product().map(|f| {
275 let mut field_names = Vec::with_capacity(f.len());
276 let mut field_types = Vec::with_capacity(f.len());
277 let mut generics = Vec::with_capacity(f.len());
278 for (field_name, field_type, field_generics) in f {
279 field_names.push(field_name);
280 field_types.push(field_type.clone());
281 generics.push(field_generics);
282 }
283 (field_names, field_types, generics)
284 })
285 {
286 let remainder_struct_name =
287 quote::format_ident!("__StateValidationGeneration_{name}Remainder_{i}");
288 let mut generics = Generics::default();
289 for g in field_generics {
290 generics = merge_generics(generics, g);
291 }
292 let q = quote::quote! {
293 pub struct #remainder_struct_name #generics {
294 #(#field_names: #field_types),*
295 }
296 };
297 state_conversions.push(q);
298 field_types.sort();
299 remainder_names.insert(field_types, remainder_struct_name);
300 i += 1;
301 }
302 }
303 create_original_conversion_combinations(
304 &mut state_conversions,
305 &ast.generics,
306 &combination_names,
307 &remainder_names,
308 name,
309 &s.fields,
310 ast.generics.clone(),
311 );
312 let cartesian_product = iter.iter().multi_cartesian_product().map(|f| {
313 let mut field_names = Vec::with_capacity(f.len());
314 let mut field_types = Vec::with_capacity(f.len());
315 let mut generics = Vec::with_capacity(f.len());
316 for (field_name, field_type, field_generics) in f {
317 field_names.push(field_name);
318 field_types.push(field_type);
319 generics.push(field_generics);
320 }
321 (field_names, field_types, generics)
322 });
323 for (k, (field_names, field_types, field_generics)) in cartesian_product.enumerate() {
324 let mut all_field_generics = Generics::default();
325 for field_generics in field_generics.iter() {
326 all_field_generics = merge_generics(all_field_generics, field_generics);
327 }
328 let fields_name_type_generics: Vec<_> = field_names
329 .clone()
330 .into_iter()
331 .zip(field_types.clone().into_iter())
332 .zip(field_generics.clone().into_iter())
333 .collect();
334 for count in 0..=(fields_count + extra_fields_count) {
335 for f in fields_name_type_generics.iter().combinations(count) {
336 for (
337 current_field_names,
338 current_field_types,
339 current_field_generics,
340 other_field_names,
341 other_field_types,
342 other_field_generics,
343 ) in f.into_iter().permutations(count).map(|subset| {
344 let remainder: Vec<_> = fields_name_type_generics
345 .iter()
346 .filter(|((field_name_a, ..), ..)| {
347 !subset.iter().any(|((field_name_b, ..), ..)| {
348 field_name_a == field_name_b
349 })
350 })
351 .collect();
352 let mut current_field_names = Vec::with_capacity(subset.len());
353 let mut current_field_types = Vec::with_capacity(subset.len());
354 let mut current_field_generics = Vec::with_capacity(subset.len());
355 for ((field_name, field_type), generics) in subset {
356 current_field_names.push((*field_name).clone());
357 current_field_types.push((*field_type).clone());
358 current_field_generics.push((*generics).clone());
359 }
360 let mut other_field_names = Vec::with_capacity(remainder.len());
361 let mut other_field_types = Vec::with_capacity(remainder.len());
362 let mut other_field_generics = Vec::with_capacity(remainder.len());
363 for ((field_name, field_type), generics) in remainder {
364 other_field_names.push((*field_name).clone());
365 other_field_types.push((*field_type).clone());
366 other_field_generics.push((*generics).clone());
367 }
368 (
369 current_field_names,
370 current_field_types,
371 current_field_generics,
372 other_field_names,
373 other_field_types,
374 other_field_generics,
375 )
376 }) {
377 let r = current_field_types
378 .iter()
379 .chain(other_field_types.iter())
380 .cloned()
381 .sorted()
382 .collect::<Vec<_>>();
383 let combined_struct_name = combination_names
384 .get(&r)
385 .expect(&format!("0: expected a combined struct: {:?}", r));
386 let remainder_struct_name = {
387 let mut other_field_types = other_field_types.clone();
388 other_field_types.sort();
389 remainder_names.get(&other_field_types).unwrap()
390 };
391 let mut o = Generics::default();
392 for other_field_generics in other_field_generics {
393 o = merge_generics(o, &other_field_generics);
394 }
395 let other_field_generics = o;
396 let q = quote::quote! {
397 impl #all_field_generics state_validation::StateFilterInputCombination<(#(#current_field_types),*)> for #remainder_struct_name #other_field_generics {
398 type Combined = #combined_struct_name #all_field_generics;
399 fn combine(self, (#(#current_field_names),*): (#(#current_field_types),*)) -> Self::Combined {
400 #combined_struct_name {
401 #(#current_field_names,)*
402 #(#other_field_names: self.#other_field_names),*
403 }
404 }
405 }
406 impl #all_field_generics state_validation::StateFilterInputConversion<(#(#current_field_types),*)> for #combined_struct_name #all_field_generics {
407 type Remainder = #remainder_struct_name #other_field_generics;
408 fn split_take(self) -> ((#(#current_field_types),*), Self::Remainder) {
409 (
410 (#(self.#current_field_names),*),
411 #remainder_struct_name {
412 #(#other_field_names: self.#other_field_names),*
413 },
414 )
415 }
416 }
417 };
418 state_conversions.push(q);
419 }
420 }
421 }
422 }
423 state_conversions
424 }
425 _ => todo!(),
426 };
427 quote::quote! {
428 #(#state_conversions)*
429 }
430 .into()
431}
432
433fn create_original_conversion_combinations(
434 state_conversions: &mut Vec<proc_macro2::TokenStream>,
435 original_generics: &Generics,
436 combination_names: &HashMap<Vec<ConversionSort>, Ident>,
437 remainder_names: &HashMap<Vec<ConversionSort>, Ident>,
438 name: &Ident,
439 fields: &syn::Fields,
440 mut all_field_generics: Generics,
441) {
442 let fields: Vec<_> = fields
443 .iter()
444 .enumerate()
445 .map(|(i, field)| {
446 let field_name = field.ident.as_ref().expect("expected a named field");
447 (
448 field_name,
449 ConversionSort {
450 sort_number: i,
451 ty: ConversionType::Type(field.ty.clone()),
452 },
453 extract_generics_from_type(&field.ty, original_generics),
454 )
455 })
456 .collect();
457 for (_, _, generics_b) in fields.iter() {
458 all_field_generics = merge_generics(all_field_generics, generics_b);
459 }
460 for k in 0..=fields.len() {
461 for combination in fields.iter().combinations(k) {
462 for (
463 current_field_names,
464 current_field_types,
465 current_field_generics,
466 other_field_names,
467 other_field_types,
468 other_field_generics,
469 ) in combination.into_iter().permutations(k).map(|subset| {
470 let remainder: Vec<_> = fields
471 .iter()
472 .filter(|(field_name_a, ..)| {
473 !subset
474 .iter()
475 .any(|(field_name_b, ..)| field_name_a == field_name_b)
476 })
477 .collect();
478 let mut current_field_names = Vec::with_capacity(subset.len());
479 let mut current_field_types = Vec::with_capacity(subset.len());
480 let mut current_field_generics = Vec::new();
481 for (field_name, field_type, generics) in subset {
482 current_field_names.push((*field_name).clone());
483 current_field_types.push(field_type.clone());
484 current_field_generics.push(generics.clone());
485 }
486 let mut other_field_names = Vec::with_capacity(remainder.len());
487 let mut other_field_types = Vec::with_capacity(remainder.len());
488 let mut other_field_generics = Vec::new();
489 for (field_name, field_type, generics) in remainder {
490 other_field_names.push((*field_name).clone());
491 other_field_types.push(field_type.clone());
492 other_field_generics.push(generics.clone());
493 }
494 (
495 current_field_names,
496 current_field_types,
497 current_field_generics,
498 other_field_names,
499 other_field_types,
500 other_field_generics,
501 )
502 }) {
503 let r = current_field_types
504 .iter()
505 .chain(other_field_types.iter())
506 .cloned()
507 .sorted()
508 .collect::<Vec<_>>();
509 let combined_struct_name = combination_names.get(&r).expect(&format!(
510 "1: expected a combined struct: {:#?}\nCOMBINATION NAMES: {:#?}",
511 r, combination_names,
512 ));
513 let remainder_struct_name = {
514 let mut other_field_types = other_field_types.clone();
515 other_field_types.sort();
516 remainder_names
517 .get(&other_field_types)
518 .expect("expected a remainder struct")
519 };
520 let mut current_field_generic = Generics::default();
521 for current_generics in current_field_generics {
522 current_field_generic =
523 merge_generics(current_field_generic, ¤t_generics);
524 }
525 let current_field_generics = current_field_generic;
526 let mut other_field_generic = Generics::default();
527 for other_generics in other_field_generics {
528 other_field_generic = merge_generics(other_field_generic, &other_generics);
529 }
530 let other_field_generics = other_field_generic;
531 let q = quote::quote! {
532 impl #all_field_generics state_validation::StateFilterInputConversion<(#(#current_field_types),*)> for #name #all_field_generics {
533 type Remainder = #remainder_struct_name #other_field_generics;
534 fn split_take(self) -> ((#(#current_field_types),*), Self::Remainder) {
535 (
536 (#(self.#current_field_names),*),
537 #remainder_struct_name {
538 #(#other_field_names: self.#other_field_names),*
539 },
540 )
541 }
542 }
543 };
544 state_conversions.push(q);
545 }
546 }
547 }
548}
549
550fn extract_generics_from_type(ty: &Type, original_generics: &Generics) -> Generics {
553 let mut type_params = BTreeSet::new();
554 let mut lifetime_params = BTreeSet::new();
555 let mut const_params = BTreeSet::new();
556
557 collect_generics(
558 ty,
559 original_generics,
560 &mut type_params,
561 &mut lifetime_params,
562 &mut const_params,
563 );
564
565 let mut generics = Generics::default();
566
567 for lt in lifetime_params {
568 generics
569 .params
570 .push(GenericParam::Lifetime(parse_quote!(#lt)));
571 }
572 for tp in type_params {
573 generics.params.push(GenericParam::Type(parse_quote!(#tp)));
574 }
575 for cp in const_params {
576 generics
577 .params
578 .push(GenericParam::Const(parse_quote!(const #cp: usize)));
579 }
580
581 generics
582}
583
584fn collect_generics(
585 ty: &Type,
586 original_generics: &Generics,
587 type_params: &mut BTreeSet<syn::Ident>,
588 lifetime_params: &mut BTreeSet<Lifetime>,
589 const_params: &mut BTreeSet<syn::Ident>,
590) {
591 match ty {
592 Type::Path(TypePath { path, .. }) => {
593 for segment in &path.segments {
594 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
596 for arg in &args.args {
597 match arg {
598 GenericArgument::Type(inner_ty) => {
599 if let Type::Path(p) = inner_ty
600 && let Some(ident) = p.path.get_ident()
601 && original_generics.type_params().any(|ty| ty.ident == *ident)
602 {
603 type_params.insert(ident.clone());
604 }
605 collect_generics(
606 inner_ty,
607 original_generics,
608 type_params,
609 lifetime_params,
610 const_params,
611 );
612 }
613 GenericArgument::Lifetime(lt) => {
614 lifetime_params.insert(lt.clone());
615 }
616 GenericArgument::Const(expr) => {
617 if let syn::Expr::Path(expr_path) = expr
618 && let Some(ident) = expr_path.path.get_ident()
619 {
620 const_params.insert(ident.clone());
621 }
622 }
623 _ => {}
624 }
625 }
626 }
627 }
628 }
629 Type::Reference(r) => {
630 if let Some(lt) = &r.lifetime {
631 lifetime_params.insert(lt.clone());
632 }
633 collect_generics(
634 &r.elem,
635 original_generics,
636 type_params,
637 lifetime_params,
638 const_params,
639 );
640 }
641 _ => {}
642 }
643}
644
645fn merge_generics(mut generics_a: Generics, generics_b: &Generics) -> Generics {
646 let mut existing = BTreeSet::new();
647 for param in &generics_a.params {
648 match param {
649 GenericParam::Type(tp) => {
650 existing.insert(tp.ident.to_string());
651 }
652 GenericParam::Lifetime(lt) => {
653 existing.insert(lt.lifetime.ident.to_string());
654 }
655 GenericParam::Const(cp) => {
656 existing.insert(cp.ident.to_string());
657 }
658 }
659 }
660
661 for param in &generics_b.params {
662 let name = match param {
663 GenericParam::Type(tp) => tp.ident.to_string(),
664 GenericParam::Lifetime(lt) => lt.lifetime.ident.to_string(),
665 GenericParam::Const(cp) => cp.ident.to_string(),
666 };
667 if !existing.contains(&name) {
668 generics_a.params.push(param.clone());
669 existing.insert(name);
670 }
671 }
672
673 match (&mut generics_a.where_clause, &generics_b.where_clause) {
674 (Some(a_wc), Some(b_wc)) => {
675 a_wc.predicates.extend(b_wc.predicates.clone());
676 }
677 (None, Some(b_wc)) => {
678 generics_a.where_clause = Some(b_wc.clone());
679 }
680 _ => {}
681 }
682
683 generics_a
684}
685
686fn type_to_ident(ty: &Type) -> &Ident {
687 match ty {
688 Type::Path(type_path) => type_path
689 .path
690 .segments
691 .last()
692 .map(|seg| &seg.ident)
693 .unwrap(),
694 _ => unimplemented!(),
695 }
696}