1use crate::{Provider, ProviderConfig, build_provider};
5use anyhow::{Result, bail};
6use async_stream::try_stream;
7use compact_str::CompactString;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use std::collections::BTreeMap;
11use std::sync::{Arc, RwLock};
12use wcore::model::{Model, Response, StreamChunk, default_context_limit};
13
14pub struct ProviderManager {
20 inner: Arc<RwLock<Inner>>,
21}
22
23struct Inner {
24 providers: BTreeMap<CompactString, (ProviderConfig, Provider)>,
26 active: CompactString,
28 client: reqwest::Client,
30}
31
32#[derive(Debug, Clone)]
34pub struct ProviderEntry {
35 pub name: CompactString,
37 pub active: bool,
39}
40
41impl ProviderManager {
42 pub async fn from_configs(configs: &[ProviderConfig]) -> Result<Self> {
48 if configs.is_empty() {
49 bail!("at least one provider config is required");
50 }
51
52 let client = reqwest::Client::new();
53 let mut providers = BTreeMap::new();
54
55 for config in configs {
56 config.validate()?;
57 let provider = build_provider(config, client.clone()).await?;
58 providers.insert(config.model.clone(), (config.clone(), provider));
59 }
60
61 let active = configs[0].model.clone();
62
63 Ok(Self {
64 inner: Arc::new(RwLock::new(Inner {
65 providers,
66 active,
67 client,
68 })),
69 })
70 }
71
72 pub fn single(config: ProviderConfig, provider: Provider) -> Self {
74 let model = config.model.clone();
75 let mut providers = BTreeMap::new();
76 providers.insert(model.clone(), (config, provider));
77 Self {
78 inner: Arc::new(RwLock::new(Inner {
79 providers,
80 active: model,
81 client: reqwest::Client::new(),
82 })),
83 }
84 }
85
86 pub fn active(&self) -> Provider {
88 let inner = self.inner.read().expect("provider lock poisoned");
89 inner.providers[&inner.active].1.clone()
90 }
91
92 pub fn active_model(&self) -> CompactString {
94 let inner = self.inner.read().expect("provider lock poisoned");
95 inner.active.clone()
96 }
97
98 pub fn active_config(&self) -> ProviderConfig {
100 let inner = self.inner.read().expect("provider lock poisoned");
101 inner.providers[&inner.active].0.clone()
102 }
103
104 pub fn switch(&self, model: &str) -> Result<()> {
107 let mut inner = self.inner.write().expect("provider lock poisoned");
108 if !inner.providers.contains_key(model) {
109 bail!("provider '{}' not found", model);
110 }
111 inner.active = CompactString::from(model);
112 Ok(())
113 }
114
115 pub async fn add(&self, config: &ProviderConfig) -> Result<()> {
118 config.validate()?;
119 let client = {
120 let inner = self.inner.read().expect("provider lock poisoned");
121 inner.client.clone()
122 };
123 let provider = build_provider(config, client).await?;
124 let mut inner = self.inner.write().expect("provider lock poisoned");
125 inner
126 .providers
127 .insert(config.model.clone(), (config.clone(), provider));
128 Ok(())
129 }
130
131 pub fn remove(&self, model: &str) -> Result<()> {
134 let mut inner = self.inner.write().expect("provider lock poisoned");
135 if inner.active == model {
136 bail!("cannot remove the active provider '{}'", model);
137 }
138 if inner.providers.remove(model).is_none() {
139 bail!("provider '{}' not found", model);
140 }
141 Ok(())
142 }
143
144 pub fn list(&self) -> Vec<ProviderEntry> {
146 let inner = self.inner.read().expect("provider lock poisoned");
147 inner
148 .providers
149 .keys()
150 .map(|name| ProviderEntry {
151 name: name.clone(),
152 active: *name == inner.active,
153 })
154 .collect()
155 }
156
157 fn provider_for(&self, model: &str) -> Result<Provider> {
160 let inner = self.inner.read().expect("provider lock poisoned");
161 inner
162 .providers
163 .get(model)
164 .map(|(_, p)| p.clone())
165 .ok_or_else(|| anyhow::anyhow!("model '{}' not found in registry", model))
166 }
167
168 pub fn context_limit(&self, model: &str) -> usize {
172 let inner = self.inner.read().expect("provider lock poisoned");
173 if let Some((_, provider)) = inner.providers.get(model)
174 && let Some(limit) = provider.context_length(model)
175 {
176 return limit;
177 }
178 default_context_limit(model)
179 }
180}
181
182impl Model for ProviderManager {
183 async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
184 let provider = self.provider_for(&request.model)?;
185 provider.send(request).await
186 }
187
188 fn stream(
189 &self,
190 request: wcore::model::Request,
191 ) -> impl Stream<Item = Result<StreamChunk>> + Send {
192 let result = self.provider_for(&request.model);
193 try_stream! {
194 let provider = result?;
195 let mut stream = std::pin::pin!(provider.stream(request));
196 while let Some(chunk) = stream.next().await {
197 yield chunk?;
198 }
199 }
200 }
201
202 fn context_limit(&self, model: &str) -> usize {
203 ProviderManager::context_limit(self, model)
204 }
205
206 fn active_model(&self) -> CompactString {
207 ProviderManager::active_model(self)
208 }
209}
210
211impl std::fmt::Debug for ProviderManager {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 let inner = self.inner.read().expect("provider lock poisoned");
214 f.debug_struct("ProviderManager")
215 .field("active", &inner.active)
216 .field("count", &inner.providers.len())
217 .finish()
218 }
219}
220
221impl Clone for ProviderManager {
222 fn clone(&self) -> Self {
223 Self {
224 inner: Arc::clone(&self.inner),
225 }
226 }
227}