1use crate::data::SharedAsyncProvider;
11use crate::plugins::{
12 CapabilityKind, LoadedPlugin, ParsedModuleSchema, ParsedOutputSchema, ParsedQuerySchema,
13 PluginDataSource, PluginLoader, PluginModule,
14};
15use shape_ast::error::{Result, ShapeError};
16use shape_value::ValueWord;
17use shape_wire::WireValue;
18use std::collections::HashMap;
19use std::path::Path;
20use std::sync::{Arc, RwLock};
21
22#[derive(Clone)]
47pub struct ProviderRegistry {
48 providers: Arc<RwLock<HashMap<String, SharedAsyncProvider>>>,
50 default_provider: Arc<RwLock<Option<String>>>,
52 extension_sources: Arc<RwLock<HashMap<String, Arc<PluginDataSource>>>>,
54 extension_modules: Arc<RwLock<HashMap<String, Arc<PluginModule>>>>,
56 loaded_extensions: Arc<RwLock<HashMap<String, LoadedPlugin>>>,
58 extension_loader: Arc<RwLock<PluginLoader>>,
60 language_runtimes:
62 Arc<RwLock<HashMap<String, Arc<crate::plugins::language_runtime::PluginLanguageRuntime>>>>,
63}
64
65impl ProviderRegistry {
66 pub fn new() -> Self {
68 Self {
69 providers: Arc::new(RwLock::new(HashMap::new())),
70 default_provider: Arc::new(RwLock::new(None)),
71 extension_sources: Arc::new(RwLock::new(HashMap::new())),
72 extension_modules: Arc::new(RwLock::new(HashMap::new())),
73 loaded_extensions: Arc::new(RwLock::new(HashMap::new())),
74 extension_loader: Arc::new(RwLock::new(PluginLoader::new())),
75 language_runtimes: Arc::new(RwLock::new(HashMap::new())),
76 }
77 }
78
79 pub fn register(&self, name: &str, provider: SharedAsyncProvider) {
92 let mut providers = self.providers.write().unwrap();
93 providers.insert(name.to_string(), provider);
94 }
95
96 pub fn get(&self, name: &str) -> Option<SharedAsyncProvider> {
106 let providers = self.providers.read().unwrap();
107 providers.get(name).cloned()
108 }
109
110 pub fn set_default(&self, name: &str) -> Result<()> {
120 let providers = self.providers.read().unwrap();
121 if !providers.contains_key(name) {
122 return Err(ShapeError::RuntimeError {
123 message: format!("Cannot set default provider: '{}' is not registered", name),
124 location: None,
125 });
126 }
127 drop(providers);
128
129 let mut default = self.default_provider.write().unwrap();
130 *default = Some(name.to_string());
131 Ok(())
132 }
133
134 pub fn get_default(&self) -> Option<SharedAsyncProvider> {
140 let default = self.default_provider.read().unwrap();
141 let name = default.as_ref().cloned();
142 drop(default);
143
144 name.and_then(|n| self.get(&n))
145 }
146
147 pub fn default_name(&self) -> Option<String> {
149 let default = self.default_provider.read().unwrap();
150 default.clone()
151 }
152
153 pub fn list_providers(&self) -> Vec<String> {
159 let providers = self.providers.read().unwrap();
160 providers.keys().cloned().collect()
161 }
162
163 pub fn has_provider(&self, name: &str) -> bool {
165 let providers = self.providers.read().unwrap();
166 providers.contains_key(name)
167 }
168
169 pub fn unregister(&self, name: &str) -> bool {
179 let mut providers = self.providers.write().unwrap();
180 let removed = providers.remove(name).is_some();
181
182 if removed {
184 let mut default = self.default_provider.write().unwrap();
185 if default.as_ref().map(|s| s == name).unwrap_or(false) {
186 *default = None;
187 }
188 }
189
190 removed
191 }
192
193 pub fn clear(&self) {
195 let mut providers = self.providers.write().unwrap();
196 providers.clear();
197
198 let mut default = self.default_provider.write().unwrap();
199 *default = None;
200
201 let mut extension_sources = self.extension_sources.write().unwrap();
202 extension_sources.clear();
203
204 let mut extension_modules = self.extension_modules.write().unwrap();
205 extension_modules.clear();
206
207 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
208 loaded_extensions.clear();
209
210 let mut runtimes = self.language_runtimes.write().unwrap();
211 runtimes.clear();
212 }
213
214 pub fn load_extension(&self, path: &Path, config: &serde_json::Value) -> Result<LoadedPlugin> {
233 let mut loader = self.extension_loader.write().unwrap();
235 let loaded_info = loader.load(path)?;
236 let name = loaded_info.name.clone();
237
238 if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
241 let vtable = loader.get_data_source_vtable(&name)?;
242 let source = PluginDataSource::new(name.clone(), vtable, config)?;
243
244 let mut sources = self.extension_sources.write().unwrap();
245 sources.insert(name.clone(), Arc::new(source));
246 } else {
247 let mut sources = self.extension_sources.write().unwrap();
250 sources.remove(&name);
251 }
252
253 if let Ok(module_vtable) = loader.get_module_vtable(&name) {
257 if let Ok(module) = PluginModule::new(name.clone(), module_vtable, config) {
258 let mut modules = self.extension_modules.write().unwrap();
259 modules.insert(name.clone(), Arc::new(module));
260 }
261 }
262
263 if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
265 let vtable = loader.get_language_runtime_vtable(&name)?;
266 let runtime =
267 crate::plugins::language_runtime::PluginLanguageRuntime::new(vtable, config)?;
268 let lang_id = runtime.language_id().to_string();
269 let mut runtimes = self.language_runtimes.write().unwrap();
270 runtimes.insert(lang_id, Arc::new(runtime));
271 }
272
273 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
274 loaded_extensions.insert(name, loaded_info.clone());
275
276 Ok(loaded_info)
277 }
278
279 pub fn load_extension_with_sections(
285 &self,
286 path: &Path,
287 config: &serde_json::Value,
288 extension_sections: &std::collections::HashMap<String, toml::Value>,
289 all_claimed: &mut std::collections::HashSet<String>,
290 ) -> Result<LoadedPlugin> {
291 let mut loader = self.extension_loader.write().unwrap();
293 let loaded_info = loader.load(path)?;
294 let name = loaded_info.name.clone();
295
296 for claim in &loaded_info.claimed_sections {
298 if !all_claimed.insert(claim.name.clone()) {
299 return Err(ShapeError::RuntimeError {
300 message: format!(
301 "Section '{}' is claimed by multiple extensions (collision detected when loading '{}')",
302 claim.name, name
303 ),
304 location: None,
305 });
306 }
307 }
308
309 let mut merged_config = config.clone();
312 if let serde_json::Value::Object(ref mut map) = merged_config {
313 for claim in &loaded_info.claimed_sections {
314 if let Some(section_value) = extension_sections.get(&claim.name) {
315 let json_value = crate::project::toml_to_json(section_value);
316 map.insert(claim.name.clone(), json_value);
317 } else if claim.required {
318 return Err(ShapeError::RuntimeError {
319 message: format!(
320 "Extension '{}' requires section '[{}]' in shape.toml, but it is missing",
321 name, claim.name
322 ),
323 location: None,
324 });
325 }
326 }
327 }
328
329 if loaded_info.has_capability_kind(CapabilityKind::DataSource) {
331 let vtable = loader.get_data_source_vtable(&name)?;
332 let source = PluginDataSource::new(name.clone(), vtable, &merged_config)?;
333 let mut sources = self.extension_sources.write().unwrap();
334 sources.insert(name.clone(), Arc::new(source));
335 } else {
336 let mut sources = self.extension_sources.write().unwrap();
337 sources.remove(&name);
338 }
339
340 if let Ok(module_vtable) = loader.get_module_vtable(&name) {
341 if let Ok(module) = PluginModule::new(name.clone(), module_vtable, &merged_config) {
342 let mut modules = self.extension_modules.write().unwrap();
343 modules.insert(name.clone(), Arc::new(module));
344 }
345 }
346
347 if loaded_info.has_capability_kind(CapabilityKind::LanguageRuntime) {
348 let vtable = loader.get_language_runtime_vtable(&name)?;
349 let runtime = crate::plugins::language_runtime::PluginLanguageRuntime::new(
350 vtable,
351 &merged_config,
352 )?;
353 let lang_id = runtime.language_id().to_string();
354 let mut runtimes = self.language_runtimes.write().unwrap();
355 runtimes.insert(lang_id, Arc::new(runtime));
356 }
357
358 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
359 loaded_extensions.insert(name, loaded_info.clone());
360
361 Ok(loaded_info)
362 }
363
364 pub fn get_language_runtime(
366 &self,
367 language_id: &str,
368 ) -> Option<Arc<crate::plugins::language_runtime::PluginLanguageRuntime>> {
369 let runtimes = self.language_runtimes.read().unwrap();
370 runtimes.get(language_id).cloned()
371 }
372
373 pub fn language_runtimes(
375 &self,
376 ) -> std::collections::HashMap<String, Arc<crate::plugins::language_runtime::PluginLanguageRuntime>> {
377 let runtimes = self.language_runtimes.read().unwrap();
378 runtimes.clone()
379 }
380
381 pub fn language_runtime_lsp_configs(
383 &self,
384 ) -> Vec<crate::plugins::language_runtime::RuntimeLspConfig> {
385 let runtimes = self.language_runtimes.read().unwrap();
386 let mut configs = Vec::new();
387
388 for runtime in runtimes.values() {
389 match runtime.lsp_config() {
390 Ok(Some(config)) => configs.push(config),
391 Ok(None) => {}
392 Err(err) => {
393 tracing::warn!("failed to query language runtime LSP config: {}", err);
394 }
395 }
396 }
397
398 configs.sort_by(|left, right| left.language_id.cmp(&right.language_id));
399 configs
400 }
401
402 pub fn get_extension(&self, name: &str) -> Option<Arc<PluginDataSource>> {
412 let sources = self.extension_sources.read().unwrap();
413 sources.get(name).cloned()
414 }
415
416 pub fn get_extension_module_schema(&self, module_name: &str) -> Option<ParsedModuleSchema> {
418 let modules = self.extension_modules.read().unwrap();
419 modules
420 .values()
421 .find(|m| m.schema().module_name == module_name)
422 .map(|m| m.schema().clone())
423 }
424
425 pub fn module_exports_from_extensions(&self) -> Vec<crate::module_exports::ModuleExports> {
427 let modules = self.extension_modules.read().unwrap();
428 modules.values().map(|m| m.to_module_exports()).collect()
429 }
430
431 pub fn invoke_extension_module_nb(
433 &self,
434 module_name: &str,
435 function: &str,
436 args: &[ValueWord],
437 ) -> Result<ValueWord> {
438 let modules = self.extension_modules.read().unwrap();
439 let module = modules
440 .values()
441 .find(|m| m.schema().module_name == module_name)
442 .ok_or_else(|| ShapeError::RuntimeError {
443 message: format!("Module namespace '{}' is not loaded", module_name),
444 location: None,
445 })?;
446 module.invoke_nb(function, args)
447 }
448
449 pub fn invoke_extension_module_wire(
451 &self,
452 module_name: &str,
453 function: &str,
454 args: &[WireValue],
455 ) -> Result<WireValue> {
456 let modules = self.extension_modules.read().unwrap();
457 let module = modules
458 .values()
459 .find(|m| m.schema().module_name == module_name)
460 .ok_or_else(|| ShapeError::RuntimeError {
461 message: format!("Module namespace '{}' is not loaded", module_name),
462 location: None,
463 })?;
464 module.invoke_wire(function, args)
465 }
466
467 pub fn get_extension_query_schema(&self, name: &str) -> Option<ParsedQuerySchema> {
477 let sources = self.extension_sources.read().unwrap();
478 sources.get(name).map(|s| s.get_query_schema().clone())
479 }
480
481 pub fn get_extension_output_schema(&self, name: &str) -> Option<ParsedOutputSchema> {
491 let sources = self.extension_sources.read().unwrap();
492 sources.get(name).map(|s| s.get_output_schema().clone())
493 }
494
495 pub fn list_extensions_with_schemas(&self) -> Vec<(String, ParsedQuerySchema)> {
501 let sources = self.extension_sources.read().unwrap();
502 sources
503 .iter()
504 .map(|(name, source)| (name.clone(), source.get_query_schema().clone()))
505 .collect()
506 }
507
508 pub fn list_extensions(&self) -> Vec<String> {
510 let loaded = self.loaded_extensions.read().unwrap();
511 loaded.keys().cloned().collect()
512 }
513
514 pub fn has_extension(&self, name: &str) -> bool {
516 let loaded = self.loaded_extensions.read().unwrap();
517 loaded.contains_key(name)
518 }
519
520 pub fn unload_extension(&self, name: &str) -> bool {
530 let mut sources = self.extension_sources.write().unwrap();
531 let removed_source = sources.remove(name).is_some();
532 drop(sources);
533
534 let mut modules = self.extension_modules.write().unwrap();
535 let removed_module = modules.remove(name).is_some();
536 drop(modules);
537
538 let mut loaded_extensions = self.loaded_extensions.write().unwrap();
539 let removed_plugin = loaded_extensions.remove(name).is_some();
540 drop(loaded_extensions);
541
542 if removed_plugin {
543 let mut loader = self.extension_loader.write().unwrap();
544 loader.unload(name);
545 }
546
547 removed_plugin || removed_source || removed_module
548 }
549}
550
551impl Default for ProviderRegistry {
552 fn default() -> Self {
553 Self::new()
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::data::async_provider::NullAsyncProvider;
561
562 #[test]
563 fn test_register_and_get() {
564 let registry = ProviderRegistry::new();
565 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
566
567 registry.register("test", provider.clone());
568
569 assert!(registry.has_provider("test"));
570 assert!(!registry.has_provider("nonexistent"));
571 assert!(registry.get("test").is_some());
572 }
573
574 #[test]
575 fn test_default_provider() {
576 let registry = ProviderRegistry::new();
577 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
578
579 registry.register("test", provider);
580
581 assert!(registry.set_default("test").is_ok());
582 assert!(registry.get_default().is_some());
583 assert_eq!(registry.default_name(), Some("test".to_string()));
584 }
585
586 #[test]
587 fn test_set_default_nonexistent() {
588 let registry = ProviderRegistry::new();
589 assert!(registry.set_default("nonexistent").is_err());
590 }
591
592 #[test]
593 fn test_list_providers() {
594 let registry = ProviderRegistry::new();
595 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
596
597 registry.register("test1", provider.clone());
598 registry.register("test2", provider);
599
600 let mut names = registry.list_providers();
601 names.sort();
602 assert_eq!(names, vec!["test1", "test2"]);
603 }
604
605 #[test]
606 fn test_unregister() {
607 let registry = ProviderRegistry::new();
608 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
609
610 registry.register("test", provider);
611 registry.set_default("test").unwrap();
612
613 assert!(registry.unregister("test"));
614 assert!(!registry.has_provider("test"));
615 assert!(registry.get_default().is_none());
616 }
617
618 #[test]
619 fn test_clear() {
620 let registry = ProviderRegistry::new();
621 let provider = Arc::new(NullAsyncProvider) as SharedAsyncProvider;
622
623 registry.register("test1", provider.clone());
624 registry.register("test2", provider);
625 registry.set_default("test1").unwrap();
626
627 registry.clear();
628
629 assert_eq!(registry.list_providers().len(), 0);
630 assert!(registry.get_default().is_none());
631 }
632
633 #[test]
636 fn test_plugin_not_loaded_by_default() {
637 let registry = ProviderRegistry::new();
638
639 assert!(!registry.has_extension("nonexistent"));
640 assert!(registry.get_extension("nonexistent").is_none());
641 }
642
643 #[test]
644 fn test_list_extensions_empty() {
645 let registry = ProviderRegistry::new();
646
647 let plugins = registry.list_extensions();
648 assert!(plugins.is_empty());
649 }
650
651 #[test]
652 fn test_list_extensions_with_schemas_empty() {
653 let registry = ProviderRegistry::new();
654
655 let schemas = registry.list_extensions_with_schemas();
656 assert!(schemas.is_empty());
657 }
658
659 #[test]
660 fn test_get_extension_query_schema_not_found() {
661 let registry = ProviderRegistry::new();
662
663 let schema = registry.get_extension_query_schema("nonexistent");
664 assert!(schema.is_none());
665 }
666
667 #[test]
668 fn test_get_extension_output_schema_not_found() {
669 let registry = ProviderRegistry::new();
670
671 let schema = registry.get_extension_output_schema("nonexistent");
672 assert!(schema.is_none());
673 }
674
675 #[test]
676 fn test_unload_plugin_not_loaded() {
677 let registry = ProviderRegistry::new();
678
679 assert!(!registry.unload_extension("nonexistent"));
681 }
682
683 #[test]
684 fn test_clear_removes_plugins() {
685 let registry = ProviderRegistry::new();
686
687 registry.clear();
689
690 assert!(registry.list_extensions().is_empty());
691 }
692}