rust_mcp_macros/
lib.rs

1extern crate proc_macro;
2
3mod utils;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{
8    parse::Parse, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Error, Expr,
9    ExprLit, Fields, Lit, Meta, Token,
10};
11use utils::{is_option, renamed_field, type_to_json_schema};
12
13/// Represents the attributes for the `mcp_tool` procedural macro.
14///
15/// This struct parses and validates the `name` and `description` attributes provided
16/// to the `mcp_tool` macro. Both attributes are required and must not be empty strings.
17///
18/// # Fields
19/// * `name` - An optional string representing the tool's name.
20/// * `description` - An optional string describing the tool.
21///
22/// The following fields are available only with the `2025_03_26` feature:
23/// * `destructive_hint` - Optional boolean for `ToolAnnotations::destructive_hint`.
24/// * `idempotent_hint` - Optional boolean for `ToolAnnotations::idempotent_hint`.
25/// * `open_world_hint` - Optional boolean for `ToolAnnotations::open_world_hint`.
26/// * `read_only_hint` - Optional boolean for `ToolAnnotations::read_only_hint`.
27/// * `title` - Optional string for `ToolAnnotations::title`.
28///
29struct McpToolMacroAttributes {
30    name: Option<String>,
31    description: Option<String>,
32    #[cfg(feature = "2025_03_26")]
33    destructive_hint: Option<bool>,
34    #[cfg(feature = "2025_03_26")]
35    idempotent_hint: Option<bool>,
36    #[cfg(feature = "2025_03_26")]
37    open_world_hint: Option<bool>,
38    #[cfg(feature = "2025_03_26")]
39    read_only_hint: Option<bool>,
40    #[cfg(feature = "2025_03_26")]
41    title: Option<String>,
42}
43
44use syn::parse::ParseStream;
45
46struct ExprList {
47    exprs: Punctuated<Expr, Token![,]>,
48}
49
50impl Parse for ExprList {
51    fn parse(input: ParseStream) -> syn::Result<Self> {
52        Ok(ExprList {
53            exprs: Punctuated::parse_terminated(input)?,
54        })
55    }
56}
57
58impl Parse for McpToolMacroAttributes {
59    /// Parses the macro attributes from a `ParseStream`.
60    ///
61    /// This implementation extracts `name` and `description` from the attribute input,
62    /// ensuring they are provided as string literals and are non-empty.
63    ///
64    /// # Errors
65    /// Returns a `syn::Error` if:
66    /// - The `name` attribute is missing or empty.
67    /// - The `description` attribute is missing or empty.
68    fn parse(attributes: syn::parse::ParseStream) -> syn::Result<Self> {
69        let mut instance = Self {
70            name: None,
71            description: None,
72            #[cfg(feature = "2025_03_26")]
73            destructive_hint: None,
74            #[cfg(feature = "2025_03_26")]
75            idempotent_hint: None,
76            #[cfg(feature = "2025_03_26")]
77            open_world_hint: None,
78            #[cfg(feature = "2025_03_26")]
79            read_only_hint: None,
80            #[cfg(feature = "2025_03_26")]
81            title: None,
82        };
83
84        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(attributes)?;
85        for meta in meta_list {
86            if let Meta::NameValue(meta_name_value) = meta {
87                let ident = meta_name_value.path.get_ident().unwrap();
88                let ident_str = ident.to_string();
89
90                match ident_str.as_str() {
91                    "name" | "description" => {
92                        let value = match &meta_name_value.value {
93                            Expr::Lit(ExprLit {
94                                lit: Lit::Str(lit_str),
95                                ..
96                            }) => lit_str.value(),
97                            Expr::Macro(expr_macro) => {
98                                let mac = &expr_macro.mac;
99                                if mac.path.is_ident("concat") {
100                                    let args: ExprList = syn::parse2(mac.tokens.clone())?;
101                                    let mut result = String::new();
102                                    for expr in args.exprs {
103                                        if let Expr::Lit(ExprLit {
104                                            lit: Lit::Str(lit_str),
105                                            ..
106                                        }) = expr
107                                        {
108                                            result.push_str(&lit_str.value());
109                                        } else {
110                                            return Err(Error::new_spanned(
111                                                expr,
112                                                "Only string literals are allowed inside concat!()",
113                                            ));
114                                        }
115                                    }
116                                    result
117                                } else {
118                                    return Err(Error::new_spanned(
119                                        expr_macro,
120                                        "Only concat!(...) is supported here",
121                                    ));
122                                }
123                            }
124                            _ => {
125                                return Err(Error::new_spanned(
126                                    &meta_name_value.value,
127                                    "Expected a string literal or concat!(...)",
128                                ));
129                            }
130                        };
131                        match ident_str.as_str() {
132                            "name" => instance.name = Some(value),
133                            "description" => instance.description = Some(value),
134                            _ => {}
135                        }
136                    }
137                    "destructive_hint" | "idempotent_hint" | "open_world_hint"
138                    | "read_only_hint" => {
139                        #[cfg(feature = "2025_03_26")]
140                        {
141                            let value = match &meta_name_value.value {
142                                Expr::Lit(ExprLit {
143                                    lit: Lit::Bool(lit_bool),
144                                    ..
145                                }) => lit_bool.value,
146                                _ => {
147                                    return Err(Error::new_spanned(
148                                        &meta_name_value.value,
149                                        "Expected a boolean literal",
150                                    ));
151                                }
152                            };
153
154                            match ident_str.as_str() {
155                                "destructive_hint" => instance.destructive_hint = Some(value),
156                                "idempotent_hint" => instance.idempotent_hint = Some(value),
157                                "open_world_hint" => instance.open_world_hint = Some(value),
158                                "read_only_hint" => instance.read_only_hint = Some(value),
159                                _ => {}
160                            }
161                        }
162                    }
163                    #[cfg(feature = "2025_03_26")]
164                    "title" => {
165                        let value = match &meta_name_value.value {
166                            Expr::Lit(ExprLit {
167                                lit: Lit::Str(lit_str),
168                                ..
169                            }) => lit_str.value(),
170                            _ => {
171                                return Err(Error::new_spanned(
172                                    &meta_name_value.value,
173                                    "Expected a string literal",
174                                ));
175                            }
176                        };
177                        instance.title = Some(value);
178                    }
179                    _ => {}
180                }
181            }
182        }
183
184        // Validate presence and non-emptiness
185        if instance
186            .name
187            .as_ref()
188            .map(|s| s.trim().is_empty())
189            .unwrap_or(true)
190        {
191            return Err(Error::new(
192                attributes.span(),
193                "The 'name' attribute is required and must not be empty.",
194            ));
195        }
196        if instance
197            .description
198            .as_ref()
199            .map(|s| s.trim().is_empty())
200            .unwrap_or(true)
201        {
202            return Err(Error::new(
203                attributes.span(),
204                "The 'description' attribute is required and must not be empty.",
205            ));
206        }
207
208        Ok(instance)
209    }
210}
211
212/// A procedural macro attribute to generate rust_mcp_schema::Tool related utility methods for a struct.
213///
214/// The `mcp_tool` macro generates an implementation for the annotated struct that includes:
215/// - A `tool_name()` method returning the tool's name as a string.
216/// - A `tool()` method returning a `rust_mcp_schema::Tool` instance with the tool's name,
217///   description, and input schema derived from the struct's fields.
218///
219/// # Attributes
220/// * `name` - The name of the tool (required, non-empty string).
221/// * `description` - A description of the tool (required, non-empty string).
222///
223/// # Panics
224/// Panics if the macro is applied to anything other than a struct.
225///
226/// # Example
227/// ```rust
228/// #[rust_mcp_macros::mcp_tool(name = "example_tool", description = "An example tool", idempotent_hint=true )]
229/// #[derive(rust_mcp_macros::JsonSchema)]
230/// struct ExampleTool {
231///     field1: String,
232///     field2: i32,
233/// }
234///
235/// assert_eq!(ExampleTool::tool_name() , "example_tool");
236/// let tool : rust_mcp_schema::Tool = ExampleTool::tool();
237/// assert_eq!(tool.name , "example_tool");
238/// assert_eq!(tool.description.unwrap() , "An example tool");
239/// assert_eq!(tool.annotations.unwrap().idempotent_hint.unwrap() , true);
240///
241/// let schema_properties = tool.input_schema.properties.unwrap();
242/// assert_eq!(schema_properties.len() , 2);
243/// assert!(schema_properties.contains_key("field1"));
244/// assert!(schema_properties.contains_key("field2"));
245///
246/// ```
247#[proc_macro_attribute]
248pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream {
249    let input = parse_macro_input!(input as DeriveInput); // Parse the input as a function
250    let input_ident = &input.ident;
251
252    // Conditionally select the path for Tool
253    let base_crate = if cfg!(feature = "sdk") {
254        quote! { rust_mcp_sdk::schema }
255    } else {
256        quote! { rust_mcp_schema }
257    };
258
259    let macro_attributes = parse_macro_input!(attributes as McpToolMacroAttributes);
260
261    let tool_name = macro_attributes.name.unwrap_or_default();
262    let tool_description = macro_attributes.description.unwrap_or_default();
263
264    #[cfg(feature = "2025_03_26")]
265    let some_annotations = macro_attributes.destructive_hint.is_some()
266        || macro_attributes.idempotent_hint.is_some()
267        || macro_attributes.open_world_hint.is_some()
268        || macro_attributes.read_only_hint.is_some()
269        || macro_attributes.title.is_some();
270
271    #[cfg(feature = "2025_03_26")]
272    let annotations = if some_annotations {
273        let destructive_hint = macro_attributes
274            .destructive_hint
275            .map_or(quote! {None}, |v| quote! {Some(#v)});
276
277        let idempotent_hint = macro_attributes
278            .idempotent_hint
279            .map_or(quote! {None}, |v| quote! {Some(#v)});
280        let open_world_hint = macro_attributes
281            .open_world_hint
282            .map_or(quote! {None}, |v| quote! {Some(#v)});
283        let read_only_hint = macro_attributes
284            .read_only_hint
285            .map_or(quote! {None}, |v| quote! {Some(#v)});
286        let title = macro_attributes
287            .title
288            .map_or(quote! {None}, |v| quote! {Some(#v)});
289        quote! {
290            Some(#base_crate::ToolAnnotations {
291                                    destructive_hint: #destructive_hint,
292                                    idempotent_hint: #idempotent_hint,
293                                    open_world_hint: #open_world_hint,
294                                    read_only_hint: #read_only_hint,
295                                    title: #title,
296                                }),
297        }
298    } else {
299        quote! {None}
300    };
301
302    let annotations_token = {
303        #[cfg(feature = "2025_03_26")]
304        {
305            quote! { annotations: #annotations }
306        }
307        #[cfg(not(feature = "2025_03_26"))]
308        {
309            quote! {}
310        }
311    };
312
313    let tool_token = quote! {
314        #base_crate::Tool {
315            name: #tool_name.to_string(),
316            description: Some(#tool_description.to_string()),
317            input_schema: #base_crate::ToolInputSchema::new(required, properties),
318            #annotations_token
319        }
320    };
321
322    let output = quote! {
323        impl #input_ident {
324            /// Returns the name of the tool as a string.
325            pub fn tool_name()->String{
326                #tool_name.to_string()
327            }
328
329            /// Constructs and returns a `rust_mcp_schema::Tool` instance.
330            ///
331            /// The tool includes the name, description, and input schema derived from
332            /// the struct's attributes.
333            pub fn tool()-> #base_crate::Tool
334            {
335                let json_schema = &#input_ident::json_schema();
336
337                let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) {
338                    Some(arr) => arr
339                        .iter()
340                        .filter_map(|item| item.as_str().map(String::from))
341                        .collect(),
342                    None => Vec::new(), // Default to an empty vector if "required" is missing or not an array
343                };
344
345                let properties: Option<
346                    std::collections::HashMap<String, serde_json::Map<String, serde_json::Value>>,
347                > = json_schema
348                    .get("properties")
349                    .and_then(|v| v.as_object()) // Safely extract "properties" as an object.
350                    .map(|properties| {
351                        properties
352                            .iter()
353                            .filter_map(|(key, value)| {
354                                serde_json::to_value(value)
355                                    .ok() // If serialization fails, return None.
356                                    .and_then(|v| {
357                                        if let serde_json::Value::Object(obj) = v {
358                                            Some(obj)
359                                        } else {
360                                            None
361                                        }
362                                    })
363                                    .map(|obj| (key.to_string(), obj)) // Return the (key, value) tuple
364                            })
365                            .collect()
366                    });
367
368                #tool_token
369            }
370        }
371        // Retain the original item (struct definition)
372        #input
373    };
374
375    TokenStream::from(output)
376}
377
378/// Derives a JSON Schema representation for a struct.
379///
380/// This procedural macro generates a `json_schema()` method for the annotated struct, returning a
381/// `serde_json::Map<String, serde_json::Value>` that represents the struct as a JSON Schema object.
382/// The schema includes the struct's fields as properties, with support for basic types, `Option<T>`,
383/// `Vec<T>`, and nested structs that also derive `JsonSchema`.
384///
385/// # Features
386/// - **Basic Types:** Maps `String` to `"string"`, `i32` to `"integer"`, `bool` to `"boolean"`, etc.
387/// - **`Option<T>`:** Adds `"nullable": true` to the schema of the inner type, indicating the field is optional.
388/// - **`Vec<T>`:** Generates an `"array"` schema with an `"items"` field describing the inner type.
389/// - **Nested Structs:** Recursively includes the schema of nested structs (assumed to derive `JsonSchema`),
390///   embedding their `"properties"` and `"required"` fields.
391/// - **Required Fields:** Adds a top-level `"required"` array listing field names not wrapped in `Option`.
392///
393/// # Notes
394/// It’s designed as a straightforward solution to meet the basic needs of this package, supporting
395/// common types and simple nested structures. For more advanced features or robust JSON Schema generation,
396/// consider exploring established crates like
397/// [`schemars`](https://crates.io/crates/schemars) on crates.io
398///
399/// # Limitations
400/// - Supports only structs with named fields (e.g., `struct S { field: Type }`).
401/// - Nested structs must also derive `JsonSchema`, or compilation will fail.
402/// - Unknown types are mapped to `{"type": "unknown"}`.
403/// - Type paths must be in scope (e.g., fully qualified paths like `my_mod::InnerStruct` work if imported).
404///
405/// # Panics
406/// - If the input is not a struct with named fields (e.g., tuple structs or enums).
407///
408/// # Dependencies
409/// Relies on `serde_json` for `Map` and `Value` types.
410///
411#[proc_macro_derive(JsonSchema)]
412pub fn derive_json_schema(input: TokenStream) -> TokenStream {
413    let input = parse_macro_input!(input as DeriveInput);
414    let name = &input.ident;
415
416    let fields = match &input.data {
417        Data::Struct(data) => match &data.fields {
418            Fields::Named(fields) => &fields.named,
419            _ => panic!("JsonSchema derive macro only supports named fields"),
420        },
421        _ => panic!("JsonSchema derive macro only supports structs"),
422    };
423
424    let field_entries = fields.iter().map(|field| {
425        let field_attrs = &field.attrs;
426        let renamed_field = renamed_field(field_attrs);
427        let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string());
428        let field_type = &field.ty;
429
430        let schema = type_to_json_schema(field_type, field_attrs);
431        quote! {
432            properties.insert(
433                #field_name.to_string(),
434                serde_json::Value::Object(#schema)
435            );
436        }
437    });
438
439    let required_fields = fields.iter().filter_map(|field| {
440        let renamed_field = renamed_field(&field.attrs);
441        let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string());
442
443        let field_type = &field.ty;
444        if !is_option(field_type) {
445            Some(quote! {
446                required.push(#field_name.to_string());
447            })
448        } else {
449            None
450        }
451    });
452
453    let expanded = quote! {
454        impl #name {
455            pub fn json_schema() -> serde_json::Map<String, serde_json::Value> {
456                let mut schema = serde_json::Map::new();
457                let mut properties = serde_json::Map::new();
458                let mut required = Vec::new();
459
460                #(#field_entries)*
461
462                #(#required_fields)*
463
464                schema.insert("type".to_string(), serde_json::Value::String("object".to_string()));
465                schema.insert("properties".to_string(), serde_json::Value::Object(properties));
466                if !required.is_empty() {
467                    schema.insert("required".to_string(), serde_json::Value::Array(
468                        required.into_iter().map(serde_json::Value::String).collect()
469                    ));
470                }
471
472                schema
473            }
474        }
475    };
476    TokenStream::from(expanded)
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use syn::parse_str;
483    #[test]
484    fn test_valid_macro_attributes() {
485        let input = r#"name = "test_tool", description = "A test tool.""#;
486        let parsed: McpToolMacroAttributes = parse_str(input).unwrap();
487
488        assert_eq!(parsed.name.unwrap(), "test_tool");
489        assert_eq!(parsed.description.unwrap(), "A test tool.");
490    }
491
492    #[test]
493    fn test_missing_name() {
494        let input = r#"description = "Only description""#;
495        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
496        assert!(result.is_err());
497        assert_eq!(
498            result.err().unwrap().to_string(),
499            "The 'name' attribute is required and must not be empty."
500        )
501    }
502
503    #[test]
504    fn test_missing_description() {
505        let input = r#"name = "OnlyName""#;
506        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
507        assert!(result.is_err());
508        assert_eq!(
509            result.err().unwrap().to_string(),
510            "The 'description' attribute is required and must not be empty."
511        )
512    }
513
514    #[test]
515    fn test_empty_name_field() {
516        let input = r#"name = "", description = "something""#;
517        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
518        assert!(result.is_err());
519        assert_eq!(
520            result.err().unwrap().to_string(),
521            "The 'name' attribute is required and must not be empty."
522        );
523    }
524    #[test]
525    fn test_empty_description_field() {
526        let input = r#"name = "my-tool", description = """#;
527        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
528        assert!(result.is_err());
529        assert_eq!(
530            result.err().unwrap().to_string(),
531            "The 'description' attribute is required and must not be empty."
532        );
533    }
534}