1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Data, Fields, Lit};
4
5#[proc_macro_derive(StructuredOutput, attributes(structured_output))]
27pub fn derive_structured_output(input: TokenStream) -> TokenStream {
28 let input = parse_macro_input!(input as DeriveInput);
29
30 let name = &input.ident;
31 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33 let (tool_name, tool_description) = parse_attributes(&input);
35
36 let schema = generate_schema(&input.data);
38
39 let expanded = quote! {
40 impl #impl_generics struct_llm::StructuredOutput for #name #ty_generics #where_clause {
41 fn tool_name() -> &'static str {
42 #tool_name
43 }
44
45 fn tool_description() -> &'static str {
46 #tool_description
47 }
48
49 fn json_schema() -> serde_json::Value {
50 #schema
51 }
52 }
53 };
54
55 TokenStream::from(expanded)
56}
57
58fn parse_attributes(input: &DeriveInput) -> (String, String) {
59 let mut tool_name = None;
60 let mut tool_description = None;
61
62 for attr in &input.attrs {
63 if !attr.path().is_ident("structured_output") {
64 continue;
65 }
66
67 let _ = attr.parse_nested_meta(|meta| {
68 if meta.path.is_ident("name") {
69 let value = meta.value()?;
70 let s: Lit = value.parse()?;
71 if let Lit::Str(lit_str) = s {
72 tool_name = Some(lit_str.value());
73 }
74 } else if meta.path.is_ident("description") {
75 let value = meta.value()?;
76 let s: Lit = value.parse()?;
77 if let Lit::Str(lit_str) = s {
78 tool_description = Some(lit_str.value());
79 }
80 }
81 Ok(())
82 });
83 }
84
85 let tool_name = tool_name.expect("missing #[structured_output(name = \"...\")] attribute");
86 let tool_description = tool_description.expect("missing #[structured_output(description = \"...\")] attribute");
87
88 (tool_name, tool_description)
89}
90
91fn generate_schema(data: &Data) -> proc_macro2::TokenStream {
92 match data {
93 Data::Struct(data_struct) => generate_struct_schema(&data_struct.fields),
94 Data::Enum(_) => {
95 panic!("StructuredOutput can only be derived for structs, not enums");
96 }
97 Data::Union(_) => {
98 panic!("StructuredOutput can only be derived for unions");
99 }
100 }
101}
102
103fn generate_field_schema_tokens(ty: &syn::Type) -> proc_macro2::TokenStream {
105 if let syn::Type::Path(type_path) = ty {
106 if let Some(segment) = type_path.path.segments.last() {
107 let type_name = segment.ident.to_string();
108
109 if type_name == "Vec" {
111 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
112 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
113 if is_primitive_type(inner_ty) {
115 let item_type = infer_json_type(inner_ty);
116 return quote! {
117 {
118 let mut items_schema = serde_json::Map::new();
119 items_schema.insert("type".to_string(), serde_json::Value::String(#item_type.to_string()));
120
121 let mut schema = serde_json::Map::new();
122 schema.insert("type".to_string(), serde_json::Value::String("array".to_string()));
123 schema.insert("items".to_string(), serde_json::Value::Object(items_schema));
124 serde_json::Value::Object(schema)
125 }
126 };
127 } else {
128 return quote! {
131 {
132 let inner_schema = <#inner_ty as struct_llm::StructuredOutput>::json_schema();
133 let mut schema = serde_json::Map::new();
134 schema.insert("type".to_string(), serde_json::Value::String("array".to_string()));
135 schema.insert("items".to_string(), inner_schema);
136 serde_json::Value::Object(schema)
137 }
138 };
139 }
140 }
141 }
142 return quote! {
144 {
145 let mut schema = serde_json::Map::new();
146 schema.insert("type".to_string(), serde_json::Value::String("array".to_string()));
147 schema.insert("items".to_string(), serde_json::Value::Object(serde_json::Map::new()));
148 serde_json::Value::Object(schema)
149 }
150 };
151 }
152
153 if type_name == "Option" {
155 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
156 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
157 return generate_field_schema_tokens(inner_ty);
159 }
160 }
161 }
162 }
163 }
164
165 let type_str = infer_json_type(ty);
167 quote! {
168 {
169 let mut schema = serde_json::Map::new();
170 schema.insert("type".to_string(), serde_json::Value::String(#type_str.to_string()));
171 serde_json::Value::Object(schema)
172 }
173 }
174}
175
176fn generate_struct_schema(fields: &Fields) -> proc_macro2::TokenStream {
177 let mut field_insertions = Vec::new();
178 let mut required = Vec::new();
179
180 match fields {
181 Fields::Named(fields_named) => {
182 for field in &fields_named.named {
183 let field_name = field.ident.as_ref().unwrap().to_string();
184 let field_schema = generate_field_schema_tokens(&field.ty);
185
186 field_insertions.push(quote! {
187 properties.insert(#field_name.to_string(), #field_schema);
188 });
189
190 required.push(field_name);
191 }
192 }
193 Fields::Unnamed(_) => {
194 panic!("StructuredOutput does not support tuple structs");
195 }
196 Fields::Unit => {
197 panic!("StructuredOutput does not support unit structs");
198 }
199 }
200
201 quote! {
202 {
203 let mut properties = serde_json::Map::new();
204 #(#field_insertions)*
205
206 let required_fields: Vec<serde_json::Value> = vec![
207 #(serde_json::Value::String(#required.to_string())),*
208 ];
209
210 let mut schema = serde_json::Map::new();
211 schema.insert("type".to_string(), serde_json::Value::String("object".to_string()));
212 schema.insert("properties".to_string(), serde_json::Value::Object(properties));
213 schema.insert("required".to_string(), serde_json::Value::Array(required_fields));
214 serde_json::Value::Object(schema)
215 }
216 }
217}
218
219fn is_primitive_type(ty: &syn::Type) -> bool {
221 if let syn::Type::Path(type_path) = ty {
222 if let Some(segment) = type_path.path.segments.last() {
223 let type_name = segment.ident.to_string();
224 matches!(
225 type_name.as_str(),
226 "String" | "str" |
227 "i8" | "i16" | "i32" | "i64" | "i128" |
228 "u8" | "u16" | "u32" | "u64" | "u128" |
229 "isize" | "usize" |
230 "f32" | "f64" |
231 "bool"
232 )
233 } else {
234 false
235 }
236 } else {
237 false
238 }
239}
240
241fn infer_json_type(ty: &syn::Type) -> &'static str {
242 if let syn::Type::Path(type_path) = ty {
244 if let Some(segment) = type_path.path.segments.last() {
245 let type_name = segment.ident.to_string();
246
247 return match type_name.as_str() {
248 "String" | "str" => "string",
249 "i8" | "i16" | "i32" | "i64" | "i128" |
250 "u8" | "u16" | "u32" | "u64" | "u128" |
251 "isize" | "usize" => "integer",
252 "f32" | "f64" => "number",
253 "bool" => "boolean",
254 "Vec" => "array",
255 "HashMap" | "BTreeMap" => "object",
256 _ => {
257 if type_name == "Option" {
259 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
261 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
262 return infer_json_type(inner_ty);
263 }
264 }
265 }
266 "string"
268 }
269 };
270 }
271 }
272
273 "string" }