spec_ai_policy/plugin/
mod.rs

1/// Plugin system for extending agent capabilities
2///
3/// This module provides a flexible plugin architecture that allows:
4/// - Dynamic registration of new model providers
5/// - Custom tool implementations
6/// - Extension of agent capabilities
7/// - Plugin lifecycle management
8use anyhow::Result;
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15/// Plugin metadata
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct PluginMetadata {
18    /// Unique plugin identifier
19    pub id: String,
20    /// Human-readable name
21    pub name: String,
22    /// Plugin version
23    pub version: String,
24    /// Plugin description
25    pub description: String,
26    /// Plugin author
27    pub author: Option<String>,
28    /// Plugin capabilities/tags
29    pub capabilities: Vec<String>,
30}
31
32impl PluginMetadata {
33    /// Create new plugin metadata
34    pub fn new(id: impl Into<String>, name: impl Into<String>, version: impl Into<String>) -> Self {
35        Self {
36            id: id.into(),
37            name: name.into(),
38            version: version.into(),
39            description: String::new(),
40            author: None,
41            capabilities: Vec::new(),
42        }
43    }
44
45    /// Set description
46    pub fn with_description(mut self, description: impl Into<String>) -> Self {
47        self.description = description.into();
48        self
49    }
50
51    /// Set author
52    pub fn with_author(mut self, author: impl Into<String>) -> Self {
53        self.author = Some(author.into());
54        self
55    }
56
57    /// Add capability
58    pub fn with_capability(mut self, capability: impl Into<String>) -> Self {
59        self.capabilities.push(capability.into());
60        self
61    }
62}
63
64/// Plugin lifecycle states
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66pub enum PluginState {
67    /// Plugin is registered but not initialized
68    Registered,
69    /// Plugin is initialized and ready
70    Active,
71    /// Plugin encountered an error
72    Error,
73    /// Plugin has been shutdown
74    Shutdown,
75}
76
77/// Core plugin trait that all plugins must implement
78#[async_trait]
79pub trait Plugin: Send + Sync {
80    /// Get plugin metadata
81    fn metadata(&self) -> &PluginMetadata;
82
83    /// Initialize the plugin
84    ///
85    /// Called once when the plugin is loaded. Use this to:
86    /// - Validate configuration
87    /// - Initialize resources
88    /// - Register capabilities
89    async fn init(&mut self) -> Result<()> {
90        Ok(())
91    }
92
93    /// Shutdown the plugin
94    ///
95    /// Called when the plugin is being unloaded. Use this to:
96    /// - Clean up resources
97    /// - Save state
98    /// - Disconnect from services
99    async fn shutdown(&mut self) -> Result<()> {
100        Ok(())
101    }
102
103    /// Health check for the plugin
104    async fn health_check(&self) -> Result<bool> {
105        Ok(true)
106    }
107}
108
109/// Plugin registration entry
110struct PluginEntry {
111    plugin: Box<dyn Plugin>,
112    state: PluginState,
113}
114
115/// Plugin registry for managing all plugins
116pub struct PluginRegistry {
117    plugins: Arc<RwLock<HashMap<String, PluginEntry>>>,
118}
119
120impl PluginRegistry {
121    /// Create a new plugin registry
122    pub fn new() -> Self {
123        Self {
124            plugins: Arc::new(RwLock::new(HashMap::new())),
125        }
126    }
127
128    /// Register a new plugin
129    ///
130    /// # Arguments
131    /// * `plugin` - The plugin instance to register
132    ///
133    /// # Returns
134    /// * `Ok(())` if registration succeeds
135    /// * `Err` if a plugin with the same ID already exists
136    pub async fn register(&self, plugin: Box<dyn Plugin>) -> Result<()> {
137        let id = plugin.metadata().id.clone();
138        let mut plugins = self.plugins.write().await;
139
140        if plugins.contains_key(&id) {
141            anyhow::bail!("Plugin with id '{}' already registered", id);
142        }
143
144        plugins.insert(
145            id,
146            PluginEntry {
147                plugin,
148                state: PluginState::Registered,
149            },
150        );
151
152        Ok(())
153    }
154
155    /// Initialize a specific plugin by ID
156    pub async fn init_plugin(&self, id: &str) -> Result<()> {
157        let mut plugins = self.plugins.write().await;
158
159        let entry = plugins
160            .get_mut(id)
161            .ok_or_else(|| anyhow::anyhow!("Plugin '{}' not found", id))?;
162
163        if entry.state != PluginState::Registered {
164            anyhow::bail!("Plugin '{}' is not in Registered state", id);
165        }
166
167        match entry.plugin.init().await {
168            Ok(()) => {
169                entry.state = PluginState::Active;
170                Ok(())
171            }
172            Err(e) => {
173                entry.state = PluginState::Error;
174                Err(e)
175            }
176        }
177    }
178
179    /// Initialize all registered plugins
180    pub async fn init_all(&self) -> Result<Vec<String>> {
181        let plugin_ids: Vec<String> = {
182            let plugins = self.plugins.read().await;
183            plugins
184                .iter()
185                .filter(|(_, entry)| entry.state == PluginState::Registered)
186                .map(|(id, _)| id.clone())
187                .collect()
188        };
189
190        let mut failed = Vec::new();
191
192        for id in &plugin_ids {
193            if let Err(e) = self.init_plugin(id).await {
194                tracing::error!("Failed to initialize plugin '{}': {}", id, e);
195                failed.push(id.clone());
196            }
197        }
198
199        Ok(failed)
200    }
201
202    /// Shutdown a specific plugin by ID
203    pub async fn shutdown_plugin(&self, id: &str) -> Result<()> {
204        let mut plugins = self.plugins.write().await;
205
206        let entry = plugins
207            .get_mut(id)
208            .ok_or_else(|| anyhow::anyhow!("Plugin '{}' not found", id))?;
209
210        if entry.state != PluginState::Active {
211            anyhow::bail!("Plugin '{}' is not active", id);
212        }
213
214        entry.plugin.shutdown().await?;
215        entry.state = PluginState::Shutdown;
216
217        Ok(())
218    }
219
220    /// Shutdown all active plugins
221    pub async fn shutdown_all(&self) -> Result<()> {
222        let plugin_ids: Vec<String> = {
223            let plugins = self.plugins.read().await;
224            plugins
225                .iter()
226                .filter(|(_, entry)| entry.state == PluginState::Active)
227                .map(|(id, _)| id.clone())
228                .collect()
229        };
230
231        for id in &plugin_ids {
232            if let Err(e) = self.shutdown_plugin(id).await {
233                tracing::error!("Failed to shutdown plugin '{}': {}", id, e);
234            }
235        }
236
237        Ok(())
238    }
239
240    /// Unregister a plugin by ID
241    pub async fn unregister(&self, id: &str) -> Result<()> {
242        let mut plugins = self.plugins.write().await;
243
244        let entry = plugins
245            .get(id)
246            .ok_or_else(|| anyhow::anyhow!("Plugin '{}' not found", id))?;
247
248        if entry.state == PluginState::Active {
249            anyhow::bail!("Cannot unregister active plugin '{}'. Shutdown first.", id);
250        }
251
252        plugins.remove(id);
253        Ok(())
254    }
255
256    /// Get plugin metadata by ID
257    pub async fn get_metadata(&self, id: &str) -> Option<PluginMetadata> {
258        let plugins = self.plugins.read().await;
259        plugins.get(id).map(|entry| entry.plugin.metadata().clone())
260    }
261
262    /// Get plugin state by ID
263    pub async fn get_state(&self, id: &str) -> Option<PluginState> {
264        let plugins = self.plugins.read().await;
265        plugins.get(id).map(|entry| entry.state)
266    }
267
268    /// List all plugin IDs
269    pub async fn list_plugin_ids(&self) -> Vec<String> {
270        let plugins = self.plugins.read().await;
271        plugins.keys().cloned().collect()
272    }
273
274    /// List all plugin metadata
275    pub async fn list_plugins(&self) -> Vec<(PluginMetadata, PluginState)> {
276        let plugins = self.plugins.read().await;
277        plugins
278            .values()
279            .map(|entry| (entry.plugin.metadata().clone(), entry.state))
280            .collect()
281    }
282
283    /// Check if a plugin is registered
284    pub async fn has_plugin(&self, id: &str) -> bool {
285        let plugins = self.plugins.read().await;
286        plugins.contains_key(id)
287    }
288
289    /// Get count of registered plugins
290    pub async fn count(&self) -> usize {
291        let plugins = self.plugins.read().await;
292        plugins.len()
293    }
294
295    /// Run health check on all active plugins
296    pub async fn health_check_all(&self) -> HashMap<String, bool> {
297        let plugins = self.plugins.read().await;
298        let mut results = HashMap::new();
299
300        for (id, entry) in plugins.iter() {
301            if entry.state == PluginState::Active {
302                let healthy = entry.plugin.health_check().await.unwrap_or(false);
303                results.insert(id.clone(), healthy);
304            }
305        }
306
307        results
308    }
309}
310
311impl Default for PluginRegistry {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    // Test plugin implementation
322    struct TestPlugin {
323        metadata: PluginMetadata,
324        init_called: bool,
325        shutdown_called: bool,
326    }
327
328    impl TestPlugin {
329        fn new(id: &str) -> Self {
330            Self {
331                metadata: PluginMetadata::new(id, format!("Test Plugin {}", id), "1.0.0"),
332                init_called: false,
333                shutdown_called: false,
334            }
335        }
336    }
337
338    #[async_trait]
339    impl Plugin for TestPlugin {
340        fn metadata(&self) -> &PluginMetadata {
341            &self.metadata
342        }
343
344        async fn init(&mut self) -> Result<()> {
345            self.init_called = true;
346            Ok(())
347        }
348
349        async fn shutdown(&mut self) -> Result<()> {
350            self.shutdown_called = true;
351            Ok(())
352        }
353    }
354
355    // Failing plugin for error tests
356    struct FailingPlugin {
357        metadata: PluginMetadata,
358    }
359
360    impl FailingPlugin {
361        fn new() -> Self {
362            Self {
363                metadata: PluginMetadata::new("failing", "Failing Plugin", "1.0.0"),
364            }
365        }
366    }
367
368    #[async_trait]
369    impl Plugin for FailingPlugin {
370        fn metadata(&self) -> &PluginMetadata {
371            &self.metadata
372        }
373
374        async fn init(&mut self) -> Result<()> {
375            anyhow::bail!("Intentional failure")
376        }
377    }
378
379    #[tokio::test]
380    async fn test_plugin_metadata() {
381        let meta = PluginMetadata::new("test", "Test Plugin", "1.0.0")
382            .with_description("A test plugin")
383            .with_author("Test Author")
384            .with_capability("testing");
385
386        assert_eq!(meta.id, "test");
387        assert_eq!(meta.name, "Test Plugin");
388        assert_eq!(meta.version, "1.0.0");
389        assert_eq!(meta.description, "A test plugin");
390        assert_eq!(meta.author, Some("Test Author".to_string()));
391        assert_eq!(meta.capabilities, vec!["testing"]);
392    }
393
394    #[tokio::test]
395    async fn test_register_plugin() {
396        let registry = PluginRegistry::new();
397        let plugin = Box::new(TestPlugin::new("test1"));
398
399        registry.register(plugin).await.unwrap();
400
401        assert!(registry.has_plugin("test1").await);
402        assert_eq!(registry.count().await, 1);
403    }
404
405    #[tokio::test]
406    async fn test_register_duplicate_plugin() {
407        let registry = PluginRegistry::new();
408        let plugin1 = Box::new(TestPlugin::new("test1"));
409        let plugin2 = Box::new(TestPlugin::new("test1"));
410
411        registry.register(plugin1).await.unwrap();
412        let result = registry.register(plugin2).await;
413
414        assert!(result.is_err());
415        assert_eq!(registry.count().await, 1);
416    }
417
418    #[tokio::test]
419    async fn test_init_plugin() {
420        let registry = PluginRegistry::new();
421        let plugin = Box::new(TestPlugin::new("test1"));
422
423        registry.register(plugin).await.unwrap();
424        registry.init_plugin("test1").await.unwrap();
425
426        let state = registry.get_state("test1").await;
427        assert_eq!(state, Some(PluginState::Active));
428    }
429
430    #[tokio::test]
431    async fn test_init_all_plugins() {
432        let registry = PluginRegistry::new();
433
434        registry
435            .register(Box::new(TestPlugin::new("test1")))
436            .await
437            .unwrap();
438        registry
439            .register(Box::new(TestPlugin::new("test2")))
440            .await
441            .unwrap();
442        registry
443            .register(Box::new(TestPlugin::new("test3")))
444            .await
445            .unwrap();
446
447        let failed = registry.init_all().await.unwrap();
448
449        assert!(failed.is_empty());
450        assert_eq!(registry.get_state("test1").await, Some(PluginState::Active));
451        assert_eq!(registry.get_state("test2").await, Some(PluginState::Active));
452        assert_eq!(registry.get_state("test3").await, Some(PluginState::Active));
453    }
454
455    #[tokio::test]
456    async fn test_init_plugin_failure() {
457        let registry = PluginRegistry::new();
458        let plugin = Box::new(FailingPlugin::new());
459
460        registry.register(plugin).await.unwrap();
461        let result = registry.init_plugin("failing").await;
462
463        assert!(result.is_err());
464        assert_eq!(
465            registry.get_state("failing").await,
466            Some(PluginState::Error)
467        );
468    }
469
470    #[tokio::test]
471    async fn test_shutdown_plugin() {
472        let registry = PluginRegistry::new();
473        let plugin = Box::new(TestPlugin::new("test1"));
474
475        registry.register(plugin).await.unwrap();
476        registry.init_plugin("test1").await.unwrap();
477        registry.shutdown_plugin("test1").await.unwrap();
478
479        let state = registry.get_state("test1").await;
480        assert_eq!(state, Some(PluginState::Shutdown));
481    }
482
483    #[tokio::test]
484    async fn test_shutdown_all_plugins() {
485        let registry = PluginRegistry::new();
486
487        registry
488            .register(Box::new(TestPlugin::new("test1")))
489            .await
490            .unwrap();
491        registry
492            .register(Box::new(TestPlugin::new("test2")))
493            .await
494            .unwrap();
495
496        registry.init_all().await.unwrap();
497        registry.shutdown_all().await.unwrap();
498
499        assert_eq!(
500            registry.get_state("test1").await,
501            Some(PluginState::Shutdown)
502        );
503        assert_eq!(
504            registry.get_state("test2").await,
505            Some(PluginState::Shutdown)
506        );
507    }
508
509    #[tokio::test]
510    async fn test_unregister_plugin() {
511        let registry = PluginRegistry::new();
512        let plugin = Box::new(TestPlugin::new("test1"));
513
514        registry.register(plugin).await.unwrap();
515        registry.unregister("test1").await.unwrap();
516
517        assert!(!registry.has_plugin("test1").await);
518        assert_eq!(registry.count().await, 0);
519    }
520
521    #[tokio::test]
522    async fn test_cannot_unregister_active_plugin() {
523        let registry = PluginRegistry::new();
524        let plugin = Box::new(TestPlugin::new("test1"));
525
526        registry.register(plugin).await.unwrap();
527        registry.init_plugin("test1").await.unwrap();
528
529        let result = registry.unregister("test1").await;
530        assert!(result.is_err());
531    }
532
533    #[tokio::test]
534    async fn test_list_plugins() {
535        let registry = PluginRegistry::new();
536
537        registry
538            .register(Box::new(TestPlugin::new("test1")))
539            .await
540            .unwrap();
541        registry
542            .register(Box::new(TestPlugin::new("test2")))
543            .await
544            .unwrap();
545
546        let plugins = registry.list_plugins().await;
547        assert_eq!(plugins.len(), 2);
548
549        let ids = registry.list_plugin_ids().await;
550        assert_eq!(ids.len(), 2);
551        assert!(ids.contains(&"test1".to_string()));
552        assert!(ids.contains(&"test2".to_string()));
553    }
554
555    #[tokio::test]
556    async fn test_get_metadata() {
557        let registry = PluginRegistry::new();
558        let plugin = Box::new(TestPlugin::new("test1"));
559
560        registry.register(plugin).await.unwrap();
561
562        let metadata = registry.get_metadata("test1").await;
563        assert!(metadata.is_some());
564        assert_eq!(metadata.unwrap().id, "test1");
565    }
566
567    #[tokio::test]
568    async fn test_health_check_all() {
569        let registry = PluginRegistry::new();
570
571        registry
572            .register(Box::new(TestPlugin::new("test1")))
573            .await
574            .unwrap();
575        registry
576            .register(Box::new(TestPlugin::new("test2")))
577            .await
578            .unwrap();
579
580        registry.init_all().await.unwrap();
581
582        let results = registry.health_check_all().await;
583        assert_eq!(results.len(), 2);
584        assert_eq!(results.get("test1"), Some(&true));
585        assert_eq!(results.get("test2"), Some(&true));
586    }
587}