Skip to main content

serde_xpath_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::Attribute;
4use syn::Data;
5use syn::DeriveInput;
6use syn::Fields;
7use syn::Meta;
8use syn::parse_macro_input;
9
10#[proc_macro_derive(Deserialize, attributes(xpath, serde))]
11pub fn derive_deserialize(input: TokenStream) -> TokenStream {
12    let input = parse_macro_input!(input as DeriveInput);
13    let name = &input.ident;
14
15    // Get the root xpath from struct attributes
16    let root_xpath = get_xpath_attr(&input.attrs).unwrap_or_default();
17
18    let fields = match &input.data {
19        Data::Struct(data) => match &data.fields {
20            Fields::Named(fields) => &fields.named,
21            _ => panic!("Only named fields are supported"),
22        },
23        _ => panic!("Only structs are supported"),
24    };
25
26    let mut field_descriptors = Vec::new();
27    let mut field_deserializations = Vec::new();
28
29    for field in fields {
30        let field_name = field.ident.as_ref().unwrap();
31        let field_name_str = field_name.to_string();
32
33        let (xpath, is_text) = get_field_xpath_attr(&field.attrs);
34        let xpath = xpath.unwrap_or_default();
35
36        let has_serde_default = has_serde_default_attr(&field.attrs);
37        let field_type = &field.ty;
38
39        // Determine field kind based on type and attributes
40        let (kind, is_optional, is_vec) = determine_field_kind(
41            field_type,
42            is_text,
43            has_serde_default,
44            &xpath,
45        );
46
47        let kind_tokens = match kind.as_str() {
48            "Attribute" => {
49                quote! { serde_xpath::__private::FieldKind::Attribute }
50            }
51            "Text" => quote! { serde_xpath::__private::FieldKind::Text },
52            "Sequence" => quote! { serde_xpath::__private::FieldKind::Sequence },
53            "Optional" => quote! { serde_xpath::__private::FieldKind::Optional },
54            "OptionalSequence" => {
55                quote! { serde_xpath::__private::FieldKind::OptionalSequence }
56            }
57            _ => quote! { serde_xpath::__private::FieldKind::Element },
58        };
59
60        field_descriptors.push(quote! {
61            serde_xpath::__private::FieldDescriptor {
62                name: #field_name_str,
63                xpath: #xpath,
64                kind: #kind_tokens,
65            }
66        });
67
68        // Generate field deserialization code
69        let deser_code = if is_vec {
70            // Vec field - use sequence deserialization
71            let inner_type = extract_inner_type(field_type, "Vec");
72            if let Some(inner) = inner_type {
73                if is_simple_type(&inner) {
74                    quote! {
75                        let #field_name: #field_type = deser.deserialize_field(#field_name_str)?;
76                    }
77                } else {
78                    // Nested struct in Vec
79                    generate_vec_nested_deser(
80                        field_name,
81                        &field_name_str,
82                        &xpath,
83                        &inner,
84                    )
85                }
86            } else {
87                quote! {
88                    let #field_name: #field_type = deser.deserialize_field(#field_name_str)?;
89                }
90            }
91        } else if is_optional {
92            // Option field
93            let inner_type = extract_inner_type(field_type, "Option");
94            if let Some(inner) = inner_type {
95                if is_simple_type(&inner) {
96                    generate_optional_simple_deser(
97                        field_name,
98                        &field_name_str,
99                        &xpath,
100                        &inner,
101                        is_text,
102                    )
103                } else {
104                    // Nested struct in Option
105                    generate_optional_nested_deser(
106                        field_name,
107                        &field_name_str,
108                        &xpath,
109                        &inner,
110                    )
111                }
112            } else {
113                quote! {
114                    let #field_name: #field_type = deser.deserialize_field(#field_name_str)?;
115                }
116            }
117        } else if is_text || xpath.starts_with("/@") {
118            // Simple text or attribute
119            quote! {
120                let #field_name: #field_type = deser.deserialize_field(#field_name_str)?;
121            }
122        } else if is_simple_type(field_type) {
123            quote! {
124                let #field_name: #field_type = deser.deserialize_field(#field_name_str)?;
125            }
126        } else {
127            // Nested struct
128            generate_nested_struct_deser(
129                field_name,
130                &field_name_str,
131                &xpath,
132                field_type,
133            )
134        };
135
136        field_deserializations.push(deser_code);
137    }
138
139    let field_names: Vec<_> =
140        fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
141
142    let expanded = quote! {
143        impl<'de> serde::Deserialize<'de> for #name {
144            fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
145            where
146                D: serde::Deserializer<'de>,
147            {
148                // This path is used when deserializing from a non-XPath deserializer
149                // For now, return an error suggesting to use serde_xpath::from_str
150                Err(serde::de::Error::custom(
151                    "use serde_xpath::from_str to deserialize this type"
152                ))
153            }
154        }
155
156        impl serde_xpath::FromXml for #name {
157            fn from_xml(xml: &str) -> std::result::Result<Self, serde_xpath::Error> {
158                Self::__deserialize_from_xml(xml)
159            }
160        }
161
162        impl #name {
163            const __XPATH_DESCRIPTOR: serde_xpath::__private::StructDescriptor =
164                serde_xpath::__private::StructDescriptor {
165                    name: stringify!(#name),
166                    root_xpath: #root_xpath,
167                    fields: &[
168                        #(#field_descriptors),*
169                    ],
170                };
171
172            #[doc(hidden)]
173            pub fn __deserialize_from_xml(
174                xml: &str,
175            ) -> std::result::Result<Self, serde_xpath::Error> {
176                serde_xpath::from_str_with_descriptor(
177                    xml,
178                    &Self::__XPATH_DESCRIPTOR,
179                    |deser| {
180                        #(#field_deserializations)*
181
182                        Ok(#name {
183                            #(#field_names),*
184                        })
185                    },
186                )
187            }
188
189            #[doc(hidden)]
190            pub fn __deserialize_from_node<'a, 'input>(
191                node: roxmltree::Node<'a, 'input>,
192                parent_descriptor: &'static serde_xpath::__private::StructDescriptor,
193            ) -> std::result::Result<Self, serde_xpath::Error> {
194                serde_xpath::__private::deserialize_struct_from_node(
195                    node,
196                    &Self::__XPATH_DESCRIPTOR,
197                    |deser| {
198                        #(#field_deserializations)*
199
200                        Ok(#name {
201                            #(#field_names),*
202                        })
203                    },
204                )
205            }
206        }
207    };
208
209    TokenStream::from(expanded)
210}
211
212fn get_xpath_attr(attrs: &[Attribute]) -> Option<String> {
213    for attr in attrs {
214        if attr.path().is_ident("xpath")
215            && let Meta::List(meta_list) = &attr.meta
216        {
217            let tokens = meta_list.tokens.to_string();
218            // Parse the string literal
219            if let Some(s) =
220                tokens.strip_prefix('"').and_then(|s| s.strip_suffix('"'))
221            {
222                return Some(s.to_string());
223            }
224        }
225    }
226    None
227}
228
229fn get_field_xpath_attr(attrs: &[Attribute]) -> (Option<String>, bool) {
230    for attr in attrs {
231        if attr.path().is_ident("xpath")
232            && let Meta::List(meta_list) = &attr.meta
233        {
234            let tokens = meta_list.tokens.to_string();
235            // Check if it contains serde_xpath::Text
236            let is_text =
237                tokens.contains("serde_xpath::Text") || tokens.contains("Text");
238
239            // Parse the xpath string (first argument)
240            let parts: Vec<&str> = tokens.split(',').collect();
241            if let Some(first) = parts.first() {
242                let first = first.trim();
243                if let Some(s) =
244                    first.strip_prefix('"').and_then(|s| s.strip_suffix('"'))
245                {
246                    return (Some(s.to_string()), is_text);
247                }
248            }
249        }
250    }
251    (None, false)
252}
253
254fn has_serde_default_attr(attrs: &[Attribute]) -> bool {
255    for attr in attrs {
256        if attr.path().is_ident("serde")
257            && let Meta::List(meta_list) = &attr.meta
258        {
259            let tokens = meta_list.tokens.to_string();
260            if tokens.contains("default") {
261                return true;
262            }
263        }
264    }
265    false
266}
267
268fn determine_field_kind(
269    ty: &syn::Type,
270    is_text: bool,
271    has_default: bool,
272    xpath: &str,
273) -> (String, bool, bool) {
274    let type_str = quote!(#ty).to_string();
275
276    let is_vec = type_str.starts_with("Vec <") || type_str.starts_with("Vec<");
277    let is_option =
278        type_str.starts_with("Option <") || type_str.starts_with("Option<");
279
280    // Check if this is an attribute xpath (ends with /@attr or is just /@attr)
281    let is_attribute = xpath.contains("/@");
282
283    if is_text {
284        return ("Text".to_string(), is_option, is_vec);
285    }
286
287    if is_attribute {
288        return ("Attribute".to_string(), is_option, is_vec);
289    }
290
291    if is_vec {
292        if is_option || has_default {
293            return ("OptionalSequence".to_string(), false, true);
294        }
295        return ("Sequence".to_string(), false, true);
296    }
297
298    if is_option || has_default {
299        return ("Optional".to_string(), true, false);
300    }
301
302    ("Element".to_string(), false, false)
303}
304
305fn extract_inner_type(ty: &syn::Type, wrapper: &str) -> Option<syn::Type> {
306    if let syn::Type::Path(type_path) = ty
307        && let Some(segment) = type_path.path.segments.last()
308        && segment.ident == wrapper
309        && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
310        && let Some(syn::GenericArgument::Type(inner)) = args.args.first()
311    {
312        return Some(inner.clone());
313    }
314    None
315}
316
317fn is_simple_type(ty: &syn::Type) -> bool {
318    let type_str = quote!(#ty).to_string();
319    matches!(
320        type_str.as_str(),
321        "String"
322            | "str"
323            | "i8"
324            | "i16"
325            | "i32"
326            | "i64"
327            | "u8"
328            | "u16"
329            | "u32"
330            | "u64"
331            | "f32"
332            | "f64"
333            | "bool"
334            | "char"
335            | "& str"
336            | "& 'static str"
337    )
338}
339
340fn generate_nested_struct_deser(
341    field_name: &syn::Ident,
342    field_name_str: &str,
343    xpath: &str,
344    field_type: &syn::Type,
345) -> proc_macro2::TokenStream {
346    quote! {
347        let #field_name: #field_type = {
348            let xpath = serde_xpath::xpath::XPath::parse(#xpath)
349                .map_err(|e| serde_xpath::Error::XPath(e))?;
350            let result = xpath.evaluate_single(deser.node())
351                .ok_or_else(|| serde_xpath::Error::MissingField(#field_name_str.to_string()))?;
352            let node = result.as_node()
353                .ok_or_else(|| serde_xpath::Error::XPath(format!("expected element for field '{}'", #field_name_str)))?;
354            #field_type::__deserialize_from_node(node, &Self::__XPATH_DESCRIPTOR)?
355        };
356    }
357}
358
359fn generate_optional_nested_deser(
360    field_name: &syn::Ident,
361    _field_name_str: &str,
362    xpath: &str,
363    inner_type: &syn::Type,
364) -> proc_macro2::TokenStream {
365    quote! {
366        let #field_name: Option<#inner_type> = {
367            let xpath = serde_xpath::xpath::XPath::parse(#xpath)
368                .map_err(|e| serde_xpath::Error::XPath(e))?;
369            match xpath.evaluate_single(deser.node()) {
370                Some(result) => {
371                    match result.as_node() {
372                        Some(node) => Some(#inner_type::__deserialize_from_node(node, &Self::__XPATH_DESCRIPTOR)?),
373                        None => None,
374                    }
375                }
376                None => None,
377            }
378        };
379    }
380}
381
382fn generate_optional_simple_deser(
383    field_name: &syn::Ident,
384    _field_name_str: &str,
385    xpath: &str,
386    inner_type: &syn::Type,
387    is_text: bool,
388) -> proc_macro2::TokenStream {
389    if is_text {
390        quote! {
391            let #field_name: Option<#inner_type> = {
392                let xpath = serde_xpath::xpath::XPath::parse(#xpath)
393                    .map_err(|e| serde_xpath::Error::XPath(e))?;
394                match xpath.evaluate_single(deser.node()) {
395                    Some(result) => {
396                        result.text().map(|s| s.to_string())
397                    }
398                    None => None,
399                }
400            };
401        }
402    } else {
403        quote! {
404            let #field_name: Option<#inner_type> = {
405                let xpath = serde_xpath::xpath::XPath::parse(#xpath)
406                    .map_err(|e| serde_xpath::Error::XPath(e))?;
407                match xpath.evaluate_single(deser.node()) {
408                    Some(result) => {
409                        match result.as_str() {
410                            Some(s) => Some(s.to_string()),
411                            None => result.text().map(|s| s.to_string()),
412                        }
413                    }
414                    None => None,
415                }
416            };
417        }
418    }
419}
420
421fn generate_vec_nested_deser(
422    field_name: &syn::Ident,
423    _field_name_str: &str,
424    xpath: &str,
425    inner_type: &syn::Type,
426) -> proc_macro2::TokenStream {
427    quote! {
428        let #field_name: Vec<#inner_type> = {
429            let xpath = serde_xpath::xpath::XPath::parse(#xpath)
430                .map_err(|e| serde_xpath::Error::XPath(e))?;
431            let results = xpath.evaluate_all(deser.node());
432            let mut items = Vec::new();
433            for result in results {
434                if let Some(node) = result.as_node() {
435                    items.push(#inner_type::__deserialize_from_node(node, &Self::__XPATH_DESCRIPTOR)?);
436                }
437            }
438            items
439        };
440    }
441}