Skip to main content

shape_runtime/
provider_registry.rs

1//! Provider registry for managing named data providers
2//!
3//! This module provides a registry system that allows multiple data providers
4//! to be registered by name and accessed at runtime. This enables:
5//! - Multiple data sources (data, CSV, APIs, future plugins)
6//! - Provider selection in Shape code via module-scoped calls
7//!   (for example, `provider.load({...})`)
8//! - Default provider for convenience
9
10use crate::data::SharedAsyncProvider;
11use crate::plugins::{
12    CapabilityKind, LoadedPlugin, ParsedModuleSchema, ParsedOutputSchema, ParsedQuerySchema,
13    PluginDataSource, PluginLoader, PluginModule,
14};
15use shape_ast::error::{Result, ShapeError};
16use shape_wire::WireValue;
17use std::collections::HashMap;
18use std::path::Path;
19use std::sync::{Arc, RwLock};
20
21/// Registry of named data providers
22///
23/// Allows registration of multiple providers and selection by name.
24/// Thread-safe for use in concurrent contexts.
25///
26/// # Example
27///
28/// ```ignore
29/// let mut registry = ProviderRegistry::new();
30///
31/// // Register market data provider
32/// let md_provider = Arc::new(DataFrameAdapter::new(...));
33/// registry.register("data", md_provider);
34///
35/// // Register an additional provider
36/// let alt_provider = Arc::new(AnotherProvider::new(...));
37/// registry.register("alt", alt_provider);
38///
39/// // Set default
40/// registry.set_default("data")?;
41///
42/// // Get provider by name
43/// let provider = registry.get("data")?;
44/// ```
45#[derive(Clone)]
46pub struct ProviderRegistry {
47    /// Map of provider name to provider instance
48    providers: Arc<RwLock<HashMap<String, SharedAsyncProvider>>>,
49    /// Name of the default provider
50    default_provider: Arc<RwLock<Option<String>>>,
51    /// Map of plugin name to plugin data source wrapper
52    extension_sources: Arc<RwLock<HashMap<String, Arc<PluginDataSource>>>>,
53    /// Map of plugin name to plugin module-capability wrapper.
54    extension_modules: Arc<RwLock<HashMap<String, Arc<PluginModule>>>>,
55    /// Metadata for all loaded extension modules (not just data-source modules).
56    loaded_extensions: Arc<RwLock<HashMap<String, LoadedPlugin>>>,
57    /// Plugin loader for dynamic plugins
58    extension_loader: Arc<RwLock<PluginLoader>>,
59    /// Map of language identifier to loaded language runtime
60    language_runtimes:
61        Arc<RwLock<HashMap<String, Arc<crate::plugins::language_runtime::PluginLanguageRuntime>>>>,
62}
63
64impl ProviderRegistry {
65    /// Create a new empty provider registry
66    pub fn new() -> Self {
67        Self {
68            providers: Arc::new(RwLock::new(HashMap::new())),
69            default_provider: Arc::new(RwLock::new(None)),
70            extension_sources: Arc::new(RwLock::new(HashMap::new())),
71            extension_modules: Arc::new(RwLock::new(HashMap::new())),
72            loaded_extensions: Arc::new(RwLock::new(HashMap::new())),
73            extension_loader: Arc::new(RwLock::new(PluginLoader::new())),
74            language_runtimes: Arc::new(RwLock::new(HashMap::new())),
75        }
76    }
77
78    /// Register a provider with a name
79    ///
80    /// # Arguments
81    ///
82    /// * `name` - Provider name (e.g., "data", "api", "warehouse")
83    /// * `provider` - AsyncDataProvider implementation
84    ///
85    /// # Example
86    ///
87    /// ```ignore
88    /// registry.register("data", Arc::new(DataFrameAdapter::new(...)));
89    /// ```
90    pub fn register(&self, name: &str, provider: SharedAsyncProvider) {
91        let mut providers = self.providers.write().unwrap();
92        providers.insert(name.to_string(), provider);
93    }
94
95    /// Get provider by name
96    ///
97    /// # Arguments
98    ///
99    /// * `name` - Provider name to lookup
100    ///
101    /// # Returns
102    ///
103    /// SharedAsyncProvider if found, None otherwise
104    pub fn get(&self, name: &str) -> Option<SharedAsyncProvider> {
105        let providers = self.providers.read().unwrap();
106        providers.get(name).cloned()
107    }
108
109    /// Set default provider
110    ///
111    /// # Arguments
112    ///
113    /// * `name` - Name of provider to use as default
114    ///
115    /// # Errors
116    ///
117    /// Returns error if provider with given name is not registered
118    pub fn set_default(&self, name: &str) -> Result<()> {
119        let providers = self.providers.read().unwrap();
120        if !providers.contains_key(name) {
121            return Err(ShapeError::RuntimeError {
122                message: format!("Cannot set default provider: '{}' is not registered", name),
123                location: None,
124            });
125        }
126        drop(providers);
127
128        let mut default = self.default_provider.write().unwrap();
129        *default = Some(name.to_string());
130        Ok(())
131    }
132
133    /// Get default provider
134    ///
135    /// # Returns
136    ///
137    /// SharedAsyncProvider if a default is set, None otherwise
138    pub fn get_default(&self) -> Option<SharedAsyncProvider> {
139        let default = self.default_provider.read().unwrap();
140        let name = default.as_ref().cloned();
141        drop(default);
142
143        name.and_then(|n| self.get(&n))
144    }
145
146    /// Get default provider name
147    pub fn default_name(&self) -> Option<String> {
148        let default = self.default_provider.read().unwrap();
149        default.clone()
150    }
151
152    /// List all registered provider names
153    ///
154    /// # Returns
155    ///
156    /// Vector of provider names currently registered
157    pub fn list_providers(&self) -> Vec<String> {
158        let providers = self.providers.read().unwrap();
159        providers.keys().cloned().collect()
160    }
161
162    /// Check if a provider is registered
163    pub fn has_provider(&self, name: &str) -> bool {
164        let providers = self.providers.read().unwrap();
165        providers.contains_key(name)
166    }
167
168    /// Unregister a provider
169    ///
170    /// # Arguments
171    ///
172    /// * `name` - Provider name to remove
173    ///
174    /// # Returns
175    ///
176    /// true if provider was removed, false if not found
177    pub fn unregister(&self, name: &str) -> bool {
178        let mut providers = self.providers.write().unwrap();
179        let removed = providers.remove(name).is_some();
180
181        // Clear default if it was the removed provider
182        if removed {
183            let mut default = self.default_provider.write().unwrap();
184            if default.as_ref().map(|s| s == name).unwrap_or(false) {
185                *default = None;
186            }
187        }
188
189        removed
190    }
191
192    /// Clear all providers
193    pub fn clear(&self) {
194        let mut providers = self.providers.write().unwrap();
195        providers.clear();
196
197        let mut default = self.default_provider.write().unwrap();
198        *default = None;
199
200        let mut extension_sources = self.extension_sources.write().unwrap();
201        extension_sources.clear();
202
203        let mut extension_modules = self.extension_modules.write().unwrap();
204        extension_modules.clear();
205
206        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
207        loaded_extensions.clear();
208
209        let mut runtimes = self.language_runtimes.write().unwrap();
210        runtimes.clear();
211    }
212
213    // ========================================================================
214    // Extension Management
215    // ========================================================================
216
217    /// Load an extension module from a shared library
218    ///
219    /// # Arguments
220    ///
221    /// * `path` - Path to the extension shared library (.so, .dll, .dylib)
222    /// * `config` - Configuration value for the extension
223    ///
224    /// # Returns
225    ///
226    /// Information about the loaded extension
227    ///
228    /// # Safety
229    ///
230    /// Loading modules executes arbitrary code. Only load from trusted sources.
231    pub fn load_extension(&self, path: &Path, config: &serde_json::Value) -> Result<LoadedPlugin> {
232        // Load the library and collect declared capabilities.
233        let mut loader = self.extension_loader.write().unwrap();
234        let loaded_info = loader.load(path)?;
235        let name = loaded_info.name.clone();
236
237        // If this module provides a data-source capability, initialize the
238        // PluginDataSource wrapper for runtime query execution.
239        if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
240            let vtable = loader.get_data_source_vtable(&name)?;
241            let source = PluginDataSource::new(name.clone(), vtable, config)?;
242
243            let mut sources = self.extension_sources.write().unwrap();
244            sources.insert(name.clone(), Arc::new(source));
245        } else {
246            // Ensure stale data-source wrappers are removed if a module is reloaded
247            // with a different capability set.
248            let mut sources = self.extension_sources.write().unwrap();
249            sources.remove(&name);
250        }
251
252        // If the plugin exposes a module capability (`shape.module`), bind its
253        // functions so VM module namespaces can dispatch through capability
254        // contracts.
255        if let Ok(module_vtable) = loader.get_module_vtable(&name) {
256            if let Ok(module) = PluginModule::new(name.clone(), module_vtable, config) {
257                let mut modules = self.extension_modules.write().unwrap();
258                modules.insert(name.clone(), Arc::new(module));
259            }
260        }
261
262        // If this plugin provides a language runtime capability, initialize it.
263        if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
264            let vtable = loader.get_language_runtime_vtable(&name)?;
265            let runtime =
266                crate::plugins::language_runtime::PluginLanguageRuntime::new(vtable, config)?;
267            let lang_id = runtime.language_id().to_string();
268            let mut runtimes = self.language_runtimes.write().unwrap();
269            runtimes.insert(lang_id, Arc::new(runtime));
270        }
271
272        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
273        loaded_extensions.insert(name, loaded_info.clone());
274
275        Ok(loaded_info)
276    }
277
278    /// Load an extension, merging claimed TOML section data into its init config.
279    ///
280    /// For each section claimed by the extension, looks it up in the project's
281    /// `extension_sections` and merges the data as JSON into the config.
282    /// Errors if a required section is missing.
283    pub fn load_extension_with_sections(
284        &self,
285        path: &Path,
286        config: &serde_json::Value,
287        extension_sections: &std::collections::HashMap<String, toml::Value>,
288        all_claimed: &mut std::collections::HashSet<String>,
289    ) -> Result<LoadedPlugin> {
290        // First, load the extension normally to get its section claims
291        let mut loader = self.extension_loader.write().unwrap();
292        let loaded_info = loader.load(path)?;
293        let name = loaded_info.name.clone();
294
295        // Collect claimed section names and check for collisions
296        for claim in &loaded_info.claimed_sections {
297            if !all_claimed.insert(claim.name.clone()) {
298                return Err(ShapeError::RuntimeError {
299                    message: format!(
300                        "Section '{}' is claimed by multiple extensions (collision detected when loading '{}')",
301                        claim.name, name
302                    ),
303                    location: None,
304                });
305            }
306        }
307
308        // Build merged config: start with the extension's own config, then
309        // overlay any claimed section data.
310        let mut merged_config = config.clone();
311        if let serde_json::Value::Object(ref mut map) = merged_config {
312            for claim in &loaded_info.claimed_sections {
313                if let Some(section_value) = extension_sections.get(&claim.name) {
314                    let json_value = crate::project::toml_to_json(section_value);
315                    map.insert(claim.name.clone(), json_value);
316                } else if claim.required {
317                    return Err(ShapeError::RuntimeError {
318                        message: format!(
319                            "Extension '{}' requires section '[{}]' in shape.toml, but it is missing",
320                            name, claim.name
321                        ),
322                        location: None,
323                    });
324                }
325            }
326        }
327
328        // Now initialize data source / module capabilities with the merged config.
329        if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
330            let vtable = loader.get_data_source_vtable(&name)?;
331            let source = PluginDataSource::new(name.clone(), vtable, &merged_config)?;
332            let mut sources = self.extension_sources.write().unwrap();
333            sources.insert(name.clone(), Arc::new(source));
334        } else {
335            let mut sources = self.extension_sources.write().unwrap();
336            sources.remove(&name);
337        }
338
339        if let Ok(module_vtable) = loader.get_module_vtable(&name) {
340            if let Ok(module) = PluginModule::new(name.clone(), module_vtable, &merged_config) {
341                let mut modules = self.extension_modules.write().unwrap();
342                modules.insert(name.clone(), Arc::new(module));
343            }
344        }
345
346        if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
347            let vtable = loader.get_language_runtime_vtable(&name)?;
348            let runtime = crate::plugins::language_runtime::PluginLanguageRuntime::new(
349                vtable,
350                &merged_config,
351            )?;
352            let lang_id = runtime.language_id().to_string();
353            let mut runtimes = self.language_runtimes.write().unwrap();
354            runtimes.insert(lang_id, Arc::new(runtime));
355        }
356
357        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
358        loaded_extensions.insert(name, loaded_info.clone());
359
360        Ok(loaded_info)
361    }
362
363    /// Get a language runtime by language identifier (e.g., "python").
364    pub fn get_language_runtime(
365        &self,
366        language_id: &str,
367    ) -> Option<Arc<crate::plugins::language_runtime::PluginLanguageRuntime>> {
368        let runtimes = self.language_runtimes.read().unwrap();
369        runtimes.get(language_id).cloned()
370    }
371
372    /// Return all loaded language runtimes, keyed by language identifier.
373    pub fn language_runtimes(
374        &self,
375    ) -> std::collections::HashMap<
376        String,
377        Arc<crate::plugins::language_runtime::PluginLanguageRuntime>,
378    > {
379        let runtimes = self.language_runtimes.read().unwrap();
380        runtimes.clone()
381    }
382
383    /// Return child-LSP configurations declared by loaded language runtimes.
384    pub fn language_runtime_lsp_configs(
385        &self,
386    ) -> Vec<crate::plugins::language_runtime::RuntimeLspConfig> {
387        let runtimes = self.language_runtimes.read().unwrap();
388        let mut configs = Vec::new();
389
390        for runtime in runtimes.values() {
391            match runtime.lsp_config() {
392                Ok(Some(config)) => configs.push(config),
393                Ok(None) => {}
394                Err(err) => {
395                    tracing::warn!("failed to query language runtime LSP config: {}", err);
396                }
397            }
398        }
399
400        configs.sort_by(|left, right| left.language_id.cmp(&right.language_id));
401        configs
402    }
403
404    /// Get an extension data source by name
405    ///
406    /// # Arguments
407    ///
408    /// * `name` - Extension name
409    ///
410    /// # Returns
411    ///
412    /// The PluginDataSource if found
413    pub fn get_extension(&self, name: &str) -> Option<Arc<PluginDataSource>> {
414        let sources = self.extension_sources.read().unwrap();
415        sources.get(name).cloned()
416    }
417
418    /// Get extension module schema by module namespace name.
419    pub fn get_extension_module_schema(&self, module_name: &str) -> Option<ParsedModuleSchema> {
420        let modules = self.extension_modules.read().unwrap();
421        modules
422            .values()
423            .find(|m| m.schema().module_name == module_name)
424            .map(|m| m.schema().clone())
425    }
426
427    /// Build runtime extension modules from all loaded extension module capabilities.
428    ///
429    /// Strict-typing follow-up `plugin-typed-abi`: the `ValueWord`-typed
430    /// dispatch shim was removed with the bulldozer; plugins are inert
431    /// until they declare typed signatures at registration. See the
432    /// 2026-05-06 entry in `docs/defections.md`.
433    pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
434        Vec::new()
435    }
436
437    /// Invoke a module-capability export by module namespace and function name.
438    pub fn invoke_extension_module_wire(
439        &self,
440        module_name: &str,
441        function: &str,
442        args: &[WireValue],
443    ) -> Result<WireValue> {
444        let modules = self.extension_modules.read().unwrap();
445        let module = modules
446            .values()
447            .find(|m| m.schema().module_name == module_name)
448            .ok_or_else(|| ShapeError::RuntimeError {
449                message: format!("Module namespace '{}' is not loaded", module_name),
450                location: None,
451            })?;
452        module.invoke_wire(function, args)
453    }
454
455    /// Get query schema for an extension (for LSP autocomplete)
456    ///
457    /// # Arguments
458    ///
459    /// * `name` - Extension name
460    ///
461    /// # Returns
462    ///
463    /// The query schema if extension exists
464    pub fn get_extension_query_schema(&self, name: &str) -> Option<ParsedQuerySchema> {
465        let sources = self.extension_sources.read().unwrap();
466        sources.get(name).map(|s| s.get_query_schema().clone())
467    }
468
469    /// Get output schema for an extension (for LSP autocomplete)
470    ///
471    /// # Arguments
472    ///
473    /// * `name` - Extension name
474    ///
475    /// # Returns
476    ///
477    /// The output schema if extension exists
478    pub fn get_extension_output_schema(&self, name: &str) -> Option<ParsedOutputSchema> {
479        let sources = self.extension_sources.read().unwrap();
480        sources.get(name).map(|s| s.get_output_schema().clone())
481    }
482
483    /// List all plugins with their query schemas (for LSP)
484    ///
485    /// # Returns
486    ///
487    /// Vector of (plugin_name, query_schema) pairs
488    pub fn list_extensions_with_schemas(&self) -> Vec<(String, ParsedQuerySchema)> {
489        let sources = self.extension_sources.read().unwrap();
490        sources
491            .iter()
492            .map(|(name, source)| (name.clone(), source.get_query_schema().clone()))
493            .collect()
494    }
495
496    /// List all loaded extension names
497    pub fn list_extensions(&self) -> Vec<String> {
498        let loaded = self.loaded_extensions.read().unwrap();
499        loaded.keys().cloned().collect()
500    }
501
502    /// Check if a plugin is loaded
503    pub fn has_extension(&self, name: &str) -> bool {
504        let loaded = self.loaded_extensions.read().unwrap();
505        loaded.contains_key(name)
506    }
507
508    /// Unload an extension
509    ///
510    /// # Arguments
511    ///
512    /// * `name` - Extension name to unload
513    ///
514    /// # Returns
515    ///
516    /// true if plugin was unloaded, false if not found
517    pub fn unload_extension(&self, name: &str) -> bool {
518        let mut sources = self.extension_sources.write().unwrap();
519        let removed_source = sources.remove(name).is_some();
520        drop(sources);
521
522        let mut modules = self.extension_modules.write().unwrap();
523        let removed_module = modules.remove(name).is_some();
524        drop(modules);
525
526        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
527        let removed_plugin = loaded_extensions.remove(name).is_some();
528        drop(loaded_extensions);
529
530        if removed_plugin {
531            let mut loader = self.extension_loader.write().unwrap();
532            loader.unload(name);
533        }
534
535        removed_plugin || removed_source || removed_module
536    }
537}
538
539impl Default for ProviderRegistry {
540    fn default() -> Self {
541        Self::new()
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use crate::data::async_provider::NullAsyncProvider;
549
550    #[test]
551    fn test_register_and_get() {
552        let registry = ProviderRegistry::new();
553        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
554
555        registry.register("test", provider.clone());
556
557        assert!(registry.has_provider("test"));
558        assert!(!registry.has_provider("nonexistent"));
559        assert!(registry.get("test").is_some());
560    }
561
562    #[test]
563    fn test_default_provider() {
564        let registry = ProviderRegistry::new();
565        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
566
567        registry.register("test", provider);
568
569        assert!(registry.set_default("test").is_ok());
570        assert!(registry.get_default().is_some());
571        assert_eq!(registry.default_name(), Some("test".to_string()));
572    }
573
574    #[test]
575    fn test_set_default_nonexistent() {
576        let registry = ProviderRegistry::new();
577        assert!(registry.set_default("nonexistent").is_err());
578    }
579
580    #[test]
581    fn test_list_providers() {
582        let registry = ProviderRegistry::new();
583        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
584
585        registry.register("test1", provider.clone());
586        registry.register("test2", provider);
587
588        let mut names = registry.list_providers();
589        names.sort();
590        assert_eq!(names, vec!["test1", "test2"]);
591    }
592
593    #[test]
594    fn test_unregister() {
595        let registry = ProviderRegistry::new();
596        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
597
598        registry.register("test", provider);
599        registry.set_default("test").unwrap();
600
601        assert!(registry.unregister("test"));
602        assert!(!registry.has_provider("test"));
603        assert!(registry.get_default().is_none());
604    }
605
606    #[test]
607    fn test_clear() {
608        let registry = ProviderRegistry::new();
609        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
610
611        registry.register("test1", provider.clone());
612        registry.register("test2", provider);
613        registry.set_default("test1").unwrap();
614
615        registry.clear();
616
617        assert_eq!(registry.list_providers().len(), 0);
618        assert!(registry.get_default().is_none());
619    }
620
621    // Plugin management tests
622
623    #[test]
624    fn test_plugin_not_loaded_by_default() {
625        let registry = ProviderRegistry::new();
626
627        assert!(!registry.has_extension("nonexistent"));
628        assert!(registry.get_extension("nonexistent").is_none());
629    }
630
631    #[test]
632    fn test_list_extensions_empty() {
633        let registry = ProviderRegistry::new();
634
635        let plugins = registry.list_extensions();
636        assert!(plugins.is_empty());
637    }
638
639    #[test]
640    fn test_list_extensions_with_schemas_empty() {
641        let registry = ProviderRegistry::new();
642
643        let schemas = registry.list_extensions_with_schemas();
644        assert!(schemas.is_empty());
645    }
646
647    #[test]
648    fn test_get_extension_query_schema_not_found() {
649        let registry = ProviderRegistry::new();
650
651        let schema = registry.get_extension_query_schema("nonexistent");
652        assert!(schema.is_none());
653    }
654
655    #[test]
656    fn test_get_extension_output_schema_not_found() {
657        let registry = ProviderRegistry::new();
658
659        let schema = registry.get_extension_output_schema("nonexistent");
660        assert!(schema.is_none());
661    }
662
663    #[test]
664    fn test_unload_plugin_not_loaded() {
665        let registry = ProviderRegistry::new();
666
667        // Unloading a non-existent plugin should return false
668        assert!(!registry.unload_extension("nonexistent"));
669    }
670
671    #[test]
672    fn test_clear_removes_plugins() {
673        let registry = ProviderRegistry::new();
674
675        // Clear should also clear plugin sources
676        registry.clear();
677
678        assert!(registry.list_extensions().is_empty());
679    }
680}