Skip to main content

walrus_model/
manager.rs

1//! `ProviderManager` — concurrent-safe named provider registry with model
2//! routing and active-provider swapping.
3
4use crate::{Provider, ProviderConfig, build_provider};
5use anyhow::{Result, anyhow, bail};
6use async_stream::try_stream;
7use compact_str::CompactString;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use std::collections::BTreeMap;
11use std::sync::{Arc, RwLock};
12use wcore::model::{Model, Response, StreamChunk, default_context_limit};
13
14/// Manages a set of named providers with an active selection.
15///
16/// All methods that read or mutate the inner state acquire the `RwLock`.
17/// `active()` returns a clone of the current `Provider` — callers do not
18/// hold the lock while performing LLM calls.
19pub struct ProviderManager {
20    inner: Arc<RwLock<Inner>>,
21}
22
23struct Inner {
24    /// Provider instances keyed by model name.
25    providers: BTreeMap<CompactString, Provider>,
26    /// Model name of the currently active provider.
27    active: CompactString,
28    /// Shared HTTP client for constructing new providers.
29    client: reqwest::Client,
30}
31
32/// Info about a single provider entry returned by `list()`.
33#[derive(Debug, Clone)]
34pub struct ProviderEntry {
35    /// Provider model name (key).
36    pub name: CompactString,
37    /// Whether this is the active provider.
38    pub active: bool,
39}
40
41impl ProviderManager {
42    /// Create an empty manager with the given active model name.
43    ///
44    /// Use `add_provider()` or `add_config()` to populate.
45    pub fn new(active: impl Into<CompactString>) -> Self {
46        Self {
47            inner: Arc::new(RwLock::new(Inner {
48                providers: BTreeMap::new(),
49                active: active.into(),
50                client: reqwest::Client::new(),
51            })),
52        }
53    }
54
55    /// Create a manager from a list of remote provider configs.
56    ///
57    /// The first element becomes the active provider.
58    /// Returns an error if the slice is empty, any config fails validation, or
59    /// any provider fails to build.
60    pub async fn from_configs(configs: &[ProviderConfig]) -> Result<Self> {
61        if configs.is_empty() {
62            bail!("at least one provider config is required");
63        }
64        let manager = Self::new(configs[0].model.clone());
65        for config in configs {
66            manager.add_config(config).await?;
67        }
68        Ok(manager)
69    }
70
71    /// Add a pre-built provider directly (e.g. local models from registry).
72    pub fn add_provider(&self, name: impl Into<CompactString>, provider: Provider) -> Result<()> {
73        let mut inner = self
74            .inner
75            .write()
76            .map_err(|_| anyhow!("provider lock poisoned"))?;
77        inner.providers.insert(name.into(), provider);
78        Ok(())
79    }
80
81    /// Add a remote provider from config. Validates and builds it.
82    pub async fn add_config(&self, config: &ProviderConfig) -> Result<()> {
83        config.validate()?;
84        let client = {
85            let inner = self
86                .inner
87                .read()
88                .map_err(|_| anyhow!("provider lock poisoned"))?;
89            inner.client.clone()
90        };
91        let provider = build_provider(config, client).await?;
92        let mut inner = self
93            .inner
94            .write()
95            .map_err(|_| anyhow!("provider lock poisoned"))?;
96        inner.providers.insert(config.model.clone(), provider);
97        Ok(())
98    }
99
100    /// Get a clone of the active provider.
101    pub fn active(&self) -> Result<Provider> {
102        let inner = self
103            .inner
104            .read()
105            .map_err(|_| anyhow!("provider lock poisoned"))?;
106        Ok(inner.providers[&inner.active].clone())
107    }
108
109    /// Get the model name of the active provider (also its key).
110    pub fn active_model_name(&self) -> Result<CompactString> {
111        let inner = self
112            .inner
113            .read()
114            .map_err(|_| anyhow!("provider lock poisoned"))?;
115        Ok(inner.active.clone())
116    }
117
118    /// Switch to a different provider by model name. Returns an error if the
119    /// name is not found.
120    pub fn switch(&self, model: &str) -> Result<()> {
121        let mut inner = self
122            .inner
123            .write()
124            .map_err(|_| anyhow!("provider lock poisoned"))?;
125        if !inner.providers.contains_key(model) {
126            bail!("provider '{}' not found", model);
127        }
128        inner.active = CompactString::from(model);
129        Ok(())
130    }
131
132    /// Remove a provider by model name. Fails if the provider is currently
133    /// active.
134    pub fn remove(&self, model: &str) -> Result<()> {
135        let mut inner = self
136            .inner
137            .write()
138            .map_err(|_| anyhow!("provider lock poisoned"))?;
139        if inner.active == model {
140            bail!("cannot remove the active provider '{}'", model);
141        }
142        if inner.providers.remove(model).is_none() {
143            bail!("provider '{}' not found", model);
144        }
145        Ok(())
146    }
147
148    /// List all providers with their active status.
149    pub fn list(&self) -> Result<Vec<ProviderEntry>> {
150        let inner = self
151            .inner
152            .read()
153            .map_err(|_| anyhow!("provider lock poisoned"))?;
154        Ok(inner
155            .providers
156            .keys()
157            .map(|name| ProviderEntry {
158                name: name.clone(),
159                active: *name == inner.active,
160            })
161            .collect())
162    }
163
164    /// Look up a provider by model name. Returns a clone so callers don't
165    /// hold the lock during LLM calls.
166    fn provider_for(&self, model: &str) -> Result<Provider> {
167        let inner = self
168            .inner
169            .read()
170            .map_err(|_| anyhow!("provider lock poisoned"))?;
171        inner
172            .providers
173            .get(model)
174            .cloned()
175            .ok_or_else(|| anyhow!("model '{}' not found in registry", model))
176    }
177
178    /// Wait until the active provider is ready.
179    ///
180    /// No-op for remote providers. For local providers, blocks until the
181    /// model finishes loading.
182    pub async fn wait_until_ready(&self) -> Result<()> {
183        let mut provider = self.active()?;
184        provider.wait_until_ready().await
185    }
186
187    /// Resolve the context limit for a model.
188    ///
189    /// Resolution chain: provider reports limit → static map → 8192 default.
190    /// Falls back to the static default if the lock is poisoned.
191    pub fn context_limit(&self, model: &str) -> usize {
192        let Ok(inner) = self.inner.read() else {
193            return default_context_limit(model);
194        };
195        if let Some(provider) = inner.providers.get(model)
196            && let Some(limit) = provider.context_length(model)
197        {
198            return limit;
199        }
200        default_context_limit(model)
201    }
202}
203
204impl Model for ProviderManager {
205    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
206        let provider = self.provider_for(&request.model)?;
207        provider.send(request).await
208    }
209
210    fn stream(
211        &self,
212        request: wcore::model::Request,
213    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
214        let result = self.provider_for(&request.model);
215        try_stream! {
216            let provider = result?;
217            let mut stream = std::pin::pin!(provider.stream(request));
218            while let Some(chunk) = stream.next().await {
219                yield chunk?;
220            }
221        }
222    }
223
224    fn context_limit(&self, model: &str) -> usize {
225        ProviderManager::context_limit(self, model)
226    }
227
228    fn active_model(&self) -> CompactString {
229        self.active_model_name()
230            .unwrap_or_else(|_| CompactString::const_new("unknown"))
231    }
232}
233
234impl std::fmt::Debug for ProviderManager {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        match self.inner.read() {
237            Ok(inner) => f
238                .debug_struct("ProviderManager")
239                .field("active", &inner.active)
240                .field("count", &inner.providers.len())
241                .finish(),
242            Err(_) => f
243                .debug_struct("ProviderManager")
244                .field("error", &"lock poisoned")
245                .finish(),
246        }
247    }
248}
249
250impl Clone for ProviderManager {
251    fn clone(&self) -> Self {
252        Self {
253            inner: Arc::clone(&self.inner),
254        }
255    }
256}