1use crate::plugins::core::{
8 ClientPlugin, PluginContext, PluginError, PluginResult, RequestContext, ResponseContext,
9};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::{debug, error, info, warn};
14
15#[derive(Debug)]
43pub struct PluginRegistry {
44 plugins: Vec<Arc<dyn ClientPlugin>>,
46
47 plugin_map: HashMap<String, usize>,
49
50 client_context: Option<PluginContext>,
52}
53
54impl Default for PluginRegistry {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl PluginRegistry {
61 pub fn new() -> Self {
63 Self {
64 plugins: Vec::new(),
65 plugin_map: HashMap::new(),
66 client_context: None,
67 }
68 }
69
70 pub fn set_client_context(&mut self, context: PluginContext) {
75 debug!(
76 "Setting client context: {} v{}",
77 context.client_name, context.client_version
78 );
79 self.client_context = Some(context);
80 }
81
82 pub async fn register_plugin(&mut self, plugin: Arc<dyn ClientPlugin>) -> PluginResult<()> {
101 let plugin_name = plugin.name().to_string();
102
103 info!("Registering plugin: {} v{}", plugin_name, plugin.version());
104
105 if self.plugin_map.contains_key(&plugin_name) {
107 return Err(PluginError::configuration(format!(
108 "Plugin '{}' is already registered",
109 plugin_name
110 )));
111 }
112
113 for dependency in plugin.dependencies() {
115 if !self.has_plugin(dependency) {
116 return Err(PluginError::dependency_not_available(dependency));
117 }
118 }
119
120 if let Some(context) = &self.client_context {
122 let mut updated_context = context.clone();
124 updated_context.available_plugins = self.get_plugin_names();
125
126 plugin.initialize(&updated_context).await.map_err(|e| {
127 error!("Failed to initialize plugin '{}': {}", plugin_name, e);
128 e
129 })?;
130 } else {
131 let context = PluginContext::new(
133 "unknown".to_string(),
134 "unknown".to_string(),
135 HashMap::new(),
136 HashMap::new(),
137 self.get_plugin_names(),
138 );
139 plugin.initialize(&context).await.map_err(|e| {
140 error!("Failed to initialize plugin '{}': {}", plugin_name, e);
141 e
142 })?;
143 }
144
145 let index = self.plugins.len();
147 self.plugins.push(plugin);
148 self.plugin_map.insert(plugin_name.clone(), index);
149
150 debug!(
151 "Plugin '{}' registered successfully at index {}",
152 plugin_name, index
153 );
154 Ok(())
155 }
156
157 pub async fn unregister_plugin(&mut self, plugin_name: &str) -> PluginResult<()> {
169 info!("Unregistering plugin: {}", plugin_name);
170
171 let index = self.plugin_map.get(plugin_name).copied().ok_or_else(|| {
172 PluginError::configuration(format!("Plugin '{}' not found", plugin_name))
173 })?;
174
175 let plugin = self.plugins[index].clone();
177 plugin.cleanup().await.map_err(|e| {
178 warn!("Plugin '{}' cleanup failed: {}", plugin_name, e);
179 e
180 })?;
181
182 self.plugins.remove(index);
184 self.plugin_map.remove(plugin_name);
185
186 for (_, plugin_index) in self.plugin_map.iter_mut() {
188 if *plugin_index > index {
189 *plugin_index -= 1;
190 }
191 }
192
193 debug!("Plugin '{}' unregistered successfully", plugin_name);
194 Ok(())
195 }
196
197 pub fn has_plugin(&self, plugin_name: &str) -> bool {
199 self.plugin_map.contains_key(plugin_name)
200 }
201
202 pub fn get_plugin(&self, plugin_name: &str) -> Option<Arc<dyn ClientPlugin>> {
204 self.plugin_map
205 .get(plugin_name)
206 .and_then(|&index| self.plugins.get(index))
207 .cloned()
208 }
209
210 pub fn get_plugin_names(&self) -> Vec<String> {
212 self.plugins
213 .iter()
214 .map(|plugin| plugin.name().to_string())
215 .collect()
216 }
217
218 pub fn plugin_count(&self) -> usize {
220 self.plugins.len()
221 }
222
223 pub async fn execute_before_request(&self, context: &mut RequestContext) -> PluginResult<()> {
236 debug!(
237 "Executing before_request middleware chain for method: {}",
238 context.method()
239 );
240
241 for (index, plugin) in self.plugins.iter().enumerate() {
242 let plugin_name = plugin.name();
243 debug!(
244 "Calling before_request on plugin '{}' ({})",
245 plugin_name, index
246 );
247
248 plugin.before_request(context).await.map_err(|e| {
249 error!(
250 "Plugin '{}' before_request failed for method '{}': {}",
251 plugin_name,
252 context.method(),
253 e
254 );
255 e
256 })?;
257 }
258
259 debug!("Before_request middleware chain completed successfully");
260 Ok(())
261 }
262
263 pub async fn execute_after_response(&self, context: &mut ResponseContext) -> PluginResult<()> {
277 debug!(
278 "Executing after_response middleware chain for method: {}",
279 context.method()
280 );
281
282 let mut _last_error = None;
283
284 for (index, plugin) in self.plugins.iter().enumerate() {
285 let plugin_name = plugin.name();
286 debug!(
287 "Calling after_response on plugin '{}' ({})",
288 plugin_name, index
289 );
290
291 if let Err(e) = plugin.after_response(context).await {
292 error!(
293 "Plugin '{}' after_response failed for method '{}': {}",
294 plugin_name,
295 context.method(),
296 e
297 );
298 _last_error = Some(e);
299 }
301 }
302
303 debug!("After_response middleware chain completed");
304
305 Ok(())
308 }
309
310 pub async fn handle_custom_method(
325 &self,
326 method: &str,
327 params: Option<Value>,
328 ) -> PluginResult<Option<Value>> {
329 debug!("Handling custom method: {}", method);
330
331 for plugin in &self.plugins {
332 let plugin_name = plugin.name();
333 debug!(
334 "Checking if plugin '{}' can handle custom method '{}'",
335 plugin_name, method
336 );
337
338 match plugin.handle_custom(method, params.clone()).await {
339 Ok(Some(result)) => {
340 info!(
341 "Plugin '{}' handled custom method '{}'",
342 plugin_name, method
343 );
344 return Ok(Some(result));
345 }
346 Ok(None) => {
347 continue;
349 }
350 Err(e) => {
351 error!(
352 "Plugin '{}' failed to handle custom method '{}': {}",
353 plugin_name, method, e
354 );
355 return Err(e);
356 }
357 }
358 }
359
360 debug!("No plugin handled custom method: {}", method);
361 Ok(None)
362 }
363
364 pub fn get_plugin_info(&self) -> Vec<(String, String, Option<String>)> {
366 self.plugins
367 .iter()
368 .map(|plugin| {
369 (
370 plugin.name().to_string(),
371 plugin.version().to_string(),
372 plugin.description().map(|s| s.to_string()),
373 )
374 })
375 .collect()
376 }
377
378 pub fn validate_dependencies(&self) -> Result<(), Vec<String>> {
383 let mut errors = Vec::new();
384
385 for plugin in &self.plugins {
386 for dependency in plugin.dependencies() {
387 if !self.has_plugin(dependency) {
388 errors.push(format!(
389 "Plugin '{}' depends on '{}' which is not registered",
390 plugin.name(),
391 dependency
392 ));
393 }
394 }
395 }
396
397 if errors.is_empty() {
398 Ok(())
399 } else {
400 Err(errors)
401 }
402 }
403
404 pub async fn clear(&mut self) -> PluginResult<()> {
409 info!("Clearing all registered plugins");
410
411 let plugins = std::mem::take(&mut self.plugins);
412 self.plugin_map.clear();
413
414 for plugin in plugins {
415 let plugin_name = plugin.name();
416 if let Err(e) = plugin.cleanup().await {
417 warn!("Plugin '{}' cleanup failed: {}", plugin_name, e);
418 }
419 }
420
421 debug!("All plugins cleared successfully");
422 Ok(())
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use crate::plugins::core::PluginContext;
430 use async_trait::async_trait;
431 use serde_json::json;
432 use std::sync::Mutex;
433 use tokio;
434 use turbomcp_protocol::MessageId;
435 use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
436
437 #[derive(Debug)]
439 struct MockPlugin {
440 name: String,
441 calls: Arc<Mutex<Vec<String>>>,
442 should_fail_init: bool,
443 should_fail_before_request: bool,
444 }
445
446 impl MockPlugin {
447 fn new(name: &str) -> Self {
448 Self {
449 name: name.to_string(),
450 calls: Arc::new(Mutex::new(Vec::new())),
451 should_fail_init: false,
452 should_fail_before_request: false,
453 }
454 }
455
456 fn with_init_failure(mut self) -> Self {
457 self.should_fail_init = true;
458 self
459 }
460
461 fn with_request_failure(mut self) -> Self {
462 self.should_fail_before_request = true;
463 self
464 }
465
466 fn get_calls(&self) -> Vec<String> {
467 self.calls.lock().unwrap().clone()
468 }
469 }
470
471 #[async_trait]
472 impl ClientPlugin for MockPlugin {
473 fn name(&self) -> &str {
474 &self.name
475 }
476
477 fn version(&self) -> &str {
478 "1.0.0"
479 }
480
481 async fn initialize(&self, _context: &PluginContext) -> PluginResult<()> {
482 self.calls.lock().unwrap().push("initialize".to_string());
483 if self.should_fail_init {
484 Err(PluginError::initialization("Mock initialization failure"))
485 } else {
486 Ok(())
487 }
488 }
489
490 async fn before_request(&self, context: &mut RequestContext) -> PluginResult<()> {
491 self.calls
492 .lock()
493 .unwrap()
494 .push(format!("before_request:{}", context.method()));
495 if self.should_fail_before_request {
496 Err(PluginError::request_processing("Mock request failure"))
497 } else {
498 Ok(())
499 }
500 }
501
502 async fn after_response(&self, context: &mut ResponseContext) -> PluginResult<()> {
503 self.calls
504 .lock()
505 .unwrap()
506 .push(format!("after_response:{}", context.method()));
507 Ok(())
508 }
509
510 async fn handle_custom(
511 &self,
512 method: &str,
513 params: Option<Value>,
514 ) -> PluginResult<Option<Value>> {
515 self.calls
516 .lock()
517 .unwrap()
518 .push(format!("handle_custom:{}", method));
519 if method.starts_with(&format!("{}.", self.name)) {
520 Ok(params)
521 } else {
522 Ok(None)
523 }
524 }
525 }
526
527 #[tokio::test]
528 async fn test_registry_creation() {
529 let registry = PluginRegistry::new();
530 assert_eq!(registry.plugin_count(), 0);
531 assert!(registry.get_plugin_names().is_empty());
532 }
533
534 #[tokio::test]
535 async fn test_plugin_registration() {
536 let mut registry = PluginRegistry::new();
537 let plugin = Arc::new(MockPlugin::new("test"));
538
539 registry.register_plugin(plugin.clone()).await.unwrap();
540
541 assert_eq!(registry.plugin_count(), 1);
542 assert!(registry.has_plugin("test"));
543 assert_eq!(registry.get_plugin_names(), vec!["test"]);
544
545 let retrieved = registry.get_plugin("test").unwrap();
546 assert_eq!(retrieved.name(), "test");
547 }
548
549 #[tokio::test]
550 async fn test_duplicate_registration() {
551 let mut registry = PluginRegistry::new();
552 let plugin1 = Arc::new(MockPlugin::new("duplicate"));
553 let plugin2 = Arc::new(MockPlugin::new("duplicate"));
554
555 registry.register_plugin(plugin1).await.unwrap();
556 let result = registry.register_plugin(plugin2).await;
557
558 assert!(result.is_err());
559 assert_eq!(registry.plugin_count(), 1);
560 }
561
562 #[tokio::test]
563 async fn test_plugin_initialization_failure() {
564 let mut registry = PluginRegistry::new();
565 let plugin = Arc::new(MockPlugin::new("failing").with_init_failure());
566
567 let result = registry.register_plugin(plugin).await;
568
569 assert!(result.is_err());
570 assert_eq!(registry.plugin_count(), 0);
571 }
572
573 #[tokio::test]
574 async fn test_plugin_unregistration() {
575 let mut registry = PluginRegistry::new();
576 let plugin = Arc::new(MockPlugin::new("removable"));
577
578 registry.register_plugin(plugin).await.unwrap();
579 assert_eq!(registry.plugin_count(), 1);
580
581 registry.unregister_plugin("removable").await.unwrap();
582 assert_eq!(registry.plugin_count(), 0);
583 assert!(!registry.has_plugin("removable"));
584 }
585
586 #[tokio::test]
587 async fn test_before_request_middleware() {
588 let mut registry = PluginRegistry::new();
589 let plugin1 = Arc::new(MockPlugin::new("first"));
590 let plugin2 = Arc::new(MockPlugin::new("second"));
591
592 registry.register_plugin(plugin1.clone()).await.unwrap();
593 registry.register_plugin(plugin2.clone()).await.unwrap();
594
595 let request = JsonRpcRequest {
596 jsonrpc: JsonRpcVersion,
597 id: MessageId::from("test"),
598 method: "test/method".to_string(),
599 params: None,
600 };
601
602 let mut context = RequestContext::new(request, HashMap::new());
603 registry.execute_before_request(&mut context).await.unwrap();
604
605 assert!(
607 plugin1
608 .get_calls()
609 .contains(&"before_request:test/method".to_string())
610 );
611 assert!(
612 plugin2
613 .get_calls()
614 .contains(&"before_request:test/method".to_string())
615 );
616 }
617
618 #[tokio::test]
619 async fn test_before_request_error_handling() {
620 let mut registry = PluginRegistry::new();
621 let good_plugin = Arc::new(MockPlugin::new("good"));
622 let bad_plugin = Arc::new(MockPlugin::new("bad").with_request_failure());
623
624 registry.register_plugin(good_plugin.clone()).await.unwrap();
625 registry.register_plugin(bad_plugin.clone()).await.unwrap();
626
627 let request = JsonRpcRequest {
628 jsonrpc: JsonRpcVersion,
629 id: MessageId::from("test"),
630 method: "test/method".to_string(),
631 params: None,
632 };
633
634 let mut context = RequestContext::new(request, HashMap::new());
635 let result = registry.execute_before_request(&mut context).await;
636
637 assert!(result.is_err());
638 assert!(
639 good_plugin
640 .get_calls()
641 .contains(&"before_request:test/method".to_string())
642 );
643 assert!(
644 bad_plugin
645 .get_calls()
646 .contains(&"before_request:test/method".to_string())
647 );
648 }
649
650 #[tokio::test]
651 async fn test_custom_method_handling() {
652 let mut registry = PluginRegistry::new();
653 let plugin = Arc::new(MockPlugin::new("handler"));
654
655 registry.register_plugin(plugin.clone()).await.unwrap();
656
657 let result = registry
658 .handle_custom_method("handler.test", Some(json!({"data": "test"})))
659 .await
660 .unwrap();
661
662 assert!(result.is_some());
663 assert_eq!(result.unwrap(), json!({"data": "test"}));
664 assert!(
665 plugin
666 .get_calls()
667 .contains(&"handle_custom:handler.test".to_string())
668 );
669 }
670
671 #[tokio::test]
672 async fn test_custom_method_not_handled() {
673 let mut registry = PluginRegistry::new();
674 let plugin = Arc::new(MockPlugin::new("handler"));
675
676 registry.register_plugin(plugin.clone()).await.unwrap();
677
678 let result = registry
679 .handle_custom_method("other.method", None)
680 .await
681 .unwrap();
682
683 assert!(result.is_none());
684 assert!(
685 plugin
686 .get_calls()
687 .contains(&"handle_custom:other.method".to_string())
688 );
689 }
690
691 #[tokio::test]
692 async fn test_plugin_info() {
693 let mut registry = PluginRegistry::new();
694 let plugin = Arc::new(MockPlugin::new("info_test"));
695
696 registry.register_plugin(plugin).await.unwrap();
697
698 let info = registry.get_plugin_info();
699 assert_eq!(info.len(), 1);
700 assert_eq!(info[0].0, "info_test");
701 assert_eq!(info[0].1, "1.0.0");
702 }
703
704 #[tokio::test]
705 async fn test_clear_plugins() {
706 let mut registry = PluginRegistry::new();
707 let plugin1 = Arc::new(MockPlugin::new("first"));
708 let plugin2 = Arc::new(MockPlugin::new("second"));
709
710 registry.register_plugin(plugin1).await.unwrap();
711 registry.register_plugin(plugin2).await.unwrap();
712 assert_eq!(registry.plugin_count(), 2);
713
714 registry.clear().await.unwrap();
715 assert_eq!(registry.plugin_count(), 0);
716 assert!(registry.get_plugin_names().is_empty());
717 }
718}