1use crate::{Provider, ProviderDef, build_provider};
5use anyhow::{Result, anyhow, bail};
6use async_stream::try_stream;
7use compact_str::CompactString;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use std::collections::BTreeMap;
11use std::sync::{Arc, RwLock};
12use wcore::model::{Model, Response, StreamChunk, default_context_limit};
13
14pub struct ProviderRegistry {
20 inner: Arc<RwLock<Inner>>,
21}
22
23struct Inner {
24 providers: BTreeMap<CompactString, Provider>,
26 active: CompactString,
28 client: reqwest::Client,
30}
31
32#[derive(Debug, Clone)]
34pub struct ProviderEntry {
35 pub name: CompactString,
37 pub active: bool,
39}
40
41impl ProviderRegistry {
42 pub fn new(active: impl Into<CompactString>) -> Self {
46 Self {
47 inner: Arc::new(RwLock::new(Inner {
48 providers: BTreeMap::new(),
49 active: active.into(),
50 client: reqwest::Client::new(),
51 })),
52 }
53 }
54
55 pub fn from_providers(
60 active: CompactString,
61 providers: &BTreeMap<CompactString, ProviderDef>,
62 ) -> Result<Self> {
63 let registry = Self::new(active);
64 for def in providers.values() {
65 registry.add_def(def)?;
66 }
67 Ok(registry)
68 }
69
70 pub fn add_provider(&self, name: impl Into<CompactString>, provider: Provider) -> Result<()> {
72 let mut inner = self
73 .inner
74 .write()
75 .map_err(|_| anyhow!("provider lock poisoned"))?;
76 inner.providers.insert(name.into(), provider);
77 Ok(())
78 }
79
80 pub fn add_def(&self, def: &ProviderDef) -> Result<()> {
82 let client = {
83 let inner = self
84 .inner
85 .read()
86 .map_err(|_| anyhow!("provider lock poisoned"))?;
87 inner.client.clone()
88 };
89 for model_name in &def.models {
90 let provider = build_provider(def, model_name, client.clone())?;
91 let mut inner = self
92 .inner
93 .write()
94 .map_err(|_| anyhow!("provider lock poisoned"))?;
95 inner
96 .providers
97 .insert(CompactString::from(model_name.as_str()), provider);
98 }
99 Ok(())
100 }
101
102 pub fn active(&self) -> Result<Provider> {
104 let inner = self
105 .inner
106 .read()
107 .map_err(|_| anyhow!("provider lock poisoned"))?;
108 Ok(inner.providers[&inner.active].clone())
109 }
110
111 pub fn active_model_name(&self) -> Result<CompactString> {
113 let inner = self
114 .inner
115 .read()
116 .map_err(|_| anyhow!("provider lock poisoned"))?;
117 Ok(inner.active.clone())
118 }
119
120 pub fn switch(&self, model: &str) -> Result<()> {
123 let mut inner = self
124 .inner
125 .write()
126 .map_err(|_| anyhow!("provider lock poisoned"))?;
127 if !inner.providers.contains_key(model) {
128 bail!("provider '{}' not found", model);
129 }
130 inner.active = CompactString::from(model);
131 Ok(())
132 }
133
134 pub fn remove(&self, model: &str) -> Result<()> {
137 let mut inner = self
138 .inner
139 .write()
140 .map_err(|_| anyhow!("provider lock poisoned"))?;
141 if inner.active == model {
142 bail!("cannot remove the active provider '{}'", model);
143 }
144 if inner.providers.remove(model).is_none() {
145 bail!("provider '{}' not found", model);
146 }
147 Ok(())
148 }
149
150 pub fn list(&self) -> Result<Vec<ProviderEntry>> {
152 let inner = self
153 .inner
154 .read()
155 .map_err(|_| anyhow!("provider lock poisoned"))?;
156 Ok(inner
157 .providers
158 .keys()
159 .map(|name| ProviderEntry {
160 name: name.clone(),
161 active: *name == inner.active,
162 })
163 .collect())
164 }
165
166 fn provider_for(&self, model: &str) -> Result<Provider> {
169 let inner = self
170 .inner
171 .read()
172 .map_err(|_| anyhow!("provider lock poisoned"))?;
173 inner
174 .providers
175 .get(model)
176 .cloned()
177 .ok_or_else(|| anyhow!("model '{}' not found in registry", model))
178 }
179
180 pub fn context_limit(&self, model: &str) -> usize {
184 default_context_limit(model)
185 }
186}
187
188impl Model for ProviderRegistry {
189 async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
190 let provider = self.provider_for(&request.model)?;
191 provider.send(request).await
192 }
193
194 fn stream(
195 &self,
196 request: wcore::model::Request,
197 ) -> impl Stream<Item = Result<StreamChunk>> + Send {
198 let result = self.provider_for(&request.model);
199 try_stream! {
200 let provider = result?;
201 let mut stream = std::pin::pin!(provider.stream(request));
202 while let Some(chunk) = stream.next().await {
203 yield chunk?;
204 }
205 }
206 }
207
208 fn context_limit(&self, model: &str) -> usize {
209 ProviderRegistry::context_limit(self, model)
210 }
211
212 fn active_model(&self) -> CompactString {
213 self.active_model_name()
214 .unwrap_or_else(|_| CompactString::const_new("unknown"))
215 }
216}
217
218impl std::fmt::Debug for ProviderRegistry {
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 match self.inner.read() {
221 Ok(inner) => f
222 .debug_struct("ProviderRegistry")
223 .field("active", &inner.active)
224 .field("count", &inner.providers.len())
225 .finish(),
226 Err(_) => f
227 .debug_struct("ProviderRegistry")
228 .field("error", &"lock poisoned")
229 .finish(),
230 }
231 }
232}
233
234impl Clone for ProviderRegistry {
235 fn clone(&self) -> Self {
236 Self {
237 inner: Arc::clone(&self.inner),
238 }
239 }
240}