struct_llm_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Data, Fields, Lit};
4
5/// Derive macro for `StructuredOutput` trait.
6///
7/// # Attributes
8///
9/// - `name`: Tool name (required) - e.g., `name = "final_answer"`
10/// - `description`: Tool description (required) - e.g., `description = "Returns the final answer"`
11///
12/// # Example
13///
14/// ```ignore
15/// #[derive(Serialize, Deserialize, StructuredOutput)]
16/// #[structured_output(
17///     name = "create_npc",
18///     description = "Creates an NPC character"
19/// )]
20/// struct NPCData {
21///     name: String,
22///     age: u32,
23///     backstory: String,
24/// }
25/// ```
26#[proc_macro_derive(StructuredOutput, attributes(structured_output))]
27pub fn derive_structured_output(input: TokenStream) -> TokenStream {
28    let input = parse_macro_input!(input as DeriveInput);
29
30    let name = &input.ident;
31    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33    // Extract attributes
34    let (tool_name, tool_description) = parse_attributes(&input);
35
36    // Generate JSON schema from struct fields
37    let schema = generate_schema(&input.data);
38
39    let expanded = quote! {
40        impl #impl_generics struct_llm::StructuredOutput for #name #ty_generics #where_clause {
41            fn tool_name() -> &'static str {
42                #tool_name
43            }
44
45            fn tool_description() -> &'static str {
46                #tool_description
47            }
48
49            fn json_schema() -> serde_json::Value {
50                serde_json::json!(#schema)
51            }
52        }
53    };
54
55    TokenStream::from(expanded)
56}
57
58fn parse_attributes(input: &DeriveInput) -> (String, String) {
59    let mut tool_name = None;
60    let mut tool_description = None;
61
62    for attr in &input.attrs {
63        if !attr.path().is_ident("structured_output") {
64            continue;
65        }
66
67        let _ = attr.parse_nested_meta(|meta| {
68            if meta.path.is_ident("name") {
69                let value = meta.value()?;
70                let s: Lit = value.parse()?;
71                if let Lit::Str(lit_str) = s {
72                    tool_name = Some(lit_str.value());
73                }
74            } else if meta.path.is_ident("description") {
75                let value = meta.value()?;
76                let s: Lit = value.parse()?;
77                if let Lit::Str(lit_str) = s {
78                    tool_description = Some(lit_str.value());
79                }
80            }
81            Ok(())
82        });
83    }
84
85    let tool_name = tool_name.expect("missing #[structured_output(name = \"...\")] attribute");
86    let tool_description = tool_description.expect("missing #[structured_output(description = \"...\")] attribute");
87
88    (tool_name, tool_description)
89}
90
91fn generate_schema(data: &Data) -> proc_macro2::TokenStream {
92    match data {
93        Data::Struct(data_struct) => generate_struct_schema(&data_struct.fields),
94        Data::Enum(_) => {
95            panic!("StructuredOutput can only be derived for structs, not enums");
96        }
97        Data::Union(_) => {
98            panic!("StructuredOutput can only be derived for unions");
99        }
100    }
101}
102
103fn generate_struct_schema(fields: &Fields) -> proc_macro2::TokenStream {
104    let mut properties = Vec::new();
105    let mut required = Vec::new();
106
107    match fields {
108        Fields::Named(fields_named) => {
109            for field in &fields_named.named {
110                let field_name = field.ident.as_ref().unwrap().to_string();
111                let field_schema = generate_field_schema(&field.ty);
112
113                properties.push(quote! {
114                    #field_name: #field_schema
115                });
116
117                required.push(field_name);
118            }
119        }
120        Fields::Unnamed(_) => {
121            panic!("StructuredOutput does not support tuple structs");
122        }
123        Fields::Unit => {
124            panic!("StructuredOutput does not support unit structs");
125        }
126    }
127
128    let required_fields = required.iter().map(|s| quote! { #s });
129
130    quote! {
131        {
132            "type": "object",
133            "properties": {
134                #(#properties),*
135            },
136            "required": [#(#required_fields),*]
137        }
138    }
139}
140
141fn generate_field_schema(ty: &syn::Type) -> proc_macro2::TokenStream {
142    if let syn::Type::Path(type_path) = ty {
143        if let Some(segment) = type_path.path.segments.last() {
144            let type_name = segment.ident.to_string();
145
146            // Handle Vec with generic argument
147            if type_name == "Vec" {
148                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
149                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
150                        let item_type = infer_json_type(inner_ty);
151                        return quote! {
152                            {
153                                "type": "array",
154                                "items": {
155                                    "type": #item_type
156                                }
157                            }
158                        };
159                    }
160                }
161                // Fallback for Vec without type info
162                return quote! {
163                    {
164                        "type": "array",
165                        "items": {}
166                    }
167                };
168            }
169
170            // Handle Option
171            if type_name == "Option" {
172                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
173                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
174                        // Return the inner type schema (Option makes field non-required)
175                        return generate_field_schema(inner_ty);
176                    }
177                }
178            }
179        }
180    }
181
182    // Default case: simple type
183    let type_str = infer_json_type(ty);
184    quote! {
185        {
186            "type": #type_str
187        }
188    }
189}
190
191fn infer_json_type(ty: &syn::Type) -> &'static str {
192    // Simple type inference - extract the last segment of the path
193    if let syn::Type::Path(type_path) = ty {
194        if let Some(segment) = type_path.path.segments.last() {
195            let type_name = segment.ident.to_string();
196
197            return match type_name.as_str() {
198                "String" | "str" => "string",
199                "i8" | "i16" | "i32" | "i64" | "i128" |
200                "u8" | "u16" | "u32" | "u64" | "u128" |
201                "isize" | "usize" => "integer",
202                "f32" | "f64" => "number",
203                "bool" => "boolean",
204                "Vec" => "array",
205                "HashMap" | "BTreeMap" => "object",
206                _ => {
207                    // Check if it's an Option
208                    if type_name == "Option" {
209                        // For Option types, we need to look at the inner type
210                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
211                            if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
212                                return infer_json_type(inner_ty);
213                            }
214                        }
215                    }
216                    // Default to string for custom types
217                    "string"
218                }
219            };
220        }
221    }
222
223    "string" // Default fallback
224}