Skip to main content

synaps_cli/extensions/
providers.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::extensions::runtime::process::RegisteredProviderSpec;
5use crate::extensions::runtime::ExtensionHandler;
6
7pub struct RegisteredProvider {
8    pub plugin_id: String,
9    pub provider_id: String,
10    pub runtime_id: String,
11    pub spec: RegisteredProviderSpec,
12    pub handler: Option<Arc<dyn ExtensionHandler>>,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct RegisteredProviderModelSummary {
17    pub runtime_id: String,
18    pub display_name: Option<String>,
19    pub tool_use: bool,
20    pub streaming: bool,
21    pub context_window: Option<u64>,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct RegisteredProviderSummary {
26    pub runtime_id: String,
27    pub display_name: String,
28    pub models: Vec<RegisteredProviderModelSummary>,
29}
30
31#[derive(Default)]
32pub struct ProviderRegistry {
33    providers: HashMap<String, RegisteredProvider>,
34}
35
36impl ProviderRegistry {
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    pub fn register(
42        &mut self,
43        plugin_id: &str,
44        spec: RegisteredProviderSpec,
45    ) -> Result<String, String> {
46        self.register_with_handler(plugin_id, spec, None)
47    }
48
49    pub fn register_with_handler(
50        &mut self,
51        plugin_id: &str,
52        spec: RegisteredProviderSpec,
53        handler: Option<Arc<dyn ExtensionHandler>>,
54    ) -> Result<String, String> {
55        let runtime_id = format!("{}:{}", plugin_id, spec.id);
56        if self.providers.contains_key(&runtime_id) {
57            return Err(format!("provider '{}' is already registered", runtime_id));
58        }
59        self.providers.insert(runtime_id.clone(), RegisteredProvider {
60            plugin_id: plugin_id.to_string(),
61            provider_id: spec.id.clone(),
62            runtime_id: runtime_id.clone(),
63            spec,
64            handler,
65        });
66        Ok(runtime_id)
67    }
68
69    pub fn unregister_plugin(&mut self, plugin_id: &str) {
70        self.providers.retain(|_, provider| provider.plugin_id != plugin_id);
71    }
72
73    pub fn get(&self, runtime_id: &str) -> Option<&RegisteredProvider> {
74        self.providers.get(runtime_id)
75    }
76
77    pub fn list(&self) -> Vec<&RegisteredProvider> {
78        let mut providers: Vec<_> = self.providers.values().collect();
79        providers.sort_by(|a, b| a.runtime_id.cmp(&b.runtime_id));
80        providers
81    }
82
83    pub fn len(&self) -> usize {
84        self.providers.len()
85    }
86
87    pub fn is_empty(&self) -> bool {
88        self.providers.is_empty()
89    }
90
91    pub fn parse_model_id(model: &str) -> Option<(&str, &str, &str)> {
92        let mut parts = model.split(':');
93        let plugin_id = parts.next()?;
94        let provider_id = parts.next()?;
95        let model_id = parts.next()?;
96        if parts.next().is_some() || plugin_id.is_empty() || provider_id.is_empty() || model_id.is_empty() {
97            return None;
98        }
99        Some((plugin_id, provider_id, model_id))
100    }
101
102    pub fn model_runtime_id(plugin_id: &str, provider_id: &str, model_id: &str) -> String {
103        format!("{}:{}:{}", plugin_id, provider_id, model_id)
104    }
105
106    pub fn summaries(&self) -> Vec<RegisteredProviderSummary> {
107        self.list()
108            .into_iter()
109            .map(|provider| RegisteredProviderSummary {
110                runtime_id: provider.runtime_id.clone(),
111                display_name: provider.spec.display_name.clone(),
112                models: provider
113                    .spec
114                    .models
115                    .iter()
116                    .map(|model| RegisteredProviderModelSummary {
117                        runtime_id: Self::model_runtime_id(&provider.plugin_id, &provider.provider_id, &model.id),
118                        display_name: model.display_name.clone(),
119                        tool_use: model.capabilities.get("tool_use").and_then(|v| v.as_bool()).unwrap_or(false),
120                        streaming: model.capabilities.get("streaming").and_then(|v| v.as_bool()).unwrap_or(false),
121                        context_window: model.context_window,
122                    })
123                    .collect(),
124            })
125            .collect()
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    fn spec(id: &str) -> RegisteredProviderSpec {
134        RegisteredProviderSpec {
135            id: id.to_string(),
136            display_name: "Local".to_string(),
137            description: "Local provider".to_string(),
138            models: vec![],
139            config_schema: None,
140        }
141    }
142
143    #[test]
144    fn summaries_include_model_tool_use_capability_and_context_metadata() {
145        let mut spec = spec("local");
146        spec.models = vec![crate::extensions::runtime::process::RegisteredProviderModelSpec {
147            id: "model-a".to_string(),
148            display_name: Some("Model A".to_string()),
149            capabilities: serde_json::json!({"tool_use": true, "streaming": true}),
150            context_window: Some(8192),
151        }];
152        let mut registry = ProviderRegistry::new();
153        registry.register("plugin", spec).unwrap();
154
155        let summaries = registry.summaries();
156
157        assert_eq!(summaries[0].models, vec![RegisteredProviderModelSummary {
158            runtime_id: "plugin:local:model-a".to_string(),
159            display_name: Some("Model A".to_string()),
160            tool_use: true,
161            streaming: true,
162            context_window: Some(8192),
163        }]);
164    }
165
166    #[test]
167    fn summaries_default_streaming_to_false_when_capability_absent() {
168        let mut spec = spec("local");
169        spec.models = vec![crate::extensions::runtime::process::RegisteredProviderModelSpec {
170            id: "model-b".to_string(),
171            display_name: None,
172            capabilities: serde_json::json!({}),
173            context_window: None,
174        }];
175        let mut registry = ProviderRegistry::new();
176        registry.register("plugin", spec).unwrap();
177
178        let summaries = registry.summaries();
179        assert!(!summaries[0].models[0].streaming);
180        assert!(!summaries[0].models[0].tool_use);
181    }
182
183    #[test]
184    fn register_namespaces_provider_by_plugin() {
185        let mut registry = ProviderRegistry::new();
186        let id = registry.register("plugin", spec("local")).unwrap();
187
188        assert_eq!(id, "plugin:local");
189        assert!(registry.get("plugin:local").is_some());
190    }
191
192    #[test]
193    fn duplicate_runtime_provider_id_is_rejected() {
194        let mut registry = ProviderRegistry::new();
195        registry.register("plugin", spec("local")).unwrap();
196        let err = registry.register("plugin", spec("local")).unwrap_err();
197
198        assert!(err.contains("already registered"));
199    }
200
201    #[test]
202    fn unregister_plugin_removes_its_providers_only() {
203        let mut registry = ProviderRegistry::new();
204        registry.register("one", spec("local")).unwrap();
205        registry.register("two", spec("local")).unwrap();
206
207        registry.unregister_plugin("one");
208
209        assert!(registry.get("one:local").is_none());
210        assert!(registry.get("two:local").is_some());
211    }
212
213    #[test]
214    fn model_ids_use_three_part_namespace() {
215        assert_eq!(
216            ProviderRegistry::parse_model_id("plugin:local:model-a"),
217            Some(("plugin", "local", "model-a"))
218        );
219        assert_eq!(ProviderRegistry::model_runtime_id("plugin", "local", "model-a"), "plugin:local:model-a");
220        assert!(ProviderRegistry::parse_model_id("plugin:local").is_none());
221        assert!(ProviderRegistry::parse_model_id("plugin:local:model:extra").is_none());
222    }
223}