turbomcp_client/plugins/
registry.rs

1//! Plugin registry for managing plugin lifecycle and execution
2//!
3//! The PluginRegistry manages the registration, ordering, and execution of plugins.
4//! It implements the middleware pattern where plugins are executed in a defined order
5//! for request/response processing.
6
7use crate::plugins::core::{
8    ClientPlugin, PluginContext, PluginError, PluginResult, RequestContext, ResponseContext,
9};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::Arc;
13use tracing::{debug, error, info, warn};
14
15/// Registry for managing client plugins
16///
17/// The registry maintains an ordered list of plugins and provides methods for:
18/// - Plugin registration and validation
19/// - Middleware chain execution
20/// - Custom method routing
21/// - Plugin lifecycle management
22///
23/// # Examples
24///
25/// ```rust,no_run
26/// use turbomcp_client::plugins::{PluginRegistry, MetricsPlugin, PluginConfig};
27/// use std::sync::Arc;
28///
29/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
30/// let mut registry = PluginRegistry::new();
31///
32/// // Register a metrics plugin
33/// let metrics = Arc::new(MetricsPlugin::new(PluginConfig::Metrics));
34/// registry.register_plugin(metrics).await?;
35///
36/// // Execute middleware chain
37/// // let mut request_context = RequestContext::new(...);
38/// // registry.execute_before_request(&mut request_context).await?;
39/// # Ok(())
40/// # }
41/// ```
42#[derive(Debug)]
43pub struct PluginRegistry {
44    /// Registered plugins in execution order
45    plugins: Vec<Arc<dyn ClientPlugin>>,
46
47    /// Plugin lookup by name for fast access
48    plugin_map: HashMap<String, usize>,
49
50    /// Client context for plugin initialization
51    client_context: Option<PluginContext>,
52}
53
54impl Default for PluginRegistry {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl PluginRegistry {
61    /// Create a new empty plugin registry
62    pub fn new() -> Self {
63        Self {
64            plugins: Vec::new(),
65            plugin_map: HashMap::new(),
66            client_context: None,
67        }
68    }
69
70    /// Set the client context for plugin initialization
71    ///
72    /// This should be called once when the client is initialized to provide
73    /// context information to plugins during registration.
74    pub fn set_client_context(&mut self, context: PluginContext) {
75        debug!(
76            "Setting client context: {} v{}",
77            context.client_name, context.client_version
78        );
79        self.client_context = Some(context);
80    }
81
82    /// Register a new plugin
83    ///
84    /// Validates the plugin, checks dependencies, and initializes it.
85    /// Plugins are executed in registration order.
86    ///
87    /// # Arguments
88    ///
89    /// * `plugin` - The plugin to register
90    ///
91    /// # Returns
92    ///
93    /// Returns `Ok(())` if registration succeeds, or `PluginError` if it fails.
94    ///
95    /// # Errors
96    ///
97    /// - Plugin name already registered
98    /// - Plugin dependencies not met
99    /// - Plugin initialization failure
100    pub async fn register_plugin(&mut self, plugin: Arc<dyn ClientPlugin>) -> PluginResult<()> {
101        let plugin_name = plugin.name().to_string();
102
103        info!("Registering plugin: {} v{}", plugin_name, plugin.version());
104
105        // Check for duplicate registration
106        if self.plugin_map.contains_key(&plugin_name) {
107            return Err(PluginError::configuration(format!(
108                "Plugin '{}' is already registered",
109                plugin_name
110            )));
111        }
112
113        // Check dependencies
114        for dependency in plugin.dependencies() {
115            if !self.has_plugin(dependency) {
116                return Err(PluginError::dependency_not_available(dependency));
117            }
118        }
119
120        // Initialize plugin with current context
121        if let Some(context) = &self.client_context {
122            // Update context with current plugin list
123            let mut updated_context = context.clone();
124            updated_context.available_plugins = self.get_plugin_names();
125
126            plugin.initialize(&updated_context).await.map_err(|e| {
127                error!("Failed to initialize plugin '{}': {}", plugin_name, e);
128                e
129            })?;
130        } else {
131            // Create minimal context if none set
132            let context = PluginContext::new(
133                "unknown".to_string(),
134                "unknown".to_string(),
135                HashMap::new(),
136                HashMap::new(),
137                self.get_plugin_names(),
138            );
139            plugin.initialize(&context).await.map_err(|e| {
140                error!("Failed to initialize plugin '{}': {}", plugin_name, e);
141                e
142            })?;
143        }
144
145        // Register the plugin
146        let index = self.plugins.len();
147        self.plugins.push(plugin);
148        self.plugin_map.insert(plugin_name.clone(), index);
149
150        debug!(
151            "Plugin '{}' registered successfully at index {}",
152            plugin_name, index
153        );
154        Ok(())
155    }
156
157    /// Unregister a plugin by name
158    ///
159    /// Removes the plugin from the registry and calls its cleanup method.
160    ///
161    /// # Arguments
162    ///
163    /// * `plugin_name` - Name of the plugin to unregister
164    ///
165    /// # Returns
166    ///
167    /// Returns `Ok(())` if unregistration succeeds, or `PluginError` if it fails.
168    pub async fn unregister_plugin(&mut self, plugin_name: &str) -> PluginResult<()> {
169        info!("Unregistering plugin: {}", plugin_name);
170
171        let index = self.plugin_map.get(plugin_name).copied().ok_or_else(|| {
172            PluginError::configuration(format!("Plugin '{}' not found", plugin_name))
173        })?;
174
175        // Get the plugin and call cleanup
176        let plugin = self.plugins[index].clone();
177        plugin.cleanup().await.map_err(|e| {
178            warn!("Plugin '{}' cleanup failed: {}", plugin_name, e);
179            e
180        })?;
181
182        // Remove from collections
183        self.plugins.remove(index);
184        self.plugin_map.remove(plugin_name);
185
186        // Update indices in the map
187        for (_, plugin_index) in self.plugin_map.iter_mut() {
188            if *plugin_index > index {
189                *plugin_index -= 1;
190            }
191        }
192
193        debug!("Plugin '{}' unregistered successfully", plugin_name);
194        Ok(())
195    }
196
197    /// Check if a plugin is registered
198    pub fn has_plugin(&self, plugin_name: &str) -> bool {
199        self.plugin_map.contains_key(plugin_name)
200    }
201
202    /// Get a plugin by name
203    pub fn get_plugin(&self, plugin_name: &str) -> Option<Arc<dyn ClientPlugin>> {
204        self.plugin_map
205            .get(plugin_name)
206            .and_then(|&index| self.plugins.get(index))
207            .cloned()
208    }
209
210    /// Get all registered plugin names in execution order
211    pub fn get_plugin_names(&self) -> Vec<String> {
212        self.plugins
213            .iter()
214            .map(|plugin| plugin.name().to_string())
215            .collect()
216    }
217
218    /// Get the number of registered plugins
219    pub fn plugin_count(&self) -> usize {
220        self.plugins.len()
221    }
222
223    /// Execute before_request middleware chain
224    ///
225    /// Calls `before_request` on all registered plugins in order.
226    /// If any plugin returns an error, the chain is aborted and the error is returned.
227    ///
228    /// # Arguments
229    ///
230    /// * `context` - Mutable request context that can be modified by plugins
231    ///
232    /// # Returns
233    ///
234    /// Returns `Ok(())` if all plugins succeed, or the first `PluginError` encountered.
235    pub async fn execute_before_request(&self, context: &mut RequestContext) -> PluginResult<()> {
236        debug!(
237            "Executing before_request middleware chain for method: {}",
238            context.method()
239        );
240
241        for (index, plugin) in self.plugins.iter().enumerate() {
242            let plugin_name = plugin.name();
243            debug!(
244                "Calling before_request on plugin '{}' ({})",
245                plugin_name, index
246            );
247
248            plugin.before_request(context).await.map_err(|e| {
249                error!(
250                    "Plugin '{}' before_request failed for method '{}': {}",
251                    plugin_name,
252                    context.method(),
253                    e
254                );
255                e
256            })?;
257        }
258
259        debug!("Before_request middleware chain completed successfully");
260        Ok(())
261    }
262
263    /// Execute after_response middleware chain
264    ///
265    /// Calls `after_response` on all registered plugins in order.
266    /// Unlike before_request, this continues execution even if a plugin fails,
267    /// logging errors but not aborting the chain.
268    ///
269    /// # Arguments
270    ///
271    /// * `context` - Mutable response context that can be modified by plugins
272    ///
273    /// # Returns
274    ///
275    /// Returns `Ok(())` unless all plugins fail, in which case returns the last error.
276    pub async fn execute_after_response(&self, context: &mut ResponseContext) -> PluginResult<()> {
277        debug!(
278            "Executing after_response middleware chain for method: {}",
279            context.method()
280        );
281
282        let mut _last_error = None;
283
284        for (index, plugin) in self.plugins.iter().enumerate() {
285            let plugin_name = plugin.name();
286            debug!(
287                "Calling after_response on plugin '{}' ({})",
288                plugin_name, index
289            );
290
291            if let Err(e) = plugin.after_response(context).await {
292                error!(
293                    "Plugin '{}' after_response failed for method '{}': {}",
294                    plugin_name,
295                    context.method(),
296                    e
297                );
298                _last_error = Some(e);
299                // Continue with other plugins
300            }
301        }
302
303        debug!("After_response middleware chain completed");
304
305        // Return error only if we have one and want to propagate it
306        // For now, we log errors but don't fail the response processing
307        Ok(())
308    }
309
310    /// Handle custom method by routing to appropriate plugin
311    ///
312    /// Attempts to handle the custom method by calling `handle_custom` on each
313    /// plugin in order until one returns `Some(Value)`.
314    ///
315    /// # Arguments
316    ///
317    /// * `method` - The custom method name
318    /// * `params` - Optional parameters for the method
319    ///
320    /// # Returns
321    ///
322    /// Returns `Some(Value)` if a plugin handled the method, `None` if no plugin handled it,
323    /// or `PluginError` if handling failed.
324    pub async fn handle_custom_method(
325        &self,
326        method: &str,
327        params: Option<Value>,
328    ) -> PluginResult<Option<Value>> {
329        debug!("Handling custom method: {}", method);
330
331        for plugin in &self.plugins {
332            let plugin_name = plugin.name();
333            debug!(
334                "Checking if plugin '{}' can handle custom method '{}'",
335                plugin_name, method
336            );
337
338            match plugin.handle_custom(method, params.clone()).await {
339                Ok(Some(result)) => {
340                    info!(
341                        "Plugin '{}' handled custom method '{}'",
342                        plugin_name, method
343                    );
344                    return Ok(Some(result));
345                }
346                Ok(None) => {
347                    // Plugin doesn't handle this method, continue
348                    continue;
349                }
350                Err(e) => {
351                    error!(
352                        "Plugin '{}' failed to handle custom method '{}': {}",
353                        plugin_name, method, e
354                    );
355                    return Err(e);
356                }
357            }
358        }
359
360        debug!("No plugin handled custom method: {}", method);
361        Ok(None)
362    }
363
364    /// Get plugin information for debugging
365    pub fn get_plugin_info(&self) -> Vec<(String, String, Option<String>)> {
366        self.plugins
367            .iter()
368            .map(|plugin| {
369                (
370                    plugin.name().to_string(),
371                    plugin.version().to_string(),
372                    plugin.description().map(|s| s.to_string()),
373                )
374            })
375            .collect()
376    }
377
378    /// Validate plugin dependencies
379    ///
380    /// Checks that all registered plugins have their dependencies satisfied.
381    /// This is useful for debugging plugin configuration issues.
382    pub fn validate_dependencies(&self) -> Result<(), Vec<String>> {
383        let mut errors = Vec::new();
384
385        for plugin in &self.plugins {
386            for dependency in plugin.dependencies() {
387                if !self.has_plugin(dependency) {
388                    errors.push(format!(
389                        "Plugin '{}' depends on '{}' which is not registered",
390                        plugin.name(),
391                        dependency
392                    ));
393                }
394            }
395        }
396
397        if errors.is_empty() {
398            Ok(())
399        } else {
400            Err(errors)
401        }
402    }
403
404    /// Clear all registered plugins
405    ///
406    /// Calls cleanup on all plugins and removes them from the registry.
407    /// This is primarily useful for testing and shutdown scenarios.
408    pub async fn clear(&mut self) -> PluginResult<()> {
409        info!("Clearing all registered plugins");
410
411        let plugins = std::mem::take(&mut self.plugins);
412        self.plugin_map.clear();
413
414        for plugin in plugins {
415            let plugin_name = plugin.name();
416            if let Err(e) = plugin.cleanup().await {
417                warn!("Plugin '{}' cleanup failed: {}", plugin_name, e);
418            }
419        }
420
421        debug!("All plugins cleared successfully");
422        Ok(())
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use crate::plugins::core::PluginContext;
430    use async_trait::async_trait;
431    use serde_json::json;
432    use std::sync::Mutex;
433    use tokio;
434    use turbomcp_protocol::MessageId;
435    use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
436
437    // Test plugin for validation
438    #[derive(Debug)]
439    struct MockPlugin {
440        name: String,
441        calls: Arc<Mutex<Vec<String>>>,
442        should_fail_init: bool,
443        should_fail_before_request: bool,
444    }
445
446    impl MockPlugin {
447        fn new(name: &str) -> Self {
448            Self {
449                name: name.to_string(),
450                calls: Arc::new(Mutex::new(Vec::new())),
451                should_fail_init: false,
452                should_fail_before_request: false,
453            }
454        }
455
456        fn with_init_failure(mut self) -> Self {
457            self.should_fail_init = true;
458            self
459        }
460
461        fn with_request_failure(mut self) -> Self {
462            self.should_fail_before_request = true;
463            self
464        }
465
466        fn get_calls(&self) -> Vec<String> {
467            self.calls.lock().unwrap().clone()
468        }
469    }
470
471    #[async_trait]
472    impl ClientPlugin for MockPlugin {
473        fn name(&self) -> &str {
474            &self.name
475        }
476
477        fn version(&self) -> &str {
478            "1.0.0"
479        }
480
481        async fn initialize(&self, _context: &PluginContext) -> PluginResult<()> {
482            self.calls.lock().unwrap().push("initialize".to_string());
483            if self.should_fail_init {
484                Err(PluginError::initialization("Mock initialization failure"))
485            } else {
486                Ok(())
487            }
488        }
489
490        async fn before_request(&self, context: &mut RequestContext) -> PluginResult<()> {
491            self.calls
492                .lock()
493                .unwrap()
494                .push(format!("before_request:{}", context.method()));
495            if self.should_fail_before_request {
496                Err(PluginError::request_processing("Mock request failure"))
497            } else {
498                Ok(())
499            }
500        }
501
502        async fn after_response(&self, context: &mut ResponseContext) -> PluginResult<()> {
503            self.calls
504                .lock()
505                .unwrap()
506                .push(format!("after_response:{}", context.method()));
507            Ok(())
508        }
509
510        async fn handle_custom(
511            &self,
512            method: &str,
513            params: Option<Value>,
514        ) -> PluginResult<Option<Value>> {
515            self.calls
516                .lock()
517                .unwrap()
518                .push(format!("handle_custom:{}", method));
519            if method.starts_with(&format!("{}.", self.name)) {
520                Ok(params)
521            } else {
522                Ok(None)
523            }
524        }
525    }
526
527    #[tokio::test]
528    async fn test_registry_creation() {
529        let registry = PluginRegistry::new();
530        assert_eq!(registry.plugin_count(), 0);
531        assert!(registry.get_plugin_names().is_empty());
532    }
533
534    #[tokio::test]
535    async fn test_plugin_registration() {
536        let mut registry = PluginRegistry::new();
537        let plugin = Arc::new(MockPlugin::new("test"));
538
539        registry.register_plugin(plugin.clone()).await.unwrap();
540
541        assert_eq!(registry.plugin_count(), 1);
542        assert!(registry.has_plugin("test"));
543        assert_eq!(registry.get_plugin_names(), vec!["test"]);
544
545        let retrieved = registry.get_plugin("test").unwrap();
546        assert_eq!(retrieved.name(), "test");
547    }
548
549    #[tokio::test]
550    async fn test_duplicate_registration() {
551        let mut registry = PluginRegistry::new();
552        let plugin1 = Arc::new(MockPlugin::new("duplicate"));
553        let plugin2 = Arc::new(MockPlugin::new("duplicate"));
554
555        registry.register_plugin(plugin1).await.unwrap();
556        let result = registry.register_plugin(plugin2).await;
557
558        assert!(result.is_err());
559        assert_eq!(registry.plugin_count(), 1);
560    }
561
562    #[tokio::test]
563    async fn test_plugin_initialization_failure() {
564        let mut registry = PluginRegistry::new();
565        let plugin = Arc::new(MockPlugin::new("failing").with_init_failure());
566
567        let result = registry.register_plugin(plugin).await;
568
569        assert!(result.is_err());
570        assert_eq!(registry.plugin_count(), 0);
571    }
572
573    #[tokio::test]
574    async fn test_plugin_unregistration() {
575        let mut registry = PluginRegistry::new();
576        let plugin = Arc::new(MockPlugin::new("removable"));
577
578        registry.register_plugin(plugin).await.unwrap();
579        assert_eq!(registry.plugin_count(), 1);
580
581        registry.unregister_plugin("removable").await.unwrap();
582        assert_eq!(registry.plugin_count(), 0);
583        assert!(!registry.has_plugin("removable"));
584    }
585
586    #[tokio::test]
587    async fn test_before_request_middleware() {
588        let mut registry = PluginRegistry::new();
589        let plugin1 = Arc::new(MockPlugin::new("first"));
590        let plugin2 = Arc::new(MockPlugin::new("second"));
591
592        registry.register_plugin(plugin1.clone()).await.unwrap();
593        registry.register_plugin(plugin2.clone()).await.unwrap();
594
595        let request = JsonRpcRequest {
596            jsonrpc: JsonRpcVersion,
597            id: MessageId::from("test"),
598            method: "test/method".to_string(),
599            params: None,
600        };
601
602        let mut context = RequestContext::new(request, HashMap::new());
603        registry.execute_before_request(&mut context).await.unwrap();
604
605        // Check both plugins were called
606        assert!(
607            plugin1
608                .get_calls()
609                .contains(&"before_request:test/method".to_string())
610        );
611        assert!(
612            plugin2
613                .get_calls()
614                .contains(&"before_request:test/method".to_string())
615        );
616    }
617
618    #[tokio::test]
619    async fn test_before_request_error_handling() {
620        let mut registry = PluginRegistry::new();
621        let good_plugin = Arc::new(MockPlugin::new("good"));
622        let bad_plugin = Arc::new(MockPlugin::new("bad").with_request_failure());
623
624        registry.register_plugin(good_plugin.clone()).await.unwrap();
625        registry.register_plugin(bad_plugin.clone()).await.unwrap();
626
627        let request = JsonRpcRequest {
628            jsonrpc: JsonRpcVersion,
629            id: MessageId::from("test"),
630            method: "test/method".to_string(),
631            params: None,
632        };
633
634        let mut context = RequestContext::new(request, HashMap::new());
635        let result = registry.execute_before_request(&mut context).await;
636
637        assert!(result.is_err());
638        assert!(
639            good_plugin
640                .get_calls()
641                .contains(&"before_request:test/method".to_string())
642        );
643        assert!(
644            bad_plugin
645                .get_calls()
646                .contains(&"before_request:test/method".to_string())
647        );
648    }
649
650    #[tokio::test]
651    async fn test_custom_method_handling() {
652        let mut registry = PluginRegistry::new();
653        let plugin = Arc::new(MockPlugin::new("handler"));
654
655        registry.register_plugin(plugin.clone()).await.unwrap();
656
657        let result = registry
658            .handle_custom_method("handler.test", Some(json!({"data": "test"})))
659            .await
660            .unwrap();
661
662        assert!(result.is_some());
663        assert_eq!(result.unwrap(), json!({"data": "test"}));
664        assert!(
665            plugin
666                .get_calls()
667                .contains(&"handle_custom:handler.test".to_string())
668        );
669    }
670
671    #[tokio::test]
672    async fn test_custom_method_not_handled() {
673        let mut registry = PluginRegistry::new();
674        let plugin = Arc::new(MockPlugin::new("handler"));
675
676        registry.register_plugin(plugin.clone()).await.unwrap();
677
678        let result = registry
679            .handle_custom_method("other.method", None)
680            .await
681            .unwrap();
682
683        assert!(result.is_none());
684        assert!(
685            plugin
686                .get_calls()
687                .contains(&"handle_custom:other.method".to_string())
688        );
689    }
690
691    #[tokio::test]
692    async fn test_plugin_info() {
693        let mut registry = PluginRegistry::new();
694        let plugin = Arc::new(MockPlugin::new("info_test"));
695
696        registry.register_plugin(plugin).await.unwrap();
697
698        let info = registry.get_plugin_info();
699        assert_eq!(info.len(), 1);
700        assert_eq!(info[0].0, "info_test");
701        assert_eq!(info[0].1, "1.0.0");
702    }
703
704    #[tokio::test]
705    async fn test_clear_plugins() {
706        let mut registry = PluginRegistry::new();
707        let plugin1 = Arc::new(MockPlugin::new("first"));
708        let plugin2 = Arc::new(MockPlugin::new("second"));
709
710        registry.register_plugin(plugin1).await.unwrap();
711        registry.register_plugin(plugin2).await.unwrap();
712        assert_eq!(registry.plugin_count(), 2);
713
714        registry.clear().await.unwrap();
715        assert_eq!(registry.plugin_count(), 0);
716        assert!(registry.get_plugin_names().is_empty());
717    }
718}