Skip to main content

steer_core/api/
factory.rs

1use crate::api::error::ApiError;
2use crate::api::provider::Provider;
3use crate::api::{
4    claude::AnthropicClient, gemini::GeminiClient, openai::OpenAIClient, xai::XAIClient,
5};
6use crate::auth::{AuthDirective, Credential};
7use crate::config::provider::{ApiFormat, ProviderConfig};
8use std::sync::Arc;
9
10/// Factory function to create a provider instance based on the provider config and credential.
11///
12/// This function dispatches to the correct API client implementation based on the provider's
13/// API format. It also supports base URL overrides for custom providers using compatible
14/// API formats.
15pub fn create_provider(
16    provider_cfg: &ProviderConfig,
17    credential: &Credential,
18) -> Result<Arc<dyn Provider>, ApiError> {
19    match credential {
20        Credential::ApiKey { value } => match &provider_cfg.api_format {
21            ApiFormat::OpenaiResponses => {
22                let client = if let Some(base_url) = &provider_cfg.base_url {
23                    OpenAIClient::with_base_url_mode(
24                        value.clone(),
25                        Some(base_url.to_string()),
26                        crate::api::openai::OpenAIMode::Responses,
27                    )?
28                } else {
29                    OpenAIClient::with_mode(
30                        value.clone(),
31                        crate::api::openai::OpenAIMode::Responses,
32                    )?
33                };
34                Ok(Arc::new(client))
35            }
36            ApiFormat::OpenaiChat => {
37                let client = if let Some(base_url) = &provider_cfg.base_url {
38                    OpenAIClient::with_base_url_mode(
39                        value.clone(),
40                        Some(base_url.to_string()),
41                        crate::api::openai::OpenAIMode::Chat,
42                    )?
43                } else {
44                    OpenAIClient::with_mode(value.clone(), crate::api::openai::OpenAIMode::Chat)?
45                };
46                Ok(Arc::new(client))
47            }
48            ApiFormat::Anthropic => {
49                // TODO: Add base_url support to AnthropicClient
50                if provider_cfg.base_url.is_some() {
51                    return Err(ApiError::Configuration(
52                        "Base URL override not yet supported for Anthropic API format".to_string(),
53                    ));
54                }
55                Ok(Arc::new(AnthropicClient::with_api_key(value)?))
56            }
57            ApiFormat::Google => {
58                // TODO: Add base_url support to GeminiClient
59                if provider_cfg.base_url.is_some() {
60                    return Err(ApiError::Configuration(
61                        "Base URL override not yet supported for Gemini API format".to_string(),
62                    ));
63                }
64                Ok(Arc::new(GeminiClient::new(value)))
65            }
66            ApiFormat::Xai => {
67                let client = if let Some(base_url) = &provider_cfg.base_url {
68                    XAIClient::with_base_url(value.clone(), Some(base_url.to_string()))?
69                } else {
70                    XAIClient::new(value.clone())?
71                };
72                Ok(Arc::new(client))
73            }
74        },
75        Credential::OAuth2(_) => Err(ApiError::Configuration(
76            "OAuth requires an AuthDirective, not a raw credential".to_string(),
77        )),
78    }
79}
80
81/// Factory function that creates a provider using an auth directive.
82pub fn create_provider_with_directive(
83    provider_cfg: &ProviderConfig,
84    directive: &AuthDirective,
85) -> Result<Arc<dyn Provider>, ApiError> {
86    match directive {
87        AuthDirective::OpenAiResponses(openai) => {
88            if provider_cfg.api_format != ApiFormat::OpenaiResponses {
89                return Err(ApiError::Configuration(
90                    "OpenAI OAuth directives require responses API format".to_string(),
91                ));
92            }
93            let base_url = provider_cfg.base_url.as_ref().map(|url| url.to_string());
94            Ok(Arc::new(OpenAIClient::with_directive(
95                openai.clone(),
96                base_url,
97            )?))
98        }
99        AuthDirective::Anthropic(anthropic) => {
100            if provider_cfg.api_format != ApiFormat::Anthropic {
101                return Err(ApiError::Configuration(
102                    "Anthropic OAuth directives require Anthropic API format".to_string(),
103                ));
104            }
105            if provider_cfg.base_url.is_some() {
106                return Err(ApiError::Configuration(
107                    "Base URL override not yet supported for Anthropic API format".to_string(),
108                ));
109            }
110            Ok(Arc::new(AnthropicClient::with_directive(
111                anthropic.clone(),
112            )?))
113        }
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::config::provider::{self, AuthScheme, ProviderId};
121
122    #[test]
123    fn test_create_openai_provider() {
124        let config = ProviderConfig {
125            id: provider::openai(),
126            name: "OpenAI".to_string(),
127            api_format: ApiFormat::OpenaiResponses,
128            auth_schemes: vec![AuthScheme::ApiKey],
129            base_url: None,
130        };
131
132        let credential = Credential::ApiKey {
133            value: "test-key".to_string(),
134        };
135
136        let provider = create_provider(&config, &credential).unwrap();
137        assert_eq!(provider.name(), "openai");
138    }
139
140    #[test]
141    fn test_create_custom_openai_provider() {
142        let config = ProviderConfig {
143            id: ProviderId("my-provider".to_string()),
144            name: "My Provider".to_string(),
145            api_format: ApiFormat::OpenaiResponses,
146            auth_schemes: vec![AuthScheme::ApiKey],
147            base_url: Some("https://my-api.example.com".parse().unwrap()),
148        };
149
150        let credential = Credential::ApiKey {
151            value: "test-key".to_string(),
152        };
153
154        let provider = create_provider(&config, &credential).unwrap();
155        assert_eq!(provider.name(), "openai"); // Still uses OpenAI client
156    }
157
158    #[test]
159    fn test_oauth_requires_directive() {
160        let config = ProviderConfig {
161            id: provider::anthropic(),
162            name: "Anthropic".to_string(),
163            api_format: ApiFormat::Anthropic,
164            auth_schemes: vec![AuthScheme::Oauth2],
165            base_url: None,
166        };
167
168        let credential = Credential::OAuth2(crate::auth::storage::OAuth2Token {
169            access_token: "test-token".to_string(),
170            refresh_token: "test-refresh".to_string(),
171            expires_at: std::time::SystemTime::now(),
172            id_token: None,
173        });
174
175        let result = create_provider(&config, &credential);
176        assert!(result.is_err());
177        let err_msg = match result {
178            Err(e) => e.to_string(),
179            Ok(_) => panic!("Expected error"),
180        };
181        assert!(err_msg.contains("OAuth requires an AuthDirective"));
182    }
183}