1use crate::plugins::core::{PluginResult, RequestContext, ResponseContext};
8use async_trait::async_trait;
9use std::sync::Arc;
10use tracing::{debug, error};
11
12pub type MiddlewareResult<T> = PluginResult<T>;
14
15#[async_trait]
20pub trait RequestMiddleware: Send + Sync + std::fmt::Debug {
21 async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()>;
29
30 fn name(&self) -> &str;
32}
33
34#[async_trait]
39pub trait ResponseMiddleware: Send + Sync + std::fmt::Debug {
40 async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()>;
48
49 fn name(&self) -> &str;
51}
52
53#[derive(Debug)]
70pub struct MiddlewareChain {
71 request_middleware: Vec<Arc<dyn RequestMiddleware>>,
73
74 response_middleware: Vec<Arc<dyn ResponseMiddleware>>,
76}
77
78impl Default for MiddlewareChain {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl MiddlewareChain {
85 pub fn new() -> Self {
87 Self {
88 request_middleware: Vec::new(),
89 response_middleware: Vec::new(),
90 }
91 }
92
93 pub fn add_request_middleware(&mut self, middleware: Arc<dyn RequestMiddleware>) {
100 debug!("Adding request middleware: {}", middleware.name());
101 self.request_middleware.push(middleware);
102 }
103
104 pub fn add_response_middleware(&mut self, middleware: Arc<dyn ResponseMiddleware>) {
111 debug!("Adding response middleware: {}", middleware.name());
112 self.response_middleware.push(middleware);
113 }
114
115 pub async fn execute_request_chain(
127 &self,
128 context: &mut RequestContext,
129 ) -> MiddlewareResult<()> {
130 debug!(
131 "Executing request middleware chain ({} middleware) for method: {}",
132 self.request_middleware.len(),
133 context.method()
134 );
135
136 for (index, middleware) in self.request_middleware.iter().enumerate() {
137 debug!(
138 "Processing request middleware {} of {}: {}",
139 index + 1,
140 self.request_middleware.len(),
141 middleware.name()
142 );
143
144 middleware.process_request(context).await.map_err(|e| {
145 error!(
146 "Request middleware '{}' failed for method '{}': {}",
147 middleware.name(),
148 context.method(),
149 e
150 );
151 e
152 })?;
153 }
154
155 debug!("Request middleware chain completed successfully");
156 Ok(())
157 }
158
159 pub async fn execute_response_chain(
171 &self,
172 context: &mut ResponseContext,
173 ) -> MiddlewareResult<()> {
174 debug!(
175 "Executing response middleware chain ({} middleware) for method: {}",
176 self.response_middleware.len(),
177 context.method()
178 );
179
180 let mut _last_error = None;
181
182 for (index, middleware) in self.response_middleware.iter().enumerate() {
183 debug!(
184 "Processing response middleware {} of {}: {}",
185 index + 1,
186 self.response_middleware.len(),
187 middleware.name()
188 );
189
190 if let Err(e) = middleware.process_response(context).await {
191 error!(
192 "Response middleware '{}' failed for method '{}': {}",
193 middleware.name(),
194 context.method(),
195 e
196 );
197 _last_error = Some(e);
198 }
200 }
201
202 debug!("Response middleware chain completed");
203
204 Ok(())
207 }
208
209 pub fn request_middleware_count(&self) -> usize {
211 self.request_middleware.len()
212 }
213
214 pub fn response_middleware_count(&self) -> usize {
216 self.response_middleware.len()
217 }
218
219 pub fn get_request_middleware_names(&self) -> Vec<String> {
221 self.request_middleware
222 .iter()
223 .map(|m| m.name().to_string())
224 .collect()
225 }
226
227 pub fn get_response_middleware_names(&self) -> Vec<String> {
229 self.response_middleware
230 .iter()
231 .map(|m| m.name().to_string())
232 .collect()
233 }
234
235 pub fn clear(&mut self) {
237 debug!("Clearing all middleware from chain");
238 self.request_middleware.clear();
239 self.response_middleware.clear();
240 }
241}
242
243#[derive(Debug)]
245pub struct PluginRequestMiddleware<P> {
246 plugin: P,
247}
248
249impl<P> PluginRequestMiddleware<P> {
250 pub fn new(plugin: P) -> Self {
252 Self { plugin }
253 }
254}
255
256#[async_trait]
257impl<P> RequestMiddleware for PluginRequestMiddleware<P>
258where
259 P: crate::plugins::core::ClientPlugin,
260{
261 async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()> {
262 self.plugin.before_request(context).await
263 }
264
265 fn name(&self) -> &str {
266 self.plugin.name()
267 }
268}
269
270#[derive(Debug)]
272pub struct PluginResponseMiddleware<P> {
273 plugin: P,
274}
275
276impl<P> PluginResponseMiddleware<P> {
277 pub fn new(plugin: P) -> Self {
279 Self { plugin }
280 }
281}
282
283#[async_trait]
284impl<P> ResponseMiddleware for PluginResponseMiddleware<P>
285where
286 P: crate::plugins::core::ClientPlugin,
287{
288 async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()> {
289 self.plugin.after_response(context).await
290 }
291
292 fn name(&self) -> &str {
293 self.plugin.name()
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::plugins::core::{PluginError, RequestContext};
301 use serde_json::json;
302 use std::collections::HashMap;
303 use std::sync::{Arc, Mutex};
304 use tokio;
305 use turbomcp_core::MessageId;
306 use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
307
308 #[derive(Debug)]
310 struct TestRequestMiddleware {
311 name: String,
312 calls: Arc<Mutex<Vec<String>>>,
313 should_fail: bool,
314 }
315
316 impl TestRequestMiddleware {
317 fn new(name: &str) -> Self {
318 Self {
319 name: name.to_string(),
320 calls: Arc::new(Mutex::new(Vec::new())),
321 should_fail: false,
322 }
323 }
324
325 fn with_failure(mut self) -> Self {
326 self.should_fail = true;
327 self
328 }
329
330 fn get_calls(&self) -> Vec<String> {
331 self.calls.lock().unwrap().clone()
332 }
333 }
334
335 #[async_trait]
336 impl RequestMiddleware for TestRequestMiddleware {
337 async fn process_request(&self, context: &mut RequestContext) -> MiddlewareResult<()> {
338 self.calls
339 .lock()
340 .unwrap()
341 .push(format!("process_request:{}", context.method()));
342
343 if self.should_fail {
344 Err(PluginError::request_processing("Test middleware failure"))
345 } else {
346 Ok(())
347 }
348 }
349
350 fn name(&self) -> &str {
351 &self.name
352 }
353 }
354
355 #[derive(Debug)]
356 struct TestResponseMiddleware {
357 name: String,
358 calls: Arc<Mutex<Vec<String>>>,
359 should_fail: bool,
360 }
361
362 impl TestResponseMiddleware {
363 fn new(name: &str) -> Self {
364 Self {
365 name: name.to_string(),
366 calls: Arc::new(Mutex::new(Vec::new())),
367 should_fail: false,
368 }
369 }
370
371 fn with_failure(mut self) -> Self {
372 self.should_fail = true;
373 self
374 }
375
376 fn get_calls(&self) -> Vec<String> {
377 self.calls.lock().unwrap().clone()
378 }
379 }
380
381 #[async_trait]
382 impl ResponseMiddleware for TestResponseMiddleware {
383 async fn process_response(&self, context: &mut ResponseContext) -> MiddlewareResult<()> {
384 self.calls
385 .lock()
386 .unwrap()
387 .push(format!("process_response:{}", context.method()));
388
389 if self.should_fail {
390 Err(PluginError::response_processing("Test middleware failure"))
391 } else {
392 Ok(())
393 }
394 }
395
396 fn name(&self) -> &str {
397 &self.name
398 }
399 }
400
401 #[tokio::test]
402 async fn test_middleware_chain_creation() {
403 let chain = MiddlewareChain::new();
404 assert_eq!(chain.request_middleware_count(), 0);
405 assert_eq!(chain.response_middleware_count(), 0);
406 }
407
408 #[tokio::test]
409 async fn test_request_middleware_registration() {
410 let mut chain = MiddlewareChain::new();
411 let middleware = Arc::new(TestRequestMiddleware::new("test"));
412
413 chain.add_request_middleware(middleware);
414
415 assert_eq!(chain.request_middleware_count(), 1);
416 assert_eq!(chain.get_request_middleware_names(), vec!["test"]);
417 }
418
419 #[tokio::test]
420 async fn test_response_middleware_registration() {
421 let mut chain = MiddlewareChain::new();
422 let middleware = Arc::new(TestResponseMiddleware::new("test"));
423
424 chain.add_response_middleware(middleware);
425
426 assert_eq!(chain.response_middleware_count(), 1);
427 assert_eq!(chain.get_response_middleware_names(), vec!["test"]);
428 }
429
430 #[tokio::test]
431 async fn test_request_middleware_execution() {
432 let mut chain = MiddlewareChain::new();
433 let middleware = Arc::new(TestRequestMiddleware::new("test"));
434
435 chain.add_request_middleware(middleware.clone());
436
437 let request = JsonRpcRequest {
438 jsonrpc: JsonRpcVersion,
439 id: MessageId::from("test"),
440 method: "test/method".to_string(),
441 params: None,
442 };
443
444 let mut context = RequestContext::new(request, HashMap::new());
445 chain.execute_request_chain(&mut context).await.unwrap();
446
447 let calls = middleware.get_calls();
448 assert!(calls.contains(&"process_request:test/method".to_string()));
449 }
450
451 #[tokio::test]
452 async fn test_response_middleware_execution() {
453 let mut chain = MiddlewareChain::new();
454 let middleware = Arc::new(TestResponseMiddleware::new("test"));
455
456 chain.add_response_middleware(middleware.clone());
457
458 let request = JsonRpcRequest {
459 jsonrpc: JsonRpcVersion,
460 id: MessageId::from("test"),
461 method: "test/method".to_string(),
462 params: None,
463 };
464
465 let request_context = RequestContext::new(request, HashMap::new());
466 let mut response_context = ResponseContext::new(
467 request_context,
468 Some(json!({"result": "success"})),
469 None,
470 std::time::Duration::from_millis(100),
471 );
472
473 chain
474 .execute_response_chain(&mut response_context)
475 .await
476 .unwrap();
477
478 let calls = middleware.get_calls();
479 assert!(calls.contains(&"process_response:test/method".to_string()));
480 }
481
482 #[tokio::test]
483 async fn test_request_middleware_error_handling() {
484 let mut chain = MiddlewareChain::new();
485 let good_middleware = Arc::new(TestRequestMiddleware::new("good"));
486 let bad_middleware = Arc::new(TestRequestMiddleware::new("bad").with_failure());
487
488 chain.add_request_middleware(good_middleware.clone());
489 chain.add_request_middleware(bad_middleware.clone());
490
491 let request = JsonRpcRequest {
492 jsonrpc: JsonRpcVersion,
493 id: MessageId::from("test"),
494 method: "test/method".to_string(),
495 params: None,
496 };
497
498 let mut context = RequestContext::new(request, HashMap::new());
499 let result = chain.execute_request_chain(&mut context).await;
500
501 assert!(result.is_err());
502 assert!(
503 good_middleware
504 .get_calls()
505 .contains(&"process_request:test/method".to_string())
506 );
507 assert!(
508 bad_middleware
509 .get_calls()
510 .contains(&"process_request:test/method".to_string())
511 );
512 }
513
514 #[tokio::test]
515 async fn test_response_middleware_error_handling() {
516 let mut chain = MiddlewareChain::new();
517 let good_middleware = Arc::new(TestResponseMiddleware::new("good"));
518 let bad_middleware = Arc::new(TestResponseMiddleware::new("bad").with_failure());
519
520 chain.add_response_middleware(good_middleware.clone());
521 chain.add_response_middleware(bad_middleware.clone());
522
523 let request = JsonRpcRequest {
524 jsonrpc: JsonRpcVersion,
525 id: MessageId::from("test"),
526 method: "test/method".to_string(),
527 params: None,
528 };
529
530 let request_context = RequestContext::new(request, HashMap::new());
531 let mut response_context = ResponseContext::new(
532 request_context,
533 Some(json!({"result": "success"})),
534 None,
535 std::time::Duration::from_millis(100),
536 );
537
538 let result = chain.execute_response_chain(&mut response_context).await;
540 assert!(result.is_ok());
541
542 assert!(
543 good_middleware
544 .get_calls()
545 .contains(&"process_response:test/method".to_string())
546 );
547 assert!(
548 bad_middleware
549 .get_calls()
550 .contains(&"process_response:test/method".to_string())
551 );
552 }
553
554 #[tokio::test]
555 async fn test_middleware_execution_order() {
556 let mut chain = MiddlewareChain::new();
557 let middleware1 = Arc::new(TestRequestMiddleware::new("first"));
558 let middleware2 = Arc::new(TestRequestMiddleware::new("second"));
559 let middleware3 = Arc::new(TestRequestMiddleware::new("third"));
560
561 chain.add_request_middleware(middleware1.clone());
562 chain.add_request_middleware(middleware2.clone());
563 chain.add_request_middleware(middleware3.clone());
564
565 let request = JsonRpcRequest {
566 jsonrpc: JsonRpcVersion,
567 id: MessageId::from("test"),
568 method: "test/method".to_string(),
569 params: None,
570 };
571
572 let mut context = RequestContext::new(request, HashMap::new());
573 chain.execute_request_chain(&mut context).await.unwrap();
574
575 assert!(
577 middleware1
578 .get_calls()
579 .contains(&"process_request:test/method".to_string())
580 );
581 assert!(
582 middleware2
583 .get_calls()
584 .contains(&"process_request:test/method".to_string())
585 );
586 assert!(
587 middleware3
588 .get_calls()
589 .contains(&"process_request:test/method".to_string())
590 );
591
592 let names = chain.get_request_middleware_names();
594 assert_eq!(names, vec!["first", "second", "third"]);
595 }
596
597 #[tokio::test]
598 async fn test_chain_clear() {
599 let mut chain = MiddlewareChain::new();
600 let req_middleware = Arc::new(TestRequestMiddleware::new("request"));
601 let resp_middleware = Arc::new(TestResponseMiddleware::new("response"));
602
603 chain.add_request_middleware(req_middleware);
604 chain.add_response_middleware(resp_middleware);
605
606 assert_eq!(chain.request_middleware_count(), 1);
607 assert_eq!(chain.response_middleware_count(), 1);
608
609 chain.clear();
610
611 assert_eq!(chain.request_middleware_count(), 0);
612 assert_eq!(chain.response_middleware_count(), 0);
613 }
614}