struct_llm_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Data, Fields, Lit};
4
5/// Derive macro for `StructuredOutput` trait.
6///
7/// # Attributes
8///
9/// - `name`: Tool name (required) - e.g., `name = "final_answer"`
10/// - `description`: Tool description (required) - e.g., `description = "Returns the final answer"`
11///
12/// # Example
13///
14/// ```ignore
15/// #[derive(Serialize, Deserialize, StructuredOutput)]
16/// #[structured_output(
17///     name = "create_npc",
18///     description = "Creates an NPC character"
19/// )]
20/// struct NPCData {
21///     name: String,
22///     age: u32,
23///     backstory: String,
24/// }
25/// ```
26#[proc_macro_derive(StructuredOutput, attributes(structured_output))]
27pub fn derive_structured_output(input: TokenStream) -> TokenStream {
28    let input = parse_macro_input!(input as DeriveInput);
29
30    let name = &input.ident;
31    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33    // Extract attributes
34    let (tool_name, tool_description) = parse_attributes(&input);
35
36    // Generate JSON schema from struct fields
37    let schema = generate_schema(&input.data);
38
39    let expanded = quote! {
40        impl #impl_generics struct_llm::StructuredOutput for #name #ty_generics #where_clause {
41            fn tool_name() -> &'static str {
42                #tool_name
43            }
44
45            fn tool_description() -> &'static str {
46                #tool_description
47            }
48
49            fn json_schema() -> serde_json::Value {
50                #schema
51            }
52        }
53    };
54
55    TokenStream::from(expanded)
56}
57
58fn parse_attributes(input: &DeriveInput) -> (String, String) {
59    let mut tool_name = None;
60    let mut tool_description = None;
61
62    for attr in &input.attrs {
63        if !attr.path().is_ident("structured_output") {
64            continue;
65        }
66
67        let _ = attr.parse_nested_meta(|meta| {
68            if meta.path.is_ident("name") {
69                let value = meta.value()?;
70                let s: Lit = value.parse()?;
71                if let Lit::Str(lit_str) = s {
72                    tool_name = Some(lit_str.value());
73                }
74            } else if meta.path.is_ident("description") {
75                let value = meta.value()?;
76                let s: Lit = value.parse()?;
77                if let Lit::Str(lit_str) = s {
78                    tool_description = Some(lit_str.value());
79                }
80            }
81            Ok(())
82        });
83    }
84
85    let tool_name = tool_name.expect("missing #[structured_output(name = \"...\")] attribute");
86    let tool_description = tool_description.expect("missing #[structured_output(description = \"...\")] attribute");
87
88    (tool_name, tool_description)
89}
90
91fn generate_schema(data: &Data) -> proc_macro2::TokenStream {
92    match data {
93        Data::Struct(data_struct) => generate_struct_schema(&data_struct.fields),
94        Data::Enum(_) => {
95            panic!("StructuredOutput can only be derived for structs, not enums");
96        }
97        Data::Union(_) => {
98            panic!("StructuredOutput can only be derived for unions");
99        }
100    }
101}
102
103/// Generates a TokenStream that produces a serde_json::Value at runtime
104fn generate_field_schema_tokens(ty: &syn::Type) -> proc_macro2::TokenStream {
105    if let syn::Type::Path(type_path) = ty {
106        if let Some(segment) = type_path.path.segments.last() {
107            let type_name = segment.ident.to_string();
108
109            // Handle Vec with generic argument
110            if type_name == "Vec" {
111                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
112                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
113                        // Check if inner type is a primitive
114                        if is_primitive_type(inner_ty) {
115                            let item_type = infer_json_type(inner_ty);
116                            return quote! {
117                                {
118                                    let mut items_schema = serde_json::Map::new();
119                                    items_schema.insert("type".to_string(), serde_json::Value::String(#item_type.to_string()));
120
121                                    let mut schema = serde_json::Map::new();
122                                    schema.insert("type".to_string(), serde_json::Value::String("array".to_string()));
123                                    schema.insert("items".to_string(), serde_json::Value::Object(items_schema));
124                                    serde_json::Value::Object(schema)
125                                }
126                            };
127                        } else {
128                            // For custom types (structs), call their json_schema() at runtime
129                            // This requires the inner type to implement StructuredOutput
130                            return quote! {
131                                {
132                                    let inner_schema = <#inner_ty as struct_llm::StructuredOutput>::json_schema();
133                                    let mut schema = serde_json::Map::new();
134                                    schema.insert("type".to_string(), serde_json::Value::String("array".to_string()));
135                                    schema.insert("items".to_string(), inner_schema);
136                                    serde_json::Value::Object(schema)
137                                }
138                            };
139                        }
140                    }
141                }
142                // Fallback for Vec without type info
143                return quote! {
144                    {
145                        let mut schema = serde_json::Map::new();
146                        schema.insert("type".to_string(), serde_json::Value::String("array".to_string()));
147                        schema.insert("items".to_string(), serde_json::Value::Object(serde_json::Map::new()));
148                        serde_json::Value::Object(schema)
149                    }
150                };
151            }
152
153            // Handle Option
154            if type_name == "Option" {
155                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
156                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
157                        // Return the inner type schema (Option makes field non-required)
158                        return generate_field_schema_tokens(inner_ty);
159                    }
160                }
161            }
162        }
163    }
164
165    // Default case: simple type
166    let type_str = infer_json_type(ty);
167    quote! {
168        {
169            let mut schema = serde_json::Map::new();
170            schema.insert("type".to_string(), serde_json::Value::String(#type_str.to_string()));
171            serde_json::Value::Object(schema)
172        }
173    }
174}
175
176fn generate_struct_schema(fields: &Fields) -> proc_macro2::TokenStream {
177    let mut field_insertions = Vec::new();
178    let mut required = Vec::new();
179
180    match fields {
181        Fields::Named(fields_named) => {
182            for field in &fields_named.named {
183                let field_name = field.ident.as_ref().unwrap().to_string();
184                let field_schema = generate_field_schema_tokens(&field.ty);
185
186                field_insertions.push(quote! {
187                    properties.insert(#field_name.to_string(), #field_schema);
188                });
189
190                // Only add to required if NOT an Option type
191                if !is_option_type(&field.ty) {
192                    required.push(field_name);
193                }
194            }
195        }
196        Fields::Unnamed(_) => {
197            panic!("StructuredOutput does not support tuple structs");
198        }
199        Fields::Unit => {
200            panic!("StructuredOutput does not support unit structs");
201        }
202    }
203
204    quote! {
205        {
206            let mut properties = serde_json::Map::new();
207            #(#field_insertions)*
208
209            let required_fields: Vec<serde_json::Value> = vec![
210                #(serde_json::Value::String(#required.to_string())),*
211            ];
212
213            let mut schema = serde_json::Map::new();
214            schema.insert("type".to_string(), serde_json::Value::String("object".to_string()));
215            schema.insert("properties".to_string(), serde_json::Value::Object(properties));
216            schema.insert("required".to_string(), serde_json::Value::Array(required_fields));
217            serde_json::Value::Object(schema)
218        }
219    }
220}
221
222/// Check if a type is a known primitive that maps directly to a JSON type
223fn is_primitive_type(ty: &syn::Type) -> bool {
224    if let syn::Type::Path(type_path) = ty {
225        if let Some(segment) = type_path.path.segments.last() {
226            let type_name = segment.ident.to_string();
227            matches!(
228                type_name.as_str(),
229                "String" | "str" |
230                "i8" | "i16" | "i32" | "i64" | "i128" |
231                "u8" | "u16" | "u32" | "u64" | "u128" |
232                "isize" | "usize" |
233                "f32" | "f64" |
234                "bool"
235            )
236        } else {
237            false
238        }
239    } else {
240        false
241    }
242}
243
244/// Check if a type is Option<T>
245fn is_option_type(ty: &syn::Type) -> bool {
246    if let syn::Type::Path(type_path) = ty {
247        if let Some(segment) = type_path.path.segments.last() {
248            segment.ident == "Option"
249        } else {
250            false
251        }
252    } else {
253        false
254    }
255}
256
257fn infer_json_type(ty: &syn::Type) -> &'static str {
258    // Simple type inference - extract the last segment of the path
259    if let syn::Type::Path(type_path) = ty {
260        if let Some(segment) = type_path.path.segments.last() {
261            let type_name = segment.ident.to_string();
262
263            return match type_name.as_str() {
264                "String" | "str" => "string",
265                "i8" | "i16" | "i32" | "i64" | "i128" |
266                "u8" | "u16" | "u32" | "u64" | "u128" |
267                "isize" | "usize" => "integer",
268                "f32" | "f64" => "number",
269                "bool" => "boolean",
270                "Vec" => "array",
271                "HashMap" | "BTreeMap" => "object",
272                _ => {
273                    // Check if it's an Option
274                    if type_name == "Option" {
275                        // For Option types, we need to look at the inner type
276                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
277                            if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
278                                return infer_json_type(inner_ty);
279                            }
280                        }
281                    }
282                    // Default to string for custom types
283                    "string"
284                }
285            };
286        }
287    }
288
289    "string" // Default fallback
290}