1use axum::{
4 extract::State,
5 http::StatusCode,
6 routing::{get, post},
7 Json, Router,
8};
9use serde::{Deserialize, Serialize};
10use std::{error::Error, fmt, sync::Arc};
11use tokitai::mcp;
12use tokitai_core::serde_types;
13use tower_http::{
14 cors::{Any, CorsLayer},
15 trace::TraceLayer,
16};
17use tracing::{info, warn};
18
19#[derive(Debug)]
21pub enum ServerError {
22 ToolNotFound(String),
24 ToolExecutionError(String),
26 InvalidArguments(String),
28 ServerStartupError(String),
30}
31
32impl fmt::Display for ServerError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 ServerError::ToolNotFound(name) => write!(f, "Tool not found: {}", name),
36 ServerError::ToolExecutionError(msg) => write!(f, "Tool execution error: {}", msg),
37 ServerError::InvalidArguments(msg) => write!(f, "Invalid arguments: {}", msg),
38 ServerError::ServerStartupError(msg) => write!(f, "Server startup error: {}", msg),
39 }
40 }
41}
42
43impl Error for ServerError {}
44
45impl From<Box<dyn Error>> for ServerError {
46 fn from(err: Box<dyn Error>) -> Self {
47 ServerError::ToolExecutionError(err.to_string())
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct McpServerConfig {
54 pub host: String,
56 pub port: u16,
58 pub cors_enabled: bool,
60 pub tracing_enabled: bool,
62}
63
64impl Default for McpServerConfig {
65 fn default() -> Self {
66 Self {
67 host: "127.0.0.1".to_string(),
68 port: 8080,
69 cors_enabled: true,
70 tracing_enabled: true,
71 }
72 }
73}
74
75impl McpServerConfig {
76 pub fn new(host: impl Into<String>, port: u16) -> Self {
78 Self {
79 host: host.into(),
80 port,
81 ..Default::default()
82 }
83 }
84
85 pub fn with_cors(mut self, enabled: bool) -> Self {
87 self.cors_enabled = enabled;
88 self
89 }
90
91 pub fn with_tracing(mut self, enabled: bool) -> Self {
93 self.tracing_enabled = enabled;
94 self
95 }
96
97 pub fn address(&self) -> String {
99 format!("{}:{}", self.host, self.port)
100 }
101}
102
103#[derive(Debug, Deserialize)]
105pub struct ToolCallRequest {
106 pub name: String,
107 #[serde(default)]
108 pub arguments: serde_json::Value,
109}
110
111#[derive(Debug, Serialize)]
113pub struct ToolCallResponse {
114 pub success: bool,
115 #[serde(skip_serializing_if = "Option::is_none")]
116 pub result: Option<serde_json::Value>,
117 #[serde(skip_serializing_if = "Option::is_none")]
118 pub error: Option<String>,
119}
120
121impl ToolCallResponse {
122 pub fn success(result: serde_json::Value) -> Self {
123 Self {
124 success: true,
125 result: Some(result),
126 error: None,
127 }
128 }
129
130 pub fn error(message: impl Into<String>) -> Self {
131 Self {
132 success: false,
133 result: None,
134 error: Some(message.into()),
135 }
136 }
137}
138
139struct ToolRegistry {
141 tools: Vec<mcp::McpTool>,
142}
143
144impl ToolRegistry {
145 fn new(tools: Vec<mcp::McpTool>) -> Self {
146 Self { tools }
147 }
148
149 fn find(&self, name: &str) -> Option<&mcp::McpTool> {
150 self.tools.iter().find(|t| t.name == name)
151 }
152}
153
154struct AppState {
156 registry: ToolRegistry,
157}
158
159pub struct McpServerBuilder<T> {
191 config: McpServerConfig,
192 tool_provider: T,
193}
194
195impl<T> McpServerBuilder<T>
196where
197 T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Send + Sync + 'static,
198{
199 pub fn with_tool(tool: T) -> Self {
201 Self {
202 config: McpServerConfig::default(),
203 tool_provider: tool,
204 }
205 }
206
207 pub fn with_config(config: McpServerConfig, tool: T) -> Self {
209 Self {
210 config,
211 tool_provider: tool,
212 }
213 }
214
215 pub fn with_port(mut self, port: u16) -> Self {
217 self.config.port = port;
218 self
219 }
220
221 pub fn with_host(mut self, host: impl Into<String>) -> Self {
223 self.config.host = host.into();
224 self
225 }
226
227 pub fn with_cors(mut self, enabled: bool) -> Self {
229 self.config.cors_enabled = enabled;
230 self
231 }
232
233 pub fn with_tracing(mut self, enabled: bool) -> Self {
235 self.config.tracing_enabled = enabled;
236 self
237 }
238
239 pub fn build(self) -> McpServerWithProvider<T>
241 where
242 T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Send + Sync + 'static,
243 {
244 let tools = get_tools_from_provider(&self.tool_provider);
247 McpServerWithProvider {
248 config: self.config,
249 tool_provider: Arc::new(self.tool_provider),
250 tools,
251 }
252 }
253}
254
255fn get_tools_from_provider<T>(provider: &T) -> Vec<mcp::McpTool>
258where
259 T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Send + Sync + 'static,
260{
261 let static_tools = T::tool_definitions();
263 if !static_tools.is_empty() {
264 return mcp::to_mcp_tools(static_tools);
265 }
266
267 get_tools_from_provider_runtime(provider)
271}
272
273fn get_tools_from_provider_runtime<T>(provider: &T) -> Vec<mcp::McpTool>
286where
287 T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Send + Sync + 'static,
288{
289 use std::any::Any;
290 if let Some(multi) = (provider as &dyn Any).downcast_ref::<MultiToolProvider>() {
291 return multi.tool_definitions().to_vec();
292 }
293
294 Vec::new()
295}
296
297pub struct McpServerWithProvider<T> {
299 config: McpServerConfig,
300 tool_provider: Arc<T>,
301 tools: Vec<mcp::McpTool>,
302}
303
304impl<T> McpServerWithProvider<T>
305where
306 T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Send + Sync + 'static,
307{
308 pub fn new(config: McpServerConfig, tool_provider: T) -> Self {
310 let tools = get_tools_from_provider(&tool_provider);
311 Self {
312 config,
313 tool_provider: Arc::new(tool_provider),
314 tools,
315 }
316 }
317
318 pub async fn run(&self) -> Result<(), ServerError> {
320 self.run_with_address(&self.config.address()).await
321 }
322
323 pub async fn run_with_address(&self, addr: &str) -> Result<(), ServerError> {
335 if self.config.tracing_enabled && !tracing::dispatcher::has_been_set() {
337 tracing_subscriber::fmt()
338 .with_env_filter(
339 tracing_subscriber::EnvFilter::from_default_env()
340 .add_directive("tokitai_mcp_server=info".parse().unwrap()),
341 )
342 .init();
343 }
344
345 let state = Arc::new(AppStateWithProvider {
346 registry: ToolRegistry::new(self.tools.clone()),
347 tool_provider: self.tool_provider.clone(), });
349
350 let mut app = Router::new()
352 .route("/tools", get(list_tools_handler_with_provider))
353 .route("/call", post(call_tool_handler_with_provider))
354 .route("/health", get(health_handler))
355 .with_state(state);
356
357 if self.config.cors_enabled {
359 let cors = CorsLayer::new()
360 .allow_origin(Any)
361 .allow_methods(Any)
362 .allow_headers(Any);
363 app = app.layer(cors);
364 }
365
366 if self.config.tracing_enabled {
368 app = app.layer(TraceLayer::new_for_http());
369 }
370
371 info!("Starting MCP server on http://{}", addr);
372 info!("Endpoints:");
373 info!(" GET /tools - List available tools");
374 info!(" POST /call - Call a tool");
375 info!(" GET /health - Health check");
376
377 let listener = tokio::net::TcpListener::bind(addr)
378 .await
379 .map_err(|e| ServerError::ServerStartupError(e.to_string()))?;
380
381 axum::serve(listener, app)
382 .await
383 .map_err(|e| ServerError::ServerStartupError(e.to_string()))?;
384
385 Ok(())
386 }
387
388 pub fn config(&self) -> &McpServerConfig {
390 &self.config
391 }
392
393 pub fn tools(&self) -> &[mcp::McpTool] {
395 &self.tools
396 }
397
398 pub fn tool_provider(&self) -> &T {
400 &self.tool_provider
401 }
402}
403
404struct AppStateWithProvider<T> {
406 registry: ToolRegistry,
407 tool_provider: Arc<T>,
408}
409
410pub struct McpServer {
429 config: McpServerConfig,
430 tools: Vec<mcp::McpTool>,
431}
432
433impl Default for McpServer {
434 fn default() -> Self {
435 Self::new()
436 }
437}
438
439impl McpServer {
440 pub fn new() -> Self {
442 Self {
443 config: McpServerConfig::default(),
444 tools: Vec::new(),
445 }
446 }
447
448 pub fn with_config(config: McpServerConfig) -> Self {
450 Self {
451 config,
452 tools: Vec::new(),
453 }
454 }
455
456 pub fn from_tools(tools: Vec<mcp::McpTool>) -> Self {
458 Self {
459 config: McpServerConfig::default(),
460 tools,
461 }
462 }
463
464 pub async fn run(&self) -> Result<(), ServerError> {
466 self.run_with_address(&self.config.address()).await
467 }
468
469 pub async fn run_with_address(&self, addr: &str) -> Result<(), ServerError> {
481 if self.config.tracing_enabled && !tracing::dispatcher::has_been_set() {
483 tracing_subscriber::fmt()
484 .with_env_filter(
485 tracing_subscriber::EnvFilter::from_default_env()
486 .add_directive("tokitai_mcp_server=info".parse().unwrap()),
487 )
488 .init();
489 }
490
491 let state = Arc::new(AppState {
492 registry: ToolRegistry::new(self.tools.clone()),
493 });
494
495 let mut app = Router::new()
497 .route("/tools", get(list_tools_handler))
498 .route("/call", post(call_tool_handler))
499 .route("/health", get(health_handler))
500 .with_state(state);
501
502 if self.config.cors_enabled {
504 let cors = CorsLayer::new()
505 .allow_origin(Any)
506 .allow_methods(Any)
507 .allow_headers(Any);
508 app = app.layer(cors);
509 }
510
511 if self.config.tracing_enabled {
513 app = app.layer(TraceLayer::new_for_http());
514 }
515
516 info!("Starting MCP server on http://{}", addr);
517 info!("Endpoints:");
518 info!(" GET /tools - List available tools");
519 info!(" POST /call - Call a tool");
520 info!(" GET /health - Health check");
521
522 let listener = tokio::net::TcpListener::bind(addr)
523 .await
524 .map_err(|e| ServerError::ServerStartupError(e.to_string()))?;
525
526 axum::serve(listener, app)
527 .await
528 .map_err(|e| ServerError::ServerStartupError(e.to_string()))?;
529
530 Ok(())
531 }
532
533 pub fn config(&self) -> &McpServerConfig {
535 &self.config
536 }
537
538 pub fn tools(&self) -> &[mcp::McpTool] {
540 &self.tools
541 }
542}
543
544async fn list_tools_handler(State(state): State<Arc<AppState>>) -> Json<Vec<mcp::McpTool>> {
550 Json(state.registry.tools.clone())
551}
552
553async fn call_tool_handler(
559 State(_state): State<Arc<AppState>>,
560 Json(request): Json<ToolCallRequest>,
561) -> Result<Json<ToolCallResponse>, StatusCode> {
562 info!("Tool call request (read-only mode): name={}", request.name);
563 Err(StatusCode::NOT_IMPLEMENTED)
565}
566
567async fn list_tools_handler_with_provider<T>(
569 State(state): State<Arc<AppStateWithProvider<T>>>,
570) -> Json<Vec<mcp::McpTool>>
571where
572 T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Send + Sync + 'static,
573{
574 Json(state.registry.tools.clone())
575}
576
577async fn call_tool_handler_with_provider<T>(
579 State(state): State<Arc<AppStateWithProvider<T>>>,
580 Json(request): Json<ToolCallRequest>,
581) -> Result<Json<ToolCallResponse>, StatusCode>
582where
583 T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Send + Sync + 'static,
584{
585 info!(
586 "Tool call request: name={}, arguments={:?}",
587 request.name, request.arguments
588 );
589
590 let tool = state.registry.find(&request.name).ok_or_else(|| {
592 warn!("Tool not found: {}", request.name);
593 StatusCode::NOT_FOUND
594 })?;
595
596 info!("Found tool: {} - {}", tool.name, tool.description);
597
598 match state
600 .tool_provider
601 .call_tool(&request.name, &request.arguments)
602 {
603 Ok(result) => {
604 info!("Tool executed successfully: {}", request.name);
605 Ok(Json(ToolCallResponse::success(result)))
606 }
607 Err(e) => {
608 warn!("Tool execution failed: {} - {}", request.name, e);
609 Ok(Json(ToolCallResponse::error(format!("{}", e))))
610 }
611 }
612}
613
614async fn health_handler() -> &'static str {
616 "OK"
617}
618
619pub struct MultiToolProvider {
641 providers: Vec<Box<dyn ToolCallerDyn>>,
642 tool_defs: Vec<mcp::McpTool>,
643}
644
645pub trait ToolCallerDyn: Send + Sync {
680 fn call_tool(
691 &self,
692 name: &str,
693 args: &serde_json::Value,
694 ) -> Result<serde_json::Value, tokitai_core::ToolError>;
695}
696
697impl<T> ToolCallerDyn for T
698where
699 T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Send + Sync + 'static,
700{
701 fn call_tool(
702 &self,
703 name: &str,
704 args: &serde_json::Value,
705 ) -> Result<serde_json::Value, tokitai_core::ToolError> {
706 self.call_tool(name, args)
707 }
708}
709
710impl MultiToolProvider {
711 pub fn new() -> Self {
713 Self {
714 providers: Vec::new(),
715 tool_defs: Vec::new(),
716 }
717 }
718
719 pub fn add<T>(&mut self, tool: T)
729 where
730 T: tokitai_core::ToolProvider + tokitai_core::ToolCaller + Send + Sync + 'static,
731 {
732 for def in T::tool_definitions() {
734 let schema: serde_json::Value =
736 serde_json::from_str(&def.input_schema).unwrap_or_else(|_| serde_json::json!({}));
737
738 let mcp_tool = mcp::McpTool {
739 name: def.name.clone(),
740 description: def.description.clone(),
741 input_schema: schema,
742 };
743 self.tool_defs.push(mcp_tool);
744 }
745
746 self.providers.push(Box::new(tool));
748 }
749
750 pub fn tool_definitions(&self) -> &[mcp::McpTool] {
752 &self.tool_defs
753 }
754}
755
756impl Default for MultiToolProvider {
757 fn default() -> Self {
758 Self::new()
759 }
760}
761
762impl MultiToolProvider {
763 pub fn clone_definitions(&self) -> Self {
798 if !self.tool_defs.is_empty() {
799 tracing::debug!(
800 "Cloning MultiToolProvider definitions ({} tools). \
801 Note: The cloned instance has no tool implementations - \
802 only metadata (names, descriptions, schemas).",
803 self.tool_defs.len()
804 );
805 }
806 Self {
807 providers: Vec::new(),
808 tool_defs: self.tool_defs.clone(),
809 }
810 }
811}
812
813impl tokitai_core::ToolProvider for MultiToolProvider {
814 fn tool_definitions() -> &'static [tokitai_core::ToolDefinition] {
815 &[]
819 }
820}
821
822impl tokitai_core::ToolCaller for MultiToolProvider {
823 fn call_tool(
824 &self,
825 name: &str,
826 args: &serde_types::Value,
827 ) -> Result<serde_types::Value, tokitai_core::ToolError> {
828 for provider in &self.providers {
830 match provider.call_tool(name, args) {
833 Ok(result) => return Ok(result),
834 Err(e) => {
835 if matches!(e.kind, tokitai_core::ToolErrorKind::NotFound) {
837 continue;
838 }
839 return Err(e);
841 }
842 }
843 }
844
845 Err(tokitai_core::ToolError::not_found(format!(
847 "Tool '{}' not found in any provider",
848 name
849 )))
850 }
851}