1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::extensions::runtime::process::RegisteredProviderSpec;
5use crate::extensions::runtime::ExtensionHandler;
6
7pub struct RegisteredProvider {
8 pub plugin_id: String,
9 pub provider_id: String,
10 pub runtime_id: String,
11 pub spec: RegisteredProviderSpec,
12 pub handler: Option<Arc<dyn ExtensionHandler>>,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct RegisteredProviderModelSummary {
17 pub runtime_id: String,
18 pub display_name: Option<String>,
19 pub tool_use: bool,
20 pub streaming: bool,
21 pub context_window: Option<u64>,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct RegisteredProviderSummary {
26 pub runtime_id: String,
27 pub display_name: String,
28 pub models: Vec<RegisteredProviderModelSummary>,
29}
30
31#[derive(Default)]
32pub struct ProviderRegistry {
33 providers: HashMap<String, RegisteredProvider>,
34}
35
36impl ProviderRegistry {
37 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn register(
42 &mut self,
43 plugin_id: &str,
44 spec: RegisteredProviderSpec,
45 ) -> Result<String, String> {
46 self.register_with_handler(plugin_id, spec, None)
47 }
48
49 pub fn register_with_handler(
50 &mut self,
51 plugin_id: &str,
52 spec: RegisteredProviderSpec,
53 handler: Option<Arc<dyn ExtensionHandler>>,
54 ) -> Result<String, String> {
55 let runtime_id = format!("{}:{}", plugin_id, spec.id);
56 if self.providers.contains_key(&runtime_id) {
57 return Err(format!("provider '{}' is already registered", runtime_id));
58 }
59 self.providers.insert(runtime_id.clone(), RegisteredProvider {
60 plugin_id: plugin_id.to_string(),
61 provider_id: spec.id.clone(),
62 runtime_id: runtime_id.clone(),
63 spec,
64 handler,
65 });
66 Ok(runtime_id)
67 }
68
69 pub fn unregister_plugin(&mut self, plugin_id: &str) {
70 self.providers.retain(|_, provider| provider.plugin_id != plugin_id);
71 }
72
73 pub fn get(&self, runtime_id: &str) -> Option<&RegisteredProvider> {
74 self.providers.get(runtime_id)
75 }
76
77 pub fn list(&self) -> Vec<&RegisteredProvider> {
78 let mut providers: Vec<_> = self.providers.values().collect();
79 providers.sort_by(|a, b| a.runtime_id.cmp(&b.runtime_id));
80 providers
81 }
82
83 pub fn len(&self) -> usize {
84 self.providers.len()
85 }
86
87 pub fn is_empty(&self) -> bool {
88 self.providers.is_empty()
89 }
90
91 pub fn parse_model_id(model: &str) -> Option<(&str, &str, &str)> {
92 let mut parts = model.split(':');
93 let plugin_id = parts.next()?;
94 let provider_id = parts.next()?;
95 let model_id = parts.next()?;
96 if parts.next().is_some() || plugin_id.is_empty() || provider_id.is_empty() || model_id.is_empty() {
97 return None;
98 }
99 Some((plugin_id, provider_id, model_id))
100 }
101
102 pub fn model_runtime_id(plugin_id: &str, provider_id: &str, model_id: &str) -> String {
103 format!("{}:{}:{}", plugin_id, provider_id, model_id)
104 }
105
106 pub fn summaries(&self) -> Vec<RegisteredProviderSummary> {
107 self.list()
108 .into_iter()
109 .map(|provider| RegisteredProviderSummary {
110 runtime_id: provider.runtime_id.clone(),
111 display_name: provider.spec.display_name.clone(),
112 models: provider
113 .spec
114 .models
115 .iter()
116 .map(|model| RegisteredProviderModelSummary {
117 runtime_id: Self::model_runtime_id(&provider.plugin_id, &provider.provider_id, &model.id),
118 display_name: model.display_name.clone(),
119 tool_use: model.capabilities.get("tool_use").and_then(|v| v.as_bool()).unwrap_or(false),
120 streaming: model.capabilities.get("streaming").and_then(|v| v.as_bool()).unwrap_or(false),
121 context_window: model.context_window,
122 })
123 .collect(),
124 })
125 .collect()
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 fn spec(id: &str) -> RegisteredProviderSpec {
134 RegisteredProviderSpec {
135 id: id.to_string(),
136 display_name: "Local".to_string(),
137 description: "Local provider".to_string(),
138 models: vec![],
139 config_schema: None,
140 }
141 }
142
143 #[test]
144 fn summaries_include_model_tool_use_capability_and_context_metadata() {
145 let mut spec = spec("local");
146 spec.models = vec![crate::extensions::runtime::process::RegisteredProviderModelSpec {
147 id: "model-a".to_string(),
148 display_name: Some("Model A".to_string()),
149 capabilities: serde_json::json!({"tool_use": true, "streaming": true}),
150 context_window: Some(8192),
151 }];
152 let mut registry = ProviderRegistry::new();
153 registry.register("plugin", spec).unwrap();
154
155 let summaries = registry.summaries();
156
157 assert_eq!(summaries[0].models, vec![RegisteredProviderModelSummary {
158 runtime_id: "plugin:local:model-a".to_string(),
159 display_name: Some("Model A".to_string()),
160 tool_use: true,
161 streaming: true,
162 context_window: Some(8192),
163 }]);
164 }
165
166 #[test]
167 fn summaries_default_streaming_to_false_when_capability_absent() {
168 let mut spec = spec("local");
169 spec.models = vec![crate::extensions::runtime::process::RegisteredProviderModelSpec {
170 id: "model-b".to_string(),
171 display_name: None,
172 capabilities: serde_json::json!({}),
173 context_window: None,
174 }];
175 let mut registry = ProviderRegistry::new();
176 registry.register("plugin", spec).unwrap();
177
178 let summaries = registry.summaries();
179 assert!(!summaries[0].models[0].streaming);
180 assert!(!summaries[0].models[0].tool_use);
181 }
182
183 #[test]
184 fn register_namespaces_provider_by_plugin() {
185 let mut registry = ProviderRegistry::new();
186 let id = registry.register("plugin", spec("local")).unwrap();
187
188 assert_eq!(id, "plugin:local");
189 assert!(registry.get("plugin:local").is_some());
190 }
191
192 #[test]
193 fn duplicate_runtime_provider_id_is_rejected() {
194 let mut registry = ProviderRegistry::new();
195 registry.register("plugin", spec("local")).unwrap();
196 let err = registry.register("plugin", spec("local")).unwrap_err();
197
198 assert!(err.contains("already registered"));
199 }
200
201 #[test]
202 fn unregister_plugin_removes_its_providers_only() {
203 let mut registry = ProviderRegistry::new();
204 registry.register("one", spec("local")).unwrap();
205 registry.register("two", spec("local")).unwrap();
206
207 registry.unregister_plugin("one");
208
209 assert!(registry.get("one:local").is_none());
210 assert!(registry.get("two:local").is_some());
211 }
212
213 #[test]
214 fn model_ids_use_three_part_namespace() {
215 assert_eq!(
216 ProviderRegistry::parse_model_id("plugin:local:model-a"),
217 Some(("plugin", "local", "model-a"))
218 );
219 assert_eq!(ProviderRegistry::model_runtime_id("plugin", "local", "model-a"), "plugin:local:model-a");
220 assert!(ProviderRegistry::parse_model_id("plugin:local").is_none());
221 assert!(ProviderRegistry::parse_model_id("plugin:local:model:extra").is_none());
222 }
223}