tryparse_derive/
lib.rs

1//! Derive macros for tryparse
2//!
3//! This crate provides the `LlmDeserialize` derive macro for automatically
4//! generating fuzzy deserialization logic from Rust types.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Type};
9
10/// Derives the `LlmDeserialize` trait for structs and enums.
11///
12/// This macro generates a custom deserialization implementation using BAML's
13/// algorithms for fuzzy field matching and type coercion.
14///
15/// # Features
16///
17/// - **Fuzzy field matching**: Handles different naming conventions (userName ↔ user_name)
18/// - **Fuzzy enum matching**: Case-insensitive, substring, and edit-distance matching for variants
19/// - **Union types**: Score-based variant selection with `#[llm(union)]`
20/// - **Optional fields**: Automatic handling of `Option<T>` fields
21/// - **Transformation tracking**: Records all coercions applied during parsing
22///
23/// # Example
24///
25/// ```ignore
26/// use tryparse::deserializer::LlmDeserialize;
27///
28/// #[derive(LlmDeserialize)]
29/// struct User {
30///     name: String,
31///     age: u32,
32///     email: Option<String>, // Optional field
33/// }
34///
35/// // Handles messy input like:
36/// // {"userName": "Alice", "age": "30"}  // camelCase + string number
37/// ```
38///
39/// # Union Types
40///
41/// ```ignore
42/// #[derive(LlmDeserialize)]
43/// #[llm(union)]
44/// enum Value {
45///     Number(i64),
46///     Text(String),
47/// }
48///
49/// // Automatically picks the best matching variant
50/// ```
51#[proc_macro_derive(LlmDeserialize, attributes(llm))]
52pub fn derive_llm_deserialize(input: TokenStream) -> TokenStream {
53    let input = parse_macro_input!(input as DeriveInput);
54
55    let name = &input.ident;
56    let generics = &input.generics;
57    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
58
59    match &input.data {
60        Data::Struct(data_struct) => {
61            let deserialize_impl = generate_struct_deserialize(name, data_struct);
62
63            let expanded = quote! {
64                impl #impl_generics ::tryparse::deserializer::LlmDeserialize for #name #ty_generics #where_clause {
65                    #deserialize_impl
66                }
67            };
68
69            TokenStream::from(expanded)
70        }
71        Data::Enum(data_enum) => {
72            // Check if this is a union enum (has #[llm(union)] attribute)
73            let is_union = has_union_attribute(&input.attrs);
74
75            let deserialize_impl = if is_union {
76                generate_union_deserialize(name, data_enum, &input.attrs)
77            } else {
78                generate_enum_deserialize(name, data_enum, &input.attrs)
79            };
80
81            let expanded = quote! {
82                impl #impl_generics ::tryparse::deserializer::LlmDeserialize for #name #ty_generics #where_clause {
83                    #deserialize_impl
84                }
85            };
86
87            TokenStream::from(expanded)
88        }
89        Data::Union(_) => {
90            syn::Error::new_spanned(input, "LlmDeserialize cannot be derived for unions")
91                .to_compile_error()
92                .into()
93        }
94    }
95}
96
97fn generate_struct_deserialize(
98    name: &syn::Ident,
99    data: &syn::DataStruct,
100) -> proc_macro2::TokenStream {
101    match &data.fields {
102        Fields::Named(fields) => {
103            let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
104            let field_types: Vec<_> = fields.named.iter().map(|f| &f.ty).collect();
105            let field_name_strs: Vec<_> = fields
106                .named
107                .iter()
108                .map(|f| f.ident.as_ref().unwrap().to_string())
109                .collect();
110
111            // Check if each field is Option<T>
112            let is_optional: Vec<_> = field_types.iter().map(|ty| is_option_type(ty)).collect();
113
114            // Extract inner type for Option<T> fields
115            let inner_types: Vec<_> = field_types
116                .iter()
117                .zip(&is_optional)
118                .map(|(ty, opt)| {
119                    if *opt {
120                        extract_option_inner(ty)
121                    } else {
122                        (*ty).clone()
123                    }
124                })
125                .collect();
126
127            let name_str = name.to_string();
128
129            // Generate field descriptor setup (collect to Vec for reuse)
130            let field_descriptors: Vec<_> = field_name_strs
131                .iter()
132                .zip(&field_types)
133                .zip(&is_optional)
134                .map(|((name, ty), opt)| {
135                    let type_name = quote!(stringify!(#ty)).to_string();
136                    quote! {
137                        .field(::tryparse::deserializer::FieldDescriptor::new(
138                            #name,
139                            #type_name,
140                            #opt
141                        ))
142                    }
143                })
144                .collect();
145
146            // Generate field extraction for try_deserialize (returns Option)
147            let field_extractions_strict: Vec<_> = field_names
148                .iter()
149                .zip(&inner_types)
150                .zip(&is_optional)
151                .map(|((field_name, inner_ty), opt)| {
152                    let field_name_str = field_name.as_ref().unwrap().to_string();
153                    if *opt {
154                        // Optional field
155                        quote! {
156                            let #field_name = fields.get(#field_name_str)
157                                .and_then(|v| v.downcast_ref::<#inner_ty>())
158                                .cloned();
159                        }
160                    } else {
161                        // Required field - return None if missing
162                        quote! {
163                            let #field_name = fields.get(#field_name_str)
164                                .and_then(|v| v.downcast_ref::<#inner_ty>())
165                                .cloned()?;
166                        }
167                    }
168                })
169                .collect();
170
171            // Generate field extraction for deserialize (returns Result)
172            let field_extractions_lenient: Vec<_> = field_names.iter().zip(&inner_types).zip(&is_optional).map(|((field_name, inner_ty), opt)| {
173                let field_name_str = field_name.as_ref().unwrap().to_string();
174                if *opt {
175                    // Optional field
176                    quote! {
177                        let #field_name = fields.get(#field_name_str)
178                            .and_then(|v| v.downcast_ref::<#inner_ty>())
179                            .cloned();
180                    }
181                } else {
182                    // Required field
183                    quote! {
184                        let #field_name = fields.get(#field_name_str)
185                            .and_then(|v| v.downcast_ref::<#inner_ty>())
186                            .cloned()
187                            .ok_or_else(|| ::tryparse::error::ParseError::DeserializeFailed(
188                                ::tryparse::error::DeserializeError::missing_field(#field_name_str)
189                            ))?;
190                    }
191                }
192            }).collect();
193
194            quote! {
195                fn try_deserialize(
196                    value: &::tryparse::value::FlexValue,
197                    ctx: &mut ::tryparse::deserializer::CoercionContext,
198                ) -> Option<Self> {
199                    use std::any::Any;
200
201                    let mut deserializer = ::tryparse::deserializer::StructDeserializer::new()
202                        #(#field_descriptors)*;
203
204                    let fields = deserializer.try_deserialize(
205                        value,
206                        ctx,
207                        #name_str,
208                        |field_name, field_value, field_ctx| {
209                            // Dispatch to the appropriate field type's LlmDeserialize impl (strict mode only)
210                            match field_name {
211                                #(
212                                    #field_name_strs => {
213                                        // Try strict deserialization
214                                        <#inner_types as ::tryparse::deserializer::LlmDeserialize>::try_deserialize(field_value, field_ctx)
215                                            .map(|v| Box::new(v) as Box<dyn Any>)
216                                    }
217                                )*
218                                _ => None
219                            }
220                        }
221                    ).ok()?;
222
223                    // Extract fields from Box<dyn Any> (strict mode - return None on failure)
224                    #(#field_extractions_strict)*
225
226                    Some(Self {
227                        #(#field_names),*
228                    })
229                }
230
231                fn deserialize(
232                    value: &::tryparse::value::FlexValue,
233                    ctx: &mut ::tryparse::deserializer::CoercionContext,
234                ) -> ::tryparse::error::Result<Self> {
235                    use std::any::Any;
236
237                    let mut deserializer = ::tryparse::deserializer::StructDeserializer::new()
238                        #(#field_descriptors)*;
239
240                    let fields = deserializer.deserialize(
241                        value,
242                        ctx,
243                        #name_str,
244                        |field_name, field_value, field_ctx, strict| {
245                            // Dispatch to the appropriate field type's LlmDeserialize impl
246                            match field_name {
247                                #(
248                                    #field_name_strs => {
249                                        if strict {
250                                            // Try strict deserialization
251                                            if let Some(v) = <#inner_types as ::tryparse::deserializer::LlmDeserialize>::try_deserialize(field_value, field_ctx) {
252                                                Ok(Box::new(v) as Box<dyn Any>)
253                                            } else {
254                                                Err(::tryparse::error::ParseError::DeserializeFailed(
255                                                    ::tryparse::error::DeserializeError::type_mismatch(
256                                                        stringify!(#inner_types),
257                                                        "value"
258                                                    )
259                                                ))
260                                            }
261                                        } else {
262                                            // Lenient deserialization
263                                            let v = <#inner_types as ::tryparse::deserializer::LlmDeserialize>::deserialize(field_value, field_ctx)?;
264                                            Ok(Box::new(v) as Box<dyn Any>)
265                                        }
266                                    }
267                                )*
268                                _ => Err(::tryparse::error::ParseError::DeserializeFailed(
269                                    ::tryparse::error::DeserializeError::Custom(
270                                        format!("Unknown field: {}", field_name)
271                                    )
272                                ))
273                            }
274                        }
275                    )?;
276
277                    // Extract fields from Box<dyn Any> (lenient mode - return error on failure)
278                    #(#field_extractions_lenient)*
279
280                    Ok(Self {
281                        #(#field_names),*
282                    })
283                }
284            }
285        }
286        Fields::Unnamed(_) => syn::Error::new_spanned(
287            data.fields.clone(),
288            "LlmDeserialize does not support tuple structs yet",
289        )
290        .to_compile_error(),
291        Fields::Unit => syn::Error::new_spanned(
292            data.fields.clone(),
293            "LlmDeserialize does not support unit structs",
294        )
295        .to_compile_error(),
296    }
297}
298
299/// Check if a type is Option<T>
300fn is_option_type(ty: &Type) -> bool {
301    if let Type::Path(type_path) = ty {
302        if let Some(segment) = type_path.path.segments.last() {
303            return segment.ident == "Option";
304        }
305    }
306    false
307}
308
309/// Extract the inner type T from Option<T>
310fn extract_option_inner(ty: &Type) -> Type {
311    if let Type::Path(type_path) = ty {
312        if let Some(segment) = type_path.path.segments.last() {
313            if segment.ident == "Option" {
314                if let PathArguments::AngleBracketed(args) = &segment.arguments {
315                    if let Some(GenericArgument::Type(inner)) = args.args.first() {
316                        return inner.clone();
317                    }
318                }
319            }
320        }
321    }
322    // Fallback: return the original type
323    ty.clone()
324}
325
326fn generate_enum_deserialize(
327    name: &syn::Ident,
328    data: &syn::DataEnum,
329    _attrs: &[syn::Attribute],
330) -> proc_macro2::TokenStream {
331    let name_str = name.to_string();
332
333    // Build EnumMatcher setup with all variants
334    let matcher_setup = data.variants.iter().map(|v| {
335        let variant_name = v.ident.to_string();
336        quote! {
337            .variant(::tryparse::deserializer::enum_coercer::EnumVariant::new(#variant_name))
338        }
339    });
340
341    // Build match arms for each variant
342    let match_arms = data.variants.iter().map(|v| {
343        let variant_ident = &v.ident;
344        let variant_name = v.ident.to_string();
345
346        match &v.fields {
347            Fields::Unit => {
348                // Simple unit variant (e.g., Status::Active)
349                quote! {
350                    #variant_name => Ok(Self::#variant_ident),
351                }
352            }
353            Fields::Named(_) | Fields::Unnamed(_) => {
354                // Complex variants with fields - not yet supported in derive macro
355                // Users can implement LlmDeserialize manually for these cases
356                quote! {
357                    #variant_name => Err(::tryparse::error::ParseError::DeserializeFailed(
358                        ::tryparse::error::DeserializeError::Custom(
359                            format!("Enum variant '{}' has fields - derive macro only supports unit variants", #variant_name)
360                        )
361                    )),
362                }
363            }
364        }
365    });
366
367    quote! {
368        fn deserialize(
369            value: &::tryparse::value::FlexValue,
370            _ctx: &mut ::tryparse::deserializer::CoercionContext,
371        ) -> ::tryparse::error::Result<Self> {
372            // Build matcher with all enum variants
373            let matcher = ::tryparse::deserializer::enum_coercer::EnumMatcher::new()
374                #(#matcher_setup)*;
375
376            // Use BAML's fuzzy matching to find the best variant
377            let matched_variant = ::tryparse::deserializer::enum_coercer::match_enum_variant(
378                value,
379                &matcher
380            )?;
381
382            // Construct the matched variant
383            match matched_variant.as_str() {
384                #(#match_arms)*
385                _ => Err(::tryparse::error::ParseError::DeserializeFailed(
386                    ::tryparse::error::DeserializeError::UnknownVariant {
387                        enum_name: #name_str.to_string(),
388                        variant: matched_variant,
389                    }
390                )),
391            }
392        }
393    }
394}
395
396/// Check if enum has #[llm(union)] attribute.
397fn has_union_attribute(attrs: &[syn::Attribute]) -> bool {
398    attrs.iter().any(|attr| {
399        if attr.path().is_ident("llm") {
400            // Parse as #[llm(union)]
401            if let Ok(meta_list) = attr.meta.require_list() {
402                // Check if any nested item is "union"
403                return meta_list.tokens.to_string().trim() == "union";
404            }
405        }
406        false
407    })
408}
409
410/// Generate union deserialization code for enums with #[llm(union)].
411fn generate_union_deserialize(
412    name: &syn::Ident,
413    data: &syn::DataEnum,
414    _attrs: &[syn::Attribute],
415) -> proc_macro2::TokenStream {
416    if data.variants.len() != 2 {
417        return syn::Error::new_spanned(name, "Union enums must have exactly 2 variants")
418            .to_compile_error();
419    }
420
421    let variants: Vec<_> = data.variants.iter().collect();
422    let variant1 = &variants[0];
423    let variant2 = &variants[1];
424
425    // Extract variant types
426    let (variant1_ident, variant1_type) = match &variant1.fields {
427        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
428            (&variant1.ident, &fields.unnamed[0].ty)
429        }
430        _ => {
431            return syn::Error::new_spanned(
432                variant1,
433                "Union variants must have exactly one unnamed field",
434            )
435            .to_compile_error();
436        }
437    };
438
439    let (variant2_ident, variant2_type) = match &variant2.fields {
440        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
441            (&variant2.ident, &fields.unnamed[0].ty)
442        }
443        _ => {
444            return syn::Error::new_spanned(
445                variant2,
446                "Union variants must have exactly one unnamed field",
447            )
448            .to_compile_error();
449        }
450    };
451
452    quote! {
453        fn deserialize(
454            value: &::tryparse::value::FlexValue,
455            ctx: &mut ::tryparse::deserializer::CoercionContext,
456        ) -> ::tryparse::error::Result<Self> {
457            use ::tryparse::deserializer::LlmDeserialize;
458
459            // BAML ALGORITHM: Try strict matching first (try_cast)
460            if let Some(v1) = <#variant1_type as LlmDeserialize>::try_deserialize(value, ctx) {
461                // Add UnionMatch transformation for strict match
462                ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
463                    index: 0,
464                    candidates: vec![
465                        stringify!(#variant1_type).to_string(),
466                        stringify!(#variant2_type).to_string(),
467                    ],
468                });
469                return Ok(Self::#variant1_ident(v1));
470            }
471
472            if let Some(v2) = <#variant2_type as LlmDeserialize>::try_deserialize(value, ctx) {
473                // Add UnionMatch transformation for strict match
474                ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
475                    index: 1,
476                    candidates: vec![
477                        stringify!(#variant1_type).to_string(),
478                        stringify!(#variant2_type).to_string(),
479                    ],
480                });
481                return Ok(Self::#variant2_ident(v2));
482            }
483
484            // BAML ALGORITHM: Try lenient matching with scoring (coerce)
485            struct MatchResult {
486                variant: u8,  // 1 or 2
487                score: u32,
488            }
489
490            let mut matches = Vec::new();
491
492            // Try variant 1 with separate FlexValue to track transformations
493            let value1 = value.clone();
494            if let Ok(_) = <#variant1_type as LlmDeserialize>::deserialize(&value1, ctx) {
495                let score: u32 = value1.transformations().iter().map(|t| t.penalty()).sum();
496                matches.push(MatchResult { variant: 1, score });
497            }
498
499            // Try variant 2 with separate FlexValue to track transformations
500            let value2 = value.clone();
501            if let Ok(_) = <#variant2_type as LlmDeserialize>::deserialize(&value2, ctx) {
502                let score: u32 = value2.transformations().iter().map(|t| t.penalty()).sum();
503                matches.push(MatchResult { variant: 2, score });
504            }
505
506            if matches.is_empty() {
507                return Err(::tryparse::error::ParseError::DeserializeFailed(
508                    ::tryparse::error::DeserializeError::Custom(
509                        "No union variant matched".to_string()
510                    )
511                ));
512            }
513
514            // Sort by score (lower is better)
515            matches.sort_by_key(|m| m.score);
516
517            // Add UnionMatch transformation to track which variant was selected
518            let variant_index = (matches[0].variant - 1) as usize;
519            ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
520                index: variant_index,
521                candidates: vec![
522                    stringify!(#variant1_type).to_string(),
523                    stringify!(#variant2_type).to_string(),
524                ],
525            });
526
527            // Deserialize the best match
528            match matches[0].variant {
529                1 => {
530                    let v1 = <#variant1_type as LlmDeserialize>::deserialize(value, ctx)?;
531                    Ok(Self::#variant1_ident(v1))
532                }
533                2 => {
534                    let v2 = <#variant2_type as LlmDeserialize>::deserialize(value, ctx)?;
535                    Ok(Self::#variant2_ident(v2))
536                }
537                _ => unreachable!(),
538            }
539        }
540    }
541}