tryparse_derive/
lib.rs

1//! Derive macros for tryparse
2//!
3//! This crate provides derive macros for automatically generating
4//! schema information and deserialization logic from Rust types.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Type};
9
10/// Derives the `SchemaInfo` trait for a struct or enum.
11///
12/// # Example
13///
14/// ```ignore
15/// use tryparse::SchemaInfo;
16///
17/// #[derive(SchemaInfo)]
18/// struct User {
19///     name: String,
20///     age: u32,
21/// }
22///
23/// let schema = User::schema();
24/// // Schema::Object { name: "User", fields: [...] }
25/// ```
26#[proc_macro_derive(SchemaInfo)]
27pub fn derive_schema_info(input: TokenStream) -> TokenStream {
28    let input = parse_macro_input!(input as DeriveInput);
29
30    let name = &input.ident;
31    let generics = &input.generics;
32    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
33
34    let schema_impl = match &input.data {
35        Data::Struct(data_struct) => generate_struct_schema(name, data_struct),
36        Data::Enum(data_enum) => generate_enum_schema(name, data_enum),
37        Data::Union(_) => {
38            return syn::Error::new_spanned(input, "SchemaInfo cannot be derived for unions")
39                .to_compile_error()
40                .into();
41        }
42    };
43
44    let expanded = quote! {
45        impl #impl_generics ::tryparse::schema::SchemaInfo for #name #ty_generics #where_clause {
46            fn schema() -> ::tryparse::schema::Schema {
47                #schema_impl
48            }
49        }
50    };
51
52    TokenStream::from(expanded)
53}
54
55fn generate_struct_schema(name: &syn::Ident, data: &syn::DataStruct) -> proc_macro2::TokenStream {
56    let name_str = name.to_string();
57
58    match &data.fields {
59        Fields::Named(fields) => {
60            let field_defs = fields.named.iter().map(|f| {
61                let field_name = f.ident.as_ref().unwrap().to_string();
62                let field_type = &f.ty;
63
64                quote! {
65                    ::tryparse::schema::Field::new(
66                        #field_name,
67                        <#field_type as ::tryparse::schema::SchemaInfo>::schema()
68                    )
69                }
70            });
71
72            quote! {
73                ::tryparse::schema::Schema::Object {
74                    name: #name_str.to_string(),
75                    fields: vec![#(#field_defs),*],
76                }
77            }
78        }
79        Fields::Unnamed(fields) => {
80            // Treat tuple structs as tuples
81            let field_types = fields.unnamed.iter().map(|f| {
82                let ty = &f.ty;
83                quote! {
84                    <#ty as ::tryparse::schema::SchemaInfo>::schema()
85                }
86            });
87
88            quote! {
89                ::tryparse::schema::Schema::Tuple(vec![#(#field_types),*])
90            }
91        }
92        Fields::Unit => {
93            // Unit struct is like null
94            quote! {
95                ::tryparse::schema::Schema::Null
96            }
97        }
98    }
99}
100
101fn generate_enum_schema(name: &syn::Ident, data: &syn::DataEnum) -> proc_macro2::TokenStream {
102    let name_str = name.to_string();
103
104    let variant_defs = data.variants.iter().map(|v| {
105        let variant_name = v.ident.to_string();
106
107        let variant_schema = match &v.fields {
108            Fields::Named(fields) => {
109                // Variant with named fields - treat as object
110                let field_defs = fields.named.iter().map(|f| {
111                    let field_name = f.ident.as_ref().unwrap().to_string();
112                    let field_type = &f.ty;
113
114                    quote! {
115                        ::tryparse::schema::Field::new(
116                            #field_name,
117                            <#field_type as ::tryparse::schema::SchemaInfo>::schema()
118                        )
119                    }
120                });
121
122                quote! {
123                    ::tryparse::schema::Schema::Object {
124                        name: #variant_name.to_string(),
125                        fields: vec![#(#field_defs),*],
126                    }
127                }
128            }
129            Fields::Unnamed(fields) => {
130                // Variant with unnamed fields - treat as tuple
131                let field_types = fields.unnamed.iter().map(|f| {
132                    let ty = &f.ty;
133                    quote! {
134                        <#ty as ::tryparse::schema::SchemaInfo>::schema()
135                    }
136                });
137
138                quote! {
139                    ::tryparse::schema::Schema::Tuple(vec![#(#field_types),*])
140                }
141            }
142            Fields::Unit => {
143                // Unit variant - like null
144                quote! {
145                    ::tryparse::schema::Schema::Null
146                }
147            }
148        };
149
150        quote! {
151            ::tryparse::schema::Variant::new(#variant_name, #variant_schema)
152        }
153    });
154
155    quote! {
156        ::tryparse::schema::Schema::Union {
157            name: #name_str.to_string(),
158            variants: vec![#(#variant_defs),*],
159        }
160    }
161}
162
163/// Derives the `LlmDeserialize` trait for a struct.
164///
165/// This macro generates a custom deserialization implementation using BAML's
166/// algorithms for fuzzy field matching and type coercion.
167///
168/// # Example
169///
170/// ```ignore
171/// use tryparse::deserializer::LlmDeserialize;
172///
173/// #[derive(LlmDeserialize)]
174/// struct User {
175///     name: String,
176///     age: u32,
177///     email: Option<String>, // Optional field
178/// }
179///
180/// // The macro generates an implementation that:
181/// // - Handles fuzzy field matching (userName → user_name)
182/// // - Supports optional fields with defaults
183/// // - Detects circular references
184/// // - Tracks transformations
185/// ```
186#[proc_macro_derive(LlmDeserialize, attributes(llm))]
187pub fn derive_llm_deserialize(input: TokenStream) -> TokenStream {
188    let input = parse_macro_input!(input as DeriveInput);
189
190    let name = &input.ident;
191    let generics = &input.generics;
192    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
193
194    match &input.data {
195        Data::Struct(data_struct) => {
196            let deserialize_impl = generate_struct_deserialize(name, data_struct);
197
198            let expanded = quote! {
199                impl #impl_generics ::tryparse::deserializer::LlmDeserialize for #name #ty_generics #where_clause {
200                    #deserialize_impl
201                }
202            };
203
204            TokenStream::from(expanded)
205        }
206        Data::Enum(data_enum) => {
207            // Check if this is a union enum (has #[llm(union)] attribute)
208            let is_union = has_union_attribute(&input.attrs);
209
210            let deserialize_impl = if is_union {
211                generate_union_deserialize(name, data_enum, &input.attrs)
212            } else {
213                generate_enum_deserialize(name, data_enum, &input.attrs)
214            };
215
216            let expanded = quote! {
217                impl #impl_generics ::tryparse::deserializer::LlmDeserialize for #name #ty_generics #where_clause {
218                    #deserialize_impl
219                }
220            };
221
222            TokenStream::from(expanded)
223        }
224        Data::Union(_) => {
225            syn::Error::new_spanned(input, "LlmDeserialize cannot be derived for unions")
226                .to_compile_error()
227                .into()
228        }
229    }
230}
231
232fn generate_struct_deserialize(
233    name: &syn::Ident,
234    data: &syn::DataStruct,
235) -> proc_macro2::TokenStream {
236    match &data.fields {
237        Fields::Named(fields) => {
238            let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
239            let field_types: Vec<_> = fields.named.iter().map(|f| &f.ty).collect();
240            let field_name_strs: Vec<_> = fields
241                .named
242                .iter()
243                .map(|f| f.ident.as_ref().unwrap().to_string())
244                .collect();
245
246            // Check if each field is Option<T>
247            let is_optional: Vec<_> = field_types.iter().map(|ty| is_option_type(ty)).collect();
248
249            // Extract inner type for Option<T> fields
250            let inner_types: Vec<_> = field_types
251                .iter()
252                .zip(&is_optional)
253                .map(|(ty, opt)| {
254                    if *opt {
255                        extract_option_inner(ty)
256                    } else {
257                        (*ty).clone()
258                    }
259                })
260                .collect();
261
262            let name_str = name.to_string();
263
264            // Generate field descriptor setup (collect to Vec for reuse)
265            let field_descriptors: Vec<_> = field_name_strs
266                .iter()
267                .zip(&field_types)
268                .zip(&is_optional)
269                .map(|((name, ty), opt)| {
270                    let type_name = quote!(stringify!(#ty)).to_string();
271                    quote! {
272                        .field(::tryparse::deserializer::FieldDescriptor::new(
273                            #name,
274                            #type_name,
275                            #opt
276                        ))
277                    }
278                })
279                .collect();
280
281            // Generate field extraction for try_deserialize (returns Option)
282            let field_extractions_strict: Vec<_> = field_names
283                .iter()
284                .zip(&inner_types)
285                .zip(&is_optional)
286                .map(|((field_name, inner_ty), opt)| {
287                    let field_name_str = field_name.as_ref().unwrap().to_string();
288                    if *opt {
289                        // Optional field
290                        quote! {
291                            let #field_name = fields.get(#field_name_str)
292                                .and_then(|v| v.downcast_ref::<#inner_ty>())
293                                .cloned();
294                        }
295                    } else {
296                        // Required field - return None if missing
297                        quote! {
298                            let #field_name = fields.get(#field_name_str)
299                                .and_then(|v| v.downcast_ref::<#inner_ty>())
300                                .cloned()?;
301                        }
302                    }
303                })
304                .collect();
305
306            // Generate field extraction for deserialize (returns Result)
307            let field_extractions_lenient: Vec<_> = field_names.iter().zip(&inner_types).zip(&is_optional).map(|((field_name, inner_ty), opt)| {
308                let field_name_str = field_name.as_ref().unwrap().to_string();
309                if *opt {
310                    // Optional field
311                    quote! {
312                        let #field_name = fields.get(#field_name_str)
313                            .and_then(|v| v.downcast_ref::<#inner_ty>())
314                            .cloned();
315                    }
316                } else {
317                    // Required field
318                    quote! {
319                        let #field_name = fields.get(#field_name_str)
320                            .and_then(|v| v.downcast_ref::<#inner_ty>())
321                            .cloned()
322                            .ok_or_else(|| ::tryparse::error::ParseError::DeserializeFailed(
323                                ::tryparse::error::DeserializeError::missing_field(#field_name_str)
324                            ))?;
325                    }
326                }
327            }).collect();
328
329            quote! {
330                fn try_deserialize(
331                    value: &::tryparse::value::FlexValue,
332                    ctx: &mut ::tryparse::deserializer::CoercionContext,
333                ) -> Option<Self> {
334                    use std::any::Any;
335
336                    let mut deserializer = ::tryparse::deserializer::StructDeserializer::new()
337                        #(#field_descriptors)*;
338
339                    let fields = deserializer.try_deserialize(
340                        value,
341                        ctx,
342                        #name_str,
343                        |field_name, field_value, field_ctx| {
344                            // Dispatch to the appropriate field type's LlmDeserialize impl (strict mode only)
345                            match field_name {
346                                #(
347                                    #field_name_strs => {
348                                        // Try strict deserialization
349                                        <#inner_types as ::tryparse::deserializer::LlmDeserialize>::try_deserialize(field_value, field_ctx)
350                                            .map(|v| Box::new(v) as Box<dyn Any>)
351                                    }
352                                )*
353                                _ => None
354                            }
355                        }
356                    ).ok()?;
357
358                    // Extract fields from Box<dyn Any> (strict mode - return None on failure)
359                    #(#field_extractions_strict)*
360
361                    Some(Self {
362                        #(#field_names),*
363                    })
364                }
365
366                fn deserialize(
367                    value: &::tryparse::value::FlexValue,
368                    ctx: &mut ::tryparse::deserializer::CoercionContext,
369                ) -> ::tryparse::error::Result<Self> {
370                    use std::any::Any;
371
372                    let mut deserializer = ::tryparse::deserializer::StructDeserializer::new()
373                        #(#field_descriptors)*;
374
375                    let fields = deserializer.deserialize(
376                        value,
377                        ctx,
378                        #name_str,
379                        |field_name, field_value, field_ctx, strict| {
380                            // Dispatch to the appropriate field type's LlmDeserialize impl
381                            match field_name {
382                                #(
383                                    #field_name_strs => {
384                                        if strict {
385                                            // Try strict deserialization
386                                            if let Some(v) = <#inner_types as ::tryparse::deserializer::LlmDeserialize>::try_deserialize(field_value, field_ctx) {
387                                                Ok(Box::new(v) as Box<dyn Any>)
388                                            } else {
389                                                Err(::tryparse::error::ParseError::DeserializeFailed(
390                                                    ::tryparse::error::DeserializeError::type_mismatch(
391                                                        stringify!(#inner_types),
392                                                        "value"
393                                                    )
394                                                ))
395                                            }
396                                        } else {
397                                            // Lenient deserialization
398                                            let v = <#inner_types as ::tryparse::deserializer::LlmDeserialize>::deserialize(field_value, field_ctx)?;
399                                            Ok(Box::new(v) as Box<dyn Any>)
400                                        }
401                                    }
402                                )*
403                                _ => Err(::tryparse::error::ParseError::DeserializeFailed(
404                                    ::tryparse::error::DeserializeError::Custom(
405                                        format!("Unknown field: {}", field_name)
406                                    )
407                                ))
408                            }
409                        }
410                    )?;
411
412                    // Extract fields from Box<dyn Any> (lenient mode - return error on failure)
413                    #(#field_extractions_lenient)*
414
415                    Ok(Self {
416                        #(#field_names),*
417                    })
418                }
419            }
420        }
421        Fields::Unnamed(_) => syn::Error::new_spanned(
422            data.fields.clone(),
423            "LlmDeserialize does not support tuple structs yet",
424        )
425        .to_compile_error(),
426        Fields::Unit => syn::Error::new_spanned(
427            data.fields.clone(),
428            "LlmDeserialize does not support unit structs",
429        )
430        .to_compile_error(),
431    }
432}
433
434/// Check if a type is Option<T>
435fn is_option_type(ty: &Type) -> bool {
436    if let Type::Path(type_path) = ty {
437        if let Some(segment) = type_path.path.segments.last() {
438            return segment.ident == "Option";
439        }
440    }
441    false
442}
443
444/// Extract the inner type T from Option<T>
445fn extract_option_inner(ty: &Type) -> Type {
446    if let Type::Path(type_path) = ty {
447        if let Some(segment) = type_path.path.segments.last() {
448            if segment.ident == "Option" {
449                if let PathArguments::AngleBracketed(args) = &segment.arguments {
450                    if let Some(GenericArgument::Type(inner)) = args.args.first() {
451                        return inner.clone();
452                    }
453                }
454            }
455        }
456    }
457    // Fallback: return the original type
458    ty.clone()
459}
460
461fn generate_enum_deserialize(
462    name: &syn::Ident,
463    data: &syn::DataEnum,
464    _attrs: &[syn::Attribute],
465) -> proc_macro2::TokenStream {
466    let name_str = name.to_string();
467
468    // Build EnumMatcher setup with all variants
469    let matcher_setup = data.variants.iter().map(|v| {
470        let variant_name = v.ident.to_string();
471        quote! {
472            .variant(::tryparse::deserializer::enum_coercer::EnumVariant::new(#variant_name))
473        }
474    });
475
476    // Build match arms for each variant
477    let match_arms = data.variants.iter().map(|v| {
478        let variant_ident = &v.ident;
479        let variant_name = v.ident.to_string();
480
481        match &v.fields {
482            Fields::Unit => {
483                // Simple unit variant (e.g., Status::Active)
484                quote! {
485                    #variant_name => Ok(Self::#variant_ident),
486                }
487            }
488            Fields::Named(_) | Fields::Unnamed(_) => {
489                // Complex variants with fields - not yet supported in derive macro
490                // Users can implement LlmDeserialize manually for these cases
491                quote! {
492                    #variant_name => Err(::tryparse::error::ParseError::DeserializeFailed(
493                        ::tryparse::error::DeserializeError::Custom(
494                            format!("Enum variant '{}' has fields - derive macro only supports unit variants", #variant_name)
495                        )
496                    )),
497                }
498            }
499        }
500    });
501
502    quote! {
503        fn deserialize(
504            value: &::tryparse::value::FlexValue,
505            _ctx: &mut ::tryparse::deserializer::CoercionContext,
506        ) -> ::tryparse::error::Result<Self> {
507            // Build matcher with all enum variants
508            let matcher = ::tryparse::deserializer::enum_coercer::EnumMatcher::new()
509                #(#matcher_setup)*;
510
511            // Use BAML's fuzzy matching to find the best variant
512            let matched_variant = ::tryparse::deserializer::enum_coercer::match_enum_variant(
513                value,
514                &matcher
515            )?;
516
517            // Construct the matched variant
518            match matched_variant.as_str() {
519                #(#match_arms)*
520                _ => Err(::tryparse::error::ParseError::DeserializeFailed(
521                    ::tryparse::error::DeserializeError::UnknownVariant {
522                        enum_name: #name_str.to_string(),
523                        variant: matched_variant,
524                    }
525                )),
526            }
527        }
528    }
529}
530
531/// Check if enum has #[llm(union)] attribute.
532fn has_union_attribute(attrs: &[syn::Attribute]) -> bool {
533    attrs.iter().any(|attr| {
534        if attr.path().is_ident("llm") {
535            // Parse as #[llm(union)]
536            if let Ok(meta_list) = attr.meta.require_list() {
537                // Check if any nested item is "union"
538                return meta_list.tokens.to_string().trim() == "union";
539            }
540        }
541        false
542    })
543}
544
545/// Generate union deserialization code for enums with #[llm(union)].
546fn generate_union_deserialize(
547    name: &syn::Ident,
548    data: &syn::DataEnum,
549    _attrs: &[syn::Attribute],
550) -> proc_macro2::TokenStream {
551    if data.variants.len() != 2 {
552        return syn::Error::new_spanned(name, "Union enums must have exactly 2 variants")
553            .to_compile_error();
554    }
555
556    let variants: Vec<_> = data.variants.iter().collect();
557    let variant1 = &variants[0];
558    let variant2 = &variants[1];
559
560    // Extract variant types
561    let (variant1_ident, variant1_type) = match &variant1.fields {
562        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
563            (&variant1.ident, &fields.unnamed[0].ty)
564        }
565        _ => {
566            return syn::Error::new_spanned(
567                variant1,
568                "Union variants must have exactly one unnamed field",
569            )
570            .to_compile_error();
571        }
572    };
573
574    let (variant2_ident, variant2_type) = match &variant2.fields {
575        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
576            (&variant2.ident, &fields.unnamed[0].ty)
577        }
578        _ => {
579            return syn::Error::new_spanned(
580                variant2,
581                "Union variants must have exactly one unnamed field",
582            )
583            .to_compile_error();
584        }
585    };
586
587    quote! {
588        fn deserialize(
589            value: &::tryparse::value::FlexValue,
590            ctx: &mut ::tryparse::deserializer::CoercionContext,
591        ) -> ::tryparse::error::Result<Self> {
592            use ::tryparse::deserializer::LlmDeserialize;
593
594            // BAML ALGORITHM: Try strict matching first (try_cast)
595            if let Some(v1) = <#variant1_type as LlmDeserialize>::try_deserialize(value, ctx) {
596                // Add UnionMatch transformation for strict match
597                ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
598                    index: 0,
599                    candidates: vec![
600                        stringify!(#variant1_type).to_string(),
601                        stringify!(#variant2_type).to_string(),
602                    ],
603                });
604                return Ok(Self::#variant1_ident(v1));
605            }
606
607            if let Some(v2) = <#variant2_type as LlmDeserialize>::try_deserialize(value, ctx) {
608                // Add UnionMatch transformation for strict match
609                ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
610                    index: 1,
611                    candidates: vec![
612                        stringify!(#variant1_type).to_string(),
613                        stringify!(#variant2_type).to_string(),
614                    ],
615                });
616                return Ok(Self::#variant2_ident(v2));
617            }
618
619            // BAML ALGORITHM: Try lenient matching with scoring (coerce)
620            struct MatchResult {
621                variant: u8,  // 1 or 2
622                score: u32,
623            }
624
625            let mut matches = Vec::new();
626
627            // Try variant 1 with separate FlexValue to track transformations
628            let value1 = value.clone();
629            if let Ok(_) = <#variant1_type as LlmDeserialize>::deserialize(&value1, ctx) {
630                let score: u32 = value1.transformations().iter().map(|t| t.penalty()).sum();
631                matches.push(MatchResult { variant: 1, score });
632            }
633
634            // Try variant 2 with separate FlexValue to track transformations
635            let value2 = value.clone();
636            if let Ok(_) = <#variant2_type as LlmDeserialize>::deserialize(&value2, ctx) {
637                let score: u32 = value2.transformations().iter().map(|t| t.penalty()).sum();
638                matches.push(MatchResult { variant: 2, score });
639            }
640
641            if matches.is_empty() {
642                return Err(::tryparse::error::ParseError::DeserializeFailed(
643                    ::tryparse::error::DeserializeError::Custom(
644                        "No union variant matched".to_string()
645                    )
646                ));
647            }
648
649            // Sort by score (lower is better)
650            matches.sort_by_key(|m| m.score);
651
652            // Add UnionMatch transformation to track which variant was selected
653            let variant_index = (matches[0].variant - 1) as usize;
654            ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
655                index: variant_index,
656                candidates: vec![
657                    stringify!(#variant1_type).to_string(),
658                    stringify!(#variant2_type).to_string(),
659                ],
660            });
661
662            // Deserialize the best match
663            match matches[0].variant {
664                1 => {
665                    let v1 = <#variant1_type as LlmDeserialize>::deserialize(value, ctx)?;
666                    Ok(Self::#variant1_ident(v1))
667                }
668                2 => {
669                    let v2 = <#variant2_type as LlmDeserialize>::deserialize(value, ctx)?;
670                    Ok(Self::#variant2_ident(v2))
671                }
672                _ => unreachable!(),
673            }
674        }
675    }
676}