Skip to main content

rust_genai_macros/
lib.rs

1//! Procedural macros for the Rust Gemini SDK.
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{
7    parse_macro_input, Attribute, Data, DeriveInput, Expr, ExprLit, Fields, GenericArgument, Lit,
8    PathArguments, Type,
9};
10
11#[proc_macro_derive(GeminiTool, attributes(gemini))]
12pub fn gemini_tool(input: TokenStream) -> TokenStream {
13    let input = parse_macro_input!(input as DeriveInput);
14    match expand_gemini_tool(&input) {
15        Ok(tokens) => tokens.into(),
16        Err(err) => err.to_compile_error().into(),
17    }
18}
19
20fn expand_gemini_tool(input: &DeriveInput) -> syn::Result<TokenStream2> {
21    let name = &input.ident;
22    let struct_attrs = parse_gemini_attrs(&input.attrs)?;
23    let struct_doc = extract_doc_comment(&input.attrs);
24
25    let GeminiAttr {
26        name: struct_name,
27        description: struct_description,
28        ..
29    } = struct_attrs;
30    let function_name = struct_name.unwrap_or_else(|| name.to_string());
31    let function_description = struct_description.or(struct_doc);
32
33    let fields = match &input.data {
34        Data::Struct(data) => &data.fields,
35        _ => return Err(syn::Error::new_spanned(input, "GeminiTool 仅支持结构体")),
36    };
37
38    let (property_inserts, required_fields, ordering_fields) = collect_schema_fields(fields)?;
39    let description_expr = build_description_expr(function_description);
40
41    Ok(quote! {
42        impl #name {
43            pub fn as_tool() -> ::rust_genai_types::tool::Tool {
44                let mut properties: ::std::collections::HashMap<String, Box<::rust_genai_types::tool::Schema>> =
45                    ::std::collections::HashMap::new();
46                #(#property_inserts)*
47
48                let required: Vec<String> = vec![#(#required_fields),*];
49                let ordering: Vec<String> = vec![#(#ordering_fields),*];
50
51                let schema = ::rust_genai_types::tool::Schema {
52                    ty: Some(::rust_genai_types::enums::Type::Object),
53                    properties: Some(properties),
54                    required: if required.is_empty() { None } else { Some(required) },
55                    property_ordering: if ordering.is_empty() { None } else { Some(ordering) },
56                    ..Default::default()
57                };
58
59                let declaration = ::rust_genai_types::tool::FunctionDeclaration {
60                    name: #function_name.to_string(),
61                    description: #description_expr,
62                    parameters: Some(schema),
63                    parameters_json_schema: None,
64                    response: None,
65                    response_json_schema: None,
66                    behavior: None,
67                };
68
69                ::rust_genai_types::tool::Tool {
70                    function_declarations: Some(vec![declaration]),
71                    ..Default::default()
72                }
73            }
74
75            pub fn from_call(call: &::rust_genai_types::content::FunctionCall) -> ::rust_genai::Result<Self> {
76                if let Some(name) = &call.name {
77                    if name != #function_name {
78                        return Err(::rust_genai::Error::InvalidConfig {
79                            message: format!("Expected {}, got {}", #function_name, name),
80                        });
81                    }
82                }
83
84                let args = call.args.as_ref().ok_or_else(|| ::rust_genai::Error::InvalidConfig {
85                    message: "Missing args".into(),
86                })?;
87
88                let parsed = ::serde_json::from_value(args.clone())?;
89                Ok(parsed)
90            }
91        }
92    })
93}
94
95fn collect_schema_fields(
96    fields: &Fields,
97) -> syn::Result<(Vec<TokenStream2>, Vec<TokenStream2>, Vec<TokenStream2>)> {
98    let mut property_inserts = Vec::new();
99    let mut required_fields = Vec::new();
100    let mut ordering_fields = Vec::new();
101
102    match fields {
103        Fields::Named(named) => {
104            for field in &named.named {
105                let field_ident = field
106                    .ident
107                    .as_ref()
108                    .ok_or_else(|| syn::Error::new_spanned(field, "GeminiTool 仅支持命名字段"))?;
109                let field_attrs = parse_gemini_attrs(&field.attrs)?;
110                if field_attrs.skip {
111                    continue;
112                }
113
114                let field_doc = extract_doc_comment(&field.attrs);
115                let property_name = field_attrs
116                    .name
117                    .clone()
118                    .unwrap_or_else(|| field_ident.to_string());
119
120                let is_optional = is_option_type(&field.ty);
121                let schema_expr =
122                    build_schema_expr(&field.ty, is_optional, &field_attrs, field_doc);
123
124                property_inserts.push(quote! {
125                    {
126                        let schema = #schema_expr;
127                        properties.insert(#property_name.to_string(), Box::new(schema));
128                    }
129                });
130
131                ordering_fields.push(quote! { #property_name.to_string() });
132
133                if field_attrs.required || (!is_optional && !field_attrs.optional) {
134                    required_fields.push(quote! { #property_name.to_string() });
135                }
136            }
137        }
138        _ => {
139            return Err(syn::Error::new_spanned(
140                fields,
141                "GeminiTool 仅支持具名字段结构体",
142            ))
143        }
144    }
145
146    Ok((property_inserts, required_fields, ordering_fields))
147}
148
149fn build_description_expr(function_description: Option<String>) -> TokenStream2 {
150    function_description.map_or_else(
151        || quote!(None),
152        |description| quote!(Some(#description.to_string())),
153    )
154}
155
156#[derive(Default)]
157struct GeminiAttr {
158    name: Option<String>,
159    description: Option<String>,
160    enum_values: Option<Vec<String>>,
161    required: bool,
162    optional: bool,
163    skip: bool,
164}
165
166fn parse_gemini_attrs(attrs: &[Attribute]) -> syn::Result<GeminiAttr> {
167    let mut output = GeminiAttr::default();
168    for attr in attrs {
169        if !attr.path().is_ident("gemini") {
170            continue;
171        }
172        attr.parse_nested_meta(|meta| {
173            if meta.path.is_ident("name") || meta.path.is_ident("rename") {
174                let value: syn::LitStr = meta.value()?.parse()?;
175                output.name = Some(value.value());
176                return Ok(());
177            }
178            if meta.path.is_ident("description") {
179                let value: syn::LitStr = meta.value()?.parse()?;
180                output.description = Some(value.value());
181                return Ok(());
182            }
183            if meta.path.is_ident("enum_values") {
184                let value: syn::LitStr = meta.value()?.parse()?;
185                let values = value
186                    .value()
187                    .split(',')
188                    .map(str::trim)
189                    .filter(|v| !v.is_empty())
190                    .map(ToString::to_string)
191                    .collect::<Vec<_>>();
192                if !values.is_empty() {
193                    output.enum_values = Some(values);
194                }
195                return Ok(());
196            }
197            if meta.path.is_ident("required") {
198                output.required = true;
199                return Ok(());
200            }
201            if meta.path.is_ident("optional") {
202                output.optional = true;
203                return Ok(());
204            }
205            if meta.path.is_ident("skip") {
206                output.skip = true;
207                return Ok(());
208            }
209            Ok(())
210        })?;
211    }
212    Ok(output)
213}
214
215fn extract_doc_comment(attrs: &[Attribute]) -> Option<String> {
216    let mut docs = Vec::new();
217    for attr in attrs {
218        if !attr.path().is_ident("doc") {
219            continue;
220        }
221        if let syn::Meta::NameValue(meta) = &attr.meta {
222            if let Expr::Lit(ExprLit {
223                lit: Lit::Str(text),
224                ..
225            }) = &meta.value
226            {
227                docs.push(text.value().trim().to_string());
228            }
229        }
230    }
231    if docs.is_empty() {
232        None
233    } else {
234        Some(docs.join("\n"))
235    }
236}
237
238fn build_schema_expr(
239    ty: &Type,
240    is_optional: bool,
241    attrs: &GeminiAttr,
242    doc: Option<String>,
243) -> TokenStream2 {
244    let base_expr = schema_expr_for_type(ty);
245    let mut statements = Vec::new();
246    statements.push(quote! { let mut schema = #base_expr; });
247
248    if is_optional {
249        statements.push(quote! { schema.nullable = Some(true); });
250    }
251
252    let description = attrs.description.clone().or(doc);
253    if let Some(description) = description {
254        statements.push(quote! { schema.description = Some(#description.to_string()); });
255    }
256
257    if let Some(values) = &attrs.enum_values {
258        let values_tokens = values.iter().map(|v| quote!(#v.to_string()));
259        statements.push(quote! { schema.enum_values = Some(vec![#(#values_tokens),*]); });
260    }
261
262    statements.push(quote! { schema });
263    quote!({ #(#statements)* })
264}
265
266fn schema_expr_for_type(ty: &Type) -> TokenStream2 {
267    if let Some(inner) = option_inner(ty) {
268        return schema_expr_for_type(inner);
269    }
270    if let Some(inner) = vec_inner(ty) {
271        let inner_expr = schema_expr_for_type(inner);
272        return quote! {
273            ::rust_genai_types::tool::Schema {
274                ty: Some(::rust_genai_types::enums::Type::Array),
275                items: Some(Box::new(#inner_expr)),
276                ..Default::default()
277            }
278        };
279    }
280
281    let ty = strip_reference(ty);
282    if is_serde_json_value(ty) {
283        return quote!(::rust_genai_types::tool::Schema::default());
284    }
285
286    if let Some(ident) = last_path_ident(ty) {
287        let schema = match ident.as_str() {
288            "String" | "str" => quote!(::rust_genai_types::tool::Schema::string()),
289            "bool" | "Boolean" => quote!(::rust_genai_types::tool::Schema::boolean()),
290            "f32" | "f64" => quote!(::rust_genai_types::tool::Schema::number()),
291            "i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64" | "usize" => {
292                quote!(::rust_genai_types::tool::Schema::integer())
293            }
294            _ => quote!(::rust_genai_types::tool::Schema {
295                ty: Some(::rust_genai_types::enums::Type::Object),
296                ..Default::default()
297            }),
298        };
299        return schema;
300    }
301
302    quote!(::rust_genai_types::tool::Schema {
303        ty: Some(::rust_genai_types::enums::Type::Object),
304        ..Default::default()
305    })
306}
307
308fn is_option_type(ty: &Type) -> bool {
309    option_inner(ty).is_some()
310}
311
312fn option_inner(ty: &Type) -> Option<&Type> {
313    let ty = strip_reference(ty);
314    if let Type::Path(path) = ty {
315        if let Some(segment) = path.path.segments.last() {
316            if segment.ident == "Option" {
317                if let PathArguments::AngleBracketed(args) = &segment.arguments {
318                    if let Some(GenericArgument::Type(inner)) = args.args.first() {
319                        return Some(inner);
320                    }
321                }
322            }
323        }
324    }
325    None
326}
327
328fn vec_inner(ty: &Type) -> Option<&Type> {
329    let ty = strip_reference(ty);
330    if let Type::Path(path) = ty {
331        if let Some(segment) = path.path.segments.last() {
332            if segment.ident == "Vec" {
333                if let PathArguments::AngleBracketed(args) = &segment.arguments {
334                    if let Some(GenericArgument::Type(inner)) = args.args.first() {
335                        return Some(inner);
336                    }
337                }
338            }
339        }
340    }
341    None
342}
343
344fn strip_reference(ty: &Type) -> &Type {
345    if let Type::Reference(reference) = ty {
346        return strip_reference(&reference.elem);
347    }
348    ty
349}
350
351fn is_serde_json_value(ty: &Type) -> bool {
352    if let Type::Path(path) = ty {
353        let segments: Vec<_> = path
354            .path
355            .segments
356            .iter()
357            .map(|s| s.ident.to_string())
358            .collect();
359        return segments.as_slice() == ["serde_json", "Value"] || segments.as_slice() == ["Value"];
360    }
361    false
362}
363
364fn last_path_ident(ty: &Type) -> Option<String> {
365    if let Type::Path(path) = ty {
366        return path.path.segments.last().map(|seg| seg.ident.to_string());
367    }
368    None
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use quote::ToTokens;
375    use syn::parse_quote;
376
377    fn normalize_tokens(tokens: &TokenStream2) -> String {
378        tokens.to_string().split_whitespace().collect()
379    }
380
381    #[test]
382    fn parse_gemini_attrs_reads_values() {
383        let attrs: Vec<Attribute> = vec![parse_quote!(
384            #[gemini(
385                name = "tool_name",
386                description = "desc",
387                enum_values = "a, b",
388                required,
389                optional,
390                skip
391            )]
392        )];
393        let parsed = parse_gemini_attrs(&attrs).unwrap();
394        assert_eq!(parsed.name.as_deref(), Some("tool_name"));
395        assert_eq!(parsed.description.as_deref(), Some("desc"));
396        assert_eq!(
397            parsed.enum_values.as_ref().unwrap(),
398            &vec!["a".to_string(), "b".to_string()]
399        );
400        assert!(parsed.required);
401        assert!(parsed.optional);
402        assert!(parsed.skip);
403    }
404
405    #[test]
406    fn parse_gemini_attrs_ignores_empty_enum_values() {
407        let attrs: Vec<Attribute> =
408            vec![parse_quote!(#[gemini(rename = "alias", enum_values = " , ")])];
409        let parsed = parse_gemini_attrs(&attrs).unwrap();
410        assert_eq!(parsed.name.as_deref(), Some("alias"));
411        assert!(parsed.enum_values.is_none());
412    }
413
414    #[test]
415    fn extract_doc_comment_combines_lines() {
416        let attrs: Vec<Attribute> = vec![
417            parse_quote!(#[doc = " First line "]),
418            parse_quote!(#[doc = "Second line"]),
419        ];
420        let docs = extract_doc_comment(&attrs).unwrap();
421        assert_eq!(docs, "First line\nSecond line");
422    }
423
424    #[test]
425    fn expand_gemini_tool_rejects_enum() {
426        let input: DeriveInput = parse_quote!(
427            enum Bad {
428                A,
429            }
430        );
431        let err = expand_gemini_tool(&input).unwrap_err();
432        assert!(err.to_string().contains("GeminiTool 仅支持结构体"));
433    }
434
435    #[test]
436    fn expand_gemini_tool_rejects_tuple_struct() {
437        let input: DeriveInput = parse_quote!(
438            struct Bad(String);
439        );
440        let err = expand_gemini_tool(&input).unwrap_err();
441        assert!(err.to_string().contains("具名字段"));
442    }
443
444    #[test]
445    fn schema_helpers_cover_variants() {
446        let opt_vec: Type = parse_quote!(Option<Vec<String>>);
447        let tokens = normalize_tokens(&schema_expr_for_type(&opt_vec));
448        assert!(tokens.contains("Type::Array"));
449        assert!(tokens.contains("Schema::string"));
450
451        let int_ty: Type = parse_quote!(i64);
452        let tokens = normalize_tokens(&schema_expr_for_type(&int_ty));
453        assert!(tokens.contains("Schema::integer"));
454
455        let unknown: Type = parse_quote!(CustomType);
456        let tokens = normalize_tokens(&schema_expr_for_type(&unknown));
457        assert!(tokens.contains("Type::Object"));
458    }
459
460    #[test]
461    fn build_schema_expr_applies_metadata() {
462        let ty: Type = parse_quote!(Option<String>);
463        let attrs = GeminiAttr {
464            description: Some("desc".to_string()),
465            enum_values: Some(vec!["x".to_string(), "y".to_string()]),
466            ..Default::default()
467        };
468        let tokens = normalize_tokens(&build_schema_expr(&ty, true, &attrs, None));
469        assert!(tokens.contains("nullable=Some(true)"));
470        assert!(tokens.contains("schema.description=Some(\"desc\".to_string())"));
471        assert!(tokens.contains("schema.enum_values=Some"));
472    }
473
474    #[test]
475    fn type_helpers_detect_options_and_vecs() {
476        let ty: Type = parse_quote!(&Option<Vec<u32>>);
477        assert!(is_option_type(&ty));
478        let inner = option_inner(&ty).unwrap();
479        let inner_tokens = inner.to_token_stream().to_string();
480        assert!(inner_tokens.contains("Vec"));
481
482        let vec_ty: Type = parse_quote!(Vec<bool>);
483        assert!(vec_inner(&vec_ty).is_some());
484        assert!(last_path_ident(&vec_ty).is_some());
485        let reference: Type = parse_quote!(&&str);
486        let stripped = strip_reference(&reference);
487        assert!(last_path_ident(stripped).is_some());
488    }
489
490    #[test]
491    fn detects_serde_json_value() {
492        let ty: Type = parse_quote!(serde_json::Value);
493        assert!(is_serde_json_value(&ty));
494        let ty: Type = parse_quote!(Value);
495        assert!(is_serde_json_value(&ty));
496        let ty: Type = parse_quote!(String);
497        assert!(!is_serde_json_value(&ty));
498    }
499}