turbomcp_client/plugins/
registry.rs

1//! Plugin registry for managing plugin lifecycle and execution
2//!
3//! The PluginRegistry manages the registration, ordering, and execution of plugins.
4//! It implements the middleware pattern where plugins are executed in a defined order
5//! for request/response processing.
6
7use crate::plugins::core::{
8    ClientPlugin, PluginContext, PluginError, PluginResult, RequestContext, ResponseContext,
9};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::{debug, error, info, warn};
14
15/// Registry for managing client plugins
16///
17/// The registry maintains an ordered list of plugins and provides methods for:
18/// - Plugin registration and validation
19/// - Middleware chain execution
20/// - Custom method routing
21/// - Plugin lifecycle management
22///
23/// # Examples
24///
25/// ```rust,no_run
26/// use turbomcp_client::plugins::{PluginRegistry, MetricsPlugin, PluginConfig};
27/// use std::sync::Arc;
28///
29/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
30/// let mut registry = PluginRegistry::new();
31///
32/// // Register a metrics plugin
33/// let metrics = Arc::new(MetricsPlugin::new(PluginConfig::Metrics));
34/// registry.register_plugin(metrics).await?;
35///
36/// // Execute middleware chain
37/// // let mut request_context = RequestContext::new(...);
38/// // registry.execute_before_request(&mut request_context).await?;
39/// # Ok(())
40/// # }
41/// ```
42#[derive(Debug)]
43pub struct PluginRegistry {
44    /// Registered plugins in execution order
45    plugins: Vec<Arc<dyn ClientPlugin>>,
46
47    /// Plugin lookup by name for fast access
48    plugin_map: HashMap<String, usize>,
49
50    /// Client context for plugin initialization
51    client_context: Option<PluginContext>,
52}
53
54impl Default for PluginRegistry {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl PluginRegistry {
61    /// Create a new empty plugin registry
62    #[must_use]
63    pub fn new() -> Self {
64        Self {
65            plugins: Vec::new(),
66            plugin_map: HashMap::new(),
67            client_context: None,
68        }
69    }
70
71    /// Set the client context for plugin initialization
72    ///
73    /// This should be called once when the client is initialized to provide
74    /// context information to plugins during registration.
75    pub fn set_client_context(&mut self, context: PluginContext) {
76        debug!(
77            "Setting client context: {} v{}",
78            context.client_name, context.client_version
79        );
80        self.client_context = Some(context);
81    }
82
83    /// Register a new plugin
84    ///
85    /// Validates the plugin, checks dependencies, and initializes it.
86    /// Plugins are executed in registration order.
87    ///
88    /// # Arguments
89    ///
90    /// * `plugin` - The plugin to register
91    ///
92    /// # Returns
93    ///
94    /// Returns `Ok(())` if registration succeeds, or `PluginError` if it fails.
95    ///
96    /// # Errors
97    ///
98    /// - Plugin name already registered
99    /// - Plugin dependencies not met
100    /// - Plugin initialization failure
101    pub async fn register_plugin(&mut self, plugin: Arc<dyn ClientPlugin>) -> PluginResult<()> {
102        let plugin_name = plugin.name().to_string();
103
104        info!("Registering plugin: {} v{}", plugin_name, plugin.version());
105
106        // Check for duplicate registration
107        if self.plugin_map.contains_key(&plugin_name) {
108            return Err(PluginError::configuration(format!(
109                "Plugin '{}' is already registered",
110                plugin_name
111            )));
112        }
113
114        // Check dependencies
115        for dependency in plugin.dependencies() {
116            if !self.has_plugin(dependency) {
117                return Err(PluginError::dependency_not_available(dependency));
118            }
119        }
120
121        // Initialize plugin with current context
122        if let Some(context) = &self.client_context {
123            // Update context with current plugin list
124            let mut updated_context = context.clone();
125            updated_context.available_plugins = self.get_plugin_names();
126
127            plugin.initialize(&updated_context).await.map_err(|e| {
128                error!("Failed to initialize plugin '{}': {}", plugin_name, e);
129                e
130            })?;
131        } else {
132            // Create minimal context if none set
133            let context = PluginContext::new(
134                "unknown".to_string(),
135                "unknown".to_string(),
136                HashMap::new(),
137                HashMap::new(),
138                self.get_plugin_names(),
139            );
140            plugin.initialize(&context).await.map_err(|e| {
141                error!("Failed to initialize plugin '{}': {}", plugin_name, e);
142                e
143            })?;
144        }
145
146        // Register the plugin
147        let index = self.plugins.len();
148        self.plugins.push(plugin);
149        self.plugin_map.insert(plugin_name.clone(), index);
150
151        debug!(
152            "Plugin '{}' registered successfully at index {}",
153            plugin_name, index
154        );
155        Ok(())
156    }
157
158    /// Unregister a plugin by name
159    ///
160    /// Removes the plugin from the registry and calls its cleanup method.
161    ///
162    /// # Arguments
163    ///
164    /// * `plugin_name` - Name of the plugin to unregister
165    ///
166    /// # Returns
167    ///
168    /// Returns `Ok(())` if unregistration succeeds, or `PluginError` if it fails.
169    pub async fn unregister_plugin(&mut self, plugin_name: &str) -> PluginResult<()> {
170        info!("Unregistering plugin: {}", plugin_name);
171
172        let index = self.plugin_map.get(plugin_name).copied().ok_or_else(|| {
173            PluginError::configuration(format!("Plugin '{}' not found", plugin_name))
174        })?;
175
176        // Get the plugin and call cleanup
177        let plugin = self.plugins[index].clone();
178        plugin.cleanup().await.map_err(|e| {
179            warn!("Plugin '{}' cleanup failed: {}", plugin_name, e);
180            e
181        })?;
182
183        // Remove from collections
184        self.plugins.remove(index);
185        self.plugin_map.remove(plugin_name);
186
187        // Update indices in the map
188        for (_, plugin_index) in self.plugin_map.iter_mut() {
189            if *plugin_index > index {
190                *plugin_index -= 1;
191            }
192        }
193
194        debug!("Plugin '{}' unregistered successfully", plugin_name);
195        Ok(())
196    }
197
198    /// Check if a plugin is registered
199    #[must_use]
200    pub fn has_plugin(&self, plugin_name: &str) -> bool {
201        self.plugin_map.contains_key(plugin_name)
202    }
203
204    /// Get a plugin by name
205    #[must_use]
206    pub fn get_plugin(&self, plugin_name: &str) -> Option<Arc<dyn ClientPlugin>> {
207        self.plugin_map
208            .get(plugin_name)
209            .and_then(|&index| self.plugins.get(index))
210            .cloned()
211    }
212
213    /// Get all registered plugin names in execution order
214    #[must_use]
215    pub fn get_plugin_names(&self) -> Vec<String> {
216        self.plugins
217            .iter()
218            .map(|plugin| plugin.name().to_string())
219            .collect()
220    }
221
222    /// Get the number of registered plugins
223    #[must_use]
224    pub fn plugin_count(&self) -> usize {
225        self.plugins.len()
226    }
227
228    /// Execute before_request middleware chain
229    ///
230    /// Calls `before_request` on all registered plugins in order.
231    /// If any plugin returns an error, the chain is aborted and the error is returned.
232    ///
233    /// # Arguments
234    ///
235    /// * `context` - Mutable request context that can be modified by plugins
236    ///
237    /// # Returns
238    ///
239    /// Returns `Ok(())` if all plugins succeed, or the first `PluginError` encountered.
240    pub async fn execute_before_request(&self, context: &mut RequestContext) -> PluginResult<()> {
241        debug!(
242            "Executing before_request middleware chain for method: {}",
243            context.method()
244        );
245
246        for (index, plugin) in self.plugins.iter().enumerate() {
247            let plugin_name = plugin.name();
248            debug!(
249                "Calling before_request on plugin '{}' ({})",
250                plugin_name, index
251            );
252
253            plugin.before_request(context).await.map_err(|e| {
254                error!(
255                    "Plugin '{}' before_request failed for method '{}': {}",
256                    plugin_name,
257                    context.method(),
258                    e
259                );
260                e
261            })?;
262        }
263
264        debug!("Before_request middleware chain completed successfully");
265        Ok(())
266    }
267
268    /// Execute after_response middleware chain
269    ///
270    /// Calls `after_response` on all registered plugins in order.
271    /// Unlike before_request, this continues execution even if a plugin fails,
272    /// logging errors but not aborting the chain.
273    ///
274    /// # Arguments
275    ///
276    /// * `context` - Mutable response context that can be modified by plugins
277    ///
278    /// # Returns
279    ///
280    /// Returns `Ok(())` unless all plugins fail, in which case returns the last error.
281    pub async fn execute_after_response(&self, context: &mut ResponseContext) -> PluginResult<()> {
282        debug!(
283            "Executing after_response middleware chain for method: {}",
284            context.method()
285        );
286
287        let mut _last_error = None;
288
289        for (index, plugin) in self.plugins.iter().enumerate() {
290            let plugin_name = plugin.name();
291            debug!(
292                "Calling after_response on plugin '{}' ({})",
293                plugin_name, index
294            );
295
296            if let Err(e) = plugin.after_response(context).await {
297                error!(
298                    "Plugin '{}' after_response failed for method '{}': {}",
299                    plugin_name,
300                    context.method(),
301                    e
302                );
303                _last_error = Some(e);
304                // Continue with other plugins
305            }
306        }
307
308        debug!("After_response middleware chain completed");
309
310        // Return error only if we have one and want to propagate it
311        // For now, we log errors but don't fail the response processing
312        Ok(())
313    }
314
315    /// Handle custom method by routing to appropriate plugin
316    ///
317    /// Attempts to handle the custom method by calling `handle_custom` on each
318    /// plugin in order until one returns `Some(Value)`.
319    ///
320    /// # Arguments
321    ///
322    /// * `method` - The custom method name
323    /// * `params` - Optional parameters for the method
324    ///
325    /// # Returns
326    ///
327    /// Returns `Some(Value)` if a plugin handled the method, `None` if no plugin handled it,
328    /// or `PluginError` if handling failed.
329    pub async fn handle_custom_method(
330        &self,
331        method: &str,
332        params: Option<Value>,
333    ) -> PluginResult<Option<Value>> {
334        debug!("Handling custom method: {}", method);
335
336        for plugin in &self.plugins {
337            let plugin_name = plugin.name();
338            debug!(
339                "Checking if plugin '{}' can handle custom method '{}'",
340                plugin_name, method
341            );
342
343            match plugin.handle_custom(method, params.clone()).await {
344                Ok(Some(result)) => {
345                    info!(
346                        "Plugin '{}' handled custom method '{}'",
347                        plugin_name, method
348                    );
349                    return Ok(Some(result));
350                }
351                Ok(None) => {
352                    // Plugin doesn't handle this method, continue
353                    continue;
354                }
355                Err(e) => {
356                    error!(
357                        "Plugin '{}' failed to handle custom method '{}': {}",
358                        plugin_name, method, e
359                    );
360                    return Err(e);
361                }
362            }
363        }
364
365        debug!("No plugin handled custom method: {}", method);
366        Ok(None)
367    }
368
369    /// Get plugin information for debugging
370    #[must_use]
371    pub fn get_plugin_info(&self) -> Vec<(String, String, Option<String>)> {
372        self.plugins
373            .iter()
374            .map(|plugin| {
375                (
376                    plugin.name().to_string(),
377                    plugin.version().to_string(),
378                    plugin.description().map(|s| s.to_string()),
379                )
380            })
381            .collect()
382    }
383
384    /// Validate plugin dependencies
385    ///
386    /// Checks that all registered plugins have their dependencies satisfied.
387    /// This is useful for debugging plugin configuration issues.
388    pub fn validate_dependencies(&self) -> Result<(), Vec<String>> {
389        let mut errors = Vec::new();
390
391        for plugin in &self.plugins {
392            for dependency in plugin.dependencies() {
393                if !self.has_plugin(dependency) {
394                    errors.push(format!(
395                        "Plugin '{}' depends on '{}' which is not registered",
396                        plugin.name(),
397                        dependency
398                    ));
399                }
400            }
401        }
402
403        if errors.is_empty() {
404            Ok(())
405        } else {
406            Err(errors)
407        }
408    }
409
410    /// Clear all registered plugins
411    ///
412    /// Calls cleanup on all plugins and removes them from the registry.
413    /// This is primarily useful for testing and shutdown scenarios.
414    pub async fn clear(&mut self) -> PluginResult<()> {
415        info!("Clearing all registered plugins");
416
417        let plugins = std::mem::take(&mut self.plugins);
418        self.plugin_map.clear();
419
420        for plugin in plugins {
421            let plugin_name = plugin.name();
422            if let Err(e) = plugin.cleanup().await {
423                warn!("Plugin '{}' cleanup failed: {}", plugin_name, e);
424            }
425        }
426
427        debug!("All plugins cleared successfully");
428        Ok(())
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::plugins::core::PluginContext;
436    use async_trait::async_trait;
437    use serde_json::json;
438    use std::sync::Mutex;
439    use tokio;
440    use turbomcp_protocol::MessageId;
441    use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
442
443    // Test plugin for validation
444    #[derive(Debug)]
445    struct MockPlugin {
446        name: String,
447        calls: Arc<Mutex<Vec<String>>>,
448        should_fail_init: bool,
449        should_fail_before_request: bool,
450    }
451
452    impl MockPlugin {
453        fn new(name: &str) -> Self {
454            Self {
455                name: name.to_string(),
456                calls: Arc::new(Mutex::new(Vec::new())),
457                should_fail_init: false,
458                should_fail_before_request: false,
459            }
460        }
461
462        fn with_init_failure(mut self) -> Self {
463            self.should_fail_init = true;
464            self
465        }
466
467        fn with_request_failure(mut self) -> Self {
468            self.should_fail_before_request = true;
469            self
470        }
471
472        fn get_calls(&self) -> Vec<String> {
473            self.calls.lock().unwrap().clone()
474        }
475    }
476
477    #[async_trait]
478    impl ClientPlugin for MockPlugin {
479        fn name(&self) -> &str {
480            &self.name
481        }
482
483        fn version(&self) -> &str {
484            "1.0.0"
485        }
486
487        async fn initialize(&self, _context: &PluginContext) -> PluginResult<()> {
488            self.calls.lock().unwrap().push("initialize".to_string());
489            if self.should_fail_init {
490                Err(PluginError::initialization("Mock initialization failure"))
491            } else {
492                Ok(())
493            }
494        }
495
496        async fn before_request(&self, context: &mut RequestContext) -> PluginResult<()> {
497            self.calls
498                .lock()
499                .unwrap()
500                .push(format!("before_request:{}", context.method()));
501            if self.should_fail_before_request {
502                Err(PluginError::request_processing("Mock request failure"))
503            } else {
504                Ok(())
505            }
506        }
507
508        async fn after_response(&self, context: &mut ResponseContext) -> PluginResult<()> {
509            self.calls
510                .lock()
511                .unwrap()
512                .push(format!("after_response:{}", context.method()));
513            Ok(())
514        }
515
516        async fn handle_custom(
517            &self,
518            method: &str,
519            params: Option<Value>,
520        ) -> PluginResult<Option<Value>> {
521            self.calls
522                .lock()
523                .unwrap()
524                .push(format!("handle_custom:{}", method));
525            if method.starts_with(&format!("{}.", self.name)) {
526                Ok(params)
527            } else {
528                Ok(None)
529            }
530        }
531    }
532
533    #[tokio::test]
534    async fn test_registry_creation() {
535        let registry = PluginRegistry::new();
536        assert_eq!(registry.plugin_count(), 0);
537        assert!(registry.get_plugin_names().is_empty());
538    }
539
540    #[tokio::test]
541    async fn test_plugin_registration() {
542        let mut registry = PluginRegistry::new();
543        let plugin = Arc::new(MockPlugin::new("test"));
544
545        registry.register_plugin(plugin.clone()).await.unwrap();
546
547        assert_eq!(registry.plugin_count(), 1);
548        assert!(registry.has_plugin("test"));
549        assert_eq!(registry.get_plugin_names(), vec!["test"]);
550
551        let retrieved = registry.get_plugin("test").unwrap();
552        assert_eq!(retrieved.name(), "test");
553    }
554
555    #[tokio::test]
556    async fn test_duplicate_registration() {
557        let mut registry = PluginRegistry::new();
558        let plugin1 = Arc::new(MockPlugin::new("duplicate"));
559        let plugin2 = Arc::new(MockPlugin::new("duplicate"));
560
561        registry.register_plugin(plugin1).await.unwrap();
562        let result = registry.register_plugin(plugin2).await;
563
564        assert!(result.is_err());
565        assert_eq!(registry.plugin_count(), 1);
566    }
567
568    #[tokio::test]
569    async fn test_plugin_initialization_failure() {
570        let mut registry = PluginRegistry::new();
571        let plugin = Arc::new(MockPlugin::new("failing").with_init_failure());
572
573        let result = registry.register_plugin(plugin).await;
574
575        assert!(result.is_err());
576        assert_eq!(registry.plugin_count(), 0);
577    }
578
579    #[tokio::test]
580    async fn test_plugin_unregistration() {
581        let mut registry = PluginRegistry::new();
582        let plugin = Arc::new(MockPlugin::new("removable"));
583
584        registry.register_plugin(plugin).await.unwrap();
585        assert_eq!(registry.plugin_count(), 1);
586
587        registry.unregister_plugin("removable").await.unwrap();
588        assert_eq!(registry.plugin_count(), 0);
589        assert!(!registry.has_plugin("removable"));
590    }
591
592    #[tokio::test]
593    async fn test_before_request_middleware() {
594        let mut registry = PluginRegistry::new();
595        let plugin1 = Arc::new(MockPlugin::new("first"));
596        let plugin2 = Arc::new(MockPlugin::new("second"));
597
598        registry.register_plugin(plugin1.clone()).await.unwrap();
599        registry.register_plugin(plugin2.clone()).await.unwrap();
600
601        let request = JsonRpcRequest {
602            jsonrpc: JsonRpcVersion,
603            id: MessageId::from("test"),
604            method: "test/method".to_string(),
605            params: None,
606        };
607
608        let mut context = RequestContext::new(request, HashMap::new());
609        registry.execute_before_request(&mut context).await.unwrap();
610
611        // Check both plugins were called
612        assert!(
613            plugin1
614                .get_calls()
615                .contains(&"before_request:test/method".to_string())
616        );
617        assert!(
618            plugin2
619                .get_calls()
620                .contains(&"before_request:test/method".to_string())
621        );
622    }
623
624    #[tokio::test]
625    async fn test_before_request_error_handling() {
626        let mut registry = PluginRegistry::new();
627        let good_plugin = Arc::new(MockPlugin::new("good"));
628        let bad_plugin = Arc::new(MockPlugin::new("bad").with_request_failure());
629
630        registry.register_plugin(good_plugin.clone()).await.unwrap();
631        registry.register_plugin(bad_plugin.clone()).await.unwrap();
632
633        let request = JsonRpcRequest {
634            jsonrpc: JsonRpcVersion,
635            id: MessageId::from("test"),
636            method: "test/method".to_string(),
637            params: None,
638        };
639
640        let mut context = RequestContext::new(request, HashMap::new());
641        let result = registry.execute_before_request(&mut context).await;
642
643        assert!(result.is_err());
644        assert!(
645            good_plugin
646                .get_calls()
647                .contains(&"before_request:test/method".to_string())
648        );
649        assert!(
650            bad_plugin
651                .get_calls()
652                .contains(&"before_request:test/method".to_string())
653        );
654    }
655
656    #[tokio::test]
657    async fn test_custom_method_handling() {
658        let mut registry = PluginRegistry::new();
659        let plugin = Arc::new(MockPlugin::new("handler"));
660
661        registry.register_plugin(plugin.clone()).await.unwrap();
662
663        let result = registry
664            .handle_custom_method("handler.test", Some(json!({"data": "test"})))
665            .await
666            .unwrap();
667
668        assert!(result.is_some());
669        assert_eq!(result.unwrap(), json!({"data": "test"}));
670        assert!(
671            plugin
672                .get_calls()
673                .contains(&"handle_custom:handler.test".to_string())
674        );
675    }
676
677    #[tokio::test]
678    async fn test_custom_method_not_handled() {
679        let mut registry = PluginRegistry::new();
680        let plugin = Arc::new(MockPlugin::new("handler"));
681
682        registry.register_plugin(plugin.clone()).await.unwrap();
683
684        let result = registry
685            .handle_custom_method("other.method", None)
686            .await
687            .unwrap();
688
689        assert!(result.is_none());
690        assert!(
691            plugin
692                .get_calls()
693                .contains(&"handle_custom:other.method".to_string())
694        );
695    }
696
697    #[tokio::test]
698    async fn test_plugin_info() {
699        let mut registry = PluginRegistry::new();
700        let plugin = Arc::new(MockPlugin::new("info_test"));
701
702        registry.register_plugin(plugin).await.unwrap();
703
704        let info = registry.get_plugin_info();
705        assert_eq!(info.len(), 1);
706        assert_eq!(info[0].0, "info_test");
707        assert_eq!(info[0].1, "1.0.0");
708    }
709
710    #[tokio::test]
711    async fn test_clear_plugins() {
712        let mut registry = PluginRegistry::new();
713        let plugin1 = Arc::new(MockPlugin::new("first"));
714        let plugin2 = Arc::new(MockPlugin::new("second"));
715
716        registry.register_plugin(plugin1).await.unwrap();
717        registry.register_plugin(plugin2).await.unwrap();
718        assert_eq!(registry.plugin_count(), 2);
719
720        registry.clear().await.unwrap();
721        assert_eq!(registry.plugin_count(), 0);
722        assert!(registry.get_plugin_names().is_empty());
723    }
724}