1mod bidirectional;
8mod config;
9mod handlers;
10mod traits;
11mod utils;
12mod validation;
13
14pub use bidirectional::BidirectionalRouter;
16pub use config::RouterConfig;
17pub use traits::{Route, RouteHandler, RouteMetadata, ServerRequestDispatcher};
18
19use dashmap::DashMap;
20use futures::stream::{self, StreamExt};
21use std::collections::HashMap;
22use std::sync::Arc;
23use tracing::warn;
24use turbomcp_protocol::RequestContext;
25use turbomcp_protocol::{
26 jsonrpc::{JsonRpcRequest, JsonRpcResponse},
27 types::{
28 CreateMessageRequest, ElicitRequest, ElicitResult, ListRootsResult, PingRequest, PingResult,
29 },
30};
31
32use crate::capabilities::ServerToClientAdapter;
33use crate::metrics::ServerMetrics;
34use crate::registry::HandlerRegistry;
35use crate::{ServerError, ServerResult};
36
37use handlers::{HandlerContext, ProtocolHandlers};
38use turbomcp_protocol::context::capabilities::ServerToClientRequests;
39use utils::{error_response, method_not_found_response};
40use validation::{validate_request, validate_response};
41
42pub struct RequestRouter {
44 registry: Arc<HandlerRegistry>,
46 config: RouterConfig,
48 custom_routes: HashMap<String, Arc<dyn RouteHandler>>,
50 #[allow(dead_code)]
52 resource_subscriptions: DashMap<String, usize>,
53 bidirectional: BidirectionalRouter,
55 handlers: ProtocolHandlers,
57 server_to_client: Arc<dyn ServerToClientRequests>,
60}
61
62impl std::fmt::Debug for RequestRouter {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_struct("RequestRouter")
65 .field("config", &self.config)
66 .field("custom_routes_count", &self.custom_routes.len())
67 .finish()
68 }
69}
70
71impl RequestRouter {
72 #[must_use]
74 pub fn new(registry: Arc<HandlerRegistry>, _metrics: Arc<ServerMetrics>) -> Self {
75 let config = RouterConfig::default();
77
78 let handler_context = HandlerContext::new(Arc::clone(®istry));
79
80 let bidirectional = BidirectionalRouter::new();
81
82 let server_to_client: Arc<dyn ServerToClientRequests> =
85 Arc::new(ServerToClientAdapter::new(bidirectional.clone()));
86
87 Self {
88 registry,
89 config,
90 custom_routes: HashMap::new(),
91 resource_subscriptions: DashMap::new(),
92 bidirectional,
93 handlers: ProtocolHandlers::new(handler_context),
94 server_to_client,
95 }
96 }
97
98 #[must_use]
100 pub fn with_config(
101 registry: Arc<HandlerRegistry>,
102 config: RouterConfig,
103 _metrics: Arc<ServerMetrics>,
104 ) -> Self {
105 let handler_context = HandlerContext::new(Arc::clone(®istry));
108
109 let bidirectional = BidirectionalRouter::new();
110
111 let server_to_client: Arc<dyn ServerToClientRequests> =
114 Arc::new(ServerToClientAdapter::new(bidirectional.clone()));
115
116 Self {
117 registry,
118 config,
119 custom_routes: HashMap::new(),
120 resource_subscriptions: DashMap::new(),
121 bidirectional,
122 handlers: ProtocolHandlers::new(handler_context),
123 server_to_client,
124 }
125 }
126
127 pub fn set_server_request_dispatcher<D>(&mut self, dispatcher: D)
131 where
132 D: ServerRequestDispatcher + 'static,
133 {
134 self.bidirectional.set_dispatcher(dispatcher);
135 }
136
137 pub fn get_server_request_dispatcher(&self) -> Option<&Arc<dyn ServerRequestDispatcher>> {
139 self.bidirectional.get_dispatcher()
140 }
141
142 pub fn supports_bidirectional(&self) -> bool {
144 self.config.enable_bidirectional && self.bidirectional.supports_bidirectional()
145 }
146
147 pub fn add_route<H>(&mut self, handler: H) -> ServerResult<()>
153 where
154 H: RouteHandler + 'static,
155 {
156 let metadata = handler.metadata();
157 let handler_arc: Arc<dyn RouteHandler> = Arc::new(handler);
158
159 for method in &metadata.methods {
160 if self.custom_routes.contains_key(method) {
161 return Err(ServerError::routing_with_method(
162 format!("Route for method '{method}' already exists"),
163 method.clone(),
164 ));
165 }
166 self.custom_routes
167 .insert(method.clone(), Arc::clone(&handler_arc));
168 }
169
170 Ok(())
171 }
172
173 pub async fn route(&self, request: JsonRpcRequest, ctx: RequestContext) -> JsonRpcResponse {
175 let ctx = ctx.with_server_to_client(Arc::clone(&self.server_to_client));
178
179 if self.config.validate_requests
181 && let Err(e) = validate_request(&request)
182 {
183 return error_response(&request, e);
184 }
185
186 let result = match request.method.as_str() {
188 "initialize" => self.handlers.handle_initialize(request, ctx).await,
190
191 "tools/list" => self.handlers.handle_list_tools(request, ctx).await,
193 "tools/call" => self.handlers.handle_call_tool(request, ctx).await,
194
195 "prompts/list" => self.handlers.handle_list_prompts(request, ctx).await,
197 "prompts/get" => self.handlers.handle_get_prompt(request, ctx).await,
198
199 "resources/list" => self.handlers.handle_list_resources(request, ctx).await,
201 "resources/read" => self.handlers.handle_read_resource(request, ctx).await,
202 "resources/subscribe" => self.handlers.handle_subscribe_resource(request, ctx).await,
203 "resources/unsubscribe" => {
204 self.handlers
205 .handle_unsubscribe_resource(request, ctx)
206 .await
207 }
208
209 "logging/setLevel" => self.handlers.handle_set_log_level(request, ctx).await,
211
212 "sampling/createMessage" => self.handlers.handle_create_message(request, ctx).await,
214
215 "roots/list" => self.handlers.handle_list_roots(request, ctx).await,
217
218 "elicitation/create" => self.handlers.handle_elicitation(request, ctx).await,
220 "completion/complete" => self.handlers.handle_completion(request, ctx).await,
221 "resources/templates/list" => {
222 self.handlers
223 .handle_list_resource_templates(request, ctx)
224 .await
225 }
226 "ping" => self.handlers.handle_ping(request, ctx).await,
227
228 method => {
230 if let Some(handler) = self.custom_routes.get(method) {
231 let request_clone = request.clone();
232 handler
233 .handle(request, ctx)
234 .await
235 .unwrap_or_else(|e| error_response(&request_clone, e))
236 } else {
237 method_not_found_response(&request)
238 }
239 }
240 };
241
242 if self.config.validate_responses
244 && let Err(e) = validate_response(&result)
245 {
246 warn!("Response validation failed: {}", e);
247 }
248
249 result
250 }
251
252 pub async fn route_batch(
254 &self,
255 requests: Vec<JsonRpcRequest>,
256 ctx: RequestContext,
257 ) -> Vec<JsonRpcResponse> {
258 let max_in_flight = self.config.max_concurrent_requests.max(1);
260 stream::iter(requests.into_iter())
261 .map(|req| {
262 let ctx_cloned = ctx.clone();
263 async move { self.route(req, ctx_cloned).await }
264 })
265 .buffer_unordered(max_in_flight)
266 .collect()
267 .await
268 }
269
270 pub async fn send_elicitation_to_client(
279 &self,
280 request: ElicitRequest,
281 ctx: RequestContext,
282 ) -> ServerResult<ElicitResult> {
283 self.bidirectional
284 .send_elicitation_to_client(request, ctx)
285 .await
286 }
287
288 pub async fn send_ping_to_client(
297 &self,
298 request: PingRequest,
299 ctx: RequestContext,
300 ) -> ServerResult<PingResult> {
301 self.bidirectional.send_ping_to_client(request, ctx).await
302 }
303
304 pub async fn send_create_message_to_client(
313 &self,
314 request: CreateMessageRequest,
315 ctx: RequestContext,
316 ) -> ServerResult<turbomcp_protocol::types::CreateMessageResult> {
317 self.bidirectional
318 .send_create_message_to_client(request, ctx)
319 .await
320 }
321
322 pub async fn send_list_roots_to_client(
331 &self,
332 request: turbomcp_protocol::types::ListRootsRequest,
333 ctx: RequestContext,
334 ) -> ServerResult<ListRootsResult> {
335 self.bidirectional
336 .send_list_roots_to_client(request, ctx)
337 .await
338 }
339}
340
341impl Clone for RequestRouter {
342 fn clone(&self) -> Self {
343 Self {
344 registry: Arc::clone(&self.registry),
345 config: self.config.clone(),
346 custom_routes: self.custom_routes.clone(),
347 resource_subscriptions: DashMap::new(),
348 bidirectional: self.bidirectional.clone(),
349 handlers: ProtocolHandlers::new(HandlerContext::new(Arc::clone(&self.registry))),
350 server_to_client: Arc::clone(&self.server_to_client),
351 }
352 }
353}
354
355pub type Router = RequestRouter;
375
376#[async_trait::async_trait]
381impl turbomcp_protocol::JsonRpcHandler for RequestRouter {
382 async fn handle_request(&self, req_value: serde_json::Value) -> serde_json::Value {
398 let req: JsonRpcRequest = match serde_json::from_value(req_value) {
400 Ok(r) => r,
401 Err(e) => {
402 return serde_json::json!({
403 "jsonrpc": "2.0",
404 "error": {
405 "code": -32700,
406 "message": format!("Parse error: {}", e)
407 },
408 "id": null
409 });
410 }
411 };
412
413 let ctx = RequestContext::default();
416
417 let response = self.route(req, ctx).await;
419
420 match serde_json::to_value(&response) {
422 Ok(v) => v,
423 Err(e) => {
424 serde_json::json!({
425 "jsonrpc": "2.0",
426 "error": {
427 "code": -32603,
428 "message": format!("Internal error: failed to serialize response: {}", e)
429 },
430 "id": response.id
431 })
432 }
433 }
434 }
435}
436
437#[cfg(test)]
439mod tests;