rust_mcp_macros/
utils.rs

1use quote::quote;
2use syn::{punctuated::Punctuated, token, Attribute, Path, PathArguments, Type};
3
4// Check if a type is an Option<T>
5pub fn is_option(ty: &Type) -> bool {
6    if let Type::Path(type_path) = ty {
7        if type_path.path.segments.len() == 1 {
8            let segment = &type_path.path.segments[0];
9            return segment.ident == "Option"
10                && matches!(segment.arguments, PathArguments::AngleBracketed(_));
11        }
12    }
13    false
14}
15
16// Check if a type is a Vec<T>
17#[allow(unused)]
18pub fn is_vec(ty: &Type) -> bool {
19    if let Type::Path(type_path) = ty {
20        if type_path.path.segments.len() == 1 {
21            let segment = &type_path.path.segments[0];
22            return segment.ident == "Vec"
23                && matches!(segment.arguments, PathArguments::AngleBracketed(_));
24        }
25    }
26    false
27}
28
29// Extract the inner type from Vec<T> or Option<T>
30#[allow(unused)]
31pub fn inner_type(ty: &Type) -> Option<&Type> {
32    if let Type::Path(type_path) = ty {
33        if type_path.path.segments.len() == 1 {
34            let segment = &type_path.path.segments[0];
35            if matches!(segment.arguments, PathArguments::AngleBracketed(_)) {
36                if let PathArguments::AngleBracketed(args) = &segment.arguments {
37                    if args.args.len() == 1 {
38                        if let syn::GenericArgument::Type(inner_ty) = &args.args[0] {
39                            return Some(inner_ty);
40                        }
41                    }
42                }
43            }
44        }
45    }
46    None
47}
48
49fn doc_comment(attrs: &[Attribute]) -> Option<String> {
50    let mut docs = Vec::new();
51    for attr in attrs {
52        if attr.path().is_ident("doc") {
53            if let syn::Meta::NameValue(meta) = &attr.meta {
54                // Match value as Expr::Lit, then extract Lit::Str
55                if let syn::Expr::Lit(expr_lit) = &meta.value {
56                    if let syn::Lit::Str(lit_str) = &expr_lit.lit {
57                        docs.push(lit_str.value().trim().to_string());
58                    }
59                }
60            }
61        }
62    }
63    if docs.is_empty() {
64        None
65    } else {
66        Some(docs.join("\n"))
67    }
68}
69
70pub fn might_be_struct(ty: &Type) -> bool {
71    if let Type::Path(type_path) = ty {
72        if type_path.path.segments.len() == 1 {
73            let ident = type_path.path.segments[0].ident.to_string();
74            let common_types = vec![
75                "i8", "i16", "i32", "i64", "i128", "u8", "u16", "u32", "u64", "u128", "f32", "f64",
76                "bool", "char", "str", "String", "Vec", "Option",
77            ];
78            return !common_types.contains(&ident.as_str())
79                && type_path.path.segments[0].arguments.is_empty();
80        }
81    }
82    false
83}
84
85pub fn type_to_json_schema(ty: &Type, attrs: &[Attribute]) -> proc_macro2::TokenStream {
86    let number_types = [
87        "i8", "i16", "i32", "i64", "i128", "u8", "u16", "u32", "u64", "u128", "f32", "f64",
88    ];
89    let doc_comment = doc_comment(attrs);
90    let description = doc_comment.as_ref().map(|desc| {
91        quote! {
92            map.insert("description".to_string(), serde_json::Value::String(#desc.to_string()));
93        }
94    });
95    match ty {
96        Type::Path(type_path) => {
97            if type_path.path.segments.len() == 1 {
98                let segment = &type_path.path.segments[0];
99                let ident = &segment.ident;
100
101                // Handle Option<T>
102                if ident == "Option" {
103                    if let PathArguments::AngleBracketed(args) = &segment.arguments {
104                        if args.args.len() == 1 {
105                            if let syn::GenericArgument::Type(inner_ty) = &args.args[0] {
106                                let inner_schema = type_to_json_schema(inner_ty, attrs);
107                                return quote! {
108                                    {
109                                        let mut map = serde_json::Map::new();
110                                        let inner_map = #inner_schema;
111                                        for (k, v) in inner_map {
112                                            map.insert(k, v);
113                                        }
114                                        map.insert("nullable".to_string(), serde_json::Value::Bool(true));
115                                        #description
116                                        map
117                                    }
118                                };
119                            }
120                        }
121                    }
122                }
123                // Handle Vec<T>
124                else if ident == "Vec" {
125                    if let PathArguments::AngleBracketed(args) = &segment.arguments {
126                        if args.args.len() == 1 {
127                            if let syn::GenericArgument::Type(inner_ty) = &args.args[0] {
128                                let inner_schema = type_to_json_schema(inner_ty, &[]);
129                                return quote! {
130                                    {
131                                        let mut map = serde_json::Map::new();
132                                        map.insert("type".to_string(), serde_json::Value::String("array".to_string()));
133                                        map.insert("items".to_string(), serde_json::Value::Object(#inner_schema));
134                                        #description
135                                        map
136                                    }
137                                };
138                            }
139                        }
140                    }
141                }
142                // Handle nested structs
143                else if might_be_struct(ty) {
144                    let path = &type_path.path;
145                    return quote! {
146                        {
147                            let inner_schema = #path::json_schema();
148                            inner_schema
149                        }
150                    };
151                }
152                // Handle basic types
153                else if ident == "String" {
154                    return quote! {
155                        {
156                            let mut map = serde_json::Map::new();
157                            map.insert("type".to_string(), serde_json::Value::String("string".to_string()));
158                            #description
159                            map
160                        }
161                    };
162                } else if number_types.iter().any(|t| ident == t) {
163                    return quote! {
164                        {
165                            let mut map = serde_json::Map::new();
166                            map.insert("type".to_string(), serde_json::Value::String("number".to_string()));
167                            #description
168                            map
169                        }
170                    };
171                } else if ident == "bool" {
172                    return quote! {
173                        {
174                            let mut map = serde_json::Map::new();
175                            map.insert("type".to_string(), serde_json::Value::String("boolean".to_string()));
176                            #description
177                            map
178                        }
179                    };
180                }
181            }
182            // Fallback for unknown types
183            quote! {
184                {
185                    let mut map = serde_json::Map::new();
186                    map.insert("type".to_string(), serde_json::Value::String("unknown".to_string()));
187                    #description
188                    map
189                }
190            }
191        }
192        _ => quote! {
193            {
194                let mut map = serde_json::Map::new();
195                map.insert("type".to_string(), serde_json::Value::String("unknown".to_string()));
196                #description
197                map
198            }
199        },
200    }
201}
202
203#[allow(unused)]
204pub fn has_derive(attrs: &[Attribute], trait_name: &str) -> bool {
205    attrs.iter().any(|attr| {
206        if attr.path().is_ident("derive") {
207            // Parse the derive arguments as a comma-separated list of paths
208            let parsed = attr.parse_args_with(Punctuated::<Path, token::Comma>::parse_terminated);
209            if let Ok(derive_paths) = parsed {
210                let derived = derive_paths.iter().any(|path| path.is_ident(trait_name));
211                return derived;
212            }
213        }
214        false
215    })
216}
217
218pub fn renamed_field(attrs: &[Attribute]) -> Option<String> {
219    let mut renamed = None;
220
221    for attr in attrs {
222        if attr.path().is_ident("serde") {
223            // Ignore other serde meta items (e.g., skip_serializing_if)
224            let _ = attr.parse_nested_meta(|meta| {
225                if meta.path.is_ident("rename") {
226                    if let Ok(lit) = meta.value() {
227                        if let Ok(syn::Lit::Str(lit_str)) = lit.parse() {
228                            renamed = Some(lit_str.value());
229                        }
230                    }
231                }
232                Ok(())
233            });
234        }
235    }
236
237    renamed
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use quote::quote;
244    use syn::parse_quote;
245
246    fn render(ts: proc_macro2::TokenStream) -> String {
247        ts.to_string().replace(char::is_whitespace, "")
248    }
249
250    #[test]
251    fn test_is_option() {
252        let ty: Type = parse_quote!(Option<String>);
253        assert!(is_option(&ty));
254
255        let ty: Type = parse_quote!(Vec<String>);
256        assert!(!is_option(&ty));
257    }
258
259    #[test]
260    fn test_is_vec() {
261        let ty: Type = parse_quote!(Vec<i32>);
262        assert!(is_vec(&ty));
263
264        let ty: Type = parse_quote!(Option<i32>);
265        assert!(!is_vec(&ty));
266    }
267
268    #[test]
269    fn test_inner_type() {
270        let ty: Type = parse_quote!(Option<String>);
271        let inner = inner_type(&ty);
272        assert!(inner.is_some());
273        let inner = inner.unwrap();
274        assert_eq!(quote!(#inner).to_string(), quote!(String).to_string());
275
276        let ty: Type = parse_quote!(Vec<i32>);
277        let inner = inner_type(&ty);
278        assert!(inner.is_some());
279        let inner = inner.unwrap();
280        assert_eq!(quote!(#inner).to_string(), quote!(i32).to_string());
281
282        let ty: Type = parse_quote!(i32);
283        assert!(inner_type(&ty).is_none());
284    }
285
286    #[test]
287    fn test_might_be_struct() {
288        let ty: Type = parse_quote!(MyStruct);
289        assert!(might_be_struct(&ty));
290
291        let ty: Type = parse_quote!(String);
292        assert!(!might_be_struct(&ty));
293    }
294
295    #[test]
296    fn test_type_to_json_schema_string() {
297        let ty: Type = parse_quote!(String);
298        let attrs: Vec<Attribute> = vec![];
299        let tokens = type_to_json_schema(&ty, &attrs);
300        let output = tokens.to_string();
301        assert!(output.contains("\"string\""));
302    }
303
304    #[test]
305    fn test_type_to_json_schema_option() {
306        let ty: Type = parse_quote!(Option<i32>);
307        let attrs: Vec<Attribute> = vec![];
308        let tokens = type_to_json_schema(&ty, &attrs);
309        let output = tokens.to_string();
310        assert!(output.contains("\"nullable\""));
311    }
312
313    #[test]
314    fn test_type_to_json_schema_vec() {
315        let ty: Type = parse_quote!(Vec<String>);
316        let attrs: Vec<Attribute> = vec![];
317        let tokens = type_to_json_schema(&ty, &attrs);
318        let output = tokens.to_string();
319        assert!(output.contains("\"array\""));
320    }
321
322    #[test]
323    fn test_has_derive() {
324        let attr: Attribute = parse_quote!(#[derive(Clone, Debug)]);
325        assert!(has_derive(&[attr.clone()], "Debug"));
326        assert!(!has_derive(&[attr], "Serialize"));
327    }
328
329    #[test]
330    fn test_renamed_field() {
331        let attr: Attribute = parse_quote!(#[serde(rename = "renamed")]);
332        assert_eq!(renamed_field(&[attr]), Some("renamed".to_string()));
333
334        let attr: Attribute = parse_quote!(#[serde(skip_serializing_if = "Option::is_none")]);
335        assert_eq!(renamed_field(&[attr]), None);
336    }
337
338    #[test]
339    fn test_get_doc_comment_single_line() {
340        let attrs: Vec<Attribute> = vec![parse_quote!(#[doc = "This is a test comment."])];
341        let result = super::doc_comment(&attrs);
342        assert_eq!(result, Some("This is a test comment.".to_string()));
343    }
344
345    #[test]
346    fn test_get_doc_comment_multi_line() {
347        let attrs: Vec<Attribute> = vec![
348            parse_quote!(#[doc = "Line one."]),
349            parse_quote!(#[doc = "Line two."]),
350            parse_quote!(#[doc = "Line three."]),
351        ];
352        let result = super::doc_comment(&attrs);
353        assert_eq!(
354            result,
355            Some("Line one.\nLine two.\nLine three.".to_string())
356        );
357    }
358
359    #[test]
360    fn test_get_doc_comment_no_doc() {
361        let attrs: Vec<Attribute> = vec![parse_quote!(#[allow(dead_code)])];
362        let result = super::doc_comment(&attrs);
363        assert_eq!(result, None);
364    }
365
366    #[test]
367    fn test_get_doc_comment_trim_whitespace() {
368        let attrs: Vec<Attribute> = vec![parse_quote!(#[doc = "  Trimmed line.  "])];
369        let result = super::doc_comment(&attrs);
370        assert_eq!(result, Some("Trimmed line.".to_string()));
371    }
372
373    #[test]
374    fn test_renamed_field_basic() {
375        let attrs = vec![parse_quote!(#[serde(rename = "new_name")])];
376        let result = renamed_field(&attrs);
377        assert_eq!(result, Some("new_name".to_string()));
378    }
379
380    #[test]
381    fn test_renamed_field_without_rename() {
382        let attrs = vec![parse_quote!(#[serde(default)])];
383        let result = renamed_field(&attrs);
384        assert_eq!(result, None);
385    }
386
387    #[test]
388    fn test_renamed_field_with_multiple_attrs() {
389        let attrs = vec![
390            parse_quote!(#[serde(default)]),
391            parse_quote!(#[serde(rename = "actual_name")]),
392        ];
393        let result = renamed_field(&attrs);
394        assert_eq!(result, Some("actual_name".to_string()));
395    }
396
397    #[test]
398    fn test_renamed_field_irrelevant_attribute() {
399        let attrs = vec![parse_quote!(#[some_other_attr(value = "irrelevant")])];
400        let result = renamed_field(&attrs);
401        assert_eq!(result, None);
402    }
403
404    #[test]
405    fn test_renamed_field_ignores_other_serde_keys() {
406        let attrs = vec![parse_quote!(#[serde(skip_serializing_if = "Option::is_none")])];
407        let result = renamed_field(&attrs);
408        assert_eq!(result, None);
409    }
410
411    #[test]
412    fn test_has_derive_positive() {
413        let attrs: Vec<Attribute> = vec![parse_quote!(#[derive(Debug, Clone)])];
414        assert!(has_derive(&attrs, "Debug"));
415        assert!(has_derive(&attrs, "Clone"));
416    }
417
418    #[test]
419    fn test_has_derive_negative() {
420        let attrs: Vec<Attribute> = vec![parse_quote!(#[derive(Serialize, Deserialize)])];
421        assert!(!has_derive(&attrs, "Debug"));
422    }
423
424    #[test]
425    fn test_has_derive_no_derive_attr() {
426        let attrs: Vec<Attribute> = vec![parse_quote!(#[allow(dead_code)])];
427        assert!(!has_derive(&attrs, "Debug"));
428    }
429
430    #[test]
431    fn test_has_derive_multiple_attrs() {
432        let attrs: Vec<Attribute> = vec![
433            parse_quote!(#[allow(unused)]),
434            parse_quote!(#[derive(PartialEq)]),
435            parse_quote!(#[derive(Eq)]),
436        ];
437        assert!(has_derive(&attrs, "PartialEq"));
438        assert!(has_derive(&attrs, "Eq"));
439        assert!(!has_derive(&attrs, "Clone"));
440    }
441
442    #[test]
443    fn test_has_derive_empty_attrs() {
444        let attrs: Vec<Attribute> = vec![];
445        assert!(!has_derive(&attrs, "Debug"));
446    }
447
448    #[test]
449    fn test_might_be_struct_with_custom_type() {
450        let ty: syn::Type = parse_quote!(MyStruct);
451        assert!(might_be_struct(&ty));
452    }
453
454    #[test]
455    fn test_might_be_struct_with_primitive_type() {
456        let primitives = [
457            "i32", "u64", "bool", "f32", "String", "Option", "Vec", "char", "str",
458        ];
459        for ty_str in &primitives {
460            let ty: syn::Type = syn::parse_str(ty_str).unwrap();
461            assert!(
462                !might_be_struct(&ty),
463                "Expected '{}' to be not a struct",
464                ty_str
465            );
466        }
467    }
468
469    #[test]
470    fn test_might_be_struct_with_namespaced_type() {
471        let ty: syn::Type = parse_quote!(std::collections::HashMap<String, i32>);
472        assert!(!might_be_struct(&ty)); // segments.len() > 1
473    }
474
475    #[test]
476    fn test_might_be_struct_with_generic_arguments() {
477        let ty: syn::Type = parse_quote!(MyStruct<T>);
478        assert!(!might_be_struct(&ty)); // has type arguments
479    }
480
481    #[test]
482    fn test_might_be_struct_with_empty_type_path() {
483        let ty: syn::Type = parse_quote!(());
484        assert!(!might_be_struct(&ty));
485    }
486
487    #[test]
488    fn test_json_schema_string() {
489        let ty: syn::Type = parse_quote!(String);
490        let tokens = type_to_json_schema(&ty, &[]);
491        let output = render(tokens);
492        assert!(output
493            .contains("\"type\".to_string(),serde_json::Value::String(\"string\".to_string())"));
494    }
495
496    #[test]
497    fn test_json_schema_number() {
498        let ty: syn::Type = parse_quote!(i32);
499        let tokens = type_to_json_schema(&ty, &[]);
500        let output = render(tokens);
501        assert!(output
502            .contains("\"type\".to_string(),serde_json::Value::String(\"number\".to_string())"));
503    }
504
505    #[test]
506    fn test_json_schema_boolean() {
507        let ty: syn::Type = parse_quote!(bool);
508        let tokens = type_to_json_schema(&ty, &[]);
509        let output = render(tokens);
510        assert!(output
511            .contains("\"type\".to_string(),serde_json::Value::String(\"boolean\".to_string())"));
512    }
513
514    #[test]
515    fn test_json_schema_vec_of_string() {
516        let ty: syn::Type = parse_quote!(Vec<String>);
517        let tokens = type_to_json_schema(&ty, &[]);
518        let output = render(tokens);
519        assert!(output
520            .contains("\"type\".to_string(),serde_json::Value::String(\"array\".to_string())"));
521        assert!(output.contains("\"items\".to_string(),serde_json::Value::Object"));
522    }
523
524    #[test]
525    fn test_json_schema_option_of_number() {
526        let ty: syn::Type = parse_quote!(Option<u64>);
527        let tokens = type_to_json_schema(&ty, &[]);
528        let output = render(tokens);
529        assert!(output.contains("\"nullable\".to_string(),serde_json::Value::Bool(true)"));
530        assert!(output
531            .contains("\"type\".to_string(),serde_json::Value::String(\"number\".to_string())"));
532    }
533
534    #[test]
535    fn test_json_schema_custom_struct() {
536        let ty: syn::Type = parse_quote!(MyStruct);
537        let tokens = type_to_json_schema(&ty, &[]);
538        let output = render(tokens);
539        assert!(output.contains("MyStruct::json_schema()"));
540    }
541
542    #[test]
543    fn test_json_schema_with_doc_comment() {
544        let ty: syn::Type = parse_quote!(String);
545        let attrs: Vec<Attribute> = vec![parse_quote!(#[doc = "A user name."])];
546        let tokens = type_to_json_schema(&ty, &attrs);
547        let output = render(tokens);
548        assert!(output.contains(
549            "\"description\".to_string(),serde_json::Value::String(\"Ausername.\".to_string())"
550        ));
551    }
552
553    #[test]
554    fn test_json_schema_fallback_unknown() {
555        let ty: syn::Type = parse_quote!((i32, i32));
556        let tokens = type_to_json_schema(&ty, &[]);
557        let output = render(tokens);
558        assert!(output
559            .contains("\"type\".to_string(),serde_json::Value::String(\"unknown\".to_string())"));
560    }
561}