1use proc_macro::{Span, TokenStream};
2use quote::quote;
3use syn::{
4 DeriveInput, Expr, Field, Fields, Ident, Index, Lit, Meta, Type, parse_macro_input,
5 punctuated::Iter,
6};
7
8#[cfg(all(not(feature = "rkyv"), feature = "unchecked"))]
9compile_error!("Feature `unchecked` requires feature `rkyv`.");
10
11fn get_field_kvs(
12 fields: Iter<Field>,
13 is_named: bool,
14) -> Vec<(Option<&Option<Ident>>, &Type, bool, bool)> {
15 fields
16 .map(|field: &Field| {
17 if field.attrs.len() > 1 {
18 panic!("Only 1 attribute per field is supported.")
19 }
20 let (mut is_required, mut skip) = Default::default();
21
22 if let Some(attr) = field.attrs.first() {
23 if attr.path().is_ident("wopt") {
24 let mut n = 0;
25 attr.parse_nested_meta(|a| {
26 if let Some(ident) = a.path.get_ident() {
27 match ident.to_string().as_str() {
28 "required" => is_required = true,
29 "skip" => skip = true,
30 _ => panic!(
31 "Only `required` & `skip` field attributes are supported."
32 ),
33 }
34 }
35 n += 1;
36 Ok(())
37 })
38 .unwrap();
39
40 if n > 1 {
41 panic!("A field has too many `wopt` attr args (max: 1)")
42 }
43 }
44 }
45 if is_named {
46 (Some(&field.ident), &field.ty, is_required, skip)
47 } else {
48 (None, &field.ty, is_required, skip)
49 }
50 })
51 .collect()
52}
53
54#[proc_macro_derive(WithOpt, attributes(id, wopt))]
55pub fn wopt_derive(input: TokenStream) -> TokenStream {
56 let input = parse_macro_input!(input as DeriveInput);
58
59 let name = &input.ident;
61
62 #[cfg(feature = "rkyv")]
64 let mut id = None;
65
66 #[allow(unused_mut)]
67 let mut is_unit = false;
68
69 let mut is_named = false;
71
72 let info: Vec<_> = if let syn::Data::Struct(ref data) = input.data {
74 match &data.fields {
75 Fields::Named(fields) => {
76 is_named = true;
77 get_field_kvs(fields.named.iter(), true)
78 }
79 Fields::Unnamed(fields) => get_field_kvs(fields.unnamed.iter(), false),
80 _ => {
81 #[cfg(not(feature = "rkyv"))]
82 panic!("Unit structs are only supported with the `rkyv` feature.");
83
84 #[cfg(feature = "rkyv")]
85 {
86 is_unit = true;
87 vec![]
88 }
89 }
90 }
91 } else {
92 panic!("Only structs are supported");
93 };
94
95 let derives = {
97 let mut derives = Vec::new();
98
99 for attr in &input.attrs {
100 if attr.path().is_ident("wopt") {
101 let meta = attr.parse_args::<Meta>().unwrap();
102
103 match &meta {
104 Meta::List(list) => {
105 list.parse_nested_meta(|a| {
106 if let Some(ident) = a.path.get_ident() {
107 derives.push(quote! { #ident });
108 }
109 Ok(())
110 })
111 .unwrap();
112 }
113 Meta::NameValue(nv) => {
114 if nv.path.is_ident("id") {
115 #[cfg(not(feature = "rkyv"))]
116 panic!("Enable the `rkyv` feature to use the `id` attribute.");
117
118 #[cfg(feature = "rkyv")]
119 {
120 id = Some(match &nv.value {
121 Expr::Lit(expr) => match &expr.lit {
122 Lit::Int(v) => {
123 let value = v
124 .base10_parse::<u8>()
125 .expect("Only `u8` is supported.");
126 if value > 127 {
127 panic!("Value too large (max: 127)")
128 }
129 value
130 }
131 _ => panic!("Expected integer literal."),
132 },
133 _ => panic!("Expected literal expression."),
134 });
135 continue;
136 }
137 }
138 if nv.path.is_ident("bf") {
139 let code = match &nv.value {
140 Expr::Lit(expr) => match &expr.lit {
141 Lit::Str(s) => s.value(),
142 _ => panic!("Expected string literal."),
143 },
144 _ => panic!("Expected literal expression."),
145 };
146
147 let s = bf2s::bf_to_str(&code);
148 derives.extend(s.split_whitespace().map(|p| {
149 let p = Ident::new(p, Span::call_site().into());
150 quote! { #p }
151 }));
152 continue;
153 }
154 panic!("Unsupported attribute.")
155 }
156 _ => (),
157 }
158 }
159 }
160 #[cfg(feature = "rkyv")]
161 if !is_unit {
162 derives.extend([quote! { ::enum_unit::EnumUnit }]);
163 }
164 derives
165 };
166
167 #[cfg(feature = "rkyv")]
168 let id_og = id.expect("Specify the `id` attribute.");
169 #[cfg(feature = "rkyv")]
170 let id_opt = id_og + i8::MAX as u8;
171
172 let opt_name = if is_unit {
173 name.clone()
174 } else {
175 Ident::new(&format!("{}Opt", name), name.span())
176 };
177
178 #[cfg(feature = "rkyv")]
179 let unit = Ident::new(&format!("{}Unit", opt_name), Span::call_site().into());
180
181 #[cfg(feature = "rkyv")]
182 let mut field_serialization = Vec::new();
183
184 #[cfg(feature = "rkyv")]
185 let mut field_deserialization = Vec::new();
186
187 #[cfg(feature = "rkyv")]
188 let mut field_deserialization_new = Vec::new();
189
190 #[cfg(feature = "rkyv")]
191 let mut field_serialization_opt = Vec::new();
192
193 #[cfg(feature = "rkyv")]
194 let mut field_deserialization_opt = Vec::new();
195
196 let mut fields = Vec::new();
197 let mut upts = Vec::new();
198 let mut mods = Vec::new();
199 let mut take = Vec::new();
200
201 #[cfg(all(feature = "rkyv", not(feature = "unchecked")))]
202 let unwrap = Ident::new("unwrap", Span::call_site().into());
203
204 #[cfg(all(feature = "rkyv", feature = "unchecked"))]
205 let unwrap = Ident::new("unwrap_unchecked", Span::call_site().into());
206
207 for (i, (field_name_opt, field_type, is_required, is_skipped)) in info.iter().enumerate() {
208 if let Some(field_name) = field_name_opt.cloned().map(|o| o.unwrap()) {
209 #[cfg(feature = "rkyv")]
210 {
211 field_serialization.push(quote! {
212 data.extend_from_slice(
213 &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(&self.#field_name, arena.acquire()).#unwrap() },
214 );
215 });
216 field_deserialization.push(quote! {
217 h = t;
218 t += ::core::mem::size_of::<#field_type>();
219 let #field_name = unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() };
220 });
221 field_deserialization_new.push(quote! {
222 #field_name
223 });
224 }
225
226 if *is_skipped {
227 continue;
228 }
229
230 if *is_required {
231 #[cfg(feature = "rkyv")]
232 {
233 field_serialization_opt.push(quote! {
234 data.extend_from_slice(
235 &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(&self.#field_name, arena.acquire()).#unwrap() },
236 );
237 });
238
239 field_deserialization_opt.push(quote! {
240 h = t;
241 t += ::core::mem::size_of::<#field_type>();
242 new.#field_name = unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() };
243 });
244 }
245 fields.push(quote! { pub #field_name: #field_type });
246 take.push(quote! { #field_name: self.#field_name });
247 } else {
248 #[cfg(feature = "rkyv")]
249 if !is_unit {
250 let unit_name = Ident::new(
251 &convert_case::Casing::to_case(
252 &field_name.to_string(),
253 convert_case::Case::Pascal,
254 ),
255 Span::call_site().into(),
256 );
257 field_serialization_opt.push(quote! {
258 if let Some(val) = self.#field_name.as_ref() {
259 mask |= #unit::#unit_name;
260 data.extend_from_slice(
261 &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(val, arena.acquire()).#unwrap() },
262 );
263 }
264 });
265
266 field_deserialization_opt.push(quote! {
267 if mask.contains(#unit::#unit_name) {
268 h = t;
269 t += ::core::mem::size_of::<#field_type>();
270 new.#field_name = Some(unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() });
271 }
272 });
273 }
274 fields.push(quote! { pub #field_name: Option<#field_type> });
275 upts.push(quote! { if let Some(#field_name) = rhs.#field_name {
276 self.#field_name = #field_name
277 } });
278 mods.push(quote! { self.#field_name.is_some() });
279 take.push(quote! { #field_name: self.#field_name.take() });
280 }
281 } else {
282 let index = Index::from(i);
283 let var = Ident::new(&format!("_{}", i), Span::call_site().into());
284
285 #[cfg(feature = "rkyv")]
286 {
287 field_serialization.push(quote! {
288 data.extend_from_slice(
289 &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(&self.#index, arena.acquire()).#unwrap() },
290 );
291 });
292 field_deserialization.push(quote! {
293 h = t;
294 t += ::core::mem::size_of::<#field_type>();
295 let #var = unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() };
296 });
297 field_deserialization_new.push(quote! {
298 #index: #var
299 });
300 }
301
302 if *is_skipped {
303 continue;
304 }
305
306 if *is_required {
307 #[cfg(feature = "rkyv")]
308 {
309 field_serialization_opt.push(quote! {
310 data.extend_from_slice(
311 &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(&self.#index, arena.acquire()).#unwrap() },
312 );
313 });
314
315 field_deserialization_opt.push(quote! {
316 h = t;
317 t += ::core::mem::size_of::<#field_type>();
318 new.#index = unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() };
319 });
320 };
321 fields.push(quote! { pub #field_type });
322 take.push(quote! { #index: self.#index });
323 } else {
324 #[cfg(feature = "rkyv")]
325 if !is_unit {
326 let unit_name = Ident::new(
327 &format!("{}{}", enum_unit_core::prefix(), i),
328 Span::call_site().into(),
329 );
330 field_serialization_opt.push(quote! {
331 if let Some(val) = self.#index.as_ref() {
332 mask |= #unit::#unit_name;
333 data.extend_from_slice(
334 &unsafe { ::rkyv::api::high::to_bytes_with_alloc::<_, ::rkyv::rancor::Error>(val, arena.acquire()).#unwrap() },
335 );
336 }
337 });
338
339 field_deserialization_opt.push(quote! {
340 if mask.contains(#unit::#unit_name) {
341 h = t;
342 t += ::core::mem::size_of::<#field_type>();
343 new.#index = Some(unsafe { ::rkyv::from_bytes::<#field_type, ::rkyv::rancor::Error>(&bytes[h..t]).#unwrap() });
344 }
345 });
346 }
347 fields.push(quote! { pub Option<#field_type> });
348 upts.push(quote! { if let Some(#var) = rhs.#index {
349 self.#index = #var
350 } });
351 mods.push(quote! { self.#index.is_some() });
352 take.push(quote! { #index: self.#index.take() });
353 }
354 };
355 }
356
357 #[cfg(feature = "rkyv")]
358 let (serde_og, serde_opt) = if is_unit {
359 let serde = quote! {
360 pub const fn serialize() -> [u8; 1] {
361 [#id_og]
362 }
363 };
364 (serde, quote! {})
365 } else {
366 let serde_og = quote! {
367 pub fn serialize(&self) -> Vec<u8> {
368 let mut data = Vec::with_capacity(::core::mem::size_of_val(self));
369 let mut arena = ::rkyv::ser::allocator::Arena::default();
370
371 #(#field_serialization)*
372
373 let mut payload = Vec::with_capacity(1 + data.len());
374 payload.push(#id_og);
375 payload.extend_from_slice(data.as_slice());
376 payload
377 }
378
379 pub fn deserialize(bytes: &[u8]) -> Self {
380 let mut h = 0;
381 let mut t = size_of::<#unit>();
382
383 #(#field_deserialization)*
384
385 Self { #(#field_deserialization_new),* }
386 }
387 };
388
389 let serde_opt = quote! {
390 pub fn serialize(&self) -> Vec<u8> {
391 let mut data = Vec::with_capacity(::core::mem::size_of_val(self));
392 let mut arena = ::rkyv::ser::allocator::Arena::default();
393 let mut mask = #unit::empty();
394
395 #(#field_serialization_opt)*
396
397 let mut payload = Vec::with_capacity(1 + ::core::mem::size_of::<#unit>() + data.len());
398 payload.push(#id_opt);
399 payload.extend_from_slice(mask.bits().to_le_bytes().as_slice());
400 payload.extend_from_slice(data.as_slice());
401 payload
402 }
403
404 pub fn deserialize(bytes: &[u8]) -> Self {
405 let mut new = Self::default();
406
407 let mut h = 0;
408 let mut t = size_of::<#unit>();
409
410 let mask_bytes = &bytes[..t];
411 let mask_bits = <#unit as ::bitflags::Flags>::Bits::from_le_bytes(
412 unsafe { mask_bytes.try_into().#unwrap() }
413 );
414 let mask = #unit::from_bits_retain(mask_bits);
415 #(#field_deserialization_opt)*
416 new
417 }
418 };
419 (serde_og, serde_opt)
420 };
421
422 if is_unit {
424 #[cfg(not(feature = "rkyv"))]
425 return quote! {}.into();
426
427 #[cfg(feature = "rkyv")]
428 return quote! {
429 impl #name {
430 pub const ID: u8 = #id_og;
431 #serde_og
432 }
433 }
434 .into();
435 }
436
437 let structure = if is_named {
439 quote! {
440 #[derive(#(#derives),*)]
441 pub struct #opt_name {
442 #(#fields),*
443 }
444 }
445 } else if is_unit {
446 quote! {}
447 } else {
448 quote! {
449 #[derive(#(#derives),*)]
450 pub struct #opt_name(#(#fields),*);
451 }
452 };
453
454 let (impl_name, impl_opt_name) = if upts.is_empty() || is_unit {
455 Default::default()
456 } else {
457 let patch = quote! {
458 pub fn patch(&mut self, rhs: &mut #opt_name) {
459 let rhs = rhs.take();
460 #(#upts)*
461 }
462 };
463 let is_modified = quote! {
464 pub const fn is_modified(&self) -> bool {
465 #(#mods)||*
466 }
467 };
468 let take = quote! {
469 pub const fn take(&mut self) -> Self {
470 Self {
471 #(#take),*
472 }
473 }
474 };
475
476 (
477 quote! {
478 #patch
479 },
480 quote! {
481 #is_modified
482 #take
483 },
484 )
485 };
486
487 #[cfg(feature = "rkyv")]
488 let impl_name_id = quote! {
489 pub const ID: u8 = #id_og;
490 };
491 #[cfg(not(feature = "rkyv"))]
492 let impl_name_id = quote! {};
493
494 #[cfg(feature = "rkyv")]
495 let impl_name = quote! {
496 #impl_name
497 #serde_og
498 };
499 let impl_name = quote! {
500 impl #name {
501 #impl_name_id
502 #impl_name
503 }
504 };
505
506 #[cfg(feature = "rkyv")]
507 let impl_opt_id = quote! {
508 pub const ID: u8 = #id_opt;
509 };
510 #[cfg(not(feature = "rkyv"))]
511 let impl_opt_id = quote! {};
512
513 #[cfg(feature = "rkyv")]
514 let impl_opt_name = quote! {
515 #impl_opt_name
516 #serde_opt
517 };
518 let impl_opt_name = quote! {
519 impl #opt_name {
520 #impl_opt_id
521 #impl_opt_name
522 }
523 };
524
525 quote! {
526 #structure
527 #impl_name
528 #impl_opt_name
529 }
530 .into()
531}