Skip to main content

wasmcp_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, parse::Parser, punctuated::Punctuated, ItemMod, Meta, Token};
4
5const WIT_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/wit");
6
7/// Creates an MCP handler with the specified tools, resources, and prompts.
8///
9/// This macro generates all the necessary WebAssembly bindings and handler logic
10/// with zero runtime overhead. It handles WIT file generation automatically,
11/// so you don't need any local WIT files in your project.
12///
13/// # Example
14///
15/// ```rust,ignore
16/// use wasmcp::mcp_handler;
17///
18/// #[mcp_handler(
19///     tools(EchoTool, CalculatorTool),
20///     resources(ConfigResource),
21///     prompts(GreetingPrompt),
22/// )]
23/// mod handler {}
24/// ```
25#[proc_macro_attribute]
26pub fn mcp_handler(args: TokenStream, input: TokenStream) -> TokenStream {
27    let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
28    let args = parse_macro_input!(args with parser);
29    let _input_mod = parse_macro_input!(input as ItemMod);
30    
31    // Parse the arguments to extract tools, resources, and prompts
32    let mut tools = Vec::new();
33    let mut resources = Vec::new(); 
34    let mut prompts = Vec::new();
35    
36    for arg in args {
37        match arg {
38            Meta::List(list) => {
39                let name = list.path.get_ident().map(|i| i.to_string());
40                match name.as_deref() {
41                    Some("tools") => {
42                        // Parse the tokens inside tools(...)
43                        let tokens = list.tokens.clone();
44                        let parser = Punctuated::<syn::Path, Token![,]>::parse_terminated;
45                        if let Ok(paths) = parser.parse(tokens.into()) {
46                            for path in paths {
47                                if let Some(ident) = path.get_ident() {
48                                    tools.push(ident.clone());
49                                }
50                            }
51                        }
52                    }
53                    Some("resources") => {
54                        // Parse the tokens inside resources(...)
55                        let tokens = list.tokens.clone();
56                        let parser = Punctuated::<syn::Path, Token![,]>::parse_terminated;
57                        if let Ok(paths) = parser.parse(tokens.into()) {
58                            for path in paths {
59                                if let Some(ident) = path.get_ident() {
60                                    resources.push(ident.clone());
61                                }
62                            }
63                        }
64                    }
65                    Some("prompts") => {
66                        // Parse the tokens inside prompts(...)
67                        let tokens = list.tokens.clone();
68                        let parser = Punctuated::<syn::Path, Token![,]>::parse_terminated;
69                        if let Ok(paths) = parser.parse(tokens.into()) {
70                            for path in paths {
71                                if let Some(ident) = path.get_ident() {
72                                    prompts.push(ident.clone());
73                                }
74                            }
75                        }
76                    }
77                    _ => {}
78                }
79            }
80            _ => {}
81        }
82    }
83    
84    // Generate the bindings and handler implementation
85    let preamble = generate_preamble();
86    let handler_impl = generate_handler_impl(&tools, &resources, &prompts);
87    
88    quote! {
89        #preamble
90        #handler_impl
91    }.into()
92}
93
94fn generate_preamble() -> proc_macro2::TokenStream {
95    // Create a string literal token from the WIT_PATH constant
96    let wit_path = proc_macro2::Literal::string(WIT_PATH);
97    quote! {
98        #[allow(warnings)]
99        mod __wasmcp_bindings {
100            #![allow(missing_docs)]
101            ::wasmcp::wit_bindgen::generate!({
102                world: "mcp-handler",
103                path: #wit_path,
104                runtime_path: "::wit_bindgen_rt",
105                generate_all,
106            });
107            pub use self::exports::wasmcp::mcp::handler;
108        }
109    }
110}
111
112fn generate_handler_impl(
113    tools: &[syn::Ident],
114    resources: &[syn::Ident],  
115    prompts: &[syn::Ident],
116) -> proc_macro2::TokenStream {
117    // Generate tool handling code
118    let tool_list = if tools.is_empty() {
119        quote! { vec![] }
120    } else {
121        quote! {
122            vec![
123                #( 
124                    __wasmcp_bindings::handler::Tool {
125                        name: <#tools as ::wasmcp::ToolHandler>::NAME.to_string(),
126                        description: <#tools as ::wasmcp::ToolHandler>::DESCRIPTION.to_string(),
127                        input_schema: <#tools as ::wasmcp::ToolHandler>::input_schema().to_string(),
128                    }
129                ),*
130            ]
131        }
132    };
133
134    let tool_call = if tools.is_empty() {
135        quote! {
136            __wasmcp_bindings::handler::ToolResult::Error(__wasmcp_bindings::handler::Error {
137                code: -32601,
138                message: format!("Tool not found: {}", name),
139                data: None,
140            })
141        }
142    } else {
143        quote! {
144            match name.as_str() {
145                #(
146                    <#tools as ::wasmcp::ToolHandler>::NAME => {
147                        // Call through ToolHandler trait - it handles both sync and async
148                        match <#tools as ::wasmcp::ToolHandler>::execute(args_json) {
149                            Ok(result) => __wasmcp_bindings::handler::ToolResult::Text(result),
150                            Err(e) => __wasmcp_bindings::handler::ToolResult::Error(__wasmcp_bindings::handler::Error {
151                                code: -32603,
152                                message: e,
153                                data: None,
154                            }),
155                        }
156                    }
157                )*
158                _ => __wasmcp_bindings::handler::ToolResult::Error(__wasmcp_bindings::handler::Error {
159                    code: -32601,
160                    message: format!("Tool not found: {}", name),
161                    data: None,
162                }),
163            }
164        }
165    };
166
167    // Similar for resources
168    let resource_list = if resources.is_empty() {
169        quote! { vec![] }
170    } else {
171        quote! {
172            vec![
173                #(
174                    {
175                        let info = <#resources as ::wasmcp::ResourceHandler>::list();
176                        info.into_iter().map(|r| __wasmcp_bindings::handler::ResourceInfo {
177                            uri: r.uri,
178                            name: r.name,  
179                            description: r.description,
180                            mime_type: r.mime_type,
181                        }).collect::<Vec<_>>()
182                    }
183                ),*
184            ].into_iter().flatten().collect()
185        }
186    };
187
188    let resource_read = if resources.is_empty() {
189        quote! {
190            __wasmcp_bindings::handler::ResourceResult::Error(__wasmcp_bindings::handler::Error {
191                code: -32601,
192                message: format!("Resource not found: {}", uri),
193                data: None,
194            })
195        }
196    } else {
197        quote! {
198            #(
199                if let Ok(contents) = <#resources as ::wasmcp::ResourceHandler>::read(&uri) {
200                    return __wasmcp_bindings::handler::ResourceResult::Contents(__wasmcp_bindings::handler::ResourceContents {
201                        uri: contents.uri,
202                        mime_type: contents.mime_type,
203                        text: contents.text,
204                        blob: contents.blob,
205                    });
206                }
207            )*
208            __wasmcp_bindings::handler::ResourceResult::Error(__wasmcp_bindings::handler::Error {
209                code: -32601,
210                message: format!("Resource not found: {}", uri),
211                data: None,
212            })
213        }
214    };
215
216    // Similar for prompts
217    let prompt_list = if prompts.is_empty() {
218        quote! { vec![] }
219    } else {
220        quote! {
221            vec![
222                #(
223                    {
224                        let prompt = <#prompts as ::wasmcp::PromptHandler>::describe();
225                        __wasmcp_bindings::handler::Prompt {
226                            name: prompt.name,
227                            description: prompt.description,
228                            arguments: prompt.arguments.into_iter().map(|a| __wasmcp_bindings::handler::PromptArgument {
229                                name: a.name,
230                                description: a.description,
231                                required: a.required,
232                            }).collect(),
233                        }
234                    }
235                ),*
236            ]
237        }
238    };
239
240    let prompt_get = if prompts.is_empty() {
241        quote! {
242            __wasmcp_bindings::handler::PromptResult::Error(__wasmcp_bindings::handler::Error {
243                code: -32601,
244                message: format!("Prompt not found: {}", name),
245                data: None,
246            })
247        }
248    } else {
249        quote! {
250            match name.as_str() {
251                #(
252                    <#prompts as ::wasmcp::PromptHandler>::NAME => {
253                        match <#prompts as ::wasmcp::PromptHandler>::get_messages(args_json) {
254                            Ok(messages) => __wasmcp_bindings::handler::PromptResult::Messages(
255                                messages.into_iter().map(|m| __wasmcp_bindings::handler::PromptMessage {
256                                    role: match m.role {
257                                        ::wasmcp::PromptRole::User => "user".to_string(),
258                                        ::wasmcp::PromptRole::Assistant => "assistant".to_string(),
259                                    },
260                                    content: m.content,
261                                }).collect()
262                            ),
263                            Err(e) => __wasmcp_bindings::handler::PromptResult::Error(__wasmcp_bindings::handler::Error {
264                                code: -32603,
265                                message: e,
266                                data: None,
267                            }),
268                        }
269                    }
270                )*
271                _ => __wasmcp_bindings::handler::PromptResult::Error(__wasmcp_bindings::handler::Error {
272                    code: -32601,
273                    message: format!("Prompt not found: {}", name),
274                    data: None,
275                }),
276            }
277        }
278    };
279
280
281    quote! {
282        struct __WasmcpHandler;
283
284        impl __wasmcp_bindings::handler::Guest for __WasmcpHandler {
285            fn list_tools() -> Vec<__wasmcp_bindings::handler::Tool> {
286                #tool_list
287            }
288
289            fn call_tool(name: String, arguments: String) -> __wasmcp_bindings::handler::ToolResult {
290                let args_json: ::serde_json::Value = match ::serde_json::from_str(&arguments) {
291                    Ok(v) => v,
292                    Err(e) => {
293                        return __wasmcp_bindings::handler::ToolResult::Error(__wasmcp_bindings::handler::Error {
294                            code: -32700,
295                            message: format!("Invalid JSON: {}", e),
296                            data: None,
297                        });
298                    }
299                };
300
301                #tool_call
302            }
303
304            fn list_resources() -> Vec<__wasmcp_bindings::handler::ResourceInfo> {
305                #resource_list
306            }
307
308            fn read_resource(uri: String) -> __wasmcp_bindings::handler::ResourceResult {
309                #resource_read
310            }
311
312            fn list_prompts() -> Vec<__wasmcp_bindings::handler::Prompt> {
313                #prompt_list
314            }
315
316            fn get_prompt(name: String, arguments: String) -> __wasmcp_bindings::handler::PromptResult {
317                let args_json: ::serde_json::Value = match ::serde_json::from_str(&arguments) {
318                    Ok(v) => v,
319                    Err(e) => {
320                        return __wasmcp_bindings::handler::PromptResult::Error(__wasmcp_bindings::handler::Error {
321                            code: -32700,
322                            message: format!("Invalid JSON: {}", e),
323                            data: None,
324                        });
325                    }
326                };
327                
328                #prompt_get
329            }
330        }
331
332        __wasmcp_bindings::export!(__WasmcpHandler with_types_in __wasmcp_bindings);
333    }
334}