turbomcp_server/routing/
mod.rs

1//! Request routing and handler dispatch system
2//!
3//! This module provides a comprehensive routing system for MCP protocol requests,
4//! supporting all standard MCP methods with enterprise features like RBAC,
5//! JSON Schema validation, timeout management, and bidirectional communication.
6
7mod bidirectional;
8mod config;
9mod handlers;
10mod traits;
11mod utils;
12mod validation;
13#[cfg(feature = "websocket")]
14mod websocket_dispatcher_adapter;
15
16// Re-export public types to maintain API compatibility
17pub use bidirectional::BidirectionalRouter;
18pub use config::RouterConfig;
19pub use traits::{Route, RouteHandler, RouteMetadata, ServerRequestDispatcher};
20#[cfg(feature = "websocket")]
21pub use websocket_dispatcher_adapter::WebSocketDispatcherAdapter;
22
23use dashmap::DashMap;
24use futures::stream::{self, StreamExt};
25use std::collections::HashMap;
26use std::sync::Arc;
27use tracing::warn;
28use turbomcp_protocol::RequestContext;
29use turbomcp_protocol::{
30    jsonrpc::{JsonRpcRequest, JsonRpcResponse},
31    types::{
32        CreateMessageRequest, ElicitRequest, ElicitResult, ListRootsResult, PingRequest, PingResult,
33    },
34};
35
36use crate::capabilities::ServerToClientAdapter;
37use crate::metrics::ServerMetrics;
38use crate::registry::HandlerRegistry;
39use crate::{ServerError, ServerResult};
40
41use handlers::{HandlerContext, ProtocolHandlers};
42use turbomcp_protocol::context::capabilities::ServerToClientRequests;
43use utils::{error_response, method_not_found_response};
44use validation::{validate_request, validate_response};
45
46/// Request router for dispatching MCP requests to appropriate handlers
47pub struct RequestRouter {
48    /// Handler registry
49    registry: Arc<HandlerRegistry>,
50    /// Route configuration
51    config: RouterConfig,
52    /// Server configuration (for protocol responses)
53    server_config: crate::config::ServerConfig,
54    /// Custom route handlers
55    custom_routes: HashMap<String, Arc<dyn RouteHandler>>,
56    /// Resource subscription counters by URI (reserved for future functionality)
57    #[allow(dead_code)]
58    resource_subscriptions: DashMap<String, usize>,
59    /// Bidirectional communication router
60    bidirectional: BidirectionalRouter,
61    /// Protocol handlers
62    handlers: ProtocolHandlers,
63    /// Server-to-client requests adapter for tool-initiated requests (sampling, elicitation, roots)
64    /// This is injected into RequestContext so tools can make server-initiated requests
65    server_to_client: Arc<dyn ServerToClientRequests>,
66}
67
68impl std::fmt::Debug for RequestRouter {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        f.debug_struct("RequestRouter")
71            .field("config", &self.config)
72            .field("custom_routes_count", &self.custom_routes.len())
73            .finish()
74    }
75}
76
77impl RequestRouter {
78    /// Create a new request router
79    #[must_use]
80    pub fn new(
81        registry: Arc<HandlerRegistry>,
82        _metrics: Arc<ServerMetrics>,
83        server_config: crate::config::ServerConfig,
84    ) -> Self {
85        // Timeout management is now handled by middleware
86        let config = RouterConfig::default();
87
88        let handler_context = HandlerContext::new(Arc::clone(&registry), server_config.clone());
89
90        let bidirectional = BidirectionalRouter::new();
91
92        // Create the server-to-client adapter that bridges bidirectional router
93        // to the ServerToClientRequests trait (type-safe, zero-cost abstraction)
94        let server_to_client: Arc<dyn ServerToClientRequests> =
95            Arc::new(ServerToClientAdapter::new(bidirectional.clone()));
96
97        Self {
98            registry,
99            config,
100            server_config,
101            custom_routes: HashMap::new(),
102            resource_subscriptions: DashMap::new(),
103            bidirectional,
104            handlers: ProtocolHandlers::new(handler_context),
105            server_to_client,
106        }
107    }
108
109    /// Create a router with configuration
110    #[must_use]
111    pub fn with_config(
112        registry: Arc<HandlerRegistry>,
113        config: RouterConfig,
114        _metrics: Arc<ServerMetrics>,
115        server_config: crate::config::ServerConfig,
116    ) -> Self {
117        // Timeout management is now handled by middleware
118
119        let handler_context = HandlerContext::new(Arc::clone(&registry), server_config.clone());
120
121        let bidirectional = BidirectionalRouter::new();
122
123        // Create the server-to-client adapter that bridges bidirectional router
124        // to the ServerToClientRequests trait (type-safe, zero-cost abstraction)
125        let server_to_client: Arc<dyn ServerToClientRequests> =
126            Arc::new(ServerToClientAdapter::new(bidirectional.clone()));
127
128        Self {
129            registry,
130            config,
131            server_config,
132            custom_routes: HashMap::new(),
133            resource_subscriptions: DashMap::new(),
134            bidirectional,
135            handlers: ProtocolHandlers::new(handler_context),
136            server_to_client,
137        }
138    }
139
140    // Timeout configuration now handled by middleware - no longer needed
141
142    /// Set the server request dispatcher for bidirectional communication
143    ///
144    /// CRITICAL: This also refreshes the server_to_client adapter so it sees the new dispatcher.
145    /// Without this refresh, the adapter would still point to the old (empty) bidirectional router.
146    pub fn set_server_request_dispatcher<D>(&mut self, dispatcher: D)
147    where
148        D: ServerRequestDispatcher + 'static,
149    {
150        self.bidirectional.set_dispatcher(dispatcher);
151
152        // CRITICAL FIX: Recreate the adapter so it sees the new dispatcher
153        // The adapter was created with a clone of bidirectional BEFORE the dispatcher was set.
154        // Since BidirectionalRouter::set_dispatcher() replaces the Option rather than mutating
155        // through it, the adapter's clone still has dispatcher: None.
156        // By recreating it here, we ensure it gets a fresh clone that includes the dispatcher.
157        self.server_to_client = Arc::new(ServerToClientAdapter::new(self.bidirectional.clone()));
158    }
159
160    /// Get the server request dispatcher
161    pub fn get_server_request_dispatcher(&self) -> Option<&Arc<dyn ServerRequestDispatcher>> {
162        self.bidirectional.get_dispatcher()
163    }
164
165    /// Check if bidirectional routing is enabled and supported
166    pub fn supports_bidirectional(&self) -> bool {
167        self.config.enable_bidirectional && self.bidirectional.supports_bidirectional()
168    }
169
170    /// Add a custom route handler
171    ///
172    /// # Errors
173    ///
174    /// Returns [`ServerError::Routing`] if a route for the same method already exists.
175    pub fn add_route<H>(&mut self, handler: H) -> ServerResult<()>
176    where
177        H: RouteHandler + 'static,
178    {
179        let metadata = handler.metadata();
180        let handler_arc: Arc<dyn RouteHandler> = Arc::new(handler);
181
182        for method in &metadata.methods {
183            if self.custom_routes.contains_key(method) {
184                return Err(ServerError::routing_with_method(
185                    format!("Route for method '{method}' already exists"),
186                    method.clone(),
187                ));
188            }
189            self.custom_routes
190                .insert(method.clone(), Arc::clone(&handler_arc));
191        }
192
193        Ok(())
194    }
195
196    /// Create a properly configured RequestContext for this router
197    ///
198    /// This factory method creates a RequestContext with all necessary capabilities
199    /// pre-configured, including server-to-client communication for bidirectional
200    /// features (sampling, elicitation, roots).
201    ///
202    /// **Design Pattern**: Explicit Factory
203    /// - Context is valid from creation (no broken intermediate state)
204    /// - Router provides factory but doesn't modify contexts
205    /// - Follows Single Responsibility Principle
206    ///
207    /// **HTTP Header Propagation**: Pass headers from HTTP/WebSocket transports
208    /// to include them in context metadata as "http_headers".
209    ///
210    /// # Arguments
211    ///
212    /// * `headers` - Optional HTTP headers from the transport layer
213    /// * `transport` - Optional transport type ("http", "websocket", etc.). Defaults to "http" if headers are provided.
214    ///
215    /// # Example
216    /// ```rust,ignore
217    /// // HTTP transport
218    /// let ctx = router.create_context(Some(headers), None);
219    ///
220    /// // WebSocket transport
221    /// let ctx = router.create_context(Some(headers), Some("websocket"));
222    /// let response = router.route(request, ctx).await;
223    /// ```
224    #[must_use]
225    pub fn create_context(
226        &self,
227        headers: Option<HashMap<String, String>>,
228        transport: Option<&str>,
229    ) -> RequestContext {
230        let mut ctx =
231            RequestContext::new().with_server_to_client(Arc::clone(&self.server_to_client));
232
233        // Add HTTP headers to context if provided
234        if let Some(headers) = headers
235            && let Ok(headers_json) = serde_json::to_value(&headers)
236        {
237            ctx = ctx.with_metadata("http_headers", headers_json);
238            // Set transport type (default to "http" if not specified)
239            let transport_type = transport.unwrap_or("http");
240            ctx = ctx.with_metadata("transport", transport_type);
241        }
242
243        ctx
244    }
245
246    /// Route a JSON-RPC request to the appropriate handler
247    ///
248    /// **IMPORTANT**: The context should be created using `create_context()` to ensure
249    /// it has all necessary capabilities configured. This method does NOT modify the
250    /// context - it only routes the request.
251    ///
252    /// # Design Pattern
253    /// This follows the Single Responsibility Principle:
254    /// - `create_context()`: Creates properly configured contexts
255    /// - `route()`: Routes requests to handlers
256    ///
257    /// Previously, `route()` was modifying the context (adding server_to_client),
258    /// which violated SRP and created invalid intermediate states.
259    pub async fn route(&self, request: JsonRpcRequest, ctx: RequestContext) -> JsonRpcResponse {
260        // Validate request if enabled
261        if self.config.validate_requests
262            && let Err(e) = validate_request(&request)
263        {
264            return error_response(&request, e);
265        }
266
267        // Handle the request
268        let result = match request.method.as_str() {
269            // Core protocol methods
270            "initialize" => self.handlers.handle_initialize(request, ctx).await,
271
272            // Tool methods
273            "tools/list" => self.handlers.handle_list_tools(request, ctx).await,
274            "tools/call" => self.handlers.handle_call_tool(request, ctx).await,
275
276            // Prompt methods
277            "prompts/list" => self.handlers.handle_list_prompts(request, ctx).await,
278            "prompts/get" => self.handlers.handle_get_prompt(request, ctx).await,
279
280            // Resource methods
281            "resources/list" => self.handlers.handle_list_resources(request, ctx).await,
282            "resources/read" => self.handlers.handle_read_resource(request, ctx).await,
283            "resources/subscribe" => self.handlers.handle_subscribe_resource(request, ctx).await,
284            "resources/unsubscribe" => {
285                self.handlers
286                    .handle_unsubscribe_resource(request, ctx)
287                    .await
288            }
289
290            // Logging methods
291            "logging/setLevel" => self.handlers.handle_set_log_level(request, ctx).await,
292
293            // Sampling methods
294            "sampling/createMessage" => self.handlers.handle_create_message(request, ctx).await,
295
296            // Roots methods
297            "roots/list" => self.handlers.handle_list_roots(request, ctx).await,
298
299            // Enhanced MCP features (MCP 2025-06-18 protocol methods)
300            "elicitation/create" => self.handlers.handle_elicitation(request, ctx).await,
301            "completion/complete" => self.handlers.handle_completion(request, ctx).await,
302            "resources/templates/list" => {
303                self.handlers
304                    .handle_list_resource_templates(request, ctx)
305                    .await
306            }
307            "ping" => self.handlers.handle_ping(request, ctx).await,
308
309            // Custom routes
310            method => {
311                if let Some(handler) = self.custom_routes.get(method) {
312                    let request_clone = request.clone();
313                    handler
314                        .handle(request, ctx)
315                        .await
316                        .unwrap_or_else(|e| error_response(&request_clone, e))
317                } else {
318                    method_not_found_response(&request)
319                }
320            }
321        };
322
323        // Validate response if enabled
324        if self.config.validate_responses
325            && let Err(e) = validate_response(&result)
326        {
327            warn!("Response validation failed: {}", e);
328        }
329
330        result
331    }
332
333    /// Handle batch requests
334    pub async fn route_batch(
335        &self,
336        requests: Vec<JsonRpcRequest>,
337        ctx: RequestContext,
338    ) -> Vec<JsonRpcResponse> {
339        // Note: Server capabilities are injected in route() for each request
340        let max_in_flight = self.config.max_concurrent_requests.max(1);
341        stream::iter(requests.into_iter())
342            .map(|req| {
343                let ctx_cloned = ctx.clone();
344                async move { self.route(req, ctx_cloned).await }
345            })
346            .buffer_unordered(max_in_flight)
347            .collect()
348            .await
349    }
350
351    /// Send an elicitation request to the client (server-initiated)
352    ///
353    /// # Errors
354    ///
355    /// Returns [`ServerError::Transport`] if:
356    /// - The bidirectional dispatcher is not configured
357    /// - The client request fails
358    /// - The client does not respond
359    pub async fn send_elicitation_to_client(
360        &self,
361        request: ElicitRequest,
362        ctx: RequestContext,
363    ) -> ServerResult<ElicitResult> {
364        self.bidirectional
365            .send_elicitation_to_client(request, ctx)
366            .await
367    }
368
369    /// Send a ping request to the client (server-initiated)
370    ///
371    /// # Errors
372    ///
373    /// Returns [`ServerError::Transport`] if:
374    /// - The bidirectional dispatcher is not configured
375    /// - The client request fails
376    /// - The client does not respond
377    pub async fn send_ping_to_client(
378        &self,
379        request: PingRequest,
380        ctx: RequestContext,
381    ) -> ServerResult<PingResult> {
382        self.bidirectional.send_ping_to_client(request, ctx).await
383    }
384
385    /// Send a create message request to the client (server-initiated)
386    ///
387    /// # Errors
388    ///
389    /// Returns [`ServerError::Transport`] if:
390    /// - The bidirectional dispatcher is not configured
391    /// - The client request fails
392    /// - The client does not support sampling
393    pub async fn send_create_message_to_client(
394        &self,
395        request: CreateMessageRequest,
396        ctx: RequestContext,
397    ) -> ServerResult<turbomcp_protocol::types::CreateMessageResult> {
398        self.bidirectional
399            .send_create_message_to_client(request, ctx)
400            .await
401    }
402
403    /// Send a list roots request to the client (server-initiated)
404    ///
405    /// # Errors
406    ///
407    /// Returns [`ServerError::Transport`] if:
408    /// - The bidirectional dispatcher is not configured
409    /// - The client request fails
410    /// - The client does not support roots
411    pub async fn send_list_roots_to_client(
412        &self,
413        request: turbomcp_protocol::types::ListRootsRequest,
414        ctx: RequestContext,
415    ) -> ServerResult<ListRootsResult> {
416        self.bidirectional
417            .send_list_roots_to_client(request, ctx)
418            .await
419    }
420}
421
422impl Clone for RequestRouter {
423    fn clone(&self) -> Self {
424        Self {
425            registry: Arc::clone(&self.registry),
426            config: self.config.clone(),
427            server_config: self.server_config.clone(),
428            custom_routes: self.custom_routes.clone(),
429            resource_subscriptions: DashMap::new(),
430            bidirectional: self.bidirectional.clone(),
431            handlers: ProtocolHandlers::new(HandlerContext::new(
432                Arc::clone(&self.registry),
433                self.server_config.clone(),
434            )),
435            server_to_client: Arc::clone(&self.server_to_client),
436        }
437    }
438}
439
440// Design Note: ServerCapabilities trait implementation
441//
442// RequestRouter currently uses BidirectionalRouter for server-initiated requests
443// (sampling, elicitation, roots) instead of directly implementing the ServerCapabilities
444// trait from turbomcp_protocol::context::capabilities.
445//
446// Current Pattern:
447// - RequestRouter contains BidirectionalRouter which handles server-to-client requests
448// - BidirectionalRouter uses ServerRequestDispatcher trait for transport-agnostic dispatch
449// - This pattern provides better separation of concerns and testability
450//
451// Alternative (not implemented):
452// - RequestRouter could implement ServerCapabilities trait directly
453// - This would allow passing router as &dyn ServerCapabilities to tools
454// - Current pattern is preferred as it keeps routing and bidirectional concerns separate
455//
456// See: crates/turbomcp-server/src/routing/bidirectional.rs for current implementation
457
458/// Router alias for convenience
459pub type Router = RequestRouter;
460
461// ===================================================================
462// JsonRpcHandler Implementation - For HTTP Transport Integration
463// ===================================================================
464
465#[async_trait::async_trait]
466impl turbomcp_protocol::JsonRpcHandler for RequestRouter {
467    /// Handle a JSON-RPC request via the HTTP transport
468    ///
469    /// This implementation enables `RequestRouter` to be used directly with
470    /// the HTTP transport layer (`run_server`), supporting the builder pattern
471    /// for programmatic server construction.
472    ///
473    /// # Architecture
474    ///
475    /// - Parses raw JSON into `JsonRpcRequest`
476    /// - Creates default `RequestContext` (no auth/session for HTTP)
477    /// - Routes through the existing `route()` method
478    /// - Serializes `JsonRpcResponse` back to JSON
479    ///
480    /// This provides the same request handling as the macro pattern but
481    /// allows runtime handler registration via `ServerBuilder`.
482    async fn handle_request(&self, req_value: serde_json::Value) -> serde_json::Value {
483        // Parse the request
484        let req: JsonRpcRequest = match serde_json::from_value(req_value) {
485            Ok(r) => r,
486            Err(e) => {
487                return serde_json::json!({
488                    "jsonrpc": "2.0",
489                    "error": {
490                        "code": -32700,
491                        "message": format!("Parse error: {}", e)
492                    },
493                    "id": null
494                });
495            }
496        };
497
498        // Create properly configured context with server-to-client capabilities
499        // Note: For authenticated HTTP requests, middleware should add auth info via with_* methods
500        // For HTTP requests with headers, use the HTTP-specific entry point that passes headers
501        let ctx = self.create_context(None, None);
502
503        // Route the request through the standard routing system
504        let response = self.route(req, ctx).await;
505
506        // Serialize response
507        match serde_json::to_value(&response) {
508            Ok(v) => v,
509            Err(e) => {
510                serde_json::json!({
511                    "jsonrpc": "2.0",
512                    "error": {
513                        "code": -32603,
514                        "message": format!("Internal error: failed to serialize response: {}", e)
515                    },
516                    "id": response.id
517                })
518            }
519        }
520    }
521}
522
523// Comprehensive tests in separate file (tokio/axum pattern)
524#[cfg(test)]
525mod tests;