1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput, Data, Fields, Lit};
4
5#[proc_macro_derive(StructuredOutput, attributes(structured_output))]
27pub fn derive_structured_output(input: TokenStream) -> TokenStream {
28 let input = parse_macro_input!(input as DeriveInput);
29
30 let name = &input.ident;
31 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33 let (tool_name, tool_description) = parse_attributes(&input);
35
36 let schema = generate_schema(&input.data);
38
39 let expanded = quote! {
40 impl #impl_generics struct_llm::StructuredOutput for #name #ty_generics #where_clause {
41 fn tool_name() -> &'static str {
42 #tool_name
43 }
44
45 fn tool_description() -> &'static str {
46 #tool_description
47 }
48
49 fn json_schema() -> serde_json::Value {
50 #schema
51 }
52 }
53 };
54
55 TokenStream::from(expanded)
56}
57
58fn parse_attributes(input: &DeriveInput) -> (String, String) {
59 let mut tool_name = None;
60 let mut tool_description = None;
61
62 for attr in &input.attrs {
63 if !attr.path().is_ident("structured_output") {
64 continue;
65 }
66
67 let _ = attr.parse_nested_meta(|meta| {
68 if meta.path.is_ident("name") {
69 let value = meta.value()?;
70 let s: Lit = value.parse()?;
71 if let Lit::Str(lit_str) = s {
72 tool_name = Some(lit_str.value());
73 }
74 } else if meta.path.is_ident("description") {
75 let value = meta.value()?;
76 let s: Lit = value.parse()?;
77 if let Lit::Str(lit_str) = s {
78 tool_description = Some(lit_str.value());
79 }
80 }
81 Ok(())
82 });
83 }
84
85 let tool_name = tool_name.expect("missing #[structured_output(name = \"...\")] attribute");
86 let tool_description = tool_description.expect("missing #[structured_output(description = \"...\")] attribute");
87
88 (tool_name, tool_description)
89}
90
91fn generate_schema(data: &Data) -> proc_macro2::TokenStream {
92 match data {
93 Data::Struct(data_struct) => generate_struct_schema(&data_struct.fields),
94 Data::Enum(_) => {
95 panic!("StructuredOutput can only be derived for structs, not enums");
96 }
97 Data::Union(_) => {
98 panic!("StructuredOutput can only be derived for unions");
99 }
100 }
101}
102
103fn generate_field_schema_tokens(ty: &syn::Type) -> proc_macro2::TokenStream {
105 if let syn::Type::Path(type_path) = ty {
106 if let Some(segment) = type_path.path.segments.last() {
107 let type_name = segment.ident.to_string();
108
109 if type_name == "Vec" {
111 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
112 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
113 if is_primitive_type(inner_ty) {
115 let item_type = infer_json_type(inner_ty);
116 return quote! {
117 {
118 let mut items_schema = serde_json::Map::new();
119 items_schema.insert("type".to_string(), serde_json::Value::String(#item_type.to_string()));
120
121 let mut schema = serde_json::Map::new();
122 schema.insert("type".to_string(), serde_json::Value::String("array".to_string()));
123 schema.insert("items".to_string(), serde_json::Value::Object(items_schema));
124 serde_json::Value::Object(schema)
125 }
126 };
127 } else {
128 return quote! {
131 {
132 let inner_schema = <#inner_ty as struct_llm::StructuredOutput>::json_schema();
133 let mut schema = serde_json::Map::new();
134 schema.insert("type".to_string(), serde_json::Value::String("array".to_string()));
135 schema.insert("items".to_string(), inner_schema);
136 serde_json::Value::Object(schema)
137 }
138 };
139 }
140 }
141 }
142 return quote! {
144 {
145 let mut schema = serde_json::Map::new();
146 schema.insert("type".to_string(), serde_json::Value::String("array".to_string()));
147 schema.insert("items".to_string(), serde_json::Value::Object(serde_json::Map::new()));
148 serde_json::Value::Object(schema)
149 }
150 };
151 }
152
153 if type_name == "Option" {
155 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
156 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
157 return generate_field_schema_tokens(inner_ty);
159 }
160 }
161 }
162 }
163 }
164
165 let type_str = infer_json_type(ty);
167 quote! {
168 {
169 let mut schema = serde_json::Map::new();
170 schema.insert("type".to_string(), serde_json::Value::String(#type_str.to_string()));
171 serde_json::Value::Object(schema)
172 }
173 }
174}
175
176fn generate_struct_schema(fields: &Fields) -> proc_macro2::TokenStream {
177 let mut field_insertions = Vec::new();
178 let mut required = Vec::new();
179
180 match fields {
181 Fields::Named(fields_named) => {
182 for field in &fields_named.named {
183 let field_name = field.ident.as_ref().unwrap().to_string();
184 let field_schema = generate_field_schema_tokens(&field.ty);
185
186 field_insertions.push(quote! {
187 properties.insert(#field_name.to_string(), #field_schema);
188 });
189
190 if !is_option_type(&field.ty) {
192 required.push(field_name);
193 }
194 }
195 }
196 Fields::Unnamed(_) => {
197 panic!("StructuredOutput does not support tuple structs");
198 }
199 Fields::Unit => {
200 panic!("StructuredOutput does not support unit structs");
201 }
202 }
203
204 quote! {
205 {
206 let mut properties = serde_json::Map::new();
207 #(#field_insertions)*
208
209 let required_fields: Vec<serde_json::Value> = vec![
210 #(serde_json::Value::String(#required.to_string())),*
211 ];
212
213 let mut schema = serde_json::Map::new();
214 schema.insert("type".to_string(), serde_json::Value::String("object".to_string()));
215 schema.insert("properties".to_string(), serde_json::Value::Object(properties));
216 schema.insert("required".to_string(), serde_json::Value::Array(required_fields));
217 serde_json::Value::Object(schema)
218 }
219 }
220}
221
222fn is_primitive_type(ty: &syn::Type) -> bool {
224 if let syn::Type::Path(type_path) = ty {
225 if let Some(segment) = type_path.path.segments.last() {
226 let type_name = segment.ident.to_string();
227 matches!(
228 type_name.as_str(),
229 "String" | "str" |
230 "i8" | "i16" | "i32" | "i64" | "i128" |
231 "u8" | "u16" | "u32" | "u64" | "u128" |
232 "isize" | "usize" |
233 "f32" | "f64" |
234 "bool"
235 )
236 } else {
237 false
238 }
239 } else {
240 false
241 }
242}
243
244fn is_option_type(ty: &syn::Type) -> bool {
246 if let syn::Type::Path(type_path) = ty {
247 if let Some(segment) = type_path.path.segments.last() {
248 segment.ident == "Option"
249 } else {
250 false
251 }
252 } else {
253 false
254 }
255}
256
257fn infer_json_type(ty: &syn::Type) -> &'static str {
258 if let syn::Type::Path(type_path) = ty {
260 if let Some(segment) = type_path.path.segments.last() {
261 let type_name = segment.ident.to_string();
262
263 return match type_name.as_str() {
264 "String" | "str" => "string",
265 "i8" | "i16" | "i32" | "i64" | "i128" |
266 "u8" | "u16" | "u32" | "u64" | "u128" |
267 "isize" | "usize" => "integer",
268 "f32" | "f64" => "number",
269 "bool" => "boolean",
270 "Vec" => "array",
271 "HashMap" | "BTreeMap" => "object",
272 _ => {
273 if type_name == "Option" {
275 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
277 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
278 return infer_json_type(inner_ty);
279 }
280 }
281 }
282 "string"
284 }
285 };
286 }
287 }
288
289 "string" }