shebling_codegen/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::{
5    parse_macro_input, spanned::Spanned, AngleBracketedGenericArguments, Error, Fields,
6    FieldsNamed, FieldsUnnamed, GenericArgument, ItemEnum, ItemStruct, PathArguments, Type,
7};
8
9/// Creates an implementation for a `new()` function that initializes
10/// the attached struct.
11///
12/// # Example
13/// ```
14/// use shebling_codegen::New;
15///
16/// #[derive(New)]
17/// struct Foo {
18///    // Any type that implements Into<String> can be used as the
19///    // first parameter for Foo::new.
20///    #[new(into)]
21///    foo: String,
22///    // With boxed types, the generated constructor will take
23///    // the unboxed type as a parameter.
24///    bar: Box<u32>,
25/// }
26///
27/// let foo = Foo::new("Foo", 10);
28/// assert_eq!((foo.foo, *foo.bar), ("Foo".into(), 10));
29/// ```
30#[proc_macro_derive(New, attributes(new))]
31pub fn new_derive(input: TokenStream) -> TokenStream {
32    let ItemStruct {
33        ident,
34        fields,
35        generics,
36        ..
37    } = parse_macro_input!(input as ItemStruct);
38
39    if let Fields::Named(FieldsNamed { named, .. }) = fields {
40        let mut param_list = Vec::with_capacity(named.len());
41        let mut assigns = Vec::with_capacity(named.len());
42
43        for field in named {
44            // Get the #[new] attribute if present.
45            let mut new_attrs = field
46                .attrs
47                .into_iter()
48                .filter(|attr| attr.path().is_ident("new"));
49
50            // Check that there is at most one #[new] attribute.
51            let new_attr = new_attrs.next();
52            if let Some(attr) = new_attrs.next() {
53                return error(attr.span(), "Duplicate #[new(..)] attribute.");
54            }
55
56            // Check if we should use Into<..> for the constructor parameter.
57            let mut use_into = false;
58            if let Some(new_attr) = new_attr {
59                if let Err(err) = new_attr.parse_nested_meta(|meta| {
60                    if meta.path.is_ident("into") {
61                        use_into = true;
62                        Ok(())
63                    } else {
64                        Err(meta.error("Invalid #[new(..)] content."))
65                    }
66                }) {
67                    return err.into_compile_error().into();
68                }
69            }
70
71            // Get the type to use in the parameter list and the field assignment value.
72            let field_ident = field.ident.expect("Field should be named.");
73            let field_ty = field.ty;
74            let (mut param_ty, mut field_assign) = if use_into {
75                (quote!(impl Into<#field_ty>), quote!(#field_ident.into()))
76            } else {
77                (quote!(#field_ty), quote!(#field_ident))
78            };
79
80            if let Type::Path(ty) = &field_ty {
81                let sgmt = if ty.path.segments.len() != 1 {
82                    return error(ty.span(), "Expected a single-segmented type path.");
83                } else {
84                    &ty.path.segments[0]
85                };
86
87                // If the type is boxed, use the inner type as the parameter type
88                // and box the parameter when assigning the field.
89                if sgmt.ident == "Box" {
90                    if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
91                        args,
92                        ..
93                    }) = &sgmt.arguments
94                    {
95                        if args.len() != 1 {
96                            return error(
97                                ty.span(),
98                                "Box<..> should have a single generic argument.",
99                            );
100                        }
101
102                        let inner_ty = if let GenericArgument::Type(ty) = &args[0] {
103                            ty
104                        } else {
105                            return error(ty.span(), "Expected a generic type.");
106                        };
107                        param_ty = if use_into {
108                            quote!(impl Into<#inner_ty>)
109                        } else {
110                            quote!(#inner_ty)
111                        };
112
113                        field_assign = quote!(Box::new(#field_assign));
114                    } else {
115                        return error(sgmt.span(), "Invalid Box<..> type.");
116                    }
117                }
118            } else {
119                return error(field_ty.span(), "Expected a type path.");
120            };
121
122            param_list.push(quote!(#field_ident: #param_ty));
123            assigns.push(quote!(#field_ident: #field_assign));
124        }
125
126        // Write the new() implementation.
127        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
128        let new_impl = quote! {
129            impl #impl_generics #ident #ty_generics #where_clause {
130                pub(crate) fn new(#(#param_list),*) -> Self {
131                    Self { #(#assigns),* }
132                }
133            }
134        };
135
136        new_impl.into()
137    } else {
138        error(
139            fields.span(),
140            "Only structs with named fields are supported.",
141        )
142    }
143}
144
145/// Creates a `From<..>` implementation for each variant of the attached
146/// enum.
147///
148/// # Example
149/// ```
150/// use shebling_codegen::struct_enum;
151///
152/// struct Foo {
153///     foo: u32,
154/// }
155///
156/// #[struct_enum]
157/// enum Bar {
158///     Foo(Foo),
159/// }
160///
161/// let bar = Bar::from(Foo { foo: 17 });
162/// assert!(matches!(bar, Bar::Foo(_)));
163/// ```
164#[proc_macro_attribute]
165pub fn struct_enum(_args: TokenStream, input: TokenStream) -> TokenStream {
166    // Parse the input enum.
167    let input @ ItemEnum {
168        ident,
169        variants,
170        generics,
171        ..
172    } = &parse_macro_input!(input as ItemEnum);
173
174    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
175
176    if variants.is_empty() {
177        return error(input.span(), "Unexpected empty enum.");
178    }
179
180    // For each enum variant with an inner struct, generate the From<..> implementation.
181    let mut from_impls = Vec::with_capacity(variants.len());
182    for variant in variants {
183        let variant_ident = &variant.ident;
184        let struct_ty = match &variant.fields {
185            Fields::Unit => None,
186            Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
187                if unnamed.len() != 1 {
188                    return error(unnamed.span(), "Expected a single struct identifier.");
189                }
190
191                Some(&unnamed[0].ty)
192            }
193            _ => return error(variant.span(), "Invalid struct of enums variant."),
194        };
195
196        if let Some(struct_ty) = struct_ty {
197            from_impls.push(quote! {
198                impl #impl_generics From<#struct_ty> for #ident #ty_generics #where_clause {
199                    fn from(inner: #struct_ty) -> Self {
200                        #ident::#variant_ident(inner)
201                    }
202                }
203            });
204        }
205    }
206
207    let output = quote! {
208        #input
209        #(#from_impls)*
210    };
211
212    output.into()
213}
214
215fn error(span: Span, message: &str) -> TokenStream {
216    Error::new(span, message).to_compile_error().into()
217}