1use std::collections::{BTreeSet, HashMap};
2
3use itertools::Itertools;
4use proc_macro::TokenStream;
5use quote::TokenStreamExt;
6use syn::{
7 Expr, GenericArgument, GenericParam, Generics, Ident, Lifetime, Type, TypePath,
8 parse_macro_input, parse_quote,
9};
10
11#[derive(Clone, PartialEq, Eq, Hash)]
12struct ConversionSort {
13 sort_number: usize,
14 ty: ConversionType,
15}
16impl Ord for ConversionSort {
17 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
18 self.sort_number.cmp(&other.sort_number)
19 }
20}
21impl PartialOrd for ConversionSort {
22 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
23 Some(self.cmp(other))
24 }
25}
26impl quote::ToTokens for ConversionSort {
27 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
28 self.ty.to_tokens(tokens)
29 }
30}
31#[derive(Clone, PartialEq, Eq, Hash)]
32enum ConversionType {
33 Type(syn::Type),
34 Generic {
35 generic_ident: Vec<syn::Ident>,
36 path: syn::Path,
37 },
38}
39impl syn::parse::Parse for ConversionType {
40 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
41 if input.peek(syn::Ident) && input.peek2(syn::Token![=]) {
42 let generic_ident = input.parse()?;
43 let _: syn::Token![=] = input.parse()?;
44 let path = input.parse()?;
45 Ok(ConversionType::Generic {
46 generic_ident: vec![generic_ident],
47 path,
48 })
49 } else {
50 input.parse().map(ConversionType::Type)
51 }
52 }
53}
54impl quote::ToTokens for ConversionType {
55 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
56 match self {
57 ConversionType::Type(ty) => {
58 tokens.append_all(ty.to_token_stream());
59 }
60 ConversionType::Generic {
61 generic_ident,
62 path,
63 } => {
64 tokens.append_all(quote::quote!(#path));
65 }
66 }
67 }
68}
69
70#[proc_macro_derive(StateFilterConversion, attributes(conversion))]
71pub fn state_filter_conversion(input: TokenStream) -> TokenStream {
72 let ast = parse_macro_input!(input as syn::DeriveInput);
73 let name = &ast.ident;
74 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
75 let generics: Vec<_> = ast
76 .generics
77 .type_params()
78 .into_iter()
79 .map(|ty| ty.ident.clone())
80 .collect();
81 let state_conversions = match &ast.data {
82 syn::Data::Struct(s) => {
83 let fields_count = s.fields.len();
84 let mut state_conversions = Vec::with_capacity(fields_count);
85 let iter: Vec<_> = s
86 .fields
87 .iter()
88 .enumerate()
89 .map(|(i, field)| {
90 let field_name = field.ident.as_ref().expect("expected a named field");
91 let mut all_conversion_fields = Vec::new();
92 all_conversion_fields.push((
93 field_name,
94 ConversionSort {
95 sort_number: i,
96 ty: ConversionType::Type(field.ty.clone()),
97 },
98 extract_generics_from_type(&field.ty),
99 ));
100 for attr in field
101 .attrs
102 .iter()
103 .filter(|attr| attr.path().is_ident("conversion"))
104 {
105 let f = attr
106 .parse_args::<ConversionType>()
107 .expect("expected a conversion type");
108 let generics = match &f {
109 ConversionType::Type(ty) => extract_generics_from_type(ty),
110 ConversionType::Generic { generic_ident, .. } => {
111 parse_quote!(<#(#generic_ident),*>)
112 }
113 };
114 all_conversion_fields.push((
115 field_name,
116 ConversionSort {
117 sort_number: i,
118 ty: f,
119 },
120 generics,
121 ));
122 }
123 all_conversion_fields
124 })
125 .collect();
126 let mut combination_names = HashMap::new();
127 let mut remainder_names = HashMap::new();
128 for (i, (field_names, mut field_types, field_generics)) in iter
129 .iter()
130 .multi_cartesian_product()
131 .map(|f| {
132 let mut field_names = Vec::with_capacity(f.len());
133 let mut field_types = Vec::with_capacity(f.len());
134 let mut generics = Vec::with_capacity(f.len());
135 for (field_name, field_type, field_generics) in f {
136 field_names.push(*field_name);
137 field_types.push(field_type.clone());
138 generics.push(field_generics);
139 }
140 (field_names, field_types, generics)
141 })
142 .enumerate()
143 {
144 let combination_struct_name =
145 quote::format_ident!("__StateValidationGeneration_{name}Combined_{i}");
146 let mut generics = Generics::default();
147 for g in field_generics {
148 generics = merge_generics(generics, g);
149 }
150 let q = quote::quote! {
151 pub struct #combination_struct_name #generics {
152 #(pub #field_names: #field_types),*
153 }
154 };
155 state_conversions.push(q);
156 field_types.sort();
157 combination_names.insert(field_types, combination_struct_name);
158 }
159 let mut i = 0;
160 for powerset in iter.iter().powerset() {
161 for (field_names, mut field_types, field_generics) in
162 powerset.into_iter().multi_cartesian_product().map(|f| {
163 let mut field_names = Vec::with_capacity(f.len());
164 let mut field_types = Vec::with_capacity(f.len());
165 let mut generics = Vec::with_capacity(f.len());
166 for (field_name, field_type, field_generics) in f {
167 field_names.push(*field_name);
168 field_types.push(field_type.clone());
169 generics.push(field_generics);
170 }
171 (field_names, field_types, generics)
172 })
173 {
174 let remainder_struct_name =
175 quote::format_ident!("__StateValidationGeneration_{name}Remainder_{i}");
176 let mut generics = Generics::default();
177 for g in field_generics {
178 generics = merge_generics(generics, g);
179 }
180 let q = quote::quote! {
181 pub struct #remainder_struct_name #generics {
182 #(#field_names: #field_types),*
183 }
184 };
185 state_conversions.push(q);
186 field_types.sort();
187 remainder_names.insert(field_types, remainder_struct_name);
188 i += 1;
189 }
190 }
191 create_original_conversion_combinations(
192 &mut state_conversions,
193 &combination_names,
194 &remainder_names,
195 name,
196 &s.fields,
197 generics,
198 );
199 let cartesian_product = iter.iter().multi_cartesian_product().map(|f| {
200 let mut field_names = Vec::with_capacity(f.len());
201 let mut field_types = Vec::with_capacity(f.len());
202 let mut generics = Vec::with_capacity(f.len());
203 for (field_name, field_type, field_generics) in f {
204 field_names.push(field_name);
205 field_types.push(field_type);
206 generics.push(field_generics);
207 }
208 (field_names, field_types, generics)
209 });
210 for (k, (field_names, field_types, field_generics)) in cartesian_product.enumerate() {
211 let mut all_field_generics = Generics::default();
212 for field_generics in field_generics.iter() {
213 all_field_generics = merge_generics(all_field_generics, field_generics);
214 }
215 let fields_name_type_generics: Vec<_> = field_names
216 .clone()
217 .into_iter()
218 .zip(field_types.clone().into_iter())
219 .zip(field_generics.clone().into_iter())
220 .collect();
221 for count in 0..=fields_count {
222 for f in fields_name_type_generics.iter().combinations(count) {
223 for (
224 current_field_names,
225 current_field_types,
226 current_field_generics,
227 other_field_names,
228 other_field_types,
229 other_field_generics,
230 ) in f.into_iter().permutations(count).map(|subset| {
231 let remainder: Vec<_> = fields_name_type_generics
232 .iter()
233 .filter(|((field_name_a, ..), ..)| {
234 !subset.iter().any(|((field_name_b, ..), ..)| {
235 field_name_a == field_name_b
236 })
237 })
238 .collect();
239 let mut current_field_names = Vec::with_capacity(subset.len());
240 let mut current_field_types = Vec::with_capacity(subset.len());
241 let mut current_field_generics = Vec::with_capacity(subset.len());
242 for ((field_name, field_type), generics) in subset {
243 current_field_names.push(**field_name);
244 current_field_types.push((*field_type).clone());
245 current_field_generics.push(generics);
246 }
247 let mut other_field_names = Vec::with_capacity(remainder.len());
248 let mut other_field_types = Vec::with_capacity(remainder.len());
249 let mut other_field_generics = Vec::with_capacity(remainder.len());
250 for ((field_name, field_type), generics) in remainder {
251 other_field_names.push(**field_name);
252 other_field_types.push((*field_type).clone());
253 other_field_generics.push(generics);
254 }
255 (
256 current_field_names,
257 current_field_types,
258 current_field_generics,
259 other_field_names,
260 other_field_types,
261 other_field_generics,
262 )
263 }) {
264 let combined_struct_name = combination_names
265 .get(
266 ¤t_field_types
267 .iter()
268 .chain(other_field_types.iter())
269 .cloned()
270 .sorted()
271 .collect::<Vec<_>>(),
272 )
273 .expect("0: expected a combined struct");
274 let remainder_struct_name = {
275 let mut other_field_types = other_field_types.clone();
276 other_field_types.sort();
277 remainder_names.get(&other_field_types).unwrap()
278 };
279 let mut o = Generics::default();
280 for other_field_generics in other_field_generics {
281 o = merge_generics(o, other_field_generics);
282 }
283 let other_field_generics = o;
284 let q = quote::quote! {
285 impl #all_field_generics state_validation::StateFilterInputCombination<(#(#current_field_types),*)> for #remainder_struct_name #other_field_generics {
286 type Combined = #combined_struct_name #all_field_generics;
287 fn combine(self, (#(#current_field_names),*): (#(#current_field_types),*)) -> Self::Combined {
288 #combined_struct_name {
289 #(#current_field_names,)*
290 #(#other_field_names: self.#other_field_names),*
291 }
292 }
293 }
294 impl #all_field_generics state_validation::StateFilterInputConversion<(#(#current_field_types),*)> for #combined_struct_name #all_field_generics {
295 type Remainder = #remainder_struct_name #other_field_generics;
296 fn split_take(self) -> ((#(#current_field_types),*), Self::Remainder) {
297 (
298 (#(self.#current_field_names),*),
299 #remainder_struct_name {
300 #(#other_field_names: self.#other_field_names),*
301 },
302 )
303 }
304 }
305 };
306 state_conversions.push(q);
307 }
308 }
309 }
310 }
311 state_conversions
312 }
313 _ => todo!(),
314 };
315 quote::quote! {
316 #(#state_conversions)*
318 }
319 .into()
320}
321
322fn create_original_conversion_combinations(
323 state_conversions: &mut Vec<proc_macro2::TokenStream>,
324 combination_names: &HashMap<Vec<ConversionSort>, Ident>,
325 remainder_names: &HashMap<Vec<ConversionSort>, Ident>,
326 name: &Ident,
327 fields: &syn::Fields,
328 all_field_generics: Vec<Ident>,
329) {
330 let fields: Vec<_> = fields
331 .iter()
332 .enumerate()
333 .map(|(i, field)| {
334 let field_name = field.ident.as_ref().expect("expected a named field");
335 (
336 field_name,
337 ConversionSort {
338 sort_number: i,
339 ty: ConversionType::Type(field.ty.clone()),
340 },
341 extract_generics_from_type(&field.ty),
342 )
343 })
344 .collect();
345 let mut all_field_generics: Generics = parse_quote!();
346 for (_, _, generics_b) in fields.iter() {
347 all_field_generics = merge_generics(all_field_generics, generics_b);
348 }
349 for k in 0..=fields.len() {
350 for combination in fields.iter().combinations(k) {
351 for (
352 current_field_names,
353 current_field_types,
354 current_field_generics,
355 other_field_names,
356 other_field_types,
357 other_field_generics,
358 ) in combination.into_iter().permutations(k).map(|subset| {
359 let remainder: Vec<_> = fields
360 .iter()
361 .filter(|(field_name_a, ..)| {
362 !subset
363 .iter()
364 .any(|(field_name_b, ..)| field_name_a == field_name_b)
365 })
366 .collect();
367 let mut current_field_names = Vec::with_capacity(subset.len());
368 let mut current_field_types = Vec::with_capacity(subset.len());
369 let mut current_field_generics = Vec::new();
370 for (field_name, field_type, generics) in subset {
371 current_field_names.push(*field_name);
372 current_field_types.push((*field_type).clone());
373 current_field_generics.push(generics);
374 }
375 let mut other_field_names = Vec::with_capacity(remainder.len());
376 let mut other_field_types = Vec::with_capacity(remainder.len());
377 let mut other_field_generics = Vec::new();
378 for (field_name, field_type, generics) in remainder {
379 other_field_names.push(*field_name);
380 other_field_types.push((*field_type).clone());
381 other_field_generics.push(generics);
382 }
383 (
384 current_field_names,
385 current_field_types,
386 current_field_generics,
387 other_field_names,
388 other_field_types,
389 other_field_generics,
390 )
391 }) {
392 let combined_struct_name = combination_names
393 .get(
394 ¤t_field_types
395 .iter()
396 .chain(other_field_types.iter())
397 .cloned()
398 .sorted()
399 .collect::<Vec<_>>(),
400 )
401 .expect("1: expected a combined struct");
402 let remainder_struct_name = {
403 let mut other_field_types = other_field_types.clone();
404 other_field_types.sort();
405 remainder_names
406 .get(&other_field_types)
407 .expect("expected a remainder struct")
408 };
409 let mut current_field_generic = Generics::default();
410 for current_generics in current_field_generics {
411 current_field_generic = merge_generics(current_field_generic, current_generics);
412 }
413 let current_field_generics = current_field_generic;
414 let mut other_field_generic = Generics::default();
415 for other_generics in other_field_generics {
416 other_field_generic = merge_generics(other_field_generic, other_generics);
417 }
418 let other_field_generics = other_field_generic;
419 let q = quote::quote! {
420 impl #all_field_generics state_validation::StateFilterInputConversion<(#(#current_field_types),*)> for #name #all_field_generics {
421 type Remainder = #remainder_struct_name #other_field_generics;
422 fn split_take(self) -> ((#(#current_field_types),*), Self::Remainder) {
423 (
424 (#(self.#current_field_names),*),
425 #remainder_struct_name {
426 #(#other_field_names: self.#other_field_names),*
427 },
428 )
429 }
430 }
431 };
432 state_conversions.push(q);
433 }
434 }
435 }
436}
437
438fn extract_generics_from_type(ty: &Type) -> Generics {
441 let mut type_params = BTreeSet::new();
442 let mut lifetime_params = BTreeSet::new();
443 let mut const_params = BTreeSet::new();
444
445 collect_generics(
446 ty,
447 &mut type_params,
448 &mut lifetime_params,
449 &mut const_params,
450 );
451
452 let mut generics = Generics::default();
453
454 for lt in lifetime_params {
455 generics
456 .params
457 .push(GenericParam::Lifetime(parse_quote!(#lt)));
458 }
459 for tp in type_params {
460 generics.params.push(GenericParam::Type(parse_quote!(#tp)));
461 }
462 for cp in const_params {
463 generics
464 .params
465 .push(GenericParam::Const(parse_quote!(const #cp: usize)));
466 }
467
468 generics
469}
470
471fn collect_generics(
472 ty: &Type,
473 type_params: &mut BTreeSet<syn::Ident>,
474 lifetime_params: &mut BTreeSet<Lifetime>,
475 const_params: &mut BTreeSet<syn::Ident>,
476) {
477 match ty {
478 Type::Path(TypePath { path, .. }) => {
479 for segment in &path.segments {
480 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
482 for arg in &args.args {
483 match arg {
484 GenericArgument::Type(inner_ty) => {
485 collect_generics(
486 inner_ty,
487 type_params,
488 lifetime_params,
489 const_params,
490 );
491 }
492 GenericArgument::Lifetime(lt) => {
493 lifetime_params.insert(lt.clone());
494 }
495 GenericArgument::Const(expr) => {
496 if let syn::Expr::Path(expr_path) = expr
497 && let Some(ident) = expr_path.path.get_ident()
498 {
499 const_params.insert(ident.clone());
500 }
501 }
502 _ => {}
503 }
504 }
505 }
506 }
507 }
508 Type::Reference(r) => {
509 if let Some(lt) = &r.lifetime {
510 lifetime_params.insert(lt.clone());
511 }
512 collect_generics(&r.elem, type_params, lifetime_params, const_params);
513 }
514 _ => {}
515 }
516}
517
518fn merge_generics(mut generics_a: Generics, generics_b: &Generics) -> Generics {
519 let mut existing = BTreeSet::new();
520 for param in &generics_a.params {
521 match param {
522 GenericParam::Type(tp) => {
523 existing.insert(tp.ident.to_string());
524 }
525 GenericParam::Lifetime(lt) => {
526 existing.insert(lt.lifetime.ident.to_string());
527 }
528 GenericParam::Const(cp) => {
529 existing.insert(cp.ident.to_string());
530 }
531 }
532 }
533
534 for param in &generics_b.params {
535 let name = match param {
536 GenericParam::Type(tp) => tp.ident.to_string(),
537 GenericParam::Lifetime(lt) => lt.lifetime.ident.to_string(),
538 GenericParam::Const(cp) => cp.ident.to_string(),
539 };
540 if !existing.contains(&name) {
541 generics_a.params.push(param.clone());
542 existing.insert(name);
543 }
544 }
545
546 match (&mut generics_a.where_clause, &generics_b.where_clause) {
547 (Some(a_wc), Some(b_wc)) => {
548 a_wc.predicates.extend(b_wc.predicates.clone());
549 }
550 (None, Some(b_wc)) => {
551 generics_a.where_clause = Some(b_wc.clone());
552 }
553 _ => {}
554 }
555
556 generics_a
557}