Skip to main content

shape_runtime/plugins/
loader.rs

1//! Plugin Loader
2//!
3//! Handles dynamic loading of plugin shared libraries using libloading.
4
5use std::collections::HashMap;
6use std::ffi::CStr;
7use std::path::{Path, PathBuf};
8use std::process::Command;
9
10use libloading::{Library, Symbol};
11
12use shape_abi_v1::{
13    ABI_VERSION, CAPABILITY_DATA_SOURCE, CAPABILITY_LANGUAGE_RUNTIME, CAPABILITY_MODULE,
14    CAPABILITY_OUTPUT_SINK, CapabilityKind, CapabilityManifest, DataSourceVTable, GetAbiVersionFn,
15    GetCapabilityManifestFn, GetCapabilityVTableFn, GetClaimedSectionsFn, GetPluginInfoFn,
16    LanguageRuntimeVTable, ModuleVTable, OutputSinkVTable, PluginType, SectionsManifest,
17};
18
19use shape_ast::error::{Result, ShapeError};
20
21/// A TOML section claimed by a loaded plugin.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct ClaimedSection {
24    /// Section name (e.g., "native-dependencies")
25    pub name: String,
26    /// Whether this section is required (error if missing)
27    pub required: bool,
28}
29
30/// One declared capability exposed by a loaded plugin.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct PluginCapability {
33    /// Capability family.
34    pub kind: CapabilityKind,
35    /// Contract name (e.g., `shape.datasource`).
36    pub contract: String,
37    /// Contract version (e.g., `1`).
38    pub version: String,
39    /// Reserved capability flags.
40    pub flags: u64,
41}
42
43/// Information about a loaded plugin
44#[derive(Debug, Clone)]
45pub struct LoadedPlugin {
46    /// Plugin name
47    pub name: String,
48    /// Plugin version
49    pub version: String,
50    /// Plugin type
51    pub plugin_type: PluginType,
52    /// Plugin description
53    pub description: String,
54    /// Self-declared capability contracts.
55    pub capabilities: Vec<PluginCapability>,
56    /// TOML sections claimed by this plugin.
57    pub claimed_sections: Vec<ClaimedSection>,
58}
59
60impl LoadedPlugin {
61    /// Returns true if the plugin declares at least one capability with `kind`.
62    pub fn has_capability_kind(&self, kind: CapabilityKind) -> bool {
63        self.capabilities.iter().any(|cap| cap.kind == kind)
64    }
65
66    /// Returns the names of all claimed sections.
67    pub fn claimed_section_names(&self) -> Vec<&str> {
68        self.claimed_sections
69            .iter()
70            .map(|s| s.name.as_str())
71            .collect()
72    }
73}
74
75/// Plugin Loader
76///
77/// Manages dynamic loading and unloading of Shape plugins.
78/// Keeps loaded libraries in memory to prevent unloading while in use.
79pub struct PluginLoader {
80    /// Loaded libraries (kept alive to prevent unloading)
81    loaded_libraries: HashMap<String, Library>,
82}
83
84impl PluginLoader {
85    /// Create a new plugin loader
86    pub fn new() -> Self {
87        Self {
88            loaded_libraries: HashMap::new(),
89        }
90    }
91
92    /// Load a plugin from a shared library file
93    ///
94    /// # Arguments
95    /// * `path` - Path to the shared library (.so, .dll, .dylib)
96    ///
97    /// # Returns
98    /// Information about the loaded plugin
99    ///
100    /// # Safety
101    /// Loading plugins executes arbitrary code. Only load from trusted sources.
102    pub fn load(&mut self, path: &Path) -> Result<LoadedPlugin> {
103        // Load the library
104        let lib =
105            load_library_with_python_fallback(path).map_err(|e| ShapeError::RuntimeError {
106                message: format!("Failed to load plugin library '{}': {}", path.display(), e),
107                location: None,
108            })?;
109
110        // Check ABI version if available
111        if let Ok(get_version) = unsafe { lib.get::<GetAbiVersionFn>(b"shape_abi_version") } {
112            let version = unsafe { get_version() };
113            if version != ABI_VERSION {
114                return Err(ShapeError::RuntimeError {
115                    message: format!(
116                        "Plugin ABI version mismatch: expected {}, got {}",
117                        ABI_VERSION, version
118                    ),
119                    location: None,
120                });
121            }
122        }
123
124        // Get plugin info
125        let get_info: Symbol<GetPluginInfoFn> = unsafe {
126            lib.get(b"shape_plugin_info")
127                .map_err(|e| ShapeError::RuntimeError {
128                    message: format!("Plugin missing 'shape_plugin_info' export: {}", e),
129                    location: None,
130                })?
131        };
132
133        let info_ptr = unsafe { get_info() };
134        if info_ptr.is_null() {
135            return Err(ShapeError::RuntimeError {
136                message: "Plugin returned null PluginInfo".to_string(),
137                location: None,
138            });
139        }
140
141        let info = unsafe { &*info_ptr };
142
143        // Extract info strings
144        let name = read_c_string(info.name, "PluginInfo.name")?;
145        let version = read_c_string(info.version, "PluginInfo.version")?;
146        let description = read_c_string(info.description, "PluginInfo.description")?;
147
148        let capabilities = self.load_capabilities(&lib)?;
149
150        // Load optional section claims
151        let claimed_sections = if let Ok(get_sections) =
152            unsafe { lib.get::<GetClaimedSectionsFn>(b"shape_claimed_sections") }
153        {
154            let manifest_ptr = unsafe { get_sections() };
155            if manifest_ptr.is_null() {
156                vec![]
157            } else {
158                let manifest = unsafe { &*manifest_ptr };
159                parse_sections_manifest(manifest)?
160            }
161        } else {
162            vec![] // Optional — no section claims
163        };
164
165        // Store the library
166        self.loaded_libraries.insert(name.clone(), lib);
167
168        Ok(LoadedPlugin {
169            name,
170            version,
171            plugin_type: info.plugin_type,
172            description,
173            capabilities,
174            claimed_sections,
175        })
176    }
177
178    fn load_capabilities(&self, lib: &Library) -> Result<Vec<PluginCapability>> {
179        let get_manifest =
180            unsafe { lib.get::<GetCapabilityManifestFn>(b"shape_capability_manifest") }.map_err(
181                |e| ShapeError::RuntimeError {
182                    message: format!(
183                        "Plugin missing required 'shape_capability_manifest' export: {}",
184                        e
185                    ),
186                    location: None,
187                },
188            )?;
189
190        let manifest_ptr = unsafe { get_manifest() };
191        if manifest_ptr.is_null() {
192            return Err(ShapeError::RuntimeError {
193                message: "Plugin returned null CapabilityManifest".to_string(),
194                location: None,
195            });
196        }
197        let manifest = unsafe { &*manifest_ptr };
198        parse_capability_manifest(manifest)
199    }
200
201    /// Get the data source vtable for a loaded plugin
202    ///
203    /// # Arguments
204    /// * `name` - Name of the loaded plugin
205    ///
206    /// # Returns
207    /// The DataSourceVTable if plugin exists and is a data source
208    pub fn get_data_source_vtable(&self, name: &str) -> Result<&'static DataSourceVTable> {
209        let lib = self
210            .loaded_libraries
211            .get(name)
212            .ok_or_else(|| ShapeError::RuntimeError {
213                message: format!("Plugin '{}' not loaded", name),
214                location: None,
215            })?;
216
217        if let Some(vtable_ptr) = try_capability_vtable(lib, CAPABILITY_DATA_SOURCE)? {
218            // SAFETY: vtable pointer is provided by the loaded module and expected static.
219            return Ok(unsafe { &*(vtable_ptr as *const DataSourceVTable) });
220        }
221
222        Err(ShapeError::RuntimeError {
223            message: format!(
224                "Plugin '{}' does not provide capability vtable for '{}'",
225                name, CAPABILITY_DATA_SOURCE
226            ),
227            location: None,
228        })
229    }
230
231    /// Get the output sink vtable for a loaded plugin
232    ///
233    /// # Arguments
234    /// * `name` - Name of the loaded plugin
235    ///
236    /// # Returns
237    /// The OutputSinkVTable if plugin exists and is an output sink
238    pub fn get_output_sink_vtable(&self, name: &str) -> Result<&'static OutputSinkVTable> {
239        let lib = self
240            .loaded_libraries
241            .get(name)
242            .ok_or_else(|| ShapeError::RuntimeError {
243                message: format!("Plugin '{}' not loaded", name),
244                location: None,
245            })?;
246
247        if let Some(vtable_ptr) = try_capability_vtable(lib, CAPABILITY_OUTPUT_SINK)? {
248            // SAFETY: vtable pointer is provided by the loaded module and expected static.
249            return Ok(unsafe { &*(vtable_ptr as *const OutputSinkVTable) });
250        }
251
252        Err(ShapeError::RuntimeError {
253            message: format!(
254                "Plugin '{}' does not provide capability vtable for '{}'",
255                name, CAPABILITY_OUTPUT_SINK
256            ),
257            location: None,
258        })
259    }
260
261    /// Get the base module vtable for a loaded plugin.
262    pub fn get_module_vtable(&self, name: &str) -> Result<&'static ModuleVTable> {
263        let lib = self
264            .loaded_libraries
265            .get(name)
266            .ok_or_else(|| ShapeError::RuntimeError {
267                message: format!("Plugin '{}' not loaded", name),
268                location: None,
269            })?;
270
271        if let Some(vtable_ptr) = try_capability_vtable(lib, CAPABILITY_MODULE)? {
272            // SAFETY: vtable pointer is provided by the loaded module and expected static.
273            return Ok(unsafe { &*(vtable_ptr as *const ModuleVTable) });
274        }
275
276        Err(ShapeError::RuntimeError {
277            message: format!(
278                "Plugin '{}' does not provide capability vtable for '{}'",
279                name, CAPABILITY_MODULE
280            ),
281            location: None,
282        })
283    }
284
285    /// Get the language runtime vtable for a loaded plugin.
286    pub fn get_language_runtime_vtable(
287        &self,
288        name: &str,
289    ) -> Result<&'static LanguageRuntimeVTable> {
290        let lib = self
291            .loaded_libraries
292            .get(name)
293            .ok_or_else(|| ShapeError::RuntimeError {
294                message: format!("Plugin '{}' not loaded", name),
295                location: None,
296            })?;
297
298        if let Some(vtable_ptr) = try_capability_vtable(lib, CAPABILITY_LANGUAGE_RUNTIME)? {
299            return Ok(unsafe { &*(vtable_ptr as *const LanguageRuntimeVTable) });
300        }
301
302        Err(ShapeError::RuntimeError {
303            message: format!(
304                "Plugin '{}' does not provide capability vtable for '{}'",
305                name, CAPABILITY_LANGUAGE_RUNTIME
306            ),
307            location: None,
308        })
309    }
310
311    /// Unload a plugin
312    ///
313    /// Note: The library is actually unloaded when dropped. This removes it
314    /// from the loader's tracking.
315    pub fn unload(&mut self, name: &str) -> bool {
316        self.loaded_libraries.remove(name).is_some()
317    }
318
319    /// List all loaded plugins
320    pub fn loaded_plugins(&self) -> Vec<&str> {
321        self.loaded_libraries.keys().map(|s| s.as_str()).collect()
322    }
323
324    /// Check if a plugin is loaded
325    pub fn is_loaded(&self, name: &str) -> bool {
326        self.loaded_libraries.contains_key(name)
327    }
328
329    /// Load a data source plugin and return a ready-to-use wrapper
330    ///
331    /// This is a convenience method that combines loading the library,
332    /// getting the vtable, and creating the PluginDataSource wrapper.
333    ///
334    /// # Arguments
335    /// * `path` - Path to the shared library
336    /// * `config` - Configuration value for the plugin
337    ///
338    /// # Returns
339    /// Ready-to-use PluginDataSource wrapper
340    pub fn load_data_source(
341        &mut self,
342        path: &Path,
343        config: &serde_json::Value,
344    ) -> Result<super::PluginDataSource> {
345        // Load the library and get info
346        let info = self.load(path)?;
347        let name = info.name.clone();
348
349        if !info.has_capability_kind(CapabilityKind::DataSource) {
350            return Err(ShapeError::RuntimeError {
351                message: format!(
352                    "Plugin '{}' does not declare data source capability",
353                    info.name
354                ),
355                location: None,
356            });
357        }
358
359        // Get the vtable
360        let vtable = self.get_data_source_vtable(&name)?;
361
362        // Create and return the wrapper
363        super::PluginDataSource::new(name, vtable, config)
364    }
365}
366
367fn load_library_with_python_fallback(path: &Path) -> std::result::Result<Library, String> {
368    let initial = unsafe { Library::new(path) };
369    let initial_error = match initial {
370        Ok(lib) => return Ok(lib),
371        Err(err) => err,
372    };
373    let initial_msg = initial_error.to_string();
374
375    if !should_try_python_fallback(&initial_msg) {
376        return Err(initial_msg);
377    }
378
379    if !preload_python_shared_library() {
380        return Err(initial_msg);
381    }
382
383    match unsafe { Library::new(path) } {
384        Ok(lib) => Ok(lib),
385        Err(retry_err) => Err(format!(
386            "{} (retry after python preload failed: {})",
387            initial_msg, retry_err
388        )),
389    }
390}
391
392fn should_try_python_fallback(error_message: &str) -> bool {
393    let lowered = error_message.to_ascii_lowercase();
394    lowered.contains("libpython") || lowered.contains("python.framework")
395}
396
397fn preload_python_shared_library() -> bool {
398    let candidates = discover_python_shared_library_candidates();
399    for candidate in candidates {
400        match unsafe { Library::new(&candidate) } {
401            Ok(lib) => {
402                tracing::info!(
403                    "preloaded python runtime library for extension loading fallback: {}",
404                    candidate.display()
405                );
406                // Keep the library loaded for process lifetime.
407                std::mem::forget(lib);
408                return true;
409            }
410            Err(err) => {
411                tracing::debug!(
412                    "failed to preload python runtime candidate '{}': {}",
413                    candidate.display(),
414                    err
415                );
416            }
417        }
418    }
419    false
420}
421
422fn discover_python_shared_library_candidates() -> Vec<PathBuf> {
423    let python = std::env::var("PYO3_PYTHON").unwrap_or_else(|_| "python3".to_string());
424    let script = r#"import os, sys, sysconfig
425cands = []
426libdir = sysconfig.get_config_var("LIBDIR")
427ldlibrary = sysconfig.get_config_var("LDLIBRARY")
428if libdir and ldlibrary:
429    cands.append(os.path.join(libdir, ldlibrary))
430if libdir:
431    for name in ("libpython3.so", "libpython3.so.1.0", "libpython3.dylib"):
432        cands.append(os.path.join(libdir, name))
433for base in {sys.base_prefix, sys.prefix}:
434    if not base:
435        continue
436    for rel in ("lib", "lib64"):
437        d = os.path.join(base, rel)
438        if ldlibrary:
439            cands.append(os.path.join(d, ldlibrary))
440seen = set()
441for cand in cands:
442    if not cand:
443        continue
444    real = os.path.realpath(cand)
445    if real in seen:
446        continue
447    seen.add(real)
448    if os.path.exists(real):
449        print(real)
450"#;
451
452    let output = Command::new(&python).arg("-c").arg(script).output();
453    let Ok(output) = output else {
454        return Vec::new();
455    };
456    if !output.status.success() {
457        return Vec::new();
458    }
459
460    String::from_utf8_lossy(&output.stdout)
461        .lines()
462        .map(str::trim)
463        .filter(|line| !line.is_empty())
464        .map(PathBuf::from)
465        .collect()
466}
467
468impl Drop for PluginLoader {
469    fn drop(&mut self) {
470        // Language runtime extensions (e.g. Python/pyo3) may register process-level
471        // atexit handlers that reference code inside the loaded .so. If we dlclose
472        // the library before those handlers run at process exit, the process segfaults.
473        // Intentionally leak language runtime libraries so they remain mapped.
474        for (_name, lib) in self.loaded_libraries.drain() {
475            if let Ok(get_manifest) =
476                unsafe { lib.get::<GetCapabilityManifestFn>(b"shape_capability_manifest") }
477            {
478                let manifest_ptr = unsafe { get_manifest() };
479                if !manifest_ptr.is_null() {
480                    let manifest = unsafe { &*manifest_ptr };
481                    if let Ok(caps) = parse_capability_manifest(manifest) {
482                        if caps
483                            .iter()
484                            .any(|c| c.kind == CapabilityKind::LanguageRuntime)
485                        {
486                            // Leak: keep the library mapped for the process lifetime.
487                            std::mem::forget(lib);
488                            continue;
489                        }
490                    }
491                }
492            }
493            // Non-language-runtime libraries are dropped normally (dlclose).
494            drop(lib);
495        }
496    }
497}
498
499impl Default for PluginLoader {
500    fn default() -> Self {
501        Self::new()
502    }
503}
504
505fn try_capability_vtable(lib: &Library, contract: &str) -> Result<Option<*const std::ffi::c_void>> {
506    let get_vtable_fn = unsafe { lib.get::<GetCapabilityVTableFn>(b"shape_capability_vtable") };
507    let Ok(get_vtable_fn) = get_vtable_fn else {
508        return Ok(None);
509    };
510
511    let vtable_ptr = unsafe { get_vtable_fn(contract.as_ptr(), contract.len()) };
512    if vtable_ptr.is_null() {
513        return Ok(None);
514    }
515    Ok(Some(vtable_ptr))
516}
517
518fn parse_capability_manifest(manifest: &CapabilityManifest) -> Result<Vec<PluginCapability>> {
519    if manifest.capabilities_len == 0 {
520        return Err(ShapeError::RuntimeError {
521            message: "CapabilityManifest must contain at least one capability".to_string(),
522            location: None,
523        });
524    }
525    if manifest.capabilities.is_null() {
526        return Err(ShapeError::RuntimeError {
527            message: "CapabilityManifest.capabilities is null".to_string(),
528            location: None,
529        });
530    }
531
532    let caps =
533        unsafe { std::slice::from_raw_parts(manifest.capabilities, manifest.capabilities_len) };
534    let mut parsed = Vec::with_capacity(caps.len());
535    for cap in caps {
536        parsed.push(PluginCapability {
537            kind: cap.kind,
538            contract: read_c_string(cap.contract, "CapabilityDescriptor.contract")?,
539            version: read_c_string(cap.version, "CapabilityDescriptor.version")?,
540            flags: cap.flags,
541        });
542    }
543    Ok(parsed)
544}
545
546pub fn parse_sections_manifest(manifest: &SectionsManifest) -> Result<Vec<ClaimedSection>> {
547    if manifest.sections_len == 0 {
548        return Ok(vec![]);
549    }
550    if manifest.sections.is_null() {
551        return Err(ShapeError::RuntimeError {
552            message: "SectionsManifest.sections is null but sections_len > 0".to_string(),
553            location: None,
554        });
555    }
556
557    let claims = unsafe { std::slice::from_raw_parts(manifest.sections, manifest.sections_len) };
558    let mut parsed = Vec::with_capacity(claims.len());
559    for claim in claims {
560        parsed.push(ClaimedSection {
561            name: read_c_string(claim.name, "SectionClaim.name")?,
562            required: claim.required,
563        });
564    }
565    Ok(parsed)
566}
567
568fn read_c_string(ptr: *const std::ffi::c_char, field: &str) -> Result<String> {
569    if ptr.is_null() {
570        return Err(ShapeError::RuntimeError {
571            message: format!("{} is null", field),
572            location: None,
573        });
574    }
575
576    Ok(unsafe { CStr::from_ptr(ptr) }.to_string_lossy().to_string())
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582    use shape_abi_v1::{CAPABILITY_MODULE, CapabilityDescriptor};
583
584    #[test]
585    fn test_plugin_loader_new() {
586        let loader = PluginLoader::new();
587        assert!(loader.loaded_plugins().is_empty());
588    }
589
590    #[test]
591    fn test_is_loaded_false() {
592        let loader = PluginLoader::new();
593        assert!(!loader.is_loaded("nonexistent"));
594    }
595
596    #[test]
597    fn test_should_try_python_fallback_matches_libpython_errors() {
598        assert!(should_try_python_fallback(
599            "libpython3.13.so.1.0: cannot open shared object file"
600        ));
601        assert!(should_try_python_fallback(
602            "Library not loaded: @rpath/Python.framework/Versions/3.12/Python"
603        ));
604        assert!(!should_try_python_fallback(
605            "undefined symbol: sqlite3_open"
606        ));
607    }
608
609    #[test]
610    fn test_parse_capability_manifest() {
611        static CAPS: [CapabilityDescriptor; 2] = [
612            CapabilityDescriptor {
613                kind: CapabilityKind::DataSource,
614                contract: c"shape.datasource".as_ptr(),
615                version: c"1".as_ptr(),
616                flags: 0,
617            },
618            CapabilityDescriptor {
619                kind: CapabilityKind::Compute,
620                contract: c"shape.compute".as_ptr(),
621                version: c"1".as_ptr(),
622                flags: 42,
623            },
624        ];
625        static MANIFEST: CapabilityManifest = CapabilityManifest {
626            capabilities: CAPS.as_ptr(),
627            capabilities_len: CAPS.len(),
628        };
629
630        let parsed = parse_capability_manifest(&MANIFEST).expect("manifest should parse");
631        assert_eq!(parsed.len(), 2);
632        assert_eq!(parsed[0].contract, "shape.datasource");
633        assert_eq!(parsed[1].kind, CapabilityKind::Compute);
634        assert_eq!(parsed[1].flags, 42);
635    }
636
637    #[test]
638    fn test_parse_capability_manifest_rejects_empty() {
639        static MANIFEST: CapabilityManifest = CapabilityManifest {
640            capabilities: std::ptr::null(),
641            capabilities_len: 0,
642        };
643        let result = parse_capability_manifest(&MANIFEST);
644        assert!(result.is_err());
645    }
646
647    #[test]
648    fn test_module_contract_constant_is_expected() {
649        assert_eq!(CAPABILITY_MODULE, "shape.module");
650    }
651
652    #[test]
653    fn test_parse_sections_manifest_valid() {
654        use shape_abi_v1::SectionClaim as AbiSectionClaim;
655
656        static CLAIMS: [AbiSectionClaim; 2] = [
657            AbiSectionClaim {
658                name: c"native-dependencies".as_ptr(),
659                required: false,
660            },
661            AbiSectionClaim {
662                name: c"custom-config".as_ptr(),
663                required: true,
664            },
665        ];
666        static MANIFEST: SectionsManifest = SectionsManifest {
667            sections: CLAIMS.as_ptr(),
668            sections_len: CLAIMS.len(),
669        };
670
671        let parsed = parse_sections_manifest(&MANIFEST).expect("should parse");
672        assert_eq!(parsed.len(), 2);
673        assert_eq!(parsed[0].name, "native-dependencies");
674        assert!(!parsed[0].required);
675        assert_eq!(parsed[1].name, "custom-config");
676        assert!(parsed[1].required);
677    }
678
679    #[test]
680    fn test_parse_sections_manifest_empty() {
681        static MANIFEST: SectionsManifest = SectionsManifest {
682            sections: std::ptr::null(),
683            sections_len: 0,
684        };
685        let parsed = parse_sections_manifest(&MANIFEST).expect("empty should parse");
686        assert!(parsed.is_empty());
687    }
688}