1use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Type};
9
10#[proc_macro_derive(LlmDeserialize, attributes(llm))]
52pub fn derive_llm_deserialize(input: TokenStream) -> TokenStream {
53 let input = parse_macro_input!(input as DeriveInput);
54
55 let name = &input.ident;
56 let generics = &input.generics;
57 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
58
59 match &input.data {
60 Data::Struct(data_struct) => {
61 let deserialize_impl = generate_struct_deserialize(name, data_struct);
62
63 let expanded = quote! {
64 impl #impl_generics ::tryparse::deserializer::LlmDeserialize for #name #ty_generics #where_clause {
65 #deserialize_impl
66 }
67 };
68
69 TokenStream::from(expanded)
70 }
71 Data::Enum(data_enum) => {
72 let is_union = has_union_attribute(&input.attrs);
74
75 let deserialize_impl = if is_union {
76 generate_union_deserialize(name, data_enum, &input.attrs)
77 } else {
78 generate_enum_deserialize(name, data_enum, &input.attrs)
79 };
80
81 let expanded = quote! {
82 impl #impl_generics ::tryparse::deserializer::LlmDeserialize for #name #ty_generics #where_clause {
83 #deserialize_impl
84 }
85 };
86
87 TokenStream::from(expanded)
88 }
89 Data::Union(_) => {
90 syn::Error::new_spanned(input, "LlmDeserialize cannot be derived for unions")
91 .to_compile_error()
92 .into()
93 }
94 }
95}
96
97fn generate_struct_deserialize(
98 name: &syn::Ident,
99 data: &syn::DataStruct,
100) -> proc_macro2::TokenStream {
101 match &data.fields {
102 Fields::Named(fields) => {
103 let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
104 let field_types: Vec<_> = fields.named.iter().map(|f| &f.ty).collect();
105 let field_name_strs: Vec<_> = fields
106 .named
107 .iter()
108 .map(|f| f.ident.as_ref().unwrap().to_string())
109 .collect();
110
111 let is_optional: Vec<_> = field_types.iter().map(|ty| is_option_type(ty)).collect();
113
114 let inner_types: Vec<_> = field_types
116 .iter()
117 .zip(&is_optional)
118 .map(|(ty, opt)| {
119 if *opt {
120 extract_option_inner(ty)
121 } else {
122 (*ty).clone()
123 }
124 })
125 .collect();
126
127 let name_str = name.to_string();
128
129 let field_descriptors: Vec<_> = field_name_strs
131 .iter()
132 .zip(&field_types)
133 .zip(&is_optional)
134 .map(|((name, ty), opt)| {
135 let type_name = quote!(stringify!(#ty)).to_string();
136 quote! {
137 .field(::tryparse::deserializer::FieldDescriptor::new(
138 #name,
139 #type_name,
140 #opt
141 ))
142 }
143 })
144 .collect();
145
146 let field_extractions_strict: Vec<_> = field_names
148 .iter()
149 .zip(&inner_types)
150 .zip(&is_optional)
151 .map(|((field_name, inner_ty), opt)| {
152 let field_name_str = field_name.as_ref().unwrap().to_string();
153 if *opt {
154 quote! {
156 let #field_name = fields.get(#field_name_str)
157 .and_then(|v| v.downcast_ref::<#inner_ty>())
158 .cloned();
159 }
160 } else {
161 quote! {
163 let #field_name = fields.get(#field_name_str)
164 .and_then(|v| v.downcast_ref::<#inner_ty>())
165 .cloned()?;
166 }
167 }
168 })
169 .collect();
170
171 let field_extractions_lenient: Vec<_> = field_names.iter().zip(&inner_types).zip(&is_optional).map(|((field_name, inner_ty), opt)| {
173 let field_name_str = field_name.as_ref().unwrap().to_string();
174 if *opt {
175 quote! {
177 let #field_name = fields.get(#field_name_str)
178 .and_then(|v| v.downcast_ref::<#inner_ty>())
179 .cloned();
180 }
181 } else {
182 quote! {
184 let #field_name = fields.get(#field_name_str)
185 .and_then(|v| v.downcast_ref::<#inner_ty>())
186 .cloned()
187 .ok_or_else(|| ::tryparse::error::ParseError::DeserializeFailed(
188 ::tryparse::error::DeserializeError::missing_field(#field_name_str)
189 ))?;
190 }
191 }
192 }).collect();
193
194 quote! {
195 fn try_deserialize(
196 value: &::tryparse::value::FlexValue,
197 ctx: &mut ::tryparse::deserializer::CoercionContext,
198 ) -> Option<Self> {
199 use std::any::Any;
200
201 let mut deserializer = ::tryparse::deserializer::StructDeserializer::new()
202 #(#field_descriptors)*;
203
204 let fields = deserializer.try_deserialize(
205 value,
206 ctx,
207 #name_str,
208 |field_name, field_value, field_ctx| {
209 match field_name {
211 #(
212 #field_name_strs => {
213 <#inner_types as ::tryparse::deserializer::LlmDeserialize>::try_deserialize(field_value, field_ctx)
215 .map(|v| Box::new(v) as Box<dyn Any>)
216 }
217 )*
218 _ => None
219 }
220 }
221 ).ok()?;
222
223 #(#field_extractions_strict)*
225
226 Some(Self {
227 #(#field_names),*
228 })
229 }
230
231 fn deserialize(
232 value: &::tryparse::value::FlexValue,
233 ctx: &mut ::tryparse::deserializer::CoercionContext,
234 ) -> ::tryparse::error::Result<Self> {
235 use std::any::Any;
236
237 let mut deserializer = ::tryparse::deserializer::StructDeserializer::new()
238 #(#field_descriptors)*;
239
240 let fields = deserializer.deserialize(
241 value,
242 ctx,
243 #name_str,
244 |field_name, field_value, field_ctx, strict| {
245 match field_name {
247 #(
248 #field_name_strs => {
249 if strict {
250 if let Some(v) = <#inner_types as ::tryparse::deserializer::LlmDeserialize>::try_deserialize(field_value, field_ctx) {
252 Ok(Box::new(v) as Box<dyn Any>)
253 } else {
254 Err(::tryparse::error::ParseError::DeserializeFailed(
255 ::tryparse::error::DeserializeError::type_mismatch(
256 stringify!(#inner_types),
257 "value"
258 )
259 ))
260 }
261 } else {
262 let v = <#inner_types as ::tryparse::deserializer::LlmDeserialize>::deserialize(field_value, field_ctx)?;
264 Ok(Box::new(v) as Box<dyn Any>)
265 }
266 }
267 )*
268 _ => Err(::tryparse::error::ParseError::DeserializeFailed(
269 ::tryparse::error::DeserializeError::Custom(
270 format!("Unknown field: {}", field_name)
271 )
272 ))
273 }
274 }
275 )?;
276
277 #(#field_extractions_lenient)*
279
280 Ok(Self {
281 #(#field_names),*
282 })
283 }
284 }
285 }
286 Fields::Unnamed(_) => syn::Error::new_spanned(
287 data.fields.clone(),
288 "LlmDeserialize does not support tuple structs yet",
289 )
290 .to_compile_error(),
291 Fields::Unit => syn::Error::new_spanned(
292 data.fields.clone(),
293 "LlmDeserialize does not support unit structs",
294 )
295 .to_compile_error(),
296 }
297}
298
299fn is_option_type(ty: &Type) -> bool {
301 if let Type::Path(type_path) = ty {
302 if let Some(segment) = type_path.path.segments.last() {
303 return segment.ident == "Option";
304 }
305 }
306 false
307}
308
309fn extract_option_inner(ty: &Type) -> Type {
311 if let Type::Path(type_path) = ty {
312 if let Some(segment) = type_path.path.segments.last() {
313 if segment.ident == "Option" {
314 if let PathArguments::AngleBracketed(args) = &segment.arguments {
315 if let Some(GenericArgument::Type(inner)) = args.args.first() {
316 return inner.clone();
317 }
318 }
319 }
320 }
321 }
322 ty.clone()
324}
325
326fn generate_enum_deserialize(
327 name: &syn::Ident,
328 data: &syn::DataEnum,
329 _attrs: &[syn::Attribute],
330) -> proc_macro2::TokenStream {
331 let name_str = name.to_string();
332
333 let matcher_setup = data.variants.iter().map(|v| {
335 let variant_name = v.ident.to_string();
336 quote! {
337 .variant(::tryparse::deserializer::enum_coercer::EnumVariant::new(#variant_name))
338 }
339 });
340
341 let match_arms = data.variants.iter().map(|v| {
343 let variant_ident = &v.ident;
344 let variant_name = v.ident.to_string();
345
346 match &v.fields {
347 Fields::Unit => {
348 quote! {
350 #variant_name => Ok(Self::#variant_ident),
351 }
352 }
353 Fields::Named(_) | Fields::Unnamed(_) => {
354 quote! {
357 #variant_name => Err(::tryparse::error::ParseError::DeserializeFailed(
358 ::tryparse::error::DeserializeError::Custom(
359 format!("Enum variant '{}' has fields - derive macro only supports unit variants", #variant_name)
360 )
361 )),
362 }
363 }
364 }
365 });
366
367 quote! {
368 fn deserialize(
369 value: &::tryparse::value::FlexValue,
370 _ctx: &mut ::tryparse::deserializer::CoercionContext,
371 ) -> ::tryparse::error::Result<Self> {
372 let matcher = ::tryparse::deserializer::enum_coercer::EnumMatcher::new()
374 #(#matcher_setup)*;
375
376 let matched_variant = ::tryparse::deserializer::enum_coercer::match_enum_variant(
378 value,
379 &matcher
380 )?;
381
382 match matched_variant.as_str() {
384 #(#match_arms)*
385 _ => Err(::tryparse::error::ParseError::DeserializeFailed(
386 ::tryparse::error::DeserializeError::UnknownVariant {
387 enum_name: #name_str.to_string(),
388 variant: matched_variant,
389 }
390 )),
391 }
392 }
393 }
394}
395
396fn has_union_attribute(attrs: &[syn::Attribute]) -> bool {
398 attrs.iter().any(|attr| {
399 if attr.path().is_ident("llm") {
400 if let Ok(meta_list) = attr.meta.require_list() {
402 return meta_list.tokens.to_string().trim() == "union";
404 }
405 }
406 false
407 })
408}
409
410fn generate_union_deserialize(
412 name: &syn::Ident,
413 data: &syn::DataEnum,
414 _attrs: &[syn::Attribute],
415) -> proc_macro2::TokenStream {
416 if data.variants.len() != 2 {
417 return syn::Error::new_spanned(name, "Union enums must have exactly 2 variants")
418 .to_compile_error();
419 }
420
421 let variants: Vec<_> = data.variants.iter().collect();
422 let variant1 = &variants[0];
423 let variant2 = &variants[1];
424
425 let (variant1_ident, variant1_type) = match &variant1.fields {
427 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
428 (&variant1.ident, &fields.unnamed[0].ty)
429 }
430 _ => {
431 return syn::Error::new_spanned(
432 variant1,
433 "Union variants must have exactly one unnamed field",
434 )
435 .to_compile_error();
436 }
437 };
438
439 let (variant2_ident, variant2_type) = match &variant2.fields {
440 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
441 (&variant2.ident, &fields.unnamed[0].ty)
442 }
443 _ => {
444 return syn::Error::new_spanned(
445 variant2,
446 "Union variants must have exactly one unnamed field",
447 )
448 .to_compile_error();
449 }
450 };
451
452 quote! {
453 fn deserialize(
454 value: &::tryparse::value::FlexValue,
455 ctx: &mut ::tryparse::deserializer::CoercionContext,
456 ) -> ::tryparse::error::Result<Self> {
457 use ::tryparse::deserializer::LlmDeserialize;
458
459 if let Some(v1) = <#variant1_type as LlmDeserialize>::try_deserialize(value, ctx) {
461 ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
463 index: 0,
464 candidates: vec![
465 stringify!(#variant1_type).to_string(),
466 stringify!(#variant2_type).to_string(),
467 ],
468 });
469 return Ok(Self::#variant1_ident(v1));
470 }
471
472 if let Some(v2) = <#variant2_type as LlmDeserialize>::try_deserialize(value, ctx) {
473 ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
475 index: 1,
476 candidates: vec![
477 stringify!(#variant1_type).to_string(),
478 stringify!(#variant2_type).to_string(),
479 ],
480 });
481 return Ok(Self::#variant2_ident(v2));
482 }
483
484 struct MatchResult {
486 variant: u8, score: u32,
488 }
489
490 let mut matches = Vec::new();
491
492 let value1 = value.clone();
494 if let Ok(_) = <#variant1_type as LlmDeserialize>::deserialize(&value1, ctx) {
495 let score: u32 = value1.transformations().iter().map(|t| t.penalty()).sum();
496 matches.push(MatchResult { variant: 1, score });
497 }
498
499 let value2 = value.clone();
501 if let Ok(_) = <#variant2_type as LlmDeserialize>::deserialize(&value2, ctx) {
502 let score: u32 = value2.transformations().iter().map(|t| t.penalty()).sum();
503 matches.push(MatchResult { variant: 2, score });
504 }
505
506 if matches.is_empty() {
507 return Err(::tryparse::error::ParseError::DeserializeFailed(
508 ::tryparse::error::DeserializeError::Custom(
509 "No union variant matched".to_string()
510 )
511 ));
512 }
513
514 matches.sort_by_key(|m| m.score);
516
517 let variant_index = (matches[0].variant - 1) as usize;
519 ctx.add_transformation(::tryparse::value::Transformation::UnionMatch {
520 index: variant_index,
521 candidates: vec![
522 stringify!(#variant1_type).to_string(),
523 stringify!(#variant2_type).to_string(),
524 ],
525 });
526
527 match matches[0].variant {
529 1 => {
530 let v1 = <#variant1_type as LlmDeserialize>::deserialize(value, ctx)?;
531 Ok(Self::#variant1_ident(v1))
532 }
533 2 => {
534 let v2 = <#variant2_type as LlmDeserialize>::deserialize(value, ctx)?;
535 Ok(Self::#variant2_ident(v2))
536 }
537 _ => unreachable!(),
538 }
539 }
540 }
541}