Skip to main content

walrus_model/
provider.rs

1//! Provider implementation.
2//!
3//! Unified `Provider` enum with enum dispatch over concrete backends.
4//! `build_provider()` uses a URL lookup table for OpenAI-compatible kinds,
5//! eliminating repeated match arms for each variant.
6
7use crate::{
8    config::{ApiStandard, ProviderConfig},
9    remote::{
10        claude::{self, Claude},
11        openai::{self, OpenAI},
12    },
13};
14use anyhow::Result;
15use async_stream::try_stream;
16use compact_str::CompactString;
17use futures_core::Stream;
18use futures_util::StreamExt;
19use wcore::model::{Model, Response, StreamChunk};
20
21/// Unified LLM provider enum.
22///
23/// The gateway constructs the appropriate variant based on `ApiStandard`
24/// from the provider config. The runtime is monomorphized on `Provider`.
25#[derive(Clone)]
26pub enum Provider {
27    /// OpenAI-compatible API (covers OpenAI, DeepSeek, Grok, Qwen, Kimi, Ollama).
28    OpenAI(OpenAI),
29    /// Anthropic Messages API.
30    Claude(Claude),
31    /// Local inference via mistralrs.
32    #[cfg(feature = "local")]
33    Local(crate::local::Local),
34}
35
36impl Provider {
37    /// Query the context length for a given model ID.
38    ///
39    /// Local providers delegate to mistralrs; remote providers return None
40    /// (callers fall back to the static map in `wcore::model::default_context_limit`).
41    pub fn context_length(&self, _model: &str) -> Option<usize> {
42        match self {
43            Self::OpenAI(_) | Self::Claude(_) => None,
44            #[cfg(feature = "local")]
45            Self::Local(p) => p.context_length(_model),
46        }
47    }
48
49    /// Wait until the provider is ready.
50    ///
51    /// No-op for remote providers. For local providers, blocks until the
52    /// model finishes loading.
53    pub async fn wait_until_ready(&mut self) -> Result<()> {
54        match self {
55            Self::OpenAI(_) | Self::Claude(_) => Ok(()),
56            #[cfg(feature = "local")]
57            Self::Local(p) => p.wait_until_ready().await,
58        }
59    }
60}
61
62/// Construct a remote `Provider` from config and a shared HTTP client.
63///
64/// Uses `effective_standard()` to pick the API protocol (OpenAI or Anthropic).
65/// Local models are not handled here — they use the built-in registry.
66pub async fn build_provider(config: &ProviderConfig, client: reqwest::Client) -> Result<Provider> {
67    let api_key = config.api_key.as_deref().unwrap_or("");
68    let model = config.model.as_str();
69
70    match config.effective_standard() {
71        ApiStandard::Anthropic => {
72            let url = config.base_url.as_deref().unwrap_or(claude::ENDPOINT);
73            Ok(Provider::Claude(Claude::custom(
74                client, api_key, url, model,
75            )?))
76        }
77        ApiStandard::OpenAI => {
78            let url = config
79                .base_url
80                .as_deref()
81                .unwrap_or(openai::endpoint::OPENAI);
82            let provider = if api_key.is_empty() {
83                OpenAI::no_auth(client, url, model)
84            } else {
85                OpenAI::custom(client, api_key, url, model)?
86            };
87            Ok(Provider::OpenAI(provider))
88        }
89    }
90}
91
92impl Model for Provider {
93    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
94        match self {
95            Self::OpenAI(p) => p.send(request).await,
96            Self::Claude(p) => p.send(request).await,
97            #[cfg(feature = "local")]
98            Self::Local(p) => p.send(request).await,
99        }
100    }
101
102    fn stream(
103        &self,
104        request: wcore::model::Request,
105    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
106        let this = self.clone();
107        try_stream! {
108            match this {
109                Provider::OpenAI(p) => {
110                    let mut stream = std::pin::pin!(p.stream(request));
111                    while let Some(chunk) = stream.next().await {
112                        yield chunk?;
113                    }
114                }
115                Provider::Claude(p) => {
116                    let mut stream = std::pin::pin!(p.stream(request));
117                    while let Some(chunk) = stream.next().await {
118                        yield chunk?;
119                    }
120                }
121                #[cfg(feature = "local")]
122                Provider::Local(p) => {
123                    let mut stream = std::pin::pin!(p.stream(request));
124                    while let Some(chunk) = stream.next().await {
125                        yield chunk?;
126                    }
127                }
128            }
129        }
130    }
131
132    fn context_limit(&self, model: &str) -> usize {
133        self.context_length(model)
134            .unwrap_or_else(|| wcore::model::default_context_limit(model))
135    }
136
137    fn active_model(&self) -> CompactString {
138        match self {
139            Self::OpenAI(p) => p.active_model(),
140            Self::Claude(p) => p.active_model(),
141            #[cfg(feature = "local")]
142            Self::Local(p) => p.active_model(),
143        }
144    }
145}