Skip to main content

walrus_model/
manager.rs

1//! `ProviderManager` — concurrent-safe named provider registry with model
2//! routing and active-provider swapping.
3
4use crate::{Provider, ProviderConfig, build_provider};
5use anyhow::{Result, bail};
6use async_stream::try_stream;
7use compact_str::CompactString;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use std::{
11    collections::BTreeMap,
12    sync::{Arc, RwLock},
13};
14use wcore::model::{Model, Response, StreamChunk, default_context_limit};
15
16/// Manages a set of named providers with an active selection.
17///
18/// All methods that read or mutate the inner state acquire the `RwLock`.
19/// `active()` returns a clone of the current `Provider` — callers do not
20/// hold the lock while performing LLM calls.
21pub struct ProviderManager {
22    inner: Arc<RwLock<Inner>>,
23}
24
25struct Inner {
26    /// Provider instances keyed by model name.
27    providers: BTreeMap<CompactString, (ProviderConfig, Provider)>,
28    /// Model name of the currently active provider.
29    active: CompactString,
30    /// Shared HTTP client for constructing new providers.
31    client: reqwest::Client,
32}
33
34/// Info about a single provider entry returned by `list()`.
35#[derive(Debug, Clone)]
36pub struct ProviderEntry {
37    /// Provider model name (key).
38    pub name: CompactString,
39    /// Whether this is the active provider.
40    pub active: bool,
41}
42
43impl ProviderManager {
44    /// Create a new manager from a list of provider configs.
45    ///
46    /// The first element becomes the active provider.
47    /// Returns an error if the slice is empty, any config fails validation, or
48    /// any provider fails to build.
49    pub async fn from_configs(configs: &[ProviderConfig]) -> Result<Self> {
50        if configs.is_empty() {
51            bail!("at least one provider config is required");
52        }
53
54        let client = reqwest::Client::new();
55        let mut providers = BTreeMap::new();
56
57        for config in configs {
58            config.validate()?;
59            let provider = build_provider(config, client.clone()).await?;
60            providers.insert(config.model.clone(), (config.clone(), provider));
61        }
62
63        let active = configs[0].model.clone();
64
65        Ok(Self {
66            inner: Arc::new(RwLock::new(Inner {
67                providers,
68                active,
69                client,
70            })),
71        })
72    }
73
74    /// Create a manager with a single provider.
75    pub fn single(config: ProviderConfig, provider: Provider) -> Self {
76        let model = config.model.clone();
77        let mut providers = BTreeMap::new();
78        providers.insert(model.clone(), (config, provider));
79        Self {
80            inner: Arc::new(RwLock::new(Inner {
81                providers,
82                active: model,
83                client: reqwest::Client::new(),
84            })),
85        }
86    }
87
88    /// Get a clone of the active provider.
89    pub fn active(&self) -> Provider {
90        let inner = self.inner.read().expect("provider lock poisoned");
91        inner.providers[&inner.active].1.clone()
92    }
93
94    /// Get the model name of the active provider (also its key).
95    pub fn active_model(&self) -> CompactString {
96        let inner = self.inner.read().expect("provider lock poisoned");
97        inner.active.clone()
98    }
99
100    /// Get a clone of the active provider's config.
101    pub fn active_config(&self) -> ProviderConfig {
102        let inner = self.inner.read().expect("provider lock poisoned");
103        inner.providers[&inner.active].0.clone()
104    }
105
106    /// Switch to a different provider by model name. Returns an error if the
107    /// name is not found.
108    pub fn switch(&self, model: &str) -> Result<()> {
109        let mut inner = self.inner.write().expect("provider lock poisoned");
110        if !inner.providers.contains_key(model) {
111            bail!("provider '{}' not found", model);
112        }
113        inner.active = CompactString::from(model);
114        Ok(())
115    }
116
117    /// Add a new provider. Validates config first. Replaces any existing
118    /// provider with the same model name.
119    pub async fn add(&self, config: &ProviderConfig) -> Result<()> {
120        config.validate()?;
121        let client = {
122            let inner = self.inner.read().expect("provider lock poisoned");
123            inner.client.clone()
124        };
125        let provider = build_provider(config, client).await?;
126        let mut inner = self.inner.write().expect("provider lock poisoned");
127        inner
128            .providers
129            .insert(config.model.clone(), (config.clone(), provider));
130        Ok(())
131    }
132
133    /// Remove a provider by model name. Fails if the provider is currently
134    /// active.
135    pub fn remove(&self, model: &str) -> Result<()> {
136        let mut inner = self.inner.write().expect("provider lock poisoned");
137        if inner.active == model {
138            bail!("cannot remove the active provider '{}'", model);
139        }
140        if inner.providers.remove(model).is_none() {
141            bail!("provider '{}' not found", model);
142        }
143        Ok(())
144    }
145
146    /// List all providers with their active status.
147    pub fn list(&self) -> Vec<ProviderEntry> {
148        let inner = self.inner.read().expect("provider lock poisoned");
149        inner
150            .providers
151            .keys()
152            .map(|name| ProviderEntry {
153                name: name.clone(),
154                active: *name == inner.active,
155            })
156            .collect()
157    }
158
159    /// Look up a provider by model name. Returns a clone so callers don't
160    /// hold the lock during LLM calls.
161    fn provider_for(&self, model: &str) -> Result<Provider> {
162        let inner = self.inner.read().expect("provider lock poisoned");
163        inner
164            .providers
165            .get(model)
166            .map(|(_, p)| p.clone())
167            .ok_or_else(|| anyhow::anyhow!("model '{}' not found in registry", model))
168    }
169
170    /// Resolve the context limit for a model.
171    ///
172    /// Resolution chain: provider reports limit → static map → 8192 default.
173    pub fn context_limit(&self, model: &str) -> usize {
174        let inner = self.inner.read().expect("provider lock poisoned");
175        if let Some((_, provider)) = inner.providers.get(model)
176            && let Some(limit) = provider.context_length(model)
177        {
178            return limit;
179        }
180        default_context_limit(model)
181    }
182}
183
184impl Model for ProviderManager {
185    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
186        let provider = self.provider_for(&request.model)?;
187        provider.send(request).await
188    }
189
190    fn stream(
191        &self,
192        request: wcore::model::Request,
193    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
194        let result = self.provider_for(&request.model);
195        try_stream! {
196            let provider = result?;
197            let mut stream = std::pin::pin!(provider.stream(request));
198            while let Some(chunk) = stream.next().await {
199                yield chunk?;
200            }
201        }
202    }
203
204    fn context_limit(&self, model: &str) -> usize {
205        ProviderManager::context_limit(self, model)
206    }
207
208    fn active_model(&self) -> CompactString {
209        ProviderManager::active_model(self)
210    }
211}
212
213impl std::fmt::Debug for ProviderManager {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        let inner = self.inner.read().expect("provider lock poisoned");
216        f.debug_struct("ProviderManager")
217            .field("active", &inner.active)
218            .field("count", &inner.providers.len())
219            .finish()
220    }
221}
222
223impl Clone for ProviderManager {
224    fn clone(&self) -> Self {
225        Self {
226            inner: Arc::clone(&self.inner),
227        }
228    }
229}