Skip to main content

rustic_ai/providers/
mod.rs

1use std::sync::Arc;
2
3use thiserror::Error;
4
5use crate::model::{Model, ModelSettings};
6
7pub mod anthropic;
8pub mod gemini;
9pub mod grok;
10pub mod openai;
11
12pub trait Provider: Send + Sync {
13    fn name(&self) -> &str;
14    fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model>;
15}
16
17#[derive(Debug, Error)]
18pub enum ProviderError {
19    #[error("unknown provider: {0}")]
20    UnknownProvider(String),
21    #[error("missing API key for provider: {0}")]
22    MissingApiKey(String),
23    #[error("invalid model string: {0}")]
24    InvalidModel(String),
25}
26
27pub fn infer_provider(name: &str) -> Result<Box<dyn Provider>, ProviderError> {
28    match name {
29        "openai" => openai::OpenAIProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>),
30        "grok" => grok::GrokProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>),
31        "anthropic" => {
32            anthropic::AnthropicProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>)
33        }
34        "gemini" => gemini::GeminiProvider::from_env().map(|p| Box::new(p) as Box<dyn Provider>),
35        other => Err(ProviderError::UnknownProvider(other.to_string())),
36    }
37}
38
39pub fn infer_model(
40    model: impl AsRef<str>,
41    provider_factory: impl Fn(&str) -> Result<Box<dyn Provider>, ProviderError>,
42) -> Result<Arc<dyn Model>, ProviderError> {
43    let model = model.as_ref();
44    let (provider_name, model_name) = match model.split_once(':') {
45        Some((provider, name)) => (provider, name),
46        None => (infer_provider_from_model(model)?, model),
47    };
48
49    let provider = provider_factory(provider_name)?;
50    Ok(provider.model(model_name, None))
51}
52
53fn infer_provider_from_model(model: &str) -> Result<&'static str, ProviderError> {
54    let lowered = model.to_lowercase();
55    if lowered.starts_with("gpt") || lowered.starts_with("o1") || lowered.starts_with("o3") {
56        return Ok("openai");
57    }
58    if lowered.starts_with("claude") {
59        return Ok("anthropic");
60    }
61    if lowered.starts_with("gemini") {
62        return Ok("gemini");
63    }
64    if lowered.starts_with("grok") {
65        return Ok("grok");
66    }
67    Err(ProviderError::InvalidModel(model.to_string()))
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use async_trait::async_trait;
74
75    struct StubProvider;
76
77    impl Provider for StubProvider {
78        fn name(&self) -> &str {
79            "stub"
80        }
81
82        fn model(&self, model: &str, _settings: Option<ModelSettings>) -> Arc<dyn Model> {
83            struct StubModel {
84                name: String,
85            }
86            #[async_trait]
87            impl Model for StubModel {
88                fn name(&self) -> &str {
89                    &self.name
90                }
91
92                async fn request(
93                    &self,
94                    _messages: &[crate::messages::ModelMessage],
95                    _settings: Option<&ModelSettings>,
96                    _params: &crate::model::ModelRequestParameters,
97                ) -> Result<crate::messages::ModelResponse, crate::model::ModelError>
98                {
99                    Err(crate::model::ModelError::Unsupported(
100                        "not implemented".to_string(),
101                    ))
102                }
103            }
104
105            Arc::new(StubModel {
106                name: model.to_string(),
107            })
108        }
109    }
110
111    #[test]
112    fn infer_provider_from_model_matches_prefixes() {
113        assert_eq!(infer_provider_from_model("gpt-4o").unwrap(), "openai");
114        assert_eq!(infer_provider_from_model("o1-mini").unwrap(), "openai");
115        assert_eq!(infer_provider_from_model("o3-mini").unwrap(), "openai");
116        assert_eq!(infer_provider_from_model("claude-3").unwrap(), "anthropic");
117        assert_eq!(infer_provider_from_model("gemini-1.5").unwrap(), "gemini");
118        assert_eq!(infer_provider_from_model("grok-2").unwrap(), "grok");
119    }
120
121    #[test]
122    fn infer_provider_from_model_rejects_unknown() {
123        let err = infer_provider_from_model("unknown-model").expect_err("unknown provider");
124        assert!(matches!(err, ProviderError::InvalidModel(_)));
125    }
126
127    #[test]
128    fn infer_model_uses_explicit_provider() {
129        let model = infer_model("stub:example", |_| Ok(Box::new(StubProvider)))
130            .expect("model from stub provider");
131        assert_eq!(model.name(), "example");
132    }
133
134    #[test]
135    fn infer_model_infers_provider_without_prefix() {
136        let model = infer_model("gpt-4o-mini", |_| Ok(Box::new(StubProvider)))
137            .expect("model from inferred provider");
138        assert_eq!(model.name(), "gpt-4o-mini");
139    }
140}