Skip to main content

walrus_model/
manager.rs

1//! `ProviderManager` — concurrent-safe named provider registry with model
2//! routing and active-provider swapping.
3
4use crate::{Provider, ProviderDef, build_provider};
5use anyhow::{Result, anyhow, bail};
6use async_stream::try_stream;
7use compact_str::CompactString;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use std::collections::BTreeMap;
11use std::sync::{Arc, RwLock};
12use wcore::model::{Model, Response, StreamChunk, default_context_limit};
13
14/// Manages a set of named providers with an active selection.
15///
16/// All methods that read or mutate the inner state acquire the `RwLock`.
17/// `active()` returns a clone of the current `Provider` — callers do not
18/// hold the lock while performing LLM calls.
19pub struct ProviderManager {
20    inner: Arc<RwLock<Inner>>,
21}
22
23struct Inner {
24    /// Provider instances keyed by model name.
25    providers: BTreeMap<CompactString, Provider>,
26    /// Model name of the currently active provider.
27    active: CompactString,
28    /// Shared HTTP client for constructing new providers.
29    client: reqwest::Client,
30}
31
32/// Info about a single provider entry returned by `list()`.
33#[derive(Debug, Clone)]
34pub struct ProviderEntry {
35    /// Provider model name (key).
36    pub name: CompactString,
37    /// Whether this is the active provider.
38    pub active: bool,
39}
40
41impl ProviderManager {
42    /// Create an empty manager with the given active model name.
43    ///
44    /// Use `add_provider()` or `add_model()` to populate.
45    pub fn new(active: impl Into<CompactString>) -> Self {
46        Self {
47            inner: Arc::new(RwLock::new(Inner {
48                providers: BTreeMap::new(),
49                active: active.into(),
50                client: reqwest::Client::new(),
51            })),
52        }
53    }
54
55    /// Build a manager from a map of provider definitions and an active model.
56    ///
57    /// Iterates each provider def, building a `Provider` instance per model
58    /// in its `models` list.
59    pub async fn from_providers(
60        active: CompactString,
61        providers: &BTreeMap<CompactString, ProviderDef>,
62    ) -> Result<Self> {
63        let manager = Self::new(active);
64        for def in providers.values() {
65            manager.add_def(def).await?;
66        }
67        Ok(manager)
68    }
69
70    /// Add a pre-built provider directly (e.g. local models from registry).
71    pub fn add_provider(&self, name: impl Into<CompactString>, provider: Provider) -> Result<()> {
72        let mut inner = self
73            .inner
74            .write()
75            .map_err(|_| anyhow!("provider lock poisoned"))?;
76        inner.providers.insert(name.into(), provider);
77        Ok(())
78    }
79
80    /// Add all models from a provider definition. Builds a `Provider` per model.
81    pub async fn add_def(&self, def: &ProviderDef) -> Result<()> {
82        let client = {
83            let inner = self
84                .inner
85                .read()
86                .map_err(|_| anyhow!("provider lock poisoned"))?;
87            inner.client.clone()
88        };
89        for model_name in &def.models {
90            let provider = build_provider(def, model_name, client.clone()).await?;
91            let mut inner = self
92                .inner
93                .write()
94                .map_err(|_| anyhow!("provider lock poisoned"))?;
95            inner.providers.insert(model_name.clone(), provider);
96        }
97        Ok(())
98    }
99
100    /// Get a clone of the active provider.
101    pub fn active(&self) -> Result<Provider> {
102        let inner = self
103            .inner
104            .read()
105            .map_err(|_| anyhow!("provider lock poisoned"))?;
106        Ok(inner.providers[&inner.active].clone())
107    }
108
109    /// Get the model name of the active provider (also its key).
110    pub fn active_model_name(&self) -> Result<CompactString> {
111        let inner = self
112            .inner
113            .read()
114            .map_err(|_| anyhow!("provider lock poisoned"))?;
115        Ok(inner.active.clone())
116    }
117
118    /// Switch to a different provider by model name. Returns an error if the
119    /// name is not found.
120    pub fn switch(&self, model: &str) -> Result<()> {
121        let mut inner = self
122            .inner
123            .write()
124            .map_err(|_| anyhow!("provider lock poisoned"))?;
125        if !inner.providers.contains_key(model) {
126            bail!("provider '{}' not found", model);
127        }
128        inner.active = CompactString::from(model);
129        Ok(())
130    }
131
132    /// Remove a provider by model name. Fails if the provider is currently
133    /// active.
134    pub fn remove(&self, model: &str) -> Result<()> {
135        let mut inner = self
136            .inner
137            .write()
138            .map_err(|_| anyhow!("provider lock poisoned"))?;
139        if inner.active == model {
140            bail!("cannot remove the active provider '{}'", model);
141        }
142        if inner.providers.remove(model).is_none() {
143            bail!("provider '{}' not found", model);
144        }
145        Ok(())
146    }
147
148    /// List all providers with their active status.
149    pub fn list(&self) -> Result<Vec<ProviderEntry>> {
150        let inner = self
151            .inner
152            .read()
153            .map_err(|_| anyhow!("provider lock poisoned"))?;
154        Ok(inner
155            .providers
156            .keys()
157            .map(|name| ProviderEntry {
158                name: name.clone(),
159                active: *name == inner.active,
160            })
161            .collect())
162    }
163
164    /// Look up a provider by model name. Returns a clone so callers don't
165    /// hold the lock during LLM calls.
166    fn provider_for(&self, model: &str) -> Result<Provider> {
167        let inner = self
168            .inner
169            .read()
170            .map_err(|_| anyhow!("provider lock poisoned"))?;
171        inner
172            .providers
173            .get(model)
174            .cloned()
175            .ok_or_else(|| anyhow!("model '{}' not found in registry", model))
176    }
177
178    /// Resolve the context limit for a model.
179    ///
180    /// Uses the static map in `wcore::model::default_context_limit`.
181    pub fn context_limit(&self, model: &str) -> usize {
182        default_context_limit(model)
183    }
184}
185
186impl Model for ProviderManager {
187    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
188        let provider = self.provider_for(&request.model)?;
189        provider.send(request).await
190    }
191
192    fn stream(
193        &self,
194        request: wcore::model::Request,
195    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
196        let result = self.provider_for(&request.model);
197        try_stream! {
198            let provider = result?;
199            let mut stream = std::pin::pin!(provider.stream(request));
200            while let Some(chunk) = stream.next().await {
201                yield chunk?;
202            }
203        }
204    }
205
206    fn context_limit(&self, model: &str) -> usize {
207        ProviderManager::context_limit(self, model)
208    }
209
210    fn active_model(&self) -> CompactString {
211        self.active_model_name()
212            .unwrap_or_else(|_| CompactString::const_new("unknown"))
213    }
214}
215
216impl std::fmt::Debug for ProviderManager {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        match self.inner.read() {
219            Ok(inner) => f
220                .debug_struct("ProviderManager")
221                .field("active", &inner.active)
222                .field("count", &inner.providers.len())
223                .finish(),
224            Err(_) => f
225                .debug_struct("ProviderManager")
226                .field("error", &"lock poisoned")
227                .finish(),
228        }
229    }
230}
231
232impl Clone for ProviderManager {
233    fn clone(&self) -> Self {
234        Self {
235            inner: Arc::clone(&self.inner),
236        }
237    }
238}