1use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::quote;
6use syn::{
7 parse_macro_input, Attribute, Data, DeriveInput, Expr, ExprLit, Fields, GenericArgument, Lit,
8 PathArguments, Type,
9};
10
11#[proc_macro_derive(GeminiTool, attributes(gemini))]
12pub fn gemini_tool(input: TokenStream) -> TokenStream {
13 let input = parse_macro_input!(input as DeriveInput);
14 match expand_gemini_tool(&input) {
15 Ok(tokens) => tokens.into(),
16 Err(err) => err.to_compile_error().into(),
17 }
18}
19
20fn expand_gemini_tool(input: &DeriveInput) -> syn::Result<TokenStream2> {
21 let name = &input.ident;
22 let struct_attrs = parse_gemini_attrs(&input.attrs)?;
23 let struct_doc = extract_doc_comment(&input.attrs);
24
25 let GeminiAttr {
26 name: struct_name,
27 description: struct_description,
28 ..
29 } = struct_attrs;
30 let function_name = struct_name.unwrap_or_else(|| name.to_string());
31 let function_description = struct_description.or(struct_doc);
32
33 let fields = match &input.data {
34 Data::Struct(data) => &data.fields,
35 _ => return Err(syn::Error::new_spanned(input, "GeminiTool 仅支持结构体")),
36 };
37
38 let (property_inserts, required_fields, ordering_fields) = collect_schema_fields(fields)?;
39 let description_expr = build_description_expr(function_description);
40
41 Ok(quote! {
42 impl #name {
43 pub fn as_tool() -> ::rust_genai_types::tool::Tool {
44 let mut properties: ::std::collections::HashMap<String, Box<::rust_genai_types::tool::Schema>> =
45 ::std::collections::HashMap::new();
46 #(#property_inserts)*
47
48 let required: Vec<String> = vec![#(#required_fields),*];
49 let ordering: Vec<String> = vec![#(#ordering_fields),*];
50
51 let schema = ::rust_genai_types::tool::Schema {
52 ty: Some(::rust_genai_types::enums::Type::Object),
53 properties: Some(properties),
54 required: if required.is_empty() { None } else { Some(required) },
55 property_ordering: if ordering.is_empty() { None } else { Some(ordering) },
56 ..Default::default()
57 };
58
59 let declaration = ::rust_genai_types::tool::FunctionDeclaration {
60 name: #function_name.to_string(),
61 description: #description_expr,
62 parameters: Some(schema),
63 parameters_json_schema: None,
64 response: None,
65 response_json_schema: None,
66 behavior: None,
67 };
68
69 ::rust_genai_types::tool::Tool {
70 function_declarations: Some(vec![declaration]),
71 ..Default::default()
72 }
73 }
74
75 pub fn from_call(call: &::rust_genai_types::content::FunctionCall) -> ::rust_genai::Result<Self> {
76 if let Some(name) = &call.name {
77 if name != #function_name {
78 return Err(::rust_genai::Error::InvalidConfig {
79 message: format!("Expected {}, got {}", #function_name, name),
80 });
81 }
82 }
83
84 let args = call.args.as_ref().ok_or_else(|| ::rust_genai::Error::InvalidConfig {
85 message: "Missing args".into(),
86 })?;
87
88 let parsed = ::serde_json::from_value(args.clone())?;
89 Ok(parsed)
90 }
91 }
92 })
93}
94
95fn collect_schema_fields(
96 fields: &Fields,
97) -> syn::Result<(Vec<TokenStream2>, Vec<TokenStream2>, Vec<TokenStream2>)> {
98 let mut property_inserts = Vec::new();
99 let mut required_fields = Vec::new();
100 let mut ordering_fields = Vec::new();
101
102 match fields {
103 Fields::Named(named) => {
104 for field in &named.named {
105 let field_ident = field
106 .ident
107 .as_ref()
108 .ok_or_else(|| syn::Error::new_spanned(field, "GeminiTool 仅支持命名字段"))?;
109 let field_attrs = parse_gemini_attrs(&field.attrs)?;
110 if field_attrs.skip {
111 continue;
112 }
113
114 let field_doc = extract_doc_comment(&field.attrs);
115 let property_name = field_attrs
116 .name
117 .clone()
118 .unwrap_or_else(|| field_ident.to_string());
119
120 let is_optional = is_option_type(&field.ty);
121 let schema_expr =
122 build_schema_expr(&field.ty, is_optional, &field_attrs, field_doc);
123
124 property_inserts.push(quote! {
125 {
126 let schema = #schema_expr;
127 properties.insert(#property_name.to_string(), Box::new(schema));
128 }
129 });
130
131 ordering_fields.push(quote! { #property_name.to_string() });
132
133 if field_attrs.required || (!is_optional && !field_attrs.optional) {
134 required_fields.push(quote! { #property_name.to_string() });
135 }
136 }
137 }
138 _ => {
139 return Err(syn::Error::new_spanned(
140 fields,
141 "GeminiTool 仅支持具名字段结构体",
142 ))
143 }
144 }
145
146 Ok((property_inserts, required_fields, ordering_fields))
147}
148
149fn build_description_expr(function_description: Option<String>) -> TokenStream2 {
150 function_description.map_or_else(
151 || quote!(None),
152 |description| quote!(Some(#description.to_string())),
153 )
154}
155
156#[derive(Default)]
157struct GeminiAttr {
158 name: Option<String>,
159 description: Option<String>,
160 enum_values: Option<Vec<String>>,
161 required: bool,
162 optional: bool,
163 skip: bool,
164}
165
166fn parse_gemini_attrs(attrs: &[Attribute]) -> syn::Result<GeminiAttr> {
167 let mut output = GeminiAttr::default();
168 for attr in attrs {
169 if !attr.path().is_ident("gemini") {
170 continue;
171 }
172 attr.parse_nested_meta(|meta| {
173 if meta.path.is_ident("name") || meta.path.is_ident("rename") {
174 let value: syn::LitStr = meta.value()?.parse()?;
175 output.name = Some(value.value());
176 return Ok(());
177 }
178 if meta.path.is_ident("description") {
179 let value: syn::LitStr = meta.value()?.parse()?;
180 output.description = Some(value.value());
181 return Ok(());
182 }
183 if meta.path.is_ident("enum_values") {
184 let value: syn::LitStr = meta.value()?.parse()?;
185 let values = value
186 .value()
187 .split(',')
188 .map(str::trim)
189 .filter(|v| !v.is_empty())
190 .map(ToString::to_string)
191 .collect::<Vec<_>>();
192 if !values.is_empty() {
193 output.enum_values = Some(values);
194 }
195 return Ok(());
196 }
197 if meta.path.is_ident("required") {
198 output.required = true;
199 return Ok(());
200 }
201 if meta.path.is_ident("optional") {
202 output.optional = true;
203 return Ok(());
204 }
205 if meta.path.is_ident("skip") {
206 output.skip = true;
207 return Ok(());
208 }
209 Ok(())
210 })?;
211 }
212 Ok(output)
213}
214
215fn extract_doc_comment(attrs: &[Attribute]) -> Option<String> {
216 let mut docs = Vec::new();
217 for attr in attrs {
218 if !attr.path().is_ident("doc") {
219 continue;
220 }
221 if let syn::Meta::NameValue(meta) = &attr.meta {
222 if let Expr::Lit(ExprLit {
223 lit: Lit::Str(text),
224 ..
225 }) = &meta.value
226 {
227 docs.push(text.value().trim().to_string());
228 }
229 }
230 }
231 if docs.is_empty() {
232 None
233 } else {
234 Some(docs.join("\n"))
235 }
236}
237
238fn build_schema_expr(
239 ty: &Type,
240 is_optional: bool,
241 attrs: &GeminiAttr,
242 doc: Option<String>,
243) -> TokenStream2 {
244 let base_expr = schema_expr_for_type(ty);
245 let mut statements = Vec::new();
246 statements.push(quote! { let mut schema = #base_expr; });
247
248 if is_optional {
249 statements.push(quote! { schema.nullable = Some(true); });
250 }
251
252 let description = attrs.description.clone().or(doc);
253 if let Some(description) = description {
254 statements.push(quote! { schema.description = Some(#description.to_string()); });
255 }
256
257 if let Some(values) = &attrs.enum_values {
258 let values_tokens = values.iter().map(|v| quote!(#v.to_string()));
259 statements.push(quote! { schema.enum_values = Some(vec![#(#values_tokens),*]); });
260 }
261
262 statements.push(quote! { schema });
263 quote!({ #(#statements)* })
264}
265
266fn schema_expr_for_type(ty: &Type) -> TokenStream2 {
267 if let Some(inner) = option_inner(ty) {
268 return schema_expr_for_type(inner);
269 }
270 if let Some(inner) = vec_inner(ty) {
271 let inner_expr = schema_expr_for_type(inner);
272 return quote! {
273 ::rust_genai_types::tool::Schema {
274 ty: Some(::rust_genai_types::enums::Type::Array),
275 items: Some(Box::new(#inner_expr)),
276 ..Default::default()
277 }
278 };
279 }
280
281 let ty = strip_reference(ty);
282 if is_serde_json_value(ty) {
283 return quote!(::rust_genai_types::tool::Schema::default());
284 }
285
286 if let Some(ident) = last_path_ident(ty) {
287 let schema = match ident.as_str() {
288 "String" | "str" => quote!(::rust_genai_types::tool::Schema::string()),
289 "bool" | "Boolean" => quote!(::rust_genai_types::tool::Schema::boolean()),
290 "f32" | "f64" => quote!(::rust_genai_types::tool::Schema::number()),
291 "i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64" | "usize" => {
292 quote!(::rust_genai_types::tool::Schema::integer())
293 }
294 _ => quote!(::rust_genai_types::tool::Schema {
295 ty: Some(::rust_genai_types::enums::Type::Object),
296 ..Default::default()
297 }),
298 };
299 return schema;
300 }
301
302 quote!(::rust_genai_types::tool::Schema {
303 ty: Some(::rust_genai_types::enums::Type::Object),
304 ..Default::default()
305 })
306}
307
308fn is_option_type(ty: &Type) -> bool {
309 option_inner(ty).is_some()
310}
311
312fn option_inner(ty: &Type) -> Option<&Type> {
313 let ty = strip_reference(ty);
314 if let Type::Path(path) = ty {
315 if let Some(segment) = path.path.segments.last() {
316 if segment.ident == "Option" {
317 if let PathArguments::AngleBracketed(args) = &segment.arguments {
318 if let Some(GenericArgument::Type(inner)) = args.args.first() {
319 return Some(inner);
320 }
321 }
322 }
323 }
324 }
325 None
326}
327
328fn vec_inner(ty: &Type) -> Option<&Type> {
329 let ty = strip_reference(ty);
330 if let Type::Path(path) = ty {
331 if let Some(segment) = path.path.segments.last() {
332 if segment.ident == "Vec" {
333 if let PathArguments::AngleBracketed(args) = &segment.arguments {
334 if let Some(GenericArgument::Type(inner)) = args.args.first() {
335 return Some(inner);
336 }
337 }
338 }
339 }
340 }
341 None
342}
343
344fn strip_reference(ty: &Type) -> &Type {
345 if let Type::Reference(reference) = ty {
346 return strip_reference(&reference.elem);
347 }
348 ty
349}
350
351fn is_serde_json_value(ty: &Type) -> bool {
352 if let Type::Path(path) = ty {
353 let segments: Vec<_> = path
354 .path
355 .segments
356 .iter()
357 .map(|s| s.ident.to_string())
358 .collect();
359 return segments.as_slice() == ["serde_json", "Value"] || segments.as_slice() == ["Value"];
360 }
361 false
362}
363
364fn last_path_ident(ty: &Type) -> Option<String> {
365 if let Type::Path(path) = ty {
366 return path.path.segments.last().map(|seg| seg.ident.to_string());
367 }
368 None
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use quote::ToTokens;
375 use syn::parse_quote;
376
377 fn normalize_tokens(tokens: &TokenStream2) -> String {
378 tokens.to_string().split_whitespace().collect()
379 }
380
381 #[test]
382 fn parse_gemini_attrs_reads_values() {
383 let attrs: Vec<Attribute> = vec![parse_quote!(
384 #[gemini(
385 name = "tool_name",
386 description = "desc",
387 enum_values = "a, b",
388 required,
389 optional,
390 skip
391 )]
392 )];
393 let parsed = parse_gemini_attrs(&attrs).unwrap();
394 assert_eq!(parsed.name.as_deref(), Some("tool_name"));
395 assert_eq!(parsed.description.as_deref(), Some("desc"));
396 assert_eq!(
397 parsed.enum_values.as_ref().unwrap(),
398 &vec!["a".to_string(), "b".to_string()]
399 );
400 assert!(parsed.required);
401 assert!(parsed.optional);
402 assert!(parsed.skip);
403 }
404
405 #[test]
406 fn parse_gemini_attrs_ignores_empty_enum_values() {
407 let attrs: Vec<Attribute> =
408 vec![parse_quote!(#[gemini(rename = "alias", enum_values = " , ")])];
409 let parsed = parse_gemini_attrs(&attrs).unwrap();
410 assert_eq!(parsed.name.as_deref(), Some("alias"));
411 assert!(parsed.enum_values.is_none());
412 }
413
414 #[test]
415 fn extract_doc_comment_combines_lines() {
416 let attrs: Vec<Attribute> = vec![
417 parse_quote!(#[doc = " First line "]),
418 parse_quote!(#[doc = "Second line"]),
419 ];
420 let docs = extract_doc_comment(&attrs).unwrap();
421 assert_eq!(docs, "First line\nSecond line");
422 }
423
424 #[test]
425 fn expand_gemini_tool_rejects_enum() {
426 let input: DeriveInput = parse_quote!(
427 enum Bad {
428 A,
429 }
430 );
431 let err = expand_gemini_tool(&input).unwrap_err();
432 assert!(err.to_string().contains("GeminiTool 仅支持结构体"));
433 }
434
435 #[test]
436 fn expand_gemini_tool_rejects_tuple_struct() {
437 let input: DeriveInput = parse_quote!(
438 struct Bad(String);
439 );
440 let err = expand_gemini_tool(&input).unwrap_err();
441 assert!(err.to_string().contains("具名字段"));
442 }
443
444 #[test]
445 fn schema_helpers_cover_variants() {
446 let opt_vec: Type = parse_quote!(Option<Vec<String>>);
447 let tokens = normalize_tokens(&schema_expr_for_type(&opt_vec));
448 assert!(tokens.contains("Type::Array"));
449 assert!(tokens.contains("Schema::string"));
450
451 let int_ty: Type = parse_quote!(i64);
452 let tokens = normalize_tokens(&schema_expr_for_type(&int_ty));
453 assert!(tokens.contains("Schema::integer"));
454
455 let unknown: Type = parse_quote!(CustomType);
456 let tokens = normalize_tokens(&schema_expr_for_type(&unknown));
457 assert!(tokens.contains("Type::Object"));
458 }
459
460 #[test]
461 fn build_schema_expr_applies_metadata() {
462 let ty: Type = parse_quote!(Option<String>);
463 let attrs = GeminiAttr {
464 description: Some("desc".to_string()),
465 enum_values: Some(vec!["x".to_string(), "y".to_string()]),
466 ..Default::default()
467 };
468 let tokens = normalize_tokens(&build_schema_expr(&ty, true, &attrs, None));
469 assert!(tokens.contains("nullable=Some(true)"));
470 assert!(tokens.contains("schema.description=Some(\"desc\".to_string())"));
471 assert!(tokens.contains("schema.enum_values=Some"));
472 }
473
474 #[test]
475 fn type_helpers_detect_options_and_vecs() {
476 let ty: Type = parse_quote!(&Option<Vec<u32>>);
477 assert!(is_option_type(&ty));
478 let inner = option_inner(&ty).unwrap();
479 let inner_tokens = inner.to_token_stream().to_string();
480 assert!(inner_tokens.contains("Vec"));
481
482 let vec_ty: Type = parse_quote!(Vec<bool>);
483 assert!(vec_inner(&vec_ty).is_some());
484 assert!(last_path_ident(&vec_ty).is_some());
485 let reference: Type = parse_quote!(&&str);
486 let stripped = strip_reference(&reference);
487 assert!(last_path_ident(stripped).is_some());
488 }
489
490 #[test]
491 fn detects_serde_json_value() {
492 let ty: Type = parse_quote!(serde_json::Value);
493 assert!(is_serde_json_value(&ty));
494 let ty: Type = parse_quote!(Value);
495 assert!(is_serde_json_value(&ty));
496 let ty: Type = parse_quote!(String);
497 assert!(!is_serde_json_value(&ty));
498 }
499}