1mod json_rpc;
59
60#[cfg(test)]
61mod tests;
62
63use std::sync::Arc;
64
65use crate::app::App;
66use crate::application::Application;
67use crate::core::New;
68use crate::header::Header;
69use crate::mime_type::MimeType;
70use crate::range::Range;
71use crate::request::Request;
72use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
73use crate::server::ConnectionInfo;
74
75const PROTOCOL_VERSION: &str = "2024-11-05";
76
77#[derive(Clone, Debug)]
84pub struct McpContent {
85 pub kind: &'static str,
87 pub text: String,
89 pub mime_type: Option<String>,
91}
92
93impl McpContent {
94 pub fn text(s: impl Into<String>) -> Self {
96 McpContent { kind: "text", text: s.into(), mime_type: None }
97 }
98
99 pub fn json(s: impl Into<String>) -> Self {
101 McpContent { kind: "text", text: s.into(), mime_type: Some("application/json".to_string()) }
102 }
103
104 fn to_content_json(&self) -> String {
105 let escaped = json_escape(&self.text);
106 format!(r#"{{"type":"{}","text":"{}"}}"#, self.kind, escaped)
107 }
108
109 fn mime(&self) -> &str {
110 self.mime_type.as_deref().unwrap_or("text/plain")
111 }
112}
113
114#[derive(Clone, Debug)]
116pub struct PromptMessage {
117 pub role: &'static str,
119 pub content: McpContent,
121}
122
123impl PromptMessage {
124 pub fn user(text: impl Into<String>) -> Self {
126 PromptMessage { role: "user", content: McpContent::text(text) }
127 }
128
129 pub fn assistant(text: impl Into<String>) -> Self {
131 PromptMessage { role: "assistant", content: McpContent::text(text) }
132 }
133
134 fn to_json(&self) -> String {
135 format!(
136 r#"{{"role":"{}","content":{}}}"#,
137 self.role,
138 self.content.to_content_json(),
139 )
140 }
141}
142
143#[derive(Clone)]
145pub struct PromptArgDef {
146 pub name: String,
147 pub description: String,
148 pub required: bool,
149}
150
151impl PromptArgDef {
152 pub fn required(name: impl Into<String>, description: impl Into<String>) -> Self {
153 PromptArgDef { name: name.into(), description: description.into(), required: true }
154 }
155
156 pub fn optional(name: impl Into<String>, description: impl Into<String>) -> Self {
157 PromptArgDef { name: name.into(), description: description.into(), required: false }
158 }
159}
160
161type ToolFn = Arc<dyn Fn(&str) -> Result<McpContent, String> + Send + Sync>;
164type ResourceFn = Arc<dyn Fn(&str) -> Result<McpContent, String> + Send + Sync>;
165type PromptFn = Arc<dyn Fn(&str) -> Result<Vec<PromptMessage>, String> + Send + Sync>;
166
167#[derive(Clone)]
168struct ToolDef {
169 name: String,
170 description: String,
171 input_schema: String,
172 handler: ToolFn,
173}
174
175#[derive(Clone)]
176struct ResourceDef {
177 uri_template: String,
178 name: String,
179 description: String,
180 handler: ResourceFn,
181}
182
183#[derive(Clone)]
184struct PromptDef {
185 name: String,
186 description: String,
187 arguments: Vec<PromptArgDef>,
188 handler: PromptFn,
189}
190
191#[derive(Clone)]
200pub struct McpServer {
201 server_name: String,
202 server_version: String,
203 path: String,
204 tools: Vec<ToolDef>,
205 resources: Vec<ResourceDef>,
206 prompts: Vec<PromptDef>,
207 fallback: Option<Arc<dyn Application + Send + Sync>>,
208 auth_token: Option<String>,
209}
210
211impl McpServer {
212 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
214 McpServer {
215 server_name: name.into(),
216 server_version: version.into(),
217 path: "/mcp".to_string(),
218 tools: vec![],
219 resources: vec![],
220 prompts: vec![],
221 fallback: None,
222 auth_token: None,
223 }
224 }
225
226 pub fn require_bearer(mut self, token: impl Into<String>) -> Self {
251 self.auth_token = Some(token.into());
252 self
253 }
254
255 pub fn wrap(mut self, app: impl Application + Send + Sync + 'static) -> Self {
283 self.fallback = Some(Arc::new(app));
284 self
285 }
286
287 pub fn at(mut self, path: impl Into<String>) -> Self {
289 self.path = path.into();
290 self
291 }
292
293 pub fn tool<F>(mut self, name: &str, description: &str, input_schema: &str, handler: F) -> Self
303 where
304 F: Fn(&str) -> Result<McpContent, String> + Send + Sync + 'static,
305 {
306 self.tools.push(ToolDef {
307 name: name.to_string(),
308 description: description.to_string(),
309 input_schema: input_schema.to_string(),
310 handler: Arc::new(handler),
311 });
312 self
313 }
314
315 pub fn resource<F>(mut self, uri_template: &str, name: &str, description: &str, handler: F) -> Self
320 where
321 F: Fn(&str) -> Result<McpContent, String> + Send + Sync + 'static,
322 {
323 self.resources.push(ResourceDef {
324 uri_template: uri_template.to_string(),
325 name: name.to_string(),
326 description: description.to_string(),
327 handler: Arc::new(handler),
328 });
329 self
330 }
331
332 pub fn prompt<F>(mut self, name: &str, description: &str, handler: F) -> Self
337 where
338 F: Fn(&str) -> Result<Vec<PromptMessage>, String> + Send + Sync + 'static,
339 {
340 self.prompts.push(PromptDef {
341 name: name.to_string(),
342 description: description.to_string(),
343 arguments: vec![],
344 handler: Arc::new(handler),
345 });
346 self
347 }
348
349 pub fn prompt_with_args<F>(
351 mut self,
352 name: &str,
353 description: &str,
354 args: Vec<PromptArgDef>,
355 handler: F,
356 ) -> Self
357 where
358 F: Fn(&str) -> Result<Vec<PromptMessage>, String> + Send + Sync + 'static,
359 {
360 self.prompts.push(PromptDef {
361 name: name.to_string(),
362 description: description.to_string(),
363 arguments: args,
364 handler: Arc::new(handler),
365 });
366 self
367 }
368
369 pub fn handle_request(&self, body: &str) -> Response {
373 let method = match json_rpc::extract_str(body, "method") {
374 Some(m) => m,
375 None => return rpc_error(None, json_rpc::INVALID_REQUEST, "Missing method"),
376 };
377
378 let id = json_rpc::extract_id(body);
379
380 if method == "notifications/initialized" || (id.is_none() && method != "ping") {
382 return no_content();
383 }
384
385 let result: Result<String, (i32, String)> = match method.as_str() {
386 "initialize" => self.do_initialize(),
387 "ping" => Ok("{}".to_string()),
388 "tools/list" => self.do_tools_list(),
389 "tools/call" => self.do_tools_call(body),
390 "resources/list" => self.do_resources_list(),
391 "resources/read" => self.do_resources_read(body),
392 "prompts/list" => self.do_prompts_list(),
393 "prompts/get" => self.do_prompts_get(body),
394 _ => Err((json_rpc::METHOD_NOT_FOUND, format!("Unknown method: {method}"))),
395 };
396
397 let id_str = id.as_deref().unwrap_or("null");
398
399 match result {
400 Ok(result_json) => json_response(&format!(
401 r#"{{"jsonrpc":"2.0","result":{result_json},"id":{id_str}}}"#
402 )),
403 Err((code, msg)) => {
404 let escaped = json_escape(&msg);
405 json_response(&format!(
406 r#"{{"jsonrpc":"2.0","error":{{"code":{code},"message":"{escaped}"}},"id":{id_str}}}"#
407 ))
408 }
409 }
410 }
411
412 fn do_initialize(&self) -> Result<String, (i32, String)> {
415 let caps = format!(
416 r#"{{"tools":{{"listChanged":false}},"resources":{{"subscribe":false,"listChanged":false}},"prompts":{{"listChanged":false}}}}"#
417 );
418 Ok(format!(
419 r#"{{"protocolVersion":"{PROTOCOL_VERSION}","capabilities":{caps},"serverInfo":{{"name":"{}","version":"{}"}}}}"#,
420 json_escape(&self.server_name),
421 json_escape(&self.server_version),
422 ))
423 }
424
425 fn do_tools_list(&self) -> Result<String, (i32, String)> {
426 let items: Vec<String> = self.tools.iter().map(|t| {
427 format!(
428 r#"{{"name":"{}","description":"{}","inputSchema":{}}}"#,
429 json_escape(&t.name),
430 json_escape(&t.description),
431 t.input_schema,
432 )
433 }).collect();
434 Ok(format!(r#"{{"tools":[{}]}}"#, items.join(",")))
435 }
436
437 fn do_tools_call(&self, body: &str) -> Result<String, (i32, String)> {
438 let params = json_rpc::extract_raw(body, "params")
439 .ok_or((json_rpc::INVALID_PARAMS, "Missing params".to_string()))?;
440 let name = json_rpc::extract_str(¶ms, "name")
441 .ok_or((json_rpc::INVALID_PARAMS, "Missing tool name".to_string()))?;
442 let args = json_rpc::extract_raw(¶ms, "arguments")
443 .unwrap_or_else(|| "{}".to_string());
444
445 let tool = self.tools.iter().find(|t| t.name == name)
446 .ok_or_else(|| (json_rpc::INVALID_PARAMS, format!("Unknown tool: {name}")))?;
447
448 match (tool.handler)(&args) {
449 Ok(c) => Ok(format!(
450 r#"{{"content":[{}],"isError":false}}"#,
451 c.to_content_json(),
452 )),
453 Err(e) => {
454 let escaped = json_escape(&e);
455 Ok(format!(
456 r#"{{"content":[{{"type":"text","text":"{escaped}"}}],"isError":true}}"#
457 ))
458 }
459 }
460 }
461
462 fn do_resources_list(&self) -> Result<String, (i32, String)> {
463 let items: Vec<String> = self.resources.iter().map(|r| {
464 format!(
465 r#"{{"uri":"{}","name":"{}","description":"{}","mimeType":"text/plain"}}"#,
466 json_escape(&r.uri_template),
467 json_escape(&r.name),
468 json_escape(&r.description),
469 )
470 }).collect();
471 Ok(format!(r#"{{"resources":[{}]}}"#, items.join(",")))
472 }
473
474 fn do_resources_read(&self, body: &str) -> Result<String, (i32, String)> {
475 let params = json_rpc::extract_raw(body, "params")
476 .ok_or((json_rpc::INVALID_PARAMS, "Missing params".to_string()))?;
477 let uri = json_rpc::extract_str(¶ms, "uri")
478 .ok_or((json_rpc::INVALID_PARAMS, "Missing uri".to_string()))?;
479
480 let resource = self.resources.iter().find(|r| uri_matches(&r.uri_template, &uri))
481 .ok_or_else(|| (json_rpc::INVALID_PARAMS, format!("Resource not found: {uri}")))?;
482
483 match (resource.handler)(&uri) {
484 Ok(c) => {
485 let text_esc = json_escape(&c.text);
486 let uri_esc = json_escape(&uri);
487 Ok(format!(
488 r#"{{"contents":[{{"uri":"{uri_esc}","mimeType":"{}","text":"{text_esc}"}}]}}"#,
489 c.mime(),
490 ))
491 }
492 Err(e) => Err((json_rpc::INVALID_PARAMS, e)),
493 }
494 }
495
496 fn do_prompts_list(&self) -> Result<String, (i32, String)> {
497 let items: Vec<String> = self.prompts.iter().map(|p| {
498 let arg_defs: Vec<String> = p.arguments.iter().map(|a| {
499 format!(
500 r#"{{"name":"{}","description":"{}","required":{}}}"#,
501 json_escape(&a.name),
502 json_escape(&a.description),
503 a.required,
504 )
505 }).collect();
506 format!(
507 r#"{{"name":"{}","description":"{}","arguments":[{}]}}"#,
508 json_escape(&p.name),
509 json_escape(&p.description),
510 arg_defs.join(","),
511 )
512 }).collect();
513 Ok(format!(r#"{{"prompts":[{}]}}"#, items.join(",")))
514 }
515
516 fn do_prompts_get(&self, body: &str) -> Result<String, (i32, String)> {
517 let params = json_rpc::extract_raw(body, "params")
518 .ok_or((json_rpc::INVALID_PARAMS, "Missing params".to_string()))?;
519 let name = json_rpc::extract_str(¶ms, "name")
520 .ok_or((json_rpc::INVALID_PARAMS, "Missing prompt name".to_string()))?;
521 let args = json_rpc::extract_raw(¶ms, "arguments")
522 .unwrap_or_else(|| "{}".to_string());
523
524 let prompt = self.prompts.iter().find(|p| p.name == name)
525 .ok_or_else(|| (json_rpc::INVALID_PARAMS, format!("Unknown prompt: {name}")))?;
526
527 match (prompt.handler)(&args) {
528 Ok(msgs) => {
529 let msg_jsons: Vec<String> = msgs.iter().map(|m| m.to_json()).collect();
530 Ok(format!(
531 r#"{{"description":"{}","messages":[{}]}}"#,
532 json_escape(&prompt.description),
533 msg_jsons.join(","),
534 ))
535 }
536 Err(e) => Err((json_rpc::INVALID_PARAMS, e)),
537 }
538 }
539}
540
541impl Application for McpServer {
544 fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
545 if request.request_uri == self.path {
546 if let Some(expected) = &self.auth_token {
548 let provided = request.headers.iter()
549 .find(|h| h.name.eq_ignore_ascii_case("authorization"))
550 .map(|h| h.value.as_str())
551 .unwrap_or("");
552 let bearer = provided.strip_prefix("Bearer ").unwrap_or("");
553 if bearer != expected.as_str() {
554 return Ok(unauthorized());
555 }
556 }
557
558 return Ok(match request.method.as_str() {
559 "POST" => {
560 let body = std::str::from_utf8(&request.body).unwrap_or("");
561 self.handle_request(body)
562 }
563 "OPTIONS" => {
564 let mut r = Response::new();
566 r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
567 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
568 r.headers.push(Header {
569 name: "Allow".to_string(),
570 value: "POST, OPTIONS".to_string(),
571 });
572 r
573 }
574 _ => {
575 let mut r = Response::new();
576 r.status_code = *STATUS_CODE_REASON_PHRASE.n405_method_not_allowed.status_code;
577 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n405_method_not_allowed.reason_phrase.to_string();
578 r.headers.push(Header {
579 name: "Allow".to_string(),
580 value: "POST, OPTIONS".to_string(),
581 });
582 r.content_range_list = vec![Range::get_content_range(
583 b"MCP endpoint only accepts POST".to_vec(),
584 MimeType::TEXT_PLAIN.to_string(),
585 )];
586 r
587 }
588 });
589 }
590
591 match &self.fallback {
593 Some(app) => app.execute(request, connection),
594 None => App::new().execute(request, connection),
595 }
596 }
597}
598
599pub fn extract_arg(arguments: &str, name: &str) -> Option<String> {
609 json_rpc::extract_str(arguments, name)
610}
611
612fn json_response(body: &str) -> Response {
615 let mut r = Response::new();
616 r.status_code = *STATUS_CODE_REASON_PHRASE.n200_ok.status_code;
617 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n200_ok.reason_phrase.to_string();
618 r.content_range_list = vec![Range::get_content_range(
619 body.as_bytes().to_vec(),
620 MimeType::APPLICATION_JSON.to_string(),
621 )];
622 r
623}
624
625fn no_content() -> Response {
626 let mut r = Response::new();
627 r.status_code = *STATUS_CODE_REASON_PHRASE.n202_accepted.status_code;
628 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n202_accepted.reason_phrase.to_string();
629 r
630}
631
632fn unauthorized() -> Response {
633 let mut r = Response::new();
634 r.status_code = *STATUS_CODE_REASON_PHRASE.n401_unauthorized.status_code;
635 r.reason_phrase = STATUS_CODE_REASON_PHRASE.n401_unauthorized.reason_phrase.to_string();
636 r.headers.push(Header {
637 name: "WWW-Authenticate".to_string(),
638 value: "Bearer".to_string(),
639 });
640 r.content_range_list = vec![Range::get_content_range(
641 b"Unauthorized".to_vec(),
642 MimeType::TEXT_PLAIN.to_string(),
643 )];
644 r
645}
646
647fn rpc_error(id: Option<&str>, code: i32, message: &str) -> Response {
648 let id_str = id.unwrap_or("null");
649 let escaped = json_escape(message);
650 json_response(&format!(
651 r#"{{"jsonrpc":"2.0","error":{{"code":{code},"message":"{escaped}"}},"id":{id_str}}}"#
652 ))
653}
654
655pub(crate) fn json_escape(s: &str) -> String {
656 let mut out = String::with_capacity(s.len() + 4);
657 for ch in s.chars() {
658 match ch {
659 '"' => out.push_str("\\\""),
660 '\\' => out.push_str("\\\\"),
661 '\n' => out.push_str("\\n"),
662 '\r' => out.push_str("\\r"),
663 '\t' => out.push_str("\\t"),
664 c if (c as u32) < 0x20 => { let _ = std::fmt::Write::write_fmt(&mut out, format_args!("\\u{:04x}", c as u32)); }
665 c => out.push(c),
666 }
667 }
668 out
669}
670
671fn uri_matches(template: &str, uri: &str) -> bool {
672 match template.find('{') {
674 Some(pos) => uri.starts_with(&template[..pos]),
675 None => template == uri,
676 }
677}