xylem_codegen/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote, ToTokens};
3use syn::parse::{Parse, ParseStream};
4use syn::punctuated::Punctuated;
5use syn::{Error, Result};
6
7mod tests;
8
9#[proc_macro_derive(Xylem, attributes(xylem))]
10pub fn xylem(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
11    match xylem_impl(ts.into()) {
12        Ok(output) => output.output(),
13        Err(err) => err.into_compile_error(),
14    }
15    .into()
16}
17
18fn xylem_impl(ts: TokenStream) -> Result<Output> {
19    let input = syn::parse2::<syn::DeriveInput>(ts)?;
20    let input_ident = &input.ident;
21
22    let mut from_ident = None;
23    let mut schema = Box::new(
24        syn::parse2::<syn::Type>(quote!(crate::Schema))
25            .expect("Failed to parse literal token stream"),
26    );
27    let mut expose_from_type = false;
28    let mut input_serde = Vec::new();
29    let mut derive_list = Vec::new();
30
31    let mut processable = false;
32
33    for attr in &input.attrs {
34        if attr.path.is_ident("xylem") {
35            let attr_list: Punctuated<InputAttr, syn::Token![,]> =
36                attr.parse_args_with(Punctuated::parse_terminated)?;
37            for attr in attr_list {
38                match attr {
39                    InputAttr::Expose(ident) => {
40                        expose_from_type = true;
41                        from_ident = Some(ident);
42                    }
43                    InputAttr::Schema(new_schema) => schema = new_schema,
44                    InputAttr::Derive(macros) => derive_list.extend(macros),
45                    InputAttr::Serde(ts) => input_serde.push(quote!(#[serde(#ts)])),
46                    InputAttr::Process => {
47                        processable = true;
48                    }
49                }
50            }
51        }
52    }
53
54    let preprocess = processable.then(|| quote!(<Self as ::xylem::Processable<#schema>>::preprocess(&mut __xylem_from, __xylem_context)?;));
55    let postprocess = processable.then(|| quote!(<Self as ::xylem::Processable<#schema>>::postprocess(&mut __xylem_ret, __xylem_context)?;));
56
57    let from_ident = from_ident.unwrap_or_else(|| format_ident!("{}Xylem", &input.ident));
58
59    let vis = &input.vis;
60
61    let (generics_decl, _generics_decl_bare, generics_usage, _generics_usage_bare) =
62        if input.generics.params.is_empty() {
63            (quote!(), quote!(), quote!(), quote!())
64        } else {
65            let decl: Vec<_> = input.generics.params.iter().collect();
66            let usage: Vec<_> = input
67                .generics
68                .params
69                .iter()
70                .map(|param| match param {
71                    syn::GenericParam::Type(syn::TypeParam { ident, .. }) => quote!(#ident),
72                    syn::GenericParam::Lifetime(syn::LifetimeDef { lifetime, .. }) => {
73                        quote!(#lifetime)
74                    }
75                    syn::GenericParam::Const(syn::ConstParam { ident, .. }) => quote!(#ident),
76                })
77                .collect();
78            (quote!(<#(#decl),*>), quote!(#(#decl),*), quote!(<#(#usage),*>), quote!(#(#usage),*))
79        };
80    let generics_where = &input.generics.where_clause;
81
82    let derive = (!derive_list.is_empty()).then(|| {
83        quote! {
84            #[derive(#(#derive_list),*)]
85        }
86    });
87
88    let prefix = quote! {
89        #[doc = concat!("See [`", stringify!(#from_ident), "`]")]
90        #[automatically_derived]
91        #derive
92        #(#input_serde)*
93    };
94
95    let (from_decl, convert_expr) = match &input.data {
96        syn::Data::Struct(data) => {
97            let mut field_froms = Vec::new();
98            let mut field_convs = Vec::new();
99
100            for (field_ord, field) in data.fields.iter().enumerate() {
101                let (from, conv) = process_field(
102                    field,
103                    match &field.ident {
104                        Some(field_ident) => quote!(__xylem_from.#field_ident),
105                        None => {
106                            let field_ord = proc_macro2::Literal::usize_unsuffixed(field_ord);
107                            quote!(__xylem_from.#field_ord)
108                        }
109                    },
110                    &schema,
111                )?;
112                if let Some(from) = from {
113                    field_froms.push(from);
114                }
115                field_convs.push(conv);
116            }
117
118            let field_froms_attrs: Vec<_> = field_froms.iter().map(|ff| &ff.attrs).collect();
119            let field_froms_ident: Vec<_> = field_froms.iter().map(|ff| &ff.ident).collect();
120            let field_froms_ty: Vec<_> = field_froms.iter().map(|ff| &ff.ty).collect();
121            let field_convs_ident: Vec<_> = field_convs.iter().map(|fc| &fc.ident).collect();
122            let field_convs_expr: Vec<_> = field_convs.iter().map(|fc| &fc.expr).collect();
123
124            match &data.fields {
125                syn::Fields::Named(_) => (
126                    quote! {
127                        #prefix
128                        #vis struct #from_ident #generics_decl #generics_where {
129                            #(
130                                #field_froms_attrs
131                                #field_froms_ident: #field_froms_ty,
132                            )*
133                        }
134                    },
135                    quote! {
136                        Self {
137                            #(
138                                #field_convs_ident: #field_convs_expr,
139                            )*
140                        }
141                    },
142                ),
143                syn::Fields::Unnamed(_) => (
144                    quote! {
145                        #prefix
146                        #vis struct #from_ident #generics_decl (
147                            #(#field_froms_attrs #field_froms_ty,)*
148                        ) #generics_where;
149                    },
150                    quote! {
151                        Self (
152                            #(#field_convs_expr,)*
153                        )
154                    },
155                ),
156                syn::Fields::Unit => (
157                    quote! {
158                        #prefix
159                        #vis struct #from_ident;
160                    },
161                    quote! {
162                        Self
163                    },
164                ),
165            }
166        }
167        syn::Data::Enum(data) => {
168            let mut variant_froms = Vec::new();
169            let mut variant_matches = Vec::new();
170
171            for variant in &data.variants {
172                let mut field_froms = Vec::new();
173                let mut field_convs = Vec::new();
174
175                for (field_ord, field) in variant.fields.iter().enumerate() {
176                    let (from, conv) = process_field(
177                        field,
178                        match &field.ident {
179                            Some(ident) => quote!(#ident),
180                            None => format_ident!("__field{}", field_ord).to_token_stream(),
181                        },
182                        &schema,
183                    )?;
184                    if let Some(from) = from {
185                        field_froms.push(from);
186                    }
187                    field_convs.push(conv);
188                }
189
190                let field_froms_attrs: Vec<_> = field_froms.iter().map(|ff| &ff.attrs).collect();
191                let field_froms_ident: Vec<_> = field_froms.iter().map(|ff| &ff.ident).collect();
192                let field_froms_ty: Vec<_> = field_froms.iter().map(|ff| &ff.ty).collect();
193
194                let variant_from_ident = &variant.ident;
195                let variant_from_fields = match &variant.fields {
196                    syn::Fields::Named(_) => {
197                        quote! {{
198                            #(
199                                #field_froms_attrs
200                                #field_froms_ident: #field_froms_ty,
201                            )*
202                        }}
203                    }
204                    syn::Fields::Unnamed(_) => {
205                        quote! {(
206                            #(#field_froms_attrs #field_froms_ty),*
207                        )}
208                    }
209                    syn::Fields::Unit => quote!(),
210                };
211
212                let variant_from = quote! {
213                    #[doc = concat!("See [`", stringify!(#input_ident), "::", stringify!(#variant_from_ident), "`]")]
214                    #variant_from_ident #variant_from_fields
215                };
216                variant_froms.push(variant_from);
217
218                let variant_from_fields_pat = match &variant.fields {
219                    syn::Fields::Named(_) => {
220                        quote!({ #(#field_froms_ident),*  })
221                    }
222                    syn::Fields::Unnamed(_) => {
223                        let numbered_fields = (0..variant.fields.len()).map(|field_ord| {
224                            format_ident!("__field{}", field_ord).to_token_stream()
225                        });
226                        quote!((#(#numbered_fields),*))
227                    }
228                    syn::Fields::Unit => quote!(),
229                };
230
231                let variant_to_ident = &variant.ident;
232
233                let field_convs_ident: Vec<_> = field_convs.iter().map(|fc| &fc.ident).collect();
234                let field_convs_expr: Vec<_> = field_convs.iter().map(|fc| &fc.expr).collect();
235                let variant_to_fields_expr = match &variant.fields {
236                    syn::Fields::Named(_) => {
237                        quote!({ #(#field_convs_ident: #field_convs_expr),* })
238                    }
239                    syn::Fields::Unnamed(_) => {
240                        quote!((#(#field_convs_expr),*))
241                    }
242                    syn::Fields::Unit => quote!(),
243                };
244
245                let variant_match = quote! {
246                    #from_ident::#variant_from_ident #variant_from_fields_pat =>
247                        Self::#variant_to_ident #variant_to_fields_expr
248                };
249                variant_matches.push(variant_match);
250            }
251
252            (
253                quote! {
254                    #prefix
255                    #vis enum #from_ident #generics_decl #generics_where {
256                        #(#variant_froms),*
257                    }
258                },
259                quote! {
260                    match __xylem_from {
261                        #(#variant_matches),*
262                    }
263                },
264            )
265        }
266        syn::Data::Union(data) => {
267            return Err(Error::new_spanned(&data.union_token, "Unions are not supported"));
268        }
269    };
270
271    let xylem_impl = quote! {
272        #[automatically_derived]
273        #[allow(clippy::needless_update)]
274        impl #generics_decl ::xylem::Xylem<#schema> for #input_ident #generics_usage {
275            type From = #from_ident #generics_usage;
276            type Args = ::xylem::NoArgs;
277
278            fn convert_impl(
279                mut __xylem_from: Self::From,
280                __xylem_context: &mut <#schema as ::xylem::Schema>::Context,
281                _: &Self::Args,
282            ) -> Result<Self, <#schema as ::xylem::Schema>::Error> {
283                #preprocess
284                let mut __xylem_ret = #convert_expr;
285                #postprocess
286                Ok(__xylem_ret)
287            }
288        }
289    };
290    Ok(Output { from_decl, xylem_impl, expose_from_type })
291}
292
293struct Output {
294    from_decl:        TokenStream,
295    xylem_impl:       TokenStream,
296    expose_from_type: bool,
297}
298
299impl Output {
300    fn output(&self) -> TokenStream {
301        let from_decl = &self.from_decl;
302        let xylem_impl = &self.xylem_impl;
303
304        let inner = quote! {
305            #from_decl
306            #xylem_impl
307        };
308
309        if self.expose_from_type {
310            quote! {
311                #inner
312            }
313        } else {
314            quote! {
315                const _: () = { #inner };
316            }
317        }
318    }
319}
320
321enum InputAttr {
322    /// Exposes the `From` type in the same namespace and visibility as the derive input
323    /// using the specified identifier as the type name.
324    Expose(syn::Ident),
325    /// Specifies the schema that the conversion is defined for.
326    /// The default value is `crate::Schema`.
327    Schema(Box<syn::Type>),
328    /// Adds a serde attribute to the `From` type.
329    Serde(TokenStream),
330    /// Adds a derive macro to the `From` type.
331    Derive(Punctuated<syn::Path, syn::Token![,]>),
332    /// Call [`Processable`].
333    Process,
334}
335
336impl Parse for InputAttr {
337    fn parse(input: ParseStream) -> Result<Self> {
338        let ident: syn::Ident = input.parse()?;
339        if ident == "expose" {
340            let _: syn::Token![=] = input.parse()?;
341            Ok(Self::Expose(input.parse()?))
342        } else if ident == "schema" {
343            let _: syn::Token![=] = input.parse()?;
344            let schema: syn::Type = input.parse()?;
345            Ok(Self::Schema(Box::new(schema)))
346        } else if ident == "serde" {
347            let inner;
348            syn::parenthesized!(inner in input);
349            Ok(Self::Serde(inner.parse()?))
350        } else if ident == "derive" {
351            let inner;
352            syn::parenthesized!(inner in input);
353            Ok(Self::Derive(Punctuated::parse_terminated(&inner)?))
354        } else if ident == "process" {
355            Ok(Self::Process)
356        } else {
357            Err(Error::new_spanned(ident, "Unsupported attribute"))
358        }
359    }
360}
361
362enum FieldAttr {
363    /// Adds a serde attribute to the field.
364    Serde(TokenStream),
365    /// Preserve the field type, without performing any conversion logic.
366    Preserve(Span),
367    /// Use the specified function to convert the field.
368    ///
369    /// # Example
370    /// ```ignore
371    /// #[xylem(transform = path(Type))]
372    /// foo: Bar,
373    /// ```
374    ///
375    /// This expects a function accessible at `path`
376    /// with the signature `fn(Type) -> Result<Bar, S::Error>`.
377    /// For example, `#[xylem(transform = Ok(Bar))]` is equivalent to `#[xylem(preserve)]`.
378    ///
379    /// # Comparison with [`FieldAttr::Default`]
380    /// `transform` differs from `default` in that
381    /// `transform` generates a field in the `From` type and passes it to the function,
382    /// while `default` does not generate a field in the `From` type
383    /// and the argument is a freeform expression.
384    Transform(syn::Path, syn::Type),
385    /// Similar to [`FieldAttr::Transform`], but also accepts the `context` parameter.
386    ///
387    /// The signature is `fn(Type, &mut S::Context) -> Result<Bar, S::Error>`.
388    TransformWithContext(syn::Path, syn::Type),
389    /// Use the specified expression to generate the field value.
390    /// The field does not appear in the `From` type.
391    Default(syn::Expr),
392    /// Pass arguments to the field type.
393    Args(Span, Punctuated<ArgDef, syn::Token![,]>),
394}
395
396impl Parse for FieldAttr {
397    fn parse(input: ParseStream) -> Result<Self> {
398        let ident: syn::Ident = input.parse()?;
399        if ident == "serde" {
400            let inner;
401            syn::parenthesized!(inner in input);
402            Ok(Self::Serde(inner.parse()?))
403        } else if ident == "preserve" {
404            Ok(Self::Preserve(ident.span()))
405        } else if ident == "transform" {
406            let _: syn::Token![=] = input.parse()?;
407            let path: syn::Path = input.parse()?;
408            let inner;
409            syn::parenthesized!(inner in input);
410            let ty: syn::Type = inner.parse()?;
411            Ok(Self::Transform(path, ty))
412        } else if ident == "transform_with_context" {
413            let _: syn::Token![=] = input.parse()?;
414            let path: syn::Path = input.parse()?;
415            let inner;
416            syn::parenthesized!(inner in input);
417            let ty: syn::Type = inner.parse()?;
418            Ok(Self::TransformWithContext(path, ty))
419        } else if ident == "default" {
420            let _: syn::Token![=] = input.parse()?;
421            let expr: syn::Expr = input.parse()?;
422            Ok(Self::Default(expr))
423        } else if ident == "args" {
424            let inner;
425            syn::parenthesized!(inner in input);
426            Ok(Self::Args(ident.span(), Punctuated::parse_terminated(&inner)?))
427        } else {
428            Err(Error::new_spanned(ident, "Unsupported attribute"))
429        }
430    }
431}
432
433struct ArgDef {
434    name: syn::Ident,
435    expr: syn::Expr,
436}
437
438impl Parse for ArgDef {
439    fn parse(input: ParseStream) -> Result<Self> {
440        let name: syn::Ident = input.parse()?;
441        let _: syn::Token![=] = input.parse()?;
442        let expr: syn::Expr = input.parse()?;
443        Ok(Self { name, expr })
444    }
445}
446
447fn process_field(
448    field: &syn::Field,
449    from_expr: TokenStream,
450    schema: &syn::Type,
451) -> Result<(Option<FieldFrom>, FieldConv)> {
452    enum Mode {
453        Standard(Vec<ArgDef>),
454        Default(TokenStream),
455        Transform { ts: TokenStream, ty: Box<syn::Type>, context: bool },
456    }
457
458    let mut mode = Mode::Standard(Vec::new());
459
460    let mut from_attrs = TokenStream::new();
461
462    for attr in &field.attrs {
463        if attr.path.is_ident("xylem") {
464            let attrs: Punctuated<FieldAttr, syn::Token![,]> =
465                attr.parse_args_with(Punctuated::parse_terminated)?;
466            for attr in attrs {
467                match attr {
468                    FieldAttr::Serde(ts) => {
469                        from_attrs.extend(quote!(#[serde(#ts)]));
470                    }
471                    FieldAttr::Preserve(span) => {
472                        if !matches!(mode, Mode::Standard(_)) {
473                            return Err(Error::new(
474                                span,
475                                "Only one of `preserve`, `transform` or `default` can be used.",
476                            ));
477                        }
478                        mode = Mode::Transform {
479                            ts:      quote!(Ok),
480                            ty:      Box::new(field.ty.clone()),
481                            context: false,
482                        };
483                    }
484                    FieldAttr::Transform(path, ty) => {
485                        if !matches!(mode, Mode::Standard(_)) {
486                            return Err(Error::new_spanned(
487                                path,
488                                "Only one of `preserve`, `transform` or `default` can be used.",
489                            ));
490                        }
491                        mode = Mode::Transform {
492                            ts:      quote!(#path),
493                            ty:      Box::new(ty),
494                            context: false,
495                        };
496                    }
497                    FieldAttr::TransformWithContext(path, ty) => {
498                        if !matches!(mode, Mode::Standard(_)) {
499                            return Err(Error::new_spanned(
500                                path,
501                                "Only one of `preserve`, `transform` or `default` can be used.",
502                            ));
503                        }
504                        mode = Mode::Transform {
505                            ts:      quote!(#path),
506                            ty:      Box::new(ty),
507                            context: true,
508                        };
509                    }
510                    FieldAttr::Default(expr) => {
511                        if !matches!(mode, Mode::Standard(_)) {
512                            return Err(Error::new_spanned(
513                                expr,
514                                "Only one of `preserve`, `transform` or `default` can be used.",
515                            ));
516                        }
517                        mode = Mode::Default(quote!(#expr));
518                    }
519                    FieldAttr::Args(span, args) => match &mut mode {
520                        Mode::Standard(arg_defs) => {
521                            arg_defs.extend(args.into_iter());
522                        }
523                        _ => {
524                            return Err(Error::new(
525                                span,
526                                "Cannot use `args` if `preserve`, `transform` or `default` is \
527                                 used.",
528                            ))
529                        }
530                    },
531                }
532            }
533        }
534    }
535
536    Ok(match mode {
537        Mode::Standard(arg_defs) => (
538            Some(FieldFrom {
539                attrs: from_attrs,
540                ident: field.ident.clone(),
541                ty:    {
542                    let ty = &field.ty;
543                    quote!(<#ty as ::xylem::Xylem<#schema>>::From)
544                },
545            }),
546            FieldConv {
547                ident: field.ident.clone(),
548                expr:  {
549                    let ty = &field.ty;
550                    let arg_names = arg_defs.iter().map(|def| &def.name);
551                    let arg_exprs = arg_defs.iter().map(|def| &def.expr);
552
553                    quote! {{
554                        type Args = <#ty as ::xylem::Xylem<#schema>>::Args;
555                        ::xylem::lazy_static! {
556                            static ref __XYLEM_ARGS: Args = Args {
557                                #(#arg_names: #arg_exprs,)*
558                                ..::std::default::Default::default()
559                            };
560                        }
561                        ::xylem::Xylem::<#schema>::convert(
562                            #from_expr,
563                            __xylem_context,
564                            &*__XYLEM_ARGS,
565                        )?
566                    }}
567                },
568            },
569        ),
570        Mode::Default(expr) => (None, FieldConv { ident: field.ident.clone(), expr }),
571        Mode::Transform { ts, ty, context } => {
572            let context = context.then(|| quote!(__xylem_context));
573            (
574                Some(FieldFrom {
575                    attrs: from_attrs,
576                    ident: field.ident.clone(),
577                    ty:    quote!(#ty),
578                }),
579                FieldConv {
580                    ident: field.ident.clone(),
581                    expr:  quote! {
582                        #ts(#from_expr, #context)?
583                    },
584                },
585            )
586        }
587    })
588}
589
590#[derive(Debug)]
591struct FieldFrom {
592    /// The attributes of the field in the `From` type.
593    attrs: TokenStream,
594    /// The name of the field in the `From` type.
595    ident: Option<syn::Ident>,
596    /// The type of the field in the `From` type.
597    ty:    TokenStream,
598}
599
600#[derive(Debug)]
601struct FieldConv {
602    /// The name of the field in the `Self` type.
603    ident: Option<syn::Ident>,
604    /// The expression of the field in the constructor.
605    expr:  TokenStream,
606}