1mod json_rpc;
58
59#[cfg(test)]
60mod tests;
61
62use std::sync::Arc;
63
64use crate::app::App;
65use crate::application::Application;
66use crate::core::New;
67use crate::header::Header;
68use crate::mime_type::MimeType;
69use crate::range::Range;
70use crate::request::Request;
71use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
72use crate::server::ConnectionInfo;
73
74const PROTOCOL_VERSION: &str = "2024-11-05";
75
76#[derive(Clone, Debug)]
83pub struct McpContent {
84 pub kind: &'static str,
86 pub text: String,
88 pub mime_type: Option<String>,
90}
91
92impl McpContent {
93 pub fn text(s: impl Into<String>) -> Self {
95 McpContent { kind: "text", text: s.into(), mime_type: None }
96 }
97
98 pub fn json(s: impl Into<String>) -> Self {
100 McpContent { kind: "text", text: s.into(), mime_type: Some("application/json".to_string()) }
101 }
102
103 fn to_content_json(&self) -> String {
104 let escaped = json_escape(&self.text);
105 format!(r#"{{"type":"{}","text":"{}"}}"#, self.kind, escaped)
106 }
107
108 fn mime(&self) -> &str {
109 self.mime_type.as_deref().unwrap_or("text/plain")
110 }
111}
112
113#[derive(Clone, Debug)]
115pub struct PromptMessage {
116 pub role: &'static str,
118 pub content: McpContent,
120}
121
122impl PromptMessage {
123 pub fn user(text: impl Into<String>) -> Self {
125 PromptMessage { role: "user", content: McpContent::text(text) }
126 }
127
128 pub fn assistant(text: impl Into<String>) -> Self {
130 PromptMessage { role: "assistant", content: McpContent::text(text) }
131 }
132
133 fn to_json(&self) -> String {
134 format!(
135 r#"{{"role":"{}","content":{}}}"#,
136 self.role,
137 self.content.to_content_json(),
138 )
139 }
140}
141
142#[derive(Clone)]
144pub struct PromptArgDef {
145 pub name: String,
146 pub description: String,
147 pub required: bool,
148}
149
150impl PromptArgDef {
151 pub fn required(name: impl Into<String>, description: impl Into<String>) -> Self {
152 PromptArgDef { name: name.into(), description: description.into(), required: true }
153 }
154
155 pub fn optional(name: impl Into<String>, description: impl Into<String>) -> Self {
156 PromptArgDef { name: name.into(), description: description.into(), required: false }
157 }
158}
159
160type ToolFn = Arc<dyn Fn(&str) -> Result<McpContent, String> + Send + Sync>;
163type ResourceFn = Arc<dyn Fn(&str) -> Result<McpContent, String> + Send + Sync>;
164type PromptFn = Arc<dyn Fn(&str) -> Result<Vec<PromptMessage>, String> + Send + Sync>;
165
166#[derive(Clone)]
167struct ToolDef {
168 name: String,
169 description: String,
170 input_schema: String,
171 handler: ToolFn,
172}
173
174#[derive(Clone)]
175struct ResourceDef {
176 uri_template: String,
177 name: String,
178 description: String,
179 handler: ResourceFn,
180}
181
182#[derive(Clone)]
183struct PromptDef {
184 name: String,
185 description: String,
186 arguments: Vec<PromptArgDef>,
187 handler: PromptFn,
188}
189
190#[derive(Clone)]
199pub struct McpServer {
200 server_name: String,
201 server_version: String,
202 path: String,
203 tools: Vec<ToolDef>,
204 resources: Vec<ResourceDef>,
205 prompts: Vec<PromptDef>,
206 fallback: Option<Arc<dyn Application + Send + Sync>>,
207 auth_token: Option<String>,
208}
209
210impl McpServer {
211 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
213 McpServer {
214 server_name: name.into(),
215 server_version: version.into(),
216 path: "/mcp".to_string(),
217 tools: vec![],
218 resources: vec![],
219 prompts: vec![],
220 fallback: None,
221 auth_token: None,
222 }
223 }
224
225 pub fn require_bearer(mut self, token: impl Into<String>) -> Self {
250 self.auth_token = Some(token.into());
251 self
252 }
253
254 pub fn wrap(mut self, app: impl Application + Send + Sync + 'static) -> Self {
279 self.fallback = Some(Arc::new(app));
280 self
281 }
282
283 pub fn at(mut self, path: impl Into<String>) -> Self {
285 self.path = path.into();
286 self
287 }
288
289 pub fn tool<F>(mut self, name: &str, description: &str, input_schema: &str, handler: F) -> Self
299 where
300 F: Fn(&str) -> Result<McpContent, String> + Send + Sync + 'static,
301 {
302 self.tools.push(ToolDef {
303 name: name.to_string(),
304 description: description.to_string(),
305 input_schema: input_schema.to_string(),
306 handler: Arc::new(handler),
307 });
308 self
309 }
310
311 pub fn resource<F>(mut self, uri_template: &str, name: &str, description: &str, handler: F) -> Self
316 where
317 F: Fn(&str) -> Result<McpContent, String> + Send + Sync + 'static,
318 {
319 self.resources.push(ResourceDef {
320 uri_template: uri_template.to_string(),
321 name: name.to_string(),
322 description: description.to_string(),
323 handler: Arc::new(handler),
324 });
325 self
326 }
327
328 pub fn prompt<F>(mut self, name: &str, description: &str, handler: F) -> Self
333 where
334 F: Fn(&str) -> Result<Vec<PromptMessage>, String> + Send + Sync + 'static,
335 {
336 self.prompts.push(PromptDef {
337 name: name.to_string(),
338 description: description.to_string(),
339 arguments: vec![],
340 handler: Arc::new(handler),
341 });
342 self
343 }
344
345 pub fn prompt_with_args<F>(
347 mut self,
348 name: &str,
349 description: &str,
350 args: Vec<PromptArgDef>,
351 handler: F,
352 ) -> Self
353 where
354 F: Fn(&str) -> Result<Vec<PromptMessage>, String> + Send + Sync + 'static,
355 {
356 self.prompts.push(PromptDef {
357 name: name.to_string(),
358 description: description.to_string(),
359 arguments: args,
360 handler: Arc::new(handler),
361 });
362 self
363 }
364
365 pub fn handle_request(&self, body: &str) -> Response {
369 let method = match json_rpc::extract_str(body, "method") {
370 Some(m) => m,
371 None => return rpc_error(None, json_rpc::INVALID_REQUEST, "Missing method"),
372 };
373
374 let id = json_rpc::extract_id(body);
375
376 if method == "notifications/initialized" || (id.is_none() && method != "ping") {
378 return no_content();
379 }
380
381 let result: Result<String, (i32, String)> = match method.as_str() {
382 "initialize" => self.do_initialize(),
383 "ping" => Ok("{}".to_string()),
384 "tools/list" => self.do_tools_list(),
385 "tools/call" => self.do_tools_call(body),
386 "resources/list" => self.do_resources_list(),
387 "resources/read" => self.do_resources_read(body),
388 "prompts/list" => self.do_prompts_list(),
389 "prompts/get" => self.do_prompts_get(body),
390 _ => Err((json_rpc::METHOD_NOT_FOUND, format!("Unknown method: {method}"))),
391 };
392
393 let id_str = id.as_deref().unwrap_or("null");
394
395 match result {
396 Ok(result_json) => json_response(&format!(
397 r#"{{"jsonrpc":"2.0","result":{result_json},"id":{id_str}}}"#
398 )),
399 Err((code, msg)) => {
400 let escaped = json_escape(&msg);
401 json_response(&format!(
402 r#"{{"jsonrpc":"2.0","error":{{"code":{code},"message":"{escaped}"}},"id":{id_str}}}"#
403 ))
404 }
405 }
406 }
407
408 fn do_initialize(&self) -> Result<String, (i32, String)> {
411 let caps = format!(
412 r#"{{"tools":{{"listChanged":false}},"resources":{{"subscribe":false,"listChanged":false}},"prompts":{{"listChanged":false}}}}"#
413 );
414 Ok(format!(
415 r#"{{"protocolVersion":"{PROTOCOL_VERSION}","capabilities":{caps},"serverInfo":{{"name":"{}","version":"{}"}}}}"#,
416 json_escape(&self.server_name),
417 json_escape(&self.server_version),
418 ))
419 }
420
421 fn do_tools_list(&self) -> Result<String, (i32, String)> {
422 let items: Vec<String> = self.tools.iter().map(|t| {
423 format!(
424 r#"{{"name":"{}","description":"{}","inputSchema":{}}}"#,
425 json_escape(&t.name),
426 json_escape(&t.description),
427 t.input_schema,
428 )
429 }).collect();
430 Ok(format!(r#"{{"tools":[{}]}}"#, items.join(",")))
431 }
432
433 fn do_tools_call(&self, body: &str) -> Result<String, (i32, String)> {
434 let params = json_rpc::extract_raw(body, "params")
435 .ok_or((json_rpc::INVALID_PARAMS, "Missing params".to_string()))?;
436 let name = json_rpc::extract_str(¶ms, "name")
437 .ok_or((json_rpc::INVALID_PARAMS, "Missing tool name".to_string()))?;
438 let args = json_rpc::extract_raw(¶ms, "arguments")
439 .unwrap_or_else(|| "{}".to_string());
440
441 let tool = self.tools.iter().find(|t| t.name == name)
442 .ok_or_else(|| (json_rpc::INVALID_PARAMS, format!("Unknown tool: {name}")))?;
443
444 match (tool.handler)(&args) {
445 Ok(c) => Ok(format!(
446 r#"{{"content":[{}],"isError":false}}"#,
447 c.to_content_json(),
448 )),
449 Err(e) => {
450 let escaped = json_escape(&e);
451 Ok(format!(
452 r#"{{"content":[{{"type":"text","text":"{escaped}"}}],"isError":true}}"#
453 ))
454 }
455 }
456 }
457
458 fn do_resources_list(&self) -> Result<String, (i32, String)> {
459 let items: Vec<String> = self.resources.iter().map(|r| {
460 format!(
461 r#"{{"uri":"{}","name":"{}","description":"{}","mimeType":"text/plain"}}"#,
462 json_escape(&r.uri_template),
463 json_escape(&r.name),
464 json_escape(&r.description),
465 )
466 }).collect();
467 Ok(format!(r#"{{"resources":[{}]}}"#, items.join(",")))
468 }
469
470 fn do_resources_read(&self, body: &str) -> Result<String, (i32, String)> {
471 let params = json_rpc::extract_raw(body, "params")
472 .ok_or((json_rpc::INVALID_PARAMS, "Missing params".to_string()))?;
473 let uri = json_rpc::extract_str(¶ms, "uri")
474 .ok_or((json_rpc::INVALID_PARAMS, "Missing uri".to_string()))?;
475
476 let resource = self.resources.iter().find(|r| uri_matches(&r.uri_template, &uri))
477 .ok_or_else(|| (json_rpc::INVALID_PARAMS, format!("Resource not found: {uri}")))?;
478
479 match (resource.handler)(&uri) {
480 Ok(c) => {
481 let text_esc = json_escape(&c.text);
482 let uri_esc = json_escape(&uri);
483 Ok(format!(
484 r#"{{"contents":[{{"uri":"{uri_esc}","mimeType":"{}","text":"{text_esc}"}}]}}"#,
485 c.mime(),
486 ))
487 }
488 Err(e) => Err((json_rpc::INVALID_PARAMS, e)),
489 }
490 }
491
492 fn do_prompts_list(&self) -> Result<String, (i32, String)> {
493 let items: Vec<String> = self.prompts.iter().map(|p| {
494 let arg_defs: Vec<String> = p.arguments.iter().map(|a| {
495 format!(
496 r#"{{"name":"{}","description":"{}","required":{}}}"#,
497 json_escape(&a.name),
498 json_escape(&a.description),
499 a.required,
500 )
501 }).collect();
502 format!(
503 r#"{{"name":"{}","description":"{}","arguments":[{}]}}"#,
504 json_escape(&p.name),
505 json_escape(&p.description),
506 arg_defs.join(","),
507 )
508 }).collect();
509 Ok(format!(r#"{{"prompts":[{}]}}"#, items.join(",")))
510 }
511
512 fn do_prompts_get(&self, body: &str) -> Result<String, (i32, String)> {
513 let params = json_rpc::extract_raw(body, "params")
514 .ok_or((json_rpc::INVALID_PARAMS, "Missing params".to_string()))?;
515 let name = json_rpc::extract_str(¶ms, "name")
516 .ok_or((json_rpc::INVALID_PARAMS, "Missing prompt name".to_string()))?;
517 let args = json_rpc::extract_raw(¶ms, "arguments")
518 .unwrap_or_else(|| "{}".to_string());
519
520 let prompt = self.prompts.iter().find(|p| p.name == name)
521 .ok_or_else(|| (json_rpc::INVALID_PARAMS, format!("Unknown prompt: {name}")))?;
522
523 match (prompt.handler)(&args) {
524 Ok(msgs) => {
525 let msg_jsons: Vec<String> = msgs.iter().map(|m| m.to_json()).collect();
526 Ok(format!(
527 r#"{{"description":"{}","messages":[{}]}}"#,
528 json_escape(&prompt.description),
529 msg_jsons.join(","),
530 ))
531 }
532 Err(e) => Err((json_rpc::INVALID_PARAMS, e)),
533 }
534 }
535}
536
537impl Application for McpServer {
540 fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
541 if request.request_uri == self.path {
542 if let Some(expected) = &self.auth_token {
544 let provided = request.headers.iter()
545 .find(|h| h.name.eq_ignore_ascii_case("authorization"))
546 .map(|h| h.value.as_str())
547 .unwrap_or("");
548 let bearer = provided.strip_prefix("Bearer ").unwrap_or("");
549 if bearer != expected.as_str() {
550 return Ok(unauthorized());
551 }
552 }
553
554 return Ok(match request.method.as_str() {
555 "POST" => {
556 let body = std::str::from_utf8(&request.body).unwrap_or("");
557 self.handle_request(body)
558 }
559 "OPTIONS" => {
560 let mut r = Response::new();
562 r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
563 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
564 r.headers.push(Header {
565 name: "Allow".to_string(),
566 value: "POST, OPTIONS".to_string(),
567 });
568 r
569 }
570 _ => {
571 let mut r = Response::new();
572 r.status_code = *STATUS_CODE_REASON_PHRASE.n405_method_not_allowed.status_code;
573 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n405_method_not_allowed.reason_phrase.to_string();
574 r.headers.push(Header {
575 name: "Allow".to_string(),
576 value: "POST, OPTIONS".to_string(),
577 });
578 r.content_range_list = vec![Range::get_content_range(
579 b"MCP endpoint only accepts POST".to_vec(),
580 MimeType::TEXT_PLAIN.to_string(),
581 )];
582 r
583 }
584 });
585 }
586
587 match &self.fallback {
589 Some(app) => app.execute(request, connection),
590 None => App::new().execute(request, connection),
591 }
592 }
593}
594
595pub fn extract_arg(arguments: &str, name: &str) -> Option<String> {
605 json_rpc::extract_str(arguments, name)
606}
607
608fn json_response(body: &str) -> Response {
611 let mut r = Response::new();
612 r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
613 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
614 r.content_range_list = vec![Range::get_content_range(
615 body.as_bytes().to_vec(),
616 MimeType::APPLICATION_JSON.to_string(),
617 )];
618 r
619}
620
621fn no_content() -> Response {
622 let mut r = Response::new();
623 r.status_code = *STATUS_CODE_REASON_PHRASE.n202_accepted.status_code;
624 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n202_accepted.reason_phrase.to_string();
625 r
626}
627
628fn unauthorized() -> Response {
629 let mut r = Response::new();
630 r.status_code = *STATUS_CODE_REASON_PHRASE.n401_unauthorized.status_code;
631 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n401_unauthorized.reason_phrase.to_string();
632 r.headers.push(Header {
633 name: "WWW-Authenticate".to_string(),
634 value: "Bearer".to_string(),
635 });
636 r.content_range_list = vec![Range::get_content_range(
637 b"Unauthorized".to_vec(),
638 MimeType::TEXT_PLAIN.to_string(),
639 )];
640 r
641}
642
643fn rpc_error(id: Option<&str>, code: i32, message: &str) -> Response {
644 let id_str = id.unwrap_or("null");
645 let escaped = json_escape(message);
646 json_response(&format!(
647 r#"{{"jsonrpc":"2.0","error":{{"code":{code},"message":"{escaped}"}},"id":{id_str}}}"#
648 ))
649}
650
651pub(crate) fn json_escape(s: &str) -> String {
652 let mut out = String::with_capacity(s.len() + 4);
653 for ch in s.chars() {
654 match ch {
655 '"' => out.push_str("\\\""),
656 '\\' => out.push_str("\\\\"),
657 '\n' => out.push_str("\\n"),
658 '\r' => out.push_str("\\r"),
659 '\t' => out.push_str("\\t"),
660 c if (c as u32) < 0x20 => { let _ = std::fmt::Write::write_fmt(&mut out, format_args!("\\u{:04x}", c as u32)); }
661 c => out.push(c),
662 }
663 }
664 out
665}
666
667fn uri_matches(template: &str, uri: &str) -> bool {
668 match template.find('{') {
670 Some(pos) => uri.starts_with(&template[..pos]),
671 None => template == uri,
672 }
673}