rust_mcp_macros/
lib.rs

1extern crate proc_macro;
2
3mod utils;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{
8    parse::Parse, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Error, Expr,
9    ExprLit, Fields, Lit, Meta, Token,
10};
11use utils::{is_option, renamed_field, type_to_json_schema};
12
13/// Represents the attributes for the `mcp_tool` procedural macro.
14///
15/// This struct parses and validates the `name` and `description` attributes provided
16/// to the `mcp_tool` macro. Both attributes are required and must not be empty strings.
17///
18/// # Fields
19/// * `name` - An optional string representing the tool's name.
20/// * `description` - An optional string describing the tool.
21///
22struct McpToolMacroAttributes {
23    name: Option<String>,
24    description: Option<String>,
25}
26
27use syn::parse::ParseStream;
28
29struct ExprList {
30    exprs: Punctuated<Expr, Token![,]>,
31}
32
33impl Parse for ExprList {
34    fn parse(input: ParseStream) -> syn::Result<Self> {
35        Ok(ExprList {
36            exprs: Punctuated::parse_terminated(input)?,
37        })
38    }
39}
40
41impl Parse for McpToolMacroAttributes {
42    /// Parses the macro attributes from a `ParseStream`.
43    ///
44    /// This implementation extracts `name` and `description` from the attribute input,
45    /// ensuring they are provided as string literals and are non-empty.
46    ///
47    /// # Errors
48    /// Returns a `syn::Error` if:
49    /// - The `name` attribute is missing or empty.
50    /// - The `description` attribute is missing or empty.
51    fn parse(attributes: syn::parse::ParseStream) -> syn::Result<Self> {
52        let mut name = None;
53        let mut description = None;
54        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(attributes)?;
55        for meta in meta_list {
56            if let Meta::NameValue(meta_name_value) = meta {
57                let ident = meta_name_value.path.get_ident().unwrap();
58                let ident_str = ident.to_string();
59
60                let value = match &meta_name_value.value {
61                    Expr::Lit(ExprLit {
62                        lit: Lit::Str(lit_str),
63                        ..
64                    }) => lit_str.value(),
65
66                    Expr::Macro(expr_macro) => {
67                        let mac = &expr_macro.mac;
68                        if mac.path.is_ident("concat") {
69                            let args: ExprList = syn::parse2(mac.tokens.clone())?;
70                            let mut result = String::new();
71
72                            for expr in args.exprs {
73                                if let Expr::Lit(ExprLit {
74                                    lit: Lit::Str(lit_str),
75                                    ..
76                                }) = expr
77                                {
78                                    result.push_str(&lit_str.value());
79                                } else {
80                                    return Err(Error::new_spanned(
81                                        expr,
82                                        "Only string literals are allowed inside concat!()",
83                                    ));
84                                }
85                            }
86
87                            result
88                        } else {
89                            return Err(Error::new_spanned(
90                                expr_macro,
91                                "Only concat!(...) is supported here",
92                            ));
93                        }
94                    }
95
96                    _ => {
97                        return Err(Error::new_spanned(
98                            &meta_name_value.value,
99                            "Expected a string literal or concat!(...)",
100                        ));
101                    }
102                };
103
104                match ident_str.as_str() {
105                    "name" => name = Some(value),
106                    "description" => description = Some(value),
107                    _ => {}
108                }
109            }
110        }
111
112        // Validate presence and non-emptiness
113        if name.as_ref().map(|s| s.trim().is_empty()).unwrap_or(true) {
114            return Err(Error::new(
115                attributes.span(),
116                "The 'name' attribute is required and must not be empty.",
117            ));
118        }
119
120        if description
121            .as_ref()
122            .map(|s| s.trim().is_empty())
123            .unwrap_or(true)
124        {
125            return Err(Error::new(
126                attributes.span(),
127                "The 'description' attribute is required and must not be empty.",
128            ));
129        }
130
131        Ok(Self { name, description })
132    }
133}
134
135/// A procedural macro attribute to generate rust_mcp_schema::Tool related utility methods for a struct.
136///
137/// The `mcp_tool` macro generates an implementation for the annotated struct that includes:
138/// - A `tool_name()` method returning the tool's name as a string.
139/// - A `tool()` method returning a `rust_mcp_schema::Tool` instance with the tool's name,
140///   description, and input schema derived from the struct's fields.
141///
142/// # Attributes
143/// * `name` - The name of the tool (required, non-empty string).
144/// * `description` - A description of the tool (required, non-empty string).
145///
146/// # Panics
147/// Panics if the macro is applied to anything other than a struct.
148///
149/// # Example
150/// ```rust
151/// #[rust_mcp_macros::mcp_tool(name = "example_tool", description = "An example tool")]
152/// #[derive(rust_mcp_macros::JsonSchema)]
153/// struct ExampleTool {
154///     field1: String,
155///     field2: i32,
156/// }
157///
158/// assert_eq!(ExampleTool::tool_name() , "example_tool");
159/// let tool : rust_mcp_schema::Tool = ExampleTool::tool();
160/// assert_eq!(tool.name , "example_tool");
161/// assert_eq!(tool.description.unwrap() , "An example tool");
162///
163/// let schema_properties = tool.input_schema.properties.unwrap();
164/// assert_eq!(schema_properties.len() , 2);
165/// assert!(schema_properties.contains_key("field1"));
166/// assert!(schema_properties.contains_key("field2"));
167///
168/// ```
169#[proc_macro_attribute]
170pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream {
171    let input = parse_macro_input!(input as DeriveInput); // Parse the input as a function
172    let input_ident = &input.ident;
173
174    let macro_attributes = parse_macro_input!(attributes as McpToolMacroAttributes);
175
176    let tool_name = macro_attributes.name.unwrap_or_default();
177    let tool_description = macro_attributes.description.unwrap_or_default();
178
179    let output = quote! {
180        impl #input_ident {
181            /// Returns the name of the tool as a string.
182            pub fn tool_name()->String{
183                #tool_name.to_string()
184            }
185
186            /// Constructs and returns a `rust_mcp_schema::Tool` instance.
187            ///
188            /// The tool includes the name, description, and input schema derived from
189            /// the struct's attributes.
190            pub fn tool()-> rust_mcp_schema::Tool
191            {
192                let json_schema = &#input_ident::json_schema();
193
194                let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) {
195                    Some(arr) => arr
196                        .iter()
197                        .filter_map(|item| item.as_str().map(String::from))
198                        .collect(),
199                    None => Vec::new(), // Default to an empty vector if "required" is missing or not an array
200                };
201
202                let properties: Option<
203                    std::collections::HashMap<String, serde_json::Map<String, serde_json::Value>>,
204                > = json_schema
205                    .get("properties")
206                    .and_then(|v| v.as_object()) // Safely extract "properties" as an object.
207                    .map(|properties| {
208                        properties
209                            .iter()
210                            .filter_map(|(key, value)| {
211                                serde_json::to_value(value)
212                                    .ok() // If serialization fails, return None.
213                                    .and_then(|v| {
214                                        if let serde_json::Value::Object(obj) = v {
215                                            Some(obj)
216                                        } else {
217                                            None
218                                        }
219                                    })
220                                    .map(|obj| (key.to_string(), obj)) // Return the (key, value) tuple
221                            })
222                            .collect()
223                    });
224
225                rust_mcp_schema::Tool {
226                    name: #tool_name.to_string(),
227                    description: Some(#tool_description.to_string()),
228                    input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties),
229                }
230            }
231
232            #[deprecated(since = "0.2.0", note = "Use `tool()` instead.")]
233            pub fn get_tool()-> rust_mcp_schema::Tool
234            {
235                let json_schema = &#input_ident::json_schema();
236
237                let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) {
238                    Some(arr) => arr
239                        .iter()
240                        .filter_map(|item| item.as_str().map(String::from))
241                        .collect(),
242                    None => Vec::new(), // Default to an empty vector if "required" is missing or not an array
243                };
244
245                let properties: Option<
246                    std::collections::HashMap<String, serde_json::Map<String, serde_json::Value>>,
247                > = json_schema
248                    .get("properties")
249                    .and_then(|v| v.as_object()) // Safely extract "properties" as an object.
250                    .map(|properties| {
251                        properties
252                            .iter()
253                            .filter_map(|(key, value)| {
254                                serde_json::to_value(value)
255                                    .ok() // If serialization fails, return None.
256                                    .and_then(|v| {
257                                        if let serde_json::Value::Object(obj) = v {
258                                            Some(obj)
259                                        } else {
260                                            None
261                                        }
262                                    })
263                                    .map(|obj| (key.to_string(), obj)) // Return the (key, value) tuple
264                            })
265                            .collect()
266                    });
267
268                rust_mcp_schema::Tool {
269                    name: #tool_name.to_string(),
270                    description: Some(#tool_description.to_string()),
271                    input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties),
272                }
273            }
274        }
275        // Retain the original item (struct definition)
276        #input
277    };
278
279    TokenStream::from(output)
280}
281
282/// Derives a JSON Schema representation for a struct.
283///
284/// This procedural macro generates a `json_schema()` method for the annotated struct, returning a
285/// `serde_json::Map<String, serde_json::Value>` that represents the struct as a JSON Schema object.
286/// The schema includes the struct's fields as properties, with support for basic types, `Option<T>`,
287/// `Vec<T>`, and nested structs that also derive `JsonSchema`.
288///
289/// # Features
290/// - **Basic Types:** Maps `String` to `"string"`, `i32` to `"integer"`, `bool` to `"boolean"`, etc.
291/// - **`Option<T>`:** Adds `"nullable": true` to the schema of the inner type, indicating the field is optional.
292/// - **`Vec<T>`:** Generates an `"array"` schema with an `"items"` field describing the inner type.
293/// - **Nested Structs:** Recursively includes the schema of nested structs (assumed to derive `JsonSchema`),
294///   embedding their `"properties"` and `"required"` fields.
295/// - **Required Fields:** Adds a top-level `"required"` array listing field names not wrapped in `Option`.
296///
297/// # Notes
298/// It’s designed as a straightforward solution to meet the basic needs of this package, supporting
299/// common types and simple nested structures. For more advanced features or robust JSON Schema generation,
300/// consider exploring established crates like
301/// [`schemars`](https://crates.io/crates/schemars) on crates.io
302///
303/// # Limitations
304/// - Supports only structs with named fields (e.g., `struct S { field: Type }`).
305/// - Nested structs must also derive `JsonSchema`, or compilation will fail.
306/// - Unknown types are mapped to `{"type": "unknown"}`.
307/// - Type paths must be in scope (e.g., fully qualified paths like `my_mod::InnerStruct` work if imported).
308///
309/// # Panics
310/// - If the input is not a struct with named fields (e.g., tuple structs or enums).
311///
312/// # Dependencies
313/// Relies on `serde_json` for `Map` and `Value` types.
314///
315#[proc_macro_derive(JsonSchema)]
316pub fn derive_json_schema(input: TokenStream) -> TokenStream {
317    let input = parse_macro_input!(input as DeriveInput);
318    let name = &input.ident;
319
320    let fields = match &input.data {
321        Data::Struct(data) => match &data.fields {
322            Fields::Named(fields) => &fields.named,
323            _ => panic!("JsonSchema derive macro only supports named fields"),
324        },
325        _ => panic!("JsonSchema derive macro only supports structs"),
326    };
327
328    let field_entries = fields.iter().map(|field| {
329        let field_attrs = &field.attrs;
330        let renamed_field = renamed_field(field_attrs);
331        let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string());
332        let field_type = &field.ty;
333
334        let schema = type_to_json_schema(field_type, field_attrs);
335        quote! {
336            properties.insert(
337                #field_name.to_string(),
338                serde_json::Value::Object(#schema)
339            );
340        }
341    });
342
343    let required_fields = fields.iter().filter_map(|field| {
344        let renamed_field = renamed_field(&field.attrs);
345        let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string());
346
347        let field_type = &field.ty;
348        if !is_option(field_type) {
349            Some(quote! {
350                required.push(#field_name.to_string());
351            })
352        } else {
353            None
354        }
355    });
356
357    let expanded = quote! {
358        impl #name {
359            pub fn json_schema() -> serde_json::Map<String, serde_json::Value> {
360                let mut schema = serde_json::Map::new();
361                let mut properties = serde_json::Map::new();
362                let mut required = Vec::new();
363
364                #(#field_entries)*
365
366                #(#required_fields)*
367
368                schema.insert("type".to_string(), serde_json::Value::String("object".to_string()));
369                schema.insert("properties".to_string(), serde_json::Value::Object(properties));
370                if !required.is_empty() {
371                    schema.insert("required".to_string(), serde_json::Value::Array(
372                        required.into_iter().map(serde_json::Value::String).collect()
373                    ));
374                }
375
376                schema
377            }
378        }
379    };
380    TokenStream::from(expanded)
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use syn::parse_str;
387    #[test]
388    fn test_valid_macro_attributes() {
389        let input = r#"name = "test_tool", description = "A test tool.""#;
390        let parsed: McpToolMacroAttributes = parse_str(input).unwrap();
391
392        assert_eq!(parsed.name.unwrap(), "test_tool");
393        assert_eq!(parsed.description.unwrap(), "A test tool.");
394    }
395
396    #[test]
397    fn test_missing_name() {
398        let input = r#"description = "Only description""#;
399        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
400        assert!(result.is_err());
401        assert_eq!(
402            result.err().unwrap().to_string(),
403            "The 'name' attribute is required and must not be empty."
404        )
405    }
406
407    #[test]
408    fn test_missing_description() {
409        let input = r#"name = "OnlyName""#;
410        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
411        assert!(result.is_err());
412        assert_eq!(
413            result.err().unwrap().to_string(),
414            "The 'description' attribute is required and must not be empty."
415        )
416    }
417
418    #[test]
419    fn test_empty_name_field() {
420        let input = r#"name = "", description = "something""#;
421        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
422        assert!(result.is_err());
423        assert_eq!(
424            result.err().unwrap().to_string(),
425            "The 'name' attribute is required and must not be empty."
426        );
427    }
428    #[test]
429    fn test_empty_description_field() {
430        let input = r#"name = "my-tool", description = """#;
431        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
432        assert!(result.is_err());
433        assert_eq!(
434            result.err().unwrap().to_string(),
435            "The 'description' attribute is required and must not be empty."
436        );
437    }
438}