Skip to main content

shape_runtime/
provider_registry.rs

1//! Provider registry for managing named data providers
2//!
3//! This module provides a registry system that allows multiple data providers
4//! to be registered by name and accessed at runtime. This enables:
5//! - Multiple data sources (data, CSV, APIs, future plugins)
6//! - Provider selection in Shape code via module-scoped calls
7//!   (for example, `provider.load({...})`)
8//! - Default provider for convenience
9
10use crate::data::SharedAsyncProvider;
11use crate::plugins::{
12    CapabilityKind, LoadedPlugin, ParsedModuleSchema, ParsedOutputSchema, ParsedQuerySchema,
13    PluginDataSource, PluginLoader, PluginModule,
14};
15use shape_ast::error::{Result, ShapeError};
16use shape_value::ValueWord;
17use shape_wire::WireValue;
18use std::collections::HashMap;
19use std::path::Path;
20use std::sync::{Arc, RwLock};
21
22/// Registry of named data providers
23///
24/// Allows registration of multiple providers and selection by name.
25/// Thread-safe for use in concurrent contexts.
26///
27/// # Example
28///
29/// ```ignore
30/// let mut registry = ProviderRegistry::new();
31///
32/// // Register market data provider
33/// let md_provider = Arc::new(DataFrameAdapter::new(...));
34/// registry.register("data", md_provider);
35///
36/// // Register an additional provider
37/// let alt_provider = Arc::new(AnotherProvider::new(...));
38/// registry.register("alt", alt_provider);
39///
40/// // Set default
41/// registry.set_default("data")?;
42///
43/// // Get provider by name
44/// let provider = registry.get("data")?;
45/// ```
46#[derive(Clone)]
47pub struct ProviderRegistry {
48    /// Map of provider name to provider instance
49    providers: Arc<RwLock<HashMap<String, SharedAsyncProvider>>>,
50    /// Name of the default provider
51    default_provider: Arc<RwLock<Option<String>>>,
52    /// Map of plugin name to plugin data source wrapper
53    extension_sources: Arc<RwLock<HashMap<String, Arc<PluginDataSource>>>>,
54    /// Map of plugin name to plugin module-capability wrapper.
55    extension_modules: Arc<RwLock<HashMap<String, Arc<PluginModule>>>>,
56    /// Metadata for all loaded extension modules (not just data-source modules).
57    loaded_extensions: Arc<RwLock<HashMap<String, LoadedPlugin>>>,
58    /// Plugin loader for dynamic plugins
59    extension_loader: Arc<RwLock<PluginLoader>>,
60    /// Map of language identifier to loaded language runtime
61    language_runtimes:
62        Arc<RwLock<HashMap<String, Arc<crate::plugins::language_runtime::PluginLanguageRuntime>>>>,
63}
64
65impl ProviderRegistry {
66    /// Create a new empty provider registry
67    pub fn new() -> Self {
68        Self {
69            providers: Arc::new(RwLock::new(HashMap::new())),
70            default_provider: Arc::new(RwLock::new(None)),
71            extension_sources: Arc::new(RwLock::new(HashMap::new())),
72            extension_modules: Arc::new(RwLock::new(HashMap::new())),
73            loaded_extensions: Arc::new(RwLock::new(HashMap::new())),
74            extension_loader: Arc::new(RwLock::new(PluginLoader::new())),
75            language_runtimes: Arc::new(RwLock::new(HashMap::new())),
76        }
77    }
78
79    /// Register a provider with a name
80    ///
81    /// # Arguments
82    ///
83    /// * `name` - Provider name (e.g., "data", "api", "warehouse")
84    /// * `provider` - AsyncDataProvider implementation
85    ///
86    /// # Example
87    ///
88    /// ```ignore
89    /// registry.register("data", Arc::new(DataFrameAdapter::new(...)));
90    /// ```
91    pub fn register(&self, name: &str, provider: SharedAsyncProvider) {
92        let mut providers = self.providers.write().unwrap();
93        providers.insert(name.to_string(), provider);
94    }
95
96    /// Get provider by name
97    ///
98    /// # Arguments
99    ///
100    /// * `name` - Provider name to lookup
101    ///
102    /// # Returns
103    ///
104    /// SharedAsyncProvider if found, None otherwise
105    pub fn get(&self, name: &str) -> Option<SharedAsyncProvider> {
106        let providers = self.providers.read().unwrap();
107        providers.get(name).cloned()
108    }
109
110    /// Set default provider
111    ///
112    /// # Arguments
113    ///
114    /// * `name` - Name of provider to use as default
115    ///
116    /// # Errors
117    ///
118    /// Returns error if provider with given name is not registered
119    pub fn set_default(&self, name: &str) -> Result<()> {
120        let providers = self.providers.read().unwrap();
121        if !providers.contains_key(name) {
122            return Err(ShapeError::RuntimeError {
123                message: format!("Cannot set default provider: '{}' is not registered", name),
124                location: None,
125            });
126        }
127        drop(providers);
128
129        let mut default = self.default_provider.write().unwrap();
130        *default = Some(name.to_string());
131        Ok(())
132    }
133
134    /// Get default provider
135    ///
136    /// # Returns
137    ///
138    /// SharedAsyncProvider if a default is set, None otherwise
139    pub fn get_default(&self) -> Option<SharedAsyncProvider> {
140        let default = self.default_provider.read().unwrap();
141        let name = default.as_ref().cloned();
142        drop(default);
143
144        name.and_then(|n| self.get(&n))
145    }
146
147    /// Get default provider name
148    pub fn default_name(&self) -> Option<String> {
149        let default = self.default_provider.read().unwrap();
150        default.clone()
151    }
152
153    /// List all registered provider names
154    ///
155    /// # Returns
156    ///
157    /// Vector of provider names currently registered
158    pub fn list_providers(&self) -> Vec<String> {
159        let providers = self.providers.read().unwrap();
160        providers.keys().cloned().collect()
161    }
162
163    /// Check if a provider is registered
164    pub fn has_provider(&self, name: &str) -> bool {
165        let providers = self.providers.read().unwrap();
166        providers.contains_key(name)
167    }
168
169    /// Unregister a provider
170    ///
171    /// # Arguments
172    ///
173    /// * `name` - Provider name to remove
174    ///
175    /// # Returns
176    ///
177    /// true if provider was removed, false if not found
178    pub fn unregister(&self, name: &str) -> bool {
179        let mut providers = self.providers.write().unwrap();
180        let removed = providers.remove(name).is_some();
181
182        // Clear default if it was the removed provider
183        if removed {
184            let mut default = self.default_provider.write().unwrap();
185            if default.as_ref().map(|s| s == name).unwrap_or(false) {
186                *default = None;
187            }
188        }
189
190        removed
191    }
192
193    /// Clear all providers
194    pub fn clear(&self) {
195        let mut providers = self.providers.write().unwrap();
196        providers.clear();
197
198        let mut default = self.default_provider.write().unwrap();
199        *default = None;
200
201        let mut extension_sources = self.extension_sources.write().unwrap();
202        extension_sources.clear();
203
204        let mut extension_modules = self.extension_modules.write().unwrap();
205        extension_modules.clear();
206
207        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
208        loaded_extensions.clear();
209
210        let mut runtimes = self.language_runtimes.write().unwrap();
211        runtimes.clear();
212    }
213
214    // ========================================================================
215    // Extension Management
216    // ========================================================================
217
218    /// Load an extension module from a shared library
219    ///
220    /// # Arguments
221    ///
222    /// * `path` - Path to the extension shared library (.so, .dll, .dylib)
223    /// * `config` - Configuration value for the extension
224    ///
225    /// # Returns
226    ///
227    /// Information about the loaded extension
228    ///
229    /// # Safety
230    ///
231    /// Loading modules executes arbitrary code. Only load from trusted sources.
232    pub fn load_extension(&self, path: &Path, config: &serde_json::Value) -> Result<LoadedPlugin> {
233        // Load the library and collect declared capabilities.
234        let mut loader = self.extension_loader.write().unwrap();
235        let loaded_info = loader.load(path)?;
236        let name = loaded_info.name.clone();
237
238        // If this module provides a data-source capability, initialize the
239        // PluginDataSource wrapper for runtime query execution.
240        if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
241            let vtable = loader.get_data_source_vtable(&name)?;
242            let source = PluginDataSource::new(name.clone(), vtable, config)?;
243
244            let mut sources = self.extension_sources.write().unwrap();
245            sources.insert(name.clone(), Arc::new(source));
246        } else {
247            // Ensure stale data-source wrappers are removed if a module is reloaded
248            // with a different capability set.
249            let mut sources = self.extension_sources.write().unwrap();
250            sources.remove(&name);
251        }
252
253        // If the plugin exposes a module capability (`shape.module`), bind its
254        // functions so VM module namespaces can dispatch through capability
255        // contracts.
256        if let Ok(module_vtable) = loader.get_module_vtable(&name) {
257            if let Ok(module) = PluginModule::new(name.clone(), module_vtable, config) {
258                let mut modules = self.extension_modules.write().unwrap();
259                modules.insert(name.clone(), Arc::new(module));
260            }
261        }
262
263        // If this plugin provides a language runtime capability, initialize it.
264        if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
265            let vtable = loader.get_language_runtime_vtable(&name)?;
266            let runtime =
267                crate::plugins::language_runtime::PluginLanguageRuntime::new(vtable, config)?;
268            let lang_id = runtime.language_id().to_string();
269            let mut runtimes = self.language_runtimes.write().unwrap();
270            runtimes.insert(lang_id, Arc::new(runtime));
271        }
272
273        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
274        loaded_extensions.insert(name, loaded_info.clone());
275
276        Ok(loaded_info)
277    }
278
279    /// Load an extension, merging claimed TOML section data into its init config.
280    ///
281    /// For each section claimed by the extension, looks it up in the project's
282    /// `extension_sections` and merges the data as JSON into the config.
283    /// Errors if a required section is missing.
284    pub fn load_extension_with_sections(
285        &self,
286        path: &Path,
287        config: &serde_json::Value,
288        extension_sections: &std::collections::HashMap<String, toml::Value>,
289        all_claimed: &mut std::collections::HashSet<String>,
290    ) -> Result<LoadedPlugin> {
291        // First, load the extension normally to get its section claims
292        let mut loader = self.extension_loader.write().unwrap();
293        let loaded_info = loader.load(path)?;
294        let name = loaded_info.name.clone();
295
296        // Collect claimed section names and check for collisions
297        for claim in &loaded_info.claimed_sections {
298            if !all_claimed.insert(claim.name.clone()) {
299                return Err(ShapeError::RuntimeError {
300                    message: format!(
301                        "Section '{}' is claimed by multiple extensions (collision detected when loading '{}')",
302                        claim.name, name
303                    ),
304                    location: None,
305                });
306            }
307        }
308
309        // Build merged config: start with the extension's own config, then
310        // overlay any claimed section data.
311        let mut merged_config = config.clone();
312        if let serde_json::Value::Object(ref mut map) = merged_config {
313            for claim in &loaded_info.claimed_sections {
314                if let Some(section_value) = extension_sections.get(&claim.name) {
315                    let json_value = crate::project::toml_to_json(section_value);
316                    map.insert(claim.name.clone(), json_value);
317                } else if claim.required {
318                    return Err(ShapeError::RuntimeError {
319                        message: format!(
320                            "Extension '{}' requires section '[{}]' in shape.toml, but it is missing",
321                            name, claim.name
322                        ),
323                        location: None,
324                    });
325                }
326            }
327        }
328
329        // Now initialize data source / module capabilities with the merged config.
330        if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
331            let vtable = loader.get_data_source_vtable(&name)?;
332            let source = PluginDataSource::new(name.clone(), vtable, &merged_config)?;
333            let mut sources = self.extension_sources.write().unwrap();
334            sources.insert(name.clone(), Arc::new(source));
335        } else {
336            let mut sources = self.extension_sources.write().unwrap();
337            sources.remove(&name);
338        }
339
340        if let Ok(module_vtable) = loader.get_module_vtable(&name) {
341            if let Ok(module) = PluginModule::new(name.clone(), module_vtable, &merged_config) {
342                let mut modules = self.extension_modules.write().unwrap();
343                modules.insert(name.clone(), Arc::new(module));
344            }
345        }
346
347        if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
348            let vtable = loader.get_language_runtime_vtable(&name)?;
349            let runtime = crate::plugins::language_runtime::PluginLanguageRuntime::new(
350                vtable,
351                &merged_config,
352            )?;
353            let lang_id = runtime.language_id().to_string();
354            let mut runtimes = self.language_runtimes.write().unwrap();
355            runtimes.insert(lang_id, Arc::new(runtime));
356        }
357
358        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
359        loaded_extensions.insert(name, loaded_info.clone());
360
361        Ok(loaded_info)
362    }
363
364    /// Get a language runtime by language identifier (e.g., "python").
365    pub fn get_language_runtime(
366        &self,
367        language_id: &str,
368    ) -> Option<Arc<crate::plugins::language_runtime::PluginLanguageRuntime>> {
369        let runtimes = self.language_runtimes.read().unwrap();
370        runtimes.get(language_id).cloned()
371    }
372
373    /// Return all loaded language runtimes, keyed by language identifier.
374    pub fn language_runtimes(
375        &self,
376    ) -> std::collections::HashMap<
377        String,
378        Arc<crate::plugins::language_runtime::PluginLanguageRuntime>,
379    > {
380        let runtimes = self.language_runtimes.read().unwrap();
381        runtimes.clone()
382    }
383
384    /// Return child-LSP configurations declared by loaded language runtimes.
385    pub fn language_runtime_lsp_configs(
386        &self,
387    ) -> Vec<crate::plugins::language_runtime::RuntimeLspConfig> {
388        let runtimes = self.language_runtimes.read().unwrap();
389        let mut configs = Vec::new();
390
391        for runtime in runtimes.values() {
392            match runtime.lsp_config() {
393                Ok(Some(config)) => configs.push(config),
394                Ok(None) => {}
395                Err(err) => {
396                    tracing::warn!("failed to query language runtime LSP config: {}", err);
397                }
398            }
399        }
400
401        configs.sort_by(|left, right| left.language_id.cmp(&right.language_id));
402        configs
403    }
404
405    /// Get an extension data source by name
406    ///
407    /// # Arguments
408    ///
409    /// * `name` - Extension name
410    ///
411    /// # Returns
412    ///
413    /// The PluginDataSource if found
414    pub fn get_extension(&self, name: &str) -> Option<Arc<PluginDataSource>> {
415        let sources = self.extension_sources.read().unwrap();
416        sources.get(name).cloned()
417    }
418
419    /// Get extension module schema by module namespace name.
420    pub fn get_extension_module_schema(&self, module_name: &str) -> Option<ParsedModuleSchema> {
421        let modules = self.extension_modules.read().unwrap();
422        modules
423            .values()
424            .find(|m| m.schema().module_name == module_name)
425            .map(|m| m.schema().clone())
426    }
427
428    /// Build runtime extension modules from all loaded extension module capabilities.
429    pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
430        let modules = self.extension_modules.read().unwrap();
431        modules.values().map(|m| m.to_module_exports()).collect()
432    }
433
434    /// Invoke a module-capability export by module namespace and function name.
435    pub fn invoke_extension_module_nb(
436        &self,
437        module_name: &str,
438        function: &str,
439        args: &[ValueWord],
440    ) -> Result<ValueWord> {
441        let modules = self.extension_modules.read().unwrap();
442        let module = modules
443            .values()
444            .find(|m| m.schema().module_name == module_name)
445            .ok_or_else(|| ShapeError::RuntimeError {
446                message: format!("Module namespace '{}' is not loaded", module_name),
447                location: None,
448            })?;
449        module.invoke_nb(function, args)
450    }
451
452    /// Invoke a module-capability export by module namespace and function name.
453    pub fn invoke_extension_module_wire(
454        &self,
455        module_name: &str,
456        function: &str,
457        args: &[WireValue],
458    ) -> Result<WireValue> {
459        let modules = self.extension_modules.read().unwrap();
460        let module = modules
461            .values()
462            .find(|m| m.schema().module_name == module_name)
463            .ok_or_else(|| ShapeError::RuntimeError {
464                message: format!("Module namespace '{}' is not loaded", module_name),
465                location: None,
466            })?;
467        module.invoke_wire(function, args)
468    }
469
470    /// Get query schema for an extension (for LSP autocomplete)
471    ///
472    /// # Arguments
473    ///
474    /// * `name` - Extension name
475    ///
476    /// # Returns
477    ///
478    /// The query schema if extension exists
479    pub fn get_extension_query_schema(&self, name: &str) -> Option<ParsedQuerySchema> {
480        let sources = self.extension_sources.read().unwrap();
481        sources.get(name).map(|s| s.get_query_schema().clone())
482    }
483
484    /// Get output schema for an extension (for LSP autocomplete)
485    ///
486    /// # Arguments
487    ///
488    /// * `name` - Extension name
489    ///
490    /// # Returns
491    ///
492    /// The output schema if extension exists
493    pub fn get_extension_output_schema(&self, name: &str) -> Option<ParsedOutputSchema> {
494        let sources = self.extension_sources.read().unwrap();
495        sources.get(name).map(|s| s.get_output_schema().clone())
496    }
497
498    /// List all plugins with their query schemas (for LSP)
499    ///
500    /// # Returns
501    ///
502    /// Vector of (plugin_name, query_schema) pairs
503    pub fn list_extensions_with_schemas(&self) -> Vec<(String, ParsedQuerySchema)> {
504        let sources = self.extension_sources.read().unwrap();
505        sources
506            .iter()
507            .map(|(name, source)| (name.clone(), source.get_query_schema().clone()))
508            .collect()
509    }
510
511    /// List all loaded extension names
512    pub fn list_extensions(&self) -> Vec<String> {
513        let loaded = self.loaded_extensions.read().unwrap();
514        loaded.keys().cloned().collect()
515    }
516
517    /// Check if a plugin is loaded
518    pub fn has_extension(&self, name: &str) -> bool {
519        let loaded = self.loaded_extensions.read().unwrap();
520        loaded.contains_key(name)
521    }
522
523    /// Unload an extension
524    ///
525    /// # Arguments
526    ///
527    /// * `name` - Extension name to unload
528    ///
529    /// # Returns
530    ///
531    /// true if plugin was unloaded, false if not found
532    pub fn unload_extension(&self, name: &str) -> bool {
533        let mut sources = self.extension_sources.write().unwrap();
534        let removed_source = sources.remove(name).is_some();
535        drop(sources);
536
537        let mut modules = self.extension_modules.write().unwrap();
538        let removed_module = modules.remove(name).is_some();
539        drop(modules);
540
541        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
542        let removed_plugin = loaded_extensions.remove(name).is_some();
543        drop(loaded_extensions);
544
545        if removed_plugin {
546            let mut loader = self.extension_loader.write().unwrap();
547            loader.unload(name);
548        }
549
550        removed_plugin || removed_source || removed_module
551    }
552}
553
554impl Default for ProviderRegistry {
555    fn default() -> Self {
556        Self::new()
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use crate::data::async_provider::NullAsyncProvider;
564
565    #[test]
566    fn test_register_and_get() {
567        let registry = ProviderRegistry::new();
568        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
569
570        registry.register("test", provider.clone());
571
572        assert!(registry.has_provider("test"));
573        assert!(!registry.has_provider("nonexistent"));
574        assert!(registry.get("test").is_some());
575    }
576
577    #[test]
578    fn test_default_provider() {
579        let registry = ProviderRegistry::new();
580        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
581
582        registry.register("test", provider);
583
584        assert!(registry.set_default("test").is_ok());
585        assert!(registry.get_default().is_some());
586        assert_eq!(registry.default_name(), Some("test".to_string()));
587    }
588
589    #[test]
590    fn test_set_default_nonexistent() {
591        let registry = ProviderRegistry::new();
592        assert!(registry.set_default("nonexistent").is_err());
593    }
594
595    #[test]
596    fn test_list_providers() {
597        let registry = ProviderRegistry::new();
598        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
599
600        registry.register("test1", provider.clone());
601        registry.register("test2", provider);
602
603        let mut names = registry.list_providers();
604        names.sort();
605        assert_eq!(names, vec!["test1", "test2"]);
606    }
607
608    #[test]
609    fn test_unregister() {
610        let registry = ProviderRegistry::new();
611        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
612
613        registry.register("test", provider);
614        registry.set_default("test").unwrap();
615
616        assert!(registry.unregister("test"));
617        assert!(!registry.has_provider("test"));
618        assert!(registry.get_default().is_none());
619    }
620
621    #[test]
622    fn test_clear() {
623        let registry = ProviderRegistry::new();
624        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
625
626        registry.register("test1", provider.clone());
627        registry.register("test2", provider);
628        registry.set_default("test1").unwrap();
629
630        registry.clear();
631
632        assert_eq!(registry.list_providers().len(), 0);
633        assert!(registry.get_default().is_none());
634    }
635
636    // Plugin management tests
637
638    #[test]
639    fn test_plugin_not_loaded_by_default() {
640        let registry = ProviderRegistry::new();
641
642        assert!(!registry.has_extension("nonexistent"));
643        assert!(registry.get_extension("nonexistent").is_none());
644    }
645
646    #[test]
647    fn test_list_extensions_empty() {
648        let registry = ProviderRegistry::new();
649
650        let plugins = registry.list_extensions();
651        assert!(plugins.is_empty());
652    }
653
654    #[test]
655    fn test_list_extensions_with_schemas_empty() {
656        let registry = ProviderRegistry::new();
657
658        let schemas = registry.list_extensions_with_schemas();
659        assert!(schemas.is_empty());
660    }
661
662    #[test]
663    fn test_get_extension_query_schema_not_found() {
664        let registry = ProviderRegistry::new();
665
666        let schema = registry.get_extension_query_schema("nonexistent");
667        assert!(schema.is_none());
668    }
669
670    #[test]
671    fn test_get_extension_output_schema_not_found() {
672        let registry = ProviderRegistry::new();
673
674        let schema = registry.get_extension_output_schema("nonexistent");
675        assert!(schema.is_none());
676    }
677
678    #[test]
679    fn test_unload_plugin_not_loaded() {
680        let registry = ProviderRegistry::new();
681
682        // Unloading a non-existent plugin should return false
683        assert!(!registry.unload_extension("nonexistent"));
684    }
685
686    #[test]
687    fn test_clear_removes_plugins() {
688        let registry = ProviderRegistry::new();
689
690        // Clear should also clear plugin sources
691        registry.clear();
692
693        assert!(registry.list_extensions().is_empty());
694    }
695}