turbomcp_client/plugins/
middleware.rs

1//! Middleware pattern implementation for plugin system
2//!
3//! Provides middleware abstractions and chain execution patterns for
4//! request/response processing. This module focuses on the middleware
5//! pattern specifically, allowing plugins to be composed as middleware.
6
7use crate::plugins::core::{PluginResult, RequestContext, ResponseContext};
8use async_trait::async_trait;
9use std::sync::Arc;
10use tracing::{debug, error};
11
12/// Result type for middleware operations
13pub type MiddlewareResult<T> = PluginResult<T>;
14
15/// Trait for request middleware
16///
17/// Request middleware can modify the request before it's sent to the server.
18/// They are executed in the order they are registered.
19#[async_trait]
20pub trait RequestMiddleware: Send + Sync + std::fmt::Debug {
21    /// Process the request context
22    ///
23    /// # Arguments
24    /// * `context` - Mutable request context that can be modified
25    ///
26    /// # Returns
27    /// Returns `Ok(())` to continue processing, or `PluginError` to abort.
28    async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()>;
29
30    /// Get middleware name for debugging
31    fn name(&self) -> &str;
32}
33
34/// Trait for response middleware
35///
36/// Response middleware process responses after they're received from the server.
37/// They are executed in the order they are registered.
38#[async_trait]
39pub trait ResponseMiddleware: Send + Sync + std::fmt::Debug {
40    /// Process the response context
41    ///
42    /// # Arguments
43    /// * `context` - Mutable response context that can be modified
44    ///
45    /// # Returns
46    /// Returns `Ok(())` if processing succeeds, or `PluginError` if it fails.
47    async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()>;
48
49    /// Get middleware name for debugging
50    fn name(&self) -> &str;
51}
52
53/// Chain of middleware for sequential execution
54///
55/// The MiddlewareChain manages the execution of multiple middleware
56/// components in a defined order. It provides error handling and
57/// short-circuiting behavior.
58///
59/// # Examples
60///
61/// ```rust,no_run
62/// use turbomcp_client::plugins::middleware::{MiddlewareChain, RequestMiddleware};
63/// use std::sync::Arc;
64///
65/// let mut chain = MiddlewareChain::new();
66/// // chain.add_request_middleware(Arc::new(some_middleware));
67/// // chain.add_response_middleware(Arc::new(other_middleware));
68/// ```
69#[derive(Debug)]
70pub struct MiddlewareChain {
71    /// Request middleware in execution order
72    request_middleware: Vec<Arc<dyn RequestMiddleware>>,
73
74    /// Response middleware in execution order
75    response_middleware: Vec<Arc<dyn ResponseMiddleware>>,
76}
77
78impl Default for MiddlewareChain {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl MiddlewareChain {
85    /// Create a new empty middleware chain
86    pub fn new() -> Self {
87        Self {
88            request_middleware: Vec::new(),
89            response_middleware: Vec::new(),
90        }
91    }
92
93    /// Add request middleware to the chain
94    ///
95    /// Middleware will be executed in the order they are added.
96    ///
97    /// # Arguments
98    /// * `middleware` - The request middleware to add
99    pub fn add_request_middleware(&mut self, middleware: Arc<dyn RequestMiddleware>) {
100        debug!("Adding request middleware: {}", middleware.name());
101        self.request_middleware.push(middleware);
102    }
103
104    /// Add response middleware to the chain
105    ///
106    /// Middleware will be executed in the order they are added.
107    ///
108    /// # Arguments
109    /// * `middleware` - The response middleware to add
110    pub fn add_response_middleware(&mut self, middleware: Arc<dyn ResponseMiddleware>) {
111        debug!("Adding response middleware: {}", middleware.name());
112        self.response_middleware.push(middleware);
113    }
114
115    /// Execute the request middleware chain
116    ///
117    /// Processes the request context through all registered request middleware
118    /// in order. If any middleware returns an error, processing is aborted
119    /// and the error is returned.
120    ///
121    /// # Arguments
122    /// * `context` - Mutable request context
123    ///
124    /// # Returns
125    /// Returns `Ok(())` if all middleware succeed, or the first error encountered.
126    pub async fn execute_request_chain(
127        &self,
128        context: &mut RequestContext,
129    ) -> MiddlewareResult<()> {
130        debug!(
131            "Executing request middleware chain ({} middleware) for method: {}",
132            self.request_middleware.len(),
133            context.method()
134        );
135
136        for (index, middleware) in self.request_middleware.iter().enumerate() {
137            debug!(
138                "Processing request middleware {} of {}: {}",
139                index + 1,
140                self.request_middleware.len(),
141                middleware.name()
142            );
143
144            middleware.process_request(context).await.map_err(|e| {
145                error!(
146                    "Request middleware '{}' failed for method '{}': {}",
147                    middleware.name(),
148                    context.method(),
149                    e
150                );
151                e
152            })?;
153        }
154
155        debug!("Request middleware chain completed successfully");
156        Ok(())
157    }
158
159    /// Execute the response middleware chain
160    ///
161    /// Processes the response context through all registered response middleware
162    /// in order. Unlike request middleware, this continues execution even if
163    /// a middleware fails, logging errors but not aborting the chain.
164    ///
165    /// # Arguments
166    /// * `context` - Mutable response context
167    ///
168    /// # Returns
169    /// Returns `Ok(())` unless all middleware fail, in which case returns the last error.
170    pub async fn execute_response_chain(
171        &self,
172        context: &mut ResponseContext,
173    ) -> MiddlewareResult<()> {
174        debug!(
175            "Executing response middleware chain ({} middleware) for method: {}",
176            self.response_middleware.len(),
177            context.method()
178        );
179
180        let mut _last_error = None;
181
182        for (index, middleware) in self.response_middleware.iter().enumerate() {
183            debug!(
184                "Processing response middleware {} of {}: {}",
185                index + 1,
186                self.response_middleware.len(),
187                middleware.name()
188            );
189
190            if let Err(e) = middleware.process_response(context).await {
191                error!(
192                    "Response middleware '{}' failed for method '{}': {}",
193                    middleware.name(),
194                    context.method(),
195                    e
196                );
197                _last_error = Some(e);
198                // Continue with other middleware
199            }
200        }
201
202        debug!("Response middleware chain completed");
203
204        // For now, we don't propagate response middleware errors
205        // as they shouldn't break the response processing
206        Ok(())
207    }
208
209    /// Get the number of request middleware
210    pub fn request_middleware_count(&self) -> usize {
211        self.request_middleware.len()
212    }
213
214    /// Get the number of response middleware
215    pub fn response_middleware_count(&self) -> usize {
216        self.response_middleware.len()
217    }
218
219    /// Get names of all request middleware
220    pub fn get_request_middleware_names(&self) -> Vec<String> {
221        self.request_middleware
222            .iter()
223            .map(|m| m.name().to_string())
224            .collect()
225    }
226
227    /// Get names of all response middleware
228    pub fn get_response_middleware_names(&self) -> Vec<String> {
229        self.response_middleware
230            .iter()
231            .map(|m| m.name().to_string())
232            .collect()
233    }
234
235    /// Clear all middleware
236    pub fn clear(&mut self) {
237        debug!("Clearing all middleware from chain");
238        self.request_middleware.clear();
239        self.response_middleware.clear();
240    }
241}
242
243/// Adapter to use a ClientPlugin as RequestMiddleware
244#[derive(Debug)]
245pub struct PluginRequestMiddleware<P> {
246    plugin: P,
247}
248
249impl<P> PluginRequestMiddleware<P> {
250    /// Create a new plugin request middleware adapter
251    pub fn new(plugin: P) -> Self {
252        Self { plugin }
253    }
254}
255
256#[async_trait]
257impl<P> RequestMiddleware for PluginRequestMiddleware<P>
258where
259    P: crate::plugins::core::ClientPlugin,
260{
261    async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()> {
262        self.plugin.before_request(context).await
263    }
264
265    fn name(&self) -> &str {
266        self.plugin.name()
267    }
268}
269
270/// Adapter to use a ClientPlugin as ResponseMiddleware
271#[derive(Debug)]
272pub struct PluginResponseMiddleware<P> {
273    plugin: P,
274}
275
276impl<P> PluginResponseMiddleware<P> {
277    /// Create a new plugin response middleware adapter
278    pub fn new(plugin: P) -> Self {
279        Self { plugin }
280    }
281}
282
283#[async_trait]
284impl<P> ResponseMiddleware for PluginResponseMiddleware<P>
285where
286    P: crate::plugins::core::ClientPlugin,
287{
288    async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()> {
289        self.plugin.after_response(context).await
290    }
291
292    fn name(&self) -> &str {
293        self.plugin.name()
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::plugins::core::{PluginError, RequestContext};
301    use serde_json::json;
302    use std::collections::HashMap;
303    use std::sync::{Arc, Mutex};
304    use tokio;
305    use turbomcp_core::MessageId;
306    use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
307
308    // Test middleware implementations
309    #[derive(Debug)]
310    struct TestRequestMiddleware {
311        name: String,
312        calls: Arc<Mutex<Vec<String>>>,
313        should_fail: bool,
314    }
315
316    impl TestRequestMiddleware {
317        fn new(name: &str) -> Self {
318            Self {
319                name: name.to_string(),
320                calls: Arc::new(Mutex::new(Vec::new())),
321                should_fail: false,
322            }
323        }
324
325        fn with_failure(mut self) -> Self {
326            self.should_fail = true;
327            self
328        }
329
330        fn get_calls(&self) -> Vec<String> {
331            self.calls.lock().unwrap().clone()
332        }
333    }
334
335    #[async_trait]
336    impl RequestMiddleware for TestRequestMiddleware {
337        async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()> {
338            self.calls
339                .lock()
340                .unwrap()
341                .push(format!("process_request:{}", context.method()));
342
343            if self.should_fail {
344                Err(PluginError::request_processing("Test middleware failure"))
345            } else {
346                Ok(())
347            }
348        }
349
350        fn name(&self) -> &str {
351            &self.name
352        }
353    }
354
355    #[derive(Debug)]
356    struct TestResponseMiddleware {
357        name: String,
358        calls: Arc<Mutex<Vec<String>>>,
359        should_fail: bool,
360    }
361
362    impl TestResponseMiddleware {
363        fn new(name: &str) -> Self {
364            Self {
365                name: name.to_string(),
366                calls: Arc::new(Mutex::new(Vec::new())),
367                should_fail: false,
368            }
369        }
370
371        fn with_failure(mut self) -> Self {
372            self.should_fail = true;
373            self
374        }
375
376        fn get_calls(&self) -> Vec<String> {
377            self.calls.lock().unwrap().clone()
378        }
379    }
380
381    #[async_trait]
382    impl ResponseMiddleware for TestResponseMiddleware {
383        async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()> {
384            self.calls
385                .lock()
386                .unwrap()
387                .push(format!("process_response:{}", context.method()));
388
389            if self.should_fail {
390                Err(PluginError::response_processing("Test middleware failure"))
391            } else {
392                Ok(())
393            }
394        }
395
396        fn name(&self) -> &str {
397            &self.name
398        }
399    }
400
401    #[tokio::test]
402    async fn test_middleware_chain_creation() {
403        let chain = MiddlewareChain::new();
404        assert_eq!(chain.request_middleware_count(), 0);
405        assert_eq!(chain.response_middleware_count(), 0);
406    }
407
408    #[tokio::test]
409    async fn test_request_middleware_registration() {
410        let mut chain = MiddlewareChain::new();
411        let middleware = Arc::new(TestRequestMiddleware::new("test"));
412
413        chain.add_request_middleware(middleware);
414
415        assert_eq!(chain.request_middleware_count(), 1);
416        assert_eq!(chain.get_request_middleware_names(), vec!["test"]);
417    }
418
419    #[tokio::test]
420    async fn test_response_middleware_registration() {
421        let mut chain = MiddlewareChain::new();
422        let middleware = Arc::new(TestResponseMiddleware::new("test"));
423
424        chain.add_response_middleware(middleware);
425
426        assert_eq!(chain.response_middleware_count(), 1);
427        assert_eq!(chain.get_response_middleware_names(), vec!["test"]);
428    }
429
430    #[tokio::test]
431    async fn test_request_middleware_execution() {
432        let mut chain = MiddlewareChain::new();
433        let middleware = Arc::new(TestRequestMiddleware::new("test"));
434
435        chain.add_request_middleware(middleware.clone());
436
437        let request = JsonRpcRequest {
438            jsonrpc: JsonRpcVersion,
439            id: MessageId::from("test"),
440            method: "test/method".to_string(),
441            params: None,
442        };
443
444        let mut context = RequestContext::new(request, HashMap::new());
445        chain.execute_request_chain(&mut context).await.unwrap();
446
447        let calls = middleware.get_calls();
448        assert!(calls.contains(&"process_request:test/method".to_string()));
449    }
450
451    #[tokio::test]
452    async fn test_response_middleware_execution() {
453        let mut chain = MiddlewareChain::new();
454        let middleware = Arc::new(TestResponseMiddleware::new("test"));
455
456        chain.add_response_middleware(middleware.clone());
457
458        let request = JsonRpcRequest {
459            jsonrpc: JsonRpcVersion,
460            id: MessageId::from("test"),
461            method: "test/method".to_string(),
462            params: None,
463        };
464
465        let request_context = RequestContext::new(request, HashMap::new());
466        let mut response_context = ResponseContext::new(
467            request_context,
468            Some(json!({"result": "success"})),
469            None,
470            std::time::Duration::from_millis(100),
471        );
472
473        chain
474            .execute_response_chain(&mut response_context)
475            .await
476            .unwrap();
477
478        let calls = middleware.get_calls();
479        assert!(calls.contains(&"process_response:test/method".to_string()));
480    }
481
482    #[tokio::test]
483    async fn test_request_middleware_error_handling() {
484        let mut chain = MiddlewareChain::new();
485        let good_middleware = Arc::new(TestRequestMiddleware::new("good"));
486        let bad_middleware = Arc::new(TestRequestMiddleware::new("bad").with_failure());
487
488        chain.add_request_middleware(good_middleware.clone());
489        chain.add_request_middleware(bad_middleware.clone());
490
491        let request = JsonRpcRequest {
492            jsonrpc: JsonRpcVersion,
493            id: MessageId::from("test"),
494            method: "test/method".to_string(),
495            params: None,
496        };
497
498        let mut context = RequestContext::new(request, HashMap::new());
499        let result = chain.execute_request_chain(&mut context).await;
500
501        assert!(result.is_err());
502        assert!(
503            good_middleware
504                .get_calls()
505                .contains(&"process_request:test/method".to_string())
506        );
507        assert!(
508            bad_middleware
509                .get_calls()
510                .contains(&"process_request:test/method".to_string())
511        );
512    }
513
514    #[tokio::test]
515    async fn test_response_middleware_error_handling() {
516        let mut chain = MiddlewareChain::new();
517        let good_middleware = Arc::new(TestResponseMiddleware::new("good"));
518        let bad_middleware = Arc::new(TestResponseMiddleware::new("bad").with_failure());
519
520        chain.add_response_middleware(good_middleware.clone());
521        chain.add_response_middleware(bad_middleware.clone());
522
523        let request = JsonRpcRequest {
524            jsonrpc: JsonRpcVersion,
525            id: MessageId::from("test"),
526            method: "test/method".to_string(),
527            params: None,
528        };
529
530        let request_context = RequestContext::new(request, HashMap::new());
531        let mut response_context = ResponseContext::new(
532            request_context,
533            Some(json!({"result": "success"})),
534            None,
535            std::time::Duration::from_millis(100),
536        );
537
538        // Response middleware continues even with errors
539        let result = chain.execute_response_chain(&mut response_context).await;
540        assert!(result.is_ok());
541
542        assert!(
543            good_middleware
544                .get_calls()
545                .contains(&"process_response:test/method".to_string())
546        );
547        assert!(
548            bad_middleware
549                .get_calls()
550                .contains(&"process_response:test/method".to_string())
551        );
552    }
553
554    #[tokio::test]
555    async fn test_middleware_execution_order() {
556        let mut chain = MiddlewareChain::new();
557        let middleware1 = Arc::new(TestRequestMiddleware::new("first"));
558        let middleware2 = Arc::new(TestRequestMiddleware::new("second"));
559        let middleware3 = Arc::new(TestRequestMiddleware::new("third"));
560
561        chain.add_request_middleware(middleware1.clone());
562        chain.add_request_middleware(middleware2.clone());
563        chain.add_request_middleware(middleware3.clone());
564
565        let request = JsonRpcRequest {
566            jsonrpc: JsonRpcVersion,
567            id: MessageId::from("test"),
568            method: "test/method".to_string(),
569            params: None,
570        };
571
572        let mut context = RequestContext::new(request, HashMap::new());
573        chain.execute_request_chain(&mut context).await.unwrap();
574
575        // All middleware should be called
576        assert!(
577            middleware1
578                .get_calls()
579                .contains(&"process_request:test/method".to_string())
580        );
581        assert!(
582            middleware2
583                .get_calls()
584                .contains(&"process_request:test/method".to_string())
585        );
586        assert!(
587            middleware3
588                .get_calls()
589                .contains(&"process_request:test/method".to_string())
590        );
591
592        // Check names are in order
593        let names = chain.get_request_middleware_names();
594        assert_eq!(names, vec!["first", "second", "third"]);
595    }
596
597    #[tokio::test]
598    async fn test_chain_clear() {
599        let mut chain = MiddlewareChain::new();
600        let req_middleware = Arc::new(TestRequestMiddleware::new("request"));
601        let resp_middleware = Arc::new(TestResponseMiddleware::new("response"));
602
603        chain.add_request_middleware(req_middleware);
604        chain.add_response_middleware(resp_middleware);
605
606        assert_eq!(chain.request_middleware_count(), 1);
607        assert_eq!(chain.response_middleware_count(), 1);
608
609        chain.clear();
610
611        assert_eq!(chain.request_middleware_count(), 0);
612        assert_eq!(chain.response_middleware_count(), 0);
613    }
614}