spec_ai_plugin/
loader.rs

1//! Plugin discovery and loading
2
3use crate::abi::{PluginModuleRef, PluginToolRef, PLUGIN_API_VERSION};
4use crate::error::PluginError;
5use abi_stable::library::RootModule;
6use anyhow::Result;
7use std::path::{Path, PathBuf};
8use tracing::{debug, error, info, warn};
9
10/// Statistics from loading plugins
11#[derive(Debug, Default, Clone)]
12pub struct LoadStats {
13    /// Total plugin files found
14    pub total: usize,
15    /// Successfully loaded plugins
16    pub loaded: usize,
17    /// Failed to load plugins
18    pub failed: usize,
19    /// Total tools loaded across all plugins
20    pub tools_loaded: usize,
21}
22
23/// A loaded plugin with its metadata
24pub struct LoadedPlugin {
25    /// Path to the plugin library
26    pub path: PathBuf,
27    /// Plugin name
28    pub name: String,
29    /// Tools provided by this plugin
30    pub tools: Vec<PluginToolRef>,
31}
32
33impl std::fmt::Debug for LoadedPlugin {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("LoadedPlugin")
36            .field("path", &self.path)
37            .field("name", &self.name)
38            .field("tools_count", &self.tools.len())
39            .finish()
40    }
41}
42
43/// Plugin loader that discovers and loads plugin libraries
44pub struct PluginLoader {
45    plugins: Vec<LoadedPlugin>,
46}
47
48impl PluginLoader {
49    /// Create a new empty plugin loader
50    pub fn new() -> Self {
51        Self {
52            plugins: Vec::new(),
53        }
54    }
55
56    /// Load all plugins from a directory
57    ///
58    /// Scans the directory for dynamic library files (.dylib on macOS, .so on Linux,
59    /// .dll on Windows) and attempts to load each one as a plugin.
60    ///
61    /// # Arguments
62    /// * `dir` - Directory to scan for plugins
63    ///
64    /// # Returns
65    /// Statistics about the loading process
66    pub fn load_directory(&mut self, dir: &Path) -> Result<LoadStats> {
67        let mut stats = LoadStats::default();
68
69        if !dir.exists() {
70            info!("Plugin directory does not exist: {}", dir.display());
71            return Ok(stats);
72        }
73
74        if !dir.is_dir() {
75            return Err(PluginError::NotADirectory(dir.to_path_buf()).into());
76        }
77
78        info!("Scanning plugin directory: {}", dir.display());
79
80        for entry in walkdir::WalkDir::new(dir)
81            .max_depth(1)
82            .into_iter()
83            .filter_map(|e| e.ok())
84        {
85            let path = entry.path();
86
87            if !Self::is_plugin_library(path) {
88                continue;
89            }
90
91            stats.total += 1;
92
93            match self.load_plugin(path) {
94                Ok(tool_count) => {
95                    stats.loaded += 1;
96                    stats.tools_loaded += tool_count;
97                    info!(
98                        "Loaded plugin: {} ({} tools)",
99                        path.display(),
100                        tool_count
101                    );
102                }
103                Err(e) => {
104                    stats.failed += 1;
105                    error!("Failed to load plugin {}: {}", path.display(), e);
106                }
107            }
108        }
109
110        Ok(stats)
111    }
112
113    /// Load a single plugin from a file
114    fn load_plugin(&mut self, path: &Path) -> Result<usize> {
115        debug!("Loading plugin from: {}", path.display());
116
117        // Load the root module using abi_stable
118        let module = PluginModuleRef::load_from_file(path).map_err(|e| PluginError::LoadFailed {
119            path: path.to_path_buf(),
120            message: e.to_string(),
121        })?;
122
123        // Check API version compatibility
124        let plugin_version = (module.api_version())();
125        if plugin_version != PLUGIN_API_VERSION {
126            return Err(PluginError::VersionMismatch {
127                expected: PLUGIN_API_VERSION,
128                found: plugin_version,
129                path: path.to_path_buf(),
130            }
131            .into());
132        }
133
134        let plugin_name = (module.plugin_name())().to_string();
135        debug!("Plugin '{}' passed version check", plugin_name);
136
137        // Check for duplicate plugin names
138        if self.plugins.iter().any(|p| p.name == plugin_name) {
139            return Err(PluginError::DuplicatePlugin(plugin_name).into());
140        }
141
142        // Get tools from the plugin
143        let tool_refs = (module.get_tools())();
144        let tool_count = tool_refs.len();
145
146        // Collect tool refs into a Vec
147        let tools: Vec<PluginToolRef> = tool_refs.into_iter().collect();
148
149        // Call initialize on each tool if it has one
150        for tool in &tools {
151            if let Some(init) = tool.initialize {
152                let context = "{}"; // Empty context for now
153                if !init(context.into()) {
154                    warn!(
155                        "Tool '{}' initialization failed",
156                        (tool.info)().name.as_str()
157                    );
158                }
159            }
160        }
161
162        self.plugins.push(LoadedPlugin {
163            path: path.to_path_buf(),
164            name: plugin_name,
165            tools,
166        });
167
168        Ok(tool_count)
169    }
170
171    /// Check if a path is a plugin library based on extension
172    fn is_plugin_library(path: &Path) -> bool {
173        if !path.is_file() {
174            return false;
175        }
176
177        let Some(ext) = path.extension() else {
178            return false;
179        };
180
181        #[cfg(target_os = "macos")]
182        let expected = "dylib";
183
184        #[cfg(target_os = "linux")]
185        let expected = "so";
186
187        #[cfg(target_os = "windows")]
188        let expected = "dll";
189
190        #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
191        let expected = "so"; // Default to .so for unknown platforms
192
193        ext == expected
194    }
195
196    /// Get all loaded plugins
197    pub fn plugins(&self) -> &[LoadedPlugin] {
198        &self.plugins
199    }
200
201    /// Get all tools from all loaded plugins as an iterator
202    pub fn all_tools(&self) -> impl Iterator<Item = (PluginToolRef, &str)> {
203        self.plugins.iter().flat_map(|p| {
204            p.tools.iter().map(move |t| (*t, p.name.as_str()))
205        })
206    }
207
208    /// Get the number of loaded plugins
209    pub fn plugin_count(&self) -> usize {
210        self.plugins.len()
211    }
212
213    /// Get the total number of tools across all plugins
214    pub fn tool_count(&self) -> usize {
215        self.plugins.iter().map(|p| p.tools.len()).sum()
216    }
217}
218
219impl Default for PluginLoader {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225/// Expand tilde (~) in paths to the home directory
226pub fn expand_tilde(path: &Path) -> PathBuf {
227    if let Ok(path_str) = path.to_str().ok_or(()) {
228        if path_str.starts_with("~/") {
229            if let Some(home) = dirs_home() {
230                return home.join(&path_str[2..]);
231            }
232        }
233    }
234    path.to_path_buf()
235}
236
237/// Get the user's home directory
238fn dirs_home() -> Option<PathBuf> {
239    #[cfg(target_os = "windows")]
240    {
241        std::env::var("USERPROFILE").ok().map(PathBuf::from)
242    }
243    #[cfg(not(target_os = "windows"))]
244    {
245        std::env::var("HOME").ok().map(PathBuf::from)
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use std::path::Path;
253
254    #[test]
255    fn test_is_plugin_library() {
256        // Note: is_plugin_library checks if the path is a file first,
257        // so these tests pass non-existent paths which will return false.
258        // The extension check only happens if the file exists.
259
260        // Non-existent paths always return false (file check first)
261        assert!(!PluginLoader::is_plugin_library(Path::new(
262            "/tmp/nonexistent/libplugin.dylib"
263        )));
264
265        // Non-library extensions also return false
266        assert!(!PluginLoader::is_plugin_library(Path::new(
267            "/tmp/test/plugin.txt"
268        )));
269        assert!(!PluginLoader::is_plugin_library(Path::new(
270            "/tmp/test/plugin"
271        )));
272    }
273
274    #[test]
275    fn test_expand_tilde() {
276        let home = dirs_home().unwrap_or_else(|| PathBuf::from("/home/user"));
277
278        let expanded = expand_tilde(Path::new("~/test"));
279        assert!(expanded.starts_with(&home) || expanded == Path::new("~/test"));
280
281        // Non-tilde paths should be unchanged
282        let absolute = expand_tilde(Path::new("/absolute/path"));
283        assert_eq!(absolute, Path::new("/absolute/path"));
284    }
285
286    #[test]
287    fn test_load_stats_default() {
288        let stats = LoadStats::default();
289        assert_eq!(stats.total, 0);
290        assert_eq!(stats.loaded, 0);
291        assert_eq!(stats.failed, 0);
292        assert_eq!(stats.tools_loaded, 0);
293    }
294}