1use crate::data::SharedAsyncProvider;
11use crate::plugins::{
12 CapabilityKind, LoadedPlugin, ParsedModuleSchema, ParsedOutputSchema, ParsedQuerySchema,
13 PluginDataSource, PluginLoader, PluginModule,
14};
15use shape_ast::error::{Result, ShapeError};
16use shape_value::ValueWord;
17use shape_wire::WireValue;
18use std::collections::HashMap;
19use std::path::Path;
20use std::sync::{Arc, RwLock};
21
22#[derive(Clone)]
47pub struct ProviderRegistry {
48 providers: Arc<RwLock<HashMap<String, SharedAsyncProvider>>>,
50 default_provider: Arc<RwLock<Option<String>>>,
52 extension_sources: Arc<RwLock<HashMap<String, Arc<PluginDataSource>>>>,
54 extension_modules: Arc<RwLock<HashMap<String, Arc<PluginModule>>>>,
56 loaded_extensions: Arc<RwLock<HashMap<String, LoadedPlugin>>>,
58 extension_loader: Arc<RwLock<PluginLoader>>,
60 language_runtimes:
62 Arc<RwLock<HashMap<String, Arc<crate::plugins::language_runtime::PluginLanguageRuntime>>>>,
63}
64
65impl ProviderRegistry {
66 pub fn new() -> Self {
68 Self {
69 providers: Arc::new(RwLock::new(HashMap::new())),
70 default_provider: Arc::new(RwLock::new(None)),
71 extension_sources: Arc::new(RwLock::new(HashMap::new())),
72 extension_modules: Arc::new(RwLock::new(HashMap::new())),
73 loaded_extensions: Arc::new(RwLock::new(HashMap::new())),
74 extension_loader: Arc::new(RwLock::new(PluginLoader::new())),
75 language_runtimes: Arc::new(RwLock::new(HashMap::new())),
76 }
77 }
78
79 pub fn register(&self, name: &str, provider: SharedAsyncProvider) {
92 let mut providers = self.providers.write().unwrap();
93 providers.insert(name.to_string(), provider);
94 }
95
96 pub fn get(&self, name: &str) -> Option<SharedAsyncProvider> {
106 let providers = self.providers.read().unwrap();
107 providers.get(name).cloned()
108 }
109
110 pub fn set_default(&self, name: &str) -> Result<()> {
120 let providers = self.providers.read().unwrap();
121 if !providers.contains_key(name) {
122 return Err(ShapeError::RuntimeError {
123 message: format!("Cannot set default provider: '{}' is not registered", name),
124 location: None,
125 });
126 }
127 drop(providers);
128
129 let mut default = self.default_provider.write().unwrap();
130 *default = Some(name.to_string());
131 Ok(())
132 }
133
134 pub fn get_default(&self) -> Option<SharedAsyncProvider> {
140 let default = self.default_provider.read().unwrap();
141 let name = default.as_ref().cloned();
142 drop(default);
143
144 name.and_then(|n| self.get(&n))
145 }
146
147 pub fn default_name(&self) -> Option<String> {
149 let default = self.default_provider.read().unwrap();
150 default.clone()
151 }
152
153 pub fn list_providers(&self) -> Vec<String> {
159 let providers = self.providers.read().unwrap();
160 providers.keys().cloned().collect()
161 }
162
163 pub fn has_provider(&self, name: &str) -> bool {
165 let providers = self.providers.read().unwrap();
166 providers.contains_key(name)
167 }
168
169 pub fn unregister(&self, name: &str) -> bool {
179 let mut providers = self.providers.write().unwrap();
180 let removed = providers.remove(name).is_some();
181
182 if removed {
184 let mut default = self.default_provider.write().unwrap();
185 if default.as_ref().map(|s| s == name).unwrap_or(false) {
186 *default = None;
187 }
188 }
189
190 removed
191 }
192
193 pub fn clear(&self) {
195 let mut providers = self.providers.write().unwrap();
196 providers.clear();
197
198 let mut default = self.default_provider.write().unwrap();
199 *default = None;
200
201 let mut extension_sources = self.extension_sources.write().unwrap();
202 extension_sources.clear();
203
204 let mut extension_modules = self.extension_modules.write().unwrap();
205 extension_modules.clear();
206
207 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
208 loaded_extensions.clear();
209
210 let mut runtimes = self.language_runtimes.write().unwrap();
211 runtimes.clear();
212 }
213
214 pub fn load_extension(&self, path: &Path, config: &serde_json::Value) -> Result<LoadedPlugin> {
233 let mut loader = self.extension_loader.write().unwrap();
235 let loaded_info = loader.load(path)?;
236 let name = loaded_info.name.clone();
237
238 if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
241 let vtable = loader.get_data_source_vtable(&name)?;
242 let source = PluginDataSource::new(name.clone(), vtable, config)?;
243
244 let mut sources = self.extension_sources.write().unwrap();
245 sources.insert(name.clone(), Arc::new(source));
246 } else {
247 let mut sources = self.extension_sources.write().unwrap();
250 sources.remove(&name);
251 }
252
253 if let Ok(module_vtable) = loader.get_module_vtable(&name) {
257 if let Ok(module) = PluginModule::new(name.clone(), module_vtable, config) {
258 let mut modules = self.extension_modules.write().unwrap();
259 modules.insert(name.clone(), Arc::new(module));
260 }
261 }
262
263 if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
265 let vtable = loader.get_language_runtime_vtable(&name)?;
266 let runtime =
267 crate::plugins::language_runtime::PluginLanguageRuntime::new(vtable, config)?;
268 let lang_id = runtime.language_id().to_string();
269 let mut runtimes = self.language_runtimes.write().unwrap();
270 runtimes.insert(lang_id, Arc::new(runtime));
271 }
272
273 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
274 loaded_extensions.insert(name, loaded_info.clone());
275
276 Ok(loaded_info)
277 }
278
279 pub fn load_extension_with_sections(
285 &self,
286 path: &Path,
287 config: &serde_json::Value,
288 extension_sections: &std::collections::HashMap<String, toml::Value>,
289 all_claimed: &mut std::collections::HashSet<String>,
290 ) -> Result<LoadedPlugin> {
291 let mut loader = self.extension_loader.write().unwrap();
293 let loaded_info = loader.load(path)?;
294 let name = loaded_info.name.clone();
295
296 for claim in &loaded_info.claimed_sections {
298 if !all_claimed.insert(claim.name.clone()) {
299 return Err(ShapeError::RuntimeError {
300 message: format!(
301 "Section '{}' is claimed by multiple extensions (collision detected when loading '{}')",
302 claim.name, name
303 ),
304 location: None,
305 });
306 }
307 }
308
309 let mut merged_config = config.clone();
312 if let serde_json::Value::Object(ref mut map) = merged_config {
313 for claim in &loaded_info.claimed_sections {
314 if let Some(section_value) = extension_sections.get(&claim.name) {
315 let json_value = crate::project::toml_to_json(section_value);
316 map.insert(claim.name.clone(), json_value);
317 } else if claim.required {
318 return Err(ShapeError::RuntimeError {
319 message: format!(
320 "Extension '{}' requires section '[{}]' in shape.toml, but it is missing",
321 name, claim.name
322 ),
323 location: None,
324 });
325 }
326 }
327 }
328
329 if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
331 let vtable = loader.get_data_source_vtable(&name)?;
332 let source = PluginDataSource::new(name.clone(), vtable, &merged_config)?;
333 let mut sources = self.extension_sources.write().unwrap();
334 sources.insert(name.clone(), Arc::new(source));
335 } else {
336 let mut sources = self.extension_sources.write().unwrap();
337 sources.remove(&name);
338 }
339
340 if let Ok(module_vtable) = loader.get_module_vtable(&name) {
341 if let Ok(module) = PluginModule::new(name.clone(), module_vtable, &merged_config) {
342 let mut modules = self.extension_modules.write().unwrap();
343 modules.insert(name.clone(), Arc::new(module));
344 }
345 }
346
347 if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
348 let vtable = loader.get_language_runtime_vtable(&name)?;
349 let runtime = crate::plugins::language_runtime::PluginLanguageRuntime::new(
350 vtable,
351 &merged_config,
352 )?;
353 let lang_id = runtime.language_id().to_string();
354 let mut runtimes = self.language_runtimes.write().unwrap();
355 runtimes.insert(lang_id, Arc::new(runtime));
356 }
357
358 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
359 loaded_extensions.insert(name, loaded_info.clone());
360
361 Ok(loaded_info)
362 }
363
364 pub fn get_language_runtime(
366 &self,
367 language_id: &str,
368 ) -> Option<Arc<crate::plugins::language_runtime::PluginLanguageRuntime>> {
369 let runtimes = self.language_runtimes.read().unwrap();
370 runtimes.get(language_id).cloned()
371 }
372
373 pub fn language_runtime_lsp_configs(
375 &self,
376 ) -> Vec<crate::plugins::language_runtime::RuntimeLspConfig> {
377 let runtimes = self.language_runtimes.read().unwrap();
378 let mut configs = Vec::new();
379
380 for runtime in runtimes.values() {
381 match runtime.lsp_config() {
382 Ok(Some(config)) => configs.push(config),
383 Ok(None) => {}
384 Err(err) => {
385 tracing::warn!("failed to query language runtime LSP config: {}", err);
386 }
387 }
388 }
389
390 configs.sort_by(|left, right| left.language_id.cmp(&right.language_id));
391 configs
392 }
393
394 pub fn get_extension(&self, name: &str) -> Option<Arc<PluginDataSource>> {
404 let sources = self.extension_sources.read().unwrap();
405 sources.get(name).cloned()
406 }
407
408 pub fn get_extension_module_schema(&self, module_name: &str) -> Option<ParsedModuleSchema> {
410 let modules = self.extension_modules.read().unwrap();
411 modules
412 .values()
413 .find(|m| m.schema().module_name == module_name)
414 .map(|m| m.schema().clone())
415 }
416
417 pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
419 let modules = self.extension_modules.read().unwrap();
420 modules.values().map(|m| m.to_module_exports()).collect()
421 }
422
423 pub fn invoke_extension_module_nb(
425 &self,
426 module_name: &str,
427 function: &str,
428 args: &[ValueWord],
429 ) -> Result<ValueWord> {
430 let modules = self.extension_modules.read().unwrap();
431 let module = modules
432 .values()
433 .find(|m| m.schema().module_name == module_name)
434 .ok_or_else(|| ShapeError::RuntimeError {
435 message: format!("Module namespace '{}' is not loaded", module_name),
436 location: None,
437 })?;
438 module.invoke_nb(function, args)
439 }
440
441 pub fn invoke_extension_module_wire(
443 &self,
444 module_name: &str,
445 function: &str,
446 args: &[WireValue],
447 ) -> Result<WireValue> {
448 let modules = self.extension_modules.read().unwrap();
449 let module = modules
450 .values()
451 .find(|m| m.schema().module_name == module_name)
452 .ok_or_else(|| ShapeError::RuntimeError {
453 message: format!("Module namespace '{}' is not loaded", module_name),
454 location: None,
455 })?;
456 module.invoke_wire(function, args)
457 }
458
459 pub fn get_extension_query_schema(&self, name: &str) -> Option<ParsedQuerySchema> {
469 let sources = self.extension_sources.read().unwrap();
470 sources.get(name).map(|s| s.get_query_schema().clone())
471 }
472
473 pub fn get_extension_output_schema(&self, name: &str) -> Option<ParsedOutputSchema> {
483 let sources = self.extension_sources.read().unwrap();
484 sources.get(name).map(|s| s.get_output_schema().clone())
485 }
486
487 pub fn list_extensions_with_schemas(&self) -> Vec<(String, ParsedQuerySchema)> {
493 let sources = self.extension_sources.read().unwrap();
494 sources
495 .iter()
496 .map(|(name, source)| (name.clone(), source.get_query_schema().clone()))
497 .collect()
498 }
499
500 pub fn list_extensions(&self) -> Vec<String> {
502 let loaded = self.loaded_extensions.read().unwrap();
503 loaded.keys().cloned().collect()
504 }
505
506 pub fn has_extension(&self, name: &str) -> bool {
508 let loaded = self.loaded_extensions.read().unwrap();
509 loaded.contains_key(name)
510 }
511
512 pub fn unload_extension(&self, name: &str) -> bool {
522 let mut sources = self.extension_sources.write().unwrap();
523 let removed_source = sources.remove(name).is_some();
524 drop(sources);
525
526 let mut modules = self.extension_modules.write().unwrap();
527 let removed_module = modules.remove(name).is_some();
528 drop(modules);
529
530 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
531 let removed_plugin = loaded_extensions.remove(name).is_some();
532 drop(loaded_extensions);
533
534 if removed_plugin {
535 let mut loader = self.extension_loader.write().unwrap();
536 loader.unload(name);
537 }
538
539 removed_plugin || removed_source || removed_module
540 }
541}
542
543impl Default for ProviderRegistry {
544 fn default() -> Self {
545 Self::new()
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use crate::data::async_provider::NullAsyncProvider;
553
554 #[test]
555 fn test_register_and_get() {
556 let registry = ProviderRegistry::new();
557 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
558
559 registry.register("test", provider.clone());
560
561 assert!(registry.has_provider("test"));
562 assert!(!registry.has_provider("nonexistent"));
563 assert!(registry.get("test").is_some());
564 }
565
566 #[test]
567 fn test_default_provider() {
568 let registry = ProviderRegistry::new();
569 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
570
571 registry.register("test", provider);
572
573 assert!(registry.set_default("test").is_ok());
574 assert!(registry.get_default().is_some());
575 assert_eq!(registry.default_name(), Some("test".to_string()));
576 }
577
578 #[test]
579 fn test_set_default_nonexistent() {
580 let registry = ProviderRegistry::new();
581 assert!(registry.set_default("nonexistent").is_err());
582 }
583
584 #[test]
585 fn test_list_providers() {
586 let registry = ProviderRegistry::new();
587 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
588
589 registry.register("test1", provider.clone());
590 registry.register("test2", provider);
591
592 let mut names = registry.list_providers();
593 names.sort();
594 assert_eq!(names, vec!["test1", "test2"]);
595 }
596
597 #[test]
598 fn test_unregister() {
599 let registry = ProviderRegistry::new();
600 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
601
602 registry.register("test", provider);
603 registry.set_default("test").unwrap();
604
605 assert!(registry.unregister("test"));
606 assert!(!registry.has_provider("test"));
607 assert!(registry.get_default().is_none());
608 }
609
610 #[test]
611 fn test_clear() {
612 let registry = ProviderRegistry::new();
613 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
614
615 registry.register("test1", provider.clone());
616 registry.register("test2", provider);
617 registry.set_default("test1").unwrap();
618
619 registry.clear();
620
621 assert_eq!(registry.list_providers().len(), 0);
622 assert!(registry.get_default().is_none());
623 }
624
625 #[test]
628 fn test_plugin_not_loaded_by_default() {
629 let registry = ProviderRegistry::new();
630
631 assert!(!registry.has_extension("nonexistent"));
632 assert!(registry.get_extension("nonexistent").is_none());
633 }
634
635 #[test]
636 fn test_list_extensions_empty() {
637 let registry = ProviderRegistry::new();
638
639 let plugins = registry.list_extensions();
640 assert!(plugins.is_empty());
641 }
642
643 #[test]
644 fn test_list_extensions_with_schemas_empty() {
645 let registry = ProviderRegistry::new();
646
647 let schemas = registry.list_extensions_with_schemas();
648 assert!(schemas.is_empty());
649 }
650
651 #[test]
652 fn test_get_extension_query_schema_not_found() {
653 let registry = ProviderRegistry::new();
654
655 let schema = registry.get_extension_query_schema("nonexistent");
656 assert!(schema.is_none());
657 }
658
659 #[test]
660 fn test_get_extension_output_schema_not_found() {
661 let registry = ProviderRegistry::new();
662
663 let schema = registry.get_extension_output_schema("nonexistent");
664 assert!(schema.is_none());
665 }
666
667 #[test]
668 fn test_unload_plugin_not_loaded() {
669 let registry = ProviderRegistry::new();
670
671 assert!(!registry.unload_extension("nonexistent"));
673 }
674
675 #[test]
676 fn test_clear_removes_plugins() {
677 let registry = ProviderRegistry::new();
678
679 registry.clear();
681
682 assert!(registry.list_extensions().is_empty());
683 }
684}