Skip to main content

smooth_plugin/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use smooth_operator::tool::ToolRegistry;
5
6use crate::plugin::Plugin;
7
8/// Registry that holds all loaded plugins.
9///
10/// Provides lifecycle management (init/shutdown), plugin lookup, and aggregation
11/// of routes and tools from all registered plugins.
12pub struct PluginRegistry {
13    plugins: HashMap<String, Arc<dyn Plugin>>,
14    /// Insertion order for deterministic iteration.
15    order: Vec<String>,
16}
17
18impl PluginRegistry {
19    /// Create an empty registry.
20    pub fn new() -> Self {
21        Self {
22            plugins: HashMap::new(),
23            order: vec![],
24        }
25    }
26
27    /// Register a plugin. Returns an error if a plugin with the same id already exists.
28    ///
29    /// # Errors
30    /// Returns an error if a plugin with a duplicate id is registered.
31    pub fn register(&mut self, plugin: Arc<dyn Plugin>) -> anyhow::Result<()> {
32        let id = plugin.id().to_string();
33        if self.plugins.contains_key(&id) {
34            anyhow::bail!("duplicate plugin id: {id}");
35        }
36        self.order.push(id.clone());
37        self.plugins.insert(id, plugin);
38        Ok(())
39    }
40
41    /// Get a plugin by id.
42    pub fn get(&self, id: &str) -> Option<&Arc<dyn Plugin>> {
43        self.plugins.get(id)
44    }
45
46    /// List all registered plugins in insertion order.
47    pub fn list(&self) -> Vec<&Arc<dyn Plugin>> {
48        self.order.iter().filter_map(|id| self.plugins.get(id)).collect()
49    }
50
51    /// Initialize all plugins in registration order.
52    ///
53    /// # Errors
54    /// Returns the first plugin initialization error encountered.
55    pub async fn init_all(&self) -> anyhow::Result<()> {
56        for id in &self.order {
57            if let Some(plugin) = self.plugins.get(id) {
58                plugin.init().await?;
59            }
60        }
61        Ok(())
62    }
63
64    /// Shutdown all plugins in reverse registration order.
65    ///
66    /// # Errors
67    /// Returns the first plugin shutdown error encountered.
68    pub async fn shutdown_all(&self) -> anyhow::Result<()> {
69        for id in self.order.iter().rev() {
70            if let Some(plugin) = self.plugins.get(id) {
71                plugin.shutdown().await?;
72            }
73        }
74        Ok(())
75    }
76
77    /// Collect all plugin routes into a single merged axum Router.
78    /// Each plugin's routes are nested under `/{plugin_id}`.
79    pub fn collect_routes(&self) -> axum::Router {
80        let mut router = axum::Router::new();
81        for id in &self.order {
82            if let Some(plugin) = self.plugins.get(id) {
83                if let Some(plugin_router) = plugin.routes() {
84                    router = router.nest(&format!("/{id}"), plugin_router);
85                }
86            }
87        }
88        router
89    }
90
91    /// Collect all plugin tools into a single `ToolRegistry`.
92    pub fn collect_tools(&self) -> ToolRegistry {
93        let mut registry = ToolRegistry::new();
94        for id in &self.order {
95            if let Some(plugin) = self.plugins.get(id) {
96                for tool in plugin.tools() {
97                    registry.register(tool);
98                }
99            }
100        }
101        registry
102    }
103}
104
105impl Default for PluginRegistry {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use async_trait::async_trait;
114    use axum::routing::get;
115    use smooth_operator::tool::{Tool, ToolSchema};
116
117    use super::*;
118    use crate::command::PluginCommandBuilder;
119    use crate::plugin::Plugin;
120
121    /// A minimal test plugin with configurable behavior.
122    struct TestPlugin {
123        id: String,
124        name: String,
125        version: String,
126        provide_routes: bool,
127        provide_tools: bool,
128        fail_init: bool,
129        fail_shutdown: bool,
130    }
131
132    impl TestPlugin {
133        fn new(id: &str) -> Self {
134            Self {
135                id: id.to_string(),
136                name: format!("Test Plugin {id}"),
137                version: "1.0.0".to_string(),
138                provide_routes: false,
139                provide_tools: false,
140                fail_init: false,
141                fail_shutdown: false,
142            }
143        }
144
145        fn with_routes(mut self) -> Self {
146            self.provide_routes = true;
147            self
148        }
149
150        fn with_tools(mut self) -> Self {
151            self.provide_tools = true;
152            self
153        }
154
155        fn with_fail_init(mut self) -> Self {
156            self.fail_init = true;
157            self
158        }
159
160        fn with_fail_shutdown(mut self) -> Self {
161            self.fail_shutdown = true;
162            self
163        }
164    }
165
166    #[async_trait]
167    impl Plugin for TestPlugin {
168        fn id(&self) -> &str {
169            &self.id
170        }
171        fn name(&self) -> &str {
172            &self.name
173        }
174        fn version(&self) -> &str {
175            &self.version
176        }
177
178        async fn init(&self) -> anyhow::Result<()> {
179            if self.fail_init {
180                anyhow::bail!("init failed for {}", self.id);
181            }
182            Ok(())
183        }
184
185        async fn shutdown(&self) -> anyhow::Result<()> {
186            if self.fail_shutdown {
187                anyhow::bail!("shutdown failed for {}", self.id);
188            }
189            Ok(())
190        }
191
192        fn commands(&self) -> Vec<crate::command::PluginCommand> {
193            vec![PluginCommandBuilder::new("test-cmd").description("A test command").build()]
194        }
195
196        fn routes(&self) -> Option<axum::Router> {
197            if self.provide_routes {
198                Some(axum::Router::new().route("/health", get(|| async { "ok" })))
199            } else {
200                None
201            }
202        }
203
204        fn tools(&self) -> Vec<Box<dyn Tool>> {
205            if self.provide_tools {
206                vec![Box::new(DummyTool {
207                    name: format!("{}-tool", self.id),
208                })]
209            } else {
210                vec![]
211            }
212        }
213    }
214
215    struct DummyTool {
216        name: String,
217    }
218
219    #[async_trait]
220    impl Tool for DummyTool {
221        fn schema(&self) -> ToolSchema {
222            ToolSchema {
223                name: self.name.clone(),
224                description: "A dummy tool".into(),
225                parameters: serde_json::json!({"type": "object"}),
226            }
227        }
228
229        async fn execute(&self, _arguments: serde_json::Value) -> anyhow::Result<String> {
230            Ok("done".into())
231        }
232    }
233
234    // --- Plugin trait default implementation tests ---
235
236    struct MinimalPlugin;
237
238    #[async_trait]
239    impl Plugin for MinimalPlugin {
240        fn id(&self) -> &str {
241            "minimal"
242        }
243        fn name(&self) -> &str {
244            "Minimal"
245        }
246        fn version(&self) -> &str {
247            "0.1.0"
248        }
249    }
250
251    #[tokio::test]
252    async fn plugin_default_init_succeeds() {
253        let p = MinimalPlugin;
254        assert!(p.init().await.is_ok());
255    }
256
257    #[tokio::test]
258    async fn plugin_default_shutdown_succeeds() {
259        let p = MinimalPlugin;
260        assert!(p.shutdown().await.is_ok());
261    }
262
263    #[test]
264    fn plugin_default_commands_empty() {
265        let p = MinimalPlugin;
266        assert!(p.commands().is_empty());
267    }
268
269    #[test]
270    fn plugin_default_routes_none() {
271        let p = MinimalPlugin;
272        assert!(p.routes().is_none());
273    }
274
275    #[test]
276    fn plugin_default_tools_empty() {
277        let p = MinimalPlugin;
278        assert!(p.tools().is_empty());
279    }
280
281    // --- PluginCommand builder tests ---
282
283    #[test]
284    fn command_builder_basic() {
285        let cmd = PluginCommandBuilder::new("deploy").description("Deploy something").build();
286        assert_eq!(cmd.name, "deploy");
287        assert_eq!(cmd.description, "Deploy something");
288        assert!(cmd.handler.is_none());
289        assert!(cmd.subcommands.is_empty());
290    }
291
292    #[tokio::test]
293    async fn command_builder_with_handler() {
294        let cmd = PluginCommandBuilder::new("greet")
295            .description("Say hello")
296            .handler(|args| async move {
297                assert_eq!(args, vec!["world".to_string()]);
298                Ok(())
299            })
300            .build();
301
302        assert!(cmd.handler.is_some());
303        let result = (cmd.handler.as_ref().expect("handler should exist"))(vec!["world".into()]).await;
304        assert!(result.is_ok());
305    }
306
307    #[test]
308    fn command_builder_with_subcommands() {
309        let sub = PluginCommandBuilder::new("sub").description("A subcommand").build();
310        let cmd = PluginCommandBuilder::new("parent").subcommand(sub).build();
311        assert_eq!(cmd.subcommands.len(), 1);
312        assert_eq!(cmd.subcommands[0].name, "sub");
313    }
314
315    // --- PluginRegistry tests ---
316
317    #[test]
318    fn registry_register_and_get() {
319        let mut reg = PluginRegistry::new();
320        let plugin = Arc::new(TestPlugin::new("alpha"));
321        reg.register(plugin).expect("register should succeed");
322
323        let found = reg.get("alpha");
324        assert!(found.is_some());
325        assert_eq!(found.expect("plugin exists").id(), "alpha");
326    }
327
328    #[test]
329    fn registry_duplicate_id_rejected() {
330        let mut reg = PluginRegistry::new();
331        reg.register(Arc::new(TestPlugin::new("dup"))).expect("first register ok");
332        let result = reg.register(Arc::new(TestPlugin::new("dup")));
333        assert!(result.is_err());
334        assert!(result.expect_err("should error").to_string().contains("duplicate plugin id"));
335    }
336
337    #[test]
338    fn registry_list_returns_insertion_order() {
339        let mut reg = PluginRegistry::new();
340        reg.register(Arc::new(TestPlugin::new("b"))).expect("ok");
341        reg.register(Arc::new(TestPlugin::new("a"))).expect("ok");
342        reg.register(Arc::new(TestPlugin::new("c"))).expect("ok");
343
344        let ids: Vec<&str> = reg.list().iter().map(|p| p.id()).collect();
345        assert_eq!(ids, vec!["b", "a", "c"]);
346    }
347
348    #[tokio::test]
349    async fn registry_init_all_succeeds() {
350        let mut reg = PluginRegistry::new();
351        reg.register(Arc::new(TestPlugin::new("one"))).expect("ok");
352        reg.register(Arc::new(TestPlugin::new("two"))).expect("ok");
353        assert!(reg.init_all().await.is_ok());
354    }
355
356    #[tokio::test]
357    async fn registry_init_all_propagates_error() {
358        let mut reg = PluginRegistry::new();
359        reg.register(Arc::new(TestPlugin::new("good"))).expect("ok");
360        reg.register(Arc::new(TestPlugin::new("bad").with_fail_init())).expect("ok");
361
362        let result = reg.init_all().await;
363        assert!(result.is_err());
364        assert!(result.expect_err("should fail").to_string().contains("init failed"));
365    }
366
367    #[tokio::test]
368    async fn registry_shutdown_all_succeeds() {
369        let mut reg = PluginRegistry::new();
370        reg.register(Arc::new(TestPlugin::new("one"))).expect("ok");
371        reg.register(Arc::new(TestPlugin::new("two"))).expect("ok");
372        assert!(reg.shutdown_all().await.is_ok());
373    }
374
375    #[tokio::test]
376    async fn registry_shutdown_all_propagates_error() {
377        let mut reg = PluginRegistry::new();
378        reg.register(Arc::new(TestPlugin::new("good"))).expect("ok");
379        reg.register(Arc::new(TestPlugin::new("bad").with_fail_shutdown())).expect("ok");
380
381        let result = reg.shutdown_all().await;
382        assert!(result.is_err());
383    }
384
385    #[test]
386    fn registry_collect_routes() {
387        let mut reg = PluginRegistry::new();
388        reg.register(Arc::new(TestPlugin::new("api1").with_routes())).expect("ok");
389        reg.register(Arc::new(TestPlugin::new("api2").with_routes())).expect("ok");
390        reg.register(Arc::new(TestPlugin::new("no-routes"))).expect("ok");
391
392        // Should not panic — routes are merged successfully
393        let _router = reg.collect_routes();
394    }
395
396    #[test]
397    fn registry_collect_tools() {
398        let mut reg = PluginRegistry::new();
399        reg.register(Arc::new(TestPlugin::new("t1").with_tools())).expect("ok");
400        reg.register(Arc::new(TestPlugin::new("t2").with_tools())).expect("ok");
401        reg.register(Arc::new(TestPlugin::new("no-tools"))).expect("ok");
402
403        let tool_reg = reg.collect_tools();
404        let schemas = tool_reg.schemas();
405        assert_eq!(schemas.len(), 2);
406
407        let names: Vec<&str> = schemas.iter().map(|s| s.name.as_str()).collect();
408        assert!(names.contains(&"t1-tool"));
409        assert!(names.contains(&"t2-tool"));
410    }
411
412    #[test]
413    fn registry_get_nonexistent() {
414        let reg = PluginRegistry::new();
415        assert!(reg.get("missing").is_none());
416    }
417}