tagged_hybrid/
lib.rs

1use std::collections::{HashMap, HashSet};
2
3use heck::ToSnakeCase;
4use itertools::{izip, Itertools};
5use proc_macro::TokenStream;
6use proc_macro2::{Group, Ident, TokenStream as TokenStream2, TokenTree};
7use quote::{quote, ToTokens};
8use smallvec::{smallvec, SmallVec};
9use syn::{
10    Data, DeriveInput, Fields, FieldsNamed, GenericArgument, Lifetime, Path, PathArguments, Type,
11    TypeParamBound,
12};
13use tap::{Pipe, Tap};
14
15#[proc_macro_attribute]
16pub fn hybrid_tagged(attr: TokenStream, item: TokenStream) -> TokenStream {
17    hybrid_tagged_impl(attr.into(), item.into()).into()
18}
19
20fn hybrid_tagged_impl(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
21    let tagged_type: DeriveInput = syn::parse2(item).unwrap();
22    let Data::Enum(tagged_enum) = tagged_type.data else {
23        panic!("hybrid_tagged is meant to be invoked on an enum")
24    };
25
26    let args = attr_args(attr);
27
28    let common_fields = args
29            .get("fields")
30            .expect("Argument `fields` was not provided")
31            .pipe(|tokens| {
32                syn::parse2::<FieldsNamed>(tokens.to_token_stream())
33                    .expect("Fields should be written with the same notation as a struct declaration, inside curly braces")
34            });
35    let common_fields_inner = &common_fields.named;
36    let tag = args
37        .get("tag")
38        .expect("Argument `tag` was not provided")
39        .to_token_stream();
40    let variants = tagged_enum.variants;
41    let generics = tagged_type.generics;
42    // whichever variants we do not have object data to collect for
43    let empty_variants = variants
44        .iter()
45        .cloned()
46        .filter(|variant| matches!(variant.fields, Fields::Unit))
47        .map(|variant| variant.ident)
48        .collect::<HashSet<_>>();
49
50    let container_name = tagged_type.ident;
51    let module_name = Ident::new(
52        &format!("{}_data", container_name.to_string().to_snake_case()),
53        container_name.span(),
54    );
55    let data_enum_name = Ident::new(&format!("{container_name}Data"), container_name.span());
56
57    let original_attrs = tagged_type.attrs;
58
59    let visibility = tagged_type.vis;
60
61    let variant_lifetimes = variants
62        .iter()
63        .map(|variant| {
64            variant
65                .fields
66                .iter()
67                .flat_map(|fld| type_lifetimes(&fld.ty))
68                .collect::<HashSet<_>>()
69                .pipe(|lifetimes| {
70                    (lifetimes.len() > 0).then(|| {
71                        let it = lifetimes.iter();
72                        Some(quote!(<#(#it),*>))
73                    })
74                })
75        })
76        .collect_vec();
77
78    // takes the variants of the annotated enum and adds the common fields to each one
79    let raw_variants = variants.clone().tap_mut(|variants| {
80        variants
81            .iter_mut()
82            .zip(variant_lifetimes.iter())
83            .for_each(|(variant, lft)| {
84                let attrs = &variant.attrs;
85                let name = &variant.ident;
86                let common_fields = common_fields_inner.iter();
87                let borrow_attr = lft.is_some().then(|| quote!(#[serde(borrow)]));
88
89                *variant = if empty_variants.contains(&variant.ident) {
90                    syn::parse_quote!(
91                        #(#attrs)*
92                        #name {
93                            #(#common_fields),*
94                        }
95                    )
96                } else {
97                    syn::parse_quote!(
98                        #(#attrs)*
99                        #name {
100                            #borrow_attr data: #name #lft,
101                            #(#common_fields),*
102                        }
103                    )
104                };
105            })
106    });
107
108    let data_variants = variants.clone().tap_mut(|variants| {
109        variants.iter_mut().for_each(|variant| {
110            variant.attrs.clear();
111            match variant.fields {
112                Fields::Named(ref mut f) => {
113                    for field in &mut f.named {
114                        field.attrs.clear();
115                    }
116                }
117                _ => (),
118            }
119        })
120    });
121
122    let struct_attrs = args.get("struct_attrs").map(|tokens| {
123        syn::parse2::<Group>(tokens.into_token_stream())
124            .unwrap()
125            .stream()
126    });
127
128    // raw enum which serde directly translates from json
129    let raw_enum = quote!(
130        #[derive(serde::Serialize, serde::Deserialize)]
131        #[serde(tag=#tag)]
132        #(#original_attrs)*
133        enum Raw #generics {
134            #raw_variants
135        }
136    );
137
138    // data enum containing the specific data for each variant
139    let data_enum = quote!(
140        #[derive(Clone)]
141        #struct_attrs
142        #visibility enum #data_enum_name #generics {
143            #data_variants
144        }
145    );
146
147    // public-facing struct which takes the place of the annotated enum
148    let common_fields_visibility = common_fields_inner.clone().tap_mut(|fields| {
149        fields
150            .iter_mut()
151            .for_each(|field| field.vis = visibility.clone())
152    });
153    let public_struct = {
154        let borrow_attr = if generics.lifetimes().next().is_some() {
155            quote!(#[serde(borrow)])
156        } else {
157            quote!()
158        };
159        quote!(
160            #[derive(serde::Serialize, serde::Deserialize, Clone)]
161            #[serde(from = "Raw", into = "Raw")]
162            #struct_attrs
163            #visibility struct #container_name #generics {
164                #borrow_attr pub data: #data_enum_name #generics,
165                #common_fields_visibility
166            }
167        )
168    };
169
170    let common_fields_names = common_fields_inner
171        .iter()
172        .cloned()
173        .map(|field| field.ident.expect("Fields of this enum must be named"))
174        .collect_vec();
175    let common_fields_renamed = common_fields_names
176        .iter()
177        .cloned()
178        .map(|name| syn::parse_str::<Ident>(&format!("c_{name}")).unwrap())
179        .collect_vec(); // TODO: Make these names hygienic
180    let variant_fields_names = variants
181        .iter()
182        .map(|variant| {
183            variant
184                .fields
185                .iter()
186                .map(|field| field.ident.clone().unwrap())
187                .collect_vec()
188        })
189        .collect_vec();
190    let raw_fields_names = raw_variants
191        .iter()
192        .map(|variant| {
193            variant
194                .fields
195                .iter()
196                .map(|field| field.ident.clone().unwrap())
197                .collect_vec()
198        })
199        .collect_vec();
200
201    let variant_names = variants
202        .iter()
203        .cloned()
204        .map(|variant| variant.ident)
205        .collect_vec();
206
207    let variant_structs = variants.iter().zip(variant_lifetimes.iter()).map(|(variant, lft)| {
208            let name = &variant.ident;
209            let fields = &variant.fields;
210
211            if empty_variants.contains(&variant.ident) {
212                quote!( #[derive(serde::Serialize, serde::Deserialize)] #struct_attrs struct #name; )
213            } else {
214                quote!( #[derive(serde::Serialize, serde::Deserialize)] #struct_attrs struct #name #lft #fields )
215            }
216        });
217
218    let (convert_from_raw, convert_to_raw): (Vec<_>, Vec<_>) =
219        izip!(variant_fields_names, raw_fields_names)
220            .zip(variant_names)
221            .map(|((variant, _), variant_name)| {
222                if empty_variants.contains(&variant_name) {
223                    let from_raw = quote!(
224                        Raw :: #variant_name {
225                            #(#common_fields_names: #common_fields_renamed),*, ..
226                        } => Self {
227                            data: #data_enum_name :: #variant_name ,
228                            #(#common_fields_names: #common_fields_renamed),*
229                        }
230                    );
231
232                    let to_raw = quote!(
233                        #data_enum_name :: #variant_name => Self :: #variant_name {
234                            #(#common_fields_names: f. #common_fields_names)*,
235                        }
236                    );
237
238                    (from_raw, to_raw)
239                } else {
240                    let from_raw = quote!(
241                        Raw :: #variant_name {
242                            data: #variant_name {
243                                #(#variant),*
244                            },
245                            #(#common_fields_names: #common_fields_renamed),*
246                        } => Self {
247                            data: #data_enum_name :: #variant_name {
248                                #(#variant),*
249                            }, #(#common_fields_names: #common_fields_renamed),*
250                        }
251                    );
252
253                    let to_raw = quote!(
254                        #data_enum_name :: #variant_name {
255                            #(#variant),*
256                        } => Self :: #variant_name {
257                            data: #variant_name {
258                                #(#variant),*
259                            },
260                            #(#common_fields_names: f. #common_fields_names),*
261                        }
262                    );
263
264                    (from_raw, to_raw)
265                }
266            })
267            .unzip();
268
269    // From impls for converting to and from the public struct and private type
270    let convert_impls = quote!(
271        impl #generics From<Raw #generics> for #container_name #generics {
272            fn from(f: Raw #generics) -> Self {
273                match f {
274                    #(#convert_from_raw),*
275                }
276            }
277        }
278
279        impl #generics From<#container_name #generics > for Raw #generics {
280            fn from(f: #container_name #generics) -> Self {
281                match f.data {
282                    #(#convert_to_raw),*
283                }
284            }
285        }
286    );
287
288    // all put together
289    quote!(
290        #visibility use #module_name::{
291            #container_name,
292            #data_enum_name
293        };
294        mod #module_name {
295            use super::*;
296            #public_struct
297            #raw_enum
298            #data_enum
299
300            #(#variant_structs)*
301
302            #convert_impls
303        }
304    )
305}
306
307fn attr_args(attr: TokenStream2) -> HashMap<String, TokenTree> {
308    attr.into_iter()
309        .group_by(|tk| !matches!(tk, TokenTree::Punct(p) if p.as_char() == ','))
310        .into_iter()
311        .filter_map(|(cond, c)| cond.then(|| c))
312        .map(|mut triple| {
313            let ident = triple.next();
314            let eq_sign = triple.next();
315            let value = triple.next();
316
317            if !matches!(eq_sign, Some(TokenTree::Punct(eq_sign)) if eq_sign.as_char() == '=') {
318                panic!(r#"Attribute arguments should be in the form of `key = value`"#)
319            }
320
321            match (ident, value) {
322                (Some(TokenTree::Ident(ident)), Some(value)) => (ident.to_string(), value),
323                _ => panic!(r#"Attribute arguments should be in the form of `key = "value"`"#),
324            }
325        })
326        .collect()
327}
328
329/// This function *will* produce duplicate items! Don't forget to dedup before using!
330fn type_lifetimes(ty: &Type) -> SmallVec<[Lifetime; 8]> {
331    match ty {
332        Type::Array(a) => type_lifetimes(&*a.elem),
333        Type::Group(g) => type_lifetimes(&*g.elem),
334        Type::ImplTrait(t) => type_param_lifetimes(t.bounds.iter()),
335        Type::Paren(p) => type_lifetimes(&*p.elem),
336        Type::Path(p) => path_lifetimes(&p.path),
337        Type::Reference(r) => {
338            type_lifetimes(&*r.elem).tap_mut(|vec| vec.extend(r.lifetime.clone()))
339        }
340        Type::Slice(s) => type_lifetimes(&*s.elem),
341        Type::TraitObject(t) => type_param_lifetimes(t.bounds.iter()),
342        Type::Tuple(tup) => tup
343            .elems
344            .iter()
345            .flat_map(type_lifetimes)
346            .collect::<SmallVec<_>>(),
347        _ => smallvec![],
348    }
349}
350
351fn path_lifetimes(path: &Path) -> SmallVec<[Lifetime; 8]> {
352    path.segments
353        .iter()
354        .flat_map(|segment| {
355            if let PathArguments::AngleBracketed(ref args) = segment.arguments {
356                args.args
357                    .iter()
358                    .flat_map(|arg| match arg {
359                        GenericArgument::Lifetime(l) => smallvec![l.clone()],
360                        GenericArgument::Type(ty) => type_lifetimes(ty),
361                        GenericArgument::Constraint(con) => type_param_lifetimes(con.bounds.iter()),
362                        _ => smallvec![],
363                    })
364                    .collect_vec()
365            } else {
366                Vec::new()
367            }
368        })
369        .collect::<SmallVec<_>>()
370}
371
372fn type_param_lifetimes<'a>(
373    it: impl IntoIterator<Item = &'a TypeParamBound>,
374) -> SmallVec<[Lifetime; 8]> {
375    it.into_iter()
376        .flat_map(|bound| match bound {
377            TypeParamBound::Lifetime(lt) => smallvec![lt.clone()],
378            TypeParamBound::Trait(trt) => path_lifetimes(&trt.path),
379        })
380        .collect::<SmallVec<_>>()
381}
382
383#[cfg(test)]
384mod test {
385    use crate::{hybrid_tagged_impl, type_lifetimes};
386    use quote::quote;
387    use syn::parse_quote;
388    use tap::Tap;
389
390    #[test]
391    fn test_hybrid_tagged_impl() {
392        let macro_out = hybrid_tagged_impl(
393            quote!(tag = "type", fields = {frame: Number, slack: Slack,}, struct_attrs = {
394                #[derive(Debug)]
395                #[serde(rename = "UPPERCASE")]
396            }),
397            quote!(
398                #[derive(Debug)]
399                #[serde(some_other_thing)]
400                pub(super) enum Variations<'a> {
401                    A {
402                        #[field_attribute]
403                        task: T,
404                        #[serde(borrow)]
405                        time: U<'a>,
406                    },
407                    B {
408                        hours: H,
409                        intervals: I,
410                    },
411                    HasFrame {
412                        frame: F,
413                    },
414                    C,
415                    // D(Wrong)
416                }
417            ),
418        );
419
420        println!("{}", macro_out)
421    }
422
423    #[test]
424    fn extract_lifetimes() {
425        type_lifetimes(&parse_quote!(&'a Str<'b>)).tap(|vec| {
426            assert!(vec.iter().find(|x| x.ident.to_string() == "a").is_some());
427            assert!(vec.iter().find(|x| x.ident.to_string() == "b").is_some());
428        });
429
430        type_lifetimes(&parse_quote!(impl Derive + Debug + Struct<'a> + 'b)).tap(|vec| {
431            assert!(vec.iter().find(|x| x.ident.to_string() == "a").is_some());
432            assert!(vec.iter().find(|x| x.ident.to_string() == "b").is_some());
433        });
434    }
435}