turul_mcp_server/dispatch/
mod.rs1use async_trait::async_trait;
2use serde_json::Value;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tracing::{debug, error, instrument};
6
7use crate::handlers::McpHandler;
8use crate::session::SessionContext;
9use turul_mcp_protocol::{McpError, McpResult};
10
11pub struct McpDispatcher {
13 route_handlers: HashMap<String, Arc<dyn McpHandler>>,
15 pattern_handlers: Vec<(String, Arc<dyn McpHandler>)>,
17 middleware: Vec<Arc<dyn DispatchMiddleware>>,
19 default_handler: Option<Arc<dyn McpHandler>>,
21}
22
23#[async_trait]
25pub trait DispatchMiddleware: Send + Sync {
26 async fn before_dispatch(
28 &self,
29 method: &str,
30 params: Option<&Value>,
31 session: Option<&SessionContext>,
32 ) -> Option<McpResult<Value>>;
33
34 async fn after_dispatch(
36 &self,
37 method: &str,
38 result: &McpResult<Value>,
39 session: Option<&SessionContext>,
40 ) -> McpResult<Value>;
41}
42
43pub struct DispatchContext {
45 pub method: String,
46 pub params: Option<Value>,
47 pub session: Option<SessionContext>,
48 pub metadata: HashMap<String, Value>,
49}
50
51impl McpDispatcher {
52 pub fn new() -> Self {
53 Self {
54 route_handlers: HashMap::new(),
55 pattern_handlers: Vec::new(),
56 middleware: Vec::new(),
57 default_handler: None,
58 }
59 }
60
61 pub fn register_exact_handler(mut self, method: String, handler: Arc<dyn McpHandler>) -> Self {
63 self.route_handlers.insert(method, handler);
64 self
65 }
66
67 pub fn register_pattern_handler(
69 mut self,
70 pattern: String,
71 handler: Arc<dyn McpHandler>,
72 ) -> Self {
73 self.pattern_handlers.push((pattern, handler));
74 self
75 }
76
77 pub fn register_middleware(mut self, middleware: Arc<dyn DispatchMiddleware>) -> Self {
79 self.middleware.push(middleware);
80 self
81 }
82
83 pub fn set_default_handler(mut self, handler: Arc<dyn McpHandler>) -> Self {
85 self.default_handler = Some(handler);
86 self
87 }
88
89 #[instrument(skip(self, params, session))]
91 pub async fn dispatch(
92 &self,
93 method: &str,
94 params: Option<Value>,
95 session: Option<SessionContext>,
96 ) -> McpResult<Value> {
97 debug!("Dispatching request: method={}", method);
98
99 for middleware in &self.middleware {
101 if let Some(result) = middleware
102 .before_dispatch(method, params.as_ref(), session.as_ref())
103 .await
104 {
105 debug!("Request short-circuited by middleware");
106 return result;
107 }
108 }
109
110 let handler = self.find_handler(method)?;
112
113 let mut result = handler.handle_with_session(params, session.clone()).await;
115
116 for middleware in self.middleware.iter().rev() {
118 result = middleware
119 .after_dispatch(method, &result, session.as_ref())
120 .await;
121 }
122
123 result
124 }
125
126 fn find_handler(&self, method: &str) -> McpResult<&Arc<dyn McpHandler>> {
128 if let Some(handler) = self.route_handlers.get(method) {
130 debug!("Found exact handler for method: {}", method);
131 return Ok(handler);
132 }
133
134 for (pattern, handler) in &self.pattern_handlers {
136 if self.matches_pattern(method, pattern) {
137 debug!("Found pattern handler '{}' for method: {}", pattern, method);
138 return Ok(handler);
139 }
140 }
141
142 if let Some(ref handler) = self.default_handler {
144 debug!("Using default handler for method: {}", method);
145 return Ok(handler);
146 }
147
148 error!("No handler found for method: {}", method);
149 Err(McpError::InvalidParameters(format!(
150 "Method not found: {}",
151 method
152 )))
153 }
154
155 fn matches_pattern(&self, method: &str, pattern: &str) -> bool {
157 if let Some(prefix) = pattern.strip_suffix("/*") {
158 method.starts_with(prefix) && method.len() > prefix.len()
159 } else if pattern.contains('*') {
160 false
162 } else {
163 method == pattern
164 }
165 }
166
167 pub fn get_supported_methods(&self) -> Vec<String> {
169 let mut methods = Vec::new();
170
171 methods.extend(self.route_handlers.keys().cloned());
173
174 methods.extend(
176 self.pattern_handlers
177 .iter()
178 .map(|(pattern, _)| pattern.clone()),
179 );
180
181 methods.sort();
182 methods
183 }
184}
185
186impl Default for McpDispatcher {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192pub struct LoggingMiddleware;
194
195#[async_trait]
196impl DispatchMiddleware for LoggingMiddleware {
197 async fn before_dispatch(
198 &self,
199 method: &str,
200 params: Option<&Value>,
201 session: Option<&SessionContext>,
202 ) -> Option<McpResult<Value>> {
203 let none_string = "none".to_string();
204 let session_id = session
205 .as_ref()
206 .map(|s| s.session_id.as_str())
207 .unwrap_or(&none_string);
208 debug!(
209 "Request: method={}, session={}, params={}",
210 method,
211 session_id,
212 params
213 .map(|p| p.to_string())
214 .unwrap_or_else(|| "none".to_string())
215 );
216 None
217 }
218
219 async fn after_dispatch(
220 &self,
221 method: &str,
222 result: &McpResult<Value>,
223 session: Option<&SessionContext>,
224 ) -> McpResult<Value> {
225 let none_string = "none".to_string();
226 let session_id = session
227 .as_ref()
228 .map(|s| s.session_id.as_str())
229 .unwrap_or(&none_string);
230 match result {
231 Ok(value) => {
232 debug!(
233 "Response: method={}, session={}, success=true, result_keys={:?}",
234 method,
235 session_id,
236 value.as_object().map(|o| o.keys().collect::<Vec<_>>())
237 );
238 }
239 Err(error) => {
240 debug!(
241 "Response: method={}, session={}, error={}",
242 method, session_id, error
243 );
244 }
245 }
246 match result {
247 Ok(value) => Ok(value.clone()),
248 Err(error) => Err(McpError::InvalidParameters(error.to_string())),
249 }
250 }
251}
252
253pub struct RateLimitingMiddleware {
255 }
258
259#[async_trait]
260impl DispatchMiddleware for RateLimitingMiddleware {
261 async fn before_dispatch(
262 &self,
263 _method: &str,
264 _params: Option<&Value>,
265 _session: Option<&SessionContext>,
266 ) -> Option<McpResult<Value>> {
267 None
269 }
270
271 async fn after_dispatch(
272 &self,
273 _method: &str,
274 result: &McpResult<Value>,
275 _session: Option<&SessionContext>,
276 ) -> McpResult<Value> {
277 match result {
278 Ok(value) => Ok(value.clone()),
279 Err(error) => Err(McpError::InvalidParameters(error.to_string())),
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use crate::handlers::McpHandler;
288
289 struct TestHandler {
290 response: Value,
291 }
292
293 #[async_trait]
294 impl McpHandler for TestHandler {
295 async fn handle(&self, _params: Option<Value>) -> McpResult<Value> {
296 Ok(self.response.clone())
297 }
298
299 fn supported_methods(&self) -> Vec<String> {
300 vec!["test".to_string()]
301 }
302 }
303
304 #[tokio::test]
305 async fn test_exact_routing() {
306 let handler = Arc::new(TestHandler {
307 response: Value::String("test_response".to_string()),
308 });
309
310 let dispatcher =
311 McpDispatcher::new().register_exact_handler("test/method".to_string(), handler);
312
313 let result = dispatcher
314 .dispatch("test/method", None, None)
315 .await
316 .unwrap();
317 assert_eq!(result, Value::String("test_response".to_string()));
318 }
319
320 #[tokio::test]
321 async fn test_pattern_routing() {
322 let handler = Arc::new(TestHandler {
323 response: Value::String("pattern_response".to_string()),
324 });
325
326 let dispatcher =
327 McpDispatcher::new().register_pattern_handler("tools/*".to_string(), handler);
328
329 let result = dispatcher.dispatch("tools/list", None, None).await.unwrap();
330 assert_eq!(result, Value::String("pattern_response".to_string()));
331 }
332
333 #[tokio::test]
334 async fn test_method_not_found() {
335 let dispatcher = McpDispatcher::new();
336
337 let result = dispatcher.dispatch("unknown/method", None, None).await;
338 assert!(result.is_err());
339 assert!(matches!(
340 result.unwrap_err(),
341 McpError::InvalidParameters(_)
342 ));
343 }
344}