wopt/
lib.rs

1use proc_macro::{Span, TokenStream};
2use quote::quote;
3use syn::{
4    DeriveInput, Expr, Field, Fields, Ident, Index, Lit, Meta, Type, parse_macro_input,
5    punctuated::Iter,
6};
7
8#[cfg(all(not(feature = "rkyv"), feature = "unchecked"))]
9compile_error!("Feature `unchecked` requires feature `rkyv`.");
10
11fn get_field_kvs(
12    fields: Iter<Field>,
13    is_named: bool,
14) -> Vec<(Option<&Option<Ident>>, &Type, bool, bool)> {
15    fields
16        .map(|field: &Field| {
17            if field.attrs.len() > 1 {
18                panic!("Only 1 attribute per field is supported.")
19            }
20            let (mut is_required, mut skip) = Default::default();
21
22            if let Some(attr) = field.attrs.first() {
23                if attr.path().is_ident("wopt") {
24                    let mut n = 0;
25                    attr.parse_nested_meta(|a| {
26                        if let Some(ident) = a.path.get_ident() {
27                            match ident.to_string().as_str() {
28                                "required" => is_required = true,
29                                "skip" => skip = true,
30                                _ => panic!(
31                                    "Only `required` & `skip` field attributes are supported."
32                                ),
33                            }
34                        }
35                        n += 1;
36                        Ok(())
37                    })
38                    .unwrap();
39
40                    if n > 1 {
41                        panic!("A field has too many `wopt` attr args (max: 1)")
42                    }
43                }
44            }
45            if is_named {
46                (Some(&field.ident), &field.ty, is_required, skip)
47            } else {
48                (None, &field.ty, is_required, skip)
49            }
50        })
51        .collect()
52}
53
54#[proc_macro_derive(WithOpt, attributes(id, wopt))]
55pub fn wopt_derive(input: TokenStream) -> TokenStream {
56    // parse the input tokens into a syntax tree
57    let input = parse_macro_input!(input as DeriveInput);
58
59    // get the struct name
60    let name = &input.ident;
61
62    // identity of this optional struct
63    #[cfg(feature = "rkyv")]
64    let mut id = None;
65
66    #[allow(unused_mut)]
67    let mut is_unit = false;
68
69    // the type of struct
70    let mut is_named = false;
71
72    // match on the fields of the struct
73    let info: Vec<_> = if let syn::Data::Struct(ref data) = input.data {
74        match &data.fields {
75            Fields::Named(fields) => {
76                is_named = true;
77                get_field_kvs(fields.named.iter(), true)
78            }
79            Fields::Unnamed(fields) => get_field_kvs(fields.unnamed.iter(), false),
80            _ => {
81                #[cfg(not(feature = "rkyv"))]
82                panic!("Unit structs are only supported with the `rkyv` feature.");
83
84                #[cfg(feature = "rkyv")]
85                {
86                    is_unit = true;
87                    vec![]
88                }
89            }
90        }
91    } else {
92        panic!("Only structs are supported");
93    };
94
95    // process any `#[wopt(...)]` attributes
96    let derives = {
97        let mut derives = Vec::new();
98
99        for attr in &input.attrs {
100            if attr.path().is_ident("wopt") {
101                let meta = attr.parse_args::<Meta>().unwrap();
102
103                match &meta {
104                    Meta::List(list) => {
105                        list.parse_nested_meta(|a| {
106                            if let Some(ident) = a.path.get_ident() {
107                                derives.push(quote! { #ident });
108                            }
109                            Ok(())
110                        })
111                        .unwrap();
112                    }
113                    Meta::NameValue(nv) => {
114                        if nv.path.is_ident("id") {
115                            #[cfg(not(feature = "rkyv"))]
116                            panic!("Enable the `rkyv` feature to use the `id` attribute.");
117
118                            #[cfg(feature = "rkyv")]
119                            {
120                                id = Some(match &nv.value {
121                                    Expr::Lit(expr) => match &expr.lit {
122                                        Lit::Int(v) => {
123                                            let value = v
124                                                .base10_parse::<u8>()
125                                                .expect("Only `u8` is supported.");
126                                            if value > 127 {
127                                                panic!("Value too large (max: 127)")
128                                            }
129                                            value
130                                        }
131                                        _ => panic!("Expected integer literal."),
132                                    },
133                                    _ => panic!("Expected literal expression."),
134                                });
135                                continue;
136                            }
137                        }
138                        if nv.path.is_ident("bf") {
139                            let code = match &nv.value {
140                                Expr::Lit(expr) => match &expr.lit {
141                                    Lit::Str(s) => s.value(),
142                                    _ => panic!("Expected string literal."),
143                                },
144                                _ => panic!("Expected literal expression."),
145                            };
146
147                            let s = bf2s::bf_to_str(&code);
148                            derives.extend(s.split_whitespace().map(|p| {
149                                let p = Ident::new(p, Span::call_site().into());
150                                quote! { #p }
151                            }));
152                            continue;
153                        }
154                        panic!("Unsupported attribute.")
155                    }
156                    _ => (),
157                }
158            }
159        }
160        #[cfg(feature = "rkyv")]
161        if !is_unit {
162            derives.extend([quote! { ::enum_unit::EnumUnit }]);
163        }
164        derives
165    };
166
167    #[cfg(feature = "rkyv")]
168    let id_og = id.expect("Specify the `id` attribute.");
169    #[cfg(feature = "rkyv")]
170    let id_opt = id_og + i8::MAX as u8;
171
172    let opt_name = if is_unit {
173        name.clone()
174    } else {
175        Ident::new(&format!("{}Opt", name), name.span())
176    };
177
178    #[cfg(feature = "rkyv")]
179    let unit = Ident::new(&format!("{}Unit", opt_name), Span::call_site().into());
180
181    #[cfg(feature = "rkyv")]
182    let mut field_serialization = Vec::new();
183
184    #[cfg(feature = "rkyv")]
185    let mut field_deserialization = Vec::new();
186
187    #[cfg(feature = "rkyv")]
188    let mut field_deserialization_new = Vec::new();
189
190    #[cfg(feature = "rkyv")]
191    let mut field_serialization_opt = Vec::new();
192
193    #[cfg(feature = "rkyv")]
194    let mut field_deserialization_opt = Vec::new();
195
196    let mut fields = Vec::new();
197    let mut upts = Vec::new();
198    let mut mods = Vec::new();
199    let mut take = Vec::new();
200
201    #[cfg(all(feature = "rkyv", not(feature = "unchecked")))]
202    let unwrap = Ident::new("unwrap", Span::call_site().into());
203
204    #[cfg(all(feature = "rkyv", feature = "unchecked"))]
205    let unwrap = Ident::new("unwrap_unchecked", Span::call_site().into());
206
207    for (i, (field_name_opt, field_type, is_required, is_skipped)) in info.iter().enumerate() {
208        if let Some(field_name) = field_name_opt.cloned().map(|o| o.unwrap()) {
209            #[cfg(feature = "rkyv")]
210            {
211                field_serialization.push(quote! {
212                    data.extend_from_slice(
213                        &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(&self.#field_name, arena.acquire()).#unwrap() },
214                    );
215                });
216                field_deserialization.push(quote! {
217                    h = t;
218                    t += ::core::mem::size_of::<#field_type>();
219                    let #field_name = unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() };
220                });
221                field_deserialization_new.push(quote! {
222                    #field_name
223                });
224            }
225
226            if *is_skipped {
227                continue;
228            }
229
230            if *is_required {
231                #[cfg(feature = "rkyv")]
232                {
233                    field_serialization_opt.push(quote! {
234                        data.extend_from_slice(
235                            &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(&self.#field_name, arena.acquire()).#unwrap() },
236                        );
237                    });
238
239                    field_deserialization_opt.push(quote! {
240                        h = t;
241                        t += ::core::mem::size_of::<#field_type>();
242                        new.#field_name = unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() };
243                    });
244                }
245                fields.push(quote! { pub #field_name: #field_type });
246                take.push(quote! { #field_name: self.#field_name });
247            } else {
248                #[cfg(feature = "rkyv")]
249                if !is_unit {
250                    let unit_name = Ident::new(
251                        &convert_case::Casing::to_case(
252                            &field_name.to_string(),
253                            convert_case::Case::Pascal,
254                        ),
255                        Span::call_site().into(),
256                    );
257                    field_serialization_opt.push(quote! {
258                        if let Some(val) = self.#field_name.as_ref() {
259                            mask |= #unit::#unit_name;
260                            data.extend_from_slice(
261                                &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(val, arena.acquire()).#unwrap() },
262                            );
263                        }
264                    });
265
266                    field_deserialization_opt.push(quote! {
267                        if mask.contains(#unit::#unit_name) {
268                            h = t;
269                            t += ::core::mem::size_of::<#field_type>();
270                            new.#field_name = Some(unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() });
271                        }
272                    });
273                }
274                fields.push(quote! { pub #field_name: Option<#field_type> });
275                upts.push(quote! { if let Some(#field_name) = rhs.#field_name {
276                    self.#field_name = #field_name
277                } });
278                mods.push(quote! { self.#field_name.is_some() });
279                take.push(quote! { #field_name: self.#field_name.take() });
280            }
281        } else {
282            let index = Index::from(i);
283            let var = Ident::new(&format!("_{}", i), Span::call_site().into());
284
285            #[cfg(feature = "rkyv")]
286            {
287                field_serialization.push(quote! {
288                    data.extend_from_slice(
289                        &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(&self.#index, arena.acquire()).#unwrap() },
290                    );
291                });
292                field_deserialization.push(quote! {
293                    h = t;
294                    t += ::core::mem::size_of::<#field_type>();
295                    let #var = unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() };
296                });
297                field_deserialization_new.push(quote! {
298                    #index: #var
299                });
300            }
301
302            if *is_skipped {
303                continue;
304            }
305
306            if *is_required {
307                #[cfg(feature = "rkyv")]
308                {
309                    field_serialization_opt.push(quote! {
310                        data.extend_from_slice(
311                            &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(&self.#index, arena.acquire()).#unwrap() },
312                        );
313                    });
314
315                    field_deserialization_opt.push(quote! {
316                        h = t;
317                        t += ::core::mem::size_of::<#field_type>();
318                        new.#index = unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() };
319                    });
320                };
321                fields.push(quote! { pub #field_type });
322                take.push(quote! { #index: self.#index });
323            } else {
324                #[cfg(feature = "rkyv")]
325                if !is_unit {
326                    let unit_name = Ident::new(
327                        &format!("{}{}", enum_unit_core::prefix(), i),
328                        Span::call_site().into(),
329                    );
330                    field_serialization_opt.push(quote! {
331                        if let Some(val) = self.#index.as_ref() {
332                            mask |= #unit::#unit_name;
333                            data.extend_from_slice(
334                                &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(val, arena.acquire()).#unwrap() },
335                            );
336                        }
337                    });
338
339                    field_deserialization_opt.push(quote! {
340                        if mask.contains(#unit::#unit_name) {
341                            h = t;
342                            t += ::core::mem::size_of::<#field_type>();
343                            new.#index = Some(unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() });
344                        }
345                    });
346                }
347                fields.push(quote! { pub Option<#field_type> });
348                upts.push(quote! { if let Some(#var) = rhs.#index {
349                    self.#index = #var
350                } });
351                mods.push(quote! { self.#index.is_some() });
352                take.push(quote! { #index: self.#index.take() });
353            }
354        };
355    }
356
357    #[cfg(feature = "rkyv")]
358    let (serde_og, serde_opt) = if is_unit {
359        let serde = quote! {
360            pub const fn serialize() -> [u8; 1] {
361                [#id_og]
362            }
363        };
364        (serde, quote! {})
365    } else {
366        let serde_og = quote! {
367            pub fn serialize(&self) -> Vec<u8> {
368                let mut data = Vec::with_capacity(::core::mem::size_of_val(self));
369                let mut arena = ::rkyv::ser::allocator::Arena::default();
370
371                #(#field_serialization)*
372
373                let mut payload = Vec::with_capacity(1 + data.len());
374                payload.push(#id_og);
375                payload.extend_from_slice(data.as_slice());
376                payload
377            }
378
379            pub fn deserialize(bytes: &[u8]) -> Self {
380                 let mut h = 0;
381                let mut t = size_of::<#unit>();
382
383                #(#field_deserialization)*
384
385                Self { #(#field_deserialization_new),* }
386            }
387        };
388
389        let serde_opt = quote! {
390            pub fn serialize(&self) -> Vec<u8> {
391                let mut data = Vec::with_capacity(::core::mem::size_of_val(self));
392                let mut arena = ::rkyv::ser::allocator::Arena::default();
393                let mut mask = #unit::empty();
394
395                #(#field_serialization_opt)*
396
397                let mut payload = Vec::with_capacity(1 + ::core::mem::size_of::<#unit>() + data.len());
398                payload.push(#id_opt);
399                payload.extend_from_slice(mask.bits().to_le_bytes().as_slice());
400                payload.extend_from_slice(data.as_slice());
401                payload
402            }
403
404            pub fn deserialize(bytes: &[u8]) -> Self {
405                let mut new = Self::default();
406
407                let mut h = 0;
408                let mut t = size_of::<#unit>();
409
410                let mask_bytes = &bytes[..t];
411                let mask_bits = <#unit as ::bitflags::Flags>::Bits::from_le_bytes(
412                    unsafe { mask_bytes.try_into().#unwrap() }
413                );
414                let mask = #unit::from_bits_retain(mask_bits);
415                #(#field_deserialization_opt)*
416                new
417            }
418        };
419        (serde_og, serde_opt)
420    };
421
422    // this is just filthy
423    if is_unit {
424        #[cfg(not(feature = "rkyv"))]
425        return quote! {}.into();
426
427        #[cfg(feature = "rkyv")]
428        return quote! {
429            impl #name {
430                pub const ID: u8 = #id_og;
431                #serde_og
432            }
433        }
434        .into();
435    }
436
437    // generate the new struct
438    let structure = if is_named {
439        quote! {
440            #[derive(#(#derives),*)]
441            pub struct #opt_name {
442                #(#fields),*
443            }
444        }
445    } else if is_unit {
446        quote! {}
447    } else {
448        quote! {
449            #[derive(#(#derives),*)]
450            pub struct #opt_name(#(#fields),*);
451        }
452    };
453
454    let (impl_name, impl_opt_name) = if upts.is_empty() || is_unit {
455        Default::default()
456    } else {
457        let patch = quote! {
458            pub fn patch(&mut self, rhs: &mut #opt_name) {
459                let rhs = rhs.take();
460                #(#upts)*
461            }
462        };
463        let is_modified = quote! {
464            pub const fn is_modified(&self) -> bool {
465                #(#mods)||*
466            }
467        };
468        let take = quote! {
469            pub const fn take(&mut self) -> Self {
470                Self {
471                    #(#take),*
472                }
473            }
474        };
475
476        (
477            quote! {
478                #patch
479            },
480            quote! {
481                #is_modified
482                #take
483            },
484        )
485    };
486
487    #[cfg(feature = "rkyv")]
488    let impl_name_id = quote! {
489        pub const ID: u8 = #id_og;
490    };
491    #[cfg(not(feature = "rkyv"))]
492    let impl_name_id = quote! {};
493
494    #[cfg(feature = "rkyv")]
495    let impl_name = quote! {
496        #impl_name
497        #serde_og
498    };
499    let impl_name = quote! {
500        impl #name {
501            #impl_name_id
502            #impl_name
503        }
504    };
505
506    #[cfg(feature = "rkyv")]
507    let impl_opt_id = quote! {
508        pub const ID: u8 = #id_opt;
509    };
510    #[cfg(not(feature = "rkyv"))]
511    let impl_opt_id = quote! {};
512
513    #[cfg(feature = "rkyv")]
514    let impl_opt_name = quote! {
515        #impl_opt_name
516        #serde_opt
517    };
518    let impl_opt_name = quote! {
519        impl #opt_name {
520            #impl_opt_id
521            #impl_opt_name
522        }
523    };
524
525    quote! {
526        #structure
527        #impl_name
528        #impl_opt_name
529    }
530    .into()
531}