Skip to main content

rust_web_server/mcp/
mod.rs

1//! Model Context Protocol (MCP) server — HTTP Streamable HTTP transport.
2//!
3//! [`McpServer`] implements [`Application`] so it can be passed directly to
4//! [`Server::run`]. Unmatched requests fall through to the built-in [`App`]
5//! controller chain (static files, health probes, etc.).
6//!
7//! # Quick start
8//!
9//! ```rust,no_run
10//! use rust_web_server::server::Server;
11//! use rust_web_server::mcp::{McpServer, McpContent, PromptMessage};
12//! # #[cfg(not(feature = "http2"))]
13//! # fn main() {
14//! let mcp = McpServer::new("my-server", "1.0")
15//!     // A tool: callable by the AI, like a function
16//!     .tool(
17//!         "echo",
18//!         "Echo text back",
19//!         r#"{"type":"object","properties":{"text":{"type":"string"}},"required":["text"]}"#,
20//!         |args| {
21//!             let text = rust_web_server::mcp::extract_arg(args, "text")
22//!                 .unwrap_or_else(|| "(nothing)".to_string());
23//!             Ok(McpContent::text(text))
24//!         },
25//!     )
26//!     // A resource: data the AI can read by URI
27//!     .resource(
28//!         "docs://{topic}",
29//!         "Documentation",
30//!         "Return documentation for a topic",
31//!         |uri| Ok(McpContent::text(format!("Documentation for: {uri}"))),
32//!     )
33//!     // A prompt template: reusable message structures
34//!     .prompt(
35//!         "summarize",
36//!         "Summarize the given text",
37//!         |args| {
38//!             let text = rust_web_server::mcp::extract_arg(args, "text")
39//!                 .unwrap_or_else(|| "some text".to_string());
40//!             Ok(vec![PromptMessage::user(format!("Please summarize: {text}"))])
41//!         },
42//!     );
43//!
44//! // let (listener, pool) = Server::setup().unwrap();
45//! // Server::run(listener, pool, mcp);
46//! # }
47//! ```
48//!
49//! # MCP endpoint
50//!
51//! All JSON-RPC messages are sent as `POST /mcp` (override with [`.at()`](McpServer::at)).
52//! The server implements the [MCP 2024-11-05 specification](https://spec.modelcontextprotocol.io).
53//!
54//! # Environment variables
55//!
56//! None — configure the server programmatically via the builder.
57
58mod json_rpc;
59
60#[cfg(test)]
61mod tests;
62
63use std::sync::Arc;
64
65use crate::app::App;
66use crate::application::Application;
67use crate::core::New;
68use crate::header::Header;
69use crate::mime_type::MimeType;
70use crate::range::Range;
71use crate::request::Request;
72use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
73use crate::server::ConnectionInfo;
74
75const PROTOCOL_VERSION: &str = "2024-11-05";
76
77// ── public content types ──────────────────────────────────────────────────────
78
79/// Content returned by tool and resource handlers.
80///
81/// Create with [`McpContent::text`] (plain text or JSON strings) or
82/// [`McpContent::json`] (marks MIME type as `application/json`).
83#[derive(Clone, Debug)]
84pub struct McpContent {
85    /// Always `"text"` in the current MCP spec.
86    pub kind: &'static str,
87    /// The content string.
88    pub text: String,
89    /// Optional MIME type override (default `"text/plain"`).
90    pub mime_type: Option<String>,
91}
92
93impl McpContent {
94    /// Plain-text content.
95    pub fn text(s: impl Into<String>) -> Self {
96        McpContent { kind: "text", text: s.into(), mime_type: None }
97    }
98
99    /// JSON content — sets `mimeType` to `application/json`.
100    pub fn json(s: impl Into<String>) -> Self {
101        McpContent { kind: "text", text: s.into(), mime_type: Some("application/json".to_string()) }
102    }
103
104    fn to_content_json(&self) -> String {
105        let escaped = json_escape(&self.text);
106        format!(r#"{{"type":"{}","text":"{}"}}"#, self.kind, escaped)
107    }
108
109    fn mime(&self) -> &str {
110        self.mime_type.as_deref().unwrap_or("text/plain")
111    }
112}
113
114/// A single message in a prompt response.
115#[derive(Clone, Debug)]
116pub struct PromptMessage {
117    /// `"user"` or `"assistant"`.
118    pub role: &'static str,
119    /// The message content.
120    pub content: McpContent,
121}
122
123impl PromptMessage {
124    /// Build a user-role message.
125    pub fn user(text: impl Into<String>) -> Self {
126        PromptMessage { role: "user", content: McpContent::text(text) }
127    }
128
129    /// Build an assistant-role message.
130    pub fn assistant(text: impl Into<String>) -> Self {
131        PromptMessage { role: "assistant", content: McpContent::text(text) }
132    }
133
134    fn to_json(&self) -> String {
135        format!(
136            r#"{{"role":"{}","content":{}}}"#,
137            self.role,
138            self.content.to_content_json(),
139        )
140    }
141}
142
143/// Argument definition for a prompt template.
144pub struct PromptArgDef {
145    pub name: String,
146    pub description: String,
147    pub required: bool,
148}
149
150impl PromptArgDef {
151    pub fn required(name: impl Into<String>, description: impl Into<String>) -> Self {
152        PromptArgDef { name: name.into(), description: description.into(), required: true }
153    }
154
155    pub fn optional(name: impl Into<String>, description: impl Into<String>) -> Self {
156        PromptArgDef { name: name.into(), description: description.into(), required: false }
157    }
158}
159
160// ── internal handler registrations ───────────────────────────────────────────
161
162type ToolFn     = Arc<dyn Fn(&str) -> Result<McpContent, String>    + Send + Sync>;
163type ResourceFn = Arc<dyn Fn(&str) -> Result<McpContent, String>    + Send + Sync>;
164type PromptFn   = Arc<dyn Fn(&str) -> Result<Vec<PromptMessage>, String> + Send + Sync>;
165
166struct ToolDef {
167    name: String,
168    description: String,
169    input_schema: String,
170    handler: ToolFn,
171}
172
173struct ResourceDef {
174    uri_template: String,
175    name: String,
176    description: String,
177    handler: ResourceFn,
178}
179
180struct PromptDef {
181    name: String,
182    description: String,
183    arguments: Vec<PromptArgDef>,
184    handler: PromptFn,
185}
186
187// ── McpServer ─────────────────────────────────────────────────────────────────
188
189/// An HTTP server that implements the MCP 2024-11-05 protocol.
190///
191/// Register tools, resources, and prompts with the builder methods, then pass
192/// the server to [`Server::run`] (or [`Server::run_tls`]) as an [`Application`].
193/// Requests that do not match the MCP endpoint fall through to the built-in
194/// [`App`] controller chain.
195pub struct McpServer {
196    server_name: String,
197    server_version: String,
198    path: String,
199    tools: Vec<ToolDef>,
200    resources: Vec<ResourceDef>,
201    prompts: Vec<PromptDef>,
202}
203
204impl McpServer {
205    /// Create a new `McpServer`.  The default MCP endpoint is `POST /mcp`.
206    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
207        McpServer {
208            server_name: name.into(),
209            server_version: version.into(),
210            path: "/mcp".to_string(),
211            tools: vec![],
212            resources: vec![],
213            prompts: vec![],
214        }
215    }
216
217    /// Override the HTTP path for the MCP endpoint (default `"/mcp"`).
218    pub fn at(mut self, path: impl Into<String>) -> Self {
219        self.path = path.into();
220        self
221    }
222
223    /// Register a callable tool.
224    ///
225    /// - `name` — tool identifier (snake_case recommended)
226    /// - `description` — human-readable description shown to the AI
227    /// - `input_schema` — JSON Schema object for the tool's arguments
228    /// - `handler` — closure receiving the raw `arguments` JSON string
229    ///
230    /// The handler returns [`McpContent`] on success or an error string.  An
231    /// error is returned to the client as `isError: true` (not a protocol error).
232    pub fn tool<F>(mut self, name: &str, description: &str, input_schema: &str, handler: F) -> Self
233    where
234        F: Fn(&str) -> Result<McpContent, String> + Send + Sync + 'static,
235    {
236        self.tools.push(ToolDef {
237            name: name.to_string(),
238            description: description.to_string(),
239            input_schema: input_schema.to_string(),
240            handler: Arc::new(handler),
241        });
242        self
243    }
244
245    /// Register a readable resource.
246    ///
247    /// `uri_template` uses `{param}` placeholders, e.g. `"user://{id}"`.
248    /// The handler receives the full concrete URI string.
249    pub fn resource<F>(mut self, uri_template: &str, name: &str, description: &str, handler: F) -> Self
250    where
251        F: Fn(&str) -> Result<McpContent, String> + Send + Sync + 'static,
252    {
253        self.resources.push(ResourceDef {
254            uri_template: uri_template.to_string(),
255            name: name.to_string(),
256            description: description.to_string(),
257            handler: Arc::new(handler),
258        });
259        self
260    }
261
262    /// Register a prompt template.
263    ///
264    /// The handler receives the raw `arguments` JSON string and returns a
265    /// list of [`PromptMessage`] values.
266    pub fn prompt<F>(mut self, name: &str, description: &str, handler: F) -> Self
267    where
268        F: Fn(&str) -> Result<Vec<PromptMessage>, String> + Send + Sync + 'static,
269    {
270        self.prompts.push(PromptDef {
271            name: name.to_string(),
272            description: description.to_string(),
273            arguments: vec![],
274            handler: Arc::new(handler),
275        });
276        self
277    }
278
279    /// Register a prompt template with explicit argument definitions.
280    pub fn prompt_with_args<F>(
281        mut self,
282        name: &str,
283        description: &str,
284        args: Vec<PromptArgDef>,
285        handler: F,
286    ) -> Self
287    where
288        F: Fn(&str) -> Result<Vec<PromptMessage>, String> + Send + Sync + 'static,
289    {
290        self.prompts.push(PromptDef {
291            name: name.to_string(),
292            description: description.to_string(),
293            arguments: args,
294            handler: Arc::new(handler),
295        });
296        self
297    }
298
299    // ── request dispatch ──────────────────────────────────────────────────────
300
301    /// Process a raw JSON-RPC body and return an HTTP response.
302    pub fn handle_request(&self, body: &str) -> Response {
303        let method = match json_rpc::extract_str(body, "method") {
304            Some(m) => m,
305            None => return rpc_error(None, json_rpc::INVALID_REQUEST, "Missing method"),
306        };
307
308        let id = json_rpc::extract_id(body);
309
310        // Notifications have no `id` — acknowledge with 202 and no body.
311        if method == "notifications/initialized" || (id.is_none() && method != "ping") {
312            return no_content();
313        }
314
315        let result: Result<String, (i32, String)> = match method.as_str() {
316            "initialize"     => self.do_initialize(),
317            "ping"           => Ok("{}".to_string()),
318            "tools/list"     => self.do_tools_list(),
319            "tools/call"     => self.do_tools_call(body),
320            "resources/list" => self.do_resources_list(),
321            "resources/read" => self.do_resources_read(body),
322            "prompts/list"   => self.do_prompts_list(),
323            "prompts/get"    => self.do_prompts_get(body),
324            _                => Err((json_rpc::METHOD_NOT_FOUND, format!("Unknown method: {method}"))),
325        };
326
327        let id_str = id.as_deref().unwrap_or("null");
328
329        match result {
330            Ok(result_json) => json_response(&format!(
331                r#"{{"jsonrpc":"2.0","result":{result_json},"id":{id_str}}}"#
332            )),
333            Err((code, msg)) => {
334                let escaped = json_escape(&msg);
335                json_response(&format!(
336                    r#"{{"jsonrpc":"2.0","error":{{"code":{code},"message":"{escaped}"}},"id":{id_str}}}"#
337                ))
338            }
339        }
340    }
341
342    // ── method handlers ───────────────────────────────────────────────────────
343
344    fn do_initialize(&self) -> Result<String, (i32, String)> {
345        let caps = format!(
346            r#"{{"tools":{{"listChanged":false}},"resources":{{"subscribe":false,"listChanged":false}},"prompts":{{"listChanged":false}}}}"#
347        );
348        Ok(format!(
349            r#"{{"protocolVersion":"{PROTOCOL_VERSION}","capabilities":{caps},"serverInfo":{{"name":"{}","version":"{}"}}}}"#,
350            json_escape(&self.server_name),
351            json_escape(&self.server_version),
352        ))
353    }
354
355    fn do_tools_list(&self) -> Result<String, (i32, String)> {
356        let items: Vec<String> = self.tools.iter().map(|t| {
357            format!(
358                r#"{{"name":"{}","description":"{}","inputSchema":{}}}"#,
359                json_escape(&t.name),
360                json_escape(&t.description),
361                t.input_schema,
362            )
363        }).collect();
364        Ok(format!(r#"{{"tools":[{}]}}"#, items.join(",")))
365    }
366
367    fn do_tools_call(&self, body: &str) -> Result<String, (i32, String)> {
368        let params = json_rpc::extract_raw(body, "params")
369            .ok_or((json_rpc::INVALID_PARAMS, "Missing params".to_string()))?;
370        let name = json_rpc::extract_str(&params, "name")
371            .ok_or((json_rpc::INVALID_PARAMS, "Missing tool name".to_string()))?;
372        let args = json_rpc::extract_raw(&params, "arguments")
373            .unwrap_or_else(|| "{}".to_string());
374
375        let tool = self.tools.iter().find(|t| t.name == name)
376            .ok_or_else(|| (json_rpc::INVALID_PARAMS, format!("Unknown tool: {name}")))?;
377
378        match (tool.handler)(&args) {
379            Ok(c) => Ok(format!(
380                r#"{{"content":[{}],"isError":false}}"#,
381                c.to_content_json(),
382            )),
383            Err(e) => {
384                let escaped = json_escape(&e);
385                Ok(format!(
386                    r#"{{"content":[{{"type":"text","text":"{escaped}"}}],"isError":true}}"#
387                ))
388            }
389        }
390    }
391
392    fn do_resources_list(&self) -> Result<String, (i32, String)> {
393        let items: Vec<String> = self.resources.iter().map(|r| {
394            format!(
395                r#"{{"uri":"{}","name":"{}","description":"{}","mimeType":"text/plain"}}"#,
396                json_escape(&r.uri_template),
397                json_escape(&r.name),
398                json_escape(&r.description),
399            )
400        }).collect();
401        Ok(format!(r#"{{"resources":[{}]}}"#, items.join(",")))
402    }
403
404    fn do_resources_read(&self, body: &str) -> Result<String, (i32, String)> {
405        let params = json_rpc::extract_raw(body, "params")
406            .ok_or((json_rpc::INVALID_PARAMS, "Missing params".to_string()))?;
407        let uri = json_rpc::extract_str(&params, "uri")
408            .ok_or((json_rpc::INVALID_PARAMS, "Missing uri".to_string()))?;
409
410        let resource = self.resources.iter().find(|r| uri_matches(&r.uri_template, &uri))
411            .ok_or_else(|| (json_rpc::INVALID_PARAMS, format!("Resource not found: {uri}")))?;
412
413        match (resource.handler)(&uri) {
414            Ok(c) => {
415                let text_esc = json_escape(&c.text);
416                let uri_esc  = json_escape(&uri);
417                Ok(format!(
418                    r#"{{"contents":[{{"uri":"{uri_esc}","mimeType":"{}","text":"{text_esc}"}}]}}"#,
419                    c.mime(),
420                ))
421            }
422            Err(e) => Err((json_rpc::INVALID_PARAMS, e)),
423        }
424    }
425
426    fn do_prompts_list(&self) -> Result<String, (i32, String)> {
427        let items: Vec<String> = self.prompts.iter().map(|p| {
428            let arg_defs: Vec<String> = p.arguments.iter().map(|a| {
429                format!(
430                    r#"{{"name":"{}","description":"{}","required":{}}}"#,
431                    json_escape(&a.name),
432                    json_escape(&a.description),
433                    a.required,
434                )
435            }).collect();
436            format!(
437                r#"{{"name":"{}","description":"{}","arguments":[{}]}}"#,
438                json_escape(&p.name),
439                json_escape(&p.description),
440                arg_defs.join(","),
441            )
442        }).collect();
443        Ok(format!(r#"{{"prompts":[{}]}}"#, items.join(",")))
444    }
445
446    fn do_prompts_get(&self, body: &str) -> Result<String, (i32, String)> {
447        let params = json_rpc::extract_raw(body, "params")
448            .ok_or((json_rpc::INVALID_PARAMS, "Missing params".to_string()))?;
449        let name = json_rpc::extract_str(&params, "name")
450            .ok_or((json_rpc::INVALID_PARAMS, "Missing prompt name".to_string()))?;
451        let args = json_rpc::extract_raw(&params, "arguments")
452            .unwrap_or_else(|| "{}".to_string());
453
454        let prompt = self.prompts.iter().find(|p| p.name == name)
455            .ok_or_else(|| (json_rpc::INVALID_PARAMS, format!("Unknown prompt: {name}")))?;
456
457        match (prompt.handler)(&args) {
458            Ok(msgs) => {
459                let msg_jsons: Vec<String> = msgs.iter().map(|m| m.to_json()).collect();
460                Ok(format!(
461                    r#"{{"description":"{}","messages":[{}]}}"#,
462                    json_escape(&prompt.description),
463                    msg_jsons.join(","),
464                ))
465            }
466            Err(e) => Err((json_rpc::INVALID_PARAMS, e)),
467        }
468    }
469}
470
471// ── Application ───────────────────────────────────────────────────────────────
472
473impl Application for McpServer {
474    fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
475        if request.request_uri == self.path {
476            return Ok(match request.method.as_str() {
477                "POST" => {
478                    let body = std::str::from_utf8(&request.body).unwrap_or("");
479                    self.handle_request(body)
480                }
481                "OPTIONS" => {
482                    // CORS preflight for browser-based MCP clients
483                    let mut r = Response::new();
484                    r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
485                    r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
486                    r.headers.push(Header {
487                        name: "Allow".to_string(),
488                        value: "POST, OPTIONS".to_string(),
489                    });
490                    r
491                }
492                _ => {
493                    let mut r = Response::new();
494                    r.status_code = *STATUS_CODE_REASON_PHRASE.n405_method_not_allowed.status_code;
495                    r.reason_phrase = STATUS_CODE_REASON_PHRASE.n405_method_not_allowed.reason_phrase.to_string();
496                    r.headers.push(Header {
497                        name: "Allow".to_string(),
498                        value: "POST, OPTIONS".to_string(),
499                    });
500                    r.content_range_list = vec![Range::get_content_range(
501                        b"MCP endpoint only accepts POST".to_vec(),
502                        MimeType::TEXT_PLAIN.to_string(),
503                    )];
504                    r
505                }
506            });
507        }
508
509        // Not an MCP path — fall through to the built-in App.
510        App::new().execute(request, connection)
511    }
512}
513
514// ── public helper ─────────────────────────────────────────────────────────────
515
516/// Extract a string argument from a tool/prompt `arguments` JSON object.
517///
518/// ```rust
519/// use rust_web_server::mcp::extract_arg;
520/// assert_eq!(extract_arg(r#"{"text":"hello"}"#, "text").as_deref(), Some("hello"));
521/// assert_eq!(extract_arg(r#"{}"#, "missing"), None);
522/// ```
523pub fn extract_arg(arguments: &str, name: &str) -> Option<String> {
524    json_rpc::extract_str(arguments, name)
525}
526
527// ── internal helpers ──────────────────────────────────────────────────────────
528
529fn json_response(body: &str) -> Response {
530    let mut r = Response::new();
531    r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
532    r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
533    r.content_range_list = vec![Range::get_content_range(
534        body.as_bytes().to_vec(),
535        MimeType::APPLICATION_JSON.to_string(),
536    )];
537    r
538}
539
540fn no_content() -> Response {
541    let mut r = Response::new();
542    r.status_code = *STATUS_CODE_REASON_PHRASE.n202_accepted.status_code;
543    r.reason_phrase = STATUS_CODE_REASON_PHRASE.n202_accepted.reason_phrase.to_string();
544    r
545}
546
547fn rpc_error(id: Option<&str>, code: i32, message: &str) -> Response {
548    let id_str  = id.unwrap_or("null");
549    let escaped = json_escape(message);
550    json_response(&format!(
551        r#"{{"jsonrpc":"2.0","error":{{"code":{code},"message":"{escaped}"}},"id":{id_str}}}"#
552    ))
553}
554
555pub(crate) fn json_escape(s: &str) -> String {
556    let mut out = String::with_capacity(s.len() + 4);
557    for ch in s.chars() {
558        match ch {
559            '"'  => out.push_str("\\\""),
560            '\\' => out.push_str("\\\\"),
561            '\n' => out.push_str("\\n"),
562            '\r' => out.push_str("\\r"),
563            '\t' => out.push_str("\\t"),
564            c if (c as u32) < 0x20 => { let _ = std::fmt::Write::write_fmt(&mut out, format_args!("\\u{:04x}", c as u32)); }
565            c    => out.push(c),
566        }
567    }
568    out
569}
570
571fn uri_matches(template: &str, uri: &str) -> bool {
572    // Template `"user://{id}"` matches any URI starting with `"user://"`.
573    match template.find('{') {
574        Some(pos) => uri.starts_with(&template[..pos]),
575        None      => template == uri,
576    }
577}