Skip to main content

walrus_model/
manager.rs

1//! `ProviderRegistry` — concurrent-safe named provider registry with model
2//! routing and active-provider swapping.
3
4use crate::{Provider, ProviderDef, build_provider};
5use anyhow::{Result, anyhow, bail};
6use async_stream::try_stream;
7use compact_str::CompactString;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use std::collections::BTreeMap;
11use std::sync::{Arc, RwLock};
12use wcore::model::{Model, Response, StreamChunk, default_context_limit};
13
14/// Manages a set of named providers with an active selection.
15///
16/// All methods that read or mutate the inner state acquire the `RwLock`.
17/// `active()` returns a clone of the current `Provider` — callers do not
18/// hold the lock while performing LLM calls.
19pub struct ProviderRegistry {
20    inner: Arc<RwLock<Inner>>,
21}
22
23struct Inner {
24    /// Provider instances keyed by model name.
25    providers: BTreeMap<CompactString, Provider>,
26    /// Model name of the currently active provider.
27    active: CompactString,
28    /// Shared HTTP client for constructing new providers.
29    client: reqwest::Client,
30}
31
32/// Info about a single provider entry returned by `list()`.
33#[derive(Debug, Clone)]
34pub struct ProviderEntry {
35    /// Provider model name (key).
36    pub name: CompactString,
37    /// Whether this is the active provider.
38    pub active: bool,
39}
40
41impl ProviderRegistry {
42    /// Create an empty manager with the given active model name.
43    ///
44    /// Use `add_provider()` or `add_model()` to populate.
45    pub fn new(active: impl Into<CompactString>) -> Self {
46        Self {
47            inner: Arc::new(RwLock::new(Inner {
48                providers: BTreeMap::new(),
49                active: active.into(),
50                client: reqwest::Client::new(),
51            })),
52        }
53    }
54
55    /// Build a registry from a map of provider definitions and an active model.
56    ///
57    /// Iterates each provider def, building a `Provider` instance per model
58    /// in its `models` list.
59    pub fn from_providers(
60        active: CompactString,
61        providers: &BTreeMap<CompactString, ProviderDef>,
62    ) -> Result<Self> {
63        let registry = Self::new(active);
64        for def in providers.values() {
65            registry.add_def(def)?;
66        }
67        Ok(registry)
68    }
69
70    /// Add a pre-built provider directly (e.g. local models from registry).
71    pub fn add_provider(&self, name: impl Into<CompactString>, provider: Provider) -> Result<()> {
72        let mut inner = self
73            .inner
74            .write()
75            .map_err(|_| anyhow!("provider lock poisoned"))?;
76        inner.providers.insert(name.into(), provider);
77        Ok(())
78    }
79
80    /// Add all models from a provider definition. Builds a `Provider` per model.
81    pub fn add_def(&self, def: &ProviderDef) -> Result<()> {
82        let client = {
83            let inner = self
84                .inner
85                .read()
86                .map_err(|_| anyhow!("provider lock poisoned"))?;
87            inner.client.clone()
88        };
89        for model_name in &def.models {
90            let provider = build_provider(def, model_name, client.clone())?;
91            let mut inner = self
92                .inner
93                .write()
94                .map_err(|_| anyhow!("provider lock poisoned"))?;
95            inner
96                .providers
97                .insert(CompactString::from(model_name.as_str()), provider);
98        }
99        Ok(())
100    }
101
102    /// Get a clone of the active provider.
103    pub fn active(&self) -> Result<Provider> {
104        let inner = self
105            .inner
106            .read()
107            .map_err(|_| anyhow!("provider lock poisoned"))?;
108        Ok(inner.providers[&inner.active].clone())
109    }
110
111    /// Get the model name of the active provider (also its key).
112    pub fn active_model_name(&self) -> Result<CompactString> {
113        let inner = self
114            .inner
115            .read()
116            .map_err(|_| anyhow!("provider lock poisoned"))?;
117        Ok(inner.active.clone())
118    }
119
120    /// Switch to a different provider by model name. Returns an error if the
121    /// name is not found.
122    pub fn switch(&self, model: &str) -> Result<()> {
123        let mut inner = self
124            .inner
125            .write()
126            .map_err(|_| anyhow!("provider lock poisoned"))?;
127        if !inner.providers.contains_key(model) {
128            bail!("provider '{}' not found", model);
129        }
130        inner.active = CompactString::from(model);
131        Ok(())
132    }
133
134    /// Remove a provider by model name. Fails if the provider is currently
135    /// active.
136    pub fn remove(&self, model: &str) -> Result<()> {
137        let mut inner = self
138            .inner
139            .write()
140            .map_err(|_| anyhow!("provider lock poisoned"))?;
141        if inner.active == model {
142            bail!("cannot remove the active provider '{}'", model);
143        }
144        if inner.providers.remove(model).is_none() {
145            bail!("provider '{}' not found", model);
146        }
147        Ok(())
148    }
149
150    /// List all providers with their active status.
151    pub fn list(&self) -> Result<Vec<ProviderEntry>> {
152        let inner = self
153            .inner
154            .read()
155            .map_err(|_| anyhow!("provider lock poisoned"))?;
156        Ok(inner
157            .providers
158            .keys()
159            .map(|name| ProviderEntry {
160                name: name.clone(),
161                active: *name == inner.active,
162            })
163            .collect())
164    }
165
166    /// Look up a provider by model name. Returns a clone so callers don't
167    /// hold the lock during LLM calls.
168    fn provider_for(&self, model: &str) -> Result<Provider> {
169        let inner = self
170            .inner
171            .read()
172            .map_err(|_| anyhow!("provider lock poisoned"))?;
173        inner
174            .providers
175            .get(model)
176            .cloned()
177            .ok_or_else(|| anyhow!("model '{}' not found in registry", model))
178    }
179
180    /// Resolve the context limit for a model.
181    ///
182    /// Uses the static map in `wcore::model::default_context_limit`.
183    pub fn context_limit(&self, model: &str) -> usize {
184        default_context_limit(model)
185    }
186}
187
188impl Model for ProviderRegistry {
189    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
190        let provider = self.provider_for(&request.model)?;
191        provider.send(request).await
192    }
193
194    fn stream(
195        &self,
196        request: wcore::model::Request,
197    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
198        let result = self.provider_for(&request.model);
199        try_stream! {
200            let provider = result?;
201            let mut stream = std::pin::pin!(provider.stream(request));
202            while let Some(chunk) = stream.next().await {
203                yield chunk?;
204            }
205        }
206    }
207
208    fn context_limit(&self, model: &str) -> usize {
209        ProviderRegistry::context_limit(self, model)
210    }
211
212    fn active_model(&self) -> CompactString {
213        self.active_model_name()
214            .unwrap_or_else(|_| CompactString::const_new("unknown"))
215    }
216}
217
218impl std::fmt::Debug for ProviderRegistry {
219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        match self.inner.read() {
221            Ok(inner) => f
222                .debug_struct("ProviderRegistry")
223                .field("active", &inner.active)
224                .field("count", &inner.providers.len())
225                .finish(),
226            Err(_) => f
227                .debug_struct("ProviderRegistry")
228                .field("error", &"lock poisoned")
229                .finish(),
230        }
231    }
232}
233
234impl Clone for ProviderRegistry {
235    fn clone(&self) -> Self {
236        Self {
237            inner: Arc::clone(&self.inner),
238        }
239    }
240}