1use crate::plugins::core::{
8 ClientPlugin, PluginContext, PluginError, PluginResult, RequestContext, ResponseContext,
9};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::{debug, error, info, warn};
14
15#[derive(Debug)]
43pub struct PluginRegistry {
44 plugins: Vec<Arc<dyn ClientPlugin>>,
46
47 plugin_map: HashMap<String, usize>,
49
50 client_context: Option<PluginContext>,
52}
53
54impl Default for PluginRegistry {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl PluginRegistry {
61 #[must_use]
63 pub fn new() -> Self {
64 Self {
65 plugins: Vec::new(),
66 plugin_map: HashMap::new(),
67 client_context: None,
68 }
69 }
70
71 pub fn set_client_context(&mut self, context: PluginContext) {
76 debug!(
77 "Setting client context: {} v{}",
78 context.client_name, context.client_version
79 );
80 self.client_context = Some(context);
81 }
82
83 pub async fn register_plugin(&mut self, plugin: Arc<dyn ClientPlugin>) -> PluginResult<()> {
102 let plugin_name = plugin.name().to_string();
103
104 info!("Registering plugin: {} v{}", plugin_name, plugin.version());
105
106 if self.plugin_map.contains_key(&plugin_name) {
108 return Err(PluginError::configuration(format!(
109 "Plugin '{}' is already registered",
110 plugin_name
111 )));
112 }
113
114 for dependency in plugin.dependencies() {
116 if !self.has_plugin(dependency) {
117 return Err(PluginError::dependency_not_available(dependency));
118 }
119 }
120
121 if let Some(context) = &self.client_context {
123 let mut updated_context = context.clone();
125 updated_context.available_plugins = self.get_plugin_names();
126
127 plugin.initialize(&updated_context).await.map_err(|e| {
128 error!("Failed to initialize plugin '{}': {}", plugin_name, e);
129 e
130 })?;
131 } else {
132 let context = PluginContext::new(
134 "unknown".to_string(),
135 "unknown".to_string(),
136 HashMap::new(),
137 HashMap::new(),
138 self.get_plugin_names(),
139 );
140 plugin.initialize(&context).await.map_err(|e| {
141 error!("Failed to initialize plugin '{}': {}", plugin_name, e);
142 e
143 })?;
144 }
145
146 let index = self.plugins.len();
148 self.plugins.push(plugin);
149 self.plugin_map.insert(plugin_name.clone(), index);
150
151 debug!(
152 "Plugin '{}' registered successfully at index {}",
153 plugin_name, index
154 );
155 Ok(())
156 }
157
158 pub async fn unregister_plugin(&mut self, plugin_name: &str) -> PluginResult<()> {
170 info!("Unregistering plugin: {}", plugin_name);
171
172 let index = self.plugin_map.get(plugin_name).copied().ok_or_else(|| {
173 PluginError::configuration(format!("Plugin '{}' not found", plugin_name))
174 })?;
175
176 let plugin = self.plugins[index].clone();
178 plugin.cleanup().await.map_err(|e| {
179 warn!("Plugin '{}' cleanup failed: {}", plugin_name, e);
180 e
181 })?;
182
183 self.plugins.remove(index);
185 self.plugin_map.remove(plugin_name);
186
187 for (_, plugin_index) in self.plugin_map.iter_mut() {
189 if *plugin_index > index {
190 *plugin_index -= 1;
191 }
192 }
193
194 debug!("Plugin '{}' unregistered successfully", plugin_name);
195 Ok(())
196 }
197
198 #[must_use]
200 pub fn has_plugin(&self, plugin_name: &str) -> bool {
201 self.plugin_map.contains_key(plugin_name)
202 }
203
204 #[must_use]
206 pub fn get_plugin(&self, plugin_name: &str) -> Option<Arc<dyn ClientPlugin>> {
207 self.plugin_map
208 .get(plugin_name)
209 .and_then(|&index| self.plugins.get(index))
210 .cloned()
211 }
212
213 #[must_use]
215 pub fn get_plugin_names(&self) -> Vec<String> {
216 self.plugins
217 .iter()
218 .map(|plugin| plugin.name().to_string())
219 .collect()
220 }
221
222 #[must_use]
224 pub fn plugin_count(&self) -> usize {
225 self.plugins.len()
226 }
227
228 pub async fn execute_before_request(&self, context: &mut RequestContext) -> PluginResult<()> {
241 debug!(
242 "Executing before_request middleware chain for method: {}",
243 context.method()
244 );
245
246 for (index, plugin) in self.plugins.iter().enumerate() {
247 let plugin_name = plugin.name();
248 debug!(
249 "Calling before_request on plugin '{}' ({})",
250 plugin_name, index
251 );
252
253 plugin.before_request(context).await.map_err(|e| {
254 error!(
255 "Plugin '{}' before_request failed for method '{}': {}",
256 plugin_name,
257 context.method(),
258 e
259 );
260 e
261 })?;
262 }
263
264 debug!("Before_request middleware chain completed successfully");
265 Ok(())
266 }
267
268 pub async fn execute_after_response(&self, context: &mut ResponseContext) -> PluginResult<()> {
282 debug!(
283 "Executing after_response middleware chain for method: {}",
284 context.method()
285 );
286
287 let mut _last_error = None;
288
289 for (index, plugin) in self.plugins.iter().enumerate() {
290 let plugin_name = plugin.name();
291 debug!(
292 "Calling after_response on plugin '{}' ({})",
293 plugin_name, index
294 );
295
296 if let Err(e) = plugin.after_response(context).await {
297 error!(
298 "Plugin '{}' after_response failed for method '{}': {}",
299 plugin_name,
300 context.method(),
301 e
302 );
303 _last_error = Some(e);
304 }
306 }
307
308 debug!("After_response middleware chain completed");
309
310 Ok(())
313 }
314
315 pub async fn handle_custom_method(
330 &self,
331 method: &str,
332 params: Option<Value>,
333 ) -> PluginResult<Option<Value>> {
334 debug!("Handling custom method: {}", method);
335
336 for plugin in &self.plugins {
337 let plugin_name = plugin.name();
338 debug!(
339 "Checking if plugin '{}' can handle custom method '{}'",
340 plugin_name, method
341 );
342
343 match plugin.handle_custom(method, params.clone()).await {
344 Ok(Some(result)) => {
345 info!(
346 "Plugin '{}' handled custom method '{}'",
347 plugin_name, method
348 );
349 return Ok(Some(result));
350 }
351 Ok(None) => {
352 continue;
354 }
355 Err(e) => {
356 error!(
357 "Plugin '{}' failed to handle custom method '{}': {}",
358 plugin_name, method, e
359 );
360 return Err(e);
361 }
362 }
363 }
364
365 debug!("No plugin handled custom method: {}", method);
366 Ok(None)
367 }
368
369 #[must_use]
371 pub fn get_plugin_info(&self) -> Vec<(String, String, Option<String>)> {
372 self.plugins
373 .iter()
374 .map(|plugin| {
375 (
376 plugin.name().to_string(),
377 plugin.version().to_string(),
378 plugin.description().map(|s| s.to_string()),
379 )
380 })
381 .collect()
382 }
383
384 pub fn validate_dependencies(&self) -> Result<(), Vec<String>> {
389 let mut errors = Vec::new();
390
391 for plugin in &self.plugins {
392 for dependency in plugin.dependencies() {
393 if !self.has_plugin(dependency) {
394 errors.push(format!(
395 "Plugin '{}' depends on '{}' which is not registered",
396 plugin.name(),
397 dependency
398 ));
399 }
400 }
401 }
402
403 if errors.is_empty() {
404 Ok(())
405 } else {
406 Err(errors)
407 }
408 }
409
410 pub async fn clear(&mut self) -> PluginResult<()> {
415 info!("Clearing all registered plugins");
416
417 let plugins = std::mem::take(&mut self.plugins);
418 self.plugin_map.clear();
419
420 for plugin in plugins {
421 let plugin_name = plugin.name();
422 if let Err(e) = plugin.cleanup().await {
423 warn!("Plugin '{}' cleanup failed: {}", plugin_name, e);
424 }
425 }
426
427 debug!("All plugins cleared successfully");
428 Ok(())
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use crate::plugins::core::PluginContext;
436 use async_trait::async_trait;
437 use serde_json::json;
438 use std::sync::Mutex;
439 use tokio;
440 use turbomcp_protocol::MessageId;
441 use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
442
443 #[derive(Debug)]
445 struct MockPlugin {
446 name: String,
447 calls: Arc<Mutex<Vec<String>>>,
448 should_fail_init: bool,
449 should_fail_before_request: bool,
450 }
451
452 impl MockPlugin {
453 fn new(name: &str) -> Self {
454 Self {
455 name: name.to_string(),
456 calls: Arc::new(Mutex::new(Vec::new())),
457 should_fail_init: false,
458 should_fail_before_request: false,
459 }
460 }
461
462 fn with_init_failure(mut self) -> Self {
463 self.should_fail_init = true;
464 self
465 }
466
467 fn with_request_failure(mut self) -> Self {
468 self.should_fail_before_request = true;
469 self
470 }
471
472 fn get_calls(&self) -> Vec<String> {
473 self.calls.lock().unwrap().clone()
474 }
475 }
476
477 #[async_trait]
478 impl ClientPlugin for MockPlugin {
479 fn name(&self) -> &str {
480 &self.name
481 }
482
483 fn version(&self) -> &str {
484 "1.0.0"
485 }
486
487 async fn initialize(&self, _context: &PluginContext) -> PluginResult<()> {
488 self.calls.lock().unwrap().push("initialize".to_string());
489 if self.should_fail_init {
490 Err(PluginError::initialization("Mock initialization failure"))
491 } else {
492 Ok(())
493 }
494 }
495
496 async fn before_request(&self, context: &mut RequestContext) -> PluginResult<()> {
497 self.calls
498 .lock()
499 .unwrap()
500 .push(format!("before_request:{}", context.method()));
501 if self.should_fail_before_request {
502 Err(PluginError::request_processing("Mock request failure"))
503 } else {
504 Ok(())
505 }
506 }
507
508 async fn after_response(&self, context: &mut ResponseContext) -> PluginResult<()> {
509 self.calls
510 .lock()
511 .unwrap()
512 .push(format!("after_response:{}", context.method()));
513 Ok(())
514 }
515
516 async fn handle_custom(
517 &self,
518 method: &str,
519 params: Option<Value>,
520 ) -> PluginResult<Option<Value>> {
521 self.calls
522 .lock()
523 .unwrap()
524 .push(format!("handle_custom:{}", method));
525 if method.starts_with(&format!("{}.", self.name)) {
526 Ok(params)
527 } else {
528 Ok(None)
529 }
530 }
531 }
532
533 #[tokio::test]
534 async fn test_registry_creation() {
535 let registry = PluginRegistry::new();
536 assert_eq!(registry.plugin_count(), 0);
537 assert!(registry.get_plugin_names().is_empty());
538 }
539
540 #[tokio::test]
541 async fn test_plugin_registration() {
542 let mut registry = PluginRegistry::new();
543 let plugin = Arc::new(MockPlugin::new("test"));
544
545 registry.register_plugin(plugin.clone()).await.unwrap();
546
547 assert_eq!(registry.plugin_count(), 1);
548 assert!(registry.has_plugin("test"));
549 assert_eq!(registry.get_plugin_names(), vec!["test"]);
550
551 let retrieved = registry.get_plugin("test").unwrap();
552 assert_eq!(retrieved.name(), "test");
553 }
554
555 #[tokio::test]
556 async fn test_duplicate_registration() {
557 let mut registry = PluginRegistry::new();
558 let plugin1 = Arc::new(MockPlugin::new("duplicate"));
559 let plugin2 = Arc::new(MockPlugin::new("duplicate"));
560
561 registry.register_plugin(plugin1).await.unwrap();
562 let result = registry.register_plugin(plugin2).await;
563
564 assert!(result.is_err());
565 assert_eq!(registry.plugin_count(), 1);
566 }
567
568 #[tokio::test]
569 async fn test_plugin_initialization_failure() {
570 let mut registry = PluginRegistry::new();
571 let plugin = Arc::new(MockPlugin::new("failing").with_init_failure());
572
573 let result = registry.register_plugin(plugin).await;
574
575 assert!(result.is_err());
576 assert_eq!(registry.plugin_count(), 0);
577 }
578
579 #[tokio::test]
580 async fn test_plugin_unregistration() {
581 let mut registry = PluginRegistry::new();
582 let plugin = Arc::new(MockPlugin::new("removable"));
583
584 registry.register_plugin(plugin).await.unwrap();
585 assert_eq!(registry.plugin_count(), 1);
586
587 registry.unregister_plugin("removable").await.unwrap();
588 assert_eq!(registry.plugin_count(), 0);
589 assert!(!registry.has_plugin("removable"));
590 }
591
592 #[tokio::test]
593 async fn test_before_request_middleware() {
594 let mut registry = PluginRegistry::new();
595 let plugin1 = Arc::new(MockPlugin::new("first"));
596 let plugin2 = Arc::new(MockPlugin::new("second"));
597
598 registry.register_plugin(plugin1.clone()).await.unwrap();
599 registry.register_plugin(plugin2.clone()).await.unwrap();
600
601 let request = JsonRpcRequest {
602 jsonrpc: JsonRpcVersion,
603 id: MessageId::from("test"),
604 method: "test/method".to_string(),
605 params: None,
606 };
607
608 let mut context = RequestContext::new(request, HashMap::new());
609 registry.execute_before_request(&mut context).await.unwrap();
610
611 assert!(
613 plugin1
614 .get_calls()
615 .contains(&"before_request:test/method".to_string())
616 );
617 assert!(
618 plugin2
619 .get_calls()
620 .contains(&"before_request:test/method".to_string())
621 );
622 }
623
624 #[tokio::test]
625 async fn test_before_request_error_handling() {
626 let mut registry = PluginRegistry::new();
627 let good_plugin = Arc::new(MockPlugin::new("good"));
628 let bad_plugin = Arc::new(MockPlugin::new("bad").with_request_failure());
629
630 registry.register_plugin(good_plugin.clone()).await.unwrap();
631 registry.register_plugin(bad_plugin.clone()).await.unwrap();
632
633 let request = JsonRpcRequest {
634 jsonrpc: JsonRpcVersion,
635 id: MessageId::from("test"),
636 method: "test/method".to_string(),
637 params: None,
638 };
639
640 let mut context = RequestContext::new(request, HashMap::new());
641 let result = registry.execute_before_request(&mut context).await;
642
643 assert!(result.is_err());
644 assert!(
645 good_plugin
646 .get_calls()
647 .contains(&"before_request:test/method".to_string())
648 );
649 assert!(
650 bad_plugin
651 .get_calls()
652 .contains(&"before_request:test/method".to_string())
653 );
654 }
655
656 #[tokio::test]
657 async fn test_custom_method_handling() {
658 let mut registry = PluginRegistry::new();
659 let plugin = Arc::new(MockPlugin::new("handler"));
660
661 registry.register_plugin(plugin.clone()).await.unwrap();
662
663 let result = registry
664 .handle_custom_method("handler.test", Some(json!({"data": "test"})))
665 .await
666 .unwrap();
667
668 assert!(result.is_some());
669 assert_eq!(result.unwrap(), json!({"data": "test"}));
670 assert!(
671 plugin
672 .get_calls()
673 .contains(&"handle_custom:handler.test".to_string())
674 );
675 }
676
677 #[tokio::test]
678 async fn test_custom_method_not_handled() {
679 let mut registry = PluginRegistry::new();
680 let plugin = Arc::new(MockPlugin::new("handler"));
681
682 registry.register_plugin(plugin.clone()).await.unwrap();
683
684 let result = registry
685 .handle_custom_method("other.method", None)
686 .await
687 .unwrap();
688
689 assert!(result.is_none());
690 assert!(
691 plugin
692 .get_calls()
693 .contains(&"handle_custom:other.method".to_string())
694 );
695 }
696
697 #[tokio::test]
698 async fn test_plugin_info() {
699 let mut registry = PluginRegistry::new();
700 let plugin = Arc::new(MockPlugin::new("info_test"));
701
702 registry.register_plugin(plugin).await.unwrap();
703
704 let info = registry.get_plugin_info();
705 assert_eq!(info.len(), 1);
706 assert_eq!(info[0].0, "info_test");
707 assert_eq!(info[0].1, "1.0.0");
708 }
709
710 #[tokio::test]
711 async fn test_clear_plugins() {
712 let mut registry = PluginRegistry::new();
713 let plugin1 = Arc::new(MockPlugin::new("first"));
714 let plugin2 = Arc::new(MockPlugin::new("second"));
715
716 registry.register_plugin(plugin1).await.unwrap();
717 registry.register_plugin(plugin2).await.unwrap();
718 assert_eq!(registry.plugin_count(), 2);
719
720 registry.clear().await.unwrap();
721 assert_eq!(registry.plugin_count(), 0);
722 assert!(registry.get_plugin_names().is_empty());
723 }
724}