Skip to main content

rustic_ai/providers/
mod.rs

1use std::sync::Arc;
2
3use thiserror::Error;
4
5use crate::model::{Model, ModelSettings};
6
7pub mod anthropic;
8pub mod gemini;
9pub mod grok;
10pub mod openai;
11
12pub trait Provider: Send + Sync {
13    fn name(&self) -> &str;
14    fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model>;
15}
16
17#[derive(Debug, Error)]
18pub enum ProviderError {
19    #[error("unknown provider: {0}")]
20    UnknownProvider(String),
21    #[error("missing API key for provider: {0}")]
22    MissingApiKey(String),
23    #[error("invalid model string: {0}")]
24    InvalidModel(String),
25}
26
27pub fn infer_provider(name: &str) -> Result<Box<dyn Provider>, ProviderError> {
28    match name {
29        "openai" => openai::OpenAIProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>),
30        "grok" => grok::GrokProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>),
31        "anthropic" => {
32            anthropic::AnthropicProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>)
33        }
34        "gemini" => gemini::GeminiProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>),
35        other => Err(ProviderError::UnknownProvider(other.to_string())),
36    }
37}
38
39pub fn infer_model(
40    model: impl AsRef<str>,
41    provider_factory: impl Fn(&str) -> Result<Box<dyn Provider>, ProviderError>,
42) -> Result<Arc<dyn Model>, ProviderError> {
43    let model = model.as_ref();
44    let (provider_name, model_name) = match model.split_once(':') {
45        Some((provider, name)) => (provider, name),
46        None => (infer_provider_from_model(model)?, model),
47    };
48
49    let provider = provider_factory(provider_name)?;
50    Ok(provider.model(model_name, None))
51}
52
53fn infer_provider_from_model(model: &str) -> Result<&'static str, ProviderError> {
54    let lowered = model.to_lowercase();
55    if lowered.starts_with("gpt") || lowered.starts_with("o1") || lowered.starts_with("o3") {
56        return Ok("openai");
57    }
58    if lowered.starts_with("claude") {
59        return Ok("anthropic");
60    }
61    if lowered.starts_with("gemini") {
62        return Ok("gemini");
63    }
64    if lowered.starts_with("grok") {
65        return Ok("grok");
66    }
67    Err(ProviderError::InvalidModel(model.to_string()))
68}