1use crate::{Provider, ProviderConfig, build_provider};
5use anyhow::{Result, bail};
6use async_stream::try_stream;
7use compact_str::CompactString;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use std::{
11 collections::BTreeMap,
12 sync::{Arc, RwLock},
13};
14use wcore::model::{Model, Response, StreamChunk, default_context_limit};
15
16pub struct ProviderManager {
22 inner: Arc<RwLock<Inner>>,
23}
24
25struct Inner {
26 providers: BTreeMap<CompactString, (ProviderConfig, Provider)>,
28 active: CompactString,
30 client: reqwest::Client,
32}
33
34#[derive(Debug, Clone)]
36pub struct ProviderEntry {
37 pub name: CompactString,
39 pub active: bool,
41}
42
43impl ProviderManager {
44 pub async fn from_configs(configs: &[ProviderConfig]) -> Result<Self> {
50 if configs.is_empty() {
51 bail!("at least one provider config is required");
52 }
53
54 let client = reqwest::Client::new();
55 let mut providers = BTreeMap::new();
56
57 for config in configs {
58 config.validate()?;
59 let provider = build_provider(config, client.clone()).await?;
60 providers.insert(config.model.clone(), (config.clone(), provider));
61 }
62
63 let active = configs[0].model.clone();
64
65 Ok(Self {
66 inner: Arc::new(RwLock::new(Inner {
67 providers,
68 active,
69 client,
70 })),
71 })
72 }
73
74 pub fn single(config: ProviderConfig, provider: Provider) -> Self {
76 let model = config.model.clone();
77 let mut providers = BTreeMap::new();
78 providers.insert(model.clone(), (config, provider));
79 Self {
80 inner: Arc::new(RwLock::new(Inner {
81 providers,
82 active: model,
83 client: reqwest::Client::new(),
84 })),
85 }
86 }
87
88 pub fn active(&self) -> Provider {
90 let inner = self.inner.read().expect("provider lock poisoned");
91 inner.providers[&inner.active].1.clone()
92 }
93
94 pub fn active_model(&self) -> CompactString {
96 let inner = self.inner.read().expect("provider lock poisoned");
97 inner.active.clone()
98 }
99
100 pub fn active_config(&self) -> ProviderConfig {
102 let inner = self.inner.read().expect("provider lock poisoned");
103 inner.providers[&inner.active].0.clone()
104 }
105
106 pub fn switch(&self, model: &str) -> Result<()> {
109 let mut inner = self.inner.write().expect("provider lock poisoned");
110 if !inner.providers.contains_key(model) {
111 bail!("provider '{}' not found", model);
112 }
113 inner.active = CompactString::from(model);
114 Ok(())
115 }
116
117 pub async fn add(&self, config: &ProviderConfig) -> Result<()> {
120 config.validate()?;
121 let client = {
122 let inner = self.inner.read().expect("provider lock poisoned");
123 inner.client.clone()
124 };
125 let provider = build_provider(config, client).await?;
126 let mut inner = self.inner.write().expect("provider lock poisoned");
127 inner
128 .providers
129 .insert(config.model.clone(), (config.clone(), provider));
130 Ok(())
131 }
132
133 pub fn remove(&self, model: &str) -> Result<()> {
136 let mut inner = self.inner.write().expect("provider lock poisoned");
137 if inner.active == model {
138 bail!("cannot remove the active provider '{}'", model);
139 }
140 if inner.providers.remove(model).is_none() {
141 bail!("provider '{}' not found", model);
142 }
143 Ok(())
144 }
145
146 pub fn list(&self) -> Vec<ProviderEntry> {
148 let inner = self.inner.read().expect("provider lock poisoned");
149 inner
150 .providers
151 .keys()
152 .map(|name| ProviderEntry {
153 name: name.clone(),
154 active: *name == inner.active,
155 })
156 .collect()
157 }
158
159 fn provider_for(&self, model: &str) -> Result<Provider> {
162 let inner = self.inner.read().expect("provider lock poisoned");
163 inner
164 .providers
165 .get(model)
166 .map(|(_, p)| p.clone())
167 .ok_or_else(|| anyhow::anyhow!("model '{}' not found in registry", model))
168 }
169
170 pub fn context_limit(&self, model: &str) -> usize {
174 let inner = self.inner.read().expect("provider lock poisoned");
175 if let Some((_, provider)) = inner.providers.get(model)
176 && let Some(limit) = provider.context_length(model)
177 {
178 return limit;
179 }
180 default_context_limit(model)
181 }
182}
183
184impl Model for ProviderManager {
185 async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
186 let provider = self.provider_for(&request.model)?;
187 provider.send(request).await
188 }
189
190 fn stream(
191 &self,
192 request: wcore::model::Request,
193 ) -> impl Stream<Item = Result<StreamChunk>> + Send {
194 let result = self.provider_for(&request.model);
195 try_stream! {
196 let provider = result?;
197 let mut stream = std::pin::pin!(provider.stream(request));
198 while let Some(chunk) = stream.next().await {
199 yield chunk?;
200 }
201 }
202 }
203
204 fn context_limit(&self, model: &str) -> usize {
205 ProviderManager::context_limit(self, model)
206 }
207
208 fn active_model(&self) -> CompactString {
209 ProviderManager::active_model(self)
210 }
211}
212
213impl std::fmt::Debug for ProviderManager {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 let inner = self.inner.read().expect("provider lock poisoned");
216 f.debug_struct("ProviderManager")
217 .field("active", &inner.active)
218 .field("count", &inner.providers.len())
219 .finish()
220 }
221}
222
223impl Clone for ProviderManager {
224 fn clone(&self) -> Self {
225 Self {
226 inner: Arc::clone(&self.inner),
227 }
228 }
229}