rust_mcp_macros/
lib.rs

1extern crate proc_macro;
2
3mod utils;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{
8    parse::Parse, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Error, Expr,
9    ExprLit, Fields, GenericArgument, Lit, Meta, PathArguments, Token, Type,
10};
11use utils::{is_option, renamed_field, type_to_json_schema};
12
13/// Represents the attributes for the `mcp_tool` procedural macro.
14///
15/// This struct parses and validates the attributes provided to the `mcp_tool` macro.
16/// The `name` and `description` attributes are required and must not be empty strings.
17///
18/// # Fields
19/// * `name` - A string representing the tool's name (required).
20/// * `description` - A string describing the tool (required).
21/// * `meta` - An optional JSON string for metadata.
22/// * `title` - An optional string for the tool's title.
23/// * The following fields are available only with the `2025_03_26` feature and later:
24///   * `destructive_hint` - Optional boolean for `ToolAnnotations::destructive_hint`.
25///   * `idempotent_hint` - Optional boolean for `ToolAnnotations::idempotent_hint`.
26///   * `open_world_hint` - Optional boolean for `ToolAnnotations::open_world_hint`.
27///   * `read_only_hint` - Optional boolean for `ToolAnnotations::read_only_hint`.
28///
29struct McpToolMacroAttributes {
30    name: Option<String>,
31    description: Option<String>,
32    #[cfg(feature = "2025_06_18")]
33    meta: Option<String>, // Store raw JSON string instead of parsed Map
34    #[cfg(feature = "2025_06_18")]
35    title: Option<String>,
36    #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
37    destructive_hint: Option<bool>,
38    #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
39    idempotent_hint: Option<bool>,
40    #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
41    open_world_hint: Option<bool>,
42    #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
43    read_only_hint: Option<bool>,
44}
45
46use syn::parse::ParseStream;
47
48use crate::utils::{generate_enum_parse, is_enum};
49
50struct ExprList {
51    exprs: Punctuated<Expr, Token![,]>,
52}
53
54impl Parse for ExprList {
55    fn parse(input: ParseStream) -> syn::Result<Self> {
56        Ok(ExprList {
57            exprs: Punctuated::parse_terminated(input)?,
58        })
59    }
60}
61
62impl Parse for McpToolMacroAttributes {
63    /// Parses the macro attributes from a `ParseStream`.
64    ///
65    /// This implementation extracts `name`, `description`, `meta`, and `title` from the attribute input.
66    /// The `name` and `description` must be provided as string literals and be non-empty.
67    /// The `meta` attribute must be a valid JSON object provided as a string literal, and `title` must be a string literal.
68    ///
69    /// # Errors
70    /// Returns a `syn::Error` if:
71    /// - The `name` attribute is missing or empty.
72    /// - The `description` attribute is missing or empty.
73    /// - The `meta` attribute is provided but is not a valid JSON object.
74    /// - The `title` attribute is provided but is not a string literal.
75    fn parse(attributes: syn::parse::ParseStream) -> syn::Result<Self> {
76        let mut instance = Self {
77            name: None,
78            description: None,
79            #[cfg(feature = "2025_06_18")]
80            meta: None,
81            #[cfg(feature = "2025_06_18")]
82            title: None,
83            #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
84            destructive_hint: None,
85            #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
86            idempotent_hint: None,
87            #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
88            open_world_hint: None,
89            #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
90            read_only_hint: None,
91        };
92
93        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(attributes)?;
94        for meta in meta_list {
95            if let Meta::NameValue(meta_name_value) = meta {
96                let ident = meta_name_value.path.get_ident().unwrap();
97                let ident_str = ident.to_string();
98
99                match ident_str.as_str() {
100                    "name" | "description" => {
101                        let value = match &meta_name_value.value {
102                            Expr::Lit(ExprLit {
103                                lit: Lit::Str(lit_str),
104                                ..
105                            }) => lit_str.value(),
106                            Expr::Macro(expr_macro) => {
107                                let mac = &expr_macro.mac;
108                                if mac.path.is_ident("concat") {
109                                    let args: ExprList = syn::parse2(mac.tokens.clone())?;
110                                    let mut result = String::new();
111                                    for expr in args.exprs {
112                                        if let Expr::Lit(ExprLit {
113                                            lit: Lit::Str(lit_str),
114                                            ..
115                                        }) = expr
116                                        {
117                                            result.push_str(&lit_str.value());
118                                        } else {
119                                            return Err(Error::new_spanned(
120                                                expr,
121                                                "Only string literals are allowed inside concat!()",
122                                            ));
123                                        }
124                                    }
125                                    result
126                                } else {
127                                    return Err(Error::new_spanned(
128                                        expr_macro,
129                                        "Only concat!(...) is supported here",
130                                    ));
131                                }
132                            }
133                            _ => {
134                                return Err(Error::new_spanned(
135                                    &meta_name_value.value,
136                                    "Expected a string literal or concat!(...)",
137                                ));
138                            }
139                        };
140                        match ident_str.as_str() {
141                            "name" => instance.name = Some(value),
142                            "description" => instance.description = Some(value),
143                            _ => {}
144                        }
145                    }
146                    #[cfg(feature = "2025_06_18")]
147                    "meta" => {
148                        let value = match &meta_name_value.value {
149                            Expr::Lit(ExprLit {
150                                lit: Lit::Str(lit_str),
151                                ..
152                            }) => lit_str.value(),
153                            _ => {
154                                return Err(Error::new_spanned(
155                                    &meta_name_value.value,
156                                    "Expected a JSON object as a string literal",
157                                ));
158                            }
159                        };
160                        // Validate that the string is a valid JSON object
161                        let parsed: serde_json::Value =
162                            serde_json::from_str(&value).map_err(|e| {
163                                Error::new_spanned(
164                                    &meta_name_value.value,
165                                    format!("Expected a valid JSON object: {e}"),
166                                )
167                            })?;
168                        if !parsed.is_object() {
169                            return Err(Error::new_spanned(
170                                &meta_name_value.value,
171                                "Expected a JSON object",
172                            ));
173                        }
174                        instance.meta = Some(value);
175                    }
176                    #[cfg(feature = "2025_06_18")]
177                    "title" => {
178                        let value = match &meta_name_value.value {
179                            Expr::Lit(ExprLit {
180                                lit: Lit::Str(lit_str),
181                                ..
182                            }) => lit_str.value(),
183                            _ => {
184                                return Err(Error::new_spanned(
185                                    &meta_name_value.value,
186                                    "Expected a string literal",
187                                ));
188                            }
189                        };
190                        instance.title = Some(value);
191                    }
192                    "destructive_hint" | "idempotent_hint" | "open_world_hint"
193                    | "read_only_hint" => {
194                        #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
195                        {
196                            let value = match &meta_name_value.value {
197                                Expr::Lit(ExprLit {
198                                    lit: Lit::Bool(lit_bool),
199                                    ..
200                                }) => lit_bool.value,
201                                _ => {
202                                    return Err(Error::new_spanned(
203                                        &meta_name_value.value,
204                                        "Expected a boolean literal",
205                                    ));
206                                }
207                            };
208
209                            match ident_str.as_str() {
210                                "destructive_hint" => instance.destructive_hint = Some(value),
211                                "idempotent_hint" => instance.idempotent_hint = Some(value),
212                                "open_world_hint" => instance.open_world_hint = Some(value),
213                                "read_only_hint" => instance.read_only_hint = Some(value),
214                                _ => {}
215                            }
216                        }
217                    }
218                    _ => {}
219                }
220            }
221        }
222
223        // Validate presence and non-emptiness
224        if instance
225            .name
226            .as_ref()
227            .map(|s| s.trim().is_empty())
228            .unwrap_or(true)
229        {
230            return Err(Error::new(
231                attributes.span(),
232                "The 'name' attribute is required and must not be empty.",
233            ));
234        }
235        if instance
236            .description
237            .as_ref()
238            .map(|s| s.trim().is_empty())
239            .unwrap_or(true)
240        {
241            return Err(Error::new(
242                attributes.span(),
243                "The 'description' attribute is required and must not be empty.",
244            ));
245        }
246
247        Ok(instance)
248    }
249}
250
251struct McpElicitationAttributes {
252    message: Option<String>,
253}
254
255impl Parse for McpElicitationAttributes {
256    fn parse(attributes: syn::parse::ParseStream) -> syn::Result<Self> {
257        let mut instance = Self { message: None };
258        let meta_list: Punctuated<Meta, Token![,]> = Punctuated::parse_terminated(attributes)?;
259        for meta in meta_list {
260            if let Meta::NameValue(meta_name_value) = meta {
261                let ident = meta_name_value.path.get_ident().unwrap();
262                let ident_str = ident.to_string();
263                if ident_str.as_str() == "message" {
264                    let value = match &meta_name_value.value {
265                        Expr::Lit(ExprLit {
266                            lit: Lit::Str(lit_str),
267                            ..
268                        }) => lit_str.value(),
269                        Expr::Macro(expr_macro) => {
270                            let mac = &expr_macro.mac;
271                            if mac.path.is_ident("concat") {
272                                let args: ExprList = syn::parse2(mac.tokens.clone())?;
273                                let mut result = String::new();
274                                for expr in args.exprs {
275                                    if let Expr::Lit(ExprLit {
276                                        lit: Lit::Str(lit_str),
277                                        ..
278                                    }) = expr
279                                    {
280                                        result.push_str(&lit_str.value());
281                                    } else {
282                                        return Err(Error::new_spanned(
283                                            expr,
284                                            "Only string literals are allowed inside concat!()",
285                                        ));
286                                    }
287                                }
288                                result
289                            } else {
290                                return Err(Error::new_spanned(
291                                    expr_macro,
292                                    "Only concat!(...) is supported here",
293                                ));
294                            }
295                        }
296                        _ => {
297                            return Err(Error::new_spanned(
298                                &meta_name_value.value,
299                                "Expected a string literal or concat!(...)",
300                            ));
301                        }
302                    };
303                    instance.message = Some(value)
304                }
305            }
306        }
307        Ok(instance)
308    }
309}
310
311/// A procedural macro attribute to generate rust_mcp_schema::Tool related utility methods for a struct.
312///
313/// The `mcp_tool` macro generates an implementation for the annotated struct that includes:
314/// - A `tool_name()` method returning the tool's name as a string.
315/// - A `tool()` method returning a `rust_mcp_schema::Tool` instance with the tool's name,
316///   description, input schema, meta, and title derived from the struct's fields and attributes.
317///
318/// # Attributes
319/// * `name` - The name of the tool (required, non-empty string).
320/// * `description` - A description of the tool (required, non-empty string).
321/// * `meta` - Optional JSON object as a string literal for metadata.
322/// * `title` - Optional string for the tool's title.
323///
324/// # Panics
325/// Panics if the macro is applied to anything other than a struct.
326///
327/// # Example
328/// ```rust,ignore
329/// # #[cfg(not(feature = "sdk"))]
330/// # {
331/// #[rust_mcp_macros::mcp_tool(
332///     name = "example_tool",
333///     description = "An example tool",
334///     meta = "{\"version\": \"1.0\"}",
335///     title = "Example Tool"
336/// )]
337/// #[derive(rust_mcp_macros::JsonSchema)]
338/// struct ExampleTool {
339///     field1: String,
340///     field2: i32,
341/// }
342///
343/// assert_eq!(ExampleTool::tool_name(), "example_tool");
344/// let tool: rust_mcp_schema::Tool = ExampleTool::tool();
345/// assert_eq!(tool.name, "example_tool");
346/// assert_eq!(tool.description.unwrap(), "An example tool");
347/// assert_eq!(tool.meta.as_ref().unwrap().get("version").unwrap(), "1.0");
348/// assert_eq!(tool.title.unwrap(), "Example Tool");
349///
350/// let schema_properties = tool.input_schema.properties.unwrap();
351/// assert_eq!(schema_properties.len(), 2);
352/// assert!(schema_properties.contains_key("field1"));
353/// assert!(schema_properties.contains_key("field2"));
354/// }
355/// ```
356#[proc_macro_attribute]
357pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream {
358    let input = parse_macro_input!(input as DeriveInput);
359    let input_ident = &input.ident;
360
361    // Conditionally select the path for Tool
362    let base_crate = if cfg!(feature = "sdk") {
363        quote! { rust_mcp_sdk::schema }
364    } else {
365        quote! { rust_mcp_schema }
366    };
367
368    let macro_attributes = parse_macro_input!(attributes as McpToolMacroAttributes);
369
370    let tool_name = macro_attributes.name.unwrap_or_default();
371    let tool_description = macro_attributes.description.unwrap_or_default();
372
373    #[cfg(not(feature = "2025_06_18"))]
374    let meta = quote! {};
375    #[cfg(feature = "2025_06_18")]
376    let meta = macro_attributes.meta.map_or(quote! { meta: None, }, |m| {
377        quote! { meta: Some(serde_json::from_str(#m).expect("Failed to parse meta JSON")), }
378    });
379
380    #[cfg(not(feature = "2025_06_18"))]
381    let title = quote! {};
382    #[cfg(feature = "2025_06_18")]
383    let title = macro_attributes.title.map_or(
384        quote! { title: None, },
385        |t| quote! { title: Some(#t.to_string()), },
386    );
387
388    #[cfg(not(feature = "2025_06_18"))]
389    let output_schema = quote! {};
390    #[cfg(feature = "2025_06_18")]
391    let output_schema = quote! { output_schema: None,};
392
393    #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
394    let some_annotations = macro_attributes.destructive_hint.is_some()
395        || macro_attributes.idempotent_hint.is_some()
396        || macro_attributes.open_world_hint.is_some()
397        || macro_attributes.read_only_hint.is_some();
398
399    #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
400    let annotations = if some_annotations {
401        let destructive_hint = macro_attributes
402            .destructive_hint
403            .map_or(quote! {None}, |v| quote! {Some(#v)});
404
405        let idempotent_hint = macro_attributes
406            .idempotent_hint
407            .map_or(quote! {None}, |v| quote! {Some(#v)});
408        let open_world_hint = macro_attributes
409            .open_world_hint
410            .map_or(quote! {None}, |v| quote! {Some(#v)});
411        let read_only_hint = macro_attributes
412            .read_only_hint
413            .map_or(quote! {None}, |v| quote! {Some(#v)});
414        quote! {
415            Some(#base_crate::ToolAnnotations {
416                destructive_hint: #destructive_hint,
417                idempotent_hint: #idempotent_hint,
418                open_world_hint: #open_world_hint,
419                read_only_hint: #read_only_hint,
420                title: None,
421            })
422        }
423    } else {
424        quote! { None }
425    };
426
427    let annotations_token = {
428        #[cfg(any(feature = "2025_03_26", feature = "2025_06_18"))]
429        {
430            quote! { annotations: #annotations, }
431        }
432        #[cfg(not(any(feature = "2025_03_26", feature = "2025_06_18")))]
433        {
434            quote! {}
435        }
436    };
437
438    let tool_token = quote! {
439        #base_crate::Tool {
440            name: #tool_name.to_string(),
441            description: Some(#tool_description.to_string()),
442            #output_schema
443            #title
444            #meta
445            #annotations_token
446            input_schema: #base_crate::ToolInputSchema::new(required, properties)
447        }
448    };
449
450    let output = quote! {
451        impl #input_ident {
452            /// Returns the name of the tool as a String.
453            pub fn tool_name() -> String {
454                #tool_name.to_string()
455            }
456
457            /// Constructs and returns a `rust_mcp_schema::Tool` instance.
458            ///
459            /// The tool includes the name, description, input schema, meta, and title derived from
460            /// the struct's attributes.
461            pub fn tool() -> #base_crate::Tool {
462                let json_schema = &#input_ident::json_schema();
463
464                let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) {
465                    Some(arr) => arr
466                        .iter()
467                        .filter_map(|item| item.as_str().map(String::from))
468                        .collect(),
469                    None => Vec::new(),
470                };
471
472                let properties: Option<
473                    std::collections::HashMap<String, serde_json::Map<String, serde_json::Value>>,
474                > = json_schema
475                    .get("properties")
476                    .and_then(|v| v.as_object()) // Safely extract "properties" as an object.
477                    .map(|properties| {
478                        properties
479                            .iter()
480                            .filter_map(|(key, value)| {
481                                serde_json::to_value(value)
482                                    .ok() // If serialization fails, return None.
483                                    .and_then(|v| {
484                                        if let serde_json::Value::Object(obj) = v {
485                                            Some(obj)
486                                        } else {
487                                            None
488                                        }
489                                    })
490                                    .map(|obj| (key.to_string(), obj)) // Return the (key, value) tuple
491                            })
492                            .collect()
493                    });
494
495                #tool_token
496            }
497        }
498        // Retain the original item (struct definition)
499        #input
500    };
501
502    TokenStream::from(output)
503}
504
505#[proc_macro_attribute]
506pub fn mcp_elicit(attributes: TokenStream, input: TokenStream) -> TokenStream {
507    let input = parse_macro_input!(input as DeriveInput);
508    let input_ident = &input.ident;
509
510    // Conditionally select the path
511    let base_crate = if cfg!(feature = "sdk") {
512        quote! { rust_mcp_sdk::schema }
513    } else {
514        quote! { rust_mcp_schema }
515    };
516
517    let macro_attributes = parse_macro_input!(attributes as McpElicitationAttributes);
518    let message = macro_attributes.message.unwrap_or_default();
519
520    // Generate field assignments for from_content_map()
521    let field_assignments = match &input.data {
522        Data::Struct(data) => match &data.fields {
523            Fields::Named(fields) => {
524                let assignments = fields.named.iter().map(|field| {
525                      let field_attrs = &field.attrs;
526                      let field_ident = &field.ident;
527                      let renamed_field = renamed_field(field_attrs);
528                      let field_name = renamed_field.unwrap_or_else(|| field_ident.as_ref().unwrap().to_string());
529                      let field_type = &field.ty;
530
531                      let type_check = if is_option(field_type) {
532                          // Extract inner type for Option<T>
533                          let inner_type = match field_type {
534                              Type::Path(type_path) => {
535                                  let segment = type_path.path.segments.last().unwrap();
536                                  if segment.ident == "Option" {
537                                      match &segment.arguments {
538                                          PathArguments::AngleBracketed(args) => {
539                                              match args.args.first().unwrap() {
540                                                  GenericArgument::Type(ty) => ty,
541                                                  _ => panic!("Expected type argument in Option<T>"),
542                                              }
543                                          }
544                                          _ => panic!("Invalid Option type"),
545                                      }
546                                  } else {
547                                      panic!("Expected Option type");
548                                  }
549                              }
550                              _ => panic!("Expected Option type"),
551                          };
552                          // Determine the match arm based on the inner type at compile time
553                          let (inner_type_ident, match_pattern, conversion) = match inner_type {
554                              Type::Path(type_path) if type_path.path.is_ident("String") => (
555                                  quote! { String },
556                                  quote! { #base_crate::ElicitResultContentValue::String(s) },
557                                  quote! { s.clone() }
558                              ),
559                              Type::Path(type_path) if type_path.path.is_ident("bool") => (
560                                  quote! { bool },
561                                  quote! { #base_crate::ElicitResultContentValue::Boolean(b) },
562                                  quote! { *b }
563                              ),
564                              Type::Path(type_path) if type_path.path.is_ident("i32") => (
565                                  quote! { i32 },
566                                  quote! { #base_crate::ElicitResultContentValue::Integer(i) },
567                                  quote! {
568                                      (*i).try_into().map_err(|_| #base_crate::RpcError::parse_error().with_message(format!(
569                                          "Invalid number for field '{}': value {} does not fit in i32",
570                                          #field_name, *i
571                                      )))?
572                                  }
573                              ),
574                              Type::Path(type_path) if type_path.path.is_ident("i64") => (
575                                  quote! { i64 },
576                                  quote! { #base_crate::ElicitResultContentValue::Integer(i) },
577                                  quote! { *i }
578                              ),
579                              _ if is_enum(inner_type, &input) => {
580                                  let enum_parse = generate_enum_parse(inner_type, &field_name, &base_crate);
581                                  (
582                                      quote! { #inner_type },
583                                      quote! { #base_crate::ElicitResultContentValue::String(s) },
584                                      quote! { #enum_parse }
585                                  )
586                              }
587                              _ => panic!("Unsupported inner type for Option field: {}", quote! { #inner_type }),
588                          };
589                          let inner_type_str = quote! { stringify!(#inner_type_ident) };
590                          quote! {
591                              let #field_ident: Option<#inner_type_ident> = match content.as_ref().and_then(|map| map.get(#field_name)) {
592                                  Some(value) => {
593                                      match value {
594                                          #match_pattern => Some(#conversion),
595                                          _ => {
596                                              return Err(#base_crate::RpcError::parse_error().with_message(format!(
597                                                  "Type mismatch for field '{}': expected {}, found {}",
598                                                  #field_name, #inner_type_str,
599                                                  match value {
600                                                      #base_crate::ElicitResultContentValue::Boolean(_) => "boolean",
601                                                      #base_crate::ElicitResultContentValue::String(_) => "string",
602                                                      #base_crate::ElicitResultContentValue::Integer(_) => "integer",
603                                                  }
604                                              )));
605                                          }
606                                      }
607                                  }
608                                  None => None,
609                              };
610                          }
611                      } else {
612                          // Determine the match arm based on the field type at compile time
613                          let (field_type_ident, match_pattern, conversion) = match field_type {
614                              Type::Path(type_path) if type_path.path.is_ident("String") => (
615                                  quote! { String },
616                                  quote! { #base_crate::ElicitResultContentValue::String(s) },
617                                  quote! { s.clone() }
618                              ),
619                              Type::Path(type_path) if type_path.path.is_ident("bool") => (
620                                  quote! { bool },
621                                  quote! { #base_crate::ElicitResultContentValue::Boolean(b) },
622                                  quote! { *b }
623                              ),
624                              Type::Path(type_path) if type_path.path.is_ident("i32") => (
625                                  quote! { i32 },
626                                  quote! { #base_crate::ElicitResultContentValue::Integer(i) },
627                                  quote! {
628                                      (*i).try_into().map_err(|_| #base_crate::RpcError::parse_error().with_message(format!(
629                                          "Invalid number for field '{}': value {} does not fit in i32",
630                                          #field_name, *i
631                                      )))?
632                                  }
633                              ),
634                              Type::Path(type_path) if type_path.path.is_ident("i64") => (
635                                  quote! { i64 },
636                                  quote! { #base_crate::ElicitResultContentValue::Integer(i) },
637                                  quote! { *i }
638                              ),
639                              _ if is_enum(field_type, &input) => {
640                                  let enum_parse = generate_enum_parse(field_type, &field_name, &base_crate);
641                                  (
642                                      quote! { #field_type },
643                                      quote! { #base_crate::ElicitResultContentValue::String(s) },
644                                      quote! { #enum_parse }
645                                  )
646                              }
647                              _ => panic!("Unsupported field type: {}", quote! { #field_type }),
648                          };
649                          let type_str = quote! { stringify!(#field_type_ident) };
650                          quote! {
651                              let #field_ident: #field_type_ident = match content.as_ref().and_then(|map| map.get(#field_name)) {
652                                  Some(value) => {
653                                      match value {
654                                          #match_pattern => #conversion,
655                                          _ => {
656                                              return Err(#base_crate::RpcError::parse_error().with_message(format!(
657                                                  "Type mismatch for field '{}': expected {}, found {}",
658                                                  #field_name, #type_str,
659                                                  match value {
660                                                      #base_crate::ElicitResultContentValue::Boolean(_) => "boolean",
661                                                      #base_crate::ElicitResultContentValue::String(_) => "string",
662                                                      #base_crate::ElicitResultContentValue::Integer(_) => "integer",
663                                                  }
664                                              )));
665                                          }
666                                      }
667                                  }
668                                  None => {
669                                      return Err(#base_crate::RpcError::parse_error().with_message(format!(
670                                          "Missing required field: {}",
671                                          #field_name
672                                      )));
673                                  }
674                              };
675                          }
676                      };
677
678                      type_check
679                  });
680
681                let field_idents = fields.named.iter().map(|field| &field.ident);
682
683                quote! {
684                    #(#assignments)*
685
686                    Ok(Self {
687                        #(#field_idents,)*
688                    })
689                }
690            }
691            _ => panic!("mcp_elicit macro only supports structs with named fields"),
692        },
693        _ => panic!("mcp_elicit macro only supports structs"),
694    };
695
696    let output = quote! {
697        impl #input_ident {
698
699            /// Returns the elicitation message defined in the `#[mcp_elicit(message = "...")]` attribute.
700            ///
701            /// This message is used to prompt the user or system for input when eliciting data for the struct.
702            /// If no message is provided in the attribute, an empty string is returned.
703            ///
704            /// # Returns
705            /// A `String` containing the elicitation message.
706            pub fn message()->String{
707                #message.to_string()
708            }
709
710            /// This method returns a `ElicitRequestedSchema` by retrieves the
711            /// struct's JSON schema (via the `JsonSchema` derive) and converting int into
712            /// a `ElicitRequestedSchema`. It extracts the `required` fields and
713            /// `properties` from the schema, mapping them to a `HashMap` of `PrimitiveSchemaDefinition` objects.
714            ///
715            /// # Returns
716            /// An `ElicitRequestedSchema` representing the schema of the struct.
717            ///
718            /// # Panics
719            /// Panics if the schema's properties cannot be converted to `PrimitiveSchemaDefinition` or if the schema
720            /// is malformed.
721            pub fn requested_schema() -> #base_crate::ElicitRequestedSchema {
722                let json_schema = &#input_ident::json_schema();
723
724                let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) {
725                    Some(arr) => arr
726                        .iter()
727                        .filter_map(|item| item.as_str().map(String::from))
728                        .collect(),
729                    None => Vec::new(),
730                };
731
732                let properties: Option<std::collections::HashMap<String, _>> = json_schema
733                    .get("properties")
734                    .and_then(|v| v.as_object()) // Safely extract "properties" as an object.
735                    .map(|properties| {
736                        properties
737                            .iter()
738                            .filter_map(|(key, value)| {
739                                serde_json::to_value(value)
740                                    .ok() // If serialization fails, return None.
741                                    .and_then(|v| {
742                                        if let serde_json::Value::Object(obj) = v {
743                                            Some(obj)
744                                        } else {
745                                            None
746                                        }
747                                    })
748                                    .map(|obj| (key.to_string(), #base_crate::PrimitiveSchemaDefinition::try_from(&obj)))
749                            })
750                            .collect()
751                    });
752
753                let properties = properties
754                    .map(|map| {
755                        map.into_iter()
756                            .map(|(k, v)| v.map(|ok_v| (k, ok_v))) // flip Result inside tuple
757                            .collect::<Result<std::collections::HashMap<_, _>, _>>() // collect only if all Ok
758                    })
759                    .transpose()
760                    .unwrap();
761
762                let properties =
763                    properties.expect("Was not able to create a ElicitRequestedSchema");
764
765                let requested_schema = #base_crate::ElicitRequestedSchema::new(properties, required);
766                requested_schema
767            }
768
769            /// Converts a map of field names and `ElicitResultContentValue` into an instance of the struct.
770            ///
771            /// This method parses the provided content map, matching field names to struct fields and converting
772            /// `ElicitResultContentValue` variants into the appropriate Rust types (e.g., `String`, `bool`, `i32`,
773            /// `i64`, or simple enums). It supports both required and optional fields (`Option<T>`).
774            ///
775            /// # Parameters
776            /// - `content`: An optional `HashMap` mapping field names to `ElicitResultContentValue` values.
777            ///
778            /// # Returns
779            /// - `Ok(Self)` if the map is successfully parsed into the struct.
780            /// - `Err(RpcError)` if:
781            ///   - A required field is missing.
782            ///   - A value’s type does not match the expected field type.
783            ///   - An integer value cannot be converted (e.g., `i64` to `i32` out of bounds).
784            ///   - An enum value is invalid (e.g., string value does not match a enum variant name).
785            ///
786            /// # Errors
787            /// Returns `RpcError` with messages like:
788            /// - `"Missing required field: {}"`
789            /// - `"Type mismatch for field '{}': expected {}, found {}"`
790            /// - `"Invalid number for field '{}': value {} does not fit in i32"`
791            /// - `"Invalid enum value for field '{}': expected 'Yes' or 'No', found '{}'"`.
792            pub fn from_content_map(content: ::std::option::Option<::std::collections::HashMap<::std::string::String, #base_crate::ElicitResultContentValue>>) -> Result<Self, #base_crate::RpcError> {
793                #field_assignments
794            }
795        }
796        #input
797    };
798
799    TokenStream::from(output)
800}
801
802/// Derives a JSON Schema representation for a struct.
803///
804/// This procedural macro generates a `json_schema()` method for the annotated struct, returning a
805/// `serde_json::Map<String, serde_json::Value>` that represents the struct as a JSON Schema object.
806/// The schema includes the struct's fields as properties, with support for basic types, `Option<T>`,
807/// `Vec<T>`, and nested structs that also derive `JsonSchema`.
808///
809/// # Features
810/// - **Basic Types:** Maps `String` to `"string"`, `i32` to `"integer"`, `bool` to `"boolean"`, etc.
811/// - **`Option<T>`:** Adds `"nullable": true` to the schema of the inner type, indicating the field is optional.
812/// - **`Vec<T>`:** Generates an `"array"` schema with an `"items"` field describing the inner type.
813/// - **Nested Structs:** Recursively includes the schema of nested structs (assumed to derive `JsonSchema`),
814///   embedding their `"properties"` and `"required"` fields.
815/// - **Required Fields:** Adds a top-level `"required"` array listing field names not wrapped in `Option`.
816///
817/// # Notes
818/// It’s designed as a straightforward solution to meet the basic needs of this package, supporting
819/// common types and simple nested structures. For more advanced features or robust JSON Schema generation,
820/// consider exploring established crates like
821/// [`schemars`](https://crates.io/crates/schemars) on crates.io
822///
823/// # Limitations
824/// - Supports only structs with named fields (e.g., `struct S { field: Type }`).
825/// - Nested structs must also derive `JsonSchema`, or compilation will fail.
826/// - Unknown types are mapped to `{"type": "unknown"}`.
827/// - Type paths must be in scope (e.g., fully qualified paths like `my_mod::InnerStruct` work if imported).
828///
829/// # Panics
830/// - If the input is not a struct with named fields (e.g., tuple structs or enums).
831///
832/// # Dependencies
833/// Relies on `serde_json` for `Map` and `Value` types.
834///
835#[proc_macro_derive(JsonSchema, attributes(json_schema))]
836pub fn derive_json_schema(input: TokenStream) -> TokenStream {
837    let input = syn::parse_macro_input!(input as DeriveInput);
838    let name = &input.ident;
839
840    let schema_body = match &input.data {
841        Data::Struct(data) => match &data.fields {
842            Fields::Named(fields) => {
843                let field_entries = fields.named.iter().map(|field| {
844                    let field_attrs = &field.attrs;
845                    let renamed_field = renamed_field(field_attrs);
846                    let field_name =
847                        renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string());
848                    let field_type = &field.ty;
849
850                    let schema = type_to_json_schema(field_type, field_attrs);
851                    quote! {
852                        properties.insert(
853                            #field_name.to_string(),
854                            serde_json::Value::Object(#schema)
855                        );
856                    }
857                });
858
859                let required_fields = fields.named.iter().filter_map(|field| {
860                    let renamed_field = renamed_field(&field.attrs);
861                    let field_name =
862                        renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string());
863
864                    let field_type = &field.ty;
865                    if !is_option(field_type) {
866                        Some(quote! {
867                            required.push(#field_name.to_string());
868                        })
869                    } else {
870                        None
871                    }
872                });
873
874                quote! {
875                    let mut schema = serde_json::Map::new();
876                    let mut properties = serde_json::Map::new();
877                    let mut required = Vec::new();
878
879                    #(#field_entries)*
880
881                    #(#required_fields)*
882
883                    schema.insert("type".to_string(), serde_json::Value::String("object".to_string()));
884                    schema.insert("properties".to_string(), serde_json::Value::Object(properties));
885                    if !required.is_empty() {
886                        schema.insert("required".to_string(), serde_json::Value::Array(
887                            required.into_iter().map(serde_json::Value::String).collect()
888                        ));
889                    }
890
891                    schema
892                }
893            }
894            _ => panic!("JsonSchema derive macro only supports named fields for structs"),
895        },
896        Data::Enum(data) => {
897            let variant_schemas = data.variants.iter().map(|variant| {
898                let variant_attrs = &variant.attrs;
899                let variant_name = variant.ident.to_string();
900                let renamed_variant = renamed_field(variant_attrs).unwrap_or(variant_name.clone());
901
902                // Parse variant-level json_schema attributes
903                let mut title: Option<String> = None;
904                let mut description: Option<String> = None;
905                for attr in variant_attrs {
906                    if attr.path().is_ident("json_schema") {
907                        let _ = attr.parse_nested_meta(|meta| {
908                            if meta.path.is_ident("title") {
909                                title = Some(meta.value()?.parse::<syn::LitStr>()?.value());
910                            } else if meta.path.is_ident("description") {
911                                description = Some(meta.value()?.parse::<syn::LitStr>()?.value());
912                            }
913                            Ok(())
914                        });
915                    }
916                }
917
918                let title_quote = title.as_ref().map(|t| {
919                    quote! { map.insert("title".to_string(), serde_json::Value::String(#t.to_string())); }
920                });
921                let description_quote = description.as_ref().map(|desc| {
922                    quote! { map.insert("description".to_string(), serde_json::Value::String(#desc.to_string())); }
923                });
924
925                match &variant.fields {
926                    Fields::Unit => {
927                        // Unit variant: use "enum" with the variant name
928                        quote! {
929                            {
930                                let mut map = serde_json::Map::new();
931                                map.insert("enum".to_string(), serde_json::Value::Array(vec![
932                                    serde_json::Value::String(#renamed_variant.to_string())
933                                ]));
934                                #title_quote
935                                #description_quote
936                                serde_json::Value::Object(map)
937                            }
938                        }
939                    }
940                    Fields::Unnamed(fields) => {
941                        // Newtype or tuple variant
942                        if fields.unnamed.len() == 1 {
943                            // Newtype variant: use the inner type's schema
944                            let field = &fields.unnamed[0];
945                            let field_type = &field.ty;
946                            let field_attrs = &field.attrs;
947                            let schema = type_to_json_schema(field_type, field_attrs);
948                            quote! {
949                                {
950                                    let mut map = #schema;
951                                    #title_quote
952                                    #description_quote
953                                    serde_json::Value::Object(map)
954                                }
955                            }
956                        } else {
957                            // Tuple variant: array with items
958                            let field_schemas = fields.unnamed.iter().map(|field| {
959                                let field_type = &field.ty;
960                                let field_attrs = &field.attrs;
961                                let schema = type_to_json_schema(field_type, field_attrs);
962                                quote! { serde_json::Value::Object(#schema) }
963                            });
964                            quote! {
965                                {
966                                    let mut map = serde_json::Map::new();
967                                    map.insert("type".to_string(), serde_json::Value::String("array".to_string()));
968                                    map.insert("items".to_string(), serde_json::Value::Array(vec![#(#field_schemas),*]));
969                                    map.insert("additionalItems".to_string(), serde_json::Value::Bool(false));
970                                    #title_quote
971                                    #description_quote
972                                    serde_json::Value::Object(map)
973                                }
974                            }
975                        }
976                    }
977                    Fields::Named(fields) => {
978                        // Struct variant: object with properties and required fields
979                        let field_entries = fields.named.iter().map(|field| {
980                            let field_attrs = &field.attrs;
981                            let renamed_field = renamed_field(field_attrs);
982                            let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string());
983                            let field_type = &field.ty;
984
985                            let schema = type_to_json_schema(field_type, field_attrs);
986                            quote! {
987                                properties.insert(
988                                    #field_name.to_string(),
989                                    serde_json::Value::Object(#schema)
990                                );
991                            }
992                        });
993
994                        let required_fields = fields.named.iter().filter_map(|field| {
995                            let renamed_field = renamed_field(&field.attrs);
996                            let field_name = renamed_field.unwrap_or(field.ident.as_ref().unwrap().to_string());
997
998                            let field_type = &field.ty;
999                            if !is_option(field_type) {
1000                                Some(quote! {
1001                                    required.push(#field_name.to_string());
1002                                })
1003                            } else {
1004                                None
1005                            }
1006                        });
1007
1008                        quote! {
1009                            {
1010                                let mut map = serde_json::Map::new();
1011                                let mut properties = serde_json::Map::new();
1012                                let mut required = Vec::new();
1013
1014                                #(#field_entries)*
1015
1016                                #(#required_fields)*
1017
1018                                map.insert("type".to_string(), serde_json::Value::String("object".to_string()));
1019                                map.insert("properties".to_string(), serde_json::Value::Object(properties));
1020                                if !required.is_empty() {
1021                                    map.insert("required".to_string(), serde_json::Value::Array(
1022                                        required.into_iter().map(serde_json::Value::String).collect()
1023                                    ));
1024                                }
1025                                #title_quote
1026                                #description_quote
1027                                serde_json::Value::Object(map)
1028                            }
1029                        }
1030                    }
1031                }
1032            });
1033
1034            quote! {
1035                let mut schema = serde_json::Map::new();
1036                schema.insert("oneOf".to_string(), serde_json::Value::Array(vec![
1037                    #(#variant_schemas),*
1038                ]));
1039                schema
1040            }
1041        }
1042        _ => panic!("JsonSchema derive macro only supports structs and enums"),
1043    };
1044
1045    let expanded = quote! {
1046        impl #name {
1047            pub fn json_schema() -> serde_json::Map<String, serde_json::Value> {
1048                #schema_body
1049            }
1050        }
1051    };
1052    TokenStream::from(expanded)
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057    use super::*;
1058    use syn::parse_str;
1059    #[test]
1060    fn test_valid_macro_attributes() {
1061        let input = r#"name = "test_tool", description = "A test tool.", meta = "{\"version\": \"1.0\"}", title = "Test Tool""#;
1062        let parsed: McpToolMacroAttributes = parse_str(input).unwrap();
1063
1064        assert_eq!(parsed.name.unwrap(), "test_tool");
1065        assert_eq!(parsed.description.unwrap(), "A test tool.");
1066        assert_eq!(parsed.meta.unwrap(), "{\"version\": \"1.0\"}");
1067        assert_eq!(parsed.title.unwrap(), "Test Tool");
1068    }
1069
1070    #[test]
1071    fn test_missing_name() {
1072        let input = r#"description = "Only description""#;
1073        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
1074        assert!(result.is_err());
1075        assert_eq!(
1076            result.err().unwrap().to_string(),
1077            "The 'name' attribute is required and must not be empty."
1078        );
1079    }
1080
1081    #[test]
1082    fn test_missing_description() {
1083        let input = r#"name = "OnlyName""#;
1084        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
1085        assert!(result.is_err());
1086        assert_eq!(
1087            result.err().unwrap().to_string(),
1088            "The 'description' attribute is required and must not be empty."
1089        );
1090    }
1091
1092    #[test]
1093    fn test_empty_name_field() {
1094        let input = r#"name = "", description = "something""#;
1095        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
1096        assert!(result.is_err());
1097        assert_eq!(
1098            result.err().unwrap().to_string(),
1099            "The 'name' attribute is required and must not be empty."
1100        );
1101    }
1102
1103    #[test]
1104    fn test_empty_description_field() {
1105        let input = r#"name = "my-tool", description = """#;
1106        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
1107        assert!(result.is_err());
1108        assert_eq!(
1109            result.err().unwrap().to_string(),
1110            "The 'description' attribute is required and must not be empty."
1111        );
1112    }
1113
1114    #[test]
1115    fn test_invalid_meta() {
1116        let input =
1117            r#"name = "test_tool", description = "A test tool.", meta = "not_a_json_object""#;
1118        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
1119        assert!(result.is_err());
1120        assert!(result
1121            .err()
1122            .unwrap()
1123            .to_string()
1124            .contains("Expected a valid JSON object"));
1125    }
1126
1127    #[test]
1128    fn test_non_object_meta() {
1129        let input = r#"name = "test_tool", description = "A test tool.", meta = "[1, 2, 3]""#;
1130        let result: Result<McpToolMacroAttributes, Error> = parse_str(input);
1131        assert!(result.is_err());
1132        assert_eq!(result.err().unwrap().to_string(), "Expected a JSON object");
1133    }
1134}