Skip to main content

ubt_cli/plugin/
mod.rs

1pub mod declarative;
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use crate::error::{Result, UbtError};
7
8// ── Data Model ──────────────────────────────────────────────────────────
9
10#[derive(Debug, Clone)]
11pub struct DetectConfig {
12    pub files: Vec<String>,
13}
14
15#[derive(Debug, Clone)]
16pub struct Variant {
17    pub detect_files: Vec<String>,
18    pub binary: String,
19}
20
21#[derive(Debug, Clone, PartialEq)]
22pub enum FlagTranslation {
23    Translation(String),
24    Unsupported,
25}
26
27#[derive(Debug, Clone, PartialEq)]
28pub enum PluginSource {
29    BuiltIn,
30    File(PathBuf),
31}
32
33#[derive(Debug, Clone)]
34pub struct Plugin {
35    pub name: String,
36    pub description: String,
37    pub homepage: Option<String>,
38    pub install_help: Option<String>,
39    pub priority: i32,
40    pub default_variant: String,
41    pub detect: DetectConfig,
42    pub variants: HashMap<String, Variant>,
43    pub commands: HashMap<String, String>,
44    pub command_variants: HashMap<String, HashMap<String, String>>,
45    pub flags: HashMap<String, HashMap<String, FlagTranslation>>,
46    pub unsupported: HashMap<String, String>,
47}
48
49#[derive(Debug, Clone)]
50pub struct ResolvedPlugin {
51    pub name: String,
52    pub description: String,
53    pub homepage: Option<String>,
54    pub install_help: Option<String>,
55    pub variant_name: String,
56    pub binary: String,
57    pub commands: HashMap<String, String>,
58    pub flags: HashMap<String, HashMap<String, FlagTranslation>>,
59    pub unsupported: HashMap<String, String>,
60    pub source: PluginSource,
61}
62
63impl Plugin {
64    pub fn resolve_variant(
65        &self,
66        variant_name: &str,
67        source: PluginSource,
68    ) -> Result<ResolvedPlugin> {
69        let variant = self
70            .variants
71            .get(variant_name)
72            .ok_or_else(|| UbtError::PluginLoadError {
73                name: self.name.clone(),
74                detail: format!("variant '{}' not found", variant_name),
75            })?;
76
77        // Start with base commands, then overlay variant-specific overrides
78        let mut commands = self.commands.clone();
79        if let Some(overrides) = self.command_variants.get(variant_name) {
80            for (cmd, mapping) in overrides {
81                commands.insert(cmd.clone(), mapping.clone());
82            }
83        }
84
85        Ok(ResolvedPlugin {
86            name: self.name.clone(),
87            description: self.description.clone(),
88            homepage: self.homepage.clone(),
89            install_help: self.install_help.clone(),
90            variant_name: variant_name.to_string(),
91            binary: variant.binary.clone(),
92            commands,
93            flags: self.flags.clone(),
94            unsupported: self.unsupported.clone(),
95            source,
96        })
97    }
98}
99
100// ── Built-in Plugin Data ────────────────────────────────────────────────
101
102const BUILTIN_GO: &str = include_str!("../../plugins/go.toml");
103const BUILTIN_NODE: &str = include_str!("../../plugins/node.toml");
104const BUILTIN_PYTHON: &str = include_str!("../../plugins/python.toml");
105const BUILTIN_RUST: &str = include_str!("../../plugins/rust.toml");
106const BUILTIN_JAVA: &str = include_str!("../../plugins/java.toml");
107const BUILTIN_DOTNET: &str = include_str!("../../plugins/dotnet.toml");
108const BUILTIN_RUBY: &str = include_str!("../../plugins/ruby.toml");
109const BUILTIN_PHP: &str = include_str!("../../plugins/php.toml");
110const BUILTIN_CPP: &str = include_str!("../../plugins/cpp.toml");
111
112const BUILTIN_PLUGINS: &[&str] = &[
113    BUILTIN_GO,
114    BUILTIN_NODE,
115    BUILTIN_PYTHON,
116    BUILTIN_RUST,
117    BUILTIN_JAVA,
118    BUILTIN_DOTNET,
119    BUILTIN_RUBY,
120    BUILTIN_PHP,
121    BUILTIN_CPP,
122];
123
124// ── Plugin Registry ─────────────────────────────────────────────────────
125
126#[derive(Debug)]
127pub struct PluginRegistry {
128    plugins: HashMap<String, (Plugin, PluginSource)>,
129}
130
131impl PluginRegistry {
132    /// Create a new registry loaded with built-in plugins.
133    pub fn new() -> Result<Self> {
134        let mut registry = Self {
135            plugins: HashMap::new(),
136        };
137
138        for toml_str in BUILTIN_PLUGINS {
139            let plugin = declarative::parse_plugin_toml(toml_str)?;
140            registry
141                .plugins
142                .insert(plugin.name.clone(), (plugin, PluginSource::BuiltIn));
143        }
144
145        Ok(registry)
146    }
147
148    /// Load plugins from a directory. Later entries override earlier ones by name.
149    pub fn load_dir(&mut self, dir: &Path, source: PluginSource) -> Result<()> {
150        if !dir.is_dir() {
151            return Ok(());
152        }
153        let mut entries: Vec<_> = std::fs::read_dir(dir)?
154            .filter_map(|e| e.ok())
155            .filter(|e| {
156                e.path()
157                    .extension()
158                    .map(|ext| ext == "toml")
159                    .unwrap_or(false)
160            })
161            .collect();
162        entries.sort_by_key(|e| e.file_name());
163
164        for entry in entries {
165            let content = std::fs::read_to_string(entry.path())?;
166            let plugin = declarative::parse_plugin_toml(&content).map_err(|_| {
167                UbtError::PluginLoadError {
168                    name: entry.path().display().to_string(),
169                    detail: "failed to parse plugin TOML".into(),
170                }
171            })?;
172            let file_source = match &source {
173                PluginSource::BuiltIn => PluginSource::BuiltIn,
174                PluginSource::File(_) => PluginSource::File(entry.path()),
175            };
176            self.plugins
177                .insert(plugin.name.clone(), (plugin, file_source));
178        }
179        Ok(())
180    }
181
182    /// Load all plugin sources in priority order (later overrides earlier):
183    /// 1. Built-in (already loaded in `new()`)
184    /// 2. User plugins: ~/.config/ubt/plugins/
185    /// 3. UBT_PLUGIN_PATH dirs
186    /// 4. Project-local: .ubt/plugins/
187    pub fn load_all(&mut self, project_root: Option<&Path>) -> Result<()> {
188        // User plugins
189        if let Some(config_dir) = dirs::config_dir() {
190            let user_dir = config_dir.join("ubt").join("plugins");
191            self.load_dir(&user_dir, PluginSource::File(user_dir.clone()))?;
192        }
193
194        // UBT_PLUGIN_PATH
195        if let Ok(plugin_path) = std::env::var("UBT_PLUGIN_PATH") {
196            for dir in plugin_path.split(':') {
197                let path = PathBuf::from(dir);
198                self.load_dir(&path, PluginSource::File(path.clone()))?;
199            }
200        }
201
202        // Project-local plugins
203        if let Some(root) = project_root {
204            let local_dir = root.join(".ubt").join("plugins");
205            self.load_dir(&local_dir, PluginSource::File(local_dir.clone()))?;
206        }
207
208        Ok(())
209    }
210
211    /// Get a plugin by name.
212    pub fn get(&self, name: &str) -> Option<&(Plugin, PluginSource)> {
213        self.plugins.get(name)
214    }
215
216    /// Iterate over all plugins.
217    pub fn iter(&self) -> impl Iterator<Item = (&String, &(Plugin, PluginSource))> {
218        self.plugins.iter()
219    }
220
221    /// Get all plugin names.
222    pub fn names(&self) -> Vec<&String> {
223        self.plugins.keys().collect()
224    }
225}
226
227// ── Tests ──────────────────────────────────────────────────────────────
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    fn make_test_plugin() -> Plugin {
234        let mut variants = HashMap::new();
235        variants.insert(
236            "npm".to_string(),
237            Variant {
238                detect_files: vec!["package-lock.json".to_string()],
239                binary: "npm".to_string(),
240            },
241        );
242        variants.insert(
243            "pnpm".to_string(),
244            Variant {
245                detect_files: vec!["pnpm-lock.yaml".to_string()],
246                binary: "pnpm".to_string(),
247            },
248        );
249
250        let mut commands = HashMap::new();
251        commands.insert("test".to_string(), "{{tool}} test".to_string());
252        commands.insert("build".to_string(), "{{tool}} run build".to_string());
253        commands.insert("exec".to_string(), "npx {{args}}".to_string());
254
255        let mut pnpm_overrides = HashMap::new();
256        pnpm_overrides.insert("exec".to_string(), "pnpm dlx {{args}}".to_string());
257        let mut command_variants = HashMap::new();
258        command_variants.insert("pnpm".to_string(), pnpm_overrides);
259
260        let mut test_flags = HashMap::new();
261        test_flags.insert(
262            "coverage".to_string(),
263            FlagTranslation::Translation("--coverage".to_string()),
264        );
265        test_flags.insert(
266            "watch".to_string(),
267            FlagTranslation::Translation("--watchAll".to_string()),
268        );
269        let mut flags = HashMap::new();
270        flags.insert("test".to_string(), test_flags);
271
272        let mut unsupported = HashMap::new();
273        unsupported.insert(
274            "dep.why".to_string(),
275            "Use 'npm explain' directly".to_string(),
276        );
277
278        Plugin {
279            name: "node".to_string(),
280            description: "Node.js projects".to_string(),
281            homepage: Some("https://nodejs.org".to_string()),
282            install_help: Some("https://nodejs.org/en/download/".to_string()),
283            priority: 0,
284            default_variant: "npm".to_string(),
285            detect: DetectConfig {
286                files: vec!["package.json".to_string()],
287            },
288            variants,
289            commands,
290            command_variants,
291            flags,
292            unsupported,
293        }
294    }
295
296    #[test]
297    fn resolve_variant_merges_overrides() {
298        let plugin = make_test_plugin();
299        let resolved = plugin
300            .resolve_variant("pnpm", PluginSource::BuiltIn)
301            .unwrap();
302        assert_eq!(resolved.commands["exec"], "pnpm dlx {{args}}");
303        assert_eq!(resolved.commands["test"], "{{tool}} test");
304    }
305
306    #[test]
307    fn resolve_variant_unknown_returns_error() {
308        let plugin = make_test_plugin();
309        let result = plugin.resolve_variant("nonexistent", PluginSource::BuiltIn);
310        assert!(result.is_err());
311        let err = result.unwrap_err();
312        assert!(err.to_string().contains("not found"));
313    }
314
315    #[test]
316    fn resolve_variant_carries_flags() {
317        let plugin = make_test_plugin();
318        let resolved = plugin
319            .resolve_variant("npm", PluginSource::BuiltIn)
320            .unwrap();
321        assert_eq!(
322            resolved.flags["test"]["coverage"],
323            FlagTranslation::Translation("--coverage".to_string())
324        );
325    }
326
327    #[test]
328    fn resolve_variant_carries_unsupported() {
329        let plugin = make_test_plugin();
330        let resolved = plugin
331            .resolve_variant("npm", PluginSource::BuiltIn)
332            .unwrap();
333        assert!(resolved.unsupported.contains_key("dep.why"));
334    }
335
336    // ── Registry tests ──────────────────────────────────────────────────
337
338    #[test]
339    fn registry_loads_builtin_plugins() {
340        let registry = PluginRegistry::new().unwrap();
341        assert!(registry.get("go").is_some());
342        assert!(registry.get("node").is_some());
343        assert!(registry.get("python").is_some());
344        assert!(registry.get("rust").is_some());
345        assert!(registry.get("java").is_some());
346        assert!(registry.get("dotnet").is_some());
347        assert!(registry.get("ruby").is_some());
348        assert!(registry.get("php").is_some());
349        assert!(registry.get("cpp").is_some());
350    }
351
352    #[test]
353    fn registry_builtin_go_has_correct_detect() {
354        let registry = PluginRegistry::new().unwrap();
355        let (plugin, source) = registry.get("go").unwrap();
356        assert_eq!(plugin.detect.files, vec!["go.mod"]);
357        assert_eq!(*source, PluginSource::BuiltIn);
358    }
359
360    #[test]
361    fn registry_builtin_node_has_variants() {
362        let registry = PluginRegistry::new().unwrap();
363        let (plugin, _) = registry.get("node").unwrap();
364        assert_eq!(plugin.variants.len(), 5);
365        assert!(plugin.variants.contains_key("npm"));
366        assert!(plugin.variants.contains_key("pnpm"));
367        assert!(plugin.variants.contains_key("yarn"));
368        assert!(plugin.variants.contains_key("bun"));
369        assert!(plugin.variants.contains_key("deno"));
370    }
371
372    #[test]
373    fn registry_load_dir_adds_plugins() {
374        let dir = tempfile::TempDir::new().unwrap();
375        let toml_content = r#"
376[plugin]
377name = "custom"
378[detect]
379files = ["custom.txt"]
380[variants.default]
381binary = "custom"
382"#;
383        std::fs::write(dir.path().join("custom.toml"), toml_content).unwrap();
384
385        let mut registry = PluginRegistry::new().unwrap();
386        registry
387            .load_dir(dir.path(), PluginSource::File(dir.path().to_path_buf()))
388            .unwrap();
389
390        assert!(registry.get("custom").is_some());
391    }
392
393    #[test]
394    fn registry_load_dir_overrides_builtin() {
395        let dir = tempfile::TempDir::new().unwrap();
396        let toml_content = r#"
397[plugin]
398name = "go"
399description = "Custom Go"
400[detect]
401files = ["go.mod"]
402[variants.go]
403binary = "go"
404"#;
405        std::fs::write(dir.path().join("go.toml"), toml_content).unwrap();
406
407        let mut registry = PluginRegistry::new().unwrap();
408        registry
409            .load_dir(dir.path(), PluginSource::File(dir.path().to_path_buf()))
410            .unwrap();
411
412        let (plugin, source) = registry.get("go").unwrap();
413        assert_eq!(plugin.description, "Custom Go");
414        assert!(matches!(source, PluginSource::File(_)));
415    }
416
417    #[test]
418    fn registry_load_dir_nonexistent_is_ok() {
419        let mut registry = PluginRegistry::new().unwrap();
420        let result = registry.load_dir(
421            Path::new("/nonexistent/path"),
422            PluginSource::File(PathBuf::from("/nonexistent")),
423        );
424        assert!(result.is_ok());
425    }
426
427    #[test]
428    fn registry_names_returns_all() {
429        let registry = PluginRegistry::new().unwrap();
430        let names = registry.names();
431        assert!(names.len() >= 9);
432    }
433}