turul_mcp_server/dispatch/
mod.rs

1use async_trait::async_trait;
2use serde_json::Value;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tracing::{debug, error, instrument};
6
7use crate::handlers::McpHandler;
8use crate::session::SessionContext;
9use turul_mcp_protocol::{McpError, McpResult};
10
11/// Advanced MCP request dispatcher with routing, middleware, and error handling
12pub struct McpDispatcher {
13    /// Route handlers mapped by method pattern
14    route_handlers: HashMap<String, Arc<dyn McpHandler>>,
15    /// Wildcard handlers for method patterns (e.g., "tools/*")
16    pattern_handlers: Vec<(String, Arc<dyn McpHandler>)>,
17    /// Middleware stack for request processing
18    middleware: Vec<Arc<dyn DispatchMiddleware>>,
19    /// Default fallback handler
20    default_handler: Option<Arc<dyn McpHandler>>,
21}
22
23/// Middleware trait for request processing pipeline
24#[async_trait]
25pub trait DispatchMiddleware: Send + Sync {
26    /// Process request before routing (return None to continue, Some(Value) to short-circuit)
27    async fn before_dispatch(
28        &self,
29        method: &str,
30        params: Option<&Value>,
31        session: Option<&SessionContext>,
32    ) -> Option<McpResult<Value>>;
33
34    /// Process response after routing
35    async fn after_dispatch(
36        &self,
37        method: &str,
38        result: &McpResult<Value>,
39        session: Option<&SessionContext>,
40    ) -> McpResult<Value>;
41}
42
43/// Request context for dispatch processing
44pub struct DispatchContext {
45    pub method: String,
46    pub params: Option<Value>,
47    pub session: Option<SessionContext>,
48    pub metadata: HashMap<String, Value>,
49}
50
51impl McpDispatcher {
52    pub fn new() -> Self {
53        Self {
54            route_handlers: HashMap::new(),
55            pattern_handlers: Vec::new(),
56            middleware: Vec::new(),
57            default_handler: None,
58        }
59    }
60
61    /// Register a handler for exact method matching
62    pub fn register_exact_handler(mut self, method: String, handler: Arc<dyn McpHandler>) -> Self {
63        self.route_handlers.insert(method, handler);
64        self
65    }
66
67    /// Register a handler for pattern matching (e.g., "tools/*")
68    pub fn register_pattern_handler(
69        mut self,
70        pattern: String,
71        handler: Arc<dyn McpHandler>,
72    ) -> Self {
73        self.pattern_handlers.push((pattern, handler));
74        self
75    }
76
77    /// Register middleware
78    pub fn register_middleware(mut self, middleware: Arc<dyn DispatchMiddleware>) -> Self {
79        self.middleware.push(middleware);
80        self
81    }
82
83    /// Set default fallback handler
84    pub fn set_default_handler(mut self, handler: Arc<dyn McpHandler>) -> Self {
85        self.default_handler = Some(handler);
86        self
87    }
88
89    /// Dispatch a request through the routing and middleware pipeline
90    #[instrument(skip(self, params, session))]
91    pub async fn dispatch(
92        &self,
93        method: &str,
94        params: Option<Value>,
95        session: Option<SessionContext>,
96    ) -> McpResult<Value> {
97        debug!("Dispatching request: method={}", method);
98
99        // Run before-dispatch middleware
100        for middleware in &self.middleware {
101            if let Some(result) = middleware
102                .before_dispatch(method, params.as_ref(), session.as_ref())
103                .await
104            {
105                debug!("Request short-circuited by middleware");
106                return result;
107            }
108        }
109
110        // Find appropriate handler
111        let handler = self.find_handler(method)?;
112
113        // Execute handler
114        let mut result = handler.handle_with_session(params, session.clone()).await;
115
116        // Run after-dispatch middleware (in reverse order)
117        for middleware in self.middleware.iter().rev() {
118            result = middleware
119                .after_dispatch(method, &result, session.as_ref())
120                .await;
121        }
122
123        result
124    }
125
126    /// Find the appropriate handler for a method
127    fn find_handler(&self, method: &str) -> McpResult<&Arc<dyn McpHandler>> {
128        // Try exact match first
129        if let Some(handler) = self.route_handlers.get(method) {
130            debug!("Found exact handler for method: {}", method);
131            return Ok(handler);
132        }
133
134        // Try pattern matching
135        for (pattern, handler) in &self.pattern_handlers {
136            if self.matches_pattern(method, pattern) {
137                debug!("Found pattern handler '{}' for method: {}", pattern, method);
138                return Ok(handler);
139            }
140        }
141
142        // Try default handler
143        if let Some(ref handler) = self.default_handler {
144            debug!("Using default handler for method: {}", method);
145            return Ok(handler);
146        }
147
148        error!("No handler found for method: {}", method);
149        Err(McpError::InvalidParameters(format!(
150            "Method not found: {}",
151            method
152        )))
153    }
154
155    /// Check if a method matches a pattern (supports wildcards)
156    fn matches_pattern(&self, method: &str, pattern: &str) -> bool {
157        if let Some(prefix) = pattern.strip_suffix("/*") {
158            method.starts_with(prefix) && method.len() > prefix.len()
159        } else if pattern.contains('*') {
160            // More sophisticated glob matching could be implemented here
161            false
162        } else {
163            method == pattern
164        }
165    }
166
167    /// Get all registered methods and patterns
168    pub fn get_supported_methods(&self) -> Vec<String> {
169        let mut methods = Vec::new();
170
171        // Add exact methods
172        methods.extend(self.route_handlers.keys().cloned());
173
174        // Add patterns
175        methods.extend(
176            self.pattern_handlers
177                .iter()
178                .map(|(pattern, _)| pattern.clone()),
179        );
180
181        methods.sort();
182        methods
183    }
184}
185
186impl Default for McpDispatcher {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192/// Logging middleware for request/response tracking
193pub struct LoggingMiddleware;
194
195#[async_trait]
196impl DispatchMiddleware for LoggingMiddleware {
197    async fn before_dispatch(
198        &self,
199        method: &str,
200        params: Option<&Value>,
201        session: Option<&SessionContext>,
202    ) -> Option<McpResult<Value>> {
203        let none_string = "none".to_string();
204        let session_id = session
205            .as_ref()
206            .map(|s| s.session_id.as_str())
207            .unwrap_or(&none_string);
208        debug!(
209            "Request: method={}, session={}, params={}",
210            method,
211            session_id,
212            params
213                .map(|p| p.to_string())
214                .unwrap_or_else(|| "none".to_string())
215        );
216        None
217    }
218
219    async fn after_dispatch(
220        &self,
221        method: &str,
222        result: &McpResult<Value>,
223        session: Option<&SessionContext>,
224    ) -> McpResult<Value> {
225        let none_string = "none".to_string();
226        let session_id = session
227            .as_ref()
228            .map(|s| s.session_id.as_str())
229            .unwrap_or(&none_string);
230        match result {
231            Ok(value) => {
232                debug!(
233                    "Response: method={}, session={}, success=true, result_keys={:?}",
234                    method,
235                    session_id,
236                    value.as_object().map(|o| o.keys().collect::<Vec<_>>())
237                );
238            }
239            Err(error) => {
240                debug!(
241                    "Response: method={}, session={}, error={}",
242                    method, session_id, error
243                );
244            }
245        }
246        match result {
247            Ok(value) => Ok(value.clone()),
248            Err(error) => Err(McpError::InvalidParameters(error.to_string())),
249        }
250    }
251}
252
253/// Rate limiting middleware
254pub struct RateLimitingMiddleware {
255    // Rate limiting could be implemented here
256    // For now, this is a placeholder
257}
258
259#[async_trait]
260impl DispatchMiddleware for RateLimitingMiddleware {
261    async fn before_dispatch(
262        &self,
263        _method: &str,
264        _params: Option<&Value>,
265        _session: Option<&SessionContext>,
266    ) -> Option<McpResult<Value>> {
267        // Rate limiting logic would go here
268        None
269    }
270
271    async fn after_dispatch(
272        &self,
273        _method: &str,
274        result: &McpResult<Value>,
275        _session: Option<&SessionContext>,
276    ) -> McpResult<Value> {
277        match result {
278            Ok(value) => Ok(value.clone()),
279            Err(error) => Err(McpError::InvalidParameters(error.to_string())),
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::handlers::McpHandler;
288
289    struct TestHandler {
290        response: Value,
291    }
292
293    #[async_trait]
294    impl McpHandler for TestHandler {
295        async fn handle(&self, _params: Option<Value>) -> McpResult<Value> {
296            Ok(self.response.clone())
297        }
298
299        fn supported_methods(&self) -> Vec<String> {
300            vec!["test".to_string()]
301        }
302    }
303
304    #[tokio::test]
305    async fn test_exact_routing() {
306        let handler = Arc::new(TestHandler {
307            response: Value::String("test_response".to_string()),
308        });
309
310        let dispatcher =
311            McpDispatcher::new().register_exact_handler("test/method".to_string(), handler);
312
313        let result = dispatcher
314            .dispatch("test/method", None, None)
315            .await
316            .unwrap();
317        assert_eq!(result, Value::String("test_response".to_string()));
318    }
319
320    #[tokio::test]
321    async fn test_pattern_routing() {
322        let handler = Arc::new(TestHandler {
323            response: Value::String("pattern_response".to_string()),
324        });
325
326        let dispatcher =
327            McpDispatcher::new().register_pattern_handler("tools/*".to_string(), handler);
328
329        let result = dispatcher.dispatch("tools/list", None, None).await.unwrap();
330        assert_eq!(result, Value::String("pattern_response".to_string()));
331    }
332
333    #[tokio::test]
334    async fn test_method_not_found() {
335        let dispatcher = McpDispatcher::new();
336
337        let result = dispatcher.dispatch("unknown/method", None, None).await;
338        assert!(result.is_err());
339        assert!(matches!(
340            result.unwrap_err(),
341            McpError::InvalidParameters(_)
342        ));
343    }
344}