1use crate::data::SharedAsyncProvider;
11use crate::plugins::{
12 CapabilityKind, LoadedPlugin, ParsedModuleSchema, ParsedOutputSchema, ParsedQuerySchema,
13 PluginDataSource, PluginLoader, PluginModule,
14};
15use shape_ast::error::{Result, ShapeError};
16use shape_value::ValueWord;
17use shape_wire::WireValue;
18use std::collections::HashMap;
19use std::path::Path;
20use std::sync::{Arc, RwLock};
21
22#[derive(Clone)]
47pub struct ProviderRegistry {
48 providers: Arc<RwLock<HashMap<String, SharedAsyncProvider>>>,
50 default_provider: Arc<RwLock<Option<String>>>,
52 extension_sources: Arc<RwLock<HashMap<String, Arc<PluginDataSource>>>>,
54 extension_modules: Arc<RwLock<HashMap<String, Arc<PluginModule>>>>,
56 loaded_extensions: Arc<RwLock<HashMap<String, LoadedPlugin>>>,
58 extension_loader: Arc<RwLock<PluginLoader>>,
60 language_runtimes:
62 Arc<RwLock<HashMap<String, Arc<crate::plugins::language_runtime::PluginLanguageRuntime>>>>,
63}
64
65impl ProviderRegistry {
66 pub fn new() -> Self {
68 Self {
69 providers: Arc::new(RwLock::new(HashMap::new())),
70 default_provider: Arc::new(RwLock::new(None)),
71 extension_sources: Arc::new(RwLock::new(HashMap::new())),
72 extension_modules: Arc::new(RwLock::new(HashMap::new())),
73 loaded_extensions: Arc::new(RwLock::new(HashMap::new())),
74 extension_loader: Arc::new(RwLock::new(PluginLoader::new())),
75 language_runtimes: Arc::new(RwLock::new(HashMap::new())),
76 }
77 }
78
79 pub fn register(&self, name: &str, provider: SharedAsyncProvider) {
92 let mut providers = self.providers.write().unwrap();
93 providers.insert(name.to_string(), provider);
94 }
95
96 pub fn get(&self, name: &str) -> Option<SharedAsyncProvider> {
106 let providers = self.providers.read().unwrap();
107 providers.get(name).cloned()
108 }
109
110 pub fn set_default(&self, name: &str) -> Result<()> {
120 let providers = self.providers.read().unwrap();
121 if !providers.contains_key(name) {
122 return Err(ShapeError::RuntimeError {
123 message: format!("Cannot set default provider: '{}' is not registered", name),
124 location: None,
125 });
126 }
127 drop(providers);
128
129 let mut default = self.default_provider.write().unwrap();
130 *default = Some(name.to_string());
131 Ok(())
132 }
133
134 pub fn get_default(&self) -> Option<SharedAsyncProvider> {
140 let default = self.default_provider.read().unwrap();
141 let name = default.as_ref().cloned();
142 drop(default);
143
144 name.and_then(|n| self.get(&n))
145 }
146
147 pub fn default_name(&self) -> Option<String> {
149 let default = self.default_provider.read().unwrap();
150 default.clone()
151 }
152
153 pub fn list_providers(&self) -> Vec<String> {
159 let providers = self.providers.read().unwrap();
160 providers.keys().cloned().collect()
161 }
162
163 pub fn has_provider(&self, name: &str) -> bool {
165 let providers = self.providers.read().unwrap();
166 providers.contains_key(name)
167 }
168
169 pub fn unregister(&self, name: &str) -> bool {
179 let mut providers = self.providers.write().unwrap();
180 let removed = providers.remove(name).is_some();
181
182 if removed {
184 let mut default = self.default_provider.write().unwrap();
185 if default.as_ref().map(|s| s == name).unwrap_or(false) {
186 *default = None;
187 }
188 }
189
190 removed
191 }
192
193 pub fn clear(&self) {
195 let mut providers = self.providers.write().unwrap();
196 providers.clear();
197
198 let mut default = self.default_provider.write().unwrap();
199 *default = None;
200
201 let mut extension_sources = self.extension_sources.write().unwrap();
202 extension_sources.clear();
203
204 let mut extension_modules = self.extension_modules.write().unwrap();
205 extension_modules.clear();
206
207 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
208 loaded_extensions.clear();
209
210 let mut runtimes = self.language_runtimes.write().unwrap();
211 runtimes.clear();
212 }
213
214 pub fn load_extension(&self, path: &Path, config: &serde_json::Value) -> Result<LoadedPlugin> {
233 let mut loader = self.extension_loader.write().unwrap();
235 let loaded_info = loader.load(path)?;
236 let name = loaded_info.name.clone();
237
238 if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
241 let vtable = loader.get_data_source_vtable(&name)?;
242 let source = PluginDataSource::new(name.clone(), vtable, config)?;
243
244 let mut sources = self.extension_sources.write().unwrap();
245 sources.insert(name.clone(), Arc::new(source));
246 } else {
247 let mut sources = self.extension_sources.write().unwrap();
250 sources.remove(&name);
251 }
252
253 if let Ok(module_vtable) = loader.get_module_vtable(&name) {
257 if let Ok(module) = PluginModule::new(name.clone(), module_vtable, config) {
258 let mut modules = self.extension_modules.write().unwrap();
259 modules.insert(name.clone(), Arc::new(module));
260 }
261 }
262
263 if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
265 let vtable = loader.get_language_runtime_vtable(&name)?;
266 let runtime =
267 crate::plugins::language_runtime::PluginLanguageRuntime::new(vtable, config)?;
268 let lang_id = runtime.language_id().to_string();
269 let mut runtimes = self.language_runtimes.write().unwrap();
270 runtimes.insert(lang_id, Arc::new(runtime));
271 }
272
273 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
274 loaded_extensions.insert(name, loaded_info.clone());
275
276 Ok(loaded_info)
277 }
278
279 pub fn load_extension_with_sections(
285 &self,
286 path: &Path,
287 config: &serde_json::Value,
288 extension_sections: &std::collections::HashMap<String, toml::Value>,
289 all_claimed: &mut std::collections::HashSet<String>,
290 ) -> Result<LoadedPlugin> {
291 let mut loader = self.extension_loader.write().unwrap();
293 let loaded_info = loader.load(path)?;
294 let name = loaded_info.name.clone();
295
296 for claim in &loaded_info.claimed_sections {
298 if !all_claimed.insert(claim.name.clone()) {
299 return Err(ShapeError::RuntimeError {
300 message: format!(
301 "Section '{}' is claimed by multiple extensions (collision detected when loading '{}')",
302 claim.name, name
303 ),
304 location: None,
305 });
306 }
307 }
308
309 let mut merged_config = config.clone();
312 if let serde_json::Value::Object(ref mut map) = merged_config {
313 for claim in &loaded_info.claimed_sections {
314 if let Some(section_value) = extension_sections.get(&claim.name) {
315 let json_value = crate::project::toml_to_json(section_value);
316 map.insert(claim.name.clone(), json_value);
317 } else if claim.required {
318 return Err(ShapeError::RuntimeError {
319 message: format!(
320 "Extension '{}' requires section '[{}]' in shape.toml, but it is missing",
321 name, claim.name
322 ),
323 location: None,
324 });
325 }
326 }
327 }
328
329 if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
331 let vtable = loader.get_data_source_vtable(&name)?;
332 let source = PluginDataSource::new(name.clone(), vtable, &merged_config)?;
333 let mut sources = self.extension_sources.write().unwrap();
334 sources.insert(name.clone(), Arc::new(source));
335 } else {
336 let mut sources = self.extension_sources.write().unwrap();
337 sources.remove(&name);
338 }
339
340 if let Ok(module_vtable) = loader.get_module_vtable(&name) {
341 if let Ok(module) = PluginModule::new(name.clone(), module_vtable, &merged_config) {
342 let mut modules = self.extension_modules.write().unwrap();
343 modules.insert(name.clone(), Arc::new(module));
344 }
345 }
346
347 if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
348 let vtable = loader.get_language_runtime_vtable(&name)?;
349 let runtime = crate::plugins::language_runtime::PluginLanguageRuntime::new(
350 vtable,
351 &merged_config,
352 )?;
353 let lang_id = runtime.language_id().to_string();
354 let mut runtimes = self.language_runtimes.write().unwrap();
355 runtimes.insert(lang_id, Arc::new(runtime));
356 }
357
358 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
359 loaded_extensions.insert(name, loaded_info.clone());
360
361 Ok(loaded_info)
362 }
363
364 pub fn get_language_runtime(
366 &self,
367 language_id: &str,
368 ) -> Option<Arc<crate::plugins::language_runtime::PluginLanguageRuntime>> {
369 let runtimes = self.language_runtimes.read().unwrap();
370 runtimes.get(language_id).cloned()
371 }
372
373 pub fn language_runtimes(
375 &self,
376 ) -> std::collections::HashMap<
377 String,
378 Arc<crate::plugins::language_runtime::PluginLanguageRuntime>,
379 > {
380 let runtimes = self.language_runtimes.read().unwrap();
381 runtimes.clone()
382 }
383
384 pub fn language_runtime_lsp_configs(
386 &self,
387 ) -> Vec<crate::plugins::language_runtime::RuntimeLspConfig> {
388 let runtimes = self.language_runtimes.read().unwrap();
389 let mut configs = Vec::new();
390
391 for runtime in runtimes.values() {
392 match runtime.lsp_config() {
393 Ok(Some(config)) => configs.push(config),
394 Ok(None) => {}
395 Err(err) => {
396 tracing::warn!("failed to query language runtime LSP config: {}", err);
397 }
398 }
399 }
400
401 configs.sort_by(|left, right| left.language_id.cmp(&right.language_id));
402 configs
403 }
404
405 pub fn get_extension(&self, name: &str) -> Option<Arc<PluginDataSource>> {
415 let sources = self.extension_sources.read().unwrap();
416 sources.get(name).cloned()
417 }
418
419 pub fn get_extension_module_schema(&self, module_name: &str) -> Option<ParsedModuleSchema> {
421 let modules = self.extension_modules.read().unwrap();
422 modules
423 .values()
424 .find(|m| m.schema().module_name == module_name)
425 .map(|m| m.schema().clone())
426 }
427
428 pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
430 let modules = self.extension_modules.read().unwrap();
431 modules.values().map(|m| m.to_module_exports()).collect()
432 }
433
434 pub fn invoke_extension_module_nb(
436 &self,
437 module_name: &str,
438 function: &str,
439 args: &[ValueWord],
440 ) -> Result<ValueWord> {
441 let modules = self.extension_modules.read().unwrap();
442 let module = modules
443 .values()
444 .find(|m| m.schema().module_name == module_name)
445 .ok_or_else(|| ShapeError::RuntimeError {
446 message: format!("Module namespace '{}' is not loaded", module_name),
447 location: None,
448 })?;
449 module.invoke_nb(function, args)
450 }
451
452 pub fn invoke_extension_module_wire(
454 &self,
455 module_name: &str,
456 function: &str,
457 args: &[WireValue],
458 ) -> Result<WireValue> {
459 let modules = self.extension_modules.read().unwrap();
460 let module = modules
461 .values()
462 .find(|m| m.schema().module_name == module_name)
463 .ok_or_else(|| ShapeError::RuntimeError {
464 message: format!("Module namespace '{}' is not loaded", module_name),
465 location: None,
466 })?;
467 module.invoke_wire(function, args)
468 }
469
470 pub fn get_extension_query_schema(&self, name: &str) -> Option<ParsedQuerySchema> {
480 let sources = self.extension_sources.read().unwrap();
481 sources.get(name).map(|s| s.get_query_schema().clone())
482 }
483
484 pub fn get_extension_output_schema(&self, name: &str) -> Option<ParsedOutputSchema> {
494 let sources = self.extension_sources.read().unwrap();
495 sources.get(name).map(|s| s.get_output_schema().clone())
496 }
497
498 pub fn list_extensions_with_schemas(&self) -> Vec<(String, ParsedQuerySchema)> {
504 let sources = self.extension_sources.read().unwrap();
505 sources
506 .iter()
507 .map(|(name, source)| (name.clone(), source.get_query_schema().clone()))
508 .collect()
509 }
510
511 pub fn list_extensions(&self) -> Vec<String> {
513 let loaded = self.loaded_extensions.read().unwrap();
514 loaded.keys().cloned().collect()
515 }
516
517 pub fn has_extension(&self, name: &str) -> bool {
519 let loaded = self.loaded_extensions.read().unwrap();
520 loaded.contains_key(name)
521 }
522
523 pub fn unload_extension(&self, name: &str) -> bool {
533 let mut sources = self.extension_sources.write().unwrap();
534 let removed_source = sources.remove(name).is_some();
535 drop(sources);
536
537 let mut modules = self.extension_modules.write().unwrap();
538 let removed_module = modules.remove(name).is_some();
539 drop(modules);
540
541 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
542 let removed_plugin = loaded_extensions.remove(name).is_some();
543 drop(loaded_extensions);
544
545 if removed_plugin {
546 let mut loader = self.extension_loader.write().unwrap();
547 loader.unload(name);
548 }
549
550 removed_plugin || removed_source || removed_module
551 }
552}
553
554impl Default for ProviderRegistry {
555 fn default() -> Self {
556 Self::new()
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use crate::data::async_provider::NullAsyncProvider;
564
565 #[test]
566 fn test_register_and_get() {
567 let registry = ProviderRegistry::new();
568 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
569
570 registry.register("test", provider.clone());
571
572 assert!(registry.has_provider("test"));
573 assert!(!registry.has_provider("nonexistent"));
574 assert!(registry.get("test").is_some());
575 }
576
577 #[test]
578 fn test_default_provider() {
579 let registry = ProviderRegistry::new();
580 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
581
582 registry.register("test", provider);
583
584 assert!(registry.set_default("test").is_ok());
585 assert!(registry.get_default().is_some());
586 assert_eq!(registry.default_name(), Some("test".to_string()));
587 }
588
589 #[test]
590 fn test_set_default_nonexistent() {
591 let registry = ProviderRegistry::new();
592 assert!(registry.set_default("nonexistent").is_err());
593 }
594
595 #[test]
596 fn test_list_providers() {
597 let registry = ProviderRegistry::new();
598 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
599
600 registry.register("test1", provider.clone());
601 registry.register("test2", provider);
602
603 let mut names = registry.list_providers();
604 names.sort();
605 assert_eq!(names, vec!["test1", "test2"]);
606 }
607
608 #[test]
609 fn test_unregister() {
610 let registry = ProviderRegistry::new();
611 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
612
613 registry.register("test", provider);
614 registry.set_default("test").unwrap();
615
616 assert!(registry.unregister("test"));
617 assert!(!registry.has_provider("test"));
618 assert!(registry.get_default().is_none());
619 }
620
621 #[test]
622 fn test_clear() {
623 let registry = ProviderRegistry::new();
624 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
625
626 registry.register("test1", provider.clone());
627 registry.register("test2", provider);
628 registry.set_default("test1").unwrap();
629
630 registry.clear();
631
632 assert_eq!(registry.list_providers().len(), 0);
633 assert!(registry.get_default().is_none());
634 }
635
636 #[test]
639 fn test_plugin_not_loaded_by_default() {
640 let registry = ProviderRegistry::new();
641
642 assert!(!registry.has_extension("nonexistent"));
643 assert!(registry.get_extension("nonexistent").is_none());
644 }
645
646 #[test]
647 fn test_list_extensions_empty() {
648 let registry = ProviderRegistry::new();
649
650 let plugins = registry.list_extensions();
651 assert!(plugins.is_empty());
652 }
653
654 #[test]
655 fn test_list_extensions_with_schemas_empty() {
656 let registry = ProviderRegistry::new();
657
658 let schemas = registry.list_extensions_with_schemas();
659 assert!(schemas.is_empty());
660 }
661
662 #[test]
663 fn test_get_extension_query_schema_not_found() {
664 let registry = ProviderRegistry::new();
665
666 let schema = registry.get_extension_query_schema("nonexistent");
667 assert!(schema.is_none());
668 }
669
670 #[test]
671 fn test_get_extension_output_schema_not_found() {
672 let registry = ProviderRegistry::new();
673
674 let schema = registry.get_extension_output_schema("nonexistent");
675 assert!(schema.is_none());
676 }
677
678 #[test]
679 fn test_unload_plugin_not_loaded() {
680 let registry = ProviderRegistry::new();
681
682 assert!(!registry.unload_extension("nonexistent"));
684 }
685
686 #[test]
687 fn test_clear_removes_plugins() {
688 let registry = ProviderRegistry::new();
689
690 registry.clear();
692
693 assert!(registry.list_extensions().is_empty());
694 }
695}