steer_core/api/
factory.rs

1use crate::api::error::ApiError;
2use crate::api::provider::Provider;
3use crate::api::{
4    claude::AnthropicClient, gemini::GeminiClient, openai::OpenAIClient, xai::XAIClient,
5};
6use crate::auth::storage::Credential;
7use crate::config::provider::{ApiFormat, ProviderConfig};
8use std::sync::Arc;
9
10/// Factory function to create a provider instance based on the provider config and credential.
11///
12/// This function dispatches to the correct API client implementation based on the provider's
13/// API format. It also supports base URL overrides for custom providers using compatible
14/// API formats.
15pub fn create_provider(
16    provider_cfg: &ProviderConfig,
17    credential: &Credential,
18) -> Result<Arc<dyn Provider>, ApiError> {
19    match credential {
20        Credential::ApiKey { value } => match &provider_cfg.api_format {
21            ApiFormat::OpenaiResponses => {
22                let client = if let Some(base_url) = &provider_cfg.base_url {
23                    OpenAIClient::with_base_url_mode(
24                        value.clone(),
25                        Some(base_url.to_string()),
26                        crate::api::openai::OpenAIMode::Responses,
27                    )
28                } else {
29                    OpenAIClient::with_mode(
30                        value.clone(),
31                        crate::api::openai::OpenAIMode::Responses,
32                    )
33                };
34                Ok(Arc::new(client))
35            }
36            ApiFormat::OpenaiChat => {
37                let client = if let Some(base_url) = &provider_cfg.base_url {
38                    OpenAIClient::with_base_url_mode(
39                        value.clone(),
40                        Some(base_url.to_string()),
41                        crate::api::openai::OpenAIMode::Chat,
42                    )
43                } else {
44                    OpenAIClient::with_mode(value.clone(), crate::api::openai::OpenAIMode::Chat)
45                };
46                Ok(Arc::new(client))
47            }
48            ApiFormat::Anthropic => {
49                // TODO: Add base_url support to AnthropicClient
50                if provider_cfg.base_url.is_some() {
51                    return Err(ApiError::Configuration(
52                        "Base URL override not yet supported for Anthropic API format".to_string(),
53                    ));
54                }
55                Ok(Arc::new(AnthropicClient::with_api_key(value)))
56            }
57            ApiFormat::Google => {
58                // TODO: Add base_url support to GeminiClient
59                if provider_cfg.base_url.is_some() {
60                    return Err(ApiError::Configuration(
61                        "Base URL override not yet supported for Gemini API format".to_string(),
62                    ));
63                }
64                Ok(Arc::new(GeminiClient::new(value)))
65            }
66            ApiFormat::Xai => {
67                let client = if let Some(base_url) = &provider_cfg.base_url {
68                    XAIClient::with_base_url(value.clone(), Some(base_url.to_string()))
69                } else {
70                    XAIClient::new(value.clone())
71                };
72                Ok(Arc::new(client))
73            }
74        },
75        Credential::OAuth2(_) => {
76            // Only Anthropic supports OAuth currently
77            match &provider_cfg.api_format {
78                ApiFormat::Anthropic => {
79                    // OAuth for Anthropic requires the storage, which we don't have here
80                    // This will be handled differently in the refactored code
81                    Err(ApiError::Configuration(
82                        "OAuth support requires auth storage context".to_string(),
83                    ))
84                }
85                _ => Err(ApiError::Configuration(format!(
86                    "OAuth is not supported for {:?} API format",
87                    provider_cfg.api_format
88                ))),
89            }
90        }
91    }
92}
93
94/// Factory function that creates a provider with OAuth support (requires storage).
95///
96/// This is a separate function because OAuth providers need access to the auth storage
97/// to refresh tokens.
98pub fn create_provider_with_storage(
99    provider_cfg: &ProviderConfig,
100    credential: &Credential,
101    storage: Arc<dyn crate::auth::AuthStorage>,
102) -> Result<Arc<dyn Provider>, ApiError> {
103    match credential {
104        Credential::ApiKey { .. } => create_provider(provider_cfg, credential),
105        Credential::OAuth2(_) => match &provider_cfg.api_format {
106            ApiFormat::Anthropic => {
107                if provider_cfg.base_url.is_some() {
108                    return Err(ApiError::Configuration(
109                        "Base URL override not supported with OAuth authentication".to_string(),
110                    ));
111                }
112                Ok(Arc::new(AnthropicClient::with_oauth(storage)))
113            }
114            _ => Err(ApiError::Configuration(format!(
115                "OAuth is not supported for {:?} API format",
116                provider_cfg.api_format
117            ))),
118        },
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::config::provider::{self, AuthScheme, ProviderId};
126
127    #[test]
128    fn test_create_openai_provider() {
129        let config = ProviderConfig {
130            id: provider::openai(),
131            name: "OpenAI".to_string(),
132            api_format: ApiFormat::OpenaiResponses,
133            auth_schemes: vec![AuthScheme::ApiKey],
134            base_url: None,
135        };
136
137        let credential = Credential::ApiKey {
138            value: "test-key".to_string(),
139        };
140
141        let provider = create_provider(&config, &credential).unwrap();
142        assert_eq!(provider.name(), "openai");
143    }
144
145    #[test]
146    fn test_create_custom_openai_provider() {
147        let config = ProviderConfig {
148            id: ProviderId("my-provider".to_string()),
149            name: "My Provider".to_string(),
150            api_format: ApiFormat::OpenaiResponses,
151            auth_schemes: vec![AuthScheme::ApiKey],
152            base_url: Some("https://my-api.example.com".parse().unwrap()),
153        };
154
155        let credential = Credential::ApiKey {
156            value: "test-key".to_string(),
157        };
158
159        let provider = create_provider(&config, &credential).unwrap();
160        assert_eq!(provider.name(), "openai"); // Still uses OpenAI client
161    }
162
163    #[test]
164    fn test_oauth_requires_storage() {
165        let config = ProviderConfig {
166            id: provider::anthropic(),
167            name: "Anthropic".to_string(),
168            api_format: ApiFormat::Anthropic,
169            auth_schemes: vec![AuthScheme::Oauth2],
170            base_url: None,
171        };
172
173        let credential = Credential::OAuth2(crate::auth::storage::OAuth2Token {
174            access_token: "test-token".to_string(),
175            refresh_token: "test-refresh".to_string(),
176            expires_at: std::time::SystemTime::now(),
177        });
178
179        let result = create_provider(&config, &credential);
180        assert!(result.is_err());
181        let err_msg = match result {
182            Err(e) => e.to_string(),
183            Ok(_) => panic!("Expected error"),
184        };
185        assert!(err_msg.contains("OAuth support requires auth storage"));
186    }
187}