shadow_derive/
lib.rs

1extern crate proc_macro;
2extern crate syn;
3#[macro_use]
4extern crate quote;
5
6use proc_macro::TokenStream;
7use proc_macro2::Span;
8use syn::parse::{Parse, ParseStream, Parser};
9use syn::parse_macro_input;
10use syn::DeriveInput;
11use syn::Generics;
12use syn::Ident;
13use syn::Result;
14use syn::{parenthesized, Attribute, Error, Field, LitStr};
15
16#[proc_macro_derive(ShadowState, attributes(shadow, static_shadow_field))]
17pub fn shadow_state(input: TokenStream) -> TokenStream {
18    match parse_macro_input!(input as ParseInput) {
19        ParseInput::Struct(input) => {
20            let shadow_patch = generate_shadow_patch_struct(&input);
21            let shadow_state = generate_shadow_state(&input);
22            let implementation = quote! {
23                #shadow_patch
24
25                #shadow_state
26            };
27            TokenStream::from(implementation)
28        }
29        _ => {
30            todo!()
31        }
32    }
33}
34
35#[proc_macro_derive(ShadowPatch, attributes(static_shadow_field, serde))]
36pub fn shadow_patch(input: TokenStream) -> TokenStream {
37    TokenStream::from(match parse_macro_input!(input as ParseInput) {
38        ParseInput::Struct(input) => generate_shadow_patch_struct(&input),
39        ParseInput::Enum(input) => generate_shadow_patch_enum(&input),
40    })
41}
42
43enum ParseInput {
44    Struct(StructParseInput),
45    Enum(EnumParseInput),
46}
47
48#[derive(Clone)]
49struct EnumParseInput {
50    pub ident: Ident,
51    pub generics: Generics,
52}
53
54#[derive(Clone)]
55struct StructParseInput {
56    pub ident: Ident,
57    pub generics: Generics,
58    pub shadow_fields: Vec<Field>,
59    pub copy_attrs: Vec<Attribute>,
60    pub shadow_name: Option<LitStr>,
61}
62
63impl Parse for ParseInput {
64    fn parse(input: ParseStream) -> Result<Self> {
65        let derive_input = DeriveInput::parse(input)?;
66
67        let mut shadow_name = None;
68        let mut copy_attrs = vec![];
69
70        let attrs_to_copy = ["serde"];
71
72        // Parse valid container attributes
73        for attr in derive_input.attrs {
74            if attr.path.is_ident("shadow") {
75                fn shadow_arg(input: ParseStream) -> Result<LitStr> {
76                    let content;
77                    parenthesized!(content in input);
78                    content.parse()
79                }
80                shadow_name = Some(shadow_arg.parse2(attr.tokens)?);
81            } else if attrs_to_copy
82                .iter()
83                .find(|a| attr.path.is_ident(a))
84                .is_some()
85            {
86                copy_attrs.push(attr);
87            }
88        }
89
90        match derive_input.data {
91            syn::Data::Struct(syn::DataStruct { fields, .. }) => {
92                Ok(Self::Struct(StructParseInput {
93                    ident: derive_input.ident,
94                    generics: derive_input.generics,
95                    shadow_fields: fields.into_iter().collect::<Vec<_>>(),
96                    copy_attrs,
97                    shadow_name,
98                }))
99            }
100            syn::Data::Enum(syn::DataEnum { .. }) => Ok(Self::Enum(EnumParseInput {
101                ident: derive_input.ident,
102                generics: derive_input.generics,
103            })),
104            _ => Err(Error::new(
105                Span::call_site(),
106                "ShadowState & ShadowPatch can only be derived for non-tuple structs & enums",
107            )),
108        }
109    }
110}
111
112fn create_assigners(fields: &Vec<Field>) -> Vec<proc_macro2::TokenStream> {
113    fields
114        .iter()
115        .filter_map(|field| {
116            let field_name = &field.ident.clone().unwrap();
117
118            if field
119                .attrs
120                .iter()
121                .find(|a| a.path.is_ident("static_shadow_field"))
122                .is_some()
123            {
124                None
125            } else {
126                Some(quote! {
127                    if let Some(attribute) = opt.#field_name {
128                        self.#field_name.apply_patch(attribute);
129                    }
130                })
131            }
132        })
133        .collect::<Vec<_>>()
134}
135
136fn create_optional_fields(fields: &Vec<Field>) -> Vec<proc_macro2::TokenStream> {
137    fields
138        .iter()
139        .filter_map(|field| {
140            let type_name = &field.ty;
141            let attrs = field
142                .attrs
143                .iter()
144                .filter(|a| {
145                    !a.path.is_ident("static_shadow_field")
146                })
147                .collect::<Vec<_>>();
148            let field_name = &field.ident.clone().unwrap();
149
150            let type_name_string = quote! {#type_name}.to_string();
151            let type_name_string: String = type_name_string.chars().filter(|&c| c != ' ').collect();
152
153            if field
154                .attrs
155                .iter()
156                .find(|a| a.path.is_ident("static_shadow_field"))
157                .is_some()
158            {
159                None
160            } else {
161                Some(if type_name_string.starts_with("Option<") {
162                    quote! { #(#attrs)* pub #field_name: Option<rustot::shadows::Patch<<#type_name as rustot::shadows::ShadowPatch>::PatchState>> }
163                } else {
164                    quote! { #(#attrs)* pub #field_name: Option<<#type_name as rustot::shadows::ShadowPatch>::PatchState> }
165                })
166            }
167        })
168        .collect::<Vec<_>>()
169}
170
171fn generate_shadow_state(input: &StructParseInput) -> proc_macro2::TokenStream {
172    let StructParseInput {
173        ident,
174        generics,
175        shadow_name,
176        ..
177    } = input;
178
179    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
180
181    let name = match shadow_name {
182        Some(name) => quote! { Some(#name) },
183        None => quote! { None },
184    };
185
186    return quote! {
187        #[automatically_derived]
188        impl #impl_generics rustot::shadows::ShadowState for #ident #ty_generics #where_clause {
189            const NAME: Option<&'static str> = #name;
190            // const MAX_PAYLOAD_SIZE: usize = 512;
191        }
192    };
193}
194
195fn generate_shadow_patch_struct(input: &StructParseInput) -> proc_macro2::TokenStream {
196    let StructParseInput {
197        ident,
198        generics,
199        shadow_fields,
200        copy_attrs,
201        ..
202    } = input;
203
204    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
205
206    let optional_ident = format_ident!("Patch{}", ident);
207
208    let assigners = create_assigners(&shadow_fields);
209    let optional_fields = create_optional_fields(&shadow_fields);
210
211    return quote! {
212        #[automatically_derived]
213        #[derive(Default, Clone, ::serde::Deserialize, ::serde::Serialize)]
214        #(#copy_attrs)*
215        pub struct #optional_ident #generics {
216            #(
217                #optional_fields
218            ),*
219        }
220
221        #[automatically_derived]
222        impl #impl_generics rustot::shadows::ShadowPatch for #ident #ty_generics #where_clause {
223            type PatchState = #optional_ident;
224
225            fn apply_patch(&mut self, opt: Self::PatchState) {
226                #(
227                    #assigners
228                )*
229            }
230        }
231    };
232}
233
234fn generate_shadow_patch_enum(input: &EnumParseInput) -> proc_macro2::TokenStream {
235    let EnumParseInput {
236        ident, generics, ..
237    } = input;
238
239    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
240
241    return quote! {
242        #[automatically_derived]
243        impl #impl_generics rustot::shadows::ShadowPatch for #ident #ty_generics #where_clause {
244            type PatchState = #ident #ty_generics;
245
246            fn apply_patch(&mut self, opt: Self::PatchState) {
247                *self = opt;
248            }
249        }
250    };
251}