Skip to main content

shape_runtime/
provider_registry.rs

1//! Provider registry for managing named data providers
2//!
3//! This module provides a registry system that allows multiple data providers
4//! to be registered by name and accessed at runtime. This enables:
5//! - Multiple data sources (data, CSV, APIs, future plugins)
6//! - Provider selection in Shape code via module-scoped calls
7//!   (for example, `provider.load({...})`)
8//! - Default provider for convenience
9
10use crate::data::SharedAsyncProvider;
11use crate::plugins::{
12    CapabilityKind, LoadedPlugin, ParsedModuleSchema, ParsedOutputSchema, ParsedQuerySchema,
13    PluginDataSource, PluginLoader, PluginModule,
14};
15use shape_ast::error::{Result, ShapeError};
16use shape_value::ValueWord;
17use shape_wire::WireValue;
18use std::collections::HashMap;
19use std::path::Path;
20use std::sync::{Arc, RwLock};
21
22/// Registry of named data providers
23///
24/// Allows registration of multiple providers and selection by name.
25/// Thread-safe for use in concurrent contexts.
26///
27/// # Example
28///
29/// ```ignore
30/// let mut registry = ProviderRegistry::new();
31///
32/// // Register market data provider
33/// let md_provider = Arc::new(DataFrameAdapter::new(...));
34/// registry.register("data", md_provider);
35///
36/// // Register an additional provider
37/// let alt_provider = Arc::new(AnotherProvider::new(...));
38/// registry.register("alt", alt_provider);
39///
40/// // Set default
41/// registry.set_default("data")?;
42///
43/// // Get provider by name
44/// let provider = registry.get("data")?;
45/// ```
46#[derive(Clone)]
47pub struct ProviderRegistry {
48    /// Map of provider name to provider instance
49    providers: Arc<RwLock<HashMap<String, SharedAsyncProvider>>>,
50    /// Name of the default provider
51    default_provider: Arc<RwLock<Option<String>>>,
52    /// Map of plugin name to plugin data source wrapper
53    extension_sources: Arc<RwLock<HashMap<String, Arc<PluginDataSource>>>>,
54    /// Map of plugin name to plugin module-capability wrapper.
55    extension_modules: Arc<RwLock<HashMap<String, Arc<PluginModule>>>>,
56    /// Metadata for all loaded extension modules (not just data-source modules).
57    loaded_extensions: Arc<RwLock<HashMap<String, LoadedPlugin>>>,
58    /// Plugin loader for dynamic plugins
59    extension_loader: Arc<RwLock<PluginLoader>>,
60    /// Map of language identifier to loaded language runtime
61    language_runtimes:
62        Arc<RwLock<HashMap<String, Arc<crate::plugins::language_runtime::PluginLanguageRuntime>>>>,
63}
64
65impl ProviderRegistry {
66    /// Create a new empty provider registry
67    pub fn new() -> Self {
68        Self {
69            providers: Arc::new(RwLock::new(HashMap::new())),
70            default_provider: Arc::new(RwLock::new(None)),
71            extension_sources: Arc::new(RwLock::new(HashMap::new())),
72            extension_modules: Arc::new(RwLock::new(HashMap::new())),
73            loaded_extensions: Arc::new(RwLock::new(HashMap::new())),
74            extension_loader: Arc::new(RwLock::new(PluginLoader::new())),
75            language_runtimes: Arc::new(RwLock::new(HashMap::new())),
76        }
77    }
78
79    /// Register a provider with a name
80    ///
81    /// # Arguments
82    ///
83    /// * `name` - Provider name (e.g., "data", "api", "warehouse")
84    /// * `provider` - AsyncDataProvider implementation
85    ///
86    /// # Example
87    ///
88    /// ```ignore
89    /// registry.register("data", Arc::new(DataFrameAdapter::new(...)));
90    /// ```
91    pub fn register(&self, name: &str, provider: SharedAsyncProvider) {
92        let mut providers = self.providers.write().unwrap();
93        providers.insert(name.to_string(), provider);
94    }
95
96    /// Get provider by name
97    ///
98    /// # Arguments
99    ///
100    /// * `name` - Provider name to lookup
101    ///
102    /// # Returns
103    ///
104    /// SharedAsyncProvider if found, None otherwise
105    pub fn get(&self, name: &str) -> Option<SharedAsyncProvider> {
106        let providers = self.providers.read().unwrap();
107        providers.get(name).cloned()
108    }
109
110    /// Set default provider
111    ///
112    /// # Arguments
113    ///
114    /// * `name` - Name of provider to use as default
115    ///
116    /// # Errors
117    ///
118    /// Returns error if provider with given name is not registered
119    pub fn set_default(&self, name: &str) -> Result<()> {
120        let providers = self.providers.read().unwrap();
121        if !providers.contains_key(name) {
122            return Err(ShapeError::RuntimeError {
123                message: format!("Cannot set default provider: '{}' is not registered", name),
124                location: None,
125            });
126        }
127        drop(providers);
128
129        let mut default = self.default_provider.write().unwrap();
130        *default = Some(name.to_string());
131        Ok(())
132    }
133
134    /// Get default provider
135    ///
136    /// # Returns
137    ///
138    /// SharedAsyncProvider if a default is set, None otherwise
139    pub fn get_default(&self) -> Option<SharedAsyncProvider> {
140        let default = self.default_provider.read().unwrap();
141        let name = default.as_ref().cloned();
142        drop(default);
143
144        name.and_then(|n| self.get(&n))
145    }
146
147    /// Get default provider name
148    pub fn default_name(&self) -> Option<String> {
149        let default = self.default_provider.read().unwrap();
150        default.clone()
151    }
152
153    /// List all registered provider names
154    ///
155    /// # Returns
156    ///
157    /// Vector of provider names currently registered
158    pub fn list_providers(&self) -> Vec<String> {
159        let providers = self.providers.read().unwrap();
160        providers.keys().cloned().collect()
161    }
162
163    /// Check if a provider is registered
164    pub fn has_provider(&self, name: &str) -> bool {
165        let providers = self.providers.read().unwrap();
166        providers.contains_key(name)
167    }
168
169    /// Unregister a provider
170    ///
171    /// # Arguments
172    ///
173    /// * `name` - Provider name to remove
174    ///
175    /// # Returns
176    ///
177    /// true if provider was removed, false if not found
178    pub fn unregister(&self, name: &str) -> bool {
179        let mut providers = self.providers.write().unwrap();
180        let removed = providers.remove(name).is_some();
181
182        // Clear default if it was the removed provider
183        if removed {
184            let mut default = self.default_provider.write().unwrap();
185            if default.as_ref().map(|s| s == name).unwrap_or(false) {
186                *default = None;
187            }
188        }
189
190        removed
191    }
192
193    /// Clear all providers
194    pub fn clear(&self) {
195        let mut providers = self.providers.write().unwrap();
196        providers.clear();
197
198        let mut default = self.default_provider.write().unwrap();
199        *default = None;
200
201        let mut extension_sources = self.extension_sources.write().unwrap();
202        extension_sources.clear();
203
204        let mut extension_modules = self.extension_modules.write().unwrap();
205        extension_modules.clear();
206
207        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
208        loaded_extensions.clear();
209
210        let mut runtimes = self.language_runtimes.write().unwrap();
211        runtimes.clear();
212    }
213
214    // ========================================================================
215    // Extension Management
216    // ========================================================================
217
218    /// Load an extension module from a shared library
219    ///
220    /// # Arguments
221    ///
222    /// * `path` - Path to the extension shared library (.so, .dll, .dylib)
223    /// * `config` - Configuration value for the extension
224    ///
225    /// # Returns
226    ///
227    /// Information about the loaded extension
228    ///
229    /// # Safety
230    ///
231    /// Loading modules executes arbitrary code. Only load from trusted sources.
232    pub fn load_extension(&self, path: &Path, config: &serde_json::Value) -> Result<LoadedPlugin> {
233        // Load the library and collect declared capabilities.
234        let mut loader = self.extension_loader.write().unwrap();
235        let loaded_info = loader.load(path)?;
236        let name = loaded_info.name.clone();
237
238        // If this module provides a data-source capability, initialize the
239        // PluginDataSource wrapper for runtime query execution.
240        if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
241            let vtable = loader.get_data_source_vtable(&name)?;
242            let source = PluginDataSource::new(name.clone(), vtable, config)?;
243
244            let mut sources = self.extension_sources.write().unwrap();
245            sources.insert(name.clone(), Arc::new(source));
246        } else {
247            // Ensure stale data-source wrappers are removed if a module is reloaded
248            // with a different capability set.
249            let mut sources = self.extension_sources.write().unwrap();
250            sources.remove(&name);
251        }
252
253        // If the plugin exposes a module capability (`shape.module`), bind its
254        // functions so VM module namespaces can dispatch through capability
255        // contracts.
256        if let Ok(module_vtable) = loader.get_module_vtable(&name) {
257            if let Ok(module) = PluginModule::new(name.clone(), module_vtable, config) {
258                let mut modules = self.extension_modules.write().unwrap();
259                modules.insert(name.clone(), Arc::new(module));
260            }
261        }
262
263        // If this plugin provides a language runtime capability, initialize it.
264        if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
265            let vtable = loader.get_language_runtime_vtable(&name)?;
266            let runtime =
267                crate::plugins::language_runtime::PluginLanguageRuntime::new(vtable, config)?;
268            let lang_id = runtime.language_id().to_string();
269            let mut runtimes = self.language_runtimes.write().unwrap();
270            runtimes.insert(lang_id, Arc::new(runtime));
271        }
272
273        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
274        loaded_extensions.insert(name, loaded_info.clone());
275
276        Ok(loaded_info)
277    }
278
279    /// Load an extension, merging claimed TOML section data into its init config.
280    ///
281    /// For each section claimed by the extension, looks it up in the project's
282    /// `extension_sections` and merges the data as JSON into the config.
283    /// Errors if a required section is missing.
284    pub fn load_extension_with_sections(
285        &self,
286        path: &Path,
287        config: &serde_json::Value,
288        extension_sections: &std::collections::HashMap<String, toml::Value>,
289        all_claimed: &mut std::collections::HashSet<String>,
290    ) -> Result<LoadedPlugin> {
291        // First, load the extension normally to get its section claims
292        let mut loader = self.extension_loader.write().unwrap();
293        let loaded_info = loader.load(path)?;
294        let name = loaded_info.name.clone();
295
296        // Collect claimed section names and check for collisions
297        for claim in &loaded_info.claimed_sections {
298            if !all_claimed.insert(claim.name.clone()) {
299                return Err(ShapeError::RuntimeError {
300                    message: format!(
301                        "Section '{}' is claimed by multiple extensions (collision detected when loading '{}')",
302                        claim.name, name
303                    ),
304                    location: None,
305                });
306            }
307        }
308
309        // Build merged config: start with the extension's own config, then
310        // overlay any claimed section data.
311        let mut merged_config = config.clone();
312        if let serde_json::Value::Object(ref mut map) = merged_config {
313            for claim in &loaded_info.claimed_sections {
314                if let Some(section_value) = extension_sections.get(&claim.name) {
315                    let json_value = crate::project::toml_to_json(section_value);
316                    map.insert(claim.name.clone(), json_value);
317                } else if claim.required {
318                    return Err(ShapeError::RuntimeError {
319                        message: format!(
320                            "Extension '{}' requires section '[{}]' in shape.toml, but it is missing",
321                            name, claim.name
322                        ),
323                        location: None,
324                    });
325                }
326            }
327        }
328
329        // Now initialize data source / module capabilities with the merged config.
330        if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
331            let vtable = loader.get_data_source_vtable(&name)?;
332            let source = PluginDataSource::new(name.clone(), vtable, &merged_config)?;
333            let mut sources = self.extension_sources.write().unwrap();
334            sources.insert(name.clone(), Arc::new(source));
335        } else {
336            let mut sources = self.extension_sources.write().unwrap();
337            sources.remove(&name);
338        }
339
340        if let Ok(module_vtable) = loader.get_module_vtable(&name) {
341            if let Ok(module) = PluginModule::new(name.clone(), module_vtable, &merged_config) {
342                let mut modules = self.extension_modules.write().unwrap();
343                modules.insert(name.clone(), Arc::new(module));
344            }
345        }
346
347        if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
348            let vtable = loader.get_language_runtime_vtable(&name)?;
349            let runtime = crate::plugins::language_runtime::PluginLanguageRuntime::new(
350                vtable,
351                &merged_config,
352            )?;
353            let lang_id = runtime.language_id().to_string();
354            let mut runtimes = self.language_runtimes.write().unwrap();
355            runtimes.insert(lang_id, Arc::new(runtime));
356        }
357
358        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
359        loaded_extensions.insert(name, loaded_info.clone());
360
361        Ok(loaded_info)
362    }
363
364    /// Get a language runtime by language identifier (e.g., "python").
365    pub fn get_language_runtime(
366        &self,
367        language_id: &str,
368    ) -> Option<Arc<crate::plugins::language_runtime::PluginLanguageRuntime>> {
369        let runtimes = self.language_runtimes.read().unwrap();
370        runtimes.get(language_id).cloned()
371    }
372
373    /// Return all loaded language runtimes, keyed by language identifier.
374    pub fn language_runtimes(
375        &self,
376    ) -> std::collections::HashMap<String, Arc<crate::plugins::language_runtime::PluginLanguageRuntime>> {
377        let runtimes = self.language_runtimes.read().unwrap();
378        runtimes.clone()
379    }
380
381    /// Return child-LSP configurations declared by loaded language runtimes.
382    pub fn language_runtime_lsp_configs(
383        &self,
384    ) -> Vec<crate::plugins::language_runtime::RuntimeLspConfig> {
385        let runtimes = self.language_runtimes.read().unwrap();
386        let mut configs = Vec::new();
387
388        for runtime in runtimes.values() {
389            match runtime.lsp_config() {
390                Ok(Some(config)) => configs.push(config),
391                Ok(None) => {}
392                Err(err) => {
393                    tracing::warn!("failed to query language runtime LSP config: {}", err);
394                }
395            }
396        }
397
398        configs.sort_by(|left, right| left.language_id.cmp(&right.language_id));
399        configs
400    }
401
402    /// Get an extension data source by name
403    ///
404    /// # Arguments
405    ///
406    /// * `name` - Extension name
407    ///
408    /// # Returns
409    ///
410    /// The PluginDataSource if found
411    pub fn get_extension(&self, name: &str) -> Option<Arc<PluginDataSource>> {
412        let sources = self.extension_sources.read().unwrap();
413        sources.get(name).cloned()
414    }
415
416    /// Get extension module schema by module namespace name.
417    pub fn get_extension_module_schema(&self, module_name: &str) -> Option<ParsedModuleSchema> {
418        let modules = self.extension_modules.read().unwrap();
419        modules
420            .values()
421            .find(|m| m.schema().module_name == module_name)
422            .map(|m| m.schema().clone())
423    }
424
425    /// Build runtime extension modules from all loaded extension module capabilities.
426    pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
427        let modules = self.extension_modules.read().unwrap();
428        modules.values().map(|m| m.to_module_exports()).collect()
429    }
430
431    /// Invoke a module-capability export by module namespace and function name.
432    pub fn invoke_extension_module_nb(
433        &self,
434        module_name: &str,
435        function: &str,
436        args: &[ValueWord],
437    ) -> Result<ValueWord> {
438        let modules = self.extension_modules.read().unwrap();
439        let module = modules
440            .values()
441            .find(|m| m.schema().module_name == module_name)
442            .ok_or_else(|| ShapeError::RuntimeError {
443                message: format!("Module namespace '{}' is not loaded", module_name),
444                location: None,
445            })?;
446        module.invoke_nb(function, args)
447    }
448
449    /// Invoke a module-capability export by module namespace and function name.
450    pub fn invoke_extension_module_wire(
451        &self,
452        module_name: &str,
453        function: &str,
454        args: &[WireValue],
455    ) -> Result<WireValue> {
456        let modules = self.extension_modules.read().unwrap();
457        let module = modules
458            .values()
459            .find(|m| m.schema().module_name == module_name)
460            .ok_or_else(|| ShapeError::RuntimeError {
461                message: format!("Module namespace '{}' is not loaded", module_name),
462                location: None,
463            })?;
464        module.invoke_wire(function, args)
465    }
466
467    /// Get query schema for an extension (for LSP autocomplete)
468    ///
469    /// # Arguments
470    ///
471    /// * `name` - Extension name
472    ///
473    /// # Returns
474    ///
475    /// The query schema if extension exists
476    pub fn get_extension_query_schema(&self, name: &str) -> Option<ParsedQuerySchema> {
477        let sources = self.extension_sources.read().unwrap();
478        sources.get(name).map(|s| s.get_query_schema().clone())
479    }
480
481    /// Get output schema for an extension (for LSP autocomplete)
482    ///
483    /// # Arguments
484    ///
485    /// * `name` - Extension name
486    ///
487    /// # Returns
488    ///
489    /// The output schema if extension exists
490    pub fn get_extension_output_schema(&self, name: &str) -> Option<ParsedOutputSchema> {
491        let sources = self.extension_sources.read().unwrap();
492        sources.get(name).map(|s| s.get_output_schema().clone())
493    }
494
495    /// List all plugins with their query schemas (for LSP)
496    ///
497    /// # Returns
498    ///
499    /// Vector of (plugin_name, query_schema) pairs
500    pub fn list_extensions_with_schemas(&self) -> Vec<(String, ParsedQuerySchema)> {
501        let sources = self.extension_sources.read().unwrap();
502        sources
503            .iter()
504            .map(|(name, source)| (name.clone(), source.get_query_schema().clone()))
505            .collect()
506    }
507
508    /// List all loaded extension names
509    pub fn list_extensions(&self) -> Vec<String> {
510        let loaded = self.loaded_extensions.read().unwrap();
511        loaded.keys().cloned().collect()
512    }
513
514    /// Check if a plugin is loaded
515    pub fn has_extension(&self, name: &str) -> bool {
516        let loaded = self.loaded_extensions.read().unwrap();
517        loaded.contains_key(name)
518    }
519
520    /// Unload an extension
521    ///
522    /// # Arguments
523    ///
524    /// * `name` - Extension name to unload
525    ///
526    /// # Returns
527    ///
528    /// true if plugin was unloaded, false if not found
529    pub fn unload_extension(&self, name: &str) -> bool {
530        let mut sources = self.extension_sources.write().unwrap();
531        let removed_source = sources.remove(name).is_some();
532        drop(sources);
533
534        let mut modules = self.extension_modules.write().unwrap();
535        let removed_module = modules.remove(name).is_some();
536        drop(modules);
537
538        let mut loaded_extensions = self.loaded_extensions.write().unwrap();
539        let removed_plugin = loaded_extensions.remove(name).is_some();
540        drop(loaded_extensions);
541
542        if removed_plugin {
543            let mut loader = self.extension_loader.write().unwrap();
544            loader.unload(name);
545        }
546
547        removed_plugin || removed_source || removed_module
548    }
549}
550
551impl Default for ProviderRegistry {
552    fn default() -> Self {
553        Self::new()
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use crate::data::async_provider::NullAsyncProvider;
561
562    #[test]
563    fn test_register_and_get() {
564        let registry = ProviderRegistry::new();
565        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
566
567        registry.register("test", provider.clone());
568
569        assert!(registry.has_provider("test"));
570        assert!(!registry.has_provider("nonexistent"));
571        assert!(registry.get("test").is_some());
572    }
573
574    #[test]
575    fn test_default_provider() {
576        let registry = ProviderRegistry::new();
577        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
578
579        registry.register("test", provider);
580
581        assert!(registry.set_default("test").is_ok());
582        assert!(registry.get_default().is_some());
583        assert_eq!(registry.default_name(), Some("test".to_string()));
584    }
585
586    #[test]
587    fn test_set_default_nonexistent() {
588        let registry = ProviderRegistry::new();
589        assert!(registry.set_default("nonexistent").is_err());
590    }
591
592    #[test]
593    fn test_list_providers() {
594        let registry = ProviderRegistry::new();
595        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
596
597        registry.register("test1", provider.clone());
598        registry.register("test2", provider);
599
600        let mut names = registry.list_providers();
601        names.sort();
602        assert_eq!(names, vec!["test1", "test2"]);
603    }
604
605    #[test]
606    fn test_unregister() {
607        let registry = ProviderRegistry::new();
608        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
609
610        registry.register("test", provider);
611        registry.set_default("test").unwrap();
612
613        assert!(registry.unregister("test"));
614        assert!(!registry.has_provider("test"));
615        assert!(registry.get_default().is_none());
616    }
617
618    #[test]
619    fn test_clear() {
620        let registry = ProviderRegistry::new();
621        let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
622
623        registry.register("test1", provider.clone());
624        registry.register("test2", provider);
625        registry.set_default("test1").unwrap();
626
627        registry.clear();
628
629        assert_eq!(registry.list_providers().len(), 0);
630        assert!(registry.get_default().is_none());
631    }
632
633    // Plugin management tests
634
635    #[test]
636    fn test_plugin_not_loaded_by_default() {
637        let registry = ProviderRegistry::new();
638
639        assert!(!registry.has_extension("nonexistent"));
640        assert!(registry.get_extension("nonexistent").is_none());
641    }
642
643    #[test]
644    fn test_list_extensions_empty() {
645        let registry = ProviderRegistry::new();
646
647        let plugins = registry.list_extensions();
648        assert!(plugins.is_empty());
649    }
650
651    #[test]
652    fn test_list_extensions_with_schemas_empty() {
653        let registry = ProviderRegistry::new();
654
655        let schemas = registry.list_extensions_with_schemas();
656        assert!(schemas.is_empty());
657    }
658
659    #[test]
660    fn test_get_extension_query_schema_not_found() {
661        let registry = ProviderRegistry::new();
662
663        let schema = registry.get_extension_query_schema("nonexistent");
664        assert!(schema.is_none());
665    }
666
667    #[test]
668    fn test_get_extension_output_schema_not_found() {
669        let registry = ProviderRegistry::new();
670
671        let schema = registry.get_extension_output_schema("nonexistent");
672        assert!(schema.is_none());
673    }
674
675    #[test]
676    fn test_unload_plugin_not_loaded() {
677        let registry = ProviderRegistry::new();
678
679        // Unloading a non-existent plugin should return false
680        assert!(!registry.unload_extension("nonexistent"));
681    }
682
683    #[test]
684    fn test_clear_removes_plugins() {
685        let registry = ProviderRegistry::new();
686
687        // Clear should also clear plugin sources
688        registry.clear();
689
690        assert!(registry.list_extensions().is_empty());
691    }
692}