1use proc_macro::TokenStream;
7use quote::quote;
8use syn::{Data, DeriveInput, Fields, GenericArgument, PathArguments, Type, parse_macro_input};
9
10#[proc_macro_derive(SerializeFields)]
40pub fn serialize_fields_derive(input: TokenStream) -> TokenStream {
41 let input = parse_macro_input!(input as DeriveInput);
42
43 let struct_name = &input.ident;
44 let selector_name = format!("{}SerializeFieldSelector", struct_name);
45 let selector_ident = syn::Ident::new(&selector_name, struct_name.span());
46
47 let fields = match &input.data {
49 Data::Struct(data) => match &data.fields {
50 Fields::Named(fields) => &fields.named,
51 _ => panic!("SerializeFields only supports structs with named fields"),
52 },
53 _ => panic!("SerializeFields only supports structs"),
54 };
55
56 let mut selector_fields = Vec::new();
58 let mut enable_match_arms = Vec::new();
59 let mut new_field_inits = Vec::new();
60 let mut serialize_fields = Vec::new();
61
62 let field_enum_name = format!("{}Field", struct_name);
64 let field_enum_ident = syn::Ident::new(&field_enum_name, struct_name.span());
65 let mut enum_variants = Vec::new();
66 let mut enable_enum_match_arms = Vec::new();
67 let mut as_dot_path_arms = Vec::new();
68 let mut deserialize_match_arms = Vec::new();
69 #[cfg(feature = "schemars")]
70 let mut schema_simple_fields: Vec<String> = Vec::new();
71 #[cfg(feature = "schemars")]
72 let mut schema_nested_prefixes: Vec<(String, String)> = Vec::new(); for field in fields {
75 let field_ident = field.ident.as_ref().unwrap();
76
77 let field_name_str = strip_raw_prefix(&field_ident.to_string());
79
80 let field_ident = field_ident;
82
83 let (is_nested, nested_type) = analyze_field_type(&field.ty);
85
86 let variant_name = to_pascal_case(&field_name_str);
88 let variant_ident = syn::Ident::new(&variant_name, field_ident.span());
89
90 if is_nested {
91 let nested_selector_type = syn::Ident::new(
92 &format!("{}SerializeFieldSelector", nested_type),
93 field_ident.span(),
94 );
95 let nested_field_enum = syn::Ident::new(
96 &format!("{}Field", nested_type),
97 field_ident.span(),
98 );
99
100 selector_fields.push(quote! {
101 #[serde(skip_serializing_if = "Option::is_none")]
102 pub #field_ident: Option<#nested_selector_type>
103 });
104
105 enable_match_arms.push(quote! {
106 #field_name_str => {
107 match &mut self.#field_ident {
108 Some(nested) => nested.enable(&field_hierarchy[1..]),
109 None => {
110 let mut new_nested = #nested_selector_type::new();
111 new_nested.enable(&field_hierarchy[1..]);
112 self.#field_ident = Some(new_nested);
113 }
114 }
115 }
116 });
117
118 serialize_fields.push(quote! {
119 if let Some(ref nested_selector) = field_selector.#field_ident {
120 state.serialize_field(#field_name_str, &SerializeFields(&data.#field_ident, nested_selector))?;
121 }
122 });
123
124 enum_variants.push(quote! {
126 #variant_ident(#nested_field_enum)
127 });
128
129 enable_enum_match_arms.push(quote! {
130 #field_enum_ident::#variant_ident(nested) => {
131 match &mut self.#field_ident {
132 Some(selector) => {
133 selector.enable_enum(nested);
134 }
135 None => {
136 let mut new_nested = #nested_selector_type::new();
137 new_nested.enable_enum(nested);
138 self.#field_ident = Some(new_nested);
139 }
140 }
141 }
142 });
143
144 as_dot_path_arms.push(quote! {
145 #field_enum_ident::#variant_ident(nested) => {
146 format!("{}.{}", #field_name_str, nested.as_dot_path())
147 }
148 });
149
150 deserialize_match_arms.push(quote! {
151 s if s.starts_with(concat!(#field_name_str, ".")) => {
152 let rest = &s[#field_name_str.len() + 1..];
153 Ok(#field_enum_ident::#variant_ident(rest.parse()?))
154 }
155 });
156
157 #[cfg(feature = "schemars")]
158 schema_nested_prefixes.push((field_name_str.clone(), nested_type.clone()));
159 } else {
160 selector_fields.push(quote! {
161 #[serde(skip_serializing_if = "Option::is_none")]
162 pub #field_ident: Option<()>
163 });
164
165 enable_match_arms.push(quote! {
166 #field_name_str => self.#field_ident = Some(())
167 });
168
169 serialize_fields.push(quote! {
170 if field_selector.#field_ident.is_some() {
171 state.serialize_field(#field_name_str, &data.#field_ident)?;
172 }
173 });
174
175 enum_variants.push(quote! {
177 #variant_ident
178 });
179
180 enable_enum_match_arms.push(quote! {
181 #field_enum_ident::#variant_ident => self.#field_ident = Some(())
182 });
183
184 as_dot_path_arms.push(quote! {
185 #field_enum_ident::#variant_ident => #field_name_str.to_string()
186 });
187
188 deserialize_match_arms.push(quote! {
189 #field_name_str => Ok(#field_enum_ident::#variant_ident)
190 });
191
192 #[cfg(feature = "schemars")]
193 schema_simple_fields.push(field_name_str.clone());
194 }
195
196 new_field_inits.push(quote! {
197 #field_ident: None
198 });
199 }
200
201 #[cfg(feature = "schemars")]
203 let schema_nested_enum_types: Vec<_> = schema_nested_prefixes
204 .iter()
205 .map(|(_, nested_type)| {
206 let ident = syn::Ident::new(&format!("{}Field", nested_type), struct_name.span());
207 quote! { #ident }
208 })
209 .collect();
210 #[cfg(feature = "schemars")]
211 let schema_nested_prefix_strs: Vec<_> = schema_nested_prefixes
212 .iter()
213 .map(|(prefix, _)| prefix.clone())
214 .collect();
215
216 let count_enabled_fields = fields
218 .iter()
219 .map(|field: &syn::Field| {
220 let field_ident = field.ident.as_ref().unwrap();
221 quote! {
222 + if field_selector.#field_ident.is_some() { 1 } else { 0 }
223 }
224 })
225 .collect::<Vec<_>>();
226
227 #[cfg(feature = "schemars")]
229 let schemars_impl = quote! {
230 impl ::schemars::JsonSchema for #field_enum_ident {
231 fn schema_name() -> ::std::borrow::Cow<'static, str> {
232 ::std::borrow::Cow::Borrowed(stringify!(#field_enum_ident))
233 }
234
235 fn json_schema(generator: &mut ::schemars::SchemaGenerator) -> ::schemars::Schema {
236 let mut all_values: Vec<String> = Vec::new();
238
239 #(all_values.push(#schema_simple_fields.to_string());)*
241
242 #(
244 let nested_schema = <#schema_nested_enum_types as ::schemars::JsonSchema>::json_schema(generator);
246 if let Some(obj) = nested_schema.as_object() {
247 if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
248 for val in enum_values {
249 if let Some(s) = val.as_str() {
250 all_values.push(format!("{}.{}", #schema_nested_prefix_strs, s));
251 }
252 }
253 }
254 }
255 )*
256
257 ::schemars::json_schema!({
258 "type": "string",
259 "enum": all_values,
260 "description": concat!("Field selector for ", stringify!(#struct_name), " - serializes as dot notation (e.g., \"field.nested\")")
261 })
262 }
263 }
264 };
265
266 #[cfg(not(feature = "schemars"))]
267 let schemars_impl = quote! {};
268
269 let expanded = quote! {
271 #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
274 pub enum #field_enum_ident {
275 #(#enum_variants,)*
276 }
277
278 impl #field_enum_ident {
279 pub fn as_dot_path(&self) -> String {
281 match self {
282 #(#as_dot_path_arms,)*
283 }
284 }
285 }
286
287 impl ::std::fmt::Debug for #field_enum_ident {
288 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
289 write!(f, "{}", self.as_dot_path())
290 }
291 }
292
293 impl ::std::fmt::Display for #field_enum_ident {
294 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
295 write!(f, "{}", self.as_dot_path())
296 }
297 }
298
299 impl ::std::str::FromStr for #field_enum_ident {
300 type Err = String;
301
302 fn from_str(s: &str) -> Result<Self, Self::Err> {
303 match s {
304 #(#deserialize_match_arms,)*
305 _ => Err(format!("Unknown field: {}", s)),
306 }
307 }
308 }
309
310 impl ::serde::Serialize for #field_enum_ident {
311 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
312 where
313 S: ::serde::Serializer,
314 {
315 serializer.serialize_str(&self.as_dot_path())
316 }
317 }
318
319 impl<'de> ::serde::Deserialize<'de> for #field_enum_ident {
320 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
321 where
322 D: ::serde::Deserializer<'de>,
323 {
324 let s = String::deserialize(deserializer)?;
325 s.parse().map_err(::serde::de::Error::custom)
326 }
327 }
328
329 #schemars_impl
330
331 #[derive(Debug, Clone, PartialEq, Eq, Hash, ::serde::Serialize)]
332 pub struct #selector_ident {
333 #(#selector_fields,)*
334 }
335
336 impl #selector_ident {
337 pub fn new() -> Self {
338 #selector_ident {
339 #(#new_field_inits,)*
340 }
341 }
342
343 pub fn enable_dot_hierarchy(&mut self, field: &str) {
344 let split: Vec<&str> = field.split('.').collect();
345 self.enable(&split);
346 }
347
348 pub fn enable(&mut self, field_hierarchy: &[&str]) {
349 if field_hierarchy.is_empty() {
350 return;
351 }
352
353 match field_hierarchy[0] {
354 #(#enable_match_arms,)*
355 _ => {}
356 }
357 }
358
359 pub fn enable_enum(&mut self, field: #field_enum_ident) {
361 match field {
362 #(#enable_enum_match_arms,)*
363 }
364 }
365 }
366
367 impl Default for #selector_ident {
368 fn default() -> Self {
369 Self::new()
370 }
371 }
372
373 impl ::serialize_fields::FieldSelector for #selector_ident {
374 fn new() -> Self {
375 Self::new()
376 }
377
378 fn enable_dot_hierarchy(&mut self, field: &str) {
379 self.enable_dot_hierarchy(field)
380 }
381
382 fn enable(&mut self, field_hierarchy: &[&str]) {
383 self.enable(field_hierarchy)
384 }
385 }
386
387 impl ::serialize_fields::SerializeFieldsTrait for #struct_name {
388 type FieldSelector = #selector_ident;
389
390 fn serialize_fields(&self) -> Self::FieldSelector {
391 #selector_ident::new()
392 }
393
394 fn serialize<__S>(
395 &self,
396 field_selector: &Self::FieldSelector,
397 __serializer: __S,
398 ) -> Result<__S::Ok, __S::Error>
399 where
400 __S: ::serde::Serializer,
401 {
402 use ::serde::ser::SerializeStruct;
403 use ::serialize_fields::SerializeFields;
404
405 let data = self;
406
407 let field_count = 0 #(#count_enabled_fields)*;
409
410 let mut state = __serializer.serialize_struct(stringify!(#struct_name), field_count)?;
411
412 #(#serialize_fields)*
413
414 state.end()
415 }
416 }
417 };
418
419 TokenStream::from(expanded)
420}
421
422fn strip_raw_prefix(s: &str) -> String {
424 if s.starts_with("r#") {
425 s[2..].to_string()
426 } else {
427 s.to_string()
428 }
429}
430
431fn to_pascal_case(s: &str) -> String {
433 s.split('_')
434 .map(|word| {
435 let mut chars = word.chars();
436 match chars.next() {
437 Some(first) => first.to_uppercase().chain(chars).collect(),
438 None => String::new(),
439 }
440 })
441 .collect()
442}
443
444fn analyze_field_type(ty: &Type) -> (bool, String) {
446 match ty {
447 Type::Path(type_path) => {
448 let last_segment = type_path.path.segments.last().unwrap();
449 let type_name = last_segment.ident.to_string();
450
451 match type_name.as_str() {
452 "u8" | "u16" | "u32" | "u64" | "u128" | "usize" | "i8" | "i16" | "i32" | "i64"
454 | "i128" | "isize" | "f32" | "f64" | "bool" | "char" | "String" => {
455 (false, String::new())
456 }
457
458 "PathBuf" | "SystemTime" | "Duration" => (false, String::new()),
460
461 "Option" | "Vec" | "HashMap" | "BTreeMap" | "HashSet" | "BTreeSet" => {
463 if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
464 if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
465 return analyze_field_type(inner_ty);
466 }
467 }
468 (false, String::new())
469 }
470
471 "Result" | "Box" | "Rc" | "Arc" => (false, String::new()),
473
474 _ => (true, type_name),
476 }
477 }
478 Type::Array(type_array) => {
479 analyze_field_type(&type_array.elem)
481 }
482 Type::Tuple(_type_tuple) => {
483 (false, String::new())
485 }
486 _ => (false, String::new()),
487 }
488}