1use std::collections::{HashMap, HashSet};
2
3use heck::ToSnakeCase;
4use itertools::{izip, Itertools};
5use proc_macro::TokenStream;
6use proc_macro2::{Group, Ident, TokenStream as TokenStream2, TokenTree};
7use quote::{quote, ToTokens};
8use smallvec::{smallvec, SmallVec};
9use syn::{
10 Data, DeriveInput, Fields, FieldsNamed, GenericArgument, Lifetime, Path, PathArguments, Type,
11 TypeParamBound,
12};
13use tap::{Pipe, Tap};
14
15#[proc_macro_attribute]
16pub fn hybrid_tagged(attr: TokenStream, item: TokenStream) -> TokenStream {
17 hybrid_tagged_impl(attr.into(), item.into()).into()
18}
19
20fn hybrid_tagged_impl(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
21 let tagged_type: DeriveInput = syn::parse2(item).unwrap();
22 let Data::Enum(tagged_enum) = tagged_type.data else {
23 panic!("hybrid_tagged is meant to be invoked on an enum")
24 };
25
26 let args = attr_args(attr);
27
28 let common_fields = args
29 .get("fields")
30 .expect("Argument `fields` was not provided")
31 .pipe(|tokens| {
32 syn::parse2::<FieldsNamed>(tokens.to_token_stream())
33 .expect("Fields should be written with the same notation as a struct declaration, inside curly braces")
34 });
35 let common_fields_inner = &common_fields.named;
36 let tag = args
37 .get("tag")
38 .expect("Argument `tag` was not provided")
39 .to_token_stream();
40 let variants = tagged_enum.variants;
41 let generics = tagged_type.generics;
42 let empty_variants = variants
44 .iter()
45 .cloned()
46 .filter(|variant| matches!(variant.fields, Fields::Unit))
47 .map(|variant| variant.ident)
48 .collect::<HashSet<_>>();
49
50 let container_name = tagged_type.ident;
51 let module_name = Ident::new(
52 &format!("{}_data", container_name.to_string().to_snake_case()),
53 container_name.span(),
54 );
55 let data_enum_name = Ident::new(&format!("{container_name}Data"), container_name.span());
56
57 let original_attrs = tagged_type.attrs;
58
59 let visibility = tagged_type.vis;
60
61 let variant_lifetimes = variants
62 .iter()
63 .map(|variant| {
64 variant
65 .fields
66 .iter()
67 .flat_map(|fld| type_lifetimes(&fld.ty))
68 .collect::<HashSet<_>>()
69 .pipe(|lifetimes| {
70 (lifetimes.len() > 0).then(|| {
71 let it = lifetimes.iter();
72 Some(quote!(<#(#it),*>))
73 })
74 })
75 })
76 .collect_vec();
77
78 let raw_variants = variants.clone().tap_mut(|variants| {
80 variants
81 .iter_mut()
82 .zip(variant_lifetimes.iter())
83 .for_each(|(variant, lft)| {
84 let attrs = &variant.attrs;
85 let name = &variant.ident;
86 let common_fields = common_fields_inner.iter();
87 let borrow_attr = lft.is_some().then(|| quote!(#[serde(borrow)]));
88
89 *variant = if empty_variants.contains(&variant.ident) {
90 syn::parse_quote!(
91 #(#attrs)*
92 #name {
93 #(#common_fields),*
94 }
95 )
96 } else {
97 syn::parse_quote!(
98 #(#attrs)*
99 #name {
100 #borrow_attr data: #name #lft,
101 #(#common_fields),*
102 }
103 )
104 };
105 })
106 });
107
108 let data_variants = variants.clone().tap_mut(|variants| {
109 variants.iter_mut().for_each(|variant| {
110 variant.attrs.clear();
111 match variant.fields {
112 Fields::Named(ref mut f) => {
113 for field in &mut f.named {
114 field.attrs.clear();
115 }
116 }
117 _ => (),
118 }
119 })
120 });
121
122 let struct_attrs = args.get("struct_attrs").map(|tokens| {
123 syn::parse2::<Group>(tokens.into_token_stream())
124 .unwrap()
125 .stream()
126 });
127
128 let raw_enum = quote!(
130 #[derive(serde::Serialize, serde::Deserialize)]
131 #[serde(tag=#tag)]
132 #(#original_attrs)*
133 enum Raw #generics {
134 #raw_variants
135 }
136 );
137
138 let data_enum = quote!(
140 #[derive(Clone)]
141 #struct_attrs
142 #visibility enum #data_enum_name #generics {
143 #data_variants
144 }
145 );
146
147 let common_fields_visibility = common_fields_inner.clone().tap_mut(|fields| {
149 fields
150 .iter_mut()
151 .for_each(|field| field.vis = visibility.clone())
152 });
153 let public_struct = {
154 let borrow_attr = if generics.lifetimes().next().is_some() {
155 quote!(#[serde(borrow)])
156 } else {
157 quote!()
158 };
159 quote!(
160 #[derive(serde::Serialize, serde::Deserialize, Clone)]
161 #[serde(from = "Raw", into = "Raw")]
162 #struct_attrs
163 #visibility struct #container_name #generics {
164 #borrow_attr pub data: #data_enum_name #generics,
165 #common_fields_visibility
166 }
167 )
168 };
169
170 let common_fields_names = common_fields_inner
171 .iter()
172 .cloned()
173 .map(|field| field.ident.expect("Fields of this enum must be named"))
174 .collect_vec();
175 let common_fields_renamed = common_fields_names
176 .iter()
177 .cloned()
178 .map(|name| syn::parse_str::<Ident>(&format!("c_{name}")).unwrap())
179 .collect_vec(); let variant_fields_names = variants
181 .iter()
182 .map(|variant| {
183 variant
184 .fields
185 .iter()
186 .map(|field| field.ident.clone().unwrap())
187 .collect_vec()
188 })
189 .collect_vec();
190 let raw_fields_names = raw_variants
191 .iter()
192 .map(|variant| {
193 variant
194 .fields
195 .iter()
196 .map(|field| field.ident.clone().unwrap())
197 .collect_vec()
198 })
199 .collect_vec();
200
201 let variant_names = variants
202 .iter()
203 .cloned()
204 .map(|variant| variant.ident)
205 .collect_vec();
206
207 let variant_structs = variants.iter().zip(variant_lifetimes.iter()).map(|(variant, lft)| {
208 let name = &variant.ident;
209 let fields = &variant.fields;
210
211 if empty_variants.contains(&variant.ident) {
212 quote!( #[derive(serde::Serialize, serde::Deserialize)] #struct_attrs struct #name; )
213 } else {
214 quote!( #[derive(serde::Serialize, serde::Deserialize)] #struct_attrs struct #name #lft #fields )
215 }
216 });
217
218 let (convert_from_raw, convert_to_raw): (Vec<_>, Vec<_>) =
219 izip!(variant_fields_names, raw_fields_names)
220 .zip(variant_names)
221 .map(|((variant, _), variant_name)| {
222 if empty_variants.contains(&variant_name) {
223 let from_raw = quote!(
224 Raw :: #variant_name {
225 #(#common_fields_names: #common_fields_renamed),*, ..
226 } => Self {
227 data: #data_enum_name :: #variant_name ,
228 #(#common_fields_names: #common_fields_renamed),*
229 }
230 );
231
232 let to_raw = quote!(
233 #data_enum_name :: #variant_name => Self :: #variant_name {
234 #(#common_fields_names: f. #common_fields_names)*,
235 }
236 );
237
238 (from_raw, to_raw)
239 } else {
240 let from_raw = quote!(
241 Raw :: #variant_name {
242 data: #variant_name {
243 #(#variant),*
244 },
245 #(#common_fields_names: #common_fields_renamed),*
246 } => Self {
247 data: #data_enum_name :: #variant_name {
248 #(#variant),*
249 }, #(#common_fields_names: #common_fields_renamed),*
250 }
251 );
252
253 let to_raw = quote!(
254 #data_enum_name :: #variant_name {
255 #(#variant),*
256 } => Self :: #variant_name {
257 data: #variant_name {
258 #(#variant),*
259 },
260 #(#common_fields_names: f. #common_fields_names),*
261 }
262 );
263
264 (from_raw, to_raw)
265 }
266 })
267 .unzip();
268
269 let convert_impls = quote!(
271 impl #generics From<Raw #generics> for #container_name #generics {
272 fn from(f: Raw #generics) -> Self {
273 match f {
274 #(#convert_from_raw),*
275 }
276 }
277 }
278
279 impl #generics From<#container_name #generics > for Raw #generics {
280 fn from(f: #container_name #generics) -> Self {
281 match f.data {
282 #(#convert_to_raw),*
283 }
284 }
285 }
286 );
287
288 quote!(
290 #visibility use #module_name::{
291 #container_name,
292 #data_enum_name
293 };
294 mod #module_name {
295 use super::*;
296 #public_struct
297 #raw_enum
298 #data_enum
299
300 #(#variant_structs)*
301
302 #convert_impls
303 }
304 )
305}
306
307fn attr_args(attr: TokenStream2) -> HashMap<String, TokenTree> {
308 attr.into_iter()
309 .group_by(|tk| !matches!(tk, TokenTree::Punct(p) if p.as_char() == ','))
310 .into_iter()
311 .filter_map(|(cond, c)| cond.then(|| c))
312 .map(|mut triple| {
313 let ident = triple.next();
314 let eq_sign = triple.next();
315 let value = triple.next();
316
317 if !matches!(eq_sign, Some(TokenTree::Punct(eq_sign)) if eq_sign.as_char() == '=') {
318 panic!(r#"Attribute arguments should be in the form of `key = value`"#)
319 }
320
321 match (ident, value) {
322 (Some(TokenTree::Ident(ident)), Some(value)) => (ident.to_string(), value),
323 _ => panic!(r#"Attribute arguments should be in the form of `key = "value"`"#),
324 }
325 })
326 .collect()
327}
328
329fn type_lifetimes(ty: &Type) -> SmallVec<[Lifetime; 8]> {
331 match ty {
332 Type::Array(a) => type_lifetimes(&*a.elem),
333 Type::Group(g) => type_lifetimes(&*g.elem),
334 Type::ImplTrait(t) => type_param_lifetimes(t.bounds.iter()),
335 Type::Paren(p) => type_lifetimes(&*p.elem),
336 Type::Path(p) => path_lifetimes(&p.path),
337 Type::Reference(r) => {
338 type_lifetimes(&*r.elem).tap_mut(|vec| vec.extend(r.lifetime.clone()))
339 }
340 Type::Slice(s) => type_lifetimes(&*s.elem),
341 Type::TraitObject(t) => type_param_lifetimes(t.bounds.iter()),
342 Type::Tuple(tup) => tup
343 .elems
344 .iter()
345 .flat_map(type_lifetimes)
346 .collect::<SmallVec<_>>(),
347 _ => smallvec![],
348 }
349}
350
351fn path_lifetimes(path: &Path) -> SmallVec<[Lifetime; 8]> {
352 path.segments
353 .iter()
354 .flat_map(|segment| {
355 if let PathArguments::AngleBracketed(ref args) = segment.arguments {
356 args.args
357 .iter()
358 .flat_map(|arg| match arg {
359 GenericArgument::Lifetime(l) => smallvec![l.clone()],
360 GenericArgument::Type(ty) => type_lifetimes(ty),
361 GenericArgument::Constraint(con) => type_param_lifetimes(con.bounds.iter()),
362 _ => smallvec![],
363 })
364 .collect_vec()
365 } else {
366 Vec::new()
367 }
368 })
369 .collect::<SmallVec<_>>()
370}
371
372fn type_param_lifetimes<'a>(
373 it: impl IntoIterator<Item = &'a TypeParamBound>,
374) -> SmallVec<[Lifetime; 8]> {
375 it.into_iter()
376 .flat_map(|bound| match bound {
377 TypeParamBound::Lifetime(lt) => smallvec![lt.clone()],
378 TypeParamBound::Trait(trt) => path_lifetimes(&trt.path),
379 })
380 .collect::<SmallVec<_>>()
381}
382
383#[cfg(test)]
384mod test {
385 use crate::{hybrid_tagged_impl, type_lifetimes};
386 use quote::quote;
387 use syn::parse_quote;
388 use tap::Tap;
389
390 #[test]
391 fn test_hybrid_tagged_impl() {
392 let macro_out = hybrid_tagged_impl(
393 quote!(tag = "type", fields = {frame: Number, slack: Slack,}, struct_attrs = {
394 #[derive(Debug)]
395 #[serde(rename = "UPPERCASE")]
396 }),
397 quote!(
398 #[derive(Debug)]
399 #[serde(some_other_thing)]
400 pub(super) enum Variations<'a> {
401 A {
402 #[field_attribute]
403 task: T,
404 #[serde(borrow)]
405 time: U<'a>,
406 },
407 B {
408 hours: H,
409 intervals: I,
410 },
411 HasFrame {
412 frame: F,
413 },
414 C,
415 }
417 ),
418 );
419
420 println!("{}", macro_out)
421 }
422
423 #[test]
424 fn extract_lifetimes() {
425 type_lifetimes(&parse_quote!(&'a Str<'b>)).tap(|vec| {
426 assert!(vec.iter().find(|x| x.ident.to_string() == "a").is_some());
427 assert!(vec.iter().find(|x| x.ident.to_string() == "b").is_some());
428 });
429
430 type_lifetimes(&parse_quote!(impl Derive + Debug + Struct<'a> + 'b)).tap(|vec| {
431 assert!(vec.iter().find(|x| x.ident.to_string() == "a").is_some());
432 assert!(vec.iter().find(|x| x.ident.to_string() == "b").is_some());
433 });
434 }
435}