rust_mcp_macros/
lib.rs

1extern crate proc_macro;
2
3mod utils;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{
8    parse::Parse, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Error, Expr,
9    ExprLit, Fields, Lit, Meta, Token,
10};
11use utils::{is_option, renamed_field, type_to_json_schema};
12
13/// Represents the attributes for the `mcp_tool` procedural macro.
14///
15/// This struct parses and validates the `name` and `description` attributes provided
16/// to the `mcp_tool` macro. Both attributes are required and must not be empty strings.
17///
18/// # Fields
19/// * `name` - An optional string representing the tool's name.
20/// * `description` - An optional string describing the tool.
21///
22struct MCPToolMacroAttributes {
23    name: Option<String>,
24    description: Option<String>,
25}
26
27impl Parse for MCPToolMacroAttributes {
28    /// Parses the macro attributes from a `ParseStream`.
29    ///
30    /// This implementation extracts `name` and `description` from the attribute input,
31    /// ensuring they are provided as string literals and are non-empty.
32    ///
33    /// # Errors
34    /// Returns a `syn::Error` if:
35    /// - The `name` attribute is missing or empty.
36    /// - The `description` attribute is missing or empty.
37    fn parse(attributes: syn::parse::ParseStream) -> syn::Result<Self> {
38        let mut name = None;
39        let mut description = None;
40        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(attributes)?;
41        for meta in meta_list {
42            if let Meta::NameValue(meta_name_value) = meta {
43                let ident = meta_name_value.path.get_ident().unwrap();
44                if let Expr::Lit(ExprLit {
45                    lit: Lit::Str(lit_str),
46                    ..
47                }) = meta_name_value.value
48                {
49                    match ident.to_string().as_str() {
50                        "name" => name = Some(lit_str.value()),
51                        "description" => description = Some(lit_str.value()),
52                        _ => {}
53                    }
54                }
55            }
56        }
57        match &name {
58            Some(tool_name) => {
59                if tool_name.trim().is_empty() {
60                    return Err(Error::new(
61                        attributes.span(),
62                        "The 'name' attribute should not be an empty string.",
63                    ));
64                }
65            }
66            None => {
67                return Err(Error::new(
68                    attributes.span(),
69                    "The 'name' attribute is required.",
70                ));
71            }
72        }
73
74        match &description {
75            Some(description) => {
76                if description.trim().is_empty() {
77                    return Err(Error::new(
78                        attributes.span(),
79                        "The 'description' attribute should not be an empty string.",
80                    ));
81                }
82            }
83            None => {
84                return Err(Error::new(
85                    attributes.span(),
86                    "The 'description' attribute is required.",
87                ));
88            }
89        }
90
91        Ok(Self { name, description })
92    }
93}
94
95/// A procedural macro attribute to generate rust_mcp_schema::Tool related utility methods for a struct.
96///
97/// The `mcp_tool` macro generates an implementation for the annotated struct that includes:
98/// - A `tool_name()` method returning the tool's name as a string.
99/// - A `get_tool()` method returning a `rust_mcp_schema::Tool` instance with the tool's name,
100///   description, and input schema derived from the struct's fields.
101///
102/// # Attributes
103/// * `name` - The name of the tool (required, non-empty string).
104/// * `description` - A description of the tool (required, non-empty string).
105///
106/// # Panics
107/// Panics if the macro is applied to anything other than a struct.
108///
109/// # Example
110/// ```rust
111/// #[rust_mcp_macros::mcp_tool(name = "example_tool", description = "An example tool")]
112/// #[derive(rust_mcp_macros::JsonSchema)]
113/// struct ExampleTool {
114///     field1: String,
115///     field2: i32,
116/// }
117///
118/// assert_eq!(ExampleTool::tool_name() , "example_tool");
119/// let tool : rust_mcp_schema::Tool = ExampleTool::get_tool();
120/// assert_eq!(tool.name , "example_tool");
121/// assert_eq!(tool.description.unwrap() , "An example tool");
122///
123/// let schema_properties = tool.input_schema.properties.unwrap();
124/// assert_eq!(schema_properties.len() , 2);
125/// assert!(schema_properties.contains_key("field1"));
126/// assert!(schema_properties.contains_key("field2"));
127///
128/// ```
129#[proc_macro_attribute]
130pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream {
131    let input = parse_macro_input!(input as DeriveInput); // Parse the input as a function
132    let input_ident = &input.ident;
133
134    let macro_attributes = parse_macro_input!(attributes as MCPToolMacroAttributes);
135
136    let tool_name = macro_attributes.name.unwrap_or_default();
137    let tool_description = macro_attributes.description.unwrap_or_default();
138
139    let output = quote! {
140        impl #input_ident {
141            /// Returns the name of the tool as a string.
142            pub fn tool_name()->String{
143                #tool_name.to_string()
144            }
145
146            /// Constructs and returns a `rust_mcp_schema::Tool` instance.
147            ///
148            /// The tool includes the name, description, and input schema derived from
149            /// the struct's attributes.
150            pub fn get_tool()-> rust_mcp_schema::Tool
151            {
152                let json_schema = &#input_ident::json_schema();
153
154                let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) {
155                    Some(arr) => arr
156                        .iter()
157                        .filter_map(|item| item.as_str().map(String::from))
158                        .collect(),
159                    None => Vec::new(), // Default to an empty vector if "required" is missing or not an array
160                };
161
162                let properties: Option<
163                    std::collections::HashMap<String, serde_json::Map<String, serde_json::Value>>,
164                > = json_schema
165                    .get("properties")
166                    .and_then(|v| v.as_object()) // Safely extract "properties" as an object.
167                    .map(|properties| {
168                        properties
169                            .iter()
170                            .filter_map(|(key, value)| {
171                                serde_json::to_value(value)
172                                    .ok() // If serialization fails, return None.
173                                    .and_then(|v| {
174                                        if let serde_json::Value::Object(obj) = v {
175                                            Some(obj)
176                                        } else {
177                                            None
178                                        }
179                                    })
180                                    .map(|obj| (key.to_string(), obj)) // Return the (key, value) tuple
181                            })
182                            .collect()
183                    });
184
185                rust_mcp_schema::Tool {
186                    name: #tool_name.to_string(),
187                    description: Some(#tool_description.to_string()),
188                    input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties),
189                }
190            }
191        }
192        // Retain the original item (struct definition)
193        #input
194    };
195
196    TokenStream::from(output)
197}
198
199/// Derives a JSON Schema representation for a struct.
200///
201/// This procedural macro generates a `json_schema()` method for the annotated struct, returning a
202/// `serde_json::Map<String, serde_json::Value>` that represents the struct as a JSON Schema object.
203/// The schema includes the struct's fields as properties, with support for basic types, `Option<T>`,
204/// `Vec<T>`, and nested structs that also derive `JsonSchema`.
205///
206/// # Features
207/// - **Basic Types:** Maps `String` to `"string"`, `i32` to `"integer"`, `bool` to `"boolean"`, etc.
208/// - **`Option<T>`:** Adds `"nullable": true` to the schema of the inner type, indicating the field is optional.
209/// - **`Vec<T>`:** Generates an `"array"` schema with an `"items"` field describing the inner type.
210/// - **Nested Structs:** Recursively includes the schema of nested structs (assumed to derive `JsonSchema`),
211///   embedding their `"properties"` and `"required"` fields.
212/// - **Required Fields:** Adds a top-level `"required"` array listing field names not wrapped in `Option`.
213///
214/// # Notes
215/// It’s designed as a straightforward solution to meet the basic needs of this package, supporting
216/// common types and simple nested structures. For more advanced features or robust JSON Schema generation,
217/// consider exploring established crates like
218/// [`schemars`](https://crates.io/crates/schemars) on crates.io
219///
220/// # Limitations
221/// - Supports only structs with named fields (e.g., `struct S { field: Type }`).
222/// - Nested structs must also derive `JsonSchema`, or compilation will fail.
223/// - Unknown types are mapped to `{"type": "unknown"}`.
224/// - Type paths must be in scope (e.g., fully qualified paths like `my_mod::InnerStruct` work if imported).
225///
226/// # Panics
227/// - If the input is not a struct with named fields (e.g., tuple structs or enums).
228///
229/// # Dependencies
230/// Relies on `serde_json` for `Map` and `Value` types.
231///
232#[proc_macro_derive(JsonSchema)]
233pub fn derive_json_schema(input: TokenStream) -> TokenStream {
234    let input = parse_macro_input!(input as DeriveInput);
235    let name = &input.ident;
236
237    let fields = match &input.data {
238        Data::Struct(data) => match &data.fields {
239            Fields::Named(fields) => &fields.named,
240            _ => panic!("JsonSchema derive macro only supports named fields"),
241        },
242        _ => panic!("JsonSchema derive macro only supports structs"),
243    };
244
245    let field_entries = fields.iter().map(|field| {
246        let field_attrs = &field.attrs;
247        let renamed_field = renamed_field(field_attrs);
248        let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string());
249        let field_type = &field.ty;
250
251        let schema = type_to_json_schema(field_type, field_attrs);
252        quote! {
253            properties.insert(
254                #field_name.to_string(),
255                serde_json::Value::Object(#schema)
256            );
257        }
258    });
259
260    let required_fields = fields.iter().filter_map(|field| {
261        let renamed_field = renamed_field(&field.attrs);
262        let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string());
263
264        let field_type = &field.ty;
265        if !is_option(field_type) {
266            Some(quote! {
267                required.push(#field_name.to_string());
268            })
269        } else {
270            None
271        }
272    });
273
274    let expanded = quote! {
275        impl #name {
276            pub fn json_schema() -> serde_json::Map<String, serde_json::Value> {
277                let mut schema = serde_json::Map::new();
278                let mut properties = serde_json::Map::new();
279                let mut required = Vec::new();
280
281                #(#field_entries)*
282
283                #(#required_fields)*
284
285                schema.insert("type".to_string(), serde_json::Value::String("object".to_string()));
286                schema.insert("properties".to_string(), serde_json::Value::Object(properties));
287                if !required.is_empty() {
288                    schema.insert("required".to_string(), serde_json::Value::Array(
289                        required.into_iter().map(serde_json::Value::String).collect()
290                    ));
291                }
292
293                schema
294            }
295        }
296    };
297    TokenStream::from(expanded)
298}