turbomcp_server/routing/
mod.rs

1//! Request routing and handler dispatch system
2//!
3//! This module provides a comprehensive routing system for MCP protocol requests,
4//! supporting all standard MCP methods with enterprise features like RBAC,
5//! JSON Schema validation, timeout management, and bidirectional communication.
6
7mod bidirectional;
8mod config;
9mod handlers;
10mod traits;
11mod utils;
12mod validation;
13#[cfg(feature = "websocket")]
14mod websocket_dispatcher_adapter;
15
16// Re-export public types to maintain API compatibility
17pub use bidirectional::BidirectionalRouter;
18pub use config::RouterConfig;
19pub use traits::{Route, RouteHandler, RouteMetadata, ServerRequestDispatcher};
20#[cfg(feature = "websocket")]
21pub use websocket_dispatcher_adapter::WebSocketDispatcherAdapter;
22
23use dashmap::DashMap;
24use futures::stream::{self, StreamExt};
25use std::collections::HashMap;
26use std::sync::Arc;
27use tracing::warn;
28use turbomcp_protocol::RequestContext;
29
30#[cfg(feature = "multi-tenancy")]
31use crate::tenant_context::TenantContextExt;
32use turbomcp_protocol::{
33    jsonrpc::{JsonRpcRequest, JsonRpcResponse},
34    types::{
35        CreateMessageRequest, ElicitRequest, ElicitResult, ListRootsResult, PingRequest, PingResult,
36    },
37};
38
39use crate::capabilities::ServerToClientAdapter;
40use crate::metrics::ServerMetrics;
41use crate::registry::HandlerRegistry;
42use crate::{ServerError, ServerResult};
43
44use handlers::{HandlerContext, ProtocolHandlers};
45use turbomcp_protocol::context::capabilities::ServerToClientRequests;
46use utils::{error_response, method_not_found_response};
47use validation::{validate_request, validate_response};
48
49/// Request router for dispatching MCP requests to appropriate handlers
50pub struct RequestRouter {
51    /// Handler registry
52    registry: Arc<HandlerRegistry>,
53    /// Route configuration
54    config: RouterConfig,
55    /// Server configuration (for protocol responses)
56    server_config: crate::config::ServerConfig,
57    /// Custom route handlers
58    custom_routes: HashMap<String, Arc<dyn RouteHandler>>,
59    /// Resource subscription counters by URI (reserved for future functionality)
60    #[allow(dead_code)]
61    resource_subscriptions: DashMap<String, usize>,
62    /// Bidirectional communication router
63    bidirectional: BidirectionalRouter,
64    /// Protocol handlers
65    handlers: ProtocolHandlers,
66    /// Server-to-client requests adapter for tool-initiated requests (sampling, elicitation, roots)
67    /// This is injected into RequestContext so tools can make server-initiated requests
68    server_to_client: Arc<dyn ServerToClientRequests>,
69    /// Task storage for MCP Tasks API (SEP-1686)
70    #[cfg(feature = "mcp-tasks")]
71    task_storage: Option<Arc<crate::task_storage::TaskStorage>>,
72}
73
74impl std::fmt::Debug for RequestRouter {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("RequestRouter")
77            .field("config", &self.config)
78            .field("custom_routes_count", &self.custom_routes.len())
79            .finish()
80    }
81}
82
83impl RequestRouter {
84    /// Create a new request router
85    #[must_use]
86    pub fn new(
87        registry: Arc<HandlerRegistry>,
88        _metrics: Arc<ServerMetrics>,
89        server_config: crate::config::ServerConfig,
90        #[cfg(feature = "mcp-tasks")] task_storage: Option<Arc<crate::task_storage::TaskStorage>>,
91    ) -> Self {
92        // Timeout management is now handled by middleware
93        let config = RouterConfig::default();
94
95        let handler_context = HandlerContext::new(
96            Arc::clone(&registry),
97            server_config.clone(),
98            #[cfg(feature = "mcp-tasks")]
99            task_storage.clone().unwrap_or_else(|| {
100                // Fallback: create empty storage if none provided
101                use tokio::time::Duration;
102                Arc::new(crate::task_storage::TaskStorage::new(Duration::from_secs(
103                    60,
104                )))
105            }),
106        );
107
108        let bidirectional = BidirectionalRouter::new();
109
110        // Create the server-to-client adapter that bridges bidirectional router
111        // to the ServerToClientRequests trait (type-safe, zero-cost abstraction)
112        let server_to_client: Arc<dyn ServerToClientRequests> =
113            Arc::new(ServerToClientAdapter::new(bidirectional.clone()));
114
115        Self {
116            registry,
117            config,
118            server_config,
119            custom_routes: HashMap::new(),
120            resource_subscriptions: DashMap::new(),
121            bidirectional,
122            handlers: ProtocolHandlers::new(handler_context),
123            server_to_client,
124            #[cfg(feature = "mcp-tasks")]
125            task_storage,
126        }
127    }
128
129    /// Create a router with configuration
130    #[must_use]
131    pub fn with_config(
132        registry: Arc<HandlerRegistry>,
133        config: RouterConfig,
134        _metrics: Arc<ServerMetrics>,
135        server_config: crate::config::ServerConfig,
136        #[cfg(feature = "mcp-tasks")] task_storage: Option<Arc<crate::task_storage::TaskStorage>>,
137    ) -> Self {
138        // Timeout management is now handled by middleware
139
140        let handler_context = HandlerContext::new(
141            Arc::clone(&registry),
142            server_config.clone(),
143            #[cfg(feature = "mcp-tasks")]
144            task_storage.clone().unwrap_or_else(|| {
145                // Fallback: create empty storage if none provided
146                use tokio::time::Duration;
147                Arc::new(crate::task_storage::TaskStorage::new(Duration::from_secs(
148                    60,
149                )))
150            }),
151        );
152
153        let bidirectional = BidirectionalRouter::new();
154
155        // Create the server-to-client adapter that bridges bidirectional router
156        // to the ServerToClientRequests trait (type-safe, zero-cost abstraction)
157        let server_to_client: Arc<dyn ServerToClientRequests> =
158            Arc::new(ServerToClientAdapter::new(bidirectional.clone()));
159
160        Self {
161            registry,
162            config,
163            server_config,
164            custom_routes: HashMap::new(),
165            resource_subscriptions: DashMap::new(),
166            bidirectional,
167            handlers: ProtocolHandlers::new(handler_context),
168            server_to_client,
169            #[cfg(feature = "mcp-tasks")]
170            task_storage,
171        }
172    }
173
174    // Timeout configuration now handled by middleware - no longer needed
175
176    /// Set the server request dispatcher for bidirectional communication
177    ///
178    /// CRITICAL: This also refreshes the server_to_client adapter so it sees the new dispatcher.
179    /// Without this refresh, the adapter would still point to the old (empty) bidirectional router.
180    pub fn set_server_request_dispatcher<D>(&mut self, dispatcher: D)
181    where
182        D: ServerRequestDispatcher + 'static,
183    {
184        self.bidirectional.set_dispatcher(dispatcher);
185
186        // CRITICAL FIX: Recreate the adapter so it sees the new dispatcher
187        // The adapter was created with a clone of bidirectional BEFORE the dispatcher was set.
188        // Since BidirectionalRouter::set_dispatcher() replaces the Option rather than mutating
189        // through it, the adapter's clone still has dispatcher: None.
190        // By recreating it here, we ensure it gets a fresh clone that includes the dispatcher.
191        self.server_to_client = Arc::new(ServerToClientAdapter::new(self.bidirectional.clone()));
192    }
193
194    /// Get the server request dispatcher
195    pub fn get_server_request_dispatcher(&self) -> Option<&Arc<dyn ServerRequestDispatcher>> {
196        self.bidirectional.get_dispatcher()
197    }
198
199    /// Check if bidirectional routing is enabled and supported
200    pub fn supports_bidirectional(&self) -> bool {
201        self.config.enable_bidirectional && self.bidirectional.supports_bidirectional()
202    }
203
204    /// Add a custom route handler
205    ///
206    /// # Errors
207    ///
208    /// Returns [`ServerError::Routing`] if a route for the same method already exists.
209    pub fn add_route<H>(&mut self, handler: H) -> ServerResult<()>
210    where
211        H: RouteHandler + 'static,
212    {
213        let metadata = handler.metadata();
214        let handler_arc: Arc<dyn RouteHandler> = Arc::new(handler);
215
216        for method in &metadata.methods {
217            if self.custom_routes.contains_key(method) {
218                return Err(ServerError::routing_with_method(
219                    format!("Route for method '{method}' already exists"),
220                    method.clone(),
221                ));
222            }
223            self.custom_routes
224                .insert(method.clone(), Arc::clone(&handler_arc));
225        }
226
227        Ok(())
228    }
229
230    /// Create a properly configured RequestContext for this router
231    ///
232    /// This factory method creates a RequestContext with all necessary capabilities
233    /// pre-configured, including server-to-client communication for bidirectional
234    /// features (sampling, elicitation, roots).
235    ///
236    /// **Design Pattern**: Explicit Factory
237    /// - Context is valid from creation (no broken intermediate state)
238    /// - Router provides factory but doesn't modify contexts
239    /// - Follows Single Responsibility Principle
240    ///
241    /// **HTTP Header Propagation**: Pass headers from HTTP/WebSocket transports
242    /// to include them in context metadata as "http_headers".
243    ///
244    /// # Arguments
245    ///
246    /// * `headers` - Optional HTTP headers from the transport layer
247    /// * `transport` - Optional transport type ("http", "websocket", etc.). Defaults to "http" if headers are provided.
248    /// * `tenant_id` - Optional tenant identifier for multi-tenant deployments. Extracted by `TenantExtractionLayer` middleware.
249    ///
250    /// # Example
251    /// ```rust,ignore
252    /// // Single-tenant HTTP (tenant_id = None)
253    /// let ctx = router.create_context(Some(headers), None, None);
254    ///
255    /// // Multi-tenant WebSocket (with tenant ID from middleware)
256    /// let ctx = router.create_context(Some(headers), Some("websocket"), Some("acme-corp".to_string()));
257    /// let response = router.route(request, ctx).await;
258    /// ```
259    #[must_use]
260    pub fn create_context(
261        &self,
262        headers: Option<HashMap<String, String>>,
263        transport: Option<&str>,
264        tenant_id: Option<String>,
265    ) -> RequestContext {
266        let mut ctx =
267            RequestContext::new().with_server_to_client(Arc::clone(&self.server_to_client));
268
269        // Set tenant ID if provided (multi-tenant deployments)
270        #[cfg(feature = "multi-tenancy")]
271        if let Some(tenant_id) = tenant_id {
272            ctx = ctx.with_tenant(tenant_id);
273        }
274        #[cfg(not(feature = "multi-tenancy"))]
275        let _ = tenant_id; // Suppress unused warning when feature disabled
276
277        // Add HTTP headers to context if provided
278        if let Some(headers) = headers
279            && let Ok(headers_json) = serde_json::to_value(&headers)
280        {
281            ctx = ctx.with_metadata("http_headers", headers_json);
282            // Set transport type (default to "http" if not specified)
283            let transport_type = transport.unwrap_or("http");
284            ctx = ctx.with_metadata("transport", transport_type);
285        }
286
287        ctx
288    }
289
290    /// Route a JSON-RPC request to the appropriate handler
291    ///
292    /// **IMPORTANT**: The context should be created using `create_context()` to ensure
293    /// it has all necessary capabilities configured. This method does NOT modify the
294    /// context - it only routes the request.
295    ///
296    /// # Design Pattern
297    /// This follows the Single Responsibility Principle:
298    /// - `create_context()`: Creates properly configured contexts
299    /// - `route()`: Routes requests to handlers
300    ///
301    /// Previously, `route()` was modifying the context (adding server_to_client),
302    /// which violated SRP and created invalid intermediate states.
303    pub async fn route(&self, request: JsonRpcRequest, ctx: RequestContext) -> JsonRpcResponse {
304        // Validate request if enabled
305        if self.config.validate_requests
306            && let Err(e) = validate_request(&request)
307        {
308            return error_response(&request, e);
309        }
310
311        // Handle the request
312        let result = match request.method.as_str() {
313            // Core protocol methods
314            "initialize" => self.handlers.handle_initialize(request, ctx).await,
315
316            // Tool methods
317            "tools/list" => self.handlers.handle_list_tools(request, ctx).await,
318            "tools/call" => self.handlers.handle_call_tool(request, ctx).await,
319
320            // Prompt methods
321            "prompts/list" => self.handlers.handle_list_prompts(request, ctx).await,
322            "prompts/get" => self.handlers.handle_get_prompt(request, ctx).await,
323
324            // Resource methods
325            "resources/list" => self.handlers.handle_list_resources(request, ctx).await,
326            "resources/read" => self.handlers.handle_read_resource(request, ctx).await,
327            "resources/subscribe" => self.handlers.handle_subscribe_resource(request, ctx).await,
328            "resources/unsubscribe" => {
329                self.handlers
330                    .handle_unsubscribe_resource(request, ctx)
331                    .await
332            }
333
334            // Logging methods
335            "logging/setLevel" => self.handlers.handle_set_log_level(request, ctx).await,
336
337            // Sampling methods
338            "sampling/createMessage" => self.handlers.handle_create_message(request, ctx).await,
339
340            // Roots methods
341            "roots/list" => self.handlers.handle_list_roots(request, ctx).await,
342
343            // Enhanced MCP features (MCP 2025-11-25 protocol methods)
344            "elicitation/create" => self.handlers.handle_elicitation(request, ctx).await,
345            "completion/complete" => self.handlers.handle_completion(request, ctx).await,
346            "resources/templates/list" => {
347                self.handlers
348                    .handle_list_resource_templates(request, ctx)
349                    .await
350            }
351            "ping" => self.handlers.handle_ping(request, ctx).await,
352
353            // Tasks API (MCP 2025-11-25 draft - SEP-1686)
354            #[cfg(feature = "mcp-tasks")]
355            "tasks/get" => self.handlers.handle_get_task(request, ctx).await,
356            #[cfg(feature = "mcp-tasks")]
357            "tasks/result" => self.handlers.handle_task_result(request, ctx).await,
358            #[cfg(feature = "mcp-tasks")]
359            "tasks/list" => self.handlers.handle_list_tasks(request, ctx).await,
360            #[cfg(feature = "mcp-tasks")]
361            "tasks/cancel" => self.handlers.handle_cancel_task(request, ctx).await,
362
363            // Custom routes
364            method => {
365                if let Some(handler) = self.custom_routes.get(method) {
366                    let request_clone = request.clone();
367                    handler
368                        .handle(request, ctx)
369                        .await
370                        .unwrap_or_else(|e| error_response(&request_clone, e))
371                } else {
372                    method_not_found_response(&request)
373                }
374            }
375        };
376
377        // Validate response if enabled
378        if self.config.validate_responses
379            && let Err(e) = validate_response(&result)
380        {
381            warn!("Response validation failed: {}", e);
382        }
383
384        result
385    }
386
387    /// Handle batch requests
388    pub async fn route_batch(
389        &self,
390        requests: Vec<JsonRpcRequest>,
391        ctx: RequestContext,
392    ) -> Vec<JsonRpcResponse> {
393        // Note: Server capabilities are injected in route() for each request
394        let max_in_flight = self.config.max_concurrent_requests.max(1);
395        stream::iter(requests.into_iter())
396            .map(|req| {
397                let ctx_cloned = ctx.clone();
398                async move { self.route(req, ctx_cloned).await }
399            })
400            .buffer_unordered(max_in_flight)
401            .collect()
402            .await
403    }
404
405    /// Send an elicitation request to the client (server-initiated)
406    ///
407    /// # Errors
408    ///
409    /// Returns [`ServerError::Transport`] if:
410    /// - The bidirectional dispatcher is not configured
411    /// - The client request fails
412    /// - The client does not respond
413    pub async fn send_elicitation_to_client(
414        &self,
415        request: ElicitRequest,
416        ctx: RequestContext,
417    ) -> ServerResult<ElicitResult> {
418        self.bidirectional
419            .send_elicitation_to_client(request, ctx)
420            .await
421    }
422
423    /// Send a ping request to the client (server-initiated)
424    ///
425    /// # Errors
426    ///
427    /// Returns [`ServerError::Transport`] if:
428    /// - The bidirectional dispatcher is not configured
429    /// - The client request fails
430    /// - The client does not respond
431    pub async fn send_ping_to_client(
432        &self,
433        request: PingRequest,
434        ctx: RequestContext,
435    ) -> ServerResult<PingResult> {
436        self.bidirectional.send_ping_to_client(request, ctx).await
437    }
438
439    /// Send a create message request to the client (server-initiated)
440    ///
441    /// # Errors
442    ///
443    /// Returns [`ServerError::Transport`] if:
444    /// - The bidirectional dispatcher is not configured
445    /// - The client request fails
446    /// - The client does not support sampling
447    pub async fn send_create_message_to_client(
448        &self,
449        request: CreateMessageRequest,
450        ctx: RequestContext,
451    ) -> ServerResult<turbomcp_protocol::types::CreateMessageResult> {
452        self.bidirectional
453            .send_create_message_to_client(request, ctx)
454            .await
455    }
456
457    /// Send a list roots request to the client (server-initiated)
458    ///
459    /// # Errors
460    ///
461    /// Returns [`ServerError::Transport`] if:
462    /// - The bidirectional dispatcher is not configured
463    /// - The client request fails
464    /// - The client does not support roots
465    pub async fn send_list_roots_to_client(
466        &self,
467        request: turbomcp_protocol::types::ListRootsRequest,
468        ctx: RequestContext,
469    ) -> ServerResult<ListRootsResult> {
470        self.bidirectional
471            .send_list_roots_to_client(request, ctx)
472            .await
473    }
474
475    /// Get task storage (exposed for testing)
476    #[cfg(feature = "mcp-tasks")]
477    #[doc(hidden)]
478    pub fn get_task_storage(&self) -> Option<Arc<crate::task_storage::TaskStorage>> {
479        self.task_storage.clone()
480    }
481}
482
483impl Clone for RequestRouter {
484    fn clone(&self) -> Self {
485        Self {
486            registry: Arc::clone(&self.registry),
487            config: self.config.clone(),
488            server_config: self.server_config.clone(),
489            custom_routes: self.custom_routes.clone(),
490            resource_subscriptions: DashMap::new(),
491            bidirectional: self.bidirectional.clone(),
492            handlers: ProtocolHandlers::new(HandlerContext::new(
493                Arc::clone(&self.registry),
494                self.server_config.clone(),
495                #[cfg(feature = "mcp-tasks")]
496                self.task_storage.clone().unwrap_or_else(|| {
497                    // Fallback: create empty storage if none provided
498                    use tokio::time::Duration;
499                    Arc::new(crate::task_storage::TaskStorage::new(Duration::from_secs(
500                        60,
501                    )))
502                }),
503            )),
504            server_to_client: Arc::clone(&self.server_to_client),
505            #[cfg(feature = "mcp-tasks")]
506            task_storage: self.task_storage.clone(),
507        }
508    }
509}
510
511// Design Note: ServerCapabilities trait implementation
512//
513// RequestRouter currently uses BidirectionalRouter for server-initiated requests
514// (sampling, elicitation, roots) instead of directly implementing the ServerCapabilities
515// trait from turbomcp_protocol::context::capabilities.
516//
517// Current Pattern:
518// - RequestRouter contains BidirectionalRouter which handles server-to-client requests
519// - BidirectionalRouter uses ServerRequestDispatcher trait for transport-agnostic dispatch
520// - This pattern provides better separation of concerns and testability
521//
522// Alternative (not implemented):
523// - RequestRouter could implement ServerCapabilities trait directly
524// - This would allow passing router as &dyn ServerCapabilities to tools
525// - Current pattern is preferred as it keeps routing and bidirectional concerns separate
526//
527// See: crates/turbomcp-server/src/routing/bidirectional.rs for current implementation
528
529/// Router alias for convenience
530pub type Router = RequestRouter;
531
532// ===================================================================
533// JsonRpcHandler Implementation - For HTTP Transport Integration
534// ===================================================================
535
536#[async_trait::async_trait]
537impl turbomcp_protocol::JsonRpcHandler for RequestRouter {
538    /// Handle a JSON-RPC request via the HTTP transport
539    ///
540    /// This implementation enables `RequestRouter` to be used directly with
541    /// the HTTP transport layer (`run_server`), supporting the builder pattern
542    /// for programmatic server construction.
543    ///
544    /// # Architecture
545    ///
546    /// - Parses raw JSON into `JsonRpcRequest`
547    /// - Creates default `RequestContext` (no auth/session for HTTP)
548    /// - Routes through the existing `route()` method
549    /// - Serializes `JsonRpcResponse` back to JSON
550    ///
551    /// This provides the same request handling as the macro pattern but
552    /// allows runtime handler registration via `ServerBuilder`.
553    async fn handle_request(&self, req_value: serde_json::Value) -> serde_json::Value {
554        // Parse the request
555        let req: JsonRpcRequest = match serde_json::from_value(req_value) {
556            Ok(r) => r,
557            Err(e) => {
558                return serde_json::json!({
559                    "jsonrpc": "2.0",
560                    "error": {
561                        "code": -32700,
562                        "message": format!("Parse error: {}", e)
563                    },
564                    "id": null
565                });
566            }
567        };
568
569        // Create properly configured context with server-to-client capabilities
570        // Note: For authenticated HTTP requests, middleware should add auth info via with_* methods
571        // For HTTP requests with headers, use the HTTP-specific entry point that passes headers
572        let ctx = self.create_context(None, None, None);
573
574        // Route the request through the standard routing system
575        let response = self.route(req, ctx).await;
576
577        // Serialize response
578        match serde_json::to_value(&response) {
579            Ok(v) => v,
580            Err(e) => {
581                serde_json::json!({
582                    "jsonrpc": "2.0",
583                    "error": {
584                        "code": -32603,
585                        "message": format!("Internal error: failed to serialize response: {}", e)
586                    },
587                    "id": response.id
588                })
589            }
590        }
591    }
592}
593
594// Comprehensive tests in separate file (tokio/axum pattern)
595#[cfg(test)]
596mod tests;