1use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Type};
9
10#[proc_macro_derive(SchemaInfo)]
27pub fn derive_schema_info(input: TokenStream) -> TokenStream {
28 let input = parse_macro_input!(input as DeriveInput);
29
30 let name = &input.ident;
31 let generics = &input.generics;
32 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
33
34 let schema_impl = match &input.data {
35 Data::Struct(data_struct) => generate_struct_schema(name, data_struct),
36 Data::Enum(data_enum) => generate_enum_schema(name, data_enum),
37 Data::Union(_) => {
38 return syn::Error::new_spanned(input, "SchemaInfo cannot be derived for unions")
39 .to_compile_error()
40 .into();
41 }
42 };
43
44 let expanded = quote! {
45 impl #impl_generics ::tryparse::schema::SchemaInfo for #name #ty_generics #where_clause {
46 fn schema() -> ::tryparse::schema::Schema {
47 #schema_impl
48 }
49 }
50 };
51
52 TokenStream::from(expanded)
53}
54
55fn generate_struct_schema(name: &syn::Ident, data: &syn::DataStruct) -> proc_macro2::TokenStream {
56 let name_str = name.to_string();
57
58 match &data.fields {
59 Fields::Named(fields) => {
60 let field_defs = fields.named.iter().map(|f| {
61 let field_name = f.ident.as_ref().unwrap().to_string();
62 let field_type = &f.ty;
63
64 quote! {
65 ::tryparse::schema::Field::new(
66 #field_name,
67 <#field_type as ::tryparse::schema::SchemaInfo>::schema()
68 )
69 }
70 });
71
72 quote! {
73 ::tryparse::schema::Schema::Object {
74 name: #name_str.to_string(),
75 fields: vec![#(#field_defs),*],
76 }
77 }
78 }
79 Fields::Unnamed(fields) => {
80 let field_types = fields.unnamed.iter().map(|f| {
82 let ty = &f.ty;
83 quote! {
84 <#ty as ::tryparse::schema::SchemaInfo>::schema()
85 }
86 });
87
88 quote! {
89 ::tryparse::schema::Schema::Tuple(vec![#(#field_types),*])
90 }
91 }
92 Fields::Unit => {
93 quote! {
95 ::tryparse::schema::Schema::Null
96 }
97 }
98 }
99}
100
101fn generate_enum_schema(name: &syn::Ident, data: &syn::DataEnum) -> proc_macro2::TokenStream {
102 let name_str = name.to_string();
103
104 let variant_defs = data.variants.iter().map(|v| {
105 let variant_name = v.ident.to_string();
106
107 let variant_schema = match &v.fields {
108 Fields::Named(fields) => {
109 let field_defs = fields.named.iter().map(|f| {
111 let field_name = f.ident.as_ref().unwrap().to_string();
112 let field_type = &f.ty;
113
114 quote! {
115 ::tryparse::schema::Field::new(
116 #field_name,
117 <#field_type as ::tryparse::schema::SchemaInfo>::schema()
118 )
119 }
120 });
121
122 quote! {
123 ::tryparse::schema::Schema::Object {
124 name: #variant_name.to_string(),
125 fields: vec![#(#field_defs),*],
126 }
127 }
128 }
129 Fields::Unnamed(fields) => {
130 let field_types = fields.unnamed.iter().map(|f| {
132 let ty = &f.ty;
133 quote! {
134 <#ty as ::tryparse::schema::SchemaInfo>::schema()
135 }
136 });
137
138 quote! {
139 ::tryparse::schema::Schema::Tuple(vec![#(#field_types),*])
140 }
141 }
142 Fields::Unit => {
143 quote! {
145 ::tryparse::schema::Schema::Null
146 }
147 }
148 };
149
150 quote! {
151 ::tryparse::schema::Variant::new(#variant_name, #variant_schema)
152 }
153 });
154
155 quote! {
156 ::tryparse::schema::Schema::Union {
157 name: #name_str.to_string(),
158 variants: vec![#(#variant_defs),*],
159 }
160 }
161}
162
163#[proc_macro_derive(LlmDeserialize, attributes(llm))]
187pub fn derive_llm_deserialize(input: TokenStream) -> TokenStream {
188 let input = parse_macro_input!(input as DeriveInput);
189
190 let name = &input.ident;
191 let generics = &input.generics;
192 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
193
194 match &input.data {
195 Data::Struct(data_struct) => {
196 let deserialize_impl = generate_struct_deserialize(name, data_struct);
197
198 let expanded = quote! {
199 impl #impl_generics ::tryparse::deserializer::LlmDeserialize for #name #ty_generics #where_clause {
200 #deserialize_impl
201 }
202 };
203
204 TokenStream::from(expanded)
205 }
206 Data::Enum(data_enum) => {
207 let is_union = has_union_attribute(&input.attrs);
209
210 let deserialize_impl = if is_union {
211 generate_union_deserialize(name, data_enum, &input.attrs)
212 } else {
213 generate_enum_deserialize(name, data_enum, &input.attrs)
214 };
215
216 let expanded = quote! {
217 impl #impl_generics ::tryparse::deserializer::LlmDeserialize for #name #ty_generics #where_clause {
218 #deserialize_impl
219 }
220 };
221
222 TokenStream::from(expanded)
223 }
224 Data::Union(_) => {
225 syn::Error::new_spanned(input, "LlmDeserialize cannot be derived for unions")
226 .to_compile_error()
227 .into()
228 }
229 }
230}
231
232fn generate_struct_deserialize(
233 name: &syn::Ident,
234 data: &syn::DataStruct,
235) -> proc_macro2::TokenStream {
236 match &data.fields {
237 Fields::Named(fields) => {
238 let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
239 let field_types: Vec<_> = fields.named.iter().map(|f| &f.ty).collect();
240 let field_name_strs: Vec<_> = fields
241 .named
242 .iter()
243 .map(|f| f.ident.as_ref().unwrap().to_string())
244 .collect();
245
246 let is_optional: Vec<_> = field_types.iter().map(|ty| is_option_type(ty)).collect();
248
249 let inner_types: Vec<_> = field_types
251 .iter()
252 .zip(&is_optional)
253 .map(|(ty, opt)| {
254 if *opt {
255 extract_option_inner(ty)
256 } else {
257 (*ty).clone()
258 }
259 })
260 .collect();
261
262 let name_str = name.to_string();
263
264 let field_descriptors: Vec<_> = field_name_strs
266 .iter()
267 .zip(&field_types)
268 .zip(&is_optional)
269 .map(|((name, ty), opt)| {
270 let type_name = quote!(stringify!(#ty)).to_string();
271 quote! {
272 .field(::tryparse::deserializer::FieldDescriptor::new(
273 #name,
274 #type_name,
275 #opt
276 ))
277 }
278 })
279 .collect();
280
281 let field_extractions_strict: Vec<_> = field_names
283 .iter()
284 .zip(&inner_types)
285 .zip(&is_optional)
286 .map(|((field_name, inner_ty), opt)| {
287 let field_name_str = field_name.as_ref().unwrap().to_string();
288 if *opt {
289 quote! {
291 let #field_name = fields.get(#field_name_str)
292 .and_then(|v| v.downcast_ref::<#inner_ty>())
293 .cloned();
294 }
295 } else {
296 quote! {
298 let #field_name = fields.get(#field_name_str)
299 .and_then(|v| v.downcast_ref::<#inner_ty>())
300 .cloned()?;
301 }
302 }
303 })
304 .collect();
305
306 let field_extractions_lenient: Vec<_> = field_names.iter().zip(&inner_types).zip(&is_optional).map(|((field_name, inner_ty), opt)| {
308 let field_name_str = field_name.as_ref().unwrap().to_string();
309 if *opt {
310 quote! {
312 let #field_name = fields.get(#field_name_str)
313 .and_then(|v| v.downcast_ref::<#inner_ty>())
314 .cloned();
315 }
316 } else {
317 quote! {
319 let #field_name = fields.get(#field_name_str)
320 .and_then(|v| v.downcast_ref::<#inner_ty>())
321 .cloned()
322 .ok_or_else(|| ::tryparse::error::ParseError::DeserializeFailed(
323 ::tryparse::error::DeserializeError::missing_field(#field_name_str)
324 ))?;
325 }
326 }
327 }).collect();
328
329 quote! {
330 fn try_deserialize(
331 value: &::tryparse::value::FlexValue,
332 ctx: &mut ::tryparse::deserializer::CoercionContext,
333 ) -> Option<Self> {
334 use std::any::Any;
335
336 let mut deserializer = ::tryparse::deserializer::StructDeserializer::new()
337 #(#field_descriptors)*;
338
339 let fields = deserializer.try_deserialize(
340 value,
341 ctx,
342 #name_str,
343 |field_name, field_value, field_ctx| {
344 match field_name {
346 #(
347 #field_name_strs => {
348 <#inner_types as ::tryparse::deserializer::LlmDeserialize>::try_deserialize(field_value, field_ctx)
350 .map(|v| Box::new(v) as Box<dyn Any>)
351 }
352 )*
353 _ => None
354 }
355 }
356 ).ok()?;
357
358 #(#field_extractions_strict)*
360
361 Some(Self {
362 #(#field_names),*
363 })
364 }
365
366 fn deserialize(
367 value: &::tryparse::value::FlexValue,
368 ctx: &mut ::tryparse::deserializer::CoercionContext,
369 ) -> ::tryparse::error::Result<Self> {
370 use std::any::Any;
371
372 let mut deserializer = ::tryparse::deserializer::StructDeserializer::new()
373 #(#field_descriptors)*;
374
375 let fields = deserializer.deserialize(
376 value,
377 ctx,
378 #name_str,
379 |field_name, field_value, field_ctx, strict| {
380 match field_name {
382 #(
383 #field_name_strs => {
384 if strict {
385 if let Some(v) = <#inner_types as ::tryparse::deserializer::LlmDeserialize>::try_deserialize(field_value, field_ctx) {
387 Ok(Box::new(v) as Box<dyn Any>)
388 } else {
389 Err(::tryparse::error::ParseError::DeserializeFailed(
390 ::tryparse::error::DeserializeError::type_mismatch(
391 stringify!(#inner_types),
392 "value"
393 )
394 ))
395 }
396 } else {
397 let v = <#inner_types as ::tryparse::deserializer::LlmDeserialize>::deserialize(field_value, field_ctx)?;
399 Ok(Box::new(v) as Box<dyn Any>)
400 }
401 }
402 )*
403 _ => Err(::tryparse::error::ParseError::DeserializeFailed(
404 ::tryparse::error::DeserializeError::Custom(
405 format!("Unknown field: {}", field_name)
406 )
407 ))
408 }
409 }
410 )?;
411
412 #(#field_extractions_lenient)*
414
415 Ok(Self {
416 #(#field_names),*
417 })
418 }
419 }
420 }
421 Fields::Unnamed(_) => syn::Error::new_spanned(
422 data.fields.clone(),
423 "LlmDeserialize does not support tuple structs yet",
424 )
425 .to_compile_error(),
426 Fields::Unit => syn::Error::new_spanned(
427 data.fields.clone(),
428 "LlmDeserialize does not support unit structs",
429 )
430 .to_compile_error(),
431 }
432}
433
434fn is_option_type(ty: &Type) -> bool {
436 if let Type::Path(type_path) = ty {
437 if let Some(segment) = type_path.path.segments.last() {
438 return segment.ident == "Option";
439 }
440 }
441 false
442}
443
444fn extract_option_inner(ty: &Type) -> Type {
446 if let Type::Path(type_path) = ty {
447 if let Some(segment) = type_path.path.segments.last() {
448 if segment.ident == "Option" {
449 if let PathArguments::AngleBracketed(args) = &segment.arguments {
450 if let Some(GenericArgument::Type(inner)) = args.args.first() {
451 return inner.clone();
452 }
453 }
454 }
455 }
456 }
457 ty.clone()
459}
460
461fn generate_enum_deserialize(
462 name: &syn::Ident,
463 data: &syn::DataEnum,
464 _attrs: &[syn::Attribute],
465) -> proc_macro2::TokenStream {
466 let name_str = name.to_string();
467
468 let matcher_setup = data.variants.iter().map(|v| {
470 let variant_name = v.ident.to_string();
471 quote! {
472 .variant(::tryparse::deserializer::enum_coercer::EnumVariant::new(#variant_name))
473 }
474 });
475
476 let match_arms = data.variants.iter().map(|v| {
478 let variant_ident = &v.ident;
479 let variant_name = v.ident.to_string();
480
481 match &v.fields {
482 Fields::Unit => {
483 quote! {
485 #variant_name => Ok(Self::#variant_ident),
486 }
487 }
488 Fields::Named(_) | Fields::Unnamed(_) => {
489 quote! {
492 #variant_name => Err(::tryparse::error::ParseError::DeserializeFailed(
493 ::tryparse::error::DeserializeError::Custom(
494 format!("Enum variant '{}' has fields - derive macro only supports unit variants", #variant_name)
495 )
496 )),
497 }
498 }
499 }
500 });
501
502 quote! {
503 fn deserialize(
504 value: &::tryparse::value::FlexValue,
505 _ctx: &mut ::tryparse::deserializer::CoercionContext,
506 ) -> ::tryparse::error::Result<Self> {
507 let matcher = ::tryparse::deserializer::enum_coercer::EnumMatcher::new()
509 #(#matcher_setup)*;
510
511 let matched_variant = ::tryparse::deserializer::enum_coercer::match_enum_variant(
513 value,
514 &matcher
515 )?;
516
517 match matched_variant.as_str() {
519 #(#match_arms)*
520 _ => Err(::tryparse::error::ParseError::DeserializeFailed(
521 ::tryparse::error::DeserializeError::UnknownVariant {
522 enum_name: #name_str.to_string(),
523 variant: matched_variant,
524 }
525 )),
526 }
527 }
528 }
529}
530
531fn has_union_attribute(attrs: &[syn::Attribute]) -> bool {
533 attrs.iter().any(|attr| {
534 if attr.path().is_ident("llm") {
535 if let Ok(meta_list) = attr.meta.require_list() {
537 return meta_list.tokens.to_string().trim() == "union";
539 }
540 }
541 false
542 })
543}
544
545fn generate_union_deserialize(
547 name: &syn::Ident,
548 data: &syn::DataEnum,
549 _attrs: &[syn::Attribute],
550) -> proc_macro2::TokenStream {
551 if data.variants.len() != 2 {
552 return syn::Error::new_spanned(name, "Union enums must have exactly 2 variants")
553 .to_compile_error();
554 }
555
556 let variants: Vec<_> = data.variants.iter().collect();
557 let variant1 = &variants[0];
558 let variant2 = &variants[1];
559
560 let (variant1_ident, variant1_type) = match &variant1.fields {
562 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
563 (&variant1.ident, &fields.unnamed[0].ty)
564 }
565 _ => {
566 return syn::Error::new_spanned(
567 variant1,
568 "Union variants must have exactly one unnamed field",
569 )
570 .to_compile_error();
571 }
572 };
573
574 let (variant2_ident, variant2_type) = match &variant2.fields {
575 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
576 (&variant2.ident, &fields.unnamed[0].ty)
577 }
578 _ => {
579 return syn::Error::new_spanned(
580 variant2,
581 "Union variants must have exactly one unnamed field",
582 )
583 .to_compile_error();
584 }
585 };
586
587 quote! {
588 fn deserialize(
589 value: &::tryparse::value::FlexValue,
590 ctx: &mut ::tryparse::deserializer::CoercionContext,
591 ) -> ::tryparse::error::Result<Self> {
592 use ::tryparse::deserializer::LlmDeserialize;
593
594 if let Some(v1) = <#variant1_type as LlmDeserialize>::try_deserialize(value, ctx) {
596 ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
598 index: 0,
599 candidates: vec![
600 stringify!(#variant1_type).to_string(),
601 stringify!(#variant2_type).to_string(),
602 ],
603 });
604 return Ok(Self::#variant1_ident(v1));
605 }
606
607 if let Some(v2) = <#variant2_type as LlmDeserialize>::try_deserialize(value, ctx) {
608 ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
610 index: 1,
611 candidates: vec![
612 stringify!(#variant1_type).to_string(),
613 stringify!(#variant2_type).to_string(),
614 ],
615 });
616 return Ok(Self::#variant2_ident(v2));
617 }
618
619 struct MatchResult {
621 variant: u8, score: u32,
623 }
624
625 let mut matches = Vec::new();
626
627 let value1 = value.clone();
629 if let Ok(_) = <#variant1_type as LlmDeserialize>::deserialize(&value1, ctx) {
630 let score: u32 = value1.transformations().iter().map(|t| t.penalty()).sum();
631 matches.push(MatchResult { variant: 1, score });
632 }
633
634 let value2 = value.clone();
636 if let Ok(_) = <#variant2_type as LlmDeserialize>::deserialize(&value2, ctx) {
637 let score: u32 = value2.transformations().iter().map(|t| t.penalty()).sum();
638 matches.push(MatchResult { variant: 2, score });
639 }
640
641 if matches.is_empty() {
642 return Err(::tryparse::error::ParseError::DeserializeFailed(
643 ::tryparse::error::DeserializeError::Custom(
644 "No union variant matched".to_string()
645 )
646 ));
647 }
648
649 matches.sort_by_key(|m| m.score);
651
652 let variant_index = (matches[0].variant - 1) as usize;
654 ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
655 index: variant_index,
656 candidates: vec![
657 stringify!(#variant1_type).to_string(),
658 stringify!(#variant2_type).to_string(),
659 ],
660 });
661
662 match matches[0].variant {
664 1 => {
665 let v1 = <#variant1_type as LlmDeserialize>::deserialize(value, ctx)?;
666 Ok(Self::#variant1_ident(v1))
667 }
668 2 => {
669 let v2 = <#variant2_type as LlmDeserialize>::deserialize(value, ctx)?;
670 Ok(Self::#variant2_ident(v2))
671 }
672 _ => unreachable!(),
673 }
674 }
675 }
676}